Add typing information, 1/2 (#1176)

* Re-enable typing and improve config.

* Make mypy pass.

* Improve settings.

* First batch of types.

* Add more type hints.

* Fixes.

* Format.

* Fix split_port() without returning to previous type chaos.

* Continue with type hints (and ignores).
This commit is contained in:
Felix Fontein 2025-10-23 07:05:42 +02:00 committed by GitHub
parent 24f35644e3
commit 3350283bcc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
92 changed files with 4366 additions and 2272 deletions

View File

@ -7,17 +7,25 @@
# disallow_untyped_defs = True -- for later # disallow_untyped_defs = True -- for later
# strict = True -- only try to enable once everything (including dependencies!) is typed # strict = True -- only try to enable once everything (including dependencies!) is typed
# strict_equality = True -- for later strict_equality = True
# strict_bytes = True -- for later strict_bytes = True
# warn_redundant_casts = True -- for later warn_redundant_casts = True
# warn_return_any = True -- for later # warn_return_any = True -- for later
# warn_unreachable = True -- for later warn_unreachable = True
[mypy-ansible.*] [mypy-ansible.*]
# ansible-core has partial typing information # ansible-core has partial typing information
follow_untyped_imports = True follow_untyped_imports = True
[mypy-docker.*]
# Docker SDK for Python has partial typing information
follow_untyped_imports = True
[mypy-ansible_collections.community.internal_test_tools.*] [mypy-ansible_collections.community.internal_test_tools.*]
# community.internal_test_tools has no typing information # community.internal_test_tools has no typing information
ignore_missing_imports = True ignore_missing_imports = True
[mypy-jsondiff.*]
# jsondiff has no typing information
ignore_missing_imports = True

View File

@ -27,7 +27,7 @@ run_yamllint = true
yamllint_config = ".yamllint" yamllint_config = ".yamllint"
yamllint_config_plugins = ".yamllint-docs" yamllint_config_plugins = ".yamllint-docs"
yamllint_config_plugins_examples = ".yamllint-examples" yamllint_config_plugins_examples = ".yamllint-examples"
run_mypy = false run_mypy = true
mypy_ansible_core_package = "ansible-core>=2.19.0" mypy_ansible_core_package = "ansible-core>=2.19.0"
mypy_config = ".mypy.ini" mypy_config = ".mypy.ini"
mypy_extra_deps = [ mypy_extra_deps = [
@ -35,7 +35,11 @@ mypy_extra_deps = [
"paramiko", "paramiko",
"urllib3", "urllib3",
"requests", "requests",
"types-mock",
"types-paramiko",
"types-pywin32",
"types-PyYAML", "types-PyYAML",
"types-requests",
] ]
[sessions.docs_check] [sessions.docs_check]

View File

@ -5,6 +5,7 @@
from __future__ import annotations from __future__ import annotations
import base64 import base64
import typing as t
from ansible import constants as C from ansible import constants as C
from ansible.plugins.action import ActionBase from ansible.plugins.action import ActionBase
@ -19,14 +20,17 @@ class ActionModule(ActionBase):
# Set to True when transferring files to the remote # Set to True when transferring files to the remote
TRANSFERS_FILES = False TRANSFERS_FILES = False
def run(self, tmp=None, task_vars=None): def run(
self, tmp: str | None = None, task_vars: dict[str, t.Any] | None = None
) -> dict[str, t.Any]:
self._supports_check_mode = True self._supports_check_mode = True
self._supports_async = True self._supports_async = True
result = super().run(tmp, task_vars) result = super().run(tmp, task_vars)
del tmp # tmp no longer has any effect del tmp # tmp no longer has any effect
self._task.args["_max_file_size_for_diff"] = C.MAX_FILE_SIZE_FOR_DIFF max_file_size_for_diff: int = C.MAX_FILE_SIZE_FOR_DIFF # type: ignore
self._task.args["_max_file_size_for_diff"] = max_file_size_for_diff
result = merge_hash( result = merge_hash(
result, result,

View File

@ -118,6 +118,7 @@ import os.path
import re import re
import selectors import selectors
import subprocess import subprocess
import typing as t
from shlex import quote from shlex import quote
from ansible.errors import AnsibleConnectionFailure, AnsibleError, AnsibleFileNotFound from ansible.errors import AnsibleConnectionFailure, AnsibleError, AnsibleFileNotFound
@ -140,8 +141,8 @@ class Connection(ConnectionBase):
transport = "community.docker.docker" transport = "community.docker.docker"
has_pipelining = True has_pipelining = True
def __init__(self, play_context, new_stdin, *args, **kwargs): def __init__(self, *args, **kwargs) -> None:
super().__init__(play_context, new_stdin, *args, **kwargs) super().__init__(*args, **kwargs)
# Note: docker supports running as non-root in some configurations. # Note: docker supports running as non-root in some configurations.
# (For instance, setting the UNIX socket file to be readable and # (For instance, setting the UNIX socket file to be readable and
@ -152,11 +153,11 @@ class Connection(ConnectionBase):
# configured to be connected to by root and they are not running as # configured to be connected to by root and they are not running as
# root. # root.
self._docker_args = [] self._docker_args: list[bytes | str] = []
self._container_user_cache = {} self._container_user_cache: dict[str, str | None] = {}
self._version = None self._version: str | None = None
self.remote_user = None self.remote_user: str | None = None
self.timeout = None self.timeout: int | float | None = None
# Windows uses Powershell modules # Windows uses Powershell modules
if getattr(self._shell, "_IS_WINDOWS", False): if getattr(self._shell, "_IS_WINDOWS", False):
@ -171,12 +172,12 @@ class Connection(ConnectionBase):
raise AnsibleError("docker command not found in PATH") from exc raise AnsibleError("docker command not found in PATH") from exc
@staticmethod @staticmethod
def _sanitize_version(version): def _sanitize_version(version: str) -> str:
version = re.sub("[^0-9a-zA-Z.]", "", version) version = re.sub("[^0-9a-zA-Z.]", "", version)
version = re.sub("^v", "", version) version = re.sub("^v", "", version)
return version return version
def _old_docker_version(self): def _old_docker_version(self) -> tuple[list[str], str, bytes, int]:
cmd_args = self._docker_args cmd_args = self._docker_args
old_version_subcommand = ["version"] old_version_subcommand = ["version"]
@ -189,7 +190,7 @@ class Connection(ConnectionBase):
return old_docker_cmd, to_native(cmd_output), err, p.returncode return old_docker_cmd, to_native(cmd_output), err, p.returncode
def _new_docker_version(self): def _new_docker_version(self) -> tuple[list[str], str, bytes, int]:
# no result yet, must be newer Docker version # no result yet, must be newer Docker version
cmd_args = self._docker_args cmd_args = self._docker_args
@ -202,8 +203,7 @@ class Connection(ConnectionBase):
cmd_output, err = p.communicate() cmd_output, err = p.communicate()
return new_docker_cmd, to_native(cmd_output), err, p.returncode return new_docker_cmd, to_native(cmd_output), err, p.returncode
def _get_docker_version(self): def _get_docker_version(self) -> str:
cmd, cmd_output, err, returncode = self._old_docker_version() cmd, cmd_output, err, returncode = self._old_docker_version()
if returncode == 0: if returncode == 0:
for line in to_text(cmd_output, errors="surrogate_or_strict").split("\n"): for line in to_text(cmd_output, errors="surrogate_or_strict").split("\n"):
@ -218,7 +218,7 @@ class Connection(ConnectionBase):
return self._sanitize_version(to_text(cmd_output, errors="surrogate_or_strict")) return self._sanitize_version(to_text(cmd_output, errors="surrogate_or_strict"))
def _get_docker_remote_user(self): def _get_docker_remote_user(self) -> str | None:
"""Get the default user configured in the docker container""" """Get the default user configured in the docker container"""
container = self.get_option("remote_addr") container = self.get_option("remote_addr")
if container in self._container_user_cache: if container in self._container_user_cache:
@ -243,7 +243,7 @@ class Connection(ConnectionBase):
self._container_user_cache[container] = user self._container_user_cache[container] = user
return user return user
def _build_exec_cmd(self, cmd): def _build_exec_cmd(self, cmd: list[bytes | str]) -> list[bytes | str]:
"""Build the local docker exec command to run cmd on remote_host """Build the local docker exec command to run cmd on remote_host
If remote_user is available and is supported by the docker If remote_user is available and is supported by the docker
@ -298,7 +298,7 @@ class Connection(ConnectionBase):
return local_cmd return local_cmd
def _set_docker_args(self): def _set_docker_args(self) -> None:
# TODO: this is mostly for backwards compatibility, play_context is used as fallback for older versions # TODO: this is mostly for backwards compatibility, play_context is used as fallback for older versions
# docker arguments # docker arguments
del self._docker_args[:] del self._docker_args[:]
@ -308,7 +308,7 @@ class Connection(ConnectionBase):
if extra_args: if extra_args:
self._docker_args += extra_args.split(" ") self._docker_args += extra_args.split(" ")
def _set_conn_data(self): def _set_conn_data(self) -> None:
"""initialize for the connection, cannot do only in init since all data is not ready at that point""" """initialize for the connection, cannot do only in init since all data is not ready at that point"""
self._set_docker_args() self._set_docker_args()
@ -323,8 +323,7 @@ class Connection(ConnectionBase):
self.timeout = self._play_context.timeout self.timeout = self._play_context.timeout
@property @property
def docker_version(self): def docker_version(self) -> str:
if not self._version: if not self._version:
self._set_docker_args() self._set_docker_args()
@ -341,7 +340,7 @@ class Connection(ConnectionBase):
) )
return self._version return self._version
def _get_actual_user(self): def _get_actual_user(self) -> str | None:
if self.remote_user is not None: if self.remote_user is not None:
# An explicit user is provided # An explicit user is provided
if self.docker_version == "dev" or LooseVersion( if self.docker_version == "dev" or LooseVersion(
@ -353,7 +352,7 @@ class Connection(ConnectionBase):
actual_user = self._get_docker_remote_user() actual_user = self._get_docker_remote_user()
if actual_user != self.get_option("remote_user"): if actual_user != self.get_option("remote_user"):
display.warning( display.warning(
f'docker {self.docker_version} does not support remote_user, using container default: {actual_user or "?"}' f"docker {self.docker_version} does not support remote_user, using container default: {actual_user or '?'}"
) )
return actual_user return actual_user
if self._display.verbosity > 2: if self._display.verbosity > 2:
@ -363,9 +362,9 @@ class Connection(ConnectionBase):
return self._get_docker_remote_user() return self._get_docker_remote_user()
return None return None
def _connect(self, port=None): def _connect(self) -> t.Self:
"""Connect to the container. Nothing to do""" """Connect to the container. Nothing to do"""
super()._connect() super()._connect() # type: ignore[safe-super]
if not self._connected: if not self._connected:
self._set_conn_data() self._set_conn_data()
actual_user = self._get_actual_user() actual_user = self._get_actual_user()
@ -374,13 +373,16 @@ class Connection(ConnectionBase):
host=self.get_option("remote_addr"), host=self.get_option("remote_addr"),
) )
self._connected = True self._connected = True
return self
def exec_command(self, cmd, in_data=None, sudoable=False): def exec_command(
self, cmd: str, in_data: bytes | None = None, sudoable: bool = False
) -> tuple[int, bytes, bytes]:
"""Run a command on the docker host""" """Run a command on the docker host"""
self._set_conn_data() self._set_conn_data()
super().exec_command(cmd, in_data=in_data, sudoable=sudoable) super().exec_command(cmd, in_data=in_data, sudoable=sudoable) # type: ignore[safe-super]
local_cmd = self._build_exec_cmd([self._play_context.executable, "-c", cmd]) local_cmd = self._build_exec_cmd([self._play_context.executable, "-c", cmd])
@ -395,6 +397,9 @@ class Connection(ConnectionBase):
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
) as p: ) as p:
assert p.stdin is not None
assert p.stdout is not None
assert p.stderr is not None
display.debug("done running command with Popen()") display.debug("done running command with Popen()")
if self.become and self.become.expect_prompt() and sudoable: if self.become and self.become.expect_prompt() and sudoable:
@ -489,10 +494,10 @@ class Connection(ConnectionBase):
remote_path = os.path.join(os.path.sep, remote_path) remote_path = os.path.join(os.path.sep, remote_path)
return os.path.normpath(remote_path) return os.path.normpath(remote_path)
def put_file(self, in_path, out_path): def put_file(self, in_path: str, out_path: str) -> None:
"""Transfer a file from local to docker container""" """Transfer a file from local to docker container"""
self._set_conn_data() self._set_conn_data()
super().put_file(in_path, out_path) super().put_file(in_path, out_path) # type: ignore[safe-super]
display.vvv(f"PUT {in_path} TO {out_path}", host=self.get_option("remote_addr")) display.vvv(f"PUT {in_path} TO {out_path}", host=self.get_option("remote_addr"))
out_path = self._prefix_login_path(out_path) out_path = self._prefix_login_path(out_path)
@ -534,10 +539,10 @@ class Connection(ConnectionBase):
f"failed to transfer file {to_native(in_path)} to {to_native(out_path)}:\n{to_native(stdout)}\n{to_native(stderr)}" f"failed to transfer file {to_native(in_path)} to {to_native(out_path)}:\n{to_native(stdout)}\n{to_native(stderr)}"
) )
def fetch_file(self, in_path, out_path): def fetch_file(self, in_path: str, out_path: str) -> None:
"""Fetch a file from container to local.""" """Fetch a file from container to local."""
self._set_conn_data() self._set_conn_data()
super().fetch_file(in_path, out_path) super().fetch_file(in_path, out_path) # type: ignore[safe-super]
display.vvv( display.vvv(
f"FETCH {in_path} TO {out_path}", host=self.get_option("remote_addr") f"FETCH {in_path} TO {out_path}", host=self.get_option("remote_addr")
) )
@ -596,7 +601,7 @@ class Connection(ConnectionBase):
if pp.returncode != 0: if pp.returncode != 0:
raise AnsibleError( raise AnsibleError(
f"failed to fetch file {in_path} to {out_path}:\n{stdout}\n{stderr}" f"failed to fetch file {in_path} to {out_path}:\n{stdout!r}\n{stderr!r}"
) )
# Rename if needed # Rename if needed
@ -606,11 +611,11 @@ class Connection(ConnectionBase):
to_bytes(out_path, errors="strict"), to_bytes(out_path, errors="strict"),
) )
def close(self): def close(self) -> None:
"""Terminate the connection. Nothing to do for Docker""" """Terminate the connection. Nothing to do for Docker"""
super().close() super().close() # type: ignore[safe-super]
self._connected = False self._connected = False
def reset(self): def reset(self) -> None:
# Clear container user cache # Clear container user cache
self._container_user_cache = {} self._container_user_cache = {}

View File

@ -107,6 +107,7 @@ options:
import os import os
import os.path import os.path
import typing as t
from ansible.errors import AnsibleConnectionFailure, AnsibleFileNotFound from ansible.errors import AnsibleConnectionFailure, AnsibleFileNotFound
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
@ -138,6 +139,12 @@ from ansible_collections.community.docker.plugins.plugin_utils._socket_handler i
) )
if t.TYPE_CHECKING:
from collections.abc import Callable
_T = t.TypeVar("_T")
MIN_DOCKER_API = None MIN_DOCKER_API = None
@ -150,10 +157,16 @@ class Connection(ConnectionBase):
transport = "community.docker.docker_api" transport = "community.docker.docker_api"
has_pipelining = True has_pipelining = True
def _call_client(self, f, not_found_can_be_resource=False): def _call_client(
self,
f: Callable[[AnsibleDockerClient], _T],
not_found_can_be_resource: bool = False,
) -> _T:
if self.client is None:
raise AssertionError("Client must be present")
remote_addr = self.get_option("remote_addr") remote_addr = self.get_option("remote_addr")
try: try:
return f() return f(self.client)
except NotFound as e: except NotFound as e:
if not_found_can_be_resource: if not_found_can_be_resource:
raise AnsibleConnectionFailure( raise AnsibleConnectionFailure(
@ -179,21 +192,21 @@ class Connection(ConnectionBase):
f'An unexpected requests error occurred for container "{remote_addr}" when trying to talk to the Docker daemon: {e}' f'An unexpected requests error occurred for container "{remote_addr}" when trying to talk to the Docker daemon: {e}'
) )
def __init__(self, play_context, new_stdin, *args, **kwargs): def __init__(self, *args, **kwargs) -> None:
super().__init__(play_context, new_stdin, *args, **kwargs) super().__init__(*args, **kwargs)
self.client = None self.client: AnsibleDockerClient | None = None
self.ids = {} self.ids: dict[str | None, tuple[int, int]] = {}
# Windows uses Powershell modules # Windows uses Powershell modules
if getattr(self._shell, "_IS_WINDOWS", False): if getattr(self._shell, "_IS_WINDOWS", False):
self.module_implementation_preferences = (".ps1", ".exe", "") self.module_implementation_preferences = (".ps1", ".exe", "")
self.actual_user = None self.actual_user: str | None = None
def _connect(self, port=None): def _connect(self) -> Connection:
"""Connect to the container. Nothing to do""" """Connect to the container. Nothing to do"""
super()._connect() super()._connect() # type: ignore[safe-super]
if not self._connected: if not self._connected:
self.actual_user = self.get_option("remote_user") self.actual_user = self.get_option("remote_user")
display.vvv( display.vvv(
@ -212,7 +225,7 @@ class Connection(ConnectionBase):
# This saves overhead from calling into docker when we do not need to # This saves overhead from calling into docker when we do not need to
display.vvv("Trying to determine actual user") display.vvv("Trying to determine actual user")
result = self._call_client( result = self._call_client(
lambda: self.client.get_json( lambda client: client.get_json(
"/containers/{0}/json", self.get_option("remote_addr") "/containers/{0}/json", self.get_option("remote_addr")
) )
) )
@ -221,12 +234,19 @@ class Connection(ConnectionBase):
if self.actual_user is not None: if self.actual_user is not None:
display.vvv(f"Actual user is '{self.actual_user}'") display.vvv(f"Actual user is '{self.actual_user}'")
def exec_command(self, cmd, in_data=None, sudoable=False): return self
def exec_command(
self, cmd: str, in_data: bytes | None = None, sudoable: bool = False
) -> tuple[int, bytes, bytes]:
"""Run a command on the docker host""" """Run a command on the docker host"""
super().exec_command(cmd, in_data=in_data, sudoable=sudoable) super().exec_command(cmd, in_data=in_data, sudoable=sudoable) # type: ignore[safe-super]
command = [self._play_context.executable, "-c", to_text(cmd)] if self.client is None:
raise AssertionError("Client must be present")
command = [self._play_context.executable, "-c", cmd]
do_become = self.become and self.become.expect_prompt() and sudoable do_become = self.become and self.become.expect_prompt() and sudoable
@ -277,7 +297,7 @@ class Connection(ConnectionBase):
) )
exec_data = self._call_client( exec_data = self._call_client(
lambda: self.client.post_json_to_json( lambda client: client.post_json_to_json(
"/containers/{0}/exec", self.get_option("remote_addr"), data=data "/containers/{0}/exec", self.get_option("remote_addr"), data=data
) )
) )
@ -286,7 +306,7 @@ class Connection(ConnectionBase):
data = {"Tty": False, "Detach": False} data = {"Tty": False, "Detach": False}
if need_stdin: if need_stdin:
exec_socket = self._call_client( exec_socket = self._call_client(
lambda: self.client.post_json_to_stream_socket( lambda client: client.post_json_to_stream_socket(
"/exec/{0}/start", exec_id, data=data "/exec/{0}/start", exec_id, data=data
) )
) )
@ -295,6 +315,8 @@ class Connection(ConnectionBase):
display, exec_socket, container=self.get_option("remote_addr") display, exec_socket, container=self.get_option("remote_addr")
) as exec_socket_handler: ) as exec_socket_handler:
if do_become: if do_become:
assert self.become is not None
become_output = [b""] become_output = [b""]
def append_become_output(stream_id, data): def append_become_output(stream_id, data):
@ -339,7 +361,7 @@ class Connection(ConnectionBase):
exec_socket.close() exec_socket.close()
else: else:
stdout, stderr = self._call_client( stdout, stderr = self._call_client(
lambda: self.client.post_json_to_stream( lambda client: client.post_json_to_stream(
"/exec/{0}/start", "/exec/{0}/start",
exec_id, exec_id,
stream=False, stream=False,
@ -350,12 +372,12 @@ class Connection(ConnectionBase):
) )
result = self._call_client( result = self._call_client(
lambda: self.client.get_json("/exec/{0}/json", exec_id) lambda client: client.get_json("/exec/{0}/json", exec_id)
) )
return result.get("ExitCode") or 0, stdout or b"", stderr or b"" return result.get("ExitCode") or 0, stdout or b"", stderr or b""
def _prefix_login_path(self, remote_path): def _prefix_login_path(self, remote_path: str) -> str:
"""Make sure that we put files into a standard path """Make sure that we put files into a standard path
If a path is relative, then we need to choose where to put it. If a path is relative, then we need to choose where to put it.
@ -373,19 +395,23 @@ class Connection(ConnectionBase):
remote_path = os.path.join(os.path.sep, remote_path) remote_path = os.path.join(os.path.sep, remote_path)
return os.path.normpath(remote_path) return os.path.normpath(remote_path)
def put_file(self, in_path, out_path): def put_file(self, in_path: str, out_path: str) -> None:
"""Transfer a file from local to docker container""" """Transfer a file from local to docker container"""
super().put_file(in_path, out_path) super().put_file(in_path, out_path) # type: ignore[safe-super]
display.vvv(f"PUT {in_path} TO {out_path}", host=self.get_option("remote_addr")) display.vvv(f"PUT {in_path} TO {out_path}", host=self.get_option("remote_addr"))
if self.client is None:
raise AssertionError("Client must be present")
out_path = self._prefix_login_path(out_path) out_path = self._prefix_login_path(out_path)
if self.actual_user not in self.ids: if self.actual_user not in self.ids:
dummy, ids, dummy = self.exec_command(b"id -u && id -g") dummy, ids, dummy2 = self.exec_command("id -u && id -g")
remote_addr = self.get_option("remote_addr") remote_addr = self.get_option("remote_addr")
try: try:
user_id, group_id = ids.splitlines() b_user_id, b_group_id = ids.splitlines()
self.ids[self.actual_user] = int(user_id), int(group_id) user_id, group_id = int(b_user_id), int(b_group_id)
self.ids[self.actual_user] = user_id, group_id
display.vvvv( display.vvvv(
f'PUT: Determined uid={user_id} and gid={group_id} for user "{self.actual_user}"', f'PUT: Determined uid={user_id} and gid={group_id} for user "{self.actual_user}"',
host=remote_addr, host=remote_addr,
@ -398,8 +424,8 @@ class Connection(ConnectionBase):
user_id, group_id = self.ids[self.actual_user] user_id, group_id = self.ids[self.actual_user]
try: try:
self._call_client( self._call_client(
lambda: put_file( lambda client: put_file(
self.client, client,
container=self.get_option("remote_addr"), container=self.get_option("remote_addr"),
in_path=in_path, in_path=in_path,
out_path=out_path, out_path=out_path,
@ -415,19 +441,22 @@ class Connection(ConnectionBase):
except DockerFileCopyError as exc: except DockerFileCopyError as exc:
raise AnsibleConnectionFailure(to_native(exc)) from exc raise AnsibleConnectionFailure(to_native(exc)) from exc
def fetch_file(self, in_path, out_path): def fetch_file(self, in_path: str, out_path: str) -> None:
"""Fetch a file from container to local.""" """Fetch a file from container to local."""
super().fetch_file(in_path, out_path) super().fetch_file(in_path, out_path) # type: ignore[safe-super]
display.vvv( display.vvv(
f"FETCH {in_path} TO {out_path}", host=self.get_option("remote_addr") f"FETCH {in_path} TO {out_path}", host=self.get_option("remote_addr")
) )
if self.client is None:
raise AssertionError("Client must be present")
in_path = self._prefix_login_path(in_path) in_path = self._prefix_login_path(in_path)
try: try:
self._call_client( self._call_client(
lambda: fetch_file( lambda client: fetch_file(
self.client, client,
container=self.get_option("remote_addr"), container=self.get_option("remote_addr"),
in_path=in_path, in_path=in_path,
out_path=out_path, out_path=out_path,
@ -443,10 +472,10 @@ class Connection(ConnectionBase):
except DockerFileCopyError as exc: except DockerFileCopyError as exc:
raise AnsibleConnectionFailure(to_native(exc)) from exc raise AnsibleConnectionFailure(to_native(exc)) from exc
def close(self): def close(self) -> None:
"""Terminate the connection. Nothing to do for Docker""" """Terminate the connection. Nothing to do for Docker"""
super().close() super().close() # type: ignore[safe-super]
self._connected = False self._connected = False
def reset(self): def reset(self) -> None:
self.ids.clear() self.ids.clear()

View File

@ -44,7 +44,9 @@ import fcntl
import os import os
import pty import pty
import selectors import selectors
import shlex
import subprocess import subprocess
import typing as t
import ansible.constants as C import ansible.constants as C
from ansible.errors import AnsibleError from ansible.errors import AnsibleError
@ -63,12 +65,12 @@ class Connection(ConnectionBase):
transport = "community.docker.nsenter" transport = "community.docker.nsenter"
has_pipelining = False has_pipelining = False
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.cwd = None self.cwd = None
self._nsenter_pid = None self._nsenter_pid = None
def _connect(self): def _connect(self) -> t.Self:
self._nsenter_pid = self.get_option("nsenter_pid") self._nsenter_pid = self.get_option("nsenter_pid")
# Because nsenter requires very high privileges, our remote user # Because nsenter requires very high privileges, our remote user
@ -83,12 +85,15 @@ class Connection(ConnectionBase):
self._connected = True self._connected = True
return self return self
def exec_command(self, cmd, in_data=None, sudoable=True): def exec_command(
super().exec_command(cmd, in_data=in_data, sudoable=sudoable) self, cmd: str, in_data: bytes | None = None, sudoable: bool = True
) -> tuple[int, bytes, bytes]:
super().exec_command(cmd, in_data=in_data, sudoable=sudoable) # type: ignore[safe-super]
display.debug("in nsenter.exec_command()") display.debug("in nsenter.exec_command()")
executable = C.DEFAULT_EXECUTABLE.split()[0] if C.DEFAULT_EXECUTABLE else None def_executable: str | None = C.DEFAULT_EXECUTABLE # type: ignore[attr-defined]
executable = def_executable.split()[0] if def_executable else None
if not os.path.exists(to_bytes(executable, errors="surrogate_or_strict")): if not os.path.exists(to_bytes(executable, errors="surrogate_or_strict")):
raise AnsibleError( raise AnsibleError(
@ -109,12 +114,8 @@ class Connection(ConnectionBase):
"--", "--",
] ]
if isinstance(cmd, (str, bytes)): cmd_parts = nsenter_cmd_parts + [cmd]
cmd_parts = nsenter_cmd_parts + [cmd] cmd = to_bytes(" ".join(cmd_parts))
cmd = to_bytes(" ".join(cmd_parts))
else:
cmd_parts = nsenter_cmd_parts + cmd
cmd = [to_bytes(arg) for arg in cmd_parts]
display.vvv(f"EXEC {to_text(cmd)}", host=self._play_context.remote_addr) display.vvv(f"EXEC {to_text(cmd)}", host=self._play_context.remote_addr)
display.debug("opening command with Popen()") display.debug("opening command with Popen()")
@ -143,6 +144,9 @@ class Connection(ConnectionBase):
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
) as p: ) as p:
assert p.stderr is not None
assert p.stdin is not None
assert p.stdout is not None
# if we created a master, we can close the other half of the pty now, otherwise master is stdin # if we created a master, we can close the other half of the pty now, otherwise master is stdin
if master is not None: if master is not None:
os.close(stdin) os.close(stdin)
@ -234,8 +238,8 @@ class Connection(ConnectionBase):
display.debug("done with nsenter.exec_command()") display.debug("done with nsenter.exec_command()")
return (p.returncode, stdout, stderr) return (p.returncode, stdout, stderr)
def put_file(self, in_path, out_path): def put_file(self, in_path: str, out_path: str) -> None:
super().put_file(in_path, out_path) super().put_file(in_path, out_path) # type: ignore[safe-super]
in_path = unfrackpath(in_path, basedir=self.cwd) in_path = unfrackpath(in_path, basedir=self.cwd)
out_path = unfrackpath(out_path, basedir=self.cwd) out_path = unfrackpath(out_path, basedir=self.cwd)
@ -245,26 +249,30 @@ class Connection(ConnectionBase):
with open(to_bytes(in_path, errors="surrogate_or_strict"), "rb") as in_file: with open(to_bytes(in_path, errors="surrogate_or_strict"), "rb") as in_file:
in_data = in_file.read() in_data = in_file.read()
rc, dummy_out, err = self.exec_command( rc, dummy_out, err = self.exec_command(
cmd=["tee", out_path], in_data=in_data cmd=f"tee {shlex.quote(out_path)}", in_data=in_data
) )
if rc != 0: if rc != 0:
raise AnsibleError(f"failed to transfer file to {out_path}: {err}") raise AnsibleError(
f"failed to transfer file to {out_path}: {to_text(err)}"
)
except IOError as e: except IOError as e:
raise AnsibleError(f"failed to transfer file to {out_path}: {e}") from e raise AnsibleError(f"failed to transfer file to {out_path}: {e}") from e
def fetch_file(self, in_path, out_path): def fetch_file(self, in_path: str, out_path: str) -> None:
super().fetch_file(in_path, out_path) super().fetch_file(in_path, out_path) # type: ignore[safe-super]
in_path = unfrackpath(in_path, basedir=self.cwd) in_path = unfrackpath(in_path, basedir=self.cwd)
out_path = unfrackpath(out_path, basedir=self.cwd) out_path = unfrackpath(out_path, basedir=self.cwd)
try: try:
rc, out, err = self.exec_command(cmd=["cat", in_path]) rc, out, err = self.exec_command(cmd=f"cat {shlex.quote(in_path)}")
display.vvv( display.vvv(
f"FETCH {in_path} TO {out_path}", host=self._play_context.remote_addr f"FETCH {in_path} TO {out_path}", host=self._play_context.remote_addr
) )
if rc != 0: if rc != 0:
raise AnsibleError(f"failed to transfer file to {in_path}: {err}") raise AnsibleError(
f"failed to transfer file to {in_path}: {to_text(err)}"
)
with open( with open(
to_bytes(out_path, errors="surrogate_or_strict"), "wb" to_bytes(out_path, errors="surrogate_or_strict"), "wb"
) as out_file: ) as out_file:
@ -274,6 +282,6 @@ class Connection(ConnectionBase):
f"failed to transfer file to {to_native(out_path)}: {e}" f"failed to transfer file to {to_native(out_path)}: {e}"
) from e ) from e
def close(self): def close(self) -> None:
"""terminate the connection; nothing to do here""" """terminate the connection; nothing to do here"""
self._connected = False self._connected = False

View File

@ -169,6 +169,7 @@ filters:
""" """
import re import re
import typing as t
from ansible.errors import AnsibleError from ansible.errors import AnsibleError
from ansible.plugins.inventory import BaseInventoryPlugin, Constructable from ansible.plugins.inventory import BaseInventoryPlugin, Constructable
@ -195,6 +196,11 @@ from ansible_collections.community.docker.plugins.plugin_utils._unsafe import (
) )
if t.TYPE_CHECKING:
from ansible.inventory.data import InventoryData
from ansible.parsing.dataloader import DataLoader
MIN_DOCKER_API = None MIN_DOCKER_API = None
@ -203,11 +209,11 @@ class InventoryModule(BaseInventoryPlugin, Constructable):
NAME = "community.docker.docker_containers" NAME = "community.docker.docker_containers"
def _slugify(self, value): def _slugify(self, value: str) -> str:
slug = re.sub(r"[^\w-]", "_", value).lower().lstrip("_") slug = re.sub(r"[^\w-]", "_", value).lower().lstrip("_")
return f"docker_{slug}" return f"docker_{slug}"
def _populate(self, client): def _populate(self, client: AnsibleDockerClient) -> None:
strict = self.get_option("strict") strict = self.get_option("strict")
ssh_port = self.get_option("private_ssh_port") ssh_port = self.get_option("private_ssh_port")
@ -217,6 +223,9 @@ class InventoryModule(BaseInventoryPlugin, Constructable):
connection_type = self.get_option("connection_type") connection_type = self.get_option("connection_type")
add_legacy_groups = self.get_option("add_legacy_groups") add_legacy_groups = self.get_option("add_legacy_groups")
if self.inventory is None:
raise AssertionError("Inventory must be there")
try: try:
params = { params = {
"limit": -1, "limit": -1,
@ -298,7 +307,7 @@ class InventoryModule(BaseInventoryPlugin, Constructable):
# Lookup the public facing port Nat'ed to ssh port. # Lookup the public facing port Nat'ed to ssh port.
network_settings = inspect.get("NetworkSettings") or {} network_settings = inspect.get("NetworkSettings") or {}
port_settings = network_settings.get("Ports") or {} port_settings = network_settings.get("Ports") or {}
port = port_settings.get(f"{ssh_port}/tcp")[0] port = port_settings.get(f"{ssh_port}/tcp")[0] # type: ignore[index]
except (IndexError, AttributeError, TypeError): except (IndexError, AttributeError, TypeError):
port = {} port = {}
@ -387,16 +396,22 @@ class InventoryModule(BaseInventoryPlugin, Constructable):
else: else:
self.inventory.add_host(name, group="stopped") self.inventory.add_host(name, group="stopped")
def verify_file(self, path): def verify_file(self, path: str) -> bool:
"""Return the possibly of a file being consumable by this plugin.""" """Return the possibly of a file being consumable by this plugin."""
return super().verify_file(path) and path.endswith( return super().verify_file(path) and path.endswith(
("docker.yaml", "docker.yml") ("docker.yaml", "docker.yml")
) )
def _create_client(self): def _create_client(self) -> AnsibleDockerClient:
return AnsibleDockerClient(self, min_docker_api_version=MIN_DOCKER_API) return AnsibleDockerClient(self, min_docker_api_version=MIN_DOCKER_API)
def parse(self, inventory, loader, path, cache=True): def parse(
self,
inventory: InventoryData,
loader: DataLoader,
path: str,
cache: bool = True,
) -> None:
super().parse(inventory, loader, path, cache) super().parse(inventory, loader, path, cache)
self._read_config_data(path) self._read_config_data(path)
client = self._create_client() client = self._create_client()

View File

@ -101,6 +101,7 @@ compose:
import json import json
import re import re
import subprocess import subprocess
import typing as t
from ansible.errors import AnsibleError from ansible.errors import AnsibleError
from ansible.module_utils.common.process import get_bin_path from ansible.module_utils.common.process import get_bin_path
@ -117,6 +118,15 @@ from ansible_collections.community.docker.plugins.plugin_utils._unsafe import (
) )
if t.TYPE_CHECKING:
from ansible.inventory.data import InventoryData
from ansible.parsing.dataloader import DataLoader
DaemonEnv = t.Literal[
"require", "require-silently", "optional", "optional-silently", "skip"
]
display = Display() display = Display()
@ -125,9 +135,9 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
NAME = "community.docker.docker_machine" NAME = "community.docker.docker_machine"
docker_machine_path = None docker_machine_path: str | None = None
def _run_command(self, args): def _run_command(self, args: list[str]) -> str:
if not self.docker_machine_path: if not self.docker_machine_path:
try: try:
self.docker_machine_path = get_bin_path("docker-machine") self.docker_machine_path = get_bin_path("docker-machine")
@ -147,7 +157,7 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
return to_text(result).strip() return to_text(result).strip()
def _get_docker_daemon_variables(self, machine_name): def _get_docker_daemon_variables(self, machine_name: str) -> list[tuple[str, str]]:
""" """
Capture settings from Docker Machine that would be needed to connect to the remote Docker daemon installed on Capture settings from Docker Machine that would be needed to connect to the remote Docker daemon installed on
the Docker Machine remote host. Note: passing '--shell=sh' is a workaround for 'Error: Unknown shell'. the Docker Machine remote host. Note: passing '--shell=sh' is a workaround for 'Error: Unknown shell'.
@ -180,7 +190,7 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
return env_vars return env_vars
def _get_machine_names(self): def _get_machine_names(self) -> list[str]:
# Filter out machines that are not in the Running state as we probably cannot do anything useful actions # Filter out machines that are not in the Running state as we probably cannot do anything useful actions
# with them. # with them.
ls_command = ["ls", "-q"] ls_command = ["ls", "-q"]
@ -194,7 +204,7 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
return ls_lines.splitlines() return ls_lines.splitlines()
def _inspect_docker_machine_host(self, node): def _inspect_docker_machine_host(self, node: str) -> t.Any | None:
try: try:
inspect_lines = self._run_command(["inspect", node]) inspect_lines = self._run_command(["inspect", node])
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
@ -202,7 +212,7 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
return json.loads(inspect_lines) return json.loads(inspect_lines)
def _ip_addr_docker_machine_host(self, node): def _ip_addr_docker_machine_host(self, node: str) -> t.Any | None:
try: try:
ip_addr = self._run_command(["ip", node]) ip_addr = self._run_command(["ip", node])
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
@ -210,7 +220,9 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
return ip_addr return ip_addr
def _should_skip_host(self, machine_name, env_var_tuples, daemon_env): def _should_skip_host(
self, machine_name: str, env_var_tuples, daemon_env: DaemonEnv
) -> bool:
if not env_var_tuples: if not env_var_tuples:
warning_prefix = f"Unable to fetch Docker daemon env vars from Docker Machine for host {machine_name}" warning_prefix = f"Unable to fetch Docker daemon env vars from Docker Machine for host {machine_name}"
if daemon_env in ("require", "require-silently"): if daemon_env in ("require", "require-silently"):
@ -224,8 +236,11 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
# daemon_env is 'optional-silently' # daemon_env is 'optional-silently'
return False return False
def _populate(self): def _populate(self) -> None:
daemon_env = self.get_option("daemon_env") if self.inventory is None:
raise AssertionError("Inventory must be there")
daemon_env: DaemonEnv = self.get_option("daemon_env")
filters = parse_filters(self.get_option("filters")) filters = parse_filters(self.get_option("filters"))
try: try:
for node in self._get_machine_names(): for node in self._get_machine_names():
@ -325,13 +340,19 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
f"Unable to fetch hosts from Docker Machine, this was the original exception: {e}" f"Unable to fetch hosts from Docker Machine, this was the original exception: {e}"
) from e ) from e
def verify_file(self, path): def verify_file(self, path: str) -> bool:
"""Return the possibility of a file being consumable by this plugin.""" """Return the possibility of a file being consumable by this plugin."""
return super().verify_file(path) and path.endswith( return super().verify_file(path) and path.endswith(
("docker_machine.yaml", "docker_machine.yml") ("docker_machine.yaml", "docker_machine.yml")
) )
def parse(self, inventory, loader, path, cache=True): def parse(
self,
inventory: InventoryData,
loader: DataLoader,
path: str,
cache: bool = True,
) -> None:
super().parse(inventory, loader, path, cache) super().parse(inventory, loader, path, cache)
self._read_config_data(path) self._read_config_data(path)
self._populate() self._populate()

View File

@ -148,6 +148,8 @@ keyed_groups:
prefix: label prefix: label
""" """
import typing as t
from ansible.errors import AnsibleError from ansible.errors import AnsibleError
from ansible.parsing.utils.addresses import parse_address from ansible.parsing.utils.addresses import parse_address
from ansible.plugins.inventory import BaseInventoryPlugin, Constructable from ansible.plugins.inventory import BaseInventoryPlugin, Constructable
@ -167,6 +169,11 @@ from ansible_collections.community.docker.plugins.plugin_utils._unsafe import (
) )
if t.TYPE_CHECKING:
from ansible.inventory.data import InventoryData
from ansible.parsing.dataloader import DataLoader
try: try:
import docker import docker
@ -180,10 +187,13 @@ class InventoryModule(BaseInventoryPlugin, Constructable):
NAME = "community.docker.docker_swarm" NAME = "community.docker.docker_swarm"
def _fail(self, msg): def _fail(self, msg: str) -> t.NoReturn:
raise AnsibleError(msg) raise AnsibleError(msg)
def _populate(self): def _populate(self) -> None:
if self.inventory is None:
raise AssertionError("Inventory must be there")
raw_params = { raw_params = {
"docker_host": self.get_option("docker_host"), "docker_host": self.get_option("docker_host"),
"tls": self.get_option("tls"), "tls": self.get_option("tls"),
@ -307,13 +317,19 @@ class InventoryModule(BaseInventoryPlugin, Constructable):
f"Unable to fetch hosts from Docker swarm API, this was the original exception: {e}" f"Unable to fetch hosts from Docker swarm API, this was the original exception: {e}"
) from e ) from e
def verify_file(self, path): def verify_file(self, path: str) -> bool:
"""Return the possibly of a file being consumable by this plugin.""" """Return the possibly of a file being consumable by this plugin."""
return super().verify_file(path) and path.endswith( return super().verify_file(path) and path.endswith(
("docker_swarm.yaml", "docker_swarm.yml") ("docker_swarm.yaml", "docker_swarm.yml")
) )
def parse(self, inventory, loader, path, cache=True): def parse(
self,
inventory: InventoryData,
loader: DataLoader,
path: str,
cache: bool = True,
) -> None:
if not HAS_DOCKER: if not HAS_DOCKER:
raise AnsibleError( raise AnsibleError(
"The Docker swarm dynamic inventory plugin requires the Docker SDK for Python: " "The Docker swarm dynamic inventory plugin requires the Docker SDK for Python: "

View File

@ -12,8 +12,10 @@
from __future__ import annotations from __future__ import annotations
import traceback import traceback
import typing as t
REQUESTS_IMPORT_ERROR: str | None # pylint: disable=invalid-name
try: try:
from requests import Session # noqa: F401, pylint: disable=unused-import from requests import Session # noqa: F401, pylint: disable=unused-import
from requests.adapters import ( # noqa: F401, pylint: disable=unused-import from requests.adapters import ( # noqa: F401, pylint: disable=unused-import
@ -26,28 +28,29 @@ try:
except ImportError: except ImportError:
REQUESTS_IMPORT_ERROR = traceback.format_exc() # pylint: disable=invalid-name REQUESTS_IMPORT_ERROR = traceback.format_exc() # pylint: disable=invalid-name
class Session: class Session: # type: ignore
__attrs__ = [] __attrs__: list[t.Never] = []
class HTTPAdapter: class HTTPAdapter: # type: ignore
__attrs__ = [] __attrs__: list[t.Never] = []
class HTTPError(Exception): class HTTPError(Exception): # type: ignore
pass pass
class InvalidSchema(Exception): class InvalidSchema(Exception): # type: ignore
pass pass
else: else:
REQUESTS_IMPORT_ERROR = None # pylint: disable=invalid-name REQUESTS_IMPORT_ERROR = None # pylint: disable=invalid-name
URLLIB3_IMPORT_ERROR = None # pylint: disable=invalid-name URLLIB3_IMPORT_ERROR: str | None = None # pylint: disable=invalid-name
try: try:
from requests.packages import urllib3 # pylint: disable=unused-import from requests.packages import urllib3 # pylint: disable=unused-import
# pylint: disable-next=unused-import from requests.packages.urllib3 import ( # type: ignore # pylint: disable=unused-import # isort: skip
from requests.packages.urllib3 import connection as urllib3_connection connection as urllib3_connection,
)
except ImportError: except ImportError:
try: try:
import urllib3 # pylint: disable=unused-import import urllib3 # pylint: disable=unused-import

View File

@ -13,8 +13,9 @@ from __future__ import annotations
import json import json
import logging import logging
import os
import struct import struct
from functools import partial import typing as t
from urllib.parse import quote from urllib.parse import quote
from .. import auth from .. import auth
@ -47,16 +48,21 @@ from ..transport.sshconn import PARAMIKO_IMPORT_ERROR, SSHHTTPAdapter
from ..transport.ssladapter import SSLHTTPAdapter from ..transport.ssladapter import SSLHTTPAdapter
from ..transport.unixconn import UnixHTTPAdapter from ..transport.unixconn import UnixHTTPAdapter
from ..utils import config, json_stream, utils from ..utils import config, json_stream, utils
from ..utils.decorators import update_headers from ..utils.decorators import minimum_version, update_headers
from ..utils.proxy import ProxyConfig from ..utils.proxy import ProxyConfig
from ..utils.socket import consume_socket_output, demux_adaptor, frames_iter from ..utils.socket import consume_socket_output, demux_adaptor, frames_iter
from .daemon import DaemonApiMixin
if t.TYPE_CHECKING:
from requests import Response
from ..._socket_helper import SocketLike
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class APIClient(_Session, DaemonApiMixin): class APIClient(_Session):
""" """
A low-level client for the Docker Engine API. A low-level client for the Docker Engine API.
@ -105,16 +111,16 @@ class APIClient(_Session, DaemonApiMixin):
def __init__( def __init__(
self, self,
base_url=None, base_url: str | None = None,
version=None, version: str | None = None,
timeout=DEFAULT_TIMEOUT_SECONDS, timeout: int | float = DEFAULT_TIMEOUT_SECONDS,
tls=False, tls: bool | TLSConfig = False,
user_agent=DEFAULT_USER_AGENT, user_agent: str = DEFAULT_USER_AGENT,
num_pools=None, num_pools: int | None = None,
credstore_env=None, credstore_env: dict[str, str] | None = None,
use_ssh_client=False, use_ssh_client: bool = False,
max_pool_size=DEFAULT_MAX_POOL_SIZE, max_pool_size: int = DEFAULT_MAX_POOL_SIZE,
): ) -> None:
super().__init__() super().__init__()
fail_on_missing_imports() fail_on_missing_imports()
@ -124,7 +130,6 @@ class APIClient(_Session, DaemonApiMixin):
"If using TLS, the base_url argument must be provided." "If using TLS, the base_url argument must be provided."
) )
self.base_url = base_url
self.timeout = timeout self.timeout = timeout
self.headers["User-Agent"] = user_agent self.headers["User-Agent"] = user_agent
@ -145,6 +150,7 @@ class APIClient(_Session, DaemonApiMixin):
self.credstore_env = credstore_env self.credstore_env = credstore_env
base_url = utils.parse_host(base_url, IS_WINDOWS_PLATFORM, tls=bool(tls)) base_url = utils.parse_host(base_url, IS_WINDOWS_PLATFORM, tls=bool(tls))
self.base_url = base_url
# SSH has a different default for num_pools to all other adapters # SSH has a different default for num_pools to all other adapters
num_pools = ( num_pools = (
num_pools or DEFAULT_NUM_POOLS_SSH num_pools or DEFAULT_NUM_POOLS_SSH
@ -152,6 +158,9 @@ class APIClient(_Session, DaemonApiMixin):
else DEFAULT_NUM_POOLS else DEFAULT_NUM_POOLS
) )
self._custom_adapter: (
UnixHTTPAdapter | NpipeHTTPAdapter | SSHHTTPAdapter | SSLHTTPAdapter | None
) = None
if base_url.startswith("http+unix://"): if base_url.startswith("http+unix://"):
self._custom_adapter = UnixHTTPAdapter( self._custom_adapter = UnixHTTPAdapter(
base_url, base_url,
@ -223,7 +232,7 @@ class APIClient(_Session, DaemonApiMixin):
f"API versions below {MINIMUM_DOCKER_API_VERSION} are no longer supported by this library." f"API versions below {MINIMUM_DOCKER_API_VERSION} are no longer supported by this library."
) )
def _retrieve_server_version(self): def _retrieve_server_version(self) -> str:
try: try:
version_result = self.version(api_version=False) version_result = self.version(api_version=False)
except Exception as e: except Exception as e:
@ -242,54 +251,87 @@ class APIClient(_Session, DaemonApiMixin):
f"Error while fetching server API version: {e}. Response seems to be broken." f"Error while fetching server API version: {e}. Response seems to be broken."
) from e ) from e
def _set_request_timeout(self, kwargs): def _set_request_timeout(self, kwargs: dict[str, t.Any]) -> dict[str, t.Any]:
"""Prepare the kwargs for an HTTP request by inserting the timeout """Prepare the kwargs for an HTTP request by inserting the timeout
parameter, if not already present.""" parameter, if not already present."""
kwargs.setdefault("timeout", self.timeout) kwargs.setdefault("timeout", self.timeout)
return kwargs return kwargs
@update_headers @update_headers
def _post(self, url, **kwargs): def _post(self, url: str, **kwargs):
return self.post(url, **self._set_request_timeout(kwargs)) return self.post(url, **self._set_request_timeout(kwargs))
@update_headers @update_headers
def _get(self, url, **kwargs): def _get(self, url: str, **kwargs):
return self.get(url, **self._set_request_timeout(kwargs)) return self.get(url, **self._set_request_timeout(kwargs))
@update_headers @update_headers
def _head(self, url, **kwargs): def _head(self, url: str, **kwargs):
return self.head(url, **self._set_request_timeout(kwargs)) return self.head(url, **self._set_request_timeout(kwargs))
@update_headers @update_headers
def _put(self, url, **kwargs): def _put(self, url: str, **kwargs):
return self.put(url, **self._set_request_timeout(kwargs)) return self.put(url, **self._set_request_timeout(kwargs))
@update_headers @update_headers
def _delete(self, url, **kwargs): def _delete(self, url: str, **kwargs):
return self.delete(url, **self._set_request_timeout(kwargs)) return self.delete(url, **self._set_request_timeout(kwargs))
def _url(self, pathfmt, *args, **kwargs): def _url(self, pathfmt: str, *args: str, versioned_api: bool = True) -> str:
for arg in args: for arg in args:
if not isinstance(arg, str): if not isinstance(arg, str):
raise ValueError( raise ValueError(
f"Expected a string but found {arg} ({type(arg)}) instead" f"Expected a string but found {arg} ({type(arg)}) instead"
) )
quote_f = partial(quote, safe="/:") q_args = [quote(arg, safe="/:") for arg in args]
args = map(quote_f, args)
if kwargs.get("versioned_api", True): if versioned_api:
return f"{self.base_url}/v{self._version}{pathfmt.format(*args)}" return f"{self.base_url}/v{self._version}{pathfmt.format(*q_args)}"
return f"{self.base_url}{pathfmt.format(*args)}" return f"{self.base_url}{pathfmt.format(*q_args)}"
def _raise_for_status(self, response): def _raise_for_status(self, response: Response) -> None:
"""Raises stored :class:`APIError`, if one occurred.""" """Raises stored :class:`APIError`, if one occurred."""
try: try:
response.raise_for_status() response.raise_for_status()
except _HTTPError as e: except _HTTPError as e:
create_api_error_from_http_exception(e) create_api_error_from_http_exception(e)
def _result(self, response, get_json=False, get_binary=False): @t.overload
def _result(
self,
response: Response,
*,
get_json: t.Literal[False] = False,
get_binary: t.Literal[False] = False,
) -> str: ...
@t.overload
def _result(
self,
response: Response,
*,
get_json: t.Literal[True],
get_binary: t.Literal[False] = False,
) -> t.Any: ...
@t.overload
def _result(
self,
response: Response,
*,
get_json: t.Literal[False] = False,
get_binary: t.Literal[True],
) -> bytes: ...
@t.overload
def _result(
self, response: Response, *, get_json: bool = False, get_binary: bool = False
) -> t.Any | str | bytes: ...
def _result(
self, response: Response, *, get_json: bool = False, get_binary: bool = False
) -> t.Any | str | bytes:
if get_json and get_binary: if get_json and get_binary:
raise AssertionError("json and binary must not be both True") raise AssertionError("json and binary must not be both True")
self._raise_for_status(response) self._raise_for_status(response)
@ -300,10 +342,12 @@ class APIClient(_Session, DaemonApiMixin):
return response.content return response.content
return response.text return response.text
def _post_json(self, url, data, **kwargs): def _post_json(
self, url: str, data: dict[str, str | None] | t.Any, **kwargs
) -> Response:
# Go <1.1 cannot unserialize null to a string # Go <1.1 cannot unserialize null to a string
# so we do this disgusting thing here. # so we do this disgusting thing here.
data2 = {} data2: dict[str, t.Any] = {}
if data is not None and isinstance(data, dict): if data is not None and isinstance(data, dict):
for k, v in data.items(): for k, v in data.items():
if v is not None: if v is not None:
@ -316,19 +360,19 @@ class APIClient(_Session, DaemonApiMixin):
kwargs["headers"]["Content-Type"] = "application/json" kwargs["headers"]["Content-Type"] = "application/json"
return self._post(url, data=json.dumps(data2), **kwargs) return self._post(url, data=json.dumps(data2), **kwargs)
def _attach_params(self, override=None): def _attach_params(self, override: dict[str, int] | None = None) -> dict[str, int]:
return override or {"stdout": 1, "stderr": 1, "stream": 1} return override or {"stdout": 1, "stderr": 1, "stream": 1}
def _get_raw_response_socket(self, response): def _get_raw_response_socket(self, response: Response) -> SocketLike:
self._raise_for_status(response) self._raise_for_status(response)
if self.base_url == "http+docker://localnpipe": if self.base_url == "http+docker://localnpipe":
sock = response.raw._fp.fp.raw.sock sock = response.raw._fp.fp.raw.sock # type: ignore[union-attr]
elif self.base_url.startswith("http+docker://ssh"): elif self.base_url.startswith("http+docker://ssh"):
sock = response.raw._fp.fp.channel sock = response.raw._fp.fp.channel # type: ignore[union-attr]
else: else:
sock = response.raw._fp.fp.raw sock = response.raw._fp.fp.raw # type: ignore[union-attr]
if self.base_url.startswith("https://"): if self.base_url.startswith("https://"):
sock = sock._sock sock = sock._sock # type: ignore[union-attr]
try: try:
# Keep a reference to the response to stop it being garbage # Keep a reference to the response to stop it being garbage
# collected. If the response is garbage collected, it will # collected. If the response is garbage collected, it will
@ -341,12 +385,26 @@ class APIClient(_Session, DaemonApiMixin):
return sock return sock
def _stream_helper(self, response, decode=False): @t.overload
def _stream_helper(
self, response: Response, *, decode: t.Literal[False] = False
) -> t.Generator[bytes]: ...
@t.overload
def _stream_helper(
self, response: Response, *, decode: t.Literal[True]
) -> t.Generator[t.Any]: ...
def _stream_helper(
self, response: Response, *, decode: bool = False
) -> t.Generator[t.Any]:
"""Generator for data coming from a chunked-encoded HTTP response.""" """Generator for data coming from a chunked-encoded HTTP response."""
if response.raw._fp.chunked: if response.raw._fp.chunked: # type: ignore[union-attr]
if decode: if decode:
yield from json_stream.json_stream(self._stream_helper(response, False)) yield from json_stream.json_stream(
self._stream_helper(response, decode=False)
)
else: else:
reader = response.raw reader = response.raw
while not reader.closed: while not reader.closed:
@ -354,15 +412,15 @@ class APIClient(_Session, DaemonApiMixin):
data = reader.read(1) data = reader.read(1)
if not data: if not data:
break break
if reader._fp.chunk_left: if reader._fp.chunk_left: # type: ignore[union-attr]
data += reader.read(reader._fp.chunk_left) data += reader.read(reader._fp.chunk_left) # type: ignore[union-attr]
yield data yield data
else: else:
# Response is not chunked, meaning we probably # Response is not chunked, meaning we probably
# encountered an error immediately # encountered an error immediately
yield self._result(response, get_json=decode) yield self._result(response, get_json=decode)
def _multiplexed_buffer_helper(self, response): def _multiplexed_buffer_helper(self, response: Response) -> t.Generator[bytes]:
"""A generator of multiplexed data blocks read from a buffered """A generator of multiplexed data blocks read from a buffered
response.""" response."""
buf = self._result(response, get_binary=True) buf = self._result(response, get_binary=True)
@ -378,7 +436,9 @@ class APIClient(_Session, DaemonApiMixin):
walker = end walker = end
yield buf[start:end] yield buf[start:end]
def _multiplexed_response_stream_helper(self, response): def _multiplexed_response_stream_helper(
self, response: Response
) -> t.Generator[bytes]:
"""A generator of multiplexed data blocks coming from a response """A generator of multiplexed data blocks coming from a response
stream.""" stream."""
@ -399,7 +459,19 @@ class APIClient(_Session, DaemonApiMixin):
break break
yield data yield data
def _stream_raw_result(self, response, chunk_size=1, decode=True): @t.overload
def _stream_raw_result(
self, response: Response, *, chunk_size: int = 1, decode: t.Literal[True] = True
) -> t.Generator[str]: ...
@t.overload
def _stream_raw_result(
self, response: Response, *, chunk_size: int = 1, decode: t.Literal[False]
) -> t.Generator[bytes]: ...
def _stream_raw_result(
self, response: Response, *, chunk_size: int = 1, decode: bool = True
) -> t.Generator[str | bytes]:
"""Stream result for TTY-enabled container and raw binary data""" """Stream result for TTY-enabled container and raw binary data"""
self._raise_for_status(response) self._raise_for_status(response)
@ -410,14 +482,81 @@ class APIClient(_Session, DaemonApiMixin):
yield from response.iter_content(chunk_size, decode) yield from response.iter_content(chunk_size, decode)
def _read_from_socket(self, response, stream, tty=True, demux=False): @t.overload
def _read_from_socket(
self,
response: Response,
*,
stream: t.Literal[True],
tty: bool = True,
demux: t.Literal[False] = False,
) -> t.Generator[bytes]: ...
@t.overload
def _read_from_socket(
self,
response: Response,
*,
stream: t.Literal[True],
tty: t.Literal[True] = True,
demux: t.Literal[True],
) -> t.Generator[tuple[bytes, None]]: ...
@t.overload
def _read_from_socket(
self,
response: Response,
*,
stream: t.Literal[True],
tty: t.Literal[False],
demux: t.Literal[True],
) -> t.Generator[tuple[bytes, None] | tuple[None, bytes]]: ...
@t.overload
def _read_from_socket(
self,
response: Response,
*,
stream: t.Literal[False],
tty: bool = True,
demux: t.Literal[False] = False,
) -> bytes: ...
@t.overload
def _read_from_socket(
self,
response: Response,
*,
stream: t.Literal[False],
tty: t.Literal[True] = True,
demux: t.Literal[True],
) -> tuple[bytes, None]: ...
@t.overload
def _read_from_socket(
self,
response: Response,
*,
stream: t.Literal[False],
tty: t.Literal[False],
demux: t.Literal[True],
) -> tuple[bytes, bytes]: ...
@t.overload
def _read_from_socket(
self, response: Response, *, stream: bool, tty: bool = True, demux: bool = False
) -> t.Any: ...
def _read_from_socket(
self, response: Response, *, stream: bool, tty: bool = True, demux: bool = False
) -> t.Any:
"""Consume all data from the socket, close the response and return the """Consume all data from the socket, close the response and return the
data. If stream=True, then a generator is returned instead and the data. If stream=True, then a generator is returned instead and the
caller is responsible for closing the response. caller is responsible for closing the response.
""" """
socket = self._get_raw_response_socket(response) socket = self._get_raw_response_socket(response)
gen = frames_iter(socket, tty) gen: t.Generator = frames_iter(socket, tty)
if demux: if demux:
# The generator will output tuples (stdout, stderr) # The generator will output tuples (stdout, stderr)
@ -434,7 +573,7 @@ class APIClient(_Session, DaemonApiMixin):
finally: finally:
response.close() response.close()
def _disable_socket_timeout(self, socket): def _disable_socket_timeout(self, socket: SocketLike) -> None:
"""Depending on the combination of python version and whether we are """Depending on the combination of python version and whether we are
connecting over http or https, we might need to access _sock, which connecting over http or https, we might need to access _sock, which
may or may not exist; or we may need to just settimeout on socket may or may not exist; or we may need to just settimeout on socket
@ -451,18 +590,38 @@ class APIClient(_Session, DaemonApiMixin):
if not hasattr(s, "settimeout"): if not hasattr(s, "settimeout"):
continue continue
timeout = -1 timeout: int | float | None = -1
if hasattr(s, "gettimeout"): if hasattr(s, "gettimeout"):
timeout = s.gettimeout() timeout = s.gettimeout() # type: ignore[union-attr]
# Do not change the timeout if it is already disabled. # Do not change the timeout if it is already disabled.
if timeout is None or timeout == 0.0: if timeout is None or timeout == 0.0:
continue continue
s.settimeout(None) s.settimeout(None) # type: ignore[union-attr]
def _get_result_tty(self, stream, res, is_tty): @t.overload
def _get_result_tty(
self, stream: t.Literal[True], res: Response, is_tty: t.Literal[True]
) -> t.Generator[str]: ...
@t.overload
def _get_result_tty(
self, stream: t.Literal[True], res: Response, is_tty: t.Literal[False]
) -> t.Generator[bytes]: ...
@t.overload
def _get_result_tty(
self, stream: t.Literal[False], res: Response, is_tty: t.Literal[True]
) -> bytes: ...
@t.overload
def _get_result_tty(
self, stream: t.Literal[False], res: Response, is_tty: t.Literal[False]
) -> bytes: ...
def _get_result_tty(self, stream: bool, res: Response, is_tty: bool) -> t.Any:
# We should also use raw streaming (without keep-alive) # We should also use raw streaming (without keep-alive)
# if we are dealing with a tty-enabled container. # if we are dealing with a tty-enabled container.
if is_tty: if is_tty:
@ -478,11 +637,11 @@ class APIClient(_Session, DaemonApiMixin):
return self._multiplexed_response_stream_helper(res) return self._multiplexed_response_stream_helper(res)
return sep.join(list(self._multiplexed_buffer_helper(res))) return sep.join(list(self._multiplexed_buffer_helper(res)))
def _unmount(self, *args): def _unmount(self, *args) -> None:
for proto in args: for proto in args:
self.adapters.pop(proto) self.adapters.pop(proto)
def get_adapter(self, url): def get_adapter(self, url: str):
try: try:
return super().get_adapter(url) return super().get_adapter(url)
except _InvalidSchema as e: except _InvalidSchema as e:
@ -491,10 +650,10 @@ class APIClient(_Session, DaemonApiMixin):
raise e raise e
@property @property
def api_version(self): def api_version(self) -> str:
return self._version return self._version
def reload_config(self, dockercfg_path=None): def reload_config(self, dockercfg_path: str | None = None) -> None:
""" """
Force a reload of the auth configuration Force a reload of the auth configuration
@ -510,7 +669,7 @@ class APIClient(_Session, DaemonApiMixin):
dockercfg_path, credstore_env=self.credstore_env dockercfg_path, credstore_env=self.credstore_env
) )
def _set_auth_headers(self, headers): def _set_auth_headers(self, headers: dict[str, str | bytes]) -> None:
log.debug("Looking for auth config") log.debug("Looking for auth config")
# If we do not have any auth data so far, try reloading the config # If we do not have any auth data so far, try reloading the config
@ -537,57 +696,62 @@ class APIClient(_Session, DaemonApiMixin):
else: else:
log.debug("No auth config found") log.debug("No auth config found")
def get_binary(self, pathfmt, *args, **kwargs): def get_binary(self, pathfmt: str, *args: str, **kwargs) -> bytes:
return self._result( return self._result(
self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs), self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_binary=True, get_binary=True,
) )
def get_json(self, pathfmt, *args, **kwargs): def get_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any:
return self._result( return self._result(
self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs), self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_json=True, get_json=True,
) )
def get_text(self, pathfmt, *args, **kwargs): def get_text(self, pathfmt: str, *args: str, **kwargs) -> str:
return self._result( return self._result(
self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs) self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs)
) )
def get_raw_stream(self, pathfmt, *args, **kwargs): def get_raw_stream(
chunk_size = kwargs.pop("chunk_size", DEFAULT_DATA_CHUNK_SIZE) self,
pathfmt: str,
*args: str,
chunk_size: int = DEFAULT_DATA_CHUNK_SIZE,
**kwargs,
) -> t.Generator[bytes]:
res = self._get( res = self._get(
self._url(pathfmt, *args, versioned_api=True), stream=True, **kwargs self._url(pathfmt, *args, versioned_api=True), stream=True, **kwargs
) )
self._raise_for_status(res) self._raise_for_status(res)
return self._stream_raw_result(res, chunk_size, False) return self._stream_raw_result(res, chunk_size=chunk_size, decode=False)
def delete_call(self, pathfmt, *args, **kwargs): def delete_call(self, pathfmt: str, *args: str, **kwargs) -> None:
self._raise_for_status( self._raise_for_status(
self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs) self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs)
) )
def delete_json(self, pathfmt, *args, **kwargs): def delete_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any:
return self._result( return self._result(
self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs), self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_json=True, get_json=True,
) )
def post_call(self, pathfmt, *args, **kwargs): def post_call(self, pathfmt: str, *args: str, **kwargs) -> None:
self._raise_for_status( self._raise_for_status(
self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs) self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs)
) )
def post_json(self, pathfmt, *args, **kwargs): def post_json(self, pathfmt: str, *args: str, data: t.Any = None, **kwargs) -> None:
data = kwargs.pop("data", None)
self._raise_for_status( self._raise_for_status(
self._post_json( self._post_json(
self._url(pathfmt, *args, versioned_api=True), data, **kwargs self._url(pathfmt, *args, versioned_api=True), data, **kwargs
) )
) )
def post_json_to_binary(self, pathfmt, *args, **kwargs): def post_json_to_binary(
data = kwargs.pop("data", None) self, pathfmt: str, *args: str, data: t.Any = None, **kwargs
) -> bytes:
return self._result( return self._result(
self._post_json( self._post_json(
self._url(pathfmt, *args, versioned_api=True), data, **kwargs self._url(pathfmt, *args, versioned_api=True), data, **kwargs
@ -595,8 +759,9 @@ class APIClient(_Session, DaemonApiMixin):
get_binary=True, get_binary=True,
) )
def post_json_to_json(self, pathfmt, *args, **kwargs): def post_json_to_json(
data = kwargs.pop("data", None) self, pathfmt: str, *args: str, data: t.Any = None, **kwargs
) -> t.Any:
return self._result( return self._result(
self._post_json( self._post_json(
self._url(pathfmt, *args, versioned_api=True), data, **kwargs self._url(pathfmt, *args, versioned_api=True), data, **kwargs
@ -604,17 +769,24 @@ class APIClient(_Session, DaemonApiMixin):
get_json=True, get_json=True,
) )
def post_json_to_text(self, pathfmt, *args, **kwargs): def post_json_to_text(
data = kwargs.pop("data", None) self, pathfmt: str, *args: str, data: t.Any = None, **kwargs
) -> str:
return self._result( return self._result(
self._post_json( self._post_json(
self._url(pathfmt, *args, versioned_api=True), data, **kwargs self._url(pathfmt, *args, versioned_api=True), data, **kwargs
), ),
) )
def post_json_to_stream_socket(self, pathfmt, *args, **kwargs): def post_json_to_stream_socket(
data = kwargs.pop("data", None) self,
headers = (kwargs.pop("headers", None) or {}).copy() pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
**kwargs,
) -> SocketLike:
headers = headers.copy() if headers else {}
headers.update( headers.update(
{ {
"Connection": "Upgrade", "Connection": "Upgrade",
@ -631,18 +803,102 @@ class APIClient(_Session, DaemonApiMixin):
) )
) )
def post_json_to_stream(self, pathfmt, *args, **kwargs): @t.overload
data = kwargs.pop("data", None) def post_json_to_stream(
headers = (kwargs.pop("headers", None) or {}).copy() self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: t.Literal[True],
tty: bool = True,
demux: t.Literal[False] = False,
**kwargs,
) -> t.Generator[bytes]: ...
@t.overload
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: t.Literal[True],
tty: t.Literal[True] = True,
demux: t.Literal[True],
**kwargs,
) -> t.Generator[tuple[bytes, None]]: ...
@t.overload
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: t.Literal[True],
tty: t.Literal[False],
demux: t.Literal[True],
**kwargs,
) -> t.Generator[tuple[bytes, None] | tuple[None, bytes]]: ...
@t.overload
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: t.Literal[False],
tty: bool = True,
demux: t.Literal[False] = False,
**kwargs,
) -> bytes: ...
@t.overload
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: t.Literal[False],
tty: t.Literal[True] = True,
demux: t.Literal[True],
**kwargs,
) -> tuple[bytes, None]: ...
@t.overload
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: t.Literal[False],
tty: t.Literal[False],
demux: t.Literal[True],
**kwargs,
) -> tuple[bytes, bytes]: ...
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: bool = False,
demux: bool = False,
tty: bool = False,
**kwargs,
) -> t.Any:
headers = headers.copy() if headers else {}
headers.update( headers.update(
{ {
"Connection": "Upgrade", "Connection": "Upgrade",
"Upgrade": "tcp", "Upgrade": "tcp",
} }
) )
stream = kwargs.pop("stream", False)
demux = kwargs.pop("demux", False)
tty = kwargs.pop("tty", False)
return self._read_from_socket( return self._read_from_socket(
self._post_json( self._post_json(
self._url(pathfmt, *args, versioned_api=True), self._url(pathfmt, *args, versioned_api=True),
@ -651,13 +907,133 @@ class APIClient(_Session, DaemonApiMixin):
stream=True, stream=True,
**kwargs, **kwargs,
), ),
stream, stream=stream,
tty=tty, tty=tty,
demux=demux, demux=demux,
) )
def post_to_json(self, pathfmt, *args, **kwargs): def post_to_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any:
return self._result( return self._result(
self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs), self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_json=True, get_json=True,
) )
@minimum_version("1.25")
def df(self) -> dict[str, t.Any]:
"""
Get data usage information.
Returns:
(dict): A dictionary representing different resource categories
and their respective data usage.
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
url = self._url("/system/df")
return self._result(self._get(url), get_json=True)
def info(self) -> dict[str, t.Any]:
"""
Display system-wide information. Identical to the ``docker info``
command.
Returns:
(dict): The info as a dict
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
return self._result(self._get(self._url("/info")), get_json=True)
def login(
self,
username: str,
password: str | None = None,
email: str | None = None,
registry: str | None = None,
reauth: bool = False,
dockercfg_path: str | None = None,
) -> dict[str, t.Any]:
"""
Authenticate with a registry. Similar to the ``docker login`` command.
Args:
username (str): The registry username
password (str): The plaintext password
email (str): The email for the registry account
registry (str): URL to the registry. E.g.
``https://index.docker.io/v1/``
reauth (bool): Whether or not to refresh existing authentication on
the Docker server.
dockercfg_path (str): Use a custom path for the Docker config file
(default ``$HOME/.docker/config.json`` if present,
otherwise ``$HOME/.dockercfg``)
Returns:
(dict): The response from the login request
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
# If we do not have any auth data so far, try reloading the config file
# one more time in case anything showed up in there.
# If dockercfg_path is passed check to see if the config file exists,
# if so load that config.
if dockercfg_path and os.path.exists(dockercfg_path):
self._auth_configs = auth.load_config(
dockercfg_path, credstore_env=self.credstore_env
)
elif not self._auth_configs or self._auth_configs.is_empty:
self._auth_configs = auth.load_config(credstore_env=self.credstore_env)
authcfg = self._auth_configs.resolve_authconfig(registry)
# If we found an existing auth config for this registry and username
# combination, we can return it immediately unless reauth is requested.
if authcfg and authcfg.get("username", None) == username and not reauth:
return authcfg
req_data = {
"username": username,
"password": password,
"email": email,
"serveraddress": registry,
}
response = self._post_json(self._url("/auth"), data=req_data)
if response.status_code == 200:
self._auth_configs.add_auth(registry or auth.INDEX_NAME, req_data)
return self._result(response, get_json=True)
def ping(self) -> bool:
"""
Checks the server is responsive. An exception will be raised if it
is not responding.
Returns:
(bool) The response from the server.
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
return self._result(self._get(self._url("/_ping"))) == "OK"
def version(self, api_version: bool = True) -> dict[str, t.Any]:
"""
Returns version information from the server. Similar to the ``docker
version`` command.
Returns:
(dict): The server version information
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
url = self._url("/version", versioned_api=api_version)
return self._result(self._get(url), get_json=True)

View File

@ -1,139 +0,0 @@
# This code is part of the Ansible collection community.docker, but is an independent component.
# This particular file, and this file only, is based on the Docker SDK for Python (https://github.com/docker/docker-py/)
#
# Copyright (c) 2016-2022 Docker, Inc.
#
# It is licensed under the Apache 2.0 license (see LICENSES/Apache-2.0.txt in this collection)
# SPDX-License-Identifier: Apache-2.0
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import os
from .. import auth
from ..utils.decorators import minimum_version
class DaemonApiMixin:
@minimum_version("1.25")
def df(self):
"""
Get data usage information.
Returns:
(dict): A dictionary representing different resource categories
and their respective data usage.
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
url = self._url("/system/df")
return self._result(self._get(url), get_json=True)
def info(self):
"""
Display system-wide information. Identical to the ``docker info``
command.
Returns:
(dict): The info as a dict
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
return self._result(self._get(self._url("/info")), get_json=True)
def login(
self,
username,
password=None,
email=None,
registry=None,
reauth=False,
dockercfg_path=None,
):
"""
Authenticate with a registry. Similar to the ``docker login`` command.
Args:
username (str): The registry username
password (str): The plaintext password
email (str): The email for the registry account
registry (str): URL to the registry. E.g.
``https://index.docker.io/v1/``
reauth (bool): Whether or not to refresh existing authentication on
the Docker server.
dockercfg_path (str): Use a custom path for the Docker config file
(default ``$HOME/.docker/config.json`` if present,
otherwise ``$HOME/.dockercfg``)
Returns:
(dict): The response from the login request
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
# If we do not have any auth data so far, try reloading the config file
# one more time in case anything showed up in there.
# If dockercfg_path is passed check to see if the config file exists,
# if so load that config.
if dockercfg_path and os.path.exists(dockercfg_path):
self._auth_configs = auth.load_config(
dockercfg_path, credstore_env=self.credstore_env
)
elif not self._auth_configs or self._auth_configs.is_empty:
self._auth_configs = auth.load_config(credstore_env=self.credstore_env)
authcfg = self._auth_configs.resolve_authconfig(registry)
# If we found an existing auth config for this registry and username
# combination, we can return it immediately unless reauth is requested.
if authcfg and authcfg.get("username", None) == username and not reauth:
return authcfg
req_data = {
"username": username,
"password": password,
"email": email,
"serveraddress": registry,
}
response = self._post_json(self._url("/auth"), data=req_data)
if response.status_code == 200:
self._auth_configs.add_auth(registry or auth.INDEX_NAME, req_data)
return self._result(response, get_json=True)
def ping(self):
"""
Checks the server is responsive. An exception will be raised if it
is not responding.
Returns:
(bool) The response from the server.
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
return self._result(self._get(self._url("/_ping"))) == "OK"
def version(self, api_version=True):
"""
Returns version information from the server. Similar to the ``docker
version`` command.
Returns:
(dict): The server version information
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
url = self._url("/version", versioned_api=api_version)
return self._result(self._get(url), get_json=True)

View File

@ -14,6 +14,7 @@ from __future__ import annotations
import base64 import base64
import json import json
import logging import logging
import typing as t
from . import errors from . import errors
from .credentials.errors import CredentialsNotFound, StoreError from .credentials.errors import CredentialsNotFound, StoreError
@ -21,6 +22,12 @@ from .credentials.store import Store
from .utils import config from .utils import config
if t.TYPE_CHECKING:
from ansible_collections.community.docker.plugins.module_utils._api.api.client import (
APIClient,
)
INDEX_NAME = "docker.io" INDEX_NAME = "docker.io"
INDEX_URL = f"https://index.{INDEX_NAME}/v1/" INDEX_URL = f"https://index.{INDEX_NAME}/v1/"
TOKEN_USERNAME = "<token>" TOKEN_USERNAME = "<token>"
@ -28,7 +35,7 @@ TOKEN_USERNAME = "<token>"
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def resolve_repository_name(repo_name): def resolve_repository_name(repo_name: str) -> tuple[str, str]:
if "://" in repo_name: if "://" in repo_name:
raise errors.InvalidRepository( raise errors.InvalidRepository(
f"Repository name cannot contain a scheme ({repo_name})" f"Repository name cannot contain a scheme ({repo_name})"
@ -42,14 +49,14 @@ def resolve_repository_name(repo_name):
return resolve_index_name(index_name), remote_name return resolve_index_name(index_name), remote_name
def resolve_index_name(index_name): def resolve_index_name(index_name: str) -> str:
index_name = convert_to_hostname(index_name) index_name = convert_to_hostname(index_name)
if index_name == "index." + INDEX_NAME: if index_name == "index." + INDEX_NAME:
index_name = INDEX_NAME index_name = INDEX_NAME
return index_name return index_name
def get_config_header(client, registry): def get_config_header(client: APIClient, registry: str) -> bytes | None:
log.debug("Looking for auth config") log.debug("Looking for auth config")
if not client._auth_configs or client._auth_configs.is_empty: if not client._auth_configs or client._auth_configs.is_empty:
log.debug("No auth config in memory - loading from filesystem") log.debug("No auth config in memory - loading from filesystem")
@ -69,32 +76,38 @@ def get_config_header(client, registry):
return None return None
def split_repo_name(repo_name): def split_repo_name(repo_name: str) -> tuple[str, str]:
parts = repo_name.split("/", 1) parts = repo_name.split("/", 1)
if len(parts) == 1 or ( if len(parts) == 1 or (
"." not in parts[0] and ":" not in parts[0] and parts[0] != "localhost" "." not in parts[0] and ":" not in parts[0] and parts[0] != "localhost"
): ):
# This is a docker index repo (ex: username/foobar or ubuntu) # This is a docker index repo (ex: username/foobar or ubuntu)
return INDEX_NAME, repo_name return INDEX_NAME, repo_name
return tuple(parts) return tuple(parts) # type: ignore
def get_credential_store(authconfig, registry): def get_credential_store(
authconfig: dict[str, t.Any] | AuthConfig, registry: str
) -> str | None:
if not isinstance(authconfig, AuthConfig): if not isinstance(authconfig, AuthConfig):
authconfig = AuthConfig(authconfig) authconfig = AuthConfig(authconfig)
return authconfig.get_credential_store(registry) return authconfig.get_credential_store(registry)
class AuthConfig(dict): class AuthConfig(dict):
def __init__(self, dct, credstore_env=None): def __init__(
self, dct: dict[str, t.Any], credstore_env: dict[str, str] | None = None
):
if "auths" not in dct: if "auths" not in dct:
dct["auths"] = {} dct["auths"] = {}
self.update(dct) self.update(dct)
self._credstore_env = credstore_env self._credstore_env = credstore_env
self._stores = {} self._stores: dict[str, Store] = {}
@classmethod @classmethod
def parse_auth(cls, entries, raise_on_error=False): def parse_auth(
cls, entries: dict[str, dict[str, t.Any]], raise_on_error=False
) -> dict[str, dict[str, t.Any]]:
""" """
Parses authentication entries Parses authentication entries
@ -107,10 +120,10 @@ class AuthConfig(dict):
Authentication registry. Authentication registry.
""" """
conf = {} conf: dict[str, dict[str, t.Any]] = {}
for registry, entry in entries.items(): for registry, entry in entries.items():
if not isinstance(entry, dict): if not isinstance(entry, dict):
log.debug("Config entry for key %s is not auth config", registry) log.debug("Config entry for key %s is not auth config", registry) # type: ignore
# We sometimes fall back to parsing the whole config as if it # We sometimes fall back to parsing the whole config as if it
# was the auth config by itself, for legacy purposes. In that # was the auth config by itself, for legacy purposes. In that
# case, we fail silently and return an empty conf if any of the # case, we fail silently and return an empty conf if any of the
@ -150,7 +163,12 @@ class AuthConfig(dict):
return conf return conf
@classmethod @classmethod
def load_config(cls, config_path, config_dict, credstore_env=None): def load_config(
cls,
config_path: str | None,
config_dict: dict[str, t.Any] | None,
credstore_env: dict[str, str] | None = None,
) -> t.Self:
""" """
Loads authentication data from a Docker configuration file in the given Loads authentication data from a Docker configuration file in the given
root directory or if config_path is passed use given path. root directory or if config_path is passed use given path.
@ -196,22 +214,24 @@ class AuthConfig(dict):
return cls({"auths": cls.parse_auth(config_dict)}, credstore_env) return cls({"auths": cls.parse_auth(config_dict)}, credstore_env)
@property @property
def auths(self): def auths(self) -> dict[str, dict[str, t.Any]]:
return self.get("auths", {}) return self.get("auths", {})
@property @property
def creds_store(self): def creds_store(self) -> str | None:
return self.get("credsStore", None) return self.get("credsStore", None)
@property @property
def cred_helpers(self): def cred_helpers(self) -> dict[str, t.Any]:
return self.get("credHelpers", {}) return self.get("credHelpers", {})
@property @property
def is_empty(self): def is_empty(self) -> bool:
return not self.auths and not self.creds_store and not self.cred_helpers return not self.auths and not self.creds_store and not self.cred_helpers
def resolve_authconfig(self, registry=None): def resolve_authconfig(
self, registry: str | None = None
) -> dict[str, t.Any] | None:
""" """
Returns the authentication data from the given auth configuration for a Returns the authentication data from the given auth configuration for a
specific registry. As with the Docker client, legacy entries in the specific registry. As with the Docker client, legacy entries in the
@ -244,7 +264,9 @@ class AuthConfig(dict):
log.debug("No entry found") log.debug("No entry found")
return None return None
def _resolve_authconfig_credstore(self, registry, credstore_name): def _resolve_authconfig_credstore(
self, registry: str | None, credstore_name: str
) -> dict[str, t.Any] | None:
if not registry or registry == INDEX_NAME: if not registry or registry == INDEX_NAME:
# The ecosystem is a little schizophrenic with index.docker.io VS # The ecosystem is a little schizophrenic with index.docker.io VS
# docker.io - in that case, it seems the full URL is necessary. # docker.io - in that case, it seems the full URL is necessary.
@ -272,19 +294,19 @@ class AuthConfig(dict):
except StoreError as e: except StoreError as e:
raise errors.DockerException(f"Credentials store error: {e}") raise errors.DockerException(f"Credentials store error: {e}")
def _get_store_instance(self, name): def _get_store_instance(self, name: str):
if name not in self._stores: if name not in self._stores:
self._stores[name] = Store(name, environment=self._credstore_env) self._stores[name] = Store(name, environment=self._credstore_env)
return self._stores[name] return self._stores[name]
def get_credential_store(self, registry): def get_credential_store(self, registry: str | None) -> str | None:
if not registry or registry == INDEX_NAME: if not registry or registry == INDEX_NAME:
registry = INDEX_URL registry = INDEX_URL
return self.cred_helpers.get(registry) or self.creds_store return self.cred_helpers.get(registry) or self.creds_store
def get_all_credentials(self): def get_all_credentials(self) -> dict[str, dict[str, t.Any] | None]:
auth_data = self.auths.copy() auth_data: dict[str, dict[str, t.Any] | None] = self.auths.copy() # type: ignore
if self.creds_store: if self.creds_store:
# Retrieve all credentials from the default store # Retrieve all credentials from the default store
store = self._get_store_instance(self.creds_store) store = self._get_store_instance(self.creds_store)
@ -299,21 +321,23 @@ class AuthConfig(dict):
return auth_data return auth_data
def add_auth(self, reg, data): def add_auth(self, reg: str, data: dict[str, t.Any]) -> None:
self["auths"][reg] = data self["auths"][reg] = data
def resolve_authconfig(authconfig, registry=None, credstore_env=None): def resolve_authconfig(
authconfig, registry: str | None = None, credstore_env: dict[str, str] | None = None
):
if not isinstance(authconfig, AuthConfig): if not isinstance(authconfig, AuthConfig):
authconfig = AuthConfig(authconfig, credstore_env) authconfig = AuthConfig(authconfig, credstore_env)
return authconfig.resolve_authconfig(registry) return authconfig.resolve_authconfig(registry)
def convert_to_hostname(url): def convert_to_hostname(url: str) -> str:
return url.replace("http://", "").replace("https://", "").split("/", 1)[0] return url.replace("http://", "").replace("https://", "").split("/", 1)[0]
def decode_auth(auth): def decode_auth(auth: str | bytes) -> tuple[str, str]:
if isinstance(auth, str): if isinstance(auth, str):
auth = auth.encode("ascii") auth = auth.encode("ascii")
s = base64.b64decode(auth) s = base64.b64decode(auth)
@ -321,12 +345,14 @@ def decode_auth(auth):
return login.decode("utf8"), pwd.decode("utf8") return login.decode("utf8"), pwd.decode("utf8")
def encode_header(auth): def encode_header(auth: dict[str, t.Any]) -> bytes:
auth_json = json.dumps(auth).encode("ascii") auth_json = json.dumps(auth).encode("ascii")
return base64.urlsafe_b64encode(auth_json) return base64.urlsafe_b64encode(auth_json)
def parse_auth(entries, raise_on_error=False): def parse_auth(
entries: dict[str, dict[str, t.Any]], raise_on_error: bool = False
) -> dict[str, dict[str, t.Any]]:
""" """
Parses authentication entries Parses authentication entries
@ -342,11 +368,15 @@ def parse_auth(entries, raise_on_error=False):
return AuthConfig.parse_auth(entries, raise_on_error) return AuthConfig.parse_auth(entries, raise_on_error)
def load_config(config_path=None, config_dict=None, credstore_env=None): def load_config(
config_path: str | None = None,
config_dict: dict[str, t.Any] | None = None,
credstore_env: dict[str, str] | None = None,
) -> AuthConfig:
return AuthConfig.load_config(config_path, config_dict, credstore_env) return AuthConfig.load_config(config_path, config_dict, credstore_env)
def _load_legacy_config(config_file): def _load_legacy_config(config_file: str) -> dict[str, dict[str, t.Any]]:
log.debug("Attempting to parse legacy auth file format") log.debug("Attempting to parse legacy auth file format")
try: try:
data = [] data = []

View File

@ -13,6 +13,7 @@ from __future__ import annotations
import json import json
import os import os
import typing as t
from .. import errors from .. import errors
from .config import ( from .config import (
@ -24,7 +25,11 @@ from .config import (
from .context import Context from .context import Context
def create_default_context(): if t.TYPE_CHECKING:
from ..tls import TLSConfig
def create_default_context() -> Context:
host = None host = None
if os.environ.get("DOCKER_HOST"): if os.environ.get("DOCKER_HOST"):
host = os.environ.get("DOCKER_HOST") host = os.environ.get("DOCKER_HOST")
@ -42,7 +47,7 @@ class ContextAPI:
DEFAULT_CONTEXT = None DEFAULT_CONTEXT = None
@classmethod @classmethod
def get_default_context(cls): def get_default_context(cls) -> Context:
context = cls.DEFAULT_CONTEXT context = cls.DEFAULT_CONTEXT
if context is None: if context is None:
context = create_default_context() context = create_default_context()
@ -52,13 +57,13 @@ class ContextAPI:
@classmethod @classmethod
def create_context( def create_context(
cls, cls,
name, name: str,
orchestrator=None, orchestrator: str | None = None,
host=None, host: str | None = None,
tls_cfg=None, tls_cfg: TLSConfig | None = None,
default_namespace=None, default_namespace: str | None = None,
skip_tls_verify=False, skip_tls_verify: bool = False,
): ) -> Context:
"""Creates a new context. """Creates a new context.
Returns: Returns:
(Context): a Context object. (Context): a Context object.
@ -108,7 +113,7 @@ class ContextAPI:
return ctx return ctx
@classmethod @classmethod
def get_context(cls, name=None): def get_context(cls, name: str | None = None) -> Context | None:
"""Retrieves a context object. """Retrieves a context object.
Args: Args:
name (str): The name of the context name (str): The name of the context
@ -136,7 +141,7 @@ class ContextAPI:
return Context.load_context(name) return Context.load_context(name)
@classmethod @classmethod
def contexts(cls): def contexts(cls) -> list[Context]:
"""Context list. """Context list.
Returns: Returns:
(Context): List of context objects. (Context): List of context objects.
@ -170,7 +175,7 @@ class ContextAPI:
return contexts return contexts
@classmethod @classmethod
def get_current_context(cls): def get_current_context(cls) -> Context | None:
"""Get current context. """Get current context.
Returns: Returns:
(Context): current context object. (Context): current context object.
@ -178,7 +183,7 @@ class ContextAPI:
return cls.get_context() return cls.get_context()
@classmethod @classmethod
def set_current_context(cls, name="default"): def set_current_context(cls, name: str = "default") -> None:
ctx = cls.get_context(name) ctx = cls.get_context(name)
if not ctx: if not ctx:
raise errors.ContextNotFound(name) raise errors.ContextNotFound(name)
@ -188,7 +193,7 @@ class ContextAPI:
raise errors.ContextException(f"Failed to set current context: {err}") raise errors.ContextException(f"Failed to set current context: {err}")
@classmethod @classmethod
def remove_context(cls, name): def remove_context(cls, name: str) -> None:
"""Remove a context. Similar to the ``docker context rm`` command. """Remove a context. Similar to the ``docker context rm`` command.
Args: Args:
@ -220,7 +225,7 @@ class ContextAPI:
ctx.remove() ctx.remove()
@classmethod @classmethod
def inspect_context(cls, name="default"): def inspect_context(cls, name: str = "default") -> dict[str, t.Any]:
"""Inspect a context. Similar to the ``docker context inspect`` command. """Inspect a context. Similar to the ``docker context inspect`` command.
Args: Args:

View File

@ -23,7 +23,7 @@ from ..utils.utils import parse_host
METAFILE = "meta.json" METAFILE = "meta.json"
def get_current_context_name_with_source(): def get_current_context_name_with_source() -> tuple[str, str]:
if os.environ.get("DOCKER_HOST"): if os.environ.get("DOCKER_HOST"):
return "default", "DOCKER_HOST environment variable set" return "default", "DOCKER_HOST environment variable set"
if os.environ.get("DOCKER_CONTEXT"): if os.environ.get("DOCKER_CONTEXT"):
@ -41,11 +41,11 @@ def get_current_context_name_with_source():
return "default", "fallback value" return "default", "fallback value"
def get_current_context_name(): def get_current_context_name() -> str:
return get_current_context_name_with_source()[0] return get_current_context_name_with_source()[0]
def write_context_name_to_docker_config(name=None): def write_context_name_to_docker_config(name: str | None = None) -> Exception | None:
if name == "default": if name == "default":
name = None name = None
docker_cfg_path = find_config_file() docker_cfg_path = find_config_file()
@ -62,44 +62,45 @@ def write_context_name_to_docker_config(name=None):
elif name: elif name:
config["currentContext"] = name config["currentContext"] = name
else: else:
return return None
if not docker_cfg_path: if not docker_cfg_path:
docker_cfg_path = get_default_config_file() docker_cfg_path = get_default_config_file()
try: try:
with open(docker_cfg_path, "wt", encoding="utf-8") as f: with open(docker_cfg_path, "wt", encoding="utf-8") as f:
json.dump(config, f, indent=4) json.dump(config, f, indent=4)
return None
except Exception as e: # pylint: disable=broad-exception-caught except Exception as e: # pylint: disable=broad-exception-caught
return e return e
def get_context_id(name): def get_context_id(name: str) -> str:
return hashlib.sha256(name.encode("utf-8")).hexdigest() return hashlib.sha256(name.encode("utf-8")).hexdigest()
def get_context_dir(): def get_context_dir() -> str:
docker_cfg_path = find_config_file() or get_default_config_file() docker_cfg_path = find_config_file() or get_default_config_file()
return os.path.join(os.path.dirname(docker_cfg_path), "contexts") return os.path.join(os.path.dirname(docker_cfg_path), "contexts")
def get_meta_dir(name=None): def get_meta_dir(name: str | None = None) -> str:
meta_dir = os.path.join(get_context_dir(), "meta") meta_dir = os.path.join(get_context_dir(), "meta")
if name: if name:
return os.path.join(meta_dir, get_context_id(name)) return os.path.join(meta_dir, get_context_id(name))
return meta_dir return meta_dir
def get_meta_file(name): def get_meta_file(name) -> str:
return os.path.join(get_meta_dir(name), METAFILE) return os.path.join(get_meta_dir(name), METAFILE)
def get_tls_dir(name=None, endpoint=""): def get_tls_dir(name: str | None = None, endpoint: str = "") -> str:
context_dir = get_context_dir() context_dir = get_context_dir()
if name: if name:
return os.path.join(context_dir, "tls", get_context_id(name), endpoint) return os.path.join(context_dir, "tls", get_context_id(name), endpoint)
return os.path.join(context_dir, "tls") return os.path.join(context_dir, "tls")
def get_context_host(path=None, tls=False): def get_context_host(path: str | None = None, tls: bool = False) -> str:
host = parse_host(path, IS_WINDOWS_PLATFORM, tls) host = parse_host(path, IS_WINDOWS_PLATFORM, tls)
if host == DEFAULT_UNIX_SOCKET: if host == DEFAULT_UNIX_SOCKET:
# remove http+ from default docker socket url # remove http+ from default docker socket url

View File

@ -13,6 +13,7 @@ from __future__ import annotations
import json import json
import os import os
import typing as t
from shutil import copyfile, rmtree from shutil import copyfile, rmtree
from ..errors import ContextException from ..errors import ContextException
@ -33,21 +34,21 @@ class Context:
def __init__( def __init__(
self, self,
name, name: str,
orchestrator=None, orchestrator: str | None = None,
host=None, host: str | None = None,
endpoints=None, endpoints: dict[str, dict[str, t.Any]] | None = None,
skip_tls_verify=False, skip_tls_verify: bool = False,
tls=False, tls: bool = False,
description=None, description: str | None = None,
): ) -> None:
if not name: if not name:
raise ValueError("Name not provided") raise ValueError("Name not provided")
self.name = name self.name = name
self.context_type = None self.context_type = None
self.orchestrator = orchestrator self.orchestrator = orchestrator
self.endpoints = {} self.endpoints = {}
self.tls_cfg = {} self.tls_cfg: dict[str, TLSConfig] = {}
self.meta_path = IN_MEMORY self.meta_path = IN_MEMORY
self.tls_path = IN_MEMORY self.tls_path = IN_MEMORY
self.description = description self.description = description
@ -89,12 +90,12 @@ class Context:
def set_endpoint( def set_endpoint(
self, self,
name="docker", name: str = "docker",
host=None, host: str | None = None,
tls_cfg=None, tls_cfg: TLSConfig | None = None,
skip_tls_verify=False, skip_tls_verify: bool = False,
def_namespace=None, def_namespace: str | None = None,
): ) -> None:
self.endpoints[name] = { self.endpoints[name] = {
"Host": get_context_host(host, not skip_tls_verify or tls_cfg is not None), "Host": get_context_host(host, not skip_tls_verify or tls_cfg is not None),
"SkipTLSVerify": skip_tls_verify, "SkipTLSVerify": skip_tls_verify,
@ -105,11 +106,11 @@ class Context:
if tls_cfg: if tls_cfg:
self.tls_cfg[name] = tls_cfg self.tls_cfg[name] = tls_cfg
def inspect(self): def inspect(self) -> dict[str, t.Any]:
return self() return self()
@classmethod @classmethod
def load_context(cls, name): def load_context(cls, name: str) -> t.Self | None:
meta = Context._load_meta(name) meta = Context._load_meta(name)
if meta: if meta:
instance = cls( instance = cls(
@ -125,12 +126,12 @@ class Context:
return None return None
@classmethod @classmethod
def _load_meta(cls, name): def _load_meta(cls, name: str) -> dict[str, t.Any] | None:
meta_file = get_meta_file(name) meta_file = get_meta_file(name)
if not os.path.isfile(meta_file): if not os.path.isfile(meta_file):
return None return None
metadata = {} metadata: dict[str, t.Any] = {}
try: try:
with open(meta_file, "rt", encoding="utf-8") as f: with open(meta_file, "rt", encoding="utf-8") as f:
metadata = json.load(f) metadata = json.load(f)
@ -154,7 +155,7 @@ class Context:
return metadata return metadata
def _load_certs(self): def _load_certs(self) -> None:
certs = {} certs = {}
tls_dir = get_tls_dir(self.name) tls_dir = get_tls_dir(self.name)
for endpoint in self.endpoints: for endpoint in self.endpoints:
@ -184,7 +185,7 @@ class Context:
self.tls_cfg = certs self.tls_cfg = certs
self.tls_path = tls_dir self.tls_path = tls_dir
def save(self): def save(self) -> None:
meta_dir = get_meta_dir(self.name) meta_dir = get_meta_dir(self.name)
if not os.path.isdir(meta_dir): if not os.path.isdir(meta_dir):
os.makedirs(meta_dir) os.makedirs(meta_dir)
@ -216,54 +217,54 @@ class Context:
self.meta_path = get_meta_dir(self.name) self.meta_path = get_meta_dir(self.name)
self.tls_path = get_tls_dir(self.name) self.tls_path = get_tls_dir(self.name)
def remove(self): def remove(self) -> None:
if os.path.isdir(self.meta_path): if os.path.isdir(self.meta_path):
rmtree(self.meta_path) rmtree(self.meta_path)
if os.path.isdir(self.tls_path): if os.path.isdir(self.tls_path):
rmtree(self.tls_path) rmtree(self.tls_path)
def __repr__(self): def __repr__(self) -> str:
return f"<{self.__class__.__name__}: '{self.name}'>" return f"<{self.__class__.__name__}: '{self.name}'>"
def __str__(self): def __str__(self) -> str:
return json.dumps(self.__call__(), indent=2) return json.dumps(self.__call__(), indent=2)
def __call__(self): def __call__(self) -> dict[str, t.Any]:
result = self.Metadata result = self.Metadata
result.update(self.TLSMaterial) result.update(self.TLSMaterial)
result.update(self.Storage) result.update(self.Storage)
return result return result
def is_docker_host(self): def is_docker_host(self) -> bool:
return self.context_type is None return self.context_type is None
@property @property
def Name(self): # pylint: disable=invalid-name def Name(self) -> str: # pylint: disable=invalid-name
return self.name return self.name
@property @property
def Host(self): # pylint: disable=invalid-name def Host(self) -> str | None: # pylint: disable=invalid-name
if not self.orchestrator or self.orchestrator == "swarm": if not self.orchestrator or self.orchestrator == "swarm":
endpoint = self.endpoints.get("docker", None) endpoint = self.endpoints.get("docker", None)
if endpoint: if endpoint:
return endpoint.get("Host", None) return endpoint.get("Host", None) # type: ignore
return None return None
return self.endpoints[self.orchestrator].get("Host", None) return self.endpoints[self.orchestrator].get("Host", None) # type: ignore
@property @property
def Orchestrator(self): # pylint: disable=invalid-name def Orchestrator(self) -> str | None: # pylint: disable=invalid-name
return self.orchestrator return self.orchestrator
@property @property
def Metadata(self): # pylint: disable=invalid-name def Metadata(self) -> dict[str, t.Any]: # pylint: disable=invalid-name
meta = {} meta: dict[str, t.Any] = {}
if self.orchestrator: if self.orchestrator:
meta = {"StackOrchestrator": self.orchestrator} meta = {"StackOrchestrator": self.orchestrator}
return {"Name": self.name, "Metadata": meta, "Endpoints": self.endpoints} return {"Name": self.name, "Metadata": meta, "Endpoints": self.endpoints}
@property @property
def TLSConfig(self): # pylint: disable=invalid-name def TLSConfig(self) -> TLSConfig | None: # pylint: disable=invalid-name
key = self.orchestrator key = self.orchestrator
if not key or key == "swarm": if not key or key == "swarm":
key = "docker" key = "docker"
@ -272,13 +273,15 @@ class Context:
return None return None
@property @property
def TLSMaterial(self): # pylint: disable=invalid-name def TLSMaterial(self) -> dict[str, t.Any]: # pylint: disable=invalid-name
certs = {} certs: dict[str, t.Any] = {}
for endpoint, tls in self.tls_cfg.items(): for endpoint, tls in self.tls_cfg.items():
cert, key = tls.cert paths = [tls.ca_cert, *tls.cert] if tls.cert else [tls.ca_cert]
certs[endpoint] = list(map(os.path.basename, [tls.ca_cert, cert, key])) certs[endpoint] = [
os.path.basename(path) if path else None for path in paths
]
return {"TLSMaterial": certs} return {"TLSMaterial": certs}
@property @property
def Storage(self): # pylint: disable=invalid-name def Storage(self) -> dict[str, t.Any]: # pylint: disable=invalid-name
return {"Storage": {"MetadataPath": self.meta_path, "TLSPath": self.tls_path}} return {"Storage": {"MetadataPath": self.meta_path, "TLSPath": self.tls_path}}

View File

@ -11,6 +11,12 @@
from __future__ import annotations from __future__ import annotations
import typing as t
if t.TYPE_CHECKING:
from subprocess import CalledProcessError
class StoreError(RuntimeError): class StoreError(RuntimeError):
pass pass
@ -24,7 +30,7 @@ class InitializationError(StoreError):
pass pass
def process_store_error(cpe, program): def process_store_error(cpe: CalledProcessError, program: str) -> StoreError:
message = cpe.output.decode("utf-8") message = cpe.output.decode("utf-8")
if "credentials not found in native keychain" in message: if "credentials not found in native keychain" in message:
return CredentialsNotFound(f"No matching credentials in {program}") return CredentialsNotFound(f"No matching credentials in {program}")

View File

@ -14,13 +14,14 @@ from __future__ import annotations
import errno import errno
import json import json
import subprocess import subprocess
import typing as t
from . import constants, errors from . import constants, errors
from .utils import create_environment_dict, find_executable from .utils import create_environment_dict, find_executable
class Store: class Store:
def __init__(self, program, environment=None): def __init__(self, program: str, environment: dict[str, str] | None = None) -> None:
"""Create a store object that acts as an interface to """Create a store object that acts as an interface to
perform the basic operations for storing, retrieving perform the basic operations for storing, retrieving
and erasing credentials using `program`. and erasing credentials using `program`.
@ -33,7 +34,7 @@ class Store:
f"{self.program} not installed or not available in PATH" f"{self.program} not installed or not available in PATH"
) )
def get(self, server): def get(self, server: str | bytes) -> dict[str, t.Any]:
"""Retrieve credentials for `server`. If no credentials are found, """Retrieve credentials for `server`. If no credentials are found,
a `StoreError` will be raised. a `StoreError` will be raised.
""" """
@ -53,7 +54,7 @@ class Store:
return result return result
def store(self, server, username, secret): def store(self, server: str, username: str, secret: str) -> bytes:
"""Store credentials for `server`. Raises a `StoreError` if an error """Store credentials for `server`. Raises a `StoreError` if an error
occurs. occurs.
""" """
@ -62,7 +63,7 @@ class Store:
).encode("utf-8") ).encode("utf-8")
return self._execute("store", data_input) return self._execute("store", data_input)
def erase(self, server): def erase(self, server: str | bytes) -> None:
"""Erase credentials for `server`. Raises a `StoreError` if an error """Erase credentials for `server`. Raises a `StoreError` if an error
occurs. occurs.
""" """
@ -70,12 +71,16 @@ class Store:
server = server.encode("utf-8") server = server.encode("utf-8")
self._execute("erase", server) self._execute("erase", server)
def list(self): def list(self) -> t.Any:
"""List stored credentials. Requires v0.4.0+ of the helper.""" """List stored credentials. Requires v0.4.0+ of the helper."""
data = self._execute("list", None) data = self._execute("list", None)
return json.loads(data.decode("utf-8")) return json.loads(data.decode("utf-8"))
def _execute(self, subcmd, data_input): def _execute(self, subcmd: str, data_input: bytes | None) -> bytes:
if self.exe is None:
raise errors.StoreError(
f"{self.program} not installed or not available in PATH"
)
output = None output = None
env = create_environment_dict(self.environment) env = create_environment_dict(self.environment)
try: try:

View File

@ -15,7 +15,7 @@ import os
from shutil import which from shutil import which
def find_executable(executable, path=None): def find_executable(executable: str, path: str | None = None) -> str | None:
""" """
As distutils.spawn.find_executable, but on Windows, look up As distutils.spawn.find_executable, but on Windows, look up
every extension declared in PATHEXT instead of just `.exe` every extension declared in PATHEXT instead of just `.exe`
@ -26,7 +26,7 @@ def find_executable(executable, path=None):
return which(executable, path=path) return which(executable, path=path)
def create_environment_dict(overrides): def create_environment_dict(overrides: dict[str, str] | None) -> dict[str, str]:
""" """
Create and return a copy of os.environ with the specified overrides Create and return a copy of os.environ with the specified overrides
""" """

View File

@ -11,6 +11,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from ansible.module_utils.common.text.converters import to_native from ansible.module_utils.common.text.converters import to_native
from ._import_helper import HTTPError as _HTTPError from ._import_helper import HTTPError as _HTTPError
@ -25,7 +27,7 @@ class DockerException(Exception):
""" """
def create_api_error_from_http_exception(e): def create_api_error_from_http_exception(e: _HTTPError) -> t.NoReturn:
""" """
Create a suitable APIError from requests.exceptions.HTTPError. Create a suitable APIError from requests.exceptions.HTTPError.
""" """
@ -52,14 +54,16 @@ class APIError(_HTTPError, DockerException):
An HTTP error from the API. An HTTP error from the API.
""" """
def __init__(self, message, response=None, explanation=None): def __init__(
self, message: str | Exception, response=None, explanation: str | None = None
) -> None:
# requests 1.2 supports response as a keyword argument, but # requests 1.2 supports response as a keyword argument, but
# requests 1.1 does not # requests 1.1 does not
super().__init__(message) super().__init__(message)
self.response = response self.response = response
self.explanation = explanation self.explanation = explanation or ""
def __str__(self): def __str__(self) -> str:
message = super().__str__() message = super().__str__()
if self.is_client_error(): if self.is_client_error():
@ -74,19 +78,20 @@ class APIError(_HTTPError, DockerException):
return message return message
@property @property
def status_code(self): def status_code(self) -> int | None:
if self.response is not None: if self.response is not None:
return self.response.status_code return self.response.status_code
return None
def is_error(self): def is_error(self) -> bool:
return self.is_client_error() or self.is_server_error() return self.is_client_error() or self.is_server_error()
def is_client_error(self): def is_client_error(self) -> bool:
if self.status_code is None: if self.status_code is None:
return False return False
return 400 <= self.status_code < 500 return 400 <= self.status_code < 500
def is_server_error(self): def is_server_error(self) -> bool:
if self.status_code is None: if self.status_code is None:
return False return False
return 500 <= self.status_code < 600 return 500 <= self.status_code < 600
@ -121,10 +126,10 @@ class DeprecatedMethod(DockerException):
class TLSParameterError(DockerException): class TLSParameterError(DockerException):
def __init__(self, msg): def __init__(self, msg: str) -> None:
self.msg = msg self.msg = msg
def __str__(self): def __str__(self) -> str:
return self.msg + ( return self.msg + (
". TLS configurations should map the Docker CLI " ". TLS configurations should map the Docker CLI "
"client configurations. See " "client configurations. See "
@ -142,7 +147,14 @@ class ContainerError(DockerException):
Represents a container that has exited with a non-zero exit code. Represents a container that has exited with a non-zero exit code.
""" """
def __init__(self, container, exit_status, command, image, stderr): def __init__(
self,
container: str,
exit_status: int,
command: list[str],
image: str,
stderr: str | None,
):
self.container = container self.container = container
self.exit_status = exit_status self.exit_status = exit_status
self.command = command self.command = command
@ -156,12 +168,12 @@ class ContainerError(DockerException):
class StreamParseError(RuntimeError): class StreamParseError(RuntimeError):
def __init__(self, reason): def __init__(self, reason: Exception) -> None:
self.msg = reason self.msg = reason
class BuildError(DockerException): class BuildError(DockerException):
def __init__(self, reason, build_log): def __init__(self, reason: str, build_log: str) -> None:
super().__init__(reason) super().__init__(reason)
self.msg = reason self.msg = reason
self.build_log = build_log self.build_log = build_log
@ -171,7 +183,7 @@ class ImageLoadError(DockerException):
pass pass
def create_unexpected_kwargs_error(name, kwargs): def create_unexpected_kwargs_error(name: str, kwargs: dict[str, t.Any]) -> TypeError:
quoted_kwargs = [f"'{k}'" for k in sorted(kwargs)] quoted_kwargs = [f"'{k}'" for k in sorted(kwargs)]
text = [f"{name}() "] text = [f"{name}() "]
if len(quoted_kwargs) == 1: if len(quoted_kwargs) == 1:
@ -183,42 +195,44 @@ def create_unexpected_kwargs_error(name, kwargs):
class MissingContextParameter(DockerException): class MissingContextParameter(DockerException):
def __init__(self, param): def __init__(self, param: str) -> None:
self.param = param self.param = param
def __str__(self): def __str__(self) -> str:
return f"missing parameter: {self.param}" return f"missing parameter: {self.param}"
class ContextAlreadyExists(DockerException): class ContextAlreadyExists(DockerException):
def __init__(self, name): def __init__(self, name: str) -> None:
self.name = name self.name = name
def __str__(self): def __str__(self) -> str:
return f"context {self.name} already exists" return f"context {self.name} already exists"
class ContextException(DockerException): class ContextException(DockerException):
def __init__(self, msg): def __init__(self, msg: str) -> None:
self.msg = msg self.msg = msg
def __str__(self): def __str__(self) -> str:
return self.msg return self.msg
class ContextNotFound(DockerException): class ContextNotFound(DockerException):
def __init__(self, name): def __init__(self, name: str) -> None:
self.name = name self.name = name
def __str__(self): def __str__(self) -> str:
return f"context '{self.name}' not found" return f"context '{self.name}' not found"
class MissingRequirementException(DockerException): class MissingRequirementException(DockerException):
def __init__(self, msg, requirement, import_exception): def __init__(
self, msg: str, requirement: str, import_exception: ImportError | str
) -> None:
self.msg = msg self.msg = msg
self.requirement = requirement self.requirement = requirement
self.import_exception = import_exception self.import_exception = import_exception
def __str__(self): def __str__(self) -> str:
return self.msg return self.msg

View File

@ -12,12 +12,18 @@
from __future__ import annotations from __future__ import annotations
import os import os
import ssl import typing as t
from . import errors from . import errors
from .transport.ssladapter import SSLHTTPAdapter from .transport.ssladapter import SSLHTTPAdapter
if t.TYPE_CHECKING:
from ansible_collections.community.docker.plugins.module_utils._api.api.client import (
APIClient,
)
class TLSConfig: class TLSConfig:
""" """
TLS configuration. TLS configuration.
@ -27,25 +33,22 @@ class TLSConfig:
ca_cert (str): Path to CA cert file. ca_cert (str): Path to CA cert file.
verify (bool or str): This can be ``False`` or a path to a CA cert verify (bool or str): This can be ``False`` or a path to a CA cert
file. file.
ssl_version (int): A valid `SSL version`_.
assert_hostname (bool): Verify the hostname of the server. assert_hostname (bool): Verify the hostname of the server.
.. _`SSL version`: .. _`SSL version`:
https://docs.python.org/3.5/library/ssl.html#ssl.PROTOCOL_TLSv1 https://docs.python.org/3.5/library/ssl.html#ssl.PROTOCOL_TLSv1
""" """
cert = None cert: tuple[str, str] | None = None
ca_cert = None ca_cert: str | None = None
verify = None verify: bool | None = None
ssl_version = None
def __init__( def __init__(
self, self,
client_cert=None, client_cert: tuple[str, str] | None = None,
ca_cert=None, ca_cert: str | None = None,
verify=None, verify: bool | None = None,
ssl_version=None, assert_hostname: bool | None = None,
assert_hostname=None,
): ):
# Argument compatibility/mapping with # Argument compatibility/mapping with
# https://docs.docker.com/engine/articles/https/ # https://docs.docker.com/engine/articles/https/
@ -55,12 +58,6 @@ class TLSConfig:
self.assert_hostname = assert_hostname self.assert_hostname = assert_hostname
# If the user provides an SSL version, we should use their preference
if ssl_version:
self.ssl_version = ssl_version
else:
self.ssl_version = ssl.PROTOCOL_TLS_CLIENT
# "client_cert" must have both or neither cert/key files. In # "client_cert" must have both or neither cert/key files. In
# either case, Alert the user when both are expected, but any are # either case, Alert the user when both are expected, but any are
# missing. # missing.
@ -90,11 +87,10 @@ class TLSConfig:
"Invalid CA certificate provided for `ca_cert`." "Invalid CA certificate provided for `ca_cert`."
) )
def configure_client(self, client): def configure_client(self, client: APIClient) -> None:
""" """
Configure a client with these TLS options. Configure a client with these TLS options.
""" """
client.ssl_version = self.ssl_version
if self.verify and self.ca_cert: if self.verify and self.ca_cert:
client.verify = self.ca_cert client.verify = self.ca_cert
@ -107,7 +103,6 @@ class TLSConfig:
client.mount( client.mount(
"https://", "https://",
SSLHTTPAdapter( SSLHTTPAdapter(
ssl_version=self.ssl_version,
assert_hostname=self.assert_hostname, assert_hostname=self.assert_hostname,
), ),
) )

View File

@ -15,7 +15,7 @@ from .._import_helper import HTTPAdapter as _HTTPAdapter
class BaseHTTPAdapter(_HTTPAdapter): class BaseHTTPAdapter(_HTTPAdapter):
def close(self): def close(self) -> None:
super().close() super().close()
if hasattr(self, "pools"): if hasattr(self, "pools"):
self.pools.clear() self.pools.clear()
@ -24,10 +24,10 @@ class BaseHTTPAdapter(_HTTPAdapter):
# https://github.com/psf/requests/commit/c0813a2d910ea6b4f8438b91d315b8d181302356 # https://github.com/psf/requests/commit/c0813a2d910ea6b4f8438b91d315b8d181302356
# changes requests.adapters.HTTPAdapter to no longer call get_connection() from # changes requests.adapters.HTTPAdapter to no longer call get_connection() from
# send(), but instead call _get_connection(). # send(), but instead call _get_connection().
def _get_connection(self, request, *args, **kwargs): def _get_connection(self, request, *args, **kwargs): # type: ignore
return self.get_connection(request.url, kwargs.get("proxies")) return self.get_connection(request.url, kwargs.get("proxies"))
# Fix for requests 2.32.2+: # Fix for requests 2.32.2+:
# https://github.com/psf/requests/commit/c98e4d133ef29c46a9b68cd783087218a8075e05 # https://github.com/psf/requests/commit/c98e4d133ef29c46a9b68cd783087218a8075e05
def get_connection_with_tls_context(self, request, verify, proxies=None, cert=None): def get_connection_with_tls_context(self, request, verify, proxies=None, cert=None): # type: ignore
return self.get_connection(request.url, proxies) return self.get_connection(request.url, proxies)

View File

@ -23,12 +23,12 @@ RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
class NpipeHTTPConnection(urllib3_connection.HTTPConnection): class NpipeHTTPConnection(urllib3_connection.HTTPConnection):
def __init__(self, npipe_path, timeout=60): def __init__(self, npipe_path: str, timeout: int | float = 60) -> None:
super().__init__("localhost", timeout=timeout) super().__init__("localhost", timeout=timeout)
self.npipe_path = npipe_path self.npipe_path = npipe_path
self.timeout = timeout self.timeout = timeout
def connect(self): def connect(self) -> None:
sock = NpipeSocket() sock = NpipeSocket()
sock.settimeout(self.timeout) sock.settimeout(self.timeout)
sock.connect(self.npipe_path) sock.connect(self.npipe_path)
@ -36,18 +36,20 @@ class NpipeHTTPConnection(urllib3_connection.HTTPConnection):
class NpipeHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool): class NpipeHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
def __init__(self, npipe_path, timeout=60, maxsize=10): def __init__(
self, npipe_path: str, timeout: int | float = 60, maxsize: int = 10
) -> None:
super().__init__("localhost", timeout=timeout, maxsize=maxsize) super().__init__("localhost", timeout=timeout, maxsize=maxsize)
self.npipe_path = npipe_path self.npipe_path = npipe_path
self.timeout = timeout self.timeout = timeout
def _new_conn(self): def _new_conn(self) -> NpipeHTTPConnection:
return NpipeHTTPConnection(self.npipe_path, self.timeout) return NpipeHTTPConnection(self.npipe_path, self.timeout)
# When re-using connections, urllib3 tries to call select() on our # When re-using connections, urllib3 tries to call select() on our
# NpipeSocket instance, causing a crash. To circumvent this, we override # NpipeSocket instance, causing a crash. To circumvent this, we override
# _get_conn, where that check happens. # _get_conn, where that check happens.
def _get_conn(self, timeout): def _get_conn(self, timeout: int | float) -> NpipeHTTPConnection:
conn = None conn = None
try: try:
conn = self.pool.get(block=self.block, timeout=timeout) conn = self.pool.get(block=self.block, timeout=timeout)
@ -67,7 +69,6 @@ class NpipeHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
class NpipeHTTPAdapter(BaseHTTPAdapter): class NpipeHTTPAdapter(BaseHTTPAdapter):
__attrs__ = HTTPAdapter.__attrs__ + [ __attrs__ = HTTPAdapter.__attrs__ + [
"npipe_path", "npipe_path",
"pools", "pools",
@ -77,11 +78,11 @@ class NpipeHTTPAdapter(BaseHTTPAdapter):
def __init__( def __init__(
self, self,
base_url, base_url: str,
timeout=60, timeout: int | float = 60,
pool_connections=constants.DEFAULT_NUM_POOLS, pool_connections: int = constants.DEFAULT_NUM_POOLS,
max_pool_size=constants.DEFAULT_MAX_POOL_SIZE, max_pool_size: int = constants.DEFAULT_MAX_POOL_SIZE,
): ) -> None:
self.npipe_path = base_url.replace("npipe://", "") self.npipe_path = base_url.replace("npipe://", "")
self.timeout = timeout self.timeout = timeout
self.max_pool_size = max_pool_size self.max_pool_size = max_pool_size
@ -90,7 +91,7 @@ class NpipeHTTPAdapter(BaseHTTPAdapter):
) )
super().__init__() super().__init__()
def get_connection(self, url, proxies=None): def get_connection(self, url: str | bytes, proxies=None) -> NpipeHTTPConnectionPool:
with self.pools.lock: with self.pools.lock:
pool = self.pools.get(url) pool = self.pools.get(url)
if pool: if pool:
@ -103,7 +104,7 @@ class NpipeHTTPAdapter(BaseHTTPAdapter):
return pool return pool
def request_url(self, request, proxies): def request_url(self, request, proxies) -> str:
# The select_proxy utility in requests errors out when the provided URL # The select_proxy utility in requests errors out when the provided URL
# does not have a hostname, like is the case when using a UNIX socket. # does not have a hostname, like is the case when using a UNIX socket.
# Since proxies are an irrelevant notion in the case of UNIX sockets # Since proxies are an irrelevant notion in the case of UNIX sockets

View File

@ -15,8 +15,10 @@ import functools
import io import io
import time import time
import traceback import traceback
import typing as t
PYWIN32_IMPORT_ERROR: str | None # pylint: disable=invalid-name
try: try:
import pywintypes import pywintypes
import win32api import win32api
@ -28,6 +30,13 @@ except ImportError:
else: else:
PYWIN32_IMPORT_ERROR = None # pylint: disable=invalid-name PYWIN32_IMPORT_ERROR = None # pylint: disable=invalid-name
if t.TYPE_CHECKING:
from collections.abc import Buffer, Callable
_Self = t.TypeVar("_Self")
_P = t.ParamSpec("_P")
_R = t.TypeVar("_R")
ERROR_PIPE_BUSY = 0xE7 ERROR_PIPE_BUSY = 0xE7
SECURITY_SQOS_PRESENT = 0x100000 SECURITY_SQOS_PRESENT = 0x100000
@ -36,10 +45,12 @@ SECURITY_ANONYMOUS = 0
MAXIMUM_RETRY_COUNT = 10 MAXIMUM_RETRY_COUNT = 10
def check_closed(f): def check_closed(
f: Callable[t.Concatenate[_Self, _P], _R],
) -> Callable[t.Concatenate[_Self, _P], _R]:
@functools.wraps(f) @functools.wraps(f)
def wrapped(self, *args, **kwargs): def wrapped(self: _Self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
if self._closed: if self._closed: # type: ignore
raise RuntimeError("Can not reuse socket after connection was closed.") raise RuntimeError("Can not reuse socket after connection was closed.")
return f(self, *args, **kwargs) return f(self, *args, **kwargs)
@ -53,25 +64,25 @@ class NpipeSocket:
implemented. implemented.
""" """
def __init__(self, handle=None): def __init__(self, handle=None) -> None:
self._timeout = win32pipe.NMPWAIT_USE_DEFAULT_WAIT self._timeout = win32pipe.NMPWAIT_USE_DEFAULT_WAIT
self._handle = handle self._handle = handle
self._address = None self._address: str | None = None
self._closed = False self._closed = False
self.flags = None self.flags: int | None = None
def accept(self): def accept(self) -> t.NoReturn:
raise NotImplementedError() raise NotImplementedError()
def bind(self, address): def bind(self, address) -> t.NoReturn:
raise NotImplementedError() raise NotImplementedError()
def close(self): def close(self) -> None:
self._handle.Close() self._handle.Close()
self._closed = True self._closed = True
@check_closed @check_closed
def connect(self, address, retry_count=0): def connect(self, address, retry_count: int = 0) -> None:
try: try:
handle = win32file.CreateFile( handle = win32file.CreateFile(
address, address,
@ -99,14 +110,14 @@ class NpipeSocket:
return self.connect(address, retry_count) return self.connect(address, retry_count)
raise e raise e
self.flags = win32pipe.GetNamedPipeInfo(handle)[0] self.flags = win32pipe.GetNamedPipeInfo(handle)[0] # type: ignore
self._handle = handle self._handle = handle
self._address = address self._address = address
@check_closed @check_closed
def connect_ex(self, address): def connect_ex(self, address) -> None:
return self.connect(address) self.connect(address)
@check_closed @check_closed
def detach(self): def detach(self):
@ -114,25 +125,25 @@ class NpipeSocket:
return self._handle return self._handle
@check_closed @check_closed
def dup(self): def dup(self) -> NpipeSocket:
return NpipeSocket(self._handle) return NpipeSocket(self._handle)
def getpeername(self): def getpeername(self) -> str | None:
return self._address return self._address
def getsockname(self): def getsockname(self) -> str | None:
return self._address return self._address
def getsockopt(self, level, optname, buflen=None): def getsockopt(self, level, optname, buflen=None) -> t.NoReturn:
raise NotImplementedError() raise NotImplementedError()
def ioctl(self, control, option): def ioctl(self, control, option) -> t.NoReturn:
raise NotImplementedError() raise NotImplementedError()
def listen(self, backlog): def listen(self, backlog) -> t.NoReturn:
raise NotImplementedError() raise NotImplementedError()
def makefile(self, mode=None, bufsize=None): def makefile(self, mode: str, bufsize: int | None = None):
if mode.strip("b") != "r": if mode.strip("b") != "r":
raise NotImplementedError() raise NotImplementedError()
rawio = NpipeFileIOBase(self) rawio = NpipeFileIOBase(self)
@ -141,30 +152,30 @@ class NpipeSocket:
return io.BufferedReader(rawio, buffer_size=bufsize) return io.BufferedReader(rawio, buffer_size=bufsize)
@check_closed @check_closed
def recv(self, bufsize, flags=0): def recv(self, bufsize: int, flags: int = 0) -> str:
dummy_err, data = win32file.ReadFile(self._handle, bufsize) dummy_err, data = win32file.ReadFile(self._handle, bufsize)
return data return data
@check_closed @check_closed
def recvfrom(self, bufsize, flags=0): def recvfrom(self, bufsize: int, flags: int = 0) -> tuple[str, str | None]:
data = self.recv(bufsize, flags) data = self.recv(bufsize, flags)
return (data, self._address) return (data, self._address)
@check_closed @check_closed
def recvfrom_into(self, buf, nbytes=0, flags=0): def recvfrom_into(
return self.recv_into(buf, nbytes, flags), self._address self, buf: Buffer, nbytes: int = 0, flags: int = 0
) -> tuple[int, str | None]:
return self.recv_into(buf, nbytes), self._address
@check_closed @check_closed
def recv_into(self, buf, nbytes=0): def recv_into(self, buf: Buffer, nbytes: int = 0) -> int:
readbuf = buf readbuf = buf if isinstance(buf, memoryview) else memoryview(buf)
if not isinstance(buf, memoryview):
readbuf = memoryview(buf)
event = win32event.CreateEvent(None, True, True, None) event = win32event.CreateEvent(None, True, True, None)
try: try:
overlapped = pywintypes.OVERLAPPED() overlapped = pywintypes.OVERLAPPED()
overlapped.hEvent = event overlapped.hEvent = event
dummy_err, dummy_data = win32file.ReadFile( dummy_err, dummy_data = win32file.ReadFile( # type: ignore
self._handle, readbuf[:nbytes] if nbytes else readbuf, overlapped self._handle, readbuf[:nbytes] if nbytes else readbuf, overlapped
) )
wait_result = win32event.WaitForSingleObject(event, self._timeout) wait_result = win32event.WaitForSingleObject(event, self._timeout)
@ -176,12 +187,12 @@ class NpipeSocket:
win32api.CloseHandle(event) win32api.CloseHandle(event)
@check_closed @check_closed
def send(self, string, flags=0): def send(self, string: Buffer, flags: int = 0) -> int:
event = win32event.CreateEvent(None, True, True, None) event = win32event.CreateEvent(None, True, True, None)
try: try:
overlapped = pywintypes.OVERLAPPED() overlapped = pywintypes.OVERLAPPED()
overlapped.hEvent = event overlapped.hEvent = event
win32file.WriteFile(self._handle, string, overlapped) win32file.WriteFile(self._handle, string, overlapped) # type: ignore
wait_result = win32event.WaitForSingleObject(event, self._timeout) wait_result = win32event.WaitForSingleObject(event, self._timeout)
if wait_result == win32event.WAIT_TIMEOUT: if wait_result == win32event.WAIT_TIMEOUT:
win32file.CancelIo(self._handle) win32file.CancelIo(self._handle)
@ -191,20 +202,20 @@ class NpipeSocket:
win32api.CloseHandle(event) win32api.CloseHandle(event)
@check_closed @check_closed
def sendall(self, string, flags=0): def sendall(self, string: Buffer, flags: int = 0) -> int:
return self.send(string, flags) return self.send(string, flags)
@check_closed @check_closed
def sendto(self, string, address): def sendto(self, string: Buffer, address: str) -> int:
self.connect(address) self.connect(address)
return self.send(string) return self.send(string)
def setblocking(self, flag): def setblocking(self, flag: bool):
if flag: if flag:
return self.settimeout(None) return self.settimeout(None)
return self.settimeout(0) return self.settimeout(0)
def settimeout(self, value): def settimeout(self, value: int | float | None) -> None:
if value is None: if value is None:
# Blocking mode # Blocking mode
self._timeout = win32event.INFINITE self._timeout = win32event.INFINITE
@ -214,39 +225,39 @@ class NpipeSocket:
# Timeout mode - Value converted to milliseconds # Timeout mode - Value converted to milliseconds
self._timeout = int(value * 1000) self._timeout = int(value * 1000)
def gettimeout(self): def gettimeout(self) -> int | float | None:
return self._timeout return self._timeout
def setsockopt(self, level, optname, value): def setsockopt(self, level, optname, value) -> t.NoReturn:
raise NotImplementedError() raise NotImplementedError()
@check_closed @check_closed
def shutdown(self, how): def shutdown(self, how) -> None:
return self.close() return self.close()
class NpipeFileIOBase(io.RawIOBase): class NpipeFileIOBase(io.RawIOBase):
def __init__(self, npipe_socket): def __init__(self, npipe_socket) -> None:
self.sock = npipe_socket self.sock = npipe_socket
def close(self): def close(self) -> None:
super().close() super().close()
self.sock = None self.sock = None
def fileno(self): def fileno(self) -> int:
return self.sock.fileno() return self.sock.fileno()
def isatty(self): def isatty(self) -> bool:
return False return False
def readable(self): def readable(self) -> bool:
return True return True
def readinto(self, buf): def readinto(self, buf: Buffer) -> int:
return self.sock.recv_into(buf) return self.sock.recv_into(buf)
def seekable(self): def seekable(self) -> bool:
return False return False
def writable(self): def writable(self) -> bool:
return False return False

View File

@ -17,6 +17,7 @@ import signal
import socket import socket
import subprocess import subprocess
import traceback import traceback
import typing as t
from queue import Empty from queue import Empty
from urllib.parse import urlparse from urllib.parse import urlparse
@ -25,6 +26,7 @@ from .._import_helper import HTTPAdapter, urllib3, urllib3_connection
from .basehttpadapter import BaseHTTPAdapter from .basehttpadapter import BaseHTTPAdapter
PARAMIKO_IMPORT_ERROR: str | None # pylint: disable=invalid-name
try: try:
import paramiko import paramiko
except ImportError: except ImportError:
@ -32,12 +34,15 @@ except ImportError:
else: else:
PARAMIKO_IMPORT_ERROR = None # pylint: disable=invalid-name PARAMIKO_IMPORT_ERROR = None # pylint: disable=invalid-name
if t.TYPE_CHECKING:
from collections.abc import Buffer
RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
class SSHSocket(socket.socket): class SSHSocket(socket.socket):
def __init__(self, host): def __init__(self, host: str) -> None:
super().__init__(socket.AF_INET, socket.SOCK_STREAM) super().__init__(socket.AF_INET, socket.SOCK_STREAM)
self.host = host self.host = host
self.port = None self.port = None
@ -47,9 +52,9 @@ class SSHSocket(socket.socket):
if "@" in self.host: if "@" in self.host:
self.user, self.host = self.host.split("@") self.user, self.host = self.host.split("@")
self.proc = None self.proc: subprocess.Popen | None = None
def connect(self, **kwargs): def connect(self, *args_: t.Any, **kwargs: t.Any) -> None:
args = ["ssh"] args = ["ssh"]
if self.user: if self.user:
args = args + ["-l", self.user] args = args + ["-l", self.user]
@ -81,37 +86,48 @@ class SSHSocket(socket.socket):
preexec_fn=preexec_func, preexec_fn=preexec_func,
) )
def _write(self, data): def _write(self, data: Buffer) -> int:
if not self.proc or self.proc.stdin.closed: if not self.proc:
raise RuntimeError( raise RuntimeError(
"SSH subprocess not initiated. connect() must be called first." "SSH subprocess not initiated. connect() must be called first."
) )
assert self.proc.stdin is not None
if self.proc.stdin.closed:
raise RuntimeError(
"SSH subprocess not initiated. connect() must be called first after close()."
)
written = self.proc.stdin.write(data) written = self.proc.stdin.write(data)
self.proc.stdin.flush() self.proc.stdin.flush()
return written return written
def sendall(self, data): def sendall(self, data: Buffer, *args, **kwargs) -> None:
self._write(data) self._write(data)
def send(self, data): def send(self, data: Buffer, *args, **kwargs) -> int:
return self._write(data) return self._write(data)
def recv(self, n): def recv(self, n: int, *args, **kwargs) -> bytes:
if not self.proc: if not self.proc:
raise RuntimeError( raise RuntimeError(
"SSH subprocess not initiated. connect() must be called first." "SSH subprocess not initiated. connect() must be called first."
) )
assert self.proc.stdout is not None
return self.proc.stdout.read(n) return self.proc.stdout.read(n)
def makefile(self, mode): def makefile(self, mode: str, *args, **kwargs) -> t.IO: # type: ignore
if not self.proc: if not self.proc:
self.connect() self.connect()
self.proc.stdout.channel = self assert self.proc is not None
assert self.proc.stdout is not None
self.proc.stdout.channel = self # type: ignore
return self.proc.stdout return self.proc.stdout
def close(self): def close(self) -> None:
if not self.proc or self.proc.stdin.closed: if not self.proc:
return
assert self.proc.stdin is not None
if self.proc.stdin.closed:
return return
self.proc.stdin.write(b"\n\n") self.proc.stdin.write(b"\n\n")
self.proc.stdin.flush() self.proc.stdin.flush()
@ -119,13 +135,19 @@ class SSHSocket(socket.socket):
class SSHConnection(urllib3_connection.HTTPConnection): class SSHConnection(urllib3_connection.HTTPConnection):
def __init__(self, ssh_transport=None, timeout=60, host=None): def __init__(
self,
*,
ssh_transport=None,
timeout: int | float = 60,
host: str,
) -> None:
super().__init__("localhost", timeout=timeout) super().__init__("localhost", timeout=timeout)
self.ssh_transport = ssh_transport self.ssh_transport = ssh_transport
self.timeout = timeout self.timeout = timeout
self.ssh_host = host self.ssh_host = host
def connect(self): def connect(self) -> None:
if self.ssh_transport: if self.ssh_transport:
sock = self.ssh_transport.open_session() sock = self.ssh_transport.open_session()
sock.settimeout(self.timeout) sock.settimeout(self.timeout)
@ -141,7 +163,14 @@ class SSHConnection(urllib3_connection.HTTPConnection):
class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool): class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
scheme = "ssh" scheme = "ssh"
def __init__(self, ssh_client=None, timeout=60, maxsize=10, host=None): def __init__(
self,
*,
ssh_client: paramiko.SSHClient | None = None,
timeout: int | float = 60,
maxsize: int = 10,
host: str,
) -> None:
super().__init__("localhost", timeout=timeout, maxsize=maxsize) super().__init__("localhost", timeout=timeout, maxsize=maxsize)
self.ssh_transport = None self.ssh_transport = None
self.timeout = timeout self.timeout = timeout
@ -149,13 +178,17 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
self.ssh_transport = ssh_client.get_transport() self.ssh_transport = ssh_client.get_transport()
self.ssh_host = host self.ssh_host = host
def _new_conn(self): def _new_conn(self) -> SSHConnection:
return SSHConnection(self.ssh_transport, self.timeout, self.ssh_host) return SSHConnection(
ssh_transport=self.ssh_transport,
timeout=self.timeout,
host=self.ssh_host,
)
# When re-using connections, urllib3 calls fileno() on our # When re-using connections, urllib3 calls fileno() on our
# SSH channel instance, quickly overloading our fd limit. To avoid this, # SSH channel instance, quickly overloading our fd limit. To avoid this,
# we override _get_conn # we override _get_conn
def _get_conn(self, timeout): def _get_conn(self, timeout: int | float) -> SSHConnection:
conn = None conn = None
try: try:
conn = self.pool.get(block=self.block, timeout=timeout) conn = self.pool.get(block=self.block, timeout=timeout)
@ -175,7 +208,6 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
class SSHHTTPAdapter(BaseHTTPAdapter): class SSHHTTPAdapter(BaseHTTPAdapter):
__attrs__ = HTTPAdapter.__attrs__ + [ __attrs__ = HTTPAdapter.__attrs__ + [
"pools", "pools",
"timeout", "timeout",
@ -186,13 +218,13 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
def __init__( def __init__(
self, self,
base_url, base_url: str,
timeout=60, timeout: int | float = 60,
pool_connections=constants.DEFAULT_NUM_POOLS, pool_connections: int = constants.DEFAULT_NUM_POOLS,
max_pool_size=constants.DEFAULT_MAX_POOL_SIZE, max_pool_size: int = constants.DEFAULT_MAX_POOL_SIZE,
shell_out=False, shell_out: bool = False,
): ) -> None:
self.ssh_client = None self.ssh_client: paramiko.SSHClient | None = None
if not shell_out: if not shell_out:
self._create_paramiko_client(base_url) self._create_paramiko_client(base_url)
self._connect() self._connect()
@ -208,30 +240,31 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
) )
super().__init__() super().__init__()
def _create_paramiko_client(self, base_url): def _create_paramiko_client(self, base_url: str) -> None:
logging.getLogger("paramiko").setLevel(logging.WARNING) logging.getLogger("paramiko").setLevel(logging.WARNING)
self.ssh_client = paramiko.SSHClient() self.ssh_client = paramiko.SSHClient()
base_url = urlparse(base_url) base_url_p = urlparse(base_url)
self.ssh_params = { assert base_url_p.hostname is not None
"hostname": base_url.hostname, self.ssh_params: dict[str, t.Any] = {
"port": base_url.port, "hostname": base_url_p.hostname,
"username": base_url.username, "port": base_url_p.port,
"username": base_url_p.username,
} }
ssh_config_file = os.path.expanduser("~/.ssh/config") ssh_config_file = os.path.expanduser("~/.ssh/config")
if os.path.exists(ssh_config_file): if os.path.exists(ssh_config_file):
conf = paramiko.SSHConfig() conf = paramiko.SSHConfig()
with open(ssh_config_file, "rt", encoding="utf-8") as f: with open(ssh_config_file, "rt", encoding="utf-8") as f:
conf.parse(f) conf.parse(f)
host_config = conf.lookup(base_url.hostname) host_config = conf.lookup(base_url_p.hostname)
if "proxycommand" in host_config: if "proxycommand" in host_config:
self.ssh_params["sock"] = paramiko.ProxyCommand( self.ssh_params["sock"] = paramiko.ProxyCommand(
host_config["proxycommand"] host_config["proxycommand"]
) )
if "hostname" in host_config: if "hostname" in host_config:
self.ssh_params["hostname"] = host_config["hostname"] self.ssh_params["hostname"] = host_config["hostname"]
if base_url.port is None and "port" in host_config: if base_url_p.port is None and "port" in host_config:
self.ssh_params["port"] = host_config["port"] self.ssh_params["port"] = host_config["port"]
if base_url.username is None and "user" in host_config: if base_url_p.username is None and "user" in host_config:
self.ssh_params["username"] = host_config["user"] self.ssh_params["username"] = host_config["user"]
if "identityfile" in host_config: if "identityfile" in host_config:
self.ssh_params["key_filename"] = host_config["identityfile"] self.ssh_params["key_filename"] = host_config["identityfile"]
@ -239,11 +272,11 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
self.ssh_client.load_system_host_keys() self.ssh_client.load_system_host_keys()
self.ssh_client.set_missing_host_key_policy(paramiko.RejectPolicy()) self.ssh_client.set_missing_host_key_policy(paramiko.RejectPolicy())
def _connect(self): def _connect(self) -> None:
if self.ssh_client: if self.ssh_client:
self.ssh_client.connect(**self.ssh_params) self.ssh_client.connect(**self.ssh_params)
def get_connection(self, url, proxies=None): def get_connection(self, url: str | bytes, proxies=None) -> SSHConnectionPool:
if not self.ssh_client: if not self.ssh_client:
return SSHConnectionPool( return SSHConnectionPool(
ssh_client=self.ssh_client, ssh_client=self.ssh_client,
@ -270,7 +303,7 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
return pool return pool
def close(self): def close(self) -> None:
super().close() super().close()
if self.ssh_client: if self.ssh_client:
self.ssh_client.close() self.ssh_client.close()

View File

@ -11,9 +11,7 @@
from __future__ import annotations from __future__ import annotations
from ansible_collections.community.docker.plugins.module_utils._version import ( import typing as t
LooseVersion,
)
from .._import_helper import HTTPAdapter, urllib3 from .._import_helper import HTTPAdapter, urllib3
from .basehttpadapter import BaseHTTPAdapter from .basehttpadapter import BaseHTTPAdapter
@ -30,14 +28,19 @@ PoolManager = urllib3.poolmanager.PoolManager
class SSLHTTPAdapter(BaseHTTPAdapter): class SSLHTTPAdapter(BaseHTTPAdapter):
"""An HTTPS Transport Adapter that uses an arbitrary SSL version.""" """An HTTPS Transport Adapter that uses an arbitrary SSL version."""
__attrs__ = HTTPAdapter.__attrs__ + ["assert_hostname", "ssl_version"] __attrs__ = HTTPAdapter.__attrs__ + ["assert_hostname"]
def __init__(self, ssl_version=None, assert_hostname=None, **kwargs): def __init__(
self.ssl_version = ssl_version self,
assert_hostname: bool | None = None,
**kwargs,
) -> None:
self.assert_hostname = assert_hostname self.assert_hostname = assert_hostname
super().__init__(**kwargs) super().__init__(**kwargs)
def init_poolmanager(self, connections, maxsize, block=False): def init_poolmanager(
self, connections: int, maxsize: int, block: bool = False, **kwargs: t.Any
) -> None:
kwargs = { kwargs = {
"num_pools": connections, "num_pools": connections,
"maxsize": maxsize, "maxsize": maxsize,
@ -45,12 +48,10 @@ class SSLHTTPAdapter(BaseHTTPAdapter):
} }
if self.assert_hostname is not None: if self.assert_hostname is not None:
kwargs["assert_hostname"] = self.assert_hostname kwargs["assert_hostname"] = self.assert_hostname
if self.ssl_version and self.can_override_ssl_version():
kwargs["ssl_version"] = self.ssl_version
self.poolmanager = PoolManager(**kwargs) self.poolmanager = PoolManager(**kwargs)
def get_connection(self, *args, **kwargs): def get_connection(self, *args, **kwargs) -> urllib3.ConnectionPool:
""" """
Ensure assert_hostname is set correctly on our pool Ensure assert_hostname is set correctly on our pool
@ -61,15 +62,7 @@ class SSLHTTPAdapter(BaseHTTPAdapter):
conn = super().get_connection(*args, **kwargs) conn = super().get_connection(*args, **kwargs)
if ( if (
self.assert_hostname is not None self.assert_hostname is not None
and conn.assert_hostname != self.assert_hostname and conn.assert_hostname != self.assert_hostname # type: ignore
): ):
conn.assert_hostname = self.assert_hostname conn.assert_hostname = self.assert_hostname # type: ignore
return conn return conn
def can_override_ssl_version(self):
urllib_ver = urllib3.__version__.split("-")[0]
if urllib_ver is None:
return False
if urllib_ver == "dev":
return True
return LooseVersion(urllib_ver) > LooseVersion("1.5")

View File

@ -12,6 +12,7 @@
from __future__ import annotations from __future__ import annotations
import socket import socket
import typing as t
from .. import constants from .. import constants
from .._import_helper import HTTPAdapter, urllib3, urllib3_connection from .._import_helper import HTTPAdapter, urllib3, urllib3_connection
@ -22,26 +23,27 @@ RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
class UnixHTTPConnection(urllib3_connection.HTTPConnection): class UnixHTTPConnection(urllib3_connection.HTTPConnection):
def __init__(
def __init__(self, base_url, unix_socket, timeout=60): self, base_url: str | bytes, unix_socket, timeout: int | float = 60
) -> None:
super().__init__("localhost", timeout=timeout) super().__init__("localhost", timeout=timeout)
self.base_url = base_url self.base_url = base_url
self.unix_socket = unix_socket self.unix_socket = unix_socket
self.timeout = timeout self.timeout = timeout
self.disable_buffering = False self.disable_buffering = False
def connect(self): def connect(self) -> None:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.settimeout(self.timeout) sock.settimeout(self.timeout)
sock.connect(self.unix_socket) sock.connect(self.unix_socket)
self.sock = sock self.sock = sock
def putheader(self, header, *values): def putheader(self, header: str, *values: str) -> None:
super().putheader(header, *values) super().putheader(header, *values)
if header == "Connection" and "Upgrade" in values: if header == "Connection" and "Upgrade" in values:
self.disable_buffering = True self.disable_buffering = True
def response_class(self, sock, *args, **kwargs): def response_class(self, sock, *args, **kwargs) -> t.Any:
# FIXME: We may need to disable buffering on Py3, # FIXME: We may need to disable buffering on Py3,
# but there's no clear way to do it at the moment. See: # but there's no clear way to do it at the moment. See:
# https://github.com/docker/docker-py/issues/1799 # https://github.com/docker/docker-py/issues/1799
@ -49,18 +51,23 @@ class UnixHTTPConnection(urllib3_connection.HTTPConnection):
class UnixHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool): class UnixHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
def __init__(self, base_url, socket_path, timeout=60, maxsize=10): def __init__(
self,
base_url: str | bytes,
socket_path: str,
timeout: int | float = 60,
maxsize: int = 10,
) -> None:
super().__init__("localhost", timeout=timeout, maxsize=maxsize) super().__init__("localhost", timeout=timeout, maxsize=maxsize)
self.base_url = base_url self.base_url = base_url
self.socket_path = socket_path self.socket_path = socket_path
self.timeout = timeout self.timeout = timeout
def _new_conn(self): def _new_conn(self) -> UnixHTTPConnection:
return UnixHTTPConnection(self.base_url, self.socket_path, self.timeout) return UnixHTTPConnection(self.base_url, self.socket_path, self.timeout)
class UnixHTTPAdapter(BaseHTTPAdapter): class UnixHTTPAdapter(BaseHTTPAdapter):
__attrs__ = HTTPAdapter.__attrs__ + [ __attrs__ = HTTPAdapter.__attrs__ + [
"pools", "pools",
"socket_path", "socket_path",
@ -70,11 +77,11 @@ class UnixHTTPAdapter(BaseHTTPAdapter):
def __init__( def __init__(
self, self,
socket_url, socket_url: str,
timeout=60, timeout: int | float = 60,
pool_connections=constants.DEFAULT_NUM_POOLS, pool_connections: int = constants.DEFAULT_NUM_POOLS,
max_pool_size=constants.DEFAULT_MAX_POOL_SIZE, max_pool_size: int = constants.DEFAULT_MAX_POOL_SIZE,
): ) -> None:
socket_path = socket_url.replace("http+unix://", "") socket_path = socket_url.replace("http+unix://", "")
if not socket_path.startswith("/"): if not socket_path.startswith("/"):
socket_path = "/" + socket_path socket_path = "/" + socket_path
@ -86,7 +93,7 @@ class UnixHTTPAdapter(BaseHTTPAdapter):
) )
super().__init__() super().__init__()
def get_connection(self, url, proxies=None): def get_connection(self, url: str | bytes, proxies=None) -> UnixHTTPConnectionPool:
with self.pools.lock: with self.pools.lock:
pool = self.pools.get(url) pool = self.pools.get(url)
if pool: if pool:
@ -99,7 +106,7 @@ class UnixHTTPAdapter(BaseHTTPAdapter):
return pool return pool
def request_url(self, request, proxies): def request_url(self, request, proxies) -> str:
# The select_proxy utility in requests errors out when the provided URL # The select_proxy utility in requests errors out when the provided URL
# does not have a hostname, like is the case when using a UNIX socket. # does not have a hostname, like is the case when using a UNIX socket.
# Since proxies are an irrelevant notion in the case of UNIX sockets # Since proxies are an irrelevant notion in the case of UNIX sockets

View File

@ -12,6 +12,7 @@
from __future__ import annotations from __future__ import annotations
import socket import socket
import typing as t
from .._import_helper import urllib3 from .._import_helper import urllib3
from ..errors import DockerException from ..errors import DockerException
@ -29,11 +30,11 @@ class CancellableStream:
>>> events.close() >>> events.close()
""" """
def __init__(self, stream, response): def __init__(self, stream, response) -> None:
self._stream = stream self._stream = stream
self._response = response self._response = response
def __iter__(self): def __iter__(self) -> t.Self:
return self return self
def __next__(self): def __next__(self):
@ -46,7 +47,7 @@ class CancellableStream:
next = __next__ next = __next__
def close(self): def close(self) -> None:
""" """
Closes the event streaming. Closes the event streaming.
""" """

View File

@ -17,26 +17,38 @@ import random
import re import re
import tarfile import tarfile
import tempfile import tempfile
import typing as t
from ..constants import IS_WINDOWS_PLATFORM, WINDOWS_LONGPATH_PREFIX from ..constants import IS_WINDOWS_PLATFORM, WINDOWS_LONGPATH_PREFIX
from . import fnmatch from . import fnmatch
if t.TYPE_CHECKING:
from collections.abc import Sequence
_SEP = re.compile("/|\\\\") if IS_WINDOWS_PLATFORM else re.compile("/") _SEP = re.compile("/|\\\\") if IS_WINDOWS_PLATFORM else re.compile("/")
def tar(path, exclude=None, dockerfile=None, fileobj=None, gzip=False): def tar(
path: str,
exclude: list[str] | None = None,
dockerfile: tuple[str, str | None] | tuple[None, None] | None = None,
fileobj: t.IO[bytes] | None = None,
gzip: bool = False,
) -> t.IO[bytes]:
root = os.path.abspath(path) root = os.path.abspath(path)
exclude = exclude or [] exclude = exclude or []
dockerfile = dockerfile or (None, None) dockerfile = dockerfile or (None, None)
extra_files = [] extra_files: list[tuple[str, str]] = []
if dockerfile[1] is not None: if dockerfile[1] is not None:
assert dockerfile[0] is not None
dockerignore_contents = "\n".join( dockerignore_contents = "\n".join(
(exclude or [".dockerignore"]) + [dockerfile[0]] (exclude or [".dockerignore"]) + [dockerfile[0]]
) )
extra_files = [ extra_files = [
(".dockerignore", dockerignore_contents), (".dockerignore", dockerignore_contents),
dockerfile, dockerfile, # type: ignore
] ]
return create_archive( return create_archive(
files=sorted(exclude_paths(root, exclude, dockerfile=dockerfile[0])), files=sorted(exclude_paths(root, exclude, dockerfile=dockerfile[0])),
@ -47,7 +59,9 @@ def tar(path, exclude=None, dockerfile=None, fileobj=None, gzip=False):
) )
def exclude_paths(root, patterns, dockerfile=None): def exclude_paths(
root: str, patterns: list[str], dockerfile: str | None = None
) -> set[str]:
""" """
Given a root directory path and a list of .dockerignore patterns, return Given a root directory path and a list of .dockerignore patterns, return
an iterator of all paths (both regular files and directories) in the root an iterator of all paths (both regular files and directories) in the root
@ -64,7 +78,7 @@ def exclude_paths(root, patterns, dockerfile=None):
return set(pm.walk(root)) return set(pm.walk(root))
def build_file_list(root): def build_file_list(root: str) -> list[str]:
files = [] files = []
for dirname, dirnames, fnames in os.walk(root): for dirname, dirnames, fnames in os.walk(root):
for filename in fnames + dirnames: for filename in fnames + dirnames:
@ -74,7 +88,13 @@ def build_file_list(root):
return files return files
def create_archive(root, files=None, fileobj=None, gzip=False, extra_files=None): def create_archive(
root: str,
files: Sequence[str] | None = None,
fileobj: t.IO[bytes] | None = None,
gzip: bool = False,
extra_files: Sequence[tuple[str, str]] | None = None,
) -> t.IO[bytes]:
extra_files = extra_files or [] extra_files = extra_files or []
if not fileobj: if not fileobj:
fileobj = tempfile.NamedTemporaryFile() fileobj = tempfile.NamedTemporaryFile()
@ -92,7 +112,7 @@ def create_archive(root, files=None, fileobj=None, gzip=False, extra_files=None)
if i is None: if i is None:
# This happens when we encounter a socket file. We can safely # This happens when we encounter a socket file. We can safely
# ignore it and proceed. # ignore it and proceed.
continue continue # type: ignore
# Workaround https://bugs.python.org/issue32713 # Workaround https://bugs.python.org/issue32713
if i.mtime < 0 or i.mtime > 8**11 - 1: if i.mtime < 0 or i.mtime > 8**11 - 1:
@ -124,11 +144,11 @@ def create_archive(root, files=None, fileobj=None, gzip=False, extra_files=None)
return fileobj return fileobj
def mkbuildcontext(dockerfile): def mkbuildcontext(dockerfile: io.BytesIO | t.IO[bytes]) -> t.IO[bytes]:
f = tempfile.NamedTemporaryFile() # pylint: disable=consider-using-with f = tempfile.NamedTemporaryFile() # pylint: disable=consider-using-with
try: try:
with tarfile.open(mode="w", fileobj=f) as t: with tarfile.open(mode="w", fileobj=f) as t:
if isinstance(dockerfile, io.StringIO): if isinstance(dockerfile, io.StringIO): # type: ignore
raise TypeError("Please use io.BytesIO to create in-memory Dockerfiles") raise TypeError("Please use io.BytesIO to create in-memory Dockerfiles")
if isinstance(dockerfile, io.BytesIO): if isinstance(dockerfile, io.BytesIO):
dfinfo = tarfile.TarInfo("Dockerfile") dfinfo = tarfile.TarInfo("Dockerfile")
@ -144,17 +164,17 @@ def mkbuildcontext(dockerfile):
return f return f
def split_path(p): def split_path(p: str) -> list[str]:
return [pt for pt in re.split(_SEP, p) if pt and pt != "."] return [pt for pt in re.split(_SEP, p) if pt and pt != "."]
def normalize_slashes(p): def normalize_slashes(p: str) -> str:
if IS_WINDOWS_PLATFORM: if IS_WINDOWS_PLATFORM:
return "/".join(split_path(p)) return "/".join(split_path(p))
return p return p
def walk(root, patterns, default=True): def walk(root: str, patterns: Sequence[str], default: bool = True) -> t.Generator[str]:
pm = PatternMatcher(patterns) pm = PatternMatcher(patterns)
return pm.walk(root) return pm.walk(root)
@ -162,11 +182,11 @@ def walk(root, patterns, default=True):
# Heavily based on # Heavily based on
# https://github.com/moby/moby/blob/master/pkg/fileutils/fileutils.go # https://github.com/moby/moby/blob/master/pkg/fileutils/fileutils.go
class PatternMatcher: class PatternMatcher:
def __init__(self, patterns): def __init__(self, patterns: Sequence[str]) -> None:
self.patterns = list(filter(lambda p: p.dirs, [Pattern(p) for p in patterns])) self.patterns = list(filter(lambda p: p.dirs, [Pattern(p) for p in patterns]))
self.patterns.append(Pattern("!.dockerignore")) self.patterns.append(Pattern("!.dockerignore"))
def matches(self, filepath): def matches(self, filepath: str) -> bool:
matched = False matched = False
parent_path = os.path.dirname(filepath) parent_path = os.path.dirname(filepath)
parent_path_dirs = split_path(parent_path) parent_path_dirs = split_path(parent_path)
@ -185,8 +205,8 @@ class PatternMatcher:
return matched return matched
def walk(self, root): def walk(self, root: str) -> t.Generator[str]:
def rec_walk(current_dir): def rec_walk(current_dir: str) -> t.Generator[str]:
for f in os.listdir(current_dir): for f in os.listdir(current_dir):
fpath = os.path.join(os.path.relpath(current_dir, root), f) fpath = os.path.join(os.path.relpath(current_dir, root), f)
if fpath.startswith("." + os.path.sep): if fpath.startswith("." + os.path.sep):
@ -220,7 +240,7 @@ class PatternMatcher:
class Pattern: class Pattern:
def __init__(self, pattern_str): def __init__(self, pattern_str: str) -> None:
self.exclusion = False self.exclusion = False
if pattern_str.startswith("!"): if pattern_str.startswith("!"):
self.exclusion = True self.exclusion = True
@ -230,8 +250,7 @@ class Pattern:
self.cleaned_pattern = "/".join(self.dirs) self.cleaned_pattern = "/".join(self.dirs)
@classmethod @classmethod
def normalize(cls, p): def normalize(cls, p: str) -> list[str]:
# Remove trailing spaces # Remove trailing spaces
p = p.strip() p = p.strip()
@ -256,11 +275,13 @@ class Pattern:
i += 1 i += 1
return split return split
def match(self, filepath): def match(self, filepath: str) -> bool:
return fnmatch.fnmatch(normalize_slashes(filepath), self.cleaned_pattern) return fnmatch.fnmatch(normalize_slashes(filepath), self.cleaned_pattern)
def process_dockerfile(dockerfile, path): def process_dockerfile(
dockerfile: str | None, path: str
) -> tuple[str, str | None] | tuple[None, None]:
if not dockerfile: if not dockerfile:
return (None, None) return (None, None)
@ -268,7 +289,7 @@ def process_dockerfile(dockerfile, path):
if not os.path.isabs(dockerfile): if not os.path.isabs(dockerfile):
abs_dockerfile = os.path.join(path, dockerfile) abs_dockerfile = os.path.join(path, dockerfile)
if IS_WINDOWS_PLATFORM and path.startswith(WINDOWS_LONGPATH_PREFIX): if IS_WINDOWS_PLATFORM and path.startswith(WINDOWS_LONGPATH_PREFIX):
abs_dockerfile = f"{WINDOWS_LONGPATH_PREFIX}{os.path.normpath(abs_dockerfile[len(WINDOWS_LONGPATH_PREFIX):])}" abs_dockerfile = f"{WINDOWS_LONGPATH_PREFIX}{os.path.normpath(abs_dockerfile[len(WINDOWS_LONGPATH_PREFIX) :])}"
if os.path.splitdrive(path)[0] != os.path.splitdrive(abs_dockerfile)[ if os.path.splitdrive(path)[0] != os.path.splitdrive(abs_dockerfile)[
0 0
] or os.path.relpath(abs_dockerfile, path).startswith(".."): ] or os.path.relpath(abs_dockerfile, path).startswith(".."):

View File

@ -14,6 +14,7 @@ from __future__ import annotations
import json import json
import logging import logging
import os import os
import typing as t
from ..constants import IS_WINDOWS_PLATFORM from ..constants import IS_WINDOWS_PLATFORM
@ -24,11 +25,11 @@ LEGACY_DOCKER_CONFIG_FILENAME = ".dockercfg"
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def get_default_config_file(): def get_default_config_file() -> str:
return os.path.join(home_dir(), DOCKER_CONFIG_FILENAME) return os.path.join(home_dir(), DOCKER_CONFIG_FILENAME)
def find_config_file(config_path=None): def find_config_file(config_path: str | None = None) -> str | None:
homedir = home_dir() homedir = home_dir()
paths = list( paths = list(
filter( filter(
@ -54,14 +55,14 @@ def find_config_file(config_path=None):
return None return None
def config_path_from_environment(): def config_path_from_environment() -> str | None:
config_dir = os.environ.get("DOCKER_CONFIG") config_dir = os.environ.get("DOCKER_CONFIG")
if not config_dir: if not config_dir:
return None return None
return os.path.join(config_dir, os.path.basename(DOCKER_CONFIG_FILENAME)) return os.path.join(config_dir, os.path.basename(DOCKER_CONFIG_FILENAME))
def home_dir(): def home_dir() -> str:
""" """
Get the user's home directory, using the same logic as the Docker Engine Get the user's home directory, using the same logic as the Docker Engine
client - use %USERPROFILE% on Windows, $HOME/getuid on POSIX. client - use %USERPROFILE% on Windows, $HOME/getuid on POSIX.
@ -71,7 +72,7 @@ def home_dir():
return os.path.expanduser("~") return os.path.expanduser("~")
def load_general_config(config_path=None): def load_general_config(config_path: str | None = None) -> dict[str, t.Any]:
config_file = find_config_file(config_path) config_file = find_config_file(config_path)
if not config_file: if not config_file:

View File

@ -12,16 +12,37 @@
from __future__ import annotations from __future__ import annotations
import functools import functools
import typing as t
from .. import errors from .. import errors
from . import utils from . import utils
def minimum_version(version): if t.TYPE_CHECKING:
def decorator(f): from collections.abc import Callable
from ..api.client import APIClient
_Self = t.TypeVar("_Self")
_P = t.ParamSpec("_P")
_R = t.TypeVar("_R")
def minimum_version(
version: str,
) -> Callable[
[Callable[t.Concatenate[_Self, _P], _R]],
Callable[t.Concatenate[_Self, _P], _R],
]:
def decorator(
f: Callable[t.Concatenate[_Self, _P], _R],
) -> Callable[t.Concatenate[_Self, _P], _R]:
@functools.wraps(f) @functools.wraps(f)
def wrapper(self, *args, **kwargs): def wrapper(self: _Self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
if utils.version_lt(self._version, version): # We use _Self instead of APIClient since this is used for mixins for APIClient.
# This unfortunately means that self._version does not exist in the mixin,
# it only exists after mixing in. This is why we ignore types here.
if utils.version_lt(self._version, version): # type: ignore
raise errors.InvalidVersion( raise errors.InvalidVersion(
f"{f.__name__} is not available for version < {version}" f"{f.__name__} is not available for version < {version}"
) )
@ -32,13 +53,16 @@ def minimum_version(version):
return decorator return decorator
def update_headers(f): def update_headers(
def inner(self, *args, **kwargs): f: Callable[t.Concatenate[APIClient, _P], _R],
) -> Callable[t.Concatenate[APIClient, _P], _R]:
def inner(self: APIClient, *args: _P.args, **kwargs: _P.kwargs) -> _R:
if "HttpHeaders" in self._general_configs: if "HttpHeaders" in self._general_configs:
if not kwargs.get("headers"): if not kwargs.get("headers"):
kwargs["headers"] = self._general_configs["HttpHeaders"] kwargs["headers"] = self._general_configs["HttpHeaders"]
else: else:
kwargs["headers"].update(self._general_configs["HttpHeaders"]) # We cannot (yet) model that kwargs["headers"] should be a dictionary
kwargs["headers"].update(self._general_configs["HttpHeaders"]) # type: ignore
return f(self, *args, **kwargs) return f(self, *args, **kwargs)
return inner return inner

View File

@ -28,16 +28,16 @@ import re
__all__ = ["fnmatch", "fnmatchcase", "translate"] __all__ = ["fnmatch", "fnmatchcase", "translate"]
_cache = {} _cache: dict[str, re.Pattern] = {}
_MAXCACHE = 100 _MAXCACHE = 100
def _purge(): def _purge() -> None:
"""Clear the pattern cache""" """Clear the pattern cache"""
_cache.clear() _cache.clear()
def fnmatch(name, pat): def fnmatch(name: str, pat: str):
"""Test whether FILENAME matches PATTERN. """Test whether FILENAME matches PATTERN.
Patterns are Unix shell style: Patterns are Unix shell style:
@ -58,7 +58,7 @@ def fnmatch(name, pat):
return fnmatchcase(name, pat) return fnmatchcase(name, pat)
def fnmatchcase(name, pat): def fnmatchcase(name: str, pat: str) -> bool:
"""Test whether FILENAME matches PATTERN, including case. """Test whether FILENAME matches PATTERN, including case.
This is a version of fnmatch() which does not case-normalize This is a version of fnmatch() which does not case-normalize
its arguments. its arguments.
@ -74,7 +74,7 @@ def fnmatchcase(name, pat):
return re_pat.match(name) is not None return re_pat.match(name) is not None
def translate(pat): def translate(pat: str) -> str:
"""Translate a shell PATTERN to a regular expression. """Translate a shell PATTERN to a regular expression.
There is no way to quote meta-characters. There is no way to quote meta-characters.

View File

@ -13,14 +13,22 @@ from __future__ import annotations
import json import json
import json.decoder import json.decoder
import typing as t
from ..errors import StreamParseError from ..errors import StreamParseError
if t.TYPE_CHECKING:
import re
from collections.abc import Callable
_T = t.TypeVar("_T")
json_decoder = json.JSONDecoder() json_decoder = json.JSONDecoder()
def stream_as_text(stream): def stream_as_text(stream: t.Generator[bytes | str]) -> t.Generator[str]:
""" """
Given a stream of bytes or text, if any of the items in the stream Given a stream of bytes or text, if any of the items in the stream
are bytes convert them to text. are bytes convert them to text.
@ -33,20 +41,22 @@ def stream_as_text(stream):
yield data yield data
def json_splitter(buffer): def json_splitter(buffer: str) -> tuple[t.Any, str] | None:
"""Attempt to parse a json object from a buffer. If there is at least one """Attempt to parse a json object from a buffer. If there is at least one
object, return it and the rest of the buffer, otherwise return None. object, return it and the rest of the buffer, otherwise return None.
""" """
buffer = buffer.strip() buffer = buffer.strip()
try: try:
obj, index = json_decoder.raw_decode(buffer) obj, index = json_decoder.raw_decode(buffer)
rest = buffer[json.decoder.WHITESPACE.match(buffer, index).end() :] ws: re.Pattern = json.decoder.WHITESPACE # type: ignore[attr-defined]
m = ws.match(buffer, index)
rest = buffer[m.end() :] if m else buffer[index:]
return obj, rest return obj, rest
except ValueError: except ValueError:
return None return None
def json_stream(stream): def json_stream(stream: t.Generator[str | bytes]) -> t.Generator[t.Any]:
"""Given a stream of text, return a stream of json objects. """Given a stream of text, return a stream of json objects.
This handles streams which are inconsistently buffered (some entries may This handles streams which are inconsistently buffered (some entries may
be newline delimited, and others are not). be newline delimited, and others are not).
@ -54,21 +64,24 @@ def json_stream(stream):
return split_buffer(stream, json_splitter, json_decoder.decode) return split_buffer(stream, json_splitter, json_decoder.decode)
def line_splitter(buffer, separator="\n"): def line_splitter(buffer: str, separator: str = "\n") -> tuple[str, str] | None:
index = buffer.find(str(separator)) index = buffer.find(str(separator))
if index == -1: if index == -1:
return None return None
return buffer[: index + 1], buffer[index + 1 :] return buffer[: index + 1], buffer[index + 1 :]
def split_buffer(stream, splitter=None, decoder=lambda a: a): def split_buffer(
stream: t.Generator[str | bytes],
splitter: Callable[[str], tuple[_T, str] | None],
decoder: Callable[[str], _T],
) -> t.Generator[_T | str]:
"""Given a generator which yields strings and a splitter function, """Given a generator which yields strings and a splitter function,
joins all input, splits on the separator and yields each chunk. joins all input, splits on the separator and yields each chunk.
Unlike string.split(), each chunk includes the trailing Unlike string.split(), each chunk includes the trailing
separator, except for the last one if none was found on the end separator, except for the last one if none was found on the end
of the input. of the input.
""" """
splitter = splitter or line_splitter
buffered = "" buffered = ""
for data in stream_as_text(stream): for data in stream_as_text(stream):

View File

@ -12,6 +12,11 @@
from __future__ import annotations from __future__ import annotations
import re import re
import typing as t
if t.TYPE_CHECKING:
from collections.abc import Collection, Sequence
PORT_SPEC = re.compile( PORT_SPEC = re.compile(
@ -26,32 +31,42 @@ PORT_SPEC = re.compile(
) )
def add_port_mapping(port_bindings, internal_port, external): def add_port_mapping(
port_bindings: dict[str, list[str | tuple[str, str | None] | None]],
internal_port: str,
external: str | tuple[str, str | None] | None,
) -> None:
if internal_port in port_bindings: if internal_port in port_bindings:
port_bindings[internal_port].append(external) port_bindings[internal_port].append(external)
else: else:
port_bindings[internal_port] = [external] port_bindings[internal_port] = [external]
def add_port(port_bindings, internal_port_range, external_range): def add_port(
port_bindings: dict[str, list[str | tuple[str, str | None] | None]],
internal_port_range: list[str],
external_range: list[str] | list[tuple[str, str | None]] | None,
) -> None:
if external_range is None: if external_range is None:
for internal_port in internal_port_range: for internal_port in internal_port_range:
add_port_mapping(port_bindings, internal_port, None) add_port_mapping(port_bindings, internal_port, None)
else: else:
ports = zip(internal_port_range, external_range) for internal_port, external_port in zip(internal_port_range, external_range):
for internal_port, external_port in ports: # mypy loses the exact type of eternal_port elements for some reason...
add_port_mapping(port_bindings, internal_port, external_port) add_port_mapping(port_bindings, internal_port, external_port) # type: ignore
def build_port_bindings(ports): def build_port_bindings(
port_bindings = {} ports: Collection[str],
) -> dict[str, list[str | tuple[str, str | None] | None]]:
port_bindings: dict[str, list[str | tuple[str, str | None] | None]] = {}
for port in ports: for port in ports:
internal_port_range, external_range = split_port(port) internal_port_range, external_range = split_port(port)
add_port(port_bindings, internal_port_range, external_range) add_port(port_bindings, internal_port_range, external_range)
return port_bindings return port_bindings
def _raise_invalid_port(port): def _raise_invalid_port(port: str) -> t.NoReturn:
raise ValueError( raise ValueError(
f'Invalid port "{port}", should be ' f'Invalid port "{port}", should be '
"[[remote_ip:]remote_port[-remote_port]:]" "[[remote_ip:]remote_port[-remote_port]:]"
@ -59,39 +74,64 @@ def _raise_invalid_port(port):
) )
def port_range(start, end, proto, randomly_available_port=False): @t.overload
if not start: def port_range(
start: str,
end: str | None,
proto: str,
randomly_available_port: bool = False,
) -> list[str]: ...
@t.overload
def port_range(
start: str | None,
end: str | None,
proto: str,
randomly_available_port: bool = False,
) -> list[str] | None: ...
def port_range(
start: str | None,
end: str | None,
proto: str,
randomly_available_port: bool = False,
) -> list[str] | None:
if start is None:
return start return start
if not end: if end is None:
return [f"{start}{proto}"] return [f"{start}{proto}"]
if randomly_available_port: if randomly_available_port:
return [f"{start}-{end}{proto}"] return [f"{start}-{end}{proto}"]
return [f"{port}{proto}" for port in range(int(start), int(end) + 1)] return [f"{port}{proto}" for port in range(int(start), int(end) + 1)]
def split_port(port): def split_port(
if hasattr(port, "legacy_repr"): port: str,
# This is the worst hack, but it prevents a bug in Compose 1.14.0 ) -> tuple[list[str], list[str] | list[tuple[str, str | None]] | None]:
# https://github.com/docker/docker-py/issues/1668
# TODO: remove once fixed in Compose stable
port = port.legacy_repr()
port = str(port) port = str(port)
match = PORT_SPEC.match(port) match = PORT_SPEC.match(port)
if match is None: if match is None:
_raise_invalid_port(port) _raise_invalid_port(port)
parts = match.groupdict() parts = match.groupdict()
host = parts["host"] host: str | None = parts["host"]
proto = parts["proto"] or "" proto: str = parts["proto"] or ""
internal = port_range(parts["int"], parts["int_end"], proto) int_p: str = parts["int"]
external = port_range(parts["ext"], parts["ext_end"], "", len(internal) == 1) ext_p: str = parts["ext"]
internal: list[str] = port_range(int_p, parts["int_end"], proto) # type: ignore
external = port_range(ext_p or None, parts["ext_end"], "", len(internal) == 1)
if host is None: if host is None:
if external is not None and len(internal) != len(external): if (external is not None and len(internal) != len(external)) or ext_p == "":
raise ValueError("Port ranges don't match in length") raise ValueError("Port ranges don't match in length")
return internal, external return internal, external
external_or_none: Sequence[str | None]
if not external: if not external:
external = [None] * len(internal) external_or_none = [None] * len(internal)
elif len(internal) != len(external): else:
raise ValueError("Port ranges don't match in length") external_or_none = external
return internal, [(host, ext_port) for ext_port in external] if len(internal) != len(external_or_none):
raise ValueError("Port ranges don't match in length")
return internal, [(host, ext_port) for ext_port in external_or_none]

View File

@ -20,23 +20,23 @@ class ProxyConfig(dict):
""" """
@property @property
def http(self): def http(self) -> str | None:
return self.get("http") return self.get("http")
@property @property
def https(self): def https(self) -> str | None:
return self.get("https") return self.get("https")
@property @property
def ftp(self): def ftp(self) -> str | None:
return self.get("ftp") return self.get("ftp")
@property @property
def no_proxy(self): def no_proxy(self) -> str | None:
return self.get("no_proxy") return self.get("no_proxy")
@staticmethod @staticmethod
def from_dict(config): def from_dict(config: dict[str, str]) -> ProxyConfig:
""" """
Instantiate a new ProxyConfig from a dictionary that represents a Instantiate a new ProxyConfig from a dictionary that represents a
client configuration, as described in `the documentation`_. client configuration, as described in `the documentation`_.
@ -51,7 +51,7 @@ class ProxyConfig(dict):
no_proxy=config.get("noProxy"), no_proxy=config.get("noProxy"),
) )
def get_environment(self): def get_environment(self) -> dict[str, str]:
""" """
Return a dictionary representing the environment variables used to Return a dictionary representing the environment variables used to
set the proxy settings. set the proxy settings.
@ -67,7 +67,7 @@ class ProxyConfig(dict):
env["no_proxy"] = env["NO_PROXY"] = self.no_proxy env["no_proxy"] = env["NO_PROXY"] = self.no_proxy
return env return env
def inject_proxy_environment(self, environment): def inject_proxy_environment(self, environment: list[str]) -> list[str]:
""" """
Given a list of strings representing environment variables, prepend the Given a list of strings representing environment variables, prepend the
environment variables corresponding to the proxy settings. environment variables corresponding to the proxy settings.
@ -82,5 +82,5 @@ class ProxyConfig(dict):
# variables defined in "environment" to take precedence. # variables defined in "environment" to take precedence.
return proxy_env + environment return proxy_env + environment
def __str__(self): def __str__(self) -> str:
return f"ProxyConfig(http={self.http}, https={self.https}, ftp={self.ftp}, no_proxy={self.no_proxy})" return f"ProxyConfig(http={self.http}, https={self.https}, ftp={self.ftp}, no_proxy={self.no_proxy})"

View File

@ -16,10 +16,15 @@ import os
import select import select
import socket as pysocket import socket as pysocket
import struct import struct
import typing as t
from ..transport.npipesocket import NpipeSocket from ..transport.npipesocket import NpipeSocket
if t.TYPE_CHECKING:
from collections.abc import Iterable
STDOUT = 1 STDOUT = 1
STDERR = 2 STDERR = 2
@ -33,7 +38,7 @@ class SocketError(Exception):
NPIPE_ENDED = 109 NPIPE_ENDED = 109
def read(socket, n=4096): def read(socket, n: int = 4096) -> bytes | None:
""" """
Reads at most n bytes from socket Reads at most n bytes from socket
""" """
@ -58,6 +63,7 @@ def read(socket, n=4096):
except EnvironmentError as e: except EnvironmentError as e:
if e.errno not in recoverable_errors: if e.errno not in recoverable_errors:
raise raise
return None # TODO ???
except Exception as e: except Exception as e:
is_pipe_ended = ( is_pipe_ended = (
isinstance(socket, NpipeSocket) isinstance(socket, NpipeSocket)
@ -67,11 +73,11 @@ def read(socket, n=4096):
if is_pipe_ended: if is_pipe_ended:
# npipes do not support duplex sockets, so we interpret # npipes do not support duplex sockets, so we interpret
# a PIPE_ENDED error as a close operation (0-length read). # a PIPE_ENDED error as a close operation (0-length read).
return "" return b""
raise raise
def read_exactly(socket, n): def read_exactly(socket, n: int) -> bytes:
""" """
Reads exactly n bytes from socket Reads exactly n bytes from socket
Raises SocketError if there is not enough data Raises SocketError if there is not enough data
@ -85,7 +91,7 @@ def read_exactly(socket, n):
return data return data
def next_frame_header(socket): def next_frame_header(socket) -> tuple[int, int]:
""" """
Returns the stream and size of the next frame of data waiting to be read Returns the stream and size of the next frame of data waiting to be read
from socket, according to the protocol defined here: from socket, according to the protocol defined here:
@ -101,7 +107,7 @@ def next_frame_header(socket):
return (stream, actual) return (stream, actual)
def frames_iter(socket, tty): def frames_iter(socket, tty: bool) -> t.Generator[tuple[int, bytes]]:
""" """
Return a generator of frames read from socket. A frame is a tuple where Return a generator of frames read from socket. A frame is a tuple where
the first item is the stream number and the second item is a chunk of data. the first item is the stream number and the second item is a chunk of data.
@ -114,7 +120,7 @@ def frames_iter(socket, tty):
return frames_iter_no_tty(socket) return frames_iter_no_tty(socket)
def frames_iter_no_tty(socket): def frames_iter_no_tty(socket) -> t.Generator[tuple[int, bytes]]:
""" """
Returns a generator of data read from the socket when the tty setting is Returns a generator of data read from the socket when the tty setting is
not enabled. not enabled.
@ -135,20 +141,34 @@ def frames_iter_no_tty(socket):
yield (stream, result) yield (stream, result)
def frames_iter_tty(socket): def frames_iter_tty(socket) -> t.Generator[bytes]:
""" """
Return a generator of data read from the socket when the tty setting is Return a generator of data read from the socket when the tty setting is
enabled. enabled.
""" """
while True: while True:
result = read(socket) result = read(socket)
if len(result) == 0: if not result:
# We have reached EOF # We have reached EOF
return return
yield result yield result
def consume_socket_output(frames, demux=False): @t.overload
def consume_socket_output(frames, demux: t.Literal[False] = False) -> bytes: ...
@t.overload
def consume_socket_output(frames, demux: t.Literal[True]) -> tuple[bytes, bytes]: ...
@t.overload
def consume_socket_output(
frames, demux: bool = False
) -> bytes | tuple[bytes, bytes]: ...
def consume_socket_output(frames, demux: bool = False) -> bytes | tuple[bytes, bytes]:
""" """
Iterate through frames read from the socket and return the result. Iterate through frames read from the socket and return the result.
@ -167,7 +187,7 @@ def consume_socket_output(frames, demux=False):
# If the streams are demultiplexed, the generator yields tuples # If the streams are demultiplexed, the generator yields tuples
# (stdout, stderr) # (stdout, stderr)
out = [None, None] out: list[bytes | None] = [None, None]
for frame in frames: for frame in frames:
# It is guaranteed that for each frame, one and only one stream # It is guaranteed that for each frame, one and only one stream
# is not None. # is not None.
@ -183,10 +203,10 @@ def consume_socket_output(frames, demux=False):
out[1] = frame[1] out[1] = frame[1]
else: else:
out[1] += frame[1] out[1] += frame[1]
return tuple(out) return tuple(out) # type: ignore
def demux_adaptor(stream_id, data): def demux_adaptor(stream_id: int, data: bytes) -> tuple[bytes | None, bytes | None]:
""" """
Utility to demultiplex stdout and stderr when reading frames from the Utility to demultiplex stdout and stderr when reading frames from the
socket. socket.

View File

@ -18,6 +18,7 @@ import os
import os.path import os.path
import shlex import shlex
import string import string
import typing as t
from urllib.parse import urlparse, urlunparse from urllib.parse import urlparse, urlunparse
from ansible_collections.community.docker.plugins.module_utils._version import ( from ansible_collections.community.docker.plugins.module_utils._version import (
@ -34,32 +35,23 @@ from ..constants import (
from ..tls import TLSConfig from ..tls import TLSConfig
if t.TYPE_CHECKING:
import ssl
from collections.abc import Mapping, Sequence
URLComponents = collections.namedtuple( URLComponents = collections.namedtuple(
"URLComponents", "URLComponents",
"scheme netloc url params query fragment", "scheme netloc url params query fragment",
) )
def create_ipam_pool(*args, **kwargs): def decode_json_header(header: str) -> dict[str, t.Any]:
raise errors.DeprecatedMethod(
"utils.create_ipam_pool has been removed. Please use a "
"docker.types.IPAMPool object instead."
)
def create_ipam_config(*args, **kwargs):
raise errors.DeprecatedMethod(
"utils.create_ipam_config has been removed. Please use a "
"docker.types.IPAMConfig object instead."
)
def decode_json_header(header):
data = base64.b64decode(header).decode("utf-8") data = base64.b64decode(header).decode("utf-8")
return json.loads(data) return json.loads(data)
def compare_version(v1, v2): def compare_version(v1: str, v2: str) -> t.Literal[-1, 0, 1]:
"""Compare docker versions """Compare docker versions
>>> v1 = '1.9' >>> v1 = '1.9'
@ -80,43 +72,64 @@ def compare_version(v1, v2):
return 1 return 1
def version_lt(v1, v2): def version_lt(v1: str, v2: str) -> bool:
return compare_version(v1, v2) > 0 return compare_version(v1, v2) > 0
def version_gte(v1, v2): def version_gte(v1: str, v2: str) -> bool:
return not version_lt(v1, v2) return not version_lt(v1, v2)
def _convert_port_binding(binding): def _convert_port_binding(
binding: (
tuple[str, str | int | None]
| tuple[str | int | None]
| dict[str, str]
| str
| int
),
) -> dict[str, str]:
result = {"HostIp": "", "HostPort": ""} result = {"HostIp": "", "HostPort": ""}
host_port: str | int | None = ""
if isinstance(binding, tuple): if isinstance(binding, tuple):
if len(binding) == 2: if len(binding) == 2:
result["HostPort"] = binding[1] host_port = binding[1] # type: ignore
result["HostIp"] = binding[0] result["HostIp"] = binding[0]
elif isinstance(binding[0], str): elif isinstance(binding[0], str):
result["HostIp"] = binding[0] result["HostIp"] = binding[0]
else: else:
result["HostPort"] = binding[0] host_port = binding[0]
elif isinstance(binding, dict): elif isinstance(binding, dict):
if "HostPort" in binding: if "HostPort" in binding:
result["HostPort"] = binding["HostPort"] host_port = binding["HostPort"]
if "HostIp" in binding: if "HostIp" in binding:
result["HostIp"] = binding["HostIp"] result["HostIp"] = binding["HostIp"]
else: else:
raise ValueError(binding) raise ValueError(binding)
else: else:
result["HostPort"] = binding host_port = binding
if result["HostPort"] is None:
result["HostPort"] = ""
else:
result["HostPort"] = str(result["HostPort"])
result["HostPort"] = str(host_port) if host_port is not None else ""
return result return result
def convert_port_bindings(port_bindings): def convert_port_bindings(
port_bindings: dict[
str | int,
tuple[str, str | int | None]
| tuple[str | int | None]
| dict[str, str]
| str
| int
| list[
tuple[str, str | int | None]
| tuple[str | int | None]
| dict[str, str]
| str
| int
],
],
) -> dict[str, list[dict[str, str]]]:
result = {} result = {}
for k, v in port_bindings.items(): for k, v in port_bindings.items():
key = str(k) key = str(k)
@ -129,9 +142,11 @@ def convert_port_bindings(port_bindings):
return result return result
def convert_volume_binds(binds): def convert_volume_binds(
binds: list[str] | Mapping[str | bytes, dict[str, str | bytes] | bytes | str | int],
) -> list[str]:
if isinstance(binds, list): if isinstance(binds, list):
return binds return binds # type: ignore
result = [] result = []
for k, v in binds.items(): for k, v in binds.items():
@ -149,7 +164,7 @@ def convert_volume_binds(binds):
if "ro" in v: if "ro" in v:
mode = "ro" if v["ro"] else "rw" mode = "ro" if v["ro"] else "rw"
elif "mode" in v: elif "mode" in v:
mode = v["mode"] mode = v["mode"] # type: ignore # TODO
else: else:
mode = "rw" mode = "rw"
@ -165,9 +180,9 @@ def convert_volume_binds(binds):
] ]
if "propagation" in v and v["propagation"] in propagation_modes: if "propagation" in v and v["propagation"] in propagation_modes:
if mode: if mode:
mode = ",".join([mode, v["propagation"]]) mode = ",".join([mode, v["propagation"]]) # type: ignore # TODO
else: else:
mode = v["propagation"] mode = v["propagation"] # type: ignore # TODO
result.append(f"{k}:{bind}:{mode}") result.append(f"{k}:{bind}:{mode}")
else: else:
@ -177,7 +192,7 @@ def convert_volume_binds(binds):
return result return result
def convert_tmpfs_mounts(tmpfs): def convert_tmpfs_mounts(tmpfs: dict[str, str] | list[str]) -> dict[str, str]:
if isinstance(tmpfs, dict): if isinstance(tmpfs, dict):
return tmpfs return tmpfs
@ -204,9 +219,11 @@ def convert_tmpfs_mounts(tmpfs):
return result return result
def convert_service_networks(networks): def convert_service_networks(
networks: list[str | dict[str, str]],
) -> list[dict[str, str]]:
if not networks: if not networks:
return networks return networks # type: ignore
if not isinstance(networks, list): if not isinstance(networks, list):
raise TypeError("networks parameter must be a list.") raise TypeError("networks parameter must be a list.")
@ -218,17 +235,17 @@ def convert_service_networks(networks):
return result return result
def parse_repository_tag(repo_name): def parse_repository_tag(repo_name: str) -> tuple[str, str | None]:
parts = repo_name.rsplit("@", 1) parts = repo_name.rsplit("@", 1)
if len(parts) == 2: if len(parts) == 2:
return tuple(parts) return tuple(parts) # type: ignore
parts = repo_name.rsplit(":", 1) parts = repo_name.rsplit(":", 1)
if len(parts) == 2 and "/" not in parts[1]: if len(parts) == 2 and "/" not in parts[1]:
return tuple(parts) return tuple(parts) # type: ignore
return repo_name, None return repo_name, None
def parse_host(addr, is_win32=False, tls=False): def parse_host(addr: str | None, is_win32: bool = False, tls: bool = False) -> str:
# Sensible defaults # Sensible defaults
if not addr and is_win32: if not addr and is_win32:
return DEFAULT_NPIPE return DEFAULT_NPIPE
@ -308,7 +325,7 @@ def parse_host(addr, is_win32=False, tls=False):
).rstrip("/") ).rstrip("/")
def parse_devices(devices): def parse_devices(devices: Sequence[dict[str, str] | str]) -> list[dict[str, str]]:
device_list = [] device_list = []
for device in devices: for device in devices:
if isinstance(device, dict): if isinstance(device, dict):
@ -337,7 +354,10 @@ def parse_devices(devices):
return device_list return device_list
def kwargs_from_env(ssl_version=None, assert_hostname=None, environment=None): def kwargs_from_env(
assert_hostname: bool | None = None,
environment: Mapping[str, str] | None = None,
) -> dict[str, t.Any]:
if not environment: if not environment:
environment = os.environ environment = os.environ
host = environment.get("DOCKER_HOST") host = environment.get("DOCKER_HOST")
@ -347,14 +367,14 @@ def kwargs_from_env(ssl_version=None, assert_hostname=None, environment=None):
# empty string for tls verify counts as "false". # empty string for tls verify counts as "false".
# Any value or 'unset' counts as true. # Any value or 'unset' counts as true.
tls_verify = environment.get("DOCKER_TLS_VERIFY") tls_verify_str = environment.get("DOCKER_TLS_VERIFY")
if tls_verify == "": if tls_verify_str == "":
tls_verify = False tls_verify = False
else: else:
tls_verify = tls_verify is not None tls_verify = tls_verify_str is not None
enable_tls = cert_path or tls_verify enable_tls = cert_path or tls_verify
params = {} params: dict[str, t.Any] = {}
if host: if host:
params["base_url"] = host params["base_url"] = host
@ -377,14 +397,13 @@ def kwargs_from_env(ssl_version=None, assert_hostname=None, environment=None):
), ),
ca_cert=os.path.join(cert_path, "ca.pem"), ca_cert=os.path.join(cert_path, "ca.pem"),
verify=tls_verify, verify=tls_verify,
ssl_version=ssl_version,
assert_hostname=assert_hostname, assert_hostname=assert_hostname,
) )
return params return params
def convert_filters(filters): def convert_filters(filters: Mapping[str, bool | str | list[str]]) -> str:
result = {} result = {}
for k, v in filters.items(): for k, v in filters.items():
if isinstance(v, bool): if isinstance(v, bool):
@ -397,7 +416,7 @@ def convert_filters(filters):
return json.dumps(result) return json.dumps(result)
def parse_bytes(s): def parse_bytes(s: int | float | str) -> int | float:
if isinstance(s, (int, float)): if isinstance(s, (int, float)):
return s return s
if len(s) == 0: if len(s) == 0:
@ -435,14 +454,16 @@ def parse_bytes(s):
return s return s
def normalize_links(links): def normalize_links(links: dict[str, str] | Sequence[tuple[str, str]]) -> list[str]:
if isinstance(links, dict): if isinstance(links, dict):
links = links.items() sorted_links = sorted(links.items())
else:
sorted_links = sorted(links)
return [f"{k}:{v}" if v else k for k, v in sorted(links)] return [f"{k}:{v}" if v else k for k, v in sorted_links]
def parse_env_file(env_file): def parse_env_file(env_file: str | os.PathLike) -> dict[str, str]:
""" """
Reads a line-separated environment file. Reads a line-separated environment file.
The format of each line should be "key=value". The format of each line should be "key=value".
@ -451,7 +472,6 @@ def parse_env_file(env_file):
with open(env_file, "rt", encoding="utf-8") as f: with open(env_file, "rt", encoding="utf-8") as f:
for line in f: for line in f:
if line[0] == "#": if line[0] == "#":
continue continue
@ -471,11 +491,11 @@ def parse_env_file(env_file):
return environment return environment
def split_command(command): def split_command(command: str) -> list[str]:
return shlex.split(command) return shlex.split(command)
def format_environment(environment): def format_environment(environment: Mapping[str, str | bytes]) -> list[str]:
def format_env(key, value): def format_env(key, value):
if value is None: if value is None:
return key return key
@ -487,16 +507,9 @@ def format_environment(environment):
return [format_env(*var) for var in environment.items()] return [format_env(*var) for var in environment.items()]
def format_extra_hosts(extra_hosts, task=False): def format_extra_hosts(extra_hosts: Mapping[str, str], task: bool = False) -> list[str]:
# Use format dictated by Swarm API if container is part of a task # Use format dictated by Swarm API if container is part of a task
if task: if task:
return [f"{v} {k}" for k, v in sorted(extra_hosts.items())] return [f"{v} {k}" for k, v in sorted(extra_hosts.items())]
return [f"{k}:{v}" for k, v in sorted(extra_hosts.items())] return [f"{k}:{v}" for k, v in sorted(extra_hosts.items())]
def create_host_config(self, *args, **kwargs):
raise errors.DeprecatedMethod(
"utils.create_host_config has been removed. Please use a "
"docker.types.HostConfig object instead."
)

View File

@ -13,6 +13,7 @@ import platform
import re import re
import sys import sys
import traceback import traceback
import typing as t
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from ansible.module_utils.basic import AnsibleModule, missing_required_lib from ansible.module_utils.basic import AnsibleModule, missing_required_lib
@ -36,6 +37,9 @@ from ansible_collections.community.docker.plugins.module_utils._version import (
HAS_DOCKER_PY_2 = False # pylint: disable=invalid-name HAS_DOCKER_PY_2 = False # pylint: disable=invalid-name
HAS_DOCKER_PY_3 = False # pylint: disable=invalid-name HAS_DOCKER_PY_3 = False # pylint: disable=invalid-name
HAS_DOCKER_ERROR: None | str # pylint: disable=invalid-name
HAS_DOCKER_TRACEBACK: None | str # pylint: disable=invalid-name
docker_version: str | None # pylint: disable=invalid-name
try: try:
from docker import __version__ as docker_version from docker import __version__ as docker_version
@ -51,12 +55,13 @@ try:
HAS_DOCKER_PY_2 = True # pylint: disable=invalid-name HAS_DOCKER_PY_2 = True # pylint: disable=invalid-name
from docker import APIClient as Client from docker import APIClient as Client
else: else:
from docker import Client from docker import Client # type: ignore
except ImportError as exc: except ImportError as exc:
HAS_DOCKER_ERROR = str(exc) # pylint: disable=invalid-name HAS_DOCKER_ERROR = str(exc) # pylint: disable=invalid-name
HAS_DOCKER_TRACEBACK = traceback.format_exc() # pylint: disable=invalid-name HAS_DOCKER_TRACEBACK = traceback.format_exc() # pylint: disable=invalid-name
HAS_DOCKER_PY = False # pylint: disable=invalid-name HAS_DOCKER_PY = False # pylint: disable=invalid-name
docker_version = None # pylint: disable=invalid-name
else: else:
HAS_DOCKER_PY = True # pylint: disable=invalid-name HAS_DOCKER_PY = True # pylint: disable=invalid-name
HAS_DOCKER_ERROR = None # pylint: disable=invalid-name HAS_DOCKER_ERROR = None # pylint: disable=invalid-name
@ -71,30 +76,34 @@ except ImportError:
# Either Docker SDK for Python is no longer using requests, or Docker SDK for Python is not around either, # Either Docker SDK for Python is no longer using requests, or Docker SDK for Python is not around either,
# or Docker SDK for Python's dependency requests is missing. In any case, define an exception # or Docker SDK for Python's dependency requests is missing. In any case, define an exception
# class RequestException so that our code does not break. # class RequestException so that our code does not break.
class RequestException(Exception): class RequestException(Exception): # type: ignore
pass pass
if t.TYPE_CHECKING:
from collections.abc import Callable
MIN_DOCKER_VERSION = "2.0.0" MIN_DOCKER_VERSION = "2.0.0"
if not HAS_DOCKER_PY: if not HAS_DOCKER_PY:
docker_version = None # pylint: disable=invalid-name
# No Docker SDK for Python. Create a place holder client to allow # No Docker SDK for Python. Create a place holder client to allow
# instantiation of AnsibleModule and proper error handing # instantiation of AnsibleModule and proper error handing
class Client: # noqa: F811, pylint: disable=function-redefined class Client: # type: ignore # noqa: F811, pylint: disable=function-redefined
def __init__(self, **kwargs): def __init__(self, **kwargs):
pass pass
class APIError(Exception): # noqa: F811, pylint: disable=function-redefined class APIError(Exception): # type: ignore # noqa: F811, pylint: disable=function-redefined
pass pass
class NotFound(Exception): # noqa: F811, pylint: disable=function-redefined class NotFound(Exception): # type: ignore # noqa: F811, pylint: disable=function-redefined
pass pass
def _get_tls_config(fail_function, **kwargs): def _get_tls_config(
fail_function: Callable[[str], t.NoReturn], **kwargs: t.Any
) -> TLSConfig:
if "assert_hostname" in kwargs and LooseVersion(docker_version) >= LooseVersion( if "assert_hostname" in kwargs and LooseVersion(docker_version) >= LooseVersion(
"7.0.0b1" "7.0.0b1"
): ):
@ -109,17 +118,18 @@ def _get_tls_config(fail_function, **kwargs):
# Filter out all None parameters # Filter out all None parameters
kwargs = dict((k, v) for k, v in kwargs.items() if v is not None) kwargs = dict((k, v) for k, v in kwargs.items() if v is not None)
try: try:
tls_config = TLSConfig(**kwargs) return TLSConfig(**kwargs)
return tls_config
except TLSParameterError as exc: except TLSParameterError as exc:
fail_function(f"TLS config error: {exc}") fail_function(f"TLS config error: {exc}")
def is_using_tls(auth_data): def is_using_tls(auth_data: dict[str, t.Any]) -> bool:
return auth_data["tls_verify"] or auth_data["tls"] return auth_data["tls_verify"] or auth_data["tls"]
def get_connect_params(auth_data, fail_function): def get_connect_params(
auth_data: dict[str, t.Any], fail_function: Callable[[str], t.NoReturn]
) -> dict[str, t.Any]:
if is_using_tls(auth_data): if is_using_tls(auth_data):
auth_data["docker_host"] = auth_data["docker_host"].replace( auth_data["docker_host"] = auth_data["docker_host"].replace(
"tcp://", "https://" "tcp://", "https://"
@ -171,7 +181,11 @@ DOCKERPYUPGRADE_UPGRADE_DOCKER = "Use `pip install --upgrade docker` to upgrade.
class AnsibleDockerClientBase(Client): class AnsibleDockerClientBase(Client):
def __init__(self, min_docker_version=None, min_docker_api_version=None): def __init__(
self,
min_docker_version: str | None = None,
min_docker_api_version: str | None = None,
) -> None:
if min_docker_version is None: if min_docker_version is None:
min_docker_version = MIN_DOCKER_VERSION min_docker_version = MIN_DOCKER_VERSION
@ -212,23 +226,34 @@ class AnsibleDockerClientBase(Client):
f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}." f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}."
) )
def log(self, msg, pretty_print=False): def log(self, msg: t.Any, pretty_print: bool = False):
pass pass
# if self.debug: # if self.debug:
# from .util import log_debug # from .util import log_debug
# log_debug(msg, pretty_print=pretty_print) # log_debug(msg, pretty_print=pretty_print)
@abc.abstractmethod @abc.abstractmethod
def fail(self, msg, **kwargs): def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
pass pass
def deprecate(self, msg, version=None, date=None, collection_name=None): @abc.abstractmethod
def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
pass pass
@staticmethod @staticmethod
def _get_value( def _get_value(
param_name, param_value, env_variable, default_value, value_type="str" param_name: str,
): param_value: t.Any,
env_variable: str | None,
default_value: t.Any | None,
value_type: t.Literal["str", "bool", "int"] = "str",
) -> t.Any:
if param_value is not None: if param_value is not None:
# take module parameter value # take module parameter value
if value_type == "bool": if value_type == "bool":
@ -265,11 +290,11 @@ class AnsibleDockerClientBase(Client):
return default_value return default_value
@abc.abstractmethod @abc.abstractmethod
def _get_params(self): def _get_params(self) -> dict[str, t.Any]:
pass pass
@property @property
def auth_params(self): def auth_params(self) -> dict[str, t.Any]:
# Get authentication credentials. # Get authentication credentials.
# Precedence: module parameters-> environment variables-> defaults. # Precedence: module parameters-> environment variables-> defaults.
@ -354,7 +379,7 @@ class AnsibleDockerClientBase(Client):
return result return result
def _handle_ssl_error(self, error): def _handle_ssl_error(self, error: Exception) -> t.NoReturn:
match = re.match(r"hostname.*doesn\'t match (\'.*\')", str(error)) match = re.match(r"hostname.*doesn\'t match (\'.*\')", str(error))
if match: if match:
hostname = self.auth_params["tls_hostname"] hostname = self.auth_params["tls_hostname"]
@ -366,7 +391,7 @@ class AnsibleDockerClientBase(Client):
) )
self.fail(f"SSL Exception: {error}") self.fail(f"SSL Exception: {error}")
def get_container_by_id(self, container_id): def get_container_by_id(self, container_id: str) -> dict[str, t.Any] | None:
try: try:
self.log(f"Inspecting container Id {container_id}") self.log(f"Inspecting container Id {container_id}")
result = self.inspect_container(container=container_id) result = self.inspect_container(container=container_id)
@ -377,7 +402,7 @@ class AnsibleDockerClientBase(Client):
except Exception as exc: # pylint: disable=broad-exception-caught except Exception as exc: # pylint: disable=broad-exception-caught
self.fail(f"Error inspecting container: {exc}") self.fail(f"Error inspecting container: {exc}")
def get_container(self, name=None): def get_container(self, name: str | None) -> dict[str, t.Any] | None:
""" """
Lookup a container and return the inspection results. Lookup a container and return the inspection results.
""" """
@ -414,7 +439,9 @@ class AnsibleDockerClientBase(Client):
return self.get_container_by_id(result["Id"]) return self.get_container_by_id(result["Id"])
def get_network(self, name=None, network_id=None): def get_network(
self, name: str | None = None, network_id: str | None = None
) -> dict[str, t.Any] | None:
""" """
Lookup a network and return the inspection results. Lookup a network and return the inspection results.
""" """
@ -453,7 +480,7 @@ class AnsibleDockerClientBase(Client):
return result return result
def find_image(self, name, tag): def find_image(self, name: str, tag: str) -> dict[str, t.Any] | None:
""" """
Lookup an image (by name and tag) and return the inspection results. Lookup an image (by name and tag) and return the inspection results.
""" """
@ -505,7 +532,9 @@ class AnsibleDockerClientBase(Client):
self.log(f"Image {name}:{tag} not found.") self.log(f"Image {name}:{tag} not found.")
return None return None
def find_image_by_id(self, image_id, accept_missing_image=False): def find_image_by_id(
self, image_id: str, accept_missing_image: bool = False
) -> dict[str, t.Any] | None:
""" """
Lookup an image (by ID) and return the inspection results. Lookup an image (by ID) and return the inspection results.
""" """
@ -524,7 +553,7 @@ class AnsibleDockerClientBase(Client):
self.fail(f"Error inspecting image ID {image_id} - {exc}") self.fail(f"Error inspecting image ID {image_id} - {exc}")
return inspection return inspection
def _image_lookup(self, name, tag): def _image_lookup(self, name: str, tag: str) -> list[dict[str, t.Any]]:
""" """
Including a tag in the name parameter sent to the Docker SDK for Python images method Including a tag in the name parameter sent to the Docker SDK for Python images method
does not work consistently. Instead, get the result set for name and manually check does not work consistently. Instead, get the result set for name and manually check
@ -547,7 +576,9 @@ class AnsibleDockerClientBase(Client):
break break
return images return images
def pull_image(self, name, tag="latest", image_platform=None): def pull_image(
self, name: str, tag: str = "latest", image_platform: str | None = None
) -> tuple[dict[str, t.Any] | None, bool]:
""" """
Pull an image Pull an image
""" """
@ -578,7 +609,7 @@ class AnsibleDockerClientBase(Client):
return new_tag, old_tag == new_tag return new_tag, old_tag == new_tag
def inspect_distribution(self, image, **kwargs): def inspect_distribution(self, image: str, **kwargs) -> dict[str, t.Any]:
""" """
Get image digest by directly calling the Docker API when running Docker SDK < 4.0.0 Get image digest by directly calling the Docker API when running Docker SDK < 4.0.0
since prior versions did not support accessing private repositories. since prior versions did not support accessing private repositories.
@ -592,7 +623,7 @@ class AnsibleDockerClientBase(Client):
self._url("/distribution/{0}/json", image), self._url("/distribution/{0}/json", image),
headers={"X-Registry-Auth": header}, headers={"X-Registry-Auth": header},
), ),
get_json=True, json=True,
) )
return super().inspect_distribution(image, **kwargs) return super().inspect_distribution(image, **kwargs)
@ -601,18 +632,24 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
def __init__( def __init__(
self, self,
argument_spec=None, argument_spec: dict[str, t.Any] | None = None,
supports_check_mode=False, supports_check_mode: bool = False,
mutually_exclusive=None, mutually_exclusive: Sequence[Sequence[str]] | None = None,
required_together=None, required_together: Sequence[Sequence[str]] | None = None,
required_if=None, required_if: (
required_one_of=None, Sequence[
required_by=None, tuple[str, t.Any, Sequence[str]]
min_docker_version=None, | tuple[str, t.Any, Sequence[str], bool]
min_docker_api_version=None, ]
option_minimal_versions=None, | None
option_minimal_versions_ignore_params=None, ) = None,
fail_results=None, required_one_of: Sequence[Sequence[str]] | None = None,
required_by: dict[str, Sequence[str]] | None = None,
min_docker_version: str | None = None,
min_docker_api_version: str | None = None,
option_minimal_versions: dict[str, t.Any] | None = None,
option_minimal_versions_ignore_params: Sequence[str] | None = None,
fail_results: dict[str, t.Any] | None = None,
): ):
# Modules can put information in here which will always be returned # Modules can put information in here which will always be returned
@ -625,12 +662,12 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
merged_arg_spec.update(argument_spec) merged_arg_spec.update(argument_spec)
self.arg_spec = merged_arg_spec self.arg_spec = merged_arg_spec
mutually_exclusive_params = [] mutually_exclusive_params: list[Sequence[str]] = []
mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE
if mutually_exclusive: if mutually_exclusive:
mutually_exclusive_params += mutually_exclusive mutually_exclusive_params += mutually_exclusive
required_together_params = [] required_together_params: list[Sequence[str]] = []
required_together_params += DOCKER_REQUIRED_TOGETHER required_together_params += DOCKER_REQUIRED_TOGETHER
if required_together: if required_together:
required_together_params += required_together required_together_params += required_together
@ -658,20 +695,30 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
option_minimal_versions, option_minimal_versions_ignore_params option_minimal_versions, option_minimal_versions_ignore_params
) )
def fail(self, msg, **kwargs): def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
self.fail_results.update(kwargs) self.fail_results.update(kwargs)
self.module.fail_json(msg=msg, **sanitize_result(self.fail_results)) self.module.fail_json(msg=msg, **sanitize_result(self.fail_results))
def deprecate(self, msg, version=None, date=None, collection_name=None): def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
self.module.deprecate( self.module.deprecate(
msg, version=version, date=date, collection_name=collection_name msg, version=version, date=date, collection_name=collection_name
) )
def _get_params(self): def _get_params(self) -> dict[str, t.Any]:
return self.module.params return self.module.params
def _get_minimal_versions(self, option_minimal_versions, ignore_params=None): def _get_minimal_versions(
self.option_minimal_versions = {} self,
option_minimal_versions: dict[str, t.Any],
ignore_params: Sequence[str] | None = None,
) -> None:
self.option_minimal_versions: dict[str, dict[str, t.Any]] = {}
for option in self.module.argument_spec: for option in self.module.argument_spec:
if ignore_params is not None: if ignore_params is not None:
if option in ignore_params: if option in ignore_params:
@ -722,7 +769,9 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
msg = f"Cannot {usg} with your configuration." msg = f"Cannot {usg} with your configuration."
self.fail(msg) self.fail(msg)
def report_warnings(self, result, warnings_key=None): def report_warnings(
self, result: t.Any, warnings_key: Sequence[str] | None = None
) -> None:
""" """
Checks result of client operation for warnings, and if present, outputs them. Checks result of client operation for warnings, and if present, outputs them.

View File

@ -11,6 +11,7 @@ from __future__ import annotations
import abc import abc
import os import os
import re import re
import typing as t
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from ansible.module_utils.basic import AnsibleModule, missing_required_lib from ansible.module_utils.basic import AnsibleModule, missing_required_lib
@ -28,7 +29,7 @@ try:
) )
except ImportError: except ImportError:
# Define an exception class RequestException so that our code does not break. # Define an exception class RequestException so that our code does not break.
class RequestException(Exception): class RequestException(Exception): # type: ignore
pass pass
@ -60,19 +61,26 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
) )
def _get_tls_config(fail_function, **kwargs): if t.TYPE_CHECKING:
from collections.abc import Callable
def _get_tls_config(
fail_function: Callable[[str], t.NoReturn], **kwargs: t.Any
) -> TLSConfig:
try: try:
tls_config = TLSConfig(**kwargs) return TLSConfig(**kwargs)
return tls_config
except TLSParameterError as exc: except TLSParameterError as exc:
fail_function(f"TLS config error: {exc}") fail_function(f"TLS config error: {exc}")
def is_using_tls(auth_data): def is_using_tls(auth_data: dict[str, t.Any]) -> bool:
return auth_data["tls_verify"] or auth_data["tls"] return auth_data["tls_verify"] or auth_data["tls"]
def get_connect_params(auth_data, fail_function): def get_connect_params(
auth_data: dict[str, t.Any], fail_function: Callable[[str], t.NoReturn]
) -> dict[str, t.Any]:
if is_using_tls(auth_data): if is_using_tls(auth_data):
auth_data["docker_host"] = auth_data["docker_host"].replace( auth_data["docker_host"] = auth_data["docker_host"].replace(
"tcp://", "https://" "tcp://", "https://"
@ -114,7 +122,7 @@ def get_connect_params(auth_data, fail_function):
class AnsibleDockerClientBase(Client): class AnsibleDockerClientBase(Client):
def __init__(self, min_docker_api_version=None): def __init__(self, min_docker_api_version: str | None = None) -> None:
self._connect_params = get_connect_params( self._connect_params = get_connect_params(
self.auth_params, fail_function=self.fail self.auth_params, fail_function=self.fail
) )
@ -138,23 +146,34 @@ class AnsibleDockerClientBase(Client):
f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}." f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}."
) )
def log(self, msg, pretty_print=False): def log(self, msg: t.Any, pretty_print: bool = False):
pass pass
# if self.debug: # if self.debug:
# from .util import log_debug # from .util import log_debug
# log_debug(msg, pretty_print=pretty_print) # log_debug(msg, pretty_print=pretty_print)
@abc.abstractmethod @abc.abstractmethod
def fail(self, msg, **kwargs): def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
pass pass
def deprecate(self, msg, version=None, date=None, collection_name=None): @abc.abstractmethod
def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
pass pass
@staticmethod @staticmethod
def _get_value( def _get_value(
param_name, param_value, env_variable, default_value, value_type="str" param_name: str,
): param_value: t.Any,
env_variable: str | None,
default_value: t.Any | None,
value_type: t.Literal["str", "bool", "int"] = "str",
) -> t.Any:
if param_value is not None: if param_value is not None:
# take module parameter value # take module parameter value
if value_type == "bool": if value_type == "bool":
@ -191,11 +210,11 @@ class AnsibleDockerClientBase(Client):
return default_value return default_value
@abc.abstractmethod @abc.abstractmethod
def _get_params(self): def _get_params(self) -> dict[str, t.Any]:
pass pass
@property @property
def auth_params(self): def auth_params(self) -> dict[str, t.Any]:
# Get authentication credentials. # Get authentication credentials.
# Precedence: module parameters-> environment variables-> defaults. # Precedence: module parameters-> environment variables-> defaults.
@ -288,7 +307,7 @@ class AnsibleDockerClientBase(Client):
return result return result
def _handle_ssl_error(self, error): def _handle_ssl_error(self, error: Exception) -> t.NoReturn:
match = re.match(r"hostname.*doesn\'t match (\'.*\')", str(error)) match = re.match(r"hostname.*doesn\'t match (\'.*\')", str(error))
if match: if match:
hostname = self.auth_params["tls_hostname"] hostname = self.auth_params["tls_hostname"]
@ -300,7 +319,7 @@ class AnsibleDockerClientBase(Client):
) )
self.fail(f"SSL Exception: {error}") self.fail(f"SSL Exception: {error}")
def get_container_by_id(self, container_id): def get_container_by_id(self, container_id: str) -> dict[str, t.Any] | None:
try: try:
self.log(f"Inspecting container Id {container_id}") self.log(f"Inspecting container Id {container_id}")
result = self.get_json("/containers/{0}/json", container_id) result = self.get_json("/containers/{0}/json", container_id)
@ -311,7 +330,7 @@ class AnsibleDockerClientBase(Client):
except Exception as exc: # pylint: disable=broad-exception-caught except Exception as exc: # pylint: disable=broad-exception-caught
self.fail(f"Error inspecting container: {exc}") self.fail(f"Error inspecting container: {exc}")
def get_container(self, name=None): def get_container(self, name: str | None) -> dict[str, t.Any] | None:
""" """
Lookup a container and return the inspection results. Lookup a container and return the inspection results.
""" """
@ -355,7 +374,9 @@ class AnsibleDockerClientBase(Client):
return self.get_container_by_id(result["Id"]) return self.get_container_by_id(result["Id"])
def get_network(self, name=None, network_id=None): def get_network(
self, name: str | None = None, network_id: str | None = None
) -> dict[str, t.Any] | None:
""" """
Lookup a network and return the inspection results. Lookup a network and return the inspection results.
""" """
@ -395,14 +416,14 @@ class AnsibleDockerClientBase(Client):
return result return result
def _image_lookup(self, name, tag): def _image_lookup(self, name: str, tag: str | None) -> list[dict[str, t.Any]]:
""" """
Including a tag in the name parameter sent to the Docker SDK for Python images method Including a tag in the name parameter sent to the Docker SDK for Python images method
does not work consistently. Instead, get the result set for name and manually check does not work consistently. Instead, get the result set for name and manually check
if the tag exists. if the tag exists.
""" """
try: try:
params = { params: dict[str, t.Any] = {
"only_ids": 0, "only_ids": 0,
"all": 0, "all": 0,
} }
@ -427,7 +448,7 @@ class AnsibleDockerClientBase(Client):
break break
return images return images
def find_image(self, name, tag): def find_image(self, name: str, tag: str | None) -> dict[str, t.Any] | None:
""" """
Lookup an image (by name and tag) and return the inspection results. Lookup an image (by name and tag) and return the inspection results.
""" """
@ -478,7 +499,9 @@ class AnsibleDockerClientBase(Client):
self.log(f"Image {name}:{tag} not found.") self.log(f"Image {name}:{tag} not found.")
return None return None
def find_image_by_id(self, image_id, accept_missing_image=False): def find_image_by_id(
self, image_id: str, accept_missing_image: bool = False
) -> dict[str, t.Any] | None:
""" """
Lookup an image (by ID) and return the inspection results. Lookup an image (by ID) and return the inspection results.
""" """
@ -496,7 +519,9 @@ class AnsibleDockerClientBase(Client):
except Exception as exc: # pylint: disable=broad-exception-caught except Exception as exc: # pylint: disable=broad-exception-caught
self.fail(f"Error inspecting image ID {image_id} - {exc}") self.fail(f"Error inspecting image ID {image_id} - {exc}")
def pull_image(self, name, tag="latest", image_platform=None): def pull_image(
self, name: str, tag: str = "latest", image_platform: str | None = None
) -> tuple[dict[str, t.Any] | None, bool]:
""" """
Pull an image Pull an image
""" """
@ -544,22 +569,26 @@ class AnsibleDockerClientBase(Client):
class AnsibleDockerClient(AnsibleDockerClientBase): class AnsibleDockerClient(AnsibleDockerClientBase):
def __init__( def __init__(
self, self,
argument_spec=None, argument_spec: dict[str, t.Any] | None = None,
supports_check_mode=False, supports_check_mode: bool = False,
mutually_exclusive=None, mutually_exclusive: Sequence[Sequence[str]] | None = None,
required_together=None, required_together: Sequence[Sequence[str]] | None = None,
required_if=None, required_if: (
required_one_of=None, Sequence[
required_by=None, tuple[str, t.Any, Sequence[str]]
min_docker_api_version=None, | tuple[str, t.Any, Sequence[str], bool]
option_minimal_versions=None, ]
option_minimal_versions_ignore_params=None, | None
fail_results=None, ) = None,
required_one_of: Sequence[Sequence[str]] | None = None,
required_by: dict[str, Sequence[str]] | None = None,
min_docker_api_version: str | None = None,
option_minimal_versions: dict[str, t.Any] | None = None,
option_minimal_versions_ignore_params: Sequence[str] | None = None,
fail_results: dict[str, t.Any] | None = None,
): ):
# Modules can put information in here which will always be returned # Modules can put information in here which will always be returned
# in case client.fail() is called. # in case client.fail() is called.
self.fail_results = fail_results or {} self.fail_results = fail_results or {}
@ -570,12 +599,12 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
merged_arg_spec.update(argument_spec) merged_arg_spec.update(argument_spec)
self.arg_spec = merged_arg_spec self.arg_spec = merged_arg_spec
mutually_exclusive_params = [] mutually_exclusive_params: list[Sequence[str]] = []
mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE
if mutually_exclusive: if mutually_exclusive:
mutually_exclusive_params += mutually_exclusive mutually_exclusive_params += mutually_exclusive
required_together_params = [] required_together_params: list[Sequence[str]] = []
required_together_params += DOCKER_REQUIRED_TOGETHER required_together_params += DOCKER_REQUIRED_TOGETHER
if required_together: if required_together:
required_together_params += required_together required_together_params += required_together
@ -600,20 +629,30 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
option_minimal_versions, option_minimal_versions_ignore_params option_minimal_versions, option_minimal_versions_ignore_params
) )
def fail(self, msg, **kwargs): def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
self.fail_results.update(kwargs) self.fail_results.update(kwargs)
self.module.fail_json(msg=msg, **sanitize_result(self.fail_results)) self.module.fail_json(msg=msg, **sanitize_result(self.fail_results))
def deprecate(self, msg, version=None, date=None, collection_name=None): def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
self.module.deprecate( self.module.deprecate(
msg, version=version, date=date, collection_name=collection_name msg, version=version, date=date, collection_name=collection_name
) )
def _get_params(self): def _get_params(self) -> dict[str, t.Any]:
return self.module.params return self.module.params
def _get_minimal_versions(self, option_minimal_versions, ignore_params=None): def _get_minimal_versions(
self.option_minimal_versions = {} self,
option_minimal_versions: dict[str, t.Any],
ignore_params: Sequence[str] | None = None,
) -> None:
self.option_minimal_versions: dict[str, dict[str, t.Any]] = {}
for option in self.module.argument_spec: for option in self.module.argument_spec:
if ignore_params is not None: if ignore_params is not None:
if option in ignore_params: if option in ignore_params:
@ -654,7 +693,9 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
msg = f"Cannot {usg} with your configuration." msg = f"Cannot {usg} with your configuration."
self.fail(msg) self.fail(msg)
def report_warnings(self, result, warnings_key=None): def report_warnings(
self, result: t.Any, warnings_key: Sequence[str] | None = None
) -> None:
""" """
Checks result of client operation for warnings, and if present, outputs them. Checks result of client operation for warnings, and if present, outputs them.

View File

@ -10,6 +10,7 @@ from __future__ import annotations
import abc import abc
import json import json
import shlex import shlex
import typing as t
from ansible.module_utils.basic import AnsibleModule, env_fallback from ansible.module_utils.basic import AnsibleModule, env_fallback
from ansible.module_utils.common.process import get_bin_path from ansible.module_utils.common.process import get_bin_path
@ -31,6 +32,10 @@ from ansible_collections.community.docker.plugins.module_utils._version import (
) )
if t.TYPE_CHECKING:
from collections.abc import Mapping, Sequence
DOCKER_COMMON_ARGS = { DOCKER_COMMON_ARGS = {
"docker_cli": {"type": "path"}, "docker_cli": {"type": "path"},
"docker_host": { "docker_host": {
@ -72,10 +77,16 @@ class DockerException(Exception):
class AnsibleDockerClientBase: class AnsibleDockerClientBase:
docker_api_version_str: str | None
docker_api_version: LooseVersion | None
def __init__( def __init__(
self, common_args, min_docker_api_version=None, needs_api_version=True self,
): common_args,
self._environment = {} min_docker_api_version: str | None = None,
needs_api_version: bool = True,
) -> None:
self._environment: dict[str, str] = {}
if common_args["tls_hostname"]: if common_args["tls_hostname"]:
self._environment["DOCKER_TLS_HOSTNAME"] = common_args["tls_hostname"] self._environment["DOCKER_TLS_HOSTNAME"] = common_args["tls_hostname"]
if common_args["api_version"] and common_args["api_version"] != "auto": if common_args["api_version"] and common_args["api_version"] != "auto":
@ -109,10 +120,10 @@ class AnsibleDockerClientBase:
self._cli_base.extend(["--context", common_args["cli_context"]]) self._cli_base.extend(["--context", common_args["cli_context"]])
# `--format json` was only added as a shorthand for `--format {{ json . }}` in Docker 23.0 # `--format json` was only added as a shorthand for `--format {{ json . }}` in Docker 23.0
dummy, self._version, dummy = self.call_cli_json( dummy, self._version, dummy2 = self.call_cli_json(
"version", "--format", "{{ json . }}", check_rc=True "version", "--format", "{{ json . }}", check_rc=True
) )
self._info = None self._info: dict[str, t.Any] | None = None
if needs_api_version: if needs_api_version:
if not isinstance(self._version.get("Server"), dict) or not isinstance( if not isinstance(self._version.get("Server"), dict) or not isinstance(
@ -138,32 +149,47 @@ class AnsibleDockerClientBase:
"Internal error: cannot have needs_api_version=False with min_docker_api_version not None" "Internal error: cannot have needs_api_version=False with min_docker_api_version not None"
) )
def log(self, msg, pretty_print=False): def log(self, msg: str, pretty_print: bool = False):
pass pass
# if self.debug: # if self.debug:
# from .util import log_debug # from .util import log_debug
# log_debug(msg, pretty_print=pretty_print) # log_debug(msg, pretty_print=pretty_print)
def get_cli(self): def get_cli(self) -> str:
return self._cli return self._cli
def get_version_info(self): def get_version_info(self) -> str:
return self._version return self._version
def _compose_cmd(self, args): def _compose_cmd(self, args: t.Sequence[str]) -> list[str]:
return self._cli_base + list(args) return self._cli_base + list(args)
def _compose_cmd_str(self, args): def _compose_cmd_str(self, args: t.Sequence[str]) -> str:
return " ".join(shlex.quote(a) for a in self._compose_cmd(args)) return " ".join(shlex.quote(a) for a in self._compose_cmd(args))
@abc.abstractmethod @abc.abstractmethod
def call_cli(self, *args, check_rc=False, data=None, cwd=None, environ_update=None): def call_cli(
self,
*args: str,
check_rc: bool = False,
data: bytes | None = None,
cwd: str | None = None,
environ_update: dict[str, str] | None = None,
) -> tuple[int, bytes, bytes]:
pass pass
# def call_cli_json(self, *args, check_rc=False, data=None, cwd=None, environ_update=None, warn_on_stderr=False): def call_cli_json(
def call_cli_json(self, *args, **kwargs): self,
warn_on_stderr = kwargs.pop("warn_on_stderr", False) *args: str,
rc, stdout, stderr = self.call_cli(*args, **kwargs) check_rc: bool = False,
data: bytes | None = None,
cwd: str | None = None,
environ_update: dict[str, str] | None = None,
warn_on_stderr: bool = False,
) -> tuple[int, t.Any, bytes]:
rc, stdout, stderr = self.call_cli(
*args, check_rc=check_rc, data=data, cwd=cwd, environ_update=environ_update
)
if warn_on_stderr and stderr: if warn_on_stderr and stderr:
self.warn(to_native(stderr)) self.warn(to_native(stderr))
try: try:
@ -174,10 +200,18 @@ class AnsibleDockerClientBase:
) )
return rc, data, stderr return rc, data, stderr
# def call_cli_json_stream(self, *args, check_rc=False, data=None, cwd=None, environ_update=None, warn_on_stderr=False): def call_cli_json_stream(
def call_cli_json_stream(self, *args, **kwargs): self,
warn_on_stderr = kwargs.pop("warn_on_stderr", False) *args: str,
rc, stdout, stderr = self.call_cli(*args, **kwargs) check_rc: bool = False,
data: bytes | None = None,
cwd: str | None = None,
environ_update: dict[str, str] | None = None,
warn_on_stderr: bool = False,
) -> tuple[int, list[t.Any], bytes]:
rc, stdout, stderr = self.call_cli(
*args, check_rc=check_rc, data=data, cwd=cwd, environ_update=environ_update
)
if warn_on_stderr and stderr: if warn_on_stderr and stderr:
self.warn(to_native(stderr)) self.warn(to_native(stderr))
result = [] result = []
@ -193,25 +227,31 @@ class AnsibleDockerClientBase:
return rc, result, stderr return rc, result, stderr
@abc.abstractmethod @abc.abstractmethod
def fail(self, msg, **kwargs): def fail(self, msg: str, **kwargs) -> t.NoReturn:
pass pass
@abc.abstractmethod @abc.abstractmethod
def warn(self, msg): def warn(self, msg: str) -> None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def deprecate(self, msg, version=None, date=None, collection_name=None): def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
pass pass
def get_cli_info(self): def get_cli_info(self) -> dict[str, t.Any]:
if self._info is None: if self._info is None:
dummy, self._info, dummy = self.call_cli_json( dummy, self._info, dummy2 = self.call_cli_json(
"info", "--format", "{{ json . }}", check_rc=True "info", "--format", "{{ json . }}", check_rc=True
) )
return self._info return self._info
def get_client_plugin_info(self, component): def get_client_plugin_info(self, component: str) -> dict[str, t.Any] | None:
cli_info = self.get_cli_info() cli_info = self.get_cli_info()
if not isinstance(cli_info.get("ClientInfo"), dict): if not isinstance(cli_info.get("ClientInfo"), dict):
self.fail( self.fail(
@ -222,13 +262,13 @@ class AnsibleDockerClientBase:
return plugin return plugin
return None return None
def _image_lookup(self, name, tag): def _image_lookup(self, name: str, tag: str) -> list[dict[str, t.Any]]:
""" """
Including a tag in the name parameter sent to the Docker SDK for Python images method Including a tag in the name parameter sent to the Docker SDK for Python images method
does not work consistently. Instead, get the result set for name and manually check does not work consistently. Instead, get the result set for name and manually check
if the tag exists. if the tag exists.
""" """
dummy, images, dummy = self.call_cli_json_stream( dummy, images, dummy2 = self.call_cli_json_stream(
"image", "image",
"ls", "ls",
"--format", "--format",
@ -247,7 +287,13 @@ class AnsibleDockerClientBase:
break break
return images return images
def find_image(self, name, tag): @t.overload
def find_image(self, name: None, tag: str) -> None: ...
@t.overload
def find_image(self, name: str, tag: str) -> dict[str, t.Any] | None: ...
def find_image(self, name: str | None, tag: str) -> dict[str, t.Any] | None:
""" """
Lookup an image (by name and tag) and return the inspection results. Lookup an image (by name and tag) and return the inspection results.
""" """
@ -298,7 +344,19 @@ class AnsibleDockerClientBase:
self.log(f"Image {name}:{tag} not found.") self.log(f"Image {name}:{tag} not found.")
return None return None
def find_image_by_id(self, image_id, accept_missing_image=False): @t.overload
def find_image_by_id(
self, image_id: None, accept_missing_image: bool = False
) -> None: ...
@t.overload
def find_image_by_id(
self, image_id: str | None, accept_missing_image: bool = False
) -> dict[str, t.Any] | None: ...
def find_image_by_id(
self, image_id: str | None, accept_missing_image: bool = False
) -> dict[str, t.Any] | None:
""" """
Lookup an image (by ID) and return the inspection results. Lookup an image (by ID) and return the inspection results.
""" """
@ -320,17 +378,23 @@ class AnsibleDockerClientBase:
class AnsibleModuleDockerClient(AnsibleDockerClientBase): class AnsibleModuleDockerClient(AnsibleDockerClientBase):
def __init__( def __init__(
self, self,
argument_spec=None, argument_spec: dict[str, t.Any] | None = None,
supports_check_mode=False, supports_check_mode: bool = False,
mutually_exclusive=None, mutually_exclusive: Sequence[Sequence[str]] | None = None,
required_together=None, required_together: Sequence[Sequence[str]] | None = None,
required_if=None, required_if: (
required_one_of=None, Sequence[
required_by=None, tuple[str, t.Any, Sequence[str]]
min_docker_api_version=None, | tuple[str, t.Any, Sequence[str], bool]
fail_results=None, ]
needs_api_version=True, | None
): ) = None,
required_one_of: Sequence[Sequence[str]] | None = None,
required_by: Mapping[str, Sequence[str]] | None = None,
min_docker_api_version: str | None = None,
fail_results: dict[str, t.Any] | None = None,
needs_api_version: bool = True,
) -> None:
# Modules can put information in here which will always be returned # Modules can put information in here which will always be returned
# in case client.fail() is called. # in case client.fail() is called.
@ -342,12 +406,14 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase):
merged_arg_spec.update(argument_spec) merged_arg_spec.update(argument_spec)
self.arg_spec = merged_arg_spec self.arg_spec = merged_arg_spec
mutually_exclusive_params = [("docker_host", "cli_context")] mutually_exclusive_params: list[Sequence[str]] = [
("docker_host", "cli_context")
]
mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE
if mutually_exclusive: if mutually_exclusive:
mutually_exclusive_params += mutually_exclusive mutually_exclusive_params += mutually_exclusive
required_together_params = [] required_together_params: list[Sequence[str]] = []
required_together_params += DOCKER_REQUIRED_TOGETHER required_together_params += DOCKER_REQUIRED_TOGETHER
if required_together: if required_together:
required_together_params += required_together required_together_params += required_together
@ -373,7 +439,14 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase):
needs_api_version=needs_api_version, needs_api_version=needs_api_version,
) )
def call_cli(self, *args, check_rc=False, data=None, cwd=None, environ_update=None): def call_cli(
self,
*args: str,
check_rc: bool = False,
data: bytes | None = None,
cwd: str | None = None,
environ_update: dict[str, str] | None = None,
) -> tuple[int, bytes, bytes]:
environment = self._environment.copy() environment = self._environment.copy()
if environ_update: if environ_update:
environment.update(environ_update) environment.update(environ_update)
@ -390,14 +463,20 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase):
) )
return rc, stdout, stderr return rc, stdout, stderr
def fail(self, msg, **kwargs): def fail(self, msg: str, **kwargs) -> t.NoReturn:
self.fail_results.update(kwargs) self.fail_results.update(kwargs)
self.module.fail_json(msg=msg, **sanitize_result(self.fail_results)) self.module.fail_json(msg=msg, **sanitize_result(self.fail_results))
def warn(self, msg): def warn(self, msg: str) -> None:
self.module.warn(msg) self.module.warn(msg)
def deprecate(self, msg, version=None, date=None, collection_name=None): def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
self.module.deprecate( self.module.deprecate(
msg, version=version, date=date, collection_name=collection_name msg, version=version, date=date, collection_name=collection_name
) )

View File

@ -14,6 +14,7 @@ import re
import shutil import shutil
import tempfile import tempfile
import traceback import traceback
import typing as t
from collections import namedtuple from collections import namedtuple
from shlex import quote from shlex import quote
@ -34,6 +35,7 @@ from ansible_collections.community.docker.plugins.module_utils._version import (
) )
PYYAML_IMPORT_ERROR: None | str # pylint: disable=invalid-name
try: try:
import yaml import yaml
@ -41,7 +43,7 @@ try:
# use C version if possible for speedup # use C version if possible for speedup
from yaml import CSafeDumper as _SafeDumper from yaml import CSafeDumper as _SafeDumper
except ImportError: except ImportError:
from yaml import SafeDumper as _SafeDumper from yaml import SafeDumper as _SafeDumper # type: ignore
except ImportError: except ImportError:
HAS_PYYAML = False HAS_PYYAML = False
PYYAML_IMPORT_ERROR = traceback.format_exc() # pylint: disable=invalid-name PYYAML_IMPORT_ERROR = traceback.format_exc() # pylint: disable=invalid-name
@ -49,6 +51,13 @@ else:
HAS_PYYAML = True HAS_PYYAML = True
PYYAML_IMPORT_ERROR = None # pylint: disable=invalid-name PYYAML_IMPORT_ERROR = None # pylint: disable=invalid-name
if t.TYPE_CHECKING:
from collections.abc import Callable, Sequence
from ansible_collections.community.docker.plugins.module_utils._common_cli import (
AnsibleModuleDockerClient as _Client,
)
DOCKER_COMPOSE_FILES = ( DOCKER_COMPOSE_FILES = (
"compose.yaml", "compose.yaml",
@ -144,8 +153,7 @@ class ResourceType:
SERVICE = "service" SERVICE = "service"
@classmethod @classmethod
def from_docker_compose_event(cls, resource_type): def from_docker_compose_event(cls, resource_type: str) -> t.Any:
# type: (Type[ResourceType], Text) -> Any
return { return {
"Network": cls.NETWORK, "Network": cls.NETWORK,
"Image": cls.IMAGE, "Image": cls.IMAGE,
@ -240,7 +248,9 @@ _RE_BUILD_PROGRESS_EVENT = re.compile(r"^\s*==>\s+(?P<msg>.*)$")
MINIMUM_COMPOSE_VERSION = "2.18.0" MINIMUM_COMPOSE_VERSION = "2.18.0"
def _extract_event(line, warn_function=None): def _extract_event(
line: str, warn_function: Callable[[str], None] | None = None
) -> tuple[Event | None, bool]:
match = _RE_RESOURCE_EVENT.match(line) match = _RE_RESOURCE_EVENT.match(line)
if match is not None: if match is not None:
status = match.group("status") status = match.group("status")
@ -323,7 +333,9 @@ def _extract_event(line, warn_function=None):
return None, False return None, False
def _extract_logfmt_event(line, warn_function=None): def _extract_logfmt_event(
line: str, warn_function: Callable[[str], None] | None = None
) -> tuple[Event | None, bool]:
try: try:
result = _parse_logfmt_line(line, logrus_mode=True) result = _parse_logfmt_line(line, logrus_mode=True)
except _InvalidLogFmt: except _InvalidLogFmt:
@ -338,7 +350,11 @@ def _extract_logfmt_event(line, warn_function=None):
return None, False return None, False
def _warn_missing_dry_run_prefix(line, warn_missing_dry_run_prefix, warn_function): def _warn_missing_dry_run_prefix(
line: str,
warn_missing_dry_run_prefix: bool,
warn_function: Callable[[str], None] | None,
) -> None:
if warn_missing_dry_run_prefix and warn_function: if warn_missing_dry_run_prefix and warn_function:
# This could be a bug, a change of docker compose's output format, ... # This could be a bug, a change of docker compose's output format, ...
# Tell the user to report it to us :-) # Tell the user to report it to us :-)
@ -349,7 +365,9 @@ def _warn_missing_dry_run_prefix(line, warn_missing_dry_run_prefix, warn_functio
) )
def _warn_unparsable_line(line, warn_function): def _warn_unparsable_line(
line: str, warn_function: Callable[[str], None] | None
) -> None:
# This could be a bug, a change of docker compose's output format, ... # This could be a bug, a change of docker compose's output format, ...
# Tell the user to report it to us :-) # Tell the user to report it to us :-)
if warn_function: if warn_function:
@ -360,14 +378,16 @@ def _warn_unparsable_line(line, warn_function):
) )
def _find_last_event_for(events, resource_id): def _find_last_event_for(
events: list[Event], resource_id: str
) -> tuple[int, Event] | None:
for index, event in enumerate(reversed(events)): for index, event in enumerate(reversed(events)):
if event.resource_id == resource_id: if event.resource_id == resource_id:
return len(events) - 1 - index, event return len(events) - 1 - index, event
return None return None
def _concat_event_msg(event, append_msg): def _concat_event_msg(event: Event, append_msg: str) -> Event:
return Event( return Event(
event.resource_type, event.resource_type,
event.resource_id, event.resource_id,
@ -382,7 +402,9 @@ _JSON_LEVEL_TO_STATUS_MAP = {
} }
def parse_json_events(stderr, warn_function=None): def parse_json_events(
stderr: bytes, warn_function: Callable[[str], None] | None = None
) -> list[Event]:
events = [] events = []
stderr_lines = stderr.splitlines() stderr_lines = stderr.splitlines()
if stderr_lines and stderr_lines[-1] == b"": if stderr_lines and stderr_lines[-1] == b"":
@ -523,7 +545,12 @@ def parse_json_events(stderr, warn_function=None):
return events return events
def parse_events(stderr, dry_run=False, warn_function=None, nonzero_rc=False): def parse_events(
stderr: bytes,
dry_run: bool = False,
warn_function: Callable[[str], None] | None = None,
nonzero_rc: bool = False,
) -> list[Event]:
events = [] events = []
error_event = None error_event = None
stderr_lines = stderr.splitlines() stderr_lines = stderr.splitlines()
@ -597,7 +624,11 @@ def parse_events(stderr, dry_run=False, warn_function=None, nonzero_rc=False):
return events return events
def has_changes(events, ignore_service_pull_events=False, ignore_build_events=False): def has_changes(
events: Sequence[Event],
ignore_service_pull_events: bool = False,
ignore_build_events: bool = False,
) -> bool:
for event in events: for event in events:
if event.status in DOCKER_STATUS_WORKING: if event.status in DOCKER_STATUS_WORKING:
if ignore_service_pull_events and event.status in DOCKER_STATUS_PULL: if ignore_service_pull_events and event.status in DOCKER_STATUS_PULL:
@ -613,7 +644,7 @@ def has_changes(events, ignore_service_pull_events=False, ignore_build_events=Fa
return False return False
def extract_actions(events): def extract_actions(events: Sequence[Event]) -> list[dict[str, t.Any]]:
actions = [] actions = []
pull_actions = set() pull_actions = set()
for event in events: for event in events:
@ -645,7 +676,9 @@ def extract_actions(events):
return actions return actions
def emit_warnings(events, warn_function): def emit_warnings(
events: Sequence[Event], warn_function: Callable[[str], None]
) -> None:
for event in events: for event in events:
# If a message is present, assume it is a warning # If a message is present, assume it is a warning
if ( if (
@ -656,13 +689,21 @@ def emit_warnings(events, warn_function):
) )
def is_failed(events, rc): def is_failed(events: Sequence[Event], rc: int) -> bool:
if rc: if rc:
return True return True
return False return False
def update_failed(result, events, args, stdout, stderr, rc, cli): def update_failed(
result: dict[str, t.Any],
events: Sequence[Event],
args: list[str],
stdout: str | bytes,
stderr: str | bytes,
rc: int,
cli: str,
) -> bool:
if not rc: if not rc:
return False return False
errors = [] errors = []
@ -696,7 +737,7 @@ def update_failed(result, events, args, stdout, stderr, rc, cli):
return True return True
def common_compose_argspec(): def common_compose_argspec() -> dict[str, t.Any]:
return { return {
"project_src": {"type": "path"}, "project_src": {"type": "path"},
"project_name": {"type": "str"}, "project_name": {"type": "str"},
@ -708,7 +749,7 @@ def common_compose_argspec():
} }
def common_compose_argspec_ex(): def common_compose_argspec_ex() -> dict[str, t.Any]:
return { return {
"argspec": common_compose_argspec(), "argspec": common_compose_argspec(),
"mutually_exclusive": [("definition", "project_src"), ("definition", "files")], "mutually_exclusive": [("definition", "project_src"), ("definition", "files")],
@ -721,16 +762,18 @@ def common_compose_argspec_ex():
} }
def combine_binary_output(*outputs): def combine_binary_output(*outputs: bytes | None) -> bytes:
return b"\n".join(out for out in outputs if out) return b"\n".join(out for out in outputs if out)
def combine_text_output(*outputs): def combine_text_output(*outputs: str | None) -> str:
return "\n".join(out for out in outputs if out) return "\n".join(out for out in outputs if out)
class BaseComposeManager(DockerBaseClass): class BaseComposeManager(DockerBaseClass):
def __init__(self, client, min_version=MINIMUM_COMPOSE_VERSION): def __init__(
self, client: _Client, min_version: str = MINIMUM_COMPOSE_VERSION
) -> None:
super().__init__() super().__init__()
self.client = client self.client = client
self.check_mode = self.client.check_mode self.check_mode = self.client.check_mode
@ -794,12 +837,12 @@ class BaseComposeManager(DockerBaseClass):
# more precisely in https://github.com/docker/compose/pull/11478 # more precisely in https://github.com/docker/compose/pull/11478
self.use_json_events = self.compose_version >= LooseVersion("2.29.0") self.use_json_events = self.compose_version >= LooseVersion("2.29.0")
def get_compose_version(self): def get_compose_version(self) -> str:
return ( return (
self.get_compose_version_from_cli() or self.get_compose_version_from_api() self.get_compose_version_from_cli() or self.get_compose_version_from_api()
) )
def get_compose_version_from_cli(self): def get_compose_version_from_cli(self) -> str | None:
rc, version_info, dummy_stderr = self.client.call_cli( rc, version_info, dummy_stderr = self.client.call_cli(
"compose", "version", "--format", "json" "compose", "version", "--format", "json"
) )
@ -813,7 +856,7 @@ class BaseComposeManager(DockerBaseClass):
except Exception: # pylint: disable=broad-exception-caught except Exception: # pylint: disable=broad-exception-caught
return None return None
def get_compose_version_from_api(self): def get_compose_version_from_api(self) -> str:
compose = self.client.get_client_plugin_info("compose") compose = self.client.get_client_plugin_info("compose")
if compose is None: if compose is None:
self.fail( self.fail(
@ -826,11 +869,11 @@ class BaseComposeManager(DockerBaseClass):
) )
return compose["Version"].lstrip("v") return compose["Version"].lstrip("v")
def fail(self, msg, **kwargs): def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
self.cleanup() self.cleanup()
self.client.fail(msg, **kwargs) self.client.fail(msg, **kwargs)
def get_base_args(self, plain_progress=False): def get_base_args(self, plain_progress: bool = False) -> list[str]:
args = ["compose", "--ansi", "never"] args = ["compose", "--ansi", "never"]
if self.use_json_events and not plain_progress: if self.use_json_events and not plain_progress:
args.extend(["--progress", "json"]) args.extend(["--progress", "json"])
@ -848,28 +891,33 @@ class BaseComposeManager(DockerBaseClass):
args.extend(["--profile", profile]) args.extend(["--profile", profile])
return args return args
def _handle_failed_cli_call(self, args, rc, stdout, stderr): def _handle_failed_cli_call(
self, args: list[str], rc: int, stdout: str | bytes, stderr: bytes
) -> t.NoReturn:
events = parse_json_events(stderr, warn_function=self.client.warn) events = parse_json_events(stderr, warn_function=self.client.warn)
result = {} result: dict[str, t.Any] = {}
self.update_failed(result, events, args, stdout, stderr, rc) self.update_failed(result, events, args, stdout, stderr, rc)
self.client.module.exit_json(**result) self.client.module.exit_json(**result)
def list_containers_raw(self): def list_containers_raw(self) -> list[dict[str, t.Any]]:
args = self.get_base_args() + ["ps", "--format", "json", "--all"] args = self.get_base_args() + ["ps", "--format", "json", "--all"]
if self.compose_version >= LooseVersion("2.23.0"): if self.compose_version >= LooseVersion("2.23.0"):
# https://github.com/docker/compose/pull/11038 # https://github.com/docker/compose/pull/11038
args.append("--no-trunc") args.append("--no-trunc")
kwargs = {"cwd": self.project_src, "check_rc": not self.use_json_events}
if self.compose_version >= LooseVersion("2.21.0"): if self.compose_version >= LooseVersion("2.21.0"):
# Breaking change in 2.21.0: https://github.com/docker/compose/pull/10918 # Breaking change in 2.21.0: https://github.com/docker/compose/pull/10918
rc, containers, stderr = self.client.call_cli_json_stream(*args, **kwargs) rc, containers, stderr = self.client.call_cli_json_stream(
*args, cwd=self.project_src, check_rc=not self.use_json_events
)
else: else:
rc, containers, stderr = self.client.call_cli_json(*args, **kwargs) rc, containers, stderr = self.client.call_cli_json(
*args, cwd=self.project_src, check_rc=not self.use_json_events
)
if self.use_json_events and rc != 0: if self.use_json_events and rc != 0:
self._handle_failed_cli_call(args, rc, containers, stderr) self._handle_failed_cli_call(args, rc, json.dumps(containers), stderr)
return containers return containers
def list_containers(self): def list_containers(self) -> list[dict[str, t.Any]]:
result = [] result = []
for container in self.list_containers_raw(): for container in self.list_containers_raw():
labels = {} labels = {}
@ -886,10 +934,11 @@ class BaseComposeManager(DockerBaseClass):
result.append(container) result.append(container)
return result return result
def list_images(self): def list_images(self) -> list[str]:
args = self.get_base_args() + ["images", "--format", "json"] args = self.get_base_args() + ["images", "--format", "json"]
kwargs = {"cwd": self.project_src, "check_rc": not self.use_json_events} rc, images, stderr = self.client.call_cli_json(
rc, images, stderr = self.client.call_cli_json(*args, **kwargs) *args, cwd=self.project_src, check_rc=not self.use_json_events
)
if self.use_json_events and rc != 0: if self.use_json_events and rc != 0:
self._handle_failed_cli_call(args, rc, images, stderr) self._handle_failed_cli_call(args, rc, images, stderr)
if isinstance(images, dict): if isinstance(images, dict):
@ -899,7 +948,9 @@ class BaseComposeManager(DockerBaseClass):
images = list(images.values()) images = list(images.values())
return images return images
def parse_events(self, stderr, dry_run=False, nonzero_rc=False): def parse_events(
self, stderr: bytes, dry_run: bool = False, nonzero_rc: bool = False
) -> list[Event]:
if self.use_json_events: if self.use_json_events:
return parse_json_events(stderr, warn_function=self.client.warn) return parse_json_events(stderr, warn_function=self.client.warn)
return parse_events( return parse_events(
@ -909,17 +960,17 @@ class BaseComposeManager(DockerBaseClass):
nonzero_rc=nonzero_rc, nonzero_rc=nonzero_rc,
) )
def emit_warnings(self, events): def emit_warnings(self, events: Sequence[Event]) -> None:
emit_warnings(events, warn_function=self.client.warn) emit_warnings(events, warn_function=self.client.warn)
def update_result( def update_result(
self, self,
result, result: dict[str, t.Any],
events, events: Sequence[Event],
stdout, stdout: str | bytes,
stderr, stderr: str | bytes,
ignore_service_pull_events=False, ignore_service_pull_events: bool = False,
ignore_build_events=False, ignore_build_events: bool = False,
): ):
result["changed"] = result.get("changed", False) or has_changes( result["changed"] = result.get("changed", False) or has_changes(
events, events,
@ -930,7 +981,15 @@ class BaseComposeManager(DockerBaseClass):
result["stdout"] = combine_text_output(result.get("stdout"), to_native(stdout)) result["stdout"] = combine_text_output(result.get("stdout"), to_native(stdout))
result["stderr"] = combine_text_output(result.get("stderr"), to_native(stderr)) result["stderr"] = combine_text_output(result.get("stderr"), to_native(stderr))
def update_failed(self, result, events, args, stdout, stderr, rc): def update_failed(
self,
result: dict[str, t.Any],
events: Sequence[Event],
args: list[str],
stdout: str | bytes,
stderr: bytes,
rc: int,
):
return update_failed( return update_failed(
result, result,
events, events,
@ -941,14 +1000,14 @@ class BaseComposeManager(DockerBaseClass):
cli=self.client.get_cli(), cli=self.client.get_cli(),
) )
def cleanup_result(self, result): def cleanup_result(self, result: dict[str, t.Any]) -> None:
if not result.get("failed"): if not result.get("failed"):
# Only return stdout and stderr if it is not empty # Only return stdout and stderr if it is not empty
for res in ("stdout", "stderr"): for res in ("stdout", "stderr"):
if result.get(res) == "": if result.get(res) == "":
result.pop(res) result.pop(res)
def cleanup(self): def cleanup(self) -> None:
for directory in self.cleanup_dirs: for directory in self.cleanup_dirs:
try: try:
shutil.rmtree(directory, True) shutil.rmtree(directory, True)

View File

@ -16,6 +16,7 @@ import os.path
import shutil import shutil
import stat import stat
import tarfile import tarfile
import typing as t
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
@ -25,6 +26,16 @@ from ansible_collections.community.docker.plugins.module_utils._api.errors impor
) )
if t.TYPE_CHECKING:
from collections.abc import Callable
from _typeshed import WriteableBuffer
from ansible_collections.community.docker.plugins.module_utils._api.api.client import (
APIClient,
)
class DockerFileCopyError(Exception): class DockerFileCopyError(Exception):
pass pass
@ -37,7 +48,9 @@ class DockerFileNotFound(DockerFileCopyError):
pass pass
def _put_archive(client, container, path, data): def _put_archive(
client: APIClient, container: str, path: str, data: bytes | t.Generator[bytes]
) -> bool:
# data can also be file object for streaming. This is because _put uses requests's put(). # data can also be file object for streaming. This is because _put uses requests's put().
# See https://requests.readthedocs.io/en/latest/user/advanced/#streaming-uploads # See https://requests.readthedocs.io/en/latest/user/advanced/#streaming-uploads
url = client._url("/containers/{0}/archive", container) url = client._url("/containers/{0}/archive", container)
@ -47,8 +60,14 @@ def _put_archive(client, container, path, data):
def _symlink_tar_creator( def _symlink_tar_creator(
b_in_path, file_stat, out_file, user_id, group_id, mode=None, user_name=None b_in_path: bytes,
): file_stat: os.stat_result,
out_file: str | bytes,
user_id: int,
group_id: int,
mode: int | None = None,
user_name: str | None = None,
) -> bytes:
if not stat.S_ISLNK(file_stat.st_mode): if not stat.S_ISLNK(file_stat.st_mode):
raise DockerUnexpectedError("stat information is not for a symlink") raise DockerUnexpectedError("stat information is not for a symlink")
bio = io.BytesIO() bio = io.BytesIO()
@ -75,16 +94,28 @@ def _symlink_tar_creator(
def _symlink_tar_generator( def _symlink_tar_generator(
b_in_path, file_stat, out_file, user_id, group_id, mode=None, user_name=None b_in_path: bytes,
): file_stat: os.stat_result,
out_file: str | bytes,
user_id: int,
group_id: int,
mode: int | None = None,
user_name: str | None = None,
) -> t.Generator[bytes]:
yield _symlink_tar_creator( yield _symlink_tar_creator(
b_in_path, file_stat, out_file, user_id, group_id, mode, user_name b_in_path, file_stat, out_file, user_id, group_id, mode, user_name
) )
def _regular_file_tar_generator( def _regular_file_tar_generator(
b_in_path, file_stat, out_file, user_id, group_id, mode=None, user_name=None b_in_path: bytes,
): file_stat: os.stat_result,
out_file: str | bytes,
user_id: int,
group_id: int,
mode: int | None = None,
user_name: str | None = None,
) -> t.Generator[bytes]:
if not stat.S_ISREG(file_stat.st_mode): if not stat.S_ISREG(file_stat.st_mode):
raise DockerUnexpectedError("stat information is not for a regular file") raise DockerUnexpectedError("stat information is not for a regular file")
tarinfo = tarfile.TarInfo() tarinfo = tarfile.TarInfo()
@ -136,8 +167,13 @@ def _regular_file_tar_generator(
def _regular_content_tar_generator( def _regular_content_tar_generator(
content, out_file, user_id, group_id, mode, user_name=None content: bytes,
): out_file: str | bytes,
user_id: int,
group_id: int,
mode: int,
user_name: str | None = None,
) -> t.Generator[bytes]:
tarinfo = tarfile.TarInfo() tarinfo = tarfile.TarInfo()
tarinfo.name = ( tarinfo.name = (
os.path.splitdrive(to_text(out_file))[1].replace(os.sep, "/").lstrip("/") os.path.splitdrive(to_text(out_file))[1].replace(os.sep, "/").lstrip("/")
@ -175,16 +211,16 @@ def _regular_content_tar_generator(
def put_file( def put_file(
client, client: APIClient,
container, container: str,
in_path, in_path: str,
out_path, out_path: str,
user_id, user_id: int,
group_id, group_id: int,
mode=None, mode: int | None = None,
user_name=None, user_name: str | None = None,
follow_links=False, follow_links: bool = False,
): ) -> None:
"""Transfer a file from local to Docker container.""" """Transfer a file from local to Docker container."""
if not os.path.exists(to_bytes(in_path, errors="surrogate_or_strict")): if not os.path.exists(to_bytes(in_path, errors="surrogate_or_strict")):
raise DockerFileNotFound(f"file or module does not exist: {to_native(in_path)}") raise DockerFileNotFound(f"file or module does not exist: {to_native(in_path)}")
@ -232,8 +268,15 @@ def put_file(
def put_file_content( def put_file_content(
client, container, content, out_path, user_id, group_id, mode, user_name=None client: APIClient,
): container: str,
content: bytes,
out_path: str,
user_id: int,
group_id: int,
mode: int,
user_name: str | None = None,
) -> None:
"""Transfer a file from local to Docker container.""" """Transfer a file from local to Docker container."""
out_dir, out_file = os.path.split(out_path) out_dir, out_file = os.path.split(out_path)
@ -248,7 +291,13 @@ def put_file_content(
) )
def stat_file(client, container, in_path, follow_links=False, log=None): def stat_file(
client: APIClient,
container: str,
in_path: str,
follow_links: bool = False,
log: Callable[[str], None] | None = None,
) -> tuple[str, dict[str, t.Any] | None, str | None]:
"""Fetch information on a file from a Docker container to local. """Fetch information on a file from a Docker container to local.
Return a tuple ``(path, stat_data, link_target)`` where: Return a tuple ``(path, stat_data, link_target)`` where:
@ -265,12 +314,12 @@ def stat_file(client, container, in_path, follow_links=False, log=None):
while True: while True:
if in_path in considered_in_paths: if in_path in considered_in_paths:
raise DockerFileCopyError( raise DockerFileCopyError(
f'Found infinite symbolic link loop when trying to stating "{in_path}"' f"Found infinite symbolic link loop when trying to stating {in_path!r}"
) )
considered_in_paths.add(in_path) considered_in_paths.add(in_path)
if log: if log:
log(f'FETCH: Stating "{in_path}"') log(f"FETCH: Stating {in_path!r}")
response = client._head( response = client._head(
client._url("/containers/{0}/archive", container), client._url("/containers/{0}/archive", container),
@ -299,24 +348,24 @@ def stat_file(client, container, in_path, follow_links=False, log=None):
class _RawGeneratorFileobj(io.RawIOBase): class _RawGeneratorFileobj(io.RawIOBase):
def __init__(self, stream): def __init__(self, stream: t.Generator[bytes]):
self._stream = stream self._stream = stream
self._buf = b"" self._buf = b""
def readable(self): def readable(self) -> bool:
return True return True
def _readinto_from_buf(self, b, index, length): def _readinto_from_buf(self, b: WriteableBuffer, index: int, length: int) -> int:
cpy = min(length - index, len(self._buf)) cpy = min(length - index, len(self._buf))
if cpy: if cpy:
b[index : index + cpy] = self._buf[:cpy] b[index : index + cpy] = self._buf[:cpy] # type: ignore # TODO!
self._buf = self._buf[cpy:] self._buf = self._buf[cpy:]
index += cpy index += cpy
return index return index
def readinto(self, b): def readinto(self, b: WriteableBuffer) -> int:
index = 0 index = 0
length = len(b) length = len(b) # type: ignore # TODO!
index = self._readinto_from_buf(b, index, length) index = self._readinto_from_buf(b, index, length)
if index == length: if index == length:
@ -330,25 +379,28 @@ class _RawGeneratorFileobj(io.RawIOBase):
return self._readinto_from_buf(b, index, length) return self._readinto_from_buf(b, index, length)
def _stream_generator_to_fileobj(stream): def _stream_generator_to_fileobj(stream: t.Generator[bytes]) -> io.BufferedReader:
"""Given a generator that generates chunks of bytes, create a readable buffered stream.""" """Given a generator that generates chunks of bytes, create a readable buffered stream."""
raw = _RawGeneratorFileobj(stream) raw = _RawGeneratorFileobj(stream)
return io.BufferedReader(raw) return io.BufferedReader(raw)
_T = t.TypeVar("_T")
def fetch_file_ex( def fetch_file_ex(
client, client: APIClient,
container, container: str,
in_path, in_path: str,
process_none, process_none: Callable[[str], _T],
process_regular, process_regular: Callable[[str, tarfile.TarFile, tarfile.TarInfo], _T],
process_symlink, process_symlink: Callable[[str, tarfile.TarInfo], _T],
process_other, process_other: Callable[[str, tarfile.TarInfo], _T],
follow_links=False, follow_links: bool = False,
log=None, log: Callable[[str], None] | None = None,
): ) -> _T:
"""Fetch a file (as a tar file entry) from a Docker container to local.""" """Fetch a file (as a tar file entry) from a Docker container to local."""
considered_in_paths = set() considered_in_paths: set[str] = set()
while True: while True:
if in_path in considered_in_paths: if in_path in considered_in_paths:
@ -372,8 +424,8 @@ def fetch_file_ex(
with tarfile.open( with tarfile.open(
fileobj=_stream_generator_to_fileobj(stream), mode="r|" fileobj=_stream_generator_to_fileobj(stream), mode="r|"
) as tar: ) as tar:
symlink_member = None symlink_member: tarfile.TarInfo | None = None
result = None result: _T | None = None
found = False found = False
for member in tar: for member in tar:
if found: if found:
@ -398,35 +450,46 @@ def fetch_file_ex(
log(f'FETCH: Following symbolic link to "{in_path}"') log(f'FETCH: Following symbolic link to "{in_path}"')
continue continue
if found: if found:
return result return result # type: ignore
raise DockerUnexpectedError("Received tarfile is empty!") raise DockerUnexpectedError("Received tarfile is empty!")
def fetch_file(client, container, in_path, out_path, follow_links=False, log=None): def fetch_file(
client: APIClient,
container: str,
in_path: str,
out_path: str,
follow_links: bool = False,
log: Callable[[str], None] | None = None,
) -> str:
b_out_path = to_bytes(out_path, errors="surrogate_or_strict") b_out_path = to_bytes(out_path, errors="surrogate_or_strict")
def process_none(in_path): def process_none(in_path: str) -> str:
raise DockerFileNotFound( raise DockerFileNotFound(
f"File {in_path} does not exist in container {container}" f"File {in_path} does not exist in container {container}"
) )
def process_regular(in_path, tar, member): def process_regular(
in_path: str, tar: tarfile.TarFile, member: tarfile.TarInfo
) -> str:
if not follow_links and os.path.exists(b_out_path): if not follow_links and os.path.exists(b_out_path):
os.unlink(b_out_path) os.unlink(b_out_path)
with tar.extractfile(member) as in_f: reader = tar.extractfile(member)
with open(b_out_path, "wb") as out_f: if reader:
shutil.copyfileobj(in_f, out_f) with reader as in_f:
with open(b_out_path, "wb") as out_f:
shutil.copyfileobj(in_f, out_f)
return in_path return in_path
def process_symlink(in_path, member): def process_symlink(in_path, member) -> str:
if os.path.exists(b_out_path): if os.path.exists(b_out_path):
os.unlink(b_out_path) os.unlink(b_out_path)
os.symlink(member.linkname, b_out_path) os.symlink(member.linkname, b_out_path)
return in_path return in_path
def process_other(in_path, member): def process_other(in_path, member) -> str:
raise DockerFileCopyError( raise DockerFileCopyError(
f'Remote file "{in_path}" is not a regular file or a symbolic link' f'Remote file "{in_path}" is not a regular file or a symbolic link'
) )
@ -444,7 +507,13 @@ def fetch_file(client, container, in_path, out_path, follow_links=False, log=Non
) )
def _execute_command(client, container, command, log=None, check_rc=False): def _execute_command(
client: APIClient,
container: str,
command: list[str],
log: Callable[[str], None] | None = None,
check_rc: bool = False,
) -> tuple[int, bytes, bytes]:
if log: if log:
log(f"Executing {command} in {container}") log(f"Executing {command} in {container}")
@ -483,7 +552,7 @@ def _execute_command(client, container, command, log=None, check_rc=False):
result = client.get_json("/exec/{0}/json", exec_id) result = client.get_json("/exec/{0}/json", exec_id)
rc = result.get("ExitCode") or 0 rc: int = result.get("ExitCode") or 0
stdout = stdout or b"" stdout = stdout or b""
stderr = stderr or b"" stderr = stderr or b""
@ -493,13 +562,15 @@ def _execute_command(client, container, command, log=None, check_rc=False):
if check_rc and rc != 0: if check_rc and rc != 0:
command_str = " ".join(command) command_str = " ".join(command)
raise DockerUnexpectedError( raise DockerUnexpectedError(
f'Obtained unexpected exit code {rc} when running "{command_str}" in {container}.\nSTDOUT: {stdout}\nSTDERR: {stderr}' f'Obtained unexpected exit code {rc} when running "{command_str}" in {container}.\nSTDOUT: {stdout!r}\nSTDERR: {stderr!r}'
) )
return rc, stdout, stderr return rc, stdout, stderr
def determine_user_group(client, container, log=None): def determine_user_group(
client: APIClient, container: str, log: Callable[[str], None] | None = None
) -> tuple[int, int]:
dummy_rc, stdout, dummy_stderr = _execute_command( dummy_rc, stdout, dummy_stderr = _execute_command(
client, container, ["/bin/sh", "-c", "id -u && id -g"], check_rc=True, log=log client, container, ["/bin/sh", "-c", "id -u && id -g"], check_rc=True, log=log
) )
@ -507,7 +578,7 @@ def determine_user_group(client, container, log=None):
stdout_lines = stdout.splitlines() stdout_lines = stdout.splitlines()
if len(stdout_lines) != 2: if len(stdout_lines) != 2:
raise DockerUnexpectedError( raise DockerUnexpectedError(
f"Expected two-line output to obtain user and group ID for container {container}, but got {len(stdout_lines)} lines:\n{stdout}" f"Expected two-line output to obtain user and group ID for container {container}, but got {len(stdout_lines)} lines:\n{stdout!r}"
) )
user_id, group_id = stdout_lines user_id, group_id = stdout_lines
@ -515,5 +586,5 @@ def determine_user_group(client, container, log=None):
return int(user_id), int(group_id) return int(user_id), int(group_id)
except ValueError as exc: except ValueError as exc:
raise DockerUnexpectedError( raise DockerUnexpectedError(
f'Expected two-line output with numeric IDs to obtain user and group ID for container {container}, but got "{user_id}" and "{group_id}" instead' f"Expected two-line output with numeric IDs to obtain user and group ID for container {container}, but got {user_id!r} and {group_id!r} instead"
) from exc ) from exc

View File

@ -18,12 +18,10 @@ class ImageArchiveManifestSummary:
"docker image save some:tag > some.tar" command. "docker image save some:tag > some.tar" command.
""" """
def __init__(self, image_id, repo_tags): def __init__(self, image_id: str, repo_tags: list[str]) -> None:
""" """
:param image_id: File name portion of Config entry, e.g. abcde12345 from abcde12345.json :param image_id: File name portion of Config entry, e.g. abcde12345 from abcde12345.json
:type image_id: str
:param repo_tags Docker image names, e.g. ["hello-world:latest"] :param repo_tags Docker image names, e.g. ["hello-world:latest"]
:type repo_tags: list[str]
""" """
self.image_id = image_id self.image_id = image_id
@ -34,22 +32,21 @@ class ImageArchiveInvalidException(Exception):
pass pass
def api_image_id(archive_image_id): def api_image_id(archive_image_id: str) -> str:
""" """
Accepts an image hash in the format stored in manifest.json, and returns an equivalent identifier Accepts an image hash in the format stored in manifest.json, and returns an equivalent identifier
that represents the same image hash, but in the format presented by the Docker Engine API. that represents the same image hash, but in the format presented by the Docker Engine API.
:param archive_image_id: plain image hash :param archive_image_id: plain image hash
:type archive_image_id: str
:returns: Prefixed hash used by REST api :returns: Prefixed hash used by REST api
:rtype: str
""" """
return f"sha256:{archive_image_id}" return f"sha256:{archive_image_id}"
def load_archived_image_manifest(archive_path): def load_archived_image_manifest(
archive_path: str,
) -> list[ImageArchiveManifestSummary] | None:
""" """
Attempts to get image IDs and image names from metadata stored in the image Attempts to get image IDs and image names from metadata stored in the image
archive tar file. archive tar file.
@ -62,10 +59,7 @@ def load_archived_image_manifest(archive_path):
ImageArchiveInvalidException: A file already exists at archive_path, but could not extract an image ID from it. ImageArchiveInvalidException: A file already exists at archive_path, but could not extract an image ID from it.
:param archive_path: Tar file to read :param archive_path: Tar file to read
:type archive_path: str
:return: None, if no file at archive_path, or a list of ImageArchiveManifestSummary objects. :return: None, if no file at archive_path, or a list of ImageArchiveManifestSummary objects.
:rtype: ImageArchiveManifestSummary
""" """
try: try:
@ -76,8 +70,15 @@ def load_archived_image_manifest(archive_path):
with tarfile.open(archive_path, "r") as tf: with tarfile.open(archive_path, "r") as tf:
try: try:
try: try:
with tf.extractfile("manifest.json") as ef: reader = tf.extractfile("manifest.json")
if reader is None:
raise ImageArchiveInvalidException(
"Failed to read manifest.json"
)
with reader as ef:
manifest = json.load(ef) manifest = json.load(ef)
except ImageArchiveInvalidException:
raise
except Exception as exc: except Exception as exc:
raise ImageArchiveInvalidException( raise ImageArchiveInvalidException(
f"Failed to decode and deserialize manifest.json: {exc}" f"Failed to decode and deserialize manifest.json: {exc}"
@ -139,7 +140,7 @@ def load_archived_image_manifest(archive_path):
) from exc ) from exc
def archived_image_manifest(archive_path): def archived_image_manifest(archive_path: str) -> ImageArchiveManifestSummary | None:
""" """
Attempts to get Image.Id and image name from metadata stored in the image Attempts to get Image.Id and image name from metadata stored in the image
archive tar file. archive tar file.
@ -152,10 +153,7 @@ def archived_image_manifest(archive_path):
ImageArchiveInvalidException: A file already exists at archive_path, but could not extract an image ID from it. ImageArchiveInvalidException: A file already exists at archive_path, but could not extract an image ID from it.
:param archive_path: Tar file to read :param archive_path: Tar file to read
:type archive_path: str
:return: None, if no file at archive_path, or the extracted image ID, which will not have a sha256: prefix. :return: None, if no file at archive_path, or the extracted image ID, which will not have a sha256: prefix.
:rtype: ImageArchiveManifestSummary
""" """
results = load_archived_image_manifest(archive_path) results = load_archived_image_manifest(archive_path)

View File

@ -13,6 +13,9 @@ See https://pkg.go.dev/github.com/kr/logfmt?utm_source=godoc for information on
from __future__ import annotations from __future__ import annotations
import typing as t
from enum import Enum
# The format is defined in https://pkg.go.dev/github.com/kr/logfmt?utm_source=godoc # The format is defined in https://pkg.go.dev/github.com/kr/logfmt?utm_source=godoc
# (look for "EBNFish") # (look for "EBNFish")
@ -22,7 +25,7 @@ class InvalidLogFmt(Exception):
pass pass
class _Mode: class _Mode(Enum):
GARBAGE = 0 GARBAGE = 0
KEY = 1 KEY = 1
EQUAL = 2 EQUAL = 2
@ -68,29 +71,29 @@ _HEX_DICT = {
} }
def _is_ident(cur): def _is_ident(cur: str) -> bool:
return cur > " " and cur not in ('"', "=") return cur > " " and cur not in ('"', "=")
class _Parser: class _Parser:
def __init__(self, line): def __init__(self, line: str) -> None:
self.line = line self.line = line
self.index = 0 self.index = 0
self.length = len(line) self.length = len(line)
def done(self): def done(self) -> bool:
return self.index >= self.length return self.index >= self.length
def cur(self): def cur(self) -> str:
return self.line[self.index] return self.line[self.index]
def next(self): def next(self) -> None:
self.index += 1 self.index += 1
def prev(self): def prev(self) -> None:
self.index -= 1 self.index -= 1
def parse_unicode_sequence(self): def parse_unicode_sequence(self) -> str:
if self.index + 6 > self.length: if self.index + 6 > self.length:
raise InvalidLogFmt("Not enough space for unicode escape") raise InvalidLogFmt("Not enough space for unicode escape")
if self.line[self.index : self.index + 2] != "\\u": if self.line[self.index : self.index + 2] != "\\u":
@ -108,27 +111,27 @@ class _Parser:
return chr(v) return chr(v)
def parse_line(line, logrus_mode=False): def parse_line(line: str, logrus_mode: bool = False) -> dict[str, t.Any]:
result = {} result: dict[str, t.Any] = {}
parser = _Parser(line) parser = _Parser(line)
key = [] key: list[str] = []
value = [] value: list[str] = []
mode = _Mode.GARBAGE mode = _Mode.GARBAGE
def handle_kv(has_no_value=False): def handle_kv(has_no_value: bool = False) -> None:
k = "".join(key) k = "".join(key)
v = None if has_no_value else "".join(value) v = None if has_no_value else "".join(value)
result[k] = v result[k] = v
del key[:] del key[:]
del value[:] del value[:]
def parse_garbage(cur): def parse_garbage(cur: str) -> _Mode:
if _is_ident(cur): if _is_ident(cur):
return _Mode.KEY return _Mode.KEY
parser.next() parser.next()
return _Mode.GARBAGE return _Mode.GARBAGE
def parse_key(cur): def parse_key(cur: str) -> _Mode:
if _is_ident(cur): if _is_ident(cur):
key.append(cur) key.append(cur)
parser.next() parser.next()
@ -142,7 +145,7 @@ def parse_line(line, logrus_mode=False):
parser.next() parser.next()
return _Mode.GARBAGE return _Mode.GARBAGE
def parse_equal(cur): def parse_equal(cur: str) -> _Mode:
if _is_ident(cur): if _is_ident(cur):
value.append(cur) value.append(cur)
parser.next() parser.next()
@ -154,7 +157,7 @@ def parse_line(line, logrus_mode=False):
parser.next() parser.next()
return _Mode.GARBAGE return _Mode.GARBAGE
def parse_ident_value(cur): def parse_ident_value(cur: str) -> _Mode:
if _is_ident(cur): if _is_ident(cur):
value.append(cur) value.append(cur)
parser.next() parser.next()
@ -163,7 +166,7 @@ def parse_line(line, logrus_mode=False):
parser.next() parser.next()
return _Mode.GARBAGE return _Mode.GARBAGE
def parse_quoted_value(cur): def parse_quoted_value(cur: str) -> _Mode:
if cur == "\\": if cur == "\\":
parser.next() parser.next()
if parser.done(): if parser.done():

View File

@ -12,6 +12,7 @@ import abc
import os import os
import re import re
import shlex import shlex
import typing as t
from functools import partial from functools import partial
from ansible.module_utils.common.text.converters import to_text from ansible.module_utils.common.text.converters import to_text
@ -32,6 +33,23 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
) )
if t.TYPE_CHECKING:
from collections.abc import Callable, Sequence
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.docker.plugins.module_utils._version import (
LooseVersion,
)
ValueType = t.Literal["set", "list", "dict", "bool", "int", "float", "str"]
AnsibleType = t.Literal["list", "dict", "bool", "int", "float", "str"]
ComparisonMode = t.Literal["ignore", "strict", "allow_more_present"]
ComparisonType = t.Literal["set", "set(dict)", "list", "dict", "value"]
Client = t.TypeVar("Client")
_DEFAULT_IP_REPLACEMENT_STRING = ( _DEFAULT_IP_REPLACEMENT_STRING = (
"[[DEFAULT_IP:iewahhaeB4Sae6Aen8IeShairoh4zeph7xaekoh8Geingunaesaeweiy3ooleiwi]]" "[[DEFAULT_IP:iewahhaeB4Sae6Aen8IeShairoh4zeph7xaekoh8Geingunaesaeweiy3ooleiwi]]"
) )
@ -54,7 +72,9 @@ _MOUNT_OPTION_TYPES = {
} }
def _get_ansible_type(value_type): def _get_ansible_type(
value_type: ValueType,
) -> AnsibleType:
if value_type == "set": if value_type == "set":
return "list" return "list"
if value_type not in ("list", "dict", "bool", "int", "float", "str"): if value_type not in ("list", "dict", "bool", "int", "float", "str"):
@ -65,21 +85,22 @@ def _get_ansible_type(value_type):
class Option: class Option:
def __init__( def __init__(
self, self,
name, name: str,
value_type, *,
owner, value_type: ValueType,
ansible_type=None, owner: OptionGroup,
elements=None, ansible_type: AnsibleType | None = None,
ansible_elements=None, elements: ValueType | None = None,
ansible_suboptions=None, ansible_elements: AnsibleType | None = None,
ansible_aliases=None, ansible_suboptions: dict[str, t.Any] | None = None,
ansible_choices=None, ansible_aliases: Sequence[str] | None = None,
needs_no_suboptions=False, ansible_choices: Sequence[str] | None = None,
default_comparison=None, needs_no_suboptions: bool = False,
not_a_container_option=False, default_comparison: ComparisonMode | None = None,
not_an_ansible_option=False, not_a_container_option: bool = False,
copy_comparison_from=None, not_an_ansible_option: bool = False,
compare=None, copy_comparison_from: str | None = None,
compare: Callable[[Option, t.Any, t.Any], bool] | None = None,
): ):
self.name = name self.name = name
self.value_type = value_type self.value_type = value_type
@ -95,8 +116,8 @@ class Option:
if (elements is None and ansible_elements is None) and needs_ansible_elements: if (elements is None and ansible_elements is None) and needs_ansible_elements:
raise ValueError("Ansible elements required for Ansible lists") raise ValueError("Ansible elements required for Ansible lists")
self.elements = elements if needs_elements else None self.elements = elements if needs_elements else None
self.ansible_elements = ( self.ansible_elements: AnsibleType | None = (
(ansible_elements or _get_ansible_type(elements)) (ansible_elements or _get_ansible_type(elements or "str"))
if needs_ansible_elements if needs_ansible_elements
else None else None
) )
@ -119,10 +140,12 @@ class Option:
self.ansible_suboptions = ansible_suboptions if needs_suboptions else None self.ansible_suboptions = ansible_suboptions if needs_suboptions else None
self.ansible_aliases = ansible_aliases or [] self.ansible_aliases = ansible_aliases or []
self.ansible_choices = ansible_choices self.ansible_choices = ansible_choices
comparison_type = self.value_type comparison_type: ComparisonType
if comparison_type == "set" and self.elements == "dict": if self.value_type == "set" and self.elements == "dict":
comparison_type = "set(dict)" comparison_type = "set(dict)"
elif comparison_type not in ("set", "list", "dict"): elif self.value_type in ("set", "list", "dict"):
comparison_type = self.value_type # type: ignore
else:
comparison_type = "value" comparison_type = "value"
self.comparison_type = comparison_type self.comparison_type = comparison_type
if default_comparison is not None: if default_comparison is not None:
@ -152,36 +175,45 @@ class Option:
class OptionGroup: class OptionGroup:
def __init__( def __init__(
self, self,
preprocess=None, *,
ansible_mutually_exclusive=None, preprocess: (
ansible_required_together=None, Callable[[AnsibleModule, dict[str, t.Any]], dict[str, t.Any]] | None
ansible_required_one_of=None, ) = None,
ansible_required_if=None, ansible_mutually_exclusive: Sequence[Sequence[str]] | None = None,
ansible_required_by=None, ansible_required_together: Sequence[Sequence[str]] | None = None,
): ansible_required_one_of: Sequence[Sequence[str]] | None = None,
ansible_required_if: (
Sequence[
tuple[str, t.Any, Sequence[str]]
| tuple[str, t.Any, Sequence[str], bool]
]
| None
) = None,
ansible_required_by: dict[str, Sequence[str]] | None = None,
) -> None:
if preprocess is None: if preprocess is None:
def preprocess(module, values): def preprocess(module, values):
return values return values
self.preprocess = preprocess self.preprocess = preprocess
self.options = [] self.options: list[Option] = []
self.all_options = [] self.all_options: list[Option] = []
self.engines = {} self.engines: dict[str, Engine] = {}
self.ansible_mutually_exclusive = ansible_mutually_exclusive or [] self.ansible_mutually_exclusive = ansible_mutually_exclusive or []
self.ansible_required_together = ansible_required_together or [] self.ansible_required_together = ansible_required_together or []
self.ansible_required_one_of = ansible_required_one_of or [] self.ansible_required_one_of = ansible_required_one_of or []
self.ansible_required_if = ansible_required_if or [] self.ansible_required_if = ansible_required_if or []
self.ansible_required_by = ansible_required_by or {} self.ansible_required_by = ansible_required_by or {}
self.argument_spec = {} self.argument_spec: dict[str, t.Any] = {}
def add_option(self, *args, **kwargs): def add_option(self, *args, **kwargs) -> OptionGroup:
option = Option(*args, owner=self, **kwargs) option = Option(*args, owner=self, **kwargs)
if not option.not_a_container_option: if not option.not_a_container_option:
self.options.append(option) self.options.append(option)
self.all_options.append(option) self.all_options.append(option)
if not option.not_an_ansible_option: if not option.not_an_ansible_option:
ansible_option = { ansible_option: dict[str, t.Any] = {
"type": option.ansible_type, "type": option.ansible_type,
} }
if option.ansible_elements is not None: if option.ansible_elements is not None:
@ -195,213 +227,297 @@ class OptionGroup:
self.argument_spec[option.name] = ansible_option self.argument_spec[option.name] = ansible_option
return self return self
def supports_engine(self, engine_name): def supports_engine(self, engine_name: str) -> bool:
return engine_name in self.engines return engine_name in self.engines
def get_engine(self, engine_name): def get_engine(self, engine_name: str) -> Engine:
return self.engines[engine_name] return self.engines[engine_name]
def add_engine(self, engine_name, engine): def add_engine(self, engine_name: str, engine: Engine) -> OptionGroup:
self.engines[engine_name] = engine self.engines[engine_name] = engine
return self return self
class Engine: class Engine(t.Generic[Client]):
min_api_version = None # string or None min_api_version: str | None = None
min_api_version_obj = None # LooseVersion object or None min_api_version_obj: LooseVersion | None = None
extra_option_minimal_versions = None # dict[str, dict[str, Any]] or None extra_option_minimal_versions: dict[str, dict[str, t.Any]] | None = None
@abc.abstractmethod @abc.abstractmethod
def get_value(self, module, container, api_version, options, image, host_info): def get_value(
self,
module: AnsibleModule,
container: dict[str, t.Any],
api_version: LooseVersion,
options: list[Option],
image: dict[str, t.Any] | None,
host_info: dict[str, t.Any] | None,
) -> dict[str, t.Any]:
pass pass
def compare_value(self, option, param_value, container_value): def compare_value(
self, option: Option, param_value: t.Any, container_value: t.Any
) -> bool:
return option.compare(param_value, container_value) return option.compare(param_value, container_value)
@abc.abstractmethod @abc.abstractmethod
def set_value(self, module, data, api_version, options, values): def set_value(
self,
module: AnsibleModule,
data: dict[str, t.Any],
api_version: LooseVersion,
options: list[Option],
values: dict[str, t.Any],
) -> None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def get_expected_values( def get_expected_values(
self, module, client, api_version, options, image, values, host_info self,
): module: AnsibleModule,
client: Client,
api_version: LooseVersion,
options: list[Option],
image: dict[str, t.Any] | None,
values: dict[str, t.Any],
host_info: dict[str, t.Any] | None,
) -> dict[str, t.Any]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def ignore_mismatching_result( def ignore_mismatching_result(
self, self,
module, module: AnsibleModule,
client, client: Client,
api_version, api_version: LooseVersion,
option, option: Option,
image, image: dict[str, t.Any] | None,
container_value, container_value: t.Any,
expected_value, expected_value: t.Any,
host_info, host_info: dict[str, t.Any] | None,
): ) -> bool:
pass pass
@abc.abstractmethod @abc.abstractmethod
def preprocess_value(self, module, client, api_version, options, values): def preprocess_value(
self,
module: AnsibleModule,
client: Client,
api_version: LooseVersion,
options: list[Option],
values: dict[str, t.Any],
) -> dict[str, t.Any]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def update_value(self, module, data, api_version, options, values): def update_value(
self,
module: AnsibleModule,
data: dict[str, t.Any],
api_version: LooseVersion,
options: list[Option],
values: dict[str, t.Any],
) -> None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def can_set_value(self, api_version): def can_set_value(self, api_version: LooseVersion) -> bool:
pass pass
@abc.abstractmethod @abc.abstractmethod
def can_update_value(self, api_version): def can_update_value(self, api_version: LooseVersion) -> bool:
pass pass
@abc.abstractmethod @abc.abstractmethod
def needs_container_image(self, values): def needs_container_image(self, values: dict[str, t.Any]) -> bool:
pass pass
@abc.abstractmethod @abc.abstractmethod
def needs_host_info(self, values): def needs_host_info(self, values: dict[str, t.Any]) -> bool:
pass pass
class EngineDriver: class EngineDriver(t.Generic[Client]):
name = None # string name: str
@abc.abstractmethod @abc.abstractmethod
def setup( def setup(
self, self,
argument_spec, argument_spec: dict[str, t.Any],
mutually_exclusive=None, mutually_exclusive: Sequence[Sequence[str]] | None = None,
required_together=None, required_together: Sequence[Sequence[str]] | None = None,
required_one_of=None, required_one_of: Sequence[Sequence[str]] | None = None,
required_if=None, required_if: (
required_by=None, Sequence[
): tuple[str, t.Any, Sequence[str]]
# Return (module, active_options, client) | tuple[str, t.Any, Sequence[str], bool]
]
| None
) = None,
required_by: dict[str, Sequence[str]] | None = None,
) -> tuple[AnsibleModule, list[OptionGroup], Client]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def get_host_info(self, client): def get_host_info(self, client: Client) -> dict[str, t.Any]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def get_api_version(self, client): def get_api_version(self, client: Client) -> LooseVersion:
pass pass
@abc.abstractmethod @abc.abstractmethod
def get_container_id(self, container): def get_container_id(self, container: dict[str, t.Any]) -> str:
pass pass
@abc.abstractmethod @abc.abstractmethod
def get_image_from_container(self, container): def get_image_from_container(self, container: dict[str, t.Any]) -> str:
pass pass
@abc.abstractmethod @abc.abstractmethod
def get_image_name_from_container(self, container): def get_image_name_from_container(self, container: dict[str, t.Any]) -> str | None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def is_container_removing(self, container): def is_container_removing(self, container: dict[str, t.Any]) -> bool:
pass pass
@abc.abstractmethod @abc.abstractmethod
def is_container_running(self, container): def is_container_running(self, container: dict[str, t.Any]) -> bool:
pass pass
@abc.abstractmethod @abc.abstractmethod
def is_container_paused(self, container): def is_container_paused(self, container: dict[str, t.Any]) -> bool:
pass pass
@abc.abstractmethod @abc.abstractmethod
def inspect_container_by_name(self, client, container_name): def inspect_container_by_name(
self, client: Client, container_name: str
) -> dict[str, t.Any] | None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def inspect_container_by_id(self, client, container_id): def inspect_container_by_id(
self, client: Client, container_id: str
) -> dict[str, t.Any] | None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def inspect_image_by_id(self, client, image_id): def inspect_image_by_id(
self, client: Client, image_id: str
) -> dict[str, t.Any] | None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def inspect_image_by_name(self, client, repository, tag): def inspect_image_by_name(
self, client: Client, repository: str, tag: str
) -> dict[str, t.Any] | None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def pull_image(self, client, repository, tag, image_platform=None): def pull_image(
self,
client: Client,
repository: str,
tag: str,
image_platform: str | None = None,
) -> tuple[dict[str, t.Any] | None, bool]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def pause_container(self, client, container_id): def pause_container(self, client: Client, container_id: str) -> None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def unpause_container(self, client, container_id): def unpause_container(self, client: Client, container_id: str) -> None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def disconnect_container_from_network(self, client, container_id, network_id): def disconnect_container_from_network(
self, client: Client, container_id: str, network_id: str
) -> None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def connect_container_to_network( def connect_container_to_network(
self, client, container_id, network_id, parameters=None self,
): client: Client,
container_id: str,
network_id: str,
parameters: dict[str, t.Any] | None = None,
) -> None:
pass pass
def create_container_supports_more_than_one_network(self, client): def create_container_supports_more_than_one_network(self, client: Client) -> bool:
return False return False
@abc.abstractmethod @abc.abstractmethod
def create_container( def create_container(
self, client, container_name, create_parameters, networks=None self,
): client: Client,
container_name: str,
create_parameters: dict[str, t.Any],
networks: dict[str, dict[str, t.Any]] | None = None,
) -> str:
pass pass
@abc.abstractmethod @abc.abstractmethod
def start_container(self, client, container_id): def start_container(self, client: Client, container_id: str) -> None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def wait_for_container(self, client, container_id, timeout=None): def wait_for_container(
self, client: Client, container_id: str, timeout: int | float | None = None
) -> int | None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def get_container_output(self, client, container_id): def get_container_output(
self, client: Client, container_id: str
) -> tuple[bytes, t.Literal[True]] | tuple[str, t.Literal[False]]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def update_container(self, client, container_id, update_parameters): def update_container(
self, client: Client, container_id: str, update_parameters: dict[str, t.Any]
) -> None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def restart_container(self, client, container_id, timeout=None): def restart_container(
self, client: Client, container_id: str, timeout: int | float | None = None
) -> None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def kill_container(self, client, container_id, kill_signal=None): def kill_container(
self, client: Client, container_id: str, kill_signal: str | None = None
) -> None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def stop_container(self, client, container_id, timeout=None): def stop_container(
self, client: Client, container_id: str, timeout: int | float | None = None
) -> None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def remove_container( def remove_container(
self, client, container_id, remove_volumes=False, link=False, force=False self,
): client: Client,
container_id: str,
remove_volumes: bool = False,
link: bool = False,
force: bool = False,
) -> None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def run(self, runner, client): def run(self, runner: Callable[[], None], client: Client) -> None:
pass pass
def _is_volume_permissions(mode): def _is_volume_permissions(mode: str) -> bool:
for part in mode.split(","): for part in mode.split(","):
if part not in ( if part not in (
"rw", "rw",
@ -423,7 +539,7 @@ def _is_volume_permissions(mode):
return True return True
def _parse_port_range(range_or_port, module): def _parse_port_range(range_or_port: str, module: AnsibleModule) -> list[int]:
""" """
Parses a string containing either a single port or a range of ports. Parses a string containing either a single port or a range of ports.
@ -443,7 +559,7 @@ def _parse_port_range(range_or_port, module):
module.fail_json(msg=f'Invalid port: "{range_or_port}"') module.fail_json(msg=f'Invalid port: "{range_or_port}"')
def _split_colon_ipv6(text, module): def _split_colon_ipv6(text: str, module: AnsibleModule) -> list[str]:
""" """
Split string by ':', while keeping IPv6 addresses in square brackets in one component. Split string by ':', while keeping IPv6 addresses in square brackets in one component.
""" """
@ -475,7 +591,9 @@ def _split_colon_ipv6(text, module):
return result return result
def _preprocess_command(module, values): def _preprocess_command(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
if "command" not in values: if "command" not in values:
return values return values
value = values["command"] value = values["command"]
@ -502,7 +620,9 @@ def _preprocess_command(module, values):
} }
def _preprocess_entrypoint(module, values): def _preprocess_entrypoint(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
if "entrypoint" not in values: if "entrypoint" not in values:
return values return values
value = values["entrypoint"] value = values["entrypoint"]
@ -522,7 +642,9 @@ def _preprocess_entrypoint(module, values):
} }
def _preprocess_env(module, values): def _preprocess_env(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
if not values: if not values:
return {} return {}
final_env = {} final_env = {}
@ -546,7 +668,9 @@ def _preprocess_env(module, values):
} }
def _preprocess_healthcheck(module, values): def _preprocess_healthcheck(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
if not values: if not values:
return {} return {}
return { return {
@ -556,7 +680,12 @@ def _preprocess_healthcheck(module, values):
} }
def _preprocess_convert_to_bytes(module, values, name, unlimited_value=None): def _preprocess_convert_to_bytes(
module: AnsibleModule,
values: dict[str, t.Any],
name: str,
unlimited_value: int | None = None,
) -> dict[str, t.Any]:
if name not in values: if name not in values:
return values return values
try: try:
@ -571,7 +700,9 @@ def _preprocess_convert_to_bytes(module, values, name, unlimited_value=None):
module.fail_json(msg=f"Failed to convert {name} to bytes: {exc}") module.fail_json(msg=f"Failed to convert {name} to bytes: {exc}")
def _preprocess_mac_address(module, values): def _preprocess_mac_address(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
if "mac_address" not in values: if "mac_address" not in values:
return values return values
return { return {
@ -579,7 +710,9 @@ def _preprocess_mac_address(module, values):
} }
def _preprocess_networks(module, values): def _preprocess_networks(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
if ( if (
module.params["networks_cli_compatible"] is True module.params["networks_cli_compatible"] is True
and values.get("networks") and values.get("networks")
@ -605,14 +738,18 @@ def _preprocess_networks(module, values):
return values return values
def _preprocess_sysctls(module, values): def _preprocess_sysctls(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
if "sysctls" in values: if "sysctls" in values:
for key, value in values["sysctls"].items(): for key, value in values["sysctls"].items():
values["sysctls"][key] = to_text(value, errors="surrogate_or_strict") values["sysctls"][key] = to_text(value, errors="surrogate_or_strict")
return values return values
def _preprocess_tmpfs(module, values): def _preprocess_tmpfs(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
if "tmpfs" not in values: if "tmpfs" not in values:
return values return values
result = {} result = {}
@ -625,7 +762,9 @@ def _preprocess_tmpfs(module, values):
return {"tmpfs": result} return {"tmpfs": result}
def _preprocess_ulimits(module, values): def _preprocess_ulimits(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
if "ulimits" not in values: if "ulimits" not in values:
return values return values
result = [] result = []
@ -644,8 +783,10 @@ def _preprocess_ulimits(module, values):
} }
def _preprocess_mounts(module, values): def _preprocess_mounts(
last = {} module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
last: dict[str, str] = {}
def check_collision(t, name): def check_collision(t, name):
if t in last: if t in last:
@ -776,7 +917,9 @@ def _preprocess_mounts(module, values):
return values return values
def _preprocess_labels(module, values): def _preprocess_labels(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
result = {} result = {}
if "labels" in values: if "labels" in values:
labels = values["labels"] labels = values["labels"]
@ -787,13 +930,15 @@ def _preprocess_labels(module, values):
return result return result
def _preprocess_log(module, values): def _preprocess_log(
result = {} module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
if "log_driver" not in values: if "log_driver" not in values:
return result return result
result["log_driver"] = values["log_driver"] result["log_driver"] = values["log_driver"]
if "log_options" in values: if "log_options" in values:
options = {} options: dict[str, str] = {}
for k, v in values["log_options"].items(): for k, v in values["log_options"].items():
if not isinstance(v, str): if not isinstance(v, str):
value = to_text(v, errors="surrogate_or_strict") value = to_text(v, errors="surrogate_or_strict")
@ -807,7 +952,9 @@ def _preprocess_log(module, values):
return result return result
def _preprocess_ports(module, values): def _preprocess_ports(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
if "published_ports" in values: if "published_ports" in values:
if "all" in values["published_ports"]: if "all" in values["published_ports"]:
module.fail_json( module.fail_json(
@ -815,7 +962,12 @@ def _preprocess_ports(module, values):
"to randomly assign port mappings for those not specified by published_ports." "to randomly assign port mappings for those not specified by published_ports."
) )
binds = {} binds: dict[
str | int,
tuple[str]
| tuple[str, str | int]
| list[tuple[str] | tuple[str, str | int]],
] = {}
for port in values["published_ports"]: for port in values["published_ports"]:
parts = _split_colon_ipv6( parts = _split_colon_ipv6(
to_text(port, errors="surrogate_or_strict"), module to_text(port, errors="surrogate_or_strict"), module
@ -827,6 +979,7 @@ def _preprocess_ports(module, values):
container_ports = _parse_port_range(container_port, module) container_ports = _parse_port_range(container_port, module)
p_len = len(parts) p_len = len(parts)
port_binds: Sequence[tuple[str] | tuple[str, str | int]]
if p_len == 1: if p_len == 1:
port_binds = len(container_ports) * [(_DEFAULT_IP_REPLACEMENT_STRING,)] port_binds = len(container_ports) * [(_DEFAULT_IP_REPLACEMENT_STRING,)]
elif p_len == 2: elif p_len == 2:
@ -865,8 +1018,12 @@ def _preprocess_ports(module, values):
"Maybe you forgot to use square brackets ([...]) around an IPv6 address?" "Maybe you forgot to use square brackets ([...]) around an IPv6 address?"
) )
for bind, container_port in zip(port_binds, container_ports): for bind, container_port_val in zip(port_binds, container_ports):
idx = f"{container_port}/{protocol}" if protocol else container_port idx = (
f"{container_port_val}/{protocol}"
if protocol
else container_port_val
)
if idx in binds: if idx in binds:
old_bind = binds[idx] old_bind = binds[idx]
if isinstance(old_bind, list): if isinstance(old_bind, list):
@ -882,9 +1039,9 @@ def _preprocess_ports(module, values):
for port in values["exposed_ports"]: for port in values["exposed_ports"]:
port = to_text(port, errors="surrogate_or_strict").strip() port = to_text(port, errors="surrogate_or_strict").strip()
protocol = "tcp" protocol = "tcp"
match = re.search(r"(/.+$)", port) matcher = re.search(r"(/.+$)", port)
if match: if matcher:
protocol = match.group(1).replace("/", "") protocol = matcher.group(1).replace("/", "")
port = re.sub(r"/.+$", "", port) port = re.sub(r"/.+$", "", port)
exposed.append((port, protocol)) exposed.append((port, protocol))
if "published_ports" in values: if "published_ports" in values:
@ -912,7 +1069,7 @@ def _preprocess_ports(module, values):
return values return values
def _compare_platform(option, param_value, container_value): def _compare_platform(option: Option, param_value: t.Any, container_value: t.Any):
if option.comparison == "ignore": if option.comparison == "ignore":
return True return True
try: try:

File diff suppressed because it is too large Load Diff

View File

@ -9,6 +9,7 @@
from __future__ import annotations from __future__ import annotations
import re import re
import typing as t
from time import sleep from time import sleep
from ansible.module_utils.common.text.converters import to_text from ansible.module_utils.common.text.converters import to_text
@ -16,6 +17,9 @@ from ansible.module_utils.common.text.converters import to_text
from ansible_collections.community.docker.plugins.module_utils._api.utils.utils import ( from ansible_collections.community.docker.plugins.module_utils._api.utils.utils import (
parse_repository_tag, parse_repository_tag,
) )
from ansible_collections.community.docker.plugins.module_utils._module_container.base import (
Client,
)
from ansible_collections.community.docker.plugins.module_utils._util import ( from ansible_collections.community.docker.plugins.module_utils._util import (
DifferenceTracker, DifferenceTracker,
DockerBaseClass, DockerBaseClass,
@ -25,13 +29,23 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
) )
if t.TYPE_CHECKING:
from collections.abc import Sequence
from ansible.module_utils.basic import AnsibleModule
from .base import Engine, EngineDriver, Option, OptionGroup
class Container(DockerBaseClass): class Container(DockerBaseClass):
def __init__(self, container, engine_driver): def __init__(
self, container: dict[str, t.Any] | None, engine_driver: EngineDriver
) -> None:
super().__init__() super().__init__()
self.raw = container self.raw = container
self.id = None self.id: str | None = None
self.image = None self.image: str | None = None
self.image_name = None self.image_name: str | None = None
self.container = container self.container = container
self.engine_driver = engine_driver self.engine_driver = engine_driver
if container: if container:
@ -41,11 +55,11 @@ class Container(DockerBaseClass):
self.log(self.container, pretty_print=True) self.log(self.container, pretty_print=True)
@property @property
def exists(self): def exists(self) -> bool:
return bool(self.container) return bool(self.container)
@property @property
def removing(self): def removing(self) -> bool:
return ( return (
self.engine_driver.is_container_removing(self.container) self.engine_driver.is_container_removing(self.container)
if self.container if self.container
@ -53,7 +67,7 @@ class Container(DockerBaseClass):
) )
@property @property
def running(self): def running(self) -> bool:
return ( return (
self.engine_driver.is_container_running(self.container) self.engine_driver.is_container_running(self.container)
if self.container if self.container
@ -61,7 +75,7 @@ class Container(DockerBaseClass):
) )
@property @property
def paused(self): def paused(self) -> bool:
return ( return (
self.engine_driver.is_container_paused(self.container) self.engine_driver.is_container_paused(self.container)
if self.container if self.container
@ -69,8 +83,14 @@ class Container(DockerBaseClass):
) )
class ContainerManager(DockerBaseClass): class ContainerManager(DockerBaseClass, t.Generic[Client]):
def __init__(self, module, engine_driver, client, active_options): def __init__(
self,
module: AnsibleModule,
engine_driver: EngineDriver,
client: Client,
active_options: list[OptionGroup],
) -> None:
super().__init__() super().__init__()
self.module = module self.module = module
self.engine_driver = engine_driver self.engine_driver = engine_driver
@ -78,46 +98,64 @@ class ContainerManager(DockerBaseClass):
self.options = active_options self.options = active_options
self.all_options = self._collect_all_options(active_options) self.all_options = self._collect_all_options(active_options)
self.check_mode = self.module.check_mode self.check_mode = self.module.check_mode
self.param_cleanup = self.module.params["cleanup"] self.param_cleanup: bool = self.module.params["cleanup"]
self.param_container_default_behavior = self.module.params[ self.param_container_default_behavior: t.Literal[
"container_default_behavior" "compatibility", "no_defaults"
] ] = self.module.params["container_default_behavior"]
self.param_default_host_ip = self.module.params["default_host_ip"] self.param_default_host_ip: str | None = self.module.params["default_host_ip"]
self.param_debug = self.module.params["debug"] self.param_debug: bool = self.module.params["debug"]
self.param_force_kill = self.module.params["force_kill"] self.param_force_kill: bool = self.module.params["force_kill"]
self.param_image = self.module.params["image"] self.param_image: str | None = self.module.params["image"]
self.param_image_comparison = self.module.params["image_comparison"] self.param_image_comparison: t.Literal["desired-image", "current-image"] = (
self.param_image_label_mismatch = self.module.params["image_label_mismatch"] self.module.params["image_comparison"]
self.param_image_name_mismatch = self.module.params["image_name_mismatch"] )
self.param_keep_volumes = self.module.params["keep_volumes"] self.param_image_label_mismatch: t.Literal["ignore", "fail"] = (
self.param_kill_signal = self.module.params["kill_signal"] self.module.params["image_label_mismatch"]
self.param_name = self.module.params["name"] )
self.param_networks_cli_compatible = self.module.params[ self.param_image_name_mismatch: t.Literal["ignore", "recreate"] = (
self.module.params["image_name_mismatch"]
)
self.param_keep_volumes: bool = self.module.params["keep_volumes"]
self.param_kill_signal: str | None = self.module.params["kill_signal"]
self.param_name: str = self.module.params["name"]
self.param_networks_cli_compatible: bool = self.module.params[
"networks_cli_compatible" "networks_cli_compatible"
] ]
self.param_output_logs = self.module.params["output_logs"] self.param_output_logs: bool = self.module.params["output_logs"]
self.param_paused = self.module.params["paused"] self.param_paused: bool | None = self.module.params["paused"]
self.param_pull = self.module.params["pull"] param_pull: t.Literal["never", "missing", "always", True, False] = (
if self.param_pull is True: self.module.params["pull"]
self.param_pull = "always" )
if self.param_pull is False: if param_pull is True:
self.param_pull = "missing" param_pull = "always"
self.param_pull_check_mode_behavior = self.module.params[ if param_pull is False:
"pull_check_mode_behavior" param_pull = "missing"
self.param_pull: t.Literal["never", "missing", "always"] = param_pull
self.param_pull_check_mode_behavior: t.Literal[
"image_not_present", "always"
] = self.module.params["pull_check_mode_behavior"]
self.param_recreate: bool = self.module.params["recreate"]
self.param_removal_wait_timeout: int | float | None = self.module.params[
"removal_wait_timeout"
] ]
self.param_recreate = self.module.params["recreate"] self.param_healthy_wait_timeout: int | float | None = self.module.params[
self.param_removal_wait_timeout = self.module.params["removal_wait_timeout"] "healthy_wait_timeout"
self.param_healthy_wait_timeout = self.module.params["healthy_wait_timeout"] ]
if self.param_healthy_wait_timeout <= 0: if (
self.param_healthy_wait_timeout is not None
and self.param_healthy_wait_timeout <= 0
):
self.param_healthy_wait_timeout = None self.param_healthy_wait_timeout = None
self.param_restart = self.module.params["restart"] self.param_restart: bool = self.module.params["restart"]
self.param_state = self.module.params["state"] self.param_state: t.Literal[
"absent", "present", "healthy", "started", "stopped"
] = self.module.params["state"]
self._parse_comparisons() self._parse_comparisons()
self._update_params() self._update_params()
self.results = {"changed": False, "actions": []} self.results = {"changed": False, "actions": []}
self.diff = {} self.diff: dict[str, t.Any] = {}
self.diff_tracker = DifferenceTracker() self.diff_tracker = DifferenceTracker()
self.facts = {} self.facts: dict[str, t.Any] | None = {}
if self.param_default_host_ip: if self.param_default_host_ip:
valid_ip = False valid_ip = False
if re.match( if re.match(
@ -134,16 +172,22 @@ class ContainerManager(DockerBaseClass):
"The value of default_host_ip must be an empty string, an IPv4 address, " "The value of default_host_ip must be an empty string, an IPv4 address, "
f'or an IPv6 address. Got "{self.param_default_host_ip}" instead.' f'or an IPv6 address. Got "{self.param_default_host_ip}" instead.'
) )
self.parameters = None self.parameters: list[tuple[OptionGroup, dict[str, t.Any]]] | None = None
def _collect_all_options(self, active_options): def _add_action(self, action: dict[str, t.Any]) -> None:
actions: list[dict[str, t.Any]] = self.results["actions"] # type: ignore
actions.append(action)
def _collect_all_options(
self, active_options: list[OptionGroup]
) -> dict[str, Option]:
all_options = {} all_options = {}
for options in active_options: for options in active_options:
for option in options.options: for option in options.options:
all_options[option.name] = option all_options[option.name] = option
return all_options return all_options
def _collect_all_module_params(self): def _collect_all_module_params(self) -> set[str]:
all_module_options = set() all_module_options = set()
for option, data in self.module.argument_spec.items(): for option, data in self.module.argument_spec.items():
all_module_options.add(option) all_module_options.add(option)
@ -152,7 +196,7 @@ class ContainerManager(DockerBaseClass):
all_module_options.add(alias) all_module_options.add(alias)
return all_module_options return all_module_options
def _parse_comparisons(self): def _parse_comparisons(self) -> None:
# Keep track of all module params and all option aliases # Keep track of all module params and all option aliases
all_module_options = self._collect_all_module_params() all_module_options = self._collect_all_module_params()
comp_aliases = {} comp_aliases = {}
@ -163,10 +207,11 @@ class ContainerManager(DockerBaseClass):
for alias in option.ansible_aliases: for alias in option.ansible_aliases:
comp_aliases[alias] = option_name comp_aliases[alias] = option_name
# Process comparisons specified by user # Process comparisons specified by user
if self.module.params.get("comparisons"): comparisons: dict[str, t.Any] | None = self.module.params.get("comparisons")
if comparisons:
# If '*' appears in comparisons, process it first # If '*' appears in comparisons, process it first
if "*" in self.module.params["comparisons"]: if "*" in comparisons:
value = self.module.params["comparisons"]["*"] value = comparisons["*"]
if value not in ("strict", "ignore"): if value not in ("strict", "ignore"):
self.fail( self.fail(
"The wildcard can only be used with comparison modes 'strict' and 'ignore'!" "The wildcard can only be used with comparison modes 'strict' and 'ignore'!"
@ -179,8 +224,8 @@ class ContainerManager(DockerBaseClass):
continue continue
option.comparison = value option.comparison = value
# Now process all other comparisons. # Now process all other comparisons.
comp_aliases_used = {} comp_aliases_used: dict[str, str] = {}
for key, value in self.module.params["comparisons"].items(): for key, value in comparisons.items():
if key == "*": if key == "*":
continue continue
# Find main key # Find main key
@ -220,7 +265,7 @@ class ContainerManager(DockerBaseClass):
option.copy_comparison_from option.copy_comparison_from
].comparison ].comparison
def _update_params(self): def _update_params(self) -> None:
if ( if (
self.param_networks_cli_compatible is True self.param_networks_cli_compatible is True
and self.module.params["networks"] and self.module.params["networks"]
@ -247,12 +292,14 @@ class ContainerManager(DockerBaseClass):
if self.module.params[param] is None: if self.module.params[param] is None:
self.module.params[param] = value self.module.params[param] = value
def fail(self, *args, **kwargs): def fail(self, *args, **kwargs) -> t.NoReturn:
self.client.fail(*args, **kwargs) # mypy doesn't know that Client has fail() method
raise self.client.fail(*args, **kwargs) # type: ignore
def run(self): def run(self) -> None:
if self.param_state in ("stopped", "started", "present", "healthy"): if self.param_state in ("stopped", "started", "present", "healthy"):
self.present(self.param_state) # mypy doesn't get that self.param_state has only one of the above values
self.present(self.param_state) # type: ignore
elif self.param_state == "absent": elif self.param_state == "absent":
self.absent() self.absent()
@ -270,15 +317,16 @@ class ContainerManager(DockerBaseClass):
def wait_for_state( def wait_for_state(
self, self,
container_id, container_id: str,
complete_states=None, *,
wait_states=None, complete_states: Sequence[str | None] | None = None,
accept_removal=False, wait_states: Sequence[str | None] | None = None,
max_wait=None, accept_removal: bool = False,
health_state=False, max_wait: int | float | None = None,
): health_state: bool = False,
) -> dict[str, t.Any] | None:
delay = 1.0 delay = 1.0
total_wait = 0 total_wait = 0.0
while True: while True:
# Inspect container # Inspect container
result = self.engine_driver.inspect_container_by_id( result = self.engine_driver.inspect_container_by_id(
@ -314,7 +362,9 @@ class ContainerManager(DockerBaseClass):
# code will have slept for ~1.5 minutes.) # code will have slept for ~1.5 minutes.)
delay = min(delay * 1.1, 10) delay = min(delay * 1.1, 10)
def _collect_params(self, active_options): def _collect_params(
self, active_options: list[OptionGroup]
) -> list[tuple[OptionGroup, dict[str, t.Any]]]:
parameters = [] parameters = []
for options in active_options: for options in active_options:
values = {} values = {}
@ -336,21 +386,25 @@ class ContainerManager(DockerBaseClass):
parameters.append((options, values)) parameters.append((options, values))
return parameters return parameters
def _needs_container_image(self): def _needs_container_image(self) -> bool:
assert self.parameters is not None
for options, values in self.parameters: for options, values in self.parameters:
engine = options.get_engine(self.engine_driver.name) engine = options.get_engine(self.engine_driver.name)
if engine.needs_container_image(values): if engine.needs_container_image(values):
return True return True
return False return False
def _needs_host_info(self): def _needs_host_info(self) -> bool:
assert self.parameters is not None
for options, values in self.parameters: for options, values in self.parameters:
engine = options.get_engine(self.engine_driver.name) engine = options.get_engine(self.engine_driver.name)
if engine.needs_host_info(values): if engine.needs_host_info(values):
return True return True
return False return False
def present(self, state): def present(
self, state: t.Literal["stopped", "started", "present", "healthy"]
) -> None:
self.parameters = self._collect_params(self.options) self.parameters = self._collect_params(self.options)
container = self._get_container(self.param_name) container = self._get_container(self.param_name)
was_running = container.running was_running = container.running
@ -382,6 +436,7 @@ class ContainerManager(DockerBaseClass):
self.diff_tracker.add("exists", parameter=True, active=False) self.diff_tracker.add("exists", parameter=True, active=False)
if container.removing and not self.check_mode: if container.removing and not self.check_mode:
# Wait for container to be removed before trying to create it # Wait for container to be removed before trying to create it
assert container.id is not None
self.wait_for_state( self.wait_for_state(
container.id, container.id,
wait_states=["removing"], wait_states=["removing"],
@ -394,6 +449,7 @@ class ContainerManager(DockerBaseClass):
container_created = True container_created = True
else: else:
# Existing container # Existing container
assert container.id is not None
different, differences = self.has_different_configuration( different, differences = self.has_different_configuration(
container, container_image, comparison_image, host_info container, container_image, comparison_image, host_info
) )
@ -453,13 +509,16 @@ class ContainerManager(DockerBaseClass):
if state in ("started", "healthy") and not container.running: if state in ("started", "healthy") and not container.running:
self.diff_tracker.add("running", parameter=True, active=was_running) self.diff_tracker.add("running", parameter=True, active=was_running)
assert container.id is not None
container = self.container_start(container.id) container = self.container_start(container.id)
elif state in ("started", "healthy") and self.param_restart: elif state in ("started", "healthy") and self.param_restart:
self.diff_tracker.add("running", parameter=True, active=was_running) self.diff_tracker.add("running", parameter=True, active=was_running)
self.diff_tracker.add("restarted", parameter=True, active=False) self.diff_tracker.add("restarted", parameter=True, active=False)
assert container.id is not None
container = self.container_restart(container.id) container = self.container_restart(container.id)
elif state == "stopped" and container.running: elif state == "stopped" and container.running:
self.diff_tracker.add("running", parameter=False, active=was_running) self.diff_tracker.add("running", parameter=False, active=was_running)
assert container.id is not None
self.container_stop(container.id) self.container_stop(container.id)
container = self._get_container(container.id) container = self._get_container(container.id)
@ -472,6 +531,7 @@ class ContainerManager(DockerBaseClass):
"paused", parameter=self.param_paused, active=was_paused "paused", parameter=self.param_paused, active=was_paused
) )
if not self.check_mode: if not self.check_mode:
assert container.id is not None
try: try:
if self.param_paused: if self.param_paused:
self.engine_driver.pause_container( self.engine_driver.pause_container(
@ -487,12 +547,13 @@ class ContainerManager(DockerBaseClass):
) )
container = self._get_container(container.id) container = self._get_container(container.id)
self.results["changed"] = True self.results["changed"] = True
self.results["actions"].append({"set_paused": self.param_paused}) self._add_action({"set_paused": self.param_paused})
self.facts = container.raw self.facts = container.raw
if state == "healthy" and not self.check_mode: if state == "healthy" and not self.check_mode:
# `None` means that no health check enabled; simply treat this as 'healthy' # `None` means that no health check enabled; simply treat this as 'healthy'
assert container.id is not None
inspect_result = self.wait_for_state( inspect_result = self.wait_for_state(
container.id, container.id,
wait_states=["starting", "unhealthy"], wait_states=["starting", "unhealthy"],
@ -504,41 +565,51 @@ class ContainerManager(DockerBaseClass):
# Return the latest inspection results retrieved # Return the latest inspection results retrieved
self.facts = inspect_result self.facts = inspect_result
def absent(self): def absent(self) -> None:
container = self._get_container(self.param_name) container = self._get_container(self.param_name)
if container.exists: if container.exists:
assert container.id is not None
if container.running: if container.running:
self.diff_tracker.add("running", parameter=False, active=True) self.diff_tracker.add("running", parameter=False, active=True)
self.container_stop(container.id) self.container_stop(container.id)
self.diff_tracker.add("exists", parameter=False, active=True) self.diff_tracker.add("exists", parameter=False, active=True)
self.container_remove(container.id) self.container_remove(container.id)
def _output_logs(self, msg): def _output_logs(self, msg: str | bytes) -> None:
self.module.log(msg=msg) self.module.log(msg=msg)
def _get_container(self, container): def _get_container(self, container: str) -> Container:
""" """
Expects container ID or Name. Returns a container object Expects container ID or Name. Returns a container object
""" """
container = self.engine_driver.inspect_container_by_name(self.client, container) container_data = self.engine_driver.inspect_container_by_name(
return Container(container, self.engine_driver) self.client, container
)
return Container(container_data, self.engine_driver)
def _get_container_image(self, container, fallback=None): def _get_container_image(
self, container: Container, fallback: dict[str, t.Any] | None = None
) -> dict[str, t.Any] | None:
if not container.exists or container.removing: if not container.exists or container.removing:
return fallback return fallback
image = container.image image = container.image
assert image is not None
if is_image_name_id(image): if is_image_name_id(image):
image = self.engine_driver.inspect_image_by_id(self.client, image) image_data = self.engine_driver.inspect_image_by_id(self.client, image)
else: else:
repository, tag = parse_repository_tag(image) repository, tag = parse_repository_tag(image)
if not tag: if not tag:
tag = "latest" tag = "latest"
image = self.engine_driver.inspect_image_by_name( image_data = self.engine_driver.inspect_image_by_name(
self.client, repository, tag self.client, repository, tag
) )
return image or fallback return image_data or fallback
def _get_image(self, container, needs_container_image=False): def _get_image(
self, container: Container, needs_container_image: bool = False
) -> tuple[
dict[str, t.Any] | None, dict[str, t.Any] | None, dict[str, t.Any] | None
]:
image_parameter = self.param_image image_parameter = self.param_image
get_container_image = needs_container_image or not image_parameter get_container_image = needs_container_image or not image_parameter
container_image = ( container_image = (
@ -553,7 +624,7 @@ class ContainerManager(DockerBaseClass):
if is_image_name_id(image_parameter): if is_image_name_id(image_parameter):
image = self.engine_driver.inspect_image_by_id(self.client, image_parameter) image = self.engine_driver.inspect_image_by_id(self.client, image_parameter)
if image is None: if image is None:
self.client.fail(f"Cannot find image with ID {image_parameter}") self.fail(f"Cannot find image with ID {image_parameter}")
else: else:
repository, tag = parse_repository_tag(image_parameter) repository, tag = parse_repository_tag(image_parameter)
if not tag: if not tag:
@ -562,7 +633,7 @@ class ContainerManager(DockerBaseClass):
self.client, repository, tag self.client, repository, tag
) )
if not image and self.param_pull == "never": if not image and self.param_pull == "never":
self.client.fail( self.fail(
f"Cannot find image with name {repository}:{tag}, and pull=never" f"Cannot find image with name {repository}:{tag}, and pull=never"
) )
if not image or self.param_pull == "always": if not image or self.param_pull == "always":
@ -576,12 +647,12 @@ class ContainerManager(DockerBaseClass):
) )
if already_to_latest: if already_to_latest:
self.results["changed"] = False self.results["changed"] = False
self.results["actions"].append( self._add_action(
{"pulled_image": f"{repository}:{tag}", "changed": False} {"pulled_image": f"{repository}:{tag}", "changed": False}
) )
else: else:
self.results["changed"] = True self.results["changed"] = True
self.results["actions"].append( self._add_action(
{"pulled_image": f"{repository}:{tag}", "changed": True} {"pulled_image": f"{repository}:{tag}", "changed": True}
) )
elif not image or self.param_pull_check_mode_behavior == "always": elif not image or self.param_pull_check_mode_behavior == "always":
@ -589,10 +660,10 @@ class ContainerManager(DockerBaseClass):
# pull. (Implicitly: if the image is there, claim it already was latest unless # pull. (Implicitly: if the image is there, claim it already was latest unless
# pull_check_mode_behavior == 'always'.) # pull_check_mode_behavior == 'always'.)
self.results["changed"] = True self.results["changed"] = True
action = {"pulled_image": f"{repository}:{tag}"} action: dict[str, t.Any] = {"pulled_image": f"{repository}:{tag}"}
if not image: if not image:
action["changed"] = True action["changed"] = True
self.results["actions"].append(action) self._add_action(action)
self.log("image") self.log("image")
self.log(image, pretty_print=True) self.log(image, pretty_print=True)
@ -605,7 +676,9 @@ class ContainerManager(DockerBaseClass):
return image, container_image, comparison_image return image, container_image, comparison_image
def _image_is_different(self, image, container): def _image_is_different(
self, image: dict[str, t.Any] | None, container: Container
) -> bool:
if image and image.get("Id"): if image and image.get("Id"):
if container and container.image: if container and container.image:
if image.get("Id") != container.image: if image.get("Id") != container.image:
@ -615,8 +688,9 @@ class ContainerManager(DockerBaseClass):
return True return True
return False return False
def _compose_create_parameters(self, image): def _compose_create_parameters(self, image: str) -> dict[str, t.Any]:
params = {} params: dict[str, t.Any] = {}
assert self.parameters is not None
for options, values in self.parameters: for options, values in self.parameters:
engine = options.get_engine(self.engine_driver.name) engine = options.get_engine(self.engine_driver.name)
if engine.can_set_value(self.engine_driver.get_api_version(self.client)): if engine.can_set_value(self.engine_driver.get_api_version(self.client)):
@ -632,15 +706,16 @@ class ContainerManager(DockerBaseClass):
def _record_differences( def _record_differences(
self, self,
differences, differences: DifferenceTracker,
options, options: OptionGroup,
param_values, param_values: dict[str, t.Any],
engine, engine: Engine,
container, container: Container,
container_image, container_image: dict[str, t.Any] | None,
image, image: dict[str, t.Any] | None,
host_info, host_info: dict[str, t.Any] | None,
): ):
assert container.raw is not None
container_values = engine.get_value( container_values = engine.get_value(
self.module, self.module,
container.raw, container.raw,
@ -709,9 +784,16 @@ class ContainerManager(DockerBaseClass):
c = sorted(c, key=sort_key_fn) c = sorted(c, key=sort_key_fn)
differences.add(option.name, parameter=p, active=c) differences.add(option.name, parameter=p, active=c)
def has_different_configuration(self, container, container_image, image, host_info): def has_different_configuration(
self,
container: Container,
container_image: dict[str, t.Any] | None,
image: dict[str, t.Any] | None,
host_info: dict[str, t.Any] | None,
) -> tuple[bool, DifferenceTracker]:
differences = DifferenceTracker() differences = DifferenceTracker()
update_differences = DifferenceTracker() update_differences = DifferenceTracker()
assert self.parameters is not None
for options, param_values in self.parameters: for options, param_values in self.parameters:
engine = options.get_engine(self.engine_driver.name) engine = options.get_engine(self.engine_driver.name)
if engine.can_update_value(self.engine_driver.get_api_version(self.client)): if engine.can_update_value(self.engine_driver.get_api_version(self.client)):
@ -743,9 +825,14 @@ class ContainerManager(DockerBaseClass):
return has_differences, differences return has_differences, differences
def has_different_resource_limits( def has_different_resource_limits(
self, container, container_image, image, host_info self,
): container: Container,
container_image: dict[str, t.Any] | None,
image: dict[str, t.Any] | None,
host_info: dict[str, t.Any] | None,
) -> tuple[bool, DifferenceTracker]:
differences = DifferenceTracker() differences = DifferenceTracker()
assert self.parameters is not None
for options, param_values in self.parameters: for options, param_values in self.parameters:
engine = options.get_engine(self.engine_driver.name) engine = options.get_engine(self.engine_driver.name)
if not engine.can_update_value( if not engine.can_update_value(
@ -765,8 +852,9 @@ class ContainerManager(DockerBaseClass):
has_differences = not differences.empty has_differences = not differences.empty
return has_differences, differences return has_differences, differences
def _compose_update_parameters(self): def _compose_update_parameters(self) -> dict[str, t.Any]:
result = {} result: dict[str, t.Any] = {}
assert self.parameters is not None
for options, values in self.parameters: for options, values in self.parameters:
engine = options.get_engine(self.engine_driver.name) engine = options.get_engine(self.engine_driver.name)
if not engine.can_update_value( if not engine.can_update_value(
@ -782,7 +870,13 @@ class ContainerManager(DockerBaseClass):
) )
return result return result
def update_limits(self, container, container_image, image, host_info): def update_limits(
self,
container: Container,
container_image: dict[str, t.Any] | None,
image: dict[str, t.Any] | None,
host_info: dict[str, t.Any] | None,
) -> Container:
limits_differ, different_limits = self.has_different_resource_limits( limits_differ, different_limits = self.has_different_resource_limits(
container, container_image, image, host_info container, container_image, image, host_info
) )
@ -793,20 +887,24 @@ class ContainerManager(DockerBaseClass):
) )
self.diff_tracker.merge(different_limits) self.diff_tracker.merge(different_limits)
if limits_differ and not self.check_mode: if limits_differ and not self.check_mode:
assert container.id is not None
self.container_update(container.id, self._compose_update_parameters()) self.container_update(container.id, self._compose_update_parameters())
return self._get_container(container.id) return self._get_container(container.id)
return container return container
def has_network_differences(self, container): def has_network_differences(
self, container: Container
) -> tuple[bool, list[dict[str, t.Any]]]:
""" """
Check if the container is connected to requested networks with expected options: links, aliases, ipv4, ipv6 Check if the container is connected to requested networks with expected options: links, aliases, ipv4, ipv6
""" """
different = False different = False
differences = [] differences: list[dict[str, t.Any]] = []
if not self.module.params["networks"]: if not self.module.params["networks"]:
return different, differences return different, differences
assert container.container is not None
if not container.container.get("NetworkSettings"): if not container.container.get("NetworkSettings"):
self.fail( self.fail(
"has_missing_networks: Error parsing container properties. NetworkSettings missing." "has_missing_networks: Error parsing container properties. NetworkSettings missing."
@ -869,13 +967,16 @@ class ContainerManager(DockerBaseClass):
) )
return different, differences return different, differences
def has_extra_networks(self, container): def has_extra_networks(
self, container: Container
) -> tuple[bool, list[dict[str, t.Any]]]:
""" """
Check if the container is connected to non-requested networks Check if the container is connected to non-requested networks
""" """
extra_networks = [] extra_networks: list[dict[str, t.Any]] = []
extra = False extra = False
assert container.container is not None
if not container.container.get("NetworkSettings"): if not container.container.get("NetworkSettings"):
self.fail( self.fail(
"has_extra_networks: Error parsing container properties. NetworkSettings missing." "has_extra_networks: Error parsing container properties. NetworkSettings missing."
@ -896,7 +997,9 @@ class ContainerManager(DockerBaseClass):
) )
return extra, extra_networks return extra, extra_networks
def update_networks(self, container, container_created): def update_networks(
self, container: Container, container_created: bool
) -> Container:
updated_container = container updated_container = container
if self.all_options["networks"].comparison != "ignore" or container_created: if self.all_options["networks"].comparison != "ignore" or container_created:
has_network_differences, network_differences = self.has_network_differences( has_network_differences, network_differences = self.has_network_differences(
@ -939,13 +1042,14 @@ class ContainerManager(DockerBaseClass):
updated_container = self._purge_networks(container, extra_networks) updated_container = self._purge_networks(container, extra_networks)
return updated_container return updated_container
def _add_networks(self, container, differences): def _add_networks(
self, container: Container, differences: list[dict[str, t.Any]]
) -> Container:
assert container.id is not None
for diff in differences: for diff in differences:
# remove the container from the network, if connected # remove the container from the network, if connected
if diff.get("container"): if diff.get("container"):
self.results["actions"].append( self._add_action({"removed_from_network": diff["parameter"]["name"]})
{"removed_from_network": diff["parameter"]["name"]}
)
if not self.check_mode: if not self.check_mode:
try: try:
self.engine_driver.disconnect_container_from_network( self.engine_driver.disconnect_container_from_network(
@ -956,7 +1060,7 @@ class ContainerManager(DockerBaseClass):
f"Error disconnecting container from network {diff['parameter']['name']} - {exc}" f"Error disconnecting container from network {diff['parameter']['name']} - {exc}"
) )
# connect to the network # connect to the network
self.results["actions"].append( self._add_action(
{ {
"added_to_network": diff["parameter"]["name"], "added_to_network": diff["parameter"]["name"],
"network_parameters": diff["parameter"], "network_parameters": diff["parameter"],
@ -982,9 +1086,12 @@ class ContainerManager(DockerBaseClass):
) )
return self._get_container(container.id) return self._get_container(container.id)
def _purge_networks(self, container, networks): def _purge_networks(
self, container: Container, networks: list[dict[str, t.Any]]
) -> Container:
assert container.id is not None
for network in networks: for network in networks:
self.results["actions"].append({"removed_from_network": network["name"]}) self._add_action({"removed_from_network": network["name"]})
if not self.check_mode: if not self.check_mode:
try: try:
self.engine_driver.disconnect_container_from_network( self.engine_driver.disconnect_container_from_network(
@ -996,7 +1103,7 @@ class ContainerManager(DockerBaseClass):
) )
return self._get_container(container.id) return self._get_container(container.id)
def container_create(self, image): def container_create(self, image: str) -> Container | None:
create_parameters = self._compose_create_parameters(image) create_parameters = self._compose_create_parameters(image)
self.log("create container") self.log("create container")
self.log(f"image: {image} parameters:") self.log(f"image: {image} parameters:")
@ -1014,7 +1121,7 @@ class ContainerManager(DockerBaseClass):
for key, value in network.items() for key, value in network.items()
if key not in ("name", "id") if key not in ("name", "id")
} }
self.results["actions"].append( self._add_action(
{ {
"created": "Created container", "created": "Created container",
"create_parameters": create_parameters, "create_parameters": create_parameters,
@ -1022,7 +1129,6 @@ class ContainerManager(DockerBaseClass):
} }
) )
self.results["changed"] = True self.results["changed"] = True
new_container = None
if not self.check_mode: if not self.check_mode:
try: try:
container_id = self.engine_driver.create_container( container_id = self.engine_driver.create_container(
@ -1031,11 +1137,11 @@ class ContainerManager(DockerBaseClass):
except Exception as exc: # pylint: disable=broad-exception-caught except Exception as exc: # pylint: disable=broad-exception-caught
self.fail(f"Error creating container: {exc}") self.fail(f"Error creating container: {exc}")
return self._get_container(container_id) return self._get_container(container_id)
return new_container return None
def container_start(self, container_id): def container_start(self, container_id: str) -> Container:
self.log(f"start container {container_id}") self.log(f"start container {container_id}")
self.results["actions"].append({"started": container_id}) self._add_action({"started": container_id})
self.results["changed"] = True self.results["changed"] = True
if not self.check_mode: if not self.check_mode:
try: try:
@ -1047,9 +1153,11 @@ class ContainerManager(DockerBaseClass):
status = self.engine_driver.wait_for_container( status = self.engine_driver.wait_for_container(
self.client, container_id self.client, container_id
) )
self.client.fail_results["status"] = status # mypy doesn't know that Client has fail_results property
self.client.fail_results["status"] = status # type: ignore
self.results["status"] = status self.results["status"] = status
output: str | bytes
if self.module.params["auto_remove"]: if self.module.params["auto_remove"]:
output = "Cannot retrieve result as auto_remove is enabled" output = "Cannot retrieve result as auto_remove is enabled"
if self.param_output_logs: if self.param_output_logs:
@ -1077,12 +1185,14 @@ class ContainerManager(DockerBaseClass):
return insp return insp
return self._get_container(container_id) return self._get_container(container_id)
def container_remove(self, container_id, link=False, force=False): def container_remove(
self, container_id: str, link: bool = False, force: bool = False
) -> None:
volume_state = not self.param_keep_volumes volume_state = not self.param_keep_volumes
self.log( self.log(
f"remove container container:{container_id} v:{volume_state} link:{link} force{force}" f"remove container container:{container_id} v:{volume_state} link:{link} force{force}"
) )
self.results["actions"].append( self._add_action(
{ {
"removed": container_id, "removed": container_id,
"volume_state": volume_state, "volume_state": volume_state,
@ -1101,13 +1211,15 @@ class ContainerManager(DockerBaseClass):
force=force, force=force,
) )
except Exception as exc: # pylint: disable=broad-exception-caught except Exception as exc: # pylint: disable=broad-exception-caught
self.client.fail(f"Error removing container {container_id}: {exc}") self.fail(f"Error removing container {container_id}: {exc}")
def container_update(self, container_id, update_parameters): def container_update(
self, container_id: str, update_parameters: dict[str, t.Any]
) -> Container:
if update_parameters: if update_parameters:
self.log(f"update container {container_id}") self.log(f"update container {container_id}")
self.log(update_parameters, pretty_print=True) self.log(update_parameters, pretty_print=True)
self.results["actions"].append( self._add_action(
{"updated": container_id, "update_parameters": update_parameters} {"updated": container_id, "update_parameters": update_parameters}
) )
self.results["changed"] = True self.results["changed"] = True
@ -1120,10 +1232,8 @@ class ContainerManager(DockerBaseClass):
self.fail(f"Error updating container {container_id}: {exc}") self.fail(f"Error updating container {container_id}: {exc}")
return self._get_container(container_id) return self._get_container(container_id)
def container_kill(self, container_id): def container_kill(self, container_id: str) -> None:
self.results["actions"].append( self._add_action({"killed": container_id, "signal": self.param_kill_signal})
{"killed": container_id, "signal": self.param_kill_signal}
)
self.results["changed"] = True self.results["changed"] = True
if not self.check_mode: if not self.check_mode:
try: try:
@ -1133,8 +1243,8 @@ class ContainerManager(DockerBaseClass):
except Exception as exc: # pylint: disable=broad-exception-caught except Exception as exc: # pylint: disable=broad-exception-caught
self.fail(f"Error killing container {container_id}: {exc}") self.fail(f"Error killing container {container_id}: {exc}")
def container_restart(self, container_id): def container_restart(self, container_id: str) -> Container:
self.results["actions"].append( self._add_action(
{"restarted": container_id, "timeout": self.module.params["stop_timeout"]} {"restarted": container_id, "timeout": self.module.params["stop_timeout"]}
) )
self.results["changed"] = True self.results["changed"] = True
@ -1147,11 +1257,11 @@ class ContainerManager(DockerBaseClass):
self.fail(f"Error restarting container {container_id}: {exc}") self.fail(f"Error restarting container {container_id}: {exc}")
return self._get_container(container_id) return self._get_container(container_id)
def container_stop(self, container_id): def container_stop(self, container_id: str) -> None:
if self.param_force_kill: if self.param_force_kill:
self.container_kill(container_id) self.container_kill(container_id)
return return
self.results["actions"].append( self._add_action(
{"stopped": container_id, "timeout": self.module.params["stop_timeout"]} {"stopped": container_id, "timeout": self.module.params["stop_timeout"]}
) )
self.results["changed"] = True self.results["changed"] = True
@ -1164,7 +1274,7 @@ class ContainerManager(DockerBaseClass):
self.fail(f"Error stopping container {container_id}: {exc}") self.fail(f"Error stopping container {container_id}: {exc}")
def run_module(engine_driver): def run_module(engine_driver: EngineDriver) -> None:
module, active_options, client = engine_driver.setup( module, active_options, client = engine_driver.setup(
argument_spec={ argument_spec={
"cleanup": {"type": "bool", "default": False}, "cleanup": {"type": "bool", "default": False},
@ -1228,7 +1338,7 @@ def run_module(engine_driver):
], ],
) )
def execute(): def execute() -> t.NoReturn:
cm = ContainerManager(module, engine_driver, client, active_options) cm = ContainerManager(module, engine_driver, client, active_options)
cm.run() cm.run()
module.exit_json(**sanitize_result(cm.results)) module.exit_json(**sanitize_result(cm.results))

View File

@ -14,12 +14,13 @@
from __future__ import annotations from __future__ import annotations
import re import re
import typing as t
_VALID_STR = re.compile("^[A-Za-z0-9_-]+$") _VALID_STR = re.compile("^[A-Za-z0-9_-]+$")
def _validate_part(string, part, part_name): def _validate_part(string: str, part: str, part_name: str) -> str:
if not part: if not part:
raise ValueError(f'Invalid platform string "{string}": {part_name} is empty') raise ValueError(f'Invalid platform string "{string}": {part_name} is empty')
if not _VALID_STR.match(part): if not _VALID_STR.match(part):
@ -79,7 +80,7 @@ _KNOWN_ARCH = (
) )
def _normalize_os(os_str): def _normalize_os(os_str: str) -> str:
# See normalizeOS() in https://github.com/containerd/containerd/blob/main/platforms/database.go # See normalizeOS() in https://github.com/containerd/containerd/blob/main/platforms/database.go
os_str = os_str.lower() os_str = os_str.lower()
if os_str == "macos": if os_str == "macos":
@ -112,7 +113,7 @@ _NORMALIZE_ARCH = {
} }
def _normalize_arch(arch_str, variant_str): def _normalize_arch(arch_str: str, variant_str: str) -> tuple[str, str]:
# See normalizeArch() in https://github.com/containerd/containerd/blob/main/platforms/database.go # See normalizeArch() in https://github.com/containerd/containerd/blob/main/platforms/database.go
arch_str = arch_str.lower() arch_str = arch_str.lower()
variant_str = variant_str.lower() variant_str = variant_str.lower()
@ -121,15 +122,16 @@ def _normalize_arch(arch_str, variant_str):
res = _NORMALIZE_ARCH.get((arch_str, None)) res = _NORMALIZE_ARCH.get((arch_str, None))
if res is None: if res is None:
return arch_str, variant_str return arch_str, variant_str
if res is not None: arch_str = res[0]
arch_str = res[0] if res[1] is not None:
if res[1] is not None: variant_str = res[1]
variant_str = res[1] return arch_str, variant_str
return arch_str, variant_str
class _Platform: class _Platform:
def __init__(self, os=None, arch=None, variant=None): def __init__(
self, os: str | None = None, arch: str | None = None, variant: str | None = None
) -> None:
self.os = os self.os = os
self.arch = arch self.arch = arch
self.variant = variant self.variant = variant
@ -140,7 +142,12 @@ class _Platform:
raise ValueError("If variant is given, os must be given too") raise ValueError("If variant is given, os must be given too")
@classmethod @classmethod
def parse_platform_string(cls, string, daemon_os=None, daemon_arch=None): def parse_platform_string(
cls,
string: str | None,
daemon_os: str | None = None,
daemon_arch: str | None = None,
) -> t.Self:
# See Parse() in https://github.com/containerd/containerd/blob/main/platforms/platforms.go # See Parse() in https://github.com/containerd/containerd/blob/main/platforms/platforms.go
if string is None: if string is None:
return cls() return cls()
@ -182,6 +189,7 @@ class _Platform:
) )
if variant is not None and not variant: if variant is not None and not variant:
raise ValueError(f'Invalid platform string "{string}": variant is empty') raise ValueError(f'Invalid platform string "{string}": variant is empty')
assert arch is not None # otherwise variant would be None as well
arch, variant = _normalize_arch(arch, variant or "") arch, variant = _normalize_arch(arch, variant or "")
if len(parts) == 2 and arch == "arm" and variant == "v7": if len(parts) == 2 and arch == "arm" and variant == "v7":
variant = None variant = None
@ -189,9 +197,12 @@ class _Platform:
variant = "v8" variant = "v8"
return cls(os=_normalize_os(os), arch=arch, variant=variant or None) return cls(os=_normalize_os(os), arch=arch, variant=variant or None)
def __str__(self): def __str__(self) -> str:
if self.variant: if self.variant:
parts = [self.os, self.arch, self.variant] assert (
self.os is not None and self.arch is not None
) # ensured in constructor
parts: list[str] = [self.os, self.arch, self.variant]
elif self.os: elif self.os:
if self.arch: if self.arch:
parts = [self.os, self.arch] parts = [self.os, self.arch]
@ -203,12 +214,14 @@ class _Platform:
parts = [] parts = []
return "/".join(parts) return "/".join(parts)
def __repr__(self): def __repr__(self) -> str:
return ( return (
f"_Platform(os={self.os!r}, arch={self.arch!r}, variant={self.variant!r})" f"_Platform(os={self.os!r}, arch={self.arch!r}, variant={self.variant!r})"
) )
def __eq__(self, other): def __eq__(self, other: object) -> bool:
if not isinstance(other, _Platform):
return NotImplemented
return ( return (
self.os == other.os self.os == other.os
and self.arch == other.arch and self.arch == other.arch
@ -216,7 +229,9 @@ class _Platform:
) )
def normalize_platform_string(string, daemon_os=None, daemon_arch=None): def normalize_platform_string(
string: str, daemon_os: str | None = None, daemon_arch: str | None = None
) -> str:
return str( return str(
_Platform.parse_platform_string( _Platform.parse_platform_string(
string, daemon_os=daemon_os, daemon_arch=daemon_arch string, daemon_os=daemon_os, daemon_arch=daemon_arch
@ -225,8 +240,12 @@ def normalize_platform_string(string, daemon_os=None, daemon_arch=None):
def compose_platform_string( def compose_platform_string(
os=None, arch=None, variant=None, daemon_os=None, daemon_arch=None os: str | None = None,
): arch: str | None = None,
variant: str | None = None,
daemon_os: str | None = None,
daemon_arch: str | None = None,
) -> str:
if os is None and daemon_os is not None: if os is None and daemon_os is not None:
os = _normalize_os(daemon_os) os = _normalize_os(daemon_os)
if arch is None and daemon_arch is not None: if arch is None and daemon_arch is not None:
@ -235,7 +254,7 @@ def compose_platform_string(
return str(_Platform(os=os, arch=arch, variant=variant or None)) return str(_Platform(os=os, arch=arch, variant=variant or None))
def compare_platform_strings(string1, string2): def compare_platform_strings(string1: str, string2: str) -> bool:
return _Platform.parse_platform_string(string1) == _Platform.parse_platform_string( return _Platform.parse_platform_string(string1) == _Platform.parse_platform_string(
string2 string2
) )

View File

@ -13,7 +13,7 @@ import random
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
def generate_insecure_key(): def generate_insecure_key() -> bytes:
"""Do NOT use this for cryptographic purposes!""" """Do NOT use this for cryptographic purposes!"""
while True: while True:
# Generate a one-byte key. Right now the functions below do not use more # Generate a one-byte key. Right now the functions below do not use more
@ -24,23 +24,23 @@ def generate_insecure_key():
return key return key
def scramble(value, key): def scramble(value: str, key: bytes) -> str:
"""Do NOT use this for cryptographic purposes!""" """Do NOT use this for cryptographic purposes!"""
if len(key) < 1: if len(key) < 1:
raise ValueError("Key must be at least one byte") raise ValueError("Key must be at least one byte")
value = to_bytes(value) b_value = to_bytes(value)
k = key[0] k = key[0]
value = bytes([k ^ b for b in value]) b_value = bytes([k ^ b for b in b_value])
return "=S=" + to_native(base64.b64encode(value)) return f"=S={to_native(base64.b64encode(b_value))}"
def unscramble(value, key): def unscramble(value: str, key: bytes) -> str:
"""Do NOT use this for cryptographic purposes!""" """Do NOT use this for cryptographic purposes!"""
if len(key) < 1: if len(key) < 1:
raise ValueError("Key must be at least one byte") raise ValueError("Key must be at least one byte")
if not value.startswith("=S="): if not value.startswith("=S="):
raise ValueError("Value does not start with indicator") raise ValueError("Value does not start with indicator")
value = base64.b64decode(value[3:]) b_value = base64.b64decode(value[3:])
k = key[0] k = key[0]
value = bytes([k ^ b for b in value]) b_value = bytes([k ^ b for b in b_value])
return to_text(value) return to_text(b_value)

View File

@ -12,6 +12,7 @@ import os.path
import selectors import selectors
import socket as pysocket import socket as pysocket
import struct import struct
import typing as t
from ansible_collections.community.docker.plugins.module_utils._api.utils import ( from ansible_collections.community.docker.plugins.module_utils._api.utils import (
socket as docker_socket, socket as docker_socket,
@ -23,58 +24,74 @@ from ansible_collections.community.docker.plugins.module_utils._socket_helper im
) )
if t.TYPE_CHECKING:
from collections.abc import Callable
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.docker.plugins.module_utils._socket_helper import (
SocketLike,
)
PARAMIKO_POLL_TIMEOUT = 0.01 # 10 milliseconds PARAMIKO_POLL_TIMEOUT = 0.01 # 10 milliseconds
def _empty_writer(msg: str) -> None:
pass
class DockerSocketHandlerBase: class DockerSocketHandlerBase:
def __init__(self, sock, log=None): def __init__(
self, sock: SocketLike, log: Callable[[str], None] | None = None
) -> None:
make_unblocking(sock) make_unblocking(sock)
if log is not None: self._log = log or _empty_writer
self._log = log
else:
self._log = lambda msg: True
self._paramiko_read_workaround = hasattr( self._paramiko_read_workaround = hasattr(
sock, "send_ready" sock, "send_ready"
) and "paramiko" in str(type(sock)) ) and "paramiko" in str(type(sock))
self._sock = sock self._sock = sock
self._block_done_callback = None self._block_done_callback: Callable[[int, bytes], None] | None = None
self._block_buffer = [] self._block_buffer: list[tuple[int, bytes]] = []
self._eof = False self._eof = False
self._read_buffer = b"" self._read_buffer = b""
self._write_buffer = b"" self._write_buffer = b""
self._end_of_writing = False self._end_of_writing = False
self._current_stream = None self._current_stream: int | None = None
self._current_missing = 0 self._current_missing = 0
self._current_buffer = b"" self._current_buffer = b""
self._selector = selectors.DefaultSelector() self._selector = selectors.DefaultSelector()
self._selector.register(self._sock, selectors.EVENT_READ) self._selector.register(self._sock, selectors.EVENT_READ)
def __enter__(self): def __enter__(self) -> t.Self:
return self return self
def __exit__(self, type_, value, tb): def __exit__(self, type_, value, tb) -> None:
self._selector.close() self._selector.close()
def set_block_done_callback(self, block_done_callback): def set_block_done_callback(
self, block_done_callback: Callable[[int, bytes], None]
) -> None:
self._block_done_callback = block_done_callback self._block_done_callback = block_done_callback
if self._block_done_callback is not None: if self._block_done_callback is not None:
while self._block_buffer: while self._block_buffer:
elt = self._block_buffer.pop(0) elt = self._block_buffer.pop(0)
self._block_done_callback(*elt) self._block_done_callback(*elt)
def _add_block(self, stream_id, data): def _add_block(self, stream_id: int, data: bytes) -> None:
if self._block_done_callback is not None: if self._block_done_callback is not None:
self._block_done_callback(stream_id, data) self._block_done_callback(stream_id, data)
else: else:
self._block_buffer.append((stream_id, data)) self._block_buffer.append((stream_id, data))
def _read(self): def _read(self) -> None:
if self._eof: if self._eof:
return return
data: bytes | None
if hasattr(self._sock, "recv"): if hasattr(self._sock, "recv"):
try: try:
data = self._sock.recv(262144) data = self._sock.recv(262144)
@ -86,13 +103,13 @@ class DockerSocketHandlerBase:
self._eof = True self._eof = True
return return
raise raise
elif isinstance(self._sock, getattr(pysocket, "SocketIO")): elif isinstance(self._sock, pysocket.SocketIO): # type: ignore[unreachable]
data = self._sock.read() data = self._sock.read() # type: ignore
else: else:
data = os.read(self._sock.fileno()) data = os.read(self._sock.fileno()) # type: ignore # TODO does this really work?!
if data is None: if data is None:
# no data available # no data available
return return # type: ignore[unreachable]
self._log(f"read {len(data)} bytes") self._log(f"read {len(data)} bytes")
if len(data) == 0: if len(data) == 0:
# Stream EOF # Stream EOF
@ -106,6 +123,7 @@ class DockerSocketHandlerBase:
self._read_buffer = self._read_buffer[n:] self._read_buffer = self._read_buffer[n:]
self._current_missing -= n self._current_missing -= n
if self._current_missing == 0: if self._current_missing == 0:
assert self._current_stream is not None
self._add_block(self._current_stream, self._current_buffer) self._add_block(self._current_stream, self._current_buffer)
self._current_buffer = b"" self._current_buffer = b""
if len(self._read_buffer) < 8: if len(self._read_buffer) < 8:
@ -119,13 +137,13 @@ class DockerSocketHandlerBase:
self._eof = True self._eof = True
break break
def _handle_end_of_writing(self): def _handle_end_of_writing(self) -> None:
if self._end_of_writing and len(self._write_buffer) == 0: if self._end_of_writing and len(self._write_buffer) == 0:
self._end_of_writing = False self._end_of_writing = False
self._log("Shutting socket down for writing") self._log("Shutting socket down for writing")
shutdown_writing(self._sock, self._log) shutdown_writing(self._sock, self._log)
def _write(self): def _write(self) -> None:
if len(self._write_buffer) > 0: if len(self._write_buffer) > 0:
written = write_to_socket(self._sock, self._write_buffer) written = write_to_socket(self._sock, self._write_buffer)
self._write_buffer = self._write_buffer[written:] self._write_buffer = self._write_buffer[written:]
@ -138,7 +156,9 @@ class DockerSocketHandlerBase:
self._selector.modify(self._sock, selectors.EVENT_READ) self._selector.modify(self._sock, selectors.EVENT_READ)
self._handle_end_of_writing() self._handle_end_of_writing()
def select(self, timeout=None, _internal_recursion=False): def select(
self, timeout: int | float | None = None, _internal_recursion: bool = False
) -> bool:
if ( if (
not _internal_recursion not _internal_recursion
and self._paramiko_read_workaround and self._paramiko_read_workaround
@ -147,12 +167,14 @@ class DockerSocketHandlerBase:
# When the SSH transport is used, Docker SDK for Python internally uses Paramiko, whose # When the SSH transport is used, Docker SDK for Python internally uses Paramiko, whose
# Channel object supports select(), but only for reading # Channel object supports select(), but only for reading
# (https://github.com/paramiko/paramiko/issues/695). # (https://github.com/paramiko/paramiko/issues/695).
if self._sock.send_ready(): if self._sock.send_ready(): # type: ignore
self._write() self._write()
return True return True
while timeout is None or timeout > PARAMIKO_POLL_TIMEOUT: while timeout is None or timeout > PARAMIKO_POLL_TIMEOUT:
result = self.select(PARAMIKO_POLL_TIMEOUT, _internal_recursion=True) result = int(
if self._sock.send_ready(): self.select(PARAMIKO_POLL_TIMEOUT, _internal_recursion=True)
)
if self._sock.send_ready(): # type: ignore
self._read() self._read()
result += 1 result += 1
if result > 0: if result > 0:
@ -172,19 +194,19 @@ class DockerSocketHandlerBase:
self._write() self._write()
result = len(events) result = len(events)
if self._paramiko_read_workaround and len(self._write_buffer) > 0: if self._paramiko_read_workaround and len(self._write_buffer) > 0:
if self._sock.send_ready(): if self._sock.send_ready(): # type: ignore
self._write() self._write()
result += 1 result += 1
return result > 0 return result > 0
def is_eof(self): def is_eof(self) -> bool:
return self._eof return self._eof
def end_of_writing(self): def end_of_writing(self) -> None:
self._end_of_writing = True self._end_of_writing = True
self._handle_end_of_writing() self._handle_end_of_writing()
def consume(self): def consume(self) -> tuple[bytes, bytes]:
stdout = [] stdout = []
stderr = [] stderr = []
@ -203,12 +225,12 @@ class DockerSocketHandlerBase:
self.select() self.select()
return b"".join(stdout), b"".join(stderr) return b"".join(stdout), b"".join(stderr)
def write(self, str_to_write): def write(self, str_to_write: bytes) -> None:
self._write_buffer += str_to_write self._write_buffer += str_to_write
if len(self._write_buffer) == len(str_to_write): if len(self._write_buffer) == len(str_to_write):
self._write() self._write()
class DockerSocketHandlerModule(DockerSocketHandlerBase): class DockerSocketHandlerModule(DockerSocketHandlerBase):
def __init__(self, sock, module): def __init__(self, sock: SocketLike, module: AnsibleModule) -> None:
super().__init__(sock, module.debug) super().__init__(sock, module.debug)

View File

@ -12,9 +12,14 @@ import os
import os.path import os.path
import socket as pysocket import socket as pysocket
import typing as t import typing as t
from collections.abc import Callable
def make_file_unblocking(file) -> None: if t.TYPE_CHECKING:
SocketLike = pysocket.socket
def make_file_unblocking(file: SocketLike) -> None:
fcntl.fcntl( fcntl.fcntl(
file.fileno(), file.fileno(),
fcntl.F_SETFL, fcntl.F_SETFL,
@ -22,7 +27,7 @@ def make_file_unblocking(file) -> None:
) )
def make_file_blocking(file) -> None: def make_file_blocking(file: SocketLike) -> None:
fcntl.fcntl( fcntl.fcntl(
file.fileno(), file.fileno(),
fcntl.F_SETFL, fcntl.F_SETFL,
@ -30,11 +35,11 @@ def make_file_blocking(file) -> None:
) )
def make_unblocking(sock) -> None: def make_unblocking(sock: SocketLike) -> None:
if hasattr(sock, "_sock"): if hasattr(sock, "_sock"):
sock._sock.setblocking(0) sock._sock.setblocking(0)
elif hasattr(sock, "setblocking"): elif hasattr(sock, "setblocking"):
sock.setblocking(0) sock.setblocking(0) # type: ignore # TODO: CHECK!
else: else:
make_file_unblocking(sock) make_file_unblocking(sock)
@ -43,7 +48,9 @@ def _empty_writer(msg: str) -> None:
pass pass
def shutdown_writing(sock, log: t.Callable[[str], None] = _empty_writer) -> None: def shutdown_writing(
sock: SocketLike, log: Callable[[str], None] = _empty_writer
) -> None:
# FIXME: This does **not work with SSLSocket**! Apparently SSLSocket does not allow to send # FIXME: This does **not work with SSLSocket**! Apparently SSLSocket does not allow to send
# a close_notify TLS alert without completely shutting down the connection. # a close_notify TLS alert without completely shutting down the connection.
# Calling sock.shutdown(pysocket.SHUT_WR) simply turns of TLS encryption and from that # Calling sock.shutdown(pysocket.SHUT_WR) simply turns of TLS encryption and from that
@ -56,14 +63,14 @@ def shutdown_writing(sock, log: t.Callable[[str], None] = _empty_writer) -> None
except TypeError as e: except TypeError as e:
# probably: "TypeError: shutdown() takes 1 positional argument but 2 were given" # probably: "TypeError: shutdown() takes 1 positional argument but 2 were given"
log(f"Shutting down for writing not possible; trying shutdown instead: {e}") log(f"Shutting down for writing not possible; trying shutdown instead: {e}")
sock.shutdown() sock.shutdown() # type: ignore
elif isinstance(sock, getattr(pysocket, "SocketIO")): elif isinstance(sock, getattr(pysocket, "SocketIO")):
sock._sock.shutdown(pysocket.SHUT_WR) sock._sock.shutdown(pysocket.SHUT_WR)
else: else:
log("No idea how to signal end of writing") log("No idea how to signal end of writing")
def write_to_socket(sock, data: bytes) -> None: def write_to_socket(sock: SocketLike, data: bytes) -> int:
if hasattr(sock, "_send_until_done"): if hasattr(sock, "_send_until_done"):
# WrappedSocket (urllib3/contrib/pyopenssl) does not have `send`, but # WrappedSocket (urllib3/contrib/pyopenssl) does not have `send`, but
# only `sendall`, which uses `_send_until_done` under the hood. # only `sendall`, which uses `_send_until_done` under the hood.

View File

@ -9,6 +9,7 @@
from __future__ import annotations from __future__ import annotations
import json import json
import typing as t
from time import sleep from time import sleep
@ -28,10 +29,7 @@ from ansible_collections.community.docker.plugins.module_utils._version import (
class AnsibleDockerSwarmClient(AnsibleDockerClient): class AnsibleDockerSwarmClient(AnsibleDockerClient):
def __init__(self, **kwargs): def get_swarm_node_id(self) -> str | None:
super().__init__(**kwargs)
def get_swarm_node_id(self):
""" """
Get the 'NodeID' of the Swarm node or 'None' if host is not in Swarm. It returns the NodeID Get the 'NodeID' of the Swarm node or 'None' if host is not in Swarm. It returns the NodeID
of Docker host the module is executed on of Docker host the module is executed on
@ -51,7 +49,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
return swarm_info["Swarm"]["NodeID"] return swarm_info["Swarm"]["NodeID"]
return None return None
def check_if_swarm_node(self, node_id=None): def check_if_swarm_node(self, node_id: str | None = None) -> bool | None:
""" """
Checking if host is part of Docker Swarm. If 'node_id' is not provided it reads the Docker host Checking if host is part of Docker Swarm. If 'node_id' is not provided it reads the Docker host
system information looking if specific key in output exists. If 'node_id' is provided then it tries to system information looking if specific key in output exists. If 'node_id' is provided then it tries to
@ -83,11 +81,11 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
try: try:
node_info = self.get_node_inspect(node_id=node_id) node_info = self.get_node_inspect(node_id=node_id)
except APIError: except APIError:
return return None
return node_info["ID"] is not None return node_info["ID"] is not None
def check_if_swarm_manager(self): def check_if_swarm_manager(self) -> bool:
""" """
Checks if node role is set as Manager in Swarm. The node is the docker host on which module action Checks if node role is set as Manager in Swarm. The node is the docker host on which module action
is performed. The inspect_swarm() will fail if node is not a manager is performed. The inspect_swarm() will fail if node is not a manager
@ -101,7 +99,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
except APIError: except APIError:
return False return False
def fail_task_if_not_swarm_manager(self): def fail_task_if_not_swarm_manager(self) -> None:
""" """
If host is not a swarm manager then Ansible task on this host should end with 'failed' state If host is not a swarm manager then Ansible task on this host should end with 'failed' state
""" """
@ -110,7 +108,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
"Error running docker swarm module: must run on swarm manager node" "Error running docker swarm module: must run on swarm manager node"
) )
def check_if_swarm_worker(self): def check_if_swarm_worker(self) -> bool:
""" """
Checks if node role is set as Worker in Swarm. The node is the docker host on which module action Checks if node role is set as Worker in Swarm. The node is the docker host on which module action
is performed. Will fail if run on host that is not part of Swarm via check_if_swarm_node() is performed. Will fail if run on host that is not part of Swarm via check_if_swarm_node()
@ -122,7 +120,9 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
return True return True
return False return False
def check_if_swarm_node_is_down(self, node_id=None, repeat_check=1): def check_if_swarm_node_is_down(
self, node_id: str | None = None, repeat_check: int = 1
) -> bool:
""" """
Checks if node status on Swarm manager is 'down'. If node_id is provided it query manager about Checks if node status on Swarm manager is 'down'. If node_id is provided it query manager about
node specified in parameter, otherwise it query manager itself. If run on Swarm Worker node or node specified in parameter, otherwise it query manager itself. If run on Swarm Worker node or
@ -147,7 +147,19 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
return True return True
return False return False
def get_node_inspect(self, node_id=None, skip_missing=False): @t.overload
def get_node_inspect(
self, node_id: str | None = None, skip_missing: t.Literal[False] = False
) -> dict[str, t.Any]: ...
@t.overload
def get_node_inspect(
self, node_id: str | None = None, skip_missing: bool = False
) -> dict[str, t.Any] | None: ...
def get_node_inspect(
self, node_id: str | None = None, skip_missing: bool = False
) -> dict[str, t.Any] | None:
""" """
Returns Swarm node info as in 'docker node inspect' command about single node Returns Swarm node info as in 'docker node inspect' command about single node
@ -195,7 +207,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
node_info["Status"]["Addr"] = swarm_leader_ip node_info["Status"]["Addr"] = swarm_leader_ip
return node_info return node_info
def get_all_nodes_inspect(self): def get_all_nodes_inspect(self) -> list[dict[str, t.Any]]:
""" """
Returns Swarm node info as in 'docker node inspect' command about all registered nodes Returns Swarm node info as in 'docker node inspect' command about all registered nodes
@ -217,7 +229,17 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
node_info = json.loads(json_str) node_info = json.loads(json_str)
return node_info return node_info
def get_all_nodes_list(self, output="short"): @t.overload
def get_all_nodes_list(self, output: t.Literal["short"] = "short") -> list[str]: ...
@t.overload
def get_all_nodes_list(
self, output: t.Literal["long"]
) -> list[dict[str, t.Any]]: ...
def get_all_nodes_list(
self, output: t.Literal["short", "long"] = "short"
) -> list[str] | list[dict[str, t.Any]]:
""" """
Returns list of nodes registered in Swarm Returns list of nodes registered in Swarm
@ -227,48 +249,46 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
if 'output' is 'long' then returns data is list of dict containing the attributes as in if 'output' is 'long' then returns data is list of dict containing the attributes as in
output of command 'docker node ls' output of command 'docker node ls'
""" """
nodes_list = []
nodes_inspect = self.get_all_nodes_inspect() nodes_inspect = self.get_all_nodes_inspect()
if nodes_inspect is None:
return None
if output == "short": if output == "short":
nodes_list = []
for node in nodes_inspect: for node in nodes_inspect:
nodes_list.append(node["Description"]["Hostname"]) nodes_list.append(node["Description"]["Hostname"])
elif output == "long": return nodes_list
if output == "long":
nodes_info_list = []
for node in nodes_inspect: for node in nodes_inspect:
node_property = {} node_property: dict[str, t.Any] = {}
node_property.update({"ID": node["ID"]}) node_property["ID"] = node["ID"]
node_property.update({"Hostname": node["Description"]["Hostname"]}) node_property["Hostname"] = node["Description"]["Hostname"]
node_property.update({"Status": node["Status"]["State"]}) node_property["Status"] = node["Status"]["State"]
node_property.update({"Availability": node["Spec"]["Availability"]}) node_property["Availability"] = node["Spec"]["Availability"]
if "ManagerStatus" in node: if "ManagerStatus" in node:
if node["ManagerStatus"]["Leader"] is True: if node["ManagerStatus"]["Leader"] is True:
node_property.update({"Leader": True}) node_property["Leader"] = True
node_property.update( node_property["ManagerStatus"] = node["ManagerStatus"][
{"ManagerStatus": node["ManagerStatus"]["Reachability"]} "Reachability"
) ]
node_property.update( node_property["EngineVersion"] = node["Description"]["Engine"][
{"EngineVersion": node["Description"]["Engine"]["EngineVersion"]} "EngineVersion"
) ]
nodes_list.append(node_property) nodes_info_list.append(node_property)
else: return nodes_info_list
return None
return nodes_list def get_node_name_by_id(self, nodeid: str) -> str:
def get_node_name_by_id(self, nodeid):
return self.get_node_inspect(nodeid)["Description"]["Hostname"] return self.get_node_inspect(nodeid)["Description"]["Hostname"]
def get_unlock_key(self): def get_unlock_key(self) -> str | None:
if self.docker_py_version < LooseVersion("2.7.0"): if self.docker_py_version < LooseVersion("2.7.0"):
return None return None
return super().get_unlock_key() return super().get_unlock_key()
def get_service_inspect(self, service_id, skip_missing=False): def get_service_inspect(
self, service_id: str, skip_missing: bool = False
) -> dict[str, t.Any] | None:
""" """
Returns Swarm service info as in 'docker service inspect' command about single service Returns Swarm service info as in 'docker service inspect' command about single service

View File

@ -9,6 +9,7 @@ from __future__ import annotations
import json import json
import re import re
import typing as t
from datetime import timedelta from datetime import timedelta
from urllib.parse import urlparse from urllib.parse import urlparse
@ -17,6 +18,12 @@ from ansible.module_utils.common.collections import is_sequence
from ansible.module_utils.common.text.converters import to_text from ansible.module_utils.common.text.converters import to_text
if t.TYPE_CHECKING:
from collections.abc import Callable
from ansible.module_utils.basic import AnsibleModule
DEFAULT_DOCKER_HOST = "unix:///var/run/docker.sock" DEFAULT_DOCKER_HOST = "unix:///var/run/docker.sock"
DEFAULT_TLS = False DEFAULT_TLS = False
DEFAULT_TLS_VERIFY = False DEFAULT_TLS_VERIFY = False
@ -69,22 +76,24 @@ DOCKER_COMMON_ARGS_VARS = {
if option_name != "debug" if option_name != "debug"
} }
DOCKER_MUTUALLY_EXCLUSIVE = [] DOCKER_MUTUALLY_EXCLUSIVE: list[tuple[str, ...] | list[str]] = []
DOCKER_REQUIRED_TOGETHER = [["client_cert", "client_key"]] DOCKER_REQUIRED_TOGETHER: list[tuple[str, ...] | list[str]] = [
["client_cert", "client_key"]
]
DEFAULT_DOCKER_REGISTRY = "https://index.docker.io/v1/" DEFAULT_DOCKER_REGISTRY = "https://index.docker.io/v1/"
BYTE_SUFFIXES = ["B", "KB", "MB", "GB", "TB", "PB"] BYTE_SUFFIXES = ["B", "KB", "MB", "GB", "TB", "PB"]
def is_image_name_id(name): def is_image_name_id(name: str) -> bool:
"""Check whether the given image name is in fact an image ID (hash).""" """Check whether the given image name is in fact an image ID (hash)."""
if re.match("^sha256:[0-9a-fA-F]{64}$", name): if re.match("^sha256:[0-9a-fA-F]{64}$", name):
return True return True
return False return False
def is_valid_tag(tag, allow_empty=False): def is_valid_tag(tag: str, allow_empty: bool = False) -> bool:
"""Check whether the given string is a valid docker tag name.""" """Check whether the given string is a valid docker tag name."""
if not tag: if not tag:
return allow_empty return allow_empty
@ -93,7 +102,7 @@ def is_valid_tag(tag, allow_empty=False):
return bool(re.match("^[a-zA-Z0-9_][a-zA-Z0-9_.-]{0,127}$", tag)) return bool(re.match("^[a-zA-Z0-9_][a-zA-Z0-9_.-]{0,127}$", tag))
def sanitize_result(data): def sanitize_result(data: t.Any) -> t.Any:
"""Sanitize data object for return to Ansible. """Sanitize data object for return to Ansible.
When the data object contains types such as docker.types.containers.HostConfig, When the data object contains types such as docker.types.containers.HostConfig,
@ -110,7 +119,7 @@ def sanitize_result(data):
return data return data
def log_debug(msg, pretty_print=False): def log_debug(msg: t.Any, pretty_print: bool = False):
"""Write a log message to docker.log. """Write a log message to docker.log.
If ``pretty_print=True``, the message will be pretty-printed as JSON. If ``pretty_print=True``, the message will be pretty-printed as JSON.
@ -126,25 +135,28 @@ def log_debug(msg, pretty_print=False):
class DockerBaseClass: class DockerBaseClass:
def __init__(self): def __init__(self) -> None:
self.debug = False self.debug = False
def log(self, msg, pretty_print=False): def log(self, msg: t.Any, pretty_print: bool = False) -> None:
pass pass
# if self.debug: # if self.debug:
# log_debug(msg, pretty_print=pretty_print) # log_debug(msg, pretty_print=pretty_print)
def update_tls_hostname( def update_tls_hostname(
result, old_behavior=False, deprecate_function=None, uses_tls=True result: dict[str, t.Any],
): old_behavior: bool = False,
deprecate_function: Callable[[str], None] | None = None,
uses_tls: bool = True,
) -> None:
if result["tls_hostname"] is None: if result["tls_hostname"] is None:
# get default machine name from the url # get default machine name from the url
parsed_url = urlparse(result["docker_host"]) parsed_url = urlparse(result["docker_host"])
result["tls_hostname"] = parsed_url.netloc.rsplit(":", 1)[0] result["tls_hostname"] = parsed_url.netloc.rsplit(":", 1)[0]
def compare_dict_allow_more_present(av, bv): def compare_dict_allow_more_present(av: dict, bv: dict) -> bool:
""" """
Compare two dictionaries for whether every entry of the first is in the second. Compare two dictionaries for whether every entry of the first is in the second.
""" """
@ -156,7 +168,12 @@ def compare_dict_allow_more_present(av, bv):
return True return True
def compare_generic(a, b, method, datatype): def compare_generic(
a: t.Any,
b: t.Any,
method: t.Literal["ignore", "strict", "allow_more_present"],
datatype: t.Literal["value", "list", "set", "set(dict)", "dict"],
) -> bool:
""" """
Compare values a and b as described by method and datatype. Compare values a and b as described by method and datatype.
@ -247,10 +264,10 @@ def compare_generic(a, b, method, datatype):
class DifferenceTracker: class DifferenceTracker:
def __init__(self): def __init__(self) -> None:
self._diff = [] self._diff: list[dict[str, t.Any]] = []
def add(self, name, parameter=None, active=None): def add(self, name: str, parameter: t.Any = None, active: t.Any = None) -> None:
self._diff.append( self._diff.append(
{ {
"name": name, "name": name,
@ -259,14 +276,14 @@ class DifferenceTracker:
} }
) )
def merge(self, other_tracker): def merge(self, other_tracker: DifferenceTracker) -> None:
self._diff.extend(other_tracker._diff) self._diff.extend(other_tracker._diff)
@property @property
def empty(self): def empty(self) -> bool:
return len(self._diff) == 0 return len(self._diff) == 0
def get_before_after(self): def get_before_after(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
""" """
Return texts ``before`` and ``after``. Return texts ``before`` and ``after``.
""" """
@ -277,13 +294,13 @@ class DifferenceTracker:
after[item["name"]] = item["parameter"] after[item["name"]] = item["parameter"]
return before, after return before, after
def has_difference_for(self, name): def has_difference_for(self, name: str) -> bool:
""" """
Returns a boolean if a difference exists for name Returns a boolean if a difference exists for name
""" """
return any(diff for diff in self._diff if diff["name"] == name) return any(diff for diff in self._diff if diff["name"] == name)
def get_legacy_docker_container_diffs(self): def get_legacy_docker_container_diffs(self) -> list[dict[str, t.Any]]:
""" """
Return differences in the docker_container legacy format. Return differences in the docker_container legacy format.
""" """
@ -297,7 +314,7 @@ class DifferenceTracker:
result.append(item) result.append(item)
return result return result
def get_legacy_docker_diffs(self): def get_legacy_docker_diffs(self) -> list[str]:
""" """
Return differences in the docker_container legacy format. Return differences in the docker_container legacy format.
""" """
@ -305,8 +322,13 @@ class DifferenceTracker:
return result return result
def sanitize_labels(labels, labels_field, client=None, module=None): def sanitize_labels(
def fail(msg): labels: dict[str, t.Any] | None,
labels_field: str,
client=None,
module: AnsibleModule | None = None,
) -> None:
def fail(msg: str) -> t.NoReturn:
if client is not None: if client is not None:
client.fail(msg) client.fail(msg)
if module is not None: if module is not None:
@ -325,7 +347,21 @@ def sanitize_labels(labels, labels_field, client=None, module=None):
labels[k] = to_text(v) labels[k] = to_text(v)
def clean_dict_booleans_for_docker_api(data, allow_sequences=False): @t.overload
def clean_dict_booleans_for_docker_api(
data: dict[str, t.Any], *, allow_sequences: t.Literal[False] = False
) -> dict[str, str]: ...
@t.overload
def clean_dict_booleans_for_docker_api(
data: dict[str, t.Any], *, allow_sequences: bool
) -> dict[str, str | list[str]]: ...
def clean_dict_booleans_for_docker_api(
data: dict[str, t.Any] | None, *, allow_sequences: bool = False
) -> dict[str, str] | dict[str, str | list[str]]:
""" """
Go does not like Python booleans 'True' or 'False', while Ansible is just Go does not like Python booleans 'True' or 'False', while Ansible is just
fine with them in YAML. As such, they need to be converted in cases where fine with them in YAML. As such, they need to be converted in cases where
@ -353,7 +389,7 @@ def clean_dict_booleans_for_docker_api(data, allow_sequences=False):
return result return result
def convert_duration_to_nanosecond(time_str): def convert_duration_to_nanosecond(time_str: str) -> int:
""" """
Return time duration in nanosecond. Return time duration in nanosecond.
""" """
@ -372,9 +408,9 @@ def convert_duration_to_nanosecond(time_str):
if not parts: if not parts:
raise ValueError(f"Invalid time duration - {time_str}") raise ValueError(f"Invalid time duration - {time_str}")
parts = parts.groupdict() parts_dict = parts.groupdict()
time_params = {} time_params = {}
for name, value in parts.items(): for name, value in parts_dict.items():
if value: if value:
time_params[name] = int(value) time_params[name] = int(value)
@ -386,13 +422,15 @@ def convert_duration_to_nanosecond(time_str):
return time_in_nanoseconds return time_in_nanoseconds
def normalize_healthcheck_test(test): def normalize_healthcheck_test(test: t.Any) -> list[str]:
if isinstance(test, (tuple, list)): if isinstance(test, (tuple, list)):
return [str(e) for e in test] return [str(e) for e in test]
return ["CMD-SHELL", str(test)] return ["CMD-SHELL", str(test)]
def normalize_healthcheck(healthcheck, normalize_test=False): def normalize_healthcheck(
healthcheck: dict[str, t.Any], normalize_test: bool = False
) -> dict[str, t.Any]:
""" """
Return dictionary of healthcheck parameters. Return dictionary of healthcheck parameters.
""" """
@ -438,7 +476,9 @@ def normalize_healthcheck(healthcheck, normalize_test=False):
return result return result
def parse_healthcheck(healthcheck): def parse_healthcheck(
healthcheck: dict[str, t.Any] | None,
) -> tuple[dict[str, t.Any] | None, bool | None]:
""" """
Return dictionary of healthcheck parameters and boolean if Return dictionary of healthcheck parameters and boolean if
healthcheck defined in image was requested to be disabled. healthcheck defined in image was requested to be disabled.
@ -456,8 +496,8 @@ def parse_healthcheck(healthcheck):
return result, False return result, False
def omit_none_from_dict(d): def omit_none_from_dict(d: dict[str, t.Any]) -> dict[str, t.Any]:
""" """
Return a copy of the dictionary with all keys with value None omitted. Return a copy of the dictionary with all keys with value None omitted.
""" """
return dict((k, v) for (k, v) in d.items() if v is not None) return {k: v for (k, v) in d.items() if v is not None}

View File

@ -80,7 +80,7 @@ import re
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
def main(): def main() -> None:
module = AnsibleModule({}, supports_check_mode=True) module = AnsibleModule({}, supports_check_mode=True)
cpuset_path = "/proc/self/cpuset" cpuset_path = "/proc/self/cpuset"

View File

@ -437,6 +437,7 @@ actions:
""" """
import traceback import traceback
import typing as t
from ansible.module_utils.common.validation import check_type_int from ansible.module_utils.common.validation import check_type_int
@ -455,26 +456,32 @@ from ansible_collections.community.docker.plugins.module_utils._version import (
class ServicesManager(BaseComposeManager): class ServicesManager(BaseComposeManager):
def __init__(self, client): def __init__(self, client: AnsibleModuleDockerClient) -> None:
super().__init__(client) super().__init__(client)
parameters = self.client.module.params parameters = self.client.module.params
self.state = parameters["state"] self.state: t.Literal["absent", "present", "stopped", "restarted"] = parameters[
self.dependencies = parameters["dependencies"] "state"
self.pull = parameters["pull"] ]
self.build = parameters["build"] self.dependencies: bool = parameters["dependencies"]
self.ignore_build_events = parameters["ignore_build_events"] self.pull: t.Literal["always", "missing", "never", "policy"] = parameters[
self.recreate = parameters["recreate"] "pull"
self.remove_images = parameters["remove_images"] ]
self.remove_volumes = parameters["remove_volumes"] self.build: t.Literal["always", "never", "policy"] = parameters["build"]
self.remove_orphans = parameters["remove_orphans"] self.ignore_build_events: bool = parameters["ignore_build_events"]
self.renew_anon_volumes = parameters["renew_anon_volumes"] self.recreate: t.Literal["always", "never", "auto"] = parameters["recreate"]
self.timeout = parameters["timeout"] self.remove_images: t.Literal["all", "local"] | None = parameters[
self.services = parameters["services"] or [] "remove_images"
self.scale = parameters["scale"] or {} ]
self.wait = parameters["wait"] self.remove_volumes: bool = parameters["remove_volumes"]
self.wait_timeout = parameters["wait_timeout"] self.remove_orphans: bool = parameters["remove_orphans"]
self.yes = parameters["assume_yes"] self.renew_anon_volumes: bool = parameters["renew_anon_volumes"]
self.timeout: int | None = parameters["timeout"]
self.services: list[str] = parameters["services"] or []
self.scale: dict[str, t.Any] = parameters["scale"] or {}
self.wait: bool = parameters["wait"]
self.wait_timeout: int | None = parameters["wait_timeout"]
self.yes: bool = parameters["assume_yes"]
if self.compose_version < LooseVersion("2.32.0") and self.yes: if self.compose_version < LooseVersion("2.32.0") and self.yes:
self.fail( self.fail(
f"assume_yes=true needs Docker Compose 2.32.0 or newer, not version {self.compose_version}" f"assume_yes=true needs Docker Compose 2.32.0 or newer, not version {self.compose_version}"
@ -491,7 +498,7 @@ class ServicesManager(BaseComposeManager):
self.fail(f"The value {value!r} for `scale[{key!r}]` is negative") self.fail(f"The value {value!r} for `scale[{key!r}]` is negative")
self.scale[key] = value self.scale[key] = value
def run(self): def run(self) -> dict[str, t.Any]:
if self.state == "present": if self.state == "present":
result = self.cmd_up() result = self.cmd_up()
elif self.state == "stopped": elif self.state == "stopped":
@ -508,7 +515,7 @@ class ServicesManager(BaseComposeManager):
self.cleanup_result(result) self.cleanup_result(result)
return result return result
def get_up_cmd(self, dry_run, no_start=False): def get_up_cmd(self, dry_run: bool, no_start: bool = False) -> list[str]:
args = self.get_base_args() + ["up", "--detach", "--no-color", "--quiet-pull"] args = self.get_base_args() + ["up", "--detach", "--no-color", "--quiet-pull"]
if self.pull != "policy": if self.pull != "policy":
args.extend(["--pull", self.pull]) args.extend(["--pull", self.pull])
@ -549,8 +556,8 @@ class ServicesManager(BaseComposeManager):
args.append(service) args.append(service)
return args return args
def cmd_up(self): def cmd_up(self) -> dict[str, t.Any]:
result = {} result: dict[str, t.Any] = {}
args = self.get_up_cmd(self.check_mode) args = self.get_up_cmd(self.check_mode)
rc, stdout, stderr = self.client.call_cli(*args, cwd=self.project_src) rc, stdout, stderr = self.client.call_cli(*args, cwd=self.project_src)
events = self.parse_events(stderr, dry_run=self.check_mode, nonzero_rc=rc != 0) events = self.parse_events(stderr, dry_run=self.check_mode, nonzero_rc=rc != 0)
@ -566,7 +573,7 @@ class ServicesManager(BaseComposeManager):
self.update_failed(result, events, args, stdout, stderr, rc) self.update_failed(result, events, args, stdout, stderr, rc)
return result return result
def get_stop_cmd(self, dry_run): def get_stop_cmd(self, dry_run: bool) -> list[str]:
args = self.get_base_args() + ["stop"] args = self.get_base_args() + ["stop"]
if self.timeout is not None: if self.timeout is not None:
args.extend(["--timeout", f"{self.timeout}"]) args.extend(["--timeout", f"{self.timeout}"])
@ -577,17 +584,17 @@ class ServicesManager(BaseComposeManager):
args.append(service) args.append(service)
return args return args
def _are_containers_stopped(self): def _are_containers_stopped(self) -> bool:
for container in self.list_containers_raw(): for container in self.list_containers_raw():
if container["State"] not in ("created", "exited", "stopped", "killed"): if container["State"] not in ("created", "exited", "stopped", "killed"):
return False return False
return True return True
def cmd_stop(self): def cmd_stop(self) -> dict[str, t.Any]:
# Since 'docker compose stop' **always** claims it is stopping containers, even if they are already # Since 'docker compose stop' **always** claims it is stopping containers, even if they are already
# stopped, we have to do this a bit more complicated. # stopped, we have to do this a bit more complicated.
result = {} result: dict[str, t.Any] = {}
# Make sure all containers are created # Make sure all containers are created
args_1 = self.get_up_cmd(self.check_mode, no_start=True) args_1 = self.get_up_cmd(self.check_mode, no_start=True)
rc_1, stdout_1, stderr_1 = self.client.call_cli(*args_1, cwd=self.project_src) rc_1, stdout_1, stderr_1 = self.client.call_cli(*args_1, cwd=self.project_src)
@ -630,7 +637,7 @@ class ServicesManager(BaseComposeManager):
) )
return result return result
def get_restart_cmd(self, dry_run): def get_restart_cmd(self, dry_run: bool) -> list[str]:
args = self.get_base_args() + ["restart"] args = self.get_base_args() + ["restart"]
if not self.dependencies: if not self.dependencies:
args.append("--no-deps") args.append("--no-deps")
@ -643,8 +650,8 @@ class ServicesManager(BaseComposeManager):
args.append(service) args.append(service)
return args return args
def cmd_restart(self): def cmd_restart(self) -> dict[str, t.Any]:
result = {} result: dict[str, t.Any] = {}
args = self.get_restart_cmd(self.check_mode) args = self.get_restart_cmd(self.check_mode)
rc, stdout, stderr = self.client.call_cli(*args, cwd=self.project_src) rc, stdout, stderr = self.client.call_cli(*args, cwd=self.project_src)
events = self.parse_events(stderr, dry_run=self.check_mode, nonzero_rc=rc != 0) events = self.parse_events(stderr, dry_run=self.check_mode, nonzero_rc=rc != 0)
@ -653,7 +660,7 @@ class ServicesManager(BaseComposeManager):
self.update_failed(result, events, args, stdout, stderr, rc) self.update_failed(result, events, args, stdout, stderr, rc)
return result return result
def get_down_cmd(self, dry_run): def get_down_cmd(self, dry_run: bool) -> list[str]:
args = self.get_base_args() + ["down"] args = self.get_base_args() + ["down"]
if self.remove_orphans: if self.remove_orphans:
args.append("--remove-orphans") args.append("--remove-orphans")
@ -670,8 +677,8 @@ class ServicesManager(BaseComposeManager):
args.append(service) args.append(service)
return args return args
def cmd_down(self): def cmd_down(self) -> dict[str, t.Any]:
result = {} result: dict[str, t.Any] = {}
args = self.get_down_cmd(self.check_mode) args = self.get_down_cmd(self.check_mode)
rc, stdout, stderr = self.client.call_cli(*args, cwd=self.project_src) rc, stdout, stderr = self.client.call_cli(*args, cwd=self.project_src)
events = self.parse_events(stderr, dry_run=self.check_mode, nonzero_rc=rc != 0) events = self.parse_events(stderr, dry_run=self.check_mode, nonzero_rc=rc != 0)
@ -681,7 +688,7 @@ class ServicesManager(BaseComposeManager):
return result return result
def main(): def main() -> None:
argument_spec = { argument_spec = {
"state": { "state": {
"type": "str", "type": "str",

View File

@ -166,6 +166,7 @@ rc:
import shlex import shlex
import traceback import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_text from ansible.module_utils.common.text.converters import to_text
@ -180,29 +181,32 @@ from ansible_collections.community.docker.plugins.module_utils._compose_v2 impor
class ExecManager(BaseComposeManager): class ExecManager(BaseComposeManager):
def __init__(self, client): def __init__(self, client: AnsibleModuleDockerClient) -> None:
super().__init__(client) super().__init__(client)
parameters = self.client.module.params parameters = self.client.module.params
self.service = parameters["service"] self.service: str = parameters["service"]
self.index = parameters["index"] self.index: int | None = parameters["index"]
self.chdir = parameters["chdir"] self.chdir: str | None = parameters["chdir"]
self.detach = parameters["detach"] self.detach: bool = parameters["detach"]
self.user = parameters["user"] self.user: str | None = parameters["user"]
self.stdin = parameters["stdin"] self.stdin: str | None = parameters["stdin"]
self.strip_empty_ends = parameters["strip_empty_ends"] self.strip_empty_ends: bool = parameters["strip_empty_ends"]
self.privileged = parameters["privileged"] self.privileged: bool = parameters["privileged"]
self.tty = parameters["tty"] self.tty: bool = parameters["tty"]
self.env = parameters["env"] self.env: dict[str, t.Any] = parameters["env"]
self.argv = parameters["argv"] self.argv: list[str]
if parameters["command"] is not None: if parameters["command"] is not None:
self.argv = shlex.split(parameters["command"]) self.argv = shlex.split(parameters["command"])
else:
self.argv = parameters["argv"]
if self.detach and self.stdin is not None: if self.detach and self.stdin is not None:
self.fail("If detach=true, stdin cannot be provided.") self.fail("If detach=true, stdin cannot be provided.")
if self.stdin is not None and parameters["stdin_add_newline"]: stdin_add_newline: bool = parameters["stdin_add_newline"]
if self.stdin is not None and stdin_add_newline:
self.stdin += "\n" self.stdin += "\n"
if self.env is not None: if self.env is not None:
@ -214,7 +218,7 @@ class ExecManager(BaseComposeManager):
) )
self.env[name] = to_text(value, errors="surrogate_or_strict") self.env[name] = to_text(value, errors="surrogate_or_strict")
def get_exec_cmd(self, dry_run, no_start=False): def get_exec_cmd(self, dry_run: bool) -> list[str]:
args = self.get_base_args(plain_progress=True) + ["exec"] args = self.get_base_args(plain_progress=True) + ["exec"]
if self.index is not None: if self.index is not None:
args.extend(["--index", str(self.index)]) args.extend(["--index", str(self.index)])
@ -237,9 +241,9 @@ class ExecManager(BaseComposeManager):
args.extend(self.argv) args.extend(self.argv)
return args return args
def run(self): def run(self) -> dict[str, t.Any]:
args = self.get_exec_cmd(self.check_mode) args = self.get_exec_cmd(self.check_mode)
kwargs = { kwargs: dict[str, t.Any] = {
"cwd": self.project_src, "cwd": self.project_src,
} }
if self.stdin is not None: if self.stdin is not None:
@ -262,7 +266,7 @@ class ExecManager(BaseComposeManager):
} }
def main(): def main() -> None:
argument_spec = { argument_spec = {
"service": {"type": "str", "required": True}, "service": {"type": "str", "required": True},
"index": {"type": "int"}, "index": {"type": "int"},

View File

@ -111,6 +111,7 @@ actions:
""" """
import traceback import traceback
import typing as t
from ansible_collections.community.docker.plugins.module_utils._common_cli import ( from ansible_collections.community.docker.plugins.module_utils._common_cli import (
AnsibleModuleDockerClient, AnsibleModuleDockerClient,
@ -126,14 +127,14 @@ from ansible_collections.community.docker.plugins.module_utils._version import (
class PullManager(BaseComposeManager): class PullManager(BaseComposeManager):
def __init__(self, client): def __init__(self, client: AnsibleModuleDockerClient) -> None:
super().__init__(client) super().__init__(client)
parameters = self.client.module.params parameters = self.client.module.params
self.policy = parameters["policy"] self.policy: t.Literal["always", "missing"] = parameters["policy"]
self.ignore_buildable = parameters["ignore_buildable"] self.ignore_buildable: bool = parameters["ignore_buildable"]
self.include_deps = parameters["include_deps"] self.include_deps: bool = parameters["include_deps"]
self.services = parameters["services"] or [] self.services: list[str] = parameters["services"] or []
if self.policy != "always" and self.compose_version < LooseVersion("2.22.0"): if self.policy != "always" and self.compose_version < LooseVersion("2.22.0"):
# https://github.com/docker/compose/pull/10981 - 2.22.0 # https://github.com/docker/compose/pull/10981 - 2.22.0
@ -146,7 +147,7 @@ class PullManager(BaseComposeManager):
f"--ignore-buildable is only supported since Docker Compose 2.15.0. {self.client.get_cli()} has version {self.compose_version}" f"--ignore-buildable is only supported since Docker Compose 2.15.0. {self.client.get_cli()} has version {self.compose_version}"
) )
def get_pull_cmd(self, dry_run, no_start=False): def get_pull_cmd(self, dry_run: bool):
args = self.get_base_args() + ["pull"] args = self.get_base_args() + ["pull"]
if self.policy != "always": if self.policy != "always":
args.extend(["--policy", self.policy]) args.extend(["--policy", self.policy])
@ -161,8 +162,8 @@ class PullManager(BaseComposeManager):
args.append(service) args.append(service)
return args return args
def run(self): def run(self) -> dict[str, t.Any]:
result = {} result: dict[str, t.Any] = {}
args = self.get_pull_cmd(self.check_mode) args = self.get_pull_cmd(self.check_mode)
rc, stdout, stderr = self.client.call_cli(*args, cwd=self.project_src) rc, stdout, stderr = self.client.call_cli(*args, cwd=self.project_src)
events = self.parse_events(stderr, dry_run=self.check_mode, nonzero_rc=rc != 0) events = self.parse_events(stderr, dry_run=self.check_mode, nonzero_rc=rc != 0)
@ -179,7 +180,7 @@ class PullManager(BaseComposeManager):
return result return result
def main(): def main() -> None:
argument_spec = { argument_spec = {
"policy": { "policy": {
"type": "str", "type": "str",

View File

@ -239,6 +239,7 @@ rc:
import shlex import shlex
import traceback import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_text from ansible.module_utils.common.text.converters import to_text
@ -253,42 +254,45 @@ from ansible_collections.community.docker.plugins.module_utils._compose_v2 impor
class ExecManager(BaseComposeManager): class ExecManager(BaseComposeManager):
def __init__(self, client): def __init__(self, client: AnsibleModuleDockerClient) -> None:
super().__init__(client) super().__init__(client)
parameters = self.client.module.params parameters = self.client.module.params
self.service = parameters["service"] self.service: str = parameters["service"]
self.build = parameters["build"] self.build: bool = parameters["build"]
self.cap_add = parameters["cap_add"] self.cap_add: list[str] | None = parameters["cap_add"]
self.cap_drop = parameters["cap_drop"] self.cap_drop: list[str] | None = parameters["cap_drop"]
self.entrypoint = parameters["entrypoint"] self.entrypoint: str | None = parameters["entrypoint"]
self.interactive = parameters["interactive"] self.interactive: bool = parameters["interactive"]
self.labels = parameters["labels"] self.labels: list[str] | None = parameters["labels"]
self.name = parameters["name"] self.name: str | None = parameters["name"]
self.no_deps = parameters["no_deps"] self.no_deps: bool = parameters["no_deps"]
self.publish = parameters["publish"] self.publish: list[str] | None = parameters["publish"]
self.quiet_pull = parameters["quiet_pull"] self.quiet_pull: bool = parameters["quiet_pull"]
self.remove_orphans = parameters["remove_orphans"] self.remove_orphans: bool = parameters["remove_orphans"]
self.do_cleanup = parameters["cleanup"] self.do_cleanup: bool = parameters["cleanup"]
self.service_ports = parameters["service_ports"] self.service_ports: bool = parameters["service_ports"]
self.use_aliases = parameters["use_aliases"] self.use_aliases: bool = parameters["use_aliases"]
self.volumes = parameters["volumes"] self.volumes: list[str] | None = parameters["volumes"]
self.chdir = parameters["chdir"] self.chdir: str | None = parameters["chdir"]
self.detach = parameters["detach"] self.detach: bool = parameters["detach"]
self.user = parameters["user"] self.user: str | None = parameters["user"]
self.stdin = parameters["stdin"] self.stdin: str | None = parameters["stdin"]
self.strip_empty_ends = parameters["strip_empty_ends"] self.strip_empty_ends: bool = parameters["strip_empty_ends"]
self.tty = parameters["tty"] self.tty: bool = parameters["tty"]
self.env = parameters["env"] self.env: dict[str, t.Any] | None = parameters["env"]
self.argv = parameters["argv"] self.argv: list[str]
if parameters["command"] is not None: if parameters["command"] is not None:
self.argv = shlex.split(parameters["command"]) self.argv = shlex.split(parameters["command"])
else:
self.argv = parameters["argv"]
if self.detach and self.stdin is not None: if self.detach and self.stdin is not None:
self.fail("If detach=true, stdin cannot be provided.") self.fail("If detach=true, stdin cannot be provided.")
if self.stdin is not None and parameters["stdin_add_newline"]: stdin_add_newline: bool = parameters["stdin_add_newline"]
if self.stdin is not None and stdin_add_newline:
self.stdin += "\n" self.stdin += "\n"
if self.env is not None: if self.env is not None:
@ -300,7 +304,7 @@ class ExecManager(BaseComposeManager):
) )
self.env[name] = to_text(value, errors="surrogate_or_strict") self.env[name] = to_text(value, errors="surrogate_or_strict")
def get_run_cmd(self, dry_run, no_start=False): def get_run_cmd(self, dry_run: bool) -> list[str]:
args = self.get_base_args(plain_progress=True) + ["run"] args = self.get_base_args(plain_progress=True) + ["run"]
if self.build: if self.build:
args.append("--build") args.append("--build")
@ -355,9 +359,9 @@ class ExecManager(BaseComposeManager):
args.extend(self.argv) args.extend(self.argv)
return args return args
def run(self): def run(self) -> dict[str, t.Any]:
args = self.get_run_cmd(self.check_mode) args = self.get_run_cmd(self.check_mode)
kwargs = { kwargs: dict[str, t.Any] = {
"cwd": self.project_src, "cwd": self.project_src,
} }
if self.stdin is not None: if self.stdin is not None:
@ -382,7 +386,7 @@ class ExecManager(BaseComposeManager):
} }
def main(): def main() -> None:
argument_spec = { argument_spec = {
"service": {"type": "str", "required": True}, "service": {"type": "str", "required": True},
"argv": {"type": "list", "elements": "str"}, "argv": {"type": "list", "elements": "str"},

View File

@ -1355,7 +1355,7 @@ from ansible_collections.community.docker.plugins.module_utils._module_container
) )
def main(): def main() -> None:
engine_driver = DockerAPIEngineDriver() engine_driver = DockerAPIEngineDriver()
run_module(engine_driver) run_module(engine_driver)

View File

@ -169,6 +169,7 @@ import io
import os import os
import stat import stat
import traceback import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
from ansible.module_utils.common.validation import check_type_int from ansible.module_utils.common.validation import check_type_int
@ -198,23 +199,29 @@ from ansible_collections.community.docker.plugins.module_utils._scramble import
) )
def are_fileobjs_equal(f1, f2): if t.TYPE_CHECKING:
import tarfile
def are_fileobjs_equal(f1: t.IO[bytes], f2: t.IO[bytes]) -> bool:
"""Given two (buffered) file objects, compare their contents.""" """Given two (buffered) file objects, compare their contents."""
f1on: t.IO[bytes] | None = f1
f2on: t.IO[bytes] | None = f2
blocksize = 65536 blocksize = 65536
b1buf = b"" b1buf = b""
b2buf = b"" b2buf = b""
while True: while True:
if f1 and len(b1buf) < blocksize: if f1on and len(b1buf) < blocksize:
f1b = f1.read(blocksize) f1b = f1on.read(blocksize)
if not f1b: if not f1b:
# f1 is EOF, so stop reading from it # f1 is EOF, so stop reading from it
f1 = None f1on = None
b1buf += f1b b1buf += f1b
if f2 and len(b2buf) < blocksize: if f2on and len(b2buf) < blocksize:
f2b = f2.read(blocksize) f2b = f2on.read(blocksize)
if not f2b: if not f2b:
# f2 is EOF, so stop reading from it # f2 is EOF, so stop reading from it
f2 = None f2on = None
b2buf += f2b b2buf += f2b
if not b1buf or not b2buf: if not b1buf or not b2buf:
# At least one of f1 and f2 is EOF and all its data has # At least one of f1 and f2 is EOF and all its data has
@ -229,29 +236,33 @@ def are_fileobjs_equal(f1, f2):
b2buf = b2buf[buflen:] b2buf = b2buf[buflen:]
def are_fileobjs_equal_read_first(f1, f2): def are_fileobjs_equal_read_first(
f1: t.IO[bytes], f2: t.IO[bytes]
) -> tuple[bool, bytes]:
"""Given two (buffered) file objects, compare their contents. """Given two (buffered) file objects, compare their contents.
Returns a tuple (is_equal, content_of_f1), where the first element indicates Returns a tuple (is_equal, content_of_f1), where the first element indicates
whether the two file objects have the same content, and the second element is whether the two file objects have the same content, and the second element is
the content of the first file object.""" the content of the first file object."""
f1on: t.IO[bytes] | None = f1
f2on: t.IO[bytes] | None = f2
blocksize = 65536 blocksize = 65536
b1buf = b"" b1buf = b""
b2buf = b"" b2buf = b""
is_equal = True is_equal = True
content = [] content = []
while True: while True:
if f1 and len(b1buf) < blocksize: if f1on and len(b1buf) < blocksize:
f1b = f1.read(blocksize) f1b = f1on.read(blocksize)
if not f1b: if not f1b:
# f1 is EOF, so stop reading from it # f1 is EOF, so stop reading from it
f1 = None f1on = None
b1buf += f1b b1buf += f1b
if f2 and len(b2buf) < blocksize: if f2on and len(b2buf) < blocksize:
f2b = f2.read(blocksize) f2b = f2on.read(blocksize)
if not f2b: if not f2b:
# f2 is EOF, so stop reading from it # f2 is EOF, so stop reading from it
f2 = None f2on = None
b2buf += f2b b2buf += f2b
if not b1buf or not b2buf: if not b1buf or not b2buf:
# At least one of f1 and f2 is EOF and all its data has # At least one of f1 and f2 is EOF and all its data has
@ -269,13 +280,13 @@ def are_fileobjs_equal_read_first(f1, f2):
b2buf = b2buf[buflen:] b2buf = b2buf[buflen:]
content.append(b1buf) content.append(b1buf)
if f1: if f1on:
content.append(f1.read()) content.append(f1on.read())
return is_equal, b"".join(content) return is_equal, b"".join(content)
def is_container_file_not_regular_file(container_stat): def is_container_file_not_regular_file(container_stat: dict[str, t.Any]) -> bool:
for bit in ( for bit in (
# https://pkg.go.dev/io/fs#FileMode # https://pkg.go.dev/io/fs#FileMode
32 - 1, # ModeDir 32 - 1, # ModeDir
@ -292,7 +303,7 @@ def is_container_file_not_regular_file(container_stat):
return False return False
def get_container_file_mode(container_stat): def get_container_file_mode(container_stat: dict[str, t.Any]) -> int:
mode = container_stat["mode"] & 0xFFF mode = container_stat["mode"] & 0xFFF
if container_stat["mode"] & (1 << (32 - 9)) != 0: # ModeSetuid if container_stat["mode"] & (1 << (32 - 9)) != 0: # ModeSetuid
mode |= stat.S_ISUID # set UID bit mode |= stat.S_ISUID # set UID bit
@ -303,7 +314,9 @@ def get_container_file_mode(container_stat):
return mode return mode
def add_other_diff(diff, in_path, member): def add_other_diff(
diff: dict[str, t.Any] | None, in_path: str, member: tarfile.TarInfo
) -> None:
if diff is None: if diff is None:
return return
diff["before_header"] = in_path diff["before_header"] = in_path
@ -326,14 +339,14 @@ def add_other_diff(diff, in_path, member):
def retrieve_diff( def retrieve_diff(
client, client: AnsibleDockerClient,
container, container: str,
container_path, container_path: str,
follow_links, follow_links: bool,
diff, diff: dict[str, t.Any] | None,
max_file_size_for_diff, max_file_size_for_diff: int,
regular_stat=None, regular_stat: dict[str, t.Any] | None = None,
link_target=None, link_target: str | None = None,
): ):
if diff is None: if diff is None:
return return
@ -377,19 +390,21 @@ def retrieve_diff(
return return
# We need to get hold of the content # We need to get hold of the content
def process_none(in_path): def process_none(in_path: str) -> None:
diff["before"] = "" diff["before"] = ""
def process_regular(in_path, tar, member): def process_regular(
in_path: str, tar: tarfile.TarFile, member: tarfile.TarInfo
) -> None:
add_diff_dst_from_regular_member( add_diff_dst_from_regular_member(
diff, max_file_size_for_diff, in_path, tar, member diff, max_file_size_for_diff, in_path, tar, member
) )
def process_symlink(in_path, member): def process_symlink(in_path: str, member: tarfile.TarInfo) -> None:
diff["before_header"] = in_path diff["before_header"] = in_path
diff["before"] = member.linkname diff["before"] = member.linkname
def process_other(in_path, member): def process_other(in_path: str, member: tarfile.TarInfo) -> None:
add_other_diff(diff, in_path, member) add_other_diff(diff, in_path, member)
fetch_file_ex( fetch_file_ex(
@ -404,7 +419,7 @@ def retrieve_diff(
) )
def is_binary(content): def is_binary(content: bytes) -> bool:
if b"\x00" in content: if b"\x00" in content:
return True return True
# TODO: better detection # TODO: better detection
@ -413,8 +428,13 @@ def is_binary(content):
def are_fileobjs_equal_with_diff_of_first( def are_fileobjs_equal_with_diff_of_first(
f1, f2, size, diff, max_file_size_for_diff, container_path f1: t.IO[bytes],
): f2: t.IO[bytes],
size: int,
diff: dict[str, t.Any] | None,
max_file_size_for_diff: int,
container_path: str,
) -> bool:
if diff is None: if diff is None:
return are_fileobjs_equal(f1, f2) return are_fileobjs_equal(f1, f2)
if size > max_file_size_for_diff > 0: if size > max_file_size_for_diff > 0:
@ -430,15 +450,22 @@ def are_fileobjs_equal_with_diff_of_first(
def add_diff_dst_from_regular_member( def add_diff_dst_from_regular_member(
diff, max_file_size_for_diff, container_path, tar, member diff: dict[str, t.Any] | None,
): max_file_size_for_diff: int,
container_path: str,
tar: tarfile.TarFile,
member: tarfile.TarInfo,
) -> None:
if diff is None: if diff is None:
return return
if member.size > max_file_size_for_diff > 0: if member.size > max_file_size_for_diff > 0:
diff["dst_larger"] = max_file_size_for_diff diff["dst_larger"] = max_file_size_for_diff
return return
with tar.extractfile(member) as tar_f: mf = tar.extractfile(member)
if not mf:
raise AssertionError("Member should be present for regular file")
with mf as tar_f:
content = tar_f.read() content = tar_f.read()
if is_binary(content): if is_binary(content):
@ -448,35 +475,35 @@ def add_diff_dst_from_regular_member(
diff["before"] = to_text(content) diff["before"] = to_text(content)
def copy_dst_to_src(diff): def copy_dst_to_src(diff: dict[str, t.Any] | None) -> None:
if diff is None: if diff is None:
return return
for f, t in [ for frm, to in [
("dst_size", "src_size"), ("dst_size", "src_size"),
("dst_binary", "src_binary"), ("dst_binary", "src_binary"),
("before_header", "after_header"), ("before_header", "after_header"),
("before", "after"), ("before", "after"),
]: ]:
if f in diff: if frm in diff:
diff[t] = diff[f] diff[to] = diff[frm]
elif t in diff: elif to in diff:
diff.pop(t) diff.pop(to)
def is_file_idempotent( def is_file_idempotent(
client, client: AnsibleDockerClient,
container, container: str,
managed_path, managed_path: str,
container_path, container_path: str,
follow_links, follow_links: bool,
local_follow_links, local_follow_links: bool,
owner_id, owner_id,
group_id, group_id,
mode, mode,
force=False, force: bool | None = False,
diff=None, diff: dict[str, t.Any] | None = None,
max_file_size_for_diff=1, max_file_size_for_diff: int = 1,
): ) -> tuple[str, int, bool]:
# Retrieve information of local file # Retrieve information of local file
try: try:
file_stat = ( file_stat = (
@ -644,10 +671,12 @@ def is_file_idempotent(
return container_path, mode, False return container_path, mode, False
# Fetch file from container # Fetch file from container
def process_none(in_path): def process_none(in_path: str) -> tuple[str, int, bool]:
return container_path, mode, False return container_path, mode, False
def process_regular(in_path, tar, member): def process_regular(
in_path: str, tar: tarfile.TarFile, member: tarfile.TarInfo
) -> tuple[str, int, bool]:
# Check things like user/group ID and mode # Check things like user/group ID and mode
if any( if any(
[ [
@ -663,14 +692,17 @@ def is_file_idempotent(
) )
return container_path, mode, False return container_path, mode, False
with tar.extractfile(member) as tar_f: mf = tar.extractfile(member)
if mf is None:
raise AssertionError("Member should be present for regular file")
with mf as tar_f:
with open(managed_path, "rb") as local_f: with open(managed_path, "rb") as local_f:
is_equal = are_fileobjs_equal_with_diff_of_first( is_equal = are_fileobjs_equal_with_diff_of_first(
tar_f, local_f, member.size, diff, max_file_size_for_diff, in_path tar_f, local_f, member.size, diff, max_file_size_for_diff, in_path
) )
return container_path, mode, is_equal return container_path, mode, is_equal
def process_symlink(in_path, member): def process_symlink(in_path: str, member: tarfile.TarInfo) -> tuple[str, int, bool]:
if diff is not None: if diff is not None:
diff["before_header"] = in_path diff["before_header"] = in_path
diff["before"] = member.linkname diff["before"] = member.linkname
@ -689,7 +721,7 @@ def is_file_idempotent(
local_link_target = os.readlink(managed_path) local_link_target = os.readlink(managed_path)
return container_path, mode, member.linkname == local_link_target return container_path, mode, member.linkname == local_link_target
def process_other(in_path, member): def process_other(in_path: str, member: tarfile.TarInfo) -> tuple[str, int, bool]:
add_other_diff(diff, in_path, member) add_other_diff(diff, in_path, member)
return container_path, mode, False return container_path, mode, False
@ -706,23 +738,21 @@ def is_file_idempotent(
def copy_file_into_container( def copy_file_into_container(
client, client: AnsibleDockerClient,
container, container: str,
managed_path, managed_path: str,
container_path, container_path: str,
follow_links, follow_links: bool,
local_follow_links, local_follow_links: bool,
owner_id, owner_id,
group_id, group_id,
mode, mode,
force=False, force: bool | None = False,
diff=False, do_diff: bool = False,
max_file_size_for_diff=1, max_file_size_for_diff: int = 1,
): ) -> t.NoReturn:
if diff: diff: dict[str, t.Any] | None
diff = {} diff = {} if do_diff else None
else:
diff = None
container_path, mode, idempotent = is_file_idempotent( container_path, mode, idempotent = is_file_idempotent(
client, client,
@ -762,18 +792,18 @@ def copy_file_into_container(
def is_content_idempotent( def is_content_idempotent(
client, client: AnsibleDockerClient,
container, container: str,
content, content: bytes,
container_path, container_path: str,
follow_links, follow_links: bool,
owner_id, owner_id,
group_id, group_id,
mode, mode,
force=False, force: bool | None = False,
diff=None, diff: dict[str, t.Any] | None = None,
max_file_size_for_diff=1, max_file_size_for_diff: int = 1,
): ) -> tuple[str, int, bool]:
if diff is not None: if diff is not None:
if len(content) > max_file_size_for_diff > 0: if len(content) > max_file_size_for_diff > 0:
diff["src_larger"] = max_file_size_for_diff diff["src_larger"] = max_file_size_for_diff
@ -894,12 +924,14 @@ def is_content_idempotent(
return container_path, mode, False return container_path, mode, False
# Fetch file from container # Fetch file from container
def process_none(in_path): def process_none(in_path: str) -> tuple[str, int, bool]:
if diff is not None: if diff is not None:
diff["before"] = "" diff["before"] = ""
return container_path, mode, False return container_path, mode, False
def process_regular(in_path, tar, member): def process_regular(
in_path: str, tar: tarfile.TarFile, member: tarfile.TarInfo
) -> tuple[str, int, bool]:
# Check things like user/group ID and mode # Check things like user/group ID and mode
if any( if any(
[ [
@ -914,7 +946,10 @@ def is_content_idempotent(
) )
return container_path, mode, False return container_path, mode, False
with tar.extractfile(member) as tar_f: mf = tar.extractfile(member)
if mf is None:
raise AssertionError("Member should be present for regular file")
with mf as tar_f:
is_equal = are_fileobjs_equal_with_diff_of_first( is_equal = are_fileobjs_equal_with_diff_of_first(
tar_f, tar_f,
io.BytesIO(content), io.BytesIO(content),
@ -925,14 +960,14 @@ def is_content_idempotent(
) )
return container_path, mode, is_equal return container_path, mode, is_equal
def process_symlink(in_path, member): def process_symlink(in_path: str, member: tarfile.TarInfo) -> tuple[str, int, bool]:
if diff is not None: if diff is not None:
diff["before_header"] = in_path diff["before_header"] = in_path
diff["before"] = member.linkname diff["before"] = member.linkname
return container_path, mode, False return container_path, mode, False
def process_other(in_path, member): def process_other(in_path: str, member: tarfile.TarInfo) -> tuple[str, int, bool]:
add_other_diff(diff, in_path, member) add_other_diff(diff, in_path, member)
return container_path, mode, False return container_path, mode, False
@ -949,22 +984,19 @@ def is_content_idempotent(
def copy_content_into_container( def copy_content_into_container(
client, client: AnsibleDockerClient,
container, container: str,
content, content: bytes,
container_path, container_path: str,
follow_links, follow_links: bool,
owner_id, owner_id,
group_id, group_id,
mode, mode,
force=False, force: bool | None = False,
diff=False, do_diff: bool = False,
max_file_size_for_diff=1, max_file_size_for_diff: int = 1,
): ) -> t.NoReturn:
if diff: diff: dict[str, t.Any] | None = {} if do_diff else None
diff = {}
else:
diff = None
container_path, mode, idempotent = is_content_idempotent( container_path, mode, idempotent = is_content_idempotent(
client, client,
@ -1007,7 +1039,7 @@ def copy_content_into_container(
client.module.exit_json(**result) client.module.exit_json(**result)
def parse_modern(mode): def parse_modern(mode: str | int) -> int:
if isinstance(mode, str): if isinstance(mode, str):
return int(to_native(mode), 8) return int(to_native(mode), 8)
if isinstance(mode, int): if isinstance(mode, int):
@ -1015,13 +1047,13 @@ def parse_modern(mode):
raise TypeError(f"must be an octal string or an integer, got {mode!r}") raise TypeError(f"must be an octal string or an integer, got {mode!r}")
def parse_octal_string_only(mode): def parse_octal_string_only(mode: str) -> int:
if isinstance(mode, str): if isinstance(mode, str):
return int(to_native(mode), 8) return int(to_native(mode), 8)
raise TypeError(f"must be an octal string, got {mode!r}") raise TypeError(f"must be an octal string, got {mode!r}")
def main(): def main() -> None:
argument_spec = { argument_spec = {
"container": {"type": "str", "required": True}, "container": {"type": "str", "required": True},
"path": {"type": "path"}, "path": {"type": "path"},
@ -1054,20 +1086,22 @@ def main():
}, },
) )
container = client.module.params["container"] container: str = client.module.params["container"]
managed_path = client.module.params["path"] managed_path: str | None = client.module.params["path"]
container_path = client.module.params["container_path"] container_path: str = client.module.params["container_path"]
follow = client.module.params["follow"] follow: bool = client.module.params["follow"]
local_follow = client.module.params["local_follow"] local_follow: bool = client.module.params["local_follow"]
owner_id = client.module.params["owner_id"] owner_id: int | None = client.module.params["owner_id"]
group_id = client.module.params["group_id"] group_id: int | None = client.module.params["group_id"]
mode = client.module.params["mode"] mode: t.Any = client.module.params["mode"]
force = client.module.params["force"] force: bool | None = client.module.params["force"]
content = client.module.params["content"] content_str: str | None = client.module.params["content"]
max_file_size_for_diff = client.module.params["_max_file_size_for_diff"] or 1 max_file_size_for_diff: int = client.module.params["_max_file_size_for_diff"] or 1
if mode is not None: if mode is not None:
mode_parse = client.module.params["mode_parse"] mode_parse: t.Literal["legacy", "modern", "octal_string_only"] = (
client.module.params["mode_parse"]
)
try: try:
if mode_parse == "legacy": if mode_parse == "legacy":
mode = check_type_int(mode) mode = check_type_int(mode)
@ -1080,14 +1114,15 @@ def main():
if mode < 0: if mode < 0:
client.fail(f"'mode' must not be negative; got {mode}") client.fail(f"'mode' must not be negative; got {mode}")
if content is not None: content: bytes | None = None
if content_str is not None:
if client.module.params["content_is_b64"]: if client.module.params["content_is_b64"]:
try: try:
content = base64.b64decode(content) content = base64.b64decode(content_str)
except Exception as e: # pylint: disable=broad-exception-caught except Exception as e: # pylint: disable=broad-exception-caught
client.fail(f"Cannot Base64 decode the content option: {e}") client.fail(f"Cannot Base64 decode the content option: {e}")
else: else:
content = to_bytes(content) content = to_bytes(content_str)
if not container_path.startswith(os.path.sep): if not container_path.startswith(os.path.sep):
container_path = os.path.join(os.path.sep, container_path) container_path = os.path.join(os.path.sep, container_path)
@ -1108,7 +1143,7 @@ def main():
group_id=group_id, group_id=group_id,
mode=mode, mode=mode,
force=force, force=force,
diff=client.module._diff, do_diff=client.module._diff,
max_file_size_for_diff=max_file_size_for_diff, max_file_size_for_diff=max_file_size_for_diff,
) )
elif managed_path is not None: elif managed_path is not None:
@ -1123,7 +1158,7 @@ def main():
group_id=group_id, group_id=group_id,
mode=mode, mode=mode,
force=force, force=force,
diff=client.module._diff, do_diff=client.module._diff,
max_file_size_for_diff=max_file_size_for_diff, max_file_size_for_diff=max_file_size_for_diff,
) )
else: else:

View File

@ -165,6 +165,7 @@ exec_id:
import shlex import shlex
import traceback import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_bytes, to_text from ansible.module_utils.common.text.converters import to_bytes, to_text
@ -185,7 +186,7 @@ from ansible_collections.community.docker.plugins.module_utils._socket_handler i
) )
def main(): def main() -> None:
argument_spec = { argument_spec = {
"container": {"type": "str", "required": True}, "container": {"type": "str", "required": True},
"argv": {"type": "list", "elements": "str"}, "argv": {"type": "list", "elements": "str"},
@ -211,16 +212,16 @@ def main():
required_one_of=[("argv", "command")], required_one_of=[("argv", "command")],
) )
container = client.module.params["container"] container: str = client.module.params["container"]
argv = client.module.params["argv"] argv: list[str] | None = client.module.params["argv"]
command = client.module.params["command"] command: str | None = client.module.params["command"]
chdir = client.module.params["chdir"] chdir: str | None = client.module.params["chdir"]
detach = client.module.params["detach"] detach: bool = client.module.params["detach"]
user = client.module.params["user"] user: str | None = client.module.params["user"]
stdin = client.module.params["stdin"] stdin: str | None = client.module.params["stdin"]
strip_empty_ends = client.module.params["strip_empty_ends"] strip_empty_ends: bool = client.module.params["strip_empty_ends"]
tty = client.module.params["tty"] tty: bool = client.module.params["tty"]
env = client.module.params["env"] env: dict[str, t.Any] = client.module.params["env"]
if env is not None: if env is not None:
for name, value in list(env.items()): for name, value in list(env.items()):
@ -233,6 +234,7 @@ def main():
if command is not None: if command is not None:
argv = shlex.split(command) argv = shlex.split(command)
assert argv is not None
if detach and stdin is not None: if detach and stdin is not None:
client.module.fail_json(msg="If detach=true, stdin cannot be provided.") client.module.fail_json(msg="If detach=true, stdin cannot be provided.")
@ -258,7 +260,7 @@ def main():
exec_data = client.post_json_to_json( exec_data = client.post_json_to_json(
"/containers/{0}/exec", container, data=data "/containers/{0}/exec", container, data=data
) )
exec_id = exec_data["Id"] exec_id: str = exec_data["Id"]
data = { data = {
"Tty": tty, "Tty": tty,
@ -269,6 +271,8 @@ def main():
client.module.exit_json(changed=True, exec_id=exec_id) client.module.exit_json(changed=True, exec_id=exec_id)
else: else:
stdout: bytes | None
stderr: bytes | None
if stdin and not detach: if stdin and not detach:
exec_socket = client.post_json_to_stream_socket( exec_socket = client.post_json_to_stream_socket(
"/exec/{0}/start", exec_id, data=data "/exec/{0}/start", exec_id, data=data
@ -283,28 +287,37 @@ def main():
stdout, stderr = exec_socket_handler.consume() stdout, stderr = exec_socket_handler.consume()
finally: finally:
exec_socket.close() exec_socket.close()
elif tty:
stdout, stderr = client.post_json_to_stream(
"/exec/{0}/start",
exec_id,
data=data,
stream=False,
tty=True,
demux=True,
)
else: else:
stdout, stderr = client.post_json_to_stream( stdout, stderr = client.post_json_to_stream(
"/exec/{0}/start", "/exec/{0}/start",
exec_id, exec_id,
data=data, data=data,
stream=False, stream=False,
tty=tty, tty=False,
demux=True, demux=True,
) )
result = client.get_json("/exec/{0}/json", exec_id) result = client.get_json("/exec/{0}/json", exec_id)
stdout = to_text(stdout or b"") stdout_t = to_text(stdout or b"")
stderr = to_text(stderr or b"") stderr_t = to_text(stderr or b"")
if strip_empty_ends: if strip_empty_ends:
stdout = stdout.rstrip("\r\n") stdout_t = stdout_t.rstrip("\r\n")
stderr = stderr.rstrip("\r\n") stderr_t = stderr_t.rstrip("\r\n")
client.module.exit_json( client.module.exit_json(
changed=True, changed=True,
stdout=stdout, stdout=stdout_t,
stderr=stderr, stderr=stderr_t,
rc=result.get("ExitCode") or 0, rc=result.get("ExitCode") or 0,
) )
except NotFound: except NotFound:

View File

@ -86,7 +86,7 @@ from ansible_collections.community.docker.plugins.module_utils._common_api impor
) )
def main(): def main() -> None:
argument_spec = { argument_spec = {
"name": {"type": "str", "required": True}, "name": {"type": "str", "required": True},
} }
@ -96,8 +96,9 @@ def main():
supports_check_mode=True, supports_check_mode=True,
) )
container_id: str = client.module.params["name"]
try: try:
container = client.get_container(client.module.params["name"]) container = client.get_container(container_id)
client.module.exit_json( client.module.exit_json(
changed=False, changed=False,

View File

@ -173,6 +173,7 @@ current_context_name:
""" """
import traceback import traceback
import typing as t
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_text from ansible.module_utils.common.text.converters import to_text
@ -185,6 +186,7 @@ from ansible_collections.community.docker.plugins.module_utils._api.context.conf
) )
from ansible_collections.community.docker.plugins.module_utils._api.context.context import ( from ansible_collections.community.docker.plugins.module_utils._api.context.context import (
IN_MEMORY, IN_MEMORY,
Context,
) )
from ansible_collections.community.docker.plugins.module_utils._api.errors import ( from ansible_collections.community.docker.plugins.module_utils._api.errors import (
ContextException, ContextException,
@ -192,7 +194,13 @@ from ansible_collections.community.docker.plugins.module_utils._api.errors impor
) )
def tls_context_to_json(context): if t.TYPE_CHECKING:
from ansible_collections.community.docker.plugins.module_utils._api.tls import (
TLSConfig,
)
def tls_context_to_json(context: TLSConfig | None) -> dict[str, t.Any] | None:
if context is None: if context is None:
return None return None
return { return {
@ -204,8 +212,8 @@ def tls_context_to_json(context):
} }
def context_to_json(context, current): def context_to_json(context: Context, current: bool) -> dict[str, t.Any]:
module_config = {} module_config: dict[str, t.Any] = {}
if "docker" in context.endpoints: if "docker" in context.endpoints:
endpoint = context.endpoints["docker"] endpoint = context.endpoints["docker"]
if isinstance(endpoint.get("Host"), str): if isinstance(endpoint.get("Host"), str):
@ -247,7 +255,7 @@ def context_to_json(context, current):
} }
def main(): def main() -> None:
argument_spec = { argument_spec = {
"only_current": {"type": "bool", "default": False}, "only_current": {"type": "bool", "default": False},
"name": {"type": "str"}, "name": {"type": "str"},
@ -262,28 +270,31 @@ def main():
], ],
) )
only_current: bool = module.params["only_current"]
name: str | None = module.params["name"]
cli_context: str | None = module.params["cli_context"]
try: try:
if module.params["cli_context"]: if cli_context:
current_context_name, current_context_source = ( current_context_name, current_context_source = (
module.params["cli_context"], cli_context,
"cli_context module option", "cli_context module option",
) )
else: else:
current_context_name, current_context_source = ( current_context_name, current_context_source = (
get_current_context_name_with_source() get_current_context_name_with_source()
) )
if module.params["name"]: if name:
contexts = [ContextAPI.get_context(module.params["name"])] context_or_none = ContextAPI.get_context(name)
if not contexts[0]: if not context_or_none:
module.fail_json( module.fail_json(msg=f"There is no context of name {name!r}")
msg=f"There is no context of name {module.params['name']!r}" contexts = [context_or_none]
) elif only_current:
elif module.params["only_current"]: context_or_none = ContextAPI.get_context(current_context_name)
contexts = [ContextAPI.get_context(current_context_name)] if not context_or_none:
if not contexts[0]:
module.fail_json( module.fail_json(
msg=f"There is no context of name {current_context_name!r}, which is configured as the default context ({current_context_source})", msg=f"There is no context of name {current_context_name!r}, which is configured as the default context ({current_context_source})",
) )
contexts = [context_or_none]
else: else:
contexts = ContextAPI.contexts() contexts = ContextAPI.contexts()

View File

@ -212,6 +212,7 @@ disk_usage:
""" """
import traceback import traceback
import typing as t
from ansible_collections.community.docker.plugins.module_utils._api.errors import ( from ansible_collections.community.docker.plugins.module_utils._api.errors import (
APIError, APIError,
@ -231,9 +232,7 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
class DockerHostManager(DockerBaseClass): class DockerHostManager(DockerBaseClass):
def __init__(self, client: AnsibleDockerClient, results: dict[str, t.Any]) -> None:
def __init__(self, client, results):
super().__init__() super().__init__()
self.client = client self.client = client
@ -253,21 +252,21 @@ class DockerHostManager(DockerBaseClass):
for docker_object in listed_objects: for docker_object in listed_objects:
if self.client.module.params[docker_object]: if self.client.module.params[docker_object]:
returned_name = docker_object returned_name = docker_object
filter_name = docker_object + "_filters" filter_name = f"{docker_object}_filters"
filters = clean_dict_booleans_for_docker_api( filters = clean_dict_booleans_for_docker_api(
client.module.params.get(filter_name), True client.module.params.get(filter_name), allow_sequences=True
) )
self.results[returned_name] = self.get_docker_items_list( self.results[returned_name] = self.get_docker_items_list(
docker_object, filters docker_object, filters
) )
def get_docker_host_info(self): def get_docker_host_info(self) -> dict[str, t.Any]:
try: try:
return self.client.info() return self.client.info()
except APIError as exc: except APIError as exc:
self.client.fail(f"Error inspecting docker host: {exc}") self.client.fail(f"Error inspecting docker host: {exc}")
def get_docker_disk_usage_facts(self): def get_docker_disk_usage_facts(self) -> dict[str, t.Any]:
try: try:
if self.verbose_output: if self.verbose_output:
return self.client.df() return self.client.df()
@ -275,9 +274,13 @@ class DockerHostManager(DockerBaseClass):
except APIError as exc: except APIError as exc:
self.client.fail(f"Error inspecting docker host: {exc}") self.client.fail(f"Error inspecting docker host: {exc}")
def get_docker_items_list(self, docker_object=None, filters=None, verbose=False): def get_docker_items_list(
items = None self,
items_list = [] docker_object: str,
filters: dict[str, t.Any] | None = None,
verbose: bool = False,
) -> list[dict[str, t.Any]]:
items = []
header_containers = [ header_containers = [
"Id", "Id",
@ -329,6 +332,7 @@ class DockerHostManager(DockerBaseClass):
if self.verbose_output: if self.verbose_output:
return items return items
items_list = []
for item in items: for item in items:
item_record = {} item_record = {}
@ -349,7 +353,7 @@ class DockerHostManager(DockerBaseClass):
return items_list return items_list
def main(): def main() -> None:
argument_spec = { argument_spec = {
"containers": {"type": "bool", "default": False}, "containers": {"type": "bool", "default": False},
"containers_all": {"type": "bool", "default": False}, "containers_all": {"type": "bool", "default": False},

View File

@ -367,6 +367,7 @@ import errno
import json import json
import os import os
import traceback import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_native from ansible.module_utils.common.text.converters import to_native
from ansible.module_utils.common.text.formatters import human_to_bytes from ansible.module_utils.common.text.formatters import human_to_bytes
@ -411,7 +412,18 @@ from ansible_collections.community.docker.plugins.module_utils._version import (
) )
def convert_to_bytes(value, module, name, unlimited_value=None): if t.TYPE_CHECKING:
from collections.abc import Callable
from ansible.module_utils.basic import AnsibleModule
def convert_to_bytes(
value: str | None,
module: AnsibleModule,
name: str,
unlimited_value: int | None = None,
) -> int | None:
if value is None: if value is None:
return value return value
try: try:
@ -423,8 +435,7 @@ def convert_to_bytes(value, module, name, unlimited_value=None):
class ImageManager(DockerBaseClass): class ImageManager(DockerBaseClass):
def __init__(self, client: AnsibleDockerClient, results: dict[str, t.Any]) -> None:
def __init__(self, client, results):
""" """
Configure a docker_image task. Configure a docker_image task.
@ -441,12 +452,14 @@ class ImageManager(DockerBaseClass):
parameters = self.client.module.params parameters = self.client.module.params
self.check_mode = self.client.check_mode self.check_mode = self.client.check_mode
self.source = parameters["source"] self.source: t.Literal["build", "load", "pull", "local"] | None = parameters[
build = parameters["build"] or {} "source"
pull = parameters["pull"] or {} ]
self.archive_path = parameters["archive_path"] build: dict[str, t.Any] = parameters["build"] or {}
self.cache_from = build.get("cache_from") pull: dict[str, t.Any] = parameters["pull"] or {}
self.container_limits = build.get("container_limits") self.archive_path: str | None = parameters["archive_path"]
self.cache_from: list[str] | None = build.get("cache_from")
self.container_limits: dict[str, t.Any] | None = build.get("container_limits")
if self.container_limits and "memory" in self.container_limits: if self.container_limits and "memory" in self.container_limits:
self.container_limits["memory"] = convert_to_bytes( self.container_limits["memory"] = convert_to_bytes(
self.container_limits["memory"], self.container_limits["memory"],
@ -460,32 +473,36 @@ class ImageManager(DockerBaseClass):
"build.container_limits.memswap", "build.container_limits.memswap",
unlimited_value=-1, unlimited_value=-1,
) )
self.dockerfile = build.get("dockerfile") self.dockerfile: str | None = build.get("dockerfile")
self.force_source = parameters["force_source"] self.force_source: bool = parameters["force_source"]
self.force_absent = parameters["force_absent"] self.force_absent: bool = parameters["force_absent"]
self.force_tag = parameters["force_tag"] self.force_tag: bool = parameters["force_tag"]
self.load_path = parameters["load_path"] self.load_path: str | None = parameters["load_path"]
self.name = parameters["name"] self.name: str = parameters["name"]
self.network = build.get("network") self.network: str | None = build.get("network")
self.extra_hosts = clean_dict_booleans_for_docker_api(build.get("etc_hosts")) self.extra_hosts: dict[str, str] = clean_dict_booleans_for_docker_api(
self.nocache = build.get("nocache", False) build.get("etc_hosts") # type: ignore
self.build_path = build.get("path") )
self.pull = build.get("pull") self.nocache: bool = build.get("nocache", False)
self.target = build.get("target") self.build_path: str | None = build.get("path")
self.repository = parameters["repository"] self.pull: bool | None = build.get("pull")
self.rm = build.get("rm", True) self.target: str | None = build.get("target")
self.state = parameters["state"] self.repository: str | None = parameters["repository"]
self.tag = parameters["tag"] self.rm: bool = build.get("rm", True)
self.http_timeout = build.get("http_timeout") self.state: t.Literal["absent", "present"] = parameters["state"]
self.pull_platform = pull.get("platform") self.tag: str = parameters["tag"]
self.push = parameters["push"] self.http_timeout: int | None = build.get("http_timeout")
self.buildargs = build.get("args") self.pull_platform: str | None = pull.get("platform")
self.build_platform = build.get("platform") self.push: bool = parameters["push"]
self.use_config_proxy = build.get("use_config_proxy") self.buildargs: dict[str, t.Any] | None = build.get("args")
self.shm_size = convert_to_bytes( self.build_platform: str | None = build.get("platform")
self.use_config_proxy: bool | None = build.get("use_config_proxy")
self.shm_size: int | None = convert_to_bytes(
build.get("shm_size"), self.client.module, "build.shm_size" build.get("shm_size"), self.client.module, "build.shm_size"
) )
self.labels = clean_dict_booleans_for_docker_api(build.get("labels")) self.labels: dict[str, str] = clean_dict_booleans_for_docker_api(
build.get("labels") # type: ignore
)
# If name contains a tag, it takes precedence over tag parameter. # If name contains a tag, it takes precedence over tag parameter.
if not is_image_name_id(self.name): if not is_image_name_id(self.name):
@ -507,10 +524,10 @@ class ImageManager(DockerBaseClass):
elif self.state == "absent": elif self.state == "absent":
self.absent() self.absent()
def fail(self, msg): def fail(self, msg: str) -> t.NoReturn:
self.client.fail(msg) self.client.fail(msg)
def present(self): def present(self) -> None:
""" """
Handles state = 'present', which includes building, loading or pulling an image, Handles state = 'present', which includes building, loading or pulling an image,
depending on user provided parameters. depending on user provided parameters.
@ -530,6 +547,7 @@ class ImageManager(DockerBaseClass):
) )
# Build the image # Build the image
assert self.build_path is not None
if not os.path.isdir(self.build_path): if not os.path.isdir(self.build_path):
self.fail( self.fail(
f"Requested build path {self.build_path} could not be found or you do not have access." f"Requested build path {self.build_path} could not be found or you do not have access."
@ -546,6 +564,7 @@ class ImageManager(DockerBaseClass):
self.results.update(self.build_image()) self.results.update(self.build_image())
elif self.source == "load": elif self.source == "load":
assert self.load_path is not None
# Load the image from an archive # Load the image from an archive
if not os.path.isfile(self.load_path): if not os.path.isfile(self.load_path):
self.fail( self.fail(
@ -596,7 +615,7 @@ class ImageManager(DockerBaseClass):
elif self.repository: elif self.repository:
self.tag_image(self.name, self.tag, self.repository, push=self.push) self.tag_image(self.name, self.tag, self.repository, push=self.push)
def absent(self): def absent(self) -> None:
""" """
Handles state = 'absent', which removes an image. Handles state = 'absent', which removes an image.
@ -627,8 +646,11 @@ class ImageManager(DockerBaseClass):
@staticmethod @staticmethod
def archived_image_action( def archived_image_action(
failure_logger, archive_path, current_image_name, current_image_id failure_logger: Callable[[str], None],
): archive_path: str,
current_image_name: str,
current_image_id: str,
) -> str | None:
""" """
If the archive is missing or requires replacement, return an action message. If the archive is missing or requires replacement, return an action message.
@ -667,7 +689,7 @@ class ImageManager(DockerBaseClass):
f"overwriting archive with image {archived.image_id} named {name}" f"overwriting archive with image {archived.image_id} named {name}"
) )
def archive_image(self, name, tag): def archive_image(self, name: str, tag: str | None) -> None:
""" """
Archive an image to a .tar file. Called when archive_path is passed. Archive an image to a .tar file. Called when archive_path is passed.
@ -676,6 +698,7 @@ class ImageManager(DockerBaseClass):
:param tag: Optional image tag; assumed to be "latest" if None :param tag: Optional image tag; assumed to be "latest" if None
:type tag: str | None :type tag: str | None
""" """
assert self.archive_path is not None
if not tag: if not tag:
tag = "latest" tag = "latest"
@ -710,8 +733,8 @@ class ImageManager(DockerBaseClass):
self.client._get( self.client._get(
self.client._url("/images/{0}/get", image_name), stream=True self.client._url("/images/{0}/get", image_name), stream=True
), ),
DEFAULT_DATA_CHUNK_SIZE, chunk_size=DEFAULT_DATA_CHUNK_SIZE,
False, decode=False,
) )
except Exception as exc: # pylint: disable=broad-exception-caught except Exception as exc: # pylint: disable=broad-exception-caught
self.fail(f"Error getting image {image_name} - {exc}") self.fail(f"Error getting image {image_name} - {exc}")
@ -725,7 +748,7 @@ class ImageManager(DockerBaseClass):
self.results["image"] = image self.results["image"] = image
def push_image(self, name, tag=None): def push_image(self, name: str, tag: str | None = None) -> None:
""" """
If the name of the image contains a repository path, then push the image. If the name of the image contains a repository path, then push the image.
@ -799,7 +822,9 @@ class ImageManager(DockerBaseClass):
self.results["image"] = {} self.results["image"] = {}
self.results["image"]["push_status"] = status self.results["image"]["push_status"] = status
def tag_image(self, name, tag, repository, push=False): def tag_image(
self, name: str, tag: str, repository: str, push: bool = False
) -> None:
""" """
Tag an image into a repository. Tag an image into a repository.
@ -852,7 +877,7 @@ class ImageManager(DockerBaseClass):
self.push_image(repo, repo_tag) self.push_image(repo, repo_tag)
@staticmethod @staticmethod
def _extract_output_line(line, output): def _extract_output_line(line: dict[str, t.Any], output: list[str]):
""" """
Extract text line from stream output and, if found, adds it to output. Extract text line from stream output and, if found, adds it to output.
""" """
@ -862,14 +887,15 @@ class ImageManager(DockerBaseClass):
text_line = line.get("stream") or line.get("status") or "" text_line = line.get("stream") or line.get("status") or ""
output.extend(text_line.splitlines()) output.extend(text_line.splitlines())
def build_image(self): def build_image(self) -> dict[str, t.Any]:
""" """
Build an image Build an image
:return: image dict :return: image dict
""" """
assert self.build_path is not None
remote = context = None remote = context = None
headers = {} headers: dict[str, str | bytes] = {}
buildargs = {} buildargs = {}
if self.buildargs: if self.buildargs:
for key, value in self.buildargs.items(): for key, value in self.buildargs.items():
@ -898,12 +924,12 @@ class ImageManager(DockerBaseClass):
[line.strip() for line in f.read().splitlines()], [line.strip() for line in f.read().splitlines()],
) )
) )
dockerfile = process_dockerfile(dockerfile, self.build_path) dockerfile_data = process_dockerfile(dockerfile, self.build_path)
context = tar( context = tar(
self.build_path, exclude=exclude, dockerfile=dockerfile, gzip=False self.build_path, exclude=exclude, dockerfile=dockerfile_data, gzip=False
) )
params = { params: dict[str, t.Any] = {
"t": f"{self.name}:{self.tag}" if self.tag else self.name, "t": f"{self.name}:{self.tag}" if self.tag else self.name,
"remote": remote, "remote": remote,
"q": False, "q": False,
@ -960,7 +986,7 @@ class ImageManager(DockerBaseClass):
if context is not None: if context is not None:
context.close() context.close()
build_output = [] build_output: list[str] = []
for line in self.client._stream_helper(response, decode=True): for line in self.client._stream_helper(response, decode=True):
# line = json.loads(line) # line = json.loads(line)
self.log(line, pretty_print=True) self.log(line, pretty_print=True)
@ -982,14 +1008,15 @@ class ImageManager(DockerBaseClass):
"image": self.client.find_image(name=self.name, tag=self.tag), "image": self.client.find_image(name=self.name, tag=self.tag),
} }
def load_image(self): def load_image(self) -> dict[str, t.Any] | None:
""" """
Load an image from a .tar archive Load an image from a .tar archive
:return: image dict :return: image dict
""" """
# Load image(s) from file # Load image(s) from file
load_output = [] assert self.load_path is not None
load_output: list[str] = []
has_output = False has_output = False
try: try:
self.log(f"Opening image {self.load_path}") self.log(f"Opening image {self.load_path}")
@ -1078,7 +1105,7 @@ class ImageManager(DockerBaseClass):
return self.client.find_image(self.name, self.tag) return self.client.find_image(self.name, self.tag)
def main(): def main() -> None:
argument_spec = { argument_spec = {
"source": {"type": "str", "choices": ["build", "load", "pull", "local"]}, "source": {"type": "str", "choices": ["build", "load", "pull", "local"]},
"build": { "build": {

View File

@ -282,6 +282,7 @@ command:
import base64 import base64
import os import os
import traceback import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_native from ansible.module_utils.common.text.converters import to_native
from ansible.module_utils.common.text.formatters import human_to_bytes from ansible.module_utils.common.text.formatters import human_to_bytes
@ -304,7 +305,16 @@ from ansible_collections.community.docker.plugins.module_utils._version import (
) )
def convert_to_bytes(value, module, name, unlimited_value=None): if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def convert_to_bytes(
value: str | None,
module: AnsibleModule,
name: str,
unlimited_value: int | None = None,
) -> int | None:
if value is None: if value is None:
return value return value
try: try:
@ -315,11 +325,11 @@ def convert_to_bytes(value, module, name, unlimited_value=None):
module.fail_json(msg=f"Failed to convert {name} to bytes: {exc}") module.fail_json(msg=f"Failed to convert {name} to bytes: {exc}")
def dict_to_list(dictionary, concat="="): def dict_to_list(dictionary: dict[str, t.Any], concat: str = "=") -> list[str]:
return [f"{k}{concat}{v}" for k, v in sorted(dictionary.items())] return [f"{k}{concat}{v}" for k, v in sorted(dictionary.items())]
def _quote_csv(text): def _quote_csv(text: str) -> str:
if text.strip() == text and all(i not in text for i in '",\r\n'): if text.strip() == text and all(i not in text for i in '",\r\n'):
return text return text
text = text.replace('"', '""') text = text.replace('"', '""')
@ -327,7 +337,7 @@ def _quote_csv(text):
class ImageBuilder(DockerBaseClass): class ImageBuilder(DockerBaseClass):
def __init__(self, client): def __init__(self, client: AnsibleModuleDockerClient) -> None:
super().__init__() super().__init__()
self.client = client self.client = client
self.check_mode = self.client.check_mode self.check_mode = self.client.check_mode
@ -420,14 +430,14 @@ class ImageBuilder(DockerBaseClass):
f" buildx plugin has version {buildx_version} which only supports one output." f" buildx plugin has version {buildx_version} which only supports one output."
) )
def fail(self, msg, **kwargs): def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
self.client.fail(msg, **kwargs) self.client.fail(msg, **kwargs)
def add_list_arg(self, args, option, values): def add_list_arg(self, args: list[str], option: str, values: list[str]) -> None:
for value in values: for value in values:
args.extend([option, value]) args.extend([option, value])
def add_args(self, args): def add_args(self, args: list[str]) -> dict[str, t.Any]:
environ_update = {} environ_update = {}
if not self.outputs: if not self.outputs:
args.extend(["--tag", f"{self.name}:{self.tag}"]) args.extend(["--tag", f"{self.name}:{self.tag}"])
@ -512,9 +522,9 @@ class ImageBuilder(DockerBaseClass):
) )
return environ_update return environ_update
def build_image(self): def build_image(self) -> dict[str, t.Any]:
image = self.client.find_image(self.name, self.tag) image = self.client.find_image(self.name, self.tag)
results = { results: dict[str, t.Any] = {
"changed": False, "changed": False,
"actions": [], "actions": [],
"image": image or {}, "image": image or {},
@ -547,7 +557,7 @@ class ImageBuilder(DockerBaseClass):
return results return results
def main(): def main() -> None:
argument_spec = { argument_spec = {
"name": {"type": "str", "required": True}, "name": {"type": "str", "required": True},
"tag": {"type": "str", "default": "latest"}, "tag": {"type": "str", "default": "latest"},

View File

@ -94,6 +94,7 @@ images:
""" """
import traceback import traceback
import typing as t
from ansible_collections.community.docker.plugins.module_utils._api.constants import ( from ansible_collections.community.docker.plugins.module_utils._api.constants import (
DEFAULT_DATA_CHUNK_SIZE, DEFAULT_DATA_CHUNK_SIZE,
@ -121,7 +122,7 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
class ImageExportManager(DockerBaseClass): class ImageExportManager(DockerBaseClass):
def __init__(self, client): def __init__(self, client: AnsibleDockerClient) -> None:
super().__init__() super().__init__()
self.client = client self.client = client
@ -151,10 +152,10 @@ class ImageExportManager(DockerBaseClass):
if not self.names: if not self.names:
self.fail("At least one image name must be specified") self.fail("At least one image name must be specified")
def fail(self, msg): def fail(self, msg: str) -> t.NoReturn:
self.client.fail(msg) self.client.fail(msg)
def get_export_reason(self): def get_export_reason(self) -> str | None:
if self.force: if self.force:
return "Exporting since force=true" return "Exporting since force=true"
@ -178,13 +179,13 @@ class ImageExportManager(DockerBaseClass):
found = True found = True
break break
if not found: if not found:
return f'Overwriting archive since it contains unexpected image {archived_image.image_id} named {", ".join(archived_image.repo_tags)}' return f"Overwriting archive since it contains unexpected image {archived_image.image_id} named {', '.join(archived_image.repo_tags)}"
if left_names: if left_names:
return f"Overwriting archive since it is missing image(s) {', '.join([name['joined'] for name in left_names])}" return f"Overwriting archive since it is missing image(s) {', '.join([name['joined'] for name in left_names])}"
return None return None
def write_chunks(self, chunks): def write_chunks(self, chunks: t.Generator[bytes]) -> None:
try: try:
with open(self.path, "wb") as fd: with open(self.path, "wb") as fd:
for chunk in chunks: for chunk in chunks:
@ -192,7 +193,7 @@ class ImageExportManager(DockerBaseClass):
except Exception as exc: # pylint: disable=broad-exception-caught except Exception as exc: # pylint: disable=broad-exception-caught
self.fail(f"Error writing image archive {self.path} - {exc}") self.fail(f"Error writing image archive {self.path} - {exc}")
def export_images(self): def export_images(self) -> None:
image_names = [name["joined"] for name in self.names] image_names = [name["joined"] for name in self.names]
image_names_str = ", ".join(image_names) image_names_str = ", ".join(image_names)
if len(image_names) == 1: if len(image_names) == 1:
@ -202,8 +203,8 @@ class ImageExportManager(DockerBaseClass):
self.client._get( self.client._get(
self.client._url("/images/{0}/get", image_names[0]), stream=True self.client._url("/images/{0}/get", image_names[0]), stream=True
), ),
DEFAULT_DATA_CHUNK_SIZE, chunk_size=DEFAULT_DATA_CHUNK_SIZE,
False, decode=False,
) )
except Exception as exc: # pylint: disable=broad-exception-caught except Exception as exc: # pylint: disable=broad-exception-caught
self.fail(f"Error getting image {image_names[0]} - {exc}") self.fail(f"Error getting image {image_names[0]} - {exc}")
@ -216,15 +217,15 @@ class ImageExportManager(DockerBaseClass):
stream=True, stream=True,
params={"names": image_names}, params={"names": image_names},
), ),
DEFAULT_DATA_CHUNK_SIZE, chunk_size=DEFAULT_DATA_CHUNK_SIZE,
False, decode=False,
) )
except Exception as exc: # pylint: disable=broad-exception-caught except Exception as exc: # pylint: disable=broad-exception-caught
self.fail(f"Error getting images {image_names_str} - {exc}") self.fail(f"Error getting images {image_names_str} - {exc}")
self.write_chunks(chunks) self.write_chunks(chunks)
def run(self): def run(self) -> dict[str, t.Any]:
tag = self.tag tag = self.tag
if not tag: if not tag:
tag = "latest" tag = "latest"
@ -260,7 +261,7 @@ class ImageExportManager(DockerBaseClass):
return results return results
def main(): def main() -> None:
argument_spec = { argument_spec = {
"path": {"type": "path"}, "path": {"type": "path"},
"force": {"type": "bool", "default": False}, "force": {"type": "bool", "default": False},

View File

@ -136,6 +136,7 @@ images:
""" """
import traceback import traceback
import typing as t
from ansible_collections.community.docker.plugins.module_utils._api.errors import ( from ansible_collections.community.docker.plugins.module_utils._api.errors import (
DockerException, DockerException,
@ -155,9 +156,7 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
class ImageManager(DockerBaseClass): class ImageManager(DockerBaseClass):
def __init__(self, client: AnsibleDockerClient, results: dict[str, t.Any]) -> None:
def __init__(self, client, results):
super().__init__() super().__init__()
self.client = client self.client = client
@ -170,10 +169,10 @@ class ImageManager(DockerBaseClass):
else: else:
self.results["images"] = self.get_all_images() self.results["images"] = self.get_all_images()
def fail(self, msg): def fail(self, msg: str) -> t.NoReturn:
self.client.fail(msg) self.client.fail(msg)
def get_facts(self): def get_facts(self) -> list[dict[str, t.Any]]:
""" """
Lookup and inspect each image name found in the names parameter. Lookup and inspect each image name found in the names parameter.
@ -200,7 +199,7 @@ class ImageManager(DockerBaseClass):
results.append(image) results.append(image)
return results return results
def get_all_images(self): def get_all_images(self) -> list[dict[str, t.Any]]:
results = [] results = []
params = { params = {
"only_ids": 0, "only_ids": 0,
@ -218,7 +217,7 @@ class ImageManager(DockerBaseClass):
return results return results
def main(): def main() -> None:
argument_spec = { argument_spec = {
"name": {"type": "list", "elements": "str"}, "name": {"type": "list", "elements": "str"},
} }

View File

@ -80,6 +80,7 @@ images:
import errno import errno
import traceback import traceback
import typing as t
from ansible_collections.community.docker.plugins.module_utils._api.errors import ( from ansible_collections.community.docker.plugins.module_utils._api.errors import (
DockerException, DockerException,
@ -95,7 +96,7 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
class ImageManager(DockerBaseClass): class ImageManager(DockerBaseClass):
def __init__(self, client, results): def __init__(self, client: AnsibleDockerClient, results: dict[str, t.Any]) -> None:
super().__init__() super().__init__()
self.client = client self.client = client
@ -108,7 +109,7 @@ class ImageManager(DockerBaseClass):
self.load_images() self.load_images()
@staticmethod @staticmethod
def _extract_output_line(line, output): def _extract_output_line(line: dict[str, t.Any], output: list[str]) -> None:
""" """
Extract text line from stream output and, if found, adds it to output. Extract text line from stream output and, if found, adds it to output.
""" """
@ -118,12 +119,12 @@ class ImageManager(DockerBaseClass):
text_line = line.get("stream") or line.get("status") or "" text_line = line.get("stream") or line.get("status") or ""
output.extend(text_line.splitlines()) output.extend(text_line.splitlines())
def load_images(self): def load_images(self) -> None:
""" """
Load images from a .tar archive Load images from a .tar archive
""" """
# Load image(s) from file # Load image(s) from file
load_output = [] load_output: list[str] = []
try: try:
self.log(f"Opening image {self.path}") self.log(f"Opening image {self.path}")
with open(self.path, "rb") as image_tar: with open(self.path, "rb") as image_tar:
@ -179,7 +180,7 @@ class ImageManager(DockerBaseClass):
self.results["stdout"] = "\n".join(load_output) self.results["stdout"] = "\n".join(load_output)
def main(): def main() -> None:
client = AnsibleDockerClient( client = AnsibleDockerClient(
argument_spec={ argument_spec={
"path": {"type": "path", "required": True}, "path": {"type": "path", "required": True},
@ -188,7 +189,7 @@ def main():
) )
try: try:
results = { results: dict[str, t.Any] = {
"image_names": [], "image_names": [],
"images": [], "images": [],
} }

View File

@ -91,6 +91,7 @@ image:
""" """
import traceback import traceback
import typing as t
from ansible_collections.community.docker.plugins.module_utils._api.errors import ( from ansible_collections.community.docker.plugins.module_utils._api.errors import (
DockerException, DockerException,
@ -114,7 +115,7 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
) )
def image_info(image): def image_info(image: dict[str, t.Any] | None) -> dict[str, t.Any]:
result = {} result = {}
if image: if image:
result["id"] = image["Id"] result["id"] = image["Id"]
@ -124,17 +125,17 @@ def image_info(image):
class ImagePuller(DockerBaseClass): class ImagePuller(DockerBaseClass):
def __init__(self, client): def __init__(self, client: AnsibleDockerClient) -> None:
super().__init__() super().__init__()
self.client = client self.client = client
self.check_mode = self.client.check_mode self.check_mode = self.client.check_mode
parameters = self.client.module.params parameters = self.client.module.params
self.name = parameters["name"] self.name: str = parameters["name"]
self.tag = parameters["tag"] self.tag: str = parameters["tag"]
self.platform = parameters["platform"] self.platform: str | None = parameters["platform"]
self.pull_mode = parameters["pull"] self.pull_mode: t.Literal["always", "not_present"] = parameters["pull"]
if is_image_name_id(self.name): if is_image_name_id(self.name):
self.client.fail("Cannot pull an image by ID") self.client.fail("Cannot pull an image by ID")
@ -147,13 +148,15 @@ class ImagePuller(DockerBaseClass):
self.name = repo self.name = repo
self.tag = repo_tag self.tag = repo_tag
def pull(self): def pull(self) -> dict[str, t.Any]:
image = self.client.find_image(name=self.name, tag=self.tag) image = self.client.find_image(name=self.name, tag=self.tag)
actions: list[str] = []
diff = {"before": image_info(image), "after": image_info(image)}
results = { results = {
"changed": False, "changed": False,
"actions": [], "actions": actions,
"image": image or {}, "image": image or {},
"diff": {"before": image_info(image), "after": image_info(image)}, "diff": diff,
} }
if image and self.pull_mode == "not_present": if image and self.pull_mode == "not_present":
@ -175,21 +178,22 @@ class ImagePuller(DockerBaseClass):
if compare_platform_strings(wanted_platform, image_platform): if compare_platform_strings(wanted_platform, image_platform):
return results return results
results["actions"].append(f"Pulled image {self.name}:{self.tag}") actions.append(f"Pulled image {self.name}:{self.tag}")
if self.check_mode: if self.check_mode:
results["changed"] = True results["changed"] = True
results["diff"]["after"] = image_info({"Id": "unknown"}) diff["after"] = image_info({"Id": "unknown"})
else: else:
results["image"], not_changed = self.client.pull_image( image, not_changed = self.client.pull_image(
self.name, tag=self.tag, image_platform=self.platform self.name, tag=self.tag, image_platform=self.platform
) )
results["image"] = image
results["changed"] = not not_changed results["changed"] = not not_changed
results["diff"]["after"] = image_info(results["image"]) diff["after"] = image_info(image)
return results return results
def main(): def main() -> None:
argument_spec = { argument_spec = {
"name": {"type": "str", "required": True}, "name": {"type": "str", "required": True},
"tag": {"type": "str", "default": "latest"}, "tag": {"type": "str", "default": "latest"},

View File

@ -73,6 +73,7 @@ image:
import base64 import base64
import traceback import traceback
import typing as t
from ansible_collections.community.docker.plugins.module_utils._api.auth import ( from ansible_collections.community.docker.plugins.module_utils._api.auth import (
get_config_header, get_config_header,
@ -96,15 +97,15 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
class ImagePusher(DockerBaseClass): class ImagePusher(DockerBaseClass):
def __init__(self, client): def __init__(self, client: AnsibleDockerClient) -> None:
super().__init__() super().__init__()
self.client = client self.client = client
self.check_mode = self.client.check_mode self.check_mode = self.client.check_mode
parameters = self.client.module.params parameters = self.client.module.params
self.name = parameters["name"] self.name: str = parameters["name"]
self.tag = parameters["tag"] self.tag: str = parameters["tag"]
if is_image_name_id(self.name): if is_image_name_id(self.name):
self.client.fail("Cannot push an image by ID") self.client.fail("Cannot push an image by ID")
@ -122,20 +123,21 @@ class ImagePusher(DockerBaseClass):
if not is_valid_tag(self.tag, allow_empty=False): if not is_valid_tag(self.tag, allow_empty=False):
self.client.fail(f'"{self.tag}" is not a valid docker tag!') self.client.fail(f'"{self.tag}" is not a valid docker tag!')
def push(self): def push(self) -> dict[str, t.Any]:
image = self.client.find_image(name=self.name, tag=self.tag) image = self.client.find_image(name=self.name, tag=self.tag)
if not image: if not image:
self.client.fail(f"Cannot find image {self.name}:{self.tag}") self.client.fail(f"Cannot find image {self.name}:{self.tag}")
results = { actions: list[str] = []
results: dict[str, t.Any] = {
"changed": False, "changed": False,
"actions": [], "actions": actions,
"image": image, "image": image,
} }
push_registry, push_repo = resolve_repository_name(self.name) push_registry, push_repo = resolve_repository_name(self.name)
try: try:
results["actions"].append(f"Pushed image {self.name}:{self.tag}") actions.append(f"Pushed image {self.name}:{self.tag}")
headers = {} headers = {}
header = get_config_header(self.client, push_registry) header = get_config_header(self.client, push_registry)
@ -174,7 +176,7 @@ class ImagePusher(DockerBaseClass):
return results return results
def main(): def main() -> None:
argument_spec = { argument_spec = {
"name": {"type": "str", "required": True}, "name": {"type": "str", "required": True},
"tag": {"type": "str", "default": "latest"}, "tag": {"type": "str", "default": "latest"},

View File

@ -98,6 +98,7 @@ untagged:
""" """
import traceback import traceback
import typing as t
from ansible_collections.community.docker.plugins.module_utils._api.errors import ( from ansible_collections.community.docker.plugins.module_utils._api.errors import (
DockerException, DockerException,
@ -118,8 +119,7 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
class ImageRemover(DockerBaseClass): class ImageRemover(DockerBaseClass):
def __init__(self, client: AnsibleDockerClient) -> None:
def __init__(self, client):
super().__init__() super().__init__()
self.client = client self.client = client
@ -142,10 +142,10 @@ class ImageRemover(DockerBaseClass):
self.name = repo self.name = repo
self.tag = repo_tag self.tag = repo_tag
def fail(self, msg): def fail(self, msg: str) -> t.NoReturn:
self.client.fail(msg) self.client.fail(msg)
def get_diff_state(self, image): def get_diff_state(self, image: dict[str, t.Any] | None) -> dict[str, t.Any]:
if not image: if not image:
return {"exists": False} return {"exists": False}
return { return {
@ -155,13 +155,16 @@ class ImageRemover(DockerBaseClass):
"digests": sorted(image.get("RepoDigests") or []), "digests": sorted(image.get("RepoDigests") or []),
} }
def absent(self): def absent(self) -> dict[str, t.Any]:
results = { actions: list[str] = []
deleted: list[str] = []
untagged: list[str] = []
results: dict[str, t.Any] = {
"changed": False, "changed": False,
"actions": [], "actions": actions,
"image": {}, "image": {},
"deleted": [], "deleted": deleted,
"untagged": [], "untagged": untagged,
} }
name = self.name name = self.name
@ -172,16 +175,18 @@ class ImageRemover(DockerBaseClass):
if self.tag: if self.tag:
name = f"{self.name}:{self.tag}" name = f"{self.name}:{self.tag}"
diff: dict[str, t.Any] = {}
if self.diff: if self.diff:
results["diff"] = {"before": self.get_diff_state(image)} results["diff"] = diff
diff["before"] = self.get_diff_state(image)
if not image: if not image:
if self.diff: if self.diff:
results["diff"]["after"] = self.get_diff_state(image) diff["after"] = self.get_diff_state(image)
return results return results
results["changed"] = True results["changed"] = True
results["actions"].append(f"Removed image {name}") actions.append(f"Removed image {name}")
results["image"] = image results["image"] = image
if not self.check_mode: if not self.check_mode:
@ -199,22 +204,22 @@ class ImageRemover(DockerBaseClass):
for entry in res: for entry in res:
if entry.get("Untagged"): if entry.get("Untagged"):
results["untagged"].append(entry["Untagged"]) untagged.append(entry["Untagged"])
if entry.get("Deleted"): if entry.get("Deleted"):
results["deleted"].append(entry["Deleted"]) deleted.append(entry["Deleted"])
results["untagged"] = sorted(results["untagged"]) untagged[:] = sorted(untagged)
results["deleted"] = sorted(results["deleted"]) deleted[:] = sorted(deleted)
if self.diff: if self.diff:
image_after = self.client.find_image_by_id( image_after = self.client.find_image_by_id(
image["Id"], accept_missing_image=True image["Id"], accept_missing_image=True
) )
results["diff"]["after"] = self.get_diff_state(image_after) diff["after"] = self.get_diff_state(image_after)
elif is_image_name_id(name): elif is_image_name_id(name):
results["deleted"].append(image["Id"]) deleted.append(image["Id"])
results["untagged"] = sorted( untagged[:] = sorted(
(image.get("RepoTags") or []) + (image.get("RepoDigests") or []) (image.get("RepoTags") or []) + (image.get("RepoDigests") or [])
) )
if not self.force and results["untagged"]: if not self.force and results["untagged"]:
@ -222,40 +227,40 @@ class ImageRemover(DockerBaseClass):
"Cannot delete image by ID that is still in use - use force=true" "Cannot delete image by ID that is still in use - use force=true"
) )
if self.diff: if self.diff:
results["diff"]["after"] = self.get_diff_state({}) diff["after"] = self.get_diff_state({})
elif is_image_name_id(self.tag): elif is_image_name_id(self.tag):
results["untagged"].append(name) untagged.append(name)
if ( if (
len(image.get("RepoTags") or []) < 1 len(image.get("RepoTags") or []) < 1
and len(image.get("RepoDigests") or []) < 2 and len(image.get("RepoDigests") or []) < 2
): ):
results["deleted"].append(image["Id"]) deleted.append(image["Id"])
if self.diff: if self.diff:
results["diff"]["after"] = self.get_diff_state(image) diff["after"] = self.get_diff_state(image)
try: try:
results["diff"]["after"]["digests"].remove(name) diff["after"]["digests"].remove(name)
except ValueError: except ValueError:
pass pass
else: else:
results["untagged"].append(name) untagged.append(name)
if ( if (
len(image.get("RepoTags") or []) < 2 len(image.get("RepoTags") or []) < 2
and len(image.get("RepoDigests") or []) < 1 and len(image.get("RepoDigests") or []) < 1
): ):
results["deleted"].append(image["Id"]) deleted.append(image["Id"])
if self.diff: if self.diff:
results["diff"]["after"] = self.get_diff_state(image) diff["after"] = self.get_diff_state(image)
try: try:
results["diff"]["after"]["tags"].remove(name) diff["after"]["tags"].remove(name)
except ValueError: except ValueError:
pass pass
return results return results
def main(): def main() -> None:
argument_spec = { argument_spec = {
"name": {"type": "str", "required": True}, "name": {"type": "str", "required": True},
"tag": {"type": "str", "default": "latest"}, "tag": {"type": "str", "default": "latest"},

View File

@ -101,6 +101,7 @@ tagged_images:
""" """
import traceback import traceback
import typing as t
from ansible.module_utils.common.text.formatters import human_to_bytes from ansible.module_utils.common.text.formatters import human_to_bytes
@ -121,7 +122,16 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
) )
def convert_to_bytes(value, module, name, unlimited_value=None): if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def convert_to_bytes(
value: str | None,
module: AnsibleModule,
name: str,
unlimited_value: int | None = None,
) -> int | None:
if value is None: if value is None:
return value return value
try: try:
@ -132,8 +142,8 @@ def convert_to_bytes(value, module, name, unlimited_value=None):
module.fail_json(msg=f"Failed to convert {name} to bytes: {exc}") module.fail_json(msg=f"Failed to convert {name} to bytes: {exc}")
def image_info(name, tag, image): def image_info(name: str, tag: str, image: dict[str, t.Any] | None) -> dict[str, t.Any]:
result = {"name": name, "tag": tag} result: dict[str, t.Any] = {"name": name, "tag": tag}
if image: if image:
result["id"] = image["Id"] result["id"] = image["Id"]
else: else:
@ -142,7 +152,7 @@ def image_info(name, tag, image):
class ImageTagger(DockerBaseClass): class ImageTagger(DockerBaseClass):
def __init__(self, client): def __init__(self, client: AnsibleDockerClient) -> None:
super().__init__() super().__init__()
self.client = client self.client = client
@ -179,10 +189,12 @@ class ImageTagger(DockerBaseClass):
) )
self.repositories.append((repo, repo_tag)) self.repositories.append((repo, repo_tag))
def fail(self, msg): def fail(self, msg: str) -> t.NoReturn:
self.client.fail(msg) self.client.fail(msg)
def tag_image(self, image, name, tag): def tag_image(
self, image: dict[str, t.Any], name: str, tag: str
) -> tuple[bool, str, dict[str, t.Any] | None]:
tagged_image = self.client.find_image(name=name, tag=tag) tagged_image = self.client.find_image(name=name, tag=tag)
if tagged_image: if tagged_image:
# Idempotency checks # Idempotency checks
@ -220,20 +232,22 @@ class ImageTagger(DockerBaseClass):
return True, msg, tagged_image return True, msg, tagged_image
def tag_images(self): def tag_images(self) -> dict[str, t.Any]:
if is_image_name_id(self.name): if is_image_name_id(self.name):
image = self.client.find_image_by_id(self.name, accept_missing_image=False) image = self.client.find_image_by_id(self.name, accept_missing_image=False)
else: else:
image = self.client.find_image(name=self.name, tag=self.tag) image = self.client.find_image(name=self.name, tag=self.tag)
if not image: if not image:
self.fail(f"Cannot find image {self.name}:{self.tag}") self.fail(f"Cannot find image {self.name}:{self.tag}")
assert image is not None
before = [] before: list[dict[str, t.Any]] = []
after = [] after: list[dict[str, t.Any]] = []
tagged_images = [] tagged_images: list[str] = []
results = { actions: list[str] = []
results: dict[str, t.Any] = {
"changed": False, "changed": False,
"actions": [], "actions": actions,
"image": image, "image": image,
"tagged_images": tagged_images, "tagged_images": tagged_images,
"diff": {"before": {"images": before}, "after": {"images": after}}, "diff": {"before": {"images": before}, "after": {"images": after}},
@ -244,19 +258,19 @@ class ImageTagger(DockerBaseClass):
after.append(image_info(repository, tag, image if tagged else old_image)) after.append(image_info(repository, tag, image if tagged else old_image))
if tagged: if tagged:
results["changed"] = True results["changed"] = True
results["actions"].append( actions.append(
f"Tagged image {image['Id']} as {repository}:{tag}: {msg}" f"Tagged image {image['Id']} as {repository}:{tag}: {msg}"
) )
tagged_images.append(f"{repository}:{tag}") tagged_images.append(f"{repository}:{tag}")
else: else:
results["actions"].append( actions.append(
f"Not tagged image {image['Id']} as {repository}:{tag}: {msg}" f"Not tagged image {image['Id']} as {repository}:{tag}: {msg}"
) )
return results return results
def main(): def main() -> None:
argument_spec = { argument_spec = {
"name": {"type": "str", "required": True}, "name": {"type": "str", "required": True},
"tag": {"type": "str", "default": "latest"}, "tag": {"type": "str", "default": "latest"},

View File

@ -120,6 +120,7 @@ import base64
import json import json
import os import os
import traceback import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_bytes, to_text from ansible.module_utils.common.text.converters import to_bytes, to_text
@ -154,11 +155,11 @@ class DockerFileStore:
program = "<legacy config>" program = "<legacy config>"
def __init__(self, config_path): def __init__(self, config_path: str) -> None:
self._config_path = config_path self._config_path = config_path
# Make sure we have a minimal config if none is available. # Make sure we have a minimal config if none is available.
self._config = {"auths": {}} self._config: dict[str, t.Any] = {"auths": {}}
try: try:
# Attempt to read the existing config. # Attempt to read the existing config.
@ -172,14 +173,14 @@ class DockerFileStore:
self._config.update(config) self._config.update(config)
@property @property
def config_path(self): def config_path(self) -> str:
""" """
Return the config path configured in this DockerFileStore instance. Return the config path configured in this DockerFileStore instance.
""" """
return self._config_path return self._config_path
def get(self, server): def get(self, server: str) -> dict[str, t.Any]:
""" """
Retrieve credentials for `server` if there are any in the config file. Retrieve credentials for `server` if there are any in the config file.
Otherwise raise a `StoreError` Otherwise raise a `StoreError`
@ -193,7 +194,7 @@ class DockerFileStore:
return {"Username": username, "Secret": password} return {"Username": username, "Secret": password}
def _write(self): def _write(self) -> None:
""" """
Write config back out to disk. Write config back out to disk.
""" """
@ -209,7 +210,7 @@ class DockerFileStore:
finally: finally:
os.close(f) os.close(f)
def store(self, server, username, password): def store(self, server: str, username: str, password: str) -> None:
""" """
Add a credentials for `server` to the current configuration. Add a credentials for `server` to the current configuration.
""" """
@ -225,7 +226,7 @@ class DockerFileStore:
self._write() self._write()
def erase(self, server): def erase(self, server: str) -> None:
""" """
Remove credentials for the given server from the configuration. Remove credentials for the given server from the configuration.
""" """
@ -236,9 +237,7 @@ class DockerFileStore:
class LoginManager(DockerBaseClass): class LoginManager(DockerBaseClass):
def __init__(self, client: AnsibleDockerClient, results: dict[str, t.Any]) -> None:
def __init__(self, client, results):
super().__init__() super().__init__()
self.client = client self.client = client
@ -246,14 +245,14 @@ class LoginManager(DockerBaseClass):
parameters = self.client.module.params parameters = self.client.module.params
self.check_mode = self.client.check_mode self.check_mode = self.client.check_mode
self.registry_url = parameters.get("registry_url") self.registry_url: str = parameters.get("registry_url")
self.username = parameters.get("username") self.username: str | None = parameters.get("username")
self.password = parameters.get("password") self.password: str | None = parameters.get("password")
self.reauthorize = parameters.get("reauthorize") self.reauthorize: bool = parameters.get("reauthorize")
self.config_path = parameters.get("config_path") self.config_path: str = parameters.get("config_path")
self.state = parameters.get("state") self.state: t.Literal["present", "absent"] = parameters.get("state")
def run(self): def run(self) -> None:
""" """
Do the actual work of this task here. This allows instantiation for partial Do the actual work of this task here. This allows instantiation for partial
testing. testing.
@ -264,10 +263,10 @@ class LoginManager(DockerBaseClass):
else: else:
self.logout() self.logout()
def fail(self, msg): def fail(self, msg: str) -> t.NoReturn:
self.client.fail(msg) self.client.fail(msg)
def _login(self, reauth): def _login(self, reauth: bool) -> dict[str, t.Any]:
if self.config_path and os.path.exists(self.config_path): if self.config_path and os.path.exists(self.config_path):
self.client._auth_configs = auth.load_config( self.client._auth_configs = auth.load_config(
self.config_path, credstore_env=self.client.credstore_env self.config_path, credstore_env=self.client.credstore_env
@ -297,7 +296,7 @@ class LoginManager(DockerBaseClass):
) )
return self.client._result(response, get_json=True) return self.client._result(response, get_json=True)
def login(self): def login(self) -> None:
""" """
Log into the registry with provided username/password. On success update the config Log into the registry with provided username/password. On success update the config
file with the new authorization. file with the new authorization.
@ -331,7 +330,7 @@ class LoginManager(DockerBaseClass):
self.update_credentials() self.update_credentials()
def logout(self): def logout(self) -> None:
""" """
Log out of the registry. On success update the config file. Log out of the registry. On success update the config file.
@ -353,13 +352,16 @@ class LoginManager(DockerBaseClass):
store.erase(self.registry_url) store.erase(self.registry_url)
self.results["changed"] = True self.results["changed"] = True
def update_credentials(self): def update_credentials(self) -> None:
""" """
If the authorization is not stored attempt to store authorization values via If the authorization is not stored attempt to store authorization values via
the appropriate credential helper or to the config file. the appropriate credential helper or to the config file.
:return: None :return: None
""" """
# This is only called from login()
assert self.username is not None
assert self.password is not None
# Check to see if credentials already exist. # Check to see if credentials already exist.
store = self.get_credential_store_instance(self.registry_url, self.config_path) store = self.get_credential_store_instance(self.registry_url, self.config_path)
@ -385,7 +387,9 @@ class LoginManager(DockerBaseClass):
) )
self.results["changed"] = True self.results["changed"] = True
def get_credential_store_instance(self, registry, dockercfg_path): def get_credential_store_instance(
self, registry: str, dockercfg_path: str
) -> Store | DockerFileStore:
""" """
Return an instance of docker.credentials.Store used by the given registry. Return an instance of docker.credentials.Store used by the given registry.
@ -408,8 +412,7 @@ class LoginManager(DockerBaseClass):
return DockerFileStore(dockercfg_path) return DockerFileStore(dockercfg_path)
def main(): def main() -> None:
argument_spec = { argument_spec = {
"registry_url": { "registry_url": {
"type": "str", "type": "str",

View File

@ -284,6 +284,7 @@ network:
import re import re
import time import time
import traceback import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_native from ansible.module_utils.common.text.converters import to_native
@ -303,29 +304,31 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
class TaskParameters(DockerBaseClass): class TaskParameters(DockerBaseClass):
def __init__(self, client): name: str
def __init__(self, client: AnsibleDockerClient) -> None:
super().__init__() super().__init__()
self.client = client self.client = client
self.name = None self.connected: list[str] = []
self.connected = None self.config_from: str | None = None
self.config_from = None self.config_only: bool | None = None
self.config_only = None self.driver: str = "bridge"
self.driver = None self.driver_options: dict[str, t.Any] = {}
self.driver_options = None self.ipam_driver: str | None = None
self.ipam_driver = None self.ipam_driver_options: dict[str, t.Any] | None = None
self.ipam_driver_options = None self.ipam_config: list[dict[str, t.Any]] | None = None
self.ipam_config = None self.appends: bool = False
self.appends = None self.force: bool = False
self.force = None self.internal: bool | None = None
self.internal = None self.labels: dict[str, t.Any] = {}
self.labels = None self.debug: bool = False
self.debug = None self.enable_ipv4: bool | None = None
self.enable_ipv4 = None self.enable_ipv6: bool | None = None
self.enable_ipv6 = None self.scope: t.Literal["local", "global", "swarm"] | None = None
self.scope = None self.attachable: bool | None = None
self.attachable = None self.ingress: bool | None = None
self.ingress = None self.state: t.Literal["present", "absent"] = "present"
for key, value in client.module.params.items(): for key, value in client.module.params.items():
setattr(self, key, value) setattr(self, key, value)
@ -333,10 +336,10 @@ class TaskParameters(DockerBaseClass):
# config_only sets driver to 'null' (and scope to 'local') so force that here. Otherwise we get # config_only sets driver to 'null' (and scope to 'local') so force that here. Otherwise we get
# diffs of 'null' --> 'bridge' given that the driver option defaults to 'bridge'. # diffs of 'null' --> 'bridge' given that the driver option defaults to 'bridge'.
if self.config_only: if self.config_only:
self.driver = "null" self.driver = "null" # type: ignore[unreachable]
def container_names_in_network(network): def container_names_in_network(network: dict[str, t.Any]) -> list[str]:
return ( return (
[c["Name"] for c in network["Containers"].values()] [c["Name"] for c in network["Containers"].values()]
if network["Containers"] if network["Containers"]
@ -348,7 +351,7 @@ CIDR_IPV4 = re.compile(r"^([0-9]{1,3}\.){3}[0-9]{1,3}/([0-9]|[1-2][0-9]|3[0-2])$
CIDR_IPV6 = re.compile(r"^[0-9a-fA-F:]+/([0-9]|[1-9][0-9]|1[0-2][0-9])$") CIDR_IPV6 = re.compile(r"^[0-9a-fA-F:]+/([0-9]|[1-9][0-9]|1[0-2][0-9])$")
def validate_cidr(cidr): def validate_cidr(cidr: str) -> t.Literal["ipv4", "ipv6"]:
"""Validate CIDR. Return IP version of a CIDR string on success. """Validate CIDR. Return IP version of a CIDR string on success.
:param cidr: Valid CIDR :param cidr: Valid CIDR
@ -364,7 +367,7 @@ def validate_cidr(cidr):
raise ValueError(f'"{cidr}" is not a valid CIDR') raise ValueError(f'"{cidr}" is not a valid CIDR')
def normalize_ipam_config_key(key): def normalize_ipam_config_key(key: str) -> str:
"""Normalizes IPAM config keys returned by Docker API to match Ansible keys. """Normalizes IPAM config keys returned by Docker API to match Ansible keys.
:param key: Docker API key :param key: Docker API key
@ -376,7 +379,7 @@ def normalize_ipam_config_key(key):
return special_cases.get(key, key.lower()) return special_cases.get(key, key.lower())
def dicts_are_essentially_equal(a, b): def dicts_are_essentially_equal(a: dict[str, t.Any], b: dict[str, t.Any]):
"""Make sure that a is a subset of b, where None entries of a are ignored.""" """Make sure that a is a subset of b, where None entries of a are ignored."""
for k, v in a.items(): for k, v in a.items():
if v is None: if v is None:
@ -387,15 +390,15 @@ def dicts_are_essentially_equal(a, b):
class DockerNetworkManager: class DockerNetworkManager:
def __init__(self, client: AnsibleDockerClient) -> None:
def __init__(self, client):
self.client = client self.client = client
self.parameters = TaskParameters(client) self.parameters = TaskParameters(client)
self.check_mode = self.client.check_mode self.check_mode = self.client.check_mode
self.results = {"changed": False, "actions": []} self.actions: list[str] = []
self.results: dict[str, t.Any] = {"changed": False, "actions": self.actions}
self.diff = self.client.module._diff self.diff = self.client.module._diff
self.diff_tracker = DifferenceTracker() self.diff_tracker = DifferenceTracker()
self.diff_result = {} self.diff_result: dict[str, t.Any] = {}
self.existing_network = self.get_existing_network() self.existing_network = self.get_existing_network()
@ -429,10 +432,12 @@ class DockerNetworkManager:
) )
self.results["diff"] = self.diff_result self.results["diff"] = self.diff_result
def get_existing_network(self): def get_existing_network(self) -> dict[str, t.Any] | None:
return self.client.get_network(name=self.parameters.name) return self.client.get_network(name=self.parameters.name)
def has_different_config(self, net): def has_different_config(
self, net: dict[str, t.Any]
) -> tuple[bool, DifferenceTracker]:
""" """
Evaluates an existing network and returns a tuple containing a boolean Evaluates an existing network and returns a tuple containing a boolean
indicating if the configuration is different and a list of differences. indicating if the configuration is different and a list of differences.
@ -601,9 +606,9 @@ class DockerNetworkManager:
return not differences.empty, differences return not differences.empty, differences
def create_network(self): def create_network(self) -> None:
if not self.existing_network: if not self.existing_network:
data = { data: dict[str, t.Any] = {
"Name": self.parameters.name, "Name": self.parameters.name,
"Driver": self.parameters.driver, "Driver": self.parameters.driver,
"Options": self.parameters.driver_options, "Options": self.parameters.driver_options,
@ -661,12 +666,12 @@ class DockerNetworkManager:
resp = self.client.post_json_to_json("/networks/create", data=data) resp = self.client.post_json_to_json("/networks/create", data=data)
self.client.report_warnings(resp, ["Warning"]) self.client.report_warnings(resp, ["Warning"])
self.existing_network = self.client.get_network(network_id=resp["Id"]) self.existing_network = self.client.get_network(network_id=resp["Id"])
self.results["actions"].append( self.actions.append(
f"Created network {self.parameters.name} with driver {self.parameters.driver}" f"Created network {self.parameters.name} with driver {self.parameters.driver}"
) )
self.results["changed"] = True self.results["changed"] = True
def remove_network(self): def remove_network(self) -> None:
if self.existing_network: if self.existing_network:
self.disconnect_all_containers() self.disconnect_all_containers()
if not self.check_mode: if not self.check_mode:
@ -674,15 +679,15 @@ class DockerNetworkManager:
if self.existing_network.get("Scope", "local") == "swarm": if self.existing_network.get("Scope", "local") == "swarm":
while self.get_existing_network(): while self.get_existing_network():
time.sleep(0.1) time.sleep(0.1)
self.results["actions"].append(f"Removed network {self.parameters.name}") self.actions.append(f"Removed network {self.parameters.name}")
self.results["changed"] = True self.results["changed"] = True
def is_container_connected(self, container_name): def is_container_connected(self, container_name: str) -> bool:
if not self.existing_network: if not self.existing_network:
return False return False
return container_name in container_names_in_network(self.existing_network) return container_name in container_names_in_network(self.existing_network)
def is_container_exist(self, container_name): def is_container_exist(self, container_name: str) -> bool:
try: try:
container = self.client.get_container(container_name) container = self.client.get_container(container_name)
return bool(container) return bool(container)
@ -698,7 +703,7 @@ class DockerNetworkManager:
exception=traceback.format_exc(), exception=traceback.format_exc(),
) )
def connect_containers(self): def connect_containers(self) -> None:
for name in self.parameters.connected: for name in self.parameters.connected:
if not self.is_container_connected(name) and self.is_container_exist(name): if not self.is_container_connected(name) and self.is_container_exist(name):
if not self.check_mode: if not self.check_mode:
@ -709,11 +714,11 @@ class DockerNetworkManager:
self.client.post_json( self.client.post_json(
"/networks/{0}/connect", self.parameters.name, data=data "/networks/{0}/connect", self.parameters.name, data=data
) )
self.results["actions"].append(f"Connected container {name}") self.actions.append(f"Connected container {name}")
self.results["changed"] = True self.results["changed"] = True
self.diff_tracker.add(f"connected.{name}", parameter=True, active=False) self.diff_tracker.add(f"connected.{name}", parameter=True, active=False)
def disconnect_missing(self): def disconnect_missing(self) -> None:
if not self.existing_network: if not self.existing_network:
return return
containers = self.existing_network["Containers"] containers = self.existing_network["Containers"]
@ -724,26 +729,29 @@ class DockerNetworkManager:
if name not in self.parameters.connected: if name not in self.parameters.connected:
self.disconnect_container(name) self.disconnect_container(name)
def disconnect_all_containers(self): def disconnect_all_containers(self) -> None:
containers = self.client.get_network(name=self.parameters.name)["Containers"] network = self.client.get_network(name=self.parameters.name)
if not network:
return
containers = network["Containers"]
if not containers: if not containers:
return return
for cont in containers.values(): for cont in containers.values():
self.disconnect_container(cont["Name"]) self.disconnect_container(cont["Name"])
def disconnect_container(self, container_name): def disconnect_container(self, container_name: str) -> None:
if not self.check_mode: if not self.check_mode:
data = {"Container": container_name, "Force": True} data = {"Container": container_name, "Force": True}
self.client.post_json( self.client.post_json(
"/networks/{0}/disconnect", self.parameters.name, data=data "/networks/{0}/disconnect", self.parameters.name, data=data
) )
self.results["actions"].append(f"Disconnected container {container_name}") self.actions.append(f"Disconnected container {container_name}")
self.results["changed"] = True self.results["changed"] = True
self.diff_tracker.add( self.diff_tracker.add(
f"connected.{container_name}", parameter=False, active=True f"connected.{container_name}", parameter=False, active=True
) )
def present(self): def present(self) -> None:
different = False different = False
differences = DifferenceTracker() differences = DifferenceTracker()
if self.existing_network: if self.existing_network:
@ -771,14 +779,14 @@ class DockerNetworkManager:
network_facts = self.get_existing_network() network_facts = self.get_existing_network()
self.results["network"] = network_facts self.results["network"] = network_facts
def absent(self): def absent(self) -> None:
self.diff_tracker.add( self.diff_tracker.add(
"exists", parameter=False, active=self.existing_network is not None "exists", parameter=False, active=self.existing_network is not None
) )
self.remove_network() self.remove_network()
def main(): def main() -> None:
argument_spec = { argument_spec = {
"name": {"type": "str", "required": True, "aliases": ["network_name"]}, "name": {"type": "str", "required": True, "aliases": ["network_name"]},
"config_from": {"type": "str"}, "config_from": {"type": "str"},

View File

@ -107,7 +107,7 @@ from ansible_collections.community.docker.plugins.module_utils._common_api impor
) )
def main(): def main() -> None:
argument_spec = { argument_spec = {
"name": {"type": "str", "required": True}, "name": {"type": "str", "required": True},
} }

View File

@ -129,6 +129,7 @@ actions:
""" """
import traceback import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_native from ansible.module_utils.common.text.converters import to_native
@ -149,35 +150,36 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
class TaskParameters(DockerBaseClass): class TaskParameters(DockerBaseClass):
def __init__(self, client): plugin_name: str
def __init__(self, client: AnsibleDockerClient) -> None:
super().__init__() super().__init__()
self.client = client self.client = client
self.plugin_name = None self.alias: str | None = None
self.alias = None self.plugin_options: dict[str, t.Any] = {}
self.plugin_options = None self.debug: bool = False
self.debug = None self.force_remove: bool = False
self.force_remove = None self.enable_timeout: int = 0
self.enable_timeout = None self.state: t.Literal["present", "absent", "enable", "disable"] = "present"
for key, value in client.module.params.items(): for key, value in client.module.params.items():
setattr(self, key, value) setattr(self, key, value)
def prepare_options(options): def prepare_options(options: dict[str, t.Any] | None) -> list[str]:
return ( return (
[f'{k}={v if v is not None else ""}' for k, v in options.items()] [f"{k}={v if v is not None else ''}" for k, v in options.items()]
if options if options
else [] else []
) )
def parse_options(options_list): def parse_options(options_list: list[str] | None) -> dict[str, str]:
return dict(x.split("=", 1) for x in options_list) if options_list else {} return dict(x.split("=", 1) for x in options_list) if options_list else {}
class DockerPluginManager: class DockerPluginManager:
def __init__(self, client: AnsibleDockerClient) -> None:
def __init__(self, client):
self.client = client self.client = client
self.parameters = TaskParameters(client) self.parameters = TaskParameters(client)
@ -185,9 +187,9 @@ class DockerPluginManager:
self.check_mode = self.client.check_mode self.check_mode = self.client.check_mode
self.diff = self.client.module._diff self.diff = self.client.module._diff
self.diff_tracker = DifferenceTracker() self.diff_tracker = DifferenceTracker()
self.diff_result = {} self.diff_result: dict[str, t.Any] = {}
self.actions = [] self.actions: list[str] = []
self.changed = False self.changed = False
self.existing_plugin = self.get_existing_plugin() self.existing_plugin = self.get_existing_plugin()
@ -209,7 +211,7 @@ class DockerPluginManager:
) )
self.diff = self.diff_result self.diff = self.diff_result
def get_existing_plugin(self): def get_existing_plugin(self) -> dict[str, t.Any] | None:
try: try:
return self.client.get_json("/plugins/{0}/json", self.preferred_name) return self.client.get_json("/plugins/{0}/json", self.preferred_name)
except NotFound: except NotFound:
@ -217,12 +219,13 @@ class DockerPluginManager:
except APIError as e: except APIError as e:
self.client.fail(to_native(e)) self.client.fail(to_native(e))
def has_different_config(self): def has_different_config(self) -> DifferenceTracker:
""" """
Return the list of differences between the current parameters and the existing plugin. Return the list of differences between the current parameters and the existing plugin.
:return: list of options that differ :return: list of options that differ
""" """
assert self.existing_plugin is not None
differences = DifferenceTracker() differences = DifferenceTracker()
if self.parameters.plugin_options: if self.parameters.plugin_options:
settings = self.existing_plugin.get("Settings") settings = self.existing_plugin.get("Settings")
@ -249,7 +252,7 @@ class DockerPluginManager:
return differences return differences
def install_plugin(self): def install_plugin(self) -> None:
if not self.existing_plugin: if not self.existing_plugin:
if not self.check_mode: if not self.check_mode:
try: try:
@ -297,7 +300,7 @@ class DockerPluginManager:
self.actions.append(f"Installed plugin {self.preferred_name}") self.actions.append(f"Installed plugin {self.preferred_name}")
self.changed = True self.changed = True
def remove_plugin(self): def remove_plugin(self) -> None:
force = self.parameters.force_remove force = self.parameters.force_remove
if self.existing_plugin: if self.existing_plugin:
if not self.check_mode: if not self.check_mode:
@ -311,7 +314,7 @@ class DockerPluginManager:
self.actions.append(f"Removed plugin {self.preferred_name}") self.actions.append(f"Removed plugin {self.preferred_name}")
self.changed = True self.changed = True
def update_plugin(self): def update_plugin(self) -> None:
if self.existing_plugin: if self.existing_plugin:
differences = self.has_different_config() differences = self.has_different_config()
if not differences.empty: if not differences.empty:
@ -328,7 +331,7 @@ class DockerPluginManager:
else: else:
self.client.fail("Cannot update the plugin: Plugin does not exist") self.client.fail("Cannot update the plugin: Plugin does not exist")
def present(self): def present(self) -> None:
differences = DifferenceTracker() differences = DifferenceTracker()
if self.existing_plugin: if self.existing_plugin:
differences = self.has_different_config() differences = self.has_different_config()
@ -345,13 +348,10 @@ class DockerPluginManager:
if self.diff or self.check_mode or self.parameters.debug: if self.diff or self.check_mode or self.parameters.debug:
self.diff_tracker.merge(differences) self.diff_tracker.merge(differences)
if not self.check_mode and not self.parameters.debug: def absent(self) -> None:
self.actions = None
def absent(self):
self.remove_plugin() self.remove_plugin()
def enable(self): def enable(self) -> None:
timeout = self.parameters.enable_timeout timeout = self.parameters.enable_timeout
if self.existing_plugin: if self.existing_plugin:
if not self.existing_plugin.get("Enabled"): if not self.existing_plugin.get("Enabled"):
@ -380,7 +380,7 @@ class DockerPluginManager:
self.actions.append(f"Enabled plugin {self.preferred_name}") self.actions.append(f"Enabled plugin {self.preferred_name}")
self.changed = True self.changed = True
def disable(self): def disable(self) -> None:
if self.existing_plugin: if self.existing_plugin:
if self.existing_plugin.get("Enabled"): if self.existing_plugin.get("Enabled"):
if not self.check_mode: if not self.check_mode:
@ -396,7 +396,7 @@ class DockerPluginManager:
self.client.fail("Plugin not found: Plugin does not exist.") self.client.fail("Plugin not found: Plugin does not exist.")
@property @property
def result(self): def result(self) -> dict[str, t.Any]:
plugin_data = {} plugin_data = {}
if self.parameters.state != "absent": if self.parameters.state != "absent":
try: try:
@ -406,16 +406,22 @@ class DockerPluginManager:
except NotFound: except NotFound:
# This can happen in check mode # This can happen in check mode
pass pass
result = { result: dict[str, t.Any] = {
"actions": self.actions, "actions": self.actions,
"changed": self.changed, "changed": self.changed,
"diff": self.diff, "diff": self.diff,
"plugin": plugin_data, "plugin": plugin_data,
} }
return dict((k, v) for k, v in result.items() if v is not None) if (
self.parameters.state == "present"
and not self.check_mode
and not self.parameters.debug
):
result["actions"] = None
return {k: v for k, v in result.items() if v is not None}
def main(): def main() -> None:
argument_spec = { argument_spec = {
"alias": {"type": "str"}, "alias": {"type": "str"},
"plugin_name": {"type": "str", "required": True}, "plugin_name": {"type": "str", "required": True},

View File

@ -247,7 +247,7 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
) )
def main(): def main() -> None:
argument_spec = { argument_spec = {
"containers": {"type": "bool", "default": False}, "containers": {"type": "bool", "default": False},
"containers_filters": {"type": "dict"}, "containers_filters": {"type": "dict"},

View File

@ -118,6 +118,7 @@ volume:
""" """
import traceback import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_native from ansible.module_utils.common.text.converters import to_native
@ -137,31 +138,33 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
class TaskParameters(DockerBaseClass): class TaskParameters(DockerBaseClass):
def __init__(self, client): volume_name: str
def __init__(self, client: AnsibleDockerClient) -> None:
super().__init__() super().__init__()
self.client = client self.client = client
self.volume_name = None self.driver: str = "local"
self.driver = None self.driver_options: dict[str, t.Any] = {}
self.driver_options = None self.labels: dict[str, t.Any] | None = None
self.labels = None self.recreate: t.Literal["always", "never", "options-changed"] = "never"
self.recreate = None self.debug: bool = False
self.debug = None self.state: t.Literal["present", "absent"] = "present"
for key, value in client.module.params.items(): for key, value in client.module.params.items():
setattr(self, key, value) setattr(self, key, value)
class DockerVolumeManager: class DockerVolumeManager:
def __init__(self, client: AnsibleDockerClient) -> None:
def __init__(self, client):
self.client = client self.client = client
self.parameters = TaskParameters(client) self.parameters = TaskParameters(client)
self.check_mode = self.client.check_mode self.check_mode = self.client.check_mode
self.results = {"changed": False, "actions": []} self.actions: list[str] = []
self.results: dict[str, t.Any] = {"changed": False, "actions": self.actions}
self.diff = self.client.module._diff self.diff = self.client.module._diff
self.diff_tracker = DifferenceTracker() self.diff_tracker = DifferenceTracker()
self.diff_result = {} self.diff_result: dict[str, t.Any] = {}
self.existing_volume = self.get_existing_volume() self.existing_volume = self.get_existing_volume()
@ -178,7 +181,7 @@ class DockerVolumeManager:
) )
self.results["diff"] = self.diff_result self.results["diff"] = self.diff_result
def get_existing_volume(self): def get_existing_volume(self) -> dict[str, t.Any] | None:
try: try:
volumes = self.client.get_json("/volumes") volumes = self.client.get_json("/volumes")
except APIError as e: except APIError as e:
@ -193,12 +196,13 @@ class DockerVolumeManager:
return None return None
def has_different_config(self): def has_different_config(self) -> DifferenceTracker:
""" """
Return the list of differences between the current parameters and the existing volume. Return the list of differences between the current parameters and the existing volume.
:return: list of options that differ :return: list of options that differ
""" """
assert self.existing_volume is not None
differences = DifferenceTracker() differences = DifferenceTracker()
if ( if (
self.parameters.driver self.parameters.driver
@ -239,7 +243,7 @@ class DockerVolumeManager:
return differences return differences
def create_volume(self): def create_volume(self) -> None:
if not self.existing_volume: if not self.existing_volume:
if not self.check_mode: if not self.check_mode:
try: try:
@ -257,12 +261,12 @@ class DockerVolumeManager:
except APIError as e: except APIError as e:
self.client.fail(to_native(e)) self.client.fail(to_native(e))
self.results["actions"].append( self.actions.append(
f"Created volume {self.parameters.volume_name} with driver {self.parameters.driver}" f"Created volume {self.parameters.volume_name} with driver {self.parameters.driver}"
) )
self.results["changed"] = True self.results["changed"] = True
def remove_volume(self): def remove_volume(self) -> None:
if self.existing_volume: if self.existing_volume:
if not self.check_mode: if not self.check_mode:
try: try:
@ -270,12 +274,10 @@ class DockerVolumeManager:
except APIError as e: except APIError as e:
self.client.fail(to_native(e)) self.client.fail(to_native(e))
self.results["actions"].append( self.actions.append(f"Removed volume {self.parameters.volume_name}")
f"Removed volume {self.parameters.volume_name}"
)
self.results["changed"] = True self.results["changed"] = True
def present(self): def present(self) -> None:
differences = DifferenceTracker() differences = DifferenceTracker()
if self.existing_volume: if self.existing_volume:
differences = self.has_different_config() differences = self.has_different_config()
@ -301,14 +303,14 @@ class DockerVolumeManager:
volume_facts = self.get_existing_volume() volume_facts = self.get_existing_volume()
self.results["volume"] = volume_facts self.results["volume"] = volume_facts
def absent(self): def absent(self) -> None:
self.diff_tracker.add( self.diff_tracker.add(
"exists", parameter=False, active=self.existing_volume is not None "exists", parameter=False, active=self.existing_volume is not None
) )
self.remove_volume() self.remove_volume()
def main(): def main() -> None:
argument_spec = { argument_spec = {
"volume_name": {"type": "str", "required": True, "aliases": ["name"]}, "volume_name": {"type": "str", "required": True, "aliases": ["name"]},
"state": { "state": {

View File

@ -71,6 +71,7 @@ volume:
""" """
import traceback import traceback
import typing as t
from ansible_collections.community.docker.plugins.module_utils._api.errors import ( from ansible_collections.community.docker.plugins.module_utils._api.errors import (
DockerException, DockerException,
@ -82,7 +83,9 @@ from ansible_collections.community.docker.plugins.module_utils._common_api impor
) )
def get_existing_volume(client, volume_name): def get_existing_volume(
client: AnsibleDockerClient, volume_name: str
) -> dict[str, t.Any] | None:
try: try:
return client.get_json("/volumes/{0}", volume_name) return client.get_json("/volumes/{0}", volume_name)
except NotFound: except NotFound:
@ -91,7 +94,7 @@ def get_existing_volume(client, volume_name):
client.fail(f"Error inspecting volume: {exc}") client.fail(f"Error inspecting volume: {exc}")
def main(): def main() -> None:
argument_spec = { argument_spec = {
"name": {"type": "str", "required": True, "aliases": ["volume_name"]}, "name": {"type": "str", "required": True, "aliases": ["volume_name"]},
} }

View File

@ -7,6 +7,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from ansible.errors import AnsibleConnectionFailure from ansible.errors import AnsibleConnectionFailure
from ansible.utils.display import Display from ansible.utils.display import Display
@ -18,8 +20,17 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
) )
if t.TYPE_CHECKING:
from ansible.plugins import AnsiblePlugin
class AnsibleDockerClient(AnsibleDockerClientBase): class AnsibleDockerClient(AnsibleDockerClientBase):
def __init__(self, plugin, min_docker_version=None, min_docker_api_version=None): def __init__(
self,
plugin: AnsiblePlugin,
min_docker_version: str | None = None,
min_docker_api_version: str | None = None,
) -> None:
self.plugin = plugin self.plugin = plugin
self.display = Display() self.display = Display()
super().__init__( super().__init__(
@ -27,17 +38,23 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
min_docker_api_version=min_docker_api_version, min_docker_api_version=min_docker_api_version,
) )
def fail(self, msg, **kwargs): def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
if kwargs: if kwargs:
msg += "\nContext:\n" + "\n".join( msg += "\nContext:\n" + "\n".join(
f" {k} = {v!r}" for (k, v) in kwargs.items() f" {k} = {v!r}" for (k, v) in kwargs.items()
) )
raise AnsibleConnectionFailure(msg) raise AnsibleConnectionFailure(msg)
def deprecate(self, msg, version=None, date=None, collection_name=None): def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
self.display.deprecated( self.display.deprecated(
msg, version=version, date=date, collection_name=collection_name msg, version=version, date=date, collection_name=collection_name
) )
def _get_params(self): def _get_params(self) -> dict[str, t.Any]:
return {option: self.plugin.get_option(option) for option in DOCKER_COMMON_ARGS} return {option: self.plugin.get_option(option) for option in DOCKER_COMMON_ARGS}

View File

@ -7,6 +7,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from ansible.errors import AnsibleConnectionFailure from ansible.errors import AnsibleConnectionFailure
from ansible.utils.display import Display from ansible.utils.display import Display
@ -18,23 +20,35 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
) )
if t.TYPE_CHECKING:
from ansible.plugins import AnsiblePlugin
class AnsibleDockerClient(AnsibleDockerClientBase): class AnsibleDockerClient(AnsibleDockerClientBase):
def __init__(self, plugin, min_docker_api_version=None): def __init__(
self, plugin: AnsiblePlugin, min_docker_api_version: str | None = None
) -> None:
self.plugin = plugin self.plugin = plugin
self.display = Display() self.display = Display()
super().__init__(min_docker_api_version=min_docker_api_version) super().__init__(min_docker_api_version=min_docker_api_version)
def fail(self, msg, **kwargs): def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
if kwargs: if kwargs:
msg += "\nContext:\n" + "\n".join( msg += "\nContext:\n" + "\n".join(
f" {k} = {v!r}" for (k, v) in kwargs.items() f" {k} = {v!r}" for (k, v) in kwargs.items()
) )
raise AnsibleConnectionFailure(msg) raise AnsibleConnectionFailure(msg)
def deprecate(self, msg, version=None, date=None, collection_name=None): def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
self.display.deprecated( self.display.deprecated(
msg, version=version, date=date, collection_name=collection_name msg, version=version, date=date, collection_name=collection_name
) )
def _get_params(self): def _get_params(self) -> dict[str, t.Any]:
return {option: self.plugin.get_option(option) for option in DOCKER_COMMON_ARGS} return {option: self.plugin.get_option(option) for option in DOCKER_COMMON_ARGS}

View File

@ -7,11 +7,23 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from ansible_collections.community.docker.plugins.module_utils._socket_handler import ( from ansible_collections.community.docker.plugins.module_utils._socket_handler import (
DockerSocketHandlerBase, DockerSocketHandlerBase,
) )
if t.TYPE_CHECKING:
from ansible.utils.display import Display
from ansible_collections.community.docker.plugins.module_utils._socket_helper import (
SocketLike,
)
class DockerSocketHandler(DockerSocketHandlerBase): class DockerSocketHandler(DockerSocketHandlerBase):
def __init__(self, display, sock, log=None, container=None): def __init__(
self, display: Display, sock: SocketLike, container: str | None = None
) -> None:
super().__init__(sock, log=lambda msg: display.vvvv(msg, host=container)) super().__init__(sock, log=lambda msg: display.vvvv(msg, host=container))

View File

@ -8,6 +8,7 @@
from __future__ import annotations from __future__ import annotations
import re import re
import typing as t
from collections.abc import Mapping, Set from collections.abc import Mapping, Set
from ansible.module_utils.common.collections import is_sequence from ansible.module_utils.common.collections import is_sequence
@ -21,7 +22,7 @@ _RE_TEMPLATE_CHARS = re.compile("[{}]")
_RE_TEMPLATE_CHARS_BYTES = re.compile(b"[{}]") _RE_TEMPLATE_CHARS_BYTES = re.compile(b"[{}]")
def make_unsafe(value): def make_unsafe(value: t.Any) -> t.Any:
if value is None or isinstance(value, AnsibleUnsafe): if value is None or isinstance(value, AnsibleUnsafe):
return value return value

View File

@ -1 +1,22 @@
plugins/connection/docker.py no-assert
plugins/connection/docker_api.py no-assert
plugins/connection/nsenter.py no-assert
plugins/module_utils/_api/api/client.py pep8:E704
plugins/module_utils/_api/transport/sshconn.py no-assert
plugins/module_utils/_api/utils/build.py no-assert
plugins/module_utils/_api/utils/ports.py pep8:E704
plugins/module_utils/_api/utils/socket.py pep8:E704
plugins/module_utils/_common_cli.py pep8:E704
plugins/module_utils/_module_container/module.py no-assert
plugins/module_utils/_platform.py no-assert
plugins/module_utils/_socket_handler.py no-assert
plugins/module_utils/_swarm.py pep8:E704
plugins/module_utils/_util.py pep8:E704
plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin
plugins/modules/docker_container_exec.py no-assert
plugins/modules/docker_container_exec.py pylint:unpacking-non-sequence
plugins/modules/docker_image.py no-assert
plugins/modules/docker_image_tag.py no-assert
plugins/modules/docker_login.py no-assert
plugins/modules/docker_plugin.py no-assert
plugins/modules/docker_volume.py no-assert

View File

@ -1 +1,21 @@
plugins/connection/docker.py no-assert
plugins/connection/docker_api.py no-assert
plugins/connection/nsenter.py no-assert
plugins/module_utils/_api/api/client.py pep8:E704
plugins/module_utils/_api/transport/sshconn.py no-assert
plugins/module_utils/_api/utils/build.py no-assert
plugins/module_utils/_api/utils/ports.py pep8:E704
plugins/module_utils/_api/utils/socket.py pep8:E704
plugins/module_utils/_common_cli.py pep8:E704
plugins/module_utils/_module_container/module.py no-assert
plugins/module_utils/_platform.py no-assert
plugins/module_utils/_socket_handler.py no-assert
plugins/module_utils/_swarm.py pep8:E704
plugins/module_utils/_util.py pep8:E704
plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin
plugins/modules/docker_container_exec.py no-assert
plugins/modules/docker_image.py no-assert
plugins/modules/docker_image_tag.py no-assert
plugins/modules/docker_login.py no-assert
plugins/modules/docker_plugin.py no-assert
plugins/modules/docker_volume.py no-assert

View File

@ -1 +1,15 @@
plugins/connection/docker.py no-assert
plugins/connection/docker_api.py no-assert
plugins/connection/nsenter.py no-assert
plugins/module_utils/_api/transport/sshconn.py no-assert
plugins/module_utils/_api/utils/build.py no-assert
plugins/module_utils/_module_container/module.py no-assert
plugins/module_utils/_platform.py no-assert
plugins/module_utils/_socket_handler.py no-assert
plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin
plugins/modules/docker_container_exec.py no-assert
plugins/modules/docker_image.py no-assert
plugins/modules/docker_image_tag.py no-assert
plugins/modules/docker_login.py no-assert
plugins/modules/docker_plugin.py no-assert
plugins/modules/docker_volume.py no-assert

View File

@ -1 +1,15 @@
plugins/connection/docker.py no-assert
plugins/connection/docker_api.py no-assert
plugins/connection/nsenter.py no-assert
plugins/module_utils/_api/transport/sshconn.py no-assert
plugins/module_utils/_api/utils/build.py no-assert
plugins/module_utils/_module_container/module.py no-assert
plugins/module_utils/_platform.py no-assert
plugins/module_utils/_socket_handler.py no-assert
plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin
plugins/modules/docker_container_exec.py no-assert
plugins/modules/docker_image.py no-assert
plugins/modules/docker_image_tag.py no-assert
plugins/modules/docker_login.py no-assert
plugins/modules/docker_plugin.py no-assert
plugins/modules/docker_volume.py no-assert

View File

@ -1 +1,15 @@
plugins/connection/docker.py no-assert
plugins/connection/docker_api.py no-assert
plugins/connection/nsenter.py no-assert
plugins/module_utils/_api/transport/sshconn.py no-assert
plugins/module_utils/_api/utils/build.py no-assert
plugins/module_utils/_module_container/module.py no-assert
plugins/module_utils/_platform.py no-assert
plugins/module_utils/_socket_handler.py no-assert
plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin
plugins/modules/docker_container_exec.py no-assert
plugins/modules/docker_image.py no-assert
plugins/modules/docker_image_tag.py no-assert
plugins/modules/docker_login.py no-assert
plugins/modules/docker_plugin.py no-assert
plugins/modules/docker_volume.py no-assert

View File

@ -19,7 +19,7 @@ from ansible_collections.community.docker.plugins.module_utils._api.transport im
try: try:
from ssl import CertificateError, match_hostname from ssl import CertificateError, match_hostname # type: ignore
except ImportError: except ImportError:
HAS_MATCH_HOSTNAME = False # pylint: disable=invalid-name HAS_MATCH_HOSTNAME = False # pylint: disable=invalid-name
else: else:

View File

@ -12,8 +12,8 @@ import json
import os import os
import shutil import shutil
import tempfile import tempfile
import typing as t
import unittest import unittest
from collections.abc import Callable
from unittest import mock from unittest import mock
from pytest import fixture, mark from pytest import fixture, mark
@ -22,7 +22,7 @@ from ansible_collections.community.docker.plugins.module_utils._api.utils import
class FindConfigFileTest(unittest.TestCase): class FindConfigFileTest(unittest.TestCase):
mkdir: t.Callable[[str], os.PathLike[str]] mkdir: Callable[[str], os.PathLike[str]]
@fixture(autouse=True) @fixture(autouse=True)
def tmpdir(self, tmpdir): def tmpdir(self, tmpdir):

View File

@ -11,7 +11,7 @@ from ansible_collections.community.docker.plugins.module_utils._compose_v2 impor
) )
EVENT_TEST_CASES = [ EVENT_TEST_CASES: list[tuple[str, str, bool, bool, str, list[Event], list[Event]]] = [
# ####################################################################################################################### # #######################################################################################################################
# ## Docker Compose 2.18.1 ############################################################################################## # ## Docker Compose 2.18.1 ##############################################################################################
# ####################################################################################################################### # #######################################################################################################################