diff --git a/.mypy.ini b/.mypy.ini index ef871684..95dc41a5 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -7,17 +7,25 @@ # disallow_untyped_defs = True -- for later # strict = True -- only try to enable once everything (including dependencies!) is typed -# strict_equality = True -- for later -# strict_bytes = True -- for later +strict_equality = True +strict_bytes = True -# warn_redundant_casts = True -- for later +warn_redundant_casts = True # warn_return_any = True -- for later -# warn_unreachable = True -- for later +warn_unreachable = True [mypy-ansible.*] # ansible-core has partial typing information follow_untyped_imports = True +[mypy-docker.*] +# Docker SDK for Python has partial typing information +follow_untyped_imports = True + [mypy-ansible_collections.community.internal_test_tools.*] # community.internal_test_tools has no typing information ignore_missing_imports = True + +[mypy-jsondiff.*] +# jsondiff has no typing information +ignore_missing_imports = True diff --git a/antsibull-nox.toml b/antsibull-nox.toml index 33f52dbb..d37d52b7 100644 --- a/antsibull-nox.toml +++ b/antsibull-nox.toml @@ -27,7 +27,7 @@ run_yamllint = true yamllint_config = ".yamllint" yamllint_config_plugins = ".yamllint-docs" yamllint_config_plugins_examples = ".yamllint-examples" -run_mypy = false +run_mypy = true mypy_ansible_core_package = "ansible-core>=2.19.0" mypy_config = ".mypy.ini" mypy_extra_deps = [ @@ -35,7 +35,11 @@ mypy_extra_deps = [ "paramiko", "urllib3", "requests", + "types-mock", + "types-paramiko", + "types-pywin32", "types-PyYAML", + "types-requests", ] [sessions.docs_check] diff --git a/plugins/action/docker_container_copy_into.py b/plugins/action/docker_container_copy_into.py index 96c6afbb..ca11d553 100644 --- a/plugins/action/docker_container_copy_into.py +++ b/plugins/action/docker_container_copy_into.py @@ -5,6 +5,7 @@ from __future__ import annotations import base64 +import typing as t from ansible import constants as C from ansible.plugins.action import ActionBase @@ -19,14 +20,17 @@ class ActionModule(ActionBase): # Set to True when transferring files to the remote TRANSFERS_FILES = False - def run(self, tmp=None, task_vars=None): + def run( + self, tmp: str | None = None, task_vars: dict[str, t.Any] | None = None + ) -> dict[str, t.Any]: self._supports_check_mode = True self._supports_async = True result = super().run(tmp, task_vars) del tmp # tmp no longer has any effect - self._task.args["_max_file_size_for_diff"] = C.MAX_FILE_SIZE_FOR_DIFF + max_file_size_for_diff: int = C.MAX_FILE_SIZE_FOR_DIFF # type: ignore + self._task.args["_max_file_size_for_diff"] = max_file_size_for_diff result = merge_hash( result, diff --git a/plugins/connection/docker.py b/plugins/connection/docker.py index 1f264e6c..fd721388 100644 --- a/plugins/connection/docker.py +++ b/plugins/connection/docker.py @@ -118,6 +118,7 @@ import os.path import re import selectors import subprocess +import typing as t from shlex import quote from ansible.errors import AnsibleConnectionFailure, AnsibleError, AnsibleFileNotFound @@ -140,8 +141,8 @@ class Connection(ConnectionBase): transport = "community.docker.docker" has_pipelining = True - def __init__(self, play_context, new_stdin, *args, **kwargs): - super().__init__(play_context, new_stdin, *args, **kwargs) + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) # Note: docker supports running as non-root in some configurations. # (For instance, setting the UNIX socket file to be readable and @@ -152,11 +153,11 @@ class Connection(ConnectionBase): # configured to be connected to by root and they are not running as # root. - self._docker_args = [] - self._container_user_cache = {} - self._version = None - self.remote_user = None - self.timeout = None + self._docker_args: list[bytes | str] = [] + self._container_user_cache: dict[str, str | None] = {} + self._version: str | None = None + self.remote_user: str | None = None + self.timeout: int | float | None = None # Windows uses Powershell modules if getattr(self._shell, "_IS_WINDOWS", False): @@ -171,12 +172,12 @@ class Connection(ConnectionBase): raise AnsibleError("docker command not found in PATH") from exc @staticmethod - def _sanitize_version(version): + def _sanitize_version(version: str) -> str: version = re.sub("[^0-9a-zA-Z.]", "", version) version = re.sub("^v", "", version) return version - def _old_docker_version(self): + def _old_docker_version(self) -> tuple[list[str], str, bytes, int]: cmd_args = self._docker_args old_version_subcommand = ["version"] @@ -189,7 +190,7 @@ class Connection(ConnectionBase): return old_docker_cmd, to_native(cmd_output), err, p.returncode - def _new_docker_version(self): + def _new_docker_version(self) -> tuple[list[str], str, bytes, int]: # no result yet, must be newer Docker version cmd_args = self._docker_args @@ -202,8 +203,7 @@ class Connection(ConnectionBase): cmd_output, err = p.communicate() return new_docker_cmd, to_native(cmd_output), err, p.returncode - def _get_docker_version(self): - + def _get_docker_version(self) -> str: cmd, cmd_output, err, returncode = self._old_docker_version() if returncode == 0: for line in to_text(cmd_output, errors="surrogate_or_strict").split("\n"): @@ -218,7 +218,7 @@ class Connection(ConnectionBase): return self._sanitize_version(to_text(cmd_output, errors="surrogate_or_strict")) - def _get_docker_remote_user(self): + def _get_docker_remote_user(self) -> str | None: """Get the default user configured in the docker container""" container = self.get_option("remote_addr") if container in self._container_user_cache: @@ -243,7 +243,7 @@ class Connection(ConnectionBase): self._container_user_cache[container] = user return user - def _build_exec_cmd(self, cmd): + def _build_exec_cmd(self, cmd: list[bytes | str]) -> list[bytes | str]: """Build the local docker exec command to run cmd on remote_host If remote_user is available and is supported by the docker @@ -298,7 +298,7 @@ class Connection(ConnectionBase): return local_cmd - def _set_docker_args(self): + def _set_docker_args(self) -> None: # TODO: this is mostly for backwards compatibility, play_context is used as fallback for older versions # docker arguments del self._docker_args[:] @@ -308,7 +308,7 @@ class Connection(ConnectionBase): if extra_args: self._docker_args += extra_args.split(" ") - def _set_conn_data(self): + def _set_conn_data(self) -> None: """initialize for the connection, cannot do only in init since all data is not ready at that point""" self._set_docker_args() @@ -323,8 +323,7 @@ class Connection(ConnectionBase): self.timeout = self._play_context.timeout @property - def docker_version(self): - + def docker_version(self) -> str: if not self._version: self._set_docker_args() @@ -341,7 +340,7 @@ class Connection(ConnectionBase): ) return self._version - def _get_actual_user(self): + def _get_actual_user(self) -> str | None: if self.remote_user is not None: # An explicit user is provided if self.docker_version == "dev" or LooseVersion( @@ -353,7 +352,7 @@ class Connection(ConnectionBase): actual_user = self._get_docker_remote_user() if actual_user != self.get_option("remote_user"): display.warning( - f'docker {self.docker_version} does not support remote_user, using container default: {actual_user or "?"}' + f"docker {self.docker_version} does not support remote_user, using container default: {actual_user or '?'}" ) return actual_user if self._display.verbosity > 2: @@ -363,9 +362,9 @@ class Connection(ConnectionBase): return self._get_docker_remote_user() return None - def _connect(self, port=None): + def _connect(self) -> t.Self: """Connect to the container. Nothing to do""" - super()._connect() + super()._connect() # type: ignore[safe-super] if not self._connected: self._set_conn_data() actual_user = self._get_actual_user() @@ -374,13 +373,16 @@ class Connection(ConnectionBase): host=self.get_option("remote_addr"), ) self._connected = True + return self - def exec_command(self, cmd, in_data=None, sudoable=False): + def exec_command( + self, cmd: str, in_data: bytes | None = None, sudoable: bool = False + ) -> tuple[int, bytes, bytes]: """Run a command on the docker host""" self._set_conn_data() - super().exec_command(cmd, in_data=in_data, sudoable=sudoable) + super().exec_command(cmd, in_data=in_data, sudoable=sudoable) # type: ignore[safe-super] local_cmd = self._build_exec_cmd([self._play_context.executable, "-c", cmd]) @@ -395,6 +397,9 @@ class Connection(ConnectionBase): stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) as p: + assert p.stdin is not None + assert p.stdout is not None + assert p.stderr is not None display.debug("done running command with Popen()") if self.become and self.become.expect_prompt() and sudoable: @@ -489,10 +494,10 @@ class Connection(ConnectionBase): remote_path = os.path.join(os.path.sep, remote_path) return os.path.normpath(remote_path) - def put_file(self, in_path, out_path): + def put_file(self, in_path: str, out_path: str) -> None: """Transfer a file from local to docker container""" self._set_conn_data() - super().put_file(in_path, out_path) + super().put_file(in_path, out_path) # type: ignore[safe-super] display.vvv(f"PUT {in_path} TO {out_path}", host=self.get_option("remote_addr")) out_path = self._prefix_login_path(out_path) @@ -534,10 +539,10 @@ class Connection(ConnectionBase): f"failed to transfer file {to_native(in_path)} to {to_native(out_path)}:\n{to_native(stdout)}\n{to_native(stderr)}" ) - def fetch_file(self, in_path, out_path): + def fetch_file(self, in_path: str, out_path: str) -> None: """Fetch a file from container to local.""" self._set_conn_data() - super().fetch_file(in_path, out_path) + super().fetch_file(in_path, out_path) # type: ignore[safe-super] display.vvv( f"FETCH {in_path} TO {out_path}", host=self.get_option("remote_addr") ) @@ -596,7 +601,7 @@ class Connection(ConnectionBase): if pp.returncode != 0: raise AnsibleError( - f"failed to fetch file {in_path} to {out_path}:\n{stdout}\n{stderr}" + f"failed to fetch file {in_path} to {out_path}:\n{stdout!r}\n{stderr!r}" ) # Rename if needed @@ -606,11 +611,11 @@ class Connection(ConnectionBase): to_bytes(out_path, errors="strict"), ) - def close(self): + def close(self) -> None: """Terminate the connection. Nothing to do for Docker""" - super().close() + super().close() # type: ignore[safe-super] self._connected = False - def reset(self): + def reset(self) -> None: # Clear container user cache self._container_user_cache = {} diff --git a/plugins/connection/docker_api.py b/plugins/connection/docker_api.py index f871549e..dd0cc479 100644 --- a/plugins/connection/docker_api.py +++ b/plugins/connection/docker_api.py @@ -107,6 +107,7 @@ options: import os import os.path +import typing as t from ansible.errors import AnsibleConnectionFailure, AnsibleFileNotFound from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text @@ -138,6 +139,12 @@ from ansible_collections.community.docker.plugins.plugin_utils._socket_handler i ) +if t.TYPE_CHECKING: + from collections.abc import Callable + + _T = t.TypeVar("_T") + + MIN_DOCKER_API = None @@ -150,10 +157,16 @@ class Connection(ConnectionBase): transport = "community.docker.docker_api" has_pipelining = True - def _call_client(self, f, not_found_can_be_resource=False): + def _call_client( + self, + f: Callable[[AnsibleDockerClient], _T], + not_found_can_be_resource: bool = False, + ) -> _T: + if self.client is None: + raise AssertionError("Client must be present") remote_addr = self.get_option("remote_addr") try: - return f() + return f(self.client) except NotFound as e: if not_found_can_be_resource: raise AnsibleConnectionFailure( @@ -179,21 +192,21 @@ class Connection(ConnectionBase): f'An unexpected requests error occurred for container "{remote_addr}" when trying to talk to the Docker daemon: {e}' ) - def __init__(self, play_context, new_stdin, *args, **kwargs): - super().__init__(play_context, new_stdin, *args, **kwargs) + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) - self.client = None - self.ids = {} + self.client: AnsibleDockerClient | None = None + self.ids: dict[str | None, tuple[int, int]] = {} # Windows uses Powershell modules if getattr(self._shell, "_IS_WINDOWS", False): self.module_implementation_preferences = (".ps1", ".exe", "") - self.actual_user = None + self.actual_user: str | None = None - def _connect(self, port=None): + def _connect(self) -> Connection: """Connect to the container. Nothing to do""" - super()._connect() + super()._connect() # type: ignore[safe-super] if not self._connected: self.actual_user = self.get_option("remote_user") display.vvv( @@ -212,7 +225,7 @@ class Connection(ConnectionBase): # This saves overhead from calling into docker when we do not need to display.vvv("Trying to determine actual user") result = self._call_client( - lambda: self.client.get_json( + lambda client: client.get_json( "/containers/{0}/json", self.get_option("remote_addr") ) ) @@ -221,12 +234,19 @@ class Connection(ConnectionBase): if self.actual_user is not None: display.vvv(f"Actual user is '{self.actual_user}'") - def exec_command(self, cmd, in_data=None, sudoable=False): + return self + + def exec_command( + self, cmd: str, in_data: bytes | None = None, sudoable: bool = False + ) -> tuple[int, bytes, bytes]: """Run a command on the docker host""" - super().exec_command(cmd, in_data=in_data, sudoable=sudoable) + super().exec_command(cmd, in_data=in_data, sudoable=sudoable) # type: ignore[safe-super] - command = [self._play_context.executable, "-c", to_text(cmd)] + if self.client is None: + raise AssertionError("Client must be present") + + command = [self._play_context.executable, "-c", cmd] do_become = self.become and self.become.expect_prompt() and sudoable @@ -277,7 +297,7 @@ class Connection(ConnectionBase): ) exec_data = self._call_client( - lambda: self.client.post_json_to_json( + lambda client: client.post_json_to_json( "/containers/{0}/exec", self.get_option("remote_addr"), data=data ) ) @@ -286,7 +306,7 @@ class Connection(ConnectionBase): data = {"Tty": False, "Detach": False} if need_stdin: exec_socket = self._call_client( - lambda: self.client.post_json_to_stream_socket( + lambda client: client.post_json_to_stream_socket( "/exec/{0}/start", exec_id, data=data ) ) @@ -295,6 +315,8 @@ class Connection(ConnectionBase): display, exec_socket, container=self.get_option("remote_addr") ) as exec_socket_handler: if do_become: + assert self.become is not None + become_output = [b""] def append_become_output(stream_id, data): @@ -339,7 +361,7 @@ class Connection(ConnectionBase): exec_socket.close() else: stdout, stderr = self._call_client( - lambda: self.client.post_json_to_stream( + lambda client: client.post_json_to_stream( "/exec/{0}/start", exec_id, stream=False, @@ -350,12 +372,12 @@ class Connection(ConnectionBase): ) result = self._call_client( - lambda: self.client.get_json("/exec/{0}/json", exec_id) + lambda client: client.get_json("/exec/{0}/json", exec_id) ) return result.get("ExitCode") or 0, stdout or b"", stderr or b"" - def _prefix_login_path(self, remote_path): + def _prefix_login_path(self, remote_path: str) -> str: """Make sure that we put files into a standard path If a path is relative, then we need to choose where to put it. @@ -373,19 +395,23 @@ class Connection(ConnectionBase): remote_path = os.path.join(os.path.sep, remote_path) return os.path.normpath(remote_path) - def put_file(self, in_path, out_path): + def put_file(self, in_path: str, out_path: str) -> None: """Transfer a file from local to docker container""" - super().put_file(in_path, out_path) + super().put_file(in_path, out_path) # type: ignore[safe-super] display.vvv(f"PUT {in_path} TO {out_path}", host=self.get_option("remote_addr")) + if self.client is None: + raise AssertionError("Client must be present") + out_path = self._prefix_login_path(out_path) if self.actual_user not in self.ids: - dummy, ids, dummy = self.exec_command(b"id -u && id -g") + dummy, ids, dummy2 = self.exec_command("id -u && id -g") remote_addr = self.get_option("remote_addr") try: - user_id, group_id = ids.splitlines() - self.ids[self.actual_user] = int(user_id), int(group_id) + b_user_id, b_group_id = ids.splitlines() + user_id, group_id = int(b_user_id), int(b_group_id) + self.ids[self.actual_user] = user_id, group_id display.vvvv( f'PUT: Determined uid={user_id} and gid={group_id} for user "{self.actual_user}"', host=remote_addr, @@ -398,8 +424,8 @@ class Connection(ConnectionBase): user_id, group_id = self.ids[self.actual_user] try: self._call_client( - lambda: put_file( - self.client, + lambda client: put_file( + client, container=self.get_option("remote_addr"), in_path=in_path, out_path=out_path, @@ -415,19 +441,22 @@ class Connection(ConnectionBase): except DockerFileCopyError as exc: raise AnsibleConnectionFailure(to_native(exc)) from exc - def fetch_file(self, in_path, out_path): + def fetch_file(self, in_path: str, out_path: str) -> None: """Fetch a file from container to local.""" - super().fetch_file(in_path, out_path) + super().fetch_file(in_path, out_path) # type: ignore[safe-super] display.vvv( f"FETCH {in_path} TO {out_path}", host=self.get_option("remote_addr") ) + if self.client is None: + raise AssertionError("Client must be present") + in_path = self._prefix_login_path(in_path) try: self._call_client( - lambda: fetch_file( - self.client, + lambda client: fetch_file( + client, container=self.get_option("remote_addr"), in_path=in_path, out_path=out_path, @@ -443,10 +472,10 @@ class Connection(ConnectionBase): except DockerFileCopyError as exc: raise AnsibleConnectionFailure(to_native(exc)) from exc - def close(self): + def close(self) -> None: """Terminate the connection. Nothing to do for Docker""" - super().close() + super().close() # type: ignore[safe-super] self._connected = False - def reset(self): + def reset(self) -> None: self.ids.clear() diff --git a/plugins/connection/nsenter.py b/plugins/connection/nsenter.py index 695ccdc9..b65803c3 100644 --- a/plugins/connection/nsenter.py +++ b/plugins/connection/nsenter.py @@ -44,7 +44,9 @@ import fcntl import os import pty import selectors +import shlex import subprocess +import typing as t import ansible.constants as C from ansible.errors import AnsibleError @@ -63,12 +65,12 @@ class Connection(ConnectionBase): transport = "community.docker.nsenter" has_pipelining = False - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.cwd = None self._nsenter_pid = None - def _connect(self): + def _connect(self) -> t.Self: self._nsenter_pid = self.get_option("nsenter_pid") # Because nsenter requires very high privileges, our remote user @@ -83,12 +85,15 @@ class Connection(ConnectionBase): self._connected = True return self - def exec_command(self, cmd, in_data=None, sudoable=True): - super().exec_command(cmd, in_data=in_data, sudoable=sudoable) + def exec_command( + self, cmd: str, in_data: bytes | None = None, sudoable: bool = True + ) -> tuple[int, bytes, bytes]: + super().exec_command(cmd, in_data=in_data, sudoable=sudoable) # type: ignore[safe-super] display.debug("in nsenter.exec_command()") - executable = C.DEFAULT_EXECUTABLE.split()[0] if C.DEFAULT_EXECUTABLE else None + def_executable: str | None = C.DEFAULT_EXECUTABLE # type: ignore[attr-defined] + executable = def_executable.split()[0] if def_executable else None if not os.path.exists(to_bytes(executable, errors="surrogate_or_strict")): raise AnsibleError( @@ -109,12 +114,8 @@ class Connection(ConnectionBase): "--", ] - if isinstance(cmd, (str, bytes)): - cmd_parts = nsenter_cmd_parts + [cmd] - cmd = to_bytes(" ".join(cmd_parts)) - else: - cmd_parts = nsenter_cmd_parts + cmd - cmd = [to_bytes(arg) for arg in cmd_parts] + cmd_parts = nsenter_cmd_parts + [cmd] + cmd = to_bytes(" ".join(cmd_parts)) display.vvv(f"EXEC {to_text(cmd)}", host=self._play_context.remote_addr) display.debug("opening command with Popen()") @@ -143,6 +144,9 @@ class Connection(ConnectionBase): stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) as p: + assert p.stderr is not None + assert p.stdin is not None + assert p.stdout is not None # if we created a master, we can close the other half of the pty now, otherwise master is stdin if master is not None: os.close(stdin) @@ -234,8 +238,8 @@ class Connection(ConnectionBase): display.debug("done with nsenter.exec_command()") return (p.returncode, stdout, stderr) - def put_file(self, in_path, out_path): - super().put_file(in_path, out_path) + def put_file(self, in_path: str, out_path: str) -> None: + super().put_file(in_path, out_path) # type: ignore[safe-super] in_path = unfrackpath(in_path, basedir=self.cwd) out_path = unfrackpath(out_path, basedir=self.cwd) @@ -245,26 +249,30 @@ class Connection(ConnectionBase): with open(to_bytes(in_path, errors="surrogate_or_strict"), "rb") as in_file: in_data = in_file.read() rc, dummy_out, err = self.exec_command( - cmd=["tee", out_path], in_data=in_data + cmd=f"tee {shlex.quote(out_path)}", in_data=in_data ) if rc != 0: - raise AnsibleError(f"failed to transfer file to {out_path}: {err}") + raise AnsibleError( + f"failed to transfer file to {out_path}: {to_text(err)}" + ) except IOError as e: raise AnsibleError(f"failed to transfer file to {out_path}: {e}") from e - def fetch_file(self, in_path, out_path): - super().fetch_file(in_path, out_path) + def fetch_file(self, in_path: str, out_path: str) -> None: + super().fetch_file(in_path, out_path) # type: ignore[safe-super] in_path = unfrackpath(in_path, basedir=self.cwd) out_path = unfrackpath(out_path, basedir=self.cwd) try: - rc, out, err = self.exec_command(cmd=["cat", in_path]) + rc, out, err = self.exec_command(cmd=f"cat {shlex.quote(in_path)}") display.vvv( f"FETCH {in_path} TO {out_path}", host=self._play_context.remote_addr ) if rc != 0: - raise AnsibleError(f"failed to transfer file to {in_path}: {err}") + raise AnsibleError( + f"failed to transfer file to {in_path}: {to_text(err)}" + ) with open( to_bytes(out_path, errors="surrogate_or_strict"), "wb" ) as out_file: @@ -274,6 +282,6 @@ class Connection(ConnectionBase): f"failed to transfer file to {to_native(out_path)}: {e}" ) from e - def close(self): + def close(self) -> None: """terminate the connection; nothing to do here""" self._connected = False diff --git a/plugins/inventory/docker_containers.py b/plugins/inventory/docker_containers.py index b5e2979c..dcdfe991 100644 --- a/plugins/inventory/docker_containers.py +++ b/plugins/inventory/docker_containers.py @@ -169,6 +169,7 @@ filters: """ import re +import typing as t from ansible.errors import AnsibleError from ansible.plugins.inventory import BaseInventoryPlugin, Constructable @@ -195,6 +196,11 @@ from ansible_collections.community.docker.plugins.plugin_utils._unsafe import ( ) +if t.TYPE_CHECKING: + from ansible.inventory.data import InventoryData + from ansible.parsing.dataloader import DataLoader + + MIN_DOCKER_API = None @@ -203,11 +209,11 @@ class InventoryModule(BaseInventoryPlugin, Constructable): NAME = "community.docker.docker_containers" - def _slugify(self, value): + def _slugify(self, value: str) -> str: slug = re.sub(r"[^\w-]", "_", value).lower().lstrip("_") return f"docker_{slug}" - def _populate(self, client): + def _populate(self, client: AnsibleDockerClient) -> None: strict = self.get_option("strict") ssh_port = self.get_option("private_ssh_port") @@ -217,6 +223,9 @@ class InventoryModule(BaseInventoryPlugin, Constructable): connection_type = self.get_option("connection_type") add_legacy_groups = self.get_option("add_legacy_groups") + if self.inventory is None: + raise AssertionError("Inventory must be there") + try: params = { "limit": -1, @@ -298,7 +307,7 @@ class InventoryModule(BaseInventoryPlugin, Constructable): # Lookup the public facing port Nat'ed to ssh port. network_settings = inspect.get("NetworkSettings") or {} port_settings = network_settings.get("Ports") or {} - port = port_settings.get(f"{ssh_port}/tcp")[0] + port = port_settings.get(f"{ssh_port}/tcp")[0] # type: ignore[index] except (IndexError, AttributeError, TypeError): port = {} @@ -387,16 +396,22 @@ class InventoryModule(BaseInventoryPlugin, Constructable): else: self.inventory.add_host(name, group="stopped") - def verify_file(self, path): + def verify_file(self, path: str) -> bool: """Return the possibly of a file being consumable by this plugin.""" return super().verify_file(path) and path.endswith( ("docker.yaml", "docker.yml") ) - def _create_client(self): + def _create_client(self) -> AnsibleDockerClient: return AnsibleDockerClient(self, min_docker_api_version=MIN_DOCKER_API) - def parse(self, inventory, loader, path, cache=True): + def parse( + self, + inventory: InventoryData, + loader: DataLoader, + path: str, + cache: bool = True, + ) -> None: super().parse(inventory, loader, path, cache) self._read_config_data(path) client = self._create_client() diff --git a/plugins/inventory/docker_machine.py b/plugins/inventory/docker_machine.py index 87f44101..0fcd08e9 100644 --- a/plugins/inventory/docker_machine.py +++ b/plugins/inventory/docker_machine.py @@ -101,6 +101,7 @@ compose: import json import re import subprocess +import typing as t from ansible.errors import AnsibleError from ansible.module_utils.common.process import get_bin_path @@ -117,6 +118,15 @@ from ansible_collections.community.docker.plugins.plugin_utils._unsafe import ( ) +if t.TYPE_CHECKING: + from ansible.inventory.data import InventoryData + from ansible.parsing.dataloader import DataLoader + + DaemonEnv = t.Literal[ + "require", "require-silently", "optional", "optional-silently", "skip" + ] + + display = Display() @@ -125,9 +135,9 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable): NAME = "community.docker.docker_machine" - docker_machine_path = None + docker_machine_path: str | None = None - def _run_command(self, args): + def _run_command(self, args: list[str]) -> str: if not self.docker_machine_path: try: self.docker_machine_path = get_bin_path("docker-machine") @@ -147,7 +157,7 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable): return to_text(result).strip() - def _get_docker_daemon_variables(self, machine_name): + def _get_docker_daemon_variables(self, machine_name: str) -> list[tuple[str, str]]: """ Capture settings from Docker Machine that would be needed to connect to the remote Docker daemon installed on the Docker Machine remote host. Note: passing '--shell=sh' is a workaround for 'Error: Unknown shell'. @@ -180,7 +190,7 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable): return env_vars - def _get_machine_names(self): + def _get_machine_names(self) -> list[str]: # Filter out machines that are not in the Running state as we probably cannot do anything useful actions # with them. ls_command = ["ls", "-q"] @@ -194,7 +204,7 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable): return ls_lines.splitlines() - def _inspect_docker_machine_host(self, node): + def _inspect_docker_machine_host(self, node: str) -> t.Any | None: try: inspect_lines = self._run_command(["inspect", node]) except subprocess.CalledProcessError: @@ -202,7 +212,7 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable): return json.loads(inspect_lines) - def _ip_addr_docker_machine_host(self, node): + def _ip_addr_docker_machine_host(self, node: str) -> t.Any | None: try: ip_addr = self._run_command(["ip", node]) except subprocess.CalledProcessError: @@ -210,7 +220,9 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable): return ip_addr - def _should_skip_host(self, machine_name, env_var_tuples, daemon_env): + def _should_skip_host( + self, machine_name: str, env_var_tuples, daemon_env: DaemonEnv + ) -> bool: if not env_var_tuples: warning_prefix = f"Unable to fetch Docker daemon env vars from Docker Machine for host {machine_name}" if daemon_env in ("require", "require-silently"): @@ -224,8 +236,11 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable): # daemon_env is 'optional-silently' return False - def _populate(self): - daemon_env = self.get_option("daemon_env") + def _populate(self) -> None: + if self.inventory is None: + raise AssertionError("Inventory must be there") + + daemon_env: DaemonEnv = self.get_option("daemon_env") filters = parse_filters(self.get_option("filters")) try: for node in self._get_machine_names(): @@ -325,13 +340,19 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable): f"Unable to fetch hosts from Docker Machine, this was the original exception: {e}" ) from e - def verify_file(self, path): + def verify_file(self, path: str) -> bool: """Return the possibility of a file being consumable by this plugin.""" return super().verify_file(path) and path.endswith( ("docker_machine.yaml", "docker_machine.yml") ) - def parse(self, inventory, loader, path, cache=True): + def parse( + self, + inventory: InventoryData, + loader: DataLoader, + path: str, + cache: bool = True, + ) -> None: super().parse(inventory, loader, path, cache) self._read_config_data(path) self._populate() diff --git a/plugins/inventory/docker_swarm.py b/plugins/inventory/docker_swarm.py index a47ed9b1..f0a2e32b 100644 --- a/plugins/inventory/docker_swarm.py +++ b/plugins/inventory/docker_swarm.py @@ -148,6 +148,8 @@ keyed_groups: prefix: label """ +import typing as t + from ansible.errors import AnsibleError from ansible.parsing.utils.addresses import parse_address from ansible.plugins.inventory import BaseInventoryPlugin, Constructable @@ -167,6 +169,11 @@ from ansible_collections.community.docker.plugins.plugin_utils._unsafe import ( ) +if t.TYPE_CHECKING: + from ansible.inventory.data import InventoryData + from ansible.parsing.dataloader import DataLoader + + try: import docker @@ -180,10 +187,13 @@ class InventoryModule(BaseInventoryPlugin, Constructable): NAME = "community.docker.docker_swarm" - def _fail(self, msg): + def _fail(self, msg: str) -> t.NoReturn: raise AnsibleError(msg) - def _populate(self): + def _populate(self) -> None: + if self.inventory is None: + raise AssertionError("Inventory must be there") + raw_params = { "docker_host": self.get_option("docker_host"), "tls": self.get_option("tls"), @@ -307,13 +317,19 @@ class InventoryModule(BaseInventoryPlugin, Constructable): f"Unable to fetch hosts from Docker swarm API, this was the original exception: {e}" ) from e - def verify_file(self, path): + def verify_file(self, path: str) -> bool: """Return the possibly of a file being consumable by this plugin.""" return super().verify_file(path) and path.endswith( ("docker_swarm.yaml", "docker_swarm.yml") ) - def parse(self, inventory, loader, path, cache=True): + def parse( + self, + inventory: InventoryData, + loader: DataLoader, + path: str, + cache: bool = True, + ) -> None: if not HAS_DOCKER: raise AnsibleError( "The Docker swarm dynamic inventory plugin requires the Docker SDK for Python: " diff --git a/plugins/module_utils/_api/_import_helper.py b/plugins/module_utils/_api/_import_helper.py index 01f7f18e..b2f7bc38 100644 --- a/plugins/module_utils/_api/_import_helper.py +++ b/plugins/module_utils/_api/_import_helper.py @@ -12,8 +12,10 @@ from __future__ import annotations import traceback +import typing as t +REQUESTS_IMPORT_ERROR: str | None # pylint: disable=invalid-name try: from requests import Session # noqa: F401, pylint: disable=unused-import from requests.adapters import ( # noqa: F401, pylint: disable=unused-import @@ -26,28 +28,29 @@ try: except ImportError: REQUESTS_IMPORT_ERROR = traceback.format_exc() # pylint: disable=invalid-name - class Session: - __attrs__ = [] + class Session: # type: ignore + __attrs__: list[t.Never] = [] - class HTTPAdapter: - __attrs__ = [] + class HTTPAdapter: # type: ignore + __attrs__: list[t.Never] = [] - class HTTPError(Exception): + class HTTPError(Exception): # type: ignore pass - class InvalidSchema(Exception): + class InvalidSchema(Exception): # type: ignore pass else: REQUESTS_IMPORT_ERROR = None # pylint: disable=invalid-name -URLLIB3_IMPORT_ERROR = None # pylint: disable=invalid-name +URLLIB3_IMPORT_ERROR: str | None = None # pylint: disable=invalid-name try: from requests.packages import urllib3 # pylint: disable=unused-import - # pylint: disable-next=unused-import - from requests.packages.urllib3 import connection as urllib3_connection + from requests.packages.urllib3 import ( # type: ignore # pylint: disable=unused-import # isort: skip + connection as urllib3_connection, + ) except ImportError: try: import urllib3 # pylint: disable=unused-import diff --git a/plugins/module_utils/_api/api/client.py b/plugins/module_utils/_api/api/client.py index af5d1db0..3393812c 100644 --- a/plugins/module_utils/_api/api/client.py +++ b/plugins/module_utils/_api/api/client.py @@ -13,8 +13,9 @@ from __future__ import annotations import json import logging +import os import struct -from functools import partial +import typing as t from urllib.parse import quote from .. import auth @@ -47,16 +48,21 @@ from ..transport.sshconn import PARAMIKO_IMPORT_ERROR, SSHHTTPAdapter from ..transport.ssladapter import SSLHTTPAdapter from ..transport.unixconn import UnixHTTPAdapter from ..utils import config, json_stream, utils -from ..utils.decorators import update_headers +from ..utils.decorators import minimum_version, update_headers from ..utils.proxy import ProxyConfig from ..utils.socket import consume_socket_output, demux_adaptor, frames_iter -from .daemon import DaemonApiMixin + + +if t.TYPE_CHECKING: + from requests import Response + + from ..._socket_helper import SocketLike log = logging.getLogger(__name__) -class APIClient(_Session, DaemonApiMixin): +class APIClient(_Session): """ A low-level client for the Docker Engine API. @@ -105,16 +111,16 @@ class APIClient(_Session, DaemonApiMixin): def __init__( self, - base_url=None, - version=None, - timeout=DEFAULT_TIMEOUT_SECONDS, - tls=False, - user_agent=DEFAULT_USER_AGENT, - num_pools=None, - credstore_env=None, - use_ssh_client=False, - max_pool_size=DEFAULT_MAX_POOL_SIZE, - ): + base_url: str | None = None, + version: str | None = None, + timeout: int | float = DEFAULT_TIMEOUT_SECONDS, + tls: bool | TLSConfig = False, + user_agent: str = DEFAULT_USER_AGENT, + num_pools: int | None = None, + credstore_env: dict[str, str] | None = None, + use_ssh_client: bool = False, + max_pool_size: int = DEFAULT_MAX_POOL_SIZE, + ) -> None: super().__init__() fail_on_missing_imports() @@ -124,7 +130,6 @@ class APIClient(_Session, DaemonApiMixin): "If using TLS, the base_url argument must be provided." ) - self.base_url = base_url self.timeout = timeout self.headers["User-Agent"] = user_agent @@ -145,6 +150,7 @@ class APIClient(_Session, DaemonApiMixin): self.credstore_env = credstore_env base_url = utils.parse_host(base_url, IS_WINDOWS_PLATFORM, tls=bool(tls)) + self.base_url = base_url # SSH has a different default for num_pools to all other adapters num_pools = ( num_pools or DEFAULT_NUM_POOLS_SSH @@ -152,6 +158,9 @@ class APIClient(_Session, DaemonApiMixin): else DEFAULT_NUM_POOLS ) + self._custom_adapter: ( + UnixHTTPAdapter | NpipeHTTPAdapter | SSHHTTPAdapter | SSLHTTPAdapter | None + ) = None if base_url.startswith("http+unix://"): self._custom_adapter = UnixHTTPAdapter( base_url, @@ -223,7 +232,7 @@ class APIClient(_Session, DaemonApiMixin): f"API versions below {MINIMUM_DOCKER_API_VERSION} are no longer supported by this library." ) - def _retrieve_server_version(self): + def _retrieve_server_version(self) -> str: try: version_result = self.version(api_version=False) except Exception as e: @@ -242,54 +251,87 @@ class APIClient(_Session, DaemonApiMixin): f"Error while fetching server API version: {e}. Response seems to be broken." ) from e - def _set_request_timeout(self, kwargs): + def _set_request_timeout(self, kwargs: dict[str, t.Any]) -> dict[str, t.Any]: """Prepare the kwargs for an HTTP request by inserting the timeout parameter, if not already present.""" kwargs.setdefault("timeout", self.timeout) return kwargs @update_headers - def _post(self, url, **kwargs): + def _post(self, url: str, **kwargs): return self.post(url, **self._set_request_timeout(kwargs)) @update_headers - def _get(self, url, **kwargs): + def _get(self, url: str, **kwargs): return self.get(url, **self._set_request_timeout(kwargs)) @update_headers - def _head(self, url, **kwargs): + def _head(self, url: str, **kwargs): return self.head(url, **self._set_request_timeout(kwargs)) @update_headers - def _put(self, url, **kwargs): + def _put(self, url: str, **kwargs): return self.put(url, **self._set_request_timeout(kwargs)) @update_headers - def _delete(self, url, **kwargs): + def _delete(self, url: str, **kwargs): return self.delete(url, **self._set_request_timeout(kwargs)) - def _url(self, pathfmt, *args, **kwargs): + def _url(self, pathfmt: str, *args: str, versioned_api: bool = True) -> str: for arg in args: if not isinstance(arg, str): raise ValueError( f"Expected a string but found {arg} ({type(arg)}) instead" ) - quote_f = partial(quote, safe="/:") - args = map(quote_f, args) + q_args = [quote(arg, safe="/:") for arg in args] - if kwargs.get("versioned_api", True): - return f"{self.base_url}/v{self._version}{pathfmt.format(*args)}" - return f"{self.base_url}{pathfmt.format(*args)}" + if versioned_api: + return f"{self.base_url}/v{self._version}{pathfmt.format(*q_args)}" + return f"{self.base_url}{pathfmt.format(*q_args)}" - def _raise_for_status(self, response): + def _raise_for_status(self, response: Response) -> None: """Raises stored :class:`APIError`, if one occurred.""" try: response.raise_for_status() except _HTTPError as e: create_api_error_from_http_exception(e) - def _result(self, response, get_json=False, get_binary=False): + @t.overload + def _result( + self, + response: Response, + *, + get_json: t.Literal[False] = False, + get_binary: t.Literal[False] = False, + ) -> str: ... + + @t.overload + def _result( + self, + response: Response, + *, + get_json: t.Literal[True], + get_binary: t.Literal[False] = False, + ) -> t.Any: ... + + @t.overload + def _result( + self, + response: Response, + *, + get_json: t.Literal[False] = False, + get_binary: t.Literal[True], + ) -> bytes: ... + + @t.overload + def _result( + self, response: Response, *, get_json: bool = False, get_binary: bool = False + ) -> t.Any | str | bytes: ... + + def _result( + self, response: Response, *, get_json: bool = False, get_binary: bool = False + ) -> t.Any | str | bytes: if get_json and get_binary: raise AssertionError("json and binary must not be both True") self._raise_for_status(response) @@ -300,10 +342,12 @@ class APIClient(_Session, DaemonApiMixin): return response.content return response.text - def _post_json(self, url, data, **kwargs): + def _post_json( + self, url: str, data: dict[str, str | None] | t.Any, **kwargs + ) -> Response: # Go <1.1 cannot unserialize null to a string # so we do this disgusting thing here. - data2 = {} + data2: dict[str, t.Any] = {} if data is not None and isinstance(data, dict): for k, v in data.items(): if v is not None: @@ -316,19 +360,19 @@ class APIClient(_Session, DaemonApiMixin): kwargs["headers"]["Content-Type"] = "application/json" return self._post(url, data=json.dumps(data2), **kwargs) - def _attach_params(self, override=None): + def _attach_params(self, override: dict[str, int] | None = None) -> dict[str, int]: return override or {"stdout": 1, "stderr": 1, "stream": 1} - def _get_raw_response_socket(self, response): + def _get_raw_response_socket(self, response: Response) -> SocketLike: self._raise_for_status(response) if self.base_url == "http+docker://localnpipe": - sock = response.raw._fp.fp.raw.sock + sock = response.raw._fp.fp.raw.sock # type: ignore[union-attr] elif self.base_url.startswith("http+docker://ssh"): - sock = response.raw._fp.fp.channel + sock = response.raw._fp.fp.channel # type: ignore[union-attr] else: - sock = response.raw._fp.fp.raw + sock = response.raw._fp.fp.raw # type: ignore[union-attr] if self.base_url.startswith("https://"): - sock = sock._sock + sock = sock._sock # type: ignore[union-attr] try: # Keep a reference to the response to stop it being garbage # collected. If the response is garbage collected, it will @@ -341,12 +385,26 @@ class APIClient(_Session, DaemonApiMixin): return sock - def _stream_helper(self, response, decode=False): + @t.overload + def _stream_helper( + self, response: Response, *, decode: t.Literal[False] = False + ) -> t.Generator[bytes]: ... + + @t.overload + def _stream_helper( + self, response: Response, *, decode: t.Literal[True] + ) -> t.Generator[t.Any]: ... + + def _stream_helper( + self, response: Response, *, decode: bool = False + ) -> t.Generator[t.Any]: """Generator for data coming from a chunked-encoded HTTP response.""" - if response.raw._fp.chunked: + if response.raw._fp.chunked: # type: ignore[union-attr] if decode: - yield from json_stream.json_stream(self._stream_helper(response, False)) + yield from json_stream.json_stream( + self._stream_helper(response, decode=False) + ) else: reader = response.raw while not reader.closed: @@ -354,15 +412,15 @@ class APIClient(_Session, DaemonApiMixin): data = reader.read(1) if not data: break - if reader._fp.chunk_left: - data += reader.read(reader._fp.chunk_left) + if reader._fp.chunk_left: # type: ignore[union-attr] + data += reader.read(reader._fp.chunk_left) # type: ignore[union-attr] yield data else: # Response is not chunked, meaning we probably # encountered an error immediately yield self._result(response, get_json=decode) - def _multiplexed_buffer_helper(self, response): + def _multiplexed_buffer_helper(self, response: Response) -> t.Generator[bytes]: """A generator of multiplexed data blocks read from a buffered response.""" buf = self._result(response, get_binary=True) @@ -378,7 +436,9 @@ class APIClient(_Session, DaemonApiMixin): walker = end yield buf[start:end] - def _multiplexed_response_stream_helper(self, response): + def _multiplexed_response_stream_helper( + self, response: Response + ) -> t.Generator[bytes]: """A generator of multiplexed data blocks coming from a response stream.""" @@ -399,7 +459,19 @@ class APIClient(_Session, DaemonApiMixin): break yield data - def _stream_raw_result(self, response, chunk_size=1, decode=True): + @t.overload + def _stream_raw_result( + self, response: Response, *, chunk_size: int = 1, decode: t.Literal[True] = True + ) -> t.Generator[str]: ... + + @t.overload + def _stream_raw_result( + self, response: Response, *, chunk_size: int = 1, decode: t.Literal[False] + ) -> t.Generator[bytes]: ... + + def _stream_raw_result( + self, response: Response, *, chunk_size: int = 1, decode: bool = True + ) -> t.Generator[str | bytes]: """Stream result for TTY-enabled container and raw binary data""" self._raise_for_status(response) @@ -410,14 +482,81 @@ class APIClient(_Session, DaemonApiMixin): yield from response.iter_content(chunk_size, decode) - def _read_from_socket(self, response, stream, tty=True, demux=False): + @t.overload + def _read_from_socket( + self, + response: Response, + *, + stream: t.Literal[True], + tty: bool = True, + demux: t.Literal[False] = False, + ) -> t.Generator[bytes]: ... + + @t.overload + def _read_from_socket( + self, + response: Response, + *, + stream: t.Literal[True], + tty: t.Literal[True] = True, + demux: t.Literal[True], + ) -> t.Generator[tuple[bytes, None]]: ... + + @t.overload + def _read_from_socket( + self, + response: Response, + *, + stream: t.Literal[True], + tty: t.Literal[False], + demux: t.Literal[True], + ) -> t.Generator[tuple[bytes, None] | tuple[None, bytes]]: ... + + @t.overload + def _read_from_socket( + self, + response: Response, + *, + stream: t.Literal[False], + tty: bool = True, + demux: t.Literal[False] = False, + ) -> bytes: ... + + @t.overload + def _read_from_socket( + self, + response: Response, + *, + stream: t.Literal[False], + tty: t.Literal[True] = True, + demux: t.Literal[True], + ) -> tuple[bytes, None]: ... + + @t.overload + def _read_from_socket( + self, + response: Response, + *, + stream: t.Literal[False], + tty: t.Literal[False], + demux: t.Literal[True], + ) -> tuple[bytes, bytes]: ... + + @t.overload + def _read_from_socket( + self, response: Response, *, stream: bool, tty: bool = True, demux: bool = False + ) -> t.Any: ... + + def _read_from_socket( + self, response: Response, *, stream: bool, tty: bool = True, demux: bool = False + ) -> t.Any: """Consume all data from the socket, close the response and return the data. If stream=True, then a generator is returned instead and the caller is responsible for closing the response. """ socket = self._get_raw_response_socket(response) - gen = frames_iter(socket, tty) + gen: t.Generator = frames_iter(socket, tty) if demux: # The generator will output tuples (stdout, stderr) @@ -434,7 +573,7 @@ class APIClient(_Session, DaemonApiMixin): finally: response.close() - def _disable_socket_timeout(self, socket): + def _disable_socket_timeout(self, socket: SocketLike) -> None: """Depending on the combination of python version and whether we are connecting over http or https, we might need to access _sock, which may or may not exist; or we may need to just settimeout on socket @@ -451,18 +590,38 @@ class APIClient(_Session, DaemonApiMixin): if not hasattr(s, "settimeout"): continue - timeout = -1 + timeout: int | float | None = -1 if hasattr(s, "gettimeout"): - timeout = s.gettimeout() + timeout = s.gettimeout() # type: ignore[union-attr] # Do not change the timeout if it is already disabled. if timeout is None or timeout == 0.0: continue - s.settimeout(None) + s.settimeout(None) # type: ignore[union-attr] - def _get_result_tty(self, stream, res, is_tty): + @t.overload + def _get_result_tty( + self, stream: t.Literal[True], res: Response, is_tty: t.Literal[True] + ) -> t.Generator[str]: ... + + @t.overload + def _get_result_tty( + self, stream: t.Literal[True], res: Response, is_tty: t.Literal[False] + ) -> t.Generator[bytes]: ... + + @t.overload + def _get_result_tty( + self, stream: t.Literal[False], res: Response, is_tty: t.Literal[True] + ) -> bytes: ... + + @t.overload + def _get_result_tty( + self, stream: t.Literal[False], res: Response, is_tty: t.Literal[False] + ) -> bytes: ... + + def _get_result_tty(self, stream: bool, res: Response, is_tty: bool) -> t.Any: # We should also use raw streaming (without keep-alive) # if we are dealing with a tty-enabled container. if is_tty: @@ -478,11 +637,11 @@ class APIClient(_Session, DaemonApiMixin): return self._multiplexed_response_stream_helper(res) return sep.join(list(self._multiplexed_buffer_helper(res))) - def _unmount(self, *args): + def _unmount(self, *args) -> None: for proto in args: self.adapters.pop(proto) - def get_adapter(self, url): + def get_adapter(self, url: str): try: return super().get_adapter(url) except _InvalidSchema as e: @@ -491,10 +650,10 @@ class APIClient(_Session, DaemonApiMixin): raise e @property - def api_version(self): + def api_version(self) -> str: return self._version - def reload_config(self, dockercfg_path=None): + def reload_config(self, dockercfg_path: str | None = None) -> None: """ Force a reload of the auth configuration @@ -510,7 +669,7 @@ class APIClient(_Session, DaemonApiMixin): dockercfg_path, credstore_env=self.credstore_env ) - def _set_auth_headers(self, headers): + def _set_auth_headers(self, headers: dict[str, str | bytes]) -> None: log.debug("Looking for auth config") # If we do not have any auth data so far, try reloading the config @@ -537,57 +696,62 @@ class APIClient(_Session, DaemonApiMixin): else: log.debug("No auth config found") - def get_binary(self, pathfmt, *args, **kwargs): + def get_binary(self, pathfmt: str, *args: str, **kwargs) -> bytes: return self._result( self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs), get_binary=True, ) - def get_json(self, pathfmt, *args, **kwargs): + def get_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any: return self._result( self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs), get_json=True, ) - def get_text(self, pathfmt, *args, **kwargs): + def get_text(self, pathfmt: str, *args: str, **kwargs) -> str: return self._result( self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs) ) - def get_raw_stream(self, pathfmt, *args, **kwargs): - chunk_size = kwargs.pop("chunk_size", DEFAULT_DATA_CHUNK_SIZE) + def get_raw_stream( + self, + pathfmt: str, + *args: str, + chunk_size: int = DEFAULT_DATA_CHUNK_SIZE, + **kwargs, + ) -> t.Generator[bytes]: res = self._get( self._url(pathfmt, *args, versioned_api=True), stream=True, **kwargs ) self._raise_for_status(res) - return self._stream_raw_result(res, chunk_size, False) + return self._stream_raw_result(res, chunk_size=chunk_size, decode=False) - def delete_call(self, pathfmt, *args, **kwargs): + def delete_call(self, pathfmt: str, *args: str, **kwargs) -> None: self._raise_for_status( self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs) ) - def delete_json(self, pathfmt, *args, **kwargs): + def delete_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any: return self._result( self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs), get_json=True, ) - def post_call(self, pathfmt, *args, **kwargs): + def post_call(self, pathfmt: str, *args: str, **kwargs) -> None: self._raise_for_status( self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs) ) - def post_json(self, pathfmt, *args, **kwargs): - data = kwargs.pop("data", None) + def post_json(self, pathfmt: str, *args: str, data: t.Any = None, **kwargs) -> None: self._raise_for_status( self._post_json( self._url(pathfmt, *args, versioned_api=True), data, **kwargs ) ) - def post_json_to_binary(self, pathfmt, *args, **kwargs): - data = kwargs.pop("data", None) + def post_json_to_binary( + self, pathfmt: str, *args: str, data: t.Any = None, **kwargs + ) -> bytes: return self._result( self._post_json( self._url(pathfmt, *args, versioned_api=True), data, **kwargs @@ -595,8 +759,9 @@ class APIClient(_Session, DaemonApiMixin): get_binary=True, ) - def post_json_to_json(self, pathfmt, *args, **kwargs): - data = kwargs.pop("data", None) + def post_json_to_json( + self, pathfmt: str, *args: str, data: t.Any = None, **kwargs + ) -> t.Any: return self._result( self._post_json( self._url(pathfmt, *args, versioned_api=True), data, **kwargs @@ -604,17 +769,24 @@ class APIClient(_Session, DaemonApiMixin): get_json=True, ) - def post_json_to_text(self, pathfmt, *args, **kwargs): - data = kwargs.pop("data", None) + def post_json_to_text( + self, pathfmt: str, *args: str, data: t.Any = None, **kwargs + ) -> str: return self._result( self._post_json( self._url(pathfmt, *args, versioned_api=True), data, **kwargs ), ) - def post_json_to_stream_socket(self, pathfmt, *args, **kwargs): - data = kwargs.pop("data", None) - headers = (kwargs.pop("headers", None) or {}).copy() + def post_json_to_stream_socket( + self, + pathfmt: str, + *args: str, + data: t.Any = None, + headers: dict[str, str] | None = None, + **kwargs, + ) -> SocketLike: + headers = headers.copy() if headers else {} headers.update( { "Connection": "Upgrade", @@ -631,18 +803,102 @@ class APIClient(_Session, DaemonApiMixin): ) ) - def post_json_to_stream(self, pathfmt, *args, **kwargs): - data = kwargs.pop("data", None) - headers = (kwargs.pop("headers", None) or {}).copy() + @t.overload + def post_json_to_stream( + self, + pathfmt: str, + *args: str, + data: t.Any = None, + headers: dict[str, str] | None = None, + stream: t.Literal[True], + tty: bool = True, + demux: t.Literal[False] = False, + **kwargs, + ) -> t.Generator[bytes]: ... + + @t.overload + def post_json_to_stream( + self, + pathfmt: str, + *args: str, + data: t.Any = None, + headers: dict[str, str] | None = None, + stream: t.Literal[True], + tty: t.Literal[True] = True, + demux: t.Literal[True], + **kwargs, + ) -> t.Generator[tuple[bytes, None]]: ... + + @t.overload + def post_json_to_stream( + self, + pathfmt: str, + *args: str, + data: t.Any = None, + headers: dict[str, str] | None = None, + stream: t.Literal[True], + tty: t.Literal[False], + demux: t.Literal[True], + **kwargs, + ) -> t.Generator[tuple[bytes, None] | tuple[None, bytes]]: ... + + @t.overload + def post_json_to_stream( + self, + pathfmt: str, + *args: str, + data: t.Any = None, + headers: dict[str, str] | None = None, + stream: t.Literal[False], + tty: bool = True, + demux: t.Literal[False] = False, + **kwargs, + ) -> bytes: ... + + @t.overload + def post_json_to_stream( + self, + pathfmt: str, + *args: str, + data: t.Any = None, + headers: dict[str, str] | None = None, + stream: t.Literal[False], + tty: t.Literal[True] = True, + demux: t.Literal[True], + **kwargs, + ) -> tuple[bytes, None]: ... + + @t.overload + def post_json_to_stream( + self, + pathfmt: str, + *args: str, + data: t.Any = None, + headers: dict[str, str] | None = None, + stream: t.Literal[False], + tty: t.Literal[False], + demux: t.Literal[True], + **kwargs, + ) -> tuple[bytes, bytes]: ... + + def post_json_to_stream( + self, + pathfmt: str, + *args: str, + data: t.Any = None, + headers: dict[str, str] | None = None, + stream: bool = False, + demux: bool = False, + tty: bool = False, + **kwargs, + ) -> t.Any: + headers = headers.copy() if headers else {} headers.update( { "Connection": "Upgrade", "Upgrade": "tcp", } ) - stream = kwargs.pop("stream", False) - demux = kwargs.pop("demux", False) - tty = kwargs.pop("tty", False) return self._read_from_socket( self._post_json( self._url(pathfmt, *args, versioned_api=True), @@ -651,13 +907,133 @@ class APIClient(_Session, DaemonApiMixin): stream=True, **kwargs, ), - stream, + stream=stream, tty=tty, demux=demux, ) - def post_to_json(self, pathfmt, *args, **kwargs): + def post_to_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any: return self._result( self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs), get_json=True, ) + + @minimum_version("1.25") + def df(self) -> dict[str, t.Any]: + """ + Get data usage information. + + Returns: + (dict): A dictionary representing different resource categories + and their respective data usage. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + url = self._url("/system/df") + return self._result(self._get(url), get_json=True) + + def info(self) -> dict[str, t.Any]: + """ + Display system-wide information. Identical to the ``docker info`` + command. + + Returns: + (dict): The info as a dict + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + return self._result(self._get(self._url("/info")), get_json=True) + + def login( + self, + username: str, + password: str | None = None, + email: str | None = None, + registry: str | None = None, + reauth: bool = False, + dockercfg_path: str | None = None, + ) -> dict[str, t.Any]: + """ + Authenticate with a registry. Similar to the ``docker login`` command. + + Args: + username (str): The registry username + password (str): The plaintext password + email (str): The email for the registry account + registry (str): URL to the registry. E.g. + ``https://index.docker.io/v1/`` + reauth (bool): Whether or not to refresh existing authentication on + the Docker server. + dockercfg_path (str): Use a custom path for the Docker config file + (default ``$HOME/.docker/config.json`` if present, + otherwise ``$HOME/.dockercfg``) + + Returns: + (dict): The response from the login request + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + + # If we do not have any auth data so far, try reloading the config file + # one more time in case anything showed up in there. + # If dockercfg_path is passed check to see if the config file exists, + # if so load that config. + if dockercfg_path and os.path.exists(dockercfg_path): + self._auth_configs = auth.load_config( + dockercfg_path, credstore_env=self.credstore_env + ) + elif not self._auth_configs or self._auth_configs.is_empty: + self._auth_configs = auth.load_config(credstore_env=self.credstore_env) + + authcfg = self._auth_configs.resolve_authconfig(registry) + # If we found an existing auth config for this registry and username + # combination, we can return it immediately unless reauth is requested. + if authcfg and authcfg.get("username", None) == username and not reauth: + return authcfg + + req_data = { + "username": username, + "password": password, + "email": email, + "serveraddress": registry, + } + + response = self._post_json(self._url("/auth"), data=req_data) + if response.status_code == 200: + self._auth_configs.add_auth(registry or auth.INDEX_NAME, req_data) + return self._result(response, get_json=True) + + def ping(self) -> bool: + """ + Checks the server is responsive. An exception will be raised if it + is not responding. + + Returns: + (bool) The response from the server. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + return self._result(self._get(self._url("/_ping"))) == "OK" + + def version(self, api_version: bool = True) -> dict[str, t.Any]: + """ + Returns version information from the server. Similar to the ``docker + version`` command. + + Returns: + (dict): The server version information + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + url = self._url("/version", versioned_api=api_version) + return self._result(self._get(url), get_json=True) diff --git a/plugins/module_utils/_api/api/daemon.py b/plugins/module_utils/_api/api/daemon.py deleted file mode 100644 index d16118d7..00000000 --- a/plugins/module_utils/_api/api/daemon.py +++ /dev/null @@ -1,139 +0,0 @@ -# This code is part of the Ansible collection community.docker, but is an independent component. -# This particular file, and this file only, is based on the Docker SDK for Python (https://github.com/docker/docker-py/) -# -# Copyright (c) 2016-2022 Docker, Inc. -# -# It is licensed under the Apache 2.0 license (see LICENSES/Apache-2.0.txt in this collection) -# SPDX-License-Identifier: Apache-2.0 - -# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time. -# Do not use this from other collections or standalone plugins/modules! - -from __future__ import annotations - -import os - -from .. import auth -from ..utils.decorators import minimum_version - - -class DaemonApiMixin: - @minimum_version("1.25") - def df(self): - """ - Get data usage information. - - Returns: - (dict): A dictionary representing different resource categories - and their respective data usage. - - Raises: - :py:class:`docker.errors.APIError` - If the server returns an error. - """ - url = self._url("/system/df") - return self._result(self._get(url), get_json=True) - - def info(self): - """ - Display system-wide information. Identical to the ``docker info`` - command. - - Returns: - (dict): The info as a dict - - Raises: - :py:class:`docker.errors.APIError` - If the server returns an error. - """ - return self._result(self._get(self._url("/info")), get_json=True) - - def login( - self, - username, - password=None, - email=None, - registry=None, - reauth=False, - dockercfg_path=None, - ): - """ - Authenticate with a registry. Similar to the ``docker login`` command. - - Args: - username (str): The registry username - password (str): The plaintext password - email (str): The email for the registry account - registry (str): URL to the registry. E.g. - ``https://index.docker.io/v1/`` - reauth (bool): Whether or not to refresh existing authentication on - the Docker server. - dockercfg_path (str): Use a custom path for the Docker config file - (default ``$HOME/.docker/config.json`` if present, - otherwise ``$HOME/.dockercfg``) - - Returns: - (dict): The response from the login request - - Raises: - :py:class:`docker.errors.APIError` - If the server returns an error. - """ - - # If we do not have any auth data so far, try reloading the config file - # one more time in case anything showed up in there. - # If dockercfg_path is passed check to see if the config file exists, - # if so load that config. - if dockercfg_path and os.path.exists(dockercfg_path): - self._auth_configs = auth.load_config( - dockercfg_path, credstore_env=self.credstore_env - ) - elif not self._auth_configs or self._auth_configs.is_empty: - self._auth_configs = auth.load_config(credstore_env=self.credstore_env) - - authcfg = self._auth_configs.resolve_authconfig(registry) - # If we found an existing auth config for this registry and username - # combination, we can return it immediately unless reauth is requested. - if authcfg and authcfg.get("username", None) == username and not reauth: - return authcfg - - req_data = { - "username": username, - "password": password, - "email": email, - "serveraddress": registry, - } - - response = self._post_json(self._url("/auth"), data=req_data) - if response.status_code == 200: - self._auth_configs.add_auth(registry or auth.INDEX_NAME, req_data) - return self._result(response, get_json=True) - - def ping(self): - """ - Checks the server is responsive. An exception will be raised if it - is not responding. - - Returns: - (bool) The response from the server. - - Raises: - :py:class:`docker.errors.APIError` - If the server returns an error. - """ - return self._result(self._get(self._url("/_ping"))) == "OK" - - def version(self, api_version=True): - """ - Returns version information from the server. Similar to the ``docker - version`` command. - - Returns: - (dict): The server version information - - Raises: - :py:class:`docker.errors.APIError` - If the server returns an error. - """ - url = self._url("/version", versioned_api=api_version) - return self._result(self._get(url), get_json=True) diff --git a/plugins/module_utils/_api/auth.py b/plugins/module_utils/_api/auth.py index 317e6c77..0c6cff00 100644 --- a/plugins/module_utils/_api/auth.py +++ b/plugins/module_utils/_api/auth.py @@ -14,6 +14,7 @@ from __future__ import annotations import base64 import json import logging +import typing as t from . import errors from .credentials.errors import CredentialsNotFound, StoreError @@ -21,6 +22,12 @@ from .credentials.store import Store from .utils import config +if t.TYPE_CHECKING: + from ansible_collections.community.docker.plugins.module_utils._api.api.client import ( + APIClient, + ) + + INDEX_NAME = "docker.io" INDEX_URL = f"https://index.{INDEX_NAME}/v1/" TOKEN_USERNAME = "" @@ -28,7 +35,7 @@ TOKEN_USERNAME = "" log = logging.getLogger(__name__) -def resolve_repository_name(repo_name): +def resolve_repository_name(repo_name: str) -> tuple[str, str]: if "://" in repo_name: raise errors.InvalidRepository( f"Repository name cannot contain a scheme ({repo_name})" @@ -42,14 +49,14 @@ def resolve_repository_name(repo_name): return resolve_index_name(index_name), remote_name -def resolve_index_name(index_name): +def resolve_index_name(index_name: str) -> str: index_name = convert_to_hostname(index_name) if index_name == "index." + INDEX_NAME: index_name = INDEX_NAME return index_name -def get_config_header(client, registry): +def get_config_header(client: APIClient, registry: str) -> bytes | None: log.debug("Looking for auth config") if not client._auth_configs or client._auth_configs.is_empty: log.debug("No auth config in memory - loading from filesystem") @@ -69,32 +76,38 @@ def get_config_header(client, registry): return None -def split_repo_name(repo_name): +def split_repo_name(repo_name: str) -> tuple[str, str]: parts = repo_name.split("/", 1) if len(parts) == 1 or ( "." not in parts[0] and ":" not in parts[0] and parts[0] != "localhost" ): # This is a docker index repo (ex: username/foobar or ubuntu) return INDEX_NAME, repo_name - return tuple(parts) + return tuple(parts) # type: ignore -def get_credential_store(authconfig, registry): +def get_credential_store( + authconfig: dict[str, t.Any] | AuthConfig, registry: str +) -> str | None: if not isinstance(authconfig, AuthConfig): authconfig = AuthConfig(authconfig) return authconfig.get_credential_store(registry) class AuthConfig(dict): - def __init__(self, dct, credstore_env=None): + def __init__( + self, dct: dict[str, t.Any], credstore_env: dict[str, str] | None = None + ): if "auths" not in dct: dct["auths"] = {} self.update(dct) self._credstore_env = credstore_env - self._stores = {} + self._stores: dict[str, Store] = {} @classmethod - def parse_auth(cls, entries, raise_on_error=False): + def parse_auth( + cls, entries: dict[str, dict[str, t.Any]], raise_on_error=False + ) -> dict[str, dict[str, t.Any]]: """ Parses authentication entries @@ -107,10 +120,10 @@ class AuthConfig(dict): Authentication registry. """ - conf = {} + conf: dict[str, dict[str, t.Any]] = {} for registry, entry in entries.items(): if not isinstance(entry, dict): - log.debug("Config entry for key %s is not auth config", registry) + log.debug("Config entry for key %s is not auth config", registry) # type: ignore # We sometimes fall back to parsing the whole config as if it # was the auth config by itself, for legacy purposes. In that # case, we fail silently and return an empty conf if any of the @@ -150,7 +163,12 @@ class AuthConfig(dict): return conf @classmethod - def load_config(cls, config_path, config_dict, credstore_env=None): + def load_config( + cls, + config_path: str | None, + config_dict: dict[str, t.Any] | None, + credstore_env: dict[str, str] | None = None, + ) -> t.Self: """ Loads authentication data from a Docker configuration file in the given root directory or if config_path is passed use given path. @@ -196,22 +214,24 @@ class AuthConfig(dict): return cls({"auths": cls.parse_auth(config_dict)}, credstore_env) @property - def auths(self): + def auths(self) -> dict[str, dict[str, t.Any]]: return self.get("auths", {}) @property - def creds_store(self): + def creds_store(self) -> str | None: return self.get("credsStore", None) @property - def cred_helpers(self): + def cred_helpers(self) -> dict[str, t.Any]: return self.get("credHelpers", {}) @property - def is_empty(self): + def is_empty(self) -> bool: return not self.auths and not self.creds_store and not self.cred_helpers - def resolve_authconfig(self, registry=None): + def resolve_authconfig( + self, registry: str | None = None + ) -> dict[str, t.Any] | None: """ Returns the authentication data from the given auth configuration for a specific registry. As with the Docker client, legacy entries in the @@ -244,7 +264,9 @@ class AuthConfig(dict): log.debug("No entry found") return None - def _resolve_authconfig_credstore(self, registry, credstore_name): + def _resolve_authconfig_credstore( + self, registry: str | None, credstore_name: str + ) -> dict[str, t.Any] | None: if not registry or registry == INDEX_NAME: # The ecosystem is a little schizophrenic with index.docker.io VS # docker.io - in that case, it seems the full URL is necessary. @@ -272,19 +294,19 @@ class AuthConfig(dict): except StoreError as e: raise errors.DockerException(f"Credentials store error: {e}") - def _get_store_instance(self, name): + def _get_store_instance(self, name: str): if name not in self._stores: self._stores[name] = Store(name, environment=self._credstore_env) return self._stores[name] - def get_credential_store(self, registry): + def get_credential_store(self, registry: str | None) -> str | None: if not registry or registry == INDEX_NAME: registry = INDEX_URL return self.cred_helpers.get(registry) or self.creds_store - def get_all_credentials(self): - auth_data = self.auths.copy() + def get_all_credentials(self) -> dict[str, dict[str, t.Any] | None]: + auth_data: dict[str, dict[str, t.Any] | None] = self.auths.copy() # type: ignore if self.creds_store: # Retrieve all credentials from the default store store = self._get_store_instance(self.creds_store) @@ -299,21 +321,23 @@ class AuthConfig(dict): return auth_data - def add_auth(self, reg, data): + def add_auth(self, reg: str, data: dict[str, t.Any]) -> None: self["auths"][reg] = data -def resolve_authconfig(authconfig, registry=None, credstore_env=None): +def resolve_authconfig( + authconfig, registry: str | None = None, credstore_env: dict[str, str] | None = None +): if not isinstance(authconfig, AuthConfig): authconfig = AuthConfig(authconfig, credstore_env) return authconfig.resolve_authconfig(registry) -def convert_to_hostname(url): +def convert_to_hostname(url: str) -> str: return url.replace("http://", "").replace("https://", "").split("/", 1)[0] -def decode_auth(auth): +def decode_auth(auth: str | bytes) -> tuple[str, str]: if isinstance(auth, str): auth = auth.encode("ascii") s = base64.b64decode(auth) @@ -321,12 +345,14 @@ def decode_auth(auth): return login.decode("utf8"), pwd.decode("utf8") -def encode_header(auth): +def encode_header(auth: dict[str, t.Any]) -> bytes: auth_json = json.dumps(auth).encode("ascii") return base64.urlsafe_b64encode(auth_json) -def parse_auth(entries, raise_on_error=False): +def parse_auth( + entries: dict[str, dict[str, t.Any]], raise_on_error: bool = False +) -> dict[str, dict[str, t.Any]]: """ Parses authentication entries @@ -342,11 +368,15 @@ def parse_auth(entries, raise_on_error=False): return AuthConfig.parse_auth(entries, raise_on_error) -def load_config(config_path=None, config_dict=None, credstore_env=None): +def load_config( + config_path: str | None = None, + config_dict: dict[str, t.Any] | None = None, + credstore_env: dict[str, str] | None = None, +) -> AuthConfig: return AuthConfig.load_config(config_path, config_dict, credstore_env) -def _load_legacy_config(config_file): +def _load_legacy_config(config_file: str) -> dict[str, dict[str, t.Any]]: log.debug("Attempting to parse legacy auth file format") try: data = [] diff --git a/plugins/module_utils/_api/context/api.py b/plugins/module_utils/_api/context/api.py index 2b026ab7..133357d2 100644 --- a/plugins/module_utils/_api/context/api.py +++ b/plugins/module_utils/_api/context/api.py @@ -13,6 +13,7 @@ from __future__ import annotations import json import os +import typing as t from .. import errors from .config import ( @@ -24,7 +25,11 @@ from .config import ( from .context import Context -def create_default_context(): +if t.TYPE_CHECKING: + from ..tls import TLSConfig + + +def create_default_context() -> Context: host = None if os.environ.get("DOCKER_HOST"): host = os.environ.get("DOCKER_HOST") @@ -42,7 +47,7 @@ class ContextAPI: DEFAULT_CONTEXT = None @classmethod - def get_default_context(cls): + def get_default_context(cls) -> Context: context = cls.DEFAULT_CONTEXT if context is None: context = create_default_context() @@ -52,13 +57,13 @@ class ContextAPI: @classmethod def create_context( cls, - name, - orchestrator=None, - host=None, - tls_cfg=None, - default_namespace=None, - skip_tls_verify=False, - ): + name: str, + orchestrator: str | None = None, + host: str | None = None, + tls_cfg: TLSConfig | None = None, + default_namespace: str | None = None, + skip_tls_verify: bool = False, + ) -> Context: """Creates a new context. Returns: (Context): a Context object. @@ -108,7 +113,7 @@ class ContextAPI: return ctx @classmethod - def get_context(cls, name=None): + def get_context(cls, name: str | None = None) -> Context | None: """Retrieves a context object. Args: name (str): The name of the context @@ -136,7 +141,7 @@ class ContextAPI: return Context.load_context(name) @classmethod - def contexts(cls): + def contexts(cls) -> list[Context]: """Context list. Returns: (Context): List of context objects. @@ -170,7 +175,7 @@ class ContextAPI: return contexts @classmethod - def get_current_context(cls): + def get_current_context(cls) -> Context | None: """Get current context. Returns: (Context): current context object. @@ -178,7 +183,7 @@ class ContextAPI: return cls.get_context() @classmethod - def set_current_context(cls, name="default"): + def set_current_context(cls, name: str = "default") -> None: ctx = cls.get_context(name) if not ctx: raise errors.ContextNotFound(name) @@ -188,7 +193,7 @@ class ContextAPI: raise errors.ContextException(f"Failed to set current context: {err}") @classmethod - def remove_context(cls, name): + def remove_context(cls, name: str) -> None: """Remove a context. Similar to the ``docker context rm`` command. Args: @@ -220,7 +225,7 @@ class ContextAPI: ctx.remove() @classmethod - def inspect_context(cls, name="default"): + def inspect_context(cls, name: str = "default") -> dict[str, t.Any]: """Inspect a context. Similar to the ``docker context inspect`` command. Args: diff --git a/plugins/module_utils/_api/context/config.py b/plugins/module_utils/_api/context/config.py index b3ff0aa0..6ab07b0d 100644 --- a/plugins/module_utils/_api/context/config.py +++ b/plugins/module_utils/_api/context/config.py @@ -23,7 +23,7 @@ from ..utils.utils import parse_host METAFILE = "meta.json" -def get_current_context_name_with_source(): +def get_current_context_name_with_source() -> tuple[str, str]: if os.environ.get("DOCKER_HOST"): return "default", "DOCKER_HOST environment variable set" if os.environ.get("DOCKER_CONTEXT"): @@ -41,11 +41,11 @@ def get_current_context_name_with_source(): return "default", "fallback value" -def get_current_context_name(): +def get_current_context_name() -> str: return get_current_context_name_with_source()[0] -def write_context_name_to_docker_config(name=None): +def write_context_name_to_docker_config(name: str | None = None) -> Exception | None: if name == "default": name = None docker_cfg_path = find_config_file() @@ -62,44 +62,45 @@ def write_context_name_to_docker_config(name=None): elif name: config["currentContext"] = name else: - return + return None if not docker_cfg_path: docker_cfg_path = get_default_config_file() try: with open(docker_cfg_path, "wt", encoding="utf-8") as f: json.dump(config, f, indent=4) + return None except Exception as e: # pylint: disable=broad-exception-caught return e -def get_context_id(name): +def get_context_id(name: str) -> str: return hashlib.sha256(name.encode("utf-8")).hexdigest() -def get_context_dir(): +def get_context_dir() -> str: docker_cfg_path = find_config_file() or get_default_config_file() return os.path.join(os.path.dirname(docker_cfg_path), "contexts") -def get_meta_dir(name=None): +def get_meta_dir(name: str | None = None) -> str: meta_dir = os.path.join(get_context_dir(), "meta") if name: return os.path.join(meta_dir, get_context_id(name)) return meta_dir -def get_meta_file(name): +def get_meta_file(name) -> str: return os.path.join(get_meta_dir(name), METAFILE) -def get_tls_dir(name=None, endpoint=""): +def get_tls_dir(name: str | None = None, endpoint: str = "") -> str: context_dir = get_context_dir() if name: return os.path.join(context_dir, "tls", get_context_id(name), endpoint) return os.path.join(context_dir, "tls") -def get_context_host(path=None, tls=False): +def get_context_host(path: str | None = None, tls: bool = False) -> str: host = parse_host(path, IS_WINDOWS_PLATFORM, tls) if host == DEFAULT_UNIX_SOCKET: # remove http+ from default docker socket url diff --git a/plugins/module_utils/_api/context/context.py b/plugins/module_utils/_api/context/context.py index b8a43fb0..aaa3c280 100644 --- a/plugins/module_utils/_api/context/context.py +++ b/plugins/module_utils/_api/context/context.py @@ -13,6 +13,7 @@ from __future__ import annotations import json import os +import typing as t from shutil import copyfile, rmtree from ..errors import ContextException @@ -33,21 +34,21 @@ class Context: def __init__( self, - name, - orchestrator=None, - host=None, - endpoints=None, - skip_tls_verify=False, - tls=False, - description=None, - ): + name: str, + orchestrator: str | None = None, + host: str | None = None, + endpoints: dict[str, dict[str, t.Any]] | None = None, + skip_tls_verify: bool = False, + tls: bool = False, + description: str | None = None, + ) -> None: if not name: raise ValueError("Name not provided") self.name = name self.context_type = None self.orchestrator = orchestrator self.endpoints = {} - self.tls_cfg = {} + self.tls_cfg: dict[str, TLSConfig] = {} self.meta_path = IN_MEMORY self.tls_path = IN_MEMORY self.description = description @@ -89,12 +90,12 @@ class Context: def set_endpoint( self, - name="docker", - host=None, - tls_cfg=None, - skip_tls_verify=False, - def_namespace=None, - ): + name: str = "docker", + host: str | None = None, + tls_cfg: TLSConfig | None = None, + skip_tls_verify: bool = False, + def_namespace: str | None = None, + ) -> None: self.endpoints[name] = { "Host": get_context_host(host, not skip_tls_verify or tls_cfg is not None), "SkipTLSVerify": skip_tls_verify, @@ -105,11 +106,11 @@ class Context: if tls_cfg: self.tls_cfg[name] = tls_cfg - def inspect(self): + def inspect(self) -> dict[str, t.Any]: return self() @classmethod - def load_context(cls, name): + def load_context(cls, name: str) -> t.Self | None: meta = Context._load_meta(name) if meta: instance = cls( @@ -125,12 +126,12 @@ class Context: return None @classmethod - def _load_meta(cls, name): + def _load_meta(cls, name: str) -> dict[str, t.Any] | None: meta_file = get_meta_file(name) if not os.path.isfile(meta_file): return None - metadata = {} + metadata: dict[str, t.Any] = {} try: with open(meta_file, "rt", encoding="utf-8") as f: metadata = json.load(f) @@ -154,7 +155,7 @@ class Context: return metadata - def _load_certs(self): + def _load_certs(self) -> None: certs = {} tls_dir = get_tls_dir(self.name) for endpoint in self.endpoints: @@ -184,7 +185,7 @@ class Context: self.tls_cfg = certs self.tls_path = tls_dir - def save(self): + def save(self) -> None: meta_dir = get_meta_dir(self.name) if not os.path.isdir(meta_dir): os.makedirs(meta_dir) @@ -216,54 +217,54 @@ class Context: self.meta_path = get_meta_dir(self.name) self.tls_path = get_tls_dir(self.name) - def remove(self): + def remove(self) -> None: if os.path.isdir(self.meta_path): rmtree(self.meta_path) if os.path.isdir(self.tls_path): rmtree(self.tls_path) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__}: '{self.name}'>" - def __str__(self): + def __str__(self) -> str: return json.dumps(self.__call__(), indent=2) - def __call__(self): + def __call__(self) -> dict[str, t.Any]: result = self.Metadata result.update(self.TLSMaterial) result.update(self.Storage) return result - def is_docker_host(self): + def is_docker_host(self) -> bool: return self.context_type is None @property - def Name(self): # pylint: disable=invalid-name + def Name(self) -> str: # pylint: disable=invalid-name return self.name @property - def Host(self): # pylint: disable=invalid-name + def Host(self) -> str | None: # pylint: disable=invalid-name if not self.orchestrator or self.orchestrator == "swarm": endpoint = self.endpoints.get("docker", None) if endpoint: - return endpoint.get("Host", None) + return endpoint.get("Host", None) # type: ignore return None - return self.endpoints[self.orchestrator].get("Host", None) + return self.endpoints[self.orchestrator].get("Host", None) # type: ignore @property - def Orchestrator(self): # pylint: disable=invalid-name + def Orchestrator(self) -> str | None: # pylint: disable=invalid-name return self.orchestrator @property - def Metadata(self): # pylint: disable=invalid-name - meta = {} + def Metadata(self) -> dict[str, t.Any]: # pylint: disable=invalid-name + meta: dict[str, t.Any] = {} if self.orchestrator: meta = {"StackOrchestrator": self.orchestrator} return {"Name": self.name, "Metadata": meta, "Endpoints": self.endpoints} @property - def TLSConfig(self): # pylint: disable=invalid-name + def TLSConfig(self) -> TLSConfig | None: # pylint: disable=invalid-name key = self.orchestrator if not key or key == "swarm": key = "docker" @@ -272,13 +273,15 @@ class Context: return None @property - def TLSMaterial(self): # pylint: disable=invalid-name - certs = {} + def TLSMaterial(self) -> dict[str, t.Any]: # pylint: disable=invalid-name + certs: dict[str, t.Any] = {} for endpoint, tls in self.tls_cfg.items(): - cert, key = tls.cert - certs[endpoint] = list(map(os.path.basename, [tls.ca_cert, cert, key])) + paths = [tls.ca_cert, *tls.cert] if tls.cert else [tls.ca_cert] + certs[endpoint] = [ + os.path.basename(path) if path else None for path in paths + ] return {"TLSMaterial": certs} @property - def Storage(self): # pylint: disable=invalid-name + def Storage(self) -> dict[str, t.Any]: # pylint: disable=invalid-name return {"Storage": {"MetadataPath": self.meta_path, "TLSPath": self.tls_path}} diff --git a/plugins/module_utils/_api/credentials/errors.py b/plugins/module_utils/_api/credentials/errors.py index 323f8f67..6faed91c 100644 --- a/plugins/module_utils/_api/credentials/errors.py +++ b/plugins/module_utils/_api/credentials/errors.py @@ -11,6 +11,12 @@ from __future__ import annotations +import typing as t + + +if t.TYPE_CHECKING: + from subprocess import CalledProcessError + class StoreError(RuntimeError): pass @@ -24,7 +30,7 @@ class InitializationError(StoreError): pass -def process_store_error(cpe, program): +def process_store_error(cpe: CalledProcessError, program: str) -> StoreError: message = cpe.output.decode("utf-8") if "credentials not found in native keychain" in message: return CredentialsNotFound(f"No matching credentials in {program}") diff --git a/plugins/module_utils/_api/credentials/store.py b/plugins/module_utils/_api/credentials/store.py index 5bf5fd28..1d560e91 100644 --- a/plugins/module_utils/_api/credentials/store.py +++ b/plugins/module_utils/_api/credentials/store.py @@ -14,13 +14,14 @@ from __future__ import annotations import errno import json import subprocess +import typing as t from . import constants, errors from .utils import create_environment_dict, find_executable class Store: - def __init__(self, program, environment=None): + def __init__(self, program: str, environment: dict[str, str] | None = None) -> None: """Create a store object that acts as an interface to perform the basic operations for storing, retrieving and erasing credentials using `program`. @@ -33,7 +34,7 @@ class Store: f"{self.program} not installed or not available in PATH" ) - def get(self, server): + def get(self, server: str | bytes) -> dict[str, t.Any]: """Retrieve credentials for `server`. If no credentials are found, a `StoreError` will be raised. """ @@ -53,7 +54,7 @@ class Store: return result - def store(self, server, username, secret): + def store(self, server: str, username: str, secret: str) -> bytes: """Store credentials for `server`. Raises a `StoreError` if an error occurs. """ @@ -62,7 +63,7 @@ class Store: ).encode("utf-8") return self._execute("store", data_input) - def erase(self, server): + def erase(self, server: str | bytes) -> None: """Erase credentials for `server`. Raises a `StoreError` if an error occurs. """ @@ -70,12 +71,16 @@ class Store: server = server.encode("utf-8") self._execute("erase", server) - def list(self): + def list(self) -> t.Any: """List stored credentials. Requires v0.4.0+ of the helper.""" data = self._execute("list", None) return json.loads(data.decode("utf-8")) - def _execute(self, subcmd, data_input): + def _execute(self, subcmd: str, data_input: bytes | None) -> bytes: + if self.exe is None: + raise errors.StoreError( + f"{self.program} not installed or not available in PATH" + ) output = None env = create_environment_dict(self.environment) try: diff --git a/plugins/module_utils/_api/credentials/utils.py b/plugins/module_utils/_api/credentials/utils.py index 7a82c34b..ff63d5df 100644 --- a/plugins/module_utils/_api/credentials/utils.py +++ b/plugins/module_utils/_api/credentials/utils.py @@ -15,7 +15,7 @@ import os from shutil import which -def find_executable(executable, path=None): +def find_executable(executable: str, path: str | None = None) -> str | None: """ As distutils.spawn.find_executable, but on Windows, look up every extension declared in PATHEXT instead of just `.exe` @@ -26,7 +26,7 @@ def find_executable(executable, path=None): return which(executable, path=path) -def create_environment_dict(overrides): +def create_environment_dict(overrides: dict[str, str] | None) -> dict[str, str]: """ Create and return a copy of os.environ with the specified overrides """ diff --git a/plugins/module_utils/_api/errors.py b/plugins/module_utils/_api/errors.py index 548b1d71..12b197cb 100644 --- a/plugins/module_utils/_api/errors.py +++ b/plugins/module_utils/_api/errors.py @@ -11,6 +11,8 @@ from __future__ import annotations +import typing as t + from ansible.module_utils.common.text.converters import to_native from ._import_helper import HTTPError as _HTTPError @@ -25,7 +27,7 @@ class DockerException(Exception): """ -def create_api_error_from_http_exception(e): +def create_api_error_from_http_exception(e: _HTTPError) -> t.NoReturn: """ Create a suitable APIError from requests.exceptions.HTTPError. """ @@ -52,14 +54,16 @@ class APIError(_HTTPError, DockerException): An HTTP error from the API. """ - def __init__(self, message, response=None, explanation=None): + def __init__( + self, message: str | Exception, response=None, explanation: str | None = None + ) -> None: # requests 1.2 supports response as a keyword argument, but # requests 1.1 does not super().__init__(message) self.response = response - self.explanation = explanation + self.explanation = explanation or "" - def __str__(self): + def __str__(self) -> str: message = super().__str__() if self.is_client_error(): @@ -74,19 +78,20 @@ class APIError(_HTTPError, DockerException): return message @property - def status_code(self): + def status_code(self) -> int | None: if self.response is not None: return self.response.status_code + return None - def is_error(self): + def is_error(self) -> bool: return self.is_client_error() or self.is_server_error() - def is_client_error(self): + def is_client_error(self) -> bool: if self.status_code is None: return False return 400 <= self.status_code < 500 - def is_server_error(self): + def is_server_error(self) -> bool: if self.status_code is None: return False return 500 <= self.status_code < 600 @@ -121,10 +126,10 @@ class DeprecatedMethod(DockerException): class TLSParameterError(DockerException): - def __init__(self, msg): + def __init__(self, msg: str) -> None: self.msg = msg - def __str__(self): + def __str__(self) -> str: return self.msg + ( ". TLS configurations should map the Docker CLI " "client configurations. See " @@ -142,7 +147,14 @@ class ContainerError(DockerException): Represents a container that has exited with a non-zero exit code. """ - def __init__(self, container, exit_status, command, image, stderr): + def __init__( + self, + container: str, + exit_status: int, + command: list[str], + image: str, + stderr: str | None, + ): self.container = container self.exit_status = exit_status self.command = command @@ -156,12 +168,12 @@ class ContainerError(DockerException): class StreamParseError(RuntimeError): - def __init__(self, reason): + def __init__(self, reason: Exception) -> None: self.msg = reason class BuildError(DockerException): - def __init__(self, reason, build_log): + def __init__(self, reason: str, build_log: str) -> None: super().__init__(reason) self.msg = reason self.build_log = build_log @@ -171,7 +183,7 @@ class ImageLoadError(DockerException): pass -def create_unexpected_kwargs_error(name, kwargs): +def create_unexpected_kwargs_error(name: str, kwargs: dict[str, t.Any]) -> TypeError: quoted_kwargs = [f"'{k}'" for k in sorted(kwargs)] text = [f"{name}() "] if len(quoted_kwargs) == 1: @@ -183,42 +195,44 @@ def create_unexpected_kwargs_error(name, kwargs): class MissingContextParameter(DockerException): - def __init__(self, param): + def __init__(self, param: str) -> None: self.param = param - def __str__(self): + def __str__(self) -> str: return f"missing parameter: {self.param}" class ContextAlreadyExists(DockerException): - def __init__(self, name): + def __init__(self, name: str) -> None: self.name = name - def __str__(self): + def __str__(self) -> str: return f"context {self.name} already exists" class ContextException(DockerException): - def __init__(self, msg): + def __init__(self, msg: str) -> None: self.msg = msg - def __str__(self): + def __str__(self) -> str: return self.msg class ContextNotFound(DockerException): - def __init__(self, name): + def __init__(self, name: str) -> None: self.name = name - def __str__(self): + def __str__(self) -> str: return f"context '{self.name}' not found" class MissingRequirementException(DockerException): - def __init__(self, msg, requirement, import_exception): + def __init__( + self, msg: str, requirement: str, import_exception: ImportError | str + ) -> None: self.msg = msg self.requirement = requirement self.import_exception = import_exception - def __str__(self): + def __str__(self) -> str: return self.msg diff --git a/plugins/module_utils/_api/tls.py b/plugins/module_utils/_api/tls.py index 1b81c193..f2918200 100644 --- a/plugins/module_utils/_api/tls.py +++ b/plugins/module_utils/_api/tls.py @@ -12,12 +12,18 @@ from __future__ import annotations import os -import ssl +import typing as t from . import errors from .transport.ssladapter import SSLHTTPAdapter +if t.TYPE_CHECKING: + from ansible_collections.community.docker.plugins.module_utils._api.api.client import ( + APIClient, + ) + + class TLSConfig: """ TLS configuration. @@ -27,25 +33,22 @@ class TLSConfig: ca_cert (str): Path to CA cert file. verify (bool or str): This can be ``False`` or a path to a CA cert file. - ssl_version (int): A valid `SSL version`_. assert_hostname (bool): Verify the hostname of the server. .. _`SSL version`: https://docs.python.org/3.5/library/ssl.html#ssl.PROTOCOL_TLSv1 """ - cert = None - ca_cert = None - verify = None - ssl_version = None + cert: tuple[str, str] | None = None + ca_cert: str | None = None + verify: bool | None = None def __init__( self, - client_cert=None, - ca_cert=None, - verify=None, - ssl_version=None, - assert_hostname=None, + client_cert: tuple[str, str] | None = None, + ca_cert: str | None = None, + verify: bool | None = None, + assert_hostname: bool | None = None, ): # Argument compatibility/mapping with # https://docs.docker.com/engine/articles/https/ @@ -55,12 +58,6 @@ class TLSConfig: self.assert_hostname = assert_hostname - # If the user provides an SSL version, we should use their preference - if ssl_version: - self.ssl_version = ssl_version - else: - self.ssl_version = ssl.PROTOCOL_TLS_CLIENT - # "client_cert" must have both or neither cert/key files. In # either case, Alert the user when both are expected, but any are # missing. @@ -90,11 +87,10 @@ class TLSConfig: "Invalid CA certificate provided for `ca_cert`." ) - def configure_client(self, client): + def configure_client(self, client: APIClient) -> None: """ Configure a client with these TLS options. """ - client.ssl_version = self.ssl_version if self.verify and self.ca_cert: client.verify = self.ca_cert @@ -107,7 +103,6 @@ class TLSConfig: client.mount( "https://", SSLHTTPAdapter( - ssl_version=self.ssl_version, assert_hostname=self.assert_hostname, ), ) diff --git a/plugins/module_utils/_api/transport/basehttpadapter.py b/plugins/module_utils/_api/transport/basehttpadapter.py index 603ba3eb..90239199 100644 --- a/plugins/module_utils/_api/transport/basehttpadapter.py +++ b/plugins/module_utils/_api/transport/basehttpadapter.py @@ -15,7 +15,7 @@ from .._import_helper import HTTPAdapter as _HTTPAdapter class BaseHTTPAdapter(_HTTPAdapter): - def close(self): + def close(self) -> None: super().close() if hasattr(self, "pools"): self.pools.clear() @@ -24,10 +24,10 @@ class BaseHTTPAdapter(_HTTPAdapter): # https://github.com/psf/requests/commit/c0813a2d910ea6b4f8438b91d315b8d181302356 # changes requests.adapters.HTTPAdapter to no longer call get_connection() from # send(), but instead call _get_connection(). - def _get_connection(self, request, *args, **kwargs): + def _get_connection(self, request, *args, **kwargs): # type: ignore return self.get_connection(request.url, kwargs.get("proxies")) # Fix for requests 2.32.2+: # https://github.com/psf/requests/commit/c98e4d133ef29c46a9b68cd783087218a8075e05 - def get_connection_with_tls_context(self, request, verify, proxies=None, cert=None): + def get_connection_with_tls_context(self, request, verify, proxies=None, cert=None): # type: ignore return self.get_connection(request.url, proxies) diff --git a/plugins/module_utils/_api/transport/npipeconn.py b/plugins/module_utils/_api/transport/npipeconn.py index 3f618b38..f50cb91b 100644 --- a/plugins/module_utils/_api/transport/npipeconn.py +++ b/plugins/module_utils/_api/transport/npipeconn.py @@ -23,12 +23,12 @@ RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer class NpipeHTTPConnection(urllib3_connection.HTTPConnection): - def __init__(self, npipe_path, timeout=60): + def __init__(self, npipe_path: str, timeout: int | float = 60) -> None: super().__init__("localhost", timeout=timeout) self.npipe_path = npipe_path self.timeout = timeout - def connect(self): + def connect(self) -> None: sock = NpipeSocket() sock.settimeout(self.timeout) sock.connect(self.npipe_path) @@ -36,18 +36,20 @@ class NpipeHTTPConnection(urllib3_connection.HTTPConnection): class NpipeHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool): - def __init__(self, npipe_path, timeout=60, maxsize=10): + def __init__( + self, npipe_path: str, timeout: int | float = 60, maxsize: int = 10 + ) -> None: super().__init__("localhost", timeout=timeout, maxsize=maxsize) self.npipe_path = npipe_path self.timeout = timeout - def _new_conn(self): + def _new_conn(self) -> NpipeHTTPConnection: return NpipeHTTPConnection(self.npipe_path, self.timeout) # When re-using connections, urllib3 tries to call select() on our # NpipeSocket instance, causing a crash. To circumvent this, we override # _get_conn, where that check happens. - def _get_conn(self, timeout): + def _get_conn(self, timeout: int | float) -> NpipeHTTPConnection: conn = None try: conn = self.pool.get(block=self.block, timeout=timeout) @@ -67,7 +69,6 @@ class NpipeHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool): class NpipeHTTPAdapter(BaseHTTPAdapter): - __attrs__ = HTTPAdapter.__attrs__ + [ "npipe_path", "pools", @@ -77,11 +78,11 @@ class NpipeHTTPAdapter(BaseHTTPAdapter): def __init__( self, - base_url, - timeout=60, - pool_connections=constants.DEFAULT_NUM_POOLS, - max_pool_size=constants.DEFAULT_MAX_POOL_SIZE, - ): + base_url: str, + timeout: int | float = 60, + pool_connections: int = constants.DEFAULT_NUM_POOLS, + max_pool_size: int = constants.DEFAULT_MAX_POOL_SIZE, + ) -> None: self.npipe_path = base_url.replace("npipe://", "") self.timeout = timeout self.max_pool_size = max_pool_size @@ -90,7 +91,7 @@ class NpipeHTTPAdapter(BaseHTTPAdapter): ) super().__init__() - def get_connection(self, url, proxies=None): + def get_connection(self, url: str | bytes, proxies=None) -> NpipeHTTPConnectionPool: with self.pools.lock: pool = self.pools.get(url) if pool: @@ -103,7 +104,7 @@ class NpipeHTTPAdapter(BaseHTTPAdapter): return pool - def request_url(self, request, proxies): + def request_url(self, request, proxies) -> str: # The select_proxy utility in requests errors out when the provided URL # does not have a hostname, like is the case when using a UNIX socket. # Since proxies are an irrelevant notion in the case of UNIX sockets diff --git a/plugins/module_utils/_api/transport/npipesocket.py b/plugins/module_utils/_api/transport/npipesocket.py index 1f5fcb50..e4473f49 100644 --- a/plugins/module_utils/_api/transport/npipesocket.py +++ b/plugins/module_utils/_api/transport/npipesocket.py @@ -15,8 +15,10 @@ import functools import io import time import traceback +import typing as t +PYWIN32_IMPORT_ERROR: str | None # pylint: disable=invalid-name try: import pywintypes import win32api @@ -28,6 +30,13 @@ except ImportError: else: PYWIN32_IMPORT_ERROR = None # pylint: disable=invalid-name +if t.TYPE_CHECKING: + from collections.abc import Buffer, Callable + + _Self = t.TypeVar("_Self") + _P = t.ParamSpec("_P") + _R = t.TypeVar("_R") + ERROR_PIPE_BUSY = 0xE7 SECURITY_SQOS_PRESENT = 0x100000 @@ -36,10 +45,12 @@ SECURITY_ANONYMOUS = 0 MAXIMUM_RETRY_COUNT = 10 -def check_closed(f): +def check_closed( + f: Callable[t.Concatenate[_Self, _P], _R], +) -> Callable[t.Concatenate[_Self, _P], _R]: @functools.wraps(f) - def wrapped(self, *args, **kwargs): - if self._closed: + def wrapped(self: _Self, *args: _P.args, **kwargs: _P.kwargs) -> _R: + if self._closed: # type: ignore raise RuntimeError("Can not reuse socket after connection was closed.") return f(self, *args, **kwargs) @@ -53,25 +64,25 @@ class NpipeSocket: implemented. """ - def __init__(self, handle=None): + def __init__(self, handle=None) -> None: self._timeout = win32pipe.NMPWAIT_USE_DEFAULT_WAIT self._handle = handle - self._address = None + self._address: str | None = None self._closed = False - self.flags = None + self.flags: int | None = None - def accept(self): + def accept(self) -> t.NoReturn: raise NotImplementedError() - def bind(self, address): + def bind(self, address) -> t.NoReturn: raise NotImplementedError() - def close(self): + def close(self) -> None: self._handle.Close() self._closed = True @check_closed - def connect(self, address, retry_count=0): + def connect(self, address, retry_count: int = 0) -> None: try: handle = win32file.CreateFile( address, @@ -99,14 +110,14 @@ class NpipeSocket: return self.connect(address, retry_count) raise e - self.flags = win32pipe.GetNamedPipeInfo(handle)[0] + self.flags = win32pipe.GetNamedPipeInfo(handle)[0] # type: ignore self._handle = handle self._address = address @check_closed - def connect_ex(self, address): - return self.connect(address) + def connect_ex(self, address) -> None: + self.connect(address) @check_closed def detach(self): @@ -114,25 +125,25 @@ class NpipeSocket: return self._handle @check_closed - def dup(self): + def dup(self) -> NpipeSocket: return NpipeSocket(self._handle) - def getpeername(self): + def getpeername(self) -> str | None: return self._address - def getsockname(self): + def getsockname(self) -> str | None: return self._address - def getsockopt(self, level, optname, buflen=None): + def getsockopt(self, level, optname, buflen=None) -> t.NoReturn: raise NotImplementedError() - def ioctl(self, control, option): + def ioctl(self, control, option) -> t.NoReturn: raise NotImplementedError() - def listen(self, backlog): + def listen(self, backlog) -> t.NoReturn: raise NotImplementedError() - def makefile(self, mode=None, bufsize=None): + def makefile(self, mode: str, bufsize: int | None = None): if mode.strip("b") != "r": raise NotImplementedError() rawio = NpipeFileIOBase(self) @@ -141,30 +152,30 @@ class NpipeSocket: return io.BufferedReader(rawio, buffer_size=bufsize) @check_closed - def recv(self, bufsize, flags=0): + def recv(self, bufsize: int, flags: int = 0) -> str: dummy_err, data = win32file.ReadFile(self._handle, bufsize) return data @check_closed - def recvfrom(self, bufsize, flags=0): + def recvfrom(self, bufsize: int, flags: int = 0) -> tuple[str, str | None]: data = self.recv(bufsize, flags) return (data, self._address) @check_closed - def recvfrom_into(self, buf, nbytes=0, flags=0): - return self.recv_into(buf, nbytes, flags), self._address + def recvfrom_into( + self, buf: Buffer, nbytes: int = 0, flags: int = 0 + ) -> tuple[int, str | None]: + return self.recv_into(buf, nbytes), self._address @check_closed - def recv_into(self, buf, nbytes=0): - readbuf = buf - if not isinstance(buf, memoryview): - readbuf = memoryview(buf) + def recv_into(self, buf: Buffer, nbytes: int = 0) -> int: + readbuf = buf if isinstance(buf, memoryview) else memoryview(buf) event = win32event.CreateEvent(None, True, True, None) try: overlapped = pywintypes.OVERLAPPED() overlapped.hEvent = event - dummy_err, dummy_data = win32file.ReadFile( + dummy_err, dummy_data = win32file.ReadFile( # type: ignore self._handle, readbuf[:nbytes] if nbytes else readbuf, overlapped ) wait_result = win32event.WaitForSingleObject(event, self._timeout) @@ -176,12 +187,12 @@ class NpipeSocket: win32api.CloseHandle(event) @check_closed - def send(self, string, flags=0): + def send(self, string: Buffer, flags: int = 0) -> int: event = win32event.CreateEvent(None, True, True, None) try: overlapped = pywintypes.OVERLAPPED() overlapped.hEvent = event - win32file.WriteFile(self._handle, string, overlapped) + win32file.WriteFile(self._handle, string, overlapped) # type: ignore wait_result = win32event.WaitForSingleObject(event, self._timeout) if wait_result == win32event.WAIT_TIMEOUT: win32file.CancelIo(self._handle) @@ -191,20 +202,20 @@ class NpipeSocket: win32api.CloseHandle(event) @check_closed - def sendall(self, string, flags=0): + def sendall(self, string: Buffer, flags: int = 0) -> int: return self.send(string, flags) @check_closed - def sendto(self, string, address): + def sendto(self, string: Buffer, address: str) -> int: self.connect(address) return self.send(string) - def setblocking(self, flag): + def setblocking(self, flag: bool): if flag: return self.settimeout(None) return self.settimeout(0) - def settimeout(self, value): + def settimeout(self, value: int | float | None) -> None: if value is None: # Blocking mode self._timeout = win32event.INFINITE @@ -214,39 +225,39 @@ class NpipeSocket: # Timeout mode - Value converted to milliseconds self._timeout = int(value * 1000) - def gettimeout(self): + def gettimeout(self) -> int | float | None: return self._timeout - def setsockopt(self, level, optname, value): + def setsockopt(self, level, optname, value) -> t.NoReturn: raise NotImplementedError() @check_closed - def shutdown(self, how): + def shutdown(self, how) -> None: return self.close() class NpipeFileIOBase(io.RawIOBase): - def __init__(self, npipe_socket): + def __init__(self, npipe_socket) -> None: self.sock = npipe_socket - def close(self): + def close(self) -> None: super().close() self.sock = None - def fileno(self): + def fileno(self) -> int: return self.sock.fileno() - def isatty(self): + def isatty(self) -> bool: return False - def readable(self): + def readable(self) -> bool: return True - def readinto(self, buf): + def readinto(self, buf: Buffer) -> int: return self.sock.recv_into(buf) - def seekable(self): + def seekable(self) -> bool: return False - def writable(self): + def writable(self) -> bool: return False diff --git a/plugins/module_utils/_api/transport/sshconn.py b/plugins/module_utils/_api/transport/sshconn.py index 028f8f16..6bafa06d 100644 --- a/plugins/module_utils/_api/transport/sshconn.py +++ b/plugins/module_utils/_api/transport/sshconn.py @@ -17,6 +17,7 @@ import signal import socket import subprocess import traceback +import typing as t from queue import Empty from urllib.parse import urlparse @@ -25,6 +26,7 @@ from .._import_helper import HTTPAdapter, urllib3, urllib3_connection from .basehttpadapter import BaseHTTPAdapter +PARAMIKO_IMPORT_ERROR: str | None # pylint: disable=invalid-name try: import paramiko except ImportError: @@ -32,12 +34,15 @@ except ImportError: else: PARAMIKO_IMPORT_ERROR = None # pylint: disable=invalid-name +if t.TYPE_CHECKING: + from collections.abc import Buffer + RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer class SSHSocket(socket.socket): - def __init__(self, host): + def __init__(self, host: str) -> None: super().__init__(socket.AF_INET, socket.SOCK_STREAM) self.host = host self.port = None @@ -47,9 +52,9 @@ class SSHSocket(socket.socket): if "@" in self.host: self.user, self.host = self.host.split("@") - self.proc = None + self.proc: subprocess.Popen | None = None - def connect(self, **kwargs): + def connect(self, *args_: t.Any, **kwargs: t.Any) -> None: args = ["ssh"] if self.user: args = args + ["-l", self.user] @@ -81,37 +86,48 @@ class SSHSocket(socket.socket): preexec_fn=preexec_func, ) - def _write(self, data): - if not self.proc or self.proc.stdin.closed: + def _write(self, data: Buffer) -> int: + if not self.proc: raise RuntimeError( "SSH subprocess not initiated. connect() must be called first." ) + assert self.proc.stdin is not None + if self.proc.stdin.closed: + raise RuntimeError( + "SSH subprocess not initiated. connect() must be called first after close()." + ) written = self.proc.stdin.write(data) self.proc.stdin.flush() return written - def sendall(self, data): + def sendall(self, data: Buffer, *args, **kwargs) -> None: self._write(data) - def send(self, data): + def send(self, data: Buffer, *args, **kwargs) -> int: return self._write(data) - def recv(self, n): + def recv(self, n: int, *args, **kwargs) -> bytes: if not self.proc: raise RuntimeError( "SSH subprocess not initiated. connect() must be called first." ) + assert self.proc.stdout is not None return self.proc.stdout.read(n) - def makefile(self, mode): + def makefile(self, mode: str, *args, **kwargs) -> t.IO: # type: ignore if not self.proc: self.connect() - self.proc.stdout.channel = self + assert self.proc is not None + assert self.proc.stdout is not None + self.proc.stdout.channel = self # type: ignore return self.proc.stdout - def close(self): - if not self.proc or self.proc.stdin.closed: + def close(self) -> None: + if not self.proc: + return + assert self.proc.stdin is not None + if self.proc.stdin.closed: return self.proc.stdin.write(b"\n\n") self.proc.stdin.flush() @@ -119,13 +135,19 @@ class SSHSocket(socket.socket): class SSHConnection(urllib3_connection.HTTPConnection): - def __init__(self, ssh_transport=None, timeout=60, host=None): + def __init__( + self, + *, + ssh_transport=None, + timeout: int | float = 60, + host: str, + ) -> None: super().__init__("localhost", timeout=timeout) self.ssh_transport = ssh_transport self.timeout = timeout self.ssh_host = host - def connect(self): + def connect(self) -> None: if self.ssh_transport: sock = self.ssh_transport.open_session() sock.settimeout(self.timeout) @@ -141,7 +163,14 @@ class SSHConnection(urllib3_connection.HTTPConnection): class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool): scheme = "ssh" - def __init__(self, ssh_client=None, timeout=60, maxsize=10, host=None): + def __init__( + self, + *, + ssh_client: paramiko.SSHClient | None = None, + timeout: int | float = 60, + maxsize: int = 10, + host: str, + ) -> None: super().__init__("localhost", timeout=timeout, maxsize=maxsize) self.ssh_transport = None self.timeout = timeout @@ -149,13 +178,17 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool): self.ssh_transport = ssh_client.get_transport() self.ssh_host = host - def _new_conn(self): - return SSHConnection(self.ssh_transport, self.timeout, self.ssh_host) + def _new_conn(self) -> SSHConnection: + return SSHConnection( + ssh_transport=self.ssh_transport, + timeout=self.timeout, + host=self.ssh_host, + ) # When re-using connections, urllib3 calls fileno() on our # SSH channel instance, quickly overloading our fd limit. To avoid this, # we override _get_conn - def _get_conn(self, timeout): + def _get_conn(self, timeout: int | float) -> SSHConnection: conn = None try: conn = self.pool.get(block=self.block, timeout=timeout) @@ -175,7 +208,6 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool): class SSHHTTPAdapter(BaseHTTPAdapter): - __attrs__ = HTTPAdapter.__attrs__ + [ "pools", "timeout", @@ -186,13 +218,13 @@ class SSHHTTPAdapter(BaseHTTPAdapter): def __init__( self, - base_url, - timeout=60, - pool_connections=constants.DEFAULT_NUM_POOLS, - max_pool_size=constants.DEFAULT_MAX_POOL_SIZE, - shell_out=False, - ): - self.ssh_client = None + base_url: str, + timeout: int | float = 60, + pool_connections: int = constants.DEFAULT_NUM_POOLS, + max_pool_size: int = constants.DEFAULT_MAX_POOL_SIZE, + shell_out: bool = False, + ) -> None: + self.ssh_client: paramiko.SSHClient | None = None if not shell_out: self._create_paramiko_client(base_url) self._connect() @@ -208,30 +240,31 @@ class SSHHTTPAdapter(BaseHTTPAdapter): ) super().__init__() - def _create_paramiko_client(self, base_url): + def _create_paramiko_client(self, base_url: str) -> None: logging.getLogger("paramiko").setLevel(logging.WARNING) self.ssh_client = paramiko.SSHClient() - base_url = urlparse(base_url) - self.ssh_params = { - "hostname": base_url.hostname, - "port": base_url.port, - "username": base_url.username, + base_url_p = urlparse(base_url) + assert base_url_p.hostname is not None + self.ssh_params: dict[str, t.Any] = { + "hostname": base_url_p.hostname, + "port": base_url_p.port, + "username": base_url_p.username, } ssh_config_file = os.path.expanduser("~/.ssh/config") if os.path.exists(ssh_config_file): conf = paramiko.SSHConfig() with open(ssh_config_file, "rt", encoding="utf-8") as f: conf.parse(f) - host_config = conf.lookup(base_url.hostname) + host_config = conf.lookup(base_url_p.hostname) if "proxycommand" in host_config: self.ssh_params["sock"] = paramiko.ProxyCommand( host_config["proxycommand"] ) if "hostname" in host_config: self.ssh_params["hostname"] = host_config["hostname"] - if base_url.port is None and "port" in host_config: + if base_url_p.port is None and "port" in host_config: self.ssh_params["port"] = host_config["port"] - if base_url.username is None and "user" in host_config: + if base_url_p.username is None and "user" in host_config: self.ssh_params["username"] = host_config["user"] if "identityfile" in host_config: self.ssh_params["key_filename"] = host_config["identityfile"] @@ -239,11 +272,11 @@ class SSHHTTPAdapter(BaseHTTPAdapter): self.ssh_client.load_system_host_keys() self.ssh_client.set_missing_host_key_policy(paramiko.RejectPolicy()) - def _connect(self): + def _connect(self) -> None: if self.ssh_client: self.ssh_client.connect(**self.ssh_params) - def get_connection(self, url, proxies=None): + def get_connection(self, url: str | bytes, proxies=None) -> SSHConnectionPool: if not self.ssh_client: return SSHConnectionPool( ssh_client=self.ssh_client, @@ -270,7 +303,7 @@ class SSHHTTPAdapter(BaseHTTPAdapter): return pool - def close(self): + def close(self) -> None: super().close() if self.ssh_client: self.ssh_client.close() diff --git a/plugins/module_utils/_api/transport/ssladapter.py b/plugins/module_utils/_api/transport/ssladapter.py index d0cb8f79..2cad6cea 100644 --- a/plugins/module_utils/_api/transport/ssladapter.py +++ b/plugins/module_utils/_api/transport/ssladapter.py @@ -11,9 +11,7 @@ from __future__ import annotations -from ansible_collections.community.docker.plugins.module_utils._version import ( - LooseVersion, -) +import typing as t from .._import_helper import HTTPAdapter, urllib3 from .basehttpadapter import BaseHTTPAdapter @@ -30,14 +28,19 @@ PoolManager = urllib3.poolmanager.PoolManager class SSLHTTPAdapter(BaseHTTPAdapter): """An HTTPS Transport Adapter that uses an arbitrary SSL version.""" - __attrs__ = HTTPAdapter.__attrs__ + ["assert_hostname", "ssl_version"] + __attrs__ = HTTPAdapter.__attrs__ + ["assert_hostname"] - def __init__(self, ssl_version=None, assert_hostname=None, **kwargs): - self.ssl_version = ssl_version + def __init__( + self, + assert_hostname: bool | None = None, + **kwargs, + ) -> None: self.assert_hostname = assert_hostname super().__init__(**kwargs) - def init_poolmanager(self, connections, maxsize, block=False): + def init_poolmanager( + self, connections: int, maxsize: int, block: bool = False, **kwargs: t.Any + ) -> None: kwargs = { "num_pools": connections, "maxsize": maxsize, @@ -45,12 +48,10 @@ class SSLHTTPAdapter(BaseHTTPAdapter): } if self.assert_hostname is not None: kwargs["assert_hostname"] = self.assert_hostname - if self.ssl_version and self.can_override_ssl_version(): - kwargs["ssl_version"] = self.ssl_version self.poolmanager = PoolManager(**kwargs) - def get_connection(self, *args, **kwargs): + def get_connection(self, *args, **kwargs) -> urllib3.ConnectionPool: """ Ensure assert_hostname is set correctly on our pool @@ -61,15 +62,7 @@ class SSLHTTPAdapter(BaseHTTPAdapter): conn = super().get_connection(*args, **kwargs) if ( self.assert_hostname is not None - and conn.assert_hostname != self.assert_hostname + and conn.assert_hostname != self.assert_hostname # type: ignore ): - conn.assert_hostname = self.assert_hostname + conn.assert_hostname = self.assert_hostname # type: ignore return conn - - def can_override_ssl_version(self): - urllib_ver = urllib3.__version__.split("-")[0] - if urllib_ver is None: - return False - if urllib_ver == "dev": - return True - return LooseVersion(urllib_ver) > LooseVersion("1.5") diff --git a/plugins/module_utils/_api/transport/unixconn.py b/plugins/module_utils/_api/transport/unixconn.py index 2c615986..4d3b5679 100644 --- a/plugins/module_utils/_api/transport/unixconn.py +++ b/plugins/module_utils/_api/transport/unixconn.py @@ -12,6 +12,7 @@ from __future__ import annotations import socket +import typing as t from .. import constants from .._import_helper import HTTPAdapter, urllib3, urllib3_connection @@ -22,26 +23,27 @@ RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer class UnixHTTPConnection(urllib3_connection.HTTPConnection): - - def __init__(self, base_url, unix_socket, timeout=60): + def __init__( + self, base_url: str | bytes, unix_socket, timeout: int | float = 60 + ) -> None: super().__init__("localhost", timeout=timeout) self.base_url = base_url self.unix_socket = unix_socket self.timeout = timeout self.disable_buffering = False - def connect(self): + def connect(self) -> None: sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock.settimeout(self.timeout) sock.connect(self.unix_socket) self.sock = sock - def putheader(self, header, *values): + def putheader(self, header: str, *values: str) -> None: super().putheader(header, *values) if header == "Connection" and "Upgrade" in values: self.disable_buffering = True - def response_class(self, sock, *args, **kwargs): + def response_class(self, sock, *args, **kwargs) -> t.Any: # FIXME: We may need to disable buffering on Py3, # but there's no clear way to do it at the moment. See: # https://github.com/docker/docker-py/issues/1799 @@ -49,18 +51,23 @@ class UnixHTTPConnection(urllib3_connection.HTTPConnection): class UnixHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool): - def __init__(self, base_url, socket_path, timeout=60, maxsize=10): + def __init__( + self, + base_url: str | bytes, + socket_path: str, + timeout: int | float = 60, + maxsize: int = 10, + ) -> None: super().__init__("localhost", timeout=timeout, maxsize=maxsize) self.base_url = base_url self.socket_path = socket_path self.timeout = timeout - def _new_conn(self): + def _new_conn(self) -> UnixHTTPConnection: return UnixHTTPConnection(self.base_url, self.socket_path, self.timeout) class UnixHTTPAdapter(BaseHTTPAdapter): - __attrs__ = HTTPAdapter.__attrs__ + [ "pools", "socket_path", @@ -70,11 +77,11 @@ class UnixHTTPAdapter(BaseHTTPAdapter): def __init__( self, - socket_url, - timeout=60, - pool_connections=constants.DEFAULT_NUM_POOLS, - max_pool_size=constants.DEFAULT_MAX_POOL_SIZE, - ): + socket_url: str, + timeout: int | float = 60, + pool_connections: int = constants.DEFAULT_NUM_POOLS, + max_pool_size: int = constants.DEFAULT_MAX_POOL_SIZE, + ) -> None: socket_path = socket_url.replace("http+unix://", "") if not socket_path.startswith("/"): socket_path = "/" + socket_path @@ -86,7 +93,7 @@ class UnixHTTPAdapter(BaseHTTPAdapter): ) super().__init__() - def get_connection(self, url, proxies=None): + def get_connection(self, url: str | bytes, proxies=None) -> UnixHTTPConnectionPool: with self.pools.lock: pool = self.pools.get(url) if pool: @@ -99,7 +106,7 @@ class UnixHTTPAdapter(BaseHTTPAdapter): return pool - def request_url(self, request, proxies): + def request_url(self, request, proxies) -> str: # The select_proxy utility in requests errors out when the provided URL # does not have a hostname, like is the case when using a UNIX socket. # Since proxies are an irrelevant notion in the case of UNIX sockets diff --git a/plugins/module_utils/_api/types/daemon.py b/plugins/module_utils/_api/types/daemon.py index 6defe9b2..eb386169 100644 --- a/plugins/module_utils/_api/types/daemon.py +++ b/plugins/module_utils/_api/types/daemon.py @@ -12,6 +12,7 @@ from __future__ import annotations import socket +import typing as t from .._import_helper import urllib3 from ..errors import DockerException @@ -29,11 +30,11 @@ class CancellableStream: >>> events.close() """ - def __init__(self, stream, response): + def __init__(self, stream, response) -> None: self._stream = stream self._response = response - def __iter__(self): + def __iter__(self) -> t.Self: return self def __next__(self): @@ -46,7 +47,7 @@ class CancellableStream: next = __next__ - def close(self): + def close(self) -> None: """ Closes the event streaming. """ diff --git a/plugins/module_utils/_api/utils/build.py b/plugins/module_utils/_api/utils/build.py index d15774be..f87204c1 100644 --- a/plugins/module_utils/_api/utils/build.py +++ b/plugins/module_utils/_api/utils/build.py @@ -17,26 +17,38 @@ import random import re import tarfile import tempfile +import typing as t from ..constants import IS_WINDOWS_PLATFORM, WINDOWS_LONGPATH_PREFIX from . import fnmatch +if t.TYPE_CHECKING: + from collections.abc import Sequence + + _SEP = re.compile("/|\\\\") if IS_WINDOWS_PLATFORM else re.compile("/") -def tar(path, exclude=None, dockerfile=None, fileobj=None, gzip=False): +def tar( + path: str, + exclude: list[str] | None = None, + dockerfile: tuple[str, str | None] | tuple[None, None] | None = None, + fileobj: t.IO[bytes] | None = None, + gzip: bool = False, +) -> t.IO[bytes]: root = os.path.abspath(path) exclude = exclude or [] dockerfile = dockerfile or (None, None) - extra_files = [] + extra_files: list[tuple[str, str]] = [] if dockerfile[1] is not None: + assert dockerfile[0] is not None dockerignore_contents = "\n".join( (exclude or [".dockerignore"]) + [dockerfile[0]] ) extra_files = [ (".dockerignore", dockerignore_contents), - dockerfile, + dockerfile, # type: ignore ] return create_archive( files=sorted(exclude_paths(root, exclude, dockerfile=dockerfile[0])), @@ -47,7 +59,9 @@ def tar(path, exclude=None, dockerfile=None, fileobj=None, gzip=False): ) -def exclude_paths(root, patterns, dockerfile=None): +def exclude_paths( + root: str, patterns: list[str], dockerfile: str | None = None +) -> set[str]: """ Given a root directory path and a list of .dockerignore patterns, return an iterator of all paths (both regular files and directories) in the root @@ -64,7 +78,7 @@ def exclude_paths(root, patterns, dockerfile=None): return set(pm.walk(root)) -def build_file_list(root): +def build_file_list(root: str) -> list[str]: files = [] for dirname, dirnames, fnames in os.walk(root): for filename in fnames + dirnames: @@ -74,7 +88,13 @@ def build_file_list(root): return files -def create_archive(root, files=None, fileobj=None, gzip=False, extra_files=None): +def create_archive( + root: str, + files: Sequence[str] | None = None, + fileobj: t.IO[bytes] | None = None, + gzip: bool = False, + extra_files: Sequence[tuple[str, str]] | None = None, +) -> t.IO[bytes]: extra_files = extra_files or [] if not fileobj: fileobj = tempfile.NamedTemporaryFile() @@ -92,7 +112,7 @@ def create_archive(root, files=None, fileobj=None, gzip=False, extra_files=None) if i is None: # This happens when we encounter a socket file. We can safely # ignore it and proceed. - continue + continue # type: ignore # Workaround https://bugs.python.org/issue32713 if i.mtime < 0 or i.mtime > 8**11 - 1: @@ -124,11 +144,11 @@ def create_archive(root, files=None, fileobj=None, gzip=False, extra_files=None) return fileobj -def mkbuildcontext(dockerfile): +def mkbuildcontext(dockerfile: io.BytesIO | t.IO[bytes]) -> t.IO[bytes]: f = tempfile.NamedTemporaryFile() # pylint: disable=consider-using-with try: with tarfile.open(mode="w", fileobj=f) as t: - if isinstance(dockerfile, io.StringIO): + if isinstance(dockerfile, io.StringIO): # type: ignore raise TypeError("Please use io.BytesIO to create in-memory Dockerfiles") if isinstance(dockerfile, io.BytesIO): dfinfo = tarfile.TarInfo("Dockerfile") @@ -144,17 +164,17 @@ def mkbuildcontext(dockerfile): return f -def split_path(p): +def split_path(p: str) -> list[str]: return [pt for pt in re.split(_SEP, p) if pt and pt != "."] -def normalize_slashes(p): +def normalize_slashes(p: str) -> str: if IS_WINDOWS_PLATFORM: return "/".join(split_path(p)) return p -def walk(root, patterns, default=True): +def walk(root: str, patterns: Sequence[str], default: bool = True) -> t.Generator[str]: pm = PatternMatcher(patterns) return pm.walk(root) @@ -162,11 +182,11 @@ def walk(root, patterns, default=True): # Heavily based on # https://github.com/moby/moby/blob/master/pkg/fileutils/fileutils.go class PatternMatcher: - def __init__(self, patterns): + def __init__(self, patterns: Sequence[str]) -> None: self.patterns = list(filter(lambda p: p.dirs, [Pattern(p) for p in patterns])) self.patterns.append(Pattern("!.dockerignore")) - def matches(self, filepath): + def matches(self, filepath: str) -> bool: matched = False parent_path = os.path.dirname(filepath) parent_path_dirs = split_path(parent_path) @@ -185,8 +205,8 @@ class PatternMatcher: return matched - def walk(self, root): - def rec_walk(current_dir): + def walk(self, root: str) -> t.Generator[str]: + def rec_walk(current_dir: str) -> t.Generator[str]: for f in os.listdir(current_dir): fpath = os.path.join(os.path.relpath(current_dir, root), f) if fpath.startswith("." + os.path.sep): @@ -220,7 +240,7 @@ class PatternMatcher: class Pattern: - def __init__(self, pattern_str): + def __init__(self, pattern_str: str) -> None: self.exclusion = False if pattern_str.startswith("!"): self.exclusion = True @@ -230,8 +250,7 @@ class Pattern: self.cleaned_pattern = "/".join(self.dirs) @classmethod - def normalize(cls, p): - + def normalize(cls, p: str) -> list[str]: # Remove trailing spaces p = p.strip() @@ -256,11 +275,13 @@ class Pattern: i += 1 return split - def match(self, filepath): + def match(self, filepath: str) -> bool: return fnmatch.fnmatch(normalize_slashes(filepath), self.cleaned_pattern) -def process_dockerfile(dockerfile, path): +def process_dockerfile( + dockerfile: str | None, path: str +) -> tuple[str, str | None] | tuple[None, None]: if not dockerfile: return (None, None) @@ -268,7 +289,7 @@ def process_dockerfile(dockerfile, path): if not os.path.isabs(dockerfile): abs_dockerfile = os.path.join(path, dockerfile) if IS_WINDOWS_PLATFORM and path.startswith(WINDOWS_LONGPATH_PREFIX): - abs_dockerfile = f"{WINDOWS_LONGPATH_PREFIX}{os.path.normpath(abs_dockerfile[len(WINDOWS_LONGPATH_PREFIX):])}" + abs_dockerfile = f"{WINDOWS_LONGPATH_PREFIX}{os.path.normpath(abs_dockerfile[len(WINDOWS_LONGPATH_PREFIX) :])}" if os.path.splitdrive(path)[0] != os.path.splitdrive(abs_dockerfile)[ 0 ] or os.path.relpath(abs_dockerfile, path).startswith(".."): diff --git a/plugins/module_utils/_api/utils/config.py b/plugins/module_utils/_api/utils/config.py index 934f2dfc..eaa9542a 100644 --- a/plugins/module_utils/_api/utils/config.py +++ b/plugins/module_utils/_api/utils/config.py @@ -14,6 +14,7 @@ from __future__ import annotations import json import logging import os +import typing as t from ..constants import IS_WINDOWS_PLATFORM @@ -24,11 +25,11 @@ LEGACY_DOCKER_CONFIG_FILENAME = ".dockercfg" log = logging.getLogger(__name__) -def get_default_config_file(): +def get_default_config_file() -> str: return os.path.join(home_dir(), DOCKER_CONFIG_FILENAME) -def find_config_file(config_path=None): +def find_config_file(config_path: str | None = None) -> str | None: homedir = home_dir() paths = list( filter( @@ -54,14 +55,14 @@ def find_config_file(config_path=None): return None -def config_path_from_environment(): +def config_path_from_environment() -> str | None: config_dir = os.environ.get("DOCKER_CONFIG") if not config_dir: return None return os.path.join(config_dir, os.path.basename(DOCKER_CONFIG_FILENAME)) -def home_dir(): +def home_dir() -> str: """ Get the user's home directory, using the same logic as the Docker Engine client - use %USERPROFILE% on Windows, $HOME/getuid on POSIX. @@ -71,7 +72,7 @@ def home_dir(): return os.path.expanduser("~") -def load_general_config(config_path=None): +def load_general_config(config_path: str | None = None) -> dict[str, t.Any]: config_file = find_config_file(config_path) if not config_file: diff --git a/plugins/module_utils/_api/utils/decorators.py b/plugins/module_utils/_api/utils/decorators.py index f046ebd3..59821aca 100644 --- a/plugins/module_utils/_api/utils/decorators.py +++ b/plugins/module_utils/_api/utils/decorators.py @@ -12,16 +12,37 @@ from __future__ import annotations import functools +import typing as t from .. import errors from . import utils -def minimum_version(version): - def decorator(f): +if t.TYPE_CHECKING: + from collections.abc import Callable + + from ..api.client import APIClient + + _Self = t.TypeVar("_Self") + _P = t.ParamSpec("_P") + _R = t.TypeVar("_R") + + +def minimum_version( + version: str, +) -> Callable[ + [Callable[t.Concatenate[_Self, _P], _R]], + Callable[t.Concatenate[_Self, _P], _R], +]: + def decorator( + f: Callable[t.Concatenate[_Self, _P], _R], + ) -> Callable[t.Concatenate[_Self, _P], _R]: @functools.wraps(f) - def wrapper(self, *args, **kwargs): - if utils.version_lt(self._version, version): + def wrapper(self: _Self, *args: _P.args, **kwargs: _P.kwargs) -> _R: + # We use _Self instead of APIClient since this is used for mixins for APIClient. + # This unfortunately means that self._version does not exist in the mixin, + # it only exists after mixing in. This is why we ignore types here. + if utils.version_lt(self._version, version): # type: ignore raise errors.InvalidVersion( f"{f.__name__} is not available for version < {version}" ) @@ -32,13 +53,16 @@ def minimum_version(version): return decorator -def update_headers(f): - def inner(self, *args, **kwargs): +def update_headers( + f: Callable[t.Concatenate[APIClient, _P], _R], +) -> Callable[t.Concatenate[APIClient, _P], _R]: + def inner(self: APIClient, *args: _P.args, **kwargs: _P.kwargs) -> _R: if "HttpHeaders" in self._general_configs: if not kwargs.get("headers"): kwargs["headers"] = self._general_configs["HttpHeaders"] else: - kwargs["headers"].update(self._general_configs["HttpHeaders"]) + # We cannot (yet) model that kwargs["headers"] should be a dictionary + kwargs["headers"].update(self._general_configs["HttpHeaders"]) # type: ignore return f(self, *args, **kwargs) return inner diff --git a/plugins/module_utils/_api/utils/fnmatch.py b/plugins/module_utils/_api/utils/fnmatch.py index 7fefa4e7..525cf84a 100644 --- a/plugins/module_utils/_api/utils/fnmatch.py +++ b/plugins/module_utils/_api/utils/fnmatch.py @@ -28,16 +28,16 @@ import re __all__ = ["fnmatch", "fnmatchcase", "translate"] -_cache = {} +_cache: dict[str, re.Pattern] = {} _MAXCACHE = 100 -def _purge(): +def _purge() -> None: """Clear the pattern cache""" _cache.clear() -def fnmatch(name, pat): +def fnmatch(name: str, pat: str): """Test whether FILENAME matches PATTERN. Patterns are Unix shell style: @@ -58,7 +58,7 @@ def fnmatch(name, pat): return fnmatchcase(name, pat) -def fnmatchcase(name, pat): +def fnmatchcase(name: str, pat: str) -> bool: """Test whether FILENAME matches PATTERN, including case. This is a version of fnmatch() which does not case-normalize its arguments. @@ -74,7 +74,7 @@ def fnmatchcase(name, pat): return re_pat.match(name) is not None -def translate(pat): +def translate(pat: str) -> str: """Translate a shell PATTERN to a regular expression. There is no way to quote meta-characters. diff --git a/plugins/module_utils/_api/utils/json_stream.py b/plugins/module_utils/_api/utils/json_stream.py index dac3d0ca..ada8905e 100644 --- a/plugins/module_utils/_api/utils/json_stream.py +++ b/plugins/module_utils/_api/utils/json_stream.py @@ -13,14 +13,22 @@ from __future__ import annotations import json import json.decoder +import typing as t from ..errors import StreamParseError +if t.TYPE_CHECKING: + import re + from collections.abc import Callable + + _T = t.TypeVar("_T") + + json_decoder = json.JSONDecoder() -def stream_as_text(stream): +def stream_as_text(stream: t.Generator[bytes | str]) -> t.Generator[str]: """ Given a stream of bytes or text, if any of the items in the stream are bytes convert them to text. @@ -33,20 +41,22 @@ def stream_as_text(stream): yield data -def json_splitter(buffer): +def json_splitter(buffer: str) -> tuple[t.Any, str] | None: """Attempt to parse a json object from a buffer. If there is at least one object, return it and the rest of the buffer, otherwise return None. """ buffer = buffer.strip() try: obj, index = json_decoder.raw_decode(buffer) - rest = buffer[json.decoder.WHITESPACE.match(buffer, index).end() :] + ws: re.Pattern = json.decoder.WHITESPACE # type: ignore[attr-defined] + m = ws.match(buffer, index) + rest = buffer[m.end() :] if m else buffer[index:] return obj, rest except ValueError: return None -def json_stream(stream): +def json_stream(stream: t.Generator[str | bytes]) -> t.Generator[t.Any]: """Given a stream of text, return a stream of json objects. This handles streams which are inconsistently buffered (some entries may be newline delimited, and others are not). @@ -54,21 +64,24 @@ def json_stream(stream): return split_buffer(stream, json_splitter, json_decoder.decode) -def line_splitter(buffer, separator="\n"): +def line_splitter(buffer: str, separator: str = "\n") -> tuple[str, str] | None: index = buffer.find(str(separator)) if index == -1: return None return buffer[: index + 1], buffer[index + 1 :] -def split_buffer(stream, splitter=None, decoder=lambda a: a): +def split_buffer( + stream: t.Generator[str | bytes], + splitter: Callable[[str], tuple[_T, str] | None], + decoder: Callable[[str], _T], +) -> t.Generator[_T | str]: """Given a generator which yields strings and a splitter function, joins all input, splits on the separator and yields each chunk. Unlike string.split(), each chunk includes the trailing separator, except for the last one if none was found on the end of the input. """ - splitter = splitter or line_splitter buffered = "" for data in stream_as_text(stream): diff --git a/plugins/module_utils/_api/utils/ports.py b/plugins/module_utils/_api/utils/ports.py index 11a350e6..eab15bd0 100644 --- a/plugins/module_utils/_api/utils/ports.py +++ b/plugins/module_utils/_api/utils/ports.py @@ -12,6 +12,11 @@ from __future__ import annotations import re +import typing as t + + +if t.TYPE_CHECKING: + from collections.abc import Collection, Sequence PORT_SPEC = re.compile( @@ -26,32 +31,42 @@ PORT_SPEC = re.compile( ) -def add_port_mapping(port_bindings, internal_port, external): +def add_port_mapping( + port_bindings: dict[str, list[str | tuple[str, str | None] | None]], + internal_port: str, + external: str | tuple[str, str | None] | None, +) -> None: if internal_port in port_bindings: port_bindings[internal_port].append(external) else: port_bindings[internal_port] = [external] -def add_port(port_bindings, internal_port_range, external_range): +def add_port( + port_bindings: dict[str, list[str | tuple[str, str | None] | None]], + internal_port_range: list[str], + external_range: list[str] | list[tuple[str, str | None]] | None, +) -> None: if external_range is None: for internal_port in internal_port_range: add_port_mapping(port_bindings, internal_port, None) else: - ports = zip(internal_port_range, external_range) - for internal_port, external_port in ports: - add_port_mapping(port_bindings, internal_port, external_port) + for internal_port, external_port in zip(internal_port_range, external_range): + # mypy loses the exact type of eternal_port elements for some reason... + add_port_mapping(port_bindings, internal_port, external_port) # type: ignore -def build_port_bindings(ports): - port_bindings = {} +def build_port_bindings( + ports: Collection[str], +) -> dict[str, list[str | tuple[str, str | None] | None]]: + port_bindings: dict[str, list[str | tuple[str, str | None] | None]] = {} for port in ports: internal_port_range, external_range = split_port(port) add_port(port_bindings, internal_port_range, external_range) return port_bindings -def _raise_invalid_port(port): +def _raise_invalid_port(port: str) -> t.NoReturn: raise ValueError( f'Invalid port "{port}", should be ' "[[remote_ip:]remote_port[-remote_port]:]" @@ -59,39 +74,64 @@ def _raise_invalid_port(port): ) -def port_range(start, end, proto, randomly_available_port=False): - if not start: +@t.overload +def port_range( + start: str, + end: str | None, + proto: str, + randomly_available_port: bool = False, +) -> list[str]: ... + + +@t.overload +def port_range( + start: str | None, + end: str | None, + proto: str, + randomly_available_port: bool = False, +) -> list[str] | None: ... + + +def port_range( + start: str | None, + end: str | None, + proto: str, + randomly_available_port: bool = False, +) -> list[str] | None: + if start is None: return start - if not end: + if end is None: return [f"{start}{proto}"] if randomly_available_port: return [f"{start}-{end}{proto}"] return [f"{port}{proto}" for port in range(int(start), int(end) + 1)] -def split_port(port): - if hasattr(port, "legacy_repr"): - # This is the worst hack, but it prevents a bug in Compose 1.14.0 - # https://github.com/docker/docker-py/issues/1668 - # TODO: remove once fixed in Compose stable - port = port.legacy_repr() +def split_port( + port: str, +) -> tuple[list[str], list[str] | list[tuple[str, str | None]] | None]: port = str(port) match = PORT_SPEC.match(port) if match is None: _raise_invalid_port(port) parts = match.groupdict() - host = parts["host"] - proto = parts["proto"] or "" - internal = port_range(parts["int"], parts["int_end"], proto) - external = port_range(parts["ext"], parts["ext_end"], "", len(internal) == 1) + host: str | None = parts["host"] + proto: str = parts["proto"] or "" + int_p: str = parts["int"] + ext_p: str = parts["ext"] + internal: list[str] = port_range(int_p, parts["int_end"], proto) # type: ignore + external = port_range(ext_p or None, parts["ext_end"], "", len(internal) == 1) if host is None: - if external is not None and len(internal) != len(external): + if (external is not None and len(internal) != len(external)) or ext_p == "": raise ValueError("Port ranges don't match in length") return internal, external + external_or_none: Sequence[str | None] if not external: - external = [None] * len(internal) - elif len(internal) != len(external): - raise ValueError("Port ranges don't match in length") - return internal, [(host, ext_port) for ext_port in external] + external_or_none = [None] * len(internal) + else: + external_or_none = external + if len(internal) != len(external_or_none): + raise ValueError("Port ranges don't match in length") + return internal, [(host, ext_port) for ext_port in external_or_none] diff --git a/plugins/module_utils/_api/utils/proxy.py b/plugins/module_utils/_api/utils/proxy.py index af1fc064..0f5fa9f3 100644 --- a/plugins/module_utils/_api/utils/proxy.py +++ b/plugins/module_utils/_api/utils/proxy.py @@ -20,23 +20,23 @@ class ProxyConfig(dict): """ @property - def http(self): + def http(self) -> str | None: return self.get("http") @property - def https(self): + def https(self) -> str | None: return self.get("https") @property - def ftp(self): + def ftp(self) -> str | None: return self.get("ftp") @property - def no_proxy(self): + def no_proxy(self) -> str | None: return self.get("no_proxy") @staticmethod - def from_dict(config): + def from_dict(config: dict[str, str]) -> ProxyConfig: """ Instantiate a new ProxyConfig from a dictionary that represents a client configuration, as described in `the documentation`_. @@ -51,7 +51,7 @@ class ProxyConfig(dict): no_proxy=config.get("noProxy"), ) - def get_environment(self): + def get_environment(self) -> dict[str, str]: """ Return a dictionary representing the environment variables used to set the proxy settings. @@ -67,7 +67,7 @@ class ProxyConfig(dict): env["no_proxy"] = env["NO_PROXY"] = self.no_proxy return env - def inject_proxy_environment(self, environment): + def inject_proxy_environment(self, environment: list[str]) -> list[str]: """ Given a list of strings representing environment variables, prepend the environment variables corresponding to the proxy settings. @@ -82,5 +82,5 @@ class ProxyConfig(dict): # variables defined in "environment" to take precedence. return proxy_env + environment - def __str__(self): + def __str__(self) -> str: return f"ProxyConfig(http={self.http}, https={self.https}, ftp={self.ftp}, no_proxy={self.no_proxy})" diff --git a/plugins/module_utils/_api/utils/socket.py b/plugins/module_utils/_api/utils/socket.py index 615018ad..6619e0ff 100644 --- a/plugins/module_utils/_api/utils/socket.py +++ b/plugins/module_utils/_api/utils/socket.py @@ -16,10 +16,15 @@ import os import select import socket as pysocket import struct +import typing as t from ..transport.npipesocket import NpipeSocket +if t.TYPE_CHECKING: + from collections.abc import Iterable + + STDOUT = 1 STDERR = 2 @@ -33,7 +38,7 @@ class SocketError(Exception): NPIPE_ENDED = 109 -def read(socket, n=4096): +def read(socket, n: int = 4096) -> bytes | None: """ Reads at most n bytes from socket """ @@ -58,6 +63,7 @@ def read(socket, n=4096): except EnvironmentError as e: if e.errno not in recoverable_errors: raise + return None # TODO ??? except Exception as e: is_pipe_ended = ( isinstance(socket, NpipeSocket) @@ -67,11 +73,11 @@ def read(socket, n=4096): if is_pipe_ended: # npipes do not support duplex sockets, so we interpret # a PIPE_ENDED error as a close operation (0-length read). - return "" + return b"" raise -def read_exactly(socket, n): +def read_exactly(socket, n: int) -> bytes: """ Reads exactly n bytes from socket Raises SocketError if there is not enough data @@ -85,7 +91,7 @@ def read_exactly(socket, n): return data -def next_frame_header(socket): +def next_frame_header(socket) -> tuple[int, int]: """ Returns the stream and size of the next frame of data waiting to be read from socket, according to the protocol defined here: @@ -101,7 +107,7 @@ def next_frame_header(socket): return (stream, actual) -def frames_iter(socket, tty): +def frames_iter(socket, tty: bool) -> t.Generator[tuple[int, bytes]]: """ Return a generator of frames read from socket. A frame is a tuple where the first item is the stream number and the second item is a chunk of data. @@ -114,7 +120,7 @@ def frames_iter(socket, tty): return frames_iter_no_tty(socket) -def frames_iter_no_tty(socket): +def frames_iter_no_tty(socket) -> t.Generator[tuple[int, bytes]]: """ Returns a generator of data read from the socket when the tty setting is not enabled. @@ -135,20 +141,34 @@ def frames_iter_no_tty(socket): yield (stream, result) -def frames_iter_tty(socket): +def frames_iter_tty(socket) -> t.Generator[bytes]: """ Return a generator of data read from the socket when the tty setting is enabled. """ while True: result = read(socket) - if len(result) == 0: + if not result: # We have reached EOF return yield result -def consume_socket_output(frames, demux=False): +@t.overload +def consume_socket_output(frames, demux: t.Literal[False] = False) -> bytes: ... + + +@t.overload +def consume_socket_output(frames, demux: t.Literal[True]) -> tuple[bytes, bytes]: ... + + +@t.overload +def consume_socket_output( + frames, demux: bool = False +) -> bytes | tuple[bytes, bytes]: ... + + +def consume_socket_output(frames, demux: bool = False) -> bytes | tuple[bytes, bytes]: """ Iterate through frames read from the socket and return the result. @@ -167,7 +187,7 @@ def consume_socket_output(frames, demux=False): # If the streams are demultiplexed, the generator yields tuples # (stdout, stderr) - out = [None, None] + out: list[bytes | None] = [None, None] for frame in frames: # It is guaranteed that for each frame, one and only one stream # is not None. @@ -183,10 +203,10 @@ def consume_socket_output(frames, demux=False): out[1] = frame[1] else: out[1] += frame[1] - return tuple(out) + return tuple(out) # type: ignore -def demux_adaptor(stream_id, data): +def demux_adaptor(stream_id: int, data: bytes) -> tuple[bytes | None, bytes | None]: """ Utility to demultiplex stdout and stderr when reading frames from the socket. diff --git a/plugins/module_utils/_api/utils/utils.py b/plugins/module_utils/_api/utils/utils.py index 31c9b39f..136b0f23 100644 --- a/plugins/module_utils/_api/utils/utils.py +++ b/plugins/module_utils/_api/utils/utils.py @@ -18,6 +18,7 @@ import os import os.path import shlex import string +import typing as t from urllib.parse import urlparse, urlunparse from ansible_collections.community.docker.plugins.module_utils._version import ( @@ -34,32 +35,23 @@ from ..constants import ( from ..tls import TLSConfig +if t.TYPE_CHECKING: + import ssl + from collections.abc import Mapping, Sequence + + URLComponents = collections.namedtuple( "URLComponents", "scheme netloc url params query fragment", ) -def create_ipam_pool(*args, **kwargs): - raise errors.DeprecatedMethod( - "utils.create_ipam_pool has been removed. Please use a " - "docker.types.IPAMPool object instead." - ) - - -def create_ipam_config(*args, **kwargs): - raise errors.DeprecatedMethod( - "utils.create_ipam_config has been removed. Please use a " - "docker.types.IPAMConfig object instead." - ) - - -def decode_json_header(header): +def decode_json_header(header: str) -> dict[str, t.Any]: data = base64.b64decode(header).decode("utf-8") return json.loads(data) -def compare_version(v1, v2): +def compare_version(v1: str, v2: str) -> t.Literal[-1, 0, 1]: """Compare docker versions >>> v1 = '1.9' @@ -80,43 +72,64 @@ def compare_version(v1, v2): return 1 -def version_lt(v1, v2): +def version_lt(v1: str, v2: str) -> bool: return compare_version(v1, v2) > 0 -def version_gte(v1, v2): +def version_gte(v1: str, v2: str) -> bool: return not version_lt(v1, v2) -def _convert_port_binding(binding): +def _convert_port_binding( + binding: ( + tuple[str, str | int | None] + | tuple[str | int | None] + | dict[str, str] + | str + | int + ), +) -> dict[str, str]: result = {"HostIp": "", "HostPort": ""} + host_port: str | int | None = "" if isinstance(binding, tuple): if len(binding) == 2: - result["HostPort"] = binding[1] + host_port = binding[1] # type: ignore result["HostIp"] = binding[0] elif isinstance(binding[0], str): result["HostIp"] = binding[0] else: - result["HostPort"] = binding[0] + host_port = binding[0] elif isinstance(binding, dict): if "HostPort" in binding: - result["HostPort"] = binding["HostPort"] + host_port = binding["HostPort"] if "HostIp" in binding: result["HostIp"] = binding["HostIp"] else: raise ValueError(binding) else: - result["HostPort"] = binding - - if result["HostPort"] is None: - result["HostPort"] = "" - else: - result["HostPort"] = str(result["HostPort"]) + host_port = binding + result["HostPort"] = str(host_port) if host_port is not None else "" return result -def convert_port_bindings(port_bindings): +def convert_port_bindings( + port_bindings: dict[ + str | int, + tuple[str, str | int | None] + | tuple[str | int | None] + | dict[str, str] + | str + | int + | list[ + tuple[str, str | int | None] + | tuple[str | int | None] + | dict[str, str] + | str + | int + ], + ], +) -> dict[str, list[dict[str, str]]]: result = {} for k, v in port_bindings.items(): key = str(k) @@ -129,9 +142,11 @@ def convert_port_bindings(port_bindings): return result -def convert_volume_binds(binds): +def convert_volume_binds( + binds: list[str] | Mapping[str | bytes, dict[str, str | bytes] | bytes | str | int], +) -> list[str]: if isinstance(binds, list): - return binds + return binds # type: ignore result = [] for k, v in binds.items(): @@ -149,7 +164,7 @@ def convert_volume_binds(binds): if "ro" in v: mode = "ro" if v["ro"] else "rw" elif "mode" in v: - mode = v["mode"] + mode = v["mode"] # type: ignore # TODO else: mode = "rw" @@ -165,9 +180,9 @@ def convert_volume_binds(binds): ] if "propagation" in v and v["propagation"] in propagation_modes: if mode: - mode = ",".join([mode, v["propagation"]]) + mode = ",".join([mode, v["propagation"]]) # type: ignore # TODO else: - mode = v["propagation"] + mode = v["propagation"] # type: ignore # TODO result.append(f"{k}:{bind}:{mode}") else: @@ -177,7 +192,7 @@ def convert_volume_binds(binds): return result -def convert_tmpfs_mounts(tmpfs): +def convert_tmpfs_mounts(tmpfs: dict[str, str] | list[str]) -> dict[str, str]: if isinstance(tmpfs, dict): return tmpfs @@ -204,9 +219,11 @@ def convert_tmpfs_mounts(tmpfs): return result -def convert_service_networks(networks): +def convert_service_networks( + networks: list[str | dict[str, str]], +) -> list[dict[str, str]]: if not networks: - return networks + return networks # type: ignore if not isinstance(networks, list): raise TypeError("networks parameter must be a list.") @@ -218,17 +235,17 @@ def convert_service_networks(networks): return result -def parse_repository_tag(repo_name): +def parse_repository_tag(repo_name: str) -> tuple[str, str | None]: parts = repo_name.rsplit("@", 1) if len(parts) == 2: - return tuple(parts) + return tuple(parts) # type: ignore parts = repo_name.rsplit(":", 1) if len(parts) == 2 and "/" not in parts[1]: - return tuple(parts) + return tuple(parts) # type: ignore return repo_name, None -def parse_host(addr, is_win32=False, tls=False): +def parse_host(addr: str | None, is_win32: bool = False, tls: bool = False) -> str: # Sensible defaults if not addr and is_win32: return DEFAULT_NPIPE @@ -308,7 +325,7 @@ def parse_host(addr, is_win32=False, tls=False): ).rstrip("/") -def parse_devices(devices): +def parse_devices(devices: Sequence[dict[str, str] | str]) -> list[dict[str, str]]: device_list = [] for device in devices: if isinstance(device, dict): @@ -337,7 +354,10 @@ def parse_devices(devices): return device_list -def kwargs_from_env(ssl_version=None, assert_hostname=None, environment=None): +def kwargs_from_env( + assert_hostname: bool | None = None, + environment: Mapping[str, str] | None = None, +) -> dict[str, t.Any]: if not environment: environment = os.environ host = environment.get("DOCKER_HOST") @@ -347,14 +367,14 @@ def kwargs_from_env(ssl_version=None, assert_hostname=None, environment=None): # empty string for tls verify counts as "false". # Any value or 'unset' counts as true. - tls_verify = environment.get("DOCKER_TLS_VERIFY") - if tls_verify == "": + tls_verify_str = environment.get("DOCKER_TLS_VERIFY") + if tls_verify_str == "": tls_verify = False else: - tls_verify = tls_verify is not None + tls_verify = tls_verify_str is not None enable_tls = cert_path or tls_verify - params = {} + params: dict[str, t.Any] = {} if host: params["base_url"] = host @@ -377,14 +397,13 @@ def kwargs_from_env(ssl_version=None, assert_hostname=None, environment=None): ), ca_cert=os.path.join(cert_path, "ca.pem"), verify=tls_verify, - ssl_version=ssl_version, assert_hostname=assert_hostname, ) return params -def convert_filters(filters): +def convert_filters(filters: Mapping[str, bool | str | list[str]]) -> str: result = {} for k, v in filters.items(): if isinstance(v, bool): @@ -397,7 +416,7 @@ def convert_filters(filters): return json.dumps(result) -def parse_bytes(s): +def parse_bytes(s: int | float | str) -> int | float: if isinstance(s, (int, float)): return s if len(s) == 0: @@ -435,14 +454,16 @@ def parse_bytes(s): return s -def normalize_links(links): +def normalize_links(links: dict[str, str] | Sequence[tuple[str, str]]) -> list[str]: if isinstance(links, dict): - links = links.items() + sorted_links = sorted(links.items()) + else: + sorted_links = sorted(links) - return [f"{k}:{v}" if v else k for k, v in sorted(links)] + return [f"{k}:{v}" if v else k for k, v in sorted_links] -def parse_env_file(env_file): +def parse_env_file(env_file: str | os.PathLike) -> dict[str, str]: """ Reads a line-separated environment file. The format of each line should be "key=value". @@ -451,7 +472,6 @@ def parse_env_file(env_file): with open(env_file, "rt", encoding="utf-8") as f: for line in f: - if line[0] == "#": continue @@ -471,11 +491,11 @@ def parse_env_file(env_file): return environment -def split_command(command): +def split_command(command: str) -> list[str]: return shlex.split(command) -def format_environment(environment): +def format_environment(environment: Mapping[str, str | bytes]) -> list[str]: def format_env(key, value): if value is None: return key @@ -487,16 +507,9 @@ def format_environment(environment): return [format_env(*var) for var in environment.items()] -def format_extra_hosts(extra_hosts, task=False): +def format_extra_hosts(extra_hosts: Mapping[str, str], task: bool = False) -> list[str]: # Use format dictated by Swarm API if container is part of a task if task: return [f"{v} {k}" for k, v in sorted(extra_hosts.items())] return [f"{k}:{v}" for k, v in sorted(extra_hosts.items())] - - -def create_host_config(self, *args, **kwargs): - raise errors.DeprecatedMethod( - "utils.create_host_config has been removed. Please use a " - "docker.types.HostConfig object instead." - ) diff --git a/plugins/module_utils/_common.py b/plugins/module_utils/_common.py index 80fb4ab7..04588837 100644 --- a/plugins/module_utils/_common.py +++ b/plugins/module_utils/_common.py @@ -13,6 +13,7 @@ import platform import re import sys import traceback +import typing as t from collections.abc import Mapping, Sequence from ansible.module_utils.basic import AnsibleModule, missing_required_lib @@ -36,6 +37,9 @@ from ansible_collections.community.docker.plugins.module_utils._version import ( HAS_DOCKER_PY_2 = False # pylint: disable=invalid-name HAS_DOCKER_PY_3 = False # pylint: disable=invalid-name +HAS_DOCKER_ERROR: None | str # pylint: disable=invalid-name +HAS_DOCKER_TRACEBACK: None | str # pylint: disable=invalid-name +docker_version: str | None # pylint: disable=invalid-name try: from docker import __version__ as docker_version @@ -51,12 +55,13 @@ try: HAS_DOCKER_PY_2 = True # pylint: disable=invalid-name from docker import APIClient as Client else: - from docker import Client + from docker import Client # type: ignore except ImportError as exc: HAS_DOCKER_ERROR = str(exc) # pylint: disable=invalid-name HAS_DOCKER_TRACEBACK = traceback.format_exc() # pylint: disable=invalid-name HAS_DOCKER_PY = False # pylint: disable=invalid-name + docker_version = None # pylint: disable=invalid-name else: HAS_DOCKER_PY = True # pylint: disable=invalid-name HAS_DOCKER_ERROR = None # pylint: disable=invalid-name @@ -71,30 +76,34 @@ except ImportError: # Either Docker SDK for Python is no longer using requests, or Docker SDK for Python is not around either, # or Docker SDK for Python's dependency requests is missing. In any case, define an exception # class RequestException so that our code does not break. - class RequestException(Exception): + class RequestException(Exception): # type: ignore pass +if t.TYPE_CHECKING: + from collections.abc import Callable + + MIN_DOCKER_VERSION = "2.0.0" if not HAS_DOCKER_PY: - docker_version = None # pylint: disable=invalid-name - # No Docker SDK for Python. Create a place holder client to allow # instantiation of AnsibleModule and proper error handing - class Client: # noqa: F811, pylint: disable=function-redefined + class Client: # type: ignore # noqa: F811, pylint: disable=function-redefined def __init__(self, **kwargs): pass - class APIError(Exception): # noqa: F811, pylint: disable=function-redefined + class APIError(Exception): # type: ignore # noqa: F811, pylint: disable=function-redefined pass - class NotFound(Exception): # noqa: F811, pylint: disable=function-redefined + class NotFound(Exception): # type: ignore # noqa: F811, pylint: disable=function-redefined pass -def _get_tls_config(fail_function, **kwargs): +def _get_tls_config( + fail_function: Callable[[str], t.NoReturn], **kwargs: t.Any +) -> TLSConfig: if "assert_hostname" in kwargs and LooseVersion(docker_version) >= LooseVersion( "7.0.0b1" ): @@ -109,17 +118,18 @@ def _get_tls_config(fail_function, **kwargs): # Filter out all None parameters kwargs = dict((k, v) for k, v in kwargs.items() if v is not None) try: - tls_config = TLSConfig(**kwargs) - return tls_config + return TLSConfig(**kwargs) except TLSParameterError as exc: fail_function(f"TLS config error: {exc}") -def is_using_tls(auth_data): +def is_using_tls(auth_data: dict[str, t.Any]) -> bool: return auth_data["tls_verify"] or auth_data["tls"] -def get_connect_params(auth_data, fail_function): +def get_connect_params( + auth_data: dict[str, t.Any], fail_function: Callable[[str], t.NoReturn] +) -> dict[str, t.Any]: if is_using_tls(auth_data): auth_data["docker_host"] = auth_data["docker_host"].replace( "tcp://", "https://" @@ -171,7 +181,11 @@ DOCKERPYUPGRADE_UPGRADE_DOCKER = "Use `pip install --upgrade docker` to upgrade. class AnsibleDockerClientBase(Client): - def __init__(self, min_docker_version=None, min_docker_api_version=None): + def __init__( + self, + min_docker_version: str | None = None, + min_docker_api_version: str | None = None, + ) -> None: if min_docker_version is None: min_docker_version = MIN_DOCKER_VERSION @@ -212,23 +226,34 @@ class AnsibleDockerClientBase(Client): f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}." ) - def log(self, msg, pretty_print=False): + def log(self, msg: t.Any, pretty_print: bool = False): pass # if self.debug: # from .util import log_debug # log_debug(msg, pretty_print=pretty_print) @abc.abstractmethod - def fail(self, msg, **kwargs): + def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn: pass - def deprecate(self, msg, version=None, date=None, collection_name=None): + @abc.abstractmethod + def deprecate( + self, + msg: str, + version: str | None = None, + date: str | None = None, + collection_name: str | None = None, + ) -> None: pass @staticmethod def _get_value( - param_name, param_value, env_variable, default_value, value_type="str" - ): + param_name: str, + param_value: t.Any, + env_variable: str | None, + default_value: t.Any | None, + value_type: t.Literal["str", "bool", "int"] = "str", + ) -> t.Any: if param_value is not None: # take module parameter value if value_type == "bool": @@ -265,11 +290,11 @@ class AnsibleDockerClientBase(Client): return default_value @abc.abstractmethod - def _get_params(self): + def _get_params(self) -> dict[str, t.Any]: pass @property - def auth_params(self): + def auth_params(self) -> dict[str, t.Any]: # Get authentication credentials. # Precedence: module parameters-> environment variables-> defaults. @@ -354,7 +379,7 @@ class AnsibleDockerClientBase(Client): return result - def _handle_ssl_error(self, error): + def _handle_ssl_error(self, error: Exception) -> t.NoReturn: match = re.match(r"hostname.*doesn\'t match (\'.*\')", str(error)) if match: hostname = self.auth_params["tls_hostname"] @@ -366,7 +391,7 @@ class AnsibleDockerClientBase(Client): ) self.fail(f"SSL Exception: {error}") - def get_container_by_id(self, container_id): + def get_container_by_id(self, container_id: str) -> dict[str, t.Any] | None: try: self.log(f"Inspecting container Id {container_id}") result = self.inspect_container(container=container_id) @@ -377,7 +402,7 @@ class AnsibleDockerClientBase(Client): except Exception as exc: # pylint: disable=broad-exception-caught self.fail(f"Error inspecting container: {exc}") - def get_container(self, name=None): + def get_container(self, name: str | None) -> dict[str, t.Any] | None: """ Lookup a container and return the inspection results. """ @@ -414,7 +439,9 @@ class AnsibleDockerClientBase(Client): return self.get_container_by_id(result["Id"]) - def get_network(self, name=None, network_id=None): + def get_network( + self, name: str | None = None, network_id: str | None = None + ) -> dict[str, t.Any] | None: """ Lookup a network and return the inspection results. """ @@ -453,7 +480,7 @@ class AnsibleDockerClientBase(Client): return result - def find_image(self, name, tag): + def find_image(self, name: str, tag: str) -> dict[str, t.Any] | None: """ Lookup an image (by name and tag) and return the inspection results. """ @@ -505,7 +532,9 @@ class AnsibleDockerClientBase(Client): self.log(f"Image {name}:{tag} not found.") return None - def find_image_by_id(self, image_id, accept_missing_image=False): + def find_image_by_id( + self, image_id: str, accept_missing_image: bool = False + ) -> dict[str, t.Any] | None: """ Lookup an image (by ID) and return the inspection results. """ @@ -524,7 +553,7 @@ class AnsibleDockerClientBase(Client): self.fail(f"Error inspecting image ID {image_id} - {exc}") return inspection - def _image_lookup(self, name, tag): + def _image_lookup(self, name: str, tag: str) -> list[dict[str, t.Any]]: """ Including a tag in the name parameter sent to the Docker SDK for Python images method does not work consistently. Instead, get the result set for name and manually check @@ -547,7 +576,9 @@ class AnsibleDockerClientBase(Client): break return images - def pull_image(self, name, tag="latest", image_platform=None): + def pull_image( + self, name: str, tag: str = "latest", image_platform: str | None = None + ) -> tuple[dict[str, t.Any] | None, bool]: """ Pull an image """ @@ -578,7 +609,7 @@ class AnsibleDockerClientBase(Client): return new_tag, old_tag == new_tag - def inspect_distribution(self, image, **kwargs): + def inspect_distribution(self, image: str, **kwargs) -> dict[str, t.Any]: """ Get image digest by directly calling the Docker API when running Docker SDK < 4.0.0 since prior versions did not support accessing private repositories. @@ -592,7 +623,7 @@ class AnsibleDockerClientBase(Client): self._url("/distribution/{0}/json", image), headers={"X-Registry-Auth": header}, ), - get_json=True, + json=True, ) return super().inspect_distribution(image, **kwargs) @@ -601,18 +632,24 @@ class AnsibleDockerClient(AnsibleDockerClientBase): def __init__( self, - argument_spec=None, - supports_check_mode=False, - mutually_exclusive=None, - required_together=None, - required_if=None, - required_one_of=None, - required_by=None, - min_docker_version=None, - min_docker_api_version=None, - option_minimal_versions=None, - option_minimal_versions_ignore_params=None, - fail_results=None, + argument_spec: dict[str, t.Any] | None = None, + supports_check_mode: bool = False, + mutually_exclusive: Sequence[Sequence[str]] | None = None, + required_together: Sequence[Sequence[str]] | None = None, + required_if: ( + Sequence[ + tuple[str, t.Any, Sequence[str]] + | tuple[str, t.Any, Sequence[str], bool] + ] + | None + ) = None, + required_one_of: Sequence[Sequence[str]] | None = None, + required_by: dict[str, Sequence[str]] | None = None, + min_docker_version: str | None = None, + min_docker_api_version: str | None = None, + option_minimal_versions: dict[str, t.Any] | None = None, + option_minimal_versions_ignore_params: Sequence[str] | None = None, + fail_results: dict[str, t.Any] | None = None, ): # Modules can put information in here which will always be returned @@ -625,12 +662,12 @@ class AnsibleDockerClient(AnsibleDockerClientBase): merged_arg_spec.update(argument_spec) self.arg_spec = merged_arg_spec - mutually_exclusive_params = [] + mutually_exclusive_params: list[Sequence[str]] = [] mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE if mutually_exclusive: mutually_exclusive_params += mutually_exclusive - required_together_params = [] + required_together_params: list[Sequence[str]] = [] required_together_params += DOCKER_REQUIRED_TOGETHER if required_together: required_together_params += required_together @@ -658,20 +695,30 @@ class AnsibleDockerClient(AnsibleDockerClientBase): option_minimal_versions, option_minimal_versions_ignore_params ) - def fail(self, msg, **kwargs): + def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn: self.fail_results.update(kwargs) self.module.fail_json(msg=msg, **sanitize_result(self.fail_results)) - def deprecate(self, msg, version=None, date=None, collection_name=None): + def deprecate( + self, + msg: str, + version: str | None = None, + date: str | None = None, + collection_name: str | None = None, + ) -> None: self.module.deprecate( msg, version=version, date=date, collection_name=collection_name ) - def _get_params(self): + def _get_params(self) -> dict[str, t.Any]: return self.module.params - def _get_minimal_versions(self, option_minimal_versions, ignore_params=None): - self.option_minimal_versions = {} + def _get_minimal_versions( + self, + option_minimal_versions: dict[str, t.Any], + ignore_params: Sequence[str] | None = None, + ) -> None: + self.option_minimal_versions: dict[str, dict[str, t.Any]] = {} for option in self.module.argument_spec: if ignore_params is not None: if option in ignore_params: @@ -722,7 +769,9 @@ class AnsibleDockerClient(AnsibleDockerClientBase): msg = f"Cannot {usg} with your configuration." self.fail(msg) - def report_warnings(self, result, warnings_key=None): + def report_warnings( + self, result: t.Any, warnings_key: Sequence[str] | None = None + ) -> None: """ Checks result of client operation for warnings, and if present, outputs them. diff --git a/plugins/module_utils/_common_api.py b/plugins/module_utils/_common_api.py index 067f1c84..7617d157 100644 --- a/plugins/module_utils/_common_api.py +++ b/plugins/module_utils/_common_api.py @@ -11,6 +11,7 @@ from __future__ import annotations import abc import os import re +import typing as t from collections.abc import Mapping, Sequence from ansible.module_utils.basic import AnsibleModule, missing_required_lib @@ -28,7 +29,7 @@ try: ) except ImportError: # Define an exception class RequestException so that our code does not break. - class RequestException(Exception): + class RequestException(Exception): # type: ignore pass @@ -60,19 +61,26 @@ from ansible_collections.community.docker.plugins.module_utils._util import ( ) -def _get_tls_config(fail_function, **kwargs): +if t.TYPE_CHECKING: + from collections.abc import Callable + + +def _get_tls_config( + fail_function: Callable[[str], t.NoReturn], **kwargs: t.Any +) -> TLSConfig: try: - tls_config = TLSConfig(**kwargs) - return tls_config + return TLSConfig(**kwargs) except TLSParameterError as exc: fail_function(f"TLS config error: {exc}") -def is_using_tls(auth_data): +def is_using_tls(auth_data: dict[str, t.Any]) -> bool: return auth_data["tls_verify"] or auth_data["tls"] -def get_connect_params(auth_data, fail_function): +def get_connect_params( + auth_data: dict[str, t.Any], fail_function: Callable[[str], t.NoReturn] +) -> dict[str, t.Any]: if is_using_tls(auth_data): auth_data["docker_host"] = auth_data["docker_host"].replace( "tcp://", "https://" @@ -114,7 +122,7 @@ def get_connect_params(auth_data, fail_function): class AnsibleDockerClientBase(Client): - def __init__(self, min_docker_api_version=None): + def __init__(self, min_docker_api_version: str | None = None) -> None: self._connect_params = get_connect_params( self.auth_params, fail_function=self.fail ) @@ -138,23 +146,34 @@ class AnsibleDockerClientBase(Client): f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}." ) - def log(self, msg, pretty_print=False): + def log(self, msg: t.Any, pretty_print: bool = False): pass # if self.debug: # from .util import log_debug # log_debug(msg, pretty_print=pretty_print) @abc.abstractmethod - def fail(self, msg, **kwargs): + def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn: pass - def deprecate(self, msg, version=None, date=None, collection_name=None): + @abc.abstractmethod + def deprecate( + self, + msg: str, + version: str | None = None, + date: str | None = None, + collection_name: str | None = None, + ) -> None: pass @staticmethod def _get_value( - param_name, param_value, env_variable, default_value, value_type="str" - ): + param_name: str, + param_value: t.Any, + env_variable: str | None, + default_value: t.Any | None, + value_type: t.Literal["str", "bool", "int"] = "str", + ) -> t.Any: if param_value is not None: # take module parameter value if value_type == "bool": @@ -191,11 +210,11 @@ class AnsibleDockerClientBase(Client): return default_value @abc.abstractmethod - def _get_params(self): + def _get_params(self) -> dict[str, t.Any]: pass @property - def auth_params(self): + def auth_params(self) -> dict[str, t.Any]: # Get authentication credentials. # Precedence: module parameters-> environment variables-> defaults. @@ -288,7 +307,7 @@ class AnsibleDockerClientBase(Client): return result - def _handle_ssl_error(self, error): + def _handle_ssl_error(self, error: Exception) -> t.NoReturn: match = re.match(r"hostname.*doesn\'t match (\'.*\')", str(error)) if match: hostname = self.auth_params["tls_hostname"] @@ -300,7 +319,7 @@ class AnsibleDockerClientBase(Client): ) self.fail(f"SSL Exception: {error}") - def get_container_by_id(self, container_id): + def get_container_by_id(self, container_id: str) -> dict[str, t.Any] | None: try: self.log(f"Inspecting container Id {container_id}") result = self.get_json("/containers/{0}/json", container_id) @@ -311,7 +330,7 @@ class AnsibleDockerClientBase(Client): except Exception as exc: # pylint: disable=broad-exception-caught self.fail(f"Error inspecting container: {exc}") - def get_container(self, name=None): + def get_container(self, name: str | None) -> dict[str, t.Any] | None: """ Lookup a container and return the inspection results. """ @@ -355,7 +374,9 @@ class AnsibleDockerClientBase(Client): return self.get_container_by_id(result["Id"]) - def get_network(self, name=None, network_id=None): + def get_network( + self, name: str | None = None, network_id: str | None = None + ) -> dict[str, t.Any] | None: """ Lookup a network and return the inspection results. """ @@ -395,14 +416,14 @@ class AnsibleDockerClientBase(Client): return result - def _image_lookup(self, name, tag): + def _image_lookup(self, name: str, tag: str | None) -> list[dict[str, t.Any]]: """ Including a tag in the name parameter sent to the Docker SDK for Python images method does not work consistently. Instead, get the result set for name and manually check if the tag exists. """ try: - params = { + params: dict[str, t.Any] = { "only_ids": 0, "all": 0, } @@ -427,7 +448,7 @@ class AnsibleDockerClientBase(Client): break return images - def find_image(self, name, tag): + def find_image(self, name: str, tag: str | None) -> dict[str, t.Any] | None: """ Lookup an image (by name and tag) and return the inspection results. """ @@ -478,7 +499,9 @@ class AnsibleDockerClientBase(Client): self.log(f"Image {name}:{tag} not found.") return None - def find_image_by_id(self, image_id, accept_missing_image=False): + def find_image_by_id( + self, image_id: str, accept_missing_image: bool = False + ) -> dict[str, t.Any] | None: """ Lookup an image (by ID) and return the inspection results. """ @@ -496,7 +519,9 @@ class AnsibleDockerClientBase(Client): except Exception as exc: # pylint: disable=broad-exception-caught self.fail(f"Error inspecting image ID {image_id} - {exc}") - def pull_image(self, name, tag="latest", image_platform=None): + def pull_image( + self, name: str, tag: str = "latest", image_platform: str | None = None + ) -> tuple[dict[str, t.Any] | None, bool]: """ Pull an image """ @@ -544,22 +569,26 @@ class AnsibleDockerClientBase(Client): class AnsibleDockerClient(AnsibleDockerClientBase): - def __init__( self, - argument_spec=None, - supports_check_mode=False, - mutually_exclusive=None, - required_together=None, - required_if=None, - required_one_of=None, - required_by=None, - min_docker_api_version=None, - option_minimal_versions=None, - option_minimal_versions_ignore_params=None, - fail_results=None, + argument_spec: dict[str, t.Any] | None = None, + supports_check_mode: bool = False, + mutually_exclusive: Sequence[Sequence[str]] | None = None, + required_together: Sequence[Sequence[str]] | None = None, + required_if: ( + Sequence[ + tuple[str, t.Any, Sequence[str]] + | tuple[str, t.Any, Sequence[str], bool] + ] + | None + ) = None, + required_one_of: Sequence[Sequence[str]] | None = None, + required_by: dict[str, Sequence[str]] | None = None, + min_docker_api_version: str | None = None, + option_minimal_versions: dict[str, t.Any] | None = None, + option_minimal_versions_ignore_params: Sequence[str] | None = None, + fail_results: dict[str, t.Any] | None = None, ): - # Modules can put information in here which will always be returned # in case client.fail() is called. self.fail_results = fail_results or {} @@ -570,12 +599,12 @@ class AnsibleDockerClient(AnsibleDockerClientBase): merged_arg_spec.update(argument_spec) self.arg_spec = merged_arg_spec - mutually_exclusive_params = [] + mutually_exclusive_params: list[Sequence[str]] = [] mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE if mutually_exclusive: mutually_exclusive_params += mutually_exclusive - required_together_params = [] + required_together_params: list[Sequence[str]] = [] required_together_params += DOCKER_REQUIRED_TOGETHER if required_together: required_together_params += required_together @@ -600,20 +629,30 @@ class AnsibleDockerClient(AnsibleDockerClientBase): option_minimal_versions, option_minimal_versions_ignore_params ) - def fail(self, msg, **kwargs): + def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn: self.fail_results.update(kwargs) self.module.fail_json(msg=msg, **sanitize_result(self.fail_results)) - def deprecate(self, msg, version=None, date=None, collection_name=None): + def deprecate( + self, + msg: str, + version: str | None = None, + date: str | None = None, + collection_name: str | None = None, + ) -> None: self.module.deprecate( msg, version=version, date=date, collection_name=collection_name ) - def _get_params(self): + def _get_params(self) -> dict[str, t.Any]: return self.module.params - def _get_minimal_versions(self, option_minimal_versions, ignore_params=None): - self.option_minimal_versions = {} + def _get_minimal_versions( + self, + option_minimal_versions: dict[str, t.Any], + ignore_params: Sequence[str] | None = None, + ) -> None: + self.option_minimal_versions: dict[str, dict[str, t.Any]] = {} for option in self.module.argument_spec: if ignore_params is not None: if option in ignore_params: @@ -654,7 +693,9 @@ class AnsibleDockerClient(AnsibleDockerClientBase): msg = f"Cannot {usg} with your configuration." self.fail(msg) - def report_warnings(self, result, warnings_key=None): + def report_warnings( + self, result: t.Any, warnings_key: Sequence[str] | None = None + ) -> None: """ Checks result of client operation for warnings, and if present, outputs them. diff --git a/plugins/module_utils/_common_cli.py b/plugins/module_utils/_common_cli.py index fb9a316b..166d0e41 100644 --- a/plugins/module_utils/_common_cli.py +++ b/plugins/module_utils/_common_cli.py @@ -10,6 +10,7 @@ from __future__ import annotations import abc import json import shlex +import typing as t from ansible.module_utils.basic import AnsibleModule, env_fallback from ansible.module_utils.common.process import get_bin_path @@ -31,6 +32,10 @@ from ansible_collections.community.docker.plugins.module_utils._version import ( ) +if t.TYPE_CHECKING: + from collections.abc import Mapping, Sequence + + DOCKER_COMMON_ARGS = { "docker_cli": {"type": "path"}, "docker_host": { @@ -72,10 +77,16 @@ class DockerException(Exception): class AnsibleDockerClientBase: + docker_api_version_str: str | None + docker_api_version: LooseVersion | None + def __init__( - self, common_args, min_docker_api_version=None, needs_api_version=True - ): - self._environment = {} + self, + common_args, + min_docker_api_version: str | None = None, + needs_api_version: bool = True, + ) -> None: + self._environment: dict[str, str] = {} if common_args["tls_hostname"]: self._environment["DOCKER_TLS_HOSTNAME"] = common_args["tls_hostname"] if common_args["api_version"] and common_args["api_version"] != "auto": @@ -109,10 +120,10 @@ class AnsibleDockerClientBase: self._cli_base.extend(["--context", common_args["cli_context"]]) # `--format json` was only added as a shorthand for `--format {{ json . }}` in Docker 23.0 - dummy, self._version, dummy = self.call_cli_json( + dummy, self._version, dummy2 = self.call_cli_json( "version", "--format", "{{ json . }}", check_rc=True ) - self._info = None + self._info: dict[str, t.Any] | None = None if needs_api_version: if not isinstance(self._version.get("Server"), dict) or not isinstance( @@ -138,32 +149,47 @@ class AnsibleDockerClientBase: "Internal error: cannot have needs_api_version=False with min_docker_api_version not None" ) - def log(self, msg, pretty_print=False): + def log(self, msg: str, pretty_print: bool = False): pass # if self.debug: # from .util import log_debug # log_debug(msg, pretty_print=pretty_print) - def get_cli(self): + def get_cli(self) -> str: return self._cli - def get_version_info(self): + def get_version_info(self) -> str: return self._version - def _compose_cmd(self, args): + def _compose_cmd(self, args: t.Sequence[str]) -> list[str]: return self._cli_base + list(args) - def _compose_cmd_str(self, args): + def _compose_cmd_str(self, args: t.Sequence[str]) -> str: return " ".join(shlex.quote(a) for a in self._compose_cmd(args)) @abc.abstractmethod - def call_cli(self, *args, check_rc=False, data=None, cwd=None, environ_update=None): + def call_cli( + self, + *args: str, + check_rc: bool = False, + data: bytes | None = None, + cwd: str | None = None, + environ_update: dict[str, str] | None = None, + ) -> tuple[int, bytes, bytes]: pass - # def call_cli_json(self, *args, check_rc=False, data=None, cwd=None, environ_update=None, warn_on_stderr=False): - def call_cli_json(self, *args, **kwargs): - warn_on_stderr = kwargs.pop("warn_on_stderr", False) - rc, stdout, stderr = self.call_cli(*args, **kwargs) + def call_cli_json( + self, + *args: str, + check_rc: bool = False, + data: bytes | None = None, + cwd: str | None = None, + environ_update: dict[str, str] | None = None, + warn_on_stderr: bool = False, + ) -> tuple[int, t.Any, bytes]: + rc, stdout, stderr = self.call_cli( + *args, check_rc=check_rc, data=data, cwd=cwd, environ_update=environ_update + ) if warn_on_stderr and stderr: self.warn(to_native(stderr)) try: @@ -174,10 +200,18 @@ class AnsibleDockerClientBase: ) return rc, data, stderr - # def call_cli_json_stream(self, *args, check_rc=False, data=None, cwd=None, environ_update=None, warn_on_stderr=False): - def call_cli_json_stream(self, *args, **kwargs): - warn_on_stderr = kwargs.pop("warn_on_stderr", False) - rc, stdout, stderr = self.call_cli(*args, **kwargs) + def call_cli_json_stream( + self, + *args: str, + check_rc: bool = False, + data: bytes | None = None, + cwd: str | None = None, + environ_update: dict[str, str] | None = None, + warn_on_stderr: bool = False, + ) -> tuple[int, list[t.Any], bytes]: + rc, stdout, stderr = self.call_cli( + *args, check_rc=check_rc, data=data, cwd=cwd, environ_update=environ_update + ) if warn_on_stderr and stderr: self.warn(to_native(stderr)) result = [] @@ -193,25 +227,31 @@ class AnsibleDockerClientBase: return rc, result, stderr @abc.abstractmethod - def fail(self, msg, **kwargs): + def fail(self, msg: str, **kwargs) -> t.NoReturn: pass @abc.abstractmethod - def warn(self, msg): + def warn(self, msg: str) -> None: pass @abc.abstractmethod - def deprecate(self, msg, version=None, date=None, collection_name=None): + def deprecate( + self, + msg: str, + version: str | None = None, + date: str | None = None, + collection_name: str | None = None, + ) -> None: pass - def get_cli_info(self): + def get_cli_info(self) -> dict[str, t.Any]: if self._info is None: - dummy, self._info, dummy = self.call_cli_json( + dummy, self._info, dummy2 = self.call_cli_json( "info", "--format", "{{ json . }}", check_rc=True ) return self._info - def get_client_plugin_info(self, component): + def get_client_plugin_info(self, component: str) -> dict[str, t.Any] | None: cli_info = self.get_cli_info() if not isinstance(cli_info.get("ClientInfo"), dict): self.fail( @@ -222,13 +262,13 @@ class AnsibleDockerClientBase: return plugin return None - def _image_lookup(self, name, tag): + def _image_lookup(self, name: str, tag: str) -> list[dict[str, t.Any]]: """ Including a tag in the name parameter sent to the Docker SDK for Python images method does not work consistently. Instead, get the result set for name and manually check if the tag exists. """ - dummy, images, dummy = self.call_cli_json_stream( + dummy, images, dummy2 = self.call_cli_json_stream( "image", "ls", "--format", @@ -247,7 +287,13 @@ class AnsibleDockerClientBase: break return images - def find_image(self, name, tag): + @t.overload + def find_image(self, name: None, tag: str) -> None: ... + + @t.overload + def find_image(self, name: str, tag: str) -> dict[str, t.Any] | None: ... + + def find_image(self, name: str | None, tag: str) -> dict[str, t.Any] | None: """ Lookup an image (by name and tag) and return the inspection results. """ @@ -298,7 +344,19 @@ class AnsibleDockerClientBase: self.log(f"Image {name}:{tag} not found.") return None - def find_image_by_id(self, image_id, accept_missing_image=False): + @t.overload + def find_image_by_id( + self, image_id: None, accept_missing_image: bool = False + ) -> None: ... + + @t.overload + def find_image_by_id( + self, image_id: str | None, accept_missing_image: bool = False + ) -> dict[str, t.Any] | None: ... + + def find_image_by_id( + self, image_id: str | None, accept_missing_image: bool = False + ) -> dict[str, t.Any] | None: """ Lookup an image (by ID) and return the inspection results. """ @@ -320,17 +378,23 @@ class AnsibleDockerClientBase: class AnsibleModuleDockerClient(AnsibleDockerClientBase): def __init__( self, - argument_spec=None, - supports_check_mode=False, - mutually_exclusive=None, - required_together=None, - required_if=None, - required_one_of=None, - required_by=None, - min_docker_api_version=None, - fail_results=None, - needs_api_version=True, - ): + argument_spec: dict[str, t.Any] | None = None, + supports_check_mode: bool = False, + mutually_exclusive: Sequence[Sequence[str]] | None = None, + required_together: Sequence[Sequence[str]] | None = None, + required_if: ( + Sequence[ + tuple[str, t.Any, Sequence[str]] + | tuple[str, t.Any, Sequence[str], bool] + ] + | None + ) = None, + required_one_of: Sequence[Sequence[str]] | None = None, + required_by: Mapping[str, Sequence[str]] | None = None, + min_docker_api_version: str | None = None, + fail_results: dict[str, t.Any] | None = None, + needs_api_version: bool = True, + ) -> None: # Modules can put information in here which will always be returned # in case client.fail() is called. @@ -342,12 +406,14 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase): merged_arg_spec.update(argument_spec) self.arg_spec = merged_arg_spec - mutually_exclusive_params = [("docker_host", "cli_context")] + mutually_exclusive_params: list[Sequence[str]] = [ + ("docker_host", "cli_context") + ] mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE if mutually_exclusive: mutually_exclusive_params += mutually_exclusive - required_together_params = [] + required_together_params: list[Sequence[str]] = [] required_together_params += DOCKER_REQUIRED_TOGETHER if required_together: required_together_params += required_together @@ -373,7 +439,14 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase): needs_api_version=needs_api_version, ) - def call_cli(self, *args, check_rc=False, data=None, cwd=None, environ_update=None): + def call_cli( + self, + *args: str, + check_rc: bool = False, + data: bytes | None = None, + cwd: str | None = None, + environ_update: dict[str, str] | None = None, + ) -> tuple[int, bytes, bytes]: environment = self._environment.copy() if environ_update: environment.update(environ_update) @@ -390,14 +463,20 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase): ) return rc, stdout, stderr - def fail(self, msg, **kwargs): + def fail(self, msg: str, **kwargs) -> t.NoReturn: self.fail_results.update(kwargs) self.module.fail_json(msg=msg, **sanitize_result(self.fail_results)) - def warn(self, msg): + def warn(self, msg: str) -> None: self.module.warn(msg) - def deprecate(self, msg, version=None, date=None, collection_name=None): + def deprecate( + self, + msg: str, + version: str | None = None, + date: str | None = None, + collection_name: str | None = None, + ) -> None: self.module.deprecate( msg, version=version, date=date, collection_name=collection_name ) diff --git a/plugins/module_utils/_compose_v2.py b/plugins/module_utils/_compose_v2.py index 6a37ec28..253e9db9 100644 --- a/plugins/module_utils/_compose_v2.py +++ b/plugins/module_utils/_compose_v2.py @@ -14,6 +14,7 @@ import re import shutil import tempfile import traceback +import typing as t from collections import namedtuple from shlex import quote @@ -34,6 +35,7 @@ from ansible_collections.community.docker.plugins.module_utils._version import ( ) +PYYAML_IMPORT_ERROR: None | str # pylint: disable=invalid-name try: import yaml @@ -41,7 +43,7 @@ try: # use C version if possible for speedup from yaml import CSafeDumper as _SafeDumper except ImportError: - from yaml import SafeDumper as _SafeDumper + from yaml import SafeDumper as _SafeDumper # type: ignore except ImportError: HAS_PYYAML = False PYYAML_IMPORT_ERROR = traceback.format_exc() # pylint: disable=invalid-name @@ -49,6 +51,13 @@ else: HAS_PYYAML = True PYYAML_IMPORT_ERROR = None # pylint: disable=invalid-name +if t.TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from ansible_collections.community.docker.plugins.module_utils._common_cli import ( + AnsibleModuleDockerClient as _Client, + ) + DOCKER_COMPOSE_FILES = ( "compose.yaml", @@ -144,8 +153,7 @@ class ResourceType: SERVICE = "service" @classmethod - def from_docker_compose_event(cls, resource_type): - # type: (Type[ResourceType], Text) -> Any + def from_docker_compose_event(cls, resource_type: str) -> t.Any: return { "Network": cls.NETWORK, "Image": cls.IMAGE, @@ -240,7 +248,9 @@ _RE_BUILD_PROGRESS_EVENT = re.compile(r"^\s*==>\s+(?P.*)$") MINIMUM_COMPOSE_VERSION = "2.18.0" -def _extract_event(line, warn_function=None): +def _extract_event( + line: str, warn_function: Callable[[str], None] | None = None +) -> tuple[Event | None, bool]: match = _RE_RESOURCE_EVENT.match(line) if match is not None: status = match.group("status") @@ -323,7 +333,9 @@ def _extract_event(line, warn_function=None): return None, False -def _extract_logfmt_event(line, warn_function=None): +def _extract_logfmt_event( + line: str, warn_function: Callable[[str], None] | None = None +) -> tuple[Event | None, bool]: try: result = _parse_logfmt_line(line, logrus_mode=True) except _InvalidLogFmt: @@ -338,7 +350,11 @@ def _extract_logfmt_event(line, warn_function=None): return None, False -def _warn_missing_dry_run_prefix(line, warn_missing_dry_run_prefix, warn_function): +def _warn_missing_dry_run_prefix( + line: str, + warn_missing_dry_run_prefix: bool, + warn_function: Callable[[str], None] | None, +) -> None: if warn_missing_dry_run_prefix and warn_function: # This could be a bug, a change of docker compose's output format, ... # Tell the user to report it to us :-) @@ -349,7 +365,9 @@ def _warn_missing_dry_run_prefix(line, warn_missing_dry_run_prefix, warn_functio ) -def _warn_unparsable_line(line, warn_function): +def _warn_unparsable_line( + line: str, warn_function: Callable[[str], None] | None +) -> None: # This could be a bug, a change of docker compose's output format, ... # Tell the user to report it to us :-) if warn_function: @@ -360,14 +378,16 @@ def _warn_unparsable_line(line, warn_function): ) -def _find_last_event_for(events, resource_id): +def _find_last_event_for( + events: list[Event], resource_id: str +) -> tuple[int, Event] | None: for index, event in enumerate(reversed(events)): if event.resource_id == resource_id: return len(events) - 1 - index, event return None -def _concat_event_msg(event, append_msg): +def _concat_event_msg(event: Event, append_msg: str) -> Event: return Event( event.resource_type, event.resource_id, @@ -382,7 +402,9 @@ _JSON_LEVEL_TO_STATUS_MAP = { } -def parse_json_events(stderr, warn_function=None): +def parse_json_events( + stderr: bytes, warn_function: Callable[[str], None] | None = None +) -> list[Event]: events = [] stderr_lines = stderr.splitlines() if stderr_lines and stderr_lines[-1] == b"": @@ -523,7 +545,12 @@ def parse_json_events(stderr, warn_function=None): return events -def parse_events(stderr, dry_run=False, warn_function=None, nonzero_rc=False): +def parse_events( + stderr: bytes, + dry_run: bool = False, + warn_function: Callable[[str], None] | None = None, + nonzero_rc: bool = False, +) -> list[Event]: events = [] error_event = None stderr_lines = stderr.splitlines() @@ -597,7 +624,11 @@ def parse_events(stderr, dry_run=False, warn_function=None, nonzero_rc=False): return events -def has_changes(events, ignore_service_pull_events=False, ignore_build_events=False): +def has_changes( + events: Sequence[Event], + ignore_service_pull_events: bool = False, + ignore_build_events: bool = False, +) -> bool: for event in events: if event.status in DOCKER_STATUS_WORKING: if ignore_service_pull_events and event.status in DOCKER_STATUS_PULL: @@ -613,7 +644,7 @@ def has_changes(events, ignore_service_pull_events=False, ignore_build_events=Fa return False -def extract_actions(events): +def extract_actions(events: Sequence[Event]) -> list[dict[str, t.Any]]: actions = [] pull_actions = set() for event in events: @@ -645,7 +676,9 @@ def extract_actions(events): return actions -def emit_warnings(events, warn_function): +def emit_warnings( + events: Sequence[Event], warn_function: Callable[[str], None] +) -> None: for event in events: # If a message is present, assume it is a warning if ( @@ -656,13 +689,21 @@ def emit_warnings(events, warn_function): ) -def is_failed(events, rc): +def is_failed(events: Sequence[Event], rc: int) -> bool: if rc: return True return False -def update_failed(result, events, args, stdout, stderr, rc, cli): +def update_failed( + result: dict[str, t.Any], + events: Sequence[Event], + args: list[str], + stdout: str | bytes, + stderr: str | bytes, + rc: int, + cli: str, +) -> bool: if not rc: return False errors = [] @@ -696,7 +737,7 @@ def update_failed(result, events, args, stdout, stderr, rc, cli): return True -def common_compose_argspec(): +def common_compose_argspec() -> dict[str, t.Any]: return { "project_src": {"type": "path"}, "project_name": {"type": "str"}, @@ -708,7 +749,7 @@ def common_compose_argspec(): } -def common_compose_argspec_ex(): +def common_compose_argspec_ex() -> dict[str, t.Any]: return { "argspec": common_compose_argspec(), "mutually_exclusive": [("definition", "project_src"), ("definition", "files")], @@ -721,16 +762,18 @@ def common_compose_argspec_ex(): } -def combine_binary_output(*outputs): +def combine_binary_output(*outputs: bytes | None) -> bytes: return b"\n".join(out for out in outputs if out) -def combine_text_output(*outputs): +def combine_text_output(*outputs: str | None) -> str: return "\n".join(out for out in outputs if out) class BaseComposeManager(DockerBaseClass): - def __init__(self, client, min_version=MINIMUM_COMPOSE_VERSION): + def __init__( + self, client: _Client, min_version: str = MINIMUM_COMPOSE_VERSION + ) -> None: super().__init__() self.client = client self.check_mode = self.client.check_mode @@ -794,12 +837,12 @@ class BaseComposeManager(DockerBaseClass): # more precisely in https://github.com/docker/compose/pull/11478 self.use_json_events = self.compose_version >= LooseVersion("2.29.0") - def get_compose_version(self): + def get_compose_version(self) -> str: return ( self.get_compose_version_from_cli() or self.get_compose_version_from_api() ) - def get_compose_version_from_cli(self): + def get_compose_version_from_cli(self) -> str | None: rc, version_info, dummy_stderr = self.client.call_cli( "compose", "version", "--format", "json" ) @@ -813,7 +856,7 @@ class BaseComposeManager(DockerBaseClass): except Exception: # pylint: disable=broad-exception-caught return None - def get_compose_version_from_api(self): + def get_compose_version_from_api(self) -> str: compose = self.client.get_client_plugin_info("compose") if compose is None: self.fail( @@ -826,11 +869,11 @@ class BaseComposeManager(DockerBaseClass): ) return compose["Version"].lstrip("v") - def fail(self, msg, **kwargs): + def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn: self.cleanup() self.client.fail(msg, **kwargs) - def get_base_args(self, plain_progress=False): + def get_base_args(self, plain_progress: bool = False) -> list[str]: args = ["compose", "--ansi", "never"] if self.use_json_events and not plain_progress: args.extend(["--progress", "json"]) @@ -848,28 +891,33 @@ class BaseComposeManager(DockerBaseClass): args.extend(["--profile", profile]) return args - def _handle_failed_cli_call(self, args, rc, stdout, stderr): + def _handle_failed_cli_call( + self, args: list[str], rc: int, stdout: str | bytes, stderr: bytes + ) -> t.NoReturn: events = parse_json_events(stderr, warn_function=self.client.warn) - result = {} + result: dict[str, t.Any] = {} self.update_failed(result, events, args, stdout, stderr, rc) self.client.module.exit_json(**result) - def list_containers_raw(self): + def list_containers_raw(self) -> list[dict[str, t.Any]]: args = self.get_base_args() + ["ps", "--format", "json", "--all"] if self.compose_version >= LooseVersion("2.23.0"): # https://github.com/docker/compose/pull/11038 args.append("--no-trunc") - kwargs = {"cwd": self.project_src, "check_rc": not self.use_json_events} if self.compose_version >= LooseVersion("2.21.0"): # Breaking change in 2.21.0: https://github.com/docker/compose/pull/10918 - rc, containers, stderr = self.client.call_cli_json_stream(*args, **kwargs) + rc, containers, stderr = self.client.call_cli_json_stream( + *args, cwd=self.project_src, check_rc=not self.use_json_events + ) else: - rc, containers, stderr = self.client.call_cli_json(*args, **kwargs) + rc, containers, stderr = self.client.call_cli_json( + *args, cwd=self.project_src, check_rc=not self.use_json_events + ) if self.use_json_events and rc != 0: - self._handle_failed_cli_call(args, rc, containers, stderr) + self._handle_failed_cli_call(args, rc, json.dumps(containers), stderr) return containers - def list_containers(self): + def list_containers(self) -> list[dict[str, t.Any]]: result = [] for container in self.list_containers_raw(): labels = {} @@ -886,10 +934,11 @@ class BaseComposeManager(DockerBaseClass): result.append(container) return result - def list_images(self): + def list_images(self) -> list[str]: args = self.get_base_args() + ["images", "--format", "json"] - kwargs = {"cwd": self.project_src, "check_rc": not self.use_json_events} - rc, images, stderr = self.client.call_cli_json(*args, **kwargs) + rc, images, stderr = self.client.call_cli_json( + *args, cwd=self.project_src, check_rc=not self.use_json_events + ) if self.use_json_events and rc != 0: self._handle_failed_cli_call(args, rc, images, stderr) if isinstance(images, dict): @@ -899,7 +948,9 @@ class BaseComposeManager(DockerBaseClass): images = list(images.values()) return images - def parse_events(self, stderr, dry_run=False, nonzero_rc=False): + def parse_events( + self, stderr: bytes, dry_run: bool = False, nonzero_rc: bool = False + ) -> list[Event]: if self.use_json_events: return parse_json_events(stderr, warn_function=self.client.warn) return parse_events( @@ -909,17 +960,17 @@ class BaseComposeManager(DockerBaseClass): nonzero_rc=nonzero_rc, ) - def emit_warnings(self, events): + def emit_warnings(self, events: Sequence[Event]) -> None: emit_warnings(events, warn_function=self.client.warn) def update_result( self, - result, - events, - stdout, - stderr, - ignore_service_pull_events=False, - ignore_build_events=False, + result: dict[str, t.Any], + events: Sequence[Event], + stdout: str | bytes, + stderr: str | bytes, + ignore_service_pull_events: bool = False, + ignore_build_events: bool = False, ): result["changed"] = result.get("changed", False) or has_changes( events, @@ -930,7 +981,15 @@ class BaseComposeManager(DockerBaseClass): result["stdout"] = combine_text_output(result.get("stdout"), to_native(stdout)) result["stderr"] = combine_text_output(result.get("stderr"), to_native(stderr)) - def update_failed(self, result, events, args, stdout, stderr, rc): + def update_failed( + self, + result: dict[str, t.Any], + events: Sequence[Event], + args: list[str], + stdout: str | bytes, + stderr: bytes, + rc: int, + ): return update_failed( result, events, @@ -941,14 +1000,14 @@ class BaseComposeManager(DockerBaseClass): cli=self.client.get_cli(), ) - def cleanup_result(self, result): + def cleanup_result(self, result: dict[str, t.Any]) -> None: if not result.get("failed"): # Only return stdout and stderr if it is not empty for res in ("stdout", "stderr"): if result.get(res) == "": result.pop(res) - def cleanup(self): + def cleanup(self) -> None: for directory in self.cleanup_dirs: try: shutil.rmtree(directory, True) diff --git a/plugins/module_utils/_copy.py b/plugins/module_utils/_copy.py index 6ee8479a..11df3403 100644 --- a/plugins/module_utils/_copy.py +++ b/plugins/module_utils/_copy.py @@ -16,6 +16,7 @@ import os.path import shutil import stat import tarfile +import typing as t from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text @@ -25,6 +26,16 @@ from ansible_collections.community.docker.plugins.module_utils._api.errors impor ) +if t.TYPE_CHECKING: + from collections.abc import Callable + + from _typeshed import WriteableBuffer + + from ansible_collections.community.docker.plugins.module_utils._api.api.client import ( + APIClient, + ) + + class DockerFileCopyError(Exception): pass @@ -37,7 +48,9 @@ class DockerFileNotFound(DockerFileCopyError): pass -def _put_archive(client, container, path, data): +def _put_archive( + client: APIClient, container: str, path: str, data: bytes | t.Generator[bytes] +) -> bool: # data can also be file object for streaming. This is because _put uses requests's put(). # See https://requests.readthedocs.io/en/latest/user/advanced/#streaming-uploads url = client._url("/containers/{0}/archive", container) @@ -47,8 +60,14 @@ def _put_archive(client, container, path, data): def _symlink_tar_creator( - b_in_path, file_stat, out_file, user_id, group_id, mode=None, user_name=None -): + b_in_path: bytes, + file_stat: os.stat_result, + out_file: str | bytes, + user_id: int, + group_id: int, + mode: int | None = None, + user_name: str | None = None, +) -> bytes: if not stat.S_ISLNK(file_stat.st_mode): raise DockerUnexpectedError("stat information is not for a symlink") bio = io.BytesIO() @@ -75,16 +94,28 @@ def _symlink_tar_creator( def _symlink_tar_generator( - b_in_path, file_stat, out_file, user_id, group_id, mode=None, user_name=None -): + b_in_path: bytes, + file_stat: os.stat_result, + out_file: str | bytes, + user_id: int, + group_id: int, + mode: int | None = None, + user_name: str | None = None, +) -> t.Generator[bytes]: yield _symlink_tar_creator( b_in_path, file_stat, out_file, user_id, group_id, mode, user_name ) def _regular_file_tar_generator( - b_in_path, file_stat, out_file, user_id, group_id, mode=None, user_name=None -): + b_in_path: bytes, + file_stat: os.stat_result, + out_file: str | bytes, + user_id: int, + group_id: int, + mode: int | None = None, + user_name: str | None = None, +) -> t.Generator[bytes]: if not stat.S_ISREG(file_stat.st_mode): raise DockerUnexpectedError("stat information is not for a regular file") tarinfo = tarfile.TarInfo() @@ -136,8 +167,13 @@ def _regular_file_tar_generator( def _regular_content_tar_generator( - content, out_file, user_id, group_id, mode, user_name=None -): + content: bytes, + out_file: str | bytes, + user_id: int, + group_id: int, + mode: int, + user_name: str | None = None, +) -> t.Generator[bytes]: tarinfo = tarfile.TarInfo() tarinfo.name = ( os.path.splitdrive(to_text(out_file))[1].replace(os.sep, "/").lstrip("/") @@ -175,16 +211,16 @@ def _regular_content_tar_generator( def put_file( - client, - container, - in_path, - out_path, - user_id, - group_id, - mode=None, - user_name=None, - follow_links=False, -): + client: APIClient, + container: str, + in_path: str, + out_path: str, + user_id: int, + group_id: int, + mode: int | None = None, + user_name: str | None = None, + follow_links: bool = False, +) -> None: """Transfer a file from local to Docker container.""" if not os.path.exists(to_bytes(in_path, errors="surrogate_or_strict")): raise DockerFileNotFound(f"file or module does not exist: {to_native(in_path)}") @@ -232,8 +268,15 @@ def put_file( def put_file_content( - client, container, content, out_path, user_id, group_id, mode, user_name=None -): + client: APIClient, + container: str, + content: bytes, + out_path: str, + user_id: int, + group_id: int, + mode: int, + user_name: str | None = None, +) -> None: """Transfer a file from local to Docker container.""" out_dir, out_file = os.path.split(out_path) @@ -248,7 +291,13 @@ def put_file_content( ) -def stat_file(client, container, in_path, follow_links=False, log=None): +def stat_file( + client: APIClient, + container: str, + in_path: str, + follow_links: bool = False, + log: Callable[[str], None] | None = None, +) -> tuple[str, dict[str, t.Any] | None, str | None]: """Fetch information on a file from a Docker container to local. Return a tuple ``(path, stat_data, link_target)`` where: @@ -265,12 +314,12 @@ def stat_file(client, container, in_path, follow_links=False, log=None): while True: if in_path in considered_in_paths: raise DockerFileCopyError( - f'Found infinite symbolic link loop when trying to stating "{in_path}"' + f"Found infinite symbolic link loop when trying to stating {in_path!r}" ) considered_in_paths.add(in_path) if log: - log(f'FETCH: Stating "{in_path}"') + log(f"FETCH: Stating {in_path!r}") response = client._head( client._url("/containers/{0}/archive", container), @@ -299,24 +348,24 @@ def stat_file(client, container, in_path, follow_links=False, log=None): class _RawGeneratorFileobj(io.RawIOBase): - def __init__(self, stream): + def __init__(self, stream: t.Generator[bytes]): self._stream = stream self._buf = b"" - def readable(self): + def readable(self) -> bool: return True - def _readinto_from_buf(self, b, index, length): + def _readinto_from_buf(self, b: WriteableBuffer, index: int, length: int) -> int: cpy = min(length - index, len(self._buf)) if cpy: - b[index : index + cpy] = self._buf[:cpy] + b[index : index + cpy] = self._buf[:cpy] # type: ignore # TODO! self._buf = self._buf[cpy:] index += cpy return index - def readinto(self, b): + def readinto(self, b: WriteableBuffer) -> int: index = 0 - length = len(b) + length = len(b) # type: ignore # TODO! index = self._readinto_from_buf(b, index, length) if index == length: @@ -330,25 +379,28 @@ class _RawGeneratorFileobj(io.RawIOBase): return self._readinto_from_buf(b, index, length) -def _stream_generator_to_fileobj(stream): +def _stream_generator_to_fileobj(stream: t.Generator[bytes]) -> io.BufferedReader: """Given a generator that generates chunks of bytes, create a readable buffered stream.""" raw = _RawGeneratorFileobj(stream) return io.BufferedReader(raw) +_T = t.TypeVar("_T") + + def fetch_file_ex( - client, - container, - in_path, - process_none, - process_regular, - process_symlink, - process_other, - follow_links=False, - log=None, -): + client: APIClient, + container: str, + in_path: str, + process_none: Callable[[str], _T], + process_regular: Callable[[str, tarfile.TarFile, tarfile.TarInfo], _T], + process_symlink: Callable[[str, tarfile.TarInfo], _T], + process_other: Callable[[str, tarfile.TarInfo], _T], + follow_links: bool = False, + log: Callable[[str], None] | None = None, +) -> _T: """Fetch a file (as a tar file entry) from a Docker container to local.""" - considered_in_paths = set() + considered_in_paths: set[str] = set() while True: if in_path in considered_in_paths: @@ -372,8 +424,8 @@ def fetch_file_ex( with tarfile.open( fileobj=_stream_generator_to_fileobj(stream), mode="r|" ) as tar: - symlink_member = None - result = None + symlink_member: tarfile.TarInfo | None = None + result: _T | None = None found = False for member in tar: if found: @@ -398,35 +450,46 @@ def fetch_file_ex( log(f'FETCH: Following symbolic link to "{in_path}"') continue if found: - return result + return result # type: ignore raise DockerUnexpectedError("Received tarfile is empty!") -def fetch_file(client, container, in_path, out_path, follow_links=False, log=None): +def fetch_file( + client: APIClient, + container: str, + in_path: str, + out_path: str, + follow_links: bool = False, + log: Callable[[str], None] | None = None, +) -> str: b_out_path = to_bytes(out_path, errors="surrogate_or_strict") - def process_none(in_path): + def process_none(in_path: str) -> str: raise DockerFileNotFound( f"File {in_path} does not exist in container {container}" ) - def process_regular(in_path, tar, member): + def process_regular( + in_path: str, tar: tarfile.TarFile, member: tarfile.TarInfo + ) -> str: if not follow_links and os.path.exists(b_out_path): os.unlink(b_out_path) - with tar.extractfile(member) as in_f: - with open(b_out_path, "wb") as out_f: - shutil.copyfileobj(in_f, out_f) + reader = tar.extractfile(member) + if reader: + with reader as in_f: + with open(b_out_path, "wb") as out_f: + shutil.copyfileobj(in_f, out_f) return in_path - def process_symlink(in_path, member): + def process_symlink(in_path, member) -> str: if os.path.exists(b_out_path): os.unlink(b_out_path) os.symlink(member.linkname, b_out_path) return in_path - def process_other(in_path, member): + def process_other(in_path, member) -> str: raise DockerFileCopyError( f'Remote file "{in_path}" is not a regular file or a symbolic link' ) @@ -444,7 +507,13 @@ def fetch_file(client, container, in_path, out_path, follow_links=False, log=Non ) -def _execute_command(client, container, command, log=None, check_rc=False): +def _execute_command( + client: APIClient, + container: str, + command: list[str], + log: Callable[[str], None] | None = None, + check_rc: bool = False, +) -> tuple[int, bytes, bytes]: if log: log(f"Executing {command} in {container}") @@ -483,7 +552,7 @@ def _execute_command(client, container, command, log=None, check_rc=False): result = client.get_json("/exec/{0}/json", exec_id) - rc = result.get("ExitCode") or 0 + rc: int = result.get("ExitCode") or 0 stdout = stdout or b"" stderr = stderr or b"" @@ -493,13 +562,15 @@ def _execute_command(client, container, command, log=None, check_rc=False): if check_rc and rc != 0: command_str = " ".join(command) raise DockerUnexpectedError( - f'Obtained unexpected exit code {rc} when running "{command_str}" in {container}.\nSTDOUT: {stdout}\nSTDERR: {stderr}' + f'Obtained unexpected exit code {rc} when running "{command_str}" in {container}.\nSTDOUT: {stdout!r}\nSTDERR: {stderr!r}' ) return rc, stdout, stderr -def determine_user_group(client, container, log=None): +def determine_user_group( + client: APIClient, container: str, log: Callable[[str], None] | None = None +) -> tuple[int, int]: dummy_rc, stdout, dummy_stderr = _execute_command( client, container, ["/bin/sh", "-c", "id -u && id -g"], check_rc=True, log=log ) @@ -507,7 +578,7 @@ def determine_user_group(client, container, log=None): stdout_lines = stdout.splitlines() if len(stdout_lines) != 2: raise DockerUnexpectedError( - f"Expected two-line output to obtain user and group ID for container {container}, but got {len(stdout_lines)} lines:\n{stdout}" + f"Expected two-line output to obtain user and group ID for container {container}, but got {len(stdout_lines)} lines:\n{stdout!r}" ) user_id, group_id = stdout_lines @@ -515,5 +586,5 @@ def determine_user_group(client, container, log=None): return int(user_id), int(group_id) except ValueError as exc: raise DockerUnexpectedError( - f'Expected two-line output with numeric IDs to obtain user and group ID for container {container}, but got "{user_id}" and "{group_id}" instead' + f"Expected two-line output with numeric IDs to obtain user and group ID for container {container}, but got {user_id!r} and {group_id!r} instead" ) from exc diff --git a/plugins/module_utils/_image_archive.py b/plugins/module_utils/_image_archive.py index 15ee9b91..6156f209 100644 --- a/plugins/module_utils/_image_archive.py +++ b/plugins/module_utils/_image_archive.py @@ -18,12 +18,10 @@ class ImageArchiveManifestSummary: "docker image save some:tag > some.tar" command. """ - def __init__(self, image_id, repo_tags): + def __init__(self, image_id: str, repo_tags: list[str]) -> None: """ :param image_id: File name portion of Config entry, e.g. abcde12345 from abcde12345.json - :type image_id: str :param repo_tags Docker image names, e.g. ["hello-world:latest"] - :type repo_tags: list[str] """ self.image_id = image_id @@ -34,22 +32,21 @@ class ImageArchiveInvalidException(Exception): pass -def api_image_id(archive_image_id): +def api_image_id(archive_image_id: str) -> str: """ Accepts an image hash in the format stored in manifest.json, and returns an equivalent identifier that represents the same image hash, but in the format presented by the Docker Engine API. :param archive_image_id: plain image hash - :type archive_image_id: str - :returns: Prefixed hash used by REST api - :rtype: str """ return f"sha256:{archive_image_id}" -def load_archived_image_manifest(archive_path): +def load_archived_image_manifest( + archive_path: str, +) -> list[ImageArchiveManifestSummary] | None: """ Attempts to get image IDs and image names from metadata stored in the image archive tar file. @@ -62,10 +59,7 @@ def load_archived_image_manifest(archive_path): ImageArchiveInvalidException: A file already exists at archive_path, but could not extract an image ID from it. :param archive_path: Tar file to read - :type archive_path: str - :return: None, if no file at archive_path, or a list of ImageArchiveManifestSummary objects. - :rtype: ImageArchiveManifestSummary """ try: @@ -76,8 +70,15 @@ def load_archived_image_manifest(archive_path): with tarfile.open(archive_path, "r") as tf: try: try: - with tf.extractfile("manifest.json") as ef: + reader = tf.extractfile("manifest.json") + if reader is None: + raise ImageArchiveInvalidException( + "Failed to read manifest.json" + ) + with reader as ef: manifest = json.load(ef) + except ImageArchiveInvalidException: + raise except Exception as exc: raise ImageArchiveInvalidException( f"Failed to decode and deserialize manifest.json: {exc}" @@ -139,7 +140,7 @@ def load_archived_image_manifest(archive_path): ) from exc -def archived_image_manifest(archive_path): +def archived_image_manifest(archive_path: str) -> ImageArchiveManifestSummary | None: """ Attempts to get Image.Id and image name from metadata stored in the image archive tar file. @@ -152,10 +153,7 @@ def archived_image_manifest(archive_path): ImageArchiveInvalidException: A file already exists at archive_path, but could not extract an image ID from it. :param archive_path: Tar file to read - :type archive_path: str - :return: None, if no file at archive_path, or the extracted image ID, which will not have a sha256: prefix. - :rtype: ImageArchiveManifestSummary """ results = load_archived_image_manifest(archive_path) diff --git a/plugins/module_utils/_logfmt.py b/plugins/module_utils/_logfmt.py index c048966b..12fcdd08 100644 --- a/plugins/module_utils/_logfmt.py +++ b/plugins/module_utils/_logfmt.py @@ -13,6 +13,9 @@ See https://pkg.go.dev/github.com/kr/logfmt?utm_source=godoc for information on from __future__ import annotations +import typing as t +from enum import Enum + # The format is defined in https://pkg.go.dev/github.com/kr/logfmt?utm_source=godoc # (look for "EBNFish") @@ -22,7 +25,7 @@ class InvalidLogFmt(Exception): pass -class _Mode: +class _Mode(Enum): GARBAGE = 0 KEY = 1 EQUAL = 2 @@ -68,29 +71,29 @@ _HEX_DICT = { } -def _is_ident(cur): +def _is_ident(cur: str) -> bool: return cur > " " and cur not in ('"', "=") class _Parser: - def __init__(self, line): + def __init__(self, line: str) -> None: self.line = line self.index = 0 self.length = len(line) - def done(self): + def done(self) -> bool: return self.index >= self.length - def cur(self): + def cur(self) -> str: return self.line[self.index] - def next(self): + def next(self) -> None: self.index += 1 - def prev(self): + def prev(self) -> None: self.index -= 1 - def parse_unicode_sequence(self): + def parse_unicode_sequence(self) -> str: if self.index + 6 > self.length: raise InvalidLogFmt("Not enough space for unicode escape") if self.line[self.index : self.index + 2] != "\\u": @@ -108,27 +111,27 @@ class _Parser: return chr(v) -def parse_line(line, logrus_mode=False): - result = {} +def parse_line(line: str, logrus_mode: bool = False) -> dict[str, t.Any]: + result: dict[str, t.Any] = {} parser = _Parser(line) - key = [] - value = [] + key: list[str] = [] + value: list[str] = [] mode = _Mode.GARBAGE - def handle_kv(has_no_value=False): + def handle_kv(has_no_value: bool = False) -> None: k = "".join(key) v = None if has_no_value else "".join(value) result[k] = v del key[:] del value[:] - def parse_garbage(cur): + def parse_garbage(cur: str) -> _Mode: if _is_ident(cur): return _Mode.KEY parser.next() return _Mode.GARBAGE - def parse_key(cur): + def parse_key(cur: str) -> _Mode: if _is_ident(cur): key.append(cur) parser.next() @@ -142,7 +145,7 @@ def parse_line(line, logrus_mode=False): parser.next() return _Mode.GARBAGE - def parse_equal(cur): + def parse_equal(cur: str) -> _Mode: if _is_ident(cur): value.append(cur) parser.next() @@ -154,7 +157,7 @@ def parse_line(line, logrus_mode=False): parser.next() return _Mode.GARBAGE - def parse_ident_value(cur): + def parse_ident_value(cur: str) -> _Mode: if _is_ident(cur): value.append(cur) parser.next() @@ -163,7 +166,7 @@ def parse_line(line, logrus_mode=False): parser.next() return _Mode.GARBAGE - def parse_quoted_value(cur): + def parse_quoted_value(cur: str) -> _Mode: if cur == "\\": parser.next() if parser.done(): diff --git a/plugins/module_utils/_module_container/base.py b/plugins/module_utils/_module_container/base.py index b68ebac1..c0d1906b 100644 --- a/plugins/module_utils/_module_container/base.py +++ b/plugins/module_utils/_module_container/base.py @@ -12,6 +12,7 @@ import abc import os import re import shlex +import typing as t from functools import partial from ansible.module_utils.common.text.converters import to_text @@ -32,6 +33,23 @@ from ansible_collections.community.docker.plugins.module_utils._util import ( ) +if t.TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from ansible.module_utils.basic import AnsibleModule + + from ansible_collections.community.docker.plugins.module_utils._version import ( + LooseVersion, + ) + + ValueType = t.Literal["set", "list", "dict", "bool", "int", "float", "str"] + AnsibleType = t.Literal["list", "dict", "bool", "int", "float", "str"] + ComparisonMode = t.Literal["ignore", "strict", "allow_more_present"] + ComparisonType = t.Literal["set", "set(dict)", "list", "dict", "value"] + +Client = t.TypeVar("Client") + + _DEFAULT_IP_REPLACEMENT_STRING = ( "[[DEFAULT_IP:iewahhaeB4Sae6Aen8IeShairoh4zeph7xaekoh8Geingunaesaeweiy3ooleiwi]]" ) @@ -54,7 +72,9 @@ _MOUNT_OPTION_TYPES = { } -def _get_ansible_type(value_type): +def _get_ansible_type( + value_type: ValueType, +) -> AnsibleType: if value_type == "set": return "list" if value_type not in ("list", "dict", "bool", "int", "float", "str"): @@ -65,21 +85,22 @@ def _get_ansible_type(value_type): class Option: def __init__( self, - name, - value_type, - owner, - ansible_type=None, - elements=None, - ansible_elements=None, - ansible_suboptions=None, - ansible_aliases=None, - ansible_choices=None, - needs_no_suboptions=False, - default_comparison=None, - not_a_container_option=False, - not_an_ansible_option=False, - copy_comparison_from=None, - compare=None, + name: str, + *, + value_type: ValueType, + owner: OptionGroup, + ansible_type: AnsibleType | None = None, + elements: ValueType | None = None, + ansible_elements: AnsibleType | None = None, + ansible_suboptions: dict[str, t.Any] | None = None, + ansible_aliases: Sequence[str] | None = None, + ansible_choices: Sequence[str] | None = None, + needs_no_suboptions: bool = False, + default_comparison: ComparisonMode | None = None, + not_a_container_option: bool = False, + not_an_ansible_option: bool = False, + copy_comparison_from: str | None = None, + compare: Callable[[Option, t.Any, t.Any], bool] | None = None, ): self.name = name self.value_type = value_type @@ -95,8 +116,8 @@ class Option: if (elements is None and ansible_elements is None) and needs_ansible_elements: raise ValueError("Ansible elements required for Ansible lists") self.elements = elements if needs_elements else None - self.ansible_elements = ( - (ansible_elements or _get_ansible_type(elements)) + self.ansible_elements: AnsibleType | None = ( + (ansible_elements or _get_ansible_type(elements or "str")) if needs_ansible_elements else None ) @@ -119,10 +140,12 @@ class Option: self.ansible_suboptions = ansible_suboptions if needs_suboptions else None self.ansible_aliases = ansible_aliases or [] self.ansible_choices = ansible_choices - comparison_type = self.value_type - if comparison_type == "set" and self.elements == "dict": + comparison_type: ComparisonType + if self.value_type == "set" and self.elements == "dict": comparison_type = "set(dict)" - elif comparison_type not in ("set", "list", "dict"): + elif self.value_type in ("set", "list", "dict"): + comparison_type = self.value_type # type: ignore + else: comparison_type = "value" self.comparison_type = comparison_type if default_comparison is not None: @@ -152,36 +175,45 @@ class Option: class OptionGroup: def __init__( self, - preprocess=None, - ansible_mutually_exclusive=None, - ansible_required_together=None, - ansible_required_one_of=None, - ansible_required_if=None, - ansible_required_by=None, - ): + *, + preprocess: ( + Callable[[AnsibleModule, dict[str, t.Any]], dict[str, t.Any]] | None + ) = None, + ansible_mutually_exclusive: Sequence[Sequence[str]] | None = None, + ansible_required_together: Sequence[Sequence[str]] | None = None, + ansible_required_one_of: Sequence[Sequence[str]] | None = None, + ansible_required_if: ( + Sequence[ + tuple[str, t.Any, Sequence[str]] + | tuple[str, t.Any, Sequence[str], bool] + ] + | None + ) = None, + ansible_required_by: dict[str, Sequence[str]] | None = None, + ) -> None: if preprocess is None: def preprocess(module, values): return values self.preprocess = preprocess - self.options = [] - self.all_options = [] - self.engines = {} + self.options: list[Option] = [] + self.all_options: list[Option] = [] + self.engines: dict[str, Engine] = {} self.ansible_mutually_exclusive = ansible_mutually_exclusive or [] self.ansible_required_together = ansible_required_together or [] self.ansible_required_one_of = ansible_required_one_of or [] self.ansible_required_if = ansible_required_if or [] self.ansible_required_by = ansible_required_by or {} - self.argument_spec = {} + self.argument_spec: dict[str, t.Any] = {} - def add_option(self, *args, **kwargs): + def add_option(self, *args, **kwargs) -> OptionGroup: option = Option(*args, owner=self, **kwargs) if not option.not_a_container_option: self.options.append(option) self.all_options.append(option) if not option.not_an_ansible_option: - ansible_option = { + ansible_option: dict[str, t.Any] = { "type": option.ansible_type, } if option.ansible_elements is not None: @@ -195,213 +227,297 @@ class OptionGroup: self.argument_spec[option.name] = ansible_option return self - def supports_engine(self, engine_name): + def supports_engine(self, engine_name: str) -> bool: return engine_name in self.engines - def get_engine(self, engine_name): + def get_engine(self, engine_name: str) -> Engine: return self.engines[engine_name] - def add_engine(self, engine_name, engine): + def add_engine(self, engine_name: str, engine: Engine) -> OptionGroup: self.engines[engine_name] = engine return self -class Engine: - min_api_version = None # string or None - min_api_version_obj = None # LooseVersion object or None - extra_option_minimal_versions = None # dict[str, dict[str, Any]] or None +class Engine(t.Generic[Client]): + min_api_version: str | None = None + min_api_version_obj: LooseVersion | None = None + extra_option_minimal_versions: dict[str, dict[str, t.Any]] | None = None @abc.abstractmethod - def get_value(self, module, container, api_version, options, image, host_info): + def get_value( + self, + module: AnsibleModule, + container: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + image: dict[str, t.Any] | None, + host_info: dict[str, t.Any] | None, + ) -> dict[str, t.Any]: pass - def compare_value(self, option, param_value, container_value): + def compare_value( + self, option: Option, param_value: t.Any, container_value: t.Any + ) -> bool: return option.compare(param_value, container_value) @abc.abstractmethod - def set_value(self, module, data, api_version, options, values): + def set_value( + self, + module: AnsibleModule, + data: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + values: dict[str, t.Any], + ) -> None: pass @abc.abstractmethod def get_expected_values( - self, module, client, api_version, options, image, values, host_info - ): + self, + module: AnsibleModule, + client: Client, + api_version: LooseVersion, + options: list[Option], + image: dict[str, t.Any] | None, + values: dict[str, t.Any], + host_info: dict[str, t.Any] | None, + ) -> dict[str, t.Any]: pass @abc.abstractmethod def ignore_mismatching_result( self, - module, - client, - api_version, - option, - image, - container_value, - expected_value, - host_info, - ): + module: AnsibleModule, + client: Client, + api_version: LooseVersion, + option: Option, + image: dict[str, t.Any] | None, + container_value: t.Any, + expected_value: t.Any, + host_info: dict[str, t.Any] | None, + ) -> bool: pass @abc.abstractmethod - def preprocess_value(self, module, client, api_version, options, values): + def preprocess_value( + self, + module: AnsibleModule, + client: Client, + api_version: LooseVersion, + options: list[Option], + values: dict[str, t.Any], + ) -> dict[str, t.Any]: pass @abc.abstractmethod - def update_value(self, module, data, api_version, options, values): + def update_value( + self, + module: AnsibleModule, + data: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + values: dict[str, t.Any], + ) -> None: pass @abc.abstractmethod - def can_set_value(self, api_version): + def can_set_value(self, api_version: LooseVersion) -> bool: pass @abc.abstractmethod - def can_update_value(self, api_version): + def can_update_value(self, api_version: LooseVersion) -> bool: pass @abc.abstractmethod - def needs_container_image(self, values): + def needs_container_image(self, values: dict[str, t.Any]) -> bool: pass @abc.abstractmethod - def needs_host_info(self, values): + def needs_host_info(self, values: dict[str, t.Any]) -> bool: pass -class EngineDriver: - name = None # string +class EngineDriver(t.Generic[Client]): + name: str @abc.abstractmethod def setup( self, - argument_spec, - mutually_exclusive=None, - required_together=None, - required_one_of=None, - required_if=None, - required_by=None, - ): - # Return (module, active_options, client) + argument_spec: dict[str, t.Any], + mutually_exclusive: Sequence[Sequence[str]] | None = None, + required_together: Sequence[Sequence[str]] | None = None, + required_one_of: Sequence[Sequence[str]] | None = None, + required_if: ( + Sequence[ + tuple[str, t.Any, Sequence[str]] + | tuple[str, t.Any, Sequence[str], bool] + ] + | None + ) = None, + required_by: dict[str, Sequence[str]] | None = None, + ) -> tuple[AnsibleModule, list[OptionGroup], Client]: pass @abc.abstractmethod - def get_host_info(self, client): + def get_host_info(self, client: Client) -> dict[str, t.Any]: pass @abc.abstractmethod - def get_api_version(self, client): + def get_api_version(self, client: Client) -> LooseVersion: pass @abc.abstractmethod - def get_container_id(self, container): + def get_container_id(self, container: dict[str, t.Any]) -> str: pass @abc.abstractmethod - def get_image_from_container(self, container): + def get_image_from_container(self, container: dict[str, t.Any]) -> str: pass @abc.abstractmethod - def get_image_name_from_container(self, container): + def get_image_name_from_container(self, container: dict[str, t.Any]) -> str | None: pass @abc.abstractmethod - def is_container_removing(self, container): + def is_container_removing(self, container: dict[str, t.Any]) -> bool: pass @abc.abstractmethod - def is_container_running(self, container): + def is_container_running(self, container: dict[str, t.Any]) -> bool: pass @abc.abstractmethod - def is_container_paused(self, container): + def is_container_paused(self, container: dict[str, t.Any]) -> bool: pass @abc.abstractmethod - def inspect_container_by_name(self, client, container_name): + def inspect_container_by_name( + self, client: Client, container_name: str + ) -> dict[str, t.Any] | None: pass @abc.abstractmethod - def inspect_container_by_id(self, client, container_id): + def inspect_container_by_id( + self, client: Client, container_id: str + ) -> dict[str, t.Any] | None: pass @abc.abstractmethod - def inspect_image_by_id(self, client, image_id): + def inspect_image_by_id( + self, client: Client, image_id: str + ) -> dict[str, t.Any] | None: pass @abc.abstractmethod - def inspect_image_by_name(self, client, repository, tag): + def inspect_image_by_name( + self, client: Client, repository: str, tag: str + ) -> dict[str, t.Any] | None: pass @abc.abstractmethod - def pull_image(self, client, repository, tag, image_platform=None): + def pull_image( + self, + client: Client, + repository: str, + tag: str, + image_platform: str | None = None, + ) -> tuple[dict[str, t.Any] | None, bool]: pass @abc.abstractmethod - def pause_container(self, client, container_id): + def pause_container(self, client: Client, container_id: str) -> None: pass @abc.abstractmethod - def unpause_container(self, client, container_id): + def unpause_container(self, client: Client, container_id: str) -> None: pass @abc.abstractmethod - def disconnect_container_from_network(self, client, container_id, network_id): + def disconnect_container_from_network( + self, client: Client, container_id: str, network_id: str + ) -> None: pass @abc.abstractmethod def connect_container_to_network( - self, client, container_id, network_id, parameters=None - ): + self, + client: Client, + container_id: str, + network_id: str, + parameters: dict[str, t.Any] | None = None, + ) -> None: pass - def create_container_supports_more_than_one_network(self, client): + def create_container_supports_more_than_one_network(self, client: Client) -> bool: return False @abc.abstractmethod def create_container( - self, client, container_name, create_parameters, networks=None - ): + self, + client: Client, + container_name: str, + create_parameters: dict[str, t.Any], + networks: dict[str, dict[str, t.Any]] | None = None, + ) -> str: pass @abc.abstractmethod - def start_container(self, client, container_id): + def start_container(self, client: Client, container_id: str) -> None: pass @abc.abstractmethod - def wait_for_container(self, client, container_id, timeout=None): + def wait_for_container( + self, client: Client, container_id: str, timeout: int | float | None = None + ) -> int | None: pass @abc.abstractmethod - def get_container_output(self, client, container_id): + def get_container_output( + self, client: Client, container_id: str + ) -> tuple[bytes, t.Literal[True]] | tuple[str, t.Literal[False]]: pass @abc.abstractmethod - def update_container(self, client, container_id, update_parameters): + def update_container( + self, client: Client, container_id: str, update_parameters: dict[str, t.Any] + ) -> None: pass @abc.abstractmethod - def restart_container(self, client, container_id, timeout=None): + def restart_container( + self, client: Client, container_id: str, timeout: int | float | None = None + ) -> None: pass @abc.abstractmethod - def kill_container(self, client, container_id, kill_signal=None): + def kill_container( + self, client: Client, container_id: str, kill_signal: str | None = None + ) -> None: pass @abc.abstractmethod - def stop_container(self, client, container_id, timeout=None): + def stop_container( + self, client: Client, container_id: str, timeout: int | float | None = None + ) -> None: pass @abc.abstractmethod def remove_container( - self, client, container_id, remove_volumes=False, link=False, force=False - ): + self, + client: Client, + container_id: str, + remove_volumes: bool = False, + link: bool = False, + force: bool = False, + ) -> None: pass @abc.abstractmethod - def run(self, runner, client): + def run(self, runner: Callable[[], None], client: Client) -> None: pass -def _is_volume_permissions(mode): +def _is_volume_permissions(mode: str) -> bool: for part in mode.split(","): if part not in ( "rw", @@ -423,7 +539,7 @@ def _is_volume_permissions(mode): return True -def _parse_port_range(range_or_port, module): +def _parse_port_range(range_or_port: str, module: AnsibleModule) -> list[int]: """ Parses a string containing either a single port or a range of ports. @@ -443,7 +559,7 @@ def _parse_port_range(range_or_port, module): module.fail_json(msg=f'Invalid port: "{range_or_port}"') -def _split_colon_ipv6(text, module): +def _split_colon_ipv6(text: str, module: AnsibleModule) -> list[str]: """ Split string by ':', while keeping IPv6 addresses in square brackets in one component. """ @@ -475,7 +591,9 @@ def _split_colon_ipv6(text, module): return result -def _preprocess_command(module, values): +def _preprocess_command( + module: AnsibleModule, values: dict[str, t.Any] +) -> dict[str, t.Any]: if "command" not in values: return values value = values["command"] @@ -502,7 +620,9 @@ def _preprocess_command(module, values): } -def _preprocess_entrypoint(module, values): +def _preprocess_entrypoint( + module: AnsibleModule, values: dict[str, t.Any] +) -> dict[str, t.Any]: if "entrypoint" not in values: return values value = values["entrypoint"] @@ -522,7 +642,9 @@ def _preprocess_entrypoint(module, values): } -def _preprocess_env(module, values): +def _preprocess_env( + module: AnsibleModule, values: dict[str, t.Any] +) -> dict[str, t.Any]: if not values: return {} final_env = {} @@ -546,7 +668,9 @@ def _preprocess_env(module, values): } -def _preprocess_healthcheck(module, values): +def _preprocess_healthcheck( + module: AnsibleModule, values: dict[str, t.Any] +) -> dict[str, t.Any]: if not values: return {} return { @@ -556,7 +680,12 @@ def _preprocess_healthcheck(module, values): } -def _preprocess_convert_to_bytes(module, values, name, unlimited_value=None): +def _preprocess_convert_to_bytes( + module: AnsibleModule, + values: dict[str, t.Any], + name: str, + unlimited_value: int | None = None, +) -> dict[str, t.Any]: if name not in values: return values try: @@ -571,7 +700,9 @@ def _preprocess_convert_to_bytes(module, values, name, unlimited_value=None): module.fail_json(msg=f"Failed to convert {name} to bytes: {exc}") -def _preprocess_mac_address(module, values): +def _preprocess_mac_address( + module: AnsibleModule, values: dict[str, t.Any] +) -> dict[str, t.Any]: if "mac_address" not in values: return values return { @@ -579,7 +710,9 @@ def _preprocess_mac_address(module, values): } -def _preprocess_networks(module, values): +def _preprocess_networks( + module: AnsibleModule, values: dict[str, t.Any] +) -> dict[str, t.Any]: if ( module.params["networks_cli_compatible"] is True and values.get("networks") @@ -605,14 +738,18 @@ def _preprocess_networks(module, values): return values -def _preprocess_sysctls(module, values): +def _preprocess_sysctls( + module: AnsibleModule, values: dict[str, t.Any] +) -> dict[str, t.Any]: if "sysctls" in values: for key, value in values["sysctls"].items(): values["sysctls"][key] = to_text(value, errors="surrogate_or_strict") return values -def _preprocess_tmpfs(module, values): +def _preprocess_tmpfs( + module: AnsibleModule, values: dict[str, t.Any] +) -> dict[str, t.Any]: if "tmpfs" not in values: return values result = {} @@ -625,7 +762,9 @@ def _preprocess_tmpfs(module, values): return {"tmpfs": result} -def _preprocess_ulimits(module, values): +def _preprocess_ulimits( + module: AnsibleModule, values: dict[str, t.Any] +) -> dict[str, t.Any]: if "ulimits" not in values: return values result = [] @@ -644,8 +783,10 @@ def _preprocess_ulimits(module, values): } -def _preprocess_mounts(module, values): - last = {} +def _preprocess_mounts( + module: AnsibleModule, values: dict[str, t.Any] +) -> dict[str, t.Any]: + last: dict[str, str] = {} def check_collision(t, name): if t in last: @@ -776,7 +917,9 @@ def _preprocess_mounts(module, values): return values -def _preprocess_labels(module, values): +def _preprocess_labels( + module: AnsibleModule, values: dict[str, t.Any] +) -> dict[str, t.Any]: result = {} if "labels" in values: labels = values["labels"] @@ -787,13 +930,15 @@ def _preprocess_labels(module, values): return result -def _preprocess_log(module, values): - result = {} +def _preprocess_log( + module: AnsibleModule, values: dict[str, t.Any] +) -> dict[str, t.Any]: + result: dict[str, t.Any] = {} if "log_driver" not in values: return result result["log_driver"] = values["log_driver"] if "log_options" in values: - options = {} + options: dict[str, str] = {} for k, v in values["log_options"].items(): if not isinstance(v, str): value = to_text(v, errors="surrogate_or_strict") @@ -807,7 +952,9 @@ def _preprocess_log(module, values): return result -def _preprocess_ports(module, values): +def _preprocess_ports( + module: AnsibleModule, values: dict[str, t.Any] +) -> dict[str, t.Any]: if "published_ports" in values: if "all" in values["published_ports"]: module.fail_json( @@ -815,7 +962,12 @@ def _preprocess_ports(module, values): "to randomly assign port mappings for those not specified by published_ports." ) - binds = {} + binds: dict[ + str | int, + tuple[str] + | tuple[str, str | int] + | list[tuple[str] | tuple[str, str | int]], + ] = {} for port in values["published_ports"]: parts = _split_colon_ipv6( to_text(port, errors="surrogate_or_strict"), module @@ -827,6 +979,7 @@ def _preprocess_ports(module, values): container_ports = _parse_port_range(container_port, module) p_len = len(parts) + port_binds: Sequence[tuple[str] | tuple[str, str | int]] if p_len == 1: port_binds = len(container_ports) * [(_DEFAULT_IP_REPLACEMENT_STRING,)] elif p_len == 2: @@ -865,8 +1018,12 @@ def _preprocess_ports(module, values): "Maybe you forgot to use square brackets ([...]) around an IPv6 address?" ) - for bind, container_port in zip(port_binds, container_ports): - idx = f"{container_port}/{protocol}" if protocol else container_port + for bind, container_port_val in zip(port_binds, container_ports): + idx = ( + f"{container_port_val}/{protocol}" + if protocol + else container_port_val + ) if idx in binds: old_bind = binds[idx] if isinstance(old_bind, list): @@ -882,9 +1039,9 @@ def _preprocess_ports(module, values): for port in values["exposed_ports"]: port = to_text(port, errors="surrogate_or_strict").strip() protocol = "tcp" - match = re.search(r"(/.+$)", port) - if match: - protocol = match.group(1).replace("/", "") + matcher = re.search(r"(/.+$)", port) + if matcher: + protocol = matcher.group(1).replace("/", "") port = re.sub(r"/.+$", "", port) exposed.append((port, protocol)) if "published_ports" in values: @@ -912,7 +1069,7 @@ def _preprocess_ports(module, values): return values -def _compare_platform(option, param_value, container_value): +def _compare_platform(option: Option, param_value: t.Any, container_value: t.Any): if option.comparison == "ignore": return True try: diff --git a/plugins/module_utils/_module_container/docker_api.py b/plugins/module_utils/_module_container/docker_api.py index f07a6214..28776607 100644 --- a/plugins/module_utils/_module_container/docker_api.py +++ b/plugins/module_utils/_module_container/docker_api.py @@ -10,6 +10,7 @@ from __future__ import annotations import json import traceback +import typing as t from ansible.module_utils.common.text.converters import to_text from ansible.module_utils.common.text.formatters import human_to_bytes @@ -101,6 +102,7 @@ from ansible_collections.community.docker.plugins.module_utils._module_container OPTIONS, Engine, EngineDriver, + _is_volume_permissions, ) from ansible_collections.community.docker.plugins.module_utils._platform import ( compose_platform_string, @@ -115,40 +117,48 @@ from ansible_collections.community.docker.plugins.module_utils._version import ( ) +if t.TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from ansible.module_utils.basic import AnsibleModule + + from .base import Option, OptionGroup + + Sentry = object + + _DEFAULT_IP_REPLACEMENT_STRING = ( "[[DEFAULT_IP:iewahhaeB4Sae6Aen8IeShairoh4zeph7xaekoh8Geingunaesaeweiy3ooleiwi]]" ) -def _get_ansible_type(our_type): - if our_type == "set": - return "list" - if our_type not in ("list", "dict", "bool", "int", "float", "str"): - raise ValueError(f'Invalid type "{our_type}"') - return our_type +_SENTRY: Sentry = object() -_SENTRY = object() - - -class DockerAPIEngineDriver(EngineDriver): +class DockerAPIEngineDriver(EngineDriver[AnsibleDockerClient]): name = "docker_api" def setup( self, - argument_spec, - mutually_exclusive=None, - required_together=None, - required_one_of=None, - required_if=None, - required_by=None, - ): - argument_spec = argument_spec or {} - mutually_exclusive = mutually_exclusive or [] - required_together = required_together or [] - required_one_of = required_one_of or [] - required_if = required_if or [] - required_by = required_by or {} + argument_spec: dict[str, t.Any], + mutually_exclusive: Sequence[Sequence[str]] | None = None, + required_together: Sequence[Sequence[str]] | None = None, + required_one_of: Sequence[Sequence[str]] | None = None, + required_if: ( + Sequence[ + tuple[str, t.Any, Sequence[str]] + | tuple[str, t.Any, Sequence[str], bool] + ] + | None + ) = None, + required_by: dict[str, Sequence[str]] | None = None, + ) -> tuple[AnsibleModule, list[OptionGroup], AnsibleDockerClient]: + argument_spec = dict(argument_spec or {}) + mutually_exclusive = list(mutually_exclusive or []) + required_together = list(required_together or []) + required_one_of = list(required_one_of or []) + required_if = list(required_if or []) + required_by = dict(required_by or {}) active_options = [] option_minimal_versions = {} @@ -188,27 +198,27 @@ class DockerAPIEngineDriver(EngineDriver): return client.module, active_options, client - def get_host_info(self, client): + def get_host_info(self, client: AnsibleDockerClient) -> dict[str, t.Any]: return client.info() - def get_api_version(self, client): + def get_api_version(self, client: AnsibleDockerClient) -> LooseVersion: return client.docker_api_version - def get_container_id(self, container): + def get_container_id(self, container: dict[str, t.Any]) -> str: return container["Id"] - def get_image_from_container(self, container): + def get_image_from_container(self, container: dict[str, t.Any]) -> str: return container["Image"] - def get_image_name_from_container(self, container): + def get_image_name_from_container(self, container: dict[str, t.Any]) -> str | None: return container["Config"].get("Image") - def is_container_removing(self, container): + def is_container_removing(self, container: dict[str, t.Any]) -> bool: if container.get("State"): return container["State"].get("Status") == "removing" return False - def is_container_running(self, container): + def is_container_running(self, container: dict[str, t.Any]) -> bool: if container.get("State"): if container["State"].get("Running") and not container["State"].get( "Ghost", False @@ -216,38 +226,54 @@ class DockerAPIEngineDriver(EngineDriver): return True return False - def is_container_paused(self, container): + def is_container_paused(self, container: dict[str, t.Any]) -> bool: if container.get("State"): return container["State"].get("Paused", False) return False - def inspect_container_by_name(self, client, container_name): + def inspect_container_by_name( + self, client: AnsibleDockerClient, container_name: str + ) -> dict[str, t.Any] | None: return client.get_container(container_name) - def inspect_container_by_id(self, client, container_id): + def inspect_container_by_id( + self, client: AnsibleDockerClient, container_id: str + ) -> dict[str, t.Any] | None: return client.get_container_by_id(container_id) - def inspect_image_by_id(self, client, image_id): + def inspect_image_by_id( + self, client: AnsibleDockerClient, image_id: str + ) -> dict[str, t.Any] | None: return client.find_image_by_id(image_id, accept_missing_image=True) - def inspect_image_by_name(self, client, repository, tag): + def inspect_image_by_name( + self, client: AnsibleDockerClient, repository: str, tag: str + ) -> dict[str, t.Any] | None: return client.find_image(repository, tag) - def pull_image(self, client, repository, tag, image_platform=None): + def pull_image( + self, + client: AnsibleDockerClient, + repository: str, + tag: str, + image_platform: str | None = None, + ) -> tuple[dict[str, t.Any] | None, bool]: return client.pull_image(repository, tag, image_platform=image_platform) - def pause_container(self, client, container_id): + def pause_container(self, client: AnsibleDockerClient, container_id: str) -> None: client.post_call("/containers/{0}/pause", container_id) - def unpause_container(self, client, container_id): + def unpause_container(self, client: AnsibleDockerClient, container_id: str) -> None: client.post_call("/containers/{0}/unpause", container_id) - def disconnect_container_from_network(self, client, container_id, network_id): + def disconnect_container_from_network( + self, client: AnsibleDockerClient, container_id: str, network_id: str + ) -> None: client.post_json( "/networks/{0}/disconnect", network_id, data={"Container": container_id} ) - def _create_endpoint_config(self, parameters): + def _create_endpoint_config(self, parameters: dict[str, t.Any]) -> dict[str, t.Any]: parameters = parameters.copy() params = {} for para, dest_para in { @@ -290,8 +316,12 @@ class DockerAPIEngineDriver(EngineDriver): return params def connect_container_to_network( - self, client, container_id, network_id, parameters=None - ): + self, + client: AnsibleDockerClient, + container_id: str, + network_id: str, + parameters: dict[str, t.Any] | None = None, + ) -> None: parameters = (parameters or {}).copy() params = self._create_endpoint_config(parameters or {}) data = { @@ -300,12 +330,18 @@ class DockerAPIEngineDriver(EngineDriver): } client.post_json("/networks/{0}/connect", network_id, data=data) - def create_container_supports_more_than_one_network(self, client): + def create_container_supports_more_than_one_network( + self, client: AnsibleDockerClient + ) -> bool: return client.docker_api_version >= LooseVersion("1.44") def create_container( - self, client, container_name, create_parameters, networks=None - ): + self, + client: AnsibleDockerClient, + container_name: str, + create_parameters: dict[str, t.Any], + networks: dict[str, dict[str, t.Any]] | None = None, + ) -> str: params = {"name": container_name} if "platform" in create_parameters: params["platform"] = create_parameters.pop("platform") @@ -323,15 +359,22 @@ class DockerAPIEngineDriver(EngineDriver): client.report_warnings(new_container) return new_container["Id"] - def start_container(self, client, container_id): + def start_container(self, client: AnsibleDockerClient, container_id: str) -> None: client.post_json("/containers/{0}/start", container_id) - def wait_for_container(self, client, container_id, timeout=None): + def wait_for_container( + self, + client: AnsibleDockerClient, + container_id: str, + timeout: int | float | None = None, + ) -> int | None: return client.post_json_to_json( "/containers/{0}/wait", container_id, timeout=timeout )["StatusCode"] - def get_container_output(self, client, container_id): + def get_container_output( + self, client: AnsibleDockerClient, container_id: str + ) -> tuple[bytes, t.Literal[True]] | tuple[str, t.Literal[False]]: config = client.get_json("/containers/{0}/json", container_id) logging_driver = config["HostConfig"]["LogConfig"]["Type"] if logging_driver in ("json-file", "journald", "local"): @@ -349,14 +392,24 @@ class DockerAPIEngineDriver(EngineDriver): return output, True return f"Result logged using `{logging_driver}` driver", False - def update_container(self, client, container_id, update_parameters): + def update_container( + self, + client: AnsibleDockerClient, + container_id: str, + update_parameters: dict[str, t.Any], + ) -> None: result = client.post_json_to_json( "/containers/{0}/update", container_id, data=update_parameters ) client.report_warnings(result) - def restart_container(self, client, container_id, timeout=None): - client_timeout = client.timeout + def restart_container( + self, + client: AnsibleDockerClient, + container_id: str, + timeout: int | float | None = None, + ) -> None: + client_timeout: int | float | None = client.timeout if client_timeout is not None: client_timeout += timeout or 10 client.post_call( @@ -366,13 +419,23 @@ class DockerAPIEngineDriver(EngineDriver): timeout=client_timeout, ) - def kill_container(self, client, container_id, kill_signal=None): + def kill_container( + self, + client: AnsibleDockerClient, + container_id: str, + kill_signal: str | None = None, + ) -> None: params = {} if kill_signal is not None: params["signal"] = kill_signal client.post_call("/containers/{0}/kill", container_id, params=params) - def stop_container(self, client, container_id, timeout=None): + def stop_container( + self, + client: AnsibleDockerClient, + container_id: str, + timeout: int | float | None = None, + ) -> None: if timeout: params = {"t": timeout} else: @@ -414,8 +477,13 @@ class DockerAPIEngineDriver(EngineDriver): break def remove_container( - self, client, container_id, remove_volumes=False, link=False, force=False - ): + self, + client: AnsibleDockerClient, + container_id: str, + remove_volumes: bool = False, + link: bool = False, + force: bool = False, + ) -> None: params = {"v": remove_volumes, "link": link, "force": force} count = 0 while True: @@ -452,7 +520,7 @@ class DockerAPIEngineDriver(EngineDriver): # We only loop when explicitly requested by 'continue' break - def run(self, runner, client): + def run(self, runner: Callable[[], None], client: AnsibleDockerClient) -> None: try: runner() except DockerException as e: @@ -467,24 +535,48 @@ class DockerAPIEngineDriver(EngineDriver): ) -class DockerAPIEngine(Engine): - def get_value(self, module, container, api_version, options, image, host_info): +class DockerAPIEngine(Engine[AnsibleDockerClient]): + def get_value( + self, + module: AnsibleModule, + container: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + image: dict[str, t.Any] | None, + host_info: dict[str, t.Any] | None, + ) -> dict[str, t.Any]: return self._get_value( module, container, api_version, options, image, host_info ) - def compare_value(self, option, param_value, container_value): + def compare_value( + self, option: Option, param_value: t.Any, container_value: t.Any + ) -> bool: if self._compare_value is not None: return self._compare_value(option, param_value, container_value) return super().compare_value(option, param_value, container_value) - def set_value(self, module, data, api_version, options, values): + def set_value( + self, + module: AnsibleModule, + data: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + values: dict[str, t.Any], + ) -> None: if self._set_value is not None: self._set_value(module, data, api_version, options, values) def get_expected_values( - self, module, client, api_version, options, image, values, host_info - ): + self, + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + options: list[Option], + image: dict[str, t.Any] | None, + values: dict[str, t.Any], + host_info: dict[str, t.Any] | None, + ) -> dict[str, t.Any]: if self._get_expected_values is None: return values return self._get_expected_values( @@ -493,15 +585,15 @@ class DockerAPIEngine(Engine): def ignore_mismatching_result( self, - module, - client, - api_version, - option, - image, - container_value, - expected_value, - host_info, - ): + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + option: Option, + image: dict[str, t.Any] | None, + container_value: t.Any, + expected_value: t.Any, + host_info: dict[str, t.Any] | None, + ) -> bool: if self._ignore_mismatching_result is None: return False return self._ignore_mismatching_result( @@ -515,50 +607,139 @@ class DockerAPIEngine(Engine): host_info, ) - def preprocess_value(self, module, client, api_version, options, values): + def preprocess_value( + self, + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + options: list[Option], + values: dict[str, t.Any], + ) -> dict[str, t.Any]: if self._preprocess_value is None: return values return self._preprocess_value(module, client, api_version, options, values) - def update_value(self, module, data, api_version, options, values): + def update_value( + self, + module: AnsibleModule, + data: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + values: dict[str, t.Any], + ) -> None: if self._update_value is not None: self._update_value(module, data, api_version, options, values) - def can_set_value(self, api_version): + def can_set_value(self, api_version: LooseVersion) -> bool: if self._can_set_value is None: return self._set_value is not None return self._can_set_value(api_version) - def can_update_value(self, api_version): + def can_update_value(self, api_version: LooseVersion) -> bool: if self._can_update_value is None: return self._update_value is not None return self._can_update_value(api_version) - def needs_container_image(self, values): + def needs_container_image(self, values: dict[str, t.Any]) -> bool: if self._needs_container_image is None: return False return self._needs_container_image(values) - def needs_host_info(self, values): + def needs_host_info(self, values: dict[str, t.Any]) -> bool: if self._needs_host_info is None: return False return self._needs_host_info(values) def __init__( self, - get_value, - preprocess_value=None, - get_expected_values=None, - ignore_mismatching_result=None, - set_value=None, - update_value=None, - can_set_value=None, - can_update_value=None, - min_api_version=None, - compare_value=None, - needs_container_image=None, - needs_host_info=None, - extra_option_minimal_versions=None, + get_value: Callable[ + [ + AnsibleModule, + dict[str, t.Any], + LooseVersion, + list[Option], + dict[str, t.Any] | None, + dict[str, t.Any] | None, + ], + dict[str, t.Any], + ], + preprocess_value: ( + Callable[ + [ + AnsibleModule, + AnsibleDockerClient, + LooseVersion, + list[Option], + dict[str, t.Any], + ], + dict[str, t.Any], + ] + | None + ) = None, + get_expected_values: ( + Callable[ + [ + AnsibleModule, + AnsibleDockerClient, + LooseVersion, + list[Option], + dict[str, t.Any] | None, + dict[str, t.Any], + dict[str, t.Any] | None, + ], + dict[str, t.Any], + ] + | None + ) = None, + ignore_mismatching_result: ( + Callable[ + [ + AnsibleModule, + AnsibleDockerClient, + LooseVersion, + Option, + dict[str, t.Any] | None, + t.Any, + t.Any, + dict[str, t.Any] | None, + ], + bool, + ] + | None + ) = None, + set_value: ( + Callable[ + [ + AnsibleModule, + dict[str, t.Any], + LooseVersion, + list[Option], + dict[str, t.Any], + ], + None, + ] + | None + ) = None, + update_value: ( + Callable[ + [ + AnsibleModule, + dict[str, t.Any], + LooseVersion, + list[Option], + dict[str, t.Any], + ], + None, + ] + | None + ) = None, + can_set_value: Callable[[LooseVersion], bool] | None = None, + can_update_value: Callable[[LooseVersion], bool] | None = None, + min_api_version: str | None = None, + compare_value: Callable[[Option, t.Any, t.Any], bool] | None = None, + needs_container_image: Callable[[dict[str, t.Any]], bool] | None = None, + needs_host_info: Callable[[dict[str, t.Any]], bool] | None = None, + extra_option_minimal_versions: dict[str, dict[str, t.Any]] | None = None, ): self.min_api_version = min_api_version self.min_api_version_obj = ( @@ -566,31 +747,73 @@ class DockerAPIEngine(Engine): ) self.extra_option_minimal_versions = extra_option_minimal_versions self._get_value = get_value - self._compare_value = compare_value # can be None - self._set_value = set_value # can be None - self._get_expected_values = get_expected_values # can be None - self._ignore_mismatching_result = ignore_mismatching_result # can be None - self._preprocess_value = preprocess_value # can be None - self._update_value = update_value # can be None - self._can_set_value = can_set_value # can be None - self._can_update_value = can_update_value # can be None - self._needs_container_image = needs_container_image # can be None - self._needs_host_info = needs_host_info # can be None + self._compare_value = compare_value + self._set_value = set_value + self._get_expected_values = get_expected_values + self._ignore_mismatching_result = ignore_mismatching_result + self._preprocess_value = preprocess_value + self._update_value = update_value + self._can_set_value = can_set_value + self._can_update_value = can_update_value + self._needs_container_image = needs_container_image + self._needs_host_info = needs_host_info @classmethod def config_value( cls, - config_name, - postprocess_for_get=None, - preprocess_for_set=None, - get_expected_value=None, - ignore_mismatching_result=None, - min_api_version=None, - preprocess_value=None, - update_parameter=None, - extra_option_minimal_versions=None, - ): - def preprocess_value_(module, client, api_version, options, values): + config_name: str, + postprocess_for_get: ( + Callable[[AnsibleModule, LooseVersion, t.Any, Sentry], t.Any | Sentry] + | None + ) = None, + preprocess_for_set: ( + Callable[[AnsibleModule, LooseVersion, t.Any], t.Any] | None + ) = None, + get_expected_value: ( + Callable[ + [ + AnsibleModule, + AnsibleDockerClient, + LooseVersion, + dict[str, t.Any] | None, + t.Any, + Sentry, + ], + t.Any | Sentry, + ] + | None + ) = None, + ignore_mismatching_result: ( + Callable[ + [ + AnsibleModule, + AnsibleDockerClient, + LooseVersion, + Option, + dict[str, t.Any] | None, + t.Any, + t.Any, + dict[str, t.Any] | None, + ], + bool, + ] + | None + ) = None, + min_api_version: str | None = None, + preprocess_value: ( + Callable[[AnsibleModule, AnsibleDockerClient, LooseVersion, t.Any], t.Any] + | None + ) = None, + update_parameter: str | None = None, + extra_option_minimal_versions: dict[str, dict[str, t.Any]] | None = None, + ) -> DockerAPIEngine: + def preprocess_value_( + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + options: list[Option], + values: dict[str, t.Any], + ) -> dict[str, t.Any]: if len(options) != 1: raise AssertionError( "config_value can only be used for a single option" @@ -605,7 +828,14 @@ class DockerAPIEngine(Engine): values[options[0].name] = value return values - def get_value(module, container, api_version, options, image, host_info): + def get_value( + module: AnsibleModule, + container: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + image: dict[str, t.Any] | None, + host_info: dict[str, t.Any] | None, + ) -> dict[str, t.Any]: if len(options) != 1: raise AssertionError( "config_value can only be used for a single option" @@ -617,10 +847,31 @@ class DockerAPIEngine(Engine): return {} return {options[0].name: value} + get_expected_values_: ( + Callable[ + [ + AnsibleModule, + AnsibleDockerClient, + LooseVersion, + list[Option], + dict[str, t.Any] | None, + dict[str, t.Any], + dict[str, t.Any] | None, + ], + dict[str, t.Any], + ] + | None + ) = None if get_expected_value: - def get_expected_values_( - module, client, api_version, options, image, values, host_info + def get_expected_values_( # noqa: F811 + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + options: list[Option], + image: dict[str, t.Any] | None, + values: dict[str, t.Any], + host_info: dict[str, t.Any] | None, ): if len(options) != 1: raise AssertionError( @@ -634,10 +885,13 @@ class DockerAPIEngine(Engine): return values return {options[0].name: value} - else: - get_expected_values_ = None - - def set_value(module, data, api_version, options, values): + def set_value( + module: AnsibleModule, + data: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + values: dict[str, t.Any], + ) -> None: if len(options) != 1: raise AssertionError( "config_value can only be used for a single option" @@ -649,9 +903,28 @@ class DockerAPIEngine(Engine): value = preprocess_for_set(module, api_version, value) data[config_name] = value + update_value: ( + Callable[ + [ + AnsibleModule, + dict[str, t.Any], + LooseVersion, + list[Option], + dict[str, t.Any], + ], + None, + ] + | None + ) = None if update_parameter: - def update_value(module, data, api_version, options, values): + def update_value( # noqa: F811 + module: AnsibleModule, + data: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + values: dict[str, t.Any], + ) -> None: if len(options) != 1: raise AssertionError( "update_parameter can only be used for a single option" @@ -663,9 +936,6 @@ class DockerAPIEngine(Engine): value = preprocess_for_set(module, api_version, value) data[update_parameter] = value - else: - update_value = None - return cls( get_value=get_value, preprocess_value=preprocess_value_, @@ -680,17 +950,59 @@ class DockerAPIEngine(Engine): @classmethod def host_config_value( cls, - host_config_name, - postprocess_for_get=None, - preprocess_for_set=None, - get_expected_value=None, - ignore_mismatching_result=None, - min_api_version=None, - preprocess_value=None, - update_parameter=None, - extra_option_minimal_versions=None, - ): - def preprocess_value_(module, client, api_version, options, values): + host_config_name: str, + postprocess_for_get: ( + Callable[[AnsibleModule, LooseVersion, t.Any, Sentry], t.Any | Sentry] + | None + ) = None, + preprocess_for_set: ( + Callable[[AnsibleModule, LooseVersion, t.Any], t.Any] | None + ) = None, + get_expected_value: ( + Callable[ + [ + AnsibleModule, + AnsibleDockerClient, + LooseVersion, + dict[str, t.Any] | None, + t.Any, + Sentry, + ], + t.Any | Sentry, + ] + | None + ) = None, + ignore_mismatching_result: ( + Callable[ + [ + AnsibleModule, + AnsibleDockerClient, + LooseVersion, + Option, + dict[str, t.Any] | None, + t.Any, + t.Any, + dict[str, t.Any] | None, + ], + bool, + ] + | None + ) = None, + min_api_version: str | None = None, + preprocess_value: ( + Callable[[AnsibleModule, AnsibleDockerClient, LooseVersion, t.Any], t.Any] + | None + ) = None, + update_parameter: str | None = None, + extra_option_minimal_versions: dict[str, dict[str, t.Any]] | None = None, + ) -> DockerAPIEngine: + def preprocess_value_( + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + options: list[Option], + values: dict[str, t.Any], + ) -> dict[str, t.Any]: if len(options) != 1: raise AssertionError( "host_config_value can only be used for a single option" @@ -705,7 +1017,14 @@ class DockerAPIEngine(Engine): values[options[0].name] = value return values - def get_value(module, container, api_version, options, get_value, host_info): + def get_value( + module: AnsibleModule, + container: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + image: dict[str, t.Any] | None, + host_info: dict[str, t.Any] | None, + ) -> dict[str, t.Any]: if len(options) != 1: raise AssertionError( "host_config_value can only be used for a single option" @@ -717,11 +1036,32 @@ class DockerAPIEngine(Engine): return {} return {options[0].name: value} + get_expected_values_: ( + Callable[ + [ + AnsibleModule, + AnsibleDockerClient, + LooseVersion, + list[Option], + dict[str, t.Any] | None, + dict[str, t.Any], + dict[str, t.Any] | None, + ], + dict[str, t.Any], + ] + | None + ) = None if get_expected_value: - def get_expected_values_( - module, client, api_version, options, image, values, host_info - ): + def get_expected_values_( # noqa: F811 + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + options: list[Option], + image: dict[str, t.Any] | None, + values: dict[str, t.Any], + host_info: dict[str, t.Any] | None, + ) -> dict[str, t.Any]: if len(options) != 1: raise AssertionError( "host_config_value can only be used for a single option" @@ -734,10 +1074,13 @@ class DockerAPIEngine(Engine): return values return {options[0].name: value} - else: - get_expected_values_ = None - - def set_value(module, data, api_version, options, values): + def set_value( + module: AnsibleModule, + data: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + values: dict[str, t.Any], + ) -> None: if len(options) != 1: raise AssertionError( "host_config_value can only be used for a single option" @@ -751,9 +1094,28 @@ class DockerAPIEngine(Engine): value = preprocess_for_set(module, api_version, value) data["HostConfig"][host_config_name] = value + update_value: ( + Callable[ + [ + AnsibleModule, + dict[str, t.Any], + LooseVersion, + list[Option], + dict[str, t.Any], + ], + None, + ] + | None + ) = None if update_parameter: - def update_value(module, data, api_version, options, values): + def update_value( # noqa: F811 + module: AnsibleModule, + data: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + values: dict[str, t.Any], + ) -> None: if len(options) != 1: raise AssertionError( "update_parameter can only be used for a single option" @@ -765,9 +1127,6 @@ class DockerAPIEngine(Engine): value = preprocess_for_set(module, api_version, value) data[update_parameter] = value - else: - update_value = None - return cls( get_value=get_value, preprocess_value=preprocess_value_, @@ -780,35 +1139,13 @@ class DockerAPIEngine(Engine): ) -def _is_volume_permissions(mode): - for part in mode.split(","): - if part not in ( - "rw", - "ro", - "z", - "Z", - "consistent", - "delegated", - "cached", - "rprivate", - "private", - "rshared", - "shared", - "rslave", - "slave", - "nocopy", - ): - return False - return True - - -def _normalize_port(port): +def _normalize_port(port: str) -> str: if "/" not in port: return port + "/tcp" return port -def _get_default_host_ip(module, client): +def _get_default_host_ip(module: AnsibleModule, client: AnsibleDockerClient) -> str: if module.params["default_host_ip"] is not None: return module.params["default_host_ip"] ip = "0.0.0.0" @@ -828,8 +1165,13 @@ def _get_default_host_ip(module, client): def _get_value_detach_interactive( - module, container, api_version, options, image, host_info -): + module: AnsibleModule, + container: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + image: dict[str, t.Any] | None, + host_info: dict[str, t.Any] | None, +) -> dict[str, t.Any]: attach_stdin = container["Config"].get("OpenStdin") attach_stderr = container["Config"].get("AttachStderr") attach_stdout = container["Config"].get("AttachStdout") @@ -839,7 +1181,13 @@ def _get_value_detach_interactive( } -def _set_value_detach_interactive(module, data, api_version, options, values): +def _set_value_detach_interactive( + module: AnsibleModule, + data: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + values: dict[str, t.Any], +) -> None: interactive = values.get("interactive") detach = values.get("detach") @@ -856,7 +1204,14 @@ def _set_value_detach_interactive(module, data, api_version, options, values): data["StdinOnce"] = True -def _get_expected_env_value(module, client, api_version, image, value, sentry): +def _get_expected_env_value( + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + image: dict[str, t.Any] | None, + value: t.Any, + sentry: Sentry, +) -> t.Any | Sentry: expected_env = {} if image and image["Config"].get("Env"): for env_var in image["Config"]["Env"]: @@ -872,13 +1227,23 @@ def _get_expected_env_value(module, client, api_version, image, value, sentry): return param_env -def _preprocess_cpus(module, client, api_version, value): +def _preprocess_cpus( + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + value: t.Any, +) -> t.Any: if value is not None: value = int(round(value * 1e9)) return value -def _preprocess_devices(module, client, api_version, value): +def _preprocess_devices( + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + value: t.Any, +) -> t.Any: if not value: return value expected_devices = [] @@ -912,7 +1277,12 @@ def _preprocess_devices(module, client, api_version, value): return expected_devices -def _preprocess_rate_bps(module, client, api_version, value): +def _preprocess_rate_bps( + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + value: t.Any, +) -> t.Any: if not value: return value devices = [] @@ -926,7 +1296,12 @@ def _preprocess_rate_bps(module, client, api_version, value): return devices -def _preprocess_rate_iops(module, client, api_version, value): +def _preprocess_rate_iops( + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + value: t.Any, +) -> t.Any: if not value: return value devices = [] @@ -940,7 +1315,12 @@ def _preprocess_rate_iops(module, client, api_version, value): return devices -def _preprocess_device_requests(module, client, api_version, value): +def _preprocess_device_requests( + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + value: t.Any, +) -> t.Any: if not value: return value device_requests = [] @@ -957,7 +1337,12 @@ def _preprocess_device_requests(module, client, api_version, value): return device_requests -def _preprocess_etc_hosts(module, client, api_version, value): +def _preprocess_etc_hosts( + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + value: t.Any, +) -> t.Any: if value is None: return value results = [] @@ -966,7 +1351,12 @@ def _preprocess_etc_hosts(module, client, api_version, value): return results -def _preprocess_healthcheck(module, client, api_version, value): +def _preprocess_healthcheck( + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + value: t.Any, +) -> t.Any: if value is None: return value if not value or not ( @@ -988,13 +1378,20 @@ def _preprocess_healthcheck(module, client, api_version, value): ) -def _postprocess_healthcheck_get_value(module, api_version, value, sentry): +def _postprocess_healthcheck_get_value( + module: AnsibleModule, api_version: LooseVersion, value: t.Any, sentry: Sentry +) -> t.Any | Sentry: if value is None or value is sentry or value.get("Test") == ["NONE"]: return {"Test": ["NONE"]} return value -def _preprocess_convert_to_bytes(module, values, name, unlimited_value=None): +def _preprocess_convert_to_bytes( + module: AnsibleModule, + values: dict[str, t.Any], + name: str, + unlimited_value: int | None = None, +) -> dict[str, t.Any]: if name not in values: return values try: @@ -1009,7 +1406,7 @@ def _preprocess_convert_to_bytes(module, values, name, unlimited_value=None): module.fail_json(msg=f"Failed to convert {name} to bytes: {exc}") -def _get_image_labels(image): +def _get_image_labels(image: dict[str, t.Any] | None) -> dict[str, str]: if not image: return {} @@ -1017,7 +1414,14 @@ def _get_image_labels(image): return image["Config"].get("Labels") or {} -def _get_expected_labels_value(module, client, api_version, image, value, sentry): +def _get_expected_labels_value( + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + image: dict[str, t.Any] | None, + value: dict[str, t.Any], + sentry: Sentry, +) -> dict[str, t.Any] | Sentry: if value is sentry: return sentry expected_labels = {} @@ -1027,7 +1431,12 @@ def _get_expected_labels_value(module, client, api_version, image, value, sentry return expected_labels -def _preprocess_links(module, client, api_version, value): +def _preprocess_links( + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + value: t.Any, +) -> t.Any: if value is None: return None @@ -1044,15 +1453,15 @@ def _preprocess_links(module, client, api_version, value): def _ignore_mismatching_label_result( - module, - client, - api_version, - option, - image, - container_value, - expected_value, - host_info, -): + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + option: Option, + image: dict[str, t.Any] | None, + container_value: t.Any, + expected_value: t.Any, + host_info: dict[str, t.Any] | None, +) -> bool: if ( option.comparison == "strict" and module.params["image_label_mismatch"] == "fail" @@ -1076,20 +1485,20 @@ def _ignore_mismatching_label_result( return False -def _needs_host_info_network(values): +def _needs_host_info_network(values: dict[str, t.Any]) -> bool: return values.get("network_mode") == "default" def _ignore_mismatching_network_result( - module, - client, - api_version, - option, - image, - container_value, - expected_value, - host_info, -): + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + option: Option, + image: dict[str, t.Any] | None, + container_value: t.Any, + expected_value: t.Any, + host_info: dict[str, t.Any] | None, +) -> bool: # 'networks' is handled out-of-band if option.name == "networks": return True @@ -1102,7 +1511,13 @@ def _ignore_mismatching_network_result( return False -def _preprocess_network_values(module, client, api_version, options, values): +def _preprocess_network_values( + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + options: list[Option], + values: dict[str, t.Any], +) -> dict[str, t.Any]: if "networks" in values: for network in values["networks"]: network["id"] = _get_network_id(module, client, network["name"]) @@ -1119,27 +1534,40 @@ def _preprocess_network_values(module, client, api_version, options, values): return values -def _get_network_id(module, client, network_name): +def _get_network_id( + module: AnsibleModule, client: AnsibleDockerClient, network_name: str +) -> str | None: try: - network_id = None params = {"filters": json.dumps({"name": [network_name]})} for network in client.get_json("/networks", params=params): if network["Name"] == network_name: - network_id = network["Id"] - break - return network_id + return network["Id"] + return None except Exception as exc: # pylint: disable=broad-exception-caught client.fail(f"Error getting network id for {network_name} - {exc}") -def _get_values_network(module, container, api_version, options, image, host_info): +def _get_values_network( + module: AnsibleModule, + container: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + image: dict[str, t.Any] | None, + host_info: dict[str, t.Any] | None, +) -> dict[str, t.Any]: value = container["HostConfig"].get("NetworkMode", _SENTRY) if value is _SENTRY: return {} return {"network_mode": value} -def _set_values_network(module, data, api_version, options, values): +def _set_values_network( + module: AnsibleModule, + data: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + values: dict[str, t.Any], +) -> None: if "network_mode" not in values: return if "HostConfig" not in data: @@ -1148,7 +1576,14 @@ def _set_values_network(module, data, api_version, options, values): data["HostConfig"]["NetworkMode"] = value -def _get_values_mounts(module, container, api_version, options, image, host_info): +def _get_values_mounts( + module: AnsibleModule, + container: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + image: dict[str, t.Any] | None, + host_info: dict[str, t.Any] | None, +) -> dict[str, t.Any]: volumes = container["Config"].get("Volumes") binds = container["HostConfig"].get("Binds") # According to https://github.com/moby/moby/, support for HostConfig.Mounts @@ -1157,10 +1592,10 @@ def _get_values_mounts(module, container, api_version, options, image, host_info # HostConfig.Mounts. I have no idea what about API 1.25... mounts = container["HostConfig"].get("Mounts") if mounts is not None: - result = [] - empty_dict = {} + mounts_list = [] + empty_dict: dict[str, t.Any] = {} for mount in mounts: - result.append( + mounts_list.append( { "type": mount.get("Type"), "source": mount.get("Source"), @@ -1207,7 +1642,7 @@ def _get_values_mounts(module, container, api_version, options, image, host_info ), } ) - mounts = result + mounts = mounts_list result = {} if volumes is not None: result["volumes"] = volumes @@ -1218,7 +1653,7 @@ def _get_values_mounts(module, container, api_version, options, image, host_info return result -def _get_bind_from_dict(volume_dict): +def _get_bind_from_dict(volume_dict: dict[str, t.Any] | None) -> list[str]: results = [] if volume_dict: for host_path, config in volume_dict.items(): @@ -1229,7 +1664,7 @@ def _get_bind_from_dict(volume_dict): return results -def _get_image_binds(volumes): +def _get_image_binds(volumes: dict[str, t.Any] | list[dict[str, t.Any]]) -> list[str]: """ Convert array of binds to array of strings with format host_path:container_path:mode @@ -1246,8 +1681,14 @@ def _get_image_binds(volumes): def _get_expected_values_mounts( - module, client, api_version, options, image, values, host_info -): + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + options: list[Option], + image: dict[str, t.Any] | None, + values: dict[str, t.Any], + host_info: dict[str, t.Any] | None, +) -> dict[str, t.Any]: expected_values = {} # binds @@ -1284,11 +1725,17 @@ def _get_expected_values_mounts( return expected_values -def _set_values_mounts(module, data, api_version, options, values): +def _set_values_mounts( + module: AnsibleModule, + data: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + values: dict[str, t.Any], +) -> None: if "mounts" in values: if "HostConfig" not in data: data["HostConfig"] = {} - mounts = [] + mounts: list[dict[str, t.Any]] = [] for mount in values["mounts"]: mount_type = mount.get("type") mount_res = { @@ -1300,7 +1747,7 @@ def _set_values_mounts(module, data, api_version, options, values): if "consistency" in mount: mount_res["Consistency"] = mount["consistency"] if mount_type == "bind": - bind_opts = {} + bind_opts: dict[str, t.Any] = {} if "propagation" in mount: bind_opts["Propagation"] = mount["propagation"] if "non_recursive" in mount: @@ -1316,13 +1763,13 @@ def _set_values_mounts(module, data, api_version, options, values): if bind_opts: mount_res["BindOptions"] = bind_opts if mount_type == "volume": - volume_opts = {} + volume_opts: dict[str, t.Any] = {} if mount.get("no_copy"): volume_opts["NoCopy"] = True if mount.get("labels"): volume_opts["Labels"] = mount.get("labels") if mount.get("volume_driver"): - driver_config = { + driver_config: dict[str, t.Any] = { "Name": mount.get("volume_driver"), } if mount.get("volume_options"): @@ -1333,7 +1780,7 @@ def _set_values_mounts(module, data, api_version, options, values): if volume_opts: mount_res["VolumeOptions"] = volume_opts if mount_type == "tmpfs": - tmpfs_opts = {} + tmpfs_opts: dict[str, t.Any] = {} if mount.get("tmpfs_mode"): tmpfs_opts["Mode"] = mount.get("tmpfs_mode") if mount.get("tmpfs_size"): @@ -1343,7 +1790,7 @@ def _set_values_mounts(module, data, api_version, options, values): if tmpfs_opts: mount_res["TmpfsOptions"] = tmpfs_opts if mount_type == "image": - image_opts = {} + image_opts: dict[str, t.Any] = {} if "subpath" in mount: image_opts["Subpath"] = mount["subpath"] if image_opts: @@ -1351,7 +1798,7 @@ def _set_values_mounts(module, data, api_version, options, values): mounts.append(mount_res) data["HostConfig"]["Mounts"] = mounts if "volumes" in values: - volumes = {} + volumes: dict[str, t.Any] = {} for volume in values["volumes"]: # Only pass anonymous volumes to create container if ":" in volume: @@ -1369,7 +1816,14 @@ def _set_values_mounts(module, data, api_version, options, values): data["HostConfig"]["Binds"] = values["volume_binds"] -def _get_values_log(module, container, api_version, options, image, host_info): +def _get_values_log( + module: AnsibleModule, + container: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + image: dict[str, t.Any] | None, + host_info: dict[str, t.Any] | None, +) -> dict[str, t.Any]: log_config = container["HostConfig"].get("LogConfig") or {} return { "log_driver": log_config.get("Type"), @@ -1377,7 +1831,13 @@ def _get_values_log(module, container, api_version, options, image, host_info): } -def _set_values_log(module, data, api_version, options, values): +def _set_values_log( + module: AnsibleModule, + data: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + values: dict[str, t.Any], +) -> None: if "log_driver" not in values: return log_config = { @@ -1389,7 +1849,14 @@ def _set_values_log(module, data, api_version, options, values): data["HostConfig"]["LogConfig"] = log_config -def _get_values_platform(module, container, api_version, options, image, host_info): +def _get_values_platform( + module: AnsibleModule, + container: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + image: dict[str, t.Any] | None, + host_info: dict[str, t.Any] | None, +) -> dict[str, t.Any]: if image and (image.get("Os") or image.get("Architecture") or image.get("Variant")): return { "platform": compose_platform_string( @@ -1406,8 +1873,14 @@ def _get_values_platform(module, container, api_version, options, image, host_in def _get_expected_values_platform( - module, client, api_version, options, image, values, host_info -): + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + options: list[Option], + image: dict[str, t.Any] | None, + values: dict[str, t.Any], + host_info: dict[str, t.Any] | None, +) -> dict[str, t.Any]: expected_values = {} if "platform" in values: try: @@ -1421,20 +1894,33 @@ def _get_expected_values_platform( return expected_values -def _set_values_platform(module, data, api_version, options, values): +def _set_values_platform( + module: AnsibleModule, + data: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + values: dict[str, t.Any], +) -> None: if "platform" in values: data["platform"] = values["platform"] -def _needs_container_image_platform(values): +def _needs_container_image_platform(values: dict[str, t.Any]) -> bool: return "platform" in values -def _needs_host_info_platform(values): +def _needs_host_info_platform(values: dict[str, t.Any]) -> bool: return "platform" in values -def _get_values_restart(module, container, api_version, options, image, host_info): +def _get_values_restart( + module: AnsibleModule, + container: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + image: dict[str, t.Any] | None, + host_info: dict[str, t.Any] | None, +) -> dict[str, t.Any]: restart_policy = container["HostConfig"].get("RestartPolicy") or {} return { "restart_policy": restart_policy.get("Name"), @@ -1442,7 +1928,13 @@ def _get_values_restart(module, container, api_version, options, image, host_inf } -def _set_values_restart(module, data, api_version, options, values): +def _set_values_restart( + module: AnsibleModule, + data: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + values: dict[str, t.Any], +) -> None: if "restart_policy" not in values: return restart_policy = { @@ -1454,7 +1946,13 @@ def _set_values_restart(module, data, api_version, options, values): data["HostConfig"]["RestartPolicy"] = restart_policy -def _update_value_restart(module, data, api_version, options, values): +def _update_value_restart( + module: AnsibleModule, + data: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + values: dict[str, t.Any], +) -> None: if "restart_policy" not in values: return data["RestartPolicy"] = { @@ -1481,9 +1979,15 @@ def _get_values_ports(module, container, api_version, options, image, host_info) def _get_expected_values_ports( - module, client, api_version, options, image, values, host_info -): - expected_values = {} + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + options: list[Option], + image: dict[str, t.Any] | None, + values: dict[str, t.Any], + host_info: dict[str, t.Any] | None, +) -> dict[str, t.Any]: + expected_values: dict[str, t.Any] = {} if "published_ports" in values: expected_bound_ports = {} @@ -1538,9 +2042,15 @@ def _get_expected_values_ports( return expected_values -def _set_values_ports(module, data, api_version, options, values): +def _set_values_ports( + module: AnsibleModule, + data: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + values: dict[str, t.Any], +) -> None: if "ports" in values: - exposed_ports = {} + exposed_ports: dict[str, dict[str, t.Any]] = {} for port_definition in values["ports"]: port = port_definition proto = "tcp" @@ -1562,7 +2072,13 @@ def _set_values_ports(module, data, api_version, options, values): data["HostConfig"]["PublishAllPorts"] = values["publish_all_ports"] -def _preprocess_value_ports(module, client, api_version, options, values): +def _preprocess_value_ports( + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + options: list[Option], + values: dict[str, t.Any], +) -> dict[str, t.Any]: if "published_ports" not in values: return values found = False @@ -1579,7 +2095,12 @@ def _preprocess_value_ports(module, client, api_version, options, values): return values -def _preprocess_container_names(module, client, api_version, value): +def _preprocess_container_names( + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + value: t.Any, +) -> t.Any: if value is None or not value.startswith("container:"): return value container_name = value[len("container:") :] @@ -1594,14 +2115,27 @@ def _preprocess_container_names(module, client, api_version, value): return f"container:{container['Id']}" -def _get_value_command(module, container, api_version, options, image, host_info): +def _get_value_command( + module: AnsibleModule, + container: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + image: dict[str, t.Any] | None, + host_info: dict[str, t.Any] | None, +) -> dict[str, t.Any]: value = container["Config"].get("Cmd", _SENTRY) if value is _SENTRY: return {} return {"command": value} -def _set_value_command(module, data, api_version, options, values): +def _set_value_command( + module: AnsibleModule, + data: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + values: dict[str, t.Any], +) -> None: if "command" not in values: return value = values["command"] @@ -1609,8 +2143,14 @@ def _set_value_command(module, data, api_version, options, values): def _get_expected_values_command( - module, client, api_version, options, image, values, host_info -): + module: AnsibleModule, + client: AnsibleDockerClient, + api_version: LooseVersion, + options: list[Option], + image: dict[str, t.Any] | None, + values: dict[str, t.Any], + host_info: dict[str, t.Any] | None, +) -> dict[str, t.Any]: expected_values = {} if "command" in values: command = values["command"] @@ -1620,7 +2160,7 @@ def _get_expected_values_command( return expected_values -def _needs_container_image_command(values): +def _needs_container_image_command(values: dict[str, t.Any]) -> bool: return values.get("command") == [] diff --git a/plugins/module_utils/_module_container/module.py b/plugins/module_utils/_module_container/module.py index 486dfdc9..c234ed55 100644 --- a/plugins/module_utils/_module_container/module.py +++ b/plugins/module_utils/_module_container/module.py @@ -9,6 +9,7 @@ from __future__ import annotations import re +import typing as t from time import sleep from ansible.module_utils.common.text.converters import to_text @@ -16,6 +17,9 @@ from ansible.module_utils.common.text.converters import to_text from ansible_collections.community.docker.plugins.module_utils._api.utils.utils import ( parse_repository_tag, ) +from ansible_collections.community.docker.plugins.module_utils._module_container.base import ( + Client, +) from ansible_collections.community.docker.plugins.module_utils._util import ( DifferenceTracker, DockerBaseClass, @@ -25,13 +29,23 @@ from ansible_collections.community.docker.plugins.module_utils._util import ( ) +if t.TYPE_CHECKING: + from collections.abc import Sequence + + from ansible.module_utils.basic import AnsibleModule + + from .base import Engine, EngineDriver, Option, OptionGroup + + class Container(DockerBaseClass): - def __init__(self, container, engine_driver): + def __init__( + self, container: dict[str, t.Any] | None, engine_driver: EngineDriver + ) -> None: super().__init__() self.raw = container - self.id = None - self.image = None - self.image_name = None + self.id: str | None = None + self.image: str | None = None + self.image_name: str | None = None self.container = container self.engine_driver = engine_driver if container: @@ -41,11 +55,11 @@ class Container(DockerBaseClass): self.log(self.container, pretty_print=True) @property - def exists(self): + def exists(self) -> bool: return bool(self.container) @property - def removing(self): + def removing(self) -> bool: return ( self.engine_driver.is_container_removing(self.container) if self.container @@ -53,7 +67,7 @@ class Container(DockerBaseClass): ) @property - def running(self): + def running(self) -> bool: return ( self.engine_driver.is_container_running(self.container) if self.container @@ -61,7 +75,7 @@ class Container(DockerBaseClass): ) @property - def paused(self): + def paused(self) -> bool: return ( self.engine_driver.is_container_paused(self.container) if self.container @@ -69,8 +83,14 @@ class Container(DockerBaseClass): ) -class ContainerManager(DockerBaseClass): - def __init__(self, module, engine_driver, client, active_options): +class ContainerManager(DockerBaseClass, t.Generic[Client]): + def __init__( + self, + module: AnsibleModule, + engine_driver: EngineDriver, + client: Client, + active_options: list[OptionGroup], + ) -> None: super().__init__() self.module = module self.engine_driver = engine_driver @@ -78,46 +98,64 @@ class ContainerManager(DockerBaseClass): self.options = active_options self.all_options = self._collect_all_options(active_options) self.check_mode = self.module.check_mode - self.param_cleanup = self.module.params["cleanup"] - self.param_container_default_behavior = self.module.params[ - "container_default_behavior" - ] - self.param_default_host_ip = self.module.params["default_host_ip"] - self.param_debug = self.module.params["debug"] - self.param_force_kill = self.module.params["force_kill"] - self.param_image = self.module.params["image"] - self.param_image_comparison = self.module.params["image_comparison"] - self.param_image_label_mismatch = self.module.params["image_label_mismatch"] - self.param_image_name_mismatch = self.module.params["image_name_mismatch"] - self.param_keep_volumes = self.module.params["keep_volumes"] - self.param_kill_signal = self.module.params["kill_signal"] - self.param_name = self.module.params["name"] - self.param_networks_cli_compatible = self.module.params[ + self.param_cleanup: bool = self.module.params["cleanup"] + self.param_container_default_behavior: t.Literal[ + "compatibility", "no_defaults" + ] = self.module.params["container_default_behavior"] + self.param_default_host_ip: str | None = self.module.params["default_host_ip"] + self.param_debug: bool = self.module.params["debug"] + self.param_force_kill: bool = self.module.params["force_kill"] + self.param_image: str | None = self.module.params["image"] + self.param_image_comparison: t.Literal["desired-image", "current-image"] = ( + self.module.params["image_comparison"] + ) + self.param_image_label_mismatch: t.Literal["ignore", "fail"] = ( + self.module.params["image_label_mismatch"] + ) + self.param_image_name_mismatch: t.Literal["ignore", "recreate"] = ( + self.module.params["image_name_mismatch"] + ) + self.param_keep_volumes: bool = self.module.params["keep_volumes"] + self.param_kill_signal: str | None = self.module.params["kill_signal"] + self.param_name: str = self.module.params["name"] + self.param_networks_cli_compatible: bool = self.module.params[ "networks_cli_compatible" ] - self.param_output_logs = self.module.params["output_logs"] - self.param_paused = self.module.params["paused"] - self.param_pull = self.module.params["pull"] - if self.param_pull is True: - self.param_pull = "always" - if self.param_pull is False: - self.param_pull = "missing" - self.param_pull_check_mode_behavior = self.module.params[ - "pull_check_mode_behavior" + self.param_output_logs: bool = self.module.params["output_logs"] + self.param_paused: bool | None = self.module.params["paused"] + param_pull: t.Literal["never", "missing", "always", True, False] = ( + self.module.params["pull"] + ) + if param_pull is True: + param_pull = "always" + if param_pull is False: + param_pull = "missing" + self.param_pull: t.Literal["never", "missing", "always"] = param_pull + self.param_pull_check_mode_behavior: t.Literal[ + "image_not_present", "always" + ] = self.module.params["pull_check_mode_behavior"] + self.param_recreate: bool = self.module.params["recreate"] + self.param_removal_wait_timeout: int | float | None = self.module.params[ + "removal_wait_timeout" ] - self.param_recreate = self.module.params["recreate"] - self.param_removal_wait_timeout = self.module.params["removal_wait_timeout"] - self.param_healthy_wait_timeout = self.module.params["healthy_wait_timeout"] - if self.param_healthy_wait_timeout <= 0: + self.param_healthy_wait_timeout: int | float | None = self.module.params[ + "healthy_wait_timeout" + ] + if ( + self.param_healthy_wait_timeout is not None + and self.param_healthy_wait_timeout <= 0 + ): self.param_healthy_wait_timeout = None - self.param_restart = self.module.params["restart"] - self.param_state = self.module.params["state"] + self.param_restart: bool = self.module.params["restart"] + self.param_state: t.Literal[ + "absent", "present", "healthy", "started", "stopped" + ] = self.module.params["state"] self._parse_comparisons() self._update_params() self.results = {"changed": False, "actions": []} - self.diff = {} + self.diff: dict[str, t.Any] = {} self.diff_tracker = DifferenceTracker() - self.facts = {} + self.facts: dict[str, t.Any] | None = {} if self.param_default_host_ip: valid_ip = False if re.match( @@ -134,16 +172,22 @@ class ContainerManager(DockerBaseClass): "The value of default_host_ip must be an empty string, an IPv4 address, " f'or an IPv6 address. Got "{self.param_default_host_ip}" instead.' ) - self.parameters = None + self.parameters: list[tuple[OptionGroup, dict[str, t.Any]]] | None = None - def _collect_all_options(self, active_options): + def _add_action(self, action: dict[str, t.Any]) -> None: + actions: list[dict[str, t.Any]] = self.results["actions"] # type: ignore + actions.append(action) + + def _collect_all_options( + self, active_options: list[OptionGroup] + ) -> dict[str, Option]: all_options = {} for options in active_options: for option in options.options: all_options[option.name] = option return all_options - def _collect_all_module_params(self): + def _collect_all_module_params(self) -> set[str]: all_module_options = set() for option, data in self.module.argument_spec.items(): all_module_options.add(option) @@ -152,7 +196,7 @@ class ContainerManager(DockerBaseClass): all_module_options.add(alias) return all_module_options - def _parse_comparisons(self): + def _parse_comparisons(self) -> None: # Keep track of all module params and all option aliases all_module_options = self._collect_all_module_params() comp_aliases = {} @@ -163,10 +207,11 @@ class ContainerManager(DockerBaseClass): for alias in option.ansible_aliases: comp_aliases[alias] = option_name # Process comparisons specified by user - if self.module.params.get("comparisons"): + comparisons: dict[str, t.Any] | None = self.module.params.get("comparisons") + if comparisons: # If '*' appears in comparisons, process it first - if "*" in self.module.params["comparisons"]: - value = self.module.params["comparisons"]["*"] + if "*" in comparisons: + value = comparisons["*"] if value not in ("strict", "ignore"): self.fail( "The wildcard can only be used with comparison modes 'strict' and 'ignore'!" @@ -179,8 +224,8 @@ class ContainerManager(DockerBaseClass): continue option.comparison = value # Now process all other comparisons. - comp_aliases_used = {} - for key, value in self.module.params["comparisons"].items(): + comp_aliases_used: dict[str, str] = {} + for key, value in comparisons.items(): if key == "*": continue # Find main key @@ -220,7 +265,7 @@ class ContainerManager(DockerBaseClass): option.copy_comparison_from ].comparison - def _update_params(self): + def _update_params(self) -> None: if ( self.param_networks_cli_compatible is True and self.module.params["networks"] @@ -247,12 +292,14 @@ class ContainerManager(DockerBaseClass): if self.module.params[param] is None: self.module.params[param] = value - def fail(self, *args, **kwargs): - self.client.fail(*args, **kwargs) + def fail(self, *args, **kwargs) -> t.NoReturn: + # mypy doesn't know that Client has fail() method + raise self.client.fail(*args, **kwargs) # type: ignore - def run(self): + def run(self) -> None: if self.param_state in ("stopped", "started", "present", "healthy"): - self.present(self.param_state) + # mypy doesn't get that self.param_state has only one of the above values + self.present(self.param_state) # type: ignore elif self.param_state == "absent": self.absent() @@ -270,15 +317,16 @@ class ContainerManager(DockerBaseClass): def wait_for_state( self, - container_id, - complete_states=None, - wait_states=None, - accept_removal=False, - max_wait=None, - health_state=False, - ): + container_id: str, + *, + complete_states: Sequence[str | None] | None = None, + wait_states: Sequence[str | None] | None = None, + accept_removal: bool = False, + max_wait: int | float | None = None, + health_state: bool = False, + ) -> dict[str, t.Any] | None: delay = 1.0 - total_wait = 0 + total_wait = 0.0 while True: # Inspect container result = self.engine_driver.inspect_container_by_id( @@ -314,7 +362,9 @@ class ContainerManager(DockerBaseClass): # code will have slept for ~1.5 minutes.) delay = min(delay * 1.1, 10) - def _collect_params(self, active_options): + def _collect_params( + self, active_options: list[OptionGroup] + ) -> list[tuple[OptionGroup, dict[str, t.Any]]]: parameters = [] for options in active_options: values = {} @@ -336,21 +386,25 @@ class ContainerManager(DockerBaseClass): parameters.append((options, values)) return parameters - def _needs_container_image(self): + def _needs_container_image(self) -> bool: + assert self.parameters is not None for options, values in self.parameters: engine = options.get_engine(self.engine_driver.name) if engine.needs_container_image(values): return True return False - def _needs_host_info(self): + def _needs_host_info(self) -> bool: + assert self.parameters is not None for options, values in self.parameters: engine = options.get_engine(self.engine_driver.name) if engine.needs_host_info(values): return True return False - def present(self, state): + def present( + self, state: t.Literal["stopped", "started", "present", "healthy"] + ) -> None: self.parameters = self._collect_params(self.options) container = self._get_container(self.param_name) was_running = container.running @@ -382,6 +436,7 @@ class ContainerManager(DockerBaseClass): self.diff_tracker.add("exists", parameter=True, active=False) if container.removing and not self.check_mode: # Wait for container to be removed before trying to create it + assert container.id is not None self.wait_for_state( container.id, wait_states=["removing"], @@ -394,6 +449,7 @@ class ContainerManager(DockerBaseClass): container_created = True else: # Existing container + assert container.id is not None different, differences = self.has_different_configuration( container, container_image, comparison_image, host_info ) @@ -453,13 +509,16 @@ class ContainerManager(DockerBaseClass): if state in ("started", "healthy") and not container.running: self.diff_tracker.add("running", parameter=True, active=was_running) + assert container.id is not None container = self.container_start(container.id) elif state in ("started", "healthy") and self.param_restart: self.diff_tracker.add("running", parameter=True, active=was_running) self.diff_tracker.add("restarted", parameter=True, active=False) + assert container.id is not None container = self.container_restart(container.id) elif state == "stopped" and container.running: self.diff_tracker.add("running", parameter=False, active=was_running) + assert container.id is not None self.container_stop(container.id) container = self._get_container(container.id) @@ -472,6 +531,7 @@ class ContainerManager(DockerBaseClass): "paused", parameter=self.param_paused, active=was_paused ) if not self.check_mode: + assert container.id is not None try: if self.param_paused: self.engine_driver.pause_container( @@ -487,12 +547,13 @@ class ContainerManager(DockerBaseClass): ) container = self._get_container(container.id) self.results["changed"] = True - self.results["actions"].append({"set_paused": self.param_paused}) + self._add_action({"set_paused": self.param_paused}) self.facts = container.raw if state == "healthy" and not self.check_mode: # `None` means that no health check enabled; simply treat this as 'healthy' + assert container.id is not None inspect_result = self.wait_for_state( container.id, wait_states=["starting", "unhealthy"], @@ -504,41 +565,51 @@ class ContainerManager(DockerBaseClass): # Return the latest inspection results retrieved self.facts = inspect_result - def absent(self): + def absent(self) -> None: container = self._get_container(self.param_name) if container.exists: + assert container.id is not None if container.running: self.diff_tracker.add("running", parameter=False, active=True) self.container_stop(container.id) self.diff_tracker.add("exists", parameter=False, active=True) self.container_remove(container.id) - def _output_logs(self, msg): + def _output_logs(self, msg: str | bytes) -> None: self.module.log(msg=msg) - def _get_container(self, container): + def _get_container(self, container: str) -> Container: """ Expects container ID or Name. Returns a container object """ - container = self.engine_driver.inspect_container_by_name(self.client, container) - return Container(container, self.engine_driver) + container_data = self.engine_driver.inspect_container_by_name( + self.client, container + ) + return Container(container_data, self.engine_driver) - def _get_container_image(self, container, fallback=None): + def _get_container_image( + self, container: Container, fallback: dict[str, t.Any] | None = None + ) -> dict[str, t.Any] | None: if not container.exists or container.removing: return fallback image = container.image + assert image is not None if is_image_name_id(image): - image = self.engine_driver.inspect_image_by_id(self.client, image) + image_data = self.engine_driver.inspect_image_by_id(self.client, image) else: repository, tag = parse_repository_tag(image) if not tag: tag = "latest" - image = self.engine_driver.inspect_image_by_name( + image_data = self.engine_driver.inspect_image_by_name( self.client, repository, tag ) - return image or fallback + return image_data or fallback - def _get_image(self, container, needs_container_image=False): + def _get_image( + self, container: Container, needs_container_image: bool = False + ) -> tuple[ + dict[str, t.Any] | None, dict[str, t.Any] | None, dict[str, t.Any] | None + ]: image_parameter = self.param_image get_container_image = needs_container_image or not image_parameter container_image = ( @@ -553,7 +624,7 @@ class ContainerManager(DockerBaseClass): if is_image_name_id(image_parameter): image = self.engine_driver.inspect_image_by_id(self.client, image_parameter) if image is None: - self.client.fail(f"Cannot find image with ID {image_parameter}") + self.fail(f"Cannot find image with ID {image_parameter}") else: repository, tag = parse_repository_tag(image_parameter) if not tag: @@ -562,7 +633,7 @@ class ContainerManager(DockerBaseClass): self.client, repository, tag ) if not image and self.param_pull == "never": - self.client.fail( + self.fail( f"Cannot find image with name {repository}:{tag}, and pull=never" ) if not image or self.param_pull == "always": @@ -576,12 +647,12 @@ class ContainerManager(DockerBaseClass): ) if already_to_latest: self.results["changed"] = False - self.results["actions"].append( + self._add_action( {"pulled_image": f"{repository}:{tag}", "changed": False} ) else: self.results["changed"] = True - self.results["actions"].append( + self._add_action( {"pulled_image": f"{repository}:{tag}", "changed": True} ) elif not image or self.param_pull_check_mode_behavior == "always": @@ -589,10 +660,10 @@ class ContainerManager(DockerBaseClass): # pull. (Implicitly: if the image is there, claim it already was latest unless # pull_check_mode_behavior == 'always'.) self.results["changed"] = True - action = {"pulled_image": f"{repository}:{tag}"} + action: dict[str, t.Any] = {"pulled_image": f"{repository}:{tag}"} if not image: action["changed"] = True - self.results["actions"].append(action) + self._add_action(action) self.log("image") self.log(image, pretty_print=True) @@ -605,7 +676,9 @@ class ContainerManager(DockerBaseClass): return image, container_image, comparison_image - def _image_is_different(self, image, container): + def _image_is_different( + self, image: dict[str, t.Any] | None, container: Container + ) -> bool: if image and image.get("Id"): if container and container.image: if image.get("Id") != container.image: @@ -615,8 +688,9 @@ class ContainerManager(DockerBaseClass): return True return False - def _compose_create_parameters(self, image): - params = {} + def _compose_create_parameters(self, image: str) -> dict[str, t.Any]: + params: dict[str, t.Any] = {} + assert self.parameters is not None for options, values in self.parameters: engine = options.get_engine(self.engine_driver.name) if engine.can_set_value(self.engine_driver.get_api_version(self.client)): @@ -632,15 +706,16 @@ class ContainerManager(DockerBaseClass): def _record_differences( self, - differences, - options, - param_values, - engine, - container, - container_image, - image, - host_info, + differences: DifferenceTracker, + options: OptionGroup, + param_values: dict[str, t.Any], + engine: Engine, + container: Container, + container_image: dict[str, t.Any] | None, + image: dict[str, t.Any] | None, + host_info: dict[str, t.Any] | None, ): + assert container.raw is not None container_values = engine.get_value( self.module, container.raw, @@ -709,9 +784,16 @@ class ContainerManager(DockerBaseClass): c = sorted(c, key=sort_key_fn) differences.add(option.name, parameter=p, active=c) - def has_different_configuration(self, container, container_image, image, host_info): + def has_different_configuration( + self, + container: Container, + container_image: dict[str, t.Any] | None, + image: dict[str, t.Any] | None, + host_info: dict[str, t.Any] | None, + ) -> tuple[bool, DifferenceTracker]: differences = DifferenceTracker() update_differences = DifferenceTracker() + assert self.parameters is not None for options, param_values in self.parameters: engine = options.get_engine(self.engine_driver.name) if engine.can_update_value(self.engine_driver.get_api_version(self.client)): @@ -743,9 +825,14 @@ class ContainerManager(DockerBaseClass): return has_differences, differences def has_different_resource_limits( - self, container, container_image, image, host_info - ): + self, + container: Container, + container_image: dict[str, t.Any] | None, + image: dict[str, t.Any] | None, + host_info: dict[str, t.Any] | None, + ) -> tuple[bool, DifferenceTracker]: differences = DifferenceTracker() + assert self.parameters is not None for options, param_values in self.parameters: engine = options.get_engine(self.engine_driver.name) if not engine.can_update_value( @@ -765,8 +852,9 @@ class ContainerManager(DockerBaseClass): has_differences = not differences.empty return has_differences, differences - def _compose_update_parameters(self): - result = {} + def _compose_update_parameters(self) -> dict[str, t.Any]: + result: dict[str, t.Any] = {} + assert self.parameters is not None for options, values in self.parameters: engine = options.get_engine(self.engine_driver.name) if not engine.can_update_value( @@ -782,7 +870,13 @@ class ContainerManager(DockerBaseClass): ) return result - def update_limits(self, container, container_image, image, host_info): + def update_limits( + self, + container: Container, + container_image: dict[str, t.Any] | None, + image: dict[str, t.Any] | None, + host_info: dict[str, t.Any] | None, + ) -> Container: limits_differ, different_limits = self.has_different_resource_limits( container, container_image, image, host_info ) @@ -793,20 +887,24 @@ class ContainerManager(DockerBaseClass): ) self.diff_tracker.merge(different_limits) if limits_differ and not self.check_mode: + assert container.id is not None self.container_update(container.id, self._compose_update_parameters()) return self._get_container(container.id) return container - def has_network_differences(self, container): + def has_network_differences( + self, container: Container + ) -> tuple[bool, list[dict[str, t.Any]]]: """ Check if the container is connected to requested networks with expected options: links, aliases, ipv4, ipv6 """ different = False - differences = [] + differences: list[dict[str, t.Any]] = [] if not self.module.params["networks"]: return different, differences + assert container.container is not None if not container.container.get("NetworkSettings"): self.fail( "has_missing_networks: Error parsing container properties. NetworkSettings missing." @@ -869,13 +967,16 @@ class ContainerManager(DockerBaseClass): ) return different, differences - def has_extra_networks(self, container): + def has_extra_networks( + self, container: Container + ) -> tuple[bool, list[dict[str, t.Any]]]: """ Check if the container is connected to non-requested networks """ - extra_networks = [] + extra_networks: list[dict[str, t.Any]] = [] extra = False + assert container.container is not None if not container.container.get("NetworkSettings"): self.fail( "has_extra_networks: Error parsing container properties. NetworkSettings missing." @@ -896,7 +997,9 @@ class ContainerManager(DockerBaseClass): ) return extra, extra_networks - def update_networks(self, container, container_created): + def update_networks( + self, container: Container, container_created: bool + ) -> Container: updated_container = container if self.all_options["networks"].comparison != "ignore" or container_created: has_network_differences, network_differences = self.has_network_differences( @@ -939,13 +1042,14 @@ class ContainerManager(DockerBaseClass): updated_container = self._purge_networks(container, extra_networks) return updated_container - def _add_networks(self, container, differences): + def _add_networks( + self, container: Container, differences: list[dict[str, t.Any]] + ) -> Container: + assert container.id is not None for diff in differences: # remove the container from the network, if connected if diff.get("container"): - self.results["actions"].append( - {"removed_from_network": diff["parameter"]["name"]} - ) + self._add_action({"removed_from_network": diff["parameter"]["name"]}) if not self.check_mode: try: self.engine_driver.disconnect_container_from_network( @@ -956,7 +1060,7 @@ class ContainerManager(DockerBaseClass): f"Error disconnecting container from network {diff['parameter']['name']} - {exc}" ) # connect to the network - self.results["actions"].append( + self._add_action( { "added_to_network": diff["parameter"]["name"], "network_parameters": diff["parameter"], @@ -982,9 +1086,12 @@ class ContainerManager(DockerBaseClass): ) return self._get_container(container.id) - def _purge_networks(self, container, networks): + def _purge_networks( + self, container: Container, networks: list[dict[str, t.Any]] + ) -> Container: + assert container.id is not None for network in networks: - self.results["actions"].append({"removed_from_network": network["name"]}) + self._add_action({"removed_from_network": network["name"]}) if not self.check_mode: try: self.engine_driver.disconnect_container_from_network( @@ -996,7 +1103,7 @@ class ContainerManager(DockerBaseClass): ) return self._get_container(container.id) - def container_create(self, image): + def container_create(self, image: str) -> Container | None: create_parameters = self._compose_create_parameters(image) self.log("create container") self.log(f"image: {image} parameters:") @@ -1014,7 +1121,7 @@ class ContainerManager(DockerBaseClass): for key, value in network.items() if key not in ("name", "id") } - self.results["actions"].append( + self._add_action( { "created": "Created container", "create_parameters": create_parameters, @@ -1022,7 +1129,6 @@ class ContainerManager(DockerBaseClass): } ) self.results["changed"] = True - new_container = None if not self.check_mode: try: container_id = self.engine_driver.create_container( @@ -1031,11 +1137,11 @@ class ContainerManager(DockerBaseClass): except Exception as exc: # pylint: disable=broad-exception-caught self.fail(f"Error creating container: {exc}") return self._get_container(container_id) - return new_container + return None - def container_start(self, container_id): + def container_start(self, container_id: str) -> Container: self.log(f"start container {container_id}") - self.results["actions"].append({"started": container_id}) + self._add_action({"started": container_id}) self.results["changed"] = True if not self.check_mode: try: @@ -1047,9 +1153,11 @@ class ContainerManager(DockerBaseClass): status = self.engine_driver.wait_for_container( self.client, container_id ) - self.client.fail_results["status"] = status + # mypy doesn't know that Client has fail_results property + self.client.fail_results["status"] = status # type: ignore self.results["status"] = status + output: str | bytes if self.module.params["auto_remove"]: output = "Cannot retrieve result as auto_remove is enabled" if self.param_output_logs: @@ -1077,12 +1185,14 @@ class ContainerManager(DockerBaseClass): return insp return self._get_container(container_id) - def container_remove(self, container_id, link=False, force=False): + def container_remove( + self, container_id: str, link: bool = False, force: bool = False + ) -> None: volume_state = not self.param_keep_volumes self.log( f"remove container container:{container_id} v:{volume_state} link:{link} force{force}" ) - self.results["actions"].append( + self._add_action( { "removed": container_id, "volume_state": volume_state, @@ -1101,13 +1211,15 @@ class ContainerManager(DockerBaseClass): force=force, ) except Exception as exc: # pylint: disable=broad-exception-caught - self.client.fail(f"Error removing container {container_id}: {exc}") + self.fail(f"Error removing container {container_id}: {exc}") - def container_update(self, container_id, update_parameters): + def container_update( + self, container_id: str, update_parameters: dict[str, t.Any] + ) -> Container: if update_parameters: self.log(f"update container {container_id}") self.log(update_parameters, pretty_print=True) - self.results["actions"].append( + self._add_action( {"updated": container_id, "update_parameters": update_parameters} ) self.results["changed"] = True @@ -1120,10 +1232,8 @@ class ContainerManager(DockerBaseClass): self.fail(f"Error updating container {container_id}: {exc}") return self._get_container(container_id) - def container_kill(self, container_id): - self.results["actions"].append( - {"killed": container_id, "signal": self.param_kill_signal} - ) + def container_kill(self, container_id: str) -> None: + self._add_action({"killed": container_id, "signal": self.param_kill_signal}) self.results["changed"] = True if not self.check_mode: try: @@ -1133,8 +1243,8 @@ class ContainerManager(DockerBaseClass): except Exception as exc: # pylint: disable=broad-exception-caught self.fail(f"Error killing container {container_id}: {exc}") - def container_restart(self, container_id): - self.results["actions"].append( + def container_restart(self, container_id: str) -> Container: + self._add_action( {"restarted": container_id, "timeout": self.module.params["stop_timeout"]} ) self.results["changed"] = True @@ -1147,11 +1257,11 @@ class ContainerManager(DockerBaseClass): self.fail(f"Error restarting container {container_id}: {exc}") return self._get_container(container_id) - def container_stop(self, container_id): + def container_stop(self, container_id: str) -> None: if self.param_force_kill: self.container_kill(container_id) return - self.results["actions"].append( + self._add_action( {"stopped": container_id, "timeout": self.module.params["stop_timeout"]} ) self.results["changed"] = True @@ -1164,7 +1274,7 @@ class ContainerManager(DockerBaseClass): self.fail(f"Error stopping container {container_id}: {exc}") -def run_module(engine_driver): +def run_module(engine_driver: EngineDriver) -> None: module, active_options, client = engine_driver.setup( argument_spec={ "cleanup": {"type": "bool", "default": False}, @@ -1228,7 +1338,7 @@ def run_module(engine_driver): ], ) - def execute(): + def execute() -> t.NoReturn: cm = ContainerManager(module, engine_driver, client, active_options) cm.run() module.exit_json(**sanitize_result(cm.results)) diff --git a/plugins/module_utils/_platform.py b/plugins/module_utils/_platform.py index 567a4173..3807b6a0 100644 --- a/plugins/module_utils/_platform.py +++ b/plugins/module_utils/_platform.py @@ -14,12 +14,13 @@ from __future__ import annotations import re +import typing as t _VALID_STR = re.compile("^[A-Za-z0-9_-]+$") -def _validate_part(string, part, part_name): +def _validate_part(string: str, part: str, part_name: str) -> str: if not part: raise ValueError(f'Invalid platform string "{string}": {part_name} is empty') if not _VALID_STR.match(part): @@ -79,7 +80,7 @@ _KNOWN_ARCH = ( ) -def _normalize_os(os_str): +def _normalize_os(os_str: str) -> str: # See normalizeOS() in https://github.com/containerd/containerd/blob/main/platforms/database.go os_str = os_str.lower() if os_str == "macos": @@ -112,7 +113,7 @@ _NORMALIZE_ARCH = { } -def _normalize_arch(arch_str, variant_str): +def _normalize_arch(arch_str: str, variant_str: str) -> tuple[str, str]: # See normalizeArch() in https://github.com/containerd/containerd/blob/main/platforms/database.go arch_str = arch_str.lower() variant_str = variant_str.lower() @@ -121,15 +122,16 @@ def _normalize_arch(arch_str, variant_str): res = _NORMALIZE_ARCH.get((arch_str, None)) if res is None: return arch_str, variant_str - if res is not None: - arch_str = res[0] - if res[1] is not None: - variant_str = res[1] - return arch_str, variant_str + arch_str = res[0] + if res[1] is not None: + variant_str = res[1] + return arch_str, variant_str class _Platform: - def __init__(self, os=None, arch=None, variant=None): + def __init__( + self, os: str | None = None, arch: str | None = None, variant: str | None = None + ) -> None: self.os = os self.arch = arch self.variant = variant @@ -140,7 +142,12 @@ class _Platform: raise ValueError("If variant is given, os must be given too") @classmethod - def parse_platform_string(cls, string, daemon_os=None, daemon_arch=None): + def parse_platform_string( + cls, + string: str | None, + daemon_os: str | None = None, + daemon_arch: str | None = None, + ) -> t.Self: # See Parse() in https://github.com/containerd/containerd/blob/main/platforms/platforms.go if string is None: return cls() @@ -182,6 +189,7 @@ class _Platform: ) if variant is not None and not variant: raise ValueError(f'Invalid platform string "{string}": variant is empty') + assert arch is not None # otherwise variant would be None as well arch, variant = _normalize_arch(arch, variant or "") if len(parts) == 2 and arch == "arm" and variant == "v7": variant = None @@ -189,9 +197,12 @@ class _Platform: variant = "v8" return cls(os=_normalize_os(os), arch=arch, variant=variant or None) - def __str__(self): + def __str__(self) -> str: if self.variant: - parts = [self.os, self.arch, self.variant] + assert ( + self.os is not None and self.arch is not None + ) # ensured in constructor + parts: list[str] = [self.os, self.arch, self.variant] elif self.os: if self.arch: parts = [self.os, self.arch] @@ -203,12 +214,14 @@ class _Platform: parts = [] return "/".join(parts) - def __repr__(self): + def __repr__(self) -> str: return ( f"_Platform(os={self.os!r}, arch={self.arch!r}, variant={self.variant!r})" ) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + if not isinstance(other, _Platform): + return NotImplemented return ( self.os == other.os and self.arch == other.arch @@ -216,7 +229,9 @@ class _Platform: ) -def normalize_platform_string(string, daemon_os=None, daemon_arch=None): +def normalize_platform_string( + string: str, daemon_os: str | None = None, daemon_arch: str | None = None +) -> str: return str( _Platform.parse_platform_string( string, daemon_os=daemon_os, daemon_arch=daemon_arch @@ -225,8 +240,12 @@ def normalize_platform_string(string, daemon_os=None, daemon_arch=None): def compose_platform_string( - os=None, arch=None, variant=None, daemon_os=None, daemon_arch=None -): + os: str | None = None, + arch: str | None = None, + variant: str | None = None, + daemon_os: str | None = None, + daemon_arch: str | None = None, +) -> str: if os is None and daemon_os is not None: os = _normalize_os(daemon_os) if arch is None and daemon_arch is not None: @@ -235,7 +254,7 @@ def compose_platform_string( return str(_Platform(os=os, arch=arch, variant=variant or None)) -def compare_platform_strings(string1, string2): +def compare_platform_strings(string1: str, string2: str) -> bool: return _Platform.parse_platform_string(string1) == _Platform.parse_platform_string( string2 ) diff --git a/plugins/module_utils/_scramble.py b/plugins/module_utils/_scramble.py index 621f459f..fa4b6dc2 100644 --- a/plugins/module_utils/_scramble.py +++ b/plugins/module_utils/_scramble.py @@ -13,7 +13,7 @@ import random from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text -def generate_insecure_key(): +def generate_insecure_key() -> bytes: """Do NOT use this for cryptographic purposes!""" while True: # Generate a one-byte key. Right now the functions below do not use more @@ -24,23 +24,23 @@ def generate_insecure_key(): return key -def scramble(value, key): +def scramble(value: str, key: bytes) -> str: """Do NOT use this for cryptographic purposes!""" if len(key) < 1: raise ValueError("Key must be at least one byte") - value = to_bytes(value) + b_value = to_bytes(value) k = key[0] - value = bytes([k ^ b for b in value]) - return "=S=" + to_native(base64.b64encode(value)) + b_value = bytes([k ^ b for b in b_value]) + return f"=S={to_native(base64.b64encode(b_value))}" -def unscramble(value, key): +def unscramble(value: str, key: bytes) -> str: """Do NOT use this for cryptographic purposes!""" if len(key) < 1: raise ValueError("Key must be at least one byte") if not value.startswith("=S="): raise ValueError("Value does not start with indicator") - value = base64.b64decode(value[3:]) + b_value = base64.b64decode(value[3:]) k = key[0] - value = bytes([k ^ b for b in value]) - return to_text(value) + b_value = bytes([k ^ b for b in b_value]) + return to_text(b_value) diff --git a/plugins/module_utils/_socket_handler.py b/plugins/module_utils/_socket_handler.py index f2a4aa7d..67e0afd4 100644 --- a/plugins/module_utils/_socket_handler.py +++ b/plugins/module_utils/_socket_handler.py @@ -12,6 +12,7 @@ import os.path import selectors import socket as pysocket import struct +import typing as t from ansible_collections.community.docker.plugins.module_utils._api.utils import ( socket as docker_socket, @@ -23,58 +24,74 @@ from ansible_collections.community.docker.plugins.module_utils._socket_helper im ) +if t.TYPE_CHECKING: + from collections.abc import Callable + + from ansible.module_utils.basic import AnsibleModule + + from ansible_collections.community.docker.plugins.module_utils._socket_helper import ( + SocketLike, + ) + + PARAMIKO_POLL_TIMEOUT = 0.01 # 10 milliseconds +def _empty_writer(msg: str) -> None: + pass + + class DockerSocketHandlerBase: - def __init__(self, sock, log=None): + def __init__( + self, sock: SocketLike, log: Callable[[str], None] | None = None + ) -> None: make_unblocking(sock) - if log is not None: - self._log = log - else: - self._log = lambda msg: True + self._log = log or _empty_writer self._paramiko_read_workaround = hasattr( sock, "send_ready" ) and "paramiko" in str(type(sock)) self._sock = sock - self._block_done_callback = None - self._block_buffer = [] + self._block_done_callback: Callable[[int, bytes], None] | None = None + self._block_buffer: list[tuple[int, bytes]] = [] self._eof = False self._read_buffer = b"" self._write_buffer = b"" self._end_of_writing = False - self._current_stream = None + self._current_stream: int | None = None self._current_missing = 0 self._current_buffer = b"" self._selector = selectors.DefaultSelector() self._selector.register(self._sock, selectors.EVENT_READ) - def __enter__(self): + def __enter__(self) -> t.Self: return self - def __exit__(self, type_, value, tb): + def __exit__(self, type_, value, tb) -> None: self._selector.close() - def set_block_done_callback(self, block_done_callback): + def set_block_done_callback( + self, block_done_callback: Callable[[int, bytes], None] + ) -> None: self._block_done_callback = block_done_callback if self._block_done_callback is not None: while self._block_buffer: elt = self._block_buffer.pop(0) self._block_done_callback(*elt) - def _add_block(self, stream_id, data): + def _add_block(self, stream_id: int, data: bytes) -> None: if self._block_done_callback is not None: self._block_done_callback(stream_id, data) else: self._block_buffer.append((stream_id, data)) - def _read(self): + def _read(self) -> None: if self._eof: return + data: bytes | None if hasattr(self._sock, "recv"): try: data = self._sock.recv(262144) @@ -86,13 +103,13 @@ class DockerSocketHandlerBase: self._eof = True return raise - elif isinstance(self._sock, getattr(pysocket, "SocketIO")): - data = self._sock.read() + elif isinstance(self._sock, pysocket.SocketIO): # type: ignore[unreachable] + data = self._sock.read() # type: ignore else: - data = os.read(self._sock.fileno()) + data = os.read(self._sock.fileno()) # type: ignore # TODO does this really work?! if data is None: # no data available - return + return # type: ignore[unreachable] self._log(f"read {len(data)} bytes") if len(data) == 0: # Stream EOF @@ -106,6 +123,7 @@ class DockerSocketHandlerBase: self._read_buffer = self._read_buffer[n:] self._current_missing -= n if self._current_missing == 0: + assert self._current_stream is not None self._add_block(self._current_stream, self._current_buffer) self._current_buffer = b"" if len(self._read_buffer) < 8: @@ -119,13 +137,13 @@ class DockerSocketHandlerBase: self._eof = True break - def _handle_end_of_writing(self): + def _handle_end_of_writing(self) -> None: if self._end_of_writing and len(self._write_buffer) == 0: self._end_of_writing = False self._log("Shutting socket down for writing") shutdown_writing(self._sock, self._log) - def _write(self): + def _write(self) -> None: if len(self._write_buffer) > 0: written = write_to_socket(self._sock, self._write_buffer) self._write_buffer = self._write_buffer[written:] @@ -138,7 +156,9 @@ class DockerSocketHandlerBase: self._selector.modify(self._sock, selectors.EVENT_READ) self._handle_end_of_writing() - def select(self, timeout=None, _internal_recursion=False): + def select( + self, timeout: int | float | None = None, _internal_recursion: bool = False + ) -> bool: if ( not _internal_recursion and self._paramiko_read_workaround @@ -147,12 +167,14 @@ class DockerSocketHandlerBase: # When the SSH transport is used, Docker SDK for Python internally uses Paramiko, whose # Channel object supports select(), but only for reading # (https://github.com/paramiko/paramiko/issues/695). - if self._sock.send_ready(): + if self._sock.send_ready(): # type: ignore self._write() return True while timeout is None or timeout > PARAMIKO_POLL_TIMEOUT: - result = self.select(PARAMIKO_POLL_TIMEOUT, _internal_recursion=True) - if self._sock.send_ready(): + result = int( + self.select(PARAMIKO_POLL_TIMEOUT, _internal_recursion=True) + ) + if self._sock.send_ready(): # type: ignore self._read() result += 1 if result > 0: @@ -172,19 +194,19 @@ class DockerSocketHandlerBase: self._write() result = len(events) if self._paramiko_read_workaround and len(self._write_buffer) > 0: - if self._sock.send_ready(): + if self._sock.send_ready(): # type: ignore self._write() result += 1 return result > 0 - def is_eof(self): + def is_eof(self) -> bool: return self._eof - def end_of_writing(self): + def end_of_writing(self) -> None: self._end_of_writing = True self._handle_end_of_writing() - def consume(self): + def consume(self) -> tuple[bytes, bytes]: stdout = [] stderr = [] @@ -203,12 +225,12 @@ class DockerSocketHandlerBase: self.select() return b"".join(stdout), b"".join(stderr) - def write(self, str_to_write): + def write(self, str_to_write: bytes) -> None: self._write_buffer += str_to_write if len(self._write_buffer) == len(str_to_write): self._write() class DockerSocketHandlerModule(DockerSocketHandlerBase): - def __init__(self, sock, module): + def __init__(self, sock: SocketLike, module: AnsibleModule) -> None: super().__init__(sock, module.debug) diff --git a/plugins/module_utils/_socket_helper.py b/plugins/module_utils/_socket_helper.py index 6158f4e5..9927c3df 100644 --- a/plugins/module_utils/_socket_helper.py +++ b/plugins/module_utils/_socket_helper.py @@ -12,9 +12,14 @@ import os import os.path import socket as pysocket import typing as t +from collections.abc import Callable -def make_file_unblocking(file) -> None: +if t.TYPE_CHECKING: + SocketLike = pysocket.socket + + +def make_file_unblocking(file: SocketLike) -> None: fcntl.fcntl( file.fileno(), fcntl.F_SETFL, @@ -22,7 +27,7 @@ def make_file_unblocking(file) -> None: ) -def make_file_blocking(file) -> None: +def make_file_blocking(file: SocketLike) -> None: fcntl.fcntl( file.fileno(), fcntl.F_SETFL, @@ -30,11 +35,11 @@ def make_file_blocking(file) -> None: ) -def make_unblocking(sock) -> None: +def make_unblocking(sock: SocketLike) -> None: if hasattr(sock, "_sock"): sock._sock.setblocking(0) elif hasattr(sock, "setblocking"): - sock.setblocking(0) + sock.setblocking(0) # type: ignore # TODO: CHECK! else: make_file_unblocking(sock) @@ -43,7 +48,9 @@ def _empty_writer(msg: str) -> None: pass -def shutdown_writing(sock, log: t.Callable[[str], None] = _empty_writer) -> None: +def shutdown_writing( + sock: SocketLike, log: Callable[[str], None] = _empty_writer +) -> None: # FIXME: This does **not work with SSLSocket**! Apparently SSLSocket does not allow to send # a close_notify TLS alert without completely shutting down the connection. # Calling sock.shutdown(pysocket.SHUT_WR) simply turns of TLS encryption and from that @@ -56,14 +63,14 @@ def shutdown_writing(sock, log: t.Callable[[str], None] = _empty_writer) -> None except TypeError as e: # probably: "TypeError: shutdown() takes 1 positional argument but 2 were given" log(f"Shutting down for writing not possible; trying shutdown instead: {e}") - sock.shutdown() + sock.shutdown() # type: ignore elif isinstance(sock, getattr(pysocket, "SocketIO")): sock._sock.shutdown(pysocket.SHUT_WR) else: log("No idea how to signal end of writing") -def write_to_socket(sock, data: bytes) -> None: +def write_to_socket(sock: SocketLike, data: bytes) -> int: if hasattr(sock, "_send_until_done"): # WrappedSocket (urllib3/contrib/pyopenssl) does not have `send`, but # only `sendall`, which uses `_send_until_done` under the hood. diff --git a/plugins/module_utils/_swarm.py b/plugins/module_utils/_swarm.py index 93b08d5d..699f5821 100644 --- a/plugins/module_utils/_swarm.py +++ b/plugins/module_utils/_swarm.py @@ -9,6 +9,7 @@ from __future__ import annotations import json +import typing as t from time import sleep @@ -28,10 +29,7 @@ from ansible_collections.community.docker.plugins.module_utils._version import ( class AnsibleDockerSwarmClient(AnsibleDockerClient): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def get_swarm_node_id(self): + def get_swarm_node_id(self) -> str | None: """ Get the 'NodeID' of the Swarm node or 'None' if host is not in Swarm. It returns the NodeID of Docker host the module is executed on @@ -51,7 +49,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient): return swarm_info["Swarm"]["NodeID"] return None - def check_if_swarm_node(self, node_id=None): + def check_if_swarm_node(self, node_id: str | None = None) -> bool | None: """ Checking if host is part of Docker Swarm. If 'node_id' is not provided it reads the Docker host system information looking if specific key in output exists. If 'node_id' is provided then it tries to @@ -83,11 +81,11 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient): try: node_info = self.get_node_inspect(node_id=node_id) except APIError: - return + return None return node_info["ID"] is not None - def check_if_swarm_manager(self): + def check_if_swarm_manager(self) -> bool: """ Checks if node role is set as Manager in Swarm. The node is the docker host on which module action is performed. The inspect_swarm() will fail if node is not a manager @@ -101,7 +99,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient): except APIError: return False - def fail_task_if_not_swarm_manager(self): + def fail_task_if_not_swarm_manager(self) -> None: """ If host is not a swarm manager then Ansible task on this host should end with 'failed' state """ @@ -110,7 +108,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient): "Error running docker swarm module: must run on swarm manager node" ) - def check_if_swarm_worker(self): + def check_if_swarm_worker(self) -> bool: """ Checks if node role is set as Worker in Swarm. The node is the docker host on which module action is performed. Will fail if run on host that is not part of Swarm via check_if_swarm_node() @@ -122,7 +120,9 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient): return True return False - def check_if_swarm_node_is_down(self, node_id=None, repeat_check=1): + def check_if_swarm_node_is_down( + self, node_id: str | None = None, repeat_check: int = 1 + ) -> bool: """ Checks if node status on Swarm manager is 'down'. If node_id is provided it query manager about node specified in parameter, otherwise it query manager itself. If run on Swarm Worker node or @@ -147,7 +147,19 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient): return True return False - def get_node_inspect(self, node_id=None, skip_missing=False): + @t.overload + def get_node_inspect( + self, node_id: str | None = None, skip_missing: t.Literal[False] = False + ) -> dict[str, t.Any]: ... + + @t.overload + def get_node_inspect( + self, node_id: str | None = None, skip_missing: bool = False + ) -> dict[str, t.Any] | None: ... + + def get_node_inspect( + self, node_id: str | None = None, skip_missing: bool = False + ) -> dict[str, t.Any] | None: """ Returns Swarm node info as in 'docker node inspect' command about single node @@ -195,7 +207,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient): node_info["Status"]["Addr"] = swarm_leader_ip return node_info - def get_all_nodes_inspect(self): + def get_all_nodes_inspect(self) -> list[dict[str, t.Any]]: """ Returns Swarm node info as in 'docker node inspect' command about all registered nodes @@ -217,7 +229,17 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient): node_info = json.loads(json_str) return node_info - def get_all_nodes_list(self, output="short"): + @t.overload + def get_all_nodes_list(self, output: t.Literal["short"] = "short") -> list[str]: ... + + @t.overload + def get_all_nodes_list( + self, output: t.Literal["long"] + ) -> list[dict[str, t.Any]]: ... + + def get_all_nodes_list( + self, output: t.Literal["short", "long"] = "short" + ) -> list[str] | list[dict[str, t.Any]]: """ Returns list of nodes registered in Swarm @@ -227,48 +249,46 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient): if 'output' is 'long' then returns data is list of dict containing the attributes as in output of command 'docker node ls' """ - nodes_list = [] - nodes_inspect = self.get_all_nodes_inspect() - if nodes_inspect is None: - return None if output == "short": + nodes_list = [] for node in nodes_inspect: nodes_list.append(node["Description"]["Hostname"]) - elif output == "long": + return nodes_list + if output == "long": + nodes_info_list = [] for node in nodes_inspect: - node_property = {} + node_property: dict[str, t.Any] = {} - node_property.update({"ID": node["ID"]}) - node_property.update({"Hostname": node["Description"]["Hostname"]}) - node_property.update({"Status": node["Status"]["State"]}) - node_property.update({"Availability": node["Spec"]["Availability"]}) + node_property["ID"] = node["ID"] + node_property["Hostname"] = node["Description"]["Hostname"] + node_property["Status"] = node["Status"]["State"] + node_property["Availability"] = node["Spec"]["Availability"] if "ManagerStatus" in node: if node["ManagerStatus"]["Leader"] is True: - node_property.update({"Leader": True}) - node_property.update( - {"ManagerStatus": node["ManagerStatus"]["Reachability"]} - ) - node_property.update( - {"EngineVersion": node["Description"]["Engine"]["EngineVersion"]} - ) + node_property["Leader"] = True + node_property["ManagerStatus"] = node["ManagerStatus"][ + "Reachability" + ] + node_property["EngineVersion"] = node["Description"]["Engine"][ + "EngineVersion" + ] - nodes_list.append(node_property) - else: - return None + nodes_info_list.append(node_property) + return nodes_info_list - return nodes_list - - def get_node_name_by_id(self, nodeid): + def get_node_name_by_id(self, nodeid: str) -> str: return self.get_node_inspect(nodeid)["Description"]["Hostname"] - def get_unlock_key(self): + def get_unlock_key(self) -> str | None: if self.docker_py_version < LooseVersion("2.7.0"): return None return super().get_unlock_key() - def get_service_inspect(self, service_id, skip_missing=False): + def get_service_inspect( + self, service_id: str, skip_missing: bool = False + ) -> dict[str, t.Any] | None: """ Returns Swarm service info as in 'docker service inspect' command about single service diff --git a/plugins/module_utils/_util.py b/plugins/module_utils/_util.py index 5022c582..171796ea 100644 --- a/plugins/module_utils/_util.py +++ b/plugins/module_utils/_util.py @@ -9,6 +9,7 @@ from __future__ import annotations import json import re +import typing as t from datetime import timedelta from urllib.parse import urlparse @@ -17,6 +18,12 @@ from ansible.module_utils.common.collections import is_sequence from ansible.module_utils.common.text.converters import to_text +if t.TYPE_CHECKING: + from collections.abc import Callable + + from ansible.module_utils.basic import AnsibleModule + + DEFAULT_DOCKER_HOST = "unix:///var/run/docker.sock" DEFAULT_TLS = False DEFAULT_TLS_VERIFY = False @@ -69,22 +76,24 @@ DOCKER_COMMON_ARGS_VARS = { if option_name != "debug" } -DOCKER_MUTUALLY_EXCLUSIVE = [] +DOCKER_MUTUALLY_EXCLUSIVE: list[tuple[str, ...] | list[str]] = [] -DOCKER_REQUIRED_TOGETHER = [["client_cert", "client_key"]] +DOCKER_REQUIRED_TOGETHER: list[tuple[str, ...] | list[str]] = [ + ["client_cert", "client_key"] +] DEFAULT_DOCKER_REGISTRY = "https://index.docker.io/v1/" BYTE_SUFFIXES = ["B", "KB", "MB", "GB", "TB", "PB"] -def is_image_name_id(name): +def is_image_name_id(name: str) -> bool: """Check whether the given image name is in fact an image ID (hash).""" if re.match("^sha256:[0-9a-fA-F]{64}$", name): return True return False -def is_valid_tag(tag, allow_empty=False): +def is_valid_tag(tag: str, allow_empty: bool = False) -> bool: """Check whether the given string is a valid docker tag name.""" if not tag: return allow_empty @@ -93,7 +102,7 @@ def is_valid_tag(tag, allow_empty=False): return bool(re.match("^[a-zA-Z0-9_][a-zA-Z0-9_.-]{0,127}$", tag)) -def sanitize_result(data): +def sanitize_result(data: t.Any) -> t.Any: """Sanitize data object for return to Ansible. When the data object contains types such as docker.types.containers.HostConfig, @@ -110,7 +119,7 @@ def sanitize_result(data): return data -def log_debug(msg, pretty_print=False): +def log_debug(msg: t.Any, pretty_print: bool = False): """Write a log message to docker.log. If ``pretty_print=True``, the message will be pretty-printed as JSON. @@ -126,25 +135,28 @@ def log_debug(msg, pretty_print=False): class DockerBaseClass: - def __init__(self): + def __init__(self) -> None: self.debug = False - def log(self, msg, pretty_print=False): + def log(self, msg: t.Any, pretty_print: bool = False) -> None: pass # if self.debug: # log_debug(msg, pretty_print=pretty_print) def update_tls_hostname( - result, old_behavior=False, deprecate_function=None, uses_tls=True -): + result: dict[str, t.Any], + old_behavior: bool = False, + deprecate_function: Callable[[str], None] | None = None, + uses_tls: bool = True, +) -> None: if result["tls_hostname"] is None: # get default machine name from the url parsed_url = urlparse(result["docker_host"]) result["tls_hostname"] = parsed_url.netloc.rsplit(":", 1)[0] -def compare_dict_allow_more_present(av, bv): +def compare_dict_allow_more_present(av: dict, bv: dict) -> bool: """ Compare two dictionaries for whether every entry of the first is in the second. """ @@ -156,7 +168,12 @@ def compare_dict_allow_more_present(av, bv): return True -def compare_generic(a, b, method, datatype): +def compare_generic( + a: t.Any, + b: t.Any, + method: t.Literal["ignore", "strict", "allow_more_present"], + datatype: t.Literal["value", "list", "set", "set(dict)", "dict"], +) -> bool: """ Compare values a and b as described by method and datatype. @@ -247,10 +264,10 @@ def compare_generic(a, b, method, datatype): class DifferenceTracker: - def __init__(self): - self._diff = [] + def __init__(self) -> None: + self._diff: list[dict[str, t.Any]] = [] - def add(self, name, parameter=None, active=None): + def add(self, name: str, parameter: t.Any = None, active: t.Any = None) -> None: self._diff.append( { "name": name, @@ -259,14 +276,14 @@ class DifferenceTracker: } ) - def merge(self, other_tracker): + def merge(self, other_tracker: DifferenceTracker) -> None: self._diff.extend(other_tracker._diff) @property - def empty(self): + def empty(self) -> bool: return len(self._diff) == 0 - def get_before_after(self): + def get_before_after(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: """ Return texts ``before`` and ``after``. """ @@ -277,13 +294,13 @@ class DifferenceTracker: after[item["name"]] = item["parameter"] return before, after - def has_difference_for(self, name): + def has_difference_for(self, name: str) -> bool: """ Returns a boolean if a difference exists for name """ return any(diff for diff in self._diff if diff["name"] == name) - def get_legacy_docker_container_diffs(self): + def get_legacy_docker_container_diffs(self) -> list[dict[str, t.Any]]: """ Return differences in the docker_container legacy format. """ @@ -297,7 +314,7 @@ class DifferenceTracker: result.append(item) return result - def get_legacy_docker_diffs(self): + def get_legacy_docker_diffs(self) -> list[str]: """ Return differences in the docker_container legacy format. """ @@ -305,8 +322,13 @@ class DifferenceTracker: return result -def sanitize_labels(labels, labels_field, client=None, module=None): - def fail(msg): +def sanitize_labels( + labels: dict[str, t.Any] | None, + labels_field: str, + client=None, + module: AnsibleModule | None = None, +) -> None: + def fail(msg: str) -> t.NoReturn: if client is not None: client.fail(msg) if module is not None: @@ -325,7 +347,21 @@ def sanitize_labels(labels, labels_field, client=None, module=None): labels[k] = to_text(v) -def clean_dict_booleans_for_docker_api(data, allow_sequences=False): +@t.overload +def clean_dict_booleans_for_docker_api( + data: dict[str, t.Any], *, allow_sequences: t.Literal[False] = False +) -> dict[str, str]: ... + + +@t.overload +def clean_dict_booleans_for_docker_api( + data: dict[str, t.Any], *, allow_sequences: bool +) -> dict[str, str | list[str]]: ... + + +def clean_dict_booleans_for_docker_api( + data: dict[str, t.Any] | None, *, allow_sequences: bool = False +) -> dict[str, str] | dict[str, str | list[str]]: """ Go does not like Python booleans 'True' or 'False', while Ansible is just fine with them in YAML. As such, they need to be converted in cases where @@ -353,7 +389,7 @@ def clean_dict_booleans_for_docker_api(data, allow_sequences=False): return result -def convert_duration_to_nanosecond(time_str): +def convert_duration_to_nanosecond(time_str: str) -> int: """ Return time duration in nanosecond. """ @@ -372,9 +408,9 @@ def convert_duration_to_nanosecond(time_str): if not parts: raise ValueError(f"Invalid time duration - {time_str}") - parts = parts.groupdict() + parts_dict = parts.groupdict() time_params = {} - for name, value in parts.items(): + for name, value in parts_dict.items(): if value: time_params[name] = int(value) @@ -386,13 +422,15 @@ def convert_duration_to_nanosecond(time_str): return time_in_nanoseconds -def normalize_healthcheck_test(test): +def normalize_healthcheck_test(test: t.Any) -> list[str]: if isinstance(test, (tuple, list)): return [str(e) for e in test] return ["CMD-SHELL", str(test)] -def normalize_healthcheck(healthcheck, normalize_test=False): +def normalize_healthcheck( + healthcheck: dict[str, t.Any], normalize_test: bool = False +) -> dict[str, t.Any]: """ Return dictionary of healthcheck parameters. """ @@ -438,7 +476,9 @@ def normalize_healthcheck(healthcheck, normalize_test=False): return result -def parse_healthcheck(healthcheck): +def parse_healthcheck( + healthcheck: dict[str, t.Any] | None, +) -> tuple[dict[str, t.Any] | None, bool | None]: """ Return dictionary of healthcheck parameters and boolean if healthcheck defined in image was requested to be disabled. @@ -456,8 +496,8 @@ def parse_healthcheck(healthcheck): return result, False -def omit_none_from_dict(d): +def omit_none_from_dict(d: dict[str, t.Any]) -> dict[str, t.Any]: """ Return a copy of the dictionary with all keys with value None omitted. """ - return dict((k, v) for (k, v) in d.items() if v is not None) + return {k: v for (k, v) in d.items() if v is not None} diff --git a/plugins/modules/current_container_facts.py b/plugins/modules/current_container_facts.py index 4aedfec1..9220dd2f 100644 --- a/plugins/modules/current_container_facts.py +++ b/plugins/modules/current_container_facts.py @@ -80,7 +80,7 @@ import re from ansible.module_utils.basic import AnsibleModule -def main(): +def main() -> None: module = AnsibleModule({}, supports_check_mode=True) cpuset_path = "/proc/self/cpuset" diff --git a/plugins/modules/docker_compose_v2.py b/plugins/modules/docker_compose_v2.py index 3cd1e09f..54b99163 100644 --- a/plugins/modules/docker_compose_v2.py +++ b/plugins/modules/docker_compose_v2.py @@ -437,6 +437,7 @@ actions: """ import traceback +import typing as t from ansible.module_utils.common.validation import check_type_int @@ -455,26 +456,32 @@ from ansible_collections.community.docker.plugins.module_utils._version import ( class ServicesManager(BaseComposeManager): - def __init__(self, client): + def __init__(self, client: AnsibleModuleDockerClient) -> None: super().__init__(client) parameters = self.client.module.params - self.state = parameters["state"] - self.dependencies = parameters["dependencies"] - self.pull = parameters["pull"] - self.build = parameters["build"] - self.ignore_build_events = parameters["ignore_build_events"] - self.recreate = parameters["recreate"] - self.remove_images = parameters["remove_images"] - self.remove_volumes = parameters["remove_volumes"] - self.remove_orphans = parameters["remove_orphans"] - self.renew_anon_volumes = parameters["renew_anon_volumes"] - self.timeout = parameters["timeout"] - self.services = parameters["services"] or [] - self.scale = parameters["scale"] or {} - self.wait = parameters["wait"] - self.wait_timeout = parameters["wait_timeout"] - self.yes = parameters["assume_yes"] + self.state: t.Literal["absent", "present", "stopped", "restarted"] = parameters[ + "state" + ] + self.dependencies: bool = parameters["dependencies"] + self.pull: t.Literal["always", "missing", "never", "policy"] = parameters[ + "pull" + ] + self.build: t.Literal["always", "never", "policy"] = parameters["build"] + self.ignore_build_events: bool = parameters["ignore_build_events"] + self.recreate: t.Literal["always", "never", "auto"] = parameters["recreate"] + self.remove_images: t.Literal["all", "local"] | None = parameters[ + "remove_images" + ] + self.remove_volumes: bool = parameters["remove_volumes"] + self.remove_orphans: bool = parameters["remove_orphans"] + self.renew_anon_volumes: bool = parameters["renew_anon_volumes"] + self.timeout: int | None = parameters["timeout"] + self.services: list[str] = parameters["services"] or [] + self.scale: dict[str, t.Any] = parameters["scale"] or {} + self.wait: bool = parameters["wait"] + self.wait_timeout: int | None = parameters["wait_timeout"] + self.yes: bool = parameters["assume_yes"] if self.compose_version < LooseVersion("2.32.0") and self.yes: self.fail( f"assume_yes=true needs Docker Compose 2.32.0 or newer, not version {self.compose_version}" @@ -491,7 +498,7 @@ class ServicesManager(BaseComposeManager): self.fail(f"The value {value!r} for `scale[{key!r}]` is negative") self.scale[key] = value - def run(self): + def run(self) -> dict[str, t.Any]: if self.state == "present": result = self.cmd_up() elif self.state == "stopped": @@ -508,7 +515,7 @@ class ServicesManager(BaseComposeManager): self.cleanup_result(result) return result - def get_up_cmd(self, dry_run, no_start=False): + def get_up_cmd(self, dry_run: bool, no_start: bool = False) -> list[str]: args = self.get_base_args() + ["up", "--detach", "--no-color", "--quiet-pull"] if self.pull != "policy": args.extend(["--pull", self.pull]) @@ -549,8 +556,8 @@ class ServicesManager(BaseComposeManager): args.append(service) return args - def cmd_up(self): - result = {} + def cmd_up(self) -> dict[str, t.Any]: + result: dict[str, t.Any] = {} args = self.get_up_cmd(self.check_mode) rc, stdout, stderr = self.client.call_cli(*args, cwd=self.project_src) events = self.parse_events(stderr, dry_run=self.check_mode, nonzero_rc=rc != 0) @@ -566,7 +573,7 @@ class ServicesManager(BaseComposeManager): self.update_failed(result, events, args, stdout, stderr, rc) return result - def get_stop_cmd(self, dry_run): + def get_stop_cmd(self, dry_run: bool) -> list[str]: args = self.get_base_args() + ["stop"] if self.timeout is not None: args.extend(["--timeout", f"{self.timeout}"]) @@ -577,17 +584,17 @@ class ServicesManager(BaseComposeManager): args.append(service) return args - def _are_containers_stopped(self): + def _are_containers_stopped(self) -> bool: for container in self.list_containers_raw(): if container["State"] not in ("created", "exited", "stopped", "killed"): return False return True - def cmd_stop(self): + def cmd_stop(self) -> dict[str, t.Any]: # Since 'docker compose stop' **always** claims it is stopping containers, even if they are already # stopped, we have to do this a bit more complicated. - result = {} + result: dict[str, t.Any] = {} # Make sure all containers are created args_1 = self.get_up_cmd(self.check_mode, no_start=True) rc_1, stdout_1, stderr_1 = self.client.call_cli(*args_1, cwd=self.project_src) @@ -630,7 +637,7 @@ class ServicesManager(BaseComposeManager): ) return result - def get_restart_cmd(self, dry_run): + def get_restart_cmd(self, dry_run: bool) -> list[str]: args = self.get_base_args() + ["restart"] if not self.dependencies: args.append("--no-deps") @@ -643,8 +650,8 @@ class ServicesManager(BaseComposeManager): args.append(service) return args - def cmd_restart(self): - result = {} + def cmd_restart(self) -> dict[str, t.Any]: + result: dict[str, t.Any] = {} args = self.get_restart_cmd(self.check_mode) rc, stdout, stderr = self.client.call_cli(*args, cwd=self.project_src) events = self.parse_events(stderr, dry_run=self.check_mode, nonzero_rc=rc != 0) @@ -653,7 +660,7 @@ class ServicesManager(BaseComposeManager): self.update_failed(result, events, args, stdout, stderr, rc) return result - def get_down_cmd(self, dry_run): + def get_down_cmd(self, dry_run: bool) -> list[str]: args = self.get_base_args() + ["down"] if self.remove_orphans: args.append("--remove-orphans") @@ -670,8 +677,8 @@ class ServicesManager(BaseComposeManager): args.append(service) return args - def cmd_down(self): - result = {} + def cmd_down(self) -> dict[str, t.Any]: + result: dict[str, t.Any] = {} args = self.get_down_cmd(self.check_mode) rc, stdout, stderr = self.client.call_cli(*args, cwd=self.project_src) events = self.parse_events(stderr, dry_run=self.check_mode, nonzero_rc=rc != 0) @@ -681,7 +688,7 @@ class ServicesManager(BaseComposeManager): return result -def main(): +def main() -> None: argument_spec = { "state": { "type": "str", diff --git a/plugins/modules/docker_compose_v2_exec.py b/plugins/modules/docker_compose_v2_exec.py index e7836bd0..1149bb91 100644 --- a/plugins/modules/docker_compose_v2_exec.py +++ b/plugins/modules/docker_compose_v2_exec.py @@ -166,6 +166,7 @@ rc: import shlex import traceback +import typing as t from ansible.module_utils.common.text.converters import to_text @@ -180,29 +181,32 @@ from ansible_collections.community.docker.plugins.module_utils._compose_v2 impor class ExecManager(BaseComposeManager): - def __init__(self, client): + def __init__(self, client: AnsibleModuleDockerClient) -> None: super().__init__(client) parameters = self.client.module.params - self.service = parameters["service"] - self.index = parameters["index"] - self.chdir = parameters["chdir"] - self.detach = parameters["detach"] - self.user = parameters["user"] - self.stdin = parameters["stdin"] - self.strip_empty_ends = parameters["strip_empty_ends"] - self.privileged = parameters["privileged"] - self.tty = parameters["tty"] - self.env = parameters["env"] + self.service: str = parameters["service"] + self.index: int | None = parameters["index"] + self.chdir: str | None = parameters["chdir"] + self.detach: bool = parameters["detach"] + self.user: str | None = parameters["user"] + self.stdin: str | None = parameters["stdin"] + self.strip_empty_ends: bool = parameters["strip_empty_ends"] + self.privileged: bool = parameters["privileged"] + self.tty: bool = parameters["tty"] + self.env: dict[str, t.Any] = parameters["env"] - self.argv = parameters["argv"] + self.argv: list[str] if parameters["command"] is not None: self.argv = shlex.split(parameters["command"]) + else: + self.argv = parameters["argv"] if self.detach and self.stdin is not None: self.fail("If detach=true, stdin cannot be provided.") - if self.stdin is not None and parameters["stdin_add_newline"]: + stdin_add_newline: bool = parameters["stdin_add_newline"] + if self.stdin is not None and stdin_add_newline: self.stdin += "\n" if self.env is not None: @@ -214,7 +218,7 @@ class ExecManager(BaseComposeManager): ) self.env[name] = to_text(value, errors="surrogate_or_strict") - def get_exec_cmd(self, dry_run, no_start=False): + def get_exec_cmd(self, dry_run: bool) -> list[str]: args = self.get_base_args(plain_progress=True) + ["exec"] if self.index is not None: args.extend(["--index", str(self.index)]) @@ -237,9 +241,9 @@ class ExecManager(BaseComposeManager): args.extend(self.argv) return args - def run(self): + def run(self) -> dict[str, t.Any]: args = self.get_exec_cmd(self.check_mode) - kwargs = { + kwargs: dict[str, t.Any] = { "cwd": self.project_src, } if self.stdin is not None: @@ -262,7 +266,7 @@ class ExecManager(BaseComposeManager): } -def main(): +def main() -> None: argument_spec = { "service": {"type": "str", "required": True}, "index": {"type": "int"}, diff --git a/plugins/modules/docker_compose_v2_pull.py b/plugins/modules/docker_compose_v2_pull.py index 1f438464..1e00af40 100644 --- a/plugins/modules/docker_compose_v2_pull.py +++ b/plugins/modules/docker_compose_v2_pull.py @@ -111,6 +111,7 @@ actions: """ import traceback +import typing as t from ansible_collections.community.docker.plugins.module_utils._common_cli import ( AnsibleModuleDockerClient, @@ -126,14 +127,14 @@ from ansible_collections.community.docker.plugins.module_utils._version import ( class PullManager(BaseComposeManager): - def __init__(self, client): + def __init__(self, client: AnsibleModuleDockerClient) -> None: super().__init__(client) parameters = self.client.module.params - self.policy = parameters["policy"] - self.ignore_buildable = parameters["ignore_buildable"] - self.include_deps = parameters["include_deps"] - self.services = parameters["services"] or [] + self.policy: t.Literal["always", "missing"] = parameters["policy"] + self.ignore_buildable: bool = parameters["ignore_buildable"] + self.include_deps: bool = parameters["include_deps"] + self.services: list[str] = parameters["services"] or [] if self.policy != "always" and self.compose_version < LooseVersion("2.22.0"): # https://github.com/docker/compose/pull/10981 - 2.22.0 @@ -146,7 +147,7 @@ class PullManager(BaseComposeManager): f"--ignore-buildable is only supported since Docker Compose 2.15.0. {self.client.get_cli()} has version {self.compose_version}" ) - def get_pull_cmd(self, dry_run, no_start=False): + def get_pull_cmd(self, dry_run: bool): args = self.get_base_args() + ["pull"] if self.policy != "always": args.extend(["--policy", self.policy]) @@ -161,8 +162,8 @@ class PullManager(BaseComposeManager): args.append(service) return args - def run(self): - result = {} + def run(self) -> dict[str, t.Any]: + result: dict[str, t.Any] = {} args = self.get_pull_cmd(self.check_mode) rc, stdout, stderr = self.client.call_cli(*args, cwd=self.project_src) events = self.parse_events(stderr, dry_run=self.check_mode, nonzero_rc=rc != 0) @@ -179,7 +180,7 @@ class PullManager(BaseComposeManager): return result -def main(): +def main() -> None: argument_spec = { "policy": { "type": "str", diff --git a/plugins/modules/docker_compose_v2_run.py b/plugins/modules/docker_compose_v2_run.py index 8f9040e2..39cc5e38 100644 --- a/plugins/modules/docker_compose_v2_run.py +++ b/plugins/modules/docker_compose_v2_run.py @@ -239,6 +239,7 @@ rc: import shlex import traceback +import typing as t from ansible.module_utils.common.text.converters import to_text @@ -253,42 +254,45 @@ from ansible_collections.community.docker.plugins.module_utils._compose_v2 impor class ExecManager(BaseComposeManager): - def __init__(self, client): + def __init__(self, client: AnsibleModuleDockerClient) -> None: super().__init__(client) parameters = self.client.module.params - self.service = parameters["service"] - self.build = parameters["build"] - self.cap_add = parameters["cap_add"] - self.cap_drop = parameters["cap_drop"] - self.entrypoint = parameters["entrypoint"] - self.interactive = parameters["interactive"] - self.labels = parameters["labels"] - self.name = parameters["name"] - self.no_deps = parameters["no_deps"] - self.publish = parameters["publish"] - self.quiet_pull = parameters["quiet_pull"] - self.remove_orphans = parameters["remove_orphans"] - self.do_cleanup = parameters["cleanup"] - self.service_ports = parameters["service_ports"] - self.use_aliases = parameters["use_aliases"] - self.volumes = parameters["volumes"] - self.chdir = parameters["chdir"] - self.detach = parameters["detach"] - self.user = parameters["user"] - self.stdin = parameters["stdin"] - self.strip_empty_ends = parameters["strip_empty_ends"] - self.tty = parameters["tty"] - self.env = parameters["env"] + self.service: str = parameters["service"] + self.build: bool = parameters["build"] + self.cap_add: list[str] | None = parameters["cap_add"] + self.cap_drop: list[str] | None = parameters["cap_drop"] + self.entrypoint: str | None = parameters["entrypoint"] + self.interactive: bool = parameters["interactive"] + self.labels: list[str] | None = parameters["labels"] + self.name: str | None = parameters["name"] + self.no_deps: bool = parameters["no_deps"] + self.publish: list[str] | None = parameters["publish"] + self.quiet_pull: bool = parameters["quiet_pull"] + self.remove_orphans: bool = parameters["remove_orphans"] + self.do_cleanup: bool = parameters["cleanup"] + self.service_ports: bool = parameters["service_ports"] + self.use_aliases: bool = parameters["use_aliases"] + self.volumes: list[str] | None = parameters["volumes"] + self.chdir: str | None = parameters["chdir"] + self.detach: bool = parameters["detach"] + self.user: str | None = parameters["user"] + self.stdin: str | None = parameters["stdin"] + self.strip_empty_ends: bool = parameters["strip_empty_ends"] + self.tty: bool = parameters["tty"] + self.env: dict[str, t.Any] | None = parameters["env"] - self.argv = parameters["argv"] + self.argv: list[str] if parameters["command"] is not None: self.argv = shlex.split(parameters["command"]) + else: + self.argv = parameters["argv"] if self.detach and self.stdin is not None: self.fail("If detach=true, stdin cannot be provided.") - if self.stdin is not None and parameters["stdin_add_newline"]: + stdin_add_newline: bool = parameters["stdin_add_newline"] + if self.stdin is not None and stdin_add_newline: self.stdin += "\n" if self.env is not None: @@ -300,7 +304,7 @@ class ExecManager(BaseComposeManager): ) self.env[name] = to_text(value, errors="surrogate_or_strict") - def get_run_cmd(self, dry_run, no_start=False): + def get_run_cmd(self, dry_run: bool) -> list[str]: args = self.get_base_args(plain_progress=True) + ["run"] if self.build: args.append("--build") @@ -355,9 +359,9 @@ class ExecManager(BaseComposeManager): args.extend(self.argv) return args - def run(self): + def run(self) -> dict[str, t.Any]: args = self.get_run_cmd(self.check_mode) - kwargs = { + kwargs: dict[str, t.Any] = { "cwd": self.project_src, } if self.stdin is not None: @@ -382,7 +386,7 @@ class ExecManager(BaseComposeManager): } -def main(): +def main() -> None: argument_spec = { "service": {"type": "str", "required": True}, "argv": {"type": "list", "elements": "str"}, diff --git a/plugins/modules/docker_container.py b/plugins/modules/docker_container.py index 72c8f1bd..e8cdb06f 100644 --- a/plugins/modules/docker_container.py +++ b/plugins/modules/docker_container.py @@ -1355,7 +1355,7 @@ from ansible_collections.community.docker.plugins.module_utils._module_container ) -def main(): +def main() -> None: engine_driver = DockerAPIEngineDriver() run_module(engine_driver) diff --git a/plugins/modules/docker_container_copy_into.py b/plugins/modules/docker_container_copy_into.py index 9bf6d965..9c7575ba 100644 --- a/plugins/modules/docker_container_copy_into.py +++ b/plugins/modules/docker_container_copy_into.py @@ -169,6 +169,7 @@ import io import os import stat import traceback +import typing as t from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text from ansible.module_utils.common.validation import check_type_int @@ -198,23 +199,29 @@ from ansible_collections.community.docker.plugins.module_utils._scramble import ) -def are_fileobjs_equal(f1, f2): +if t.TYPE_CHECKING: + import tarfile + + +def are_fileobjs_equal(f1: t.IO[bytes], f2: t.IO[bytes]) -> bool: """Given two (buffered) file objects, compare their contents.""" + f1on: t.IO[bytes] | None = f1 + f2on: t.IO[bytes] | None = f2 blocksize = 65536 b1buf = b"" b2buf = b"" while True: - if f1 and len(b1buf) < blocksize: - f1b = f1.read(blocksize) + if f1on and len(b1buf) < blocksize: + f1b = f1on.read(blocksize) if not f1b: # f1 is EOF, so stop reading from it - f1 = None + f1on = None b1buf += f1b - if f2 and len(b2buf) < blocksize: - f2b = f2.read(blocksize) + if f2on and len(b2buf) < blocksize: + f2b = f2on.read(blocksize) if not f2b: # f2 is EOF, so stop reading from it - f2 = None + f2on = None b2buf += f2b if not b1buf or not b2buf: # At least one of f1 and f2 is EOF and all its data has @@ -229,29 +236,33 @@ def are_fileobjs_equal(f1, f2): b2buf = b2buf[buflen:] -def are_fileobjs_equal_read_first(f1, f2): +def are_fileobjs_equal_read_first( + f1: t.IO[bytes], f2: t.IO[bytes] +) -> tuple[bool, bytes]: """Given two (buffered) file objects, compare their contents. Returns a tuple (is_equal, content_of_f1), where the first element indicates whether the two file objects have the same content, and the second element is the content of the first file object.""" + f1on: t.IO[bytes] | None = f1 + f2on: t.IO[bytes] | None = f2 blocksize = 65536 b1buf = b"" b2buf = b"" is_equal = True content = [] while True: - if f1 and len(b1buf) < blocksize: - f1b = f1.read(blocksize) + if f1on and len(b1buf) < blocksize: + f1b = f1on.read(blocksize) if not f1b: # f1 is EOF, so stop reading from it - f1 = None + f1on = None b1buf += f1b - if f2 and len(b2buf) < blocksize: - f2b = f2.read(blocksize) + if f2on and len(b2buf) < blocksize: + f2b = f2on.read(blocksize) if not f2b: # f2 is EOF, so stop reading from it - f2 = None + f2on = None b2buf += f2b if not b1buf or not b2buf: # At least one of f1 and f2 is EOF and all its data has @@ -269,13 +280,13 @@ def are_fileobjs_equal_read_first(f1, f2): b2buf = b2buf[buflen:] content.append(b1buf) - if f1: - content.append(f1.read()) + if f1on: + content.append(f1on.read()) return is_equal, b"".join(content) -def is_container_file_not_regular_file(container_stat): +def is_container_file_not_regular_file(container_stat: dict[str, t.Any]) -> bool: for bit in ( # https://pkg.go.dev/io/fs#FileMode 32 - 1, # ModeDir @@ -292,7 +303,7 @@ def is_container_file_not_regular_file(container_stat): return False -def get_container_file_mode(container_stat): +def get_container_file_mode(container_stat: dict[str, t.Any]) -> int: mode = container_stat["mode"] & 0xFFF if container_stat["mode"] & (1 << (32 - 9)) != 0: # ModeSetuid mode |= stat.S_ISUID # set UID bit @@ -303,7 +314,9 @@ def get_container_file_mode(container_stat): return mode -def add_other_diff(diff, in_path, member): +def add_other_diff( + diff: dict[str, t.Any] | None, in_path: str, member: tarfile.TarInfo +) -> None: if diff is None: return diff["before_header"] = in_path @@ -326,14 +339,14 @@ def add_other_diff(diff, in_path, member): def retrieve_diff( - client, - container, - container_path, - follow_links, - diff, - max_file_size_for_diff, - regular_stat=None, - link_target=None, + client: AnsibleDockerClient, + container: str, + container_path: str, + follow_links: bool, + diff: dict[str, t.Any] | None, + max_file_size_for_diff: int, + regular_stat: dict[str, t.Any] | None = None, + link_target: str | None = None, ): if diff is None: return @@ -377,19 +390,21 @@ def retrieve_diff( return # We need to get hold of the content - def process_none(in_path): + def process_none(in_path: str) -> None: diff["before"] = "" - def process_regular(in_path, tar, member): + def process_regular( + in_path: str, tar: tarfile.TarFile, member: tarfile.TarInfo + ) -> None: add_diff_dst_from_regular_member( diff, max_file_size_for_diff, in_path, tar, member ) - def process_symlink(in_path, member): + def process_symlink(in_path: str, member: tarfile.TarInfo) -> None: diff["before_header"] = in_path diff["before"] = member.linkname - def process_other(in_path, member): + def process_other(in_path: str, member: tarfile.TarInfo) -> None: add_other_diff(diff, in_path, member) fetch_file_ex( @@ -404,7 +419,7 @@ def retrieve_diff( ) -def is_binary(content): +def is_binary(content: bytes) -> bool: if b"\x00" in content: return True # TODO: better detection @@ -413,8 +428,13 @@ def is_binary(content): def are_fileobjs_equal_with_diff_of_first( - f1, f2, size, diff, max_file_size_for_diff, container_path -): + f1: t.IO[bytes], + f2: t.IO[bytes], + size: int, + diff: dict[str, t.Any] | None, + max_file_size_for_diff: int, + container_path: str, +) -> bool: if diff is None: return are_fileobjs_equal(f1, f2) if size > max_file_size_for_diff > 0: @@ -430,15 +450,22 @@ def are_fileobjs_equal_with_diff_of_first( def add_diff_dst_from_regular_member( - diff, max_file_size_for_diff, container_path, tar, member -): + diff: dict[str, t.Any] | None, + max_file_size_for_diff: int, + container_path: str, + tar: tarfile.TarFile, + member: tarfile.TarInfo, +) -> None: if diff is None: return if member.size > max_file_size_for_diff > 0: diff["dst_larger"] = max_file_size_for_diff return - with tar.extractfile(member) as tar_f: + mf = tar.extractfile(member) + if not mf: + raise AssertionError("Member should be present for regular file") + with mf as tar_f: content = tar_f.read() if is_binary(content): @@ -448,35 +475,35 @@ def add_diff_dst_from_regular_member( diff["before"] = to_text(content) -def copy_dst_to_src(diff): +def copy_dst_to_src(diff: dict[str, t.Any] | None) -> None: if diff is None: return - for f, t in [ + for frm, to in [ ("dst_size", "src_size"), ("dst_binary", "src_binary"), ("before_header", "after_header"), ("before", "after"), ]: - if f in diff: - diff[t] = diff[f] - elif t in diff: - diff.pop(t) + if frm in diff: + diff[to] = diff[frm] + elif to in diff: + diff.pop(to) def is_file_idempotent( - client, - container, - managed_path, - container_path, - follow_links, - local_follow_links, + client: AnsibleDockerClient, + container: str, + managed_path: str, + container_path: str, + follow_links: bool, + local_follow_links: bool, owner_id, group_id, mode, - force=False, - diff=None, - max_file_size_for_diff=1, -): + force: bool | None = False, + diff: dict[str, t.Any] | None = None, + max_file_size_for_diff: int = 1, +) -> tuple[str, int, bool]: # Retrieve information of local file try: file_stat = ( @@ -644,10 +671,12 @@ def is_file_idempotent( return container_path, mode, False # Fetch file from container - def process_none(in_path): + def process_none(in_path: str) -> tuple[str, int, bool]: return container_path, mode, False - def process_regular(in_path, tar, member): + def process_regular( + in_path: str, tar: tarfile.TarFile, member: tarfile.TarInfo + ) -> tuple[str, int, bool]: # Check things like user/group ID and mode if any( [ @@ -663,14 +692,17 @@ def is_file_idempotent( ) return container_path, mode, False - with tar.extractfile(member) as tar_f: + mf = tar.extractfile(member) + if mf is None: + raise AssertionError("Member should be present for regular file") + with mf as tar_f: with open(managed_path, "rb") as local_f: is_equal = are_fileobjs_equal_with_diff_of_first( tar_f, local_f, member.size, diff, max_file_size_for_diff, in_path ) return container_path, mode, is_equal - def process_symlink(in_path, member): + def process_symlink(in_path: str, member: tarfile.TarInfo) -> tuple[str, int, bool]: if diff is not None: diff["before_header"] = in_path diff["before"] = member.linkname @@ -689,7 +721,7 @@ def is_file_idempotent( local_link_target = os.readlink(managed_path) return container_path, mode, member.linkname == local_link_target - def process_other(in_path, member): + def process_other(in_path: str, member: tarfile.TarInfo) -> tuple[str, int, bool]: add_other_diff(diff, in_path, member) return container_path, mode, False @@ -706,23 +738,21 @@ def is_file_idempotent( def copy_file_into_container( - client, - container, - managed_path, - container_path, - follow_links, - local_follow_links, + client: AnsibleDockerClient, + container: str, + managed_path: str, + container_path: str, + follow_links: bool, + local_follow_links: bool, owner_id, group_id, mode, - force=False, - diff=False, - max_file_size_for_diff=1, -): - if diff: - diff = {} - else: - diff = None + force: bool | None = False, + do_diff: bool = False, + max_file_size_for_diff: int = 1, +) -> t.NoReturn: + diff: dict[str, t.Any] | None + diff = {} if do_diff else None container_path, mode, idempotent = is_file_idempotent( client, @@ -762,18 +792,18 @@ def copy_file_into_container( def is_content_idempotent( - client, - container, - content, - container_path, - follow_links, + client: AnsibleDockerClient, + container: str, + content: bytes, + container_path: str, + follow_links: bool, owner_id, group_id, mode, - force=False, - diff=None, - max_file_size_for_diff=1, -): + force: bool | None = False, + diff: dict[str, t.Any] | None = None, + max_file_size_for_diff: int = 1, +) -> tuple[str, int, bool]: if diff is not None: if len(content) > max_file_size_for_diff > 0: diff["src_larger"] = max_file_size_for_diff @@ -894,12 +924,14 @@ def is_content_idempotent( return container_path, mode, False # Fetch file from container - def process_none(in_path): + def process_none(in_path: str) -> tuple[str, int, bool]: if diff is not None: diff["before"] = "" return container_path, mode, False - def process_regular(in_path, tar, member): + def process_regular( + in_path: str, tar: tarfile.TarFile, member: tarfile.TarInfo + ) -> tuple[str, int, bool]: # Check things like user/group ID and mode if any( [ @@ -914,7 +946,10 @@ def is_content_idempotent( ) return container_path, mode, False - with tar.extractfile(member) as tar_f: + mf = tar.extractfile(member) + if mf is None: + raise AssertionError("Member should be present for regular file") + with mf as tar_f: is_equal = are_fileobjs_equal_with_diff_of_first( tar_f, io.BytesIO(content), @@ -925,14 +960,14 @@ def is_content_idempotent( ) return container_path, mode, is_equal - def process_symlink(in_path, member): + def process_symlink(in_path: str, member: tarfile.TarInfo) -> tuple[str, int, bool]: if diff is not None: diff["before_header"] = in_path diff["before"] = member.linkname return container_path, mode, False - def process_other(in_path, member): + def process_other(in_path: str, member: tarfile.TarInfo) -> tuple[str, int, bool]: add_other_diff(diff, in_path, member) return container_path, mode, False @@ -949,22 +984,19 @@ def is_content_idempotent( def copy_content_into_container( - client, - container, - content, - container_path, - follow_links, + client: AnsibleDockerClient, + container: str, + content: bytes, + container_path: str, + follow_links: bool, owner_id, group_id, mode, - force=False, - diff=False, - max_file_size_for_diff=1, -): - if diff: - diff = {} - else: - diff = None + force: bool | None = False, + do_diff: bool = False, + max_file_size_for_diff: int = 1, +) -> t.NoReturn: + diff: dict[str, t.Any] | None = {} if do_diff else None container_path, mode, idempotent = is_content_idempotent( client, @@ -1007,7 +1039,7 @@ def copy_content_into_container( client.module.exit_json(**result) -def parse_modern(mode): +def parse_modern(mode: str | int) -> int: if isinstance(mode, str): return int(to_native(mode), 8) if isinstance(mode, int): @@ -1015,13 +1047,13 @@ def parse_modern(mode): raise TypeError(f"must be an octal string or an integer, got {mode!r}") -def parse_octal_string_only(mode): +def parse_octal_string_only(mode: str) -> int: if isinstance(mode, str): return int(to_native(mode), 8) raise TypeError(f"must be an octal string, got {mode!r}") -def main(): +def main() -> None: argument_spec = { "container": {"type": "str", "required": True}, "path": {"type": "path"}, @@ -1054,20 +1086,22 @@ def main(): }, ) - container = client.module.params["container"] - managed_path = client.module.params["path"] - container_path = client.module.params["container_path"] - follow = client.module.params["follow"] - local_follow = client.module.params["local_follow"] - owner_id = client.module.params["owner_id"] - group_id = client.module.params["group_id"] - mode = client.module.params["mode"] - force = client.module.params["force"] - content = client.module.params["content"] - max_file_size_for_diff = client.module.params["_max_file_size_for_diff"] or 1 + container: str = client.module.params["container"] + managed_path: str | None = client.module.params["path"] + container_path: str = client.module.params["container_path"] + follow: bool = client.module.params["follow"] + local_follow: bool = client.module.params["local_follow"] + owner_id: int | None = client.module.params["owner_id"] + group_id: int | None = client.module.params["group_id"] + mode: t.Any = client.module.params["mode"] + force: bool | None = client.module.params["force"] + content_str: str | None = client.module.params["content"] + max_file_size_for_diff: int = client.module.params["_max_file_size_for_diff"] or 1 if mode is not None: - mode_parse = client.module.params["mode_parse"] + mode_parse: t.Literal["legacy", "modern", "octal_string_only"] = ( + client.module.params["mode_parse"] + ) try: if mode_parse == "legacy": mode = check_type_int(mode) @@ -1080,14 +1114,15 @@ def main(): if mode < 0: client.fail(f"'mode' must not be negative; got {mode}") - if content is not None: + content: bytes | None = None + if content_str is not None: if client.module.params["content_is_b64"]: try: - content = base64.b64decode(content) + content = base64.b64decode(content_str) except Exception as e: # pylint: disable=broad-exception-caught client.fail(f"Cannot Base64 decode the content option: {e}") else: - content = to_bytes(content) + content = to_bytes(content_str) if not container_path.startswith(os.path.sep): container_path = os.path.join(os.path.sep, container_path) @@ -1108,7 +1143,7 @@ def main(): group_id=group_id, mode=mode, force=force, - diff=client.module._diff, + do_diff=client.module._diff, max_file_size_for_diff=max_file_size_for_diff, ) elif managed_path is not None: @@ -1123,7 +1158,7 @@ def main(): group_id=group_id, mode=mode, force=force, - diff=client.module._diff, + do_diff=client.module._diff, max_file_size_for_diff=max_file_size_for_diff, ) else: diff --git a/plugins/modules/docker_container_exec.py b/plugins/modules/docker_container_exec.py index b8bf8049..99651f4e 100644 --- a/plugins/modules/docker_container_exec.py +++ b/plugins/modules/docker_container_exec.py @@ -165,6 +165,7 @@ exec_id: import shlex import traceback +import typing as t from ansible.module_utils.common.text.converters import to_bytes, to_text @@ -185,7 +186,7 @@ from ansible_collections.community.docker.plugins.module_utils._socket_handler i ) -def main(): +def main() -> None: argument_spec = { "container": {"type": "str", "required": True}, "argv": {"type": "list", "elements": "str"}, @@ -211,16 +212,16 @@ def main(): required_one_of=[("argv", "command")], ) - container = client.module.params["container"] - argv = client.module.params["argv"] - command = client.module.params["command"] - chdir = client.module.params["chdir"] - detach = client.module.params["detach"] - user = client.module.params["user"] - stdin = client.module.params["stdin"] - strip_empty_ends = client.module.params["strip_empty_ends"] - tty = client.module.params["tty"] - env = client.module.params["env"] + container: str = client.module.params["container"] + argv: list[str] | None = client.module.params["argv"] + command: str | None = client.module.params["command"] + chdir: str | None = client.module.params["chdir"] + detach: bool = client.module.params["detach"] + user: str | None = client.module.params["user"] + stdin: str | None = client.module.params["stdin"] + strip_empty_ends: bool = client.module.params["strip_empty_ends"] + tty: bool = client.module.params["tty"] + env: dict[str, t.Any] = client.module.params["env"] if env is not None: for name, value in list(env.items()): @@ -233,6 +234,7 @@ def main(): if command is not None: argv = shlex.split(command) + assert argv is not None if detach and stdin is not None: client.module.fail_json(msg="If detach=true, stdin cannot be provided.") @@ -258,7 +260,7 @@ def main(): exec_data = client.post_json_to_json( "/containers/{0}/exec", container, data=data ) - exec_id = exec_data["Id"] + exec_id: str = exec_data["Id"] data = { "Tty": tty, @@ -269,6 +271,8 @@ def main(): client.module.exit_json(changed=True, exec_id=exec_id) else: + stdout: bytes | None + stderr: bytes | None if stdin and not detach: exec_socket = client.post_json_to_stream_socket( "/exec/{0}/start", exec_id, data=data @@ -283,28 +287,37 @@ def main(): stdout, stderr = exec_socket_handler.consume() finally: exec_socket.close() + elif tty: + stdout, stderr = client.post_json_to_stream( + "/exec/{0}/start", + exec_id, + data=data, + stream=False, + tty=True, + demux=True, + ) else: stdout, stderr = client.post_json_to_stream( "/exec/{0}/start", exec_id, data=data, stream=False, - tty=tty, + tty=False, demux=True, ) result = client.get_json("/exec/{0}/json", exec_id) - stdout = to_text(stdout or b"") - stderr = to_text(stderr or b"") + stdout_t = to_text(stdout or b"") + stderr_t = to_text(stderr or b"") if strip_empty_ends: - stdout = stdout.rstrip("\r\n") - stderr = stderr.rstrip("\r\n") + stdout_t = stdout_t.rstrip("\r\n") + stderr_t = stderr_t.rstrip("\r\n") client.module.exit_json( changed=True, - stdout=stdout, - stderr=stderr, + stdout=stdout_t, + stderr=stderr_t, rc=result.get("ExitCode") or 0, ) except NotFound: diff --git a/plugins/modules/docker_container_info.py b/plugins/modules/docker_container_info.py index 6a97fd13..174d2ac7 100644 --- a/plugins/modules/docker_container_info.py +++ b/plugins/modules/docker_container_info.py @@ -86,7 +86,7 @@ from ansible_collections.community.docker.plugins.module_utils._common_api impor ) -def main(): +def main() -> None: argument_spec = { "name": {"type": "str", "required": True}, } @@ -96,8 +96,9 @@ def main(): supports_check_mode=True, ) + container_id: str = client.module.params["name"] try: - container = client.get_container(client.module.params["name"]) + container = client.get_container(container_id) client.module.exit_json( changed=False, diff --git a/plugins/modules/docker_context_info.py b/plugins/modules/docker_context_info.py index 7a5bfd78..0528a872 100644 --- a/plugins/modules/docker_context_info.py +++ b/plugins/modules/docker_context_info.py @@ -173,6 +173,7 @@ current_context_name: """ import traceback +import typing as t from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.common.text.converters import to_text @@ -185,6 +186,7 @@ from ansible_collections.community.docker.plugins.module_utils._api.context.conf ) from ansible_collections.community.docker.plugins.module_utils._api.context.context import ( IN_MEMORY, + Context, ) from ansible_collections.community.docker.plugins.module_utils._api.errors import ( ContextException, @@ -192,7 +194,13 @@ from ansible_collections.community.docker.plugins.module_utils._api.errors impor ) -def tls_context_to_json(context): +if t.TYPE_CHECKING: + from ansible_collections.community.docker.plugins.module_utils._api.tls import ( + TLSConfig, + ) + + +def tls_context_to_json(context: TLSConfig | None) -> dict[str, t.Any] | None: if context is None: return None return { @@ -204,8 +212,8 @@ def tls_context_to_json(context): } -def context_to_json(context, current): - module_config = {} +def context_to_json(context: Context, current: bool) -> dict[str, t.Any]: + module_config: dict[str, t.Any] = {} if "docker" in context.endpoints: endpoint = context.endpoints["docker"] if isinstance(endpoint.get("Host"), str): @@ -247,7 +255,7 @@ def context_to_json(context, current): } -def main(): +def main() -> None: argument_spec = { "only_current": {"type": "bool", "default": False}, "name": {"type": "str"}, @@ -262,28 +270,31 @@ def main(): ], ) + only_current: bool = module.params["only_current"] + name: str | None = module.params["name"] + cli_context: str | None = module.params["cli_context"] try: - if module.params["cli_context"]: + if cli_context: current_context_name, current_context_source = ( - module.params["cli_context"], + cli_context, "cli_context module option", ) else: current_context_name, current_context_source = ( get_current_context_name_with_source() ) - if module.params["name"]: - contexts = [ContextAPI.get_context(module.params["name"])] - if not contexts[0]: - module.fail_json( - msg=f"There is no context of name {module.params['name']!r}" - ) - elif module.params["only_current"]: - contexts = [ContextAPI.get_context(current_context_name)] - if not contexts[0]: + if name: + context_or_none = ContextAPI.get_context(name) + if not context_or_none: + module.fail_json(msg=f"There is no context of name {name!r}") + contexts = [context_or_none] + elif only_current: + context_or_none = ContextAPI.get_context(current_context_name) + if not context_or_none: module.fail_json( msg=f"There is no context of name {current_context_name!r}, which is configured as the default context ({current_context_source})", ) + contexts = [context_or_none] else: contexts = ContextAPI.contexts() diff --git a/plugins/modules/docker_host_info.py b/plugins/modules/docker_host_info.py index b71387cf..02fa5ed5 100644 --- a/plugins/modules/docker_host_info.py +++ b/plugins/modules/docker_host_info.py @@ -212,6 +212,7 @@ disk_usage: """ import traceback +import typing as t from ansible_collections.community.docker.plugins.module_utils._api.errors import ( APIError, @@ -231,9 +232,7 @@ from ansible_collections.community.docker.plugins.module_utils._util import ( class DockerHostManager(DockerBaseClass): - - def __init__(self, client, results): - + def __init__(self, client: AnsibleDockerClient, results: dict[str, t.Any]) -> None: super().__init__() self.client = client @@ -253,21 +252,21 @@ class DockerHostManager(DockerBaseClass): for docker_object in listed_objects: if self.client.module.params[docker_object]: returned_name = docker_object - filter_name = docker_object + "_filters" + filter_name = f"{docker_object}_filters" filters = clean_dict_booleans_for_docker_api( - client.module.params.get(filter_name), True + client.module.params.get(filter_name), allow_sequences=True ) self.results[returned_name] = self.get_docker_items_list( docker_object, filters ) - def get_docker_host_info(self): + def get_docker_host_info(self) -> dict[str, t.Any]: try: return self.client.info() except APIError as exc: self.client.fail(f"Error inspecting docker host: {exc}") - def get_docker_disk_usage_facts(self): + def get_docker_disk_usage_facts(self) -> dict[str, t.Any]: try: if self.verbose_output: return self.client.df() @@ -275,9 +274,13 @@ class DockerHostManager(DockerBaseClass): except APIError as exc: self.client.fail(f"Error inspecting docker host: {exc}") - def get_docker_items_list(self, docker_object=None, filters=None, verbose=False): - items = None - items_list = [] + def get_docker_items_list( + self, + docker_object: str, + filters: dict[str, t.Any] | None = None, + verbose: bool = False, + ) -> list[dict[str, t.Any]]: + items = [] header_containers = [ "Id", @@ -329,6 +332,7 @@ class DockerHostManager(DockerBaseClass): if self.verbose_output: return items + items_list = [] for item in items: item_record = {} @@ -349,7 +353,7 @@ class DockerHostManager(DockerBaseClass): return items_list -def main(): +def main() -> None: argument_spec = { "containers": {"type": "bool", "default": False}, "containers_all": {"type": "bool", "default": False}, diff --git a/plugins/modules/docker_image.py b/plugins/modules/docker_image.py index f1e9e0d1..4605ba46 100644 --- a/plugins/modules/docker_image.py +++ b/plugins/modules/docker_image.py @@ -367,6 +367,7 @@ import errno import json import os import traceback +import typing as t from ansible.module_utils.common.text.converters import to_native from ansible.module_utils.common.text.formatters import human_to_bytes @@ -411,7 +412,18 @@ from ansible_collections.community.docker.plugins.module_utils._version import ( ) -def convert_to_bytes(value, module, name, unlimited_value=None): +if t.TYPE_CHECKING: + from collections.abc import Callable + + from ansible.module_utils.basic import AnsibleModule + + +def convert_to_bytes( + value: str | None, + module: AnsibleModule, + name: str, + unlimited_value: int | None = None, +) -> int | None: if value is None: return value try: @@ -423,8 +435,7 @@ def convert_to_bytes(value, module, name, unlimited_value=None): class ImageManager(DockerBaseClass): - - def __init__(self, client, results): + def __init__(self, client: AnsibleDockerClient, results: dict[str, t.Any]) -> None: """ Configure a docker_image task. @@ -441,12 +452,14 @@ class ImageManager(DockerBaseClass): parameters = self.client.module.params self.check_mode = self.client.check_mode - self.source = parameters["source"] - build = parameters["build"] or {} - pull = parameters["pull"] or {} - self.archive_path = parameters["archive_path"] - self.cache_from = build.get("cache_from") - self.container_limits = build.get("container_limits") + self.source: t.Literal["build", "load", "pull", "local"] | None = parameters[ + "source" + ] + build: dict[str, t.Any] = parameters["build"] or {} + pull: dict[str, t.Any] = parameters["pull"] or {} + self.archive_path: str | None = parameters["archive_path"] + self.cache_from: list[str] | None = build.get("cache_from") + self.container_limits: dict[str, t.Any] | None = build.get("container_limits") if self.container_limits and "memory" in self.container_limits: self.container_limits["memory"] = convert_to_bytes( self.container_limits["memory"], @@ -460,32 +473,36 @@ class ImageManager(DockerBaseClass): "build.container_limits.memswap", unlimited_value=-1, ) - self.dockerfile = build.get("dockerfile") - self.force_source = parameters["force_source"] - self.force_absent = parameters["force_absent"] - self.force_tag = parameters["force_tag"] - self.load_path = parameters["load_path"] - self.name = parameters["name"] - self.network = build.get("network") - self.extra_hosts = clean_dict_booleans_for_docker_api(build.get("etc_hosts")) - self.nocache = build.get("nocache", False) - self.build_path = build.get("path") - self.pull = build.get("pull") - self.target = build.get("target") - self.repository = parameters["repository"] - self.rm = build.get("rm", True) - self.state = parameters["state"] - self.tag = parameters["tag"] - self.http_timeout = build.get("http_timeout") - self.pull_platform = pull.get("platform") - self.push = parameters["push"] - self.buildargs = build.get("args") - self.build_platform = build.get("platform") - self.use_config_proxy = build.get("use_config_proxy") - self.shm_size = convert_to_bytes( + self.dockerfile: str | None = build.get("dockerfile") + self.force_source: bool = parameters["force_source"] + self.force_absent: bool = parameters["force_absent"] + self.force_tag: bool = parameters["force_tag"] + self.load_path: str | None = parameters["load_path"] + self.name: str = parameters["name"] + self.network: str | None = build.get("network") + self.extra_hosts: dict[str, str] = clean_dict_booleans_for_docker_api( + build.get("etc_hosts") # type: ignore + ) + self.nocache: bool = build.get("nocache", False) + self.build_path: str | None = build.get("path") + self.pull: bool | None = build.get("pull") + self.target: str | None = build.get("target") + self.repository: str | None = parameters["repository"] + self.rm: bool = build.get("rm", True) + self.state: t.Literal["absent", "present"] = parameters["state"] + self.tag: str = parameters["tag"] + self.http_timeout: int | None = build.get("http_timeout") + self.pull_platform: str | None = pull.get("platform") + self.push: bool = parameters["push"] + self.buildargs: dict[str, t.Any] | None = build.get("args") + self.build_platform: str | None = build.get("platform") + self.use_config_proxy: bool | None = build.get("use_config_proxy") + self.shm_size: int | None = convert_to_bytes( build.get("shm_size"), self.client.module, "build.shm_size" ) - self.labels = clean_dict_booleans_for_docker_api(build.get("labels")) + self.labels: dict[str, str] = clean_dict_booleans_for_docker_api( + build.get("labels") # type: ignore + ) # If name contains a tag, it takes precedence over tag parameter. if not is_image_name_id(self.name): @@ -507,10 +524,10 @@ class ImageManager(DockerBaseClass): elif self.state == "absent": self.absent() - def fail(self, msg): + def fail(self, msg: str) -> t.NoReturn: self.client.fail(msg) - def present(self): + def present(self) -> None: """ Handles state = 'present', which includes building, loading or pulling an image, depending on user provided parameters. @@ -530,6 +547,7 @@ class ImageManager(DockerBaseClass): ) # Build the image + assert self.build_path is not None if not os.path.isdir(self.build_path): self.fail( f"Requested build path {self.build_path} could not be found or you do not have access." @@ -546,6 +564,7 @@ class ImageManager(DockerBaseClass): self.results.update(self.build_image()) elif self.source == "load": + assert self.load_path is not None # Load the image from an archive if not os.path.isfile(self.load_path): self.fail( @@ -596,7 +615,7 @@ class ImageManager(DockerBaseClass): elif self.repository: self.tag_image(self.name, self.tag, self.repository, push=self.push) - def absent(self): + def absent(self) -> None: """ Handles state = 'absent', which removes an image. @@ -627,8 +646,11 @@ class ImageManager(DockerBaseClass): @staticmethod def archived_image_action( - failure_logger, archive_path, current_image_name, current_image_id - ): + failure_logger: Callable[[str], None], + archive_path: str, + current_image_name: str, + current_image_id: str, + ) -> str | None: """ If the archive is missing or requires replacement, return an action message. @@ -667,7 +689,7 @@ class ImageManager(DockerBaseClass): f"overwriting archive with image {archived.image_id} named {name}" ) - def archive_image(self, name, tag): + def archive_image(self, name: str, tag: str | None) -> None: """ Archive an image to a .tar file. Called when archive_path is passed. @@ -676,6 +698,7 @@ class ImageManager(DockerBaseClass): :param tag: Optional image tag; assumed to be "latest" if None :type tag: str | None """ + assert self.archive_path is not None if not tag: tag = "latest" @@ -710,8 +733,8 @@ class ImageManager(DockerBaseClass): self.client._get( self.client._url("/images/{0}/get", image_name), stream=True ), - DEFAULT_DATA_CHUNK_SIZE, - False, + chunk_size=DEFAULT_DATA_CHUNK_SIZE, + decode=False, ) except Exception as exc: # pylint: disable=broad-exception-caught self.fail(f"Error getting image {image_name} - {exc}") @@ -725,7 +748,7 @@ class ImageManager(DockerBaseClass): self.results["image"] = image - def push_image(self, name, tag=None): + def push_image(self, name: str, tag: str | None = None) -> None: """ If the name of the image contains a repository path, then push the image. @@ -799,7 +822,9 @@ class ImageManager(DockerBaseClass): self.results["image"] = {} self.results["image"]["push_status"] = status - def tag_image(self, name, tag, repository, push=False): + def tag_image( + self, name: str, tag: str, repository: str, push: bool = False + ) -> None: """ Tag an image into a repository. @@ -852,7 +877,7 @@ class ImageManager(DockerBaseClass): self.push_image(repo, repo_tag) @staticmethod - def _extract_output_line(line, output): + def _extract_output_line(line: dict[str, t.Any], output: list[str]): """ Extract text line from stream output and, if found, adds it to output. """ @@ -862,14 +887,15 @@ class ImageManager(DockerBaseClass): text_line = line.get("stream") or line.get("status") or "" output.extend(text_line.splitlines()) - def build_image(self): + def build_image(self) -> dict[str, t.Any]: """ Build an image :return: image dict """ + assert self.build_path is not None remote = context = None - headers = {} + headers: dict[str, str | bytes] = {} buildargs = {} if self.buildargs: for key, value in self.buildargs.items(): @@ -898,12 +924,12 @@ class ImageManager(DockerBaseClass): [line.strip() for line in f.read().splitlines()], ) ) - dockerfile = process_dockerfile(dockerfile, self.build_path) + dockerfile_data = process_dockerfile(dockerfile, self.build_path) context = tar( - self.build_path, exclude=exclude, dockerfile=dockerfile, gzip=False + self.build_path, exclude=exclude, dockerfile=dockerfile_data, gzip=False ) - params = { + params: dict[str, t.Any] = { "t": f"{self.name}:{self.tag}" if self.tag else self.name, "remote": remote, "q": False, @@ -960,7 +986,7 @@ class ImageManager(DockerBaseClass): if context is not None: context.close() - build_output = [] + build_output: list[str] = [] for line in self.client._stream_helper(response, decode=True): # line = json.loads(line) self.log(line, pretty_print=True) @@ -982,14 +1008,15 @@ class ImageManager(DockerBaseClass): "image": self.client.find_image(name=self.name, tag=self.tag), } - def load_image(self): + def load_image(self) -> dict[str, t.Any] | None: """ Load an image from a .tar archive :return: image dict """ # Load image(s) from file - load_output = [] + assert self.load_path is not None + load_output: list[str] = [] has_output = False try: self.log(f"Opening image {self.load_path}") @@ -1078,7 +1105,7 @@ class ImageManager(DockerBaseClass): return self.client.find_image(self.name, self.tag) -def main(): +def main() -> None: argument_spec = { "source": {"type": "str", "choices": ["build", "load", "pull", "local"]}, "build": { diff --git a/plugins/modules/docker_image_build.py b/plugins/modules/docker_image_build.py index b4b765cf..ee1f64c1 100644 --- a/plugins/modules/docker_image_build.py +++ b/plugins/modules/docker_image_build.py @@ -282,6 +282,7 @@ command: import base64 import os import traceback +import typing as t from ansible.module_utils.common.text.converters import to_native from ansible.module_utils.common.text.formatters import human_to_bytes @@ -304,7 +305,16 @@ from ansible_collections.community.docker.plugins.module_utils._version import ( ) -def convert_to_bytes(value, module, name, unlimited_value=None): +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + + +def convert_to_bytes( + value: str | None, + module: AnsibleModule, + name: str, + unlimited_value: int | None = None, +) -> int | None: if value is None: return value try: @@ -315,11 +325,11 @@ def convert_to_bytes(value, module, name, unlimited_value=None): module.fail_json(msg=f"Failed to convert {name} to bytes: {exc}") -def dict_to_list(dictionary, concat="="): +def dict_to_list(dictionary: dict[str, t.Any], concat: str = "=") -> list[str]: return [f"{k}{concat}{v}" for k, v in sorted(dictionary.items())] -def _quote_csv(text): +def _quote_csv(text: str) -> str: if text.strip() == text and all(i not in text for i in '",\r\n'): return text text = text.replace('"', '""') @@ -327,7 +337,7 @@ def _quote_csv(text): class ImageBuilder(DockerBaseClass): - def __init__(self, client): + def __init__(self, client: AnsibleModuleDockerClient) -> None: super().__init__() self.client = client self.check_mode = self.client.check_mode @@ -420,14 +430,14 @@ class ImageBuilder(DockerBaseClass): f" buildx plugin has version {buildx_version} which only supports one output." ) - def fail(self, msg, **kwargs): + def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn: self.client.fail(msg, **kwargs) - def add_list_arg(self, args, option, values): + def add_list_arg(self, args: list[str], option: str, values: list[str]) -> None: for value in values: args.extend([option, value]) - def add_args(self, args): + def add_args(self, args: list[str]) -> dict[str, t.Any]: environ_update = {} if not self.outputs: args.extend(["--tag", f"{self.name}:{self.tag}"]) @@ -512,9 +522,9 @@ class ImageBuilder(DockerBaseClass): ) return environ_update - def build_image(self): + def build_image(self) -> dict[str, t.Any]: image = self.client.find_image(self.name, self.tag) - results = { + results: dict[str, t.Any] = { "changed": False, "actions": [], "image": image or {}, @@ -547,7 +557,7 @@ class ImageBuilder(DockerBaseClass): return results -def main(): +def main() -> None: argument_spec = { "name": {"type": "str", "required": True}, "tag": {"type": "str", "default": "latest"}, diff --git a/plugins/modules/docker_image_export.py b/plugins/modules/docker_image_export.py index 0d107631..51c1af97 100644 --- a/plugins/modules/docker_image_export.py +++ b/plugins/modules/docker_image_export.py @@ -94,6 +94,7 @@ images: """ import traceback +import typing as t from ansible_collections.community.docker.plugins.module_utils._api.constants import ( DEFAULT_DATA_CHUNK_SIZE, @@ -121,7 +122,7 @@ from ansible_collections.community.docker.plugins.module_utils._util import ( class ImageExportManager(DockerBaseClass): - def __init__(self, client): + def __init__(self, client: AnsibleDockerClient) -> None: super().__init__() self.client = client @@ -151,10 +152,10 @@ class ImageExportManager(DockerBaseClass): if not self.names: self.fail("At least one image name must be specified") - def fail(self, msg): + def fail(self, msg: str) -> t.NoReturn: self.client.fail(msg) - def get_export_reason(self): + def get_export_reason(self) -> str | None: if self.force: return "Exporting since force=true" @@ -178,13 +179,13 @@ class ImageExportManager(DockerBaseClass): found = True break if not found: - return f'Overwriting archive since it contains unexpected image {archived_image.image_id} named {", ".join(archived_image.repo_tags)}' + return f"Overwriting archive since it contains unexpected image {archived_image.image_id} named {', '.join(archived_image.repo_tags)}" if left_names: return f"Overwriting archive since it is missing image(s) {', '.join([name['joined'] for name in left_names])}" return None - def write_chunks(self, chunks): + def write_chunks(self, chunks: t.Generator[bytes]) -> None: try: with open(self.path, "wb") as fd: for chunk in chunks: @@ -192,7 +193,7 @@ class ImageExportManager(DockerBaseClass): except Exception as exc: # pylint: disable=broad-exception-caught self.fail(f"Error writing image archive {self.path} - {exc}") - def export_images(self): + def export_images(self) -> None: image_names = [name["joined"] for name in self.names] image_names_str = ", ".join(image_names) if len(image_names) == 1: @@ -202,8 +203,8 @@ class ImageExportManager(DockerBaseClass): self.client._get( self.client._url("/images/{0}/get", image_names[0]), stream=True ), - DEFAULT_DATA_CHUNK_SIZE, - False, + chunk_size=DEFAULT_DATA_CHUNK_SIZE, + decode=False, ) except Exception as exc: # pylint: disable=broad-exception-caught self.fail(f"Error getting image {image_names[0]} - {exc}") @@ -216,15 +217,15 @@ class ImageExportManager(DockerBaseClass): stream=True, params={"names": image_names}, ), - DEFAULT_DATA_CHUNK_SIZE, - False, + chunk_size=DEFAULT_DATA_CHUNK_SIZE, + decode=False, ) except Exception as exc: # pylint: disable=broad-exception-caught self.fail(f"Error getting images {image_names_str} - {exc}") self.write_chunks(chunks) - def run(self): + def run(self) -> dict[str, t.Any]: tag = self.tag if not tag: tag = "latest" @@ -260,7 +261,7 @@ class ImageExportManager(DockerBaseClass): return results -def main(): +def main() -> None: argument_spec = { "path": {"type": "path"}, "force": {"type": "bool", "default": False}, diff --git a/plugins/modules/docker_image_info.py b/plugins/modules/docker_image_info.py index 050f60d7..2c6d24db 100644 --- a/plugins/modules/docker_image_info.py +++ b/plugins/modules/docker_image_info.py @@ -136,6 +136,7 @@ images: """ import traceback +import typing as t from ansible_collections.community.docker.plugins.module_utils._api.errors import ( DockerException, @@ -155,9 +156,7 @@ from ansible_collections.community.docker.plugins.module_utils._util import ( class ImageManager(DockerBaseClass): - - def __init__(self, client, results): - + def __init__(self, client: AnsibleDockerClient, results: dict[str, t.Any]) -> None: super().__init__() self.client = client @@ -170,10 +169,10 @@ class ImageManager(DockerBaseClass): else: self.results["images"] = self.get_all_images() - def fail(self, msg): + def fail(self, msg: str) -> t.NoReturn: self.client.fail(msg) - def get_facts(self): + def get_facts(self) -> list[dict[str, t.Any]]: """ Lookup and inspect each image name found in the names parameter. @@ -200,7 +199,7 @@ class ImageManager(DockerBaseClass): results.append(image) return results - def get_all_images(self): + def get_all_images(self) -> list[dict[str, t.Any]]: results = [] params = { "only_ids": 0, @@ -218,7 +217,7 @@ class ImageManager(DockerBaseClass): return results -def main(): +def main() -> None: argument_spec = { "name": {"type": "list", "elements": "str"}, } diff --git a/plugins/modules/docker_image_load.py b/plugins/modules/docker_image_load.py index e7e10431..46c97de5 100644 --- a/plugins/modules/docker_image_load.py +++ b/plugins/modules/docker_image_load.py @@ -80,6 +80,7 @@ images: import errno import traceback +import typing as t from ansible_collections.community.docker.plugins.module_utils._api.errors import ( DockerException, @@ -95,7 +96,7 @@ from ansible_collections.community.docker.plugins.module_utils._util import ( class ImageManager(DockerBaseClass): - def __init__(self, client, results): + def __init__(self, client: AnsibleDockerClient, results: dict[str, t.Any]) -> None: super().__init__() self.client = client @@ -108,7 +109,7 @@ class ImageManager(DockerBaseClass): self.load_images() @staticmethod - def _extract_output_line(line, output): + def _extract_output_line(line: dict[str, t.Any], output: list[str]) -> None: """ Extract text line from stream output and, if found, adds it to output. """ @@ -118,12 +119,12 @@ class ImageManager(DockerBaseClass): text_line = line.get("stream") or line.get("status") or "" output.extend(text_line.splitlines()) - def load_images(self): + def load_images(self) -> None: """ Load images from a .tar archive """ # Load image(s) from file - load_output = [] + load_output: list[str] = [] try: self.log(f"Opening image {self.path}") with open(self.path, "rb") as image_tar: @@ -179,7 +180,7 @@ class ImageManager(DockerBaseClass): self.results["stdout"] = "\n".join(load_output) -def main(): +def main() -> None: client = AnsibleDockerClient( argument_spec={ "path": {"type": "path", "required": True}, @@ -188,7 +189,7 @@ def main(): ) try: - results = { + results: dict[str, t.Any] = { "image_names": [], "images": [], } diff --git a/plugins/modules/docker_image_pull.py b/plugins/modules/docker_image_pull.py index 02042734..a443f2f8 100644 --- a/plugins/modules/docker_image_pull.py +++ b/plugins/modules/docker_image_pull.py @@ -91,6 +91,7 @@ image: """ import traceback +import typing as t from ansible_collections.community.docker.plugins.module_utils._api.errors import ( DockerException, @@ -114,7 +115,7 @@ from ansible_collections.community.docker.plugins.module_utils._util import ( ) -def image_info(image): +def image_info(image: dict[str, t.Any] | None) -> dict[str, t.Any]: result = {} if image: result["id"] = image["Id"] @@ -124,17 +125,17 @@ def image_info(image): class ImagePuller(DockerBaseClass): - def __init__(self, client): + def __init__(self, client: AnsibleDockerClient) -> None: super().__init__() self.client = client self.check_mode = self.client.check_mode parameters = self.client.module.params - self.name = parameters["name"] - self.tag = parameters["tag"] - self.platform = parameters["platform"] - self.pull_mode = parameters["pull"] + self.name: str = parameters["name"] + self.tag: str = parameters["tag"] + self.platform: str | None = parameters["platform"] + self.pull_mode: t.Literal["always", "not_present"] = parameters["pull"] if is_image_name_id(self.name): self.client.fail("Cannot pull an image by ID") @@ -147,13 +148,15 @@ class ImagePuller(DockerBaseClass): self.name = repo self.tag = repo_tag - def pull(self): + def pull(self) -> dict[str, t.Any]: image = self.client.find_image(name=self.name, tag=self.tag) + actions: list[str] = [] + diff = {"before": image_info(image), "after": image_info(image)} results = { "changed": False, - "actions": [], + "actions": actions, "image": image or {}, - "diff": {"before": image_info(image), "after": image_info(image)}, + "diff": diff, } if image and self.pull_mode == "not_present": @@ -175,21 +178,22 @@ class ImagePuller(DockerBaseClass): if compare_platform_strings(wanted_platform, image_platform): return results - results["actions"].append(f"Pulled image {self.name}:{self.tag}") + actions.append(f"Pulled image {self.name}:{self.tag}") if self.check_mode: results["changed"] = True - results["diff"]["after"] = image_info({"Id": "unknown"}) + diff["after"] = image_info({"Id": "unknown"}) else: - results["image"], not_changed = self.client.pull_image( + image, not_changed = self.client.pull_image( self.name, tag=self.tag, image_platform=self.platform ) + results["image"] = image results["changed"] = not not_changed - results["diff"]["after"] = image_info(results["image"]) + diff["after"] = image_info(image) return results -def main(): +def main() -> None: argument_spec = { "name": {"type": "str", "required": True}, "tag": {"type": "str", "default": "latest"}, diff --git a/plugins/modules/docker_image_push.py b/plugins/modules/docker_image_push.py index a350df34..8d8100dd 100644 --- a/plugins/modules/docker_image_push.py +++ b/plugins/modules/docker_image_push.py @@ -73,6 +73,7 @@ image: import base64 import traceback +import typing as t from ansible_collections.community.docker.plugins.module_utils._api.auth import ( get_config_header, @@ -96,15 +97,15 @@ from ansible_collections.community.docker.plugins.module_utils._util import ( class ImagePusher(DockerBaseClass): - def __init__(self, client): + def __init__(self, client: AnsibleDockerClient) -> None: super().__init__() self.client = client self.check_mode = self.client.check_mode parameters = self.client.module.params - self.name = parameters["name"] - self.tag = parameters["tag"] + self.name: str = parameters["name"] + self.tag: str = parameters["tag"] if is_image_name_id(self.name): self.client.fail("Cannot push an image by ID") @@ -122,20 +123,21 @@ class ImagePusher(DockerBaseClass): if not is_valid_tag(self.tag, allow_empty=False): self.client.fail(f'"{self.tag}" is not a valid docker tag!') - def push(self): + def push(self) -> dict[str, t.Any]: image = self.client.find_image(name=self.name, tag=self.tag) if not image: self.client.fail(f"Cannot find image {self.name}:{self.tag}") - results = { + actions: list[str] = [] + results: dict[str, t.Any] = { "changed": False, - "actions": [], + "actions": actions, "image": image, } push_registry, push_repo = resolve_repository_name(self.name) try: - results["actions"].append(f"Pushed image {self.name}:{self.tag}") + actions.append(f"Pushed image {self.name}:{self.tag}") headers = {} header = get_config_header(self.client, push_registry) @@ -174,7 +176,7 @@ class ImagePusher(DockerBaseClass): return results -def main(): +def main() -> None: argument_spec = { "name": {"type": "str", "required": True}, "tag": {"type": "str", "default": "latest"}, diff --git a/plugins/modules/docker_image_remove.py b/plugins/modules/docker_image_remove.py index 5bb1276f..2b3f0e16 100644 --- a/plugins/modules/docker_image_remove.py +++ b/plugins/modules/docker_image_remove.py @@ -98,6 +98,7 @@ untagged: """ import traceback +import typing as t from ansible_collections.community.docker.plugins.module_utils._api.errors import ( DockerException, @@ -118,8 +119,7 @@ from ansible_collections.community.docker.plugins.module_utils._util import ( class ImageRemover(DockerBaseClass): - - def __init__(self, client): + def __init__(self, client: AnsibleDockerClient) -> None: super().__init__() self.client = client @@ -142,10 +142,10 @@ class ImageRemover(DockerBaseClass): self.name = repo self.tag = repo_tag - def fail(self, msg): + def fail(self, msg: str) -> t.NoReturn: self.client.fail(msg) - def get_diff_state(self, image): + def get_diff_state(self, image: dict[str, t.Any] | None) -> dict[str, t.Any]: if not image: return {"exists": False} return { @@ -155,13 +155,16 @@ class ImageRemover(DockerBaseClass): "digests": sorted(image.get("RepoDigests") or []), } - def absent(self): - results = { + def absent(self) -> dict[str, t.Any]: + actions: list[str] = [] + deleted: list[str] = [] + untagged: list[str] = [] + results: dict[str, t.Any] = { "changed": False, - "actions": [], + "actions": actions, "image": {}, - "deleted": [], - "untagged": [], + "deleted": deleted, + "untagged": untagged, } name = self.name @@ -172,16 +175,18 @@ class ImageRemover(DockerBaseClass): if self.tag: name = f"{self.name}:{self.tag}" + diff: dict[str, t.Any] = {} if self.diff: - results["diff"] = {"before": self.get_diff_state(image)} + results["diff"] = diff + diff["before"] = self.get_diff_state(image) if not image: if self.diff: - results["diff"]["after"] = self.get_diff_state(image) + diff["after"] = self.get_diff_state(image) return results results["changed"] = True - results["actions"].append(f"Removed image {name}") + actions.append(f"Removed image {name}") results["image"] = image if not self.check_mode: @@ -199,22 +204,22 @@ class ImageRemover(DockerBaseClass): for entry in res: if entry.get("Untagged"): - results["untagged"].append(entry["Untagged"]) + untagged.append(entry["Untagged"]) if entry.get("Deleted"): - results["deleted"].append(entry["Deleted"]) + deleted.append(entry["Deleted"]) - results["untagged"] = sorted(results["untagged"]) - results["deleted"] = sorted(results["deleted"]) + untagged[:] = sorted(untagged) + deleted[:] = sorted(deleted) if self.diff: image_after = self.client.find_image_by_id( image["Id"], accept_missing_image=True ) - results["diff"]["after"] = self.get_diff_state(image_after) + diff["after"] = self.get_diff_state(image_after) elif is_image_name_id(name): - results["deleted"].append(image["Id"]) - results["untagged"] = sorted( + deleted.append(image["Id"]) + untagged[:] = sorted( (image.get("RepoTags") or []) + (image.get("RepoDigests") or []) ) if not self.force and results["untagged"]: @@ -222,40 +227,40 @@ class ImageRemover(DockerBaseClass): "Cannot delete image by ID that is still in use - use force=true" ) if self.diff: - results["diff"]["after"] = self.get_diff_state({}) + diff["after"] = self.get_diff_state({}) elif is_image_name_id(self.tag): - results["untagged"].append(name) + untagged.append(name) if ( len(image.get("RepoTags") or []) < 1 and len(image.get("RepoDigests") or []) < 2 ): - results["deleted"].append(image["Id"]) + deleted.append(image["Id"]) if self.diff: - results["diff"]["after"] = self.get_diff_state(image) + diff["after"] = self.get_diff_state(image) try: - results["diff"]["after"]["digests"].remove(name) + diff["after"]["digests"].remove(name) except ValueError: pass else: - results["untagged"].append(name) + untagged.append(name) if ( len(image.get("RepoTags") or []) < 2 and len(image.get("RepoDigests") or []) < 1 ): - results["deleted"].append(image["Id"]) + deleted.append(image["Id"]) if self.diff: - results["diff"]["after"] = self.get_diff_state(image) + diff["after"] = self.get_diff_state(image) try: - results["diff"]["after"]["tags"].remove(name) + diff["after"]["tags"].remove(name) except ValueError: pass return results -def main(): +def main() -> None: argument_spec = { "name": {"type": "str", "required": True}, "tag": {"type": "str", "default": "latest"}, diff --git a/plugins/modules/docker_image_tag.py b/plugins/modules/docker_image_tag.py index 65dc7642..1fff33bc 100644 --- a/plugins/modules/docker_image_tag.py +++ b/plugins/modules/docker_image_tag.py @@ -101,6 +101,7 @@ tagged_images: """ import traceback +import typing as t from ansible.module_utils.common.text.formatters import human_to_bytes @@ -121,7 +122,16 @@ from ansible_collections.community.docker.plugins.module_utils._util import ( ) -def convert_to_bytes(value, module, name, unlimited_value=None): +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + + +def convert_to_bytes( + value: str | None, + module: AnsibleModule, + name: str, + unlimited_value: int | None = None, +) -> int | None: if value is None: return value try: @@ -132,8 +142,8 @@ def convert_to_bytes(value, module, name, unlimited_value=None): module.fail_json(msg=f"Failed to convert {name} to bytes: {exc}") -def image_info(name, tag, image): - result = {"name": name, "tag": tag} +def image_info(name: str, tag: str, image: dict[str, t.Any] | None) -> dict[str, t.Any]: + result: dict[str, t.Any] = {"name": name, "tag": tag} if image: result["id"] = image["Id"] else: @@ -142,7 +152,7 @@ def image_info(name, tag, image): class ImageTagger(DockerBaseClass): - def __init__(self, client): + def __init__(self, client: AnsibleDockerClient) -> None: super().__init__() self.client = client @@ -179,10 +189,12 @@ class ImageTagger(DockerBaseClass): ) self.repositories.append((repo, repo_tag)) - def fail(self, msg): + def fail(self, msg: str) -> t.NoReturn: self.client.fail(msg) - def tag_image(self, image, name, tag): + def tag_image( + self, image: dict[str, t.Any], name: str, tag: str + ) -> tuple[bool, str, dict[str, t.Any] | None]: tagged_image = self.client.find_image(name=name, tag=tag) if tagged_image: # Idempotency checks @@ -220,20 +232,22 @@ class ImageTagger(DockerBaseClass): return True, msg, tagged_image - def tag_images(self): + def tag_images(self) -> dict[str, t.Any]: if is_image_name_id(self.name): image = self.client.find_image_by_id(self.name, accept_missing_image=False) else: image = self.client.find_image(name=self.name, tag=self.tag) if not image: self.fail(f"Cannot find image {self.name}:{self.tag}") + assert image is not None - before = [] - after = [] - tagged_images = [] - results = { + before: list[dict[str, t.Any]] = [] + after: list[dict[str, t.Any]] = [] + tagged_images: list[str] = [] + actions: list[str] = [] + results: dict[str, t.Any] = { "changed": False, - "actions": [], + "actions": actions, "image": image, "tagged_images": tagged_images, "diff": {"before": {"images": before}, "after": {"images": after}}, @@ -244,19 +258,19 @@ class ImageTagger(DockerBaseClass): after.append(image_info(repository, tag, image if tagged else old_image)) if tagged: results["changed"] = True - results["actions"].append( + actions.append( f"Tagged image {image['Id']} as {repository}:{tag}: {msg}" ) tagged_images.append(f"{repository}:{tag}") else: - results["actions"].append( + actions.append( f"Not tagged image {image['Id']} as {repository}:{tag}: {msg}" ) return results -def main(): +def main() -> None: argument_spec = { "name": {"type": "str", "required": True}, "tag": {"type": "str", "default": "latest"}, diff --git a/plugins/modules/docker_login.py b/plugins/modules/docker_login.py index f90ac6a9..9a87be1d 100644 --- a/plugins/modules/docker_login.py +++ b/plugins/modules/docker_login.py @@ -120,6 +120,7 @@ import base64 import json import os import traceback +import typing as t from ansible.module_utils.common.text.converters import to_bytes, to_text @@ -154,11 +155,11 @@ class DockerFileStore: program = "" - def __init__(self, config_path): + def __init__(self, config_path: str) -> None: self._config_path = config_path # Make sure we have a minimal config if none is available. - self._config = {"auths": {}} + self._config: dict[str, t.Any] = {"auths": {}} try: # Attempt to read the existing config. @@ -172,14 +173,14 @@ class DockerFileStore: self._config.update(config) @property - def config_path(self): + def config_path(self) -> str: """ Return the config path configured in this DockerFileStore instance. """ return self._config_path - def get(self, server): + def get(self, server: str) -> dict[str, t.Any]: """ Retrieve credentials for `server` if there are any in the config file. Otherwise raise a `StoreError` @@ -193,7 +194,7 @@ class DockerFileStore: return {"Username": username, "Secret": password} - def _write(self): + def _write(self) -> None: """ Write config back out to disk. """ @@ -209,7 +210,7 @@ class DockerFileStore: finally: os.close(f) - def store(self, server, username, password): + def store(self, server: str, username: str, password: str) -> None: """ Add a credentials for `server` to the current configuration. """ @@ -225,7 +226,7 @@ class DockerFileStore: self._write() - def erase(self, server): + def erase(self, server: str) -> None: """ Remove credentials for the given server from the configuration. """ @@ -236,9 +237,7 @@ class DockerFileStore: class LoginManager(DockerBaseClass): - - def __init__(self, client, results): - + def __init__(self, client: AnsibleDockerClient, results: dict[str, t.Any]) -> None: super().__init__() self.client = client @@ -246,14 +245,14 @@ class LoginManager(DockerBaseClass): parameters = self.client.module.params self.check_mode = self.client.check_mode - self.registry_url = parameters.get("registry_url") - self.username = parameters.get("username") - self.password = parameters.get("password") - self.reauthorize = parameters.get("reauthorize") - self.config_path = parameters.get("config_path") - self.state = parameters.get("state") + self.registry_url: str = parameters.get("registry_url") + self.username: str | None = parameters.get("username") + self.password: str | None = parameters.get("password") + self.reauthorize: bool = parameters.get("reauthorize") + self.config_path: str = parameters.get("config_path") + self.state: t.Literal["present", "absent"] = parameters.get("state") - def run(self): + def run(self) -> None: """ Do the actual work of this task here. This allows instantiation for partial testing. @@ -264,10 +263,10 @@ class LoginManager(DockerBaseClass): else: self.logout() - def fail(self, msg): + def fail(self, msg: str) -> t.NoReturn: self.client.fail(msg) - def _login(self, reauth): + def _login(self, reauth: bool) -> dict[str, t.Any]: if self.config_path and os.path.exists(self.config_path): self.client._auth_configs = auth.load_config( self.config_path, credstore_env=self.client.credstore_env @@ -297,7 +296,7 @@ class LoginManager(DockerBaseClass): ) return self.client._result(response, get_json=True) - def login(self): + def login(self) -> None: """ Log into the registry with provided username/password. On success update the config file with the new authorization. @@ -331,7 +330,7 @@ class LoginManager(DockerBaseClass): self.update_credentials() - def logout(self): + def logout(self) -> None: """ Log out of the registry. On success update the config file. @@ -353,13 +352,16 @@ class LoginManager(DockerBaseClass): store.erase(self.registry_url) self.results["changed"] = True - def update_credentials(self): + def update_credentials(self) -> None: """ If the authorization is not stored attempt to store authorization values via the appropriate credential helper or to the config file. :return: None """ + # This is only called from login() + assert self.username is not None + assert self.password is not None # Check to see if credentials already exist. store = self.get_credential_store_instance(self.registry_url, self.config_path) @@ -385,7 +387,9 @@ class LoginManager(DockerBaseClass): ) self.results["changed"] = True - def get_credential_store_instance(self, registry, dockercfg_path): + def get_credential_store_instance( + self, registry: str, dockercfg_path: str + ) -> Store | DockerFileStore: """ Return an instance of docker.credentials.Store used by the given registry. @@ -408,8 +412,7 @@ class LoginManager(DockerBaseClass): return DockerFileStore(dockercfg_path) -def main(): - +def main() -> None: argument_spec = { "registry_url": { "type": "str", diff --git a/plugins/modules/docker_network.py b/plugins/modules/docker_network.py index 6c539aea..2386e0ef 100644 --- a/plugins/modules/docker_network.py +++ b/plugins/modules/docker_network.py @@ -284,6 +284,7 @@ network: import re import time import traceback +import typing as t from ansible.module_utils.common.text.converters import to_native @@ -303,29 +304,31 @@ from ansible_collections.community.docker.plugins.module_utils._util import ( class TaskParameters(DockerBaseClass): - def __init__(self, client): + name: str + + def __init__(self, client: AnsibleDockerClient) -> None: super().__init__() self.client = client - self.name = None - self.connected = None - self.config_from = None - self.config_only = None - self.driver = None - self.driver_options = None - self.ipam_driver = None - self.ipam_driver_options = None - self.ipam_config = None - self.appends = None - self.force = None - self.internal = None - self.labels = None - self.debug = None - self.enable_ipv4 = None - self.enable_ipv6 = None - self.scope = None - self.attachable = None - self.ingress = None + self.connected: list[str] = [] + self.config_from: str | None = None + self.config_only: bool | None = None + self.driver: str = "bridge" + self.driver_options: dict[str, t.Any] = {} + self.ipam_driver: str | None = None + self.ipam_driver_options: dict[str, t.Any] | None = None + self.ipam_config: list[dict[str, t.Any]] | None = None + self.appends: bool = False + self.force: bool = False + self.internal: bool | None = None + self.labels: dict[str, t.Any] = {} + self.debug: bool = False + self.enable_ipv4: bool | None = None + self.enable_ipv6: bool | None = None + self.scope: t.Literal["local", "global", "swarm"] | None = None + self.attachable: bool | None = None + self.ingress: bool | None = None + self.state: t.Literal["present", "absent"] = "present" for key, value in client.module.params.items(): setattr(self, key, value) @@ -333,10 +336,10 @@ class TaskParameters(DockerBaseClass): # config_only sets driver to 'null' (and scope to 'local') so force that here. Otherwise we get # diffs of 'null' --> 'bridge' given that the driver option defaults to 'bridge'. if self.config_only: - self.driver = "null" + self.driver = "null" # type: ignore[unreachable] -def container_names_in_network(network): +def container_names_in_network(network: dict[str, t.Any]) -> list[str]: return ( [c["Name"] for c in network["Containers"].values()] if network["Containers"] @@ -348,7 +351,7 @@ CIDR_IPV4 = re.compile(r"^([0-9]{1,3}\.){3}[0-9]{1,3}/([0-9]|[1-2][0-9]|3[0-2])$ CIDR_IPV6 = re.compile(r"^[0-9a-fA-F:]+/([0-9]|[1-9][0-9]|1[0-2][0-9])$") -def validate_cidr(cidr): +def validate_cidr(cidr: str) -> t.Literal["ipv4", "ipv6"]: """Validate CIDR. Return IP version of a CIDR string on success. :param cidr: Valid CIDR @@ -364,7 +367,7 @@ def validate_cidr(cidr): raise ValueError(f'"{cidr}" is not a valid CIDR') -def normalize_ipam_config_key(key): +def normalize_ipam_config_key(key: str) -> str: """Normalizes IPAM config keys returned by Docker API to match Ansible keys. :param key: Docker API key @@ -376,7 +379,7 @@ def normalize_ipam_config_key(key): return special_cases.get(key, key.lower()) -def dicts_are_essentially_equal(a, b): +def dicts_are_essentially_equal(a: dict[str, t.Any], b: dict[str, t.Any]): """Make sure that a is a subset of b, where None entries of a are ignored.""" for k, v in a.items(): if v is None: @@ -387,15 +390,15 @@ def dicts_are_essentially_equal(a, b): class DockerNetworkManager: - - def __init__(self, client): + def __init__(self, client: AnsibleDockerClient) -> None: self.client = client self.parameters = TaskParameters(client) self.check_mode = self.client.check_mode - self.results = {"changed": False, "actions": []} + self.actions: list[str] = [] + self.results: dict[str, t.Any] = {"changed": False, "actions": self.actions} self.diff = self.client.module._diff self.diff_tracker = DifferenceTracker() - self.diff_result = {} + self.diff_result: dict[str, t.Any] = {} self.existing_network = self.get_existing_network() @@ -429,10 +432,12 @@ class DockerNetworkManager: ) self.results["diff"] = self.diff_result - def get_existing_network(self): + def get_existing_network(self) -> dict[str, t.Any] | None: return self.client.get_network(name=self.parameters.name) - def has_different_config(self, net): + def has_different_config( + self, net: dict[str, t.Any] + ) -> tuple[bool, DifferenceTracker]: """ Evaluates an existing network and returns a tuple containing a boolean indicating if the configuration is different and a list of differences. @@ -601,9 +606,9 @@ class DockerNetworkManager: return not differences.empty, differences - def create_network(self): + def create_network(self) -> None: if not self.existing_network: - data = { + data: dict[str, t.Any] = { "Name": self.parameters.name, "Driver": self.parameters.driver, "Options": self.parameters.driver_options, @@ -661,12 +666,12 @@ class DockerNetworkManager: resp = self.client.post_json_to_json("/networks/create", data=data) self.client.report_warnings(resp, ["Warning"]) self.existing_network = self.client.get_network(network_id=resp["Id"]) - self.results["actions"].append( + self.actions.append( f"Created network {self.parameters.name} with driver {self.parameters.driver}" ) self.results["changed"] = True - def remove_network(self): + def remove_network(self) -> None: if self.existing_network: self.disconnect_all_containers() if not self.check_mode: @@ -674,15 +679,15 @@ class DockerNetworkManager: if self.existing_network.get("Scope", "local") == "swarm": while self.get_existing_network(): time.sleep(0.1) - self.results["actions"].append(f"Removed network {self.parameters.name}") + self.actions.append(f"Removed network {self.parameters.name}") self.results["changed"] = True - def is_container_connected(self, container_name): + def is_container_connected(self, container_name: str) -> bool: if not self.existing_network: return False return container_name in container_names_in_network(self.existing_network) - def is_container_exist(self, container_name): + def is_container_exist(self, container_name: str) -> bool: try: container = self.client.get_container(container_name) return bool(container) @@ -698,7 +703,7 @@ class DockerNetworkManager: exception=traceback.format_exc(), ) - def connect_containers(self): + def connect_containers(self) -> None: for name in self.parameters.connected: if not self.is_container_connected(name) and self.is_container_exist(name): if not self.check_mode: @@ -709,11 +714,11 @@ class DockerNetworkManager: self.client.post_json( "/networks/{0}/connect", self.parameters.name, data=data ) - self.results["actions"].append(f"Connected container {name}") + self.actions.append(f"Connected container {name}") self.results["changed"] = True self.diff_tracker.add(f"connected.{name}", parameter=True, active=False) - def disconnect_missing(self): + def disconnect_missing(self) -> None: if not self.existing_network: return containers = self.existing_network["Containers"] @@ -724,26 +729,29 @@ class DockerNetworkManager: if name not in self.parameters.connected: self.disconnect_container(name) - def disconnect_all_containers(self): - containers = self.client.get_network(name=self.parameters.name)["Containers"] + def disconnect_all_containers(self) -> None: + network = self.client.get_network(name=self.parameters.name) + if not network: + return + containers = network["Containers"] if not containers: return for cont in containers.values(): self.disconnect_container(cont["Name"]) - def disconnect_container(self, container_name): + def disconnect_container(self, container_name: str) -> None: if not self.check_mode: data = {"Container": container_name, "Force": True} self.client.post_json( "/networks/{0}/disconnect", self.parameters.name, data=data ) - self.results["actions"].append(f"Disconnected container {container_name}") + self.actions.append(f"Disconnected container {container_name}") self.results["changed"] = True self.diff_tracker.add( f"connected.{container_name}", parameter=False, active=True ) - def present(self): + def present(self) -> None: different = False differences = DifferenceTracker() if self.existing_network: @@ -771,14 +779,14 @@ class DockerNetworkManager: network_facts = self.get_existing_network() self.results["network"] = network_facts - def absent(self): + def absent(self) -> None: self.diff_tracker.add( "exists", parameter=False, active=self.existing_network is not None ) self.remove_network() -def main(): +def main() -> None: argument_spec = { "name": {"type": "str", "required": True, "aliases": ["network_name"]}, "config_from": {"type": "str"}, diff --git a/plugins/modules/docker_network_info.py b/plugins/modules/docker_network_info.py index 03536119..b6ab29ae 100644 --- a/plugins/modules/docker_network_info.py +++ b/plugins/modules/docker_network_info.py @@ -107,7 +107,7 @@ from ansible_collections.community.docker.plugins.module_utils._common_api impor ) -def main(): +def main() -> None: argument_spec = { "name": {"type": "str", "required": True}, } diff --git a/plugins/modules/docker_plugin.py b/plugins/modules/docker_plugin.py index d9953953..024cd282 100644 --- a/plugins/modules/docker_plugin.py +++ b/plugins/modules/docker_plugin.py @@ -129,6 +129,7 @@ actions: """ import traceback +import typing as t from ansible.module_utils.common.text.converters import to_native @@ -149,35 +150,36 @@ from ansible_collections.community.docker.plugins.module_utils._util import ( class TaskParameters(DockerBaseClass): - def __init__(self, client): + plugin_name: str + + def __init__(self, client: AnsibleDockerClient) -> None: super().__init__() self.client = client - self.plugin_name = None - self.alias = None - self.plugin_options = None - self.debug = None - self.force_remove = None - self.enable_timeout = None + self.alias: str | None = None + self.plugin_options: dict[str, t.Any] = {} + self.debug: bool = False + self.force_remove: bool = False + self.enable_timeout: int = 0 + self.state: t.Literal["present", "absent", "enable", "disable"] = "present" for key, value in client.module.params.items(): setattr(self, key, value) -def prepare_options(options): +def prepare_options(options: dict[str, t.Any] | None) -> list[str]: return ( - [f'{k}={v if v is not None else ""}' for k, v in options.items()] + [f"{k}={v if v is not None else ''}" for k, v in options.items()] if options else [] ) -def parse_options(options_list): +def parse_options(options_list: list[str] | None) -> dict[str, str]: return dict(x.split("=", 1) for x in options_list) if options_list else {} class DockerPluginManager: - - def __init__(self, client): + def __init__(self, client: AnsibleDockerClient) -> None: self.client = client self.parameters = TaskParameters(client) @@ -185,9 +187,9 @@ class DockerPluginManager: self.check_mode = self.client.check_mode self.diff = self.client.module._diff self.diff_tracker = DifferenceTracker() - self.diff_result = {} + self.diff_result: dict[str, t.Any] = {} - self.actions = [] + self.actions: list[str] = [] self.changed = False self.existing_plugin = self.get_existing_plugin() @@ -209,7 +211,7 @@ class DockerPluginManager: ) self.diff = self.diff_result - def get_existing_plugin(self): + def get_existing_plugin(self) -> dict[str, t.Any] | None: try: return self.client.get_json("/plugins/{0}/json", self.preferred_name) except NotFound: @@ -217,12 +219,13 @@ class DockerPluginManager: except APIError as e: self.client.fail(to_native(e)) - def has_different_config(self): + def has_different_config(self) -> DifferenceTracker: """ Return the list of differences between the current parameters and the existing plugin. :return: list of options that differ """ + assert self.existing_plugin is not None differences = DifferenceTracker() if self.parameters.plugin_options: settings = self.existing_plugin.get("Settings") @@ -249,7 +252,7 @@ class DockerPluginManager: return differences - def install_plugin(self): + def install_plugin(self) -> None: if not self.existing_plugin: if not self.check_mode: try: @@ -297,7 +300,7 @@ class DockerPluginManager: self.actions.append(f"Installed plugin {self.preferred_name}") self.changed = True - def remove_plugin(self): + def remove_plugin(self) -> None: force = self.parameters.force_remove if self.existing_plugin: if not self.check_mode: @@ -311,7 +314,7 @@ class DockerPluginManager: self.actions.append(f"Removed plugin {self.preferred_name}") self.changed = True - def update_plugin(self): + def update_plugin(self) -> None: if self.existing_plugin: differences = self.has_different_config() if not differences.empty: @@ -328,7 +331,7 @@ class DockerPluginManager: else: self.client.fail("Cannot update the plugin: Plugin does not exist") - def present(self): + def present(self) -> None: differences = DifferenceTracker() if self.existing_plugin: differences = self.has_different_config() @@ -345,13 +348,10 @@ class DockerPluginManager: if self.diff or self.check_mode or self.parameters.debug: self.diff_tracker.merge(differences) - if not self.check_mode and not self.parameters.debug: - self.actions = None - - def absent(self): + def absent(self) -> None: self.remove_plugin() - def enable(self): + def enable(self) -> None: timeout = self.parameters.enable_timeout if self.existing_plugin: if not self.existing_plugin.get("Enabled"): @@ -380,7 +380,7 @@ class DockerPluginManager: self.actions.append(f"Enabled plugin {self.preferred_name}") self.changed = True - def disable(self): + def disable(self) -> None: if self.existing_plugin: if self.existing_plugin.get("Enabled"): if not self.check_mode: @@ -396,7 +396,7 @@ class DockerPluginManager: self.client.fail("Plugin not found: Plugin does not exist.") @property - def result(self): + def result(self) -> dict[str, t.Any]: plugin_data = {} if self.parameters.state != "absent": try: @@ -406,16 +406,22 @@ class DockerPluginManager: except NotFound: # This can happen in check mode pass - result = { + result: dict[str, t.Any] = { "actions": self.actions, "changed": self.changed, "diff": self.diff, "plugin": plugin_data, } - return dict((k, v) for k, v in result.items() if v is not None) + if ( + self.parameters.state == "present" + and not self.check_mode + and not self.parameters.debug + ): + result["actions"] = None + return {k: v for k, v in result.items() if v is not None} -def main(): +def main() -> None: argument_spec = { "alias": {"type": "str"}, "plugin_name": {"type": "str", "required": True}, diff --git a/plugins/modules/docker_prune.py b/plugins/modules/docker_prune.py index a2d3f80e..825b796c 100644 --- a/plugins/modules/docker_prune.py +++ b/plugins/modules/docker_prune.py @@ -247,7 +247,7 @@ from ansible_collections.community.docker.plugins.module_utils._util import ( ) -def main(): +def main() -> None: argument_spec = { "containers": {"type": "bool", "default": False}, "containers_filters": {"type": "dict"}, diff --git a/plugins/modules/docker_volume.py b/plugins/modules/docker_volume.py index 5c59fed9..0c6238ea 100644 --- a/plugins/modules/docker_volume.py +++ b/plugins/modules/docker_volume.py @@ -118,6 +118,7 @@ volume: """ import traceback +import typing as t from ansible.module_utils.common.text.converters import to_native @@ -137,31 +138,33 @@ from ansible_collections.community.docker.plugins.module_utils._util import ( class TaskParameters(DockerBaseClass): - def __init__(self, client): + volume_name: str + + def __init__(self, client: AnsibleDockerClient) -> None: super().__init__() self.client = client - self.volume_name = None - self.driver = None - self.driver_options = None - self.labels = None - self.recreate = None - self.debug = None + self.driver: str = "local" + self.driver_options: dict[str, t.Any] = {} + self.labels: dict[str, t.Any] | None = None + self.recreate: t.Literal["always", "never", "options-changed"] = "never" + self.debug: bool = False + self.state: t.Literal["present", "absent"] = "present" for key, value in client.module.params.items(): setattr(self, key, value) class DockerVolumeManager: - - def __init__(self, client): + def __init__(self, client: AnsibleDockerClient) -> None: self.client = client self.parameters = TaskParameters(client) self.check_mode = self.client.check_mode - self.results = {"changed": False, "actions": []} + self.actions: list[str] = [] + self.results: dict[str, t.Any] = {"changed": False, "actions": self.actions} self.diff = self.client.module._diff self.diff_tracker = DifferenceTracker() - self.diff_result = {} + self.diff_result: dict[str, t.Any] = {} self.existing_volume = self.get_existing_volume() @@ -178,7 +181,7 @@ class DockerVolumeManager: ) self.results["diff"] = self.diff_result - def get_existing_volume(self): + def get_existing_volume(self) -> dict[str, t.Any] | None: try: volumes = self.client.get_json("/volumes") except APIError as e: @@ -193,12 +196,13 @@ class DockerVolumeManager: return None - def has_different_config(self): + def has_different_config(self) -> DifferenceTracker: """ Return the list of differences between the current parameters and the existing volume. :return: list of options that differ """ + assert self.existing_volume is not None differences = DifferenceTracker() if ( self.parameters.driver @@ -239,7 +243,7 @@ class DockerVolumeManager: return differences - def create_volume(self): + def create_volume(self) -> None: if not self.existing_volume: if not self.check_mode: try: @@ -257,12 +261,12 @@ class DockerVolumeManager: except APIError as e: self.client.fail(to_native(e)) - self.results["actions"].append( + self.actions.append( f"Created volume {self.parameters.volume_name} with driver {self.parameters.driver}" ) self.results["changed"] = True - def remove_volume(self): + def remove_volume(self) -> None: if self.existing_volume: if not self.check_mode: try: @@ -270,12 +274,10 @@ class DockerVolumeManager: except APIError as e: self.client.fail(to_native(e)) - self.results["actions"].append( - f"Removed volume {self.parameters.volume_name}" - ) + self.actions.append(f"Removed volume {self.parameters.volume_name}") self.results["changed"] = True - def present(self): + def present(self) -> None: differences = DifferenceTracker() if self.existing_volume: differences = self.has_different_config() @@ -301,14 +303,14 @@ class DockerVolumeManager: volume_facts = self.get_existing_volume() self.results["volume"] = volume_facts - def absent(self): + def absent(self) -> None: self.diff_tracker.add( "exists", parameter=False, active=self.existing_volume is not None ) self.remove_volume() -def main(): +def main() -> None: argument_spec = { "volume_name": {"type": "str", "required": True, "aliases": ["name"]}, "state": { diff --git a/plugins/modules/docker_volume_info.py b/plugins/modules/docker_volume_info.py index 76d29d96..75ee2340 100644 --- a/plugins/modules/docker_volume_info.py +++ b/plugins/modules/docker_volume_info.py @@ -71,6 +71,7 @@ volume: """ import traceback +import typing as t from ansible_collections.community.docker.plugins.module_utils._api.errors import ( DockerException, @@ -82,7 +83,9 @@ from ansible_collections.community.docker.plugins.module_utils._common_api impor ) -def get_existing_volume(client, volume_name): +def get_existing_volume( + client: AnsibleDockerClient, volume_name: str +) -> dict[str, t.Any] | None: try: return client.get_json("/volumes/{0}", volume_name) except NotFound: @@ -91,7 +94,7 @@ def get_existing_volume(client, volume_name): client.fail(f"Error inspecting volume: {exc}") -def main(): +def main() -> None: argument_spec = { "name": {"type": "str", "required": True, "aliases": ["volume_name"]}, } diff --git a/plugins/plugin_utils/_common.py b/plugins/plugin_utils/_common.py index 1803df67..250ff436 100644 --- a/plugins/plugin_utils/_common.py +++ b/plugins/plugin_utils/_common.py @@ -7,6 +7,8 @@ from __future__ import annotations +import typing as t + from ansible.errors import AnsibleConnectionFailure from ansible.utils.display import Display @@ -18,8 +20,17 @@ from ansible_collections.community.docker.plugins.module_utils._util import ( ) +if t.TYPE_CHECKING: + from ansible.plugins import AnsiblePlugin + + class AnsibleDockerClient(AnsibleDockerClientBase): - def __init__(self, plugin, min_docker_version=None, min_docker_api_version=None): + def __init__( + self, + plugin: AnsiblePlugin, + min_docker_version: str | None = None, + min_docker_api_version: str | None = None, + ) -> None: self.plugin = plugin self.display = Display() super().__init__( @@ -27,17 +38,23 @@ class AnsibleDockerClient(AnsibleDockerClientBase): min_docker_api_version=min_docker_api_version, ) - def fail(self, msg, **kwargs): + def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn: if kwargs: msg += "\nContext:\n" + "\n".join( f" {k} = {v!r}" for (k, v) in kwargs.items() ) raise AnsibleConnectionFailure(msg) - def deprecate(self, msg, version=None, date=None, collection_name=None): + def deprecate( + self, + msg: str, + version: str | None = None, + date: str | None = None, + collection_name: str | None = None, + ) -> None: self.display.deprecated( msg, version=version, date=date, collection_name=collection_name ) - def _get_params(self): + def _get_params(self) -> dict[str, t.Any]: return {option: self.plugin.get_option(option) for option in DOCKER_COMMON_ARGS} diff --git a/plugins/plugin_utils/_common_api.py b/plugins/plugin_utils/_common_api.py index 340a7565..b878c3ba 100644 --- a/plugins/plugin_utils/_common_api.py +++ b/plugins/plugin_utils/_common_api.py @@ -7,6 +7,8 @@ from __future__ import annotations +import typing as t + from ansible.errors import AnsibleConnectionFailure from ansible.utils.display import Display @@ -18,23 +20,35 @@ from ansible_collections.community.docker.plugins.module_utils._util import ( ) +if t.TYPE_CHECKING: + from ansible.plugins import AnsiblePlugin + + class AnsibleDockerClient(AnsibleDockerClientBase): - def __init__(self, plugin, min_docker_api_version=None): + def __init__( + self, plugin: AnsiblePlugin, min_docker_api_version: str | None = None + ) -> None: self.plugin = plugin self.display = Display() super().__init__(min_docker_api_version=min_docker_api_version) - def fail(self, msg, **kwargs): + def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn: if kwargs: msg += "\nContext:\n" + "\n".join( f" {k} = {v!r}" for (k, v) in kwargs.items() ) raise AnsibleConnectionFailure(msg) - def deprecate(self, msg, version=None, date=None, collection_name=None): + def deprecate( + self, + msg: str, + version: str | None = None, + date: str | None = None, + collection_name: str | None = None, + ) -> None: self.display.deprecated( msg, version=version, date=date, collection_name=collection_name ) - def _get_params(self): + def _get_params(self) -> dict[str, t.Any]: return {option: self.plugin.get_option(option) for option in DOCKER_COMMON_ARGS} diff --git a/plugins/plugin_utils/_socket_handler.py b/plugins/plugin_utils/_socket_handler.py index 2bdf8817..fbe89c09 100644 --- a/plugins/plugin_utils/_socket_handler.py +++ b/plugins/plugin_utils/_socket_handler.py @@ -7,11 +7,23 @@ from __future__ import annotations +import typing as t + from ansible_collections.community.docker.plugins.module_utils._socket_handler import ( DockerSocketHandlerBase, ) +if t.TYPE_CHECKING: + from ansible.utils.display import Display + + from ansible_collections.community.docker.plugins.module_utils._socket_helper import ( + SocketLike, + ) + + class DockerSocketHandler(DockerSocketHandlerBase): - def __init__(self, display, sock, log=None, container=None): + def __init__( + self, display: Display, sock: SocketLike, container: str | None = None + ) -> None: super().__init__(sock, log=lambda msg: display.vvvv(msg, host=container)) diff --git a/plugins/plugin_utils/_unsafe.py b/plugins/plugin_utils/_unsafe.py index 29576dce..408b1f28 100644 --- a/plugins/plugin_utils/_unsafe.py +++ b/plugins/plugin_utils/_unsafe.py @@ -8,6 +8,7 @@ from __future__ import annotations import re +import typing as t from collections.abc import Mapping, Set from ansible.module_utils.common.collections import is_sequence @@ -21,7 +22,7 @@ _RE_TEMPLATE_CHARS = re.compile("[{}]") _RE_TEMPLATE_CHARS_BYTES = re.compile(b"[{}]") -def make_unsafe(value): +def make_unsafe(value: t.Any) -> t.Any: if value is None or isinstance(value, AnsibleUnsafe): return value diff --git a/tests/sanity/ignore-2.17.txt b/tests/sanity/ignore-2.17.txt index 12e0b26f..b2c9b3ff 100644 --- a/tests/sanity/ignore-2.17.txt +++ b/tests/sanity/ignore-2.17.txt @@ -1 +1,22 @@ +plugins/connection/docker.py no-assert +plugins/connection/docker_api.py no-assert +plugins/connection/nsenter.py no-assert +plugins/module_utils/_api/api/client.py pep8:E704 +plugins/module_utils/_api/transport/sshconn.py no-assert +plugins/module_utils/_api/utils/build.py no-assert +plugins/module_utils/_api/utils/ports.py pep8:E704 +plugins/module_utils/_api/utils/socket.py pep8:E704 +plugins/module_utils/_common_cli.py pep8:E704 +plugins/module_utils/_module_container/module.py no-assert +plugins/module_utils/_platform.py no-assert +plugins/module_utils/_socket_handler.py no-assert +plugins/module_utils/_swarm.py pep8:E704 +plugins/module_utils/_util.py pep8:E704 plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin +plugins/modules/docker_container_exec.py no-assert +plugins/modules/docker_container_exec.py pylint:unpacking-non-sequence +plugins/modules/docker_image.py no-assert +plugins/modules/docker_image_tag.py no-assert +plugins/modules/docker_login.py no-assert +plugins/modules/docker_plugin.py no-assert +plugins/modules/docker_volume.py no-assert diff --git a/tests/sanity/ignore-2.18.txt b/tests/sanity/ignore-2.18.txt index 12e0b26f..65be094d 100644 --- a/tests/sanity/ignore-2.18.txt +++ b/tests/sanity/ignore-2.18.txt @@ -1 +1,21 @@ +plugins/connection/docker.py no-assert +plugins/connection/docker_api.py no-assert +plugins/connection/nsenter.py no-assert +plugins/module_utils/_api/api/client.py pep8:E704 +plugins/module_utils/_api/transport/sshconn.py no-assert +plugins/module_utils/_api/utils/build.py no-assert +plugins/module_utils/_api/utils/ports.py pep8:E704 +plugins/module_utils/_api/utils/socket.py pep8:E704 +plugins/module_utils/_common_cli.py pep8:E704 +plugins/module_utils/_module_container/module.py no-assert +plugins/module_utils/_platform.py no-assert +plugins/module_utils/_socket_handler.py no-assert +plugins/module_utils/_swarm.py pep8:E704 +plugins/module_utils/_util.py pep8:E704 plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin +plugins/modules/docker_container_exec.py no-assert +plugins/modules/docker_image.py no-assert +plugins/modules/docker_image_tag.py no-assert +plugins/modules/docker_login.py no-assert +plugins/modules/docker_plugin.py no-assert +plugins/modules/docker_volume.py no-assert diff --git a/tests/sanity/ignore-2.19.txt b/tests/sanity/ignore-2.19.txt index 12e0b26f..b47a6747 100644 --- a/tests/sanity/ignore-2.19.txt +++ b/tests/sanity/ignore-2.19.txt @@ -1 +1,15 @@ +plugins/connection/docker.py no-assert +plugins/connection/docker_api.py no-assert +plugins/connection/nsenter.py no-assert +plugins/module_utils/_api/transport/sshconn.py no-assert +plugins/module_utils/_api/utils/build.py no-assert +plugins/module_utils/_module_container/module.py no-assert +plugins/module_utils/_platform.py no-assert +plugins/module_utils/_socket_handler.py no-assert plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin +plugins/modules/docker_container_exec.py no-assert +plugins/modules/docker_image.py no-assert +plugins/modules/docker_image_tag.py no-assert +plugins/modules/docker_login.py no-assert +plugins/modules/docker_plugin.py no-assert +plugins/modules/docker_volume.py no-assert diff --git a/tests/sanity/ignore-2.20.txt b/tests/sanity/ignore-2.20.txt index 12e0b26f..b47a6747 100644 --- a/tests/sanity/ignore-2.20.txt +++ b/tests/sanity/ignore-2.20.txt @@ -1 +1,15 @@ +plugins/connection/docker.py no-assert +plugins/connection/docker_api.py no-assert +plugins/connection/nsenter.py no-assert +plugins/module_utils/_api/transport/sshconn.py no-assert +plugins/module_utils/_api/utils/build.py no-assert +plugins/module_utils/_module_container/module.py no-assert +plugins/module_utils/_platform.py no-assert +plugins/module_utils/_socket_handler.py no-assert plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin +plugins/modules/docker_container_exec.py no-assert +plugins/modules/docker_image.py no-assert +plugins/modules/docker_image_tag.py no-assert +plugins/modules/docker_login.py no-assert +plugins/modules/docker_plugin.py no-assert +plugins/modules/docker_volume.py no-assert diff --git a/tests/sanity/ignore-2.21.txt b/tests/sanity/ignore-2.21.txt index 12e0b26f..b47a6747 100644 --- a/tests/sanity/ignore-2.21.txt +++ b/tests/sanity/ignore-2.21.txt @@ -1 +1,15 @@ +plugins/connection/docker.py no-assert +plugins/connection/docker_api.py no-assert +plugins/connection/nsenter.py no-assert +plugins/module_utils/_api/transport/sshconn.py no-assert +plugins/module_utils/_api/utils/build.py no-assert +plugins/module_utils/_module_container/module.py no-assert +plugins/module_utils/_platform.py no-assert +plugins/module_utils/_socket_handler.py no-assert plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin +plugins/modules/docker_container_exec.py no-assert +plugins/modules/docker_image.py no-assert +plugins/modules/docker_image_tag.py no-assert +plugins/modules/docker_login.py no-assert +plugins/modules/docker_plugin.py no-assert +plugins/modules/docker_volume.py no-assert diff --git a/tests/unit/plugins/module_utils/_api/transport/test_ssladapter.py b/tests/unit/plugins/module_utils/_api/transport/test_ssladapter.py index e2b71763..f1ea023f 100644 --- a/tests/unit/plugins/module_utils/_api/transport/test_ssladapter.py +++ b/tests/unit/plugins/module_utils/_api/transport/test_ssladapter.py @@ -19,7 +19,7 @@ from ansible_collections.community.docker.plugins.module_utils._api.transport im try: - from ssl import CertificateError, match_hostname + from ssl import CertificateError, match_hostname # type: ignore except ImportError: HAS_MATCH_HOSTNAME = False # pylint: disable=invalid-name else: diff --git a/tests/unit/plugins/module_utils/_api/utils/test_config.py b/tests/unit/plugins/module_utils/_api/utils/test_config.py index 93378402..cbe8bf0d 100644 --- a/tests/unit/plugins/module_utils/_api/utils/test_config.py +++ b/tests/unit/plugins/module_utils/_api/utils/test_config.py @@ -12,8 +12,8 @@ import json import os import shutil import tempfile -import typing as t import unittest +from collections.abc import Callable from unittest import mock from pytest import fixture, mark @@ -22,7 +22,7 @@ from ansible_collections.community.docker.plugins.module_utils._api.utils import class FindConfigFileTest(unittest.TestCase): - mkdir: t.Callable[[str], os.PathLike[str]] + mkdir: Callable[[str], os.PathLike[str]] @fixture(autouse=True) def tmpdir(self, tmpdir): diff --git a/tests/unit/plugins/module_utils/compose_v2_test_cases.py b/tests/unit/plugins/module_utils/compose_v2_test_cases.py index c9da99a6..be41ebd9 100644 --- a/tests/unit/plugins/module_utils/compose_v2_test_cases.py +++ b/tests/unit/plugins/module_utils/compose_v2_test_cases.py @@ -11,7 +11,7 @@ from ansible_collections.community.docker.plugins.module_utils._compose_v2 impor ) -EVENT_TEST_CASES = [ +EVENT_TEST_CASES: list[tuple[str, str, bool, bool, str, list[Event], list[Event]]] = [ # ####################################################################################################################### # ## Docker Compose 2.18.1 ############################################################################################## # #######################################################################################################################