First batch of types.

This commit is contained in:
Felix Fontein 2025-10-18 20:06:37 +02:00
parent c76ee1d1cc
commit 0ff66e8b24
12 changed files with 739 additions and 363 deletions

View File

@ -13,6 +13,7 @@ import platform
import re
import sys
import traceback
import typing as t
from collections.abc import Mapping, Sequence
from ansible.module_utils.basic import AnsibleModule, missing_required_lib
@ -79,6 +80,10 @@ except ImportError:
pass
if t.TYPE_CHECKING:
from collections.abc import Callable
MIN_DOCKER_VERSION = "2.0.0"
@ -96,7 +101,9 @@ if not HAS_DOCKER_PY:
pass
def _get_tls_config(fail_function, **kwargs):
def _get_tls_config(
fail_function: Callable[[str], t.NoReturn], **kwargs: t.Any
) -> TLSConfig:
if "assert_hostname" in kwargs and LooseVersion(docker_version) >= LooseVersion(
"7.0.0b1"
):
@ -111,17 +118,18 @@ def _get_tls_config(fail_function, **kwargs):
# Filter out all None parameters
kwargs = dict((k, v) for k, v in kwargs.items() if v is not None)
try:
tls_config = TLSConfig(**kwargs)
return tls_config
return TLSConfig(**kwargs)
except TLSParameterError as exc:
fail_function(f"TLS config error: {exc}")
def is_using_tls(auth_data):
def is_using_tls(auth_data: dict[str, t.Any]) -> bool:
return auth_data["tls_verify"] or auth_data["tls"]
def get_connect_params(auth_data, fail_function):
def get_connect_params(
auth_data: dict[str, t.Any], fail_function: Callable[[str], t.NoReturn]
) -> dict[str, t.Any]:
if is_using_tls(auth_data):
auth_data["docker_host"] = auth_data["docker_host"].replace(
"tcp://", "https://"
@ -173,7 +181,11 @@ DOCKERPYUPGRADE_UPGRADE_DOCKER = "Use `pip install --upgrade docker` to upgrade.
class AnsibleDockerClientBase(Client):
def __init__(self, min_docker_version=None, min_docker_api_version=None):
def __init__(
self,
min_docker_version: str | None = None,
min_docker_api_version: str | None = None,
) -> None:
if min_docker_version is None:
min_docker_version = MIN_DOCKER_VERSION
@ -214,23 +226,34 @@ class AnsibleDockerClientBase(Client):
f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}."
)
def log(self, msg, pretty_print=False):
def log(self, msg: t.Any, pretty_print: bool = False):
pass
# if self.debug:
# from .util import log_debug
# log_debug(msg, pretty_print=pretty_print)
@abc.abstractmethod
def fail(self, msg, **kwargs):
def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
pass
def deprecate(self, msg, version=None, date=None, collection_name=None):
@abc.abstractmethod
def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
pass
@staticmethod
def _get_value(
param_name, param_value, env_variable, default_value, value_type="str"
):
param_name: str,
param_value: t.Any,
env_variable: str | None,
default_value: t.Any | None,
value_type: t.Literal["str", "bool", "int"] = "str",
) -> t.Any:
if param_value is not None:
# take module parameter value
if value_type == "bool":
@ -267,11 +290,11 @@ class AnsibleDockerClientBase(Client):
return default_value
@abc.abstractmethod
def _get_params(self):
def _get_params(self) -> dict[str, t.Any]:
pass
@property
def auth_params(self):
def auth_params(self) -> dict[str, t.Any]:
# Get authentication credentials.
# Precedence: module parameters-> environment variables-> defaults.
@ -356,7 +379,7 @@ class AnsibleDockerClientBase(Client):
return result
def _handle_ssl_error(self, error):
def _handle_ssl_error(self, error: Exception) -> t.NoReturn:
match = re.match(r"hostname.*doesn\'t match (\'.*\')", str(error))
if match:
hostname = self.auth_params["tls_hostname"]
@ -368,7 +391,7 @@ class AnsibleDockerClientBase(Client):
)
self.fail(f"SSL Exception: {error}")
def get_container_by_id(self, container_id):
def get_container_by_id(self, container_id: str) -> dict[str, t.Any] | None:
try:
self.log(f"Inspecting container Id {container_id}")
result = self.inspect_container(container=container_id)
@ -379,7 +402,7 @@ class AnsibleDockerClientBase(Client):
except Exception as exc: # pylint: disable=broad-exception-caught
self.fail(f"Error inspecting container: {exc}")
def get_container(self, name=None):
def get_container(self, name: str | None) -> dict[str, t.Any] | None:
"""
Lookup a container and return the inspection results.
"""
@ -416,7 +439,9 @@ class AnsibleDockerClientBase(Client):
return self.get_container_by_id(result["Id"])
def get_network(self, name=None, network_id=None):
def get_network(
self, name: str | None = None, network_id: str | None = None
) -> dict[str, t.Any] | None:
"""
Lookup a network and return the inspection results.
"""
@ -455,7 +480,7 @@ class AnsibleDockerClientBase(Client):
return result
def find_image(self, name, tag):
def find_image(self, name: str, tag: str) -> dict[str, t.Any] | None:
"""
Lookup an image (by name and tag) and return the inspection results.
"""
@ -507,7 +532,9 @@ class AnsibleDockerClientBase(Client):
self.log(f"Image {name}:{tag} not found.")
return None
def find_image_by_id(self, image_id, accept_missing_image=False):
def find_image_by_id(
self, image_id: str, accept_missing_image: bool = False
) -> dict[str, t.Any] | None:
"""
Lookup an image (by ID) and return the inspection results.
"""
@ -526,7 +553,7 @@ class AnsibleDockerClientBase(Client):
self.fail(f"Error inspecting image ID {image_id} - {exc}")
return inspection
def _image_lookup(self, name, tag):
def _image_lookup(self, name: str, tag: str) -> list[dict[str, t.Any]]:
"""
Including a tag in the name parameter sent to the Docker SDK for Python images method
does not work consistently. Instead, get the result set for name and manually check
@ -549,7 +576,9 @@ class AnsibleDockerClientBase(Client):
break
return images
def pull_image(self, name, tag="latest", image_platform=None):
def pull_image(
self, name: str, tag: str = "latest", image_platform: str | None = None
) -> tuple[dict[str, t.Any] | None, bool]:
"""
Pull an image
"""
@ -580,7 +609,7 @@ class AnsibleDockerClientBase(Client):
return new_tag, old_tag == new_tag
def inspect_distribution(self, image, **kwargs):
def inspect_distribution(self, image: str, **kwargs) -> dict[str, t.Any]:
"""
Get image digest by directly calling the Docker API when running Docker SDK < 4.0.0
since prior versions did not support accessing private repositories.
@ -594,7 +623,7 @@ class AnsibleDockerClientBase(Client):
self._url("/distribution/{0}/json", image),
headers={"X-Registry-Auth": header},
),
get_json=True,
json=True,
)
return super().inspect_distribution(image, **kwargs)
@ -603,18 +632,24 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
def __init__(
self,
argument_spec=None,
supports_check_mode=False,
mutually_exclusive=None,
required_together=None,
required_if=None,
required_one_of=None,
required_by=None,
min_docker_version=None,
min_docker_api_version=None,
option_minimal_versions=None,
option_minimal_versions_ignore_params=None,
fail_results=None,
argument_spec: dict[str, t.Any] | None = None,
supports_check_mode: bool = False,
mutually_exclusive: Sequence[Sequence[str]] | None = None,
required_together: Sequence[Sequence[str]] | None = None,
required_if: (
Sequence[
tuple[str, t.Any, Sequence[str]]
| tuple[str, t.Any, Sequence[str], bool]
]
| None
) = None,
required_one_of: Sequence[Sequence[str]] | None = None,
required_by: dict[str, Sequence[str]] | None = None,
min_docker_version: str | None = None,
min_docker_api_version: str | None = None,
option_minimal_versions: dict[str, t.Any] | None = None,
option_minimal_versions_ignore_params: Sequence[str] | None = None,
fail_results: dict[str, t.Any] | None = None,
):
# Modules can put information in here which will always be returned
@ -627,12 +662,12 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
merged_arg_spec.update(argument_spec)
self.arg_spec = merged_arg_spec
mutually_exclusive_params = []
mutually_exclusive_params: list[Sequence[str]] = []
mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE
if mutually_exclusive:
mutually_exclusive_params += mutually_exclusive
required_together_params = []
required_together_params: list[Sequence[str]] = []
required_together_params += DOCKER_REQUIRED_TOGETHER
if required_together:
required_together_params += required_together
@ -660,20 +695,30 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
option_minimal_versions, option_minimal_versions_ignore_params
)
def fail(self, msg, **kwargs):
def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
self.fail_results.update(kwargs)
self.module.fail_json(msg=msg, **sanitize_result(self.fail_results))
def deprecate(self, msg, version=None, date=None, collection_name=None):
def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
self.module.deprecate(
msg, version=version, date=date, collection_name=collection_name
)
def _get_params(self):
def _get_params(self) -> dict[str, t.Any]:
return self.module.params
def _get_minimal_versions(self, option_minimal_versions, ignore_params=None):
self.option_minimal_versions = {}
def _get_minimal_versions(
self,
option_minimal_versions: dict[str, t.Any],
ignore_params: Sequence[str] | None = None,
) -> None:
self.option_minimal_versions: dict[str, dict[str, t.Any]] = {}
for option in self.module.argument_spec:
if ignore_params is not None:
if option in ignore_params:
@ -724,7 +769,9 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
msg = f"Cannot {usg} with your configuration."
self.fail(msg)
def report_warnings(self, result, warnings_key=None):
def report_warnings(
self, result: t.Any, warnings_key: Sequence[str] | None = None
) -> None:
"""
Checks result of client operation for warnings, and if present, outputs them.

View File

@ -11,6 +11,7 @@ from __future__ import annotations
import abc
import os
import re
import typing as t
from collections.abc import Mapping, Sequence
from ansible.module_utils.basic import AnsibleModule, missing_required_lib
@ -60,19 +61,26 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
)
def _get_tls_config(fail_function, **kwargs):
if t.TYPE_CHECKING:
from collections.abc import Callable
def _get_tls_config(
fail_function: Callable[[str], t.NoReturn], **kwargs: t.Any
) -> TLSConfig:
try:
tls_config = TLSConfig(**kwargs)
return tls_config
return TLSConfig(**kwargs)
except TLSParameterError as exc:
fail_function(f"TLS config error: {exc}")
def is_using_tls(auth_data):
def is_using_tls(auth_data: dict[str, t.Any]) -> bool:
return auth_data["tls_verify"] or auth_data["tls"]
def get_connect_params(auth_data, fail_function):
def get_connect_params(
auth_data: dict[str, t.Any], fail_function: Callable[[str], t.NoReturn]
) -> dict[str, t.Any]:
if is_using_tls(auth_data):
auth_data["docker_host"] = auth_data["docker_host"].replace(
"tcp://", "https://"
@ -114,7 +122,7 @@ def get_connect_params(auth_data, fail_function):
class AnsibleDockerClientBase(Client):
def __init__(self, min_docker_api_version=None):
def __init__(self, min_docker_api_version: str | None = None) -> None:
self._connect_params = get_connect_params(
self.auth_params, fail_function=self.fail
)
@ -138,23 +146,34 @@ class AnsibleDockerClientBase(Client):
f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}."
)
def log(self, msg, pretty_print=False):
def log(self, msg: t.Any, pretty_print: bool = False):
pass
# if self.debug:
# from .util import log_debug
# log_debug(msg, pretty_print=pretty_print)
@abc.abstractmethod
def fail(self, msg, **kwargs):
def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
pass
def deprecate(self, msg, version=None, date=None, collection_name=None):
@abc.abstractmethod
def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
pass
@staticmethod
def _get_value(
param_name, param_value, env_variable, default_value, value_type="str"
):
param_name: str,
param_value: t.Any,
env_variable: str | None,
default_value: t.Any | None,
value_type: t.Literal["str", "bool", "int"] = "str",
) -> t.Any:
if param_value is not None:
# take module parameter value
if value_type == "bool":
@ -191,11 +210,11 @@ class AnsibleDockerClientBase(Client):
return default_value
@abc.abstractmethod
def _get_params(self):
def _get_params(self) -> dict[str, t.Any]:
pass
@property
def auth_params(self):
def auth_params(self) -> dict[str, t.Any]:
# Get authentication credentials.
# Precedence: module parameters-> environment variables-> defaults.
@ -288,7 +307,7 @@ class AnsibleDockerClientBase(Client):
return result
def _handle_ssl_error(self, error):
def _handle_ssl_error(self, error: Exception) -> t.NoReturn:
match = re.match(r"hostname.*doesn\'t match (\'.*\')", str(error))
if match:
hostname = self.auth_params["tls_hostname"]
@ -300,7 +319,7 @@ class AnsibleDockerClientBase(Client):
)
self.fail(f"SSL Exception: {error}")
def get_container_by_id(self, container_id):
def get_container_by_id(self, container_id: str) -> dict[str, t.Any] | None:
try:
self.log(f"Inspecting container Id {container_id}")
result = self.get_json("/containers/{0}/json", container_id)
@ -311,7 +330,7 @@ class AnsibleDockerClientBase(Client):
except Exception as exc: # pylint: disable=broad-exception-caught
self.fail(f"Error inspecting container: {exc}")
def get_container(self, name=None):
def get_container(self, name: str | None) -> dict[str, t.Any] | None:
"""
Lookup a container and return the inspection results.
"""
@ -355,7 +374,9 @@ class AnsibleDockerClientBase(Client):
return self.get_container_by_id(result["Id"])
def get_network(self, name=None, network_id=None):
def get_network(
self, name: str | None = None, network_id: str | None = None
) -> dict[str, t.Any] | None:
"""
Lookup a network and return the inspection results.
"""
@ -395,14 +416,14 @@ class AnsibleDockerClientBase(Client):
return result
def _image_lookup(self, name, tag):
def _image_lookup(self, name: str, tag: str) -> list[dict[str, t.Any]]:
"""
Including a tag in the name parameter sent to the Docker SDK for Python images method
does not work consistently. Instead, get the result set for name and manually check
if the tag exists.
"""
try:
params = {
params: dict[str, t.Any] = {
"only_ids": 0,
"all": 0,
}
@ -427,7 +448,7 @@ class AnsibleDockerClientBase(Client):
break
return images
def find_image(self, name, tag):
def find_image(self, name: str, tag: str) -> dict[str, t.Any] | None:
"""
Lookup an image (by name and tag) and return the inspection results.
"""
@ -478,7 +499,9 @@ class AnsibleDockerClientBase(Client):
self.log(f"Image {name}:{tag} not found.")
return None
def find_image_by_id(self, image_id, accept_missing_image=False):
def find_image_by_id(
self, image_id: str, accept_missing_image: bool = False
) -> dict[str, t.Any] | None:
"""
Lookup an image (by ID) and return the inspection results.
"""
@ -496,7 +519,9 @@ class AnsibleDockerClientBase(Client):
except Exception as exc: # pylint: disable=broad-exception-caught
self.fail(f"Error inspecting image ID {image_id} - {exc}")
def pull_image(self, name, tag="latest", image_platform=None):
def pull_image(
self, name: str, tag: str = "latest", image_platform: str | None = None
) -> tuple[dict[str, t.Any] | None, bool]:
"""
Pull an image
"""
@ -547,17 +572,23 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
def __init__(
self,
argument_spec=None,
supports_check_mode=False,
mutually_exclusive=None,
required_together=None,
required_if=None,
required_one_of=None,
required_by=None,
min_docker_api_version=None,
option_minimal_versions=None,
option_minimal_versions_ignore_params=None,
fail_results=None,
argument_spec: dict[str, t.Any] | None = None,
supports_check_mode: bool = False,
mutually_exclusive: Sequence[Sequence[str]] | None = None,
required_together: Sequence[Sequence[str]] | None = None,
required_if: (
Sequence[
tuple[str, t.Any, Sequence[str]]
| tuple[str, t.Any, Sequence[str], bool]
]
| None
) = None,
required_one_of: Sequence[Sequence[str]] | None = None,
required_by: dict[str, Sequence[str]] | None = None,
min_docker_api_version: str | None = None,
option_minimal_versions: dict[str, t.Any] | None = None,
option_minimal_versions_ignore_params: Sequence[str] | None = None,
fail_results: dict[str, t.Any] | None = None,
):
# Modules can put information in here which will always be returned
@ -570,12 +601,12 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
merged_arg_spec.update(argument_spec)
self.arg_spec = merged_arg_spec
mutually_exclusive_params = []
mutually_exclusive_params: list[Sequence[str]] = []
mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE
if mutually_exclusive:
mutually_exclusive_params += mutually_exclusive
required_together_params = []
required_together_params: list[Sequence[str]] = []
required_together_params += DOCKER_REQUIRED_TOGETHER
if required_together:
required_together_params += required_together
@ -600,20 +631,30 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
option_minimal_versions, option_minimal_versions_ignore_params
)
def fail(self, msg, **kwargs):
def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
self.fail_results.update(kwargs)
self.module.fail_json(msg=msg, **sanitize_result(self.fail_results))
def deprecate(self, msg, version=None, date=None, collection_name=None):
def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
self.module.deprecate(
msg, version=version, date=date, collection_name=collection_name
)
def _get_params(self):
def _get_params(self) -> dict[str, t.Any]:
return self.module.params
def _get_minimal_versions(self, option_minimal_versions, ignore_params=None):
self.option_minimal_versions = {}
def _get_minimal_versions(
self,
option_minimal_versions: dict[str, t.Any],
ignore_params: Sequence[str] | None = None,
) -> None:
self.option_minimal_versions: dict[str, dict[str, t.Any]] = {}
for option in self.module.argument_spec:
if ignore_params is not None:
if option in ignore_params:
@ -654,7 +695,9 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
msg = f"Cannot {usg} with your configuration."
self.fail(msg)
def report_warnings(self, result, warnings_key=None):
def report_warnings(
self, result: t.Any, warnings_key: Sequence[str] | None = None
) -> None:
"""
Checks result of client operation for warnings, and if present, outputs them.

View File

@ -10,6 +10,7 @@ from __future__ import annotations
import abc
import json
import shlex
import typing as t
from ansible.module_utils.basic import AnsibleModule, env_fallback
from ansible.module_utils.common.process import get_bin_path
@ -31,6 +32,10 @@ from ansible_collections.community.docker.plugins.module_utils._version import (
)
if t.TYPE_CHECKING:
from collections.abc import Mapping, Sequence
DOCKER_COMMON_ARGS = {
"docker_cli": {"type": "path"},
"docker_host": {
@ -72,10 +77,16 @@ class DockerException(Exception):
class AnsibleDockerClientBase:
docker_api_version_str: str | None
docker_api_version: LooseVersion | None
def __init__(
self, common_args, min_docker_api_version=None, needs_api_version=True
):
self._environment = {}
self,
common_args,
min_docker_api_version: str | None = None,
needs_api_version: bool = True,
) -> None:
self._environment: dict[str, str] = {}
if common_args["tls_hostname"]:
self._environment["DOCKER_TLS_HOSTNAME"] = common_args["tls_hostname"]
if common_args["api_version"] and common_args["api_version"] != "auto":
@ -109,10 +120,10 @@ class AnsibleDockerClientBase:
self._cli_base.extend(["--context", common_args["cli_context"]])
# `--format json` was only added as a shorthand for `--format {{ json . }}` in Docker 23.0
dummy, self._version, dummy = self.call_cli_json(
dummy, self._version, dummy2 = self.call_cli_json(
"version", "--format", "{{ json . }}", check_rc=True
)
self._info = None
self._info: dict[str, t.Any] | None = None
if needs_api_version:
if not isinstance(self._version.get("Server"), dict) or not isinstance(
@ -138,32 +149,47 @@ class AnsibleDockerClientBase:
"Internal error: cannot have needs_api_version=False with min_docker_api_version not None"
)
def log(self, msg, pretty_print=False):
def log(self, msg: str, pretty_print: bool = False):
pass
# if self.debug:
# from .util import log_debug
# log_debug(msg, pretty_print=pretty_print)
def get_cli(self):
def get_cli(self) -> str:
return self._cli
def get_version_info(self):
def get_version_info(self) -> str:
return self._version
def _compose_cmd(self, args):
def _compose_cmd(self, args: t.Sequence[str]) -> list[str]:
return self._cli_base + list(args)
def _compose_cmd_str(self, args):
def _compose_cmd_str(self, args: t.Sequence[str]) -> str:
return " ".join(shlex.quote(a) for a in self._compose_cmd(args))
@abc.abstractmethod
def call_cli(self, *args, check_rc=False, data=None, cwd=None, environ_update=None):
def call_cli(
self,
*args: str,
check_rc: bool = False,
data: bytes | None = None,
cwd: str | None = None,
environ_update: dict[str, str] | None = None,
) -> tuple[int, bytes, bytes]:
pass
# def call_cli_json(self, *args, check_rc=False, data=None, cwd=None, environ_update=None, warn_on_stderr=False):
def call_cli_json(self, *args, **kwargs):
warn_on_stderr = kwargs.pop("warn_on_stderr", False)
rc, stdout, stderr = self.call_cli(*args, **kwargs)
def call_cli_json(
self,
*args: str,
check_rc: bool = False,
data: bytes | None = None,
cwd: str | None = None,
environ_update: dict[str, str] | None = None,
warn_on_stderr: bool = False,
) -> tuple[int, t.Any, bytes]:
rc, stdout, stderr = self.call_cli(
*args, check_rc=check_rc, data=data, cwd=cwd, environ_update=environ_update
)
if warn_on_stderr and stderr:
self.warn(to_native(stderr))
try:
@ -174,10 +200,18 @@ class AnsibleDockerClientBase:
)
return rc, data, stderr
# def call_cli_json_stream(self, *args, check_rc=False, data=None, cwd=None, environ_update=None, warn_on_stderr=False):
def call_cli_json_stream(self, *args, **kwargs):
warn_on_stderr = kwargs.pop("warn_on_stderr", False)
rc, stdout, stderr = self.call_cli(*args, **kwargs)
def call_cli_json_stream(
self,
*args: str,
check_rc: bool = False,
data: bytes | None = None,
cwd: str | None = None,
environ_update: dict[str, str] | None = None,
warn_on_stderr: bool = False,
) -> tuple[int, list[t.Any], bytes]:
rc, stdout, stderr = self.call_cli(
*args, check_rc=check_rc, data=data, cwd=cwd, environ_update=environ_update
)
if warn_on_stderr and stderr:
self.warn(to_native(stderr))
result = []
@ -193,25 +227,31 @@ class AnsibleDockerClientBase:
return rc, result, stderr
@abc.abstractmethod
def fail(self, msg, **kwargs):
def fail(self, msg: str, **kwargs) -> t.NoReturn:
pass
@abc.abstractmethod
def warn(self, msg):
def warn(self, msg: str) -> None:
pass
@abc.abstractmethod
def deprecate(self, msg, version=None, date=None, collection_name=None):
def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
pass
def get_cli_info(self):
def get_cli_info(self) -> dict[str, t.Any]:
if self._info is None:
dummy, self._info, dummy = self.call_cli_json(
dummy, self._info, dummy2 = self.call_cli_json(
"info", "--format", "{{ json . }}", check_rc=True
)
return self._info
def get_client_plugin_info(self, component):
def get_client_plugin_info(self, component: str) -> dict[str, t.Any] | None:
cli_info = self.get_cli_info()
if not isinstance(cli_info.get("ClientInfo"), dict):
self.fail(
@ -222,13 +262,13 @@ class AnsibleDockerClientBase:
return plugin
return None
def _image_lookup(self, name, tag):
def _image_lookup(self, name: str, tag: str) -> list[dict[str, t.Any]]:
"""
Including a tag in the name parameter sent to the Docker SDK for Python images method
does not work consistently. Instead, get the result set for name and manually check
if the tag exists.
"""
dummy, images, dummy = self.call_cli_json_stream(
dummy, images, dummy2 = self.call_cli_json_stream(
"image",
"ls",
"--format",
@ -247,7 +287,13 @@ class AnsibleDockerClientBase:
break
return images
def find_image(self, name, tag):
@t.overload
def find_image(self, name: None, tag: str) -> None: ...
@t.overload
def find_image(self, name: str, tag: str) -> dict[str, t.Any] | None: ...
def find_image(self, name: str | None, tag: str) -> dict[str, t.Any] | None:
"""
Lookup an image (by name and tag) and return the inspection results.
"""
@ -298,7 +344,19 @@ class AnsibleDockerClientBase:
self.log(f"Image {name}:{tag} not found.")
return None
def find_image_by_id(self, image_id, accept_missing_image=False):
@t.overload
def find_image_by_id(
self, image_id: None, accept_missing_image: bool = False
) -> None: ...
@t.overload
def find_image_by_id(
self, image_id: str | None, accept_missing_image: bool = False
) -> dict[str, t.Any] | None: ...
def find_image_by_id(
self, image_id: str | None, accept_missing_image: bool = False
) -> dict[str, t.Any] | None:
"""
Lookup an image (by ID) and return the inspection results.
"""
@ -320,17 +378,23 @@ class AnsibleDockerClientBase:
class AnsibleModuleDockerClient(AnsibleDockerClientBase):
def __init__(
self,
argument_spec=None,
supports_check_mode=False,
mutually_exclusive=None,
required_together=None,
required_if=None,
required_one_of=None,
required_by=None,
min_docker_api_version=None,
fail_results=None,
needs_api_version=True,
):
argument_spec: dict[str, t.Any] | None = None,
supports_check_mode: bool = False,
mutually_exclusive: Sequence[Sequence[str]] | None = None,
required_together: Sequence[Sequence[str]] | None = None,
required_if: (
Sequence[
tuple[str, t.Any, Sequence[str]]
| tuple[str, t.Any, Sequence[str], bool]
]
| None
) = None,
required_one_of: Sequence[Sequence[str]] | None = None,
required_by: Mapping[str, Sequence[str]] | None = None,
min_docker_api_version: str | None = None,
fail_results: dict[str, t.Any] | None = None,
needs_api_version: bool = True,
) -> None:
# Modules can put information in here which will always be returned
# in case client.fail() is called.
@ -342,12 +406,14 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase):
merged_arg_spec.update(argument_spec)
self.arg_spec = merged_arg_spec
mutually_exclusive_params = [("docker_host", "cli_context")]
mutually_exclusive_params: list[Sequence[str]] = [
("docker_host", "cli_context")
]
mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE
if mutually_exclusive:
mutually_exclusive_params += mutually_exclusive
required_together_params = []
required_together_params: list[Sequence[str]] = []
required_together_params += DOCKER_REQUIRED_TOGETHER
if required_together:
required_together_params += required_together
@ -373,7 +439,14 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase):
needs_api_version=needs_api_version,
)
def call_cli(self, *args, check_rc=False, data=None, cwd=None, environ_update=None):
def call_cli(
self,
*args: str,
check_rc: bool = False,
data: bytes | None = None,
cwd: str | None = None,
environ_update: dict[str, str] | None = None,
) -> tuple[int, bytes, bytes]:
environment = self._environment.copy()
if environ_update:
environment.update(environ_update)
@ -390,14 +463,20 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase):
)
return rc, stdout, stderr
def fail(self, msg, **kwargs):
def fail(self, msg: str, **kwargs) -> t.NoReturn:
self.fail_results.update(kwargs)
self.module.fail_json(msg=msg, **sanitize_result(self.fail_results))
def warn(self, msg):
def warn(self, msg: str) -> None:
self.module.warn(msg)
def deprecate(self, msg, version=None, date=None, collection_name=None):
def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
self.module.deprecate(
msg, version=version, date=date, collection_name=collection_name
)

View File

@ -51,6 +51,13 @@ else:
HAS_PYYAML = True
PYYAML_IMPORT_ERROR = None # pylint: disable=invalid-name
if t.TYPE_CHECKING:
from collections.abc import Callable, Sequence
from ansible_collections.community.docker.plugins.module_utils._common_cli import (
AnsibleModuleDockerClient as _Client,
)
DOCKER_COMPOSE_FILES = (
"compose.yaml",
@ -241,7 +248,9 @@ _RE_BUILD_PROGRESS_EVENT = re.compile(r"^\s*==>\s+(?P<msg>.*)$")
MINIMUM_COMPOSE_VERSION = "2.18.0"
def _extract_event(line, warn_function=None):
def _extract_event(
line: str, warn_function: Callable[[str], None] | None = None
) -> tuple[Event | None, bool]:
match = _RE_RESOURCE_EVENT.match(line)
if match is not None:
status = match.group("status")
@ -324,7 +333,9 @@ def _extract_event(line, warn_function=None):
return None, False
def _extract_logfmt_event(line, warn_function=None):
def _extract_logfmt_event(
line: str, warn_function: Callable[[str], None] | None = None
) -> tuple[Event | None, bool]:
try:
result = _parse_logfmt_line(line, logrus_mode=True)
except _InvalidLogFmt:
@ -339,7 +350,11 @@ def _extract_logfmt_event(line, warn_function=None):
return None, False
def _warn_missing_dry_run_prefix(line, warn_missing_dry_run_prefix, warn_function):
def _warn_missing_dry_run_prefix(
line: str,
warn_missing_dry_run_prefix: bool,
warn_function: Callable[[str], None] | None,
) -> None:
if warn_missing_dry_run_prefix and warn_function:
# This could be a bug, a change of docker compose's output format, ...
# Tell the user to report it to us :-)
@ -350,7 +365,9 @@ def _warn_missing_dry_run_prefix(line, warn_missing_dry_run_prefix, warn_functio
)
def _warn_unparsable_line(line, warn_function):
def _warn_unparsable_line(
line: str, warn_function: Callable[[str], None] | None
) -> None:
# This could be a bug, a change of docker compose's output format, ...
# Tell the user to report it to us :-)
if warn_function:
@ -361,14 +378,16 @@ def _warn_unparsable_line(line, warn_function):
)
def _find_last_event_for(events, resource_id):
def _find_last_event_for(
events: list[Event], resource_id: str
) -> tuple[int, Event] | None:
for index, event in enumerate(reversed(events)):
if event.resource_id == resource_id:
return len(events) - 1 - index, event
return None
def _concat_event_msg(event, append_msg):
def _concat_event_msg(event: Event, append_msg: str) -> Event:
return Event(
event.resource_type,
event.resource_id,
@ -383,7 +402,9 @@ _JSON_LEVEL_TO_STATUS_MAP = {
}
def parse_json_events(stderr, warn_function=None):
def parse_json_events(
stderr: bytes, warn_function: Callable[[str], None] | None = None
) -> list[Event]:
events = []
stderr_lines = stderr.splitlines()
if stderr_lines and stderr_lines[-1] == b"":
@ -524,7 +545,12 @@ def parse_json_events(stderr, warn_function=None):
return events
def parse_events(stderr, dry_run=False, warn_function=None, nonzero_rc=False):
def parse_events(
stderr: bytes,
dry_run: bool = False,
warn_function: Callable[[str], None] | None = None,
nonzero_rc: bool = False,
) -> list[Event]:
events = []
error_event = None
stderr_lines = stderr.splitlines()
@ -598,7 +624,11 @@ def parse_events(stderr, dry_run=False, warn_function=None, nonzero_rc=False):
return events
def has_changes(events, ignore_service_pull_events=False, ignore_build_events=False):
def has_changes(
events: Sequence[Event],
ignore_service_pull_events: bool = False,
ignore_build_events: bool = False,
) -> bool:
for event in events:
if event.status in DOCKER_STATUS_WORKING:
if ignore_service_pull_events and event.status in DOCKER_STATUS_PULL:
@ -614,7 +644,7 @@ def has_changes(events, ignore_service_pull_events=False, ignore_build_events=Fa
return False
def extract_actions(events):
def extract_actions(events: Sequence[Event]) -> list[dict[str, t.Any]]:
actions = []
pull_actions = set()
for event in events:
@ -646,7 +676,9 @@ def extract_actions(events):
return actions
def emit_warnings(events, warn_function):
def emit_warnings(
events: Sequence[Event], warn_function: Callable[[str], None]
) -> None:
for event in events:
# If a message is present, assume it is a warning
if (
@ -657,13 +689,21 @@ def emit_warnings(events, warn_function):
)
def is_failed(events, rc):
def is_failed(events: Sequence[Event], rc: int) -> bool:
if rc:
return True
return False
def update_failed(result, events, args, stdout, stderr, rc, cli):
def update_failed(
result: dict[str, t.Any],
events: Sequence[Event],
args: list[str],
stdout: str | bytes,
stderr: str | bytes,
rc: int,
cli: str,
) -> bool:
if not rc:
return False
errors = []
@ -697,7 +737,7 @@ def update_failed(result, events, args, stdout, stderr, rc, cli):
return True
def common_compose_argspec():
def common_compose_argspec() -> dict[str, t.Any]:
return {
"project_src": {"type": "path"},
"project_name": {"type": "str"},
@ -709,7 +749,7 @@ def common_compose_argspec():
}
def common_compose_argspec_ex():
def common_compose_argspec_ex() -> dict[str, t.Any]:
return {
"argspec": common_compose_argspec(),
"mutually_exclusive": [("definition", "project_src"), ("definition", "files")],
@ -722,16 +762,18 @@ def common_compose_argspec_ex():
}
def combine_binary_output(*outputs):
def combine_binary_output(*outputs: bytes | None) -> bytes:
return b"\n".join(out for out in outputs if out)
def combine_text_output(*outputs):
def combine_text_output(*outputs: str | None) -> str:
return "\n".join(out for out in outputs if out)
class BaseComposeManager(DockerBaseClass):
def __init__(self, client, min_version=MINIMUM_COMPOSE_VERSION):
def __init__(
self, client: _Client, min_version: str = MINIMUM_COMPOSE_VERSION
) -> None:
super().__init__()
self.client = client
self.check_mode = self.client.check_mode
@ -795,12 +837,12 @@ class BaseComposeManager(DockerBaseClass):
# more precisely in https://github.com/docker/compose/pull/11478
self.use_json_events = self.compose_version >= LooseVersion("2.29.0")
def get_compose_version(self):
def get_compose_version(self) -> str:
return (
self.get_compose_version_from_cli() or self.get_compose_version_from_api()
)
def get_compose_version_from_cli(self):
def get_compose_version_from_cli(self) -> str | None:
rc, version_info, dummy_stderr = self.client.call_cli(
"compose", "version", "--format", "json"
)
@ -814,7 +856,7 @@ class BaseComposeManager(DockerBaseClass):
except Exception: # pylint: disable=broad-exception-caught
return None
def get_compose_version_from_api(self):
def get_compose_version_from_api(self) -> str:
compose = self.client.get_client_plugin_info("compose")
if compose is None:
self.fail(
@ -827,11 +869,11 @@ class BaseComposeManager(DockerBaseClass):
)
return compose["Version"].lstrip("v")
def fail(self, msg, **kwargs):
def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
self.cleanup()
self.client.fail(msg, **kwargs)
def get_base_args(self, plain_progress=False):
def get_base_args(self, plain_progress: bool = False) -> list[str]:
args = ["compose", "--ansi", "never"]
if self.use_json_events and not plain_progress:
args.extend(["--progress", "json"])
@ -849,28 +891,33 @@ class BaseComposeManager(DockerBaseClass):
args.extend(["--profile", profile])
return args
def _handle_failed_cli_call(self, args, rc, stdout, stderr):
def _handle_failed_cli_call(
self, args: list[str], rc: int, stdout: str | bytes, stderr: bytes
) -> t.NoReturn:
events = parse_json_events(stderr, warn_function=self.client.warn)
result = {}
result: dict[str, t.Any] = {}
self.update_failed(result, events, args, stdout, stderr, rc)
self.client.module.exit_json(**result)
def list_containers_raw(self):
def list_containers_raw(self) -> list[dict[str, t.Any]]:
args = self.get_base_args() + ["ps", "--format", "json", "--all"]
if self.compose_version >= LooseVersion("2.23.0"):
# https://github.com/docker/compose/pull/11038
args.append("--no-trunc")
kwargs = {"cwd": self.project_src, "check_rc": not self.use_json_events}
if self.compose_version >= LooseVersion("2.21.0"):
# Breaking change in 2.21.0: https://github.com/docker/compose/pull/10918
rc, containers, stderr = self.client.call_cli_json_stream(*args, **kwargs)
rc, containers, stderr = self.client.call_cli_json_stream(
*args, cwd=self.project_src, check_rc=not self.use_json_events
)
else:
rc, containers, stderr = self.client.call_cli_json(*args, **kwargs)
rc, containers, stderr = self.client.call_cli_json(
*args, cwd=self.project_src, check_rc=not self.use_json_events
)
if self.use_json_events and rc != 0:
self._handle_failed_cli_call(args, rc, containers, stderr)
self._handle_failed_cli_call(args, rc, json.dumps(containers), stderr)
return containers
def list_containers(self):
def list_containers(self) -> list[dict[str, t.Any]]:
result = []
for container in self.list_containers_raw():
labels = {}
@ -887,10 +934,11 @@ class BaseComposeManager(DockerBaseClass):
result.append(container)
return result
def list_images(self):
def list_images(self) -> list[str]:
args = self.get_base_args() + ["images", "--format", "json"]
kwargs = {"cwd": self.project_src, "check_rc": not self.use_json_events}
rc, images, stderr = self.client.call_cli_json(*args, **kwargs)
rc, images, stderr = self.client.call_cli_json(
*args, cwd=self.project_src, check_rc=not self.use_json_events
)
if self.use_json_events and rc != 0:
self._handle_failed_cli_call(args, rc, images, stderr)
if isinstance(images, dict):
@ -900,7 +948,9 @@ class BaseComposeManager(DockerBaseClass):
images = list(images.values())
return images
def parse_events(self, stderr, dry_run=False, nonzero_rc=False):
def parse_events(
self, stderr: bytes, dry_run: bool = False, nonzero_rc: bool = False
) -> list[Event]:
if self.use_json_events:
return parse_json_events(stderr, warn_function=self.client.warn)
return parse_events(
@ -910,17 +960,17 @@ class BaseComposeManager(DockerBaseClass):
nonzero_rc=nonzero_rc,
)
def emit_warnings(self, events):
def emit_warnings(self, events: Sequence[Event]) -> None:
emit_warnings(events, warn_function=self.client.warn)
def update_result(
self,
result,
events,
stdout,
stderr,
ignore_service_pull_events=False,
ignore_build_events=False,
result: dict[str, t.Any],
events: Sequence[Event],
stdout: str,
stderr: str,
ignore_service_pull_events: bool = False,
ignore_build_events: bool = False,
):
result["changed"] = result.get("changed", False) or has_changes(
events,
@ -931,7 +981,15 @@ class BaseComposeManager(DockerBaseClass):
result["stdout"] = combine_text_output(result.get("stdout"), to_native(stdout))
result["stderr"] = combine_text_output(result.get("stderr"), to_native(stderr))
def update_failed(self, result, events, args, stdout, stderr, rc):
def update_failed(
self,
result: dict[str, t.Any],
events: Sequence[Event],
args: list[str],
stdout: str | bytes,
stderr: bytes,
rc: int,
):
return update_failed(
result,
events,
@ -942,14 +1000,14 @@ class BaseComposeManager(DockerBaseClass):
cli=self.client.get_cli(),
)
def cleanup_result(self, result):
def cleanup_result(self, result: dict[str, t.Any]) -> None:
if not result.get("failed"):
# Only return stdout and stderr if it is not empty
for res in ("stdout", "stderr"):
if result.get(res) == "":
result.pop(res)
def cleanup(self):
def cleanup(self) -> None:
for directory in self.cleanup_dirs:
try:
shutil.rmtree(directory, True)

View File

@ -16,6 +16,7 @@ import os.path
import shutil
import stat
import tarfile
import typing as t
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
@ -25,6 +26,16 @@ from ansible_collections.community.docker.plugins.module_utils._api.errors impor
)
if t.TYPE_CHECKING:
from collections.abc import Callable
from _typeshed import WriteableBuffer
from ansible_collections.community.docker.plugins.module_utils._api.api.client import (
APIClient,
)
class DockerFileCopyError(Exception):
pass
@ -37,7 +48,9 @@ class DockerFileNotFound(DockerFileCopyError):
pass
def _put_archive(client, container, path, data):
def _put_archive(
client: APIClient, container: str, path: str, data: bytes | t.Generator[bytes]
) -> bool:
# data can also be file object for streaming. This is because _put uses requests's put().
# See https://requests.readthedocs.io/en/latest/user/advanced/#streaming-uploads
url = client._url("/containers/{0}/archive", container)
@ -47,8 +60,14 @@ def _put_archive(client, container, path, data):
def _symlink_tar_creator(
b_in_path, file_stat, out_file, user_id, group_id, mode=None, user_name=None
):
b_in_path: bytes,
file_stat: os.stat_result,
out_file: str | bytes,
user_id: int,
group_id: int,
mode: int | None = None,
user_name: str | None = None,
) -> bytes:
if not stat.S_ISLNK(file_stat.st_mode):
raise DockerUnexpectedError("stat information is not for a symlink")
bio = io.BytesIO()
@ -75,16 +94,28 @@ def _symlink_tar_creator(
def _symlink_tar_generator(
b_in_path, file_stat, out_file, user_id, group_id, mode=None, user_name=None
):
b_in_path: bytes,
file_stat: os.stat_result,
out_file: str | bytes,
user_id: int,
group_id: int,
mode: int | None = None,
user_name: str | None = None,
) -> t.Generator[bytes]:
yield _symlink_tar_creator(
b_in_path, file_stat, out_file, user_id, group_id, mode, user_name
)
def _regular_file_tar_generator(
b_in_path, file_stat, out_file, user_id, group_id, mode=None, user_name=None
):
b_in_path: bytes,
file_stat: os.stat_result,
out_file: str | bytes,
user_id: int,
group_id: int,
mode: int | None = None,
user_name: str | None = None,
) -> t.Generator[bytes]:
if not stat.S_ISREG(file_stat.st_mode):
raise DockerUnexpectedError("stat information is not for a regular file")
tarinfo = tarfile.TarInfo()
@ -136,8 +167,13 @@ def _regular_file_tar_generator(
def _regular_content_tar_generator(
content, out_file, user_id, group_id, mode, user_name=None
):
content: bytes,
out_file: str | bytes,
user_id: int,
group_id: int,
mode: int,
user_name: str | None = None,
) -> t.Generator[bytes]:
tarinfo = tarfile.TarInfo()
tarinfo.name = (
os.path.splitdrive(to_text(out_file))[1].replace(os.sep, "/").lstrip("/")
@ -175,16 +211,16 @@ def _regular_content_tar_generator(
def put_file(
client,
container,
in_path,
out_path,
user_id,
group_id,
mode=None,
user_name=None,
follow_links=False,
):
client: APIClient,
container: str,
in_path: str,
out_path: str,
user_id: int,
group_id: int,
mode: int | None = None,
user_name: str | None = None,
follow_links: bool = False,
) -> None:
"""Transfer a file from local to Docker container."""
if not os.path.exists(to_bytes(in_path, errors="surrogate_or_strict")):
raise DockerFileNotFound(f"file or module does not exist: {to_native(in_path)}")
@ -232,8 +268,15 @@ def put_file(
def put_file_content(
client, container, content, out_path, user_id, group_id, mode, user_name=None
):
client: APIClient,
container: str,
content: bytes,
out_path: str,
user_id: int,
group_id: int,
mode: int,
user_name: str | None = None,
) -> None:
"""Transfer a file from local to Docker container."""
out_dir, out_file = os.path.split(out_path)
@ -248,7 +291,13 @@ def put_file_content(
)
def stat_file(client, container, in_path, follow_links=False, log=None):
def stat_file(
client: APIClient,
container: str,
in_path: str,
follow_links: bool = False,
log: Callable[[str], None] | None = None,
) -> tuple[str | bytes, dict[str, t.Any] | None, str | None]:
"""Fetch information on a file from a Docker container to local.
Return a tuple ``(path, stat_data, link_target)`` where:
@ -265,12 +314,12 @@ def stat_file(client, container, in_path, follow_links=False, log=None):
while True:
if in_path in considered_in_paths:
raise DockerFileCopyError(
f'Found infinite symbolic link loop when trying to stating "{in_path}"'
f"Found infinite symbolic link loop when trying to stating {in_path!r}"
)
considered_in_paths.add(in_path)
if log:
log(f'FETCH: Stating "{in_path}"')
log(f"FETCH: Stating {in_path!r}")
response = client._head(
client._url("/containers/{0}/archive", container),
@ -299,24 +348,24 @@ def stat_file(client, container, in_path, follow_links=False, log=None):
class _RawGeneratorFileobj(io.RawIOBase):
def __init__(self, stream):
def __init__(self, stream: t.Generator[bytes]):
self._stream = stream
self._buf = b""
def readable(self):
def readable(self) -> bool:
return True
def _readinto_from_buf(self, b, index, length):
def _readinto_from_buf(self, b: WriteableBuffer, index: int, length: int) -> int:
cpy = min(length - index, len(self._buf))
if cpy:
b[index : index + cpy] = self._buf[:cpy]
b[index : index + cpy] = self._buf[:cpy] # type: ignore # TODO!
self._buf = self._buf[cpy:]
index += cpy
return index
def readinto(self, b):
def readinto(self, b: WriteableBuffer) -> int:
index = 0
length = len(b)
length = len(b) # type: ignore # TODO!
index = self._readinto_from_buf(b, index, length)
if index == length:
@ -330,25 +379,28 @@ class _RawGeneratorFileobj(io.RawIOBase):
return self._readinto_from_buf(b, index, length)
def _stream_generator_to_fileobj(stream):
def _stream_generator_to_fileobj(stream: t.Generator[bytes]) -> io.BufferedReader:
"""Given a generator that generates chunks of bytes, create a readable buffered stream."""
raw = _RawGeneratorFileobj(stream)
return io.BufferedReader(raw)
_T = t.TypeVar("_T")
def fetch_file_ex(
client,
container,
in_path,
process_none,
process_regular,
process_symlink,
process_other,
follow_links=False,
log=None,
):
client: APIClient,
container: str,
in_path: str,
process_none: Callable[[str], _T],
process_regular: Callable[[str, tarfile.TarFile, tarfile.TarInfo], _T],
process_symlink: Callable[[str, tarfile.TarInfo], _T],
process_other: Callable[[str, tarfile.TarInfo], _T],
follow_links: bool = False,
log: Callable[[str], None] | None = None,
) -> _T:
"""Fetch a file (as a tar file entry) from a Docker container to local."""
considered_in_paths = set()
considered_in_paths: set[str] = set()
while True:
if in_path in considered_in_paths:
@ -372,8 +424,8 @@ def fetch_file_ex(
with tarfile.open(
fileobj=_stream_generator_to_fileobj(stream), mode="r|"
) as tar:
symlink_member = None
result = None
symlink_member: tarfile.TarInfo | None = None
result: _T | None = None
found = False
for member in tar:
if found:
@ -398,35 +450,46 @@ def fetch_file_ex(
log(f'FETCH: Following symbolic link to "{in_path}"')
continue
if found:
return result
return result # type: ignore
raise DockerUnexpectedError("Received tarfile is empty!")
def fetch_file(client, container, in_path, out_path, follow_links=False, log=None):
def fetch_file(
client: APIClient,
container: str,
in_path: str,
out_path: str,
follow_links: bool = False,
log: Callable[[str], None] | None = None,
) -> str:
b_out_path = to_bytes(out_path, errors="surrogate_or_strict")
def process_none(in_path):
def process_none(in_path: str) -> str:
raise DockerFileNotFound(
f"File {in_path} does not exist in container {container}"
)
def process_regular(in_path, tar, member):
def process_regular(
in_path: str, tar: tarfile.TarFile, member: tarfile.TarInfo
) -> str:
if not follow_links and os.path.exists(b_out_path):
os.unlink(b_out_path)
with tar.extractfile(member) as in_f:
with open(b_out_path, "wb") as out_f:
shutil.copyfileobj(in_f, out_f)
reader = tar.extractfile(member)
if reader:
with reader as in_f:
with open(b_out_path, "wb") as out_f:
shutil.copyfileobj(in_f, out_f)
return in_path
def process_symlink(in_path, member):
def process_symlink(in_path, member) -> str:
if os.path.exists(b_out_path):
os.unlink(b_out_path)
os.symlink(member.linkname, b_out_path)
return in_path
def process_other(in_path, member):
def process_other(in_path, member) -> str:
raise DockerFileCopyError(
f'Remote file "{in_path}" is not a regular file or a symbolic link'
)
@ -444,7 +507,13 @@ def fetch_file(client, container, in_path, out_path, follow_links=False, log=Non
)
def _execute_command(client, container, command, log=None, check_rc=False):
def _execute_command(
client: APIClient,
container: str,
command: list[str],
log: Callable[[str], None] | None = None,
check_rc: bool = False,
) -> tuple[int, bytes, bytes]:
if log:
log(f"Executing {command} in {container}")
@ -493,13 +562,15 @@ def _execute_command(client, container, command, log=None, check_rc=False):
if check_rc and rc != 0:
command_str = " ".join(command)
raise DockerUnexpectedError(
f'Obtained unexpected exit code {rc} when running "{command_str}" in {container}.\nSTDOUT: {stdout}\nSTDERR: {stderr}'
f'Obtained unexpected exit code {rc} when running "{command_str}" in {container}.\nSTDOUT: {stdout!r}\nSTDERR: {stderr!r}'
)
return rc, stdout, stderr
def determine_user_group(client, container, log=None):
def determine_user_group(
client: APIClient, container: str, log: Callable[[str], None] | None = None
) -> tuple[int, int]:
dummy_rc, stdout, dummy_stderr = _execute_command(
client, container, ["/bin/sh", "-c", "id -u && id -g"], check_rc=True, log=log
)
@ -507,7 +578,7 @@ def determine_user_group(client, container, log=None):
stdout_lines = stdout.splitlines()
if len(stdout_lines) != 2:
raise DockerUnexpectedError(
f"Expected two-line output to obtain user and group ID for container {container}, but got {len(stdout_lines)} lines:\n{stdout}"
f"Expected two-line output to obtain user and group ID for container {container}, but got {len(stdout_lines)} lines:\n{stdout!r}"
)
user_id, group_id = stdout_lines
@ -515,5 +586,5 @@ def determine_user_group(client, container, log=None):
return int(user_id), int(group_id)
except ValueError as exc:
raise DockerUnexpectedError(
f'Expected two-line output with numeric IDs to obtain user and group ID for container {container}, but got "{user_id}" and "{group_id}" instead'
f"Expected two-line output with numeric IDs to obtain user and group ID for container {container}, but got {user_id!r} and {group_id!r} instead"
) from exc

View File

@ -18,12 +18,10 @@ class ImageArchiveManifestSummary:
"docker image save some:tag > some.tar" command.
"""
def __init__(self, image_id, repo_tags):
def __init__(self, image_id: str, repo_tags: list[str]) -> None:
"""
:param image_id: File name portion of Config entry, e.g. abcde12345 from abcde12345.json
:type image_id: str
:param repo_tags Docker image names, e.g. ["hello-world:latest"]
:type repo_tags: list[str]
"""
self.image_id = image_id
@ -34,22 +32,21 @@ class ImageArchiveInvalidException(Exception):
pass
def api_image_id(archive_image_id):
def api_image_id(archive_image_id: str) -> str:
"""
Accepts an image hash in the format stored in manifest.json, and returns an equivalent identifier
that represents the same image hash, but in the format presented by the Docker Engine API.
:param archive_image_id: plain image hash
:type archive_image_id: str
:returns: Prefixed hash used by REST api
:rtype: str
"""
return f"sha256:{archive_image_id}"
def load_archived_image_manifest(archive_path):
def load_archived_image_manifest(
archive_path: str,
) -> list[ImageArchiveManifestSummary] | None:
"""
Attempts to get image IDs and image names from metadata stored in the image
archive tar file.
@ -62,10 +59,7 @@ def load_archived_image_manifest(archive_path):
ImageArchiveInvalidException: A file already exists at archive_path, but could not extract an image ID from it.
:param archive_path: Tar file to read
:type archive_path: str
:return: None, if no file at archive_path, or a list of ImageArchiveManifestSummary objects.
:rtype: ImageArchiveManifestSummary
"""
try:
@ -76,8 +70,15 @@ def load_archived_image_manifest(archive_path):
with tarfile.open(archive_path, "r") as tf:
try:
try:
with tf.extractfile("manifest.json") as ef:
reader = tf.extractfile("manifest.json")
if reader is None:
raise ImageArchiveInvalidException(
"Failed to read manifest.json"
)
with reader as ef:
manifest = json.load(ef)
except ImageArchiveInvalidException:
raise
except Exception as exc:
raise ImageArchiveInvalidException(
f"Failed to decode and deserialize manifest.json: {exc}"
@ -139,7 +140,7 @@ def load_archived_image_manifest(archive_path):
) from exc
def archived_image_manifest(archive_path):
def archived_image_manifest(archive_path: str) -> ImageArchiveManifestSummary | None:
"""
Attempts to get Image.Id and image name from metadata stored in the image
archive tar file.
@ -152,10 +153,7 @@ def archived_image_manifest(archive_path):
ImageArchiveInvalidException: A file already exists at archive_path, but could not extract an image ID from it.
:param archive_path: Tar file to read
:type archive_path: str
:return: None, if no file at archive_path, or the extracted image ID, which will not have a sha256: prefix.
:rtype: ImageArchiveManifestSummary
"""
results = load_archived_image_manifest(archive_path)

View File

@ -13,6 +13,9 @@ See https://pkg.go.dev/github.com/kr/logfmt?utm_source=godoc for information on
from __future__ import annotations
import typing as t
from enum import Enum
# The format is defined in https://pkg.go.dev/github.com/kr/logfmt?utm_source=godoc
# (look for "EBNFish")
@ -22,7 +25,7 @@ class InvalidLogFmt(Exception):
pass
class _Mode:
class _Mode(Enum):
GARBAGE = 0
KEY = 1
EQUAL = 2
@ -68,29 +71,29 @@ _HEX_DICT = {
}
def _is_ident(cur):
def _is_ident(cur: str) -> bool:
return cur > " " and cur not in ('"', "=")
class _Parser:
def __init__(self, line):
def __init__(self, line: str) -> None:
self.line = line
self.index = 0
self.length = len(line)
def done(self):
def done(self) -> bool:
return self.index >= self.length
def cur(self):
def cur(self) -> str:
return self.line[self.index]
def next(self):
def next(self) -> None:
self.index += 1
def prev(self):
def prev(self) -> None:
self.index -= 1
def parse_unicode_sequence(self):
def parse_unicode_sequence(self) -> str:
if self.index + 6 > self.length:
raise InvalidLogFmt("Not enough space for unicode escape")
if self.line[self.index : self.index + 2] != "\\u":
@ -108,27 +111,27 @@ class _Parser:
return chr(v)
def parse_line(line, logrus_mode=False):
result = {}
def parse_line(line: str, logrus_mode: bool = False) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
parser = _Parser(line)
key = []
value = []
key: list[str] = []
value: list[str] = []
mode = _Mode.GARBAGE
def handle_kv(has_no_value=False):
def handle_kv(has_no_value: bool = False) -> None:
k = "".join(key)
v = None if has_no_value else "".join(value)
result[k] = v
del key[:]
del value[:]
def parse_garbage(cur):
def parse_garbage(cur: str) -> _Mode:
if _is_ident(cur):
return _Mode.KEY
parser.next()
return _Mode.GARBAGE
def parse_key(cur):
def parse_key(cur: str) -> _Mode:
if _is_ident(cur):
key.append(cur)
parser.next()
@ -142,7 +145,7 @@ def parse_line(line, logrus_mode=False):
parser.next()
return _Mode.GARBAGE
def parse_equal(cur):
def parse_equal(cur: str) -> _Mode:
if _is_ident(cur):
value.append(cur)
parser.next()
@ -154,7 +157,7 @@ def parse_line(line, logrus_mode=False):
parser.next()
return _Mode.GARBAGE
def parse_ident_value(cur):
def parse_ident_value(cur: str) -> _Mode:
if _is_ident(cur):
value.append(cur)
parser.next()
@ -163,7 +166,7 @@ def parse_line(line, logrus_mode=False):
parser.next()
return _Mode.GARBAGE
def parse_quoted_value(cur):
def parse_quoted_value(cur: str) -> _Mode:
if cur == "\\":
parser.next()
if parser.done():

View File

@ -14,12 +14,13 @@
from __future__ import annotations
import re
import typing as t
_VALID_STR = re.compile("^[A-Za-z0-9_-]+$")
def _validate_part(string, part, part_name):
def _validate_part(string: str, part: str, part_name: str) -> str:
if not part:
raise ValueError(f'Invalid platform string "{string}": {part_name} is empty')
if not _VALID_STR.match(part):
@ -79,7 +80,7 @@ _KNOWN_ARCH = (
)
def _normalize_os(os_str):
def _normalize_os(os_str: str) -> str:
# See normalizeOS() in https://github.com/containerd/containerd/blob/main/platforms/database.go
os_str = os_str.lower()
if os_str == "macos":
@ -112,7 +113,7 @@ _NORMALIZE_ARCH = {
}
def _normalize_arch(arch_str, variant_str):
def _normalize_arch(arch_str: str, variant_str: str) -> tuple[str, str]:
# See normalizeArch() in https://github.com/containerd/containerd/blob/main/platforms/database.go
arch_str = arch_str.lower()
variant_str = variant_str.lower()
@ -121,15 +122,16 @@ def _normalize_arch(arch_str, variant_str):
res = _NORMALIZE_ARCH.get((arch_str, None))
if res is None:
return arch_str, variant_str
if res is not None:
arch_str = res[0]
if res[1] is not None:
variant_str = res[1]
return arch_str, variant_str
arch_str = res[0]
if res[1] is not None:
variant_str = res[1]
return arch_str, variant_str
class _Platform:
def __init__(self, os=None, arch=None, variant=None):
def __init__(
self, os: str | None = None, arch: str | None = None, variant: str | None = None
) -> None:
self.os = os
self.arch = arch
self.variant = variant
@ -140,7 +142,12 @@ class _Platform:
raise ValueError("If variant is given, os must be given too")
@classmethod
def parse_platform_string(cls, string, daemon_os=None, daemon_arch=None):
def parse_platform_string(
cls,
string: str | None,
daemon_os: str | None = None,
daemon_arch: str | None = None,
) -> t.Self:
# See Parse() in https://github.com/containerd/containerd/blob/main/platforms/platforms.go
if string is None:
return cls()
@ -182,6 +189,7 @@ class _Platform:
)
if variant is not None and not variant:
raise ValueError(f'Invalid platform string "{string}": variant is empty')
assert arch is not None # otherwise variant would be None as well
arch, variant = _normalize_arch(arch, variant or "")
if len(parts) == 2 and arch == "arm" and variant == "v7":
variant = None
@ -189,9 +197,12 @@ class _Platform:
variant = "v8"
return cls(os=_normalize_os(os), arch=arch, variant=variant or None)
def __str__(self):
def __str__(self) -> str:
if self.variant:
parts = [self.os, self.arch, self.variant]
assert (
self.os is not None and self.arch is not None
) # ensured in constructor
parts: list[str] = [self.os, self.arch, self.variant]
elif self.os:
if self.arch:
parts = [self.os, self.arch]
@ -203,12 +214,14 @@ class _Platform:
parts = []
return "/".join(parts)
def __repr__(self):
def __repr__(self) -> str:
return (
f"_Platform(os={self.os!r}, arch={self.arch!r}, variant={self.variant!r})"
)
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, _Platform):
return NotImplemented
return (
self.os == other.os
and self.arch == other.arch
@ -216,7 +229,9 @@ class _Platform:
)
def normalize_platform_string(string, daemon_os=None, daemon_arch=None):
def normalize_platform_string(
string: str, daemon_os: str | None = None, daemon_arch: str | None = None
) -> str:
return str(
_Platform.parse_platform_string(
string, daemon_os=daemon_os, daemon_arch=daemon_arch
@ -225,8 +240,12 @@ def normalize_platform_string(string, daemon_os=None, daemon_arch=None):
def compose_platform_string(
os=None, arch=None, variant=None, daemon_os=None, daemon_arch=None
):
os: str | None = None,
arch: str | None = None,
variant: str | None = None,
daemon_os: str | None = None,
daemon_arch: str | None = None,
) -> str:
if os is None and daemon_os is not None:
os = _normalize_os(daemon_os)
if arch is None and daemon_arch is not None:
@ -235,7 +254,7 @@ def compose_platform_string(
return str(_Platform(os=os, arch=arch, variant=variant or None))
def compare_platform_strings(string1, string2):
def compare_platform_strings(string1: str, string2: str) -> bool:
return _Platform.parse_platform_string(string1) == _Platform.parse_platform_string(
string2
)

View File

@ -13,7 +13,7 @@ import random
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
def generate_insecure_key():
def generate_insecure_key() -> bytes:
"""Do NOT use this for cryptographic purposes!"""
while True:
# Generate a one-byte key. Right now the functions below do not use more
@ -24,23 +24,23 @@ def generate_insecure_key():
return key
def scramble(value, key):
def scramble(value: str, key: bytes) -> str:
"""Do NOT use this for cryptographic purposes!"""
if len(key) < 1:
raise ValueError("Key must be at least one byte")
value = to_bytes(value)
b_value = to_bytes(value)
k = key[0]
value = bytes([k ^ b for b in value])
return "=S=" + to_native(base64.b64encode(value))
b_value = bytes([k ^ b for b in b_value])
return f"=S={to_native(base64.b64encode(b_value))}"
def unscramble(value, key):
def unscramble(value: str, key: bytes) -> str:
"""Do NOT use this for cryptographic purposes!"""
if len(key) < 1:
raise ValueError("Key must be at least one byte")
if not value.startswith("=S="):
raise ValueError("Value does not start with indicator")
value = base64.b64decode(value[3:])
b_value = base64.b64decode(value[3:])
k = key[0]
value = bytes([k ^ b for b in value])
return to_text(value)
b_value = bytes([k ^ b for b in b_value])
return to_text(b_value)

View File

@ -9,6 +9,7 @@
from __future__ import annotations
import json
import typing as t
from time import sleep
@ -28,10 +29,7 @@ from ansible_collections.community.docker.plugins.module_utils._version import (
class AnsibleDockerSwarmClient(AnsibleDockerClient):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def get_swarm_node_id(self):
def get_swarm_node_id(self) -> str | None:
"""
Get the 'NodeID' of the Swarm node or 'None' if host is not in Swarm. It returns the NodeID
of Docker host the module is executed on
@ -51,7 +49,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
return swarm_info["Swarm"]["NodeID"]
return None
def check_if_swarm_node(self, node_id=None):
def check_if_swarm_node(self, node_id: str | None = None) -> bool | None:
"""
Checking if host is part of Docker Swarm. If 'node_id' is not provided it reads the Docker host
system information looking if specific key in output exists. If 'node_id' is provided then it tries to
@ -83,11 +81,11 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
try:
node_info = self.get_node_inspect(node_id=node_id)
except APIError:
return
return None
return node_info["ID"] is not None
def check_if_swarm_manager(self):
def check_if_swarm_manager(self) -> bool:
"""
Checks if node role is set as Manager in Swarm. The node is the docker host on which module action
is performed. The inspect_swarm() will fail if node is not a manager
@ -101,7 +99,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
except APIError:
return False
def fail_task_if_not_swarm_manager(self):
def fail_task_if_not_swarm_manager(self) -> None:
"""
If host is not a swarm manager then Ansible task on this host should end with 'failed' state
"""
@ -110,7 +108,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
"Error running docker swarm module: must run on swarm manager node"
)
def check_if_swarm_worker(self):
def check_if_swarm_worker(self) -> bool:
"""
Checks if node role is set as Worker in Swarm. The node is the docker host on which module action
is performed. Will fail if run on host that is not part of Swarm via check_if_swarm_node()
@ -122,7 +120,9 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
return True
return False
def check_if_swarm_node_is_down(self, node_id=None, repeat_check=1):
def check_if_swarm_node_is_down(
self, node_id: str | None = None, repeat_check: int = 1
) -> bool:
"""
Checks if node status on Swarm manager is 'down'. If node_id is provided it query manager about
node specified in parameter, otherwise it query manager itself. If run on Swarm Worker node or
@ -147,7 +147,19 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
return True
return False
def get_node_inspect(self, node_id=None, skip_missing=False):
@t.overload
def get_node_inspect(
self, node_id: str | None = None, skip_missing: t.Literal[False] = False
) -> dict[str, t.Any]: ...
@t.overload
def get_node_inspect(
self, node_id: str | None = None, skip_missing: bool = False
) -> dict[str, t.Any] | None: ...
def get_node_inspect(
self, node_id: str | None = None, skip_missing: bool = False
) -> dict[str, t.Any] | None:
"""
Returns Swarm node info as in 'docker node inspect' command about single node
@ -195,7 +207,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
node_info["Status"]["Addr"] = swarm_leader_ip
return node_info
def get_all_nodes_inspect(self):
def get_all_nodes_inspect(self) -> list[dict[str, t.Any]]:
"""
Returns Swarm node info as in 'docker node inspect' command about all registered nodes
@ -217,7 +229,17 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
node_info = json.loads(json_str)
return node_info
def get_all_nodes_list(self, output="short"):
@t.overload
def get_all_nodes_list(self, output: t.Literal["short"] = "short") -> list[str]: ...
@t.overload
def get_all_nodes_list(
self, output: t.Literal["long"]
) -> list[dict[str, t.Any]]: ...
def get_all_nodes_list(
self, output: t.Literal["short", "long"] = "short"
) -> list[str] | list[dict[str, t.Any]]:
"""
Returns list of nodes registered in Swarm
@ -227,48 +249,46 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
if 'output' is 'long' then returns data is list of dict containing the attributes as in
output of command 'docker node ls'
"""
nodes_list = []
nodes_inspect = self.get_all_nodes_inspect()
if nodes_inspect is None:
return None
if output == "short":
nodes_list = []
for node in nodes_inspect:
nodes_list.append(node["Description"]["Hostname"])
elif output == "long":
return nodes_list
if output == "long":
nodes_info_list = []
for node in nodes_inspect:
node_property = {}
node_property: dict[str, t.Any] = {}
node_property.update({"ID": node["ID"]})
node_property.update({"Hostname": node["Description"]["Hostname"]})
node_property.update({"Status": node["Status"]["State"]})
node_property.update({"Availability": node["Spec"]["Availability"]})
node_property["ID"] = node["ID"]
node_property["Hostname"] = node["Description"]["Hostname"]
node_property["Status"] = node["Status"]["State"]
node_property["Availability"] = node["Spec"]["Availability"]
if "ManagerStatus" in node:
if node["ManagerStatus"]["Leader"] is True:
node_property.update({"Leader": True})
node_property.update(
{"ManagerStatus": node["ManagerStatus"]["Reachability"]}
)
node_property.update(
{"EngineVersion": node["Description"]["Engine"]["EngineVersion"]}
)
node_property["Leader"] = True
node_property["ManagerStatus"] = node["ManagerStatus"][
"Reachability"
]
node_property["EngineVersion"] = node["Description"]["Engine"][
"EngineVersion"
]
nodes_list.append(node_property)
else:
return None
nodes_info_list.append(node_property)
return nodes_info_list
return nodes_list
def get_node_name_by_id(self, nodeid):
def get_node_name_by_id(self, nodeid: str) -> str:
return self.get_node_inspect(nodeid)["Description"]["Hostname"]
def get_unlock_key(self):
def get_unlock_key(self) -> str | None:
if self.docker_py_version < LooseVersion("2.7.0"):
return None
return super().get_unlock_key()
def get_service_inspect(self, service_id, skip_missing=False):
def get_service_inspect(
self, service_id: str, skip_missing: bool = False
) -> dict[str, t.Any] | None:
"""
Returns Swarm service info as in 'docker service inspect' command about single service

View File

@ -9,6 +9,7 @@ from __future__ import annotations
import json
import re
import typing as t
from datetime import timedelta
from urllib.parse import urlparse
@ -17,6 +18,12 @@ from ansible.module_utils.common.collections import is_sequence
from ansible.module_utils.common.text.converters import to_text
if t.TYPE_CHECKING:
from collections.abc import Callable
from ansible.module_utils.basic import AnsibleModule
DEFAULT_DOCKER_HOST = "unix:///var/run/docker.sock"
DEFAULT_TLS = False
DEFAULT_TLS_VERIFY = False
@ -79,14 +86,14 @@ DEFAULT_DOCKER_REGISTRY = "https://index.docker.io/v1/"
BYTE_SUFFIXES = ["B", "KB", "MB", "GB", "TB", "PB"]
def is_image_name_id(name):
def is_image_name_id(name: str) -> bool:
"""Check whether the given image name is in fact an image ID (hash)."""
if re.match("^sha256:[0-9a-fA-F]{64}$", name):
return True
return False
def is_valid_tag(tag, allow_empty=False):
def is_valid_tag(tag: str, allow_empty: bool = False) -> bool:
"""Check whether the given string is a valid docker tag name."""
if not tag:
return allow_empty
@ -95,7 +102,7 @@ def is_valid_tag(tag, allow_empty=False):
return bool(re.match("^[a-zA-Z0-9_][a-zA-Z0-9_.-]{0,127}$", tag))
def sanitize_result(data):
def sanitize_result(data: t.Any) -> t.Any:
"""Sanitize data object for return to Ansible.
When the data object contains types such as docker.types.containers.HostConfig,
@ -112,7 +119,7 @@ def sanitize_result(data):
return data
def log_debug(msg, pretty_print=False):
def log_debug(msg: t.Any, pretty_print: bool = False):
"""Write a log message to docker.log.
If ``pretty_print=True``, the message will be pretty-printed as JSON.
@ -128,25 +135,28 @@ def log_debug(msg, pretty_print=False):
class DockerBaseClass:
def __init__(self):
def __init__(self) -> None:
self.debug = False
def log(self, msg, pretty_print=False):
def log(self, msg: t.Any, pretty_print: bool = False) -> None:
pass
# if self.debug:
# log_debug(msg, pretty_print=pretty_print)
def update_tls_hostname(
result, old_behavior=False, deprecate_function=None, uses_tls=True
):
result: dict[str, t.Any],
old_behavior: bool = False,
deprecate_function: Callable[[str], None] | None = None,
uses_tls: bool = True,
) -> None:
if result["tls_hostname"] is None:
# get default machine name from the url
parsed_url = urlparse(result["docker_host"])
result["tls_hostname"] = parsed_url.netloc.rsplit(":", 1)[0]
def compare_dict_allow_more_present(av, bv):
def compare_dict_allow_more_present(av: dict, bv: dict) -> bool:
"""
Compare two dictionaries for whether every entry of the first is in the second.
"""
@ -158,7 +168,12 @@ def compare_dict_allow_more_present(av, bv):
return True
def compare_generic(a, b, method, datatype):
def compare_generic(
a: t.Any,
b: t.Any,
method: t.Literal["ignore", "strict", "allow_more_present"],
datatype: t.Literal["value", "list", "set", "set(dict)", "dict"],
) -> bool:
"""
Compare values a and b as described by method and datatype.
@ -249,10 +264,10 @@ def compare_generic(a, b, method, datatype):
class DifferenceTracker:
def __init__(self):
self._diff = []
def __init__(self) -> None:
self._diff: list[dict[str, t.Any]] = []
def add(self, name, parameter=None, active=None):
def add(self, name: str, parameter: t.Any = None, active: t.Any = None) -> None:
self._diff.append(
{
"name": name,
@ -261,14 +276,14 @@ class DifferenceTracker:
}
)
def merge(self, other_tracker):
def merge(self, other_tracker: DifferenceTracker) -> None:
self._diff.extend(other_tracker._diff)
@property
def empty(self):
def empty(self) -> bool:
return len(self._diff) == 0
def get_before_after(self):
def get_before_after(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
"""
Return texts ``before`` and ``after``.
"""
@ -279,13 +294,13 @@ class DifferenceTracker:
after[item["name"]] = item["parameter"]
return before, after
def has_difference_for(self, name):
def has_difference_for(self, name: str) -> bool:
"""
Returns a boolean if a difference exists for name
"""
return any(diff for diff in self._diff if diff["name"] == name)
def get_legacy_docker_container_diffs(self):
def get_legacy_docker_container_diffs(self) -> list[dict[str, t.Any]]:
"""
Return differences in the docker_container legacy format.
"""
@ -299,7 +314,7 @@ class DifferenceTracker:
result.append(item)
return result
def get_legacy_docker_diffs(self):
def get_legacy_docker_diffs(self) -> list[str]:
"""
Return differences in the docker_container legacy format.
"""
@ -307,8 +322,13 @@ class DifferenceTracker:
return result
def sanitize_labels(labels, labels_field, client=None, module=None):
def fail(msg):
def sanitize_labels(
labels: dict[str, t.Any] | None,
labels_field: str,
client=None,
module: AnsibleModule | None = None,
) -> None:
def fail(msg: str) -> t.NoReturn:
if client is not None:
client.fail(msg)
if module is not None:
@ -327,7 +347,21 @@ def sanitize_labels(labels, labels_field, client=None, module=None):
labels[k] = to_text(v)
def clean_dict_booleans_for_docker_api(data, allow_sequences=False):
@t.overload
def clean_dict_booleans_for_docker_api(
data: dict[str, t.Any], *, allow_sequences: t.Literal[False] = False
) -> dict[str, str]: ...
@t.overload
def clean_dict_booleans_for_docker_api(
data: dict[str, t.Any], *, allow_sequences: bool
) -> dict[str, str | list[str]]: ...
def clean_dict_booleans_for_docker_api(
data: dict[str, t.Any], *, allow_sequences: bool = False
) -> dict[str, str] | dict[str, str | list[str]]:
"""
Go does not like Python booleans 'True' or 'False', while Ansible is just
fine with them in YAML. As such, they need to be converted in cases where
@ -355,7 +389,7 @@ def clean_dict_booleans_for_docker_api(data, allow_sequences=False):
return result
def convert_duration_to_nanosecond(time_str):
def convert_duration_to_nanosecond(time_str: str) -> int:
"""
Return time duration in nanosecond.
"""
@ -374,9 +408,9 @@ def convert_duration_to_nanosecond(time_str):
if not parts:
raise ValueError(f"Invalid time duration - {time_str}")
parts = parts.groupdict()
parts_dict = parts.groupdict()
time_params = {}
for name, value in parts.items():
for name, value in parts_dict.items():
if value:
time_params[name] = int(value)
@ -388,13 +422,15 @@ def convert_duration_to_nanosecond(time_str):
return time_in_nanoseconds
def normalize_healthcheck_test(test):
def normalize_healthcheck_test(test: t.Any) -> list[str]:
if isinstance(test, (tuple, list)):
return [str(e) for e in test]
return ["CMD-SHELL", str(test)]
def normalize_healthcheck(healthcheck, normalize_test=False):
def normalize_healthcheck(
healthcheck: dict[str, t.Any], normalize_test: bool = False
) -> dict[str, t.Any]:
"""
Return dictionary of healthcheck parameters.
"""
@ -440,7 +476,9 @@ def normalize_healthcheck(healthcheck, normalize_test=False):
return result
def parse_healthcheck(healthcheck):
def parse_healthcheck(
healthcheck: dict[str, t.Any] | None,
) -> tuple[dict[str, t.Any] | None, bool | None]:
"""
Return dictionary of healthcheck parameters and boolean if
healthcheck defined in image was requested to be disabled.
@ -458,8 +496,8 @@ def parse_healthcheck(healthcheck):
return result, False
def omit_none_from_dict(d):
def omit_none_from_dict(d: dict[str, t.Any]) -> dict[str, t.Any]:
"""
Return a copy of the dictionary with all keys with value None omitted.
"""
return dict((k, v) for (k, v) in d.items() if v is not None)
return {k: v for (k, v) in d.items() if v is not None}

View File

@ -255,7 +255,7 @@ class DockerHostManager(DockerBaseClass):
returned_name = docker_object
filter_name = docker_object + "_filters"
filters = clean_dict_booleans_for_docker_api(
client.module.params.get(filter_name), True
client.module.params.get(filter_name), allow_sequences=True
)
self.results[returned_name] = self.get_docker_items_list(
docker_object, filters