Add typing information, 1/2 (#1176)

* Re-enable typing and improve config.

* Make mypy pass.

* Improve settings.

* First batch of types.

* Add more type hints.

* Fixes.

* Format.

* Fix split_port() without returning to previous type chaos.

* Continue with type hints (and ignores).
This commit is contained in:
Felix Fontein 2025-10-23 07:05:42 +02:00 committed by GitHub
parent 24f35644e3
commit 3350283bcc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
92 changed files with 4366 additions and 2272 deletions

View File

@ -7,17 +7,25 @@
# disallow_untyped_defs = True -- for later
# strict = True -- only try to enable once everything (including dependencies!) is typed
# strict_equality = True -- for later
# strict_bytes = True -- for later
strict_equality = True
strict_bytes = True
# warn_redundant_casts = True -- for later
warn_redundant_casts = True
# warn_return_any = True -- for later
# warn_unreachable = True -- for later
warn_unreachable = True
[mypy-ansible.*]
# ansible-core has partial typing information
follow_untyped_imports = True
[mypy-docker.*]
# Docker SDK for Python has partial typing information
follow_untyped_imports = True
[mypy-ansible_collections.community.internal_test_tools.*]
# community.internal_test_tools has no typing information
ignore_missing_imports = True
[mypy-jsondiff.*]
# jsondiff has no typing information
ignore_missing_imports = True

View File

@ -27,7 +27,7 @@ run_yamllint = true
yamllint_config = ".yamllint"
yamllint_config_plugins = ".yamllint-docs"
yamllint_config_plugins_examples = ".yamllint-examples"
run_mypy = false
run_mypy = true
mypy_ansible_core_package = "ansible-core>=2.19.0"
mypy_config = ".mypy.ini"
mypy_extra_deps = [
@ -35,7 +35,11 @@ mypy_extra_deps = [
"paramiko",
"urllib3",
"requests",
"types-mock",
"types-paramiko",
"types-pywin32",
"types-PyYAML",
"types-requests",
]
[sessions.docs_check]

View File

@ -5,6 +5,7 @@
from __future__ import annotations
import base64
import typing as t
from ansible import constants as C
from ansible.plugins.action import ActionBase
@ -19,14 +20,17 @@ class ActionModule(ActionBase):
# Set to True when transferring files to the remote
TRANSFERS_FILES = False
def run(self, tmp=None, task_vars=None):
def run(
self, tmp: str | None = None, task_vars: dict[str, t.Any] | None = None
) -> dict[str, t.Any]:
self._supports_check_mode = True
self._supports_async = True
result = super().run(tmp, task_vars)
del tmp # tmp no longer has any effect
self._task.args["_max_file_size_for_diff"] = C.MAX_FILE_SIZE_FOR_DIFF
max_file_size_for_diff: int = C.MAX_FILE_SIZE_FOR_DIFF # type: ignore
self._task.args["_max_file_size_for_diff"] = max_file_size_for_diff
result = merge_hash(
result,

View File

@ -118,6 +118,7 @@ import os.path
import re
import selectors
import subprocess
import typing as t
from shlex import quote
from ansible.errors import AnsibleConnectionFailure, AnsibleError, AnsibleFileNotFound
@ -140,8 +141,8 @@ class Connection(ConnectionBase):
transport = "community.docker.docker"
has_pipelining = True
def __init__(self, play_context, new_stdin, *args, **kwargs):
super().__init__(play_context, new_stdin, *args, **kwargs)
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# Note: docker supports running as non-root in some configurations.
# (For instance, setting the UNIX socket file to be readable and
@ -152,11 +153,11 @@ class Connection(ConnectionBase):
# configured to be connected to by root and they are not running as
# root.
self._docker_args = []
self._container_user_cache = {}
self._version = None
self.remote_user = None
self.timeout = None
self._docker_args: list[bytes | str] = []
self._container_user_cache: dict[str, str | None] = {}
self._version: str | None = None
self.remote_user: str | None = None
self.timeout: int | float | None = None
# Windows uses Powershell modules
if getattr(self._shell, "_IS_WINDOWS", False):
@ -171,12 +172,12 @@ class Connection(ConnectionBase):
raise AnsibleError("docker command not found in PATH") from exc
@staticmethod
def _sanitize_version(version):
def _sanitize_version(version: str) -> str:
version = re.sub("[^0-9a-zA-Z.]", "", version)
version = re.sub("^v", "", version)
return version
def _old_docker_version(self):
def _old_docker_version(self) -> tuple[list[str], str, bytes, int]:
cmd_args = self._docker_args
old_version_subcommand = ["version"]
@ -189,7 +190,7 @@ class Connection(ConnectionBase):
return old_docker_cmd, to_native(cmd_output), err, p.returncode
def _new_docker_version(self):
def _new_docker_version(self) -> tuple[list[str], str, bytes, int]:
# no result yet, must be newer Docker version
cmd_args = self._docker_args
@ -202,8 +203,7 @@ class Connection(ConnectionBase):
cmd_output, err = p.communicate()
return new_docker_cmd, to_native(cmd_output), err, p.returncode
def _get_docker_version(self):
def _get_docker_version(self) -> str:
cmd, cmd_output, err, returncode = self._old_docker_version()
if returncode == 0:
for line in to_text(cmd_output, errors="surrogate_or_strict").split("\n"):
@ -218,7 +218,7 @@ class Connection(ConnectionBase):
return self._sanitize_version(to_text(cmd_output, errors="surrogate_or_strict"))
def _get_docker_remote_user(self):
def _get_docker_remote_user(self) -> str | None:
"""Get the default user configured in the docker container"""
container = self.get_option("remote_addr")
if container in self._container_user_cache:
@ -243,7 +243,7 @@ class Connection(ConnectionBase):
self._container_user_cache[container] = user
return user
def _build_exec_cmd(self, cmd):
def _build_exec_cmd(self, cmd: list[bytes | str]) -> list[bytes | str]:
"""Build the local docker exec command to run cmd on remote_host
If remote_user is available and is supported by the docker
@ -298,7 +298,7 @@ class Connection(ConnectionBase):
return local_cmd
def _set_docker_args(self):
def _set_docker_args(self) -> None:
# TODO: this is mostly for backwards compatibility, play_context is used as fallback for older versions
# docker arguments
del self._docker_args[:]
@ -308,7 +308,7 @@ class Connection(ConnectionBase):
if extra_args:
self._docker_args += extra_args.split(" ")
def _set_conn_data(self):
def _set_conn_data(self) -> None:
"""initialize for the connection, cannot do only in init since all data is not ready at that point"""
self._set_docker_args()
@ -323,8 +323,7 @@ class Connection(ConnectionBase):
self.timeout = self._play_context.timeout
@property
def docker_version(self):
def docker_version(self) -> str:
if not self._version:
self._set_docker_args()
@ -341,7 +340,7 @@ class Connection(ConnectionBase):
)
return self._version
def _get_actual_user(self):
def _get_actual_user(self) -> str | None:
if self.remote_user is not None:
# An explicit user is provided
if self.docker_version == "dev" or LooseVersion(
@ -353,7 +352,7 @@ class Connection(ConnectionBase):
actual_user = self._get_docker_remote_user()
if actual_user != self.get_option("remote_user"):
display.warning(
f'docker {self.docker_version} does not support remote_user, using container default: {actual_user or "?"}'
f"docker {self.docker_version} does not support remote_user, using container default: {actual_user or '?'}"
)
return actual_user
if self._display.verbosity > 2:
@ -363,9 +362,9 @@ class Connection(ConnectionBase):
return self._get_docker_remote_user()
return None
def _connect(self, port=None):
def _connect(self) -> t.Self:
"""Connect to the container. Nothing to do"""
super()._connect()
super()._connect() # type: ignore[safe-super]
if not self._connected:
self._set_conn_data()
actual_user = self._get_actual_user()
@ -374,13 +373,16 @@ class Connection(ConnectionBase):
host=self.get_option("remote_addr"),
)
self._connected = True
return self
def exec_command(self, cmd, in_data=None, sudoable=False):
def exec_command(
self, cmd: str, in_data: bytes | None = None, sudoable: bool = False
) -> tuple[int, bytes, bytes]:
"""Run a command on the docker host"""
self._set_conn_data()
super().exec_command(cmd, in_data=in_data, sudoable=sudoable)
super().exec_command(cmd, in_data=in_data, sudoable=sudoable) # type: ignore[safe-super]
local_cmd = self._build_exec_cmd([self._play_context.executable, "-c", cmd])
@ -395,6 +397,9 @@ class Connection(ConnectionBase):
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
) as p:
assert p.stdin is not None
assert p.stdout is not None
assert p.stderr is not None
display.debug("done running command with Popen()")
if self.become and self.become.expect_prompt() and sudoable:
@ -489,10 +494,10 @@ class Connection(ConnectionBase):
remote_path = os.path.join(os.path.sep, remote_path)
return os.path.normpath(remote_path)
def put_file(self, in_path, out_path):
def put_file(self, in_path: str, out_path: str) -> None:
"""Transfer a file from local to docker container"""
self._set_conn_data()
super().put_file(in_path, out_path)
super().put_file(in_path, out_path) # type: ignore[safe-super]
display.vvv(f"PUT {in_path} TO {out_path}", host=self.get_option("remote_addr"))
out_path = self._prefix_login_path(out_path)
@ -534,10 +539,10 @@ class Connection(ConnectionBase):
f"failed to transfer file {to_native(in_path)} to {to_native(out_path)}:\n{to_native(stdout)}\n{to_native(stderr)}"
)
def fetch_file(self, in_path, out_path):
def fetch_file(self, in_path: str, out_path: str) -> None:
"""Fetch a file from container to local."""
self._set_conn_data()
super().fetch_file(in_path, out_path)
super().fetch_file(in_path, out_path) # type: ignore[safe-super]
display.vvv(
f"FETCH {in_path} TO {out_path}", host=self.get_option("remote_addr")
)
@ -596,7 +601,7 @@ class Connection(ConnectionBase):
if pp.returncode != 0:
raise AnsibleError(
f"failed to fetch file {in_path} to {out_path}:\n{stdout}\n{stderr}"
f"failed to fetch file {in_path} to {out_path}:\n{stdout!r}\n{stderr!r}"
)
# Rename if needed
@ -606,11 +611,11 @@ class Connection(ConnectionBase):
to_bytes(out_path, errors="strict"),
)
def close(self):
def close(self) -> None:
"""Terminate the connection. Nothing to do for Docker"""
super().close()
super().close() # type: ignore[safe-super]
self._connected = False
def reset(self):
def reset(self) -> None:
# Clear container user cache
self._container_user_cache = {}

View File

@ -107,6 +107,7 @@ options:
import os
import os.path
import typing as t
from ansible.errors import AnsibleConnectionFailure, AnsibleFileNotFound
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
@ -138,6 +139,12 @@ from ansible_collections.community.docker.plugins.plugin_utils._socket_handler i
)
if t.TYPE_CHECKING:
from collections.abc import Callable
_T = t.TypeVar("_T")
MIN_DOCKER_API = None
@ -150,10 +157,16 @@ class Connection(ConnectionBase):
transport = "community.docker.docker_api"
has_pipelining = True
def _call_client(self, f, not_found_can_be_resource=False):
def _call_client(
self,
f: Callable[[AnsibleDockerClient], _T],
not_found_can_be_resource: bool = False,
) -> _T:
if self.client is None:
raise AssertionError("Client must be present")
remote_addr = self.get_option("remote_addr")
try:
return f()
return f(self.client)
except NotFound as e:
if not_found_can_be_resource:
raise AnsibleConnectionFailure(
@ -179,21 +192,21 @@ class Connection(ConnectionBase):
f'An unexpected requests error occurred for container "{remote_addr}" when trying to talk to the Docker daemon: {e}'
)
def __init__(self, play_context, new_stdin, *args, **kwargs):
super().__init__(play_context, new_stdin, *args, **kwargs)
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.client = None
self.ids = {}
self.client: AnsibleDockerClient | None = None
self.ids: dict[str | None, tuple[int, int]] = {}
# Windows uses Powershell modules
if getattr(self._shell, "_IS_WINDOWS", False):
self.module_implementation_preferences = (".ps1", ".exe", "")
self.actual_user = None
self.actual_user: str | None = None
def _connect(self, port=None):
def _connect(self) -> Connection:
"""Connect to the container. Nothing to do"""
super()._connect()
super()._connect() # type: ignore[safe-super]
if not self._connected:
self.actual_user = self.get_option("remote_user")
display.vvv(
@ -212,7 +225,7 @@ class Connection(ConnectionBase):
# This saves overhead from calling into docker when we do not need to
display.vvv("Trying to determine actual user")
result = self._call_client(
lambda: self.client.get_json(
lambda client: client.get_json(
"/containers/{0}/json", self.get_option("remote_addr")
)
)
@ -221,12 +234,19 @@ class Connection(ConnectionBase):
if self.actual_user is not None:
display.vvv(f"Actual user is '{self.actual_user}'")
def exec_command(self, cmd, in_data=None, sudoable=False):
return self
def exec_command(
self, cmd: str, in_data: bytes | None = None, sudoable: bool = False
) -> tuple[int, bytes, bytes]:
"""Run a command on the docker host"""
super().exec_command(cmd, in_data=in_data, sudoable=sudoable)
super().exec_command(cmd, in_data=in_data, sudoable=sudoable) # type: ignore[safe-super]
command = [self._play_context.executable, "-c", to_text(cmd)]
if self.client is None:
raise AssertionError("Client must be present")
command = [self._play_context.executable, "-c", cmd]
do_become = self.become and self.become.expect_prompt() and sudoable
@ -277,7 +297,7 @@ class Connection(ConnectionBase):
)
exec_data = self._call_client(
lambda: self.client.post_json_to_json(
lambda client: client.post_json_to_json(
"/containers/{0}/exec", self.get_option("remote_addr"), data=data
)
)
@ -286,7 +306,7 @@ class Connection(ConnectionBase):
data = {"Tty": False, "Detach": False}
if need_stdin:
exec_socket = self._call_client(
lambda: self.client.post_json_to_stream_socket(
lambda client: client.post_json_to_stream_socket(
"/exec/{0}/start", exec_id, data=data
)
)
@ -295,6 +315,8 @@ class Connection(ConnectionBase):
display, exec_socket, container=self.get_option("remote_addr")
) as exec_socket_handler:
if do_become:
assert self.become is not None
become_output = [b""]
def append_become_output(stream_id, data):
@ -339,7 +361,7 @@ class Connection(ConnectionBase):
exec_socket.close()
else:
stdout, stderr = self._call_client(
lambda: self.client.post_json_to_stream(
lambda client: client.post_json_to_stream(
"/exec/{0}/start",
exec_id,
stream=False,
@ -350,12 +372,12 @@ class Connection(ConnectionBase):
)
result = self._call_client(
lambda: self.client.get_json("/exec/{0}/json", exec_id)
lambda client: client.get_json("/exec/{0}/json", exec_id)
)
return result.get("ExitCode") or 0, stdout or b"", stderr or b""
def _prefix_login_path(self, remote_path):
def _prefix_login_path(self, remote_path: str) -> str:
"""Make sure that we put files into a standard path
If a path is relative, then we need to choose where to put it.
@ -373,19 +395,23 @@ class Connection(ConnectionBase):
remote_path = os.path.join(os.path.sep, remote_path)
return os.path.normpath(remote_path)
def put_file(self, in_path, out_path):
def put_file(self, in_path: str, out_path: str) -> None:
"""Transfer a file from local to docker container"""
super().put_file(in_path, out_path)
super().put_file(in_path, out_path) # type: ignore[safe-super]
display.vvv(f"PUT {in_path} TO {out_path}", host=self.get_option("remote_addr"))
if self.client is None:
raise AssertionError("Client must be present")
out_path = self._prefix_login_path(out_path)
if self.actual_user not in self.ids:
dummy, ids, dummy = self.exec_command(b"id -u && id -g")
dummy, ids, dummy2 = self.exec_command("id -u && id -g")
remote_addr = self.get_option("remote_addr")
try:
user_id, group_id = ids.splitlines()
self.ids[self.actual_user] = int(user_id), int(group_id)
b_user_id, b_group_id = ids.splitlines()
user_id, group_id = int(b_user_id), int(b_group_id)
self.ids[self.actual_user] = user_id, group_id
display.vvvv(
f'PUT: Determined uid={user_id} and gid={group_id} for user "{self.actual_user}"',
host=remote_addr,
@ -398,8 +424,8 @@ class Connection(ConnectionBase):
user_id, group_id = self.ids[self.actual_user]
try:
self._call_client(
lambda: put_file(
self.client,
lambda client: put_file(
client,
container=self.get_option("remote_addr"),
in_path=in_path,
out_path=out_path,
@ -415,19 +441,22 @@ class Connection(ConnectionBase):
except DockerFileCopyError as exc:
raise AnsibleConnectionFailure(to_native(exc)) from exc
def fetch_file(self, in_path, out_path):
def fetch_file(self, in_path: str, out_path: str) -> None:
"""Fetch a file from container to local."""
super().fetch_file(in_path, out_path)
super().fetch_file(in_path, out_path) # type: ignore[safe-super]
display.vvv(
f"FETCH {in_path} TO {out_path}", host=self.get_option("remote_addr")
)
if self.client is None:
raise AssertionError("Client must be present")
in_path = self._prefix_login_path(in_path)
try:
self._call_client(
lambda: fetch_file(
self.client,
lambda client: fetch_file(
client,
container=self.get_option("remote_addr"),
in_path=in_path,
out_path=out_path,
@ -443,10 +472,10 @@ class Connection(ConnectionBase):
except DockerFileCopyError as exc:
raise AnsibleConnectionFailure(to_native(exc)) from exc
def close(self):
def close(self) -> None:
"""Terminate the connection. Nothing to do for Docker"""
super().close()
super().close() # type: ignore[safe-super]
self._connected = False
def reset(self):
def reset(self) -> None:
self.ids.clear()

View File

@ -44,7 +44,9 @@ import fcntl
import os
import pty
import selectors
import shlex
import subprocess
import typing as t
import ansible.constants as C
from ansible.errors import AnsibleError
@ -63,12 +65,12 @@ class Connection(ConnectionBase):
transport = "community.docker.nsenter"
has_pipelining = False
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.cwd = None
self._nsenter_pid = None
def _connect(self):
def _connect(self) -> t.Self:
self._nsenter_pid = self.get_option("nsenter_pid")
# Because nsenter requires very high privileges, our remote user
@ -83,12 +85,15 @@ class Connection(ConnectionBase):
self._connected = True
return self
def exec_command(self, cmd, in_data=None, sudoable=True):
super().exec_command(cmd, in_data=in_data, sudoable=sudoable)
def exec_command(
self, cmd: str, in_data: bytes | None = None, sudoable: bool = True
) -> tuple[int, bytes, bytes]:
super().exec_command(cmd, in_data=in_data, sudoable=sudoable) # type: ignore[safe-super]
display.debug("in nsenter.exec_command()")
executable = C.DEFAULT_EXECUTABLE.split()[0] if C.DEFAULT_EXECUTABLE else None
def_executable: str | None = C.DEFAULT_EXECUTABLE # type: ignore[attr-defined]
executable = def_executable.split()[0] if def_executable else None
if not os.path.exists(to_bytes(executable, errors="surrogate_or_strict")):
raise AnsibleError(
@ -109,12 +114,8 @@ class Connection(ConnectionBase):
"--",
]
if isinstance(cmd, (str, bytes)):
cmd_parts = nsenter_cmd_parts + [cmd]
cmd = to_bytes(" ".join(cmd_parts))
else:
cmd_parts = nsenter_cmd_parts + cmd
cmd = [to_bytes(arg) for arg in cmd_parts]
cmd_parts = nsenter_cmd_parts + [cmd]
cmd = to_bytes(" ".join(cmd_parts))
display.vvv(f"EXEC {to_text(cmd)}", host=self._play_context.remote_addr)
display.debug("opening command with Popen()")
@ -143,6 +144,9 @@ class Connection(ConnectionBase):
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
) as p:
assert p.stderr is not None
assert p.stdin is not None
assert p.stdout is not None
# if we created a master, we can close the other half of the pty now, otherwise master is stdin
if master is not None:
os.close(stdin)
@ -234,8 +238,8 @@ class Connection(ConnectionBase):
display.debug("done with nsenter.exec_command()")
return (p.returncode, stdout, stderr)
def put_file(self, in_path, out_path):
super().put_file(in_path, out_path)
def put_file(self, in_path: str, out_path: str) -> None:
super().put_file(in_path, out_path) # type: ignore[safe-super]
in_path = unfrackpath(in_path, basedir=self.cwd)
out_path = unfrackpath(out_path, basedir=self.cwd)
@ -245,26 +249,30 @@ class Connection(ConnectionBase):
with open(to_bytes(in_path, errors="surrogate_or_strict"), "rb") as in_file:
in_data = in_file.read()
rc, dummy_out, err = self.exec_command(
cmd=["tee", out_path], in_data=in_data
cmd=f"tee {shlex.quote(out_path)}", in_data=in_data
)
if rc != 0:
raise AnsibleError(f"failed to transfer file to {out_path}: {err}")
raise AnsibleError(
f"failed to transfer file to {out_path}: {to_text(err)}"
)
except IOError as e:
raise AnsibleError(f"failed to transfer file to {out_path}: {e}") from e
def fetch_file(self, in_path, out_path):
super().fetch_file(in_path, out_path)
def fetch_file(self, in_path: str, out_path: str) -> None:
super().fetch_file(in_path, out_path) # type: ignore[safe-super]
in_path = unfrackpath(in_path, basedir=self.cwd)
out_path = unfrackpath(out_path, basedir=self.cwd)
try:
rc, out, err = self.exec_command(cmd=["cat", in_path])
rc, out, err = self.exec_command(cmd=f"cat {shlex.quote(in_path)}")
display.vvv(
f"FETCH {in_path} TO {out_path}", host=self._play_context.remote_addr
)
if rc != 0:
raise AnsibleError(f"failed to transfer file to {in_path}: {err}")
raise AnsibleError(
f"failed to transfer file to {in_path}: {to_text(err)}"
)
with open(
to_bytes(out_path, errors="surrogate_or_strict"), "wb"
) as out_file:
@ -274,6 +282,6 @@ class Connection(ConnectionBase):
f"failed to transfer file to {to_native(out_path)}: {e}"
) from e
def close(self):
def close(self) -> None:
"""terminate the connection; nothing to do here"""
self._connected = False

View File

@ -169,6 +169,7 @@ filters:
"""
import re
import typing as t
from ansible.errors import AnsibleError
from ansible.plugins.inventory import BaseInventoryPlugin, Constructable
@ -195,6 +196,11 @@ from ansible_collections.community.docker.plugins.plugin_utils._unsafe import (
)
if t.TYPE_CHECKING:
from ansible.inventory.data import InventoryData
from ansible.parsing.dataloader import DataLoader
MIN_DOCKER_API = None
@ -203,11 +209,11 @@ class InventoryModule(BaseInventoryPlugin, Constructable):
NAME = "community.docker.docker_containers"
def _slugify(self, value):
def _slugify(self, value: str) -> str:
slug = re.sub(r"[^\w-]", "_", value).lower().lstrip("_")
return f"docker_{slug}"
def _populate(self, client):
def _populate(self, client: AnsibleDockerClient) -> None:
strict = self.get_option("strict")
ssh_port = self.get_option("private_ssh_port")
@ -217,6 +223,9 @@ class InventoryModule(BaseInventoryPlugin, Constructable):
connection_type = self.get_option("connection_type")
add_legacy_groups = self.get_option("add_legacy_groups")
if self.inventory is None:
raise AssertionError("Inventory must be there")
try:
params = {
"limit": -1,
@ -298,7 +307,7 @@ class InventoryModule(BaseInventoryPlugin, Constructable):
# Lookup the public facing port Nat'ed to ssh port.
network_settings = inspect.get("NetworkSettings") or {}
port_settings = network_settings.get("Ports") or {}
port = port_settings.get(f"{ssh_port}/tcp")[0]
port = port_settings.get(f"{ssh_port}/tcp")[0] # type: ignore[index]
except (IndexError, AttributeError, TypeError):
port = {}
@ -387,16 +396,22 @@ class InventoryModule(BaseInventoryPlugin, Constructable):
else:
self.inventory.add_host(name, group="stopped")
def verify_file(self, path):
def verify_file(self, path: str) -> bool:
"""Return the possibly of a file being consumable by this plugin."""
return super().verify_file(path) and path.endswith(
("docker.yaml", "docker.yml")
)
def _create_client(self):
def _create_client(self) -> AnsibleDockerClient:
return AnsibleDockerClient(self, min_docker_api_version=MIN_DOCKER_API)
def parse(self, inventory, loader, path, cache=True):
def parse(
self,
inventory: InventoryData,
loader: DataLoader,
path: str,
cache: bool = True,
) -> None:
super().parse(inventory, loader, path, cache)
self._read_config_data(path)
client = self._create_client()

View File

@ -101,6 +101,7 @@ compose:
import json
import re
import subprocess
import typing as t
from ansible.errors import AnsibleError
from ansible.module_utils.common.process import get_bin_path
@ -117,6 +118,15 @@ from ansible_collections.community.docker.plugins.plugin_utils._unsafe import (
)
if t.TYPE_CHECKING:
from ansible.inventory.data import InventoryData
from ansible.parsing.dataloader import DataLoader
DaemonEnv = t.Literal[
"require", "require-silently", "optional", "optional-silently", "skip"
]
display = Display()
@ -125,9 +135,9 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
NAME = "community.docker.docker_machine"
docker_machine_path = None
docker_machine_path: str | None = None
def _run_command(self, args):
def _run_command(self, args: list[str]) -> str:
if not self.docker_machine_path:
try:
self.docker_machine_path = get_bin_path("docker-machine")
@ -147,7 +157,7 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
return to_text(result).strip()
def _get_docker_daemon_variables(self, machine_name):
def _get_docker_daemon_variables(self, machine_name: str) -> list[tuple[str, str]]:
"""
Capture settings from Docker Machine that would be needed to connect to the remote Docker daemon installed on
the Docker Machine remote host. Note: passing '--shell=sh' is a workaround for 'Error: Unknown shell'.
@ -180,7 +190,7 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
return env_vars
def _get_machine_names(self):
def _get_machine_names(self) -> list[str]:
# Filter out machines that are not in the Running state as we probably cannot do anything useful actions
# with them.
ls_command = ["ls", "-q"]
@ -194,7 +204,7 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
return ls_lines.splitlines()
def _inspect_docker_machine_host(self, node):
def _inspect_docker_machine_host(self, node: str) -> t.Any | None:
try:
inspect_lines = self._run_command(["inspect", node])
except subprocess.CalledProcessError:
@ -202,7 +212,7 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
return json.loads(inspect_lines)
def _ip_addr_docker_machine_host(self, node):
def _ip_addr_docker_machine_host(self, node: str) -> t.Any | None:
try:
ip_addr = self._run_command(["ip", node])
except subprocess.CalledProcessError:
@ -210,7 +220,9 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
return ip_addr
def _should_skip_host(self, machine_name, env_var_tuples, daemon_env):
def _should_skip_host(
self, machine_name: str, env_var_tuples, daemon_env: DaemonEnv
) -> bool:
if not env_var_tuples:
warning_prefix = f"Unable to fetch Docker daemon env vars from Docker Machine for host {machine_name}"
if daemon_env in ("require", "require-silently"):
@ -224,8 +236,11 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
# daemon_env is 'optional-silently'
return False
def _populate(self):
daemon_env = self.get_option("daemon_env")
def _populate(self) -> None:
if self.inventory is None:
raise AssertionError("Inventory must be there")
daemon_env: DaemonEnv = self.get_option("daemon_env")
filters = parse_filters(self.get_option("filters"))
try:
for node in self._get_machine_names():
@ -325,13 +340,19 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
f"Unable to fetch hosts from Docker Machine, this was the original exception: {e}"
) from e
def verify_file(self, path):
def verify_file(self, path: str) -> bool:
"""Return the possibility of a file being consumable by this plugin."""
return super().verify_file(path) and path.endswith(
("docker_machine.yaml", "docker_machine.yml")
)
def parse(self, inventory, loader, path, cache=True):
def parse(
self,
inventory: InventoryData,
loader: DataLoader,
path: str,
cache: bool = True,
) -> None:
super().parse(inventory, loader, path, cache)
self._read_config_data(path)
self._populate()

View File

@ -148,6 +148,8 @@ keyed_groups:
prefix: label
"""
import typing as t
from ansible.errors import AnsibleError
from ansible.parsing.utils.addresses import parse_address
from ansible.plugins.inventory import BaseInventoryPlugin, Constructable
@ -167,6 +169,11 @@ from ansible_collections.community.docker.plugins.plugin_utils._unsafe import (
)
if t.TYPE_CHECKING:
from ansible.inventory.data import InventoryData
from ansible.parsing.dataloader import DataLoader
try:
import docker
@ -180,10 +187,13 @@ class InventoryModule(BaseInventoryPlugin, Constructable):
NAME = "community.docker.docker_swarm"
def _fail(self, msg):
def _fail(self, msg: str) -> t.NoReturn:
raise AnsibleError(msg)
def _populate(self):
def _populate(self) -> None:
if self.inventory is None:
raise AssertionError("Inventory must be there")
raw_params = {
"docker_host": self.get_option("docker_host"),
"tls": self.get_option("tls"),
@ -307,13 +317,19 @@ class InventoryModule(BaseInventoryPlugin, Constructable):
f"Unable to fetch hosts from Docker swarm API, this was the original exception: {e}"
) from e
def verify_file(self, path):
def verify_file(self, path: str) -> bool:
"""Return the possibly of a file being consumable by this plugin."""
return super().verify_file(path) and path.endswith(
("docker_swarm.yaml", "docker_swarm.yml")
)
def parse(self, inventory, loader, path, cache=True):
def parse(
self,
inventory: InventoryData,
loader: DataLoader,
path: str,
cache: bool = True,
) -> None:
if not HAS_DOCKER:
raise AnsibleError(
"The Docker swarm dynamic inventory plugin requires the Docker SDK for Python: "

View File

@ -12,8 +12,10 @@
from __future__ import annotations
import traceback
import typing as t
REQUESTS_IMPORT_ERROR: str | None # pylint: disable=invalid-name
try:
from requests import Session # noqa: F401, pylint: disable=unused-import
from requests.adapters import ( # noqa: F401, pylint: disable=unused-import
@ -26,28 +28,29 @@ try:
except ImportError:
REQUESTS_IMPORT_ERROR = traceback.format_exc() # pylint: disable=invalid-name
class Session:
__attrs__ = []
class Session: # type: ignore
__attrs__: list[t.Never] = []
class HTTPAdapter:
__attrs__ = []
class HTTPAdapter: # type: ignore
__attrs__: list[t.Never] = []
class HTTPError(Exception):
class HTTPError(Exception): # type: ignore
pass
class InvalidSchema(Exception):
class InvalidSchema(Exception): # type: ignore
pass
else:
REQUESTS_IMPORT_ERROR = None # pylint: disable=invalid-name
URLLIB3_IMPORT_ERROR = None # pylint: disable=invalid-name
URLLIB3_IMPORT_ERROR: str | None = None # pylint: disable=invalid-name
try:
from requests.packages import urllib3 # pylint: disable=unused-import
# pylint: disable-next=unused-import
from requests.packages.urllib3 import connection as urllib3_connection
from requests.packages.urllib3 import ( # type: ignore # pylint: disable=unused-import # isort: skip
connection as urllib3_connection,
)
except ImportError:
try:
import urllib3 # pylint: disable=unused-import

View File

@ -13,8 +13,9 @@ from __future__ import annotations
import json
import logging
import os
import struct
from functools import partial
import typing as t
from urllib.parse import quote
from .. import auth
@ -47,16 +48,21 @@ from ..transport.sshconn import PARAMIKO_IMPORT_ERROR, SSHHTTPAdapter
from ..transport.ssladapter import SSLHTTPAdapter
from ..transport.unixconn import UnixHTTPAdapter
from ..utils import config, json_stream, utils
from ..utils.decorators import update_headers
from ..utils.decorators import minimum_version, update_headers
from ..utils.proxy import ProxyConfig
from ..utils.socket import consume_socket_output, demux_adaptor, frames_iter
from .daemon import DaemonApiMixin
if t.TYPE_CHECKING:
from requests import Response
from ..._socket_helper import SocketLike
log = logging.getLogger(__name__)
class APIClient(_Session, DaemonApiMixin):
class APIClient(_Session):
"""
A low-level client for the Docker Engine API.
@ -105,16 +111,16 @@ class APIClient(_Session, DaemonApiMixin):
def __init__(
self,
base_url=None,
version=None,
timeout=DEFAULT_TIMEOUT_SECONDS,
tls=False,
user_agent=DEFAULT_USER_AGENT,
num_pools=None,
credstore_env=None,
use_ssh_client=False,
max_pool_size=DEFAULT_MAX_POOL_SIZE,
):
base_url: str | None = None,
version: str | None = None,
timeout: int | float = DEFAULT_TIMEOUT_SECONDS,
tls: bool | TLSConfig = False,
user_agent: str = DEFAULT_USER_AGENT,
num_pools: int | None = None,
credstore_env: dict[str, str] | None = None,
use_ssh_client: bool = False,
max_pool_size: int = DEFAULT_MAX_POOL_SIZE,
) -> None:
super().__init__()
fail_on_missing_imports()
@ -124,7 +130,6 @@ class APIClient(_Session, DaemonApiMixin):
"If using TLS, the base_url argument must be provided."
)
self.base_url = base_url
self.timeout = timeout
self.headers["User-Agent"] = user_agent
@ -145,6 +150,7 @@ class APIClient(_Session, DaemonApiMixin):
self.credstore_env = credstore_env
base_url = utils.parse_host(base_url, IS_WINDOWS_PLATFORM, tls=bool(tls))
self.base_url = base_url
# SSH has a different default for num_pools to all other adapters
num_pools = (
num_pools or DEFAULT_NUM_POOLS_SSH
@ -152,6 +158,9 @@ class APIClient(_Session, DaemonApiMixin):
else DEFAULT_NUM_POOLS
)
self._custom_adapter: (
UnixHTTPAdapter | NpipeHTTPAdapter | SSHHTTPAdapter | SSLHTTPAdapter | None
) = None
if base_url.startswith("http+unix://"):
self._custom_adapter = UnixHTTPAdapter(
base_url,
@ -223,7 +232,7 @@ class APIClient(_Session, DaemonApiMixin):
f"API versions below {MINIMUM_DOCKER_API_VERSION} are no longer supported by this library."
)
def _retrieve_server_version(self):
def _retrieve_server_version(self) -> str:
try:
version_result = self.version(api_version=False)
except Exception as e:
@ -242,54 +251,87 @@ class APIClient(_Session, DaemonApiMixin):
f"Error while fetching server API version: {e}. Response seems to be broken."
) from e
def _set_request_timeout(self, kwargs):
def _set_request_timeout(self, kwargs: dict[str, t.Any]) -> dict[str, t.Any]:
"""Prepare the kwargs for an HTTP request by inserting the timeout
parameter, if not already present."""
kwargs.setdefault("timeout", self.timeout)
return kwargs
@update_headers
def _post(self, url, **kwargs):
def _post(self, url: str, **kwargs):
return self.post(url, **self._set_request_timeout(kwargs))
@update_headers
def _get(self, url, **kwargs):
def _get(self, url: str, **kwargs):
return self.get(url, **self._set_request_timeout(kwargs))
@update_headers
def _head(self, url, **kwargs):
def _head(self, url: str, **kwargs):
return self.head(url, **self._set_request_timeout(kwargs))
@update_headers
def _put(self, url, **kwargs):
def _put(self, url: str, **kwargs):
return self.put(url, **self._set_request_timeout(kwargs))
@update_headers
def _delete(self, url, **kwargs):
def _delete(self, url: str, **kwargs):
return self.delete(url, **self._set_request_timeout(kwargs))
def _url(self, pathfmt, *args, **kwargs):
def _url(self, pathfmt: str, *args: str, versioned_api: bool = True) -> str:
for arg in args:
if not isinstance(arg, str):
raise ValueError(
f"Expected a string but found {arg} ({type(arg)}) instead"
)
quote_f = partial(quote, safe="/:")
args = map(quote_f, args)
q_args = [quote(arg, safe="/:") for arg in args]
if kwargs.get("versioned_api", True):
return f"{self.base_url}/v{self._version}{pathfmt.format(*args)}"
return f"{self.base_url}{pathfmt.format(*args)}"
if versioned_api:
return f"{self.base_url}/v{self._version}{pathfmt.format(*q_args)}"
return f"{self.base_url}{pathfmt.format(*q_args)}"
def _raise_for_status(self, response):
def _raise_for_status(self, response: Response) -> None:
"""Raises stored :class:`APIError`, if one occurred."""
try:
response.raise_for_status()
except _HTTPError as e:
create_api_error_from_http_exception(e)
def _result(self, response, get_json=False, get_binary=False):
@t.overload
def _result(
self,
response: Response,
*,
get_json: t.Literal[False] = False,
get_binary: t.Literal[False] = False,
) -> str: ...
@t.overload
def _result(
self,
response: Response,
*,
get_json: t.Literal[True],
get_binary: t.Literal[False] = False,
) -> t.Any: ...
@t.overload
def _result(
self,
response: Response,
*,
get_json: t.Literal[False] = False,
get_binary: t.Literal[True],
) -> bytes: ...
@t.overload
def _result(
self, response: Response, *, get_json: bool = False, get_binary: bool = False
) -> t.Any | str | bytes: ...
def _result(
self, response: Response, *, get_json: bool = False, get_binary: bool = False
) -> t.Any | str | bytes:
if get_json and get_binary:
raise AssertionError("json and binary must not be both True")
self._raise_for_status(response)
@ -300,10 +342,12 @@ class APIClient(_Session, DaemonApiMixin):
return response.content
return response.text
def _post_json(self, url, data, **kwargs):
def _post_json(
self, url: str, data: dict[str, str | None] | t.Any, **kwargs
) -> Response:
# Go <1.1 cannot unserialize null to a string
# so we do this disgusting thing here.
data2 = {}
data2: dict[str, t.Any] = {}
if data is not None and isinstance(data, dict):
for k, v in data.items():
if v is not None:
@ -316,19 +360,19 @@ class APIClient(_Session, DaemonApiMixin):
kwargs["headers"]["Content-Type"] = "application/json"
return self._post(url, data=json.dumps(data2), **kwargs)
def _attach_params(self, override=None):
def _attach_params(self, override: dict[str, int] | None = None) -> dict[str, int]:
return override or {"stdout": 1, "stderr": 1, "stream": 1}
def _get_raw_response_socket(self, response):
def _get_raw_response_socket(self, response: Response) -> SocketLike:
self._raise_for_status(response)
if self.base_url == "http+docker://localnpipe":
sock = response.raw._fp.fp.raw.sock
sock = response.raw._fp.fp.raw.sock # type: ignore[union-attr]
elif self.base_url.startswith("http+docker://ssh"):
sock = response.raw._fp.fp.channel
sock = response.raw._fp.fp.channel # type: ignore[union-attr]
else:
sock = response.raw._fp.fp.raw
sock = response.raw._fp.fp.raw # type: ignore[union-attr]
if self.base_url.startswith("https://"):
sock = sock._sock
sock = sock._sock # type: ignore[union-attr]
try:
# Keep a reference to the response to stop it being garbage
# collected. If the response is garbage collected, it will
@ -341,12 +385,26 @@ class APIClient(_Session, DaemonApiMixin):
return sock
def _stream_helper(self, response, decode=False):
@t.overload
def _stream_helper(
self, response: Response, *, decode: t.Literal[False] = False
) -> t.Generator[bytes]: ...
@t.overload
def _stream_helper(
self, response: Response, *, decode: t.Literal[True]
) -> t.Generator[t.Any]: ...
def _stream_helper(
self, response: Response, *, decode: bool = False
) -> t.Generator[t.Any]:
"""Generator for data coming from a chunked-encoded HTTP response."""
if response.raw._fp.chunked:
if response.raw._fp.chunked: # type: ignore[union-attr]
if decode:
yield from json_stream.json_stream(self._stream_helper(response, False))
yield from json_stream.json_stream(
self._stream_helper(response, decode=False)
)
else:
reader = response.raw
while not reader.closed:
@ -354,15 +412,15 @@ class APIClient(_Session, DaemonApiMixin):
data = reader.read(1)
if not data:
break
if reader._fp.chunk_left:
data += reader.read(reader._fp.chunk_left)
if reader._fp.chunk_left: # type: ignore[union-attr]
data += reader.read(reader._fp.chunk_left) # type: ignore[union-attr]
yield data
else:
# Response is not chunked, meaning we probably
# encountered an error immediately
yield self._result(response, get_json=decode)
def _multiplexed_buffer_helper(self, response):
def _multiplexed_buffer_helper(self, response: Response) -> t.Generator[bytes]:
"""A generator of multiplexed data blocks read from a buffered
response."""
buf = self._result(response, get_binary=True)
@ -378,7 +436,9 @@ class APIClient(_Session, DaemonApiMixin):
walker = end
yield buf[start:end]
def _multiplexed_response_stream_helper(self, response):
def _multiplexed_response_stream_helper(
self, response: Response
) -> t.Generator[bytes]:
"""A generator of multiplexed data blocks coming from a response
stream."""
@ -399,7 +459,19 @@ class APIClient(_Session, DaemonApiMixin):
break
yield data
def _stream_raw_result(self, response, chunk_size=1, decode=True):
@t.overload
def _stream_raw_result(
self, response: Response, *, chunk_size: int = 1, decode: t.Literal[True] = True
) -> t.Generator[str]: ...
@t.overload
def _stream_raw_result(
self, response: Response, *, chunk_size: int = 1, decode: t.Literal[False]
) -> t.Generator[bytes]: ...
def _stream_raw_result(
self, response: Response, *, chunk_size: int = 1, decode: bool = True
) -> t.Generator[str | bytes]:
"""Stream result for TTY-enabled container and raw binary data"""
self._raise_for_status(response)
@ -410,14 +482,81 @@ class APIClient(_Session, DaemonApiMixin):
yield from response.iter_content(chunk_size, decode)
def _read_from_socket(self, response, stream, tty=True, demux=False):
@t.overload
def _read_from_socket(
self,
response: Response,
*,
stream: t.Literal[True],
tty: bool = True,
demux: t.Literal[False] = False,
) -> t.Generator[bytes]: ...
@t.overload
def _read_from_socket(
self,
response: Response,
*,
stream: t.Literal[True],
tty: t.Literal[True] = True,
demux: t.Literal[True],
) -> t.Generator[tuple[bytes, None]]: ...
@t.overload
def _read_from_socket(
self,
response: Response,
*,
stream: t.Literal[True],
tty: t.Literal[False],
demux: t.Literal[True],
) -> t.Generator[tuple[bytes, None] | tuple[None, bytes]]: ...
@t.overload
def _read_from_socket(
self,
response: Response,
*,
stream: t.Literal[False],
tty: bool = True,
demux: t.Literal[False] = False,
) -> bytes: ...
@t.overload
def _read_from_socket(
self,
response: Response,
*,
stream: t.Literal[False],
tty: t.Literal[True] = True,
demux: t.Literal[True],
) -> tuple[bytes, None]: ...
@t.overload
def _read_from_socket(
self,
response: Response,
*,
stream: t.Literal[False],
tty: t.Literal[False],
demux: t.Literal[True],
) -> tuple[bytes, bytes]: ...
@t.overload
def _read_from_socket(
self, response: Response, *, stream: bool, tty: bool = True, demux: bool = False
) -> t.Any: ...
def _read_from_socket(
self, response: Response, *, stream: bool, tty: bool = True, demux: bool = False
) -> t.Any:
"""Consume all data from the socket, close the response and return the
data. If stream=True, then a generator is returned instead and the
caller is responsible for closing the response.
"""
socket = self._get_raw_response_socket(response)
gen = frames_iter(socket, tty)
gen: t.Generator = frames_iter(socket, tty)
if demux:
# The generator will output tuples (stdout, stderr)
@ -434,7 +573,7 @@ class APIClient(_Session, DaemonApiMixin):
finally:
response.close()
def _disable_socket_timeout(self, socket):
def _disable_socket_timeout(self, socket: SocketLike) -> None:
"""Depending on the combination of python version and whether we are
connecting over http or https, we might need to access _sock, which
may or may not exist; or we may need to just settimeout on socket
@ -451,18 +590,38 @@ class APIClient(_Session, DaemonApiMixin):
if not hasattr(s, "settimeout"):
continue
timeout = -1
timeout: int | float | None = -1
if hasattr(s, "gettimeout"):
timeout = s.gettimeout()
timeout = s.gettimeout() # type: ignore[union-attr]
# Do not change the timeout if it is already disabled.
if timeout is None or timeout == 0.0:
continue
s.settimeout(None)
s.settimeout(None) # type: ignore[union-attr]
def _get_result_tty(self, stream, res, is_tty):
@t.overload
def _get_result_tty(
self, stream: t.Literal[True], res: Response, is_tty: t.Literal[True]
) -> t.Generator[str]: ...
@t.overload
def _get_result_tty(
self, stream: t.Literal[True], res: Response, is_tty: t.Literal[False]
) -> t.Generator[bytes]: ...
@t.overload
def _get_result_tty(
self, stream: t.Literal[False], res: Response, is_tty: t.Literal[True]
) -> bytes: ...
@t.overload
def _get_result_tty(
self, stream: t.Literal[False], res: Response, is_tty: t.Literal[False]
) -> bytes: ...
def _get_result_tty(self, stream: bool, res: Response, is_tty: bool) -> t.Any:
# We should also use raw streaming (without keep-alive)
# if we are dealing with a tty-enabled container.
if is_tty:
@ -478,11 +637,11 @@ class APIClient(_Session, DaemonApiMixin):
return self._multiplexed_response_stream_helper(res)
return sep.join(list(self._multiplexed_buffer_helper(res)))
def _unmount(self, *args):
def _unmount(self, *args) -> None:
for proto in args:
self.adapters.pop(proto)
def get_adapter(self, url):
def get_adapter(self, url: str):
try:
return super().get_adapter(url)
except _InvalidSchema as e:
@ -491,10 +650,10 @@ class APIClient(_Session, DaemonApiMixin):
raise e
@property
def api_version(self):
def api_version(self) -> str:
return self._version
def reload_config(self, dockercfg_path=None):
def reload_config(self, dockercfg_path: str | None = None) -> None:
"""
Force a reload of the auth configuration
@ -510,7 +669,7 @@ class APIClient(_Session, DaemonApiMixin):
dockercfg_path, credstore_env=self.credstore_env
)
def _set_auth_headers(self, headers):
def _set_auth_headers(self, headers: dict[str, str | bytes]) -> None:
log.debug("Looking for auth config")
# If we do not have any auth data so far, try reloading the config
@ -537,57 +696,62 @@ class APIClient(_Session, DaemonApiMixin):
else:
log.debug("No auth config found")
def get_binary(self, pathfmt, *args, **kwargs):
def get_binary(self, pathfmt: str, *args: str, **kwargs) -> bytes:
return self._result(
self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_binary=True,
)
def get_json(self, pathfmt, *args, **kwargs):
def get_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any:
return self._result(
self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_json=True,
)
def get_text(self, pathfmt, *args, **kwargs):
def get_text(self, pathfmt: str, *args: str, **kwargs) -> str:
return self._result(
self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs)
)
def get_raw_stream(self, pathfmt, *args, **kwargs):
chunk_size = kwargs.pop("chunk_size", DEFAULT_DATA_CHUNK_SIZE)
def get_raw_stream(
self,
pathfmt: str,
*args: str,
chunk_size: int = DEFAULT_DATA_CHUNK_SIZE,
**kwargs,
) -> t.Generator[bytes]:
res = self._get(
self._url(pathfmt, *args, versioned_api=True), stream=True, **kwargs
)
self._raise_for_status(res)
return self._stream_raw_result(res, chunk_size, False)
return self._stream_raw_result(res, chunk_size=chunk_size, decode=False)
def delete_call(self, pathfmt, *args, **kwargs):
def delete_call(self, pathfmt: str, *args: str, **kwargs) -> None:
self._raise_for_status(
self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs)
)
def delete_json(self, pathfmt, *args, **kwargs):
def delete_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any:
return self._result(
self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_json=True,
)
def post_call(self, pathfmt, *args, **kwargs):
def post_call(self, pathfmt: str, *args: str, **kwargs) -> None:
self._raise_for_status(
self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs)
)
def post_json(self, pathfmt, *args, **kwargs):
data = kwargs.pop("data", None)
def post_json(self, pathfmt: str, *args: str, data: t.Any = None, **kwargs) -> None:
self._raise_for_status(
self._post_json(
self._url(pathfmt, *args, versioned_api=True), data, **kwargs
)
)
def post_json_to_binary(self, pathfmt, *args, **kwargs):
data = kwargs.pop("data", None)
def post_json_to_binary(
self, pathfmt: str, *args: str, data: t.Any = None, **kwargs
) -> bytes:
return self._result(
self._post_json(
self._url(pathfmt, *args, versioned_api=True), data, **kwargs
@ -595,8 +759,9 @@ class APIClient(_Session, DaemonApiMixin):
get_binary=True,
)
def post_json_to_json(self, pathfmt, *args, **kwargs):
data = kwargs.pop("data", None)
def post_json_to_json(
self, pathfmt: str, *args: str, data: t.Any = None, **kwargs
) -> t.Any:
return self._result(
self._post_json(
self._url(pathfmt, *args, versioned_api=True), data, **kwargs
@ -604,17 +769,24 @@ class APIClient(_Session, DaemonApiMixin):
get_json=True,
)
def post_json_to_text(self, pathfmt, *args, **kwargs):
data = kwargs.pop("data", None)
def post_json_to_text(
self, pathfmt: str, *args: str, data: t.Any = None, **kwargs
) -> str:
return self._result(
self._post_json(
self._url(pathfmt, *args, versioned_api=True), data, **kwargs
),
)
def post_json_to_stream_socket(self, pathfmt, *args, **kwargs):
data = kwargs.pop("data", None)
headers = (kwargs.pop("headers", None) or {}).copy()
def post_json_to_stream_socket(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
**kwargs,
) -> SocketLike:
headers = headers.copy() if headers else {}
headers.update(
{
"Connection": "Upgrade",
@ -631,18 +803,102 @@ class APIClient(_Session, DaemonApiMixin):
)
)
def post_json_to_stream(self, pathfmt, *args, **kwargs):
data = kwargs.pop("data", None)
headers = (kwargs.pop("headers", None) or {}).copy()
@t.overload
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: t.Literal[True],
tty: bool = True,
demux: t.Literal[False] = False,
**kwargs,
) -> t.Generator[bytes]: ...
@t.overload
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: t.Literal[True],
tty: t.Literal[True] = True,
demux: t.Literal[True],
**kwargs,
) -> t.Generator[tuple[bytes, None]]: ...
@t.overload
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: t.Literal[True],
tty: t.Literal[False],
demux: t.Literal[True],
**kwargs,
) -> t.Generator[tuple[bytes, None] | tuple[None, bytes]]: ...
@t.overload
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: t.Literal[False],
tty: bool = True,
demux: t.Literal[False] = False,
**kwargs,
) -> bytes: ...
@t.overload
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: t.Literal[False],
tty: t.Literal[True] = True,
demux: t.Literal[True],
**kwargs,
) -> tuple[bytes, None]: ...
@t.overload
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: t.Literal[False],
tty: t.Literal[False],
demux: t.Literal[True],
**kwargs,
) -> tuple[bytes, bytes]: ...
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: bool = False,
demux: bool = False,
tty: bool = False,
**kwargs,
) -> t.Any:
headers = headers.copy() if headers else {}
headers.update(
{
"Connection": "Upgrade",
"Upgrade": "tcp",
}
)
stream = kwargs.pop("stream", False)
demux = kwargs.pop("demux", False)
tty = kwargs.pop("tty", False)
return self._read_from_socket(
self._post_json(
self._url(pathfmt, *args, versioned_api=True),
@ -651,13 +907,133 @@ class APIClient(_Session, DaemonApiMixin):
stream=True,
**kwargs,
),
stream,
stream=stream,
tty=tty,
demux=demux,
)
def post_to_json(self, pathfmt, *args, **kwargs):
def post_to_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any:
return self._result(
self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_json=True,
)
@minimum_version("1.25")
def df(self) -> dict[str, t.Any]:
"""
Get data usage information.
Returns:
(dict): A dictionary representing different resource categories
and their respective data usage.
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
url = self._url("/system/df")
return self._result(self._get(url), get_json=True)
def info(self) -> dict[str, t.Any]:
"""
Display system-wide information. Identical to the ``docker info``
command.
Returns:
(dict): The info as a dict
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
return self._result(self._get(self._url("/info")), get_json=True)
def login(
self,
username: str,
password: str | None = None,
email: str | None = None,
registry: str | None = None,
reauth: bool = False,
dockercfg_path: str | None = None,
) -> dict[str, t.Any]:
"""
Authenticate with a registry. Similar to the ``docker login`` command.
Args:
username (str): The registry username
password (str): The plaintext password
email (str): The email for the registry account
registry (str): URL to the registry. E.g.
``https://index.docker.io/v1/``
reauth (bool): Whether or not to refresh existing authentication on
the Docker server.
dockercfg_path (str): Use a custom path for the Docker config file
(default ``$HOME/.docker/config.json`` if present,
otherwise ``$HOME/.dockercfg``)
Returns:
(dict): The response from the login request
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
# If we do not have any auth data so far, try reloading the config file
# one more time in case anything showed up in there.
# If dockercfg_path is passed check to see if the config file exists,
# if so load that config.
if dockercfg_path and os.path.exists(dockercfg_path):
self._auth_configs = auth.load_config(
dockercfg_path, credstore_env=self.credstore_env
)
elif not self._auth_configs or self._auth_configs.is_empty:
self._auth_configs = auth.load_config(credstore_env=self.credstore_env)
authcfg = self._auth_configs.resolve_authconfig(registry)
# If we found an existing auth config for this registry and username
# combination, we can return it immediately unless reauth is requested.
if authcfg and authcfg.get("username", None) == username and not reauth:
return authcfg
req_data = {
"username": username,
"password": password,
"email": email,
"serveraddress": registry,
}
response = self._post_json(self._url("/auth"), data=req_data)
if response.status_code == 200:
self._auth_configs.add_auth(registry or auth.INDEX_NAME, req_data)
return self._result(response, get_json=True)
def ping(self) -> bool:
"""
Checks the server is responsive. An exception will be raised if it
is not responding.
Returns:
(bool) The response from the server.
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
return self._result(self._get(self._url("/_ping"))) == "OK"
def version(self, api_version: bool = True) -> dict[str, t.Any]:
"""
Returns version information from the server. Similar to the ``docker
version`` command.
Returns:
(dict): The server version information
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
url = self._url("/version", versioned_api=api_version)
return self._result(self._get(url), get_json=True)

View File

@ -1,139 +0,0 @@
# This code is part of the Ansible collection community.docker, but is an independent component.
# This particular file, and this file only, is based on the Docker SDK for Python (https://github.com/docker/docker-py/)
#
# Copyright (c) 2016-2022 Docker, Inc.
#
# It is licensed under the Apache 2.0 license (see LICENSES/Apache-2.0.txt in this collection)
# SPDX-License-Identifier: Apache-2.0
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import os
from .. import auth
from ..utils.decorators import minimum_version
class DaemonApiMixin:
@minimum_version("1.25")
def df(self):
"""
Get data usage information.
Returns:
(dict): A dictionary representing different resource categories
and their respective data usage.
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
url = self._url("/system/df")
return self._result(self._get(url), get_json=True)
def info(self):
"""
Display system-wide information. Identical to the ``docker info``
command.
Returns:
(dict): The info as a dict
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
return self._result(self._get(self._url("/info")), get_json=True)
def login(
self,
username,
password=None,
email=None,
registry=None,
reauth=False,
dockercfg_path=None,
):
"""
Authenticate with a registry. Similar to the ``docker login`` command.
Args:
username (str): The registry username
password (str): The plaintext password
email (str): The email for the registry account
registry (str): URL to the registry. E.g.
``https://index.docker.io/v1/``
reauth (bool): Whether or not to refresh existing authentication on
the Docker server.
dockercfg_path (str): Use a custom path for the Docker config file
(default ``$HOME/.docker/config.json`` if present,
otherwise ``$HOME/.dockercfg``)
Returns:
(dict): The response from the login request
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
# If we do not have any auth data so far, try reloading the config file
# one more time in case anything showed up in there.
# If dockercfg_path is passed check to see if the config file exists,
# if so load that config.
if dockercfg_path and os.path.exists(dockercfg_path):
self._auth_configs = auth.load_config(
dockercfg_path, credstore_env=self.credstore_env
)
elif not self._auth_configs or self._auth_configs.is_empty:
self._auth_configs = auth.load_config(credstore_env=self.credstore_env)
authcfg = self._auth_configs.resolve_authconfig(registry)
# If we found an existing auth config for this registry and username
# combination, we can return it immediately unless reauth is requested.
if authcfg and authcfg.get("username", None) == username and not reauth:
return authcfg
req_data = {
"username": username,
"password": password,
"email": email,
"serveraddress": registry,
}
response = self._post_json(self._url("/auth"), data=req_data)
if response.status_code == 200:
self._auth_configs.add_auth(registry or auth.INDEX_NAME, req_data)
return self._result(response, get_json=True)
def ping(self):
"""
Checks the server is responsive. An exception will be raised if it
is not responding.
Returns:
(bool) The response from the server.
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
return self._result(self._get(self._url("/_ping"))) == "OK"
def version(self, api_version=True):
"""
Returns version information from the server. Similar to the ``docker
version`` command.
Returns:
(dict): The server version information
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
url = self._url("/version", versioned_api=api_version)
return self._result(self._get(url), get_json=True)

View File

@ -14,6 +14,7 @@ from __future__ import annotations
import base64
import json
import logging
import typing as t
from . import errors
from .credentials.errors import CredentialsNotFound, StoreError
@ -21,6 +22,12 @@ from .credentials.store import Store
from .utils import config
if t.TYPE_CHECKING:
from ansible_collections.community.docker.plugins.module_utils._api.api.client import (
APIClient,
)
INDEX_NAME = "docker.io"
INDEX_URL = f"https://index.{INDEX_NAME}/v1/"
TOKEN_USERNAME = "<token>"
@ -28,7 +35,7 @@ TOKEN_USERNAME = "<token>"
log = logging.getLogger(__name__)
def resolve_repository_name(repo_name):
def resolve_repository_name(repo_name: str) -> tuple[str, str]:
if "://" in repo_name:
raise errors.InvalidRepository(
f"Repository name cannot contain a scheme ({repo_name})"
@ -42,14 +49,14 @@ def resolve_repository_name(repo_name):
return resolve_index_name(index_name), remote_name
def resolve_index_name(index_name):
def resolve_index_name(index_name: str) -> str:
index_name = convert_to_hostname(index_name)
if index_name == "index." + INDEX_NAME:
index_name = INDEX_NAME
return index_name
def get_config_header(client, registry):
def get_config_header(client: APIClient, registry: str) -> bytes | None:
log.debug("Looking for auth config")
if not client._auth_configs or client._auth_configs.is_empty:
log.debug("No auth config in memory - loading from filesystem")
@ -69,32 +76,38 @@ def get_config_header(client, registry):
return None
def split_repo_name(repo_name):
def split_repo_name(repo_name: str) -> tuple[str, str]:
parts = repo_name.split("/", 1)
if len(parts) == 1 or (
"." not in parts[0] and ":" not in parts[0] and parts[0] != "localhost"
):
# This is a docker index repo (ex: username/foobar or ubuntu)
return INDEX_NAME, repo_name
return tuple(parts)
return tuple(parts) # type: ignore
def get_credential_store(authconfig, registry):
def get_credential_store(
authconfig: dict[str, t.Any] | AuthConfig, registry: str
) -> str | None:
if not isinstance(authconfig, AuthConfig):
authconfig = AuthConfig(authconfig)
return authconfig.get_credential_store(registry)
class AuthConfig(dict):
def __init__(self, dct, credstore_env=None):
def __init__(
self, dct: dict[str, t.Any], credstore_env: dict[str, str] | None = None
):
if "auths" not in dct:
dct["auths"] = {}
self.update(dct)
self._credstore_env = credstore_env
self._stores = {}
self._stores: dict[str, Store] = {}
@classmethod
def parse_auth(cls, entries, raise_on_error=False):
def parse_auth(
cls, entries: dict[str, dict[str, t.Any]], raise_on_error=False
) -> dict[str, dict[str, t.Any]]:
"""
Parses authentication entries
@ -107,10 +120,10 @@ class AuthConfig(dict):
Authentication registry.
"""
conf = {}
conf: dict[str, dict[str, t.Any]] = {}
for registry, entry in entries.items():
if not isinstance(entry, dict):
log.debug("Config entry for key %s is not auth config", registry)
log.debug("Config entry for key %s is not auth config", registry) # type: ignore
# We sometimes fall back to parsing the whole config as if it
# was the auth config by itself, for legacy purposes. In that
# case, we fail silently and return an empty conf if any of the
@ -150,7 +163,12 @@ class AuthConfig(dict):
return conf
@classmethod
def load_config(cls, config_path, config_dict, credstore_env=None):
def load_config(
cls,
config_path: str | None,
config_dict: dict[str, t.Any] | None,
credstore_env: dict[str, str] | None = None,
) -> t.Self:
"""
Loads authentication data from a Docker configuration file in the given
root directory or if config_path is passed use given path.
@ -196,22 +214,24 @@ class AuthConfig(dict):
return cls({"auths": cls.parse_auth(config_dict)}, credstore_env)
@property
def auths(self):
def auths(self) -> dict[str, dict[str, t.Any]]:
return self.get("auths", {})
@property
def creds_store(self):
def creds_store(self) -> str | None:
return self.get("credsStore", None)
@property
def cred_helpers(self):
def cred_helpers(self) -> dict[str, t.Any]:
return self.get("credHelpers", {})
@property
def is_empty(self):
def is_empty(self) -> bool:
return not self.auths and not self.creds_store and not self.cred_helpers
def resolve_authconfig(self, registry=None):
def resolve_authconfig(
self, registry: str | None = None
) -> dict[str, t.Any] | None:
"""
Returns the authentication data from the given auth configuration for a
specific registry. As with the Docker client, legacy entries in the
@ -244,7 +264,9 @@ class AuthConfig(dict):
log.debug("No entry found")
return None
def _resolve_authconfig_credstore(self, registry, credstore_name):
def _resolve_authconfig_credstore(
self, registry: str | None, credstore_name: str
) -> dict[str, t.Any] | None:
if not registry or registry == INDEX_NAME:
# The ecosystem is a little schizophrenic with index.docker.io VS
# docker.io - in that case, it seems the full URL is necessary.
@ -272,19 +294,19 @@ class AuthConfig(dict):
except StoreError as e:
raise errors.DockerException(f"Credentials store error: {e}")
def _get_store_instance(self, name):
def _get_store_instance(self, name: str):
if name not in self._stores:
self._stores[name] = Store(name, environment=self._credstore_env)
return self._stores[name]
def get_credential_store(self, registry):
def get_credential_store(self, registry: str | None) -> str | None:
if not registry or registry == INDEX_NAME:
registry = INDEX_URL
return self.cred_helpers.get(registry) or self.creds_store
def get_all_credentials(self):
auth_data = self.auths.copy()
def get_all_credentials(self) -> dict[str, dict[str, t.Any] | None]:
auth_data: dict[str, dict[str, t.Any] | None] = self.auths.copy() # type: ignore
if self.creds_store:
# Retrieve all credentials from the default store
store = self._get_store_instance(self.creds_store)
@ -299,21 +321,23 @@ class AuthConfig(dict):
return auth_data
def add_auth(self, reg, data):
def add_auth(self, reg: str, data: dict[str, t.Any]) -> None:
self["auths"][reg] = data
def resolve_authconfig(authconfig, registry=None, credstore_env=None):
def resolve_authconfig(
authconfig, registry: str | None = None, credstore_env: dict[str, str] | None = None
):
if not isinstance(authconfig, AuthConfig):
authconfig = AuthConfig(authconfig, credstore_env)
return authconfig.resolve_authconfig(registry)
def convert_to_hostname(url):
def convert_to_hostname(url: str) -> str:
return url.replace("http://", "").replace("https://", "").split("/", 1)[0]
def decode_auth(auth):
def decode_auth(auth: str | bytes) -> tuple[str, str]:
if isinstance(auth, str):
auth = auth.encode("ascii")
s = base64.b64decode(auth)
@ -321,12 +345,14 @@ def decode_auth(auth):
return login.decode("utf8"), pwd.decode("utf8")
def encode_header(auth):
def encode_header(auth: dict[str, t.Any]) -> bytes:
auth_json = json.dumps(auth).encode("ascii")
return base64.urlsafe_b64encode(auth_json)
def parse_auth(entries, raise_on_error=False):
def parse_auth(
entries: dict[str, dict[str, t.Any]], raise_on_error: bool = False
) -> dict[str, dict[str, t.Any]]:
"""
Parses authentication entries
@ -342,11 +368,15 @@ def parse_auth(entries, raise_on_error=False):
return AuthConfig.parse_auth(entries, raise_on_error)
def load_config(config_path=None, config_dict=None, credstore_env=None):
def load_config(
config_path: str | None = None,
config_dict: dict[str, t.Any] | None = None,
credstore_env: dict[str, str] | None = None,
) -> AuthConfig:
return AuthConfig.load_config(config_path, config_dict, credstore_env)
def _load_legacy_config(config_file):
def _load_legacy_config(config_file: str) -> dict[str, dict[str, t.Any]]:
log.debug("Attempting to parse legacy auth file format")
try:
data = []

View File

@ -13,6 +13,7 @@ from __future__ import annotations
import json
import os
import typing as t
from .. import errors
from .config import (
@ -24,7 +25,11 @@ from .config import (
from .context import Context
def create_default_context():
if t.TYPE_CHECKING:
from ..tls import TLSConfig
def create_default_context() -> Context:
host = None
if os.environ.get("DOCKER_HOST"):
host = os.environ.get("DOCKER_HOST")
@ -42,7 +47,7 @@ class ContextAPI:
DEFAULT_CONTEXT = None
@classmethod
def get_default_context(cls):
def get_default_context(cls) -> Context:
context = cls.DEFAULT_CONTEXT
if context is None:
context = create_default_context()
@ -52,13 +57,13 @@ class ContextAPI:
@classmethod
def create_context(
cls,
name,
orchestrator=None,
host=None,
tls_cfg=None,
default_namespace=None,
skip_tls_verify=False,
):
name: str,
orchestrator: str | None = None,
host: str | None = None,
tls_cfg: TLSConfig | None = None,
default_namespace: str | None = None,
skip_tls_verify: bool = False,
) -> Context:
"""Creates a new context.
Returns:
(Context): a Context object.
@ -108,7 +113,7 @@ class ContextAPI:
return ctx
@classmethod
def get_context(cls, name=None):
def get_context(cls, name: str | None = None) -> Context | None:
"""Retrieves a context object.
Args:
name (str): The name of the context
@ -136,7 +141,7 @@ class ContextAPI:
return Context.load_context(name)
@classmethod
def contexts(cls):
def contexts(cls) -> list[Context]:
"""Context list.
Returns:
(Context): List of context objects.
@ -170,7 +175,7 @@ class ContextAPI:
return contexts
@classmethod
def get_current_context(cls):
def get_current_context(cls) -> Context | None:
"""Get current context.
Returns:
(Context): current context object.
@ -178,7 +183,7 @@ class ContextAPI:
return cls.get_context()
@classmethod
def set_current_context(cls, name="default"):
def set_current_context(cls, name: str = "default") -> None:
ctx = cls.get_context(name)
if not ctx:
raise errors.ContextNotFound(name)
@ -188,7 +193,7 @@ class ContextAPI:
raise errors.ContextException(f"Failed to set current context: {err}")
@classmethod
def remove_context(cls, name):
def remove_context(cls, name: str) -> None:
"""Remove a context. Similar to the ``docker context rm`` command.
Args:
@ -220,7 +225,7 @@ class ContextAPI:
ctx.remove()
@classmethod
def inspect_context(cls, name="default"):
def inspect_context(cls, name: str = "default") -> dict[str, t.Any]:
"""Inspect a context. Similar to the ``docker context inspect`` command.
Args:

View File

@ -23,7 +23,7 @@ from ..utils.utils import parse_host
METAFILE = "meta.json"
def get_current_context_name_with_source():
def get_current_context_name_with_source() -> tuple[str, str]:
if os.environ.get("DOCKER_HOST"):
return "default", "DOCKER_HOST environment variable set"
if os.environ.get("DOCKER_CONTEXT"):
@ -41,11 +41,11 @@ def get_current_context_name_with_source():
return "default", "fallback value"
def get_current_context_name():
def get_current_context_name() -> str:
return get_current_context_name_with_source()[0]
def write_context_name_to_docker_config(name=None):
def write_context_name_to_docker_config(name: str | None = None) -> Exception | None:
if name == "default":
name = None
docker_cfg_path = find_config_file()
@ -62,44 +62,45 @@ def write_context_name_to_docker_config(name=None):
elif name:
config["currentContext"] = name
else:
return
return None
if not docker_cfg_path:
docker_cfg_path = get_default_config_file()
try:
with open(docker_cfg_path, "wt", encoding="utf-8") as f:
json.dump(config, f, indent=4)
return None
except Exception as e: # pylint: disable=broad-exception-caught
return e
def get_context_id(name):
def get_context_id(name: str) -> str:
return hashlib.sha256(name.encode("utf-8")).hexdigest()
def get_context_dir():
def get_context_dir() -> str:
docker_cfg_path = find_config_file() or get_default_config_file()
return os.path.join(os.path.dirname(docker_cfg_path), "contexts")
def get_meta_dir(name=None):
def get_meta_dir(name: str | None = None) -> str:
meta_dir = os.path.join(get_context_dir(), "meta")
if name:
return os.path.join(meta_dir, get_context_id(name))
return meta_dir
def get_meta_file(name):
def get_meta_file(name) -> str:
return os.path.join(get_meta_dir(name), METAFILE)
def get_tls_dir(name=None, endpoint=""):
def get_tls_dir(name: str | None = None, endpoint: str = "") -> str:
context_dir = get_context_dir()
if name:
return os.path.join(context_dir, "tls", get_context_id(name), endpoint)
return os.path.join(context_dir, "tls")
def get_context_host(path=None, tls=False):
def get_context_host(path: str | None = None, tls: bool = False) -> str:
host = parse_host(path, IS_WINDOWS_PLATFORM, tls)
if host == DEFAULT_UNIX_SOCKET:
# remove http+ from default docker socket url

View File

@ -13,6 +13,7 @@ from __future__ import annotations
import json
import os
import typing as t
from shutil import copyfile, rmtree
from ..errors import ContextException
@ -33,21 +34,21 @@ class Context:
def __init__(
self,
name,
orchestrator=None,
host=None,
endpoints=None,
skip_tls_verify=False,
tls=False,
description=None,
):
name: str,
orchestrator: str | None = None,
host: str | None = None,
endpoints: dict[str, dict[str, t.Any]] | None = None,
skip_tls_verify: bool = False,
tls: bool = False,
description: str | None = None,
) -> None:
if not name:
raise ValueError("Name not provided")
self.name = name
self.context_type = None
self.orchestrator = orchestrator
self.endpoints = {}
self.tls_cfg = {}
self.tls_cfg: dict[str, TLSConfig] = {}
self.meta_path = IN_MEMORY
self.tls_path = IN_MEMORY
self.description = description
@ -89,12 +90,12 @@ class Context:
def set_endpoint(
self,
name="docker",
host=None,
tls_cfg=None,
skip_tls_verify=False,
def_namespace=None,
):
name: str = "docker",
host: str | None = None,
tls_cfg: TLSConfig | None = None,
skip_tls_verify: bool = False,
def_namespace: str | None = None,
) -> None:
self.endpoints[name] = {
"Host": get_context_host(host, not skip_tls_verify or tls_cfg is not None),
"SkipTLSVerify": skip_tls_verify,
@ -105,11 +106,11 @@ class Context:
if tls_cfg:
self.tls_cfg[name] = tls_cfg
def inspect(self):
def inspect(self) -> dict[str, t.Any]:
return self()
@classmethod
def load_context(cls, name):
def load_context(cls, name: str) -> t.Self | None:
meta = Context._load_meta(name)
if meta:
instance = cls(
@ -125,12 +126,12 @@ class Context:
return None
@classmethod
def _load_meta(cls, name):
def _load_meta(cls, name: str) -> dict[str, t.Any] | None:
meta_file = get_meta_file(name)
if not os.path.isfile(meta_file):
return None
metadata = {}
metadata: dict[str, t.Any] = {}
try:
with open(meta_file, "rt", encoding="utf-8") as f:
metadata = json.load(f)
@ -154,7 +155,7 @@ class Context:
return metadata
def _load_certs(self):
def _load_certs(self) -> None:
certs = {}
tls_dir = get_tls_dir(self.name)
for endpoint in self.endpoints:
@ -184,7 +185,7 @@ class Context:
self.tls_cfg = certs
self.tls_path = tls_dir
def save(self):
def save(self) -> None:
meta_dir = get_meta_dir(self.name)
if not os.path.isdir(meta_dir):
os.makedirs(meta_dir)
@ -216,54 +217,54 @@ class Context:
self.meta_path = get_meta_dir(self.name)
self.tls_path = get_tls_dir(self.name)
def remove(self):
def remove(self) -> None:
if os.path.isdir(self.meta_path):
rmtree(self.meta_path)
if os.path.isdir(self.tls_path):
rmtree(self.tls_path)
def __repr__(self):
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: '{self.name}'>"
def __str__(self):
def __str__(self) -> str:
return json.dumps(self.__call__(), indent=2)
def __call__(self):
def __call__(self) -> dict[str, t.Any]:
result = self.Metadata
result.update(self.TLSMaterial)
result.update(self.Storage)
return result
def is_docker_host(self):
def is_docker_host(self) -> bool:
return self.context_type is None
@property
def Name(self): # pylint: disable=invalid-name
def Name(self) -> str: # pylint: disable=invalid-name
return self.name
@property
def Host(self): # pylint: disable=invalid-name
def Host(self) -> str | None: # pylint: disable=invalid-name
if not self.orchestrator or self.orchestrator == "swarm":
endpoint = self.endpoints.get("docker", None)
if endpoint:
return endpoint.get("Host", None)
return endpoint.get("Host", None) # type: ignore
return None
return self.endpoints[self.orchestrator].get("Host", None)
return self.endpoints[self.orchestrator].get("Host", None) # type: ignore
@property
def Orchestrator(self): # pylint: disable=invalid-name
def Orchestrator(self) -> str | None: # pylint: disable=invalid-name
return self.orchestrator
@property
def Metadata(self): # pylint: disable=invalid-name
meta = {}
def Metadata(self) -> dict[str, t.Any]: # pylint: disable=invalid-name
meta: dict[str, t.Any] = {}
if self.orchestrator:
meta = {"StackOrchestrator": self.orchestrator}
return {"Name": self.name, "Metadata": meta, "Endpoints": self.endpoints}
@property
def TLSConfig(self): # pylint: disable=invalid-name
def TLSConfig(self) -> TLSConfig | None: # pylint: disable=invalid-name
key = self.orchestrator
if not key or key == "swarm":
key = "docker"
@ -272,13 +273,15 @@ class Context:
return None
@property
def TLSMaterial(self): # pylint: disable=invalid-name
certs = {}
def TLSMaterial(self) -> dict[str, t.Any]: # pylint: disable=invalid-name
certs: dict[str, t.Any] = {}
for endpoint, tls in self.tls_cfg.items():
cert, key = tls.cert
certs[endpoint] = list(map(os.path.basename, [tls.ca_cert, cert, key]))
paths = [tls.ca_cert, *tls.cert] if tls.cert else [tls.ca_cert]
certs[endpoint] = [
os.path.basename(path) if path else None for path in paths
]
return {"TLSMaterial": certs}
@property
def Storage(self): # pylint: disable=invalid-name
def Storage(self) -> dict[str, t.Any]: # pylint: disable=invalid-name
return {"Storage": {"MetadataPath": self.meta_path, "TLSPath": self.tls_path}}

View File

@ -11,6 +11,12 @@
from __future__ import annotations
import typing as t
if t.TYPE_CHECKING:
from subprocess import CalledProcessError
class StoreError(RuntimeError):
pass
@ -24,7 +30,7 @@ class InitializationError(StoreError):
pass
def process_store_error(cpe, program):
def process_store_error(cpe: CalledProcessError, program: str) -> StoreError:
message = cpe.output.decode("utf-8")
if "credentials not found in native keychain" in message:
return CredentialsNotFound(f"No matching credentials in {program}")

View File

@ -14,13 +14,14 @@ from __future__ import annotations
import errno
import json
import subprocess
import typing as t
from . import constants, errors
from .utils import create_environment_dict, find_executable
class Store:
def __init__(self, program, environment=None):
def __init__(self, program: str, environment: dict[str, str] | None = None) -> None:
"""Create a store object that acts as an interface to
perform the basic operations for storing, retrieving
and erasing credentials using `program`.
@ -33,7 +34,7 @@ class Store:
f"{self.program} not installed or not available in PATH"
)
def get(self, server):
def get(self, server: str | bytes) -> dict[str, t.Any]:
"""Retrieve credentials for `server`. If no credentials are found,
a `StoreError` will be raised.
"""
@ -53,7 +54,7 @@ class Store:
return result
def store(self, server, username, secret):
def store(self, server: str, username: str, secret: str) -> bytes:
"""Store credentials for `server`. Raises a `StoreError` if an error
occurs.
"""
@ -62,7 +63,7 @@ class Store:
).encode("utf-8")
return self._execute("store", data_input)
def erase(self, server):
def erase(self, server: str | bytes) -> None:
"""Erase credentials for `server`. Raises a `StoreError` if an error
occurs.
"""
@ -70,12 +71,16 @@ class Store:
server = server.encode("utf-8")
self._execute("erase", server)
def list(self):
def list(self) -> t.Any:
"""List stored credentials. Requires v0.4.0+ of the helper."""
data = self._execute("list", None)
return json.loads(data.decode("utf-8"))
def _execute(self, subcmd, data_input):
def _execute(self, subcmd: str, data_input: bytes | None) -> bytes:
if self.exe is None:
raise errors.StoreError(
f"{self.program} not installed or not available in PATH"
)
output = None
env = create_environment_dict(self.environment)
try:

View File

@ -15,7 +15,7 @@ import os
from shutil import which
def find_executable(executable, path=None):
def find_executable(executable: str, path: str | None = None) -> str | None:
"""
As distutils.spawn.find_executable, but on Windows, look up
every extension declared in PATHEXT instead of just `.exe`
@ -26,7 +26,7 @@ def find_executable(executable, path=None):
return which(executable, path=path)
def create_environment_dict(overrides):
def create_environment_dict(overrides: dict[str, str] | None) -> dict[str, str]:
"""
Create and return a copy of os.environ with the specified overrides
"""

View File

@ -11,6 +11,8 @@
from __future__ import annotations
import typing as t
from ansible.module_utils.common.text.converters import to_native
from ._import_helper import HTTPError as _HTTPError
@ -25,7 +27,7 @@ class DockerException(Exception):
"""
def create_api_error_from_http_exception(e):
def create_api_error_from_http_exception(e: _HTTPError) -> t.NoReturn:
"""
Create a suitable APIError from requests.exceptions.HTTPError.
"""
@ -52,14 +54,16 @@ class APIError(_HTTPError, DockerException):
An HTTP error from the API.
"""
def __init__(self, message, response=None, explanation=None):
def __init__(
self, message: str | Exception, response=None, explanation: str | None = None
) -> None:
# requests 1.2 supports response as a keyword argument, but
# requests 1.1 does not
super().__init__(message)
self.response = response
self.explanation = explanation
self.explanation = explanation or ""
def __str__(self):
def __str__(self) -> str:
message = super().__str__()
if self.is_client_error():
@ -74,19 +78,20 @@ class APIError(_HTTPError, DockerException):
return message
@property
def status_code(self):
def status_code(self) -> int | None:
if self.response is not None:
return self.response.status_code
return None
def is_error(self):
def is_error(self) -> bool:
return self.is_client_error() or self.is_server_error()
def is_client_error(self):
def is_client_error(self) -> bool:
if self.status_code is None:
return False
return 400 <= self.status_code < 500
def is_server_error(self):
def is_server_error(self) -> bool:
if self.status_code is None:
return False
return 500 <= self.status_code < 600
@ -121,10 +126,10 @@ class DeprecatedMethod(DockerException):
class TLSParameterError(DockerException):
def __init__(self, msg):
def __init__(self, msg: str) -> None:
self.msg = msg
def __str__(self):
def __str__(self) -> str:
return self.msg + (
". TLS configurations should map the Docker CLI "
"client configurations. See "
@ -142,7 +147,14 @@ class ContainerError(DockerException):
Represents a container that has exited with a non-zero exit code.
"""
def __init__(self, container, exit_status, command, image, stderr):
def __init__(
self,
container: str,
exit_status: int,
command: list[str],
image: str,
stderr: str | None,
):
self.container = container
self.exit_status = exit_status
self.command = command
@ -156,12 +168,12 @@ class ContainerError(DockerException):
class StreamParseError(RuntimeError):
def __init__(self, reason):
def __init__(self, reason: Exception) -> None:
self.msg = reason
class BuildError(DockerException):
def __init__(self, reason, build_log):
def __init__(self, reason: str, build_log: str) -> None:
super().__init__(reason)
self.msg = reason
self.build_log = build_log
@ -171,7 +183,7 @@ class ImageLoadError(DockerException):
pass
def create_unexpected_kwargs_error(name, kwargs):
def create_unexpected_kwargs_error(name: str, kwargs: dict[str, t.Any]) -> TypeError:
quoted_kwargs = [f"'{k}'" for k in sorted(kwargs)]
text = [f"{name}() "]
if len(quoted_kwargs) == 1:
@ -183,42 +195,44 @@ def create_unexpected_kwargs_error(name, kwargs):
class MissingContextParameter(DockerException):
def __init__(self, param):
def __init__(self, param: str) -> None:
self.param = param
def __str__(self):
def __str__(self) -> str:
return f"missing parameter: {self.param}"
class ContextAlreadyExists(DockerException):
def __init__(self, name):
def __init__(self, name: str) -> None:
self.name = name
def __str__(self):
def __str__(self) -> str:
return f"context {self.name} already exists"
class ContextException(DockerException):
def __init__(self, msg):
def __init__(self, msg: str) -> None:
self.msg = msg
def __str__(self):
def __str__(self) -> str:
return self.msg
class ContextNotFound(DockerException):
def __init__(self, name):
def __init__(self, name: str) -> None:
self.name = name
def __str__(self):
def __str__(self) -> str:
return f"context '{self.name}' not found"
class MissingRequirementException(DockerException):
def __init__(self, msg, requirement, import_exception):
def __init__(
self, msg: str, requirement: str, import_exception: ImportError | str
) -> None:
self.msg = msg
self.requirement = requirement
self.import_exception = import_exception
def __str__(self):
def __str__(self) -> str:
return self.msg

View File

@ -12,12 +12,18 @@
from __future__ import annotations
import os
import ssl
import typing as t
from . import errors
from .transport.ssladapter import SSLHTTPAdapter
if t.TYPE_CHECKING:
from ansible_collections.community.docker.plugins.module_utils._api.api.client import (
APIClient,
)
class TLSConfig:
"""
TLS configuration.
@ -27,25 +33,22 @@ class TLSConfig:
ca_cert (str): Path to CA cert file.
verify (bool or str): This can be ``False`` or a path to a CA cert
file.
ssl_version (int): A valid `SSL version`_.
assert_hostname (bool): Verify the hostname of the server.
.. _`SSL version`:
https://docs.python.org/3.5/library/ssl.html#ssl.PROTOCOL_TLSv1
"""
cert = None
ca_cert = None
verify = None
ssl_version = None
cert: tuple[str, str] | None = None
ca_cert: str | None = None
verify: bool | None = None
def __init__(
self,
client_cert=None,
ca_cert=None,
verify=None,
ssl_version=None,
assert_hostname=None,
client_cert: tuple[str, str] | None = None,
ca_cert: str | None = None,
verify: bool | None = None,
assert_hostname: bool | None = None,
):
# Argument compatibility/mapping with
# https://docs.docker.com/engine/articles/https/
@ -55,12 +58,6 @@ class TLSConfig:
self.assert_hostname = assert_hostname
# If the user provides an SSL version, we should use their preference
if ssl_version:
self.ssl_version = ssl_version
else:
self.ssl_version = ssl.PROTOCOL_TLS_CLIENT
# "client_cert" must have both or neither cert/key files. In
# either case, Alert the user when both are expected, but any are
# missing.
@ -90,11 +87,10 @@ class TLSConfig:
"Invalid CA certificate provided for `ca_cert`."
)
def configure_client(self, client):
def configure_client(self, client: APIClient) -> None:
"""
Configure a client with these TLS options.
"""
client.ssl_version = self.ssl_version
if self.verify and self.ca_cert:
client.verify = self.ca_cert
@ -107,7 +103,6 @@ class TLSConfig:
client.mount(
"https://",
SSLHTTPAdapter(
ssl_version=self.ssl_version,
assert_hostname=self.assert_hostname,
),
)

View File

@ -15,7 +15,7 @@ from .._import_helper import HTTPAdapter as _HTTPAdapter
class BaseHTTPAdapter(_HTTPAdapter):
def close(self):
def close(self) -> None:
super().close()
if hasattr(self, "pools"):
self.pools.clear()
@ -24,10 +24,10 @@ class BaseHTTPAdapter(_HTTPAdapter):
# https://github.com/psf/requests/commit/c0813a2d910ea6b4f8438b91d315b8d181302356
# changes requests.adapters.HTTPAdapter to no longer call get_connection() from
# send(), but instead call _get_connection().
def _get_connection(self, request, *args, **kwargs):
def _get_connection(self, request, *args, **kwargs): # type: ignore
return self.get_connection(request.url, kwargs.get("proxies"))
# Fix for requests 2.32.2+:
# https://github.com/psf/requests/commit/c98e4d133ef29c46a9b68cd783087218a8075e05
def get_connection_with_tls_context(self, request, verify, proxies=None, cert=None):
def get_connection_with_tls_context(self, request, verify, proxies=None, cert=None): # type: ignore
return self.get_connection(request.url, proxies)

View File

@ -23,12 +23,12 @@ RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
class NpipeHTTPConnection(urllib3_connection.HTTPConnection):
def __init__(self, npipe_path, timeout=60):
def __init__(self, npipe_path: str, timeout: int | float = 60) -> None:
super().__init__("localhost", timeout=timeout)
self.npipe_path = npipe_path
self.timeout = timeout
def connect(self):
def connect(self) -> None:
sock = NpipeSocket()
sock.settimeout(self.timeout)
sock.connect(self.npipe_path)
@ -36,18 +36,20 @@ class NpipeHTTPConnection(urllib3_connection.HTTPConnection):
class NpipeHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
def __init__(self, npipe_path, timeout=60, maxsize=10):
def __init__(
self, npipe_path: str, timeout: int | float = 60, maxsize: int = 10
) -> None:
super().__init__("localhost", timeout=timeout, maxsize=maxsize)
self.npipe_path = npipe_path
self.timeout = timeout
def _new_conn(self):
def _new_conn(self) -> NpipeHTTPConnection:
return NpipeHTTPConnection(self.npipe_path, self.timeout)
# When re-using connections, urllib3 tries to call select() on our
# NpipeSocket instance, causing a crash. To circumvent this, we override
# _get_conn, where that check happens.
def _get_conn(self, timeout):
def _get_conn(self, timeout: int | float) -> NpipeHTTPConnection:
conn = None
try:
conn = self.pool.get(block=self.block, timeout=timeout)
@ -67,7 +69,6 @@ class NpipeHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
class NpipeHTTPAdapter(BaseHTTPAdapter):
__attrs__ = HTTPAdapter.__attrs__ + [
"npipe_path",
"pools",
@ -77,11 +78,11 @@ class NpipeHTTPAdapter(BaseHTTPAdapter):
def __init__(
self,
base_url,
timeout=60,
pool_connections=constants.DEFAULT_NUM_POOLS,
max_pool_size=constants.DEFAULT_MAX_POOL_SIZE,
):
base_url: str,
timeout: int | float = 60,
pool_connections: int = constants.DEFAULT_NUM_POOLS,
max_pool_size: int = constants.DEFAULT_MAX_POOL_SIZE,
) -> None:
self.npipe_path = base_url.replace("npipe://", "")
self.timeout = timeout
self.max_pool_size = max_pool_size
@ -90,7 +91,7 @@ class NpipeHTTPAdapter(BaseHTTPAdapter):
)
super().__init__()
def get_connection(self, url, proxies=None):
def get_connection(self, url: str | bytes, proxies=None) -> NpipeHTTPConnectionPool:
with self.pools.lock:
pool = self.pools.get(url)
if pool:
@ -103,7 +104,7 @@ class NpipeHTTPAdapter(BaseHTTPAdapter):
return pool
def request_url(self, request, proxies):
def request_url(self, request, proxies) -> str:
# The select_proxy utility in requests errors out when the provided URL
# does not have a hostname, like is the case when using a UNIX socket.
# Since proxies are an irrelevant notion in the case of UNIX sockets

View File

@ -15,8 +15,10 @@ import functools
import io
import time
import traceback
import typing as t
PYWIN32_IMPORT_ERROR: str | None # pylint: disable=invalid-name
try:
import pywintypes
import win32api
@ -28,6 +30,13 @@ except ImportError:
else:
PYWIN32_IMPORT_ERROR = None # pylint: disable=invalid-name
if t.TYPE_CHECKING:
from collections.abc import Buffer, Callable
_Self = t.TypeVar("_Self")
_P = t.ParamSpec("_P")
_R = t.TypeVar("_R")
ERROR_PIPE_BUSY = 0xE7
SECURITY_SQOS_PRESENT = 0x100000
@ -36,10 +45,12 @@ SECURITY_ANONYMOUS = 0
MAXIMUM_RETRY_COUNT = 10
def check_closed(f):
def check_closed(
f: Callable[t.Concatenate[_Self, _P], _R],
) -> Callable[t.Concatenate[_Self, _P], _R]:
@functools.wraps(f)
def wrapped(self, *args, **kwargs):
if self._closed:
def wrapped(self: _Self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
if self._closed: # type: ignore
raise RuntimeError("Can not reuse socket after connection was closed.")
return f(self, *args, **kwargs)
@ -53,25 +64,25 @@ class NpipeSocket:
implemented.
"""
def __init__(self, handle=None):
def __init__(self, handle=None) -> None:
self._timeout = win32pipe.NMPWAIT_USE_DEFAULT_WAIT
self._handle = handle
self._address = None
self._address: str | None = None
self._closed = False
self.flags = None
self.flags: int | None = None
def accept(self):
def accept(self) -> t.NoReturn:
raise NotImplementedError()
def bind(self, address):
def bind(self, address) -> t.NoReturn:
raise NotImplementedError()
def close(self):
def close(self) -> None:
self._handle.Close()
self._closed = True
@check_closed
def connect(self, address, retry_count=0):
def connect(self, address, retry_count: int = 0) -> None:
try:
handle = win32file.CreateFile(
address,
@ -99,14 +110,14 @@ class NpipeSocket:
return self.connect(address, retry_count)
raise e
self.flags = win32pipe.GetNamedPipeInfo(handle)[0]
self.flags = win32pipe.GetNamedPipeInfo(handle)[0] # type: ignore
self._handle = handle
self._address = address
@check_closed
def connect_ex(self, address):
return self.connect(address)
def connect_ex(self, address) -> None:
self.connect(address)
@check_closed
def detach(self):
@ -114,25 +125,25 @@ class NpipeSocket:
return self._handle
@check_closed
def dup(self):
def dup(self) -> NpipeSocket:
return NpipeSocket(self._handle)
def getpeername(self):
def getpeername(self) -> str | None:
return self._address
def getsockname(self):
def getsockname(self) -> str | None:
return self._address
def getsockopt(self, level, optname, buflen=None):
def getsockopt(self, level, optname, buflen=None) -> t.NoReturn:
raise NotImplementedError()
def ioctl(self, control, option):
def ioctl(self, control, option) -> t.NoReturn:
raise NotImplementedError()
def listen(self, backlog):
def listen(self, backlog) -> t.NoReturn:
raise NotImplementedError()
def makefile(self, mode=None, bufsize=None):
def makefile(self, mode: str, bufsize: int | None = None):
if mode.strip("b") != "r":
raise NotImplementedError()
rawio = NpipeFileIOBase(self)
@ -141,30 +152,30 @@ class NpipeSocket:
return io.BufferedReader(rawio, buffer_size=bufsize)
@check_closed
def recv(self, bufsize, flags=0):
def recv(self, bufsize: int, flags: int = 0) -> str:
dummy_err, data = win32file.ReadFile(self._handle, bufsize)
return data
@check_closed
def recvfrom(self, bufsize, flags=0):
def recvfrom(self, bufsize: int, flags: int = 0) -> tuple[str, str | None]:
data = self.recv(bufsize, flags)
return (data, self._address)
@check_closed
def recvfrom_into(self, buf, nbytes=0, flags=0):
return self.recv_into(buf, nbytes, flags), self._address
def recvfrom_into(
self, buf: Buffer, nbytes: int = 0, flags: int = 0
) -> tuple[int, str | None]:
return self.recv_into(buf, nbytes), self._address
@check_closed
def recv_into(self, buf, nbytes=0):
readbuf = buf
if not isinstance(buf, memoryview):
readbuf = memoryview(buf)
def recv_into(self, buf: Buffer, nbytes: int = 0) -> int:
readbuf = buf if isinstance(buf, memoryview) else memoryview(buf)
event = win32event.CreateEvent(None, True, True, None)
try:
overlapped = pywintypes.OVERLAPPED()
overlapped.hEvent = event
dummy_err, dummy_data = win32file.ReadFile(
dummy_err, dummy_data = win32file.ReadFile( # type: ignore
self._handle, readbuf[:nbytes] if nbytes else readbuf, overlapped
)
wait_result = win32event.WaitForSingleObject(event, self._timeout)
@ -176,12 +187,12 @@ class NpipeSocket:
win32api.CloseHandle(event)
@check_closed
def send(self, string, flags=0):
def send(self, string: Buffer, flags: int = 0) -> int:
event = win32event.CreateEvent(None, True, True, None)
try:
overlapped = pywintypes.OVERLAPPED()
overlapped.hEvent = event
win32file.WriteFile(self._handle, string, overlapped)
win32file.WriteFile(self._handle, string, overlapped) # type: ignore
wait_result = win32event.WaitForSingleObject(event, self._timeout)
if wait_result == win32event.WAIT_TIMEOUT:
win32file.CancelIo(self._handle)
@ -191,20 +202,20 @@ class NpipeSocket:
win32api.CloseHandle(event)
@check_closed
def sendall(self, string, flags=0):
def sendall(self, string: Buffer, flags: int = 0) -> int:
return self.send(string, flags)
@check_closed
def sendto(self, string, address):
def sendto(self, string: Buffer, address: str) -> int:
self.connect(address)
return self.send(string)
def setblocking(self, flag):
def setblocking(self, flag: bool):
if flag:
return self.settimeout(None)
return self.settimeout(0)
def settimeout(self, value):
def settimeout(self, value: int | float | None) -> None:
if value is None:
# Blocking mode
self._timeout = win32event.INFINITE
@ -214,39 +225,39 @@ class NpipeSocket:
# Timeout mode - Value converted to milliseconds
self._timeout = int(value * 1000)
def gettimeout(self):
def gettimeout(self) -> int | float | None:
return self._timeout
def setsockopt(self, level, optname, value):
def setsockopt(self, level, optname, value) -> t.NoReturn:
raise NotImplementedError()
@check_closed
def shutdown(self, how):
def shutdown(self, how) -> None:
return self.close()
class NpipeFileIOBase(io.RawIOBase):
def __init__(self, npipe_socket):
def __init__(self, npipe_socket) -> None:
self.sock = npipe_socket
def close(self):
def close(self) -> None:
super().close()
self.sock = None
def fileno(self):
def fileno(self) -> int:
return self.sock.fileno()
def isatty(self):
def isatty(self) -> bool:
return False
def readable(self):
def readable(self) -> bool:
return True
def readinto(self, buf):
def readinto(self, buf: Buffer) -> int:
return self.sock.recv_into(buf)
def seekable(self):
def seekable(self) -> bool:
return False
def writable(self):
def writable(self) -> bool:
return False

View File

@ -17,6 +17,7 @@ import signal
import socket
import subprocess
import traceback
import typing as t
from queue import Empty
from urllib.parse import urlparse
@ -25,6 +26,7 @@ from .._import_helper import HTTPAdapter, urllib3, urllib3_connection
from .basehttpadapter import BaseHTTPAdapter
PARAMIKO_IMPORT_ERROR: str | None # pylint: disable=invalid-name
try:
import paramiko
except ImportError:
@ -32,12 +34,15 @@ except ImportError:
else:
PARAMIKO_IMPORT_ERROR = None # pylint: disable=invalid-name
if t.TYPE_CHECKING:
from collections.abc import Buffer
RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
class SSHSocket(socket.socket):
def __init__(self, host):
def __init__(self, host: str) -> None:
super().__init__(socket.AF_INET, socket.SOCK_STREAM)
self.host = host
self.port = None
@ -47,9 +52,9 @@ class SSHSocket(socket.socket):
if "@" in self.host:
self.user, self.host = self.host.split("@")
self.proc = None
self.proc: subprocess.Popen | None = None
def connect(self, **kwargs):
def connect(self, *args_: t.Any, **kwargs: t.Any) -> None:
args = ["ssh"]
if self.user:
args = args + ["-l", self.user]
@ -81,37 +86,48 @@ class SSHSocket(socket.socket):
preexec_fn=preexec_func,
)
def _write(self, data):
if not self.proc or self.proc.stdin.closed:
def _write(self, data: Buffer) -> int:
if not self.proc:
raise RuntimeError(
"SSH subprocess not initiated. connect() must be called first."
)
assert self.proc.stdin is not None
if self.proc.stdin.closed:
raise RuntimeError(
"SSH subprocess not initiated. connect() must be called first after close()."
)
written = self.proc.stdin.write(data)
self.proc.stdin.flush()
return written
def sendall(self, data):
def sendall(self, data: Buffer, *args, **kwargs) -> None:
self._write(data)
def send(self, data):
def send(self, data: Buffer, *args, **kwargs) -> int:
return self._write(data)
def recv(self, n):
def recv(self, n: int, *args, **kwargs) -> bytes:
if not self.proc:
raise RuntimeError(
"SSH subprocess not initiated. connect() must be called first."
)
assert self.proc.stdout is not None
return self.proc.stdout.read(n)
def makefile(self, mode):
def makefile(self, mode: str, *args, **kwargs) -> t.IO: # type: ignore
if not self.proc:
self.connect()
self.proc.stdout.channel = self
assert self.proc is not None
assert self.proc.stdout is not None
self.proc.stdout.channel = self # type: ignore
return self.proc.stdout
def close(self):
if not self.proc or self.proc.stdin.closed:
def close(self) -> None:
if not self.proc:
return
assert self.proc.stdin is not None
if self.proc.stdin.closed:
return
self.proc.stdin.write(b"\n\n")
self.proc.stdin.flush()
@ -119,13 +135,19 @@ class SSHSocket(socket.socket):
class SSHConnection(urllib3_connection.HTTPConnection):
def __init__(self, ssh_transport=None, timeout=60, host=None):
def __init__(
self,
*,
ssh_transport=None,
timeout: int | float = 60,
host: str,
) -> None:
super().__init__("localhost", timeout=timeout)
self.ssh_transport = ssh_transport
self.timeout = timeout
self.ssh_host = host
def connect(self):
def connect(self) -> None:
if self.ssh_transport:
sock = self.ssh_transport.open_session()
sock.settimeout(self.timeout)
@ -141,7 +163,14 @@ class SSHConnection(urllib3_connection.HTTPConnection):
class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
scheme = "ssh"
def __init__(self, ssh_client=None, timeout=60, maxsize=10, host=None):
def __init__(
self,
*,
ssh_client: paramiko.SSHClient | None = None,
timeout: int | float = 60,
maxsize: int = 10,
host: str,
) -> None:
super().__init__("localhost", timeout=timeout, maxsize=maxsize)
self.ssh_transport = None
self.timeout = timeout
@ -149,13 +178,17 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
self.ssh_transport = ssh_client.get_transport()
self.ssh_host = host
def _new_conn(self):
return SSHConnection(self.ssh_transport, self.timeout, self.ssh_host)
def _new_conn(self) -> SSHConnection:
return SSHConnection(
ssh_transport=self.ssh_transport,
timeout=self.timeout,
host=self.ssh_host,
)
# When re-using connections, urllib3 calls fileno() on our
# SSH channel instance, quickly overloading our fd limit. To avoid this,
# we override _get_conn
def _get_conn(self, timeout):
def _get_conn(self, timeout: int | float) -> SSHConnection:
conn = None
try:
conn = self.pool.get(block=self.block, timeout=timeout)
@ -175,7 +208,6 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
class SSHHTTPAdapter(BaseHTTPAdapter):
__attrs__ = HTTPAdapter.__attrs__ + [
"pools",
"timeout",
@ -186,13 +218,13 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
def __init__(
self,
base_url,
timeout=60,
pool_connections=constants.DEFAULT_NUM_POOLS,
max_pool_size=constants.DEFAULT_MAX_POOL_SIZE,
shell_out=False,
):
self.ssh_client = None
base_url: str,
timeout: int | float = 60,
pool_connections: int = constants.DEFAULT_NUM_POOLS,
max_pool_size: int = constants.DEFAULT_MAX_POOL_SIZE,
shell_out: bool = False,
) -> None:
self.ssh_client: paramiko.SSHClient | None = None
if not shell_out:
self._create_paramiko_client(base_url)
self._connect()
@ -208,30 +240,31 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
)
super().__init__()
def _create_paramiko_client(self, base_url):
def _create_paramiko_client(self, base_url: str) -> None:
logging.getLogger("paramiko").setLevel(logging.WARNING)
self.ssh_client = paramiko.SSHClient()
base_url = urlparse(base_url)
self.ssh_params = {
"hostname": base_url.hostname,
"port": base_url.port,
"username": base_url.username,
base_url_p = urlparse(base_url)
assert base_url_p.hostname is not None
self.ssh_params: dict[str, t.Any] = {
"hostname": base_url_p.hostname,
"port": base_url_p.port,
"username": base_url_p.username,
}
ssh_config_file = os.path.expanduser("~/.ssh/config")
if os.path.exists(ssh_config_file):
conf = paramiko.SSHConfig()
with open(ssh_config_file, "rt", encoding="utf-8") as f:
conf.parse(f)
host_config = conf.lookup(base_url.hostname)
host_config = conf.lookup(base_url_p.hostname)
if "proxycommand" in host_config:
self.ssh_params["sock"] = paramiko.ProxyCommand(
host_config["proxycommand"]
)
if "hostname" in host_config:
self.ssh_params["hostname"] = host_config["hostname"]
if base_url.port is None and "port" in host_config:
if base_url_p.port is None and "port" in host_config:
self.ssh_params["port"] = host_config["port"]
if base_url.username is None and "user" in host_config:
if base_url_p.username is None and "user" in host_config:
self.ssh_params["username"] = host_config["user"]
if "identityfile" in host_config:
self.ssh_params["key_filename"] = host_config["identityfile"]
@ -239,11 +272,11 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
self.ssh_client.load_system_host_keys()
self.ssh_client.set_missing_host_key_policy(paramiko.RejectPolicy())
def _connect(self):
def _connect(self) -> None:
if self.ssh_client:
self.ssh_client.connect(**self.ssh_params)
def get_connection(self, url, proxies=None):
def get_connection(self, url: str | bytes, proxies=None) -> SSHConnectionPool:
if not self.ssh_client:
return SSHConnectionPool(
ssh_client=self.ssh_client,
@ -270,7 +303,7 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
return pool
def close(self):
def close(self) -> None:
super().close()
if self.ssh_client:
self.ssh_client.close()

View File

@ -11,9 +11,7 @@
from __future__ import annotations
from ansible_collections.community.docker.plugins.module_utils._version import (
LooseVersion,
)
import typing as t
from .._import_helper import HTTPAdapter, urllib3
from .basehttpadapter import BaseHTTPAdapter
@ -30,14 +28,19 @@ PoolManager = urllib3.poolmanager.PoolManager
class SSLHTTPAdapter(BaseHTTPAdapter):
"""An HTTPS Transport Adapter that uses an arbitrary SSL version."""
__attrs__ = HTTPAdapter.__attrs__ + ["assert_hostname", "ssl_version"]
__attrs__ = HTTPAdapter.__attrs__ + ["assert_hostname"]
def __init__(self, ssl_version=None, assert_hostname=None, **kwargs):
self.ssl_version = ssl_version
def __init__(
self,
assert_hostname: bool | None = None,
**kwargs,
) -> None:
self.assert_hostname = assert_hostname
super().__init__(**kwargs)
def init_poolmanager(self, connections, maxsize, block=False):
def init_poolmanager(
self, connections: int, maxsize: int, block: bool = False, **kwargs: t.Any
) -> None:
kwargs = {
"num_pools": connections,
"maxsize": maxsize,
@ -45,12 +48,10 @@ class SSLHTTPAdapter(BaseHTTPAdapter):
}
if self.assert_hostname is not None:
kwargs["assert_hostname"] = self.assert_hostname
if self.ssl_version and self.can_override_ssl_version():
kwargs["ssl_version"] = self.ssl_version
self.poolmanager = PoolManager(**kwargs)
def get_connection(self, *args, **kwargs):
def get_connection(self, *args, **kwargs) -> urllib3.ConnectionPool:
"""
Ensure assert_hostname is set correctly on our pool
@ -61,15 +62,7 @@ class SSLHTTPAdapter(BaseHTTPAdapter):
conn = super().get_connection(*args, **kwargs)
if (
self.assert_hostname is not None
and conn.assert_hostname != self.assert_hostname
and conn.assert_hostname != self.assert_hostname # type: ignore
):
conn.assert_hostname = self.assert_hostname
conn.assert_hostname = self.assert_hostname # type: ignore
return conn
def can_override_ssl_version(self):
urllib_ver = urllib3.__version__.split("-")[0]
if urllib_ver is None:
return False
if urllib_ver == "dev":
return True
return LooseVersion(urllib_ver) > LooseVersion("1.5")

View File

@ -12,6 +12,7 @@
from __future__ import annotations
import socket
import typing as t
from .. import constants
from .._import_helper import HTTPAdapter, urllib3, urllib3_connection
@ -22,26 +23,27 @@ RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
class UnixHTTPConnection(urllib3_connection.HTTPConnection):
def __init__(self, base_url, unix_socket, timeout=60):
def __init__(
self, base_url: str | bytes, unix_socket, timeout: int | float = 60
) -> None:
super().__init__("localhost", timeout=timeout)
self.base_url = base_url
self.unix_socket = unix_socket
self.timeout = timeout
self.disable_buffering = False
def connect(self):
def connect(self) -> None:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.settimeout(self.timeout)
sock.connect(self.unix_socket)
self.sock = sock
def putheader(self, header, *values):
def putheader(self, header: str, *values: str) -> None:
super().putheader(header, *values)
if header == "Connection" and "Upgrade" in values:
self.disable_buffering = True
def response_class(self, sock, *args, **kwargs):
def response_class(self, sock, *args, **kwargs) -> t.Any:
# FIXME: We may need to disable buffering on Py3,
# but there's no clear way to do it at the moment. See:
# https://github.com/docker/docker-py/issues/1799
@ -49,18 +51,23 @@ class UnixHTTPConnection(urllib3_connection.HTTPConnection):
class UnixHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
def __init__(self, base_url, socket_path, timeout=60, maxsize=10):
def __init__(
self,
base_url: str | bytes,
socket_path: str,
timeout: int | float = 60,
maxsize: int = 10,
) -> None:
super().__init__("localhost", timeout=timeout, maxsize=maxsize)
self.base_url = base_url
self.socket_path = socket_path
self.timeout = timeout
def _new_conn(self):
def _new_conn(self) -> UnixHTTPConnection:
return UnixHTTPConnection(self.base_url, self.socket_path, self.timeout)
class UnixHTTPAdapter(BaseHTTPAdapter):
__attrs__ = HTTPAdapter.__attrs__ + [
"pools",
"socket_path",
@ -70,11 +77,11 @@ class UnixHTTPAdapter(BaseHTTPAdapter):
def __init__(
self,
socket_url,
timeout=60,
pool_connections=constants.DEFAULT_NUM_POOLS,
max_pool_size=constants.DEFAULT_MAX_POOL_SIZE,
):
socket_url: str,
timeout: int | float = 60,
pool_connections: int = constants.DEFAULT_NUM_POOLS,
max_pool_size: int = constants.DEFAULT_MAX_POOL_SIZE,
) -> None:
socket_path = socket_url.replace("http+unix://", "")
if not socket_path.startswith("/"):
socket_path = "/" + socket_path
@ -86,7 +93,7 @@ class UnixHTTPAdapter(BaseHTTPAdapter):
)
super().__init__()
def get_connection(self, url, proxies=None):
def get_connection(self, url: str | bytes, proxies=None) -> UnixHTTPConnectionPool:
with self.pools.lock:
pool = self.pools.get(url)
if pool:
@ -99,7 +106,7 @@ class UnixHTTPAdapter(BaseHTTPAdapter):
return pool
def request_url(self, request, proxies):
def request_url(self, request, proxies) -> str:
# The select_proxy utility in requests errors out when the provided URL
# does not have a hostname, like is the case when using a UNIX socket.
# Since proxies are an irrelevant notion in the case of UNIX sockets

View File

@ -12,6 +12,7 @@
from __future__ import annotations
import socket
import typing as t
from .._import_helper import urllib3
from ..errors import DockerException
@ -29,11 +30,11 @@ class CancellableStream:
>>> events.close()
"""
def __init__(self, stream, response):
def __init__(self, stream, response) -> None:
self._stream = stream
self._response = response
def __iter__(self):
def __iter__(self) -> t.Self:
return self
def __next__(self):
@ -46,7 +47,7 @@ class CancellableStream:
next = __next__
def close(self):
def close(self) -> None:
"""
Closes the event streaming.
"""

View File

@ -17,26 +17,38 @@ import random
import re
import tarfile
import tempfile
import typing as t
from ..constants import IS_WINDOWS_PLATFORM, WINDOWS_LONGPATH_PREFIX
from . import fnmatch
if t.TYPE_CHECKING:
from collections.abc import Sequence
_SEP = re.compile("/|\\\\") if IS_WINDOWS_PLATFORM else re.compile("/")
def tar(path, exclude=None, dockerfile=None, fileobj=None, gzip=False):
def tar(
path: str,
exclude: list[str] | None = None,
dockerfile: tuple[str, str | None] | tuple[None, None] | None = None,
fileobj: t.IO[bytes] | None = None,
gzip: bool = False,
) -> t.IO[bytes]:
root = os.path.abspath(path)
exclude = exclude or []
dockerfile = dockerfile or (None, None)
extra_files = []
extra_files: list[tuple[str, str]] = []
if dockerfile[1] is not None:
assert dockerfile[0] is not None
dockerignore_contents = "\n".join(
(exclude or [".dockerignore"]) + [dockerfile[0]]
)
extra_files = [
(".dockerignore", dockerignore_contents),
dockerfile,
dockerfile, # type: ignore
]
return create_archive(
files=sorted(exclude_paths(root, exclude, dockerfile=dockerfile[0])),
@ -47,7 +59,9 @@ def tar(path, exclude=None, dockerfile=None, fileobj=None, gzip=False):
)
def exclude_paths(root, patterns, dockerfile=None):
def exclude_paths(
root: str, patterns: list[str], dockerfile: str | None = None
) -> set[str]:
"""
Given a root directory path and a list of .dockerignore patterns, return
an iterator of all paths (both regular files and directories) in the root
@ -64,7 +78,7 @@ def exclude_paths(root, patterns, dockerfile=None):
return set(pm.walk(root))
def build_file_list(root):
def build_file_list(root: str) -> list[str]:
files = []
for dirname, dirnames, fnames in os.walk(root):
for filename in fnames + dirnames:
@ -74,7 +88,13 @@ def build_file_list(root):
return files
def create_archive(root, files=None, fileobj=None, gzip=False, extra_files=None):
def create_archive(
root: str,
files: Sequence[str] | None = None,
fileobj: t.IO[bytes] | None = None,
gzip: bool = False,
extra_files: Sequence[tuple[str, str]] | None = None,
) -> t.IO[bytes]:
extra_files = extra_files or []
if not fileobj:
fileobj = tempfile.NamedTemporaryFile()
@ -92,7 +112,7 @@ def create_archive(root, files=None, fileobj=None, gzip=False, extra_files=None)
if i is None:
# This happens when we encounter a socket file. We can safely
# ignore it and proceed.
continue
continue # type: ignore
# Workaround https://bugs.python.org/issue32713
if i.mtime < 0 or i.mtime > 8**11 - 1:
@ -124,11 +144,11 @@ def create_archive(root, files=None, fileobj=None, gzip=False, extra_files=None)
return fileobj
def mkbuildcontext(dockerfile):
def mkbuildcontext(dockerfile: io.BytesIO | t.IO[bytes]) -> t.IO[bytes]:
f = tempfile.NamedTemporaryFile() # pylint: disable=consider-using-with
try:
with tarfile.open(mode="w", fileobj=f) as t:
if isinstance(dockerfile, io.StringIO):
if isinstance(dockerfile, io.StringIO): # type: ignore
raise TypeError("Please use io.BytesIO to create in-memory Dockerfiles")
if isinstance(dockerfile, io.BytesIO):
dfinfo = tarfile.TarInfo("Dockerfile")
@ -144,17 +164,17 @@ def mkbuildcontext(dockerfile):
return f
def split_path(p):
def split_path(p: str) -> list[str]:
return [pt for pt in re.split(_SEP, p) if pt and pt != "."]
def normalize_slashes(p):
def normalize_slashes(p: str) -> str:
if IS_WINDOWS_PLATFORM:
return "/".join(split_path(p))
return p
def walk(root, patterns, default=True):
def walk(root: str, patterns: Sequence[str], default: bool = True) -> t.Generator[str]:
pm = PatternMatcher(patterns)
return pm.walk(root)
@ -162,11 +182,11 @@ def walk(root, patterns, default=True):
# Heavily based on
# https://github.com/moby/moby/blob/master/pkg/fileutils/fileutils.go
class PatternMatcher:
def __init__(self, patterns):
def __init__(self, patterns: Sequence[str]) -> None:
self.patterns = list(filter(lambda p: p.dirs, [Pattern(p) for p in patterns]))
self.patterns.append(Pattern("!.dockerignore"))
def matches(self, filepath):
def matches(self, filepath: str) -> bool:
matched = False
parent_path = os.path.dirname(filepath)
parent_path_dirs = split_path(parent_path)
@ -185,8 +205,8 @@ class PatternMatcher:
return matched
def walk(self, root):
def rec_walk(current_dir):
def walk(self, root: str) -> t.Generator[str]:
def rec_walk(current_dir: str) -> t.Generator[str]:
for f in os.listdir(current_dir):
fpath = os.path.join(os.path.relpath(current_dir, root), f)
if fpath.startswith("." + os.path.sep):
@ -220,7 +240,7 @@ class PatternMatcher:
class Pattern:
def __init__(self, pattern_str):
def __init__(self, pattern_str: str) -> None:
self.exclusion = False
if pattern_str.startswith("!"):
self.exclusion = True
@ -230,8 +250,7 @@ class Pattern:
self.cleaned_pattern = "/".join(self.dirs)
@classmethod
def normalize(cls, p):
def normalize(cls, p: str) -> list[str]:
# Remove trailing spaces
p = p.strip()
@ -256,11 +275,13 @@ class Pattern:
i += 1
return split
def match(self, filepath):
def match(self, filepath: str) -> bool:
return fnmatch.fnmatch(normalize_slashes(filepath), self.cleaned_pattern)
def process_dockerfile(dockerfile, path):
def process_dockerfile(
dockerfile: str | None, path: str
) -> tuple[str, str | None] | tuple[None, None]:
if not dockerfile:
return (None, None)
@ -268,7 +289,7 @@ def process_dockerfile(dockerfile, path):
if not os.path.isabs(dockerfile):
abs_dockerfile = os.path.join(path, dockerfile)
if IS_WINDOWS_PLATFORM and path.startswith(WINDOWS_LONGPATH_PREFIX):
abs_dockerfile = f"{WINDOWS_LONGPATH_PREFIX}{os.path.normpath(abs_dockerfile[len(WINDOWS_LONGPATH_PREFIX):])}"
abs_dockerfile = f"{WINDOWS_LONGPATH_PREFIX}{os.path.normpath(abs_dockerfile[len(WINDOWS_LONGPATH_PREFIX) :])}"
if os.path.splitdrive(path)[0] != os.path.splitdrive(abs_dockerfile)[
0
] or os.path.relpath(abs_dockerfile, path).startswith(".."):

View File

@ -14,6 +14,7 @@ from __future__ import annotations
import json
import logging
import os
import typing as t
from ..constants import IS_WINDOWS_PLATFORM
@ -24,11 +25,11 @@ LEGACY_DOCKER_CONFIG_FILENAME = ".dockercfg"
log = logging.getLogger(__name__)
def get_default_config_file():
def get_default_config_file() -> str:
return os.path.join(home_dir(), DOCKER_CONFIG_FILENAME)
def find_config_file(config_path=None):
def find_config_file(config_path: str | None = None) -> str | None:
homedir = home_dir()
paths = list(
filter(
@ -54,14 +55,14 @@ def find_config_file(config_path=None):
return None
def config_path_from_environment():
def config_path_from_environment() -> str | None:
config_dir = os.environ.get("DOCKER_CONFIG")
if not config_dir:
return None
return os.path.join(config_dir, os.path.basename(DOCKER_CONFIG_FILENAME))
def home_dir():
def home_dir() -> str:
"""
Get the user's home directory, using the same logic as the Docker Engine
client - use %USERPROFILE% on Windows, $HOME/getuid on POSIX.
@ -71,7 +72,7 @@ def home_dir():
return os.path.expanduser("~")
def load_general_config(config_path=None):
def load_general_config(config_path: str | None = None) -> dict[str, t.Any]:
config_file = find_config_file(config_path)
if not config_file:

View File

@ -12,16 +12,37 @@
from __future__ import annotations
import functools
import typing as t
from .. import errors
from . import utils
def minimum_version(version):
def decorator(f):
if t.TYPE_CHECKING:
from collections.abc import Callable
from ..api.client import APIClient
_Self = t.TypeVar("_Self")
_P = t.ParamSpec("_P")
_R = t.TypeVar("_R")
def minimum_version(
version: str,
) -> Callable[
[Callable[t.Concatenate[_Self, _P], _R]],
Callable[t.Concatenate[_Self, _P], _R],
]:
def decorator(
f: Callable[t.Concatenate[_Self, _P], _R],
) -> Callable[t.Concatenate[_Self, _P], _R]:
@functools.wraps(f)
def wrapper(self, *args, **kwargs):
if utils.version_lt(self._version, version):
def wrapper(self: _Self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
# We use _Self instead of APIClient since this is used for mixins for APIClient.
# This unfortunately means that self._version does not exist in the mixin,
# it only exists after mixing in. This is why we ignore types here.
if utils.version_lt(self._version, version): # type: ignore
raise errors.InvalidVersion(
f"{f.__name__} is not available for version < {version}"
)
@ -32,13 +53,16 @@ def minimum_version(version):
return decorator
def update_headers(f):
def inner(self, *args, **kwargs):
def update_headers(
f: Callable[t.Concatenate[APIClient, _P], _R],
) -> Callable[t.Concatenate[APIClient, _P], _R]:
def inner(self: APIClient, *args: _P.args, **kwargs: _P.kwargs) -> _R:
if "HttpHeaders" in self._general_configs:
if not kwargs.get("headers"):
kwargs["headers"] = self._general_configs["HttpHeaders"]
else:
kwargs["headers"].update(self._general_configs["HttpHeaders"])
# We cannot (yet) model that kwargs["headers"] should be a dictionary
kwargs["headers"].update(self._general_configs["HttpHeaders"]) # type: ignore
return f(self, *args, **kwargs)
return inner

View File

@ -28,16 +28,16 @@ import re
__all__ = ["fnmatch", "fnmatchcase", "translate"]
_cache = {}
_cache: dict[str, re.Pattern] = {}
_MAXCACHE = 100
def _purge():
def _purge() -> None:
"""Clear the pattern cache"""
_cache.clear()
def fnmatch(name, pat):
def fnmatch(name: str, pat: str):
"""Test whether FILENAME matches PATTERN.
Patterns are Unix shell style:
@ -58,7 +58,7 @@ def fnmatch(name, pat):
return fnmatchcase(name, pat)
def fnmatchcase(name, pat):
def fnmatchcase(name: str, pat: str) -> bool:
"""Test whether FILENAME matches PATTERN, including case.
This is a version of fnmatch() which does not case-normalize
its arguments.
@ -74,7 +74,7 @@ def fnmatchcase(name, pat):
return re_pat.match(name) is not None
def translate(pat):
def translate(pat: str) -> str:
"""Translate a shell PATTERN to a regular expression.
There is no way to quote meta-characters.

View File

@ -13,14 +13,22 @@ from __future__ import annotations
import json
import json.decoder
import typing as t
from ..errors import StreamParseError
if t.TYPE_CHECKING:
import re
from collections.abc import Callable
_T = t.TypeVar("_T")
json_decoder = json.JSONDecoder()
def stream_as_text(stream):
def stream_as_text(stream: t.Generator[bytes | str]) -> t.Generator[str]:
"""
Given a stream of bytes or text, if any of the items in the stream
are bytes convert them to text.
@ -33,20 +41,22 @@ def stream_as_text(stream):
yield data
def json_splitter(buffer):
def json_splitter(buffer: str) -> tuple[t.Any, str] | None:
"""Attempt to parse a json object from a buffer. If there is at least one
object, return it and the rest of the buffer, otherwise return None.
"""
buffer = buffer.strip()
try:
obj, index = json_decoder.raw_decode(buffer)
rest = buffer[json.decoder.WHITESPACE.match(buffer, index).end() :]
ws: re.Pattern = json.decoder.WHITESPACE # type: ignore[attr-defined]
m = ws.match(buffer, index)
rest = buffer[m.end() :] if m else buffer[index:]
return obj, rest
except ValueError:
return None
def json_stream(stream):
def json_stream(stream: t.Generator[str | bytes]) -> t.Generator[t.Any]:
"""Given a stream of text, return a stream of json objects.
This handles streams which are inconsistently buffered (some entries may
be newline delimited, and others are not).
@ -54,21 +64,24 @@ def json_stream(stream):
return split_buffer(stream, json_splitter, json_decoder.decode)
def line_splitter(buffer, separator="\n"):
def line_splitter(buffer: str, separator: str = "\n") -> tuple[str, str] | None:
index = buffer.find(str(separator))
if index == -1:
return None
return buffer[: index + 1], buffer[index + 1 :]
def split_buffer(stream, splitter=None, decoder=lambda a: a):
def split_buffer(
stream: t.Generator[str | bytes],
splitter: Callable[[str], tuple[_T, str] | None],
decoder: Callable[[str], _T],
) -> t.Generator[_T | str]:
"""Given a generator which yields strings and a splitter function,
joins all input, splits on the separator and yields each chunk.
Unlike string.split(), each chunk includes the trailing
separator, except for the last one if none was found on the end
of the input.
"""
splitter = splitter or line_splitter
buffered = ""
for data in stream_as_text(stream):

View File

@ -12,6 +12,11 @@
from __future__ import annotations
import re
import typing as t
if t.TYPE_CHECKING:
from collections.abc import Collection, Sequence
PORT_SPEC = re.compile(
@ -26,32 +31,42 @@ PORT_SPEC = re.compile(
)
def add_port_mapping(port_bindings, internal_port, external):
def add_port_mapping(
port_bindings: dict[str, list[str | tuple[str, str | None] | None]],
internal_port: str,
external: str | tuple[str, str | None] | None,
) -> None:
if internal_port in port_bindings:
port_bindings[internal_port].append(external)
else:
port_bindings[internal_port] = [external]
def add_port(port_bindings, internal_port_range, external_range):
def add_port(
port_bindings: dict[str, list[str | tuple[str, str | None] | None]],
internal_port_range: list[str],
external_range: list[str] | list[tuple[str, str | None]] | None,
) -> None:
if external_range is None:
for internal_port in internal_port_range:
add_port_mapping(port_bindings, internal_port, None)
else:
ports = zip(internal_port_range, external_range)
for internal_port, external_port in ports:
add_port_mapping(port_bindings, internal_port, external_port)
for internal_port, external_port in zip(internal_port_range, external_range):
# mypy loses the exact type of eternal_port elements for some reason...
add_port_mapping(port_bindings, internal_port, external_port) # type: ignore
def build_port_bindings(ports):
port_bindings = {}
def build_port_bindings(
ports: Collection[str],
) -> dict[str, list[str | tuple[str, str | None] | None]]:
port_bindings: dict[str, list[str | tuple[str, str | None] | None]] = {}
for port in ports:
internal_port_range, external_range = split_port(port)
add_port(port_bindings, internal_port_range, external_range)
return port_bindings
def _raise_invalid_port(port):
def _raise_invalid_port(port: str) -> t.NoReturn:
raise ValueError(
f'Invalid port "{port}", should be '
"[[remote_ip:]remote_port[-remote_port]:]"
@ -59,39 +74,64 @@ def _raise_invalid_port(port):
)
def port_range(start, end, proto, randomly_available_port=False):
if not start:
@t.overload
def port_range(
start: str,
end: str | None,
proto: str,
randomly_available_port: bool = False,
) -> list[str]: ...
@t.overload
def port_range(
start: str | None,
end: str | None,
proto: str,
randomly_available_port: bool = False,
) -> list[str] | None: ...
def port_range(
start: str | None,
end: str | None,
proto: str,
randomly_available_port: bool = False,
) -> list[str] | None:
if start is None:
return start
if not end:
if end is None:
return [f"{start}{proto}"]
if randomly_available_port:
return [f"{start}-{end}{proto}"]
return [f"{port}{proto}" for port in range(int(start), int(end) + 1)]
def split_port(port):
if hasattr(port, "legacy_repr"):
# This is the worst hack, but it prevents a bug in Compose 1.14.0
# https://github.com/docker/docker-py/issues/1668
# TODO: remove once fixed in Compose stable
port = port.legacy_repr()
def split_port(
port: str,
) -> tuple[list[str], list[str] | list[tuple[str, str | None]] | None]:
port = str(port)
match = PORT_SPEC.match(port)
if match is None:
_raise_invalid_port(port)
parts = match.groupdict()
host = parts["host"]
proto = parts["proto"] or ""
internal = port_range(parts["int"], parts["int_end"], proto)
external = port_range(parts["ext"], parts["ext_end"], "", len(internal) == 1)
host: str | None = parts["host"]
proto: str = parts["proto"] or ""
int_p: str = parts["int"]
ext_p: str = parts["ext"]
internal: list[str] = port_range(int_p, parts["int_end"], proto) # type: ignore
external = port_range(ext_p or None, parts["ext_end"], "", len(internal) == 1)
if host is None:
if external is not None and len(internal) != len(external):
if (external is not None and len(internal) != len(external)) or ext_p == "":
raise ValueError("Port ranges don't match in length")
return internal, external
external_or_none: Sequence[str | None]
if not external:
external = [None] * len(internal)
elif len(internal) != len(external):
raise ValueError("Port ranges don't match in length")
return internal, [(host, ext_port) for ext_port in external]
external_or_none = [None] * len(internal)
else:
external_or_none = external
if len(internal) != len(external_or_none):
raise ValueError("Port ranges don't match in length")
return internal, [(host, ext_port) for ext_port in external_or_none]

View File

@ -20,23 +20,23 @@ class ProxyConfig(dict):
"""
@property
def http(self):
def http(self) -> str | None:
return self.get("http")
@property
def https(self):
def https(self) -> str | None:
return self.get("https")
@property
def ftp(self):
def ftp(self) -> str | None:
return self.get("ftp")
@property
def no_proxy(self):
def no_proxy(self) -> str | None:
return self.get("no_proxy")
@staticmethod
def from_dict(config):
def from_dict(config: dict[str, str]) -> ProxyConfig:
"""
Instantiate a new ProxyConfig from a dictionary that represents a
client configuration, as described in `the documentation`_.
@ -51,7 +51,7 @@ class ProxyConfig(dict):
no_proxy=config.get("noProxy"),
)
def get_environment(self):
def get_environment(self) -> dict[str, str]:
"""
Return a dictionary representing the environment variables used to
set the proxy settings.
@ -67,7 +67,7 @@ class ProxyConfig(dict):
env["no_proxy"] = env["NO_PROXY"] = self.no_proxy
return env
def inject_proxy_environment(self, environment):
def inject_proxy_environment(self, environment: list[str]) -> list[str]:
"""
Given a list of strings representing environment variables, prepend the
environment variables corresponding to the proxy settings.
@ -82,5 +82,5 @@ class ProxyConfig(dict):
# variables defined in "environment" to take precedence.
return proxy_env + environment
def __str__(self):
def __str__(self) -> str:
return f"ProxyConfig(http={self.http}, https={self.https}, ftp={self.ftp}, no_proxy={self.no_proxy})"

View File

@ -16,10 +16,15 @@ import os
import select
import socket as pysocket
import struct
import typing as t
from ..transport.npipesocket import NpipeSocket
if t.TYPE_CHECKING:
from collections.abc import Iterable
STDOUT = 1
STDERR = 2
@ -33,7 +38,7 @@ class SocketError(Exception):
NPIPE_ENDED = 109
def read(socket, n=4096):
def read(socket, n: int = 4096) -> bytes | None:
"""
Reads at most n bytes from socket
"""
@ -58,6 +63,7 @@ def read(socket, n=4096):
except EnvironmentError as e:
if e.errno not in recoverable_errors:
raise
return None # TODO ???
except Exception as e:
is_pipe_ended = (
isinstance(socket, NpipeSocket)
@ -67,11 +73,11 @@ def read(socket, n=4096):
if is_pipe_ended:
# npipes do not support duplex sockets, so we interpret
# a PIPE_ENDED error as a close operation (0-length read).
return ""
return b""
raise
def read_exactly(socket, n):
def read_exactly(socket, n: int) -> bytes:
"""
Reads exactly n bytes from socket
Raises SocketError if there is not enough data
@ -85,7 +91,7 @@ def read_exactly(socket, n):
return data
def next_frame_header(socket):
def next_frame_header(socket) -> tuple[int, int]:
"""
Returns the stream and size of the next frame of data waiting to be read
from socket, according to the protocol defined here:
@ -101,7 +107,7 @@ def next_frame_header(socket):
return (stream, actual)
def frames_iter(socket, tty):
def frames_iter(socket, tty: bool) -> t.Generator[tuple[int, bytes]]:
"""
Return a generator of frames read from socket. A frame is a tuple where
the first item is the stream number and the second item is a chunk of data.
@ -114,7 +120,7 @@ def frames_iter(socket, tty):
return frames_iter_no_tty(socket)
def frames_iter_no_tty(socket):
def frames_iter_no_tty(socket) -> t.Generator[tuple[int, bytes]]:
"""
Returns a generator of data read from the socket when the tty setting is
not enabled.
@ -135,20 +141,34 @@ def frames_iter_no_tty(socket):
yield (stream, result)
def frames_iter_tty(socket):
def frames_iter_tty(socket) -> t.Generator[bytes]:
"""
Return a generator of data read from the socket when the tty setting is
enabled.
"""
while True:
result = read(socket)
if len(result) == 0:
if not result:
# We have reached EOF
return
yield result
def consume_socket_output(frames, demux=False):
@t.overload
def consume_socket_output(frames, demux: t.Literal[False] = False) -> bytes: ...
@t.overload
def consume_socket_output(frames, demux: t.Literal[True]) -> tuple[bytes, bytes]: ...
@t.overload
def consume_socket_output(
frames, demux: bool = False
) -> bytes | tuple[bytes, bytes]: ...
def consume_socket_output(frames, demux: bool = False) -> bytes | tuple[bytes, bytes]:
"""
Iterate through frames read from the socket and return the result.
@ -167,7 +187,7 @@ def consume_socket_output(frames, demux=False):
# If the streams are demultiplexed, the generator yields tuples
# (stdout, stderr)
out = [None, None]
out: list[bytes | None] = [None, None]
for frame in frames:
# It is guaranteed that for each frame, one and only one stream
# is not None.
@ -183,10 +203,10 @@ def consume_socket_output(frames, demux=False):
out[1] = frame[1]
else:
out[1] += frame[1]
return tuple(out)
return tuple(out) # type: ignore
def demux_adaptor(stream_id, data):
def demux_adaptor(stream_id: int, data: bytes) -> tuple[bytes | None, bytes | None]:
"""
Utility to demultiplex stdout and stderr when reading frames from the
socket.

View File

@ -18,6 +18,7 @@ import os
import os.path
import shlex
import string
import typing as t
from urllib.parse import urlparse, urlunparse
from ansible_collections.community.docker.plugins.module_utils._version import (
@ -34,32 +35,23 @@ from ..constants import (
from ..tls import TLSConfig
if t.TYPE_CHECKING:
import ssl
from collections.abc import Mapping, Sequence
URLComponents = collections.namedtuple(
"URLComponents",
"scheme netloc url params query fragment",
)
def create_ipam_pool(*args, **kwargs):
raise errors.DeprecatedMethod(
"utils.create_ipam_pool has been removed. Please use a "
"docker.types.IPAMPool object instead."
)
def create_ipam_config(*args, **kwargs):
raise errors.DeprecatedMethod(
"utils.create_ipam_config has been removed. Please use a "
"docker.types.IPAMConfig object instead."
)
def decode_json_header(header):
def decode_json_header(header: str) -> dict[str, t.Any]:
data = base64.b64decode(header).decode("utf-8")
return json.loads(data)
def compare_version(v1, v2):
def compare_version(v1: str, v2: str) -> t.Literal[-1, 0, 1]:
"""Compare docker versions
>>> v1 = '1.9'
@ -80,43 +72,64 @@ def compare_version(v1, v2):
return 1
def version_lt(v1, v2):
def version_lt(v1: str, v2: str) -> bool:
return compare_version(v1, v2) > 0
def version_gte(v1, v2):
def version_gte(v1: str, v2: str) -> bool:
return not version_lt(v1, v2)
def _convert_port_binding(binding):
def _convert_port_binding(
binding: (
tuple[str, str | int | None]
| tuple[str | int | None]
| dict[str, str]
| str
| int
),
) -> dict[str, str]:
result = {"HostIp": "", "HostPort": ""}
host_port: str | int | None = ""
if isinstance(binding, tuple):
if len(binding) == 2:
result["HostPort"] = binding[1]
host_port = binding[1] # type: ignore
result["HostIp"] = binding[0]
elif isinstance(binding[0], str):
result["HostIp"] = binding[0]
else:
result["HostPort"] = binding[0]
host_port = binding[0]
elif isinstance(binding, dict):
if "HostPort" in binding:
result["HostPort"] = binding["HostPort"]
host_port = binding["HostPort"]
if "HostIp" in binding:
result["HostIp"] = binding["HostIp"]
else:
raise ValueError(binding)
else:
result["HostPort"] = binding
if result["HostPort"] is None:
result["HostPort"] = ""
else:
result["HostPort"] = str(result["HostPort"])
host_port = binding
result["HostPort"] = str(host_port) if host_port is not None else ""
return result
def convert_port_bindings(port_bindings):
def convert_port_bindings(
port_bindings: dict[
str | int,
tuple[str, str | int | None]
| tuple[str | int | None]
| dict[str, str]
| str
| int
| list[
tuple[str, str | int | None]
| tuple[str | int | None]
| dict[str, str]
| str
| int
],
],
) -> dict[str, list[dict[str, str]]]:
result = {}
for k, v in port_bindings.items():
key = str(k)
@ -129,9 +142,11 @@ def convert_port_bindings(port_bindings):
return result
def convert_volume_binds(binds):
def convert_volume_binds(
binds: list[str] | Mapping[str | bytes, dict[str, str | bytes] | bytes | str | int],
) -> list[str]:
if isinstance(binds, list):
return binds
return binds # type: ignore
result = []
for k, v in binds.items():
@ -149,7 +164,7 @@ def convert_volume_binds(binds):
if "ro" in v:
mode = "ro" if v["ro"] else "rw"
elif "mode" in v:
mode = v["mode"]
mode = v["mode"] # type: ignore # TODO
else:
mode = "rw"
@ -165,9 +180,9 @@ def convert_volume_binds(binds):
]
if "propagation" in v and v["propagation"] in propagation_modes:
if mode:
mode = ",".join([mode, v["propagation"]])
mode = ",".join([mode, v["propagation"]]) # type: ignore # TODO
else:
mode = v["propagation"]
mode = v["propagation"] # type: ignore # TODO
result.append(f"{k}:{bind}:{mode}")
else:
@ -177,7 +192,7 @@ def convert_volume_binds(binds):
return result
def convert_tmpfs_mounts(tmpfs):
def convert_tmpfs_mounts(tmpfs: dict[str, str] | list[str]) -> dict[str, str]:
if isinstance(tmpfs, dict):
return tmpfs
@ -204,9 +219,11 @@ def convert_tmpfs_mounts(tmpfs):
return result
def convert_service_networks(networks):
def convert_service_networks(
networks: list[str | dict[str, str]],
) -> list[dict[str, str]]:
if not networks:
return networks
return networks # type: ignore
if not isinstance(networks, list):
raise TypeError("networks parameter must be a list.")
@ -218,17 +235,17 @@ def convert_service_networks(networks):
return result
def parse_repository_tag(repo_name):
def parse_repository_tag(repo_name: str) -> tuple[str, str | None]:
parts = repo_name.rsplit("@", 1)
if len(parts) == 2:
return tuple(parts)
return tuple(parts) # type: ignore
parts = repo_name.rsplit(":", 1)
if len(parts) == 2 and "/" not in parts[1]:
return tuple(parts)
return tuple(parts) # type: ignore
return repo_name, None
def parse_host(addr, is_win32=False, tls=False):
def parse_host(addr: str | None, is_win32: bool = False, tls: bool = False) -> str:
# Sensible defaults
if not addr and is_win32:
return DEFAULT_NPIPE
@ -308,7 +325,7 @@ def parse_host(addr, is_win32=False, tls=False):
).rstrip("/")
def parse_devices(devices):
def parse_devices(devices: Sequence[dict[str, str] | str]) -> list[dict[str, str]]:
device_list = []
for device in devices:
if isinstance(device, dict):
@ -337,7 +354,10 @@ def parse_devices(devices):
return device_list
def kwargs_from_env(ssl_version=None, assert_hostname=None, environment=None):
def kwargs_from_env(
assert_hostname: bool | None = None,
environment: Mapping[str, str] | None = None,
) -> dict[str, t.Any]:
if not environment:
environment = os.environ
host = environment.get("DOCKER_HOST")
@ -347,14 +367,14 @@ def kwargs_from_env(ssl_version=None, assert_hostname=None, environment=None):
# empty string for tls verify counts as "false".
# Any value or 'unset' counts as true.
tls_verify = environment.get("DOCKER_TLS_VERIFY")
if tls_verify == "":
tls_verify_str = environment.get("DOCKER_TLS_VERIFY")
if tls_verify_str == "":
tls_verify = False
else:
tls_verify = tls_verify is not None
tls_verify = tls_verify_str is not None
enable_tls = cert_path or tls_verify
params = {}
params: dict[str, t.Any] = {}
if host:
params["base_url"] = host
@ -377,14 +397,13 @@ def kwargs_from_env(ssl_version=None, assert_hostname=None, environment=None):
),
ca_cert=os.path.join(cert_path, "ca.pem"),
verify=tls_verify,
ssl_version=ssl_version,
assert_hostname=assert_hostname,
)
return params
def convert_filters(filters):
def convert_filters(filters: Mapping[str, bool | str | list[str]]) -> str:
result = {}
for k, v in filters.items():
if isinstance(v, bool):
@ -397,7 +416,7 @@ def convert_filters(filters):
return json.dumps(result)
def parse_bytes(s):
def parse_bytes(s: int | float | str) -> int | float:
if isinstance(s, (int, float)):
return s
if len(s) == 0:
@ -435,14 +454,16 @@ def parse_bytes(s):
return s
def normalize_links(links):
def normalize_links(links: dict[str, str] | Sequence[tuple[str, str]]) -> list[str]:
if isinstance(links, dict):
links = links.items()
sorted_links = sorted(links.items())
else:
sorted_links = sorted(links)
return [f"{k}:{v}" if v else k for k, v in sorted(links)]
return [f"{k}:{v}" if v else k for k, v in sorted_links]
def parse_env_file(env_file):
def parse_env_file(env_file: str | os.PathLike) -> dict[str, str]:
"""
Reads a line-separated environment file.
The format of each line should be "key=value".
@ -451,7 +472,6 @@ def parse_env_file(env_file):
with open(env_file, "rt", encoding="utf-8") as f:
for line in f:
if line[0] == "#":
continue
@ -471,11 +491,11 @@ def parse_env_file(env_file):
return environment
def split_command(command):
def split_command(command: str) -> list[str]:
return shlex.split(command)
def format_environment(environment):
def format_environment(environment: Mapping[str, str | bytes]) -> list[str]:
def format_env(key, value):
if value is None:
return key
@ -487,16 +507,9 @@ def format_environment(environment):
return [format_env(*var) for var in environment.items()]
def format_extra_hosts(extra_hosts, task=False):
def format_extra_hosts(extra_hosts: Mapping[str, str], task: bool = False) -> list[str]:
# Use format dictated by Swarm API if container is part of a task
if task:
return [f"{v} {k}" for k, v in sorted(extra_hosts.items())]
return [f"{k}:{v}" for k, v in sorted(extra_hosts.items())]
def create_host_config(self, *args, **kwargs):
raise errors.DeprecatedMethod(
"utils.create_host_config has been removed. Please use a "
"docker.types.HostConfig object instead."
)

View File

@ -13,6 +13,7 @@ import platform
import re
import sys
import traceback
import typing as t
from collections.abc import Mapping, Sequence
from ansible.module_utils.basic import AnsibleModule, missing_required_lib
@ -36,6 +37,9 @@ from ansible_collections.community.docker.plugins.module_utils._version import (
HAS_DOCKER_PY_2 = False # pylint: disable=invalid-name
HAS_DOCKER_PY_3 = False # pylint: disable=invalid-name
HAS_DOCKER_ERROR: None | str # pylint: disable=invalid-name
HAS_DOCKER_TRACEBACK: None | str # pylint: disable=invalid-name
docker_version: str | None # pylint: disable=invalid-name
try:
from docker import __version__ as docker_version
@ -51,12 +55,13 @@ try:
HAS_DOCKER_PY_2 = True # pylint: disable=invalid-name
from docker import APIClient as Client
else:
from docker import Client
from docker import Client # type: ignore
except ImportError as exc:
HAS_DOCKER_ERROR = str(exc) # pylint: disable=invalid-name
HAS_DOCKER_TRACEBACK = traceback.format_exc() # pylint: disable=invalid-name
HAS_DOCKER_PY = False # pylint: disable=invalid-name
docker_version = None # pylint: disable=invalid-name
else:
HAS_DOCKER_PY = True # pylint: disable=invalid-name
HAS_DOCKER_ERROR = None # pylint: disable=invalid-name
@ -71,30 +76,34 @@ except ImportError:
# Either Docker SDK for Python is no longer using requests, or Docker SDK for Python is not around either,
# or Docker SDK for Python's dependency requests is missing. In any case, define an exception
# class RequestException so that our code does not break.
class RequestException(Exception):
class RequestException(Exception): # type: ignore
pass
if t.TYPE_CHECKING:
from collections.abc import Callable
MIN_DOCKER_VERSION = "2.0.0"
if not HAS_DOCKER_PY:
docker_version = None # pylint: disable=invalid-name
# No Docker SDK for Python. Create a place holder client to allow
# instantiation of AnsibleModule and proper error handing
class Client: # noqa: F811, pylint: disable=function-redefined
class Client: # type: ignore # noqa: F811, pylint: disable=function-redefined
def __init__(self, **kwargs):
pass
class APIError(Exception): # noqa: F811, pylint: disable=function-redefined
class APIError(Exception): # type: ignore # noqa: F811, pylint: disable=function-redefined
pass
class NotFound(Exception): # noqa: F811, pylint: disable=function-redefined
class NotFound(Exception): # type: ignore # noqa: F811, pylint: disable=function-redefined
pass
def _get_tls_config(fail_function, **kwargs):
def _get_tls_config(
fail_function: Callable[[str], t.NoReturn], **kwargs: t.Any
) -> TLSConfig:
if "assert_hostname" in kwargs and LooseVersion(docker_version) >= LooseVersion(
"7.0.0b1"
):
@ -109,17 +118,18 @@ def _get_tls_config(fail_function, **kwargs):
# Filter out all None parameters
kwargs = dict((k, v) for k, v in kwargs.items() if v is not None)
try:
tls_config = TLSConfig(**kwargs)
return tls_config
return TLSConfig(**kwargs)
except TLSParameterError as exc:
fail_function(f"TLS config error: {exc}")
def is_using_tls(auth_data):
def is_using_tls(auth_data: dict[str, t.Any]) -> bool:
return auth_data["tls_verify"] or auth_data["tls"]
def get_connect_params(auth_data, fail_function):
def get_connect_params(
auth_data: dict[str, t.Any], fail_function: Callable[[str], t.NoReturn]
) -> dict[str, t.Any]:
if is_using_tls(auth_data):
auth_data["docker_host"] = auth_data["docker_host"].replace(
"tcp://", "https://"
@ -171,7 +181,11 @@ DOCKERPYUPGRADE_UPGRADE_DOCKER = "Use `pip install --upgrade docker` to upgrade.
class AnsibleDockerClientBase(Client):
def __init__(self, min_docker_version=None, min_docker_api_version=None):
def __init__(
self,
min_docker_version: str | None = None,
min_docker_api_version: str | None = None,
) -> None:
if min_docker_version is None:
min_docker_version = MIN_DOCKER_VERSION
@ -212,23 +226,34 @@ class AnsibleDockerClientBase(Client):
f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}."
)
def log(self, msg, pretty_print=False):
def log(self, msg: t.Any, pretty_print: bool = False):
pass
# if self.debug:
# from .util import log_debug
# log_debug(msg, pretty_print=pretty_print)
@abc.abstractmethod
def fail(self, msg, **kwargs):
def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
pass
def deprecate(self, msg, version=None, date=None, collection_name=None):
@abc.abstractmethod
def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
pass
@staticmethod
def _get_value(
param_name, param_value, env_variable, default_value, value_type="str"
):
param_name: str,
param_value: t.Any,
env_variable: str | None,
default_value: t.Any | None,
value_type: t.Literal["str", "bool", "int"] = "str",
) -> t.Any:
if param_value is not None:
# take module parameter value
if value_type == "bool":
@ -265,11 +290,11 @@ class AnsibleDockerClientBase(Client):
return default_value
@abc.abstractmethod
def _get_params(self):
def _get_params(self) -> dict[str, t.Any]:
pass
@property
def auth_params(self):
def auth_params(self) -> dict[str, t.Any]:
# Get authentication credentials.
# Precedence: module parameters-> environment variables-> defaults.
@ -354,7 +379,7 @@ class AnsibleDockerClientBase(Client):
return result
def _handle_ssl_error(self, error):
def _handle_ssl_error(self, error: Exception) -> t.NoReturn:
match = re.match(r"hostname.*doesn\'t match (\'.*\')", str(error))
if match:
hostname = self.auth_params["tls_hostname"]
@ -366,7 +391,7 @@ class AnsibleDockerClientBase(Client):
)
self.fail(f"SSL Exception: {error}")
def get_container_by_id(self, container_id):
def get_container_by_id(self, container_id: str) -> dict[str, t.Any] | None:
try:
self.log(f"Inspecting container Id {container_id}")
result = self.inspect_container(container=container_id)
@ -377,7 +402,7 @@ class AnsibleDockerClientBase(Client):
except Exception as exc: # pylint: disable=broad-exception-caught
self.fail(f"Error inspecting container: {exc}")
def get_container(self, name=None):
def get_container(self, name: str | None) -> dict[str, t.Any] | None:
"""
Lookup a container and return the inspection results.
"""
@ -414,7 +439,9 @@ class AnsibleDockerClientBase(Client):
return self.get_container_by_id(result["Id"])
def get_network(self, name=None, network_id=None):
def get_network(
self, name: str | None = None, network_id: str | None = None
) -> dict[str, t.Any] | None:
"""
Lookup a network and return the inspection results.
"""
@ -453,7 +480,7 @@ class AnsibleDockerClientBase(Client):
return result
def find_image(self, name, tag):
def find_image(self, name: str, tag: str) -> dict[str, t.Any] | None:
"""
Lookup an image (by name and tag) and return the inspection results.
"""
@ -505,7 +532,9 @@ class AnsibleDockerClientBase(Client):
self.log(f"Image {name}:{tag} not found.")
return None
def find_image_by_id(self, image_id, accept_missing_image=False):
def find_image_by_id(
self, image_id: str, accept_missing_image: bool = False
) -> dict[str, t.Any] | None:
"""
Lookup an image (by ID) and return the inspection results.
"""
@ -524,7 +553,7 @@ class AnsibleDockerClientBase(Client):
self.fail(f"Error inspecting image ID {image_id} - {exc}")
return inspection
def _image_lookup(self, name, tag):
def _image_lookup(self, name: str, tag: str) -> list[dict[str, t.Any]]:
"""
Including a tag in the name parameter sent to the Docker SDK for Python images method
does not work consistently. Instead, get the result set for name and manually check
@ -547,7 +576,9 @@ class AnsibleDockerClientBase(Client):
break
return images
def pull_image(self, name, tag="latest", image_platform=None):
def pull_image(
self, name: str, tag: str = "latest", image_platform: str | None = None
) -> tuple[dict[str, t.Any] | None, bool]:
"""
Pull an image
"""
@ -578,7 +609,7 @@ class AnsibleDockerClientBase(Client):
return new_tag, old_tag == new_tag
def inspect_distribution(self, image, **kwargs):
def inspect_distribution(self, image: str, **kwargs) -> dict[str, t.Any]:
"""
Get image digest by directly calling the Docker API when running Docker SDK < 4.0.0
since prior versions did not support accessing private repositories.
@ -592,7 +623,7 @@ class AnsibleDockerClientBase(Client):
self._url("/distribution/{0}/json", image),
headers={"X-Registry-Auth": header},
),
get_json=True,
json=True,
)
return super().inspect_distribution(image, **kwargs)
@ -601,18 +632,24 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
def __init__(
self,
argument_spec=None,
supports_check_mode=False,
mutually_exclusive=None,
required_together=None,
required_if=None,
required_one_of=None,
required_by=None,
min_docker_version=None,
min_docker_api_version=None,
option_minimal_versions=None,
option_minimal_versions_ignore_params=None,
fail_results=None,
argument_spec: dict[str, t.Any] | None = None,
supports_check_mode: bool = False,
mutually_exclusive: Sequence[Sequence[str]] | None = None,
required_together: Sequence[Sequence[str]] | None = None,
required_if: (
Sequence[
tuple[str, t.Any, Sequence[str]]
| tuple[str, t.Any, Sequence[str], bool]
]
| None
) = None,
required_one_of: Sequence[Sequence[str]] | None = None,
required_by: dict[str, Sequence[str]] | None = None,
min_docker_version: str | None = None,
min_docker_api_version: str | None = None,
option_minimal_versions: dict[str, t.Any] | None = None,
option_minimal_versions_ignore_params: Sequence[str] | None = None,
fail_results: dict[str, t.Any] | None = None,
):
# Modules can put information in here which will always be returned
@ -625,12 +662,12 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
merged_arg_spec.update(argument_spec)
self.arg_spec = merged_arg_spec
mutually_exclusive_params = []
mutually_exclusive_params: list[Sequence[str]] = []
mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE
if mutually_exclusive:
mutually_exclusive_params += mutually_exclusive
required_together_params = []
required_together_params: list[Sequence[str]] = []
required_together_params += DOCKER_REQUIRED_TOGETHER
if required_together:
required_together_params += required_together
@ -658,20 +695,30 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
option_minimal_versions, option_minimal_versions_ignore_params
)
def fail(self, msg, **kwargs):
def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
self.fail_results.update(kwargs)
self.module.fail_json(msg=msg, **sanitize_result(self.fail_results))
def deprecate(self, msg, version=None, date=None, collection_name=None):
def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
self.module.deprecate(
msg, version=version, date=date, collection_name=collection_name
)
def _get_params(self):
def _get_params(self) -> dict[str, t.Any]:
return self.module.params
def _get_minimal_versions(self, option_minimal_versions, ignore_params=None):
self.option_minimal_versions = {}
def _get_minimal_versions(
self,
option_minimal_versions: dict[str, t.Any],
ignore_params: Sequence[str] | None = None,
) -> None:
self.option_minimal_versions: dict[str, dict[str, t.Any]] = {}
for option in self.module.argument_spec:
if ignore_params is not None:
if option in ignore_params:
@ -722,7 +769,9 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
msg = f"Cannot {usg} with your configuration."
self.fail(msg)
def report_warnings(self, result, warnings_key=None):
def report_warnings(
self, result: t.Any, warnings_key: Sequence[str] | None = None
) -> None:
"""
Checks result of client operation for warnings, and if present, outputs them.

View File

@ -11,6 +11,7 @@ from __future__ import annotations
import abc
import os
import re
import typing as t
from collections.abc import Mapping, Sequence
from ansible.module_utils.basic import AnsibleModule, missing_required_lib
@ -28,7 +29,7 @@ try:
)
except ImportError:
# Define an exception class RequestException so that our code does not break.
class RequestException(Exception):
class RequestException(Exception): # type: ignore
pass
@ -60,19 +61,26 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
)
def _get_tls_config(fail_function, **kwargs):
if t.TYPE_CHECKING:
from collections.abc import Callable
def _get_tls_config(
fail_function: Callable[[str], t.NoReturn], **kwargs: t.Any
) -> TLSConfig:
try:
tls_config = TLSConfig(**kwargs)
return tls_config
return TLSConfig(**kwargs)
except TLSParameterError as exc:
fail_function(f"TLS config error: {exc}")
def is_using_tls(auth_data):
def is_using_tls(auth_data: dict[str, t.Any]) -> bool:
return auth_data["tls_verify"] or auth_data["tls"]
def get_connect_params(auth_data, fail_function):
def get_connect_params(
auth_data: dict[str, t.Any], fail_function: Callable[[str], t.NoReturn]
) -> dict[str, t.Any]:
if is_using_tls(auth_data):
auth_data["docker_host"] = auth_data["docker_host"].replace(
"tcp://", "https://"
@ -114,7 +122,7 @@ def get_connect_params(auth_data, fail_function):
class AnsibleDockerClientBase(Client):
def __init__(self, min_docker_api_version=None):
def __init__(self, min_docker_api_version: str | None = None) -> None:
self._connect_params = get_connect_params(
self.auth_params, fail_function=self.fail
)
@ -138,23 +146,34 @@ class AnsibleDockerClientBase(Client):
f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}."
)
def log(self, msg, pretty_print=False):
def log(self, msg: t.Any, pretty_print: bool = False):
pass
# if self.debug:
# from .util import log_debug
# log_debug(msg, pretty_print=pretty_print)
@abc.abstractmethod
def fail(self, msg, **kwargs):
def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
pass
def deprecate(self, msg, version=None, date=None, collection_name=None):
@abc.abstractmethod
def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
pass
@staticmethod
def _get_value(
param_name, param_value, env_variable, default_value, value_type="str"
):
param_name: str,
param_value: t.Any,
env_variable: str | None,
default_value: t.Any | None,
value_type: t.Literal["str", "bool", "int"] = "str",
) -> t.Any:
if param_value is not None:
# take module parameter value
if value_type == "bool":
@ -191,11 +210,11 @@ class AnsibleDockerClientBase(Client):
return default_value
@abc.abstractmethod
def _get_params(self):
def _get_params(self) -> dict[str, t.Any]:
pass
@property
def auth_params(self):
def auth_params(self) -> dict[str, t.Any]:
# Get authentication credentials.
# Precedence: module parameters-> environment variables-> defaults.
@ -288,7 +307,7 @@ class AnsibleDockerClientBase(Client):
return result
def _handle_ssl_error(self, error):
def _handle_ssl_error(self, error: Exception) -> t.NoReturn:
match = re.match(r"hostname.*doesn\'t match (\'.*\')", str(error))
if match:
hostname = self.auth_params["tls_hostname"]
@ -300,7 +319,7 @@ class AnsibleDockerClientBase(Client):
)
self.fail(f"SSL Exception: {error}")
def get_container_by_id(self, container_id):
def get_container_by_id(self, container_id: str) -> dict[str, t.Any] | None:
try:
self.log(f"Inspecting container Id {container_id}")
result = self.get_json("/containers/{0}/json", container_id)
@ -311,7 +330,7 @@ class AnsibleDockerClientBase(Client):
except Exception as exc: # pylint: disable=broad-exception-caught
self.fail(f"Error inspecting container: {exc}")
def get_container(self, name=None):
def get_container(self, name: str | None) -> dict[str, t.Any] | None:
"""
Lookup a container and return the inspection results.
"""
@ -355,7 +374,9 @@ class AnsibleDockerClientBase(Client):
return self.get_container_by_id(result["Id"])
def get_network(self, name=None, network_id=None):
def get_network(
self, name: str | None = None, network_id: str | None = None
) -> dict[str, t.Any] | None:
"""
Lookup a network and return the inspection results.
"""
@ -395,14 +416,14 @@ class AnsibleDockerClientBase(Client):
return result
def _image_lookup(self, name, tag):
def _image_lookup(self, name: str, tag: str | None) -> list[dict[str, t.Any]]:
"""
Including a tag in the name parameter sent to the Docker SDK for Python images method
does not work consistently. Instead, get the result set for name and manually check
if the tag exists.
"""
try:
params = {
params: dict[str, t.Any] = {
"only_ids": 0,
"all": 0,
}
@ -427,7 +448,7 @@ class AnsibleDockerClientBase(Client):
break
return images
def find_image(self, name, tag):
def find_image(self, name: str, tag: str | None) -> dict[str, t.Any] | None:
"""
Lookup an image (by name and tag) and return the inspection results.
"""
@ -478,7 +499,9 @@ class AnsibleDockerClientBase(Client):
self.log(f"Image {name}:{tag} not found.")
return None
def find_image_by_id(self, image_id, accept_missing_image=False):
def find_image_by_id(
self, image_id: str, accept_missing_image: bool = False
) -> dict[str, t.Any] | None:
"""
Lookup an image (by ID) and return the inspection results.
"""
@ -496,7 +519,9 @@ class AnsibleDockerClientBase(Client):
except Exception as exc: # pylint: disable=broad-exception-caught
self.fail(f"Error inspecting image ID {image_id} - {exc}")
def pull_image(self, name, tag="latest", image_platform=None):
def pull_image(
self, name: str, tag: str = "latest", image_platform: str | None = None
) -> tuple[dict[str, t.Any] | None, bool]:
"""
Pull an image
"""
@ -544,22 +569,26 @@ class AnsibleDockerClientBase(Client):
class AnsibleDockerClient(AnsibleDockerClientBase):
def __init__(
self,
argument_spec=None,
supports_check_mode=False,
mutually_exclusive=None,
required_together=None,
required_if=None,
required_one_of=None,
required_by=None,
min_docker_api_version=None,
option_minimal_versions=None,
option_minimal_versions_ignore_params=None,
fail_results=None,
argument_spec: dict[str, t.Any] | None = None,
supports_check_mode: bool = False,
mutually_exclusive: Sequence[Sequence[str]] | None = None,
required_together: Sequence[Sequence[str]] | None = None,
required_if: (
Sequence[
tuple[str, t.Any, Sequence[str]]
| tuple[str, t.Any, Sequence[str], bool]
]
| None
) = None,
required_one_of: Sequence[Sequence[str]] | None = None,
required_by: dict[str, Sequence[str]] | None = None,
min_docker_api_version: str | None = None,
option_minimal_versions: dict[str, t.Any] | None = None,
option_minimal_versions_ignore_params: Sequence[str] | None = None,
fail_results: dict[str, t.Any] | None = None,
):
# Modules can put information in here which will always be returned
# in case client.fail() is called.
self.fail_results = fail_results or {}
@ -570,12 +599,12 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
merged_arg_spec.update(argument_spec)
self.arg_spec = merged_arg_spec
mutually_exclusive_params = []
mutually_exclusive_params: list[Sequence[str]] = []
mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE
if mutually_exclusive:
mutually_exclusive_params += mutually_exclusive
required_together_params = []
required_together_params: list[Sequence[str]] = []
required_together_params += DOCKER_REQUIRED_TOGETHER
if required_together:
required_together_params += required_together
@ -600,20 +629,30 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
option_minimal_versions, option_minimal_versions_ignore_params
)
def fail(self, msg, **kwargs):
def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
self.fail_results.update(kwargs)
self.module.fail_json(msg=msg, **sanitize_result(self.fail_results))
def deprecate(self, msg, version=None, date=None, collection_name=None):
def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
self.module.deprecate(
msg, version=version, date=date, collection_name=collection_name
)
def _get_params(self):
def _get_params(self) -> dict[str, t.Any]:
return self.module.params
def _get_minimal_versions(self, option_minimal_versions, ignore_params=None):
self.option_minimal_versions = {}
def _get_minimal_versions(
self,
option_minimal_versions: dict[str, t.Any],
ignore_params: Sequence[str] | None = None,
) -> None:
self.option_minimal_versions: dict[str, dict[str, t.Any]] = {}
for option in self.module.argument_spec:
if ignore_params is not None:
if option in ignore_params:
@ -654,7 +693,9 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
msg = f"Cannot {usg} with your configuration."
self.fail(msg)
def report_warnings(self, result, warnings_key=None):
def report_warnings(
self, result: t.Any, warnings_key: Sequence[str] | None = None
) -> None:
"""
Checks result of client operation for warnings, and if present, outputs them.

View File

@ -10,6 +10,7 @@ from __future__ import annotations
import abc
import json
import shlex
import typing as t
from ansible.module_utils.basic import AnsibleModule, env_fallback
from ansible.module_utils.common.process import get_bin_path
@ -31,6 +32,10 @@ from ansible_collections.community.docker.plugins.module_utils._version import (
)
if t.TYPE_CHECKING:
from collections.abc import Mapping, Sequence
DOCKER_COMMON_ARGS = {
"docker_cli": {"type": "path"},
"docker_host": {
@ -72,10 +77,16 @@ class DockerException(Exception):
class AnsibleDockerClientBase:
docker_api_version_str: str | None
docker_api_version: LooseVersion | None
def __init__(
self, common_args, min_docker_api_version=None, needs_api_version=True
):
self._environment = {}
self,
common_args,
min_docker_api_version: str | None = None,
needs_api_version: bool = True,
) -> None:
self._environment: dict[str, str] = {}
if common_args["tls_hostname"]:
self._environment["DOCKER_TLS_HOSTNAME"] = common_args["tls_hostname"]
if common_args["api_version"] and common_args["api_version"] != "auto":
@ -109,10 +120,10 @@ class AnsibleDockerClientBase:
self._cli_base.extend(["--context", common_args["cli_context"]])
# `--format json` was only added as a shorthand for `--format {{ json . }}` in Docker 23.0
dummy, self._version, dummy = self.call_cli_json(
dummy, self._version, dummy2 = self.call_cli_json(
"version", "--format", "{{ json . }}", check_rc=True
)
self._info = None
self._info: dict[str, t.Any] | None = None
if needs_api_version:
if not isinstance(self._version.get("Server"), dict) or not isinstance(
@ -138,32 +149,47 @@ class AnsibleDockerClientBase:
"Internal error: cannot have needs_api_version=False with min_docker_api_version not None"
)
def log(self, msg, pretty_print=False):
def log(self, msg: str, pretty_print: bool = False):
pass
# if self.debug:
# from .util import log_debug
# log_debug(msg, pretty_print=pretty_print)
def get_cli(self):
def get_cli(self) -> str:
return self._cli
def get_version_info(self):
def get_version_info(self) -> str:
return self._version
def _compose_cmd(self, args):
def _compose_cmd(self, args: t.Sequence[str]) -> list[str]:
return self._cli_base + list(args)
def _compose_cmd_str(self, args):
def _compose_cmd_str(self, args: t.Sequence[str]) -> str:
return " ".join(shlex.quote(a) for a in self._compose_cmd(args))
@abc.abstractmethod
def call_cli(self, *args, check_rc=False, data=None, cwd=None, environ_update=None):
def call_cli(
self,
*args: str,
check_rc: bool = False,
data: bytes | None = None,
cwd: str | None = None,
environ_update: dict[str, str] | None = None,
) -> tuple[int, bytes, bytes]:
pass
# def call_cli_json(self, *args, check_rc=False, data=None, cwd=None, environ_update=None, warn_on_stderr=False):
def call_cli_json(self, *args, **kwargs):
warn_on_stderr = kwargs.pop("warn_on_stderr", False)
rc, stdout, stderr = self.call_cli(*args, **kwargs)
def call_cli_json(
self,
*args: str,
check_rc: bool = False,
data: bytes | None = None,
cwd: str | None = None,
environ_update: dict[str, str] | None = None,
warn_on_stderr: bool = False,
) -> tuple[int, t.Any, bytes]:
rc, stdout, stderr = self.call_cli(
*args, check_rc=check_rc, data=data, cwd=cwd, environ_update=environ_update
)
if warn_on_stderr and stderr:
self.warn(to_native(stderr))
try:
@ -174,10 +200,18 @@ class AnsibleDockerClientBase:
)
return rc, data, stderr
# def call_cli_json_stream(self, *args, check_rc=False, data=None, cwd=None, environ_update=None, warn_on_stderr=False):
def call_cli_json_stream(self, *args, **kwargs):
warn_on_stderr = kwargs.pop("warn_on_stderr", False)
rc, stdout, stderr = self.call_cli(*args, **kwargs)
def call_cli_json_stream(
self,
*args: str,
check_rc: bool = False,
data: bytes | None = None,
cwd: str | None = None,
environ_update: dict[str, str] | None = None,
warn_on_stderr: bool = False,
) -> tuple[int, list[t.Any], bytes]:
rc, stdout, stderr = self.call_cli(
*args, check_rc=check_rc, data=data, cwd=cwd, environ_update=environ_update
)
if warn_on_stderr and stderr:
self.warn(to_native(stderr))
result = []
@ -193,25 +227,31 @@ class AnsibleDockerClientBase:
return rc, result, stderr
@abc.abstractmethod
def fail(self, msg, **kwargs):
def fail(self, msg: str, **kwargs) -> t.NoReturn:
pass
@abc.abstractmethod
def warn(self, msg):
def warn(self, msg: str) -> None:
pass
@abc.abstractmethod
def deprecate(self, msg, version=None, date=None, collection_name=None):
def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
pass
def get_cli_info(self):
def get_cli_info(self) -> dict[str, t.Any]:
if self._info is None:
dummy, self._info, dummy = self.call_cli_json(
dummy, self._info, dummy2 = self.call_cli_json(
"info", "--format", "{{ json . }}", check_rc=True
)
return self._info
def get_client_plugin_info(self, component):
def get_client_plugin_info(self, component: str) -> dict[str, t.Any] | None:
cli_info = self.get_cli_info()
if not isinstance(cli_info.get("ClientInfo"), dict):
self.fail(
@ -222,13 +262,13 @@ class AnsibleDockerClientBase:
return plugin
return None
def _image_lookup(self, name, tag):
def _image_lookup(self, name: str, tag: str) -> list[dict[str, t.Any]]:
"""
Including a tag in the name parameter sent to the Docker SDK for Python images method
does not work consistently. Instead, get the result set for name and manually check
if the tag exists.
"""
dummy, images, dummy = self.call_cli_json_stream(
dummy, images, dummy2 = self.call_cli_json_stream(
"image",
"ls",
"--format",
@ -247,7 +287,13 @@ class AnsibleDockerClientBase:
break
return images
def find_image(self, name, tag):
@t.overload
def find_image(self, name: None, tag: str) -> None: ...
@t.overload
def find_image(self, name: str, tag: str) -> dict[str, t.Any] | None: ...
def find_image(self, name: str | None, tag: str) -> dict[str, t.Any] | None:
"""
Lookup an image (by name and tag) and return the inspection results.
"""
@ -298,7 +344,19 @@ class AnsibleDockerClientBase:
self.log(f"Image {name}:{tag} not found.")
return None
def find_image_by_id(self, image_id, accept_missing_image=False):
@t.overload
def find_image_by_id(
self, image_id: None, accept_missing_image: bool = False
) -> None: ...
@t.overload
def find_image_by_id(
self, image_id: str | None, accept_missing_image: bool = False
) -> dict[str, t.Any] | None: ...
def find_image_by_id(
self, image_id: str | None, accept_missing_image: bool = False
) -> dict[str, t.Any] | None:
"""
Lookup an image (by ID) and return the inspection results.
"""
@ -320,17 +378,23 @@ class AnsibleDockerClientBase:
class AnsibleModuleDockerClient(AnsibleDockerClientBase):
def __init__(
self,
argument_spec=None,
supports_check_mode=False,
mutually_exclusive=None,
required_together=None,
required_if=None,
required_one_of=None,
required_by=None,
min_docker_api_version=None,
fail_results=None,
needs_api_version=True,
):
argument_spec: dict[str, t.Any] | None = None,
supports_check_mode: bool = False,
mutually_exclusive: Sequence[Sequence[str]] | None = None,
required_together: Sequence[Sequence[str]] | None = None,
required_if: (
Sequence[
tuple[str, t.Any, Sequence[str]]
| tuple[str, t.Any, Sequence[str], bool]
]
| None
) = None,
required_one_of: Sequence[Sequence[str]] | None = None,
required_by: Mapping[str, Sequence[str]] | None = None,
min_docker_api_version: str | None = None,
fail_results: dict[str, t.Any] | None = None,
needs_api_version: bool = True,
) -> None:
# Modules can put information in here which will always be returned
# in case client.fail() is called.
@ -342,12 +406,14 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase):
merged_arg_spec.update(argument_spec)
self.arg_spec = merged_arg_spec
mutually_exclusive_params = [("docker_host", "cli_context")]
mutually_exclusive_params: list[Sequence[str]] = [
("docker_host", "cli_context")
]
mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE
if mutually_exclusive:
mutually_exclusive_params += mutually_exclusive
required_together_params = []
required_together_params: list[Sequence[str]] = []
required_together_params += DOCKER_REQUIRED_TOGETHER
if required_together:
required_together_params += required_together
@ -373,7 +439,14 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase):
needs_api_version=needs_api_version,
)
def call_cli(self, *args, check_rc=False, data=None, cwd=None, environ_update=None):
def call_cli(
self,
*args: str,
check_rc: bool = False,
data: bytes | None = None,
cwd: str | None = None,
environ_update: dict[str, str] | None = None,
) -> tuple[int, bytes, bytes]:
environment = self._environment.copy()
if environ_update:
environment.update(environ_update)
@ -390,14 +463,20 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase):
)
return rc, stdout, stderr
def fail(self, msg, **kwargs):
def fail(self, msg: str, **kwargs) -> t.NoReturn:
self.fail_results.update(kwargs)
self.module.fail_json(msg=msg, **sanitize_result(self.fail_results))
def warn(self, msg):
def warn(self, msg: str) -> None:
self.module.warn(msg)
def deprecate(self, msg, version=None, date=None, collection_name=None):
def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
self.module.deprecate(
msg, version=version, date=date, collection_name=collection_name
)

View File

@ -14,6 +14,7 @@ import re
import shutil
import tempfile
import traceback
import typing as t
from collections import namedtuple
from shlex import quote
@ -34,6 +35,7 @@ from ansible_collections.community.docker.plugins.module_utils._version import (
)
PYYAML_IMPORT_ERROR: None | str # pylint: disable=invalid-name
try:
import yaml
@ -41,7 +43,7 @@ try:
# use C version if possible for speedup
from yaml import CSafeDumper as _SafeDumper
except ImportError:
from yaml import SafeDumper as _SafeDumper
from yaml import SafeDumper as _SafeDumper # type: ignore
except ImportError:
HAS_PYYAML = False
PYYAML_IMPORT_ERROR = traceback.format_exc() # pylint: disable=invalid-name
@ -49,6 +51,13 @@ else:
HAS_PYYAML = True
PYYAML_IMPORT_ERROR = None # pylint: disable=invalid-name
if t.TYPE_CHECKING:
from collections.abc import Callable, Sequence
from ansible_collections.community.docker.plugins.module_utils._common_cli import (
AnsibleModuleDockerClient as _Client,
)
DOCKER_COMPOSE_FILES = (
"compose.yaml",
@ -144,8 +153,7 @@ class ResourceType:
SERVICE = "service"
@classmethod
def from_docker_compose_event(cls, resource_type):
# type: (Type[ResourceType], Text) -> Any
def from_docker_compose_event(cls, resource_type: str) -> t.Any:
return {
"Network": cls.NETWORK,
"Image": cls.IMAGE,
@ -240,7 +248,9 @@ _RE_BUILD_PROGRESS_EVENT = re.compile(r"^\s*==>\s+(?P<msg>.*)$")
MINIMUM_COMPOSE_VERSION = "2.18.0"
def _extract_event(line, warn_function=None):
def _extract_event(
line: str, warn_function: Callable[[str], None] | None = None
) -> tuple[Event | None, bool]:
match = _RE_RESOURCE_EVENT.match(line)
if match is not None:
status = match.group("status")
@ -323,7 +333,9 @@ def _extract_event(line, warn_function=None):
return None, False
def _extract_logfmt_event(line, warn_function=None):
def _extract_logfmt_event(
line: str, warn_function: Callable[[str], None] | None = None
) -> tuple[Event | None, bool]:
try:
result = _parse_logfmt_line(line, logrus_mode=True)
except _InvalidLogFmt:
@ -338,7 +350,11 @@ def _extract_logfmt_event(line, warn_function=None):
return None, False
def _warn_missing_dry_run_prefix(line, warn_missing_dry_run_prefix, warn_function):
def _warn_missing_dry_run_prefix(
line: str,
warn_missing_dry_run_prefix: bool,
warn_function: Callable[[str], None] | None,
) -> None:
if warn_missing_dry_run_prefix and warn_function:
# This could be a bug, a change of docker compose's output format, ...
# Tell the user to report it to us :-)
@ -349,7 +365,9 @@ def _warn_missing_dry_run_prefix(line, warn_missing_dry_run_prefix, warn_functio
)
def _warn_unparsable_line(line, warn_function):
def _warn_unparsable_line(
line: str, warn_function: Callable[[str], None] | None
) -> None:
# This could be a bug, a change of docker compose's output format, ...
# Tell the user to report it to us :-)
if warn_function:
@ -360,14 +378,16 @@ def _warn_unparsable_line(line, warn_function):
)
def _find_last_event_for(events, resource_id):
def _find_last_event_for(
events: list[Event], resource_id: str
) -> tuple[int, Event] | None:
for index, event in enumerate(reversed(events)):
if event.resource_id == resource_id:
return len(events) - 1 - index, event
return None
def _concat_event_msg(event, append_msg):
def _concat_event_msg(event: Event, append_msg: str) -> Event:
return Event(
event.resource_type,
event.resource_id,
@ -382,7 +402,9 @@ _JSON_LEVEL_TO_STATUS_MAP = {
}
def parse_json_events(stderr, warn_function=None):
def parse_json_events(
stderr: bytes, warn_function: Callable[[str], None] | None = None
) -> list[Event]:
events = []
stderr_lines = stderr.splitlines()
if stderr_lines and stderr_lines[-1] == b"":
@ -523,7 +545,12 @@ def parse_json_events(stderr, warn_function=None):
return events
def parse_events(stderr, dry_run=False, warn_function=None, nonzero_rc=False):
def parse_events(
stderr: bytes,
dry_run: bool = False,
warn_function: Callable[[str], None] | None = None,
nonzero_rc: bool = False,
) -> list[Event]:
events = []
error_event = None
stderr_lines = stderr.splitlines()
@ -597,7 +624,11 @@ def parse_events(stderr, dry_run=False, warn_function=None, nonzero_rc=False):
return events
def has_changes(events, ignore_service_pull_events=False, ignore_build_events=False):
def has_changes(
events: Sequence[Event],
ignore_service_pull_events: bool = False,
ignore_build_events: bool = False,
) -> bool:
for event in events:
if event.status in DOCKER_STATUS_WORKING:
if ignore_service_pull_events and event.status in DOCKER_STATUS_PULL:
@ -613,7 +644,7 @@ def has_changes(events, ignore_service_pull_events=False, ignore_build_events=Fa
return False
def extract_actions(events):
def extract_actions(events: Sequence[Event]) -> list[dict[str, t.Any]]:
actions = []
pull_actions = set()
for event in events:
@ -645,7 +676,9 @@ def extract_actions(events):
return actions
def emit_warnings(events, warn_function):
def emit_warnings(
events: Sequence[Event], warn_function: Callable[[str], None]
) -> None:
for event in events:
# If a message is present, assume it is a warning
if (
@ -656,13 +689,21 @@ def emit_warnings(events, warn_function):
)
def is_failed(events, rc):
def is_failed(events: Sequence[Event], rc: int) -> bool:
if rc:
return True
return False
def update_failed(result, events, args, stdout, stderr, rc, cli):
def update_failed(
result: dict[str, t.Any],
events: Sequence[Event],
args: list[str],
stdout: str | bytes,
stderr: str | bytes,
rc: int,
cli: str,
) -> bool:
if not rc:
return False
errors = []
@ -696,7 +737,7 @@ def update_failed(result, events, args, stdout, stderr, rc, cli):
return True
def common_compose_argspec():
def common_compose_argspec() -> dict[str, t.Any]:
return {
"project_src": {"type": "path"},
"project_name": {"type": "str"},
@ -708,7 +749,7 @@ def common_compose_argspec():
}
def common_compose_argspec_ex():
def common_compose_argspec_ex() -> dict[str, t.Any]:
return {
"argspec": common_compose_argspec(),
"mutually_exclusive": [("definition", "project_src"), ("definition", "files")],
@ -721,16 +762,18 @@ def common_compose_argspec_ex():
}
def combine_binary_output(*outputs):
def combine_binary_output(*outputs: bytes | None) -> bytes:
return b"\n".join(out for out in outputs if out)
def combine_text_output(*outputs):
def combine_text_output(*outputs: str | None) -> str:
return "\n".join(out for out in outputs if out)
class BaseComposeManager(DockerBaseClass):
def __init__(self, client, min_version=MINIMUM_COMPOSE_VERSION):
def __init__(
self, client: _Client, min_version: str = MINIMUM_COMPOSE_VERSION
) -> None:
super().__init__()
self.client = client
self.check_mode = self.client.check_mode
@ -794,12 +837,12 @@ class BaseComposeManager(DockerBaseClass):
# more precisely in https://github.com/docker/compose/pull/11478
self.use_json_events = self.compose_version >= LooseVersion("2.29.0")
def get_compose_version(self):
def get_compose_version(self) -> str:
return (
self.get_compose_version_from_cli() or self.get_compose_version_from_api()
)
def get_compose_version_from_cli(self):
def get_compose_version_from_cli(self) -> str | None:
rc, version_info, dummy_stderr = self.client.call_cli(
"compose", "version", "--format", "json"
)
@ -813,7 +856,7 @@ class BaseComposeManager(DockerBaseClass):
except Exception: # pylint: disable=broad-exception-caught
return None
def get_compose_version_from_api(self):
def get_compose_version_from_api(self) -> str:
compose = self.client.get_client_plugin_info("compose")
if compose is None:
self.fail(
@ -826,11 +869,11 @@ class BaseComposeManager(DockerBaseClass):
)
return compose["Version"].lstrip("v")
def fail(self, msg, **kwargs):
def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
self.cleanup()
self.client.fail(msg, **kwargs)
def get_base_args(self, plain_progress=False):
def get_base_args(self, plain_progress: bool = False) -> list[str]:
args = ["compose", "--ansi", "never"]
if self.use_json_events and not plain_progress:
args.extend(["--progress", "json"])
@ -848,28 +891,33 @@ class BaseComposeManager(DockerBaseClass):
args.extend(["--profile", profile])
return args
def _handle_failed_cli_call(self, args, rc, stdout, stderr):
def _handle_failed_cli_call(
self, args: list[str], rc: int, stdout: str | bytes, stderr: bytes
) -> t.NoReturn:
events = parse_json_events(stderr, warn_function=self.client.warn)
result = {}
result: dict[str, t.Any] = {}
self.update_failed(result, events, args, stdout, stderr, rc)
self.client.module.exit_json(**result)
def list_containers_raw(self):
def list_containers_raw(self) -> list[dict[str, t.Any]]:
args = self.get_base_args() + ["ps", "--format", "json", "--all"]
if self.compose_version >= LooseVersion("2.23.0"):
# https://github.com/docker/compose/pull/11038
args.append("--no-trunc")
kwargs = {"cwd": self.project_src, "check_rc": not self.use_json_events}
if self.compose_version >= LooseVersion("2.21.0"):
# Breaking change in 2.21.0: https://github.com/docker/compose/pull/10918
rc, containers, stderr = self.client.call_cli_json_stream(*args, **kwargs)
rc, containers, stderr = self.client.call_cli_json_stream(
*args, cwd=self.project_src, check_rc=not self.use_json_events
)
else:
rc, containers, stderr = self.client.call_cli_json(*args, **kwargs)
rc, containers, stderr = self.client.call_cli_json(
*args, cwd=self.project_src, check_rc=not self.use_json_events
)
if self.use_json_events and rc != 0:
self._handle_failed_cli_call(args, rc, containers, stderr)
self._handle_failed_cli_call(args, rc, json.dumps(containers), stderr)
return containers
def list_containers(self):
def list_containers(self) -> list[dict[str, t.Any]]:
result = []
for container in self.list_containers_raw():
labels = {}
@ -886,10 +934,11 @@ class BaseComposeManager(DockerBaseClass):
result.append(container)
return result
def list_images(self):
def list_images(self) -> list[str]:
args = self.get_base_args() + ["images", "--format", "json"]
kwargs = {"cwd": self.project_src, "check_rc": not self.use_json_events}
rc, images, stderr = self.client.call_cli_json(*args, **kwargs)
rc, images, stderr = self.client.call_cli_json(
*args, cwd=self.project_src, check_rc=not self.use_json_events
)
if self.use_json_events and rc != 0:
self._handle_failed_cli_call(args, rc, images, stderr)
if isinstance(images, dict):
@ -899,7 +948,9 @@ class BaseComposeManager(DockerBaseClass):
images = list(images.values())
return images
def parse_events(self, stderr, dry_run=False, nonzero_rc=False):
def parse_events(
self, stderr: bytes, dry_run: bool = False, nonzero_rc: bool = False
) -> list[Event]:
if self.use_json_events:
return parse_json_events(stderr, warn_function=self.client.warn)
return parse_events(
@ -909,17 +960,17 @@ class BaseComposeManager(DockerBaseClass):
nonzero_rc=nonzero_rc,
)
def emit_warnings(self, events):
def emit_warnings(self, events: Sequence[Event]) -> None:
emit_warnings(events, warn_function=self.client.warn)
def update_result(
self,
result,
events,
stdout,
stderr,
ignore_service_pull_events=False,
ignore_build_events=False,
result: dict[str, t.Any],
events: Sequence[Event],
stdout: str | bytes,
stderr: str | bytes,
ignore_service_pull_events: bool = False,
ignore_build_events: bool = False,
):
result["changed"] = result.get("changed", False) or has_changes(
events,
@ -930,7 +981,15 @@ class BaseComposeManager(DockerBaseClass):
result["stdout"] = combine_text_output(result.get("stdout"), to_native(stdout))
result["stderr"] = combine_text_output(result.get("stderr"), to_native(stderr))
def update_failed(self, result, events, args, stdout, stderr, rc):
def update_failed(
self,
result: dict[str, t.Any],
events: Sequence[Event],
args: list[str],
stdout: str | bytes,
stderr: bytes,
rc: int,
):
return update_failed(
result,
events,
@ -941,14 +1000,14 @@ class BaseComposeManager(DockerBaseClass):
cli=self.client.get_cli(),
)
def cleanup_result(self, result):
def cleanup_result(self, result: dict[str, t.Any]) -> None:
if not result.get("failed"):
# Only return stdout and stderr if it is not empty
for res in ("stdout", "stderr"):
if result.get(res) == "":
result.pop(res)
def cleanup(self):
def cleanup(self) -> None:
for directory in self.cleanup_dirs:
try:
shutil.rmtree(directory, True)

View File

@ -16,6 +16,7 @@ import os.path
import shutil
import stat
import tarfile
import typing as t
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
@ -25,6 +26,16 @@ from ansible_collections.community.docker.plugins.module_utils._api.errors impor
)
if t.TYPE_CHECKING:
from collections.abc import Callable
from _typeshed import WriteableBuffer
from ansible_collections.community.docker.plugins.module_utils._api.api.client import (
APIClient,
)
class DockerFileCopyError(Exception):
pass
@ -37,7 +48,9 @@ class DockerFileNotFound(DockerFileCopyError):
pass
def _put_archive(client, container, path, data):
def _put_archive(
client: APIClient, container: str, path: str, data: bytes | t.Generator[bytes]
) -> bool:
# data can also be file object for streaming. This is because _put uses requests's put().
# See https://requests.readthedocs.io/en/latest/user/advanced/#streaming-uploads
url = client._url("/containers/{0}/archive", container)
@ -47,8 +60,14 @@ def _put_archive(client, container, path, data):
def _symlink_tar_creator(
b_in_path, file_stat, out_file, user_id, group_id, mode=None, user_name=None
):
b_in_path: bytes,
file_stat: os.stat_result,
out_file: str | bytes,
user_id: int,
group_id: int,
mode: int | None = None,
user_name: str | None = None,
) -> bytes:
if not stat.S_ISLNK(file_stat.st_mode):
raise DockerUnexpectedError("stat information is not for a symlink")
bio = io.BytesIO()
@ -75,16 +94,28 @@ def _symlink_tar_creator(
def _symlink_tar_generator(
b_in_path, file_stat, out_file, user_id, group_id, mode=None, user_name=None
):
b_in_path: bytes,
file_stat: os.stat_result,
out_file: str | bytes,
user_id: int,
group_id: int,
mode: int | None = None,
user_name: str | None = None,
) -> t.Generator[bytes]:
yield _symlink_tar_creator(
b_in_path, file_stat, out_file, user_id, group_id, mode, user_name
)
def _regular_file_tar_generator(
b_in_path, file_stat, out_file, user_id, group_id, mode=None, user_name=None
):
b_in_path: bytes,
file_stat: os.stat_result,
out_file: str | bytes,
user_id: int,
group_id: int,
mode: int | None = None,
user_name: str | None = None,
) -> t.Generator[bytes]:
if not stat.S_ISREG(file_stat.st_mode):
raise DockerUnexpectedError("stat information is not for a regular file")
tarinfo = tarfile.TarInfo()
@ -136,8 +167,13 @@ def _regular_file_tar_generator(
def _regular_content_tar_generator(
content, out_file, user_id, group_id, mode, user_name=None
):
content: bytes,
out_file: str | bytes,
user_id: int,
group_id: int,
mode: int,
user_name: str | None = None,
) -> t.Generator[bytes]:
tarinfo = tarfile.TarInfo()
tarinfo.name = (
os.path.splitdrive(to_text(out_file))[1].replace(os.sep, "/").lstrip("/")
@ -175,16 +211,16 @@ def _regular_content_tar_generator(
def put_file(
client,
container,
in_path,
out_path,
user_id,
group_id,
mode=None,
user_name=None,
follow_links=False,
):
client: APIClient,
container: str,
in_path: str,
out_path: str,
user_id: int,
group_id: int,
mode: int | None = None,
user_name: str | None = None,
follow_links: bool = False,
) -> None:
"""Transfer a file from local to Docker container."""
if not os.path.exists(to_bytes(in_path, errors="surrogate_or_strict")):
raise DockerFileNotFound(f"file or module does not exist: {to_native(in_path)}")
@ -232,8 +268,15 @@ def put_file(
def put_file_content(
client, container, content, out_path, user_id, group_id, mode, user_name=None
):
client: APIClient,
container: str,
content: bytes,
out_path: str,
user_id: int,
group_id: int,
mode: int,
user_name: str | None = None,
) -> None:
"""Transfer a file from local to Docker container."""
out_dir, out_file = os.path.split(out_path)
@ -248,7 +291,13 @@ def put_file_content(
)
def stat_file(client, container, in_path, follow_links=False, log=None):
def stat_file(
client: APIClient,
container: str,
in_path: str,
follow_links: bool = False,
log: Callable[[str], None] | None = None,
) -> tuple[str, dict[str, t.Any] | None, str | None]:
"""Fetch information on a file from a Docker container to local.
Return a tuple ``(path, stat_data, link_target)`` where:
@ -265,12 +314,12 @@ def stat_file(client, container, in_path, follow_links=False, log=None):
while True:
if in_path in considered_in_paths:
raise DockerFileCopyError(
f'Found infinite symbolic link loop when trying to stating "{in_path}"'
f"Found infinite symbolic link loop when trying to stating {in_path!r}"
)
considered_in_paths.add(in_path)
if log:
log(f'FETCH: Stating "{in_path}"')
log(f"FETCH: Stating {in_path!r}")
response = client._head(
client._url("/containers/{0}/archive", container),
@ -299,24 +348,24 @@ def stat_file(client, container, in_path, follow_links=False, log=None):
class _RawGeneratorFileobj(io.RawIOBase):
def __init__(self, stream):
def __init__(self, stream: t.Generator[bytes]):
self._stream = stream
self._buf = b""
def readable(self):
def readable(self) -> bool:
return True
def _readinto_from_buf(self, b, index, length):
def _readinto_from_buf(self, b: WriteableBuffer, index: int, length: int) -> int:
cpy = min(length - index, len(self._buf))
if cpy:
b[index : index + cpy] = self._buf[:cpy]
b[index : index + cpy] = self._buf[:cpy] # type: ignore # TODO!
self._buf = self._buf[cpy:]
index += cpy
return index
def readinto(self, b):
def readinto(self, b: WriteableBuffer) -> int:
index = 0
length = len(b)
length = len(b) # type: ignore # TODO!
index = self._readinto_from_buf(b, index, length)
if index == length:
@ -330,25 +379,28 @@ class _RawGeneratorFileobj(io.RawIOBase):
return self._readinto_from_buf(b, index, length)
def _stream_generator_to_fileobj(stream):
def _stream_generator_to_fileobj(stream: t.Generator[bytes]) -> io.BufferedReader:
"""Given a generator that generates chunks of bytes, create a readable buffered stream."""
raw = _RawGeneratorFileobj(stream)
return io.BufferedReader(raw)
_T = t.TypeVar("_T")
def fetch_file_ex(
client,
container,
in_path,
process_none,
process_regular,
process_symlink,
process_other,
follow_links=False,
log=None,
):
client: APIClient,
container: str,
in_path: str,
process_none: Callable[[str], _T],
process_regular: Callable[[str, tarfile.TarFile, tarfile.TarInfo], _T],
process_symlink: Callable[[str, tarfile.TarInfo], _T],
process_other: Callable[[str, tarfile.TarInfo], _T],
follow_links: bool = False,
log: Callable[[str], None] | None = None,
) -> _T:
"""Fetch a file (as a tar file entry) from a Docker container to local."""
considered_in_paths = set()
considered_in_paths: set[str] = set()
while True:
if in_path in considered_in_paths:
@ -372,8 +424,8 @@ def fetch_file_ex(
with tarfile.open(
fileobj=_stream_generator_to_fileobj(stream), mode="r|"
) as tar:
symlink_member = None
result = None
symlink_member: tarfile.TarInfo | None = None
result: _T | None = None
found = False
for member in tar:
if found:
@ -398,35 +450,46 @@ def fetch_file_ex(
log(f'FETCH: Following symbolic link to "{in_path}"')
continue
if found:
return result
return result # type: ignore
raise DockerUnexpectedError("Received tarfile is empty!")
def fetch_file(client, container, in_path, out_path, follow_links=False, log=None):
def fetch_file(
client: APIClient,
container: str,
in_path: str,
out_path: str,
follow_links: bool = False,
log: Callable[[str], None] | None = None,
) -> str:
b_out_path = to_bytes(out_path, errors="surrogate_or_strict")
def process_none(in_path):
def process_none(in_path: str) -> str:
raise DockerFileNotFound(
f"File {in_path} does not exist in container {container}"
)
def process_regular(in_path, tar, member):
def process_regular(
in_path: str, tar: tarfile.TarFile, member: tarfile.TarInfo
) -> str:
if not follow_links and os.path.exists(b_out_path):
os.unlink(b_out_path)
with tar.extractfile(member) as in_f:
with open(b_out_path, "wb") as out_f:
shutil.copyfileobj(in_f, out_f)
reader = tar.extractfile(member)
if reader:
with reader as in_f:
with open(b_out_path, "wb") as out_f:
shutil.copyfileobj(in_f, out_f)
return in_path
def process_symlink(in_path, member):
def process_symlink(in_path, member) -> str:
if os.path.exists(b_out_path):
os.unlink(b_out_path)
os.symlink(member.linkname, b_out_path)
return in_path
def process_other(in_path, member):
def process_other(in_path, member) -> str:
raise DockerFileCopyError(
f'Remote file "{in_path}" is not a regular file or a symbolic link'
)
@ -444,7 +507,13 @@ def fetch_file(client, container, in_path, out_path, follow_links=False, log=Non
)
def _execute_command(client, container, command, log=None, check_rc=False):
def _execute_command(
client: APIClient,
container: str,
command: list[str],
log: Callable[[str], None] | None = None,
check_rc: bool = False,
) -> tuple[int, bytes, bytes]:
if log:
log(f"Executing {command} in {container}")
@ -483,7 +552,7 @@ def _execute_command(client, container, command, log=None, check_rc=False):
result = client.get_json("/exec/{0}/json", exec_id)
rc = result.get("ExitCode") or 0
rc: int = result.get("ExitCode") or 0
stdout = stdout or b""
stderr = stderr or b""
@ -493,13 +562,15 @@ def _execute_command(client, container, command, log=None, check_rc=False):
if check_rc and rc != 0:
command_str = " ".join(command)
raise DockerUnexpectedError(
f'Obtained unexpected exit code {rc} when running "{command_str}" in {container}.\nSTDOUT: {stdout}\nSTDERR: {stderr}'
f'Obtained unexpected exit code {rc} when running "{command_str}" in {container}.\nSTDOUT: {stdout!r}\nSTDERR: {stderr!r}'
)
return rc, stdout, stderr
def determine_user_group(client, container, log=None):
def determine_user_group(
client: APIClient, container: str, log: Callable[[str], None] | None = None
) -> tuple[int, int]:
dummy_rc, stdout, dummy_stderr = _execute_command(
client, container, ["/bin/sh", "-c", "id -u && id -g"], check_rc=True, log=log
)
@ -507,7 +578,7 @@ def determine_user_group(client, container, log=None):
stdout_lines = stdout.splitlines()
if len(stdout_lines) != 2:
raise DockerUnexpectedError(
f"Expected two-line output to obtain user and group ID for container {container}, but got {len(stdout_lines)} lines:\n{stdout}"
f"Expected two-line output to obtain user and group ID for container {container}, but got {len(stdout_lines)} lines:\n{stdout!r}"
)
user_id, group_id = stdout_lines
@ -515,5 +586,5 @@ def determine_user_group(client, container, log=None):
return int(user_id), int(group_id)
except ValueError as exc:
raise DockerUnexpectedError(
f'Expected two-line output with numeric IDs to obtain user and group ID for container {container}, but got "{user_id}" and "{group_id}" instead'
f"Expected two-line output with numeric IDs to obtain user and group ID for container {container}, but got {user_id!r} and {group_id!r} instead"
) from exc

View File

@ -18,12 +18,10 @@ class ImageArchiveManifestSummary:
"docker image save some:tag > some.tar" command.
"""
def __init__(self, image_id, repo_tags):
def __init__(self, image_id: str, repo_tags: list[str]) -> None:
"""
:param image_id: File name portion of Config entry, e.g. abcde12345 from abcde12345.json
:type image_id: str
:param repo_tags Docker image names, e.g. ["hello-world:latest"]
:type repo_tags: list[str]
"""
self.image_id = image_id
@ -34,22 +32,21 @@ class ImageArchiveInvalidException(Exception):
pass
def api_image_id(archive_image_id):
def api_image_id(archive_image_id: str) -> str:
"""
Accepts an image hash in the format stored in manifest.json, and returns an equivalent identifier
that represents the same image hash, but in the format presented by the Docker Engine API.
:param archive_image_id: plain image hash
:type archive_image_id: str
:returns: Prefixed hash used by REST api
:rtype: str
"""
return f"sha256:{archive_image_id}"
def load_archived_image_manifest(archive_path):
def load_archived_image_manifest(
archive_path: str,
) -> list[ImageArchiveManifestSummary] | None:
"""
Attempts to get image IDs and image names from metadata stored in the image
archive tar file.
@ -62,10 +59,7 @@ def load_archived_image_manifest(archive_path):
ImageArchiveInvalidException: A file already exists at archive_path, but could not extract an image ID from it.
:param archive_path: Tar file to read
:type archive_path: str
:return: None, if no file at archive_path, or a list of ImageArchiveManifestSummary objects.
:rtype: ImageArchiveManifestSummary
"""
try:
@ -76,8 +70,15 @@ def load_archived_image_manifest(archive_path):
with tarfile.open(archive_path, "r") as tf:
try:
try:
with tf.extractfile("manifest.json") as ef:
reader = tf.extractfile("manifest.json")
if reader is None:
raise ImageArchiveInvalidException(
"Failed to read manifest.json"
)
with reader as ef:
manifest = json.load(ef)
except ImageArchiveInvalidException:
raise
except Exception as exc:
raise ImageArchiveInvalidException(
f"Failed to decode and deserialize manifest.json: {exc}"
@ -139,7 +140,7 @@ def load_archived_image_manifest(archive_path):
) from exc
def archived_image_manifest(archive_path):
def archived_image_manifest(archive_path: str) -> ImageArchiveManifestSummary | None:
"""
Attempts to get Image.Id and image name from metadata stored in the image
archive tar file.
@ -152,10 +153,7 @@ def archived_image_manifest(archive_path):
ImageArchiveInvalidException: A file already exists at archive_path, but could not extract an image ID from it.
:param archive_path: Tar file to read
:type archive_path: str
:return: None, if no file at archive_path, or the extracted image ID, which will not have a sha256: prefix.
:rtype: ImageArchiveManifestSummary
"""
results = load_archived_image_manifest(archive_path)

View File

@ -13,6 +13,9 @@ See https://pkg.go.dev/github.com/kr/logfmt?utm_source=godoc for information on
from __future__ import annotations
import typing as t
from enum import Enum
# The format is defined in https://pkg.go.dev/github.com/kr/logfmt?utm_source=godoc
# (look for "EBNFish")
@ -22,7 +25,7 @@ class InvalidLogFmt(Exception):
pass
class _Mode:
class _Mode(Enum):
GARBAGE = 0
KEY = 1
EQUAL = 2
@ -68,29 +71,29 @@ _HEX_DICT = {
}
def _is_ident(cur):
def _is_ident(cur: str) -> bool:
return cur > " " and cur not in ('"', "=")
class _Parser:
def __init__(self, line):
def __init__(self, line: str) -> None:
self.line = line
self.index = 0
self.length = len(line)
def done(self):
def done(self) -> bool:
return self.index >= self.length
def cur(self):
def cur(self) -> str:
return self.line[self.index]
def next(self):
def next(self) -> None:
self.index += 1
def prev(self):
def prev(self) -> None:
self.index -= 1
def parse_unicode_sequence(self):
def parse_unicode_sequence(self) -> str:
if self.index + 6 > self.length:
raise InvalidLogFmt("Not enough space for unicode escape")
if self.line[self.index : self.index + 2] != "\\u":
@ -108,27 +111,27 @@ class _Parser:
return chr(v)
def parse_line(line, logrus_mode=False):
result = {}
def parse_line(line: str, logrus_mode: bool = False) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
parser = _Parser(line)
key = []
value = []
key: list[str] = []
value: list[str] = []
mode = _Mode.GARBAGE
def handle_kv(has_no_value=False):
def handle_kv(has_no_value: bool = False) -> None:
k = "".join(key)
v = None if has_no_value else "".join(value)
result[k] = v
del key[:]
del value[:]
def parse_garbage(cur):
def parse_garbage(cur: str) -> _Mode:
if _is_ident(cur):
return _Mode.KEY
parser.next()
return _Mode.GARBAGE
def parse_key(cur):
def parse_key(cur: str) -> _Mode:
if _is_ident(cur):
key.append(cur)
parser.next()
@ -142,7 +145,7 @@ def parse_line(line, logrus_mode=False):
parser.next()
return _Mode.GARBAGE
def parse_equal(cur):
def parse_equal(cur: str) -> _Mode:
if _is_ident(cur):
value.append(cur)
parser.next()
@ -154,7 +157,7 @@ def parse_line(line, logrus_mode=False):
parser.next()
return _Mode.GARBAGE
def parse_ident_value(cur):
def parse_ident_value(cur: str) -> _Mode:
if _is_ident(cur):
value.append(cur)
parser.next()
@ -163,7 +166,7 @@ def parse_line(line, logrus_mode=False):
parser.next()
return _Mode.GARBAGE
def parse_quoted_value(cur):
def parse_quoted_value(cur: str) -> _Mode:
if cur == "\\":
parser.next()
if parser.done():

View File

@ -12,6 +12,7 @@ import abc
import os
import re
import shlex
import typing as t
from functools import partial
from ansible.module_utils.common.text.converters import to_text
@ -32,6 +33,23 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
)
if t.TYPE_CHECKING:
from collections.abc import Callable, Sequence
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.docker.plugins.module_utils._version import (
LooseVersion,
)
ValueType = t.Literal["set", "list", "dict", "bool", "int", "float", "str"]
AnsibleType = t.Literal["list", "dict", "bool", "int", "float", "str"]
ComparisonMode = t.Literal["ignore", "strict", "allow_more_present"]
ComparisonType = t.Literal["set", "set(dict)", "list", "dict", "value"]
Client = t.TypeVar("Client")
_DEFAULT_IP_REPLACEMENT_STRING = (
"[[DEFAULT_IP:iewahhaeB4Sae6Aen8IeShairoh4zeph7xaekoh8Geingunaesaeweiy3ooleiwi]]"
)
@ -54,7 +72,9 @@ _MOUNT_OPTION_TYPES = {
}
def _get_ansible_type(value_type):
def _get_ansible_type(
value_type: ValueType,
) -> AnsibleType:
if value_type == "set":
return "list"
if value_type not in ("list", "dict", "bool", "int", "float", "str"):
@ -65,21 +85,22 @@ def _get_ansible_type(value_type):
class Option:
def __init__(
self,
name,
value_type,
owner,
ansible_type=None,
elements=None,
ansible_elements=None,
ansible_suboptions=None,
ansible_aliases=None,
ansible_choices=None,
needs_no_suboptions=False,
default_comparison=None,
not_a_container_option=False,
not_an_ansible_option=False,
copy_comparison_from=None,
compare=None,
name: str,
*,
value_type: ValueType,
owner: OptionGroup,
ansible_type: AnsibleType | None = None,
elements: ValueType | None = None,
ansible_elements: AnsibleType | None = None,
ansible_suboptions: dict[str, t.Any] | None = None,
ansible_aliases: Sequence[str] | None = None,
ansible_choices: Sequence[str] | None = None,
needs_no_suboptions: bool = False,
default_comparison: ComparisonMode | None = None,
not_a_container_option: bool = False,
not_an_ansible_option: bool = False,
copy_comparison_from: str | None = None,
compare: Callable[[Option, t.Any, t.Any], bool] | None = None,
):
self.name = name
self.value_type = value_type
@ -95,8 +116,8 @@ class Option:
if (elements is None and ansible_elements is None) and needs_ansible_elements:
raise ValueError("Ansible elements required for Ansible lists")
self.elements = elements if needs_elements else None
self.ansible_elements = (
(ansible_elements or _get_ansible_type(elements))
self.ansible_elements: AnsibleType | None = (
(ansible_elements or _get_ansible_type(elements or "str"))
if needs_ansible_elements
else None
)
@ -119,10 +140,12 @@ class Option:
self.ansible_suboptions = ansible_suboptions if needs_suboptions else None
self.ansible_aliases = ansible_aliases or []
self.ansible_choices = ansible_choices
comparison_type = self.value_type
if comparison_type == "set" and self.elements == "dict":
comparison_type: ComparisonType
if self.value_type == "set" and self.elements == "dict":
comparison_type = "set(dict)"
elif comparison_type not in ("set", "list", "dict"):
elif self.value_type in ("set", "list", "dict"):
comparison_type = self.value_type # type: ignore
else:
comparison_type = "value"
self.comparison_type = comparison_type
if default_comparison is not None:
@ -152,36 +175,45 @@ class Option:
class OptionGroup:
def __init__(
self,
preprocess=None,
ansible_mutually_exclusive=None,
ansible_required_together=None,
ansible_required_one_of=None,
ansible_required_if=None,
ansible_required_by=None,
):
*,
preprocess: (
Callable[[AnsibleModule, dict[str, t.Any]], dict[str, t.Any]] | None
) = None,
ansible_mutually_exclusive: Sequence[Sequence[str]] | None = None,
ansible_required_together: Sequence[Sequence[str]] | None = None,
ansible_required_one_of: Sequence[Sequence[str]] | None = None,
ansible_required_if: (
Sequence[
tuple[str, t.Any, Sequence[str]]
| tuple[str, t.Any, Sequence[str], bool]
]
| None
) = None,
ansible_required_by: dict[str, Sequence[str]] | None = None,
) -> None:
if preprocess is None:
def preprocess(module, values):
return values
self.preprocess = preprocess
self.options = []
self.all_options = []
self.engines = {}
self.options: list[Option] = []
self.all_options: list[Option] = []
self.engines: dict[str, Engine] = {}
self.ansible_mutually_exclusive = ansible_mutually_exclusive or []
self.ansible_required_together = ansible_required_together or []
self.ansible_required_one_of = ansible_required_one_of or []
self.ansible_required_if = ansible_required_if or []
self.ansible_required_by = ansible_required_by or {}
self.argument_spec = {}
self.argument_spec: dict[str, t.Any] = {}
def add_option(self, *args, **kwargs):
def add_option(self, *args, **kwargs) -> OptionGroup:
option = Option(*args, owner=self, **kwargs)
if not option.not_a_container_option:
self.options.append(option)
self.all_options.append(option)
if not option.not_an_ansible_option:
ansible_option = {
ansible_option: dict[str, t.Any] = {
"type": option.ansible_type,
}
if option.ansible_elements is not None:
@ -195,213 +227,297 @@ class OptionGroup:
self.argument_spec[option.name] = ansible_option
return self
def supports_engine(self, engine_name):
def supports_engine(self, engine_name: str) -> bool:
return engine_name in self.engines
def get_engine(self, engine_name):
def get_engine(self, engine_name: str) -> Engine:
return self.engines[engine_name]
def add_engine(self, engine_name, engine):
def add_engine(self, engine_name: str, engine: Engine) -> OptionGroup:
self.engines[engine_name] = engine
return self
class Engine:
min_api_version = None # string or None
min_api_version_obj = None # LooseVersion object or None
extra_option_minimal_versions = None # dict[str, dict[str, Any]] or None
class Engine(t.Generic[Client]):
min_api_version: str | None = None
min_api_version_obj: LooseVersion | None = None
extra_option_minimal_versions: dict[str, dict[str, t.Any]] | None = None
@abc.abstractmethod
def get_value(self, module, container, api_version, options, image, host_info):
def get_value(
self,
module: AnsibleModule,
container: dict[str, t.Any],
api_version: LooseVersion,
options: list[Option],
image: dict[str, t.Any] | None,
host_info: dict[str, t.Any] | None,
) -> dict[str, t.Any]:
pass
def compare_value(self, option, param_value, container_value):
def compare_value(
self, option: Option, param_value: t.Any, container_value: t.Any
) -> bool:
return option.compare(param_value, container_value)
@abc.abstractmethod
def set_value(self, module, data, api_version, options, values):
def set_value(
self,
module: AnsibleModule,
data: dict[str, t.Any],
api_version: LooseVersion,
options: list[Option],
values: dict[str, t.Any],
) -> None:
pass
@abc.abstractmethod
def get_expected_values(
self, module, client, api_version, options, image, values, host_info
):
self,
module: AnsibleModule,
client: Client,
api_version: LooseVersion,
options: list[Option],
image: dict[str, t.Any] | None,
values: dict[str, t.Any],
host_info: dict[str, t.Any] | None,
) -> dict[str, t.Any]:
pass
@abc.abstractmethod
def ignore_mismatching_result(
self,
module,
client,
api_version,
option,
image,
container_value,
expected_value,
host_info,
):
module: AnsibleModule,
client: Client,
api_version: LooseVersion,
option: Option,
image: dict[str, t.Any] | None,
container_value: t.Any,
expected_value: t.Any,
host_info: dict[str, t.Any] | None,
) -> bool:
pass
@abc.abstractmethod
def preprocess_value(self, module, client, api_version, options, values):
def preprocess_value(
self,
module: AnsibleModule,
client: Client,
api_version: LooseVersion,
options: list[Option],
values: dict[str, t.Any],
) -> dict[str, t.Any]:
pass
@abc.abstractmethod
def update_value(self, module, data, api_version, options, values):
def update_value(
self,
module: AnsibleModule,
data: dict[str, t.Any],
api_version: LooseVersion,
options: list[Option],
values: dict[str, t.Any],
) -> None:
pass
@abc.abstractmethod
def can_set_value(self, api_version):
def can_set_value(self, api_version: LooseVersion) -> bool:
pass
@abc.abstractmethod
def can_update_value(self, api_version):
def can_update_value(self, api_version: LooseVersion) -> bool:
pass
@abc.abstractmethod
def needs_container_image(self, values):
def needs_container_image(self, values: dict[str, t.Any]) -> bool:
pass
@abc.abstractmethod
def needs_host_info(self, values):
def needs_host_info(self, values: dict[str, t.Any]) -> bool:
pass
class EngineDriver:
name = None # string
class EngineDriver(t.Generic[Client]):
name: str
@abc.abstractmethod
def setup(
self,
argument_spec,
mutually_exclusive=None,
required_together=None,
required_one_of=None,
required_if=None,
required_by=None,
):
# Return (module, active_options, client)
argument_spec: dict[str, t.Any],
mutually_exclusive: Sequence[Sequence[str]] | None = None,
required_together: Sequence[Sequence[str]] | None = None,
required_one_of: Sequence[Sequence[str]] | None = None,
required_if: (
Sequence[
tuple[str, t.Any, Sequence[str]]
| tuple[str, t.Any, Sequence[str], bool]
]
| None
) = None,
required_by: dict[str, Sequence[str]] | None = None,
) -> tuple[AnsibleModule, list[OptionGroup], Client]:
pass
@abc.abstractmethod
def get_host_info(self, client):
def get_host_info(self, client: Client) -> dict[str, t.Any]:
pass
@abc.abstractmethod
def get_api_version(self, client):
def get_api_version(self, client: Client) -> LooseVersion:
pass
@abc.abstractmethod
def get_container_id(self, container):
def get_container_id(self, container: dict[str, t.Any]) -> str:
pass
@abc.abstractmethod
def get_image_from_container(self, container):
def get_image_from_container(self, container: dict[str, t.Any]) -> str:
pass
@abc.abstractmethod
def get_image_name_from_container(self, container):
def get_image_name_from_container(self, container: dict[str, t.Any]) -> str | None:
pass
@abc.abstractmethod
def is_container_removing(self, container):
def is_container_removing(self, container: dict[str, t.Any]) -> bool:
pass
@abc.abstractmethod
def is_container_running(self, container):
def is_container_running(self, container: dict[str, t.Any]) -> bool:
pass
@abc.abstractmethod
def is_container_paused(self, container):
def is_container_paused(self, container: dict[str, t.Any]) -> bool:
pass
@abc.abstractmethod
def inspect_container_by_name(self, client, container_name):
def inspect_container_by_name(
self, client: Client, container_name: str
) -> dict[str, t.Any] | None:
pass
@abc.abstractmethod
def inspect_container_by_id(self, client, container_id):
def inspect_container_by_id(
self, client: Client, container_id: str
) -> dict[str, t.Any] | None:
pass
@abc.abstractmethod
def inspect_image_by_id(self, client, image_id):
def inspect_image_by_id(
self, client: Client, image_id: str
) -> dict[str, t.Any] | None:
pass
@abc.abstractmethod
def inspect_image_by_name(self, client, repository, tag):
def inspect_image_by_name(
self, client: Client, repository: str, tag: str
) -> dict[str, t.Any] | None:
pass
@abc.abstractmethod
def pull_image(self, client, repository, tag, image_platform=None):
def pull_image(
self,
client: Client,
repository: str,
tag: str,
image_platform: str | None = None,
) -> tuple[dict[str, t.Any] | None, bool]:
pass
@abc.abstractmethod
def pause_container(self, client, container_id):
def pause_container(self, client: Client, container_id: str) -> None:
pass
@abc.abstractmethod
def unpause_container(self, client, container_id):
def unpause_container(self, client: Client, container_id: str) -> None:
pass
@abc.abstractmethod
def disconnect_container_from_network(self, client, container_id, network_id):
def disconnect_container_from_network(
self, client: Client, container_id: str, network_id: str
) -> None:
pass
@abc.abstractmethod
def connect_container_to_network(
self, client, container_id, network_id, parameters=None
):
self,
client: Client,
container_id: str,
network_id: str,
parameters: dict[str, t.Any] | None = None,
) -> None:
pass
def create_container_supports_more_than_one_network(self, client):
def create_container_supports_more_than_one_network(self, client: Client) -> bool:
return False
@abc.abstractmethod
def create_container(
self, client, container_name, create_parameters, networks=None
):
self,
client: Client,
container_name: str,
create_parameters: dict[str, t.Any],
networks: dict[str, dict[str, t.Any]] | None = None,
) -> str:
pass
@abc.abstractmethod
def start_container(self, client, container_id):
def start_container(self, client: Client, container_id: str) -> None:
pass
@abc.abstractmethod
def wait_for_container(self, client, container_id, timeout=None):
def wait_for_container(
self, client: Client, container_id: str, timeout: int | float | None = None
) -> int | None:
pass
@abc.abstractmethod
def get_container_output(self, client, container_id):
def get_container_output(
self, client: Client, container_id: str
) -> tuple[bytes, t.Literal[True]] | tuple[str, t.Literal[False]]:
pass
@abc.abstractmethod
def update_container(self, client, container_id, update_parameters):
def update_container(
self, client: Client, container_id: str, update_parameters: dict[str, t.Any]
) -> None:
pass
@abc.abstractmethod
def restart_container(self, client, container_id, timeout=None):
def restart_container(
self, client: Client, container_id: str, timeout: int | float | None = None
) -> None:
pass
@abc.abstractmethod
def kill_container(self, client, container_id, kill_signal=None):
def kill_container(
self, client: Client, container_id: str, kill_signal: str | None = None
) -> None:
pass
@abc.abstractmethod
def stop_container(self, client, container_id, timeout=None):
def stop_container(
self, client: Client, container_id: str, timeout: int | float | None = None
) -> None:
pass
@abc.abstractmethod
def remove_container(
self, client, container_id, remove_volumes=False, link=False, force=False
):
self,
client: Client,
container_id: str,
remove_volumes: bool = False,
link: bool = False,
force: bool = False,
) -> None:
pass
@abc.abstractmethod
def run(self, runner, client):
def run(self, runner: Callable[[], None], client: Client) -> None:
pass
def _is_volume_permissions(mode):
def _is_volume_permissions(mode: str) -> bool:
for part in mode.split(","):
if part not in (
"rw",
@ -423,7 +539,7 @@ def _is_volume_permissions(mode):
return True
def _parse_port_range(range_or_port, module):
def _parse_port_range(range_or_port: str, module: AnsibleModule) -> list[int]:
"""
Parses a string containing either a single port or a range of ports.
@ -443,7 +559,7 @@ def _parse_port_range(range_or_port, module):
module.fail_json(msg=f'Invalid port: "{range_or_port}"')
def _split_colon_ipv6(text, module):
def _split_colon_ipv6(text: str, module: AnsibleModule) -> list[str]:
"""
Split string by ':', while keeping IPv6 addresses in square brackets in one component.
"""
@ -475,7 +591,9 @@ def _split_colon_ipv6(text, module):
return result
def _preprocess_command(module, values):
def _preprocess_command(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
if "command" not in values:
return values
value = values["command"]
@ -502,7 +620,9 @@ def _preprocess_command(module, values):
}
def _preprocess_entrypoint(module, values):
def _preprocess_entrypoint(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
if "entrypoint" not in values:
return values
value = values["entrypoint"]
@ -522,7 +642,9 @@ def _preprocess_entrypoint(module, values):
}
def _preprocess_env(module, values):
def _preprocess_env(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
if not values:
return {}
final_env = {}
@ -546,7 +668,9 @@ def _preprocess_env(module, values):
}
def _preprocess_healthcheck(module, values):
def _preprocess_healthcheck(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
if not values:
return {}
return {
@ -556,7 +680,12 @@ def _preprocess_healthcheck(module, values):
}
def _preprocess_convert_to_bytes(module, values, name, unlimited_value=None):
def _preprocess_convert_to_bytes(
module: AnsibleModule,
values: dict[str, t.Any],
name: str,
unlimited_value: int | None = None,
) -> dict[str, t.Any]:
if name not in values:
return values
try:
@ -571,7 +700,9 @@ def _preprocess_convert_to_bytes(module, values, name, unlimited_value=None):
module.fail_json(msg=f"Failed to convert {name} to bytes: {exc}")
def _preprocess_mac_address(module, values):
def _preprocess_mac_address(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
if "mac_address" not in values:
return values
return {
@ -579,7 +710,9 @@ def _preprocess_mac_address(module, values):
}
def _preprocess_networks(module, values):
def _preprocess_networks(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
if (
module.params["networks_cli_compatible"] is True
and values.get("networks")
@ -605,14 +738,18 @@ def _preprocess_networks(module, values):
return values
def _preprocess_sysctls(module, values):
def _preprocess_sysctls(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
if "sysctls" in values:
for key, value in values["sysctls"].items():
values["sysctls"][key] = to_text(value, errors="surrogate_or_strict")
return values
def _preprocess_tmpfs(module, values):
def _preprocess_tmpfs(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
if "tmpfs" not in values:
return values
result = {}
@ -625,7 +762,9 @@ def _preprocess_tmpfs(module, values):
return {"tmpfs": result}
def _preprocess_ulimits(module, values):
def _preprocess_ulimits(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
if "ulimits" not in values:
return values
result = []
@ -644,8 +783,10 @@ def _preprocess_ulimits(module, values):
}
def _preprocess_mounts(module, values):
last = {}
def _preprocess_mounts(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
last: dict[str, str] = {}
def check_collision(t, name):
if t in last:
@ -776,7 +917,9 @@ def _preprocess_mounts(module, values):
return values
def _preprocess_labels(module, values):
def _preprocess_labels(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
result = {}
if "labels" in values:
labels = values["labels"]
@ -787,13 +930,15 @@ def _preprocess_labels(module, values):
return result
def _preprocess_log(module, values):
result = {}
def _preprocess_log(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
if "log_driver" not in values:
return result
result["log_driver"] = values["log_driver"]
if "log_options" in values:
options = {}
options: dict[str, str] = {}
for k, v in values["log_options"].items():
if not isinstance(v, str):
value = to_text(v, errors="surrogate_or_strict")
@ -807,7 +952,9 @@ def _preprocess_log(module, values):
return result
def _preprocess_ports(module, values):
def _preprocess_ports(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
if "published_ports" in values:
if "all" in values["published_ports"]:
module.fail_json(
@ -815,7 +962,12 @@ def _preprocess_ports(module, values):
"to randomly assign port mappings for those not specified by published_ports."
)
binds = {}
binds: dict[
str | int,
tuple[str]
| tuple[str, str | int]
| list[tuple[str] | tuple[str, str | int]],
] = {}
for port in values["published_ports"]:
parts = _split_colon_ipv6(
to_text(port, errors="surrogate_or_strict"), module
@ -827,6 +979,7 @@ def _preprocess_ports(module, values):
container_ports = _parse_port_range(container_port, module)
p_len = len(parts)
port_binds: Sequence[tuple[str] | tuple[str, str | int]]
if p_len == 1:
port_binds = len(container_ports) * [(_DEFAULT_IP_REPLACEMENT_STRING,)]
elif p_len == 2:
@ -865,8 +1018,12 @@ def _preprocess_ports(module, values):
"Maybe you forgot to use square brackets ([...]) around an IPv6 address?"
)
for bind, container_port in zip(port_binds, container_ports):
idx = f"{container_port}/{protocol}" if protocol else container_port
for bind, container_port_val in zip(port_binds, container_ports):
idx = (
f"{container_port_val}/{protocol}"
if protocol
else container_port_val
)
if idx in binds:
old_bind = binds[idx]
if isinstance(old_bind, list):
@ -882,9 +1039,9 @@ def _preprocess_ports(module, values):
for port in values["exposed_ports"]:
port = to_text(port, errors="surrogate_or_strict").strip()
protocol = "tcp"
match = re.search(r"(/.+$)", port)
if match:
protocol = match.group(1).replace("/", "")
matcher = re.search(r"(/.+$)", port)
if matcher:
protocol = matcher.group(1).replace("/", "")
port = re.sub(r"/.+$", "", port)
exposed.append((port, protocol))
if "published_ports" in values:
@ -912,7 +1069,7 @@ def _preprocess_ports(module, values):
return values
def _compare_platform(option, param_value, container_value):
def _compare_platform(option: Option, param_value: t.Any, container_value: t.Any):
if option.comparison == "ignore":
return True
try:

File diff suppressed because it is too large Load Diff

View File

@ -9,6 +9,7 @@
from __future__ import annotations
import re
import typing as t
from time import sleep
from ansible.module_utils.common.text.converters import to_text
@ -16,6 +17,9 @@ from ansible.module_utils.common.text.converters import to_text
from ansible_collections.community.docker.plugins.module_utils._api.utils.utils import (
parse_repository_tag,
)
from ansible_collections.community.docker.plugins.module_utils._module_container.base import (
Client,
)
from ansible_collections.community.docker.plugins.module_utils._util import (
DifferenceTracker,
DockerBaseClass,
@ -25,13 +29,23 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
)
if t.TYPE_CHECKING:
from collections.abc import Sequence
from ansible.module_utils.basic import AnsibleModule
from .base import Engine, EngineDriver, Option, OptionGroup
class Container(DockerBaseClass):
def __init__(self, container, engine_driver):
def __init__(
self, container: dict[str, t.Any] | None, engine_driver: EngineDriver
) -> None:
super().__init__()
self.raw = container
self.id = None
self.image = None
self.image_name = None
self.id: str | None = None
self.image: str | None = None
self.image_name: str | None = None
self.container = container
self.engine_driver = engine_driver
if container:
@ -41,11 +55,11 @@ class Container(DockerBaseClass):
self.log(self.container, pretty_print=True)
@property
def exists(self):
def exists(self) -> bool:
return bool(self.container)
@property
def removing(self):
def removing(self) -> bool:
return (
self.engine_driver.is_container_removing(self.container)
if self.container
@ -53,7 +67,7 @@ class Container(DockerBaseClass):
)
@property
def running(self):
def running(self) -> bool:
return (
self.engine_driver.is_container_running(self.container)
if self.container
@ -61,7 +75,7 @@ class Container(DockerBaseClass):
)
@property
def paused(self):
def paused(self) -> bool:
return (
self.engine_driver.is_container_paused(self.container)
if self.container
@ -69,8 +83,14 @@ class Container(DockerBaseClass):
)
class ContainerManager(DockerBaseClass):
def __init__(self, module, engine_driver, client, active_options):
class ContainerManager(DockerBaseClass, t.Generic[Client]):
def __init__(
self,
module: AnsibleModule,
engine_driver: EngineDriver,
client: Client,
active_options: list[OptionGroup],
) -> None:
super().__init__()
self.module = module
self.engine_driver = engine_driver
@ -78,46 +98,64 @@ class ContainerManager(DockerBaseClass):
self.options = active_options
self.all_options = self._collect_all_options(active_options)
self.check_mode = self.module.check_mode
self.param_cleanup = self.module.params["cleanup"]
self.param_container_default_behavior = self.module.params[
"container_default_behavior"
]
self.param_default_host_ip = self.module.params["default_host_ip"]
self.param_debug = self.module.params["debug"]
self.param_force_kill = self.module.params["force_kill"]
self.param_image = self.module.params["image"]
self.param_image_comparison = self.module.params["image_comparison"]
self.param_image_label_mismatch = self.module.params["image_label_mismatch"]
self.param_image_name_mismatch = self.module.params["image_name_mismatch"]
self.param_keep_volumes = self.module.params["keep_volumes"]
self.param_kill_signal = self.module.params["kill_signal"]
self.param_name = self.module.params["name"]
self.param_networks_cli_compatible = self.module.params[
self.param_cleanup: bool = self.module.params["cleanup"]
self.param_container_default_behavior: t.Literal[
"compatibility", "no_defaults"
] = self.module.params["container_default_behavior"]
self.param_default_host_ip: str | None = self.module.params["default_host_ip"]
self.param_debug: bool = self.module.params["debug"]
self.param_force_kill: bool = self.module.params["force_kill"]
self.param_image: str | None = self.module.params["image"]
self.param_image_comparison: t.Literal["desired-image", "current-image"] = (
self.module.params["image_comparison"]
)
self.param_image_label_mismatch: t.Literal["ignore", "fail"] = (
self.module.params["image_label_mismatch"]
)
self.param_image_name_mismatch: t.Literal["ignore", "recreate"] = (
self.module.params["image_name_mismatch"]
)
self.param_keep_volumes: bool = self.module.params["keep_volumes"]
self.param_kill_signal: str | None = self.module.params["kill_signal"]
self.param_name: str = self.module.params["name"]
self.param_networks_cli_compatible: bool = self.module.params[
"networks_cli_compatible"
]
self.param_output_logs = self.module.params["output_logs"]
self.param_paused = self.module.params["paused"]
self.param_pull = self.module.params["pull"]
if self.param_pull is True:
self.param_pull = "always"
if self.param_pull is False:
self.param_pull = "missing"
self.param_pull_check_mode_behavior = self.module.params[
"pull_check_mode_behavior"
self.param_output_logs: bool = self.module.params["output_logs"]
self.param_paused: bool | None = self.module.params["paused"]
param_pull: t.Literal["never", "missing", "always", True, False] = (
self.module.params["pull"]
)
if param_pull is True:
param_pull = "always"
if param_pull is False:
param_pull = "missing"
self.param_pull: t.Literal["never", "missing", "always"] = param_pull
self.param_pull_check_mode_behavior: t.Literal[
"image_not_present", "always"
] = self.module.params["pull_check_mode_behavior"]
self.param_recreate: bool = self.module.params["recreate"]
self.param_removal_wait_timeout: int | float | None = self.module.params[
"removal_wait_timeout"
]
self.param_recreate = self.module.params["recreate"]
self.param_removal_wait_timeout = self.module.params["removal_wait_timeout"]
self.param_healthy_wait_timeout = self.module.params["healthy_wait_timeout"]
if self.param_healthy_wait_timeout <= 0:
self.param_healthy_wait_timeout: int | float | None = self.module.params[
"healthy_wait_timeout"
]
if (
self.param_healthy_wait_timeout is not None
and self.param_healthy_wait_timeout <= 0
):
self.param_healthy_wait_timeout = None
self.param_restart = self.module.params["restart"]
self.param_state = self.module.params["state"]
self.param_restart: bool = self.module.params["restart"]
self.param_state: t.Literal[
"absent", "present", "healthy", "started", "stopped"
] = self.module.params["state"]
self._parse_comparisons()
self._update_params()
self.results = {"changed": False, "actions": []}
self.diff = {}
self.diff: dict[str, t.Any] = {}
self.diff_tracker = DifferenceTracker()
self.facts = {}
self.facts: dict[str, t.Any] | None = {}
if self.param_default_host_ip:
valid_ip = False
if re.match(
@ -134,16 +172,22 @@ class ContainerManager(DockerBaseClass):
"The value of default_host_ip must be an empty string, an IPv4 address, "
f'or an IPv6 address. Got "{self.param_default_host_ip}" instead.'
)
self.parameters = None
self.parameters: list[tuple[OptionGroup, dict[str, t.Any]]] | None = None
def _collect_all_options(self, active_options):
def _add_action(self, action: dict[str, t.Any]) -> None:
actions: list[dict[str, t.Any]] = self.results["actions"] # type: ignore
actions.append(action)
def _collect_all_options(
self, active_options: list[OptionGroup]
) -> dict[str, Option]:
all_options = {}
for options in active_options:
for option in options.options:
all_options[option.name] = option
return all_options
def _collect_all_module_params(self):
def _collect_all_module_params(self) -> set[str]:
all_module_options = set()
for option, data in self.module.argument_spec.items():
all_module_options.add(option)
@ -152,7 +196,7 @@ class ContainerManager(DockerBaseClass):
all_module_options.add(alias)
return all_module_options
def _parse_comparisons(self):
def _parse_comparisons(self) -> None:
# Keep track of all module params and all option aliases
all_module_options = self._collect_all_module_params()
comp_aliases = {}
@ -163,10 +207,11 @@ class ContainerManager(DockerBaseClass):
for alias in option.ansible_aliases:
comp_aliases[alias] = option_name
# Process comparisons specified by user
if self.module.params.get("comparisons"):
comparisons: dict[str, t.Any] | None = self.module.params.get("comparisons")
if comparisons:
# If '*' appears in comparisons, process it first
if "*" in self.module.params["comparisons"]:
value = self.module.params["comparisons"]["*"]
if "*" in comparisons:
value = comparisons["*"]
if value not in ("strict", "ignore"):
self.fail(
"The wildcard can only be used with comparison modes 'strict' and 'ignore'!"
@ -179,8 +224,8 @@ class ContainerManager(DockerBaseClass):
continue
option.comparison = value
# Now process all other comparisons.
comp_aliases_used = {}
for key, value in self.module.params["comparisons"].items():
comp_aliases_used: dict[str, str] = {}
for key, value in comparisons.items():
if key == "*":
continue
# Find main key
@ -220,7 +265,7 @@ class ContainerManager(DockerBaseClass):
option.copy_comparison_from
].comparison
def _update_params(self):
def _update_params(self) -> None:
if (
self.param_networks_cli_compatible is True
and self.module.params["networks"]
@ -247,12 +292,14 @@ class ContainerManager(DockerBaseClass):
if self.module.params[param] is None:
self.module.params[param] = value
def fail(self, *args, **kwargs):
self.client.fail(*args, **kwargs)
def fail(self, *args, **kwargs) -> t.NoReturn:
# mypy doesn't know that Client has fail() method
raise self.client.fail(*args, **kwargs) # type: ignore
def run(self):
def run(self) -> None:
if self.param_state in ("stopped", "started", "present", "healthy"):
self.present(self.param_state)
# mypy doesn't get that self.param_state has only one of the above values
self.present(self.param_state) # type: ignore
elif self.param_state == "absent":
self.absent()
@ -270,15 +317,16 @@ class ContainerManager(DockerBaseClass):
def wait_for_state(
self,
container_id,
complete_states=None,
wait_states=None,
accept_removal=False,
max_wait=None,
health_state=False,
):
container_id: str,
*,
complete_states: Sequence[str | None] | None = None,
wait_states: Sequence[str | None] | None = None,
accept_removal: bool = False,
max_wait: int | float | None = None,
health_state: bool = False,
) -> dict[str, t.Any] | None:
delay = 1.0
total_wait = 0
total_wait = 0.0
while True:
# Inspect container
result = self.engine_driver.inspect_container_by_id(
@ -314,7 +362,9 @@ class ContainerManager(DockerBaseClass):
# code will have slept for ~1.5 minutes.)
delay = min(delay * 1.1, 10)
def _collect_params(self, active_options):
def _collect_params(
self, active_options: list[OptionGroup]
) -> list[tuple[OptionGroup, dict[str, t.Any]]]:
parameters = []
for options in active_options:
values = {}
@ -336,21 +386,25 @@ class ContainerManager(DockerBaseClass):
parameters.append((options, values))
return parameters
def _needs_container_image(self):
def _needs_container_image(self) -> bool:
assert self.parameters is not None
for options, values in self.parameters:
engine = options.get_engine(self.engine_driver.name)
if engine.needs_container_image(values):
return True
return False
def _needs_host_info(self):
def _needs_host_info(self) -> bool:
assert self.parameters is not None
for options, values in self.parameters:
engine = options.get_engine(self.engine_driver.name)
if engine.needs_host_info(values):
return True
return False
def present(self, state):
def present(
self, state: t.Literal["stopped", "started", "present", "healthy"]
) -> None:
self.parameters = self._collect_params(self.options)
container = self._get_container(self.param_name)
was_running = container.running
@ -382,6 +436,7 @@ class ContainerManager(DockerBaseClass):
self.diff_tracker.add("exists", parameter=True, active=False)
if container.removing and not self.check_mode:
# Wait for container to be removed before trying to create it
assert container.id is not None
self.wait_for_state(
container.id,
wait_states=["removing"],
@ -394,6 +449,7 @@ class ContainerManager(DockerBaseClass):
container_created = True
else:
# Existing container
assert container.id is not None
different, differences = self.has_different_configuration(
container, container_image, comparison_image, host_info
)
@ -453,13 +509,16 @@ class ContainerManager(DockerBaseClass):
if state in ("started", "healthy") and not container.running:
self.diff_tracker.add("running", parameter=True, active=was_running)
assert container.id is not None
container = self.container_start(container.id)
elif state in ("started", "healthy") and self.param_restart:
self.diff_tracker.add("running", parameter=True, active=was_running)
self.diff_tracker.add("restarted", parameter=True, active=False)
assert container.id is not None
container = self.container_restart(container.id)
elif state == "stopped" and container.running:
self.diff_tracker.add("running", parameter=False, active=was_running)
assert container.id is not None
self.container_stop(container.id)
container = self._get_container(container.id)
@ -472,6 +531,7 @@ class ContainerManager(DockerBaseClass):
"paused", parameter=self.param_paused, active=was_paused
)
if not self.check_mode:
assert container.id is not None
try:
if self.param_paused:
self.engine_driver.pause_container(
@ -487,12 +547,13 @@ class ContainerManager(DockerBaseClass):
)
container = self._get_container(container.id)
self.results["changed"] = True
self.results["actions"].append({"set_paused": self.param_paused})
self._add_action({"set_paused": self.param_paused})
self.facts = container.raw
if state == "healthy" and not self.check_mode:
# `None` means that no health check enabled; simply treat this as 'healthy'
assert container.id is not None
inspect_result = self.wait_for_state(
container.id,
wait_states=["starting", "unhealthy"],
@ -504,41 +565,51 @@ class ContainerManager(DockerBaseClass):
# Return the latest inspection results retrieved
self.facts = inspect_result
def absent(self):
def absent(self) -> None:
container = self._get_container(self.param_name)
if container.exists:
assert container.id is not None
if container.running:
self.diff_tracker.add("running", parameter=False, active=True)
self.container_stop(container.id)
self.diff_tracker.add("exists", parameter=False, active=True)
self.container_remove(container.id)
def _output_logs(self, msg):
def _output_logs(self, msg: str | bytes) -> None:
self.module.log(msg=msg)
def _get_container(self, container):
def _get_container(self, container: str) -> Container:
"""
Expects container ID or Name. Returns a container object
"""
container = self.engine_driver.inspect_container_by_name(self.client, container)
return Container(container, self.engine_driver)
container_data = self.engine_driver.inspect_container_by_name(
self.client, container
)
return Container(container_data, self.engine_driver)
def _get_container_image(self, container, fallback=None):
def _get_container_image(
self, container: Container, fallback: dict[str, t.Any] | None = None
) -> dict[str, t.Any] | None:
if not container.exists or container.removing:
return fallback
image = container.image
assert image is not None
if is_image_name_id(image):
image = self.engine_driver.inspect_image_by_id(self.client, image)
image_data = self.engine_driver.inspect_image_by_id(self.client, image)
else:
repository, tag = parse_repository_tag(image)
if not tag:
tag = "latest"
image = self.engine_driver.inspect_image_by_name(
image_data = self.engine_driver.inspect_image_by_name(
self.client, repository, tag
)
return image or fallback
return image_data or fallback
def _get_image(self, container, needs_container_image=False):
def _get_image(
self, container: Container, needs_container_image: bool = False
) -> tuple[
dict[str, t.Any] | None, dict[str, t.Any] | None, dict[str, t.Any] | None
]:
image_parameter = self.param_image
get_container_image = needs_container_image or not image_parameter
container_image = (
@ -553,7 +624,7 @@ class ContainerManager(DockerBaseClass):
if is_image_name_id(image_parameter):
image = self.engine_driver.inspect_image_by_id(self.client, image_parameter)
if image is None:
self.client.fail(f"Cannot find image with ID {image_parameter}")
self.fail(f"Cannot find image with ID {image_parameter}")
else:
repository, tag = parse_repository_tag(image_parameter)
if not tag:
@ -562,7 +633,7 @@ class ContainerManager(DockerBaseClass):
self.client, repository, tag
)
if not image and self.param_pull == "never":
self.client.fail(
self.fail(
f"Cannot find image with name {repository}:{tag}, and pull=never"
)
if not image or self.param_pull == "always":
@ -576,12 +647,12 @@ class ContainerManager(DockerBaseClass):
)
if already_to_latest:
self.results["changed"] = False
self.results["actions"].append(
self._add_action(
{"pulled_image": f"{repository}:{tag}", "changed": False}
)
else:
self.results["changed"] = True
self.results["actions"].append(
self._add_action(
{"pulled_image": f"{repository}:{tag}", "changed": True}
)
elif not image or self.param_pull_check_mode_behavior == "always":
@ -589,10 +660,10 @@ class ContainerManager(DockerBaseClass):
# pull. (Implicitly: if the image is there, claim it already was latest unless
# pull_check_mode_behavior == 'always'.)
self.results["changed"] = True
action = {"pulled_image": f"{repository}:{tag}"}
action: dict[str, t.Any] = {"pulled_image": f"{repository}:{tag}"}
if not image:
action["changed"] = True
self.results["actions"].append(action)
self._add_action(action)
self.log("image")
self.log(image, pretty_print=True)
@ -605,7 +676,9 @@ class ContainerManager(DockerBaseClass):
return image, container_image, comparison_image
def _image_is_different(self, image, container):
def _image_is_different(
self, image: dict[str, t.Any] | None, container: Container
) -> bool:
if image and image.get("Id"):
if container and container.image:
if image.get("Id") != container.image:
@ -615,8 +688,9 @@ class ContainerManager(DockerBaseClass):
return True
return False
def _compose_create_parameters(self, image):
params = {}
def _compose_create_parameters(self, image: str) -> dict[str, t.Any]:
params: dict[str, t.Any] = {}
assert self.parameters is not None
for options, values in self.parameters:
engine = options.get_engine(self.engine_driver.name)
if engine.can_set_value(self.engine_driver.get_api_version(self.client)):
@ -632,15 +706,16 @@ class ContainerManager(DockerBaseClass):
def _record_differences(
self,
differences,
options,
param_values,
engine,
container,
container_image,
image,
host_info,
differences: DifferenceTracker,
options: OptionGroup,
param_values: dict[str, t.Any],
engine: Engine,
container: Container,
container_image: dict[str, t.Any] | None,
image: dict[str, t.Any] | None,
host_info: dict[str, t.Any] | None,
):
assert container.raw is not None
container_values = engine.get_value(
self.module,
container.raw,
@ -709,9 +784,16 @@ class ContainerManager(DockerBaseClass):
c = sorted(c, key=sort_key_fn)
differences.add(option.name, parameter=p, active=c)
def has_different_configuration(self, container, container_image, image, host_info):
def has_different_configuration(
self,
container: Container,
container_image: dict[str, t.Any] | None,
image: dict[str, t.Any] | None,
host_info: dict[str, t.Any] | None,
) -> tuple[bool, DifferenceTracker]:
differences = DifferenceTracker()
update_differences = DifferenceTracker()
assert self.parameters is not None
for options, param_values in self.parameters:
engine = options.get_engine(self.engine_driver.name)
if engine.can_update_value(self.engine_driver.get_api_version(self.client)):
@ -743,9 +825,14 @@ class ContainerManager(DockerBaseClass):
return has_differences, differences
def has_different_resource_limits(
self, container, container_image, image, host_info
):
self,
container: Container,
container_image: dict[str, t.Any] | None,
image: dict[str, t.Any] | None,
host_info: dict[str, t.Any] | None,
) -> tuple[bool, DifferenceTracker]:
differences = DifferenceTracker()
assert self.parameters is not None
for options, param_values in self.parameters:
engine = options.get_engine(self.engine_driver.name)
if not engine.can_update_value(
@ -765,8 +852,9 @@ class ContainerManager(DockerBaseClass):
has_differences = not differences.empty
return has_differences, differences
def _compose_update_parameters(self):
result = {}
def _compose_update_parameters(self) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
assert self.parameters is not None
for options, values in self.parameters:
engine = options.get_engine(self.engine_driver.name)
if not engine.can_update_value(
@ -782,7 +870,13 @@ class ContainerManager(DockerBaseClass):
)
return result
def update_limits(self, container, container_image, image, host_info):
def update_limits(
self,
container: Container,
container_image: dict[str, t.Any] | None,
image: dict[str, t.Any] | None,
host_info: dict[str, t.Any] | None,
) -> Container:
limits_differ, different_limits = self.has_different_resource_limits(
container, container_image, image, host_info
)
@ -793,20 +887,24 @@ class ContainerManager(DockerBaseClass):
)
self.diff_tracker.merge(different_limits)
if limits_differ and not self.check_mode:
assert container.id is not None
self.container_update(container.id, self._compose_update_parameters())
return self._get_container(container.id)
return container
def has_network_differences(self, container):
def has_network_differences(
self, container: Container
) -> tuple[bool, list[dict[str, t.Any]]]:
"""
Check if the container is connected to requested networks with expected options: links, aliases, ipv4, ipv6
"""
different = False
differences = []
differences: list[dict[str, t.Any]] = []
if not self.module.params["networks"]:
return different, differences
assert container.container is not None
if not container.container.get("NetworkSettings"):
self.fail(
"has_missing_networks: Error parsing container properties. NetworkSettings missing."
@ -869,13 +967,16 @@ class ContainerManager(DockerBaseClass):
)
return different, differences
def has_extra_networks(self, container):
def has_extra_networks(
self, container: Container
) -> tuple[bool, list[dict[str, t.Any]]]:
"""
Check if the container is connected to non-requested networks
"""
extra_networks = []
extra_networks: list[dict[str, t.Any]] = []
extra = False
assert container.container is not None
if not container.container.get("NetworkSettings"):
self.fail(
"has_extra_networks: Error parsing container properties. NetworkSettings missing."
@ -896,7 +997,9 @@ class ContainerManager(DockerBaseClass):
)
return extra, extra_networks
def update_networks(self, container, container_created):
def update_networks(
self, container: Container, container_created: bool
) -> Container:
updated_container = container
if self.all_options["networks"].comparison != "ignore" or container_created:
has_network_differences, network_differences = self.has_network_differences(
@ -939,13 +1042,14 @@ class ContainerManager(DockerBaseClass):
updated_container = self._purge_networks(container, extra_networks)
return updated_container
def _add_networks(self, container, differences):
def _add_networks(
self, container: Container, differences: list[dict[str, t.Any]]
) -> Container:
assert container.id is not None
for diff in differences:
# remove the container from the network, if connected
if diff.get("container"):
self.results["actions"].append(
{"removed_from_network": diff["parameter"]["name"]}
)
self._add_action({"removed_from_network": diff["parameter"]["name"]})
if not self.check_mode:
try:
self.engine_driver.disconnect_container_from_network(
@ -956,7 +1060,7 @@ class ContainerManager(DockerBaseClass):
f"Error disconnecting container from network {diff['parameter']['name']} - {exc}"
)
# connect to the network
self.results["actions"].append(
self._add_action(
{
"added_to_network": diff["parameter"]["name"],
"network_parameters": diff["parameter"],
@ -982,9 +1086,12 @@ class ContainerManager(DockerBaseClass):
)
return self._get_container(container.id)
def _purge_networks(self, container, networks):
def _purge_networks(
self, container: Container, networks: list[dict[str, t.Any]]
) -> Container:
assert container.id is not None
for network in networks:
self.results["actions"].append({"removed_from_network": network["name"]})
self._add_action({"removed_from_network": network["name"]})
if not self.check_mode:
try:
self.engine_driver.disconnect_container_from_network(
@ -996,7 +1103,7 @@ class ContainerManager(DockerBaseClass):
)
return self._get_container(container.id)
def container_create(self, image):
def container_create(self, image: str) -> Container | None:
create_parameters = self._compose_create_parameters(image)
self.log("create container")
self.log(f"image: {image} parameters:")
@ -1014,7 +1121,7 @@ class ContainerManager(DockerBaseClass):
for key, value in network.items()
if key not in ("name", "id")
}
self.results["actions"].append(
self._add_action(
{
"created": "Created container",
"create_parameters": create_parameters,
@ -1022,7 +1129,6 @@ class ContainerManager(DockerBaseClass):
}
)
self.results["changed"] = True
new_container = None
if not self.check_mode:
try:
container_id = self.engine_driver.create_container(
@ -1031,11 +1137,11 @@ class ContainerManager(DockerBaseClass):
except Exception as exc: # pylint: disable=broad-exception-caught
self.fail(f"Error creating container: {exc}")
return self._get_container(container_id)
return new_container
return None
def container_start(self, container_id):
def container_start(self, container_id: str) -> Container:
self.log(f"start container {container_id}")
self.results["actions"].append({"started": container_id})
self._add_action({"started": container_id})
self.results["changed"] = True
if not self.check_mode:
try:
@ -1047,9 +1153,11 @@ class ContainerManager(DockerBaseClass):
status = self.engine_driver.wait_for_container(
self.client, container_id
)
self.client.fail_results["status"] = status
# mypy doesn't know that Client has fail_results property
self.client.fail_results["status"] = status # type: ignore
self.results["status"] = status
output: str | bytes
if self.module.params["auto_remove"]:
output = "Cannot retrieve result as auto_remove is enabled"
if self.param_output_logs:
@ -1077,12 +1185,14 @@ class ContainerManager(DockerBaseClass):
return insp
return self._get_container(container_id)
def container_remove(self, container_id, link=False, force=False):
def container_remove(
self, container_id: str, link: bool = False, force: bool = False
) -> None:
volume_state = not self.param_keep_volumes
self.log(
f"remove container container:{container_id} v:{volume_state} link:{link} force{force}"
)
self.results["actions"].append(
self._add_action(
{
"removed": container_id,
"volume_state": volume_state,
@ -1101,13 +1211,15 @@ class ContainerManager(DockerBaseClass):
force=force,
)
except Exception as exc: # pylint: disable=broad-exception-caught
self.client.fail(f"Error removing container {container_id}: {exc}")
self.fail(f"Error removing container {container_id}: {exc}")
def container_update(self, container_id, update_parameters):
def container_update(
self, container_id: str, update_parameters: dict[str, t.Any]
) -> Container:
if update_parameters:
self.log(f"update container {container_id}")
self.log(update_parameters, pretty_print=True)
self.results["actions"].append(
self._add_action(
{"updated": container_id, "update_parameters": update_parameters}
)
self.results["changed"] = True
@ -1120,10 +1232,8 @@ class ContainerManager(DockerBaseClass):
self.fail(f"Error updating container {container_id}: {exc}")
return self._get_container(container_id)
def container_kill(self, container_id):
self.results["actions"].append(
{"killed": container_id, "signal": self.param_kill_signal}
)
def container_kill(self, container_id: str) -> None:
self._add_action({"killed": container_id, "signal": self.param_kill_signal})
self.results["changed"] = True
if not self.check_mode:
try:
@ -1133,8 +1243,8 @@ class ContainerManager(DockerBaseClass):
except Exception as exc: # pylint: disable=broad-exception-caught
self.fail(f"Error killing container {container_id}: {exc}")
def container_restart(self, container_id):
self.results["actions"].append(
def container_restart(self, container_id: str) -> Container:
self._add_action(
{"restarted": container_id, "timeout": self.module.params["stop_timeout"]}
)
self.results["changed"] = True
@ -1147,11 +1257,11 @@ class ContainerManager(DockerBaseClass):
self.fail(f"Error restarting container {container_id}: {exc}")
return self._get_container(container_id)
def container_stop(self, container_id):
def container_stop(self, container_id: str) -> None:
if self.param_force_kill:
self.container_kill(container_id)
return
self.results["actions"].append(
self._add_action(
{"stopped": container_id, "timeout": self.module.params["stop_timeout"]}
)
self.results["changed"] = True
@ -1164,7 +1274,7 @@ class ContainerManager(DockerBaseClass):
self.fail(f"Error stopping container {container_id}: {exc}")
def run_module(engine_driver):
def run_module(engine_driver: EngineDriver) -> None:
module, active_options, client = engine_driver.setup(
argument_spec={
"cleanup": {"type": "bool", "default": False},
@ -1228,7 +1338,7 @@ def run_module(engine_driver):
],
)
def execute():
def execute() -> t.NoReturn:
cm = ContainerManager(module, engine_driver, client, active_options)
cm.run()
module.exit_json(**sanitize_result(cm.results))

View File

@ -14,12 +14,13 @@
from __future__ import annotations
import re
import typing as t
_VALID_STR = re.compile("^[A-Za-z0-9_-]+$")
def _validate_part(string, part, part_name):
def _validate_part(string: str, part: str, part_name: str) -> str:
if not part:
raise ValueError(f'Invalid platform string "{string}": {part_name} is empty')
if not _VALID_STR.match(part):
@ -79,7 +80,7 @@ _KNOWN_ARCH = (
)
def _normalize_os(os_str):
def _normalize_os(os_str: str) -> str:
# See normalizeOS() in https://github.com/containerd/containerd/blob/main/platforms/database.go
os_str = os_str.lower()
if os_str == "macos":
@ -112,7 +113,7 @@ _NORMALIZE_ARCH = {
}
def _normalize_arch(arch_str, variant_str):
def _normalize_arch(arch_str: str, variant_str: str) -> tuple[str, str]:
# See normalizeArch() in https://github.com/containerd/containerd/blob/main/platforms/database.go
arch_str = arch_str.lower()
variant_str = variant_str.lower()
@ -121,15 +122,16 @@ def _normalize_arch(arch_str, variant_str):
res = _NORMALIZE_ARCH.get((arch_str, None))
if res is None:
return arch_str, variant_str
if res is not None:
arch_str = res[0]
if res[1] is not None:
variant_str = res[1]
return arch_str, variant_str
arch_str = res[0]
if res[1] is not None:
variant_str = res[1]
return arch_str, variant_str
class _Platform:
def __init__(self, os=None, arch=None, variant=None):
def __init__(
self, os: str | None = None, arch: str | None = None, variant: str | None = None
) -> None:
self.os = os
self.arch = arch
self.variant = variant
@ -140,7 +142,12 @@ class _Platform:
raise ValueError("If variant is given, os must be given too")
@classmethod
def parse_platform_string(cls, string, daemon_os=None, daemon_arch=None):
def parse_platform_string(
cls,
string: str | None,
daemon_os: str | None = None,
daemon_arch: str | None = None,
) -> t.Self:
# See Parse() in https://github.com/containerd/containerd/blob/main/platforms/platforms.go
if string is None:
return cls()
@ -182,6 +189,7 @@ class _Platform:
)
if variant is not None and not variant:
raise ValueError(f'Invalid platform string "{string}": variant is empty')
assert arch is not None # otherwise variant would be None as well
arch, variant = _normalize_arch(arch, variant or "")
if len(parts) == 2 and arch == "arm" and variant == "v7":
variant = None
@ -189,9 +197,12 @@ class _Platform:
variant = "v8"
return cls(os=_normalize_os(os), arch=arch, variant=variant or None)
def __str__(self):
def __str__(self) -> str:
if self.variant:
parts = [self.os, self.arch, self.variant]
assert (
self.os is not None and self.arch is not None
) # ensured in constructor
parts: list[str] = [self.os, self.arch, self.variant]
elif self.os:
if self.arch:
parts = [self.os, self.arch]
@ -203,12 +214,14 @@ class _Platform:
parts = []
return "/".join(parts)
def __repr__(self):
def __repr__(self) -> str:
return (
f"_Platform(os={self.os!r}, arch={self.arch!r}, variant={self.variant!r})"
)
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, _Platform):
return NotImplemented
return (
self.os == other.os
and self.arch == other.arch
@ -216,7 +229,9 @@ class _Platform:
)
def normalize_platform_string(string, daemon_os=None, daemon_arch=None):
def normalize_platform_string(
string: str, daemon_os: str | None = None, daemon_arch: str | None = None
) -> str:
return str(
_Platform.parse_platform_string(
string, daemon_os=daemon_os, daemon_arch=daemon_arch
@ -225,8 +240,12 @@ def normalize_platform_string(string, daemon_os=None, daemon_arch=None):
def compose_platform_string(
os=None, arch=None, variant=None, daemon_os=None, daemon_arch=None
):
os: str | None = None,
arch: str | None = None,
variant: str | None = None,
daemon_os: str | None = None,
daemon_arch: str | None = None,
) -> str:
if os is None and daemon_os is not None:
os = _normalize_os(daemon_os)
if arch is None and daemon_arch is not None:
@ -235,7 +254,7 @@ def compose_platform_string(
return str(_Platform(os=os, arch=arch, variant=variant or None))
def compare_platform_strings(string1, string2):
def compare_platform_strings(string1: str, string2: str) -> bool:
return _Platform.parse_platform_string(string1) == _Platform.parse_platform_string(
string2
)

View File

@ -13,7 +13,7 @@ import random
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
def generate_insecure_key():
def generate_insecure_key() -> bytes:
"""Do NOT use this for cryptographic purposes!"""
while True:
# Generate a one-byte key. Right now the functions below do not use more
@ -24,23 +24,23 @@ def generate_insecure_key():
return key
def scramble(value, key):
def scramble(value: str, key: bytes) -> str:
"""Do NOT use this for cryptographic purposes!"""
if len(key) < 1:
raise ValueError("Key must be at least one byte")
value = to_bytes(value)
b_value = to_bytes(value)
k = key[0]
value = bytes([k ^ b for b in value])
return "=S=" + to_native(base64.b64encode(value))
b_value = bytes([k ^ b for b in b_value])
return f"=S={to_native(base64.b64encode(b_value))}"
def unscramble(value, key):
def unscramble(value: str, key: bytes) -> str:
"""Do NOT use this for cryptographic purposes!"""
if len(key) < 1:
raise ValueError("Key must be at least one byte")
if not value.startswith("=S="):
raise ValueError("Value does not start with indicator")
value = base64.b64decode(value[3:])
b_value = base64.b64decode(value[3:])
k = key[0]
value = bytes([k ^ b for b in value])
return to_text(value)
b_value = bytes([k ^ b for b in b_value])
return to_text(b_value)

View File

@ -12,6 +12,7 @@ import os.path
import selectors
import socket as pysocket
import struct
import typing as t
from ansible_collections.community.docker.plugins.module_utils._api.utils import (
socket as docker_socket,
@ -23,58 +24,74 @@ from ansible_collections.community.docker.plugins.module_utils._socket_helper im
)
if t.TYPE_CHECKING:
from collections.abc import Callable
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.docker.plugins.module_utils._socket_helper import (
SocketLike,
)
PARAMIKO_POLL_TIMEOUT = 0.01 # 10 milliseconds
def _empty_writer(msg: str) -> None:
pass
class DockerSocketHandlerBase:
def __init__(self, sock, log=None):
def __init__(
self, sock: SocketLike, log: Callable[[str], None] | None = None
) -> None:
make_unblocking(sock)
if log is not None:
self._log = log
else:
self._log = lambda msg: True
self._log = log or _empty_writer
self._paramiko_read_workaround = hasattr(
sock, "send_ready"
) and "paramiko" in str(type(sock))
self._sock = sock
self._block_done_callback = None
self._block_buffer = []
self._block_done_callback: Callable[[int, bytes], None] | None = None
self._block_buffer: list[tuple[int, bytes]] = []
self._eof = False
self._read_buffer = b""
self._write_buffer = b""
self._end_of_writing = False
self._current_stream = None
self._current_stream: int | None = None
self._current_missing = 0
self._current_buffer = b""
self._selector = selectors.DefaultSelector()
self._selector.register(self._sock, selectors.EVENT_READ)
def __enter__(self):
def __enter__(self) -> t.Self:
return self
def __exit__(self, type_, value, tb):
def __exit__(self, type_, value, tb) -> None:
self._selector.close()
def set_block_done_callback(self, block_done_callback):
def set_block_done_callback(
self, block_done_callback: Callable[[int, bytes], None]
) -> None:
self._block_done_callback = block_done_callback
if self._block_done_callback is not None:
while self._block_buffer:
elt = self._block_buffer.pop(0)
self._block_done_callback(*elt)
def _add_block(self, stream_id, data):
def _add_block(self, stream_id: int, data: bytes) -> None:
if self._block_done_callback is not None:
self._block_done_callback(stream_id, data)
else:
self._block_buffer.append((stream_id, data))
def _read(self):
def _read(self) -> None:
if self._eof:
return
data: bytes | None
if hasattr(self._sock, "recv"):
try:
data = self._sock.recv(262144)
@ -86,13 +103,13 @@ class DockerSocketHandlerBase:
self._eof = True
return
raise
elif isinstance(self._sock, getattr(pysocket, "SocketIO")):
data = self._sock.read()
elif isinstance(self._sock, pysocket.SocketIO): # type: ignore[unreachable]
data = self._sock.read() # type: ignore
else:
data = os.read(self._sock.fileno())
data = os.read(self._sock.fileno()) # type: ignore # TODO does this really work?!
if data is None:
# no data available
return
return # type: ignore[unreachable]
self._log(f"read {len(data)} bytes")
if len(data) == 0:
# Stream EOF
@ -106,6 +123,7 @@ class DockerSocketHandlerBase:
self._read_buffer = self._read_buffer[n:]
self._current_missing -= n
if self._current_missing == 0:
assert self._current_stream is not None
self._add_block(self._current_stream, self._current_buffer)
self._current_buffer = b""
if len(self._read_buffer) < 8:
@ -119,13 +137,13 @@ class DockerSocketHandlerBase:
self._eof = True
break
def _handle_end_of_writing(self):
def _handle_end_of_writing(self) -> None:
if self._end_of_writing and len(self._write_buffer) == 0:
self._end_of_writing = False
self._log("Shutting socket down for writing")
shutdown_writing(self._sock, self._log)
def _write(self):
def _write(self) -> None:
if len(self._write_buffer) > 0:
written = write_to_socket(self._sock, self._write_buffer)
self._write_buffer = self._write_buffer[written:]
@ -138,7 +156,9 @@ class DockerSocketHandlerBase:
self._selector.modify(self._sock, selectors.EVENT_READ)
self._handle_end_of_writing()
def select(self, timeout=None, _internal_recursion=False):
def select(
self, timeout: int | float | None = None, _internal_recursion: bool = False
) -> bool:
if (
not _internal_recursion
and self._paramiko_read_workaround
@ -147,12 +167,14 @@ class DockerSocketHandlerBase:
# When the SSH transport is used, Docker SDK for Python internally uses Paramiko, whose
# Channel object supports select(), but only for reading
# (https://github.com/paramiko/paramiko/issues/695).
if self._sock.send_ready():
if self._sock.send_ready(): # type: ignore
self._write()
return True
while timeout is None or timeout > PARAMIKO_POLL_TIMEOUT:
result = self.select(PARAMIKO_POLL_TIMEOUT, _internal_recursion=True)
if self._sock.send_ready():
result = int(
self.select(PARAMIKO_POLL_TIMEOUT, _internal_recursion=True)
)
if self._sock.send_ready(): # type: ignore
self._read()
result += 1
if result > 0:
@ -172,19 +194,19 @@ class DockerSocketHandlerBase:
self._write()
result = len(events)
if self._paramiko_read_workaround and len(self._write_buffer) > 0:
if self._sock.send_ready():
if self._sock.send_ready(): # type: ignore
self._write()
result += 1
return result > 0
def is_eof(self):
def is_eof(self) -> bool:
return self._eof
def end_of_writing(self):
def end_of_writing(self) -> None:
self._end_of_writing = True
self._handle_end_of_writing()
def consume(self):
def consume(self) -> tuple[bytes, bytes]:
stdout = []
stderr = []
@ -203,12 +225,12 @@ class DockerSocketHandlerBase:
self.select()
return b"".join(stdout), b"".join(stderr)
def write(self, str_to_write):
def write(self, str_to_write: bytes) -> None:
self._write_buffer += str_to_write
if len(self._write_buffer) == len(str_to_write):
self._write()
class DockerSocketHandlerModule(DockerSocketHandlerBase):
def __init__(self, sock, module):
def __init__(self, sock: SocketLike, module: AnsibleModule) -> None:
super().__init__(sock, module.debug)

View File

@ -12,9 +12,14 @@ import os
import os.path
import socket as pysocket
import typing as t
from collections.abc import Callable
def make_file_unblocking(file) -> None:
if t.TYPE_CHECKING:
SocketLike = pysocket.socket
def make_file_unblocking(file: SocketLike) -> None:
fcntl.fcntl(
file.fileno(),
fcntl.F_SETFL,
@ -22,7 +27,7 @@ def make_file_unblocking(file) -> None:
)
def make_file_blocking(file) -> None:
def make_file_blocking(file: SocketLike) -> None:
fcntl.fcntl(
file.fileno(),
fcntl.F_SETFL,
@ -30,11 +35,11 @@ def make_file_blocking(file) -> None:
)
def make_unblocking(sock) -> None:
def make_unblocking(sock: SocketLike) -> None:
if hasattr(sock, "_sock"):
sock._sock.setblocking(0)
elif hasattr(sock, "setblocking"):
sock.setblocking(0)
sock.setblocking(0) # type: ignore # TODO: CHECK!
else:
make_file_unblocking(sock)
@ -43,7 +48,9 @@ def _empty_writer(msg: str) -> None:
pass
def shutdown_writing(sock, log: t.Callable[[str], None] = _empty_writer) -> None:
def shutdown_writing(
sock: SocketLike, log: Callable[[str], None] = _empty_writer
) -> None:
# FIXME: This does **not work with SSLSocket**! Apparently SSLSocket does not allow to send
# a close_notify TLS alert without completely shutting down the connection.
# Calling sock.shutdown(pysocket.SHUT_WR) simply turns of TLS encryption and from that
@ -56,14 +63,14 @@ def shutdown_writing(sock, log: t.Callable[[str], None] = _empty_writer) -> None
except TypeError as e:
# probably: "TypeError: shutdown() takes 1 positional argument but 2 were given"
log(f"Shutting down for writing not possible; trying shutdown instead: {e}")
sock.shutdown()
sock.shutdown() # type: ignore
elif isinstance(sock, getattr(pysocket, "SocketIO")):
sock._sock.shutdown(pysocket.SHUT_WR)
else:
log("No idea how to signal end of writing")
def write_to_socket(sock, data: bytes) -> None:
def write_to_socket(sock: SocketLike, data: bytes) -> int:
if hasattr(sock, "_send_until_done"):
# WrappedSocket (urllib3/contrib/pyopenssl) does not have `send`, but
# only `sendall`, which uses `_send_until_done` under the hood.

View File

@ -9,6 +9,7 @@
from __future__ import annotations
import json
import typing as t
from time import sleep
@ -28,10 +29,7 @@ from ansible_collections.community.docker.plugins.module_utils._version import (
class AnsibleDockerSwarmClient(AnsibleDockerClient):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def get_swarm_node_id(self):
def get_swarm_node_id(self) -> str | None:
"""
Get the 'NodeID' of the Swarm node or 'None' if host is not in Swarm. It returns the NodeID
of Docker host the module is executed on
@ -51,7 +49,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
return swarm_info["Swarm"]["NodeID"]
return None
def check_if_swarm_node(self, node_id=None):
def check_if_swarm_node(self, node_id: str | None = None) -> bool | None:
"""
Checking if host is part of Docker Swarm. If 'node_id' is not provided it reads the Docker host
system information looking if specific key in output exists. If 'node_id' is provided then it tries to
@ -83,11 +81,11 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
try:
node_info = self.get_node_inspect(node_id=node_id)
except APIError:
return
return None
return node_info["ID"] is not None
def check_if_swarm_manager(self):
def check_if_swarm_manager(self) -> bool:
"""
Checks if node role is set as Manager in Swarm. The node is the docker host on which module action
is performed. The inspect_swarm() will fail if node is not a manager
@ -101,7 +99,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
except APIError:
return False
def fail_task_if_not_swarm_manager(self):
def fail_task_if_not_swarm_manager(self) -> None:
"""
If host is not a swarm manager then Ansible task on this host should end with 'failed' state
"""
@ -110,7 +108,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
"Error running docker swarm module: must run on swarm manager node"
)
def check_if_swarm_worker(self):
def check_if_swarm_worker(self) -> bool:
"""
Checks if node role is set as Worker in Swarm. The node is the docker host on which module action
is performed. Will fail if run on host that is not part of Swarm via check_if_swarm_node()
@ -122,7 +120,9 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
return True
return False
def check_if_swarm_node_is_down(self, node_id=None, repeat_check=1):
def check_if_swarm_node_is_down(
self, node_id: str | None = None, repeat_check: int = 1
) -> bool:
"""
Checks if node status on Swarm manager is 'down'. If node_id is provided it query manager about
node specified in parameter, otherwise it query manager itself. If run on Swarm Worker node or
@ -147,7 +147,19 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
return True
return False
def get_node_inspect(self, node_id=None, skip_missing=False):
@t.overload
def get_node_inspect(
self, node_id: str | None = None, skip_missing: t.Literal[False] = False
) -> dict[str, t.Any]: ...
@t.overload
def get_node_inspect(
self, node_id: str | None = None, skip_missing: bool = False
) -> dict[str, t.Any] | None: ...
def get_node_inspect(
self, node_id: str | None = None, skip_missing: bool = False
) -> dict[str, t.Any] | None:
"""
Returns Swarm node info as in 'docker node inspect' command about single node
@ -195,7 +207,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
node_info["Status"]["Addr"] = swarm_leader_ip
return node_info
def get_all_nodes_inspect(self):
def get_all_nodes_inspect(self) -> list[dict[str, t.Any]]:
"""
Returns Swarm node info as in 'docker node inspect' command about all registered nodes
@ -217,7 +229,17 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
node_info = json.loads(json_str)
return node_info
def get_all_nodes_list(self, output="short"):
@t.overload
def get_all_nodes_list(self, output: t.Literal["short"] = "short") -> list[str]: ...
@t.overload
def get_all_nodes_list(
self, output: t.Literal["long"]
) -> list[dict[str, t.Any]]: ...
def get_all_nodes_list(
self, output: t.Literal["short", "long"] = "short"
) -> list[str] | list[dict[str, t.Any]]:
"""
Returns list of nodes registered in Swarm
@ -227,48 +249,46 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
if 'output' is 'long' then returns data is list of dict containing the attributes as in
output of command 'docker node ls'
"""
nodes_list = []
nodes_inspect = self.get_all_nodes_inspect()
if nodes_inspect is None:
return None
if output == "short":
nodes_list = []
for node in nodes_inspect:
nodes_list.append(node["Description"]["Hostname"])
elif output == "long":
return nodes_list
if output == "long":
nodes_info_list = []
for node in nodes_inspect:
node_property = {}
node_property: dict[str, t.Any] = {}
node_property.update({"ID": node["ID"]})
node_property.update({"Hostname": node["Description"]["Hostname"]})
node_property.update({"Status": node["Status"]["State"]})
node_property.update({"Availability": node["Spec"]["Availability"]})
node_property["ID"] = node["ID"]
node_property["Hostname"] = node["Description"]["Hostname"]
node_property["Status"] = node["Status"]["State"]
node_property["Availability"] = node["Spec"]["Availability"]
if "ManagerStatus" in node:
if node["ManagerStatus"]["Leader"] is True:
node_property.update({"Leader": True})
node_property.update(
{"ManagerStatus": node["ManagerStatus"]["Reachability"]}
)
node_property.update(
{"EngineVersion": node["Description"]["Engine"]["EngineVersion"]}
)
node_property["Leader"] = True
node_property["ManagerStatus"] = node["ManagerStatus"][
"Reachability"
]
node_property["EngineVersion"] = node["Description"]["Engine"][
"EngineVersion"
]
nodes_list.append(node_property)
else:
return None
nodes_info_list.append(node_property)
return nodes_info_list
return nodes_list
def get_node_name_by_id(self, nodeid):
def get_node_name_by_id(self, nodeid: str) -> str:
return self.get_node_inspect(nodeid)["Description"]["Hostname"]
def get_unlock_key(self):
def get_unlock_key(self) -> str | None:
if self.docker_py_version < LooseVersion("2.7.0"):
return None
return super().get_unlock_key()
def get_service_inspect(self, service_id, skip_missing=False):
def get_service_inspect(
self, service_id: str, skip_missing: bool = False
) -> dict[str, t.Any] | None:
"""
Returns Swarm service info as in 'docker service inspect' command about single service

View File

@ -9,6 +9,7 @@ from __future__ import annotations
import json
import re
import typing as t
from datetime import timedelta
from urllib.parse import urlparse
@ -17,6 +18,12 @@ from ansible.module_utils.common.collections import is_sequence
from ansible.module_utils.common.text.converters import to_text
if t.TYPE_CHECKING:
from collections.abc import Callable
from ansible.module_utils.basic import AnsibleModule
DEFAULT_DOCKER_HOST = "unix:///var/run/docker.sock"
DEFAULT_TLS = False
DEFAULT_TLS_VERIFY = False
@ -69,22 +76,24 @@ DOCKER_COMMON_ARGS_VARS = {
if option_name != "debug"
}
DOCKER_MUTUALLY_EXCLUSIVE = []
DOCKER_MUTUALLY_EXCLUSIVE: list[tuple[str, ...] | list[str]] = []
DOCKER_REQUIRED_TOGETHER = [["client_cert", "client_key"]]
DOCKER_REQUIRED_TOGETHER: list[tuple[str, ...] | list[str]] = [
["client_cert", "client_key"]
]
DEFAULT_DOCKER_REGISTRY = "https://index.docker.io/v1/"
BYTE_SUFFIXES = ["B", "KB", "MB", "GB", "TB", "PB"]
def is_image_name_id(name):
def is_image_name_id(name: str) -> bool:
"""Check whether the given image name is in fact an image ID (hash)."""
if re.match("^sha256:[0-9a-fA-F]{64}$", name):
return True
return False
def is_valid_tag(tag, allow_empty=False):
def is_valid_tag(tag: str, allow_empty: bool = False) -> bool:
"""Check whether the given string is a valid docker tag name."""
if not tag:
return allow_empty
@ -93,7 +102,7 @@ def is_valid_tag(tag, allow_empty=False):
return bool(re.match("^[a-zA-Z0-9_][a-zA-Z0-9_.-]{0,127}$", tag))
def sanitize_result(data):
def sanitize_result(data: t.Any) -> t.Any:
"""Sanitize data object for return to Ansible.
When the data object contains types such as docker.types.containers.HostConfig,
@ -110,7 +119,7 @@ def sanitize_result(data):
return data
def log_debug(msg, pretty_print=False):
def log_debug(msg: t.Any, pretty_print: bool = False):
"""Write a log message to docker.log.
If ``pretty_print=True``, the message will be pretty-printed as JSON.
@ -126,25 +135,28 @@ def log_debug(msg, pretty_print=False):
class DockerBaseClass:
def __init__(self):
def __init__(self) -> None:
self.debug = False
def log(self, msg, pretty_print=False):
def log(self, msg: t.Any, pretty_print: bool = False) -> None:
pass
# if self.debug:
# log_debug(msg, pretty_print=pretty_print)
def update_tls_hostname(
result, old_behavior=False, deprecate_function=None, uses_tls=True
):
result: dict[str, t.Any],
old_behavior: bool = False,
deprecate_function: Callable[[str], None] | None = None,
uses_tls: bool = True,
) -> None:
if result["tls_hostname"] is None:
# get default machine name from the url
parsed_url = urlparse(result["docker_host"])
result["tls_hostname"] = parsed_url.netloc.rsplit(":", 1)[0]
def compare_dict_allow_more_present(av, bv):
def compare_dict_allow_more_present(av: dict, bv: dict) -> bool:
"""
Compare two dictionaries for whether every entry of the first is in the second.
"""
@ -156,7 +168,12 @@ def compare_dict_allow_more_present(av, bv):
return True
def compare_generic(a, b, method, datatype):
def compare_generic(
a: t.Any,
b: t.Any,
method: t.Literal["ignore", "strict", "allow_more_present"],
datatype: t.Literal["value", "list", "set", "set(dict)", "dict"],
) -> bool:
"""
Compare values a and b as described by method and datatype.
@ -247,10 +264,10 @@ def compare_generic(a, b, method, datatype):
class DifferenceTracker:
def __init__(self):
self._diff = []
def __init__(self) -> None:
self._diff: list[dict[str, t.Any]] = []
def add(self, name, parameter=None, active=None):
def add(self, name: str, parameter: t.Any = None, active: t.Any = None) -> None:
self._diff.append(
{
"name": name,
@ -259,14 +276,14 @@ class DifferenceTracker:
}
)
def merge(self, other_tracker):
def merge(self, other_tracker: DifferenceTracker) -> None:
self._diff.extend(other_tracker._diff)
@property
def empty(self):
def empty(self) -> bool:
return len(self._diff) == 0
def get_before_after(self):
def get_before_after(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
"""
Return texts ``before`` and ``after``.
"""
@ -277,13 +294,13 @@ class DifferenceTracker:
after[item["name"]] = item["parameter"]
return before, after
def has_difference_for(self, name):
def has_difference_for(self, name: str) -> bool:
"""
Returns a boolean if a difference exists for name
"""
return any(diff for diff in self._diff if diff["name"] == name)
def get_legacy_docker_container_diffs(self):
def get_legacy_docker_container_diffs(self) -> list[dict[str, t.Any]]:
"""
Return differences in the docker_container legacy format.
"""
@ -297,7 +314,7 @@ class DifferenceTracker:
result.append(item)
return result
def get_legacy_docker_diffs(self):
def get_legacy_docker_diffs(self) -> list[str]:
"""
Return differences in the docker_container legacy format.
"""
@ -305,8 +322,13 @@ class DifferenceTracker:
return result
def sanitize_labels(labels, labels_field, client=None, module=None):
def fail(msg):
def sanitize_labels(
labels: dict[str, t.Any] | None,
labels_field: str,
client=None,
module: AnsibleModule | None = None,
) -> None:
def fail(msg: str) -> t.NoReturn:
if client is not None:
client.fail(msg)
if module is not None:
@ -325,7 +347,21 @@ def sanitize_labels(labels, labels_field, client=None, module=None):
labels[k] = to_text(v)
def clean_dict_booleans_for_docker_api(data, allow_sequences=False):
@t.overload
def clean_dict_booleans_for_docker_api(
data: dict[str, t.Any], *, allow_sequences: t.Literal[False] = False
) -> dict[str, str]: ...
@t.overload
def clean_dict_booleans_for_docker_api(
data: dict[str, t.Any], *, allow_sequences: bool
) -> dict[str, str | list[str]]: ...
def clean_dict_booleans_for_docker_api(
data: dict[str, t.Any] | None, *, allow_sequences: bool = False
) -> dict[str, str] | dict[str, str | list[str]]:
"""
Go does not like Python booleans 'True' or 'False', while Ansible is just
fine with them in YAML. As such, they need to be converted in cases where
@ -353,7 +389,7 @@ def clean_dict_booleans_for_docker_api(data, allow_sequences=False):
return result
def convert_duration_to_nanosecond(time_str):
def convert_duration_to_nanosecond(time_str: str) -> int:
"""
Return time duration in nanosecond.
"""
@ -372,9 +408,9 @@ def convert_duration_to_nanosecond(time_str):
if not parts:
raise ValueError(f"Invalid time duration - {time_str}")
parts = parts.groupdict()
parts_dict = parts.groupdict()
time_params = {}
for name, value in parts.items():
for name, value in parts_dict.items():
if value:
time_params[name] = int(value)
@ -386,13 +422,15 @@ def convert_duration_to_nanosecond(time_str):
return time_in_nanoseconds
def normalize_healthcheck_test(test):
def normalize_healthcheck_test(test: t.Any) -> list[str]:
if isinstance(test, (tuple, list)):
return [str(e) for e in test]
return ["CMD-SHELL", str(test)]
def normalize_healthcheck(healthcheck, normalize_test=False):
def normalize_healthcheck(
healthcheck: dict[str, t.Any], normalize_test: bool = False
) -> dict[str, t.Any]:
"""
Return dictionary of healthcheck parameters.
"""
@ -438,7 +476,9 @@ def normalize_healthcheck(healthcheck, normalize_test=False):
return result
def parse_healthcheck(healthcheck):
def parse_healthcheck(
healthcheck: dict[str, t.Any] | None,
) -> tuple[dict[str, t.Any] | None, bool | None]:
"""
Return dictionary of healthcheck parameters and boolean if
healthcheck defined in image was requested to be disabled.
@ -456,8 +496,8 @@ def parse_healthcheck(healthcheck):
return result, False
def omit_none_from_dict(d):
def omit_none_from_dict(d: dict[str, t.Any]) -> dict[str, t.Any]:
"""
Return a copy of the dictionary with all keys with value None omitted.
"""
return dict((k, v) for (k, v) in d.items() if v is not None)
return {k: v for (k, v) in d.items() if v is not None}

View File

@ -80,7 +80,7 @@ import re
from ansible.module_utils.basic import AnsibleModule
def main():
def main() -> None:
module = AnsibleModule({}, supports_check_mode=True)
cpuset_path = "/proc/self/cpuset"

View File

@ -437,6 +437,7 @@ actions:
"""
import traceback
import typing as t
from ansible.module_utils.common.validation import check_type_int
@ -455,26 +456,32 @@ from ansible_collections.community.docker.plugins.module_utils._version import (
class ServicesManager(BaseComposeManager):
def __init__(self, client):
def __init__(self, client: AnsibleModuleDockerClient) -> None:
super().__init__(client)
parameters = self.client.module.params
self.state = parameters["state"]
self.dependencies = parameters["dependencies"]
self.pull = parameters["pull"]
self.build = parameters["build"]
self.ignore_build_events = parameters["ignore_build_events"]
self.recreate = parameters["recreate"]
self.remove_images = parameters["remove_images"]
self.remove_volumes = parameters["remove_volumes"]
self.remove_orphans = parameters["remove_orphans"]
self.renew_anon_volumes = parameters["renew_anon_volumes"]
self.timeout = parameters["timeout"]
self.services = parameters["services"] or []
self.scale = parameters["scale"] or {}
self.wait = parameters["wait"]
self.wait_timeout = parameters["wait_timeout"]
self.yes = parameters["assume_yes"]
self.state: t.Literal["absent", "present", "stopped", "restarted"] = parameters[
"state"
]
self.dependencies: bool = parameters["dependencies"]
self.pull: t.Literal["always", "missing", "never", "policy"] = parameters[
"pull"
]
self.build: t.Literal["always", "never", "policy"] = parameters["build"]
self.ignore_build_events: bool = parameters["ignore_build_events"]
self.recreate: t.Literal["always", "never", "auto"] = parameters["recreate"]
self.remove_images: t.Literal["all", "local"] | None = parameters[
"remove_images"
]
self.remove_volumes: bool = parameters["remove_volumes"]
self.remove_orphans: bool = parameters["remove_orphans"]
self.renew_anon_volumes: bool = parameters["renew_anon_volumes"]
self.timeout: int | None = parameters["timeout"]
self.services: list[str] = parameters["services"] or []
self.scale: dict[str, t.Any] = parameters["scale"] or {}
self.wait: bool = parameters["wait"]
self.wait_timeout: int | None = parameters["wait_timeout"]
self.yes: bool = parameters["assume_yes"]
if self.compose_version < LooseVersion("2.32.0") and self.yes:
self.fail(
f"assume_yes=true needs Docker Compose 2.32.0 or newer, not version {self.compose_version}"
@ -491,7 +498,7 @@ class ServicesManager(BaseComposeManager):
self.fail(f"The value {value!r} for `scale[{key!r}]` is negative")
self.scale[key] = value
def run(self):
def run(self) -> dict[str, t.Any]:
if self.state == "present":
result = self.cmd_up()
elif self.state == "stopped":
@ -508,7 +515,7 @@ class ServicesManager(BaseComposeManager):
self.cleanup_result(result)
return result
def get_up_cmd(self, dry_run, no_start=False):
def get_up_cmd(self, dry_run: bool, no_start: bool = False) -> list[str]:
args = self.get_base_args() + ["up", "--detach", "--no-color", "--quiet-pull"]
if self.pull != "policy":
args.extend(["--pull", self.pull])
@ -549,8 +556,8 @@ class ServicesManager(BaseComposeManager):
args.append(service)
return args
def cmd_up(self):
result = {}
def cmd_up(self) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
args = self.get_up_cmd(self.check_mode)
rc, stdout, stderr = self.client.call_cli(*args, cwd=self.project_src)
events = self.parse_events(stderr, dry_run=self.check_mode, nonzero_rc=rc != 0)
@ -566,7 +573,7 @@ class ServicesManager(BaseComposeManager):
self.update_failed(result, events, args, stdout, stderr, rc)
return result
def get_stop_cmd(self, dry_run):
def get_stop_cmd(self, dry_run: bool) -> list[str]:
args = self.get_base_args() + ["stop"]
if self.timeout is not None:
args.extend(["--timeout", f"{self.timeout}"])
@ -577,17 +584,17 @@ class ServicesManager(BaseComposeManager):
args.append(service)
return args
def _are_containers_stopped(self):
def _are_containers_stopped(self) -> bool:
for container in self.list_containers_raw():
if container["State"] not in ("created", "exited", "stopped", "killed"):
return False
return True
def cmd_stop(self):
def cmd_stop(self) -> dict[str, t.Any]:
# Since 'docker compose stop' **always** claims it is stopping containers, even if they are already
# stopped, we have to do this a bit more complicated.
result = {}
result: dict[str, t.Any] = {}
# Make sure all containers are created
args_1 = self.get_up_cmd(self.check_mode, no_start=True)
rc_1, stdout_1, stderr_1 = self.client.call_cli(*args_1, cwd=self.project_src)
@ -630,7 +637,7 @@ class ServicesManager(BaseComposeManager):
)
return result
def get_restart_cmd(self, dry_run):
def get_restart_cmd(self, dry_run: bool) -> list[str]:
args = self.get_base_args() + ["restart"]
if not self.dependencies:
args.append("--no-deps")
@ -643,8 +650,8 @@ class ServicesManager(BaseComposeManager):
args.append(service)
return args
def cmd_restart(self):
result = {}
def cmd_restart(self) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
args = self.get_restart_cmd(self.check_mode)
rc, stdout, stderr = self.client.call_cli(*args, cwd=self.project_src)
events = self.parse_events(stderr, dry_run=self.check_mode, nonzero_rc=rc != 0)
@ -653,7 +660,7 @@ class ServicesManager(BaseComposeManager):
self.update_failed(result, events, args, stdout, stderr, rc)
return result
def get_down_cmd(self, dry_run):
def get_down_cmd(self, dry_run: bool) -> list[str]:
args = self.get_base_args() + ["down"]
if self.remove_orphans:
args.append("--remove-orphans")
@ -670,8 +677,8 @@ class ServicesManager(BaseComposeManager):
args.append(service)
return args
def cmd_down(self):
result = {}
def cmd_down(self) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
args = self.get_down_cmd(self.check_mode)
rc, stdout, stderr = self.client.call_cli(*args, cwd=self.project_src)
events = self.parse_events(stderr, dry_run=self.check_mode, nonzero_rc=rc != 0)
@ -681,7 +688,7 @@ class ServicesManager(BaseComposeManager):
return result
def main():
def main() -> None:
argument_spec = {
"state": {
"type": "str",

View File

@ -166,6 +166,7 @@ rc:
import shlex
import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_text
@ -180,29 +181,32 @@ from ansible_collections.community.docker.plugins.module_utils._compose_v2 impor
class ExecManager(BaseComposeManager):
def __init__(self, client):
def __init__(self, client: AnsibleModuleDockerClient) -> None:
super().__init__(client)
parameters = self.client.module.params
self.service = parameters["service"]
self.index = parameters["index"]
self.chdir = parameters["chdir"]
self.detach = parameters["detach"]
self.user = parameters["user"]
self.stdin = parameters["stdin"]
self.strip_empty_ends = parameters["strip_empty_ends"]
self.privileged = parameters["privileged"]
self.tty = parameters["tty"]
self.env = parameters["env"]
self.service: str = parameters["service"]
self.index: int | None = parameters["index"]
self.chdir: str | None = parameters["chdir"]
self.detach: bool = parameters["detach"]
self.user: str | None = parameters["user"]
self.stdin: str | None = parameters["stdin"]
self.strip_empty_ends: bool = parameters["strip_empty_ends"]
self.privileged: bool = parameters["privileged"]
self.tty: bool = parameters["tty"]
self.env: dict[str, t.Any] = parameters["env"]
self.argv = parameters["argv"]
self.argv: list[str]
if parameters["command"] is not None:
self.argv = shlex.split(parameters["command"])
else:
self.argv = parameters["argv"]
if self.detach and self.stdin is not None:
self.fail("If detach=true, stdin cannot be provided.")
if self.stdin is not None and parameters["stdin_add_newline"]:
stdin_add_newline: bool = parameters["stdin_add_newline"]
if self.stdin is not None and stdin_add_newline:
self.stdin += "\n"
if self.env is not None:
@ -214,7 +218,7 @@ class ExecManager(BaseComposeManager):
)
self.env[name] = to_text(value, errors="surrogate_or_strict")
def get_exec_cmd(self, dry_run, no_start=False):
def get_exec_cmd(self, dry_run: bool) -> list[str]:
args = self.get_base_args(plain_progress=True) + ["exec"]
if self.index is not None:
args.extend(["--index", str(self.index)])
@ -237,9 +241,9 @@ class ExecManager(BaseComposeManager):
args.extend(self.argv)
return args
def run(self):
def run(self) -> dict[str, t.Any]:
args = self.get_exec_cmd(self.check_mode)
kwargs = {
kwargs: dict[str, t.Any] = {
"cwd": self.project_src,
}
if self.stdin is not None:
@ -262,7 +266,7 @@ class ExecManager(BaseComposeManager):
}
def main():
def main() -> None:
argument_spec = {
"service": {"type": "str", "required": True},
"index": {"type": "int"},

View File

@ -111,6 +111,7 @@ actions:
"""
import traceback
import typing as t
from ansible_collections.community.docker.plugins.module_utils._common_cli import (
AnsibleModuleDockerClient,
@ -126,14 +127,14 @@ from ansible_collections.community.docker.plugins.module_utils._version import (
class PullManager(BaseComposeManager):
def __init__(self, client):
def __init__(self, client: AnsibleModuleDockerClient) -> None:
super().__init__(client)
parameters = self.client.module.params
self.policy = parameters["policy"]
self.ignore_buildable = parameters["ignore_buildable"]
self.include_deps = parameters["include_deps"]
self.services = parameters["services"] or []
self.policy: t.Literal["always", "missing"] = parameters["policy"]
self.ignore_buildable: bool = parameters["ignore_buildable"]
self.include_deps: bool = parameters["include_deps"]
self.services: list[str] = parameters["services"] or []
if self.policy != "always" and self.compose_version < LooseVersion("2.22.0"):
# https://github.com/docker/compose/pull/10981 - 2.22.0
@ -146,7 +147,7 @@ class PullManager(BaseComposeManager):
f"--ignore-buildable is only supported since Docker Compose 2.15.0. {self.client.get_cli()} has version {self.compose_version}"
)
def get_pull_cmd(self, dry_run, no_start=False):
def get_pull_cmd(self, dry_run: bool):
args = self.get_base_args() + ["pull"]
if self.policy != "always":
args.extend(["--policy", self.policy])
@ -161,8 +162,8 @@ class PullManager(BaseComposeManager):
args.append(service)
return args
def run(self):
result = {}
def run(self) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
args = self.get_pull_cmd(self.check_mode)
rc, stdout, stderr = self.client.call_cli(*args, cwd=self.project_src)
events = self.parse_events(stderr, dry_run=self.check_mode, nonzero_rc=rc != 0)
@ -179,7 +180,7 @@ class PullManager(BaseComposeManager):
return result
def main():
def main() -> None:
argument_spec = {
"policy": {
"type": "str",

View File

@ -239,6 +239,7 @@ rc:
import shlex
import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_text
@ -253,42 +254,45 @@ from ansible_collections.community.docker.plugins.module_utils._compose_v2 impor
class ExecManager(BaseComposeManager):
def __init__(self, client):
def __init__(self, client: AnsibleModuleDockerClient) -> None:
super().__init__(client)
parameters = self.client.module.params
self.service = parameters["service"]
self.build = parameters["build"]
self.cap_add = parameters["cap_add"]
self.cap_drop = parameters["cap_drop"]
self.entrypoint = parameters["entrypoint"]
self.interactive = parameters["interactive"]
self.labels = parameters["labels"]
self.name = parameters["name"]
self.no_deps = parameters["no_deps"]
self.publish = parameters["publish"]
self.quiet_pull = parameters["quiet_pull"]
self.remove_orphans = parameters["remove_orphans"]
self.do_cleanup = parameters["cleanup"]
self.service_ports = parameters["service_ports"]
self.use_aliases = parameters["use_aliases"]
self.volumes = parameters["volumes"]
self.chdir = parameters["chdir"]
self.detach = parameters["detach"]
self.user = parameters["user"]
self.stdin = parameters["stdin"]
self.strip_empty_ends = parameters["strip_empty_ends"]
self.tty = parameters["tty"]
self.env = parameters["env"]
self.service: str = parameters["service"]
self.build: bool = parameters["build"]
self.cap_add: list[str] | None = parameters["cap_add"]
self.cap_drop: list[str] | None = parameters["cap_drop"]
self.entrypoint: str | None = parameters["entrypoint"]
self.interactive: bool = parameters["interactive"]
self.labels: list[str] | None = parameters["labels"]
self.name: str | None = parameters["name"]
self.no_deps: bool = parameters["no_deps"]
self.publish: list[str] | None = parameters["publish"]
self.quiet_pull: bool = parameters["quiet_pull"]
self.remove_orphans: bool = parameters["remove_orphans"]
self.do_cleanup: bool = parameters["cleanup"]
self.service_ports: bool = parameters["service_ports"]
self.use_aliases: bool = parameters["use_aliases"]
self.volumes: list[str] | None = parameters["volumes"]
self.chdir: str | None = parameters["chdir"]
self.detach: bool = parameters["detach"]
self.user: str | None = parameters["user"]
self.stdin: str | None = parameters["stdin"]
self.strip_empty_ends: bool = parameters["strip_empty_ends"]
self.tty: bool = parameters["tty"]
self.env: dict[str, t.Any] | None = parameters["env"]
self.argv = parameters["argv"]
self.argv: list[str]
if parameters["command"] is not None:
self.argv = shlex.split(parameters["command"])
else:
self.argv = parameters["argv"]
if self.detach and self.stdin is not None:
self.fail("If detach=true, stdin cannot be provided.")
if self.stdin is not None and parameters["stdin_add_newline"]:
stdin_add_newline: bool = parameters["stdin_add_newline"]
if self.stdin is not None and stdin_add_newline:
self.stdin += "\n"
if self.env is not None:
@ -300,7 +304,7 @@ class ExecManager(BaseComposeManager):
)
self.env[name] = to_text(value, errors="surrogate_or_strict")
def get_run_cmd(self, dry_run, no_start=False):
def get_run_cmd(self, dry_run: bool) -> list[str]:
args = self.get_base_args(plain_progress=True) + ["run"]
if self.build:
args.append("--build")
@ -355,9 +359,9 @@ class ExecManager(BaseComposeManager):
args.extend(self.argv)
return args
def run(self):
def run(self) -> dict[str, t.Any]:
args = self.get_run_cmd(self.check_mode)
kwargs = {
kwargs: dict[str, t.Any] = {
"cwd": self.project_src,
}
if self.stdin is not None:
@ -382,7 +386,7 @@ class ExecManager(BaseComposeManager):
}
def main():
def main() -> None:
argument_spec = {
"service": {"type": "str", "required": True},
"argv": {"type": "list", "elements": "str"},

View File

@ -1355,7 +1355,7 @@ from ansible_collections.community.docker.plugins.module_utils._module_container
)
def main():
def main() -> None:
engine_driver = DockerAPIEngineDriver()
run_module(engine_driver)

View File

@ -169,6 +169,7 @@ import io
import os
import stat
import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
from ansible.module_utils.common.validation import check_type_int
@ -198,23 +199,29 @@ from ansible_collections.community.docker.plugins.module_utils._scramble import
)
def are_fileobjs_equal(f1, f2):
if t.TYPE_CHECKING:
import tarfile
def are_fileobjs_equal(f1: t.IO[bytes], f2: t.IO[bytes]) -> bool:
"""Given two (buffered) file objects, compare their contents."""
f1on: t.IO[bytes] | None = f1
f2on: t.IO[bytes] | None = f2
blocksize = 65536
b1buf = b""
b2buf = b""
while True:
if f1 and len(b1buf) < blocksize:
f1b = f1.read(blocksize)
if f1on and len(b1buf) < blocksize:
f1b = f1on.read(blocksize)
if not f1b:
# f1 is EOF, so stop reading from it
f1 = None
f1on = None
b1buf += f1b
if f2 and len(b2buf) < blocksize:
f2b = f2.read(blocksize)
if f2on and len(b2buf) < blocksize:
f2b = f2on.read(blocksize)
if not f2b:
# f2 is EOF, so stop reading from it
f2 = None
f2on = None
b2buf += f2b
if not b1buf or not b2buf:
# At least one of f1 and f2 is EOF and all its data has
@ -229,29 +236,33 @@ def are_fileobjs_equal(f1, f2):
b2buf = b2buf[buflen:]
def are_fileobjs_equal_read_first(f1, f2):
def are_fileobjs_equal_read_first(
f1: t.IO[bytes], f2: t.IO[bytes]
) -> tuple[bool, bytes]:
"""Given two (buffered) file objects, compare their contents.
Returns a tuple (is_equal, content_of_f1), where the first element indicates
whether the two file objects have the same content, and the second element is
the content of the first file object."""
f1on: t.IO[bytes] | None = f1
f2on: t.IO[bytes] | None = f2
blocksize = 65536
b1buf = b""
b2buf = b""
is_equal = True
content = []
while True:
if f1 and len(b1buf) < blocksize:
f1b = f1.read(blocksize)
if f1on and len(b1buf) < blocksize:
f1b = f1on.read(blocksize)
if not f1b:
# f1 is EOF, so stop reading from it
f1 = None
f1on = None
b1buf += f1b
if f2 and len(b2buf) < blocksize:
f2b = f2.read(blocksize)
if f2on and len(b2buf) < blocksize:
f2b = f2on.read(blocksize)
if not f2b:
# f2 is EOF, so stop reading from it
f2 = None
f2on = None
b2buf += f2b
if not b1buf or not b2buf:
# At least one of f1 and f2 is EOF and all its data has
@ -269,13 +280,13 @@ def are_fileobjs_equal_read_first(f1, f2):
b2buf = b2buf[buflen:]
content.append(b1buf)
if f1:
content.append(f1.read())
if f1on:
content.append(f1on.read())
return is_equal, b"".join(content)
def is_container_file_not_regular_file(container_stat):
def is_container_file_not_regular_file(container_stat: dict[str, t.Any]) -> bool:
for bit in (
# https://pkg.go.dev/io/fs#FileMode
32 - 1, # ModeDir
@ -292,7 +303,7 @@ def is_container_file_not_regular_file(container_stat):
return False
def get_container_file_mode(container_stat):
def get_container_file_mode(container_stat: dict[str, t.Any]) -> int:
mode = container_stat["mode"] & 0xFFF
if container_stat["mode"] & (1 << (32 - 9)) != 0: # ModeSetuid
mode |= stat.S_ISUID # set UID bit
@ -303,7 +314,9 @@ def get_container_file_mode(container_stat):
return mode
def add_other_diff(diff, in_path, member):
def add_other_diff(
diff: dict[str, t.Any] | None, in_path: str, member: tarfile.TarInfo
) -> None:
if diff is None:
return
diff["before_header"] = in_path
@ -326,14 +339,14 @@ def add_other_diff(diff, in_path, member):
def retrieve_diff(
client,
container,
container_path,
follow_links,
diff,
max_file_size_for_diff,
regular_stat=None,
link_target=None,
client: AnsibleDockerClient,
container: str,
container_path: str,
follow_links: bool,
diff: dict[str, t.Any] | None,
max_file_size_for_diff: int,
regular_stat: dict[str, t.Any] | None = None,
link_target: str | None = None,
):
if diff is None:
return
@ -377,19 +390,21 @@ def retrieve_diff(
return
# We need to get hold of the content
def process_none(in_path):
def process_none(in_path: str) -> None:
diff["before"] = ""
def process_regular(in_path, tar, member):
def process_regular(
in_path: str, tar: tarfile.TarFile, member: tarfile.TarInfo
) -> None:
add_diff_dst_from_regular_member(
diff, max_file_size_for_diff, in_path, tar, member
)
def process_symlink(in_path, member):
def process_symlink(in_path: str, member: tarfile.TarInfo) -> None:
diff["before_header"] = in_path
diff["before"] = member.linkname
def process_other(in_path, member):
def process_other(in_path: str, member: tarfile.TarInfo) -> None:
add_other_diff(diff, in_path, member)
fetch_file_ex(
@ -404,7 +419,7 @@ def retrieve_diff(
)
def is_binary(content):
def is_binary(content: bytes) -> bool:
if b"\x00" in content:
return True
# TODO: better detection
@ -413,8 +428,13 @@ def is_binary(content):
def are_fileobjs_equal_with_diff_of_first(
f1, f2, size, diff, max_file_size_for_diff, container_path
):
f1: t.IO[bytes],
f2: t.IO[bytes],
size: int,
diff: dict[str, t.Any] | None,
max_file_size_for_diff: int,
container_path: str,
) -> bool:
if diff is None:
return are_fileobjs_equal(f1, f2)
if size > max_file_size_for_diff > 0:
@ -430,15 +450,22 @@ def are_fileobjs_equal_with_diff_of_first(
def add_diff_dst_from_regular_member(
diff, max_file_size_for_diff, container_path, tar, member
):
diff: dict[str, t.Any] | None,
max_file_size_for_diff: int,
container_path: str,
tar: tarfile.TarFile,
member: tarfile.TarInfo,
) -> None:
if diff is None:
return
if member.size > max_file_size_for_diff > 0:
diff["dst_larger"] = max_file_size_for_diff
return
with tar.extractfile(member) as tar_f:
mf = tar.extractfile(member)
if not mf:
raise AssertionError("Member should be present for regular file")
with mf as tar_f:
content = tar_f.read()
if is_binary(content):
@ -448,35 +475,35 @@ def add_diff_dst_from_regular_member(
diff["before"] = to_text(content)
def copy_dst_to_src(diff):
def copy_dst_to_src(diff: dict[str, t.Any] | None) -> None:
if diff is None:
return
for f, t in [
for frm, to in [
("dst_size", "src_size"),
("dst_binary", "src_binary"),
("before_header", "after_header"),
("before", "after"),
]:
if f in diff:
diff[t] = diff[f]
elif t in diff:
diff.pop(t)
if frm in diff:
diff[to] = diff[frm]
elif to in diff:
diff.pop(to)
def is_file_idempotent(
client,
container,
managed_path,
container_path,
follow_links,
local_follow_links,
client: AnsibleDockerClient,
container: str,
managed_path: str,
container_path: str,
follow_links: bool,
local_follow_links: bool,
owner_id,
group_id,
mode,
force=False,
diff=None,
max_file_size_for_diff=1,
):
force: bool | None = False,
diff: dict[str, t.Any] | None = None,
max_file_size_for_diff: int = 1,
) -> tuple[str, int, bool]:
# Retrieve information of local file
try:
file_stat = (
@ -644,10 +671,12 @@ def is_file_idempotent(
return container_path, mode, False
# Fetch file from container
def process_none(in_path):
def process_none(in_path: str) -> tuple[str, int, bool]:
return container_path, mode, False
def process_regular(in_path, tar, member):
def process_regular(
in_path: str, tar: tarfile.TarFile, member: tarfile.TarInfo
) -> tuple[str, int, bool]:
# Check things like user/group ID and mode
if any(
[
@ -663,14 +692,17 @@ def is_file_idempotent(
)
return container_path, mode, False
with tar.extractfile(member) as tar_f:
mf = tar.extractfile(member)
if mf is None:
raise AssertionError("Member should be present for regular file")
with mf as tar_f:
with open(managed_path, "rb") as local_f:
is_equal = are_fileobjs_equal_with_diff_of_first(
tar_f, local_f, member.size, diff, max_file_size_for_diff, in_path
)
return container_path, mode, is_equal
def process_symlink(in_path, member):
def process_symlink(in_path: str, member: tarfile.TarInfo) -> tuple[str, int, bool]:
if diff is not None:
diff["before_header"] = in_path
diff["before"] = member.linkname
@ -689,7 +721,7 @@ def is_file_idempotent(
local_link_target = os.readlink(managed_path)
return container_path, mode, member.linkname == local_link_target
def process_other(in_path, member):
def process_other(in_path: str, member: tarfile.TarInfo) -> tuple[str, int, bool]:
add_other_diff(diff, in_path, member)
return container_path, mode, False
@ -706,23 +738,21 @@ def is_file_idempotent(
def copy_file_into_container(
client,
container,
managed_path,
container_path,
follow_links,
local_follow_links,
client: AnsibleDockerClient,
container: str,
managed_path: str,
container_path: str,
follow_links: bool,
local_follow_links: bool,
owner_id,
group_id,
mode,
force=False,
diff=False,
max_file_size_for_diff=1,
):
if diff:
diff = {}
else:
diff = None
force: bool | None = False,
do_diff: bool = False,
max_file_size_for_diff: int = 1,
) -> t.NoReturn:
diff: dict[str, t.Any] | None
diff = {} if do_diff else None
container_path, mode, idempotent = is_file_idempotent(
client,
@ -762,18 +792,18 @@ def copy_file_into_container(
def is_content_idempotent(
client,
container,
content,
container_path,
follow_links,
client: AnsibleDockerClient,
container: str,
content: bytes,
container_path: str,
follow_links: bool,
owner_id,
group_id,
mode,
force=False,
diff=None,
max_file_size_for_diff=1,
):
force: bool | None = False,
diff: dict[str, t.Any] | None = None,
max_file_size_for_diff: int = 1,
) -> tuple[str, int, bool]:
if diff is not None:
if len(content) > max_file_size_for_diff > 0:
diff["src_larger"] = max_file_size_for_diff
@ -894,12 +924,14 @@ def is_content_idempotent(
return container_path, mode, False
# Fetch file from container
def process_none(in_path):
def process_none(in_path: str) -> tuple[str, int, bool]:
if diff is not None:
diff["before"] = ""
return container_path, mode, False
def process_regular(in_path, tar, member):
def process_regular(
in_path: str, tar: tarfile.TarFile, member: tarfile.TarInfo
) -> tuple[str, int, bool]:
# Check things like user/group ID and mode
if any(
[
@ -914,7 +946,10 @@ def is_content_idempotent(
)
return container_path, mode, False
with tar.extractfile(member) as tar_f:
mf = tar.extractfile(member)
if mf is None:
raise AssertionError("Member should be present for regular file")
with mf as tar_f:
is_equal = are_fileobjs_equal_with_diff_of_first(
tar_f,
io.BytesIO(content),
@ -925,14 +960,14 @@ def is_content_idempotent(
)
return container_path, mode, is_equal
def process_symlink(in_path, member):
def process_symlink(in_path: str, member: tarfile.TarInfo) -> tuple[str, int, bool]:
if diff is not None:
diff["before_header"] = in_path
diff["before"] = member.linkname
return container_path, mode, False
def process_other(in_path, member):
def process_other(in_path: str, member: tarfile.TarInfo) -> tuple[str, int, bool]:
add_other_diff(diff, in_path, member)
return container_path, mode, False
@ -949,22 +984,19 @@ def is_content_idempotent(
def copy_content_into_container(
client,
container,
content,
container_path,
follow_links,
client: AnsibleDockerClient,
container: str,
content: bytes,
container_path: str,
follow_links: bool,
owner_id,
group_id,
mode,
force=False,
diff=False,
max_file_size_for_diff=1,
):
if diff:
diff = {}
else:
diff = None
force: bool | None = False,
do_diff: bool = False,
max_file_size_for_diff: int = 1,
) -> t.NoReturn:
diff: dict[str, t.Any] | None = {} if do_diff else None
container_path, mode, idempotent = is_content_idempotent(
client,
@ -1007,7 +1039,7 @@ def copy_content_into_container(
client.module.exit_json(**result)
def parse_modern(mode):
def parse_modern(mode: str | int) -> int:
if isinstance(mode, str):
return int(to_native(mode), 8)
if isinstance(mode, int):
@ -1015,13 +1047,13 @@ def parse_modern(mode):
raise TypeError(f"must be an octal string or an integer, got {mode!r}")
def parse_octal_string_only(mode):
def parse_octal_string_only(mode: str) -> int:
if isinstance(mode, str):
return int(to_native(mode), 8)
raise TypeError(f"must be an octal string, got {mode!r}")
def main():
def main() -> None:
argument_spec = {
"container": {"type": "str", "required": True},
"path": {"type": "path"},
@ -1054,20 +1086,22 @@ def main():
},
)
container = client.module.params["container"]
managed_path = client.module.params["path"]
container_path = client.module.params["container_path"]
follow = client.module.params["follow"]
local_follow = client.module.params["local_follow"]
owner_id = client.module.params["owner_id"]
group_id = client.module.params["group_id"]
mode = client.module.params["mode"]
force = client.module.params["force"]
content = client.module.params["content"]
max_file_size_for_diff = client.module.params["_max_file_size_for_diff"] or 1
container: str = client.module.params["container"]
managed_path: str | None = client.module.params["path"]
container_path: str = client.module.params["container_path"]
follow: bool = client.module.params["follow"]
local_follow: bool = client.module.params["local_follow"]
owner_id: int | None = client.module.params["owner_id"]
group_id: int | None = client.module.params["group_id"]
mode: t.Any = client.module.params["mode"]
force: bool | None = client.module.params["force"]
content_str: str | None = client.module.params["content"]
max_file_size_for_diff: int = client.module.params["_max_file_size_for_diff"] or 1
if mode is not None:
mode_parse = client.module.params["mode_parse"]
mode_parse: t.Literal["legacy", "modern", "octal_string_only"] = (
client.module.params["mode_parse"]
)
try:
if mode_parse == "legacy":
mode = check_type_int(mode)
@ -1080,14 +1114,15 @@ def main():
if mode < 0:
client.fail(f"'mode' must not be negative; got {mode}")
if content is not None:
content: bytes | None = None
if content_str is not None:
if client.module.params["content_is_b64"]:
try:
content = base64.b64decode(content)
content = base64.b64decode(content_str)
except Exception as e: # pylint: disable=broad-exception-caught
client.fail(f"Cannot Base64 decode the content option: {e}")
else:
content = to_bytes(content)
content = to_bytes(content_str)
if not container_path.startswith(os.path.sep):
container_path = os.path.join(os.path.sep, container_path)
@ -1108,7 +1143,7 @@ def main():
group_id=group_id,
mode=mode,
force=force,
diff=client.module._diff,
do_diff=client.module._diff,
max_file_size_for_diff=max_file_size_for_diff,
)
elif managed_path is not None:
@ -1123,7 +1158,7 @@ def main():
group_id=group_id,
mode=mode,
force=force,
diff=client.module._diff,
do_diff=client.module._diff,
max_file_size_for_diff=max_file_size_for_diff,
)
else:

View File

@ -165,6 +165,7 @@ exec_id:
import shlex
import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_bytes, to_text
@ -185,7 +186,7 @@ from ansible_collections.community.docker.plugins.module_utils._socket_handler i
)
def main():
def main() -> None:
argument_spec = {
"container": {"type": "str", "required": True},
"argv": {"type": "list", "elements": "str"},
@ -211,16 +212,16 @@ def main():
required_one_of=[("argv", "command")],
)
container = client.module.params["container"]
argv = client.module.params["argv"]
command = client.module.params["command"]
chdir = client.module.params["chdir"]
detach = client.module.params["detach"]
user = client.module.params["user"]
stdin = client.module.params["stdin"]
strip_empty_ends = client.module.params["strip_empty_ends"]
tty = client.module.params["tty"]
env = client.module.params["env"]
container: str = client.module.params["container"]
argv: list[str] | None = client.module.params["argv"]
command: str | None = client.module.params["command"]
chdir: str | None = client.module.params["chdir"]
detach: bool = client.module.params["detach"]
user: str | None = client.module.params["user"]
stdin: str | None = client.module.params["stdin"]
strip_empty_ends: bool = client.module.params["strip_empty_ends"]
tty: bool = client.module.params["tty"]
env: dict[str, t.Any] = client.module.params["env"]
if env is not None:
for name, value in list(env.items()):
@ -233,6 +234,7 @@ def main():
if command is not None:
argv = shlex.split(command)
assert argv is not None
if detach and stdin is not None:
client.module.fail_json(msg="If detach=true, stdin cannot be provided.")
@ -258,7 +260,7 @@ def main():
exec_data = client.post_json_to_json(
"/containers/{0}/exec", container, data=data
)
exec_id = exec_data["Id"]
exec_id: str = exec_data["Id"]
data = {
"Tty": tty,
@ -269,6 +271,8 @@ def main():
client.module.exit_json(changed=True, exec_id=exec_id)
else:
stdout: bytes | None
stderr: bytes | None
if stdin and not detach:
exec_socket = client.post_json_to_stream_socket(
"/exec/{0}/start", exec_id, data=data
@ -283,28 +287,37 @@ def main():
stdout, stderr = exec_socket_handler.consume()
finally:
exec_socket.close()
elif tty:
stdout, stderr = client.post_json_to_stream(
"/exec/{0}/start",
exec_id,
data=data,
stream=False,
tty=True,
demux=True,
)
else:
stdout, stderr = client.post_json_to_stream(
"/exec/{0}/start",
exec_id,
data=data,
stream=False,
tty=tty,
tty=False,
demux=True,
)
result = client.get_json("/exec/{0}/json", exec_id)
stdout = to_text(stdout or b"")
stderr = to_text(stderr or b"")
stdout_t = to_text(stdout or b"")
stderr_t = to_text(stderr or b"")
if strip_empty_ends:
stdout = stdout.rstrip("\r\n")
stderr = stderr.rstrip("\r\n")
stdout_t = stdout_t.rstrip("\r\n")
stderr_t = stderr_t.rstrip("\r\n")
client.module.exit_json(
changed=True,
stdout=stdout,
stderr=stderr,
stdout=stdout_t,
stderr=stderr_t,
rc=result.get("ExitCode") or 0,
)
except NotFound:

View File

@ -86,7 +86,7 @@ from ansible_collections.community.docker.plugins.module_utils._common_api impor
)
def main():
def main() -> None:
argument_spec = {
"name": {"type": "str", "required": True},
}
@ -96,8 +96,9 @@ def main():
supports_check_mode=True,
)
container_id: str = client.module.params["name"]
try:
container = client.get_container(client.module.params["name"])
container = client.get_container(container_id)
client.module.exit_json(
changed=False,

View File

@ -173,6 +173,7 @@ current_context_name:
"""
import traceback
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_text
@ -185,6 +186,7 @@ from ansible_collections.community.docker.plugins.module_utils._api.context.conf
)
from ansible_collections.community.docker.plugins.module_utils._api.context.context import (
IN_MEMORY,
Context,
)
from ansible_collections.community.docker.plugins.module_utils._api.errors import (
ContextException,
@ -192,7 +194,13 @@ from ansible_collections.community.docker.plugins.module_utils._api.errors impor
)
def tls_context_to_json(context):
if t.TYPE_CHECKING:
from ansible_collections.community.docker.plugins.module_utils._api.tls import (
TLSConfig,
)
def tls_context_to_json(context: TLSConfig | None) -> dict[str, t.Any] | None:
if context is None:
return None
return {
@ -204,8 +212,8 @@ def tls_context_to_json(context):
}
def context_to_json(context, current):
module_config = {}
def context_to_json(context: Context, current: bool) -> dict[str, t.Any]:
module_config: dict[str, t.Any] = {}
if "docker" in context.endpoints:
endpoint = context.endpoints["docker"]
if isinstance(endpoint.get("Host"), str):
@ -247,7 +255,7 @@ def context_to_json(context, current):
}
def main():
def main() -> None:
argument_spec = {
"only_current": {"type": "bool", "default": False},
"name": {"type": "str"},
@ -262,28 +270,31 @@ def main():
],
)
only_current: bool = module.params["only_current"]
name: str | None = module.params["name"]
cli_context: str | None = module.params["cli_context"]
try:
if module.params["cli_context"]:
if cli_context:
current_context_name, current_context_source = (
module.params["cli_context"],
cli_context,
"cli_context module option",
)
else:
current_context_name, current_context_source = (
get_current_context_name_with_source()
)
if module.params["name"]:
contexts = [ContextAPI.get_context(module.params["name"])]
if not contexts[0]:
module.fail_json(
msg=f"There is no context of name {module.params['name']!r}"
)
elif module.params["only_current"]:
contexts = [ContextAPI.get_context(current_context_name)]
if not contexts[0]:
if name:
context_or_none = ContextAPI.get_context(name)
if not context_or_none:
module.fail_json(msg=f"There is no context of name {name!r}")
contexts = [context_or_none]
elif only_current:
context_or_none = ContextAPI.get_context(current_context_name)
if not context_or_none:
module.fail_json(
msg=f"There is no context of name {current_context_name!r}, which is configured as the default context ({current_context_source})",
)
contexts = [context_or_none]
else:
contexts = ContextAPI.contexts()

View File

@ -212,6 +212,7 @@ disk_usage:
"""
import traceback
import typing as t
from ansible_collections.community.docker.plugins.module_utils._api.errors import (
APIError,
@ -231,9 +232,7 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
class DockerHostManager(DockerBaseClass):
def __init__(self, client, results):
def __init__(self, client: AnsibleDockerClient, results: dict[str, t.Any]) -> None:
super().__init__()
self.client = client
@ -253,21 +252,21 @@ class DockerHostManager(DockerBaseClass):
for docker_object in listed_objects:
if self.client.module.params[docker_object]:
returned_name = docker_object
filter_name = docker_object + "_filters"
filter_name = f"{docker_object}_filters"
filters = clean_dict_booleans_for_docker_api(
client.module.params.get(filter_name), True
client.module.params.get(filter_name), allow_sequences=True
)
self.results[returned_name] = self.get_docker_items_list(
docker_object, filters
)
def get_docker_host_info(self):
def get_docker_host_info(self) -> dict[str, t.Any]:
try:
return self.client.info()
except APIError as exc:
self.client.fail(f"Error inspecting docker host: {exc}")
def get_docker_disk_usage_facts(self):
def get_docker_disk_usage_facts(self) -> dict[str, t.Any]:
try:
if self.verbose_output:
return self.client.df()
@ -275,9 +274,13 @@ class DockerHostManager(DockerBaseClass):
except APIError as exc:
self.client.fail(f"Error inspecting docker host: {exc}")
def get_docker_items_list(self, docker_object=None, filters=None, verbose=False):
items = None
items_list = []
def get_docker_items_list(
self,
docker_object: str,
filters: dict[str, t.Any] | None = None,
verbose: bool = False,
) -> list[dict[str, t.Any]]:
items = []
header_containers = [
"Id",
@ -329,6 +332,7 @@ class DockerHostManager(DockerBaseClass):
if self.verbose_output:
return items
items_list = []
for item in items:
item_record = {}
@ -349,7 +353,7 @@ class DockerHostManager(DockerBaseClass):
return items_list
def main():
def main() -> None:
argument_spec = {
"containers": {"type": "bool", "default": False},
"containers_all": {"type": "bool", "default": False},

View File

@ -367,6 +367,7 @@ import errno
import json
import os
import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_native
from ansible.module_utils.common.text.formatters import human_to_bytes
@ -411,7 +412,18 @@ from ansible_collections.community.docker.plugins.module_utils._version import (
)
def convert_to_bytes(value, module, name, unlimited_value=None):
if t.TYPE_CHECKING:
from collections.abc import Callable
from ansible.module_utils.basic import AnsibleModule
def convert_to_bytes(
value: str | None,
module: AnsibleModule,
name: str,
unlimited_value: int | None = None,
) -> int | None:
if value is None:
return value
try:
@ -423,8 +435,7 @@ def convert_to_bytes(value, module, name, unlimited_value=None):
class ImageManager(DockerBaseClass):
def __init__(self, client, results):
def __init__(self, client: AnsibleDockerClient, results: dict[str, t.Any]) -> None:
"""
Configure a docker_image task.
@ -441,12 +452,14 @@ class ImageManager(DockerBaseClass):
parameters = self.client.module.params
self.check_mode = self.client.check_mode
self.source = parameters["source"]
build = parameters["build"] or {}
pull = parameters["pull"] or {}
self.archive_path = parameters["archive_path"]
self.cache_from = build.get("cache_from")
self.container_limits = build.get("container_limits")
self.source: t.Literal["build", "load", "pull", "local"] | None = parameters[
"source"
]
build: dict[str, t.Any] = parameters["build"] or {}
pull: dict[str, t.Any] = parameters["pull"] or {}
self.archive_path: str | None = parameters["archive_path"]
self.cache_from: list[str] | None = build.get("cache_from")
self.container_limits: dict[str, t.Any] | None = build.get("container_limits")
if self.container_limits and "memory" in self.container_limits:
self.container_limits["memory"] = convert_to_bytes(
self.container_limits["memory"],
@ -460,32 +473,36 @@ class ImageManager(DockerBaseClass):
"build.container_limits.memswap",
unlimited_value=-1,
)
self.dockerfile = build.get("dockerfile")
self.force_source = parameters["force_source"]
self.force_absent = parameters["force_absent"]
self.force_tag = parameters["force_tag"]
self.load_path = parameters["load_path"]
self.name = parameters["name"]
self.network = build.get("network")
self.extra_hosts = clean_dict_booleans_for_docker_api(build.get("etc_hosts"))
self.nocache = build.get("nocache", False)
self.build_path = build.get("path")
self.pull = build.get("pull")
self.target = build.get("target")
self.repository = parameters["repository"]
self.rm = build.get("rm", True)
self.state = parameters["state"]
self.tag = parameters["tag"]
self.http_timeout = build.get("http_timeout")
self.pull_platform = pull.get("platform")
self.push = parameters["push"]
self.buildargs = build.get("args")
self.build_platform = build.get("platform")
self.use_config_proxy = build.get("use_config_proxy")
self.shm_size = convert_to_bytes(
self.dockerfile: str | None = build.get("dockerfile")
self.force_source: bool = parameters["force_source"]
self.force_absent: bool = parameters["force_absent"]
self.force_tag: bool = parameters["force_tag"]
self.load_path: str | None = parameters["load_path"]
self.name: str = parameters["name"]
self.network: str | None = build.get("network")
self.extra_hosts: dict[str, str] = clean_dict_booleans_for_docker_api(
build.get("etc_hosts") # type: ignore
)
self.nocache: bool = build.get("nocache", False)
self.build_path: str | None = build.get("path")
self.pull: bool | None = build.get("pull")
self.target: str | None = build.get("target")
self.repository: str | None = parameters["repository"]
self.rm: bool = build.get("rm", True)
self.state: t.Literal["absent", "present"] = parameters["state"]
self.tag: str = parameters["tag"]
self.http_timeout: int | None = build.get("http_timeout")
self.pull_platform: str | None = pull.get("platform")
self.push: bool = parameters["push"]
self.buildargs: dict[str, t.Any] | None = build.get("args")
self.build_platform: str | None = build.get("platform")
self.use_config_proxy: bool | None = build.get("use_config_proxy")
self.shm_size: int | None = convert_to_bytes(
build.get("shm_size"), self.client.module, "build.shm_size"
)
self.labels = clean_dict_booleans_for_docker_api(build.get("labels"))
self.labels: dict[str, str] = clean_dict_booleans_for_docker_api(
build.get("labels") # type: ignore
)
# If name contains a tag, it takes precedence over tag parameter.
if not is_image_name_id(self.name):
@ -507,10 +524,10 @@ class ImageManager(DockerBaseClass):
elif self.state == "absent":
self.absent()
def fail(self, msg):
def fail(self, msg: str) -> t.NoReturn:
self.client.fail(msg)
def present(self):
def present(self) -> None:
"""
Handles state = 'present', which includes building, loading or pulling an image,
depending on user provided parameters.
@ -530,6 +547,7 @@ class ImageManager(DockerBaseClass):
)
# Build the image
assert self.build_path is not None
if not os.path.isdir(self.build_path):
self.fail(
f"Requested build path {self.build_path} could not be found or you do not have access."
@ -546,6 +564,7 @@ class ImageManager(DockerBaseClass):
self.results.update(self.build_image())
elif self.source == "load":
assert self.load_path is not None
# Load the image from an archive
if not os.path.isfile(self.load_path):
self.fail(
@ -596,7 +615,7 @@ class ImageManager(DockerBaseClass):
elif self.repository:
self.tag_image(self.name, self.tag, self.repository, push=self.push)
def absent(self):
def absent(self) -> None:
"""
Handles state = 'absent', which removes an image.
@ -627,8 +646,11 @@ class ImageManager(DockerBaseClass):
@staticmethod
def archived_image_action(
failure_logger, archive_path, current_image_name, current_image_id
):
failure_logger: Callable[[str], None],
archive_path: str,
current_image_name: str,
current_image_id: str,
) -> str | None:
"""
If the archive is missing or requires replacement, return an action message.
@ -667,7 +689,7 @@ class ImageManager(DockerBaseClass):
f"overwriting archive with image {archived.image_id} named {name}"
)
def archive_image(self, name, tag):
def archive_image(self, name: str, tag: str | None) -> None:
"""
Archive an image to a .tar file. Called when archive_path is passed.
@ -676,6 +698,7 @@ class ImageManager(DockerBaseClass):
:param tag: Optional image tag; assumed to be "latest" if None
:type tag: str | None
"""
assert self.archive_path is not None
if not tag:
tag = "latest"
@ -710,8 +733,8 @@ class ImageManager(DockerBaseClass):
self.client._get(
self.client._url("/images/{0}/get", image_name), stream=True
),
DEFAULT_DATA_CHUNK_SIZE,
False,
chunk_size=DEFAULT_DATA_CHUNK_SIZE,
decode=False,
)
except Exception as exc: # pylint: disable=broad-exception-caught
self.fail(f"Error getting image {image_name} - {exc}")
@ -725,7 +748,7 @@ class ImageManager(DockerBaseClass):
self.results["image"] = image
def push_image(self, name, tag=None):
def push_image(self, name: str, tag: str | None = None) -> None:
"""
If the name of the image contains a repository path, then push the image.
@ -799,7 +822,9 @@ class ImageManager(DockerBaseClass):
self.results["image"] = {}
self.results["image"]["push_status"] = status
def tag_image(self, name, tag, repository, push=False):
def tag_image(
self, name: str, tag: str, repository: str, push: bool = False
) -> None:
"""
Tag an image into a repository.
@ -852,7 +877,7 @@ class ImageManager(DockerBaseClass):
self.push_image(repo, repo_tag)
@staticmethod
def _extract_output_line(line, output):
def _extract_output_line(line: dict[str, t.Any], output: list[str]):
"""
Extract text line from stream output and, if found, adds it to output.
"""
@ -862,14 +887,15 @@ class ImageManager(DockerBaseClass):
text_line = line.get("stream") or line.get("status") or ""
output.extend(text_line.splitlines())
def build_image(self):
def build_image(self) -> dict[str, t.Any]:
"""
Build an image
:return: image dict
"""
assert self.build_path is not None
remote = context = None
headers = {}
headers: dict[str, str | bytes] = {}
buildargs = {}
if self.buildargs:
for key, value in self.buildargs.items():
@ -898,12 +924,12 @@ class ImageManager(DockerBaseClass):
[line.strip() for line in f.read().splitlines()],
)
)
dockerfile = process_dockerfile(dockerfile, self.build_path)
dockerfile_data = process_dockerfile(dockerfile, self.build_path)
context = tar(
self.build_path, exclude=exclude, dockerfile=dockerfile, gzip=False
self.build_path, exclude=exclude, dockerfile=dockerfile_data, gzip=False
)
params = {
params: dict[str, t.Any] = {
"t": f"{self.name}:{self.tag}" if self.tag else self.name,
"remote": remote,
"q": False,
@ -960,7 +986,7 @@ class ImageManager(DockerBaseClass):
if context is not None:
context.close()
build_output = []
build_output: list[str] = []
for line in self.client._stream_helper(response, decode=True):
# line = json.loads(line)
self.log(line, pretty_print=True)
@ -982,14 +1008,15 @@ class ImageManager(DockerBaseClass):
"image": self.client.find_image(name=self.name, tag=self.tag),
}
def load_image(self):
def load_image(self) -> dict[str, t.Any] | None:
"""
Load an image from a .tar archive
:return: image dict
"""
# Load image(s) from file
load_output = []
assert self.load_path is not None
load_output: list[str] = []
has_output = False
try:
self.log(f"Opening image {self.load_path}")
@ -1078,7 +1105,7 @@ class ImageManager(DockerBaseClass):
return self.client.find_image(self.name, self.tag)
def main():
def main() -> None:
argument_spec = {
"source": {"type": "str", "choices": ["build", "load", "pull", "local"]},
"build": {

View File

@ -282,6 +282,7 @@ command:
import base64
import os
import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_native
from ansible.module_utils.common.text.formatters import human_to_bytes
@ -304,7 +305,16 @@ from ansible_collections.community.docker.plugins.module_utils._version import (
)
def convert_to_bytes(value, module, name, unlimited_value=None):
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def convert_to_bytes(
value: str | None,
module: AnsibleModule,
name: str,
unlimited_value: int | None = None,
) -> int | None:
if value is None:
return value
try:
@ -315,11 +325,11 @@ def convert_to_bytes(value, module, name, unlimited_value=None):
module.fail_json(msg=f"Failed to convert {name} to bytes: {exc}")
def dict_to_list(dictionary, concat="="):
def dict_to_list(dictionary: dict[str, t.Any], concat: str = "=") -> list[str]:
return [f"{k}{concat}{v}" for k, v in sorted(dictionary.items())]
def _quote_csv(text):
def _quote_csv(text: str) -> str:
if text.strip() == text and all(i not in text for i in '",\r\n'):
return text
text = text.replace('"', '""')
@ -327,7 +337,7 @@ def _quote_csv(text):
class ImageBuilder(DockerBaseClass):
def __init__(self, client):
def __init__(self, client: AnsibleModuleDockerClient) -> None:
super().__init__()
self.client = client
self.check_mode = self.client.check_mode
@ -420,14 +430,14 @@ class ImageBuilder(DockerBaseClass):
f" buildx plugin has version {buildx_version} which only supports one output."
)
def fail(self, msg, **kwargs):
def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
self.client.fail(msg, **kwargs)
def add_list_arg(self, args, option, values):
def add_list_arg(self, args: list[str], option: str, values: list[str]) -> None:
for value in values:
args.extend([option, value])
def add_args(self, args):
def add_args(self, args: list[str]) -> dict[str, t.Any]:
environ_update = {}
if not self.outputs:
args.extend(["--tag", f"{self.name}:{self.tag}"])
@ -512,9 +522,9 @@ class ImageBuilder(DockerBaseClass):
)
return environ_update
def build_image(self):
def build_image(self) -> dict[str, t.Any]:
image = self.client.find_image(self.name, self.tag)
results = {
results: dict[str, t.Any] = {
"changed": False,
"actions": [],
"image": image or {},
@ -547,7 +557,7 @@ class ImageBuilder(DockerBaseClass):
return results
def main():
def main() -> None:
argument_spec = {
"name": {"type": "str", "required": True},
"tag": {"type": "str", "default": "latest"},

View File

@ -94,6 +94,7 @@ images:
"""
import traceback
import typing as t
from ansible_collections.community.docker.plugins.module_utils._api.constants import (
DEFAULT_DATA_CHUNK_SIZE,
@ -121,7 +122,7 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
class ImageExportManager(DockerBaseClass):
def __init__(self, client):
def __init__(self, client: AnsibleDockerClient) -> None:
super().__init__()
self.client = client
@ -151,10 +152,10 @@ class ImageExportManager(DockerBaseClass):
if not self.names:
self.fail("At least one image name must be specified")
def fail(self, msg):
def fail(self, msg: str) -> t.NoReturn:
self.client.fail(msg)
def get_export_reason(self):
def get_export_reason(self) -> str | None:
if self.force:
return "Exporting since force=true"
@ -178,13 +179,13 @@ class ImageExportManager(DockerBaseClass):
found = True
break
if not found:
return f'Overwriting archive since it contains unexpected image {archived_image.image_id} named {", ".join(archived_image.repo_tags)}'
return f"Overwriting archive since it contains unexpected image {archived_image.image_id} named {', '.join(archived_image.repo_tags)}"
if left_names:
return f"Overwriting archive since it is missing image(s) {', '.join([name['joined'] for name in left_names])}"
return None
def write_chunks(self, chunks):
def write_chunks(self, chunks: t.Generator[bytes]) -> None:
try:
with open(self.path, "wb") as fd:
for chunk in chunks:
@ -192,7 +193,7 @@ class ImageExportManager(DockerBaseClass):
except Exception as exc: # pylint: disable=broad-exception-caught
self.fail(f"Error writing image archive {self.path} - {exc}")
def export_images(self):
def export_images(self) -> None:
image_names = [name["joined"] for name in self.names]
image_names_str = ", ".join(image_names)
if len(image_names) == 1:
@ -202,8 +203,8 @@ class ImageExportManager(DockerBaseClass):
self.client._get(
self.client._url("/images/{0}/get", image_names[0]), stream=True
),
DEFAULT_DATA_CHUNK_SIZE,
False,
chunk_size=DEFAULT_DATA_CHUNK_SIZE,
decode=False,
)
except Exception as exc: # pylint: disable=broad-exception-caught
self.fail(f"Error getting image {image_names[0]} - {exc}")
@ -216,15 +217,15 @@ class ImageExportManager(DockerBaseClass):
stream=True,
params={"names": image_names},
),
DEFAULT_DATA_CHUNK_SIZE,
False,
chunk_size=DEFAULT_DATA_CHUNK_SIZE,
decode=False,
)
except Exception as exc: # pylint: disable=broad-exception-caught
self.fail(f"Error getting images {image_names_str} - {exc}")
self.write_chunks(chunks)
def run(self):
def run(self) -> dict[str, t.Any]:
tag = self.tag
if not tag:
tag = "latest"
@ -260,7 +261,7 @@ class ImageExportManager(DockerBaseClass):
return results
def main():
def main() -> None:
argument_spec = {
"path": {"type": "path"},
"force": {"type": "bool", "default": False},

View File

@ -136,6 +136,7 @@ images:
"""
import traceback
import typing as t
from ansible_collections.community.docker.plugins.module_utils._api.errors import (
DockerException,
@ -155,9 +156,7 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
class ImageManager(DockerBaseClass):
def __init__(self, client, results):
def __init__(self, client: AnsibleDockerClient, results: dict[str, t.Any]) -> None:
super().__init__()
self.client = client
@ -170,10 +169,10 @@ class ImageManager(DockerBaseClass):
else:
self.results["images"] = self.get_all_images()
def fail(self, msg):
def fail(self, msg: str) -> t.NoReturn:
self.client.fail(msg)
def get_facts(self):
def get_facts(self) -> list[dict[str, t.Any]]:
"""
Lookup and inspect each image name found in the names parameter.
@ -200,7 +199,7 @@ class ImageManager(DockerBaseClass):
results.append(image)
return results
def get_all_images(self):
def get_all_images(self) -> list[dict[str, t.Any]]:
results = []
params = {
"only_ids": 0,
@ -218,7 +217,7 @@ class ImageManager(DockerBaseClass):
return results
def main():
def main() -> None:
argument_spec = {
"name": {"type": "list", "elements": "str"},
}

View File

@ -80,6 +80,7 @@ images:
import errno
import traceback
import typing as t
from ansible_collections.community.docker.plugins.module_utils._api.errors import (
DockerException,
@ -95,7 +96,7 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
class ImageManager(DockerBaseClass):
def __init__(self, client, results):
def __init__(self, client: AnsibleDockerClient, results: dict[str, t.Any]) -> None:
super().__init__()
self.client = client
@ -108,7 +109,7 @@ class ImageManager(DockerBaseClass):
self.load_images()
@staticmethod
def _extract_output_line(line, output):
def _extract_output_line(line: dict[str, t.Any], output: list[str]) -> None:
"""
Extract text line from stream output and, if found, adds it to output.
"""
@ -118,12 +119,12 @@ class ImageManager(DockerBaseClass):
text_line = line.get("stream") or line.get("status") or ""
output.extend(text_line.splitlines())
def load_images(self):
def load_images(self) -> None:
"""
Load images from a .tar archive
"""
# Load image(s) from file
load_output = []
load_output: list[str] = []
try:
self.log(f"Opening image {self.path}")
with open(self.path, "rb") as image_tar:
@ -179,7 +180,7 @@ class ImageManager(DockerBaseClass):
self.results["stdout"] = "\n".join(load_output)
def main():
def main() -> None:
client = AnsibleDockerClient(
argument_spec={
"path": {"type": "path", "required": True},
@ -188,7 +189,7 @@ def main():
)
try:
results = {
results: dict[str, t.Any] = {
"image_names": [],
"images": [],
}

View File

@ -91,6 +91,7 @@ image:
"""
import traceback
import typing as t
from ansible_collections.community.docker.plugins.module_utils._api.errors import (
DockerException,
@ -114,7 +115,7 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
)
def image_info(image):
def image_info(image: dict[str, t.Any] | None) -> dict[str, t.Any]:
result = {}
if image:
result["id"] = image["Id"]
@ -124,17 +125,17 @@ def image_info(image):
class ImagePuller(DockerBaseClass):
def __init__(self, client):
def __init__(self, client: AnsibleDockerClient) -> None:
super().__init__()
self.client = client
self.check_mode = self.client.check_mode
parameters = self.client.module.params
self.name = parameters["name"]
self.tag = parameters["tag"]
self.platform = parameters["platform"]
self.pull_mode = parameters["pull"]
self.name: str = parameters["name"]
self.tag: str = parameters["tag"]
self.platform: str | None = parameters["platform"]
self.pull_mode: t.Literal["always", "not_present"] = parameters["pull"]
if is_image_name_id(self.name):
self.client.fail("Cannot pull an image by ID")
@ -147,13 +148,15 @@ class ImagePuller(DockerBaseClass):
self.name = repo
self.tag = repo_tag
def pull(self):
def pull(self) -> dict[str, t.Any]:
image = self.client.find_image(name=self.name, tag=self.tag)
actions: list[str] = []
diff = {"before": image_info(image), "after": image_info(image)}
results = {
"changed": False,
"actions": [],
"actions": actions,
"image": image or {},
"diff": {"before": image_info(image), "after": image_info(image)},
"diff": diff,
}
if image and self.pull_mode == "not_present":
@ -175,21 +178,22 @@ class ImagePuller(DockerBaseClass):
if compare_platform_strings(wanted_platform, image_platform):
return results
results["actions"].append(f"Pulled image {self.name}:{self.tag}")
actions.append(f"Pulled image {self.name}:{self.tag}")
if self.check_mode:
results["changed"] = True
results["diff"]["after"] = image_info({"Id": "unknown"})
diff["after"] = image_info({"Id": "unknown"})
else:
results["image"], not_changed = self.client.pull_image(
image, not_changed = self.client.pull_image(
self.name, tag=self.tag, image_platform=self.platform
)
results["image"] = image
results["changed"] = not not_changed
results["diff"]["after"] = image_info(results["image"])
diff["after"] = image_info(image)
return results
def main():
def main() -> None:
argument_spec = {
"name": {"type": "str", "required": True},
"tag": {"type": "str", "default": "latest"},

View File

@ -73,6 +73,7 @@ image:
import base64
import traceback
import typing as t
from ansible_collections.community.docker.plugins.module_utils._api.auth import (
get_config_header,
@ -96,15 +97,15 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
class ImagePusher(DockerBaseClass):
def __init__(self, client):
def __init__(self, client: AnsibleDockerClient) -> None:
super().__init__()
self.client = client
self.check_mode = self.client.check_mode
parameters = self.client.module.params
self.name = parameters["name"]
self.tag = parameters["tag"]
self.name: str = parameters["name"]
self.tag: str = parameters["tag"]
if is_image_name_id(self.name):
self.client.fail("Cannot push an image by ID")
@ -122,20 +123,21 @@ class ImagePusher(DockerBaseClass):
if not is_valid_tag(self.tag, allow_empty=False):
self.client.fail(f'"{self.tag}" is not a valid docker tag!')
def push(self):
def push(self) -> dict[str, t.Any]:
image = self.client.find_image(name=self.name, tag=self.tag)
if not image:
self.client.fail(f"Cannot find image {self.name}:{self.tag}")
results = {
actions: list[str] = []
results: dict[str, t.Any] = {
"changed": False,
"actions": [],
"actions": actions,
"image": image,
}
push_registry, push_repo = resolve_repository_name(self.name)
try:
results["actions"].append(f"Pushed image {self.name}:{self.tag}")
actions.append(f"Pushed image {self.name}:{self.tag}")
headers = {}
header = get_config_header(self.client, push_registry)
@ -174,7 +176,7 @@ class ImagePusher(DockerBaseClass):
return results
def main():
def main() -> None:
argument_spec = {
"name": {"type": "str", "required": True},
"tag": {"type": "str", "default": "latest"},

View File

@ -98,6 +98,7 @@ untagged:
"""
import traceback
import typing as t
from ansible_collections.community.docker.plugins.module_utils._api.errors import (
DockerException,
@ -118,8 +119,7 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
class ImageRemover(DockerBaseClass):
def __init__(self, client):
def __init__(self, client: AnsibleDockerClient) -> None:
super().__init__()
self.client = client
@ -142,10 +142,10 @@ class ImageRemover(DockerBaseClass):
self.name = repo
self.tag = repo_tag
def fail(self, msg):
def fail(self, msg: str) -> t.NoReturn:
self.client.fail(msg)
def get_diff_state(self, image):
def get_diff_state(self, image: dict[str, t.Any] | None) -> dict[str, t.Any]:
if not image:
return {"exists": False}
return {
@ -155,13 +155,16 @@ class ImageRemover(DockerBaseClass):
"digests": sorted(image.get("RepoDigests") or []),
}
def absent(self):
results = {
def absent(self) -> dict[str, t.Any]:
actions: list[str] = []
deleted: list[str] = []
untagged: list[str] = []
results: dict[str, t.Any] = {
"changed": False,
"actions": [],
"actions": actions,
"image": {},
"deleted": [],
"untagged": [],
"deleted": deleted,
"untagged": untagged,
}
name = self.name
@ -172,16 +175,18 @@ class ImageRemover(DockerBaseClass):
if self.tag:
name = f"{self.name}:{self.tag}"
diff: dict[str, t.Any] = {}
if self.diff:
results["diff"] = {"before": self.get_diff_state(image)}
results["diff"] = diff
diff["before"] = self.get_diff_state(image)
if not image:
if self.diff:
results["diff"]["after"] = self.get_diff_state(image)
diff["after"] = self.get_diff_state(image)
return results
results["changed"] = True
results["actions"].append(f"Removed image {name}")
actions.append(f"Removed image {name}")
results["image"] = image
if not self.check_mode:
@ -199,22 +204,22 @@ class ImageRemover(DockerBaseClass):
for entry in res:
if entry.get("Untagged"):
results["untagged"].append(entry["Untagged"])
untagged.append(entry["Untagged"])
if entry.get("Deleted"):
results["deleted"].append(entry["Deleted"])
deleted.append(entry["Deleted"])
results["untagged"] = sorted(results["untagged"])
results["deleted"] = sorted(results["deleted"])
untagged[:] = sorted(untagged)
deleted[:] = sorted(deleted)
if self.diff:
image_after = self.client.find_image_by_id(
image["Id"], accept_missing_image=True
)
results["diff"]["after"] = self.get_diff_state(image_after)
diff["after"] = self.get_diff_state(image_after)
elif is_image_name_id(name):
results["deleted"].append(image["Id"])
results["untagged"] = sorted(
deleted.append(image["Id"])
untagged[:] = sorted(
(image.get("RepoTags") or []) + (image.get("RepoDigests") or [])
)
if not self.force and results["untagged"]:
@ -222,40 +227,40 @@ class ImageRemover(DockerBaseClass):
"Cannot delete image by ID that is still in use - use force=true"
)
if self.diff:
results["diff"]["after"] = self.get_diff_state({})
diff["after"] = self.get_diff_state({})
elif is_image_name_id(self.tag):
results["untagged"].append(name)
untagged.append(name)
if (
len(image.get("RepoTags") or []) < 1
and len(image.get("RepoDigests") or []) < 2
):
results["deleted"].append(image["Id"])
deleted.append(image["Id"])
if self.diff:
results["diff"]["after"] = self.get_diff_state(image)
diff["after"] = self.get_diff_state(image)
try:
results["diff"]["after"]["digests"].remove(name)
diff["after"]["digests"].remove(name)
except ValueError:
pass
else:
results["untagged"].append(name)
untagged.append(name)
if (
len(image.get("RepoTags") or []) < 2
and len(image.get("RepoDigests") or []) < 1
):
results["deleted"].append(image["Id"])
deleted.append(image["Id"])
if self.diff:
results["diff"]["after"] = self.get_diff_state(image)
diff["after"] = self.get_diff_state(image)
try:
results["diff"]["after"]["tags"].remove(name)
diff["after"]["tags"].remove(name)
except ValueError:
pass
return results
def main():
def main() -> None:
argument_spec = {
"name": {"type": "str", "required": True},
"tag": {"type": "str", "default": "latest"},

View File

@ -101,6 +101,7 @@ tagged_images:
"""
import traceback
import typing as t
from ansible.module_utils.common.text.formatters import human_to_bytes
@ -121,7 +122,16 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
)
def convert_to_bytes(value, module, name, unlimited_value=None):
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def convert_to_bytes(
value: str | None,
module: AnsibleModule,
name: str,
unlimited_value: int | None = None,
) -> int | None:
if value is None:
return value
try:
@ -132,8 +142,8 @@ def convert_to_bytes(value, module, name, unlimited_value=None):
module.fail_json(msg=f"Failed to convert {name} to bytes: {exc}")
def image_info(name, tag, image):
result = {"name": name, "tag": tag}
def image_info(name: str, tag: str, image: dict[str, t.Any] | None) -> dict[str, t.Any]:
result: dict[str, t.Any] = {"name": name, "tag": tag}
if image:
result["id"] = image["Id"]
else:
@ -142,7 +152,7 @@ def image_info(name, tag, image):
class ImageTagger(DockerBaseClass):
def __init__(self, client):
def __init__(self, client: AnsibleDockerClient) -> None:
super().__init__()
self.client = client
@ -179,10 +189,12 @@ class ImageTagger(DockerBaseClass):
)
self.repositories.append((repo, repo_tag))
def fail(self, msg):
def fail(self, msg: str) -> t.NoReturn:
self.client.fail(msg)
def tag_image(self, image, name, tag):
def tag_image(
self, image: dict[str, t.Any], name: str, tag: str
) -> tuple[bool, str, dict[str, t.Any] | None]:
tagged_image = self.client.find_image(name=name, tag=tag)
if tagged_image:
# Idempotency checks
@ -220,20 +232,22 @@ class ImageTagger(DockerBaseClass):
return True, msg, tagged_image
def tag_images(self):
def tag_images(self) -> dict[str, t.Any]:
if is_image_name_id(self.name):
image = self.client.find_image_by_id(self.name, accept_missing_image=False)
else:
image = self.client.find_image(name=self.name, tag=self.tag)
if not image:
self.fail(f"Cannot find image {self.name}:{self.tag}")
assert image is not None
before = []
after = []
tagged_images = []
results = {
before: list[dict[str, t.Any]] = []
after: list[dict[str, t.Any]] = []
tagged_images: list[str] = []
actions: list[str] = []
results: dict[str, t.Any] = {
"changed": False,
"actions": [],
"actions": actions,
"image": image,
"tagged_images": tagged_images,
"diff": {"before": {"images": before}, "after": {"images": after}},
@ -244,19 +258,19 @@ class ImageTagger(DockerBaseClass):
after.append(image_info(repository, tag, image if tagged else old_image))
if tagged:
results["changed"] = True
results["actions"].append(
actions.append(
f"Tagged image {image['Id']} as {repository}:{tag}: {msg}"
)
tagged_images.append(f"{repository}:{tag}")
else:
results["actions"].append(
actions.append(
f"Not tagged image {image['Id']} as {repository}:{tag}: {msg}"
)
return results
def main():
def main() -> None:
argument_spec = {
"name": {"type": "str", "required": True},
"tag": {"type": "str", "default": "latest"},

View File

@ -120,6 +120,7 @@ import base64
import json
import os
import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_bytes, to_text
@ -154,11 +155,11 @@ class DockerFileStore:
program = "<legacy config>"
def __init__(self, config_path):
def __init__(self, config_path: str) -> None:
self._config_path = config_path
# Make sure we have a minimal config if none is available.
self._config = {"auths": {}}
self._config: dict[str, t.Any] = {"auths": {}}
try:
# Attempt to read the existing config.
@ -172,14 +173,14 @@ class DockerFileStore:
self._config.update(config)
@property
def config_path(self):
def config_path(self) -> str:
"""
Return the config path configured in this DockerFileStore instance.
"""
return self._config_path
def get(self, server):
def get(self, server: str) -> dict[str, t.Any]:
"""
Retrieve credentials for `server` if there are any in the config file.
Otherwise raise a `StoreError`
@ -193,7 +194,7 @@ class DockerFileStore:
return {"Username": username, "Secret": password}
def _write(self):
def _write(self) -> None:
"""
Write config back out to disk.
"""
@ -209,7 +210,7 @@ class DockerFileStore:
finally:
os.close(f)
def store(self, server, username, password):
def store(self, server: str, username: str, password: str) -> None:
"""
Add a credentials for `server` to the current configuration.
"""
@ -225,7 +226,7 @@ class DockerFileStore:
self._write()
def erase(self, server):
def erase(self, server: str) -> None:
"""
Remove credentials for the given server from the configuration.
"""
@ -236,9 +237,7 @@ class DockerFileStore:
class LoginManager(DockerBaseClass):
def __init__(self, client, results):
def __init__(self, client: AnsibleDockerClient, results: dict[str, t.Any]) -> None:
super().__init__()
self.client = client
@ -246,14 +245,14 @@ class LoginManager(DockerBaseClass):
parameters = self.client.module.params
self.check_mode = self.client.check_mode
self.registry_url = parameters.get("registry_url")
self.username = parameters.get("username")
self.password = parameters.get("password")
self.reauthorize = parameters.get("reauthorize")
self.config_path = parameters.get("config_path")
self.state = parameters.get("state")
self.registry_url: str = parameters.get("registry_url")
self.username: str | None = parameters.get("username")
self.password: str | None = parameters.get("password")
self.reauthorize: bool = parameters.get("reauthorize")
self.config_path: str = parameters.get("config_path")
self.state: t.Literal["present", "absent"] = parameters.get("state")
def run(self):
def run(self) -> None:
"""
Do the actual work of this task here. This allows instantiation for partial
testing.
@ -264,10 +263,10 @@ class LoginManager(DockerBaseClass):
else:
self.logout()
def fail(self, msg):
def fail(self, msg: str) -> t.NoReturn:
self.client.fail(msg)
def _login(self, reauth):
def _login(self, reauth: bool) -> dict[str, t.Any]:
if self.config_path and os.path.exists(self.config_path):
self.client._auth_configs = auth.load_config(
self.config_path, credstore_env=self.client.credstore_env
@ -297,7 +296,7 @@ class LoginManager(DockerBaseClass):
)
return self.client._result(response, get_json=True)
def login(self):
def login(self) -> None:
"""
Log into the registry with provided username/password. On success update the config
file with the new authorization.
@ -331,7 +330,7 @@ class LoginManager(DockerBaseClass):
self.update_credentials()
def logout(self):
def logout(self) -> None:
"""
Log out of the registry. On success update the config file.
@ -353,13 +352,16 @@ class LoginManager(DockerBaseClass):
store.erase(self.registry_url)
self.results["changed"] = True
def update_credentials(self):
def update_credentials(self) -> None:
"""
If the authorization is not stored attempt to store authorization values via
the appropriate credential helper or to the config file.
:return: None
"""
# This is only called from login()
assert self.username is not None
assert self.password is not None
# Check to see if credentials already exist.
store = self.get_credential_store_instance(self.registry_url, self.config_path)
@ -385,7 +387,9 @@ class LoginManager(DockerBaseClass):
)
self.results["changed"] = True
def get_credential_store_instance(self, registry, dockercfg_path):
def get_credential_store_instance(
self, registry: str, dockercfg_path: str
) -> Store | DockerFileStore:
"""
Return an instance of docker.credentials.Store used by the given registry.
@ -408,8 +412,7 @@ class LoginManager(DockerBaseClass):
return DockerFileStore(dockercfg_path)
def main():
def main() -> None:
argument_spec = {
"registry_url": {
"type": "str",

View File

@ -284,6 +284,7 @@ network:
import re
import time
import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_native
@ -303,29 +304,31 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
class TaskParameters(DockerBaseClass):
def __init__(self, client):
name: str
def __init__(self, client: AnsibleDockerClient) -> None:
super().__init__()
self.client = client
self.name = None
self.connected = None
self.config_from = None
self.config_only = None
self.driver = None
self.driver_options = None
self.ipam_driver = None
self.ipam_driver_options = None
self.ipam_config = None
self.appends = None
self.force = None
self.internal = None
self.labels = None
self.debug = None
self.enable_ipv4 = None
self.enable_ipv6 = None
self.scope = None
self.attachable = None
self.ingress = None
self.connected: list[str] = []
self.config_from: str | None = None
self.config_only: bool | None = None
self.driver: str = "bridge"
self.driver_options: dict[str, t.Any] = {}
self.ipam_driver: str | None = None
self.ipam_driver_options: dict[str, t.Any] | None = None
self.ipam_config: list[dict[str, t.Any]] | None = None
self.appends: bool = False
self.force: bool = False
self.internal: bool | None = None
self.labels: dict[str, t.Any] = {}
self.debug: bool = False
self.enable_ipv4: bool | None = None
self.enable_ipv6: bool | None = None
self.scope: t.Literal["local", "global", "swarm"] | None = None
self.attachable: bool | None = None
self.ingress: bool | None = None
self.state: t.Literal["present", "absent"] = "present"
for key, value in client.module.params.items():
setattr(self, key, value)
@ -333,10 +336,10 @@ class TaskParameters(DockerBaseClass):
# config_only sets driver to 'null' (and scope to 'local') so force that here. Otherwise we get
# diffs of 'null' --> 'bridge' given that the driver option defaults to 'bridge'.
if self.config_only:
self.driver = "null"
self.driver = "null" # type: ignore[unreachable]
def container_names_in_network(network):
def container_names_in_network(network: dict[str, t.Any]) -> list[str]:
return (
[c["Name"] for c in network["Containers"].values()]
if network["Containers"]
@ -348,7 +351,7 @@ CIDR_IPV4 = re.compile(r"^([0-9]{1,3}\.){3}[0-9]{1,3}/([0-9]|[1-2][0-9]|3[0-2])$
CIDR_IPV6 = re.compile(r"^[0-9a-fA-F:]+/([0-9]|[1-9][0-9]|1[0-2][0-9])$")
def validate_cidr(cidr):
def validate_cidr(cidr: str) -> t.Literal["ipv4", "ipv6"]:
"""Validate CIDR. Return IP version of a CIDR string on success.
:param cidr: Valid CIDR
@ -364,7 +367,7 @@ def validate_cidr(cidr):
raise ValueError(f'"{cidr}" is not a valid CIDR')
def normalize_ipam_config_key(key):
def normalize_ipam_config_key(key: str) -> str:
"""Normalizes IPAM config keys returned by Docker API to match Ansible keys.
:param key: Docker API key
@ -376,7 +379,7 @@ def normalize_ipam_config_key(key):
return special_cases.get(key, key.lower())
def dicts_are_essentially_equal(a, b):
def dicts_are_essentially_equal(a: dict[str, t.Any], b: dict[str, t.Any]):
"""Make sure that a is a subset of b, where None entries of a are ignored."""
for k, v in a.items():
if v is None:
@ -387,15 +390,15 @@ def dicts_are_essentially_equal(a, b):
class DockerNetworkManager:
def __init__(self, client):
def __init__(self, client: AnsibleDockerClient) -> None:
self.client = client
self.parameters = TaskParameters(client)
self.check_mode = self.client.check_mode
self.results = {"changed": False, "actions": []}
self.actions: list[str] = []
self.results: dict[str, t.Any] = {"changed": False, "actions": self.actions}
self.diff = self.client.module._diff
self.diff_tracker = DifferenceTracker()
self.diff_result = {}
self.diff_result: dict[str, t.Any] = {}
self.existing_network = self.get_existing_network()
@ -429,10 +432,12 @@ class DockerNetworkManager:
)
self.results["diff"] = self.diff_result
def get_existing_network(self):
def get_existing_network(self) -> dict[str, t.Any] | None:
return self.client.get_network(name=self.parameters.name)
def has_different_config(self, net):
def has_different_config(
self, net: dict[str, t.Any]
) -> tuple[bool, DifferenceTracker]:
"""
Evaluates an existing network and returns a tuple containing a boolean
indicating if the configuration is different and a list of differences.
@ -601,9 +606,9 @@ class DockerNetworkManager:
return not differences.empty, differences
def create_network(self):
def create_network(self) -> None:
if not self.existing_network:
data = {
data: dict[str, t.Any] = {
"Name": self.parameters.name,
"Driver": self.parameters.driver,
"Options": self.parameters.driver_options,
@ -661,12 +666,12 @@ class DockerNetworkManager:
resp = self.client.post_json_to_json("/networks/create", data=data)
self.client.report_warnings(resp, ["Warning"])
self.existing_network = self.client.get_network(network_id=resp["Id"])
self.results["actions"].append(
self.actions.append(
f"Created network {self.parameters.name} with driver {self.parameters.driver}"
)
self.results["changed"] = True
def remove_network(self):
def remove_network(self) -> None:
if self.existing_network:
self.disconnect_all_containers()
if not self.check_mode:
@ -674,15 +679,15 @@ class DockerNetworkManager:
if self.existing_network.get("Scope", "local") == "swarm":
while self.get_existing_network():
time.sleep(0.1)
self.results["actions"].append(f"Removed network {self.parameters.name}")
self.actions.append(f"Removed network {self.parameters.name}")
self.results["changed"] = True
def is_container_connected(self, container_name):
def is_container_connected(self, container_name: str) -> bool:
if not self.existing_network:
return False
return container_name in container_names_in_network(self.existing_network)
def is_container_exist(self, container_name):
def is_container_exist(self, container_name: str) -> bool:
try:
container = self.client.get_container(container_name)
return bool(container)
@ -698,7 +703,7 @@ class DockerNetworkManager:
exception=traceback.format_exc(),
)
def connect_containers(self):
def connect_containers(self) -> None:
for name in self.parameters.connected:
if not self.is_container_connected(name) and self.is_container_exist(name):
if not self.check_mode:
@ -709,11 +714,11 @@ class DockerNetworkManager:
self.client.post_json(
"/networks/{0}/connect", self.parameters.name, data=data
)
self.results["actions"].append(f"Connected container {name}")
self.actions.append(f"Connected container {name}")
self.results["changed"] = True
self.diff_tracker.add(f"connected.{name}", parameter=True, active=False)
def disconnect_missing(self):
def disconnect_missing(self) -> None:
if not self.existing_network:
return
containers = self.existing_network["Containers"]
@ -724,26 +729,29 @@ class DockerNetworkManager:
if name not in self.parameters.connected:
self.disconnect_container(name)
def disconnect_all_containers(self):
containers = self.client.get_network(name=self.parameters.name)["Containers"]
def disconnect_all_containers(self) -> None:
network = self.client.get_network(name=self.parameters.name)
if not network:
return
containers = network["Containers"]
if not containers:
return
for cont in containers.values():
self.disconnect_container(cont["Name"])
def disconnect_container(self, container_name):
def disconnect_container(self, container_name: str) -> None:
if not self.check_mode:
data = {"Container": container_name, "Force": True}
self.client.post_json(
"/networks/{0}/disconnect", self.parameters.name, data=data
)
self.results["actions"].append(f"Disconnected container {container_name}")
self.actions.append(f"Disconnected container {container_name}")
self.results["changed"] = True
self.diff_tracker.add(
f"connected.{container_name}", parameter=False, active=True
)
def present(self):
def present(self) -> None:
different = False
differences = DifferenceTracker()
if self.existing_network:
@ -771,14 +779,14 @@ class DockerNetworkManager:
network_facts = self.get_existing_network()
self.results["network"] = network_facts
def absent(self):
def absent(self) -> None:
self.diff_tracker.add(
"exists", parameter=False, active=self.existing_network is not None
)
self.remove_network()
def main():
def main() -> None:
argument_spec = {
"name": {"type": "str", "required": True, "aliases": ["network_name"]},
"config_from": {"type": "str"},

View File

@ -107,7 +107,7 @@ from ansible_collections.community.docker.plugins.module_utils._common_api impor
)
def main():
def main() -> None:
argument_spec = {
"name": {"type": "str", "required": True},
}

View File

@ -129,6 +129,7 @@ actions:
"""
import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_native
@ -149,35 +150,36 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
class TaskParameters(DockerBaseClass):
def __init__(self, client):
plugin_name: str
def __init__(self, client: AnsibleDockerClient) -> None:
super().__init__()
self.client = client
self.plugin_name = None
self.alias = None
self.plugin_options = None
self.debug = None
self.force_remove = None
self.enable_timeout = None
self.alias: str | None = None
self.plugin_options: dict[str, t.Any] = {}
self.debug: bool = False
self.force_remove: bool = False
self.enable_timeout: int = 0
self.state: t.Literal["present", "absent", "enable", "disable"] = "present"
for key, value in client.module.params.items():
setattr(self, key, value)
def prepare_options(options):
def prepare_options(options: dict[str, t.Any] | None) -> list[str]:
return (
[f'{k}={v if v is not None else ""}' for k, v in options.items()]
[f"{k}={v if v is not None else ''}" for k, v in options.items()]
if options
else []
)
def parse_options(options_list):
def parse_options(options_list: list[str] | None) -> dict[str, str]:
return dict(x.split("=", 1) for x in options_list) if options_list else {}
class DockerPluginManager:
def __init__(self, client):
def __init__(self, client: AnsibleDockerClient) -> None:
self.client = client
self.parameters = TaskParameters(client)
@ -185,9 +187,9 @@ class DockerPluginManager:
self.check_mode = self.client.check_mode
self.diff = self.client.module._diff
self.diff_tracker = DifferenceTracker()
self.diff_result = {}
self.diff_result: dict[str, t.Any] = {}
self.actions = []
self.actions: list[str] = []
self.changed = False
self.existing_plugin = self.get_existing_plugin()
@ -209,7 +211,7 @@ class DockerPluginManager:
)
self.diff = self.diff_result
def get_existing_plugin(self):
def get_existing_plugin(self) -> dict[str, t.Any] | None:
try:
return self.client.get_json("/plugins/{0}/json", self.preferred_name)
except NotFound:
@ -217,12 +219,13 @@ class DockerPluginManager:
except APIError as e:
self.client.fail(to_native(e))
def has_different_config(self):
def has_different_config(self) -> DifferenceTracker:
"""
Return the list of differences between the current parameters and the existing plugin.
:return: list of options that differ
"""
assert self.existing_plugin is not None
differences = DifferenceTracker()
if self.parameters.plugin_options:
settings = self.existing_plugin.get("Settings")
@ -249,7 +252,7 @@ class DockerPluginManager:
return differences
def install_plugin(self):
def install_plugin(self) -> None:
if not self.existing_plugin:
if not self.check_mode:
try:
@ -297,7 +300,7 @@ class DockerPluginManager:
self.actions.append(f"Installed plugin {self.preferred_name}")
self.changed = True
def remove_plugin(self):
def remove_plugin(self) -> None:
force = self.parameters.force_remove
if self.existing_plugin:
if not self.check_mode:
@ -311,7 +314,7 @@ class DockerPluginManager:
self.actions.append(f"Removed plugin {self.preferred_name}")
self.changed = True
def update_plugin(self):
def update_plugin(self) -> None:
if self.existing_plugin:
differences = self.has_different_config()
if not differences.empty:
@ -328,7 +331,7 @@ class DockerPluginManager:
else:
self.client.fail("Cannot update the plugin: Plugin does not exist")
def present(self):
def present(self) -> None:
differences = DifferenceTracker()
if self.existing_plugin:
differences = self.has_different_config()
@ -345,13 +348,10 @@ class DockerPluginManager:
if self.diff or self.check_mode or self.parameters.debug:
self.diff_tracker.merge(differences)
if not self.check_mode and not self.parameters.debug:
self.actions = None
def absent(self):
def absent(self) -> None:
self.remove_plugin()
def enable(self):
def enable(self) -> None:
timeout = self.parameters.enable_timeout
if self.existing_plugin:
if not self.existing_plugin.get("Enabled"):
@ -380,7 +380,7 @@ class DockerPluginManager:
self.actions.append(f"Enabled plugin {self.preferred_name}")
self.changed = True
def disable(self):
def disable(self) -> None:
if self.existing_plugin:
if self.existing_plugin.get("Enabled"):
if not self.check_mode:
@ -396,7 +396,7 @@ class DockerPluginManager:
self.client.fail("Plugin not found: Plugin does not exist.")
@property
def result(self):
def result(self) -> dict[str, t.Any]:
plugin_data = {}
if self.parameters.state != "absent":
try:
@ -406,16 +406,22 @@ class DockerPluginManager:
except NotFound:
# This can happen in check mode
pass
result = {
result: dict[str, t.Any] = {
"actions": self.actions,
"changed": self.changed,
"diff": self.diff,
"plugin": plugin_data,
}
return dict((k, v) for k, v in result.items() if v is not None)
if (
self.parameters.state == "present"
and not self.check_mode
and not self.parameters.debug
):
result["actions"] = None
return {k: v for k, v in result.items() if v is not None}
def main():
def main() -> None:
argument_spec = {
"alias": {"type": "str"},
"plugin_name": {"type": "str", "required": True},

View File

@ -247,7 +247,7 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
)
def main():
def main() -> None:
argument_spec = {
"containers": {"type": "bool", "default": False},
"containers_filters": {"type": "dict"},

View File

@ -118,6 +118,7 @@ volume:
"""
import traceback
import typing as t
from ansible.module_utils.common.text.converters import to_native
@ -137,31 +138,33 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
class TaskParameters(DockerBaseClass):
def __init__(self, client):
volume_name: str
def __init__(self, client: AnsibleDockerClient) -> None:
super().__init__()
self.client = client
self.volume_name = None
self.driver = None
self.driver_options = None
self.labels = None
self.recreate = None
self.debug = None
self.driver: str = "local"
self.driver_options: dict[str, t.Any] = {}
self.labels: dict[str, t.Any] | None = None
self.recreate: t.Literal["always", "never", "options-changed"] = "never"
self.debug: bool = False
self.state: t.Literal["present", "absent"] = "present"
for key, value in client.module.params.items():
setattr(self, key, value)
class DockerVolumeManager:
def __init__(self, client):
def __init__(self, client: AnsibleDockerClient) -> None:
self.client = client
self.parameters = TaskParameters(client)
self.check_mode = self.client.check_mode
self.results = {"changed": False, "actions": []}
self.actions: list[str] = []
self.results: dict[str, t.Any] = {"changed": False, "actions": self.actions}
self.diff = self.client.module._diff
self.diff_tracker = DifferenceTracker()
self.diff_result = {}
self.diff_result: dict[str, t.Any] = {}
self.existing_volume = self.get_existing_volume()
@ -178,7 +181,7 @@ class DockerVolumeManager:
)
self.results["diff"] = self.diff_result
def get_existing_volume(self):
def get_existing_volume(self) -> dict[str, t.Any] | None:
try:
volumes = self.client.get_json("/volumes")
except APIError as e:
@ -193,12 +196,13 @@ class DockerVolumeManager:
return None
def has_different_config(self):
def has_different_config(self) -> DifferenceTracker:
"""
Return the list of differences between the current parameters and the existing volume.
:return: list of options that differ
"""
assert self.existing_volume is not None
differences = DifferenceTracker()
if (
self.parameters.driver
@ -239,7 +243,7 @@ class DockerVolumeManager:
return differences
def create_volume(self):
def create_volume(self) -> None:
if not self.existing_volume:
if not self.check_mode:
try:
@ -257,12 +261,12 @@ class DockerVolumeManager:
except APIError as e:
self.client.fail(to_native(e))
self.results["actions"].append(
self.actions.append(
f"Created volume {self.parameters.volume_name} with driver {self.parameters.driver}"
)
self.results["changed"] = True
def remove_volume(self):
def remove_volume(self) -> None:
if self.existing_volume:
if not self.check_mode:
try:
@ -270,12 +274,10 @@ class DockerVolumeManager:
except APIError as e:
self.client.fail(to_native(e))
self.results["actions"].append(
f"Removed volume {self.parameters.volume_name}"
)
self.actions.append(f"Removed volume {self.parameters.volume_name}")
self.results["changed"] = True
def present(self):
def present(self) -> None:
differences = DifferenceTracker()
if self.existing_volume:
differences = self.has_different_config()
@ -301,14 +303,14 @@ class DockerVolumeManager:
volume_facts = self.get_existing_volume()
self.results["volume"] = volume_facts
def absent(self):
def absent(self) -> None:
self.diff_tracker.add(
"exists", parameter=False, active=self.existing_volume is not None
)
self.remove_volume()
def main():
def main() -> None:
argument_spec = {
"volume_name": {"type": "str", "required": True, "aliases": ["name"]},
"state": {

View File

@ -71,6 +71,7 @@ volume:
"""
import traceback
import typing as t
from ansible_collections.community.docker.plugins.module_utils._api.errors import (
DockerException,
@ -82,7 +83,9 @@ from ansible_collections.community.docker.plugins.module_utils._common_api impor
)
def get_existing_volume(client, volume_name):
def get_existing_volume(
client: AnsibleDockerClient, volume_name: str
) -> dict[str, t.Any] | None:
try:
return client.get_json("/volumes/{0}", volume_name)
except NotFound:
@ -91,7 +94,7 @@ def get_existing_volume(client, volume_name):
client.fail(f"Error inspecting volume: {exc}")
def main():
def main() -> None:
argument_spec = {
"name": {"type": "str", "required": True, "aliases": ["volume_name"]},
}

View File

@ -7,6 +7,8 @@
from __future__ import annotations
import typing as t
from ansible.errors import AnsibleConnectionFailure
from ansible.utils.display import Display
@ -18,8 +20,17 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
)
if t.TYPE_CHECKING:
from ansible.plugins import AnsiblePlugin
class AnsibleDockerClient(AnsibleDockerClientBase):
def __init__(self, plugin, min_docker_version=None, min_docker_api_version=None):
def __init__(
self,
plugin: AnsiblePlugin,
min_docker_version: str | None = None,
min_docker_api_version: str | None = None,
) -> None:
self.plugin = plugin
self.display = Display()
super().__init__(
@ -27,17 +38,23 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
min_docker_api_version=min_docker_api_version,
)
def fail(self, msg, **kwargs):
def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
if kwargs:
msg += "\nContext:\n" + "\n".join(
f" {k} = {v!r}" for (k, v) in kwargs.items()
)
raise AnsibleConnectionFailure(msg)
def deprecate(self, msg, version=None, date=None, collection_name=None):
def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
self.display.deprecated(
msg, version=version, date=date, collection_name=collection_name
)
def _get_params(self):
def _get_params(self) -> dict[str, t.Any]:
return {option: self.plugin.get_option(option) for option in DOCKER_COMMON_ARGS}

View File

@ -7,6 +7,8 @@
from __future__ import annotations
import typing as t
from ansible.errors import AnsibleConnectionFailure
from ansible.utils.display import Display
@ -18,23 +20,35 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
)
if t.TYPE_CHECKING:
from ansible.plugins import AnsiblePlugin
class AnsibleDockerClient(AnsibleDockerClientBase):
def __init__(self, plugin, min_docker_api_version=None):
def __init__(
self, plugin: AnsiblePlugin, min_docker_api_version: str | None = None
) -> None:
self.plugin = plugin
self.display = Display()
super().__init__(min_docker_api_version=min_docker_api_version)
def fail(self, msg, **kwargs):
def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
if kwargs:
msg += "\nContext:\n" + "\n".join(
f" {k} = {v!r}" for (k, v) in kwargs.items()
)
raise AnsibleConnectionFailure(msg)
def deprecate(self, msg, version=None, date=None, collection_name=None):
def deprecate(
self,
msg: str,
version: str | None = None,
date: str | None = None,
collection_name: str | None = None,
) -> None:
self.display.deprecated(
msg, version=version, date=date, collection_name=collection_name
)
def _get_params(self):
def _get_params(self) -> dict[str, t.Any]:
return {option: self.plugin.get_option(option) for option in DOCKER_COMMON_ARGS}

View File

@ -7,11 +7,23 @@
from __future__ import annotations
import typing as t
from ansible_collections.community.docker.plugins.module_utils._socket_handler import (
DockerSocketHandlerBase,
)
if t.TYPE_CHECKING:
from ansible.utils.display import Display
from ansible_collections.community.docker.plugins.module_utils._socket_helper import (
SocketLike,
)
class DockerSocketHandler(DockerSocketHandlerBase):
def __init__(self, display, sock, log=None, container=None):
def __init__(
self, display: Display, sock: SocketLike, container: str | None = None
) -> None:
super().__init__(sock, log=lambda msg: display.vvvv(msg, host=container))

View File

@ -8,6 +8,7 @@
from __future__ import annotations
import re
import typing as t
from collections.abc import Mapping, Set
from ansible.module_utils.common.collections import is_sequence
@ -21,7 +22,7 @@ _RE_TEMPLATE_CHARS = re.compile("[{}]")
_RE_TEMPLATE_CHARS_BYTES = re.compile(b"[{}]")
def make_unsafe(value):
def make_unsafe(value: t.Any) -> t.Any:
if value is None or isinstance(value, AnsibleUnsafe):
return value

View File

@ -1 +1,22 @@
plugins/connection/docker.py no-assert
plugins/connection/docker_api.py no-assert
plugins/connection/nsenter.py no-assert
plugins/module_utils/_api/api/client.py pep8:E704
plugins/module_utils/_api/transport/sshconn.py no-assert
plugins/module_utils/_api/utils/build.py no-assert
plugins/module_utils/_api/utils/ports.py pep8:E704
plugins/module_utils/_api/utils/socket.py pep8:E704
plugins/module_utils/_common_cli.py pep8:E704
plugins/module_utils/_module_container/module.py no-assert
plugins/module_utils/_platform.py no-assert
plugins/module_utils/_socket_handler.py no-assert
plugins/module_utils/_swarm.py pep8:E704
plugins/module_utils/_util.py pep8:E704
plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin
plugins/modules/docker_container_exec.py no-assert
plugins/modules/docker_container_exec.py pylint:unpacking-non-sequence
plugins/modules/docker_image.py no-assert
plugins/modules/docker_image_tag.py no-assert
plugins/modules/docker_login.py no-assert
plugins/modules/docker_plugin.py no-assert
plugins/modules/docker_volume.py no-assert

View File

@ -1 +1,21 @@
plugins/connection/docker.py no-assert
plugins/connection/docker_api.py no-assert
plugins/connection/nsenter.py no-assert
plugins/module_utils/_api/api/client.py pep8:E704
plugins/module_utils/_api/transport/sshconn.py no-assert
plugins/module_utils/_api/utils/build.py no-assert
plugins/module_utils/_api/utils/ports.py pep8:E704
plugins/module_utils/_api/utils/socket.py pep8:E704
plugins/module_utils/_common_cli.py pep8:E704
plugins/module_utils/_module_container/module.py no-assert
plugins/module_utils/_platform.py no-assert
plugins/module_utils/_socket_handler.py no-assert
plugins/module_utils/_swarm.py pep8:E704
plugins/module_utils/_util.py pep8:E704
plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin
plugins/modules/docker_container_exec.py no-assert
plugins/modules/docker_image.py no-assert
plugins/modules/docker_image_tag.py no-assert
plugins/modules/docker_login.py no-assert
plugins/modules/docker_plugin.py no-assert
plugins/modules/docker_volume.py no-assert

View File

@ -1 +1,15 @@
plugins/connection/docker.py no-assert
plugins/connection/docker_api.py no-assert
plugins/connection/nsenter.py no-assert
plugins/module_utils/_api/transport/sshconn.py no-assert
plugins/module_utils/_api/utils/build.py no-assert
plugins/module_utils/_module_container/module.py no-assert
plugins/module_utils/_platform.py no-assert
plugins/module_utils/_socket_handler.py no-assert
plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin
plugins/modules/docker_container_exec.py no-assert
plugins/modules/docker_image.py no-assert
plugins/modules/docker_image_tag.py no-assert
plugins/modules/docker_login.py no-assert
plugins/modules/docker_plugin.py no-assert
plugins/modules/docker_volume.py no-assert

View File

@ -1 +1,15 @@
plugins/connection/docker.py no-assert
plugins/connection/docker_api.py no-assert
plugins/connection/nsenter.py no-assert
plugins/module_utils/_api/transport/sshconn.py no-assert
plugins/module_utils/_api/utils/build.py no-assert
plugins/module_utils/_module_container/module.py no-assert
plugins/module_utils/_platform.py no-assert
plugins/module_utils/_socket_handler.py no-assert
plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin
plugins/modules/docker_container_exec.py no-assert
plugins/modules/docker_image.py no-assert
plugins/modules/docker_image_tag.py no-assert
plugins/modules/docker_login.py no-assert
plugins/modules/docker_plugin.py no-assert
plugins/modules/docker_volume.py no-assert

View File

@ -1 +1,15 @@
plugins/connection/docker.py no-assert
plugins/connection/docker_api.py no-assert
plugins/connection/nsenter.py no-assert
plugins/module_utils/_api/transport/sshconn.py no-assert
plugins/module_utils/_api/utils/build.py no-assert
plugins/module_utils/_module_container/module.py no-assert
plugins/module_utils/_platform.py no-assert
plugins/module_utils/_socket_handler.py no-assert
plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin
plugins/modules/docker_container_exec.py no-assert
plugins/modules/docker_image.py no-assert
plugins/modules/docker_image_tag.py no-assert
plugins/modules/docker_login.py no-assert
plugins/modules/docker_plugin.py no-assert
plugins/modules/docker_volume.py no-assert

View File

@ -19,7 +19,7 @@ from ansible_collections.community.docker.plugins.module_utils._api.transport im
try:
from ssl import CertificateError, match_hostname
from ssl import CertificateError, match_hostname # type: ignore
except ImportError:
HAS_MATCH_HOSTNAME = False # pylint: disable=invalid-name
else:

View File

@ -12,8 +12,8 @@ import json
import os
import shutil
import tempfile
import typing as t
import unittest
from collections.abc import Callable
from unittest import mock
from pytest import fixture, mark
@ -22,7 +22,7 @@ from ansible_collections.community.docker.plugins.module_utils._api.utils import
class FindConfigFileTest(unittest.TestCase):
mkdir: t.Callable[[str], os.PathLike[str]]
mkdir: Callable[[str], os.PathLike[str]]
@fixture(autouse=True)
def tmpdir(self, tmpdir):

View File

@ -11,7 +11,7 @@ from ansible_collections.community.docker.plugins.module_utils._compose_v2 impor
)
EVENT_TEST_CASES = [
EVENT_TEST_CASES: list[tuple[str, str, bool, bool, str, list[Event], list[Event]]] = [
# #######################################################################################################################
# ## Docker Compose 2.18.1 ##############################################################################################
# #######################################################################################################################