First batch of types.

This commit is contained in:
Felix Fontein 2025-10-18 20:06:37 +02:00
parent c76ee1d1cc
commit 0ff66e8b24
12 changed files with 739 additions and 363 deletions

View File

@ -13,6 +13,7 @@ import platform
import re import re
import sys import sys
import traceback import traceback
import typing as t
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from ansible.module_utils.basic import AnsibleModule, missing_required_lib from ansible.module_utils.basic import AnsibleModule, missing_required_lib
@ -79,6 +80,10 @@ except ImportError:
pass pass
if t.TYPE_CHECKING:
from collections.abc import Callable
MIN_DOCKER_VERSION = "2.0.0" MIN_DOCKER_VERSION = "2.0.0"
@ -96,7 +101,9 @@ if not HAS_DOCKER_PY:
pass pass
def _get_tls_config(fail_function, **kwargs): def _get_tls_config(
fail_function: Callable[[str], t.NoReturn], **kwargs: t.Any
) -> TLSConfig:
if "assert_hostname" in kwargs and LooseVersion(docker_version) >= LooseVersion( if "assert_hostname" in kwargs and LooseVersion(docker_version) >= LooseVersion(
"7.0.0b1" "7.0.0b1"
): ):
@ -111,17 +118,18 @@ def _get_tls_config(fail_function, **kwargs):
# Filter out all None parameters # Filter out all None parameters
kwargs = dict((k, v) for k, v in kwargs.items() if v is not None) kwargs = dict((k, v) for k, v in kwargs.items() if v is not None)
try: try:
tls_config = TLSConfig(**kwargs) return TLSConfig(**kwargs)
return tls_config
except TLSParameterError as exc: except TLSParameterError as exc:
fail_function(f"TLS config error: {exc}") fail_function(f"TLS config error: {exc}")
def is_using_tls(auth_data): def is_using_tls(auth_data: dict[str, t.Any]) -> bool:
return auth_data["tls_verify"] or auth_data["tls"] return auth_data["tls_verify"] or auth_data["tls"]
def get_connect_params(auth_data, fail_function): def get_connect_params(
auth_data: dict[str, t.Any], fail_function: Callable[[str], t.NoReturn]
) -> dict[str, t.Any]:
if is_using_tls(auth_data): if is_using_tls(auth_data):
auth_data["docker_host"] = auth_data["docker_host"].replace( auth_data["docker_host"] = auth_data["docker_host"].replace(
"tcp://", "https://" "tcp://", "https://"
@ -173,7 +181,11 @@ DOCKERPYUPGRADE_UPGRADE_DOCKER = "Use `pip install --upgrade docker` to upgrade.
class AnsibleDockerClientBase(Client): class AnsibleDockerClientBase(Client):
def __init__(self, min_docker_version=None, min_docker_api_version=None): def __init__(
self,
min_docker_version: str | None = None,
min_docker_api_version: str | None = None,
) -> None:
if min_docker_version is None: if min_docker_version is None:
min_docker_version = MIN_DOCKER_VERSION min_docker_version = MIN_DOCKER_VERSION
@ -214,23 +226,34 @@ class AnsibleDockerClientBase(Client):
f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}." f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}."
) )
def log(self, msg, pretty_print=False): def log(self, msg: t.Any, pretty_print: bool = False):
pass pass
# if self.debug: # if self.debug:
# from .util import log_debug # from .util import log_debug
# log_debug(msg, pretty_print=pretty_print) # log_debug(msg, pretty_print=pretty_print)
@abc.abstractmethod @abc.abstractmethod
def fail(self, msg, **kwargs): def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
pass pass
def deprecate(self, msg, version=None, date=None, collection_name=None): @abc.abstractmethod
def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
pass pass
@staticmethod @staticmethod
def _get_value( def _get_value(
param_name, param_value, env_variable, default_value, value_type="str" param_name: str,
): param_value: t.Any,
env_variable: str | None,
default_value: t.Any | None,
value_type: t.Literal["str", "bool", "int"] = "str",
) -> t.Any:
if param_value is not None: if param_value is not None:
# take module parameter value # take module parameter value
if value_type == "bool": if value_type == "bool":
@ -267,11 +290,11 @@ class AnsibleDockerClientBase(Client):
return default_value return default_value
@abc.abstractmethod @abc.abstractmethod
def _get_params(self): def _get_params(self) -> dict[str, t.Any]:
pass pass
@property @property
def auth_params(self): def auth_params(self) -> dict[str, t.Any]:
# Get authentication credentials. # Get authentication credentials.
# Precedence: module parameters-> environment variables-> defaults. # Precedence: module parameters-> environment variables-> defaults.
@ -356,7 +379,7 @@ class AnsibleDockerClientBase(Client):
return result return result
def _handle_ssl_error(self, error): def _handle_ssl_error(self, error: Exception) -> t.NoReturn:
match = re.match(r"hostname.*doesn\'t match (\'.*\')", str(error)) match = re.match(r"hostname.*doesn\'t match (\'.*\')", str(error))
if match: if match:
hostname = self.auth_params["tls_hostname"] hostname = self.auth_params["tls_hostname"]
@ -368,7 +391,7 @@ class AnsibleDockerClientBase(Client):
) )
self.fail(f"SSL Exception: {error}") self.fail(f"SSL Exception: {error}")
def get_container_by_id(self, container_id): def get_container_by_id(self, container_id: str) -> dict[str, t.Any] | None:
try: try:
self.log(f"Inspecting container Id {container_id}") self.log(f"Inspecting container Id {container_id}")
result = self.inspect_container(container=container_id) result = self.inspect_container(container=container_id)
@ -379,7 +402,7 @@ class AnsibleDockerClientBase(Client):
except Exception as exc: # pylint: disable=broad-exception-caught except Exception as exc: # pylint: disable=broad-exception-caught
self.fail(f"Error inspecting container: {exc}") self.fail(f"Error inspecting container: {exc}")
def get_container(self, name=None): def get_container(self, name: str | None) -> dict[str, t.Any] | None:
""" """
Lookup a container and return the inspection results. Lookup a container and return the inspection results.
""" """
@ -416,7 +439,9 @@ class AnsibleDockerClientBase(Client):
return self.get_container_by_id(result["Id"]) return self.get_container_by_id(result["Id"])
def get_network(self, name=None, network_id=None): def get_network(
self, name: str | None = None, network_id: str | None = None
) -> dict[str, t.Any] | None:
""" """
Lookup a network and return the inspection results. Lookup a network and return the inspection results.
""" """
@ -455,7 +480,7 @@ class AnsibleDockerClientBase(Client):
return result return result
def find_image(self, name, tag): def find_image(self, name: str, tag: str) -> dict[str, t.Any] | None:
""" """
Lookup an image (by name and tag) and return the inspection results. Lookup an image (by name and tag) and return the inspection results.
""" """
@ -507,7 +532,9 @@ class AnsibleDockerClientBase(Client):
self.log(f"Image {name}:{tag} not found.") self.log(f"Image {name}:{tag} not found.")
return None return None
def find_image_by_id(self, image_id, accept_missing_image=False): def find_image_by_id(
self, image_id: str, accept_missing_image: bool = False
) -> dict[str, t.Any] | None:
""" """
Lookup an image (by ID) and return the inspection results. Lookup an image (by ID) and return the inspection results.
""" """
@ -526,7 +553,7 @@ class AnsibleDockerClientBase(Client):
self.fail(f"Error inspecting image ID {image_id} - {exc}") self.fail(f"Error inspecting image ID {image_id} - {exc}")
return inspection return inspection
def _image_lookup(self, name, tag): def _image_lookup(self, name: str, tag: str) -> list[dict[str, t.Any]]:
""" """
Including a tag in the name parameter sent to the Docker SDK for Python images method Including a tag in the name parameter sent to the Docker SDK for Python images method
does not work consistently. Instead, get the result set for name and manually check does not work consistently. Instead, get the result set for name and manually check
@ -549,7 +576,9 @@ class AnsibleDockerClientBase(Client):
break break
return images return images
def pull_image(self, name, tag="latest", image_platform=None): def pull_image(
self, name: str, tag: str = "latest", image_platform: str | None = None
) -> tuple[dict[str, t.Any] | None, bool]:
""" """
Pull an image Pull an image
""" """
@ -580,7 +609,7 @@ class AnsibleDockerClientBase(Client):
return new_tag, old_tag == new_tag return new_tag, old_tag == new_tag
def inspect_distribution(self, image, **kwargs): def inspect_distribution(self, image: str, **kwargs) -> dict[str, t.Any]:
""" """
Get image digest by directly calling the Docker API when running Docker SDK < 4.0.0 Get image digest by directly calling the Docker API when running Docker SDK < 4.0.0
since prior versions did not support accessing private repositories. since prior versions did not support accessing private repositories.
@ -594,7 +623,7 @@ class AnsibleDockerClientBase(Client):
self._url("/distribution/{0}/json", image), self._url("/distribution/{0}/json", image),
headers={"X-Registry-Auth": header}, headers={"X-Registry-Auth": header},
), ),
get_json=True, json=True,
) )
return super().inspect_distribution(image, **kwargs) return super().inspect_distribution(image, **kwargs)
@ -603,18 +632,24 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
def __init__( def __init__(
self, self,
argument_spec=None, argument_spec: dict[str, t.Any] | None = None,
supports_check_mode=False, supports_check_mode: bool = False,
mutually_exclusive=None, mutually_exclusive: Sequence[Sequence[str]] | None = None,
required_together=None, required_together: Sequence[Sequence[str]] | None = None,
required_if=None, required_if: (
required_one_of=None, Sequence[
required_by=None, tuple[str, t.Any, Sequence[str]]
min_docker_version=None, | tuple[str, t.Any, Sequence[str], bool]
min_docker_api_version=None, ]
option_minimal_versions=None, | None
option_minimal_versions_ignore_params=None, ) = None,
fail_results=None, required_one_of: Sequence[Sequence[str]] | None = None,
required_by: dict[str, Sequence[str]] | None = None,
min_docker_version: str | None = None,
min_docker_api_version: str | None = None,
option_minimal_versions: dict[str, t.Any] | None = None,
option_minimal_versions_ignore_params: Sequence[str] | None = None,
fail_results: dict[str, t.Any] | None = None,
): ):
# Modules can put information in here which will always be returned # Modules can put information in here which will always be returned
@ -627,12 +662,12 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
merged_arg_spec.update(argument_spec) merged_arg_spec.update(argument_spec)
self.arg_spec = merged_arg_spec self.arg_spec = merged_arg_spec
mutually_exclusive_params = [] mutually_exclusive_params: list[Sequence[str]] = []
mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE
if mutually_exclusive: if mutually_exclusive:
mutually_exclusive_params += mutually_exclusive mutually_exclusive_params += mutually_exclusive
required_together_params = [] required_together_params: list[Sequence[str]] = []
required_together_params += DOCKER_REQUIRED_TOGETHER required_together_params += DOCKER_REQUIRED_TOGETHER
if required_together: if required_together:
required_together_params += required_together required_together_params += required_together
@ -660,20 +695,30 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
option_minimal_versions, option_minimal_versions_ignore_params option_minimal_versions, option_minimal_versions_ignore_params
) )
def fail(self, msg, **kwargs): def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
self.fail_results.update(kwargs) self.fail_results.update(kwargs)
self.module.fail_json(msg=msg, **sanitize_result(self.fail_results)) self.module.fail_json(msg=msg, **sanitize_result(self.fail_results))
def deprecate(self, msg, version=None, date=None, collection_name=None): def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
self.module.deprecate( self.module.deprecate(
msg, version=version, date=date, collection_name=collection_name msg, version=version, date=date, collection_name=collection_name
) )
def _get_params(self): def _get_params(self) -> dict[str, t.Any]:
return self.module.params return self.module.params
def _get_minimal_versions(self, option_minimal_versions, ignore_params=None): def _get_minimal_versions(
self.option_minimal_versions = {} self,
option_minimal_versions: dict[str, t.Any],
ignore_params: Sequence[str] | None = None,
) -> None:
self.option_minimal_versions: dict[str, dict[str, t.Any]] = {}
for option in self.module.argument_spec: for option in self.module.argument_spec:
if ignore_params is not None: if ignore_params is not None:
if option in ignore_params: if option in ignore_params:
@ -724,7 +769,9 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
msg = f"Cannot {usg} with your configuration." msg = f"Cannot {usg} with your configuration."
self.fail(msg) self.fail(msg)
def report_warnings(self, result, warnings_key=None): def report_warnings(
self, result: t.Any, warnings_key: Sequence[str] | None = None
) -> None:
""" """
Checks result of client operation for warnings, and if present, outputs them. Checks result of client operation for warnings, and if present, outputs them.

View File

@ -11,6 +11,7 @@ from __future__ import annotations
import abc import abc
import os import os
import re import re
import typing as t
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from ansible.module_utils.basic import AnsibleModule, missing_required_lib from ansible.module_utils.basic import AnsibleModule, missing_required_lib
@ -60,19 +61,26 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
) )
def _get_tls_config(fail_function, **kwargs): if t.TYPE_CHECKING:
from collections.abc import Callable
def _get_tls_config(
fail_function: Callable[[str], t.NoReturn], **kwargs: t.Any
) -> TLSConfig:
try: try:
tls_config = TLSConfig(**kwargs) return TLSConfig(**kwargs)
return tls_config
except TLSParameterError as exc: except TLSParameterError as exc:
fail_function(f"TLS config error: {exc}") fail_function(f"TLS config error: {exc}")
def is_using_tls(auth_data): def is_using_tls(auth_data: dict[str, t.Any]) -> bool:
return auth_data["tls_verify"] or auth_data["tls"] return auth_data["tls_verify"] or auth_data["tls"]
def get_connect_params(auth_data, fail_function): def get_connect_params(
auth_data: dict[str, t.Any], fail_function: Callable[[str], t.NoReturn]
) -> dict[str, t.Any]:
if is_using_tls(auth_data): if is_using_tls(auth_data):
auth_data["docker_host"] = auth_data["docker_host"].replace( auth_data["docker_host"] = auth_data["docker_host"].replace(
"tcp://", "https://" "tcp://", "https://"
@ -114,7 +122,7 @@ def get_connect_params(auth_data, fail_function):
class AnsibleDockerClientBase(Client): class AnsibleDockerClientBase(Client):
def __init__(self, min_docker_api_version=None): def __init__(self, min_docker_api_version: str | None = None) -> None:
self._connect_params = get_connect_params( self._connect_params = get_connect_params(
self.auth_params, fail_function=self.fail self.auth_params, fail_function=self.fail
) )
@ -138,23 +146,34 @@ class AnsibleDockerClientBase(Client):
f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}." f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}."
) )
def log(self, msg, pretty_print=False): def log(self, msg: t.Any, pretty_print: bool = False):
pass pass
# if self.debug: # if self.debug:
# from .util import log_debug # from .util import log_debug
# log_debug(msg, pretty_print=pretty_print) # log_debug(msg, pretty_print=pretty_print)
@abc.abstractmethod @abc.abstractmethod
def fail(self, msg, **kwargs): def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
pass pass
def deprecate(self, msg, version=None, date=None, collection_name=None): @abc.abstractmethod
def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
pass pass
@staticmethod @staticmethod
def _get_value( def _get_value(
param_name, param_value, env_variable, default_value, value_type="str" param_name: str,
): param_value: t.Any,
env_variable: str | None,
default_value: t.Any | None,
value_type: t.Literal["str", "bool", "int"] = "str",
) -> t.Any:
if param_value is not None: if param_value is not None:
# take module parameter value # take module parameter value
if value_type == "bool": if value_type == "bool":
@ -191,11 +210,11 @@ class AnsibleDockerClientBase(Client):
return default_value return default_value
@abc.abstractmethod @abc.abstractmethod
def _get_params(self): def _get_params(self) -> dict[str, t.Any]:
pass pass
@property @property
def auth_params(self): def auth_params(self) -> dict[str, t.Any]:
# Get authentication credentials. # Get authentication credentials.
# Precedence: module parameters-> environment variables-> defaults. # Precedence: module parameters-> environment variables-> defaults.
@ -288,7 +307,7 @@ class AnsibleDockerClientBase(Client):
return result return result
def _handle_ssl_error(self, error): def _handle_ssl_error(self, error: Exception) -> t.NoReturn:
match = re.match(r"hostname.*doesn\'t match (\'.*\')", str(error)) match = re.match(r"hostname.*doesn\'t match (\'.*\')", str(error))
if match: if match:
hostname = self.auth_params["tls_hostname"] hostname = self.auth_params["tls_hostname"]
@ -300,7 +319,7 @@ class AnsibleDockerClientBase(Client):
) )
self.fail(f"SSL Exception: {error}") self.fail(f"SSL Exception: {error}")
def get_container_by_id(self, container_id): def get_container_by_id(self, container_id: str) -> dict[str, t.Any] | None:
try: try:
self.log(f"Inspecting container Id {container_id}") self.log(f"Inspecting container Id {container_id}")
result = self.get_json("/containers/{0}/json", container_id) result = self.get_json("/containers/{0}/json", container_id)
@ -311,7 +330,7 @@ class AnsibleDockerClientBase(Client):
except Exception as exc: # pylint: disable=broad-exception-caught except Exception as exc: # pylint: disable=broad-exception-caught
self.fail(f"Error inspecting container: {exc}") self.fail(f"Error inspecting container: {exc}")
def get_container(self, name=None): def get_container(self, name: str | None) -> dict[str, t.Any] | None:
""" """
Lookup a container and return the inspection results. Lookup a container and return the inspection results.
""" """
@ -355,7 +374,9 @@ class AnsibleDockerClientBase(Client):
return self.get_container_by_id(result["Id"]) return self.get_container_by_id(result["Id"])
def get_network(self, name=None, network_id=None): def get_network(
self, name: str | None = None, network_id: str | None = None
) -> dict[str, t.Any] | None:
""" """
Lookup a network and return the inspection results. Lookup a network and return the inspection results.
""" """
@ -395,14 +416,14 @@ class AnsibleDockerClientBase(Client):
return result return result
def _image_lookup(self, name, tag): def _image_lookup(self, name: str, tag: str) -> list[dict[str, t.Any]]:
""" """
Including a tag in the name parameter sent to the Docker SDK for Python images method Including a tag in the name parameter sent to the Docker SDK for Python images method
does not work consistently. Instead, get the result set for name and manually check does not work consistently. Instead, get the result set for name and manually check
if the tag exists. if the tag exists.
""" """
try: try:
params = { params: dict[str, t.Any] = {
"only_ids": 0, "only_ids": 0,
"all": 0, "all": 0,
} }
@ -427,7 +448,7 @@ class AnsibleDockerClientBase(Client):
break break
return images return images
def find_image(self, name, tag): def find_image(self, name: str, tag: str) -> dict[str, t.Any] | None:
""" """
Lookup an image (by name and tag) and return the inspection results. Lookup an image (by name and tag) and return the inspection results.
""" """
@ -478,7 +499,9 @@ class AnsibleDockerClientBase(Client):
self.log(f"Image {name}:{tag} not found.") self.log(f"Image {name}:{tag} not found.")
return None return None
def find_image_by_id(self, image_id, accept_missing_image=False): def find_image_by_id(
self, image_id: str, accept_missing_image: bool = False
) -> dict[str, t.Any] | None:
""" """
Lookup an image (by ID) and return the inspection results. Lookup an image (by ID) and return the inspection results.
""" """
@ -496,7 +519,9 @@ class AnsibleDockerClientBase(Client):
except Exception as exc: # pylint: disable=broad-exception-caught except Exception as exc: # pylint: disable=broad-exception-caught
self.fail(f"Error inspecting image ID {image_id} - {exc}") self.fail(f"Error inspecting image ID {image_id} - {exc}")
def pull_image(self, name, tag="latest", image_platform=None): def pull_image(
self, name: str, tag: str = "latest", image_platform: str | None = None
) -> tuple[dict[str, t.Any] | None, bool]:
""" """
Pull an image Pull an image
""" """
@ -547,17 +572,23 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
def __init__( def __init__(
self, self,
argument_spec=None, argument_spec: dict[str, t.Any] | None = None,
supports_check_mode=False, supports_check_mode: bool = False,
mutually_exclusive=None, mutually_exclusive: Sequence[Sequence[str]] | None = None,
required_together=None, required_together: Sequence[Sequence[str]] | None = None,
required_if=None, required_if: (
required_one_of=None, Sequence[
required_by=None, tuple[str, t.Any, Sequence[str]]
min_docker_api_version=None, | tuple[str, t.Any, Sequence[str], bool]
option_minimal_versions=None, ]
option_minimal_versions_ignore_params=None, | None
fail_results=None, ) = None,
required_one_of: Sequence[Sequence[str]] | None = None,
required_by: dict[str, Sequence[str]] | None = None,
min_docker_api_version: str | None = None,
option_minimal_versions: dict[str, t.Any] | None = None,
option_minimal_versions_ignore_params: Sequence[str] | None = None,
fail_results: dict[str, t.Any] | None = None,
): ):
# Modules can put information in here which will always be returned # Modules can put information in here which will always be returned
@ -570,12 +601,12 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
merged_arg_spec.update(argument_spec) merged_arg_spec.update(argument_spec)
self.arg_spec = merged_arg_spec self.arg_spec = merged_arg_spec
mutually_exclusive_params = [] mutually_exclusive_params: list[Sequence[str]] = []
mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE
if mutually_exclusive: if mutually_exclusive:
mutually_exclusive_params += mutually_exclusive mutually_exclusive_params += mutually_exclusive
required_together_params = [] required_together_params: list[Sequence[str]] = []
required_together_params += DOCKER_REQUIRED_TOGETHER required_together_params += DOCKER_REQUIRED_TOGETHER
if required_together: if required_together:
required_together_params += required_together required_together_params += required_together
@ -600,20 +631,30 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
option_minimal_versions, option_minimal_versions_ignore_params option_minimal_versions, option_minimal_versions_ignore_params
) )
def fail(self, msg, **kwargs): def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
self.fail_results.update(kwargs) self.fail_results.update(kwargs)
self.module.fail_json(msg=msg, **sanitize_result(self.fail_results)) self.module.fail_json(msg=msg, **sanitize_result(self.fail_results))
def deprecate(self, msg, version=None, date=None, collection_name=None): def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
self.module.deprecate( self.module.deprecate(
msg, version=version, date=date, collection_name=collection_name msg, version=version, date=date, collection_name=collection_name
) )
def _get_params(self): def _get_params(self) -> dict[str, t.Any]:
return self.module.params return self.module.params
def _get_minimal_versions(self, option_minimal_versions, ignore_params=None): def _get_minimal_versions(
self.option_minimal_versions = {} self,
option_minimal_versions: dict[str, t.Any],
ignore_params: Sequence[str] | None = None,
) -> None:
self.option_minimal_versions: dict[str, dict[str, t.Any]] = {}
for option in self.module.argument_spec: for option in self.module.argument_spec:
if ignore_params is not None: if ignore_params is not None:
if option in ignore_params: if option in ignore_params:
@ -654,7 +695,9 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
msg = f"Cannot {usg} with your configuration." msg = f"Cannot {usg} with your configuration."
self.fail(msg) self.fail(msg)
def report_warnings(self, result, warnings_key=None): def report_warnings(
self, result: t.Any, warnings_key: Sequence[str] | None = None
) -> None:
""" """
Checks result of client operation for warnings, and if present, outputs them. Checks result of client operation for warnings, and if present, outputs them.

View File

@ -10,6 +10,7 @@ from __future__ import annotations
import abc import abc
import json import json
import shlex import shlex
import typing as t
from ansible.module_utils.basic import AnsibleModule, env_fallback from ansible.module_utils.basic import AnsibleModule, env_fallback
from ansible.module_utils.common.process import get_bin_path from ansible.module_utils.common.process import get_bin_path
@ -31,6 +32,10 @@ from ansible_collections.community.docker.plugins.module_utils._version import (
) )
if t.TYPE_CHECKING:
from collections.abc import Mapping, Sequence
DOCKER_COMMON_ARGS = { DOCKER_COMMON_ARGS = {
"docker_cli": {"type": "path"}, "docker_cli": {"type": "path"},
"docker_host": { "docker_host": {
@ -72,10 +77,16 @@ class DockerException(Exception):
class AnsibleDockerClientBase: class AnsibleDockerClientBase:
docker_api_version_str: str | None
docker_api_version: LooseVersion | None
def __init__( def __init__(
self, common_args, min_docker_api_version=None, needs_api_version=True self,
): common_args,
self._environment = {} min_docker_api_version: str | None = None,
needs_api_version: bool = True,
) -> None:
self._environment: dict[str, str] = {}
if common_args["tls_hostname"]: if common_args["tls_hostname"]:
self._environment["DOCKER_TLS_HOSTNAME"] = common_args["tls_hostname"] self._environment["DOCKER_TLS_HOSTNAME"] = common_args["tls_hostname"]
if common_args["api_version"] and common_args["api_version"] != "auto": if common_args["api_version"] and common_args["api_version"] != "auto":
@ -109,10 +120,10 @@ class AnsibleDockerClientBase:
self._cli_base.extend(["--context", common_args["cli_context"]]) self._cli_base.extend(["--context", common_args["cli_context"]])
# `--format json` was only added as a shorthand for `--format {{ json . }}` in Docker 23.0 # `--format json` was only added as a shorthand for `--format {{ json . }}` in Docker 23.0
dummy, self._version, dummy = self.call_cli_json( dummy, self._version, dummy2 = self.call_cli_json(
"version", "--format", "{{ json . }}", check_rc=True "version", "--format", "{{ json . }}", check_rc=True
) )
self._info = None self._info: dict[str, t.Any] | None = None
if needs_api_version: if needs_api_version:
if not isinstance(self._version.get("Server"), dict) or not isinstance( if not isinstance(self._version.get("Server"), dict) or not isinstance(
@ -138,32 +149,47 @@ class AnsibleDockerClientBase:
"Internal error: cannot have needs_api_version=False with min_docker_api_version not None" "Internal error: cannot have needs_api_version=False with min_docker_api_version not None"
) )
def log(self, msg, pretty_print=False): def log(self, msg: str, pretty_print: bool = False):
pass pass
# if self.debug: # if self.debug:
# from .util import log_debug # from .util import log_debug
# log_debug(msg, pretty_print=pretty_print) # log_debug(msg, pretty_print=pretty_print)
def get_cli(self): def get_cli(self) -> str:
return self._cli return self._cli
def get_version_info(self): def get_version_info(self) -> str:
return self._version return self._version
def _compose_cmd(self, args): def _compose_cmd(self, args: t.Sequence[str]) -> list[str]:
return self._cli_base + list(args) return self._cli_base + list(args)
def _compose_cmd_str(self, args): def _compose_cmd_str(self, args: t.Sequence[str]) -> str:
return " ".join(shlex.quote(a) for a in self._compose_cmd(args)) return " ".join(shlex.quote(a) for a in self._compose_cmd(args))
@abc.abstractmethod @abc.abstractmethod
def call_cli(self, *args, check_rc=False, data=None, cwd=None, environ_update=None): def call_cli(
self,
*args: str,
check_rc: bool = False,
data: bytes | None = None,
cwd: str | None = None,
environ_update: dict[str, str] | None = None,
) -> tuple[int, bytes, bytes]:
pass pass
# def call_cli_json(self, *args, check_rc=False, data=None, cwd=None, environ_update=None, warn_on_stderr=False): def call_cli_json(
def call_cli_json(self, *args, **kwargs): self,
warn_on_stderr = kwargs.pop("warn_on_stderr", False) *args: str,
rc, stdout, stderr = self.call_cli(*args, **kwargs) check_rc: bool = False,
data: bytes | None = None,
cwd: str | None = None,
environ_update: dict[str, str] | None = None,
warn_on_stderr: bool = False,
) -> tuple[int, t.Any, bytes]:
rc, stdout, stderr = self.call_cli(
*args, check_rc=check_rc, data=data, cwd=cwd, environ_update=environ_update
)
if warn_on_stderr and stderr: if warn_on_stderr and stderr:
self.warn(to_native(stderr)) self.warn(to_native(stderr))
try: try:
@ -174,10 +200,18 @@ class AnsibleDockerClientBase:
) )
return rc, data, stderr return rc, data, stderr
# def call_cli_json_stream(self, *args, check_rc=False, data=None, cwd=None, environ_update=None, warn_on_stderr=False): def call_cli_json_stream(
def call_cli_json_stream(self, *args, **kwargs): self,
warn_on_stderr = kwargs.pop("warn_on_stderr", False) *args: str,
rc, stdout, stderr = self.call_cli(*args, **kwargs) check_rc: bool = False,
data: bytes | None = None,
cwd: str | None = None,
environ_update: dict[str, str] | None = None,
warn_on_stderr: bool = False,
) -> tuple[int, list[t.Any], bytes]:
rc, stdout, stderr = self.call_cli(
*args, check_rc=check_rc, data=data, cwd=cwd, environ_update=environ_update
)
if warn_on_stderr and stderr: if warn_on_stderr and stderr:
self.warn(to_native(stderr)) self.warn(to_native(stderr))
result = [] result = []
@ -193,25 +227,31 @@ class AnsibleDockerClientBase:
return rc, result, stderr return rc, result, stderr
@abc.abstractmethod @abc.abstractmethod
def fail(self, msg, **kwargs): def fail(self, msg: str, **kwargs) -> t.NoReturn:
pass pass
@abc.abstractmethod @abc.abstractmethod
def warn(self, msg): def warn(self, msg: str) -> None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def deprecate(self, msg, version=None, date=None, collection_name=None): def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
pass pass
def get_cli_info(self): def get_cli_info(self) -> dict[str, t.Any]:
if self._info is None: if self._info is None:
dummy, self._info, dummy = self.call_cli_json( dummy, self._info, dummy2 = self.call_cli_json(
"info", "--format", "{{ json . }}", check_rc=True "info", "--format", "{{ json . }}", check_rc=True
) )
return self._info return self._info
def get_client_plugin_info(self, component): def get_client_plugin_info(self, component: str) -> dict[str, t.Any] | None:
cli_info = self.get_cli_info() cli_info = self.get_cli_info()
if not isinstance(cli_info.get("ClientInfo"), dict): if not isinstance(cli_info.get("ClientInfo"), dict):
self.fail( self.fail(
@ -222,13 +262,13 @@ class AnsibleDockerClientBase:
return plugin return plugin
return None return None
def _image_lookup(self, name, tag): def _image_lookup(self, name: str, tag: str) -> list[dict[str, t.Any]]:
""" """
Including a tag in the name parameter sent to the Docker SDK for Python images method Including a tag in the name parameter sent to the Docker SDK for Python images method
does not work consistently. Instead, get the result set for name and manually check does not work consistently. Instead, get the result set for name and manually check
if the tag exists. if the tag exists.
""" """
dummy, images, dummy = self.call_cli_json_stream( dummy, images, dummy2 = self.call_cli_json_stream(
"image", "image",
"ls", "ls",
"--format", "--format",
@ -247,7 +287,13 @@ class AnsibleDockerClientBase:
break break
return images return images
def find_image(self, name, tag): @t.overload
def find_image(self, name: None, tag: str) -> None: ...
@t.overload
def find_image(self, name: str, tag: str) -> dict[str, t.Any] | None: ...
def find_image(self, name: str | None, tag: str) -> dict[str, t.Any] | None:
""" """
Lookup an image (by name and tag) and return the inspection results. Lookup an image (by name and tag) and return the inspection results.
""" """
@ -298,7 +344,19 @@ class AnsibleDockerClientBase:
self.log(f"Image {name}:{tag} not found.") self.log(f"Image {name}:{tag} not found.")
return None return None
def find_image_by_id(self, image_id, accept_missing_image=False): @t.overload
def find_image_by_id(
self, image_id: None, accept_missing_image: bool = False
) -> None: ...
@t.overload
def find_image_by_id(
self, image_id: str | None, accept_missing_image: bool = False
) -> dict[str, t.Any] | None: ...
def find_image_by_id(
self, image_id: str | None, accept_missing_image: bool = False
) -> dict[str, t.Any] | None:
""" """
Lookup an image (by ID) and return the inspection results. Lookup an image (by ID) and return the inspection results.
""" """
@ -320,17 +378,23 @@ class AnsibleDockerClientBase:
class AnsibleModuleDockerClient(AnsibleDockerClientBase): class AnsibleModuleDockerClient(AnsibleDockerClientBase):
def __init__( def __init__(
self, self,
argument_spec=None, argument_spec: dict[str, t.Any] | None = None,
supports_check_mode=False, supports_check_mode: bool = False,
mutually_exclusive=None, mutually_exclusive: Sequence[Sequence[str]] | None = None,
required_together=None, required_together: Sequence[Sequence[str]] | None = None,
required_if=None, required_if: (
required_one_of=None, Sequence[
required_by=None, tuple[str, t.Any, Sequence[str]]
min_docker_api_version=None, | tuple[str, t.Any, Sequence[str], bool]
fail_results=None, ]
needs_api_version=True, | None
): ) = None,
required_one_of: Sequence[Sequence[str]] | None = None,
required_by: Mapping[str, Sequence[str]] | None = None,
min_docker_api_version: str | None = None,
fail_results: dict[str, t.Any] | None = None,
needs_api_version: bool = True,
) -> None:
# Modules can put information in here which will always be returned # Modules can put information in here which will always be returned
# in case client.fail() is called. # in case client.fail() is called.
@ -342,12 +406,14 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase):
merged_arg_spec.update(argument_spec) merged_arg_spec.update(argument_spec)
self.arg_spec = merged_arg_spec self.arg_spec = merged_arg_spec
mutually_exclusive_params = [("docker_host", "cli_context")] mutually_exclusive_params: list[Sequence[str]] = [
("docker_host", "cli_context")
]
mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE
if mutually_exclusive: if mutually_exclusive:
mutually_exclusive_params += mutually_exclusive mutually_exclusive_params += mutually_exclusive
required_together_params = [] required_together_params: list[Sequence[str]] = []
required_together_params += DOCKER_REQUIRED_TOGETHER required_together_params += DOCKER_REQUIRED_TOGETHER
if required_together: if required_together:
required_together_params += required_together required_together_params += required_together
@ -373,7 +439,14 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase):
needs_api_version=needs_api_version, needs_api_version=needs_api_version,
) )
def call_cli(self, *args, check_rc=False, data=None, cwd=None, environ_update=None): def call_cli(
self,
*args: str,
check_rc: bool = False,
data: bytes | None = None,
cwd: str | None = None,
environ_update: dict[str, str] | None = None,
) -> tuple[int, bytes, bytes]:
environment = self._environment.copy() environment = self._environment.copy()
if environ_update: if environ_update:
environment.update(environ_update) environment.update(environ_update)
@ -390,14 +463,20 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase):
) )
return rc, stdout, stderr return rc, stdout, stderr
def fail(self, msg, **kwargs): def fail(self, msg: str, **kwargs) -> t.NoReturn:
self.fail_results.update(kwargs) self.fail_results.update(kwargs)
self.module.fail_json(msg=msg, **sanitize_result(self.fail_results)) self.module.fail_json(msg=msg, **sanitize_result(self.fail_results))
def warn(self, msg): def warn(self, msg: str) -> None:
self.module.warn(msg) self.module.warn(msg)
def deprecate(self, msg, version=None, date=None, collection_name=None): def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
self.module.deprecate( self.module.deprecate(
msg, version=version, date=date, collection_name=collection_name msg, version=version, date=date, collection_name=collection_name
) )

View File

@ -51,6 +51,13 @@ else:
HAS_PYYAML = True HAS_PYYAML = True
PYYAML_IMPORT_ERROR = None # pylint: disable=invalid-name PYYAML_IMPORT_ERROR = None # pylint: disable=invalid-name
if t.TYPE_CHECKING:
from collections.abc import Callable, Sequence
from ansible_collections.community.docker.plugins.module_utils._common_cli import (
AnsibleModuleDockerClient as _Client,
)
DOCKER_COMPOSE_FILES = ( DOCKER_COMPOSE_FILES = (
"compose.yaml", "compose.yaml",
@ -241,7 +248,9 @@ _RE_BUILD_PROGRESS_EVENT = re.compile(r"^\s*==>\s+(?P<msg>.*)$")
MINIMUM_COMPOSE_VERSION = "2.18.0" MINIMUM_COMPOSE_VERSION = "2.18.0"
def _extract_event(line, warn_function=None): def _extract_event(
line: str, warn_function: Callable[[str], None] | None = None
) -> tuple[Event | None, bool]:
match = _RE_RESOURCE_EVENT.match(line) match = _RE_RESOURCE_EVENT.match(line)
if match is not None: if match is not None:
status = match.group("status") status = match.group("status")
@ -324,7 +333,9 @@ def _extract_event(line, warn_function=None):
return None, False return None, False
def _extract_logfmt_event(line, warn_function=None): def _extract_logfmt_event(
line: str, warn_function: Callable[[str], None] | None = None
) -> tuple[Event | None, bool]:
try: try:
result = _parse_logfmt_line(line, logrus_mode=True) result = _parse_logfmt_line(line, logrus_mode=True)
except _InvalidLogFmt: except _InvalidLogFmt:
@ -339,7 +350,11 @@ def _extract_logfmt_event(line, warn_function=None):
return None, False return None, False
def _warn_missing_dry_run_prefix(line, warn_missing_dry_run_prefix, warn_function): def _warn_missing_dry_run_prefix(
line: str,
warn_missing_dry_run_prefix: bool,
warn_function: Callable[[str], None] | None,
) -> None:
if warn_missing_dry_run_prefix and warn_function: if warn_missing_dry_run_prefix and warn_function:
# This could be a bug, a change of docker compose's output format, ... # This could be a bug, a change of docker compose's output format, ...
# Tell the user to report it to us :-) # Tell the user to report it to us :-)
@ -350,7 +365,9 @@ def _warn_missing_dry_run_prefix(line, warn_missing_dry_run_prefix, warn_functio
) )
def _warn_unparsable_line(line, warn_function): def _warn_unparsable_line(
line: str, warn_function: Callable[[str], None] | None
) -> None:
# This could be a bug, a change of docker compose's output format, ... # This could be a bug, a change of docker compose's output format, ...
# Tell the user to report it to us :-) # Tell the user to report it to us :-)
if warn_function: if warn_function:
@ -361,14 +378,16 @@ def _warn_unparsable_line(line, warn_function):
) )
def _find_last_event_for(events, resource_id): def _find_last_event_for(
events: list[Event], resource_id: str
) -> tuple[int, Event] | None:
for index, event in enumerate(reversed(events)): for index, event in enumerate(reversed(events)):
if event.resource_id == resource_id: if event.resource_id == resource_id:
return len(events) - 1 - index, event return len(events) - 1 - index, event
return None return None
def _concat_event_msg(event, append_msg): def _concat_event_msg(event: Event, append_msg: str) -> Event:
return Event( return Event(
event.resource_type, event.resource_type,
event.resource_id, event.resource_id,
@ -383,7 +402,9 @@ _JSON_LEVEL_TO_STATUS_MAP = {
} }
def parse_json_events(stderr, warn_function=None): def parse_json_events(
stderr: bytes, warn_function: Callable[[str], None] | None = None
) -> list[Event]:
events = [] events = []
stderr_lines = stderr.splitlines() stderr_lines = stderr.splitlines()
if stderr_lines and stderr_lines[-1] == b"": if stderr_lines and stderr_lines[-1] == b"":
@ -524,7 +545,12 @@ def parse_json_events(stderr, warn_function=None):
return events return events
def parse_events(stderr, dry_run=False, warn_function=None, nonzero_rc=False): def parse_events(
stderr: bytes,
dry_run: bool = False,
warn_function: Callable[[str], None] | None = None,
nonzero_rc: bool = False,
) -> list[Event]:
events = [] events = []
error_event = None error_event = None
stderr_lines = stderr.splitlines() stderr_lines = stderr.splitlines()
@ -598,7 +624,11 @@ def parse_events(stderr, dry_run=False, warn_function=None, nonzero_rc=False):
return events return events
def has_changes(events, ignore_service_pull_events=False, ignore_build_events=False): def has_changes(
events: Sequence[Event],
ignore_service_pull_events: bool = False,
ignore_build_events: bool = False,
) -> bool:
for event in events: for event in events:
if event.status in DOCKER_STATUS_WORKING: if event.status in DOCKER_STATUS_WORKING:
if ignore_service_pull_events and event.status in DOCKER_STATUS_PULL: if ignore_service_pull_events and event.status in DOCKER_STATUS_PULL:
@ -614,7 +644,7 @@ def has_changes(events, ignore_service_pull_events=False, ignore_build_events=Fa
return False return False
def extract_actions(events): def extract_actions(events: Sequence[Event]) -> list[dict[str, t.Any]]:
actions = [] actions = []
pull_actions = set() pull_actions = set()
for event in events: for event in events:
@ -646,7 +676,9 @@ def extract_actions(events):
return actions return actions
def emit_warnings(events, warn_function): def emit_warnings(
events: Sequence[Event], warn_function: Callable[[str], None]
) -> None:
for event in events: for event in events:
# If a message is present, assume it is a warning # If a message is present, assume it is a warning
if ( if (
@ -657,13 +689,21 @@ def emit_warnings(events, warn_function):
) )
def is_failed(events, rc): def is_failed(events: Sequence[Event], rc: int) -> bool:
if rc: if rc:
return True return True
return False return False
def update_failed(result, events, args, stdout, stderr, rc, cli): def update_failed(
result: dict[str, t.Any],
events: Sequence[Event],
args: list[str],
stdout: str | bytes,
stderr: str | bytes,
rc: int,
cli: str,
) -> bool:
if not rc: if not rc:
return False return False
errors = [] errors = []
@ -697,7 +737,7 @@ def update_failed(result, events, args, stdout, stderr, rc, cli):
return True return True
def common_compose_argspec(): def common_compose_argspec() -> dict[str, t.Any]:
return { return {
"project_src": {"type": "path"}, "project_src": {"type": "path"},
"project_name": {"type": "str"}, "project_name": {"type": "str"},
@ -709,7 +749,7 @@ def common_compose_argspec():
} }
def common_compose_argspec_ex(): def common_compose_argspec_ex() -> dict[str, t.Any]:
return { return {
"argspec": common_compose_argspec(), "argspec": common_compose_argspec(),
"mutually_exclusive": [("definition", "project_src"), ("definition", "files")], "mutually_exclusive": [("definition", "project_src"), ("definition", "files")],
@ -722,16 +762,18 @@ def common_compose_argspec_ex():
} }
def combine_binary_output(*outputs): def combine_binary_output(*outputs: bytes | None) -> bytes:
return b"\n".join(out for out in outputs if out) return b"\n".join(out for out in outputs if out)
def combine_text_output(*outputs): def combine_text_output(*outputs: str | None) -> str:
return "\n".join(out for out in outputs if out) return "\n".join(out for out in outputs if out)
class BaseComposeManager(DockerBaseClass): class BaseComposeManager(DockerBaseClass):
def __init__(self, client, min_version=MINIMUM_COMPOSE_VERSION): def __init__(
self, client: _Client, min_version: str = MINIMUM_COMPOSE_VERSION
) -> None:
super().__init__() super().__init__()
self.client = client self.client = client
self.check_mode = self.client.check_mode self.check_mode = self.client.check_mode
@ -795,12 +837,12 @@ class BaseComposeManager(DockerBaseClass):
# more precisely in https://github.com/docker/compose/pull/11478 # more precisely in https://github.com/docker/compose/pull/11478
self.use_json_events = self.compose_version >= LooseVersion("2.29.0") self.use_json_events = self.compose_version >= LooseVersion("2.29.0")
def get_compose_version(self): def get_compose_version(self) -> str:
return ( return (
self.get_compose_version_from_cli() or self.get_compose_version_from_api() self.get_compose_version_from_cli() or self.get_compose_version_from_api()
) )
def get_compose_version_from_cli(self): def get_compose_version_from_cli(self) -> str | None:
rc, version_info, dummy_stderr = self.client.call_cli( rc, version_info, dummy_stderr = self.client.call_cli(
"compose", "version", "--format", "json" "compose", "version", "--format", "json"
) )
@ -814,7 +856,7 @@ class BaseComposeManager(DockerBaseClass):
except Exception: # pylint: disable=broad-exception-caught except Exception: # pylint: disable=broad-exception-caught
return None return None
def get_compose_version_from_api(self): def get_compose_version_from_api(self) -> str:
compose = self.client.get_client_plugin_info("compose") compose = self.client.get_client_plugin_info("compose")
if compose is None: if compose is None:
self.fail( self.fail(
@ -827,11 +869,11 @@ class BaseComposeManager(DockerBaseClass):
) )
return compose["Version"].lstrip("v") return compose["Version"].lstrip("v")
def fail(self, msg, **kwargs): def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
self.cleanup() self.cleanup()
self.client.fail(msg, **kwargs) self.client.fail(msg, **kwargs)
def get_base_args(self, plain_progress=False): def get_base_args(self, plain_progress: bool = False) -> list[str]:
args = ["compose", "--ansi", "never"] args = ["compose", "--ansi", "never"]
if self.use_json_events and not plain_progress: if self.use_json_events and not plain_progress:
args.extend(["--progress", "json"]) args.extend(["--progress", "json"])
@ -849,28 +891,33 @@ class BaseComposeManager(DockerBaseClass):
args.extend(["--profile", profile]) args.extend(["--profile", profile])
return args return args
def _handle_failed_cli_call(self, args, rc, stdout, stderr): def _handle_failed_cli_call(
self, args: list[str], rc: int, stdout: str | bytes, stderr: bytes
) -> t.NoReturn:
events = parse_json_events(stderr, warn_function=self.client.warn) events = parse_json_events(stderr, warn_function=self.client.warn)
result = {} result: dict[str, t.Any] = {}
self.update_failed(result, events, args, stdout, stderr, rc) self.update_failed(result, events, args, stdout, stderr, rc)
self.client.module.exit_json(**result) self.client.module.exit_json(**result)
def list_containers_raw(self): def list_containers_raw(self) -> list[dict[str, t.Any]]:
args = self.get_base_args() + ["ps", "--format", "json", "--all"] args = self.get_base_args() + ["ps", "--format", "json", "--all"]
if self.compose_version >= LooseVersion("2.23.0"): if self.compose_version >= LooseVersion("2.23.0"):
# https://github.com/docker/compose/pull/11038 # https://github.com/docker/compose/pull/11038
args.append("--no-trunc") args.append("--no-trunc")
kwargs = {"cwd": self.project_src, "check_rc": not self.use_json_events}
if self.compose_version >= LooseVersion("2.21.0"): if self.compose_version >= LooseVersion("2.21.0"):
# Breaking change in 2.21.0: https://github.com/docker/compose/pull/10918 # Breaking change in 2.21.0: https://github.com/docker/compose/pull/10918
rc, containers, stderr = self.client.call_cli_json_stream(*args, **kwargs) rc, containers, stderr = self.client.call_cli_json_stream(
*args, cwd=self.project_src, check_rc=not self.use_json_events
)
else: else:
rc, containers, stderr = self.client.call_cli_json(*args, **kwargs) rc, containers, stderr = self.client.call_cli_json(
*args, cwd=self.project_src, check_rc=not self.use_json_events
)
if self.use_json_events and rc != 0: if self.use_json_events and rc != 0:
self._handle_failed_cli_call(args, rc, containers, stderr) self._handle_failed_cli_call(args, rc, json.dumps(containers), stderr)
return containers return containers
def list_containers(self): def list_containers(self) -> list[dict[str, t.Any]]:
result = [] result = []
for container in self.list_containers_raw(): for container in self.list_containers_raw():
labels = {} labels = {}
@ -887,10 +934,11 @@ class BaseComposeManager(DockerBaseClass):
result.append(container) result.append(container)
return result return result
def list_images(self): def list_images(self) -> list[str]:
args = self.get_base_args() + ["images", "--format", "json"] args = self.get_base_args() + ["images", "--format", "json"]
kwargs = {"cwd": self.project_src, "check_rc": not self.use_json_events} rc, images, stderr = self.client.call_cli_json(
rc, images, stderr = self.client.call_cli_json(*args, **kwargs) *args, cwd=self.project_src, check_rc=not self.use_json_events
)
if self.use_json_events and rc != 0: if self.use_json_events and rc != 0:
self._handle_failed_cli_call(args, rc, images, stderr) self._handle_failed_cli_call(args, rc, images, stderr)
if isinstance(images, dict): if isinstance(images, dict):
@ -900,7 +948,9 @@ class BaseComposeManager(DockerBaseClass):
images = list(images.values()) images = list(images.values())
return images return images
def parse_events(self, stderr, dry_run=False, nonzero_rc=False): def parse_events(
self, stderr: bytes, dry_run: bool = False, nonzero_rc: bool = False
) -> list[Event]:
if self.use_json_events: if self.use_json_events:
return parse_json_events(stderr, warn_function=self.client.warn) return parse_json_events(stderr, warn_function=self.client.warn)
return parse_events( return parse_events(
@ -910,17 +960,17 @@ class BaseComposeManager(DockerBaseClass):
nonzero_rc=nonzero_rc, nonzero_rc=nonzero_rc,
) )
def emit_warnings(self, events): def emit_warnings(self, events: Sequence[Event]) -> None:
emit_warnings(events, warn_function=self.client.warn) emit_warnings(events, warn_function=self.client.warn)
def update_result( def update_result(
self, self,
result, result: dict[str, t.Any],
events, events: Sequence[Event],
stdout, stdout: str,
stderr, stderr: str,
ignore_service_pull_events=False, ignore_service_pull_events: bool = False,
ignore_build_events=False, ignore_build_events: bool = False,
): ):
result["changed"] = result.get("changed", False) or has_changes( result["changed"] = result.get("changed", False) or has_changes(
events, events,
@ -931,7 +981,15 @@ class BaseComposeManager(DockerBaseClass):
result["stdout"] = combine_text_output(result.get("stdout"), to_native(stdout)) result["stdout"] = combine_text_output(result.get("stdout"), to_native(stdout))
result["stderr"] = combine_text_output(result.get("stderr"), to_native(stderr)) result["stderr"] = combine_text_output(result.get("stderr"), to_native(stderr))
def update_failed(self, result, events, args, stdout, stderr, rc): def update_failed(
self,
result: dict[str, t.Any],
events: Sequence[Event],
args: list[str],
stdout: str | bytes,
stderr: bytes,
rc: int,
):
return update_failed( return update_failed(
result, result,
events, events,
@ -942,14 +1000,14 @@ class BaseComposeManager(DockerBaseClass):
cli=self.client.get_cli(), cli=self.client.get_cli(),
) )
def cleanup_result(self, result): def cleanup_result(self, result: dict[str, t.Any]) -> None:
if not result.get("failed"): if not result.get("failed"):
# Only return stdout and stderr if it is not empty # Only return stdout and stderr if it is not empty
for res in ("stdout", "stderr"): for res in ("stdout", "stderr"):
if result.get(res) == "": if result.get(res) == "":
result.pop(res) result.pop(res)
def cleanup(self): def cleanup(self) -> None:
for directory in self.cleanup_dirs: for directory in self.cleanup_dirs:
try: try:
shutil.rmtree(directory, True) shutil.rmtree(directory, True)

View File

@ -16,6 +16,7 @@ import os.path
import shutil import shutil
import stat import stat
import tarfile import tarfile
import typing as t
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
@ -25,6 +26,16 @@ from ansible_collections.community.docker.plugins.module_utils._api.errors impor
) )
if t.TYPE_CHECKING:
from collections.abc import Callable
from _typeshed import WriteableBuffer
from ansible_collections.community.docker.plugins.module_utils._api.api.client import (
APIClient,
)
class DockerFileCopyError(Exception): class DockerFileCopyError(Exception):
pass pass
@ -37,7 +48,9 @@ class DockerFileNotFound(DockerFileCopyError):
pass pass
def _put_archive(client, container, path, data): def _put_archive(
client: APIClient, container: str, path: str, data: bytes | t.Generator[bytes]
) -> bool:
# data can also be file object for streaming. This is because _put uses requests's put(). # data can also be file object for streaming. This is because _put uses requests's put().
# See https://requests.readthedocs.io/en/latest/user/advanced/#streaming-uploads # See https://requests.readthedocs.io/en/latest/user/advanced/#streaming-uploads
url = client._url("/containers/{0}/archive", container) url = client._url("/containers/{0}/archive", container)
@ -47,8 +60,14 @@ def _put_archive(client, container, path, data):
def _symlink_tar_creator( def _symlink_tar_creator(
b_in_path, file_stat, out_file, user_id, group_id, mode=None, user_name=None b_in_path: bytes,
): file_stat: os.stat_result,
out_file: str | bytes,
user_id: int,
group_id: int,
mode: int | None = None,
user_name: str | None = None,
) -> bytes:
if not stat.S_ISLNK(file_stat.st_mode): if not stat.S_ISLNK(file_stat.st_mode):
raise DockerUnexpectedError("stat information is not for a symlink") raise DockerUnexpectedError("stat information is not for a symlink")
bio = io.BytesIO() bio = io.BytesIO()
@ -75,16 +94,28 @@ def _symlink_tar_creator(
def _symlink_tar_generator( def _symlink_tar_generator(
b_in_path, file_stat, out_file, user_id, group_id, mode=None, user_name=None b_in_path: bytes,
): file_stat: os.stat_result,
out_file: str | bytes,
user_id: int,
group_id: int,
mode: int | None = None,
user_name: str | None = None,
) -> t.Generator[bytes]:
yield _symlink_tar_creator( yield _symlink_tar_creator(
b_in_path, file_stat, out_file, user_id, group_id, mode, user_name b_in_path, file_stat, out_file, user_id, group_id, mode, user_name
) )
def _regular_file_tar_generator( def _regular_file_tar_generator(
b_in_path, file_stat, out_file, user_id, group_id, mode=None, user_name=None b_in_path: bytes,
): file_stat: os.stat_result,
out_file: str | bytes,
user_id: int,
group_id: int,
mode: int | None = None,
user_name: str | None = None,
) -> t.Generator[bytes]:
if not stat.S_ISREG(file_stat.st_mode): if not stat.S_ISREG(file_stat.st_mode):
raise DockerUnexpectedError("stat information is not for a regular file") raise DockerUnexpectedError("stat information is not for a regular file")
tarinfo = tarfile.TarInfo() tarinfo = tarfile.TarInfo()
@ -136,8 +167,13 @@ def _regular_file_tar_generator(
def _regular_content_tar_generator( def _regular_content_tar_generator(
content, out_file, user_id, group_id, mode, user_name=None content: bytes,
): out_file: str | bytes,
user_id: int,
group_id: int,
mode: int,
user_name: str | None = None,
) -> t.Generator[bytes]:
tarinfo = tarfile.TarInfo() tarinfo = tarfile.TarInfo()
tarinfo.name = ( tarinfo.name = (
os.path.splitdrive(to_text(out_file))[1].replace(os.sep, "/").lstrip("/") os.path.splitdrive(to_text(out_file))[1].replace(os.sep, "/").lstrip("/")
@ -175,16 +211,16 @@ def _regular_content_tar_generator(
def put_file( def put_file(
client, client: APIClient,
container, container: str,
in_path, in_path: str,
out_path, out_path: str,
user_id, user_id: int,
group_id, group_id: int,
mode=None, mode: int | None = None,
user_name=None, user_name: str | None = None,
follow_links=False, follow_links: bool = False,
): ) -> None:
"""Transfer a file from local to Docker container.""" """Transfer a file from local to Docker container."""
if not os.path.exists(to_bytes(in_path, errors="surrogate_or_strict")): if not os.path.exists(to_bytes(in_path, errors="surrogate_or_strict")):
raise DockerFileNotFound(f"file or module does not exist: {to_native(in_path)}") raise DockerFileNotFound(f"file or module does not exist: {to_native(in_path)}")
@ -232,8 +268,15 @@ def put_file(
def put_file_content( def put_file_content(
client, container, content, out_path, user_id, group_id, mode, user_name=None client: APIClient,
): container: str,
content: bytes,
out_path: str,
user_id: int,
group_id: int,
mode: int,
user_name: str | None = None,
) -> None:
"""Transfer a file from local to Docker container.""" """Transfer a file from local to Docker container."""
out_dir, out_file = os.path.split(out_path) out_dir, out_file = os.path.split(out_path)
@ -248,7 +291,13 @@ def put_file_content(
) )
def stat_file(client, container, in_path, follow_links=False, log=None): def stat_file(
client: APIClient,
container: str,
in_path: str,
follow_links: bool = False,
log: Callable[[str], None] | None = None,
) -> tuple[str | bytes, dict[str, t.Any] | None, str | None]:
"""Fetch information on a file from a Docker container to local. """Fetch information on a file from a Docker container to local.
Return a tuple ``(path, stat_data, link_target)`` where: Return a tuple ``(path, stat_data, link_target)`` where:
@ -265,12 +314,12 @@ def stat_file(client, container, in_path, follow_links=False, log=None):
while True: while True:
if in_path in considered_in_paths: if in_path in considered_in_paths:
raise DockerFileCopyError( raise DockerFileCopyError(
f'Found infinite symbolic link loop when trying to stating "{in_path}"' f"Found infinite symbolic link loop when trying to stating {in_path!r}"
) )
considered_in_paths.add(in_path) considered_in_paths.add(in_path)
if log: if log:
log(f'FETCH: Stating "{in_path}"') log(f"FETCH: Stating {in_path!r}")
response = client._head( response = client._head(
client._url("/containers/{0}/archive", container), client._url("/containers/{0}/archive", container),
@ -299,24 +348,24 @@ def stat_file(client, container, in_path, follow_links=False, log=None):
class _RawGeneratorFileobj(io.RawIOBase): class _RawGeneratorFileobj(io.RawIOBase):
def __init__(self, stream): def __init__(self, stream: t.Generator[bytes]):
self._stream = stream self._stream = stream
self._buf = b"" self._buf = b""
def readable(self): def readable(self) -> bool:
return True return True
def _readinto_from_buf(self, b, index, length): def _readinto_from_buf(self, b: WriteableBuffer, index: int, length: int) -> int:
cpy = min(length - index, len(self._buf)) cpy = min(length - index, len(self._buf))
if cpy: if cpy:
b[index : index + cpy] = self._buf[:cpy] b[index : index + cpy] = self._buf[:cpy] # type: ignore # TODO!
self._buf = self._buf[cpy:] self._buf = self._buf[cpy:]
index += cpy index += cpy
return index return index
def readinto(self, b): def readinto(self, b: WriteableBuffer) -> int:
index = 0 index = 0
length = len(b) length = len(b) # type: ignore # TODO!
index = self._readinto_from_buf(b, index, length) index = self._readinto_from_buf(b, index, length)
if index == length: if index == length:
@ -330,25 +379,28 @@ class _RawGeneratorFileobj(io.RawIOBase):
return self._readinto_from_buf(b, index, length) return self._readinto_from_buf(b, index, length)
def _stream_generator_to_fileobj(stream): def _stream_generator_to_fileobj(stream: t.Generator[bytes]) -> io.BufferedReader:
"""Given a generator that generates chunks of bytes, create a readable buffered stream.""" """Given a generator that generates chunks of bytes, create a readable buffered stream."""
raw = _RawGeneratorFileobj(stream) raw = _RawGeneratorFileobj(stream)
return io.BufferedReader(raw) return io.BufferedReader(raw)
_T = t.TypeVar("_T")
def fetch_file_ex( def fetch_file_ex(
client, client: APIClient,
container, container: str,
in_path, in_path: str,
process_none, process_none: Callable[[str], _T],
process_regular, process_regular: Callable[[str, tarfile.TarFile, tarfile.TarInfo], _T],
process_symlink, process_symlink: Callable[[str, tarfile.TarInfo], _T],
process_other, process_other: Callable[[str, tarfile.TarInfo], _T],
follow_links=False, follow_links: bool = False,
log=None, log: Callable[[str], None] | None = None,
): ) -> _T:
"""Fetch a file (as a tar file entry) from a Docker container to local.""" """Fetch a file (as a tar file entry) from a Docker container to local."""
considered_in_paths = set() considered_in_paths: set[str] = set()
while True: while True:
if in_path in considered_in_paths: if in_path in considered_in_paths:
@ -372,8 +424,8 @@ def fetch_file_ex(
with tarfile.open( with tarfile.open(
fileobj=_stream_generator_to_fileobj(stream), mode="r|" fileobj=_stream_generator_to_fileobj(stream), mode="r|"
) as tar: ) as tar:
symlink_member = None symlink_member: tarfile.TarInfo | None = None
result = None result: _T | None = None
found = False found = False
for member in tar: for member in tar:
if found: if found:
@ -398,35 +450,46 @@ def fetch_file_ex(
log(f'FETCH: Following symbolic link to "{in_path}"') log(f'FETCH: Following symbolic link to "{in_path}"')
continue continue
if found: if found:
return result return result # type: ignore
raise DockerUnexpectedError("Received tarfile is empty!") raise DockerUnexpectedError("Received tarfile is empty!")
def fetch_file(client, container, in_path, out_path, follow_links=False, log=None): def fetch_file(
client: APIClient,
container: str,
in_path: str,
out_path: str,
follow_links: bool = False,
log: Callable[[str], None] | None = None,
) -> str:
b_out_path = to_bytes(out_path, errors="surrogate_or_strict") b_out_path = to_bytes(out_path, errors="surrogate_or_strict")
def process_none(in_path): def process_none(in_path: str) -> str:
raise DockerFileNotFound( raise DockerFileNotFound(
f"File {in_path} does not exist in container {container}" f"File {in_path} does not exist in container {container}"
) )
def process_regular(in_path, tar, member): def process_regular(
in_path: str, tar: tarfile.TarFile, member: tarfile.TarInfo
) -> str:
if not follow_links and os.path.exists(b_out_path): if not follow_links and os.path.exists(b_out_path):
os.unlink(b_out_path) os.unlink(b_out_path)
with tar.extractfile(member) as in_f: reader = tar.extractfile(member)
with open(b_out_path, "wb") as out_f: if reader:
shutil.copyfileobj(in_f, out_f) with reader as in_f:
with open(b_out_path, "wb") as out_f:
shutil.copyfileobj(in_f, out_f)
return in_path return in_path
def process_symlink(in_path, member): def process_symlink(in_path, member) -> str:
if os.path.exists(b_out_path): if os.path.exists(b_out_path):
os.unlink(b_out_path) os.unlink(b_out_path)
os.symlink(member.linkname, b_out_path) os.symlink(member.linkname, b_out_path)
return in_path return in_path
def process_other(in_path, member): def process_other(in_path, member) -> str:
raise DockerFileCopyError( raise DockerFileCopyError(
f'Remote file "{in_path}" is not a regular file or a symbolic link' f'Remote file "{in_path}" is not a regular file or a symbolic link'
) )
@ -444,7 +507,13 @@ def fetch_file(client, container, in_path, out_path, follow_links=False, log=Non
) )
def _execute_command(client, container, command, log=None, check_rc=False): def _execute_command(
client: APIClient,
container: str,
command: list[str],
log: Callable[[str], None] | None = None,
check_rc: bool = False,
) -> tuple[int, bytes, bytes]:
if log: if log:
log(f"Executing {command} in {container}") log(f"Executing {command} in {container}")
@ -493,13 +562,15 @@ def _execute_command(client, container, command, log=None, check_rc=False):
if check_rc and rc != 0: if check_rc and rc != 0:
command_str = " ".join(command) command_str = " ".join(command)
raise DockerUnexpectedError( raise DockerUnexpectedError(
f'Obtained unexpected exit code {rc} when running "{command_str}" in {container}.\nSTDOUT: {stdout}\nSTDERR: {stderr}' f'Obtained unexpected exit code {rc} when running "{command_str}" in {container}.\nSTDOUT: {stdout!r}\nSTDERR: {stderr!r}'
) )
return rc, stdout, stderr return rc, stdout, stderr
def determine_user_group(client, container, log=None): def determine_user_group(
client: APIClient, container: str, log: Callable[[str], None] | None = None
) -> tuple[int, int]:
dummy_rc, stdout, dummy_stderr = _execute_command( dummy_rc, stdout, dummy_stderr = _execute_command(
client, container, ["/bin/sh", "-c", "id -u && id -g"], check_rc=True, log=log client, container, ["/bin/sh", "-c", "id -u && id -g"], check_rc=True, log=log
) )
@ -507,7 +578,7 @@ def determine_user_group(client, container, log=None):
stdout_lines = stdout.splitlines() stdout_lines = stdout.splitlines()
if len(stdout_lines) != 2: if len(stdout_lines) != 2:
raise DockerUnexpectedError( raise DockerUnexpectedError(
f"Expected two-line output to obtain user and group ID for container {container}, but got {len(stdout_lines)} lines:\n{stdout}" f"Expected two-line output to obtain user and group ID for container {container}, but got {len(stdout_lines)} lines:\n{stdout!r}"
) )
user_id, group_id = stdout_lines user_id, group_id = stdout_lines
@ -515,5 +586,5 @@ def determine_user_group(client, container, log=None):
return int(user_id), int(group_id) return int(user_id), int(group_id)
except ValueError as exc: except ValueError as exc:
raise DockerUnexpectedError( raise DockerUnexpectedError(
f'Expected two-line output with numeric IDs to obtain user and group ID for container {container}, but got "{user_id}" and "{group_id}" instead' f"Expected two-line output with numeric IDs to obtain user and group ID for container {container}, but got {user_id!r} and {group_id!r} instead"
) from exc ) from exc

View File

@ -18,12 +18,10 @@ class ImageArchiveManifestSummary:
"docker image save some:tag > some.tar" command. "docker image save some:tag > some.tar" command.
""" """
def __init__(self, image_id, repo_tags): def __init__(self, image_id: str, repo_tags: list[str]) -> None:
""" """
:param image_id: File name portion of Config entry, e.g. abcde12345 from abcde12345.json :param image_id: File name portion of Config entry, e.g. abcde12345 from abcde12345.json
:type image_id: str
:param repo_tags Docker image names, e.g. ["hello-world:latest"] :param repo_tags Docker image names, e.g. ["hello-world:latest"]
:type repo_tags: list[str]
""" """
self.image_id = image_id self.image_id = image_id
@ -34,22 +32,21 @@ class ImageArchiveInvalidException(Exception):
pass pass
def api_image_id(archive_image_id): def api_image_id(archive_image_id: str) -> str:
""" """
Accepts an image hash in the format stored in manifest.json, and returns an equivalent identifier Accepts an image hash in the format stored in manifest.json, and returns an equivalent identifier
that represents the same image hash, but in the format presented by the Docker Engine API. that represents the same image hash, but in the format presented by the Docker Engine API.
:param archive_image_id: plain image hash :param archive_image_id: plain image hash
:type archive_image_id: str
:returns: Prefixed hash used by REST api :returns: Prefixed hash used by REST api
:rtype: str
""" """
return f"sha256:{archive_image_id}" return f"sha256:{archive_image_id}"
def load_archived_image_manifest(archive_path): def load_archived_image_manifest(
archive_path: str,
) -> list[ImageArchiveManifestSummary] | None:
""" """
Attempts to get image IDs and image names from metadata stored in the image Attempts to get image IDs and image names from metadata stored in the image
archive tar file. archive tar file.
@ -62,10 +59,7 @@ def load_archived_image_manifest(archive_path):
ImageArchiveInvalidException: A file already exists at archive_path, but could not extract an image ID from it. ImageArchiveInvalidException: A file already exists at archive_path, but could not extract an image ID from it.
:param archive_path: Tar file to read :param archive_path: Tar file to read
:type archive_path: str
:return: None, if no file at archive_path, or a list of ImageArchiveManifestSummary objects. :return: None, if no file at archive_path, or a list of ImageArchiveManifestSummary objects.
:rtype: ImageArchiveManifestSummary
""" """
try: try:
@ -76,8 +70,15 @@ def load_archived_image_manifest(archive_path):
with tarfile.open(archive_path, "r") as tf: with tarfile.open(archive_path, "r") as tf:
try: try:
try: try:
with tf.extractfile("manifest.json") as ef: reader = tf.extractfile("manifest.json")
if reader is None:
raise ImageArchiveInvalidException(
"Failed to read manifest.json"
)
with reader as ef:
manifest = json.load(ef) manifest = json.load(ef)
except ImageArchiveInvalidException:
raise
except Exception as exc: except Exception as exc:
raise ImageArchiveInvalidException( raise ImageArchiveInvalidException(
f"Failed to decode and deserialize manifest.json: {exc}" f"Failed to decode and deserialize manifest.json: {exc}"
@ -139,7 +140,7 @@ def load_archived_image_manifest(archive_path):
) from exc ) from exc
def archived_image_manifest(archive_path): def archived_image_manifest(archive_path: str) -> ImageArchiveManifestSummary | None:
""" """
Attempts to get Image.Id and image name from metadata stored in the image Attempts to get Image.Id and image name from metadata stored in the image
archive tar file. archive tar file.
@ -152,10 +153,7 @@ def archived_image_manifest(archive_path):
ImageArchiveInvalidException: A file already exists at archive_path, but could not extract an image ID from it. ImageArchiveInvalidException: A file already exists at archive_path, but could not extract an image ID from it.
:param archive_path: Tar file to read :param archive_path: Tar file to read
:type archive_path: str
:return: None, if no file at archive_path, or the extracted image ID, which will not have a sha256: prefix. :return: None, if no file at archive_path, or the extracted image ID, which will not have a sha256: prefix.
:rtype: ImageArchiveManifestSummary
""" """
results = load_archived_image_manifest(archive_path) results = load_archived_image_manifest(archive_path)

View File

@ -13,6 +13,9 @@ See https://pkg.go.dev/github.com/kr/logfmt?utm_source=godoc for information on
from __future__ import annotations from __future__ import annotations
import typing as t
from enum import Enum
# The format is defined in https://pkg.go.dev/github.com/kr/logfmt?utm_source=godoc # The format is defined in https://pkg.go.dev/github.com/kr/logfmt?utm_source=godoc
# (look for "EBNFish") # (look for "EBNFish")
@ -22,7 +25,7 @@ class InvalidLogFmt(Exception):
pass pass
class _Mode: class _Mode(Enum):
GARBAGE = 0 GARBAGE = 0
KEY = 1 KEY = 1
EQUAL = 2 EQUAL = 2
@ -68,29 +71,29 @@ _HEX_DICT = {
} }
def _is_ident(cur): def _is_ident(cur: str) -> bool:
return cur > " " and cur not in ('"', "=") return cur > " " and cur not in ('"', "=")
class _Parser: class _Parser:
def __init__(self, line): def __init__(self, line: str) -> None:
self.line = line self.line = line
self.index = 0 self.index = 0
self.length = len(line) self.length = len(line)
def done(self): def done(self) -> bool:
return self.index >= self.length return self.index >= self.length
def cur(self): def cur(self) -> str:
return self.line[self.index] return self.line[self.index]
def next(self): def next(self) -> None:
self.index += 1 self.index += 1
def prev(self): def prev(self) -> None:
self.index -= 1 self.index -= 1
def parse_unicode_sequence(self): def parse_unicode_sequence(self) -> str:
if self.index + 6 > self.length: if self.index + 6 > self.length:
raise InvalidLogFmt("Not enough space for unicode escape") raise InvalidLogFmt("Not enough space for unicode escape")
if self.line[self.index : self.index + 2] != "\\u": if self.line[self.index : self.index + 2] != "\\u":
@ -108,27 +111,27 @@ class _Parser:
return chr(v) return chr(v)
def parse_line(line, logrus_mode=False): def parse_line(line: str, logrus_mode: bool = False) -> dict[str, t.Any]:
result = {} result: dict[str, t.Any] = {}
parser = _Parser(line) parser = _Parser(line)
key = [] key: list[str] = []
value = [] value: list[str] = []
mode = _Mode.GARBAGE mode = _Mode.GARBAGE
def handle_kv(has_no_value=False): def handle_kv(has_no_value: bool = False) -> None:
k = "".join(key) k = "".join(key)
v = None if has_no_value else "".join(value) v = None if has_no_value else "".join(value)
result[k] = v result[k] = v
del key[:] del key[:]
del value[:] del value[:]
def parse_garbage(cur): def parse_garbage(cur: str) -> _Mode:
if _is_ident(cur): if _is_ident(cur):
return _Mode.KEY return _Mode.KEY
parser.next() parser.next()
return _Mode.GARBAGE return _Mode.GARBAGE
def parse_key(cur): def parse_key(cur: str) -> _Mode:
if _is_ident(cur): if _is_ident(cur):
key.append(cur) key.append(cur)
parser.next() parser.next()
@ -142,7 +145,7 @@ def parse_line(line, logrus_mode=False):
parser.next() parser.next()
return _Mode.GARBAGE return _Mode.GARBAGE
def parse_equal(cur): def parse_equal(cur: str) -> _Mode:
if _is_ident(cur): if _is_ident(cur):
value.append(cur) value.append(cur)
parser.next() parser.next()
@ -154,7 +157,7 @@ def parse_line(line, logrus_mode=False):
parser.next() parser.next()
return _Mode.GARBAGE return _Mode.GARBAGE
def parse_ident_value(cur): def parse_ident_value(cur: str) -> _Mode:
if _is_ident(cur): if _is_ident(cur):
value.append(cur) value.append(cur)
parser.next() parser.next()
@ -163,7 +166,7 @@ def parse_line(line, logrus_mode=False):
parser.next() parser.next()
return _Mode.GARBAGE return _Mode.GARBAGE
def parse_quoted_value(cur): def parse_quoted_value(cur: str) -> _Mode:
if cur == "\\": if cur == "\\":
parser.next() parser.next()
if parser.done(): if parser.done():

View File

@ -14,12 +14,13 @@
from __future__ import annotations from __future__ import annotations
import re import re
import typing as t
_VALID_STR = re.compile("^[A-Za-z0-9_-]+$") _VALID_STR = re.compile("^[A-Za-z0-9_-]+$")
def _validate_part(string, part, part_name): def _validate_part(string: str, part: str, part_name: str) -> str:
if not part: if not part:
raise ValueError(f'Invalid platform string "{string}": {part_name} is empty') raise ValueError(f'Invalid platform string "{string}": {part_name} is empty')
if not _VALID_STR.match(part): if not _VALID_STR.match(part):
@ -79,7 +80,7 @@ _KNOWN_ARCH = (
) )
def _normalize_os(os_str): def _normalize_os(os_str: str) -> str:
# See normalizeOS() in https://github.com/containerd/containerd/blob/main/platforms/database.go # See normalizeOS() in https://github.com/containerd/containerd/blob/main/platforms/database.go
os_str = os_str.lower() os_str = os_str.lower()
if os_str == "macos": if os_str == "macos":
@ -112,7 +113,7 @@ _NORMALIZE_ARCH = {
} }
def _normalize_arch(arch_str, variant_str): def _normalize_arch(arch_str: str, variant_str: str) -> tuple[str, str]:
# See normalizeArch() in https://github.com/containerd/containerd/blob/main/platforms/database.go # See normalizeArch() in https://github.com/containerd/containerd/blob/main/platforms/database.go
arch_str = arch_str.lower() arch_str = arch_str.lower()
variant_str = variant_str.lower() variant_str = variant_str.lower()
@ -121,15 +122,16 @@ def _normalize_arch(arch_str, variant_str):
res = _NORMALIZE_ARCH.get((arch_str, None)) res = _NORMALIZE_ARCH.get((arch_str, None))
if res is None: if res is None:
return arch_str, variant_str return arch_str, variant_str
if res is not None: arch_str = res[0]
arch_str = res[0] if res[1] is not None:
if res[1] is not None: variant_str = res[1]
variant_str = res[1] return arch_str, variant_str
return arch_str, variant_str
class _Platform: class _Platform:
def __init__(self, os=None, arch=None, variant=None): def __init__(
self, os: str | None = None, arch: str | None = None, variant: str | None = None
) -> None:
self.os = os self.os = os
self.arch = arch self.arch = arch
self.variant = variant self.variant = variant
@ -140,7 +142,12 @@ class _Platform:
raise ValueError("If variant is given, os must be given too") raise ValueError("If variant is given, os must be given too")
@classmethod @classmethod
def parse_platform_string(cls, string, daemon_os=None, daemon_arch=None): def parse_platform_string(
cls,
string: str | None,
daemon_os: str | None = None,
daemon_arch: str | None = None,
) -> t.Self:
# See Parse() in https://github.com/containerd/containerd/blob/main/platforms/platforms.go # See Parse() in https://github.com/containerd/containerd/blob/main/platforms/platforms.go
if string is None: if string is None:
return cls() return cls()
@ -182,6 +189,7 @@ class _Platform:
) )
if variant is not None and not variant: if variant is not None and not variant:
raise ValueError(f'Invalid platform string "{string}": variant is empty') raise ValueError(f'Invalid platform string "{string}": variant is empty')
assert arch is not None # otherwise variant would be None as well
arch, variant = _normalize_arch(arch, variant or "") arch, variant = _normalize_arch(arch, variant or "")
if len(parts) == 2 and arch == "arm" and variant == "v7": if len(parts) == 2 and arch == "arm" and variant == "v7":
variant = None variant = None
@ -189,9 +197,12 @@ class _Platform:
variant = "v8" variant = "v8"
return cls(os=_normalize_os(os), arch=arch, variant=variant or None) return cls(os=_normalize_os(os), arch=arch, variant=variant or None)
def __str__(self): def __str__(self) -> str:
if self.variant: if self.variant:
parts = [self.os, self.arch, self.variant] assert (
self.os is not None and self.arch is not None
) # ensured in constructor
parts: list[str] = [self.os, self.arch, self.variant]
elif self.os: elif self.os:
if self.arch: if self.arch:
parts = [self.os, self.arch] parts = [self.os, self.arch]
@ -203,12 +214,14 @@ class _Platform:
parts = [] parts = []
return "/".join(parts) return "/".join(parts)
def __repr__(self): def __repr__(self) -> str:
return ( return (
f"_Platform(os={self.os!r}, arch={self.arch!r}, variant={self.variant!r})" f"_Platform(os={self.os!r}, arch={self.arch!r}, variant={self.variant!r})"
) )
def __eq__(self, other): def __eq__(self, other: object) -> bool:
if not isinstance(other, _Platform):
return NotImplemented
return ( return (
self.os == other.os self.os == other.os
and self.arch == other.arch and self.arch == other.arch
@ -216,7 +229,9 @@ class _Platform:
) )
def normalize_platform_string(string, daemon_os=None, daemon_arch=None): def normalize_platform_string(
string: str, daemon_os: str | None = None, daemon_arch: str | None = None
) -> str:
return str( return str(
_Platform.parse_platform_string( _Platform.parse_platform_string(
string, daemon_os=daemon_os, daemon_arch=daemon_arch string, daemon_os=daemon_os, daemon_arch=daemon_arch
@ -225,8 +240,12 @@ def normalize_platform_string(string, daemon_os=None, daemon_arch=None):
def compose_platform_string( def compose_platform_string(
os=None, arch=None, variant=None, daemon_os=None, daemon_arch=None os: str | None = None,
): arch: str | None = None,
variant: str | None = None,
daemon_os: str | None = None,
daemon_arch: str | None = None,
) -> str:
if os is None and daemon_os is not None: if os is None and daemon_os is not None:
os = _normalize_os(daemon_os) os = _normalize_os(daemon_os)
if arch is None and daemon_arch is not None: if arch is None and daemon_arch is not None:
@ -235,7 +254,7 @@ def compose_platform_string(
return str(_Platform(os=os, arch=arch, variant=variant or None)) return str(_Platform(os=os, arch=arch, variant=variant or None))
def compare_platform_strings(string1, string2): def compare_platform_strings(string1: str, string2: str) -> bool:
return _Platform.parse_platform_string(string1) == _Platform.parse_platform_string( return _Platform.parse_platform_string(string1) == _Platform.parse_platform_string(
string2 string2
) )

View File

@ -13,7 +13,7 @@ import random
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
def generate_insecure_key(): def generate_insecure_key() -> bytes:
"""Do NOT use this for cryptographic purposes!""" """Do NOT use this for cryptographic purposes!"""
while True: while True:
# Generate a one-byte key. Right now the functions below do not use more # Generate a one-byte key. Right now the functions below do not use more
@ -24,23 +24,23 @@ def generate_insecure_key():
return key return key
def scramble(value, key): def scramble(value: str, key: bytes) -> str:
"""Do NOT use this for cryptographic purposes!""" """Do NOT use this for cryptographic purposes!"""
if len(key) < 1: if len(key) < 1:
raise ValueError("Key must be at least one byte") raise ValueError("Key must be at least one byte")
value = to_bytes(value) b_value = to_bytes(value)
k = key[0] k = key[0]
value = bytes([k ^ b for b in value]) b_value = bytes([k ^ b for b in b_value])
return "=S=" + to_native(base64.b64encode(value)) return f"=S={to_native(base64.b64encode(b_value))}"
def unscramble(value, key): def unscramble(value: str, key: bytes) -> str:
"""Do NOT use this for cryptographic purposes!""" """Do NOT use this for cryptographic purposes!"""
if len(key) < 1: if len(key) < 1:
raise ValueError("Key must be at least one byte") raise ValueError("Key must be at least one byte")
if not value.startswith("=S="): if not value.startswith("=S="):
raise ValueError("Value does not start with indicator") raise ValueError("Value does not start with indicator")
value = base64.b64decode(value[3:]) b_value = base64.b64decode(value[3:])
k = key[0] k = key[0]
value = bytes([k ^ b for b in value]) b_value = bytes([k ^ b for b in b_value])
return to_text(value) return to_text(b_value)

View File

@ -9,6 +9,7 @@
from __future__ import annotations from __future__ import annotations
import json import json
import typing as t
from time import sleep from time import sleep
@ -28,10 +29,7 @@ from ansible_collections.community.docker.plugins.module_utils._version import (
class AnsibleDockerSwarmClient(AnsibleDockerClient): class AnsibleDockerSwarmClient(AnsibleDockerClient):
def __init__(self, **kwargs): def get_swarm_node_id(self) -> str | None:
super().__init__(**kwargs)
def get_swarm_node_id(self):
""" """
Get the 'NodeID' of the Swarm node or 'None' if host is not in Swarm. It returns the NodeID Get the 'NodeID' of the Swarm node or 'None' if host is not in Swarm. It returns the NodeID
of Docker host the module is executed on of Docker host the module is executed on
@ -51,7 +49,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
return swarm_info["Swarm"]["NodeID"] return swarm_info["Swarm"]["NodeID"]
return None return None
def check_if_swarm_node(self, node_id=None): def check_if_swarm_node(self, node_id: str | None = None) -> bool | None:
""" """
Checking if host is part of Docker Swarm. If 'node_id' is not provided it reads the Docker host Checking if host is part of Docker Swarm. If 'node_id' is not provided it reads the Docker host
system information looking if specific key in output exists. If 'node_id' is provided then it tries to system information looking if specific key in output exists. If 'node_id' is provided then it tries to
@ -83,11 +81,11 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
try: try:
node_info = self.get_node_inspect(node_id=node_id) node_info = self.get_node_inspect(node_id=node_id)
except APIError: except APIError:
return return None
return node_info["ID"] is not None return node_info["ID"] is not None
def check_if_swarm_manager(self): def check_if_swarm_manager(self) -> bool:
""" """
Checks if node role is set as Manager in Swarm. The node is the docker host on which module action Checks if node role is set as Manager in Swarm. The node is the docker host on which module action
is performed. The inspect_swarm() will fail if node is not a manager is performed. The inspect_swarm() will fail if node is not a manager
@ -101,7 +99,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
except APIError: except APIError:
return False return False
def fail_task_if_not_swarm_manager(self): def fail_task_if_not_swarm_manager(self) -> None:
""" """
If host is not a swarm manager then Ansible task on this host should end with 'failed' state If host is not a swarm manager then Ansible task on this host should end with 'failed' state
""" """
@ -110,7 +108,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
"Error running docker swarm module: must run on swarm manager node" "Error running docker swarm module: must run on swarm manager node"
) )
def check_if_swarm_worker(self): def check_if_swarm_worker(self) -> bool:
""" """
Checks if node role is set as Worker in Swarm. The node is the docker host on which module action Checks if node role is set as Worker in Swarm. The node is the docker host on which module action
is performed. Will fail if run on host that is not part of Swarm via check_if_swarm_node() is performed. Will fail if run on host that is not part of Swarm via check_if_swarm_node()
@ -122,7 +120,9 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
return True return True
return False return False
def check_if_swarm_node_is_down(self, node_id=None, repeat_check=1): def check_if_swarm_node_is_down(
self, node_id: str | None = None, repeat_check: int = 1
) -> bool:
""" """
Checks if node status on Swarm manager is 'down'. If node_id is provided it query manager about Checks if node status on Swarm manager is 'down'. If node_id is provided it query manager about
node specified in parameter, otherwise it query manager itself. If run on Swarm Worker node or node specified in parameter, otherwise it query manager itself. If run on Swarm Worker node or
@ -147,7 +147,19 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
return True return True
return False return False
def get_node_inspect(self, node_id=None, skip_missing=False): @t.overload
def get_node_inspect(
self, node_id: str | None = None, skip_missing: t.Literal[False] = False
) -> dict[str, t.Any]: ...
@t.overload
def get_node_inspect(
self, node_id: str | None = None, skip_missing: bool = False
) -> dict[str, t.Any] | None: ...
def get_node_inspect(
self, node_id: str | None = None, skip_missing: bool = False
) -> dict[str, t.Any] | None:
""" """
Returns Swarm node info as in 'docker node inspect' command about single node Returns Swarm node info as in 'docker node inspect' command about single node
@ -195,7 +207,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
node_info["Status"]["Addr"] = swarm_leader_ip node_info["Status"]["Addr"] = swarm_leader_ip
return node_info return node_info
def get_all_nodes_inspect(self): def get_all_nodes_inspect(self) -> list[dict[str, t.Any]]:
""" """
Returns Swarm node info as in 'docker node inspect' command about all registered nodes Returns Swarm node info as in 'docker node inspect' command about all registered nodes
@ -217,7 +229,17 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
node_info = json.loads(json_str) node_info = json.loads(json_str)
return node_info return node_info
def get_all_nodes_list(self, output="short"): @t.overload
def get_all_nodes_list(self, output: t.Literal["short"] = "short") -> list[str]: ...
@t.overload
def get_all_nodes_list(
self, output: t.Literal["long"]
) -> list[dict[str, t.Any]]: ...
def get_all_nodes_list(
self, output: t.Literal["short", "long"] = "short"
) -> list[str] | list[dict[str, t.Any]]:
""" """
Returns list of nodes registered in Swarm Returns list of nodes registered in Swarm
@ -227,48 +249,46 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
if 'output' is 'long' then returns data is list of dict containing the attributes as in if 'output' is 'long' then returns data is list of dict containing the attributes as in
output of command 'docker node ls' output of command 'docker node ls'
""" """
nodes_list = []
nodes_inspect = self.get_all_nodes_inspect() nodes_inspect = self.get_all_nodes_inspect()
if nodes_inspect is None:
return None
if output == "short": if output == "short":
nodes_list = []
for node in nodes_inspect: for node in nodes_inspect:
nodes_list.append(node["Description"]["Hostname"]) nodes_list.append(node["Description"]["Hostname"])
elif output == "long": return nodes_list
if output == "long":
nodes_info_list = []
for node in nodes_inspect: for node in nodes_inspect:
node_property = {} node_property: dict[str, t.Any] = {}
node_property.update({"ID": node["ID"]}) node_property["ID"] = node["ID"]
node_property.update({"Hostname": node["Description"]["Hostname"]}) node_property["Hostname"] = node["Description"]["Hostname"]
node_property.update({"Status": node["Status"]["State"]}) node_property["Status"] = node["Status"]["State"]
node_property.update({"Availability": node["Spec"]["Availability"]}) node_property["Availability"] = node["Spec"]["Availability"]
if "ManagerStatus" in node: if "ManagerStatus" in node:
if node["ManagerStatus"]["Leader"] is True: if node["ManagerStatus"]["Leader"] is True:
node_property.update({"Leader": True}) node_property["Leader"] = True
node_property.update( node_property["ManagerStatus"] = node["ManagerStatus"][
{"ManagerStatus": node["ManagerStatus"]["Reachability"]} "Reachability"
) ]
node_property.update( node_property["EngineVersion"] = node["Description"]["Engine"][
{"EngineVersion": node["Description"]["Engine"]["EngineVersion"]} "EngineVersion"
) ]
nodes_list.append(node_property) nodes_info_list.append(node_property)
else: return nodes_info_list
return None
return nodes_list def get_node_name_by_id(self, nodeid: str) -> str:
def get_node_name_by_id(self, nodeid):
return self.get_node_inspect(nodeid)["Description"]["Hostname"] return self.get_node_inspect(nodeid)["Description"]["Hostname"]
def get_unlock_key(self): def get_unlock_key(self) -> str | None:
if self.docker_py_version < LooseVersion("2.7.0"): if self.docker_py_version < LooseVersion("2.7.0"):
return None return None
return super().get_unlock_key() return super().get_unlock_key()
def get_service_inspect(self, service_id, skip_missing=False): def get_service_inspect(
self, service_id: str, skip_missing: bool = False
) -> dict[str, t.Any] | None:
""" """
Returns Swarm service info as in 'docker service inspect' command about single service Returns Swarm service info as in 'docker service inspect' command about single service

View File

@ -9,6 +9,7 @@ from __future__ import annotations
import json import json
import re import re
import typing as t
from datetime import timedelta from datetime import timedelta
from urllib.parse import urlparse from urllib.parse import urlparse
@ -17,6 +18,12 @@ from ansible.module_utils.common.collections import is_sequence
from ansible.module_utils.common.text.converters import to_text from ansible.module_utils.common.text.converters import to_text
if t.TYPE_CHECKING:
from collections.abc import Callable
from ansible.module_utils.basic import AnsibleModule
DEFAULT_DOCKER_HOST = "unix:///var/run/docker.sock" DEFAULT_DOCKER_HOST = "unix:///var/run/docker.sock"
DEFAULT_TLS = False DEFAULT_TLS = False
DEFAULT_TLS_VERIFY = False DEFAULT_TLS_VERIFY = False
@ -79,14 +86,14 @@ DEFAULT_DOCKER_REGISTRY = "https://index.docker.io/v1/"
BYTE_SUFFIXES = ["B", "KB", "MB", "GB", "TB", "PB"] BYTE_SUFFIXES = ["B", "KB", "MB", "GB", "TB", "PB"]
def is_image_name_id(name): def is_image_name_id(name: str) -> bool:
"""Check whether the given image name is in fact an image ID (hash).""" """Check whether the given image name is in fact an image ID (hash)."""
if re.match("^sha256:[0-9a-fA-F]{64}$", name): if re.match("^sha256:[0-9a-fA-F]{64}$", name):
return True return True
return False return False
def is_valid_tag(tag, allow_empty=False): def is_valid_tag(tag: str, allow_empty: bool = False) -> bool:
"""Check whether the given string is a valid docker tag name.""" """Check whether the given string is a valid docker tag name."""
if not tag: if not tag:
return allow_empty return allow_empty
@ -95,7 +102,7 @@ def is_valid_tag(tag, allow_empty=False):
return bool(re.match("^[a-zA-Z0-9_][a-zA-Z0-9_.-]{0,127}$", tag)) return bool(re.match("^[a-zA-Z0-9_][a-zA-Z0-9_.-]{0,127}$", tag))
def sanitize_result(data): def sanitize_result(data: t.Any) -> t.Any:
"""Sanitize data object for return to Ansible. """Sanitize data object for return to Ansible.
When the data object contains types such as docker.types.containers.HostConfig, When the data object contains types such as docker.types.containers.HostConfig,
@ -112,7 +119,7 @@ def sanitize_result(data):
return data return data
def log_debug(msg, pretty_print=False): def log_debug(msg: t.Any, pretty_print: bool = False):
"""Write a log message to docker.log. """Write a log message to docker.log.
If ``pretty_print=True``, the message will be pretty-printed as JSON. If ``pretty_print=True``, the message will be pretty-printed as JSON.
@ -128,25 +135,28 @@ def log_debug(msg, pretty_print=False):
class DockerBaseClass: class DockerBaseClass:
def __init__(self): def __init__(self) -> None:
self.debug = False self.debug = False
def log(self, msg, pretty_print=False): def log(self, msg: t.Any, pretty_print: bool = False) -> None:
pass pass
# if self.debug: # if self.debug:
# log_debug(msg, pretty_print=pretty_print) # log_debug(msg, pretty_print=pretty_print)
def update_tls_hostname( def update_tls_hostname(
result, old_behavior=False, deprecate_function=None, uses_tls=True result: dict[str, t.Any],
): old_behavior: bool = False,
deprecate_function: Callable[[str], None] | None = None,
uses_tls: bool = True,
) -> None:
if result["tls_hostname"] is None: if result["tls_hostname"] is None:
# get default machine name from the url # get default machine name from the url
parsed_url = urlparse(result["docker_host"]) parsed_url = urlparse(result["docker_host"])
result["tls_hostname"] = parsed_url.netloc.rsplit(":", 1)[0] result["tls_hostname"] = parsed_url.netloc.rsplit(":", 1)[0]
def compare_dict_allow_more_present(av, bv): def compare_dict_allow_more_present(av: dict, bv: dict) -> bool:
""" """
Compare two dictionaries for whether every entry of the first is in the second. Compare two dictionaries for whether every entry of the first is in the second.
""" """
@ -158,7 +168,12 @@ def compare_dict_allow_more_present(av, bv):
return True return True
def compare_generic(a, b, method, datatype): def compare_generic(
a: t.Any,
b: t.Any,
method: t.Literal["ignore", "strict", "allow_more_present"],
datatype: t.Literal["value", "list", "set", "set(dict)", "dict"],
) -> bool:
""" """
Compare values a and b as described by method and datatype. Compare values a and b as described by method and datatype.
@ -249,10 +264,10 @@ def compare_generic(a, b, method, datatype):
class DifferenceTracker: class DifferenceTracker:
def __init__(self): def __init__(self) -> None:
self._diff = [] self._diff: list[dict[str, t.Any]] = []
def add(self, name, parameter=None, active=None): def add(self, name: str, parameter: t.Any = None, active: t.Any = None) -> None:
self._diff.append( self._diff.append(
{ {
"name": name, "name": name,
@ -261,14 +276,14 @@ class DifferenceTracker:
} }
) )
def merge(self, other_tracker): def merge(self, other_tracker: DifferenceTracker) -> None:
self._diff.extend(other_tracker._diff) self._diff.extend(other_tracker._diff)
@property @property
def empty(self): def empty(self) -> bool:
return len(self._diff) == 0 return len(self._diff) == 0
def get_before_after(self): def get_before_after(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
""" """
Return texts ``before`` and ``after``. Return texts ``before`` and ``after``.
""" """
@ -279,13 +294,13 @@ class DifferenceTracker:
after[item["name"]] = item["parameter"] after[item["name"]] = item["parameter"]
return before, after return before, after
def has_difference_for(self, name): def has_difference_for(self, name: str) -> bool:
""" """
Returns a boolean if a difference exists for name Returns a boolean if a difference exists for name
""" """
return any(diff for diff in self._diff if diff["name"] == name) return any(diff for diff in self._diff if diff["name"] == name)
def get_legacy_docker_container_diffs(self): def get_legacy_docker_container_diffs(self) -> list[dict[str, t.Any]]:
""" """
Return differences in the docker_container legacy format. Return differences in the docker_container legacy format.
""" """
@ -299,7 +314,7 @@ class DifferenceTracker:
result.append(item) result.append(item)
return result return result
def get_legacy_docker_diffs(self): def get_legacy_docker_diffs(self) -> list[str]:
""" """
Return differences in the docker_container legacy format. Return differences in the docker_container legacy format.
""" """
@ -307,8 +322,13 @@ class DifferenceTracker:
return result return result
def sanitize_labels(labels, labels_field, client=None, module=None): def sanitize_labels(
def fail(msg): labels: dict[str, t.Any] | None,
labels_field: str,
client=None,
module: AnsibleModule | None = None,
) -> None:
def fail(msg: str) -> t.NoReturn:
if client is not None: if client is not None:
client.fail(msg) client.fail(msg)
if module is not None: if module is not None:
@ -327,7 +347,21 @@ def sanitize_labels(labels, labels_field, client=None, module=None):
labels[k] = to_text(v) labels[k] = to_text(v)
def clean_dict_booleans_for_docker_api(data, allow_sequences=False): @t.overload
def clean_dict_booleans_for_docker_api(
data: dict[str, t.Any], *, allow_sequences: t.Literal[False] = False
) -> dict[str, str]: ...
@t.overload
def clean_dict_booleans_for_docker_api(
data: dict[str, t.Any], *, allow_sequences: bool
) -> dict[str, str | list[str]]: ...
def clean_dict_booleans_for_docker_api(
data: dict[str, t.Any], *, allow_sequences: bool = False
) -> dict[str, str] | dict[str, str | list[str]]:
""" """
Go does not like Python booleans 'True' or 'False', while Ansible is just Go does not like Python booleans 'True' or 'False', while Ansible is just
fine with them in YAML. As such, they need to be converted in cases where fine with them in YAML. As such, they need to be converted in cases where
@ -355,7 +389,7 @@ def clean_dict_booleans_for_docker_api(data, allow_sequences=False):
return result return result
def convert_duration_to_nanosecond(time_str): def convert_duration_to_nanosecond(time_str: str) -> int:
""" """
Return time duration in nanosecond. Return time duration in nanosecond.
""" """
@ -374,9 +408,9 @@ def convert_duration_to_nanosecond(time_str):
if not parts: if not parts:
raise ValueError(f"Invalid time duration - {time_str}") raise ValueError(f"Invalid time duration - {time_str}")
parts = parts.groupdict() parts_dict = parts.groupdict()
time_params = {} time_params = {}
for name, value in parts.items(): for name, value in parts_dict.items():
if value: if value:
time_params[name] = int(value) time_params[name] = int(value)
@ -388,13 +422,15 @@ def convert_duration_to_nanosecond(time_str):
return time_in_nanoseconds return time_in_nanoseconds
def normalize_healthcheck_test(test): def normalize_healthcheck_test(test: t.Any) -> list[str]:
if isinstance(test, (tuple, list)): if isinstance(test, (tuple, list)):
return [str(e) for e in test] return [str(e) for e in test]
return ["CMD-SHELL", str(test)] return ["CMD-SHELL", str(test)]
def normalize_healthcheck(healthcheck, normalize_test=False): def normalize_healthcheck(
healthcheck: dict[str, t.Any], normalize_test: bool = False
) -> dict[str, t.Any]:
""" """
Return dictionary of healthcheck parameters. Return dictionary of healthcheck parameters.
""" """
@ -440,7 +476,9 @@ def normalize_healthcheck(healthcheck, normalize_test=False):
return result return result
def parse_healthcheck(healthcheck): def parse_healthcheck(
healthcheck: dict[str, t.Any] | None,
) -> tuple[dict[str, t.Any] | None, bool | None]:
""" """
Return dictionary of healthcheck parameters and boolean if Return dictionary of healthcheck parameters and boolean if
healthcheck defined in image was requested to be disabled. healthcheck defined in image was requested to be disabled.
@ -458,8 +496,8 @@ def parse_healthcheck(healthcheck):
return result, False return result, False
def omit_none_from_dict(d): def omit_none_from_dict(d: dict[str, t.Any]) -> dict[str, t.Any]:
""" """
Return a copy of the dictionary with all keys with value None omitted. Return a copy of the dictionary with all keys with value None omitted.
""" """
return dict((k, v) for (k, v) in d.items() if v is not None) return {k: v for (k, v) in d.items() if v is not None}

View File

@ -255,7 +255,7 @@ class DockerHostManager(DockerBaseClass):
returned_name = docker_object returned_name = docker_object
filter_name = docker_object + "_filters" filter_name = docker_object + "_filters"
filters = clean_dict_booleans_for_docker_api( filters = clean_dict_booleans_for_docker_api(
client.module.params.get(filter_name), True client.module.params.get(filter_name), allow_sequences=True
) )
self.results[returned_name] = self.get_docker_items_list( self.results[returned_name] = self.get_docker_items_list(
docker_object, filters docker_object, filters