diff --git a/plugins/module_utils/_common.py b/plugins/module_utils/_common.py index 79076c90..04588837 100644 --- a/plugins/module_utils/_common.py +++ b/plugins/module_utils/_common.py @@ -13,6 +13,7 @@ import platform import re import sys import traceback +import typing as t from collections.abc import Mapping, Sequence from ansible.module_utils.basic import AnsibleModule, missing_required_lib @@ -79,6 +80,10 @@ except ImportError: pass +if t.TYPE_CHECKING: + from collections.abc import Callable + + MIN_DOCKER_VERSION = "2.0.0" @@ -96,7 +101,9 @@ if not HAS_DOCKER_PY: pass -def _get_tls_config(fail_function, **kwargs): +def _get_tls_config( + fail_function: Callable[[str], t.NoReturn], **kwargs: t.Any +) -> TLSConfig: if "assert_hostname" in kwargs and LooseVersion(docker_version) >= LooseVersion( "7.0.0b1" ): @@ -111,17 +118,18 @@ def _get_tls_config(fail_function, **kwargs): # Filter out all None parameters kwargs = dict((k, v) for k, v in kwargs.items() if v is not None) try: - tls_config = TLSConfig(**kwargs) - return tls_config + return TLSConfig(**kwargs) except TLSParameterError as exc: fail_function(f"TLS config error: {exc}") -def is_using_tls(auth_data): +def is_using_tls(auth_data: dict[str, t.Any]) -> bool: return auth_data["tls_verify"] or auth_data["tls"] -def get_connect_params(auth_data, fail_function): +def get_connect_params( + auth_data: dict[str, t.Any], fail_function: Callable[[str], t.NoReturn] +) -> dict[str, t.Any]: if is_using_tls(auth_data): auth_data["docker_host"] = auth_data["docker_host"].replace( "tcp://", "https://" @@ -173,7 +181,11 @@ DOCKERPYUPGRADE_UPGRADE_DOCKER = "Use `pip install --upgrade docker` to upgrade. class AnsibleDockerClientBase(Client): - def __init__(self, min_docker_version=None, min_docker_api_version=None): + def __init__( + self, + min_docker_version: str | None = None, + min_docker_api_version: str | None = None, + ) -> None: if min_docker_version is None: min_docker_version = MIN_DOCKER_VERSION @@ -214,23 +226,34 @@ class AnsibleDockerClientBase(Client): f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}." ) - def log(self, msg, pretty_print=False): + def log(self, msg: t.Any, pretty_print: bool = False): pass # if self.debug: # from .util import log_debug # log_debug(msg, pretty_print=pretty_print) @abc.abstractmethod - def fail(self, msg, **kwargs): + def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn: pass - def deprecate(self, msg, version=None, date=None, collection_name=None): + @abc.abstractmethod + def deprecate( + self, + msg: str, + version: str | None = None, + date: str | None = None, + collection_name: str | None = None, + ) -> None: pass @staticmethod def _get_value( - param_name, param_value, env_variable, default_value, value_type="str" - ): + param_name: str, + param_value: t.Any, + env_variable: str | None, + default_value: t.Any | None, + value_type: t.Literal["str", "bool", "int"] = "str", + ) -> t.Any: if param_value is not None: # take module parameter value if value_type == "bool": @@ -267,11 +290,11 @@ class AnsibleDockerClientBase(Client): return default_value @abc.abstractmethod - def _get_params(self): + def _get_params(self) -> dict[str, t.Any]: pass @property - def auth_params(self): + def auth_params(self) -> dict[str, t.Any]: # Get authentication credentials. # Precedence: module parameters-> environment variables-> defaults. @@ -356,7 +379,7 @@ class AnsibleDockerClientBase(Client): return result - def _handle_ssl_error(self, error): + def _handle_ssl_error(self, error: Exception) -> t.NoReturn: match = re.match(r"hostname.*doesn\'t match (\'.*\')", str(error)) if match: hostname = self.auth_params["tls_hostname"] @@ -368,7 +391,7 @@ class AnsibleDockerClientBase(Client): ) self.fail(f"SSL Exception: {error}") - def get_container_by_id(self, container_id): + def get_container_by_id(self, container_id: str) -> dict[str, t.Any] | None: try: self.log(f"Inspecting container Id {container_id}") result = self.inspect_container(container=container_id) @@ -379,7 +402,7 @@ class AnsibleDockerClientBase(Client): except Exception as exc: # pylint: disable=broad-exception-caught self.fail(f"Error inspecting container: {exc}") - def get_container(self, name=None): + def get_container(self, name: str | None) -> dict[str, t.Any] | None: """ Lookup a container and return the inspection results. """ @@ -416,7 +439,9 @@ class AnsibleDockerClientBase(Client): return self.get_container_by_id(result["Id"]) - def get_network(self, name=None, network_id=None): + def get_network( + self, name: str | None = None, network_id: str | None = None + ) -> dict[str, t.Any] | None: """ Lookup a network and return the inspection results. """ @@ -455,7 +480,7 @@ class AnsibleDockerClientBase(Client): return result - def find_image(self, name, tag): + def find_image(self, name: str, tag: str) -> dict[str, t.Any] | None: """ Lookup an image (by name and tag) and return the inspection results. """ @@ -507,7 +532,9 @@ class AnsibleDockerClientBase(Client): self.log(f"Image {name}:{tag} not found.") return None - def find_image_by_id(self, image_id, accept_missing_image=False): + def find_image_by_id( + self, image_id: str, accept_missing_image: bool = False + ) -> dict[str, t.Any] | None: """ Lookup an image (by ID) and return the inspection results. """ @@ -526,7 +553,7 @@ class AnsibleDockerClientBase(Client): self.fail(f"Error inspecting image ID {image_id} - {exc}") return inspection - def _image_lookup(self, name, tag): + def _image_lookup(self, name: str, tag: str) -> list[dict[str, t.Any]]: """ Including a tag in the name parameter sent to the Docker SDK for Python images method does not work consistently. Instead, get the result set for name and manually check @@ -549,7 +576,9 @@ class AnsibleDockerClientBase(Client): break return images - def pull_image(self, name, tag="latest", image_platform=None): + def pull_image( + self, name: str, tag: str = "latest", image_platform: str | None = None + ) -> tuple[dict[str, t.Any] | None, bool]: """ Pull an image """ @@ -580,7 +609,7 @@ class AnsibleDockerClientBase(Client): return new_tag, old_tag == new_tag - def inspect_distribution(self, image, **kwargs): + def inspect_distribution(self, image: str, **kwargs) -> dict[str, t.Any]: """ Get image digest by directly calling the Docker API when running Docker SDK < 4.0.0 since prior versions did not support accessing private repositories. @@ -594,7 +623,7 @@ class AnsibleDockerClientBase(Client): self._url("/distribution/{0}/json", image), headers={"X-Registry-Auth": header}, ), - get_json=True, + json=True, ) return super().inspect_distribution(image, **kwargs) @@ -603,18 +632,24 @@ class AnsibleDockerClient(AnsibleDockerClientBase): def __init__( self, - argument_spec=None, - supports_check_mode=False, - mutually_exclusive=None, - required_together=None, - required_if=None, - required_one_of=None, - required_by=None, - min_docker_version=None, - min_docker_api_version=None, - option_minimal_versions=None, - option_minimal_versions_ignore_params=None, - fail_results=None, + argument_spec: dict[str, t.Any] | None = None, + supports_check_mode: bool = False, + mutually_exclusive: Sequence[Sequence[str]] | None = None, + required_together: Sequence[Sequence[str]] | None = None, + required_if: ( + Sequence[ + tuple[str, t.Any, Sequence[str]] + | tuple[str, t.Any, Sequence[str], bool] + ] + | None + ) = None, + required_one_of: Sequence[Sequence[str]] | None = None, + required_by: dict[str, Sequence[str]] | None = None, + min_docker_version: str | None = None, + min_docker_api_version: str | None = None, + option_minimal_versions: dict[str, t.Any] | None = None, + option_minimal_versions_ignore_params: Sequence[str] | None = None, + fail_results: dict[str, t.Any] | None = None, ): # Modules can put information in here which will always be returned @@ -627,12 +662,12 @@ class AnsibleDockerClient(AnsibleDockerClientBase): merged_arg_spec.update(argument_spec) self.arg_spec = merged_arg_spec - mutually_exclusive_params = [] + mutually_exclusive_params: list[Sequence[str]] = [] mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE if mutually_exclusive: mutually_exclusive_params += mutually_exclusive - required_together_params = [] + required_together_params: list[Sequence[str]] = [] required_together_params += DOCKER_REQUIRED_TOGETHER if required_together: required_together_params += required_together @@ -660,20 +695,30 @@ class AnsibleDockerClient(AnsibleDockerClientBase): option_minimal_versions, option_minimal_versions_ignore_params ) - def fail(self, msg, **kwargs): + def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn: self.fail_results.update(kwargs) self.module.fail_json(msg=msg, **sanitize_result(self.fail_results)) - def deprecate(self, msg, version=None, date=None, collection_name=None): + def deprecate( + self, + msg: str, + version: str | None = None, + date: str | None = None, + collection_name: str | None = None, + ) -> None: self.module.deprecate( msg, version=version, date=date, collection_name=collection_name ) - def _get_params(self): + def _get_params(self) -> dict[str, t.Any]: return self.module.params - def _get_minimal_versions(self, option_minimal_versions, ignore_params=None): - self.option_minimal_versions = {} + def _get_minimal_versions( + self, + option_minimal_versions: dict[str, t.Any], + ignore_params: Sequence[str] | None = None, + ) -> None: + self.option_minimal_versions: dict[str, dict[str, t.Any]] = {} for option in self.module.argument_spec: if ignore_params is not None: if option in ignore_params: @@ -724,7 +769,9 @@ class AnsibleDockerClient(AnsibleDockerClientBase): msg = f"Cannot {usg} with your configuration." self.fail(msg) - def report_warnings(self, result, warnings_key=None): + def report_warnings( + self, result: t.Any, warnings_key: Sequence[str] | None = None + ) -> None: """ Checks result of client operation for warnings, and if present, outputs them. diff --git a/plugins/module_utils/_common_api.py b/plugins/module_utils/_common_api.py index d58cf65d..faca443f 100644 --- a/plugins/module_utils/_common_api.py +++ b/plugins/module_utils/_common_api.py @@ -11,6 +11,7 @@ from __future__ import annotations import abc import os import re +import typing as t from collections.abc import Mapping, Sequence from ansible.module_utils.basic import AnsibleModule, missing_required_lib @@ -60,19 +61,26 @@ from ansible_collections.community.docker.plugins.module_utils._util import ( ) -def _get_tls_config(fail_function, **kwargs): +if t.TYPE_CHECKING: + from collections.abc import Callable + + +def _get_tls_config( + fail_function: Callable[[str], t.NoReturn], **kwargs: t.Any +) -> TLSConfig: try: - tls_config = TLSConfig(**kwargs) - return tls_config + return TLSConfig(**kwargs) except TLSParameterError as exc: fail_function(f"TLS config error: {exc}") -def is_using_tls(auth_data): +def is_using_tls(auth_data: dict[str, t.Any]) -> bool: return auth_data["tls_verify"] or auth_data["tls"] -def get_connect_params(auth_data, fail_function): +def get_connect_params( + auth_data: dict[str, t.Any], fail_function: Callable[[str], t.NoReturn] +) -> dict[str, t.Any]: if is_using_tls(auth_data): auth_data["docker_host"] = auth_data["docker_host"].replace( "tcp://", "https://" @@ -114,7 +122,7 @@ def get_connect_params(auth_data, fail_function): class AnsibleDockerClientBase(Client): - def __init__(self, min_docker_api_version=None): + def __init__(self, min_docker_api_version: str | None = None) -> None: self._connect_params = get_connect_params( self.auth_params, fail_function=self.fail ) @@ -138,23 +146,34 @@ class AnsibleDockerClientBase(Client): f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}." ) - def log(self, msg, pretty_print=False): + def log(self, msg: t.Any, pretty_print: bool = False): pass # if self.debug: # from .util import log_debug # log_debug(msg, pretty_print=pretty_print) @abc.abstractmethod - def fail(self, msg, **kwargs): + def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn: pass - def deprecate(self, msg, version=None, date=None, collection_name=None): + @abc.abstractmethod + def deprecate( + self, + msg: str, + version: str | None = None, + date: str | None = None, + collection_name: str | None = None, + ) -> None: pass @staticmethod def _get_value( - param_name, param_value, env_variable, default_value, value_type="str" - ): + param_name: str, + param_value: t.Any, + env_variable: str | None, + default_value: t.Any | None, + value_type: t.Literal["str", "bool", "int"] = "str", + ) -> t.Any: if param_value is not None: # take module parameter value if value_type == "bool": @@ -191,11 +210,11 @@ class AnsibleDockerClientBase(Client): return default_value @abc.abstractmethod - def _get_params(self): + def _get_params(self) -> dict[str, t.Any]: pass @property - def auth_params(self): + def auth_params(self) -> dict[str, t.Any]: # Get authentication credentials. # Precedence: module parameters-> environment variables-> defaults. @@ -288,7 +307,7 @@ class AnsibleDockerClientBase(Client): return result - def _handle_ssl_error(self, error): + def _handle_ssl_error(self, error: Exception) -> t.NoReturn: match = re.match(r"hostname.*doesn\'t match (\'.*\')", str(error)) if match: hostname = self.auth_params["tls_hostname"] @@ -300,7 +319,7 @@ class AnsibleDockerClientBase(Client): ) self.fail(f"SSL Exception: {error}") - def get_container_by_id(self, container_id): + def get_container_by_id(self, container_id: str) -> dict[str, t.Any] | None: try: self.log(f"Inspecting container Id {container_id}") result = self.get_json("/containers/{0}/json", container_id) @@ -311,7 +330,7 @@ class AnsibleDockerClientBase(Client): except Exception as exc: # pylint: disable=broad-exception-caught self.fail(f"Error inspecting container: {exc}") - def get_container(self, name=None): + def get_container(self, name: str | None) -> dict[str, t.Any] | None: """ Lookup a container and return the inspection results. """ @@ -355,7 +374,9 @@ class AnsibleDockerClientBase(Client): return self.get_container_by_id(result["Id"]) - def get_network(self, name=None, network_id=None): + def get_network( + self, name: str | None = None, network_id: str | None = None + ) -> dict[str, t.Any] | None: """ Lookup a network and return the inspection results. """ @@ -395,14 +416,14 @@ class AnsibleDockerClientBase(Client): return result - def _image_lookup(self, name, tag): + def _image_lookup(self, name: str, tag: str) -> list[dict[str, t.Any]]: """ Including a tag in the name parameter sent to the Docker SDK for Python images method does not work consistently. Instead, get the result set for name and manually check if the tag exists. """ try: - params = { + params: dict[str, t.Any] = { "only_ids": 0, "all": 0, } @@ -427,7 +448,7 @@ class AnsibleDockerClientBase(Client): break return images - def find_image(self, name, tag): + def find_image(self, name: str, tag: str) -> dict[str, t.Any] | None: """ Lookup an image (by name and tag) and return the inspection results. """ @@ -478,7 +499,9 @@ class AnsibleDockerClientBase(Client): self.log(f"Image {name}:{tag} not found.") return None - def find_image_by_id(self, image_id, accept_missing_image=False): + def find_image_by_id( + self, image_id: str, accept_missing_image: bool = False + ) -> dict[str, t.Any] | None: """ Lookup an image (by ID) and return the inspection results. """ @@ -496,7 +519,9 @@ class AnsibleDockerClientBase(Client): except Exception as exc: # pylint: disable=broad-exception-caught self.fail(f"Error inspecting image ID {image_id} - {exc}") - def pull_image(self, name, tag="latest", image_platform=None): + def pull_image( + self, name: str, tag: str = "latest", image_platform: str | None = None + ) -> tuple[dict[str, t.Any] | None, bool]: """ Pull an image """ @@ -547,17 +572,23 @@ class AnsibleDockerClient(AnsibleDockerClientBase): def __init__( self, - argument_spec=None, - supports_check_mode=False, - mutually_exclusive=None, - required_together=None, - required_if=None, - required_one_of=None, - required_by=None, - min_docker_api_version=None, - option_minimal_versions=None, - option_minimal_versions_ignore_params=None, - fail_results=None, + argument_spec: dict[str, t.Any] | None = None, + supports_check_mode: bool = False, + mutually_exclusive: Sequence[Sequence[str]] | None = None, + required_together: Sequence[Sequence[str]] | None = None, + required_if: ( + Sequence[ + tuple[str, t.Any, Sequence[str]] + | tuple[str, t.Any, Sequence[str], bool] + ] + | None + ) = None, + required_one_of: Sequence[Sequence[str]] | None = None, + required_by: dict[str, Sequence[str]] | None = None, + min_docker_api_version: str | None = None, + option_minimal_versions: dict[str, t.Any] | None = None, + option_minimal_versions_ignore_params: Sequence[str] | None = None, + fail_results: dict[str, t.Any] | None = None, ): # Modules can put information in here which will always be returned @@ -570,12 +601,12 @@ class AnsibleDockerClient(AnsibleDockerClientBase): merged_arg_spec.update(argument_spec) self.arg_spec = merged_arg_spec - mutually_exclusive_params = [] + mutually_exclusive_params: list[Sequence[str]] = [] mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE if mutually_exclusive: mutually_exclusive_params += mutually_exclusive - required_together_params = [] + required_together_params: list[Sequence[str]] = [] required_together_params += DOCKER_REQUIRED_TOGETHER if required_together: required_together_params += required_together @@ -600,20 +631,30 @@ class AnsibleDockerClient(AnsibleDockerClientBase): option_minimal_versions, option_minimal_versions_ignore_params ) - def fail(self, msg, **kwargs): + def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn: self.fail_results.update(kwargs) self.module.fail_json(msg=msg, **sanitize_result(self.fail_results)) - def deprecate(self, msg, version=None, date=None, collection_name=None): + def deprecate( + self, + msg: str, + version: str | None = None, + date: str | None = None, + collection_name: str | None = None, + ) -> None: self.module.deprecate( msg, version=version, date=date, collection_name=collection_name ) - def _get_params(self): + def _get_params(self) -> dict[str, t.Any]: return self.module.params - def _get_minimal_versions(self, option_minimal_versions, ignore_params=None): - self.option_minimal_versions = {} + def _get_minimal_versions( + self, + option_minimal_versions: dict[str, t.Any], + ignore_params: Sequence[str] | None = None, + ) -> None: + self.option_minimal_versions: dict[str, dict[str, t.Any]] = {} for option in self.module.argument_spec: if ignore_params is not None: if option in ignore_params: @@ -654,7 +695,9 @@ class AnsibleDockerClient(AnsibleDockerClientBase): msg = f"Cannot {usg} with your configuration." self.fail(msg) - def report_warnings(self, result, warnings_key=None): + def report_warnings( + self, result: t.Any, warnings_key: Sequence[str] | None = None + ) -> None: """ Checks result of client operation for warnings, and if present, outputs them. diff --git a/plugins/module_utils/_common_cli.py b/plugins/module_utils/_common_cli.py index fb9a316b..166d0e41 100644 --- a/plugins/module_utils/_common_cli.py +++ b/plugins/module_utils/_common_cli.py @@ -10,6 +10,7 @@ from __future__ import annotations import abc import json import shlex +import typing as t from ansible.module_utils.basic import AnsibleModule, env_fallback from ansible.module_utils.common.process import get_bin_path @@ -31,6 +32,10 @@ from ansible_collections.community.docker.plugins.module_utils._version import ( ) +if t.TYPE_CHECKING: + from collections.abc import Mapping, Sequence + + DOCKER_COMMON_ARGS = { "docker_cli": {"type": "path"}, "docker_host": { @@ -72,10 +77,16 @@ class DockerException(Exception): class AnsibleDockerClientBase: + docker_api_version_str: str | None + docker_api_version: LooseVersion | None + def __init__( - self, common_args, min_docker_api_version=None, needs_api_version=True - ): - self._environment = {} + self, + common_args, + min_docker_api_version: str | None = None, + needs_api_version: bool = True, + ) -> None: + self._environment: dict[str, str] = {} if common_args["tls_hostname"]: self._environment["DOCKER_TLS_HOSTNAME"] = common_args["tls_hostname"] if common_args["api_version"] and common_args["api_version"] != "auto": @@ -109,10 +120,10 @@ class AnsibleDockerClientBase: self._cli_base.extend(["--context", common_args["cli_context"]]) # `--format json` was only added as a shorthand for `--format {{ json . }}` in Docker 23.0 - dummy, self._version, dummy = self.call_cli_json( + dummy, self._version, dummy2 = self.call_cli_json( "version", "--format", "{{ json . }}", check_rc=True ) - self._info = None + self._info: dict[str, t.Any] | None = None if needs_api_version: if not isinstance(self._version.get("Server"), dict) or not isinstance( @@ -138,32 +149,47 @@ class AnsibleDockerClientBase: "Internal error: cannot have needs_api_version=False with min_docker_api_version not None" ) - def log(self, msg, pretty_print=False): + def log(self, msg: str, pretty_print: bool = False): pass # if self.debug: # from .util import log_debug # log_debug(msg, pretty_print=pretty_print) - def get_cli(self): + def get_cli(self) -> str: return self._cli - def get_version_info(self): + def get_version_info(self) -> str: return self._version - def _compose_cmd(self, args): + def _compose_cmd(self, args: t.Sequence[str]) -> list[str]: return self._cli_base + list(args) - def _compose_cmd_str(self, args): + def _compose_cmd_str(self, args: t.Sequence[str]) -> str: return " ".join(shlex.quote(a) for a in self._compose_cmd(args)) @abc.abstractmethod - def call_cli(self, *args, check_rc=False, data=None, cwd=None, environ_update=None): + def call_cli( + self, + *args: str, + check_rc: bool = False, + data: bytes | None = None, + cwd: str | None = None, + environ_update: dict[str, str] | None = None, + ) -> tuple[int, bytes, bytes]: pass - # def call_cli_json(self, *args, check_rc=False, data=None, cwd=None, environ_update=None, warn_on_stderr=False): - def call_cli_json(self, *args, **kwargs): - warn_on_stderr = kwargs.pop("warn_on_stderr", False) - rc, stdout, stderr = self.call_cli(*args, **kwargs) + def call_cli_json( + self, + *args: str, + check_rc: bool = False, + data: bytes | None = None, + cwd: str | None = None, + environ_update: dict[str, str] | None = None, + warn_on_stderr: bool = False, + ) -> tuple[int, t.Any, bytes]: + rc, stdout, stderr = self.call_cli( + *args, check_rc=check_rc, data=data, cwd=cwd, environ_update=environ_update + ) if warn_on_stderr and stderr: self.warn(to_native(stderr)) try: @@ -174,10 +200,18 @@ class AnsibleDockerClientBase: ) return rc, data, stderr - # def call_cli_json_stream(self, *args, check_rc=False, data=None, cwd=None, environ_update=None, warn_on_stderr=False): - def call_cli_json_stream(self, *args, **kwargs): - warn_on_stderr = kwargs.pop("warn_on_stderr", False) - rc, stdout, stderr = self.call_cli(*args, **kwargs) + def call_cli_json_stream( + self, + *args: str, + check_rc: bool = False, + data: bytes | None = None, + cwd: str | None = None, + environ_update: dict[str, str] | None = None, + warn_on_stderr: bool = False, + ) -> tuple[int, list[t.Any], bytes]: + rc, stdout, stderr = self.call_cli( + *args, check_rc=check_rc, data=data, cwd=cwd, environ_update=environ_update + ) if warn_on_stderr and stderr: self.warn(to_native(stderr)) result = [] @@ -193,25 +227,31 @@ class AnsibleDockerClientBase: return rc, result, stderr @abc.abstractmethod - def fail(self, msg, **kwargs): + def fail(self, msg: str, **kwargs) -> t.NoReturn: pass @abc.abstractmethod - def warn(self, msg): + def warn(self, msg: str) -> None: pass @abc.abstractmethod - def deprecate(self, msg, version=None, date=None, collection_name=None): + def deprecate( + self, + msg: str, + version: str | None = None, + date: str | None = None, + collection_name: str | None = None, + ) -> None: pass - def get_cli_info(self): + def get_cli_info(self) -> dict[str, t.Any]: if self._info is None: - dummy, self._info, dummy = self.call_cli_json( + dummy, self._info, dummy2 = self.call_cli_json( "info", "--format", "{{ json . }}", check_rc=True ) return self._info - def get_client_plugin_info(self, component): + def get_client_plugin_info(self, component: str) -> dict[str, t.Any] | None: cli_info = self.get_cli_info() if not isinstance(cli_info.get("ClientInfo"), dict): self.fail( @@ -222,13 +262,13 @@ class AnsibleDockerClientBase: return plugin return None - def _image_lookup(self, name, tag): + def _image_lookup(self, name: str, tag: str) -> list[dict[str, t.Any]]: """ Including a tag in the name parameter sent to the Docker SDK for Python images method does not work consistently. Instead, get the result set for name and manually check if the tag exists. """ - dummy, images, dummy = self.call_cli_json_stream( + dummy, images, dummy2 = self.call_cli_json_stream( "image", "ls", "--format", @@ -247,7 +287,13 @@ class AnsibleDockerClientBase: break return images - def find_image(self, name, tag): + @t.overload + def find_image(self, name: None, tag: str) -> None: ... + + @t.overload + def find_image(self, name: str, tag: str) -> dict[str, t.Any] | None: ... + + def find_image(self, name: str | None, tag: str) -> dict[str, t.Any] | None: """ Lookup an image (by name and tag) and return the inspection results. """ @@ -298,7 +344,19 @@ class AnsibleDockerClientBase: self.log(f"Image {name}:{tag} not found.") return None - def find_image_by_id(self, image_id, accept_missing_image=False): + @t.overload + def find_image_by_id( + self, image_id: None, accept_missing_image: bool = False + ) -> None: ... + + @t.overload + def find_image_by_id( + self, image_id: str | None, accept_missing_image: bool = False + ) -> dict[str, t.Any] | None: ... + + def find_image_by_id( + self, image_id: str | None, accept_missing_image: bool = False + ) -> dict[str, t.Any] | None: """ Lookup an image (by ID) and return the inspection results. """ @@ -320,17 +378,23 @@ class AnsibleDockerClientBase: class AnsibleModuleDockerClient(AnsibleDockerClientBase): def __init__( self, - argument_spec=None, - supports_check_mode=False, - mutually_exclusive=None, - required_together=None, - required_if=None, - required_one_of=None, - required_by=None, - min_docker_api_version=None, - fail_results=None, - needs_api_version=True, - ): + argument_spec: dict[str, t.Any] | None = None, + supports_check_mode: bool = False, + mutually_exclusive: Sequence[Sequence[str]] | None = None, + required_together: Sequence[Sequence[str]] | None = None, + required_if: ( + Sequence[ + tuple[str, t.Any, Sequence[str]] + | tuple[str, t.Any, Sequence[str], bool] + ] + | None + ) = None, + required_one_of: Sequence[Sequence[str]] | None = None, + required_by: Mapping[str, Sequence[str]] | None = None, + min_docker_api_version: str | None = None, + fail_results: dict[str, t.Any] | None = None, + needs_api_version: bool = True, + ) -> None: # Modules can put information in here which will always be returned # in case client.fail() is called. @@ -342,12 +406,14 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase): merged_arg_spec.update(argument_spec) self.arg_spec = merged_arg_spec - mutually_exclusive_params = [("docker_host", "cli_context")] + mutually_exclusive_params: list[Sequence[str]] = [ + ("docker_host", "cli_context") + ] mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE if mutually_exclusive: mutually_exclusive_params += mutually_exclusive - required_together_params = [] + required_together_params: list[Sequence[str]] = [] required_together_params += DOCKER_REQUIRED_TOGETHER if required_together: required_together_params += required_together @@ -373,7 +439,14 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase): needs_api_version=needs_api_version, ) - def call_cli(self, *args, check_rc=False, data=None, cwd=None, environ_update=None): + def call_cli( + self, + *args: str, + check_rc: bool = False, + data: bytes | None = None, + cwd: str | None = None, + environ_update: dict[str, str] | None = None, + ) -> tuple[int, bytes, bytes]: environment = self._environment.copy() if environ_update: environment.update(environ_update) @@ -390,14 +463,20 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase): ) return rc, stdout, stderr - def fail(self, msg, **kwargs): + def fail(self, msg: str, **kwargs) -> t.NoReturn: self.fail_results.update(kwargs) self.module.fail_json(msg=msg, **sanitize_result(self.fail_results)) - def warn(self, msg): + def warn(self, msg: str) -> None: self.module.warn(msg) - def deprecate(self, msg, version=None, date=None, collection_name=None): + def deprecate( + self, + msg: str, + version: str | None = None, + date: str | None = None, + collection_name: str | None = None, + ) -> None: self.module.deprecate( msg, version=version, date=date, collection_name=collection_name ) diff --git a/plugins/module_utils/_compose_v2.py b/plugins/module_utils/_compose_v2.py index 9bac4671..fc295d60 100644 --- a/plugins/module_utils/_compose_v2.py +++ b/plugins/module_utils/_compose_v2.py @@ -51,6 +51,13 @@ else: HAS_PYYAML = True PYYAML_IMPORT_ERROR = None # pylint: disable=invalid-name +if t.TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from ansible_collections.community.docker.plugins.module_utils._common_cli import ( + AnsibleModuleDockerClient as _Client, + ) + DOCKER_COMPOSE_FILES = ( "compose.yaml", @@ -241,7 +248,9 @@ _RE_BUILD_PROGRESS_EVENT = re.compile(r"^\s*==>\s+(?P.*)$") MINIMUM_COMPOSE_VERSION = "2.18.0" -def _extract_event(line, warn_function=None): +def _extract_event( + line: str, warn_function: Callable[[str], None] | None = None +) -> tuple[Event | None, bool]: match = _RE_RESOURCE_EVENT.match(line) if match is not None: status = match.group("status") @@ -324,7 +333,9 @@ def _extract_event(line, warn_function=None): return None, False -def _extract_logfmt_event(line, warn_function=None): +def _extract_logfmt_event( + line: str, warn_function: Callable[[str], None] | None = None +) -> tuple[Event | None, bool]: try: result = _parse_logfmt_line(line, logrus_mode=True) except _InvalidLogFmt: @@ -339,7 +350,11 @@ def _extract_logfmt_event(line, warn_function=None): return None, False -def _warn_missing_dry_run_prefix(line, warn_missing_dry_run_prefix, warn_function): +def _warn_missing_dry_run_prefix( + line: str, + warn_missing_dry_run_prefix: bool, + warn_function: Callable[[str], None] | None, +) -> None: if warn_missing_dry_run_prefix and warn_function: # This could be a bug, a change of docker compose's output format, ... # Tell the user to report it to us :-) @@ -350,7 +365,9 @@ def _warn_missing_dry_run_prefix(line, warn_missing_dry_run_prefix, warn_functio ) -def _warn_unparsable_line(line, warn_function): +def _warn_unparsable_line( + line: str, warn_function: Callable[[str], None] | None +) -> None: # This could be a bug, a change of docker compose's output format, ... # Tell the user to report it to us :-) if warn_function: @@ -361,14 +378,16 @@ def _warn_unparsable_line(line, warn_function): ) -def _find_last_event_for(events, resource_id): +def _find_last_event_for( + events: list[Event], resource_id: str +) -> tuple[int, Event] | None: for index, event in enumerate(reversed(events)): if event.resource_id == resource_id: return len(events) - 1 - index, event return None -def _concat_event_msg(event, append_msg): +def _concat_event_msg(event: Event, append_msg: str) -> Event: return Event( event.resource_type, event.resource_id, @@ -383,7 +402,9 @@ _JSON_LEVEL_TO_STATUS_MAP = { } -def parse_json_events(stderr, warn_function=None): +def parse_json_events( + stderr: bytes, warn_function: Callable[[str], None] | None = None +) -> list[Event]: events = [] stderr_lines = stderr.splitlines() if stderr_lines and stderr_lines[-1] == b"": @@ -524,7 +545,12 @@ def parse_json_events(stderr, warn_function=None): return events -def parse_events(stderr, dry_run=False, warn_function=None, nonzero_rc=False): +def parse_events( + stderr: bytes, + dry_run: bool = False, + warn_function: Callable[[str], None] | None = None, + nonzero_rc: bool = False, +) -> list[Event]: events = [] error_event = None stderr_lines = stderr.splitlines() @@ -598,7 +624,11 @@ def parse_events(stderr, dry_run=False, warn_function=None, nonzero_rc=False): return events -def has_changes(events, ignore_service_pull_events=False, ignore_build_events=False): +def has_changes( + events: Sequence[Event], + ignore_service_pull_events: bool = False, + ignore_build_events: bool = False, +) -> bool: for event in events: if event.status in DOCKER_STATUS_WORKING: if ignore_service_pull_events and event.status in DOCKER_STATUS_PULL: @@ -614,7 +644,7 @@ def has_changes(events, ignore_service_pull_events=False, ignore_build_events=Fa return False -def extract_actions(events): +def extract_actions(events: Sequence[Event]) -> list[dict[str, t.Any]]: actions = [] pull_actions = set() for event in events: @@ -646,7 +676,9 @@ def extract_actions(events): return actions -def emit_warnings(events, warn_function): +def emit_warnings( + events: Sequence[Event], warn_function: Callable[[str], None] +) -> None: for event in events: # If a message is present, assume it is a warning if ( @@ -657,13 +689,21 @@ def emit_warnings(events, warn_function): ) -def is_failed(events, rc): +def is_failed(events: Sequence[Event], rc: int) -> bool: if rc: return True return False -def update_failed(result, events, args, stdout, stderr, rc, cli): +def update_failed( + result: dict[str, t.Any], + events: Sequence[Event], + args: list[str], + stdout: str | bytes, + stderr: str | bytes, + rc: int, + cli: str, +) -> bool: if not rc: return False errors = [] @@ -697,7 +737,7 @@ def update_failed(result, events, args, stdout, stderr, rc, cli): return True -def common_compose_argspec(): +def common_compose_argspec() -> dict[str, t.Any]: return { "project_src": {"type": "path"}, "project_name": {"type": "str"}, @@ -709,7 +749,7 @@ def common_compose_argspec(): } -def common_compose_argspec_ex(): +def common_compose_argspec_ex() -> dict[str, t.Any]: return { "argspec": common_compose_argspec(), "mutually_exclusive": [("definition", "project_src"), ("definition", "files")], @@ -722,16 +762,18 @@ def common_compose_argspec_ex(): } -def combine_binary_output(*outputs): +def combine_binary_output(*outputs: bytes | None) -> bytes: return b"\n".join(out for out in outputs if out) -def combine_text_output(*outputs): +def combine_text_output(*outputs: str | None) -> str: return "\n".join(out for out in outputs if out) class BaseComposeManager(DockerBaseClass): - def __init__(self, client, min_version=MINIMUM_COMPOSE_VERSION): + def __init__( + self, client: _Client, min_version: str = MINIMUM_COMPOSE_VERSION + ) -> None: super().__init__() self.client = client self.check_mode = self.client.check_mode @@ -795,12 +837,12 @@ class BaseComposeManager(DockerBaseClass): # more precisely in https://github.com/docker/compose/pull/11478 self.use_json_events = self.compose_version >= LooseVersion("2.29.0") - def get_compose_version(self): + def get_compose_version(self) -> str: return ( self.get_compose_version_from_cli() or self.get_compose_version_from_api() ) - def get_compose_version_from_cli(self): + def get_compose_version_from_cli(self) -> str | None: rc, version_info, dummy_stderr = self.client.call_cli( "compose", "version", "--format", "json" ) @@ -814,7 +856,7 @@ class BaseComposeManager(DockerBaseClass): except Exception: # pylint: disable=broad-exception-caught return None - def get_compose_version_from_api(self): + def get_compose_version_from_api(self) -> str: compose = self.client.get_client_plugin_info("compose") if compose is None: self.fail( @@ -827,11 +869,11 @@ class BaseComposeManager(DockerBaseClass): ) return compose["Version"].lstrip("v") - def fail(self, msg, **kwargs): + def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn: self.cleanup() self.client.fail(msg, **kwargs) - def get_base_args(self, plain_progress=False): + def get_base_args(self, plain_progress: bool = False) -> list[str]: args = ["compose", "--ansi", "never"] if self.use_json_events and not plain_progress: args.extend(["--progress", "json"]) @@ -849,28 +891,33 @@ class BaseComposeManager(DockerBaseClass): args.extend(["--profile", profile]) return args - def _handle_failed_cli_call(self, args, rc, stdout, stderr): + def _handle_failed_cli_call( + self, args: list[str], rc: int, stdout: str | bytes, stderr: bytes + ) -> t.NoReturn: events = parse_json_events(stderr, warn_function=self.client.warn) - result = {} + result: dict[str, t.Any] = {} self.update_failed(result, events, args, stdout, stderr, rc) self.client.module.exit_json(**result) - def list_containers_raw(self): + def list_containers_raw(self) -> list[dict[str, t.Any]]: args = self.get_base_args() + ["ps", "--format", "json", "--all"] if self.compose_version >= LooseVersion("2.23.0"): # https://github.com/docker/compose/pull/11038 args.append("--no-trunc") - kwargs = {"cwd": self.project_src, "check_rc": not self.use_json_events} if self.compose_version >= LooseVersion("2.21.0"): # Breaking change in 2.21.0: https://github.com/docker/compose/pull/10918 - rc, containers, stderr = self.client.call_cli_json_stream(*args, **kwargs) + rc, containers, stderr = self.client.call_cli_json_stream( + *args, cwd=self.project_src, check_rc=not self.use_json_events + ) else: - rc, containers, stderr = self.client.call_cli_json(*args, **kwargs) + rc, containers, stderr = self.client.call_cli_json( + *args, cwd=self.project_src, check_rc=not self.use_json_events + ) if self.use_json_events and rc != 0: - self._handle_failed_cli_call(args, rc, containers, stderr) + self._handle_failed_cli_call(args, rc, json.dumps(containers), stderr) return containers - def list_containers(self): + def list_containers(self) -> list[dict[str, t.Any]]: result = [] for container in self.list_containers_raw(): labels = {} @@ -887,10 +934,11 @@ class BaseComposeManager(DockerBaseClass): result.append(container) return result - def list_images(self): + def list_images(self) -> list[str]: args = self.get_base_args() + ["images", "--format", "json"] - kwargs = {"cwd": self.project_src, "check_rc": not self.use_json_events} - rc, images, stderr = self.client.call_cli_json(*args, **kwargs) + rc, images, stderr = self.client.call_cli_json( + *args, cwd=self.project_src, check_rc=not self.use_json_events + ) if self.use_json_events and rc != 0: self._handle_failed_cli_call(args, rc, images, stderr) if isinstance(images, dict): @@ -900,7 +948,9 @@ class BaseComposeManager(DockerBaseClass): images = list(images.values()) return images - def parse_events(self, stderr, dry_run=False, nonzero_rc=False): + def parse_events( + self, stderr: bytes, dry_run: bool = False, nonzero_rc: bool = False + ) -> list[Event]: if self.use_json_events: return parse_json_events(stderr, warn_function=self.client.warn) return parse_events( @@ -910,17 +960,17 @@ class BaseComposeManager(DockerBaseClass): nonzero_rc=nonzero_rc, ) - def emit_warnings(self, events): + def emit_warnings(self, events: Sequence[Event]) -> None: emit_warnings(events, warn_function=self.client.warn) def update_result( self, - result, - events, - stdout, - stderr, - ignore_service_pull_events=False, - ignore_build_events=False, + result: dict[str, t.Any], + events: Sequence[Event], + stdout: str, + stderr: str, + ignore_service_pull_events: bool = False, + ignore_build_events: bool = False, ): result["changed"] = result.get("changed", False) or has_changes( events, @@ -931,7 +981,15 @@ class BaseComposeManager(DockerBaseClass): result["stdout"] = combine_text_output(result.get("stdout"), to_native(stdout)) result["stderr"] = combine_text_output(result.get("stderr"), to_native(stderr)) - def update_failed(self, result, events, args, stdout, stderr, rc): + def update_failed( + self, + result: dict[str, t.Any], + events: Sequence[Event], + args: list[str], + stdout: str | bytes, + stderr: bytes, + rc: int, + ): return update_failed( result, events, @@ -942,14 +1000,14 @@ class BaseComposeManager(DockerBaseClass): cli=self.client.get_cli(), ) - def cleanup_result(self, result): + def cleanup_result(self, result: dict[str, t.Any]) -> None: if not result.get("failed"): # Only return stdout and stderr if it is not empty for res in ("stdout", "stderr"): if result.get(res) == "": result.pop(res) - def cleanup(self): + def cleanup(self) -> None: for directory in self.cleanup_dirs: try: shutil.rmtree(directory, True) diff --git a/plugins/module_utils/_copy.py b/plugins/module_utils/_copy.py index 6ee8479a..db1808d6 100644 --- a/plugins/module_utils/_copy.py +++ b/plugins/module_utils/_copy.py @@ -16,6 +16,7 @@ import os.path import shutil import stat import tarfile +import typing as t from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text @@ -25,6 +26,16 @@ from ansible_collections.community.docker.plugins.module_utils._api.errors impor ) +if t.TYPE_CHECKING: + from collections.abc import Callable + + from _typeshed import WriteableBuffer + + from ansible_collections.community.docker.plugins.module_utils._api.api.client import ( + APIClient, + ) + + class DockerFileCopyError(Exception): pass @@ -37,7 +48,9 @@ class DockerFileNotFound(DockerFileCopyError): pass -def _put_archive(client, container, path, data): +def _put_archive( + client: APIClient, container: str, path: str, data: bytes | t.Generator[bytes] +) -> bool: # data can also be file object for streaming. This is because _put uses requests's put(). # See https://requests.readthedocs.io/en/latest/user/advanced/#streaming-uploads url = client._url("/containers/{0}/archive", container) @@ -47,8 +60,14 @@ def _put_archive(client, container, path, data): def _symlink_tar_creator( - b_in_path, file_stat, out_file, user_id, group_id, mode=None, user_name=None -): + b_in_path: bytes, + file_stat: os.stat_result, + out_file: str | bytes, + user_id: int, + group_id: int, + mode: int | None = None, + user_name: str | None = None, +) -> bytes: if not stat.S_ISLNK(file_stat.st_mode): raise DockerUnexpectedError("stat information is not for a symlink") bio = io.BytesIO() @@ -75,16 +94,28 @@ def _symlink_tar_creator( def _symlink_tar_generator( - b_in_path, file_stat, out_file, user_id, group_id, mode=None, user_name=None -): + b_in_path: bytes, + file_stat: os.stat_result, + out_file: str | bytes, + user_id: int, + group_id: int, + mode: int | None = None, + user_name: str | None = None, +) -> t.Generator[bytes]: yield _symlink_tar_creator( b_in_path, file_stat, out_file, user_id, group_id, mode, user_name ) def _regular_file_tar_generator( - b_in_path, file_stat, out_file, user_id, group_id, mode=None, user_name=None -): + b_in_path: bytes, + file_stat: os.stat_result, + out_file: str | bytes, + user_id: int, + group_id: int, + mode: int | None = None, + user_name: str | None = None, +) -> t.Generator[bytes]: if not stat.S_ISREG(file_stat.st_mode): raise DockerUnexpectedError("stat information is not for a regular file") tarinfo = tarfile.TarInfo() @@ -136,8 +167,13 @@ def _regular_file_tar_generator( def _regular_content_tar_generator( - content, out_file, user_id, group_id, mode, user_name=None -): + content: bytes, + out_file: str | bytes, + user_id: int, + group_id: int, + mode: int, + user_name: str | None = None, +) -> t.Generator[bytes]: tarinfo = tarfile.TarInfo() tarinfo.name = ( os.path.splitdrive(to_text(out_file))[1].replace(os.sep, "/").lstrip("/") @@ -175,16 +211,16 @@ def _regular_content_tar_generator( def put_file( - client, - container, - in_path, - out_path, - user_id, - group_id, - mode=None, - user_name=None, - follow_links=False, -): + client: APIClient, + container: str, + in_path: str, + out_path: str, + user_id: int, + group_id: int, + mode: int | None = None, + user_name: str | None = None, + follow_links: bool = False, +) -> None: """Transfer a file from local to Docker container.""" if not os.path.exists(to_bytes(in_path, errors="surrogate_or_strict")): raise DockerFileNotFound(f"file or module does not exist: {to_native(in_path)}") @@ -232,8 +268,15 @@ def put_file( def put_file_content( - client, container, content, out_path, user_id, group_id, mode, user_name=None -): + client: APIClient, + container: str, + content: bytes, + out_path: str, + user_id: int, + group_id: int, + mode: int, + user_name: str | None = None, +) -> None: """Transfer a file from local to Docker container.""" out_dir, out_file = os.path.split(out_path) @@ -248,7 +291,13 @@ def put_file_content( ) -def stat_file(client, container, in_path, follow_links=False, log=None): +def stat_file( + client: APIClient, + container: str, + in_path: str, + follow_links: bool = False, + log: Callable[[str], None] | None = None, +) -> tuple[str | bytes, dict[str, t.Any] | None, str | None]: """Fetch information on a file from a Docker container to local. Return a tuple ``(path, stat_data, link_target)`` where: @@ -265,12 +314,12 @@ def stat_file(client, container, in_path, follow_links=False, log=None): while True: if in_path in considered_in_paths: raise DockerFileCopyError( - f'Found infinite symbolic link loop when trying to stating "{in_path}"' + f"Found infinite symbolic link loop when trying to stating {in_path!r}" ) considered_in_paths.add(in_path) if log: - log(f'FETCH: Stating "{in_path}"') + log(f"FETCH: Stating {in_path!r}") response = client._head( client._url("/containers/{0}/archive", container), @@ -299,24 +348,24 @@ def stat_file(client, container, in_path, follow_links=False, log=None): class _RawGeneratorFileobj(io.RawIOBase): - def __init__(self, stream): + def __init__(self, stream: t.Generator[bytes]): self._stream = stream self._buf = b"" - def readable(self): + def readable(self) -> bool: return True - def _readinto_from_buf(self, b, index, length): + def _readinto_from_buf(self, b: WriteableBuffer, index: int, length: int) -> int: cpy = min(length - index, len(self._buf)) if cpy: - b[index : index + cpy] = self._buf[:cpy] + b[index : index + cpy] = self._buf[:cpy] # type: ignore # TODO! self._buf = self._buf[cpy:] index += cpy return index - def readinto(self, b): + def readinto(self, b: WriteableBuffer) -> int: index = 0 - length = len(b) + length = len(b) # type: ignore # TODO! index = self._readinto_from_buf(b, index, length) if index == length: @@ -330,25 +379,28 @@ class _RawGeneratorFileobj(io.RawIOBase): return self._readinto_from_buf(b, index, length) -def _stream_generator_to_fileobj(stream): +def _stream_generator_to_fileobj(stream: t.Generator[bytes]) -> io.BufferedReader: """Given a generator that generates chunks of bytes, create a readable buffered stream.""" raw = _RawGeneratorFileobj(stream) return io.BufferedReader(raw) +_T = t.TypeVar("_T") + + def fetch_file_ex( - client, - container, - in_path, - process_none, - process_regular, - process_symlink, - process_other, - follow_links=False, - log=None, -): + client: APIClient, + container: str, + in_path: str, + process_none: Callable[[str], _T], + process_regular: Callable[[str, tarfile.TarFile, tarfile.TarInfo], _T], + process_symlink: Callable[[str, tarfile.TarInfo], _T], + process_other: Callable[[str, tarfile.TarInfo], _T], + follow_links: bool = False, + log: Callable[[str], None] | None = None, +) -> _T: """Fetch a file (as a tar file entry) from a Docker container to local.""" - considered_in_paths = set() + considered_in_paths: set[str] = set() while True: if in_path in considered_in_paths: @@ -372,8 +424,8 @@ def fetch_file_ex( with tarfile.open( fileobj=_stream_generator_to_fileobj(stream), mode="r|" ) as tar: - symlink_member = None - result = None + symlink_member: tarfile.TarInfo | None = None + result: _T | None = None found = False for member in tar: if found: @@ -398,35 +450,46 @@ def fetch_file_ex( log(f'FETCH: Following symbolic link to "{in_path}"') continue if found: - return result + return result # type: ignore raise DockerUnexpectedError("Received tarfile is empty!") -def fetch_file(client, container, in_path, out_path, follow_links=False, log=None): +def fetch_file( + client: APIClient, + container: str, + in_path: str, + out_path: str, + follow_links: bool = False, + log: Callable[[str], None] | None = None, +) -> str: b_out_path = to_bytes(out_path, errors="surrogate_or_strict") - def process_none(in_path): + def process_none(in_path: str) -> str: raise DockerFileNotFound( f"File {in_path} does not exist in container {container}" ) - def process_regular(in_path, tar, member): + def process_regular( + in_path: str, tar: tarfile.TarFile, member: tarfile.TarInfo + ) -> str: if not follow_links and os.path.exists(b_out_path): os.unlink(b_out_path) - with tar.extractfile(member) as in_f: - with open(b_out_path, "wb") as out_f: - shutil.copyfileobj(in_f, out_f) + reader = tar.extractfile(member) + if reader: + with reader as in_f: + with open(b_out_path, "wb") as out_f: + shutil.copyfileobj(in_f, out_f) return in_path - def process_symlink(in_path, member): + def process_symlink(in_path, member) -> str: if os.path.exists(b_out_path): os.unlink(b_out_path) os.symlink(member.linkname, b_out_path) return in_path - def process_other(in_path, member): + def process_other(in_path, member) -> str: raise DockerFileCopyError( f'Remote file "{in_path}" is not a regular file or a symbolic link' ) @@ -444,7 +507,13 @@ def fetch_file(client, container, in_path, out_path, follow_links=False, log=Non ) -def _execute_command(client, container, command, log=None, check_rc=False): +def _execute_command( + client: APIClient, + container: str, + command: list[str], + log: Callable[[str], None] | None = None, + check_rc: bool = False, +) -> tuple[int, bytes, bytes]: if log: log(f"Executing {command} in {container}") @@ -493,13 +562,15 @@ def _execute_command(client, container, command, log=None, check_rc=False): if check_rc and rc != 0: command_str = " ".join(command) raise DockerUnexpectedError( - f'Obtained unexpected exit code {rc} when running "{command_str}" in {container}.\nSTDOUT: {stdout}\nSTDERR: {stderr}' + f'Obtained unexpected exit code {rc} when running "{command_str}" in {container}.\nSTDOUT: {stdout!r}\nSTDERR: {stderr!r}' ) return rc, stdout, stderr -def determine_user_group(client, container, log=None): +def determine_user_group( + client: APIClient, container: str, log: Callable[[str], None] | None = None +) -> tuple[int, int]: dummy_rc, stdout, dummy_stderr = _execute_command( client, container, ["/bin/sh", "-c", "id -u && id -g"], check_rc=True, log=log ) @@ -507,7 +578,7 @@ def determine_user_group(client, container, log=None): stdout_lines = stdout.splitlines() if len(stdout_lines) != 2: raise DockerUnexpectedError( - f"Expected two-line output to obtain user and group ID for container {container}, but got {len(stdout_lines)} lines:\n{stdout}" + f"Expected two-line output to obtain user and group ID for container {container}, but got {len(stdout_lines)} lines:\n{stdout!r}" ) user_id, group_id = stdout_lines @@ -515,5 +586,5 @@ def determine_user_group(client, container, log=None): return int(user_id), int(group_id) except ValueError as exc: raise DockerUnexpectedError( - f'Expected two-line output with numeric IDs to obtain user and group ID for container {container}, but got "{user_id}" and "{group_id}" instead' + f"Expected two-line output with numeric IDs to obtain user and group ID for container {container}, but got {user_id!r} and {group_id!r} instead" ) from exc diff --git a/plugins/module_utils/_image_archive.py b/plugins/module_utils/_image_archive.py index 15ee9b91..6156f209 100644 --- a/plugins/module_utils/_image_archive.py +++ b/plugins/module_utils/_image_archive.py @@ -18,12 +18,10 @@ class ImageArchiveManifestSummary: "docker image save some:tag > some.tar" command. """ - def __init__(self, image_id, repo_tags): + def __init__(self, image_id: str, repo_tags: list[str]) -> None: """ :param image_id: File name portion of Config entry, e.g. abcde12345 from abcde12345.json - :type image_id: str :param repo_tags Docker image names, e.g. ["hello-world:latest"] - :type repo_tags: list[str] """ self.image_id = image_id @@ -34,22 +32,21 @@ class ImageArchiveInvalidException(Exception): pass -def api_image_id(archive_image_id): +def api_image_id(archive_image_id: str) -> str: """ Accepts an image hash in the format stored in manifest.json, and returns an equivalent identifier that represents the same image hash, but in the format presented by the Docker Engine API. :param archive_image_id: plain image hash - :type archive_image_id: str - :returns: Prefixed hash used by REST api - :rtype: str """ return f"sha256:{archive_image_id}" -def load_archived_image_manifest(archive_path): +def load_archived_image_manifest( + archive_path: str, +) -> list[ImageArchiveManifestSummary] | None: """ Attempts to get image IDs and image names from metadata stored in the image archive tar file. @@ -62,10 +59,7 @@ def load_archived_image_manifest(archive_path): ImageArchiveInvalidException: A file already exists at archive_path, but could not extract an image ID from it. :param archive_path: Tar file to read - :type archive_path: str - :return: None, if no file at archive_path, or a list of ImageArchiveManifestSummary objects. - :rtype: ImageArchiveManifestSummary """ try: @@ -76,8 +70,15 @@ def load_archived_image_manifest(archive_path): with tarfile.open(archive_path, "r") as tf: try: try: - with tf.extractfile("manifest.json") as ef: + reader = tf.extractfile("manifest.json") + if reader is None: + raise ImageArchiveInvalidException( + "Failed to read manifest.json" + ) + with reader as ef: manifest = json.load(ef) + except ImageArchiveInvalidException: + raise except Exception as exc: raise ImageArchiveInvalidException( f"Failed to decode and deserialize manifest.json: {exc}" @@ -139,7 +140,7 @@ def load_archived_image_manifest(archive_path): ) from exc -def archived_image_manifest(archive_path): +def archived_image_manifest(archive_path: str) -> ImageArchiveManifestSummary | None: """ Attempts to get Image.Id and image name from metadata stored in the image archive tar file. @@ -152,10 +153,7 @@ def archived_image_manifest(archive_path): ImageArchiveInvalidException: A file already exists at archive_path, but could not extract an image ID from it. :param archive_path: Tar file to read - :type archive_path: str - :return: None, if no file at archive_path, or the extracted image ID, which will not have a sha256: prefix. - :rtype: ImageArchiveManifestSummary """ results = load_archived_image_manifest(archive_path) diff --git a/plugins/module_utils/_logfmt.py b/plugins/module_utils/_logfmt.py index c048966b..12fcdd08 100644 --- a/plugins/module_utils/_logfmt.py +++ b/plugins/module_utils/_logfmt.py @@ -13,6 +13,9 @@ See https://pkg.go.dev/github.com/kr/logfmt?utm_source=godoc for information on from __future__ import annotations +import typing as t +from enum import Enum + # The format is defined in https://pkg.go.dev/github.com/kr/logfmt?utm_source=godoc # (look for "EBNFish") @@ -22,7 +25,7 @@ class InvalidLogFmt(Exception): pass -class _Mode: +class _Mode(Enum): GARBAGE = 0 KEY = 1 EQUAL = 2 @@ -68,29 +71,29 @@ _HEX_DICT = { } -def _is_ident(cur): +def _is_ident(cur: str) -> bool: return cur > " " and cur not in ('"', "=") class _Parser: - def __init__(self, line): + def __init__(self, line: str) -> None: self.line = line self.index = 0 self.length = len(line) - def done(self): + def done(self) -> bool: return self.index >= self.length - def cur(self): + def cur(self) -> str: return self.line[self.index] - def next(self): + def next(self) -> None: self.index += 1 - def prev(self): + def prev(self) -> None: self.index -= 1 - def parse_unicode_sequence(self): + def parse_unicode_sequence(self) -> str: if self.index + 6 > self.length: raise InvalidLogFmt("Not enough space for unicode escape") if self.line[self.index : self.index + 2] != "\\u": @@ -108,27 +111,27 @@ class _Parser: return chr(v) -def parse_line(line, logrus_mode=False): - result = {} +def parse_line(line: str, logrus_mode: bool = False) -> dict[str, t.Any]: + result: dict[str, t.Any] = {} parser = _Parser(line) - key = [] - value = [] + key: list[str] = [] + value: list[str] = [] mode = _Mode.GARBAGE - def handle_kv(has_no_value=False): + def handle_kv(has_no_value: bool = False) -> None: k = "".join(key) v = None if has_no_value else "".join(value) result[k] = v del key[:] del value[:] - def parse_garbage(cur): + def parse_garbage(cur: str) -> _Mode: if _is_ident(cur): return _Mode.KEY parser.next() return _Mode.GARBAGE - def parse_key(cur): + def parse_key(cur: str) -> _Mode: if _is_ident(cur): key.append(cur) parser.next() @@ -142,7 +145,7 @@ def parse_line(line, logrus_mode=False): parser.next() return _Mode.GARBAGE - def parse_equal(cur): + def parse_equal(cur: str) -> _Mode: if _is_ident(cur): value.append(cur) parser.next() @@ -154,7 +157,7 @@ def parse_line(line, logrus_mode=False): parser.next() return _Mode.GARBAGE - def parse_ident_value(cur): + def parse_ident_value(cur: str) -> _Mode: if _is_ident(cur): value.append(cur) parser.next() @@ -163,7 +166,7 @@ def parse_line(line, logrus_mode=False): parser.next() return _Mode.GARBAGE - def parse_quoted_value(cur): + def parse_quoted_value(cur: str) -> _Mode: if cur == "\\": parser.next() if parser.done(): diff --git a/plugins/module_utils/_platform.py b/plugins/module_utils/_platform.py index 567a4173..3807b6a0 100644 --- a/plugins/module_utils/_platform.py +++ b/plugins/module_utils/_platform.py @@ -14,12 +14,13 @@ from __future__ import annotations import re +import typing as t _VALID_STR = re.compile("^[A-Za-z0-9_-]+$") -def _validate_part(string, part, part_name): +def _validate_part(string: str, part: str, part_name: str) -> str: if not part: raise ValueError(f'Invalid platform string "{string}": {part_name} is empty') if not _VALID_STR.match(part): @@ -79,7 +80,7 @@ _KNOWN_ARCH = ( ) -def _normalize_os(os_str): +def _normalize_os(os_str: str) -> str: # See normalizeOS() in https://github.com/containerd/containerd/blob/main/platforms/database.go os_str = os_str.lower() if os_str == "macos": @@ -112,7 +113,7 @@ _NORMALIZE_ARCH = { } -def _normalize_arch(arch_str, variant_str): +def _normalize_arch(arch_str: str, variant_str: str) -> tuple[str, str]: # See normalizeArch() in https://github.com/containerd/containerd/blob/main/platforms/database.go arch_str = arch_str.lower() variant_str = variant_str.lower() @@ -121,15 +122,16 @@ def _normalize_arch(arch_str, variant_str): res = _NORMALIZE_ARCH.get((arch_str, None)) if res is None: return arch_str, variant_str - if res is not None: - arch_str = res[0] - if res[1] is not None: - variant_str = res[1] - return arch_str, variant_str + arch_str = res[0] + if res[1] is not None: + variant_str = res[1] + return arch_str, variant_str class _Platform: - def __init__(self, os=None, arch=None, variant=None): + def __init__( + self, os: str | None = None, arch: str | None = None, variant: str | None = None + ) -> None: self.os = os self.arch = arch self.variant = variant @@ -140,7 +142,12 @@ class _Platform: raise ValueError("If variant is given, os must be given too") @classmethod - def parse_platform_string(cls, string, daemon_os=None, daemon_arch=None): + def parse_platform_string( + cls, + string: str | None, + daemon_os: str | None = None, + daemon_arch: str | None = None, + ) -> t.Self: # See Parse() in https://github.com/containerd/containerd/blob/main/platforms/platforms.go if string is None: return cls() @@ -182,6 +189,7 @@ class _Platform: ) if variant is not None and not variant: raise ValueError(f'Invalid platform string "{string}": variant is empty') + assert arch is not None # otherwise variant would be None as well arch, variant = _normalize_arch(arch, variant or "") if len(parts) == 2 and arch == "arm" and variant == "v7": variant = None @@ -189,9 +197,12 @@ class _Platform: variant = "v8" return cls(os=_normalize_os(os), arch=arch, variant=variant or None) - def __str__(self): + def __str__(self) -> str: if self.variant: - parts = [self.os, self.arch, self.variant] + assert ( + self.os is not None and self.arch is not None + ) # ensured in constructor + parts: list[str] = [self.os, self.arch, self.variant] elif self.os: if self.arch: parts = [self.os, self.arch] @@ -203,12 +214,14 @@ class _Platform: parts = [] return "/".join(parts) - def __repr__(self): + def __repr__(self) -> str: return ( f"_Platform(os={self.os!r}, arch={self.arch!r}, variant={self.variant!r})" ) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + if not isinstance(other, _Platform): + return NotImplemented return ( self.os == other.os and self.arch == other.arch @@ -216,7 +229,9 @@ class _Platform: ) -def normalize_platform_string(string, daemon_os=None, daemon_arch=None): +def normalize_platform_string( + string: str, daemon_os: str | None = None, daemon_arch: str | None = None +) -> str: return str( _Platform.parse_platform_string( string, daemon_os=daemon_os, daemon_arch=daemon_arch @@ -225,8 +240,12 @@ def normalize_platform_string(string, daemon_os=None, daemon_arch=None): def compose_platform_string( - os=None, arch=None, variant=None, daemon_os=None, daemon_arch=None -): + os: str | None = None, + arch: str | None = None, + variant: str | None = None, + daemon_os: str | None = None, + daemon_arch: str | None = None, +) -> str: if os is None and daemon_os is not None: os = _normalize_os(daemon_os) if arch is None and daemon_arch is not None: @@ -235,7 +254,7 @@ def compose_platform_string( return str(_Platform(os=os, arch=arch, variant=variant or None)) -def compare_platform_strings(string1, string2): +def compare_platform_strings(string1: str, string2: str) -> bool: return _Platform.parse_platform_string(string1) == _Platform.parse_platform_string( string2 ) diff --git a/plugins/module_utils/_scramble.py b/plugins/module_utils/_scramble.py index 621f459f..fa4b6dc2 100644 --- a/plugins/module_utils/_scramble.py +++ b/plugins/module_utils/_scramble.py @@ -13,7 +13,7 @@ import random from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text -def generate_insecure_key(): +def generate_insecure_key() -> bytes: """Do NOT use this for cryptographic purposes!""" while True: # Generate a one-byte key. Right now the functions below do not use more @@ -24,23 +24,23 @@ def generate_insecure_key(): return key -def scramble(value, key): +def scramble(value: str, key: bytes) -> str: """Do NOT use this for cryptographic purposes!""" if len(key) < 1: raise ValueError("Key must be at least one byte") - value = to_bytes(value) + b_value = to_bytes(value) k = key[0] - value = bytes([k ^ b for b in value]) - return "=S=" + to_native(base64.b64encode(value)) + b_value = bytes([k ^ b for b in b_value]) + return f"=S={to_native(base64.b64encode(b_value))}" -def unscramble(value, key): +def unscramble(value: str, key: bytes) -> str: """Do NOT use this for cryptographic purposes!""" if len(key) < 1: raise ValueError("Key must be at least one byte") if not value.startswith("=S="): raise ValueError("Value does not start with indicator") - value = base64.b64decode(value[3:]) + b_value = base64.b64decode(value[3:]) k = key[0] - value = bytes([k ^ b for b in value]) - return to_text(value) + b_value = bytes([k ^ b for b in b_value]) + return to_text(b_value) diff --git a/plugins/module_utils/_swarm.py b/plugins/module_utils/_swarm.py index 93b08d5d..699f5821 100644 --- a/plugins/module_utils/_swarm.py +++ b/plugins/module_utils/_swarm.py @@ -9,6 +9,7 @@ from __future__ import annotations import json +import typing as t from time import sleep @@ -28,10 +29,7 @@ from ansible_collections.community.docker.plugins.module_utils._version import ( class AnsibleDockerSwarmClient(AnsibleDockerClient): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def get_swarm_node_id(self): + def get_swarm_node_id(self) -> str | None: """ Get the 'NodeID' of the Swarm node or 'None' if host is not in Swarm. It returns the NodeID of Docker host the module is executed on @@ -51,7 +49,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient): return swarm_info["Swarm"]["NodeID"] return None - def check_if_swarm_node(self, node_id=None): + def check_if_swarm_node(self, node_id: str | None = None) -> bool | None: """ Checking if host is part of Docker Swarm. If 'node_id' is not provided it reads the Docker host system information looking if specific key in output exists. If 'node_id' is provided then it tries to @@ -83,11 +81,11 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient): try: node_info = self.get_node_inspect(node_id=node_id) except APIError: - return + return None return node_info["ID"] is not None - def check_if_swarm_manager(self): + def check_if_swarm_manager(self) -> bool: """ Checks if node role is set as Manager in Swarm. The node is the docker host on which module action is performed. The inspect_swarm() will fail if node is not a manager @@ -101,7 +99,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient): except APIError: return False - def fail_task_if_not_swarm_manager(self): + def fail_task_if_not_swarm_manager(self) -> None: """ If host is not a swarm manager then Ansible task on this host should end with 'failed' state """ @@ -110,7 +108,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient): "Error running docker swarm module: must run on swarm manager node" ) - def check_if_swarm_worker(self): + def check_if_swarm_worker(self) -> bool: """ Checks if node role is set as Worker in Swarm. The node is the docker host on which module action is performed. Will fail if run on host that is not part of Swarm via check_if_swarm_node() @@ -122,7 +120,9 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient): return True return False - def check_if_swarm_node_is_down(self, node_id=None, repeat_check=1): + def check_if_swarm_node_is_down( + self, node_id: str | None = None, repeat_check: int = 1 + ) -> bool: """ Checks if node status on Swarm manager is 'down'. If node_id is provided it query manager about node specified in parameter, otherwise it query manager itself. If run on Swarm Worker node or @@ -147,7 +147,19 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient): return True return False - def get_node_inspect(self, node_id=None, skip_missing=False): + @t.overload + def get_node_inspect( + self, node_id: str | None = None, skip_missing: t.Literal[False] = False + ) -> dict[str, t.Any]: ... + + @t.overload + def get_node_inspect( + self, node_id: str | None = None, skip_missing: bool = False + ) -> dict[str, t.Any] | None: ... + + def get_node_inspect( + self, node_id: str | None = None, skip_missing: bool = False + ) -> dict[str, t.Any] | None: """ Returns Swarm node info as in 'docker node inspect' command about single node @@ -195,7 +207,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient): node_info["Status"]["Addr"] = swarm_leader_ip return node_info - def get_all_nodes_inspect(self): + def get_all_nodes_inspect(self) -> list[dict[str, t.Any]]: """ Returns Swarm node info as in 'docker node inspect' command about all registered nodes @@ -217,7 +229,17 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient): node_info = json.loads(json_str) return node_info - def get_all_nodes_list(self, output="short"): + @t.overload + def get_all_nodes_list(self, output: t.Literal["short"] = "short") -> list[str]: ... + + @t.overload + def get_all_nodes_list( + self, output: t.Literal["long"] + ) -> list[dict[str, t.Any]]: ... + + def get_all_nodes_list( + self, output: t.Literal["short", "long"] = "short" + ) -> list[str] | list[dict[str, t.Any]]: """ Returns list of nodes registered in Swarm @@ -227,48 +249,46 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient): if 'output' is 'long' then returns data is list of dict containing the attributes as in output of command 'docker node ls' """ - nodes_list = [] - nodes_inspect = self.get_all_nodes_inspect() - if nodes_inspect is None: - return None if output == "short": + nodes_list = [] for node in nodes_inspect: nodes_list.append(node["Description"]["Hostname"]) - elif output == "long": + return nodes_list + if output == "long": + nodes_info_list = [] for node in nodes_inspect: - node_property = {} + node_property: dict[str, t.Any] = {} - node_property.update({"ID": node["ID"]}) - node_property.update({"Hostname": node["Description"]["Hostname"]}) - node_property.update({"Status": node["Status"]["State"]}) - node_property.update({"Availability": node["Spec"]["Availability"]}) + node_property["ID"] = node["ID"] + node_property["Hostname"] = node["Description"]["Hostname"] + node_property["Status"] = node["Status"]["State"] + node_property["Availability"] = node["Spec"]["Availability"] if "ManagerStatus" in node: if node["ManagerStatus"]["Leader"] is True: - node_property.update({"Leader": True}) - node_property.update( - {"ManagerStatus": node["ManagerStatus"]["Reachability"]} - ) - node_property.update( - {"EngineVersion": node["Description"]["Engine"]["EngineVersion"]} - ) + node_property["Leader"] = True + node_property["ManagerStatus"] = node["ManagerStatus"][ + "Reachability" + ] + node_property["EngineVersion"] = node["Description"]["Engine"][ + "EngineVersion" + ] - nodes_list.append(node_property) - else: - return None + nodes_info_list.append(node_property) + return nodes_info_list - return nodes_list - - def get_node_name_by_id(self, nodeid): + def get_node_name_by_id(self, nodeid: str) -> str: return self.get_node_inspect(nodeid)["Description"]["Hostname"] - def get_unlock_key(self): + def get_unlock_key(self) -> str | None: if self.docker_py_version < LooseVersion("2.7.0"): return None return super().get_unlock_key() - def get_service_inspect(self, service_id, skip_missing=False): + def get_service_inspect( + self, service_id: str, skip_missing: bool = False + ) -> dict[str, t.Any] | None: """ Returns Swarm service info as in 'docker service inspect' command about single service diff --git a/plugins/module_utils/_util.py b/plugins/module_utils/_util.py index 5171fae6..d9fb21b7 100644 --- a/plugins/module_utils/_util.py +++ b/plugins/module_utils/_util.py @@ -9,6 +9,7 @@ from __future__ import annotations import json import re +import typing as t from datetime import timedelta from urllib.parse import urlparse @@ -17,6 +18,12 @@ from ansible.module_utils.common.collections import is_sequence from ansible.module_utils.common.text.converters import to_text +if t.TYPE_CHECKING: + from collections.abc import Callable + + from ansible.module_utils.basic import AnsibleModule + + DEFAULT_DOCKER_HOST = "unix:///var/run/docker.sock" DEFAULT_TLS = False DEFAULT_TLS_VERIFY = False @@ -79,14 +86,14 @@ DEFAULT_DOCKER_REGISTRY = "https://index.docker.io/v1/" BYTE_SUFFIXES = ["B", "KB", "MB", "GB", "TB", "PB"] -def is_image_name_id(name): +def is_image_name_id(name: str) -> bool: """Check whether the given image name is in fact an image ID (hash).""" if re.match("^sha256:[0-9a-fA-F]{64}$", name): return True return False -def is_valid_tag(tag, allow_empty=False): +def is_valid_tag(tag: str, allow_empty: bool = False) -> bool: """Check whether the given string is a valid docker tag name.""" if not tag: return allow_empty @@ -95,7 +102,7 @@ def is_valid_tag(tag, allow_empty=False): return bool(re.match("^[a-zA-Z0-9_][a-zA-Z0-9_.-]{0,127}$", tag)) -def sanitize_result(data): +def sanitize_result(data: t.Any) -> t.Any: """Sanitize data object for return to Ansible. When the data object contains types such as docker.types.containers.HostConfig, @@ -112,7 +119,7 @@ def sanitize_result(data): return data -def log_debug(msg, pretty_print=False): +def log_debug(msg: t.Any, pretty_print: bool = False): """Write a log message to docker.log. If ``pretty_print=True``, the message will be pretty-printed as JSON. @@ -128,25 +135,28 @@ def log_debug(msg, pretty_print=False): class DockerBaseClass: - def __init__(self): + def __init__(self) -> None: self.debug = False - def log(self, msg, pretty_print=False): + def log(self, msg: t.Any, pretty_print: bool = False) -> None: pass # if self.debug: # log_debug(msg, pretty_print=pretty_print) def update_tls_hostname( - result, old_behavior=False, deprecate_function=None, uses_tls=True -): + result: dict[str, t.Any], + old_behavior: bool = False, + deprecate_function: Callable[[str], None] | None = None, + uses_tls: bool = True, +) -> None: if result["tls_hostname"] is None: # get default machine name from the url parsed_url = urlparse(result["docker_host"]) result["tls_hostname"] = parsed_url.netloc.rsplit(":", 1)[0] -def compare_dict_allow_more_present(av, bv): +def compare_dict_allow_more_present(av: dict, bv: dict) -> bool: """ Compare two dictionaries for whether every entry of the first is in the second. """ @@ -158,7 +168,12 @@ def compare_dict_allow_more_present(av, bv): return True -def compare_generic(a, b, method, datatype): +def compare_generic( + a: t.Any, + b: t.Any, + method: t.Literal["ignore", "strict", "allow_more_present"], + datatype: t.Literal["value", "list", "set", "set(dict)", "dict"], +) -> bool: """ Compare values a and b as described by method and datatype. @@ -249,10 +264,10 @@ def compare_generic(a, b, method, datatype): class DifferenceTracker: - def __init__(self): - self._diff = [] + def __init__(self) -> None: + self._diff: list[dict[str, t.Any]] = [] - def add(self, name, parameter=None, active=None): + def add(self, name: str, parameter: t.Any = None, active: t.Any = None) -> None: self._diff.append( { "name": name, @@ -261,14 +276,14 @@ class DifferenceTracker: } ) - def merge(self, other_tracker): + def merge(self, other_tracker: DifferenceTracker) -> None: self._diff.extend(other_tracker._diff) @property - def empty(self): + def empty(self) -> bool: return len(self._diff) == 0 - def get_before_after(self): + def get_before_after(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: """ Return texts ``before`` and ``after``. """ @@ -279,13 +294,13 @@ class DifferenceTracker: after[item["name"]] = item["parameter"] return before, after - def has_difference_for(self, name): + def has_difference_for(self, name: str) -> bool: """ Returns a boolean if a difference exists for name """ return any(diff for diff in self._diff if diff["name"] == name) - def get_legacy_docker_container_diffs(self): + def get_legacy_docker_container_diffs(self) -> list[dict[str, t.Any]]: """ Return differences in the docker_container legacy format. """ @@ -299,7 +314,7 @@ class DifferenceTracker: result.append(item) return result - def get_legacy_docker_diffs(self): + def get_legacy_docker_diffs(self) -> list[str]: """ Return differences in the docker_container legacy format. """ @@ -307,8 +322,13 @@ class DifferenceTracker: return result -def sanitize_labels(labels, labels_field, client=None, module=None): - def fail(msg): +def sanitize_labels( + labels: dict[str, t.Any] | None, + labels_field: str, + client=None, + module: AnsibleModule | None = None, +) -> None: + def fail(msg: str) -> t.NoReturn: if client is not None: client.fail(msg) if module is not None: @@ -327,7 +347,21 @@ def sanitize_labels(labels, labels_field, client=None, module=None): labels[k] = to_text(v) -def clean_dict_booleans_for_docker_api(data, allow_sequences=False): +@t.overload +def clean_dict_booleans_for_docker_api( + data: dict[str, t.Any], *, allow_sequences: t.Literal[False] = False +) -> dict[str, str]: ... + + +@t.overload +def clean_dict_booleans_for_docker_api( + data: dict[str, t.Any], *, allow_sequences: bool +) -> dict[str, str | list[str]]: ... + + +def clean_dict_booleans_for_docker_api( + data: dict[str, t.Any], *, allow_sequences: bool = False +) -> dict[str, str] | dict[str, str | list[str]]: """ Go does not like Python booleans 'True' or 'False', while Ansible is just fine with them in YAML. As such, they need to be converted in cases where @@ -355,7 +389,7 @@ def clean_dict_booleans_for_docker_api(data, allow_sequences=False): return result -def convert_duration_to_nanosecond(time_str): +def convert_duration_to_nanosecond(time_str: str) -> int: """ Return time duration in nanosecond. """ @@ -374,9 +408,9 @@ def convert_duration_to_nanosecond(time_str): if not parts: raise ValueError(f"Invalid time duration - {time_str}") - parts = parts.groupdict() + parts_dict = parts.groupdict() time_params = {} - for name, value in parts.items(): + for name, value in parts_dict.items(): if value: time_params[name] = int(value) @@ -388,13 +422,15 @@ def convert_duration_to_nanosecond(time_str): return time_in_nanoseconds -def normalize_healthcheck_test(test): +def normalize_healthcheck_test(test: t.Any) -> list[str]: if isinstance(test, (tuple, list)): return [str(e) for e in test] return ["CMD-SHELL", str(test)] -def normalize_healthcheck(healthcheck, normalize_test=False): +def normalize_healthcheck( + healthcheck: dict[str, t.Any], normalize_test: bool = False +) -> dict[str, t.Any]: """ Return dictionary of healthcheck parameters. """ @@ -440,7 +476,9 @@ def normalize_healthcheck(healthcheck, normalize_test=False): return result -def parse_healthcheck(healthcheck): +def parse_healthcheck( + healthcheck: dict[str, t.Any] | None, +) -> tuple[dict[str, t.Any] | None, bool | None]: """ Return dictionary of healthcheck parameters and boolean if healthcheck defined in image was requested to be disabled. @@ -458,8 +496,8 @@ def parse_healthcheck(healthcheck): return result, False -def omit_none_from_dict(d): +def omit_none_from_dict(d: dict[str, t.Any]) -> dict[str, t.Any]: """ Return a copy of the dictionary with all keys with value None omitted. """ - return dict((k, v) for (k, v) in d.items() if v is not None) + return {k: v for (k, v) in d.items() if v is not None} diff --git a/plugins/modules/docker_host_info.py b/plugins/modules/docker_host_info.py index b71387cf..36c9699c 100644 --- a/plugins/modules/docker_host_info.py +++ b/plugins/modules/docker_host_info.py @@ -255,7 +255,7 @@ class DockerHostManager(DockerBaseClass): returned_name = docker_object filter_name = docker_object + "_filters" filters = clean_dict_booleans_for_docker_api( - client.module.params.get(filter_name), True + client.module.params.get(filter_name), allow_sequences=True ) self.results[returned_name] = self.get_docker_items_list( docker_object, filters