diff --git a/.mypy.ini b/.mypy.ini index 95dc41a5..28e42654 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -3,15 +3,15 @@ # SPDX-License-Identifier: GPL-3.0-or-later [mypy] -# check_untyped_defs = True -- for later -# disallow_untyped_defs = True -- for later +check_untyped_defs = True +disallow_untyped_defs = True # strict = True -- only try to enable once everything (including dependencies!) is typed strict_equality = True strict_bytes = True warn_redundant_casts = True -# warn_return_any = True -- for later +# warn_return_any = True warn_unreachable = True [mypy-ansible.*] diff --git a/plugins/connection/docker.py b/plugins/connection/docker.py index fd721388..fe66e653 100644 --- a/plugins/connection/docker.py +++ b/plugins/connection/docker.py @@ -141,7 +141,7 @@ class Connection(ConnectionBase): transport = "community.docker.docker" has_pipelining = True - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: super().__init__(*args, **kwargs) # Note: docker supports running as non-root in some configurations. @@ -476,7 +476,7 @@ class Connection(ConnectionBase): display.debug("done with docker.exec_command()") return (p.returncode, stdout, stderr) - def _prefix_login_path(self, remote_path): + def _prefix_login_path(self, remote_path: str) -> str: """Make sure that we put files into a standard path If a path is relative, then we need to choose where to put it. diff --git a/plugins/connection/docker_api.py b/plugins/connection/docker_api.py index dd0cc479..bce216ab 100644 --- a/plugins/connection/docker_api.py +++ b/plugins/connection/docker_api.py @@ -192,7 +192,7 @@ class Connection(ConnectionBase): f'An unexpected requests error occurred for container "{remote_addr}" when trying to talk to the Docker daemon: {e}' ) - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: super().__init__(*args, **kwargs) self.client: AnsibleDockerClient | None = None @@ -319,7 +319,7 @@ class Connection(ConnectionBase): become_output = [b""] - def append_become_output(stream_id, data): + def append_become_output(stream_id: int, data: bytes) -> None: become_output[0] += data exec_socket_handler.set_block_done_callback( diff --git a/plugins/connection/nsenter.py b/plugins/connection/nsenter.py index b65803c3..57b39c3a 100644 --- a/plugins/connection/nsenter.py +++ b/plugins/connection/nsenter.py @@ -65,7 +65,7 @@ class Connection(ConnectionBase): transport = "community.docker.nsenter" has_pipelining = False - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: super().__init__(*args, **kwargs) self.cwd = None self._nsenter_pid = None diff --git a/plugins/inventory/docker_machine.py b/plugins/inventory/docker_machine.py index 0fcd08e9..e9ccca1b 100644 --- a/plugins/inventory/docker_machine.py +++ b/plugins/inventory/docker_machine.py @@ -221,7 +221,10 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable): return ip_addr def _should_skip_host( - self, machine_name: str, env_var_tuples, daemon_env: DaemonEnv + self, + machine_name: str, + env_var_tuples: list[tuple[str, str]], + daemon_env: DaemonEnv, ) -> bool: if not env_var_tuples: warning_prefix = f"Unable to fetch Docker daemon env vars from Docker Machine for host {machine_name}" diff --git a/plugins/module_utils/_api/_import_helper.py b/plugins/module_utils/_api/_import_helper.py index b2f7bc38..3891ca10 100644 --- a/plugins/module_utils/_api/_import_helper.py +++ b/plugins/module_utils/_api/_import_helper.py @@ -67,7 +67,7 @@ except ImportError: pass class FakeURLLIB3: - def __init__(self): + def __init__(self) -> None: self._collections = self self.poolmanager = self self.connection = self @@ -81,14 +81,14 @@ except ImportError: ) class FakeURLLIB3Connection: - def __init__(self): + def __init__(self) -> None: self.HTTPConnection = _HTTPConnection # pylint: disable=invalid-name urllib3 = FakeURLLIB3() urllib3_connection = FakeURLLIB3Connection() -def fail_on_missing_imports(): +def fail_on_missing_imports() -> None: if REQUESTS_IMPORT_ERROR is not None: from .errors import MissingRequirementException # pylint: disable=cyclic-import diff --git a/plugins/module_utils/_api/api/client.py b/plugins/module_utils/_api/api/client.py index 3393812c..622b4539 100644 --- a/plugins/module_utils/_api/api/client.py +++ b/plugins/module_utils/_api/api/client.py @@ -55,6 +55,7 @@ from ..utils.socket import consume_socket_output, demux_adaptor, frames_iter if t.TYPE_CHECKING: from requests import Response + from requests.adapters import BaseAdapter from ..._socket_helper import SocketLike @@ -258,23 +259,23 @@ class APIClient(_Session): return kwargs @update_headers - def _post(self, url: str, **kwargs): + def _post(self, url: str, **kwargs: t.Any) -> Response: return self.post(url, **self._set_request_timeout(kwargs)) @update_headers - def _get(self, url: str, **kwargs): + def _get(self, url: str, **kwargs: t.Any) -> Response: return self.get(url, **self._set_request_timeout(kwargs)) @update_headers - def _head(self, url: str, **kwargs): + def _head(self, url: str, **kwargs: t.Any) -> Response: return self.head(url, **self._set_request_timeout(kwargs)) @update_headers - def _put(self, url: str, **kwargs): + def _put(self, url: str, **kwargs: t.Any) -> Response: return self.put(url, **self._set_request_timeout(kwargs)) @update_headers - def _delete(self, url: str, **kwargs): + def _delete(self, url: str, **kwargs: t.Any) -> Response: return self.delete(url, **self._set_request_timeout(kwargs)) def _url(self, pathfmt: str, *args: str, versioned_api: bool = True) -> str: @@ -343,7 +344,7 @@ class APIClient(_Session): return response.text def _post_json( - self, url: str, data: dict[str, str | None] | t.Any, **kwargs + self, url: str, data: dict[str, str | None] | t.Any, **kwargs: t.Any ) -> Response: # Go <1.1 cannot unserialize null to a string # so we do this disgusting thing here. @@ -556,22 +557,30 @@ class APIClient(_Session): """ socket = self._get_raw_response_socket(response) - gen: t.Generator = frames_iter(socket, tty) + gen = frames_iter(socket, tty) if demux: # The generator will output tuples (stdout, stderr) - gen = (demux_adaptor(*frame) for frame in gen) + demux_gen: t.Generator[tuple[bytes | None, bytes | None]] = ( + demux_adaptor(*frame) for frame in gen + ) + if stream: + return demux_gen + try: + # Wait for all the frames, concatenate them, and return the result + return consume_socket_output(demux_gen, demux=True) + finally: + response.close() else: # The generator will output strings - gen = (data for (dummy, data) in gen) - - if stream: - return gen - try: - # Wait for all the frames, concatenate them, and return the result - return consume_socket_output(gen, demux=demux) - finally: - response.close() + mux_gen: t.Generator[bytes] = (data for (dummy, data) in gen) + if stream: + return mux_gen + try: + # Wait for all the frames, concatenate them, and return the result + return consume_socket_output(mux_gen, demux=False) + finally: + response.close() def _disable_socket_timeout(self, socket: SocketLike) -> None: """Depending on the combination of python version and whether we are @@ -637,11 +646,11 @@ class APIClient(_Session): return self._multiplexed_response_stream_helper(res) return sep.join(list(self._multiplexed_buffer_helper(res))) - def _unmount(self, *args) -> None: + def _unmount(self, *args: t.Any) -> None: for proto in args: self.adapters.pop(proto) - def get_adapter(self, url: str): + def get_adapter(self, url: str) -> BaseAdapter: try: return super().get_adapter(url) except _InvalidSchema as e: @@ -696,19 +705,19 @@ class APIClient(_Session): else: log.debug("No auth config found") - def get_binary(self, pathfmt: str, *args: str, **kwargs) -> bytes: + def get_binary(self, pathfmt: str, *args: str, **kwargs: t.Any) -> bytes: return self._result( self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs), get_binary=True, ) - def get_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any: + def get_json(self, pathfmt: str, *args: str, **kwargs: t.Any) -> t.Any: return self._result( self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs), get_json=True, ) - def get_text(self, pathfmt: str, *args: str, **kwargs) -> str: + def get_text(self, pathfmt: str, *args: str, **kwargs: t.Any) -> str: return self._result( self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs) ) @@ -718,7 +727,7 @@ class APIClient(_Session): pathfmt: str, *args: str, chunk_size: int = DEFAULT_DATA_CHUNK_SIZE, - **kwargs, + **kwargs: t.Any, ) -> t.Generator[bytes]: res = self._get( self._url(pathfmt, *args, versioned_api=True), stream=True, **kwargs @@ -726,23 +735,25 @@ class APIClient(_Session): self._raise_for_status(res) return self._stream_raw_result(res, chunk_size=chunk_size, decode=False) - def delete_call(self, pathfmt: str, *args: str, **kwargs) -> None: + def delete_call(self, pathfmt: str, *args: str, **kwargs: t.Any) -> None: self._raise_for_status( self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs) ) - def delete_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any: + def delete_json(self, pathfmt: str, *args: str, **kwargs: t.Any) -> t.Any: return self._result( self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs), get_json=True, ) - def post_call(self, pathfmt: str, *args: str, **kwargs) -> None: + def post_call(self, pathfmt: str, *args: str, **kwargs: t.Any) -> None: self._raise_for_status( self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs) ) - def post_json(self, pathfmt: str, *args: str, data: t.Any = None, **kwargs) -> None: + def post_json( + self, pathfmt: str, *args: str, data: t.Any = None, **kwargs: t.Any + ) -> None: self._raise_for_status( self._post_json( self._url(pathfmt, *args, versioned_api=True), data, **kwargs @@ -750,7 +761,7 @@ class APIClient(_Session): ) def post_json_to_binary( - self, pathfmt: str, *args: str, data: t.Any = None, **kwargs + self, pathfmt: str, *args: str, data: t.Any = None, **kwargs: t.Any ) -> bytes: return self._result( self._post_json( @@ -760,7 +771,7 @@ class APIClient(_Session): ) def post_json_to_json( - self, pathfmt: str, *args: str, data: t.Any = None, **kwargs + self, pathfmt: str, *args: str, data: t.Any = None, **kwargs: t.Any ) -> t.Any: return self._result( self._post_json( @@ -770,7 +781,7 @@ class APIClient(_Session): ) def post_json_to_text( - self, pathfmt: str, *args: str, data: t.Any = None, **kwargs + self, pathfmt: str, *args: str, data: t.Any = None, **kwargs: t.Any ) -> str: return self._result( self._post_json( @@ -784,7 +795,7 @@ class APIClient(_Session): *args: str, data: t.Any = None, headers: dict[str, str] | None = None, - **kwargs, + **kwargs: t.Any, ) -> SocketLike: headers = headers.copy() if headers else {} headers.update( @@ -813,7 +824,7 @@ class APIClient(_Session): stream: t.Literal[True], tty: bool = True, demux: t.Literal[False] = False, - **kwargs, + **kwargs: t.Any, ) -> t.Generator[bytes]: ... @t.overload @@ -826,7 +837,7 @@ class APIClient(_Session): stream: t.Literal[True], tty: t.Literal[True] = True, demux: t.Literal[True], - **kwargs, + **kwargs: t.Any, ) -> t.Generator[tuple[bytes, None]]: ... @t.overload @@ -839,7 +850,7 @@ class APIClient(_Session): stream: t.Literal[True], tty: t.Literal[False], demux: t.Literal[True], - **kwargs, + **kwargs: t.Any, ) -> t.Generator[tuple[bytes, None] | tuple[None, bytes]]: ... @t.overload @@ -852,7 +863,7 @@ class APIClient(_Session): stream: t.Literal[False], tty: bool = True, demux: t.Literal[False] = False, - **kwargs, + **kwargs: t.Any, ) -> bytes: ... @t.overload @@ -865,7 +876,7 @@ class APIClient(_Session): stream: t.Literal[False], tty: t.Literal[True] = True, demux: t.Literal[True], - **kwargs, + **kwargs: t.Any, ) -> tuple[bytes, None]: ... @t.overload @@ -878,7 +889,7 @@ class APIClient(_Session): stream: t.Literal[False], tty: t.Literal[False], demux: t.Literal[True], - **kwargs, + **kwargs: t.Any, ) -> tuple[bytes, bytes]: ... def post_json_to_stream( @@ -890,7 +901,7 @@ class APIClient(_Session): stream: bool = False, demux: bool = False, tty: bool = False, - **kwargs, + **kwargs: t.Any, ) -> t.Any: headers = headers.copy() if headers else {} headers.update( @@ -912,7 +923,7 @@ class APIClient(_Session): demux=demux, ) - def post_to_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any: + def post_to_json(self, pathfmt: str, *args: str, **kwargs: t.Any) -> t.Any: return self._result( self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs), get_json=True, diff --git a/plugins/module_utils/_api/auth.py b/plugins/module_utils/_api/auth.py index 0c6cff00..78271ef1 100644 --- a/plugins/module_utils/_api/auth.py +++ b/plugins/module_utils/_api/auth.py @@ -106,7 +106,7 @@ class AuthConfig(dict): @classmethod def parse_auth( - cls, entries: dict[str, dict[str, t.Any]], raise_on_error=False + cls, entries: dict[str, dict[str, t.Any]], raise_on_error: bool = False ) -> dict[str, dict[str, t.Any]]: """ Parses authentication entries @@ -294,7 +294,7 @@ class AuthConfig(dict): except StoreError as e: raise errors.DockerException(f"Credentials store error: {e}") - def _get_store_instance(self, name: str): + def _get_store_instance(self, name: str) -> Store: if name not in self._stores: self._stores[name] = Store(name, environment=self._credstore_env) return self._stores[name] @@ -326,8 +326,10 @@ class AuthConfig(dict): def resolve_authconfig( - authconfig, registry: str | None = None, credstore_env: dict[str, str] | None = None -): + authconfig: AuthConfig | dict[str, t.Any], + registry: str | None = None, + credstore_env: dict[str, str] | None = None, +) -> dict[str, t.Any] | None: if not isinstance(authconfig, AuthConfig): authconfig = AuthConfig(authconfig, credstore_env) return authconfig.resolve_authconfig(registry) diff --git a/plugins/module_utils/_api/context/config.py b/plugins/module_utils/_api/context/config.py index 6ab07b0d..04fddc12 100644 --- a/plugins/module_utils/_api/context/config.py +++ b/plugins/module_utils/_api/context/config.py @@ -89,7 +89,7 @@ def get_meta_dir(name: str | None = None) -> str: return meta_dir -def get_meta_file(name) -> str: +def get_meta_file(name: str) -> str: return os.path.join(get_meta_dir(name), METAFILE) diff --git a/plugins/module_utils/_api/errors.py b/plugins/module_utils/_api/errors.py index 12b197cb..cd62ba5e 100644 --- a/plugins/module_utils/_api/errors.py +++ b/plugins/module_utils/_api/errors.py @@ -18,6 +18,10 @@ from ansible.module_utils.common.text.converters import to_native from ._import_helper import HTTPError as _HTTPError +if t.TYPE_CHECKING: + from requests import Response + + class DockerException(Exception): """ A base class from which all other exceptions inherit. @@ -55,7 +59,10 @@ class APIError(_HTTPError, DockerException): """ def __init__( - self, message: str | Exception, response=None, explanation: str | None = None + self, + message: str | Exception, + response: Response | None = None, + explanation: str | None = None, ) -> None: # requests 1.2 supports response as a keyword argument, but # requests 1.1 does not diff --git a/plugins/module_utils/_api/transport/npipeconn.py b/plugins/module_utils/_api/transport/npipeconn.py index f50cb91b..8c89b98b 100644 --- a/plugins/module_utils/_api/transport/npipeconn.py +++ b/plugins/module_utils/_api/transport/npipeconn.py @@ -11,6 +11,7 @@ from __future__ import annotations +import typing as t from queue import Empty from .. import constants @@ -19,6 +20,12 @@ from .basehttpadapter import BaseHTTPAdapter from .npipesocket import NpipeSocket +if t.TYPE_CHECKING: + from collections.abc import Mapping + + from requests import PreparedRequest + + RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer @@ -91,7 +98,9 @@ class NpipeHTTPAdapter(BaseHTTPAdapter): ) super().__init__() - def get_connection(self, url: str | bytes, proxies=None) -> NpipeHTTPConnectionPool: + def get_connection( + self, url: str | bytes, proxies: Mapping[str, str] | None = None + ) -> NpipeHTTPConnectionPool: with self.pools.lock: pool = self.pools.get(url) if pool: @@ -104,7 +113,9 @@ class NpipeHTTPAdapter(BaseHTTPAdapter): return pool - def request_url(self, request, proxies) -> str: + def request_url( + self, request: PreparedRequest, proxies: Mapping[str, str] | None + ) -> str: # The select_proxy utility in requests errors out when the provided URL # does not have a hostname, like is the case when using a UNIX socket. # Since proxies are an irrelevant notion in the case of UNIX sockets diff --git a/plugins/module_utils/_api/transport/npipesocket.py b/plugins/module_utils/_api/transport/npipesocket.py index e4473f49..2f4e8ab4 100644 --- a/plugins/module_utils/_api/transport/npipesocket.py +++ b/plugins/module_utils/_api/transport/npipesocket.py @@ -64,7 +64,7 @@ class NpipeSocket: implemented. """ - def __init__(self, handle=None) -> None: + def __init__(self, handle: t.Any | None = None) -> None: self._timeout = win32pipe.NMPWAIT_USE_DEFAULT_WAIT self._handle = handle self._address: str | None = None @@ -74,15 +74,17 @@ class NpipeSocket: def accept(self) -> t.NoReturn: raise NotImplementedError() - def bind(self, address) -> t.NoReturn: + def bind(self, address: t.Any) -> t.NoReturn: raise NotImplementedError() def close(self) -> None: + if self._handle is None: + raise ValueError("Handle not present") self._handle.Close() self._closed = True @check_closed - def connect(self, address, retry_count: int = 0) -> None: + def connect(self, address: str, retry_count: int = 0) -> None: try: handle = win32file.CreateFile( address, @@ -116,11 +118,11 @@ class NpipeSocket: self._address = address @check_closed - def connect_ex(self, address) -> None: + def connect_ex(self, address: str) -> None: self.connect(address) @check_closed - def detach(self): + def detach(self) -> t.Any: self._closed = True return self._handle @@ -134,16 +136,18 @@ class NpipeSocket: def getsockname(self) -> str | None: return self._address - def getsockopt(self, level, optname, buflen=None) -> t.NoReturn: + def getsockopt( + self, level: t.Any, optname: t.Any, buflen: t.Any = None + ) -> t.NoReturn: raise NotImplementedError() - def ioctl(self, control, option) -> t.NoReturn: + def ioctl(self, control: t.Any, option: t.Any) -> t.NoReturn: raise NotImplementedError() - def listen(self, backlog) -> t.NoReturn: + def listen(self, backlog: t.Any) -> t.NoReturn: raise NotImplementedError() - def makefile(self, mode: str, bufsize: int | None = None): + def makefile(self, mode: str, bufsize: int | None = None) -> t.IO[bytes]: if mode.strip("b") != "r": raise NotImplementedError() rawio = NpipeFileIOBase(self) @@ -153,6 +157,8 @@ class NpipeSocket: @check_closed def recv(self, bufsize: int, flags: int = 0) -> str: + if self._handle is None: + raise ValueError("Handle not present") dummy_err, data = win32file.ReadFile(self._handle, bufsize) return data @@ -169,6 +175,8 @@ class NpipeSocket: @check_closed def recv_into(self, buf: Buffer, nbytes: int = 0) -> int: + if self._handle is None: + raise ValueError("Handle not present") readbuf = buf if isinstance(buf, memoryview) else memoryview(buf) event = win32event.CreateEvent(None, True, True, None) @@ -188,6 +196,8 @@ class NpipeSocket: @check_closed def send(self, string: Buffer, flags: int = 0) -> int: + if self._handle is None: + raise ValueError("Handle not present") event = win32event.CreateEvent(None, True, True, None) try: overlapped = pywintypes.OVERLAPPED() @@ -210,7 +220,7 @@ class NpipeSocket: self.connect(address) return self.send(string) - def setblocking(self, flag: bool): + def setblocking(self, flag: bool) -> None: if flag: return self.settimeout(None) return self.settimeout(0) @@ -228,16 +238,16 @@ class NpipeSocket: def gettimeout(self) -> int | float | None: return self._timeout - def setsockopt(self, level, optname, value) -> t.NoReturn: + def setsockopt(self, level: t.Any, optname: t.Any, value: t.Any) -> t.NoReturn: raise NotImplementedError() @check_closed - def shutdown(self, how) -> None: + def shutdown(self, how: t.Any) -> None: return self.close() class NpipeFileIOBase(io.RawIOBase): - def __init__(self, npipe_socket) -> None: + def __init__(self, npipe_socket: NpipeSocket | None) -> None: self.sock = npipe_socket def close(self) -> None: @@ -245,7 +255,10 @@ class NpipeFileIOBase(io.RawIOBase): self.sock = None def fileno(self) -> int: - return self.sock.fileno() + if self.sock is None: + raise RuntimeError("socket is closed") + # TODO: This is definitely a bug, NpipeSocket.fileno() does not exist! + return self.sock.fileno() # type: ignore def isatty(self) -> bool: return False @@ -254,6 +267,8 @@ class NpipeFileIOBase(io.RawIOBase): return True def readinto(self, buf: Buffer) -> int: + if self.sock is None: + raise RuntimeError("socket is closed") return self.sock.recv_into(buf) def seekable(self) -> bool: diff --git a/plugins/module_utils/_api/transport/sshconn.py b/plugins/module_utils/_api/transport/sshconn.py index 6bafa06d..876a430c 100644 --- a/plugins/module_utils/_api/transport/sshconn.py +++ b/plugins/module_utils/_api/transport/sshconn.py @@ -35,7 +35,7 @@ else: PARAMIKO_IMPORT_ERROR = None # pylint: disable=invalid-name if t.TYPE_CHECKING: - from collections.abc import Buffer + from collections.abc import Buffer, Mapping RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer @@ -67,7 +67,7 @@ class SSHSocket(socket.socket): preexec_func = None if not constants.IS_WINDOWS_PLATFORM: - def f(): + def f() -> None: signal.signal(signal.SIGINT, signal.SIG_IGN) preexec_func = f @@ -100,13 +100,13 @@ class SSHSocket(socket.socket): self.proc.stdin.flush() return written - def sendall(self, data: Buffer, *args, **kwargs) -> None: + def sendall(self, data: Buffer, *args: t.Any, **kwargs: t.Any) -> None: self._write(data) - def send(self, data: Buffer, *args, **kwargs) -> int: + def send(self, data: Buffer, *args: t.Any, **kwargs: t.Any) -> int: return self._write(data) - def recv(self, n: int, *args, **kwargs) -> bytes: + def recv(self, n: int, *args: t.Any, **kwargs: t.Any) -> bytes: if not self.proc: raise RuntimeError( "SSH subprocess not initiated. connect() must be called first." @@ -114,7 +114,7 @@ class SSHSocket(socket.socket): assert self.proc.stdout is not None return self.proc.stdout.read(n) - def makefile(self, mode: str, *args, **kwargs) -> t.IO: # type: ignore + def makefile(self, mode: str, *args: t.Any, **kwargs: t.Any) -> t.IO: # type: ignore if not self.proc: self.connect() assert self.proc is not None @@ -138,7 +138,7 @@ class SSHConnection(urllib3_connection.HTTPConnection): def __init__( self, *, - ssh_transport=None, + ssh_transport: paramiko.Transport | None = None, timeout: int | float = 60, host: str, ) -> None: @@ -146,18 +146,19 @@ class SSHConnection(urllib3_connection.HTTPConnection): self.ssh_transport = ssh_transport self.timeout = timeout self.ssh_host = host + self.sock: paramiko.Channel | SSHSocket | None = None def connect(self) -> None: if self.ssh_transport: - sock = self.ssh_transport.open_session() - sock.settimeout(self.timeout) - sock.exec_command("docker system dial-stdio") + channel = self.ssh_transport.open_session() + channel.settimeout(self.timeout) + channel.exec_command("docker system dial-stdio") + self.sock = channel else: sock = SSHSocket(self.ssh_host) sock.settimeout(self.timeout) sock.connect() - - self.sock = sock + self.sock = sock class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool): @@ -172,7 +173,7 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool): host: str, ) -> None: super().__init__("localhost", timeout=timeout, maxsize=maxsize) - self.ssh_transport = None + self.ssh_transport: paramiko.Transport | None = None self.timeout = timeout if ssh_client: self.ssh_transport = ssh_client.get_transport() @@ -276,7 +277,9 @@ class SSHHTTPAdapter(BaseHTTPAdapter): if self.ssh_client: self.ssh_client.connect(**self.ssh_params) - def get_connection(self, url: str | bytes, proxies=None) -> SSHConnectionPool: + def get_connection( + self, url: str | bytes, proxies: Mapping[str, str] | None = None + ) -> SSHConnectionPool: if not self.ssh_client: return SSHConnectionPool( ssh_client=self.ssh_client, diff --git a/plugins/module_utils/_api/transport/ssladapter.py b/plugins/module_utils/_api/transport/ssladapter.py index 2cad6cea..3d1d674e 100644 --- a/plugins/module_utils/_api/transport/ssladapter.py +++ b/plugins/module_utils/_api/transport/ssladapter.py @@ -33,7 +33,7 @@ class SSLHTTPAdapter(BaseHTTPAdapter): def __init__( self, assert_hostname: bool | None = None, - **kwargs, + **kwargs: t.Any, ) -> None: self.assert_hostname = assert_hostname super().__init__(**kwargs) @@ -51,7 +51,7 @@ class SSLHTTPAdapter(BaseHTTPAdapter): self.poolmanager = PoolManager(**kwargs) - def get_connection(self, *args, **kwargs) -> urllib3.ConnectionPool: + def get_connection(self, *args: t.Any, **kwargs: t.Any) -> urllib3.ConnectionPool: """ Ensure assert_hostname is set correctly on our pool diff --git a/plugins/module_utils/_api/transport/unixconn.py b/plugins/module_utils/_api/transport/unixconn.py index 4d3b5679..4f4c05f2 100644 --- a/plugins/module_utils/_api/transport/unixconn.py +++ b/plugins/module_utils/_api/transport/unixconn.py @@ -19,12 +19,20 @@ from .._import_helper import HTTPAdapter, urllib3, urllib3_connection from .basehttpadapter import BaseHTTPAdapter +if t.TYPE_CHECKING: + from collections.abc import Mapping + + from requests import PreparedRequest + + from ..._socket_helper import SocketLike + + RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer class UnixHTTPConnection(urllib3_connection.HTTPConnection): def __init__( - self, base_url: str | bytes, unix_socket, timeout: int | float = 60 + self, base_url: str | bytes, unix_socket: str, timeout: int | float = 60 ) -> None: super().__init__("localhost", timeout=timeout) self.base_url = base_url @@ -43,7 +51,7 @@ class UnixHTTPConnection(urllib3_connection.HTTPConnection): if header == "Connection" and "Upgrade" in values: self.disable_buffering = True - def response_class(self, sock, *args, **kwargs) -> t.Any: + def response_class(self, sock: SocketLike, *args: t.Any, **kwargs: t.Any) -> t.Any: # FIXME: We may need to disable buffering on Py3, # but there's no clear way to do it at the moment. See: # https://github.com/docker/docker-py/issues/1799 @@ -88,12 +96,16 @@ class UnixHTTPAdapter(BaseHTTPAdapter): self.socket_path = socket_path self.timeout = timeout self.max_pool_size = max_pool_size - self.pools = RecentlyUsedContainer( - pool_connections, dispose_func=lambda p: p.close() - ) + + def f(p: t.Any) -> None: + p.close() + + self.pools = RecentlyUsedContainer(pool_connections, dispose_func=f) super().__init__() - def get_connection(self, url: str | bytes, proxies=None) -> UnixHTTPConnectionPool: + def get_connection( + self, url: str | bytes, proxies: Mapping[str, str] | None = None + ) -> UnixHTTPConnectionPool: with self.pools.lock: pool = self.pools.get(url) if pool: @@ -106,7 +118,7 @@ class UnixHTTPAdapter(BaseHTTPAdapter): return pool - def request_url(self, request, proxies) -> str: + def request_url(self, request: PreparedRequest, proxies: Mapping[str, str]) -> str: # The select_proxy utility in requests errors out when the provided URL # does not have a hostname, like is the case when using a UNIX socket. # Since proxies are an irrelevant notion in the case of UNIX sockets diff --git a/plugins/module_utils/_api/types/daemon.py b/plugins/module_utils/_api/types/daemon.py index eb386169..4d9591d6 100644 --- a/plugins/module_utils/_api/types/daemon.py +++ b/plugins/module_utils/_api/types/daemon.py @@ -18,7 +18,13 @@ from .._import_helper import urllib3 from ..errors import DockerException -class CancellableStream: +if t.TYPE_CHECKING: + from requests import Response + +_T = t.TypeVar("_T") + + +class CancellableStream(t.Generic[_T]): """ Stream wrapper for real-time events, logs, etc. from the server. @@ -30,14 +36,14 @@ class CancellableStream: >>> events.close() """ - def __init__(self, stream, response) -> None: + def __init__(self, stream: t.Generator[_T], response: Response) -> None: self._stream = stream self._response = response def __iter__(self) -> t.Self: return self - def __next__(self): + def __next__(self) -> _T: try: return next(self._stream) except urllib3.exceptions.ProtocolError as exc: @@ -56,7 +62,7 @@ class CancellableStream: # find the underlying socket object # based on api.client._get_raw_response_socket - sock_fp = self._response.raw._fp.fp + sock_fp = self._response.raw._fp.fp # type: ignore if hasattr(sock_fp, "raw"): sock_raw = sock_fp.raw @@ -74,7 +80,7 @@ class CancellableStream: "Cancellable streams not supported for the SSH protocol" ) else: - sock = sock_fp._sock + sock = sock_fp._sock # type: ignore if hasattr(urllib3.contrib, "pyopenssl") and isinstance( sock, urllib3.contrib.pyopenssl.WrappedSocket diff --git a/plugins/module_utils/_api/utils/fnmatch.py b/plugins/module_utils/_api/utils/fnmatch.py index 525cf84a..2761585b 100644 --- a/plugins/module_utils/_api/utils/fnmatch.py +++ b/plugins/module_utils/_api/utils/fnmatch.py @@ -37,7 +37,7 @@ def _purge() -> None: _cache.clear() -def fnmatch(name: str, pat: str): +def fnmatch(name: str, pat: str) -> bool: """Test whether FILENAME matches PATTERN. Patterns are Unix shell style: diff --git a/plugins/module_utils/_api/utils/ports.py b/plugins/module_utils/_api/utils/ports.py index eab15bd0..e5eb28a3 100644 --- a/plugins/module_utils/_api/utils/ports.py +++ b/plugins/module_utils/_api/utils/ports.py @@ -108,7 +108,7 @@ def port_range( def split_port( - port: str, + port: str | int, ) -> tuple[list[str], list[str] | list[tuple[str, str | None]] | None]: port = str(port) match = PORT_SPEC.match(port) diff --git a/plugins/module_utils/_api/utils/proxy.py b/plugins/module_utils/_api/utils/proxy.py index 0f5fa9f3..6228be6a 100644 --- a/plugins/module_utils/_api/utils/proxy.py +++ b/plugins/module_utils/_api/utils/proxy.py @@ -11,6 +11,8 @@ from __future__ import annotations +import typing as t + from .utils import format_environment @@ -67,7 +69,17 @@ class ProxyConfig(dict): env["no_proxy"] = env["NO_PROXY"] = self.no_proxy return env - def inject_proxy_environment(self, environment: list[str]) -> list[str]: + @t.overload + def inject_proxy_environment(self, environment: list[str]) -> list[str]: ... + + @t.overload + def inject_proxy_environment( + self, environment: list[str] | None + ) -> list[str] | None: ... + + def inject_proxy_environment( + self, environment: list[str] | None + ) -> list[str] | None: """ Given a list of strings representing environment variables, prepend the environment variables corresponding to the proxy settings. diff --git a/plugins/module_utils/_api/utils/socket.py b/plugins/module_utils/_api/utils/socket.py index 6619e0ff..642a3997 100644 --- a/plugins/module_utils/_api/utils/socket.py +++ b/plugins/module_utils/_api/utils/socket.py @@ -22,7 +22,9 @@ from ..transport.npipesocket import NpipeSocket if t.TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Iterable, Sequence + + from ..._socket_helper import SocketLike STDOUT = 1 @@ -38,14 +40,14 @@ class SocketError(Exception): NPIPE_ENDED = 109 -def read(socket, n: int = 4096) -> bytes | None: +def read(socket: SocketLike, n: int = 4096) -> bytes | None: """ Reads at most n bytes from socket """ recoverable_errors = (errno.EINTR, errno.EDEADLK, errno.EWOULDBLOCK) - if not isinstance(socket, NpipeSocket): + if not isinstance(socket, NpipeSocket): # type: ignore[unreachable] if not hasattr(select, "poll"): # Limited to 1024 select.select([socket], [], []) @@ -66,7 +68,7 @@ def read(socket, n: int = 4096) -> bytes | None: return None # TODO ??? except Exception as e: is_pipe_ended = ( - isinstance(socket, NpipeSocket) + isinstance(socket, NpipeSocket) # type: ignore[unreachable] and len(e.args) > 0 and e.args[0] == NPIPE_ENDED ) @@ -77,7 +79,7 @@ def read(socket, n: int = 4096) -> bytes | None: raise -def read_exactly(socket, n: int) -> bytes: +def read_exactly(socket: SocketLike, n: int) -> bytes: """ Reads exactly n bytes from socket Raises SocketError if there is not enough data @@ -91,7 +93,7 @@ def read_exactly(socket, n: int) -> bytes: return data -def next_frame_header(socket) -> tuple[int, int]: +def next_frame_header(socket: SocketLike) -> tuple[int, int]: """ Returns the stream and size of the next frame of data waiting to be read from socket, according to the protocol defined here: @@ -107,7 +109,7 @@ def next_frame_header(socket) -> tuple[int, int]: return (stream, actual) -def frames_iter(socket, tty: bool) -> t.Generator[tuple[int, bytes]]: +def frames_iter(socket: SocketLike, tty: bool) -> t.Generator[tuple[int, bytes]]: """ Return a generator of frames read from socket. A frame is a tuple where the first item is the stream number and the second item is a chunk of data. @@ -120,7 +122,7 @@ def frames_iter(socket, tty: bool) -> t.Generator[tuple[int, bytes]]: return frames_iter_no_tty(socket) -def frames_iter_no_tty(socket) -> t.Generator[tuple[int, bytes]]: +def frames_iter_no_tty(socket: SocketLike) -> t.Generator[tuple[int, bytes]]: """ Returns a generator of data read from the socket when the tty setting is not enabled. @@ -141,7 +143,7 @@ def frames_iter_no_tty(socket) -> t.Generator[tuple[int, bytes]]: yield (stream, result) -def frames_iter_tty(socket) -> t.Generator[bytes]: +def frames_iter_tty(socket: SocketLike) -> t.Generator[bytes]: """ Return a generator of data read from the socket when the tty setting is enabled. @@ -155,20 +157,42 @@ def frames_iter_tty(socket) -> t.Generator[bytes]: @t.overload -def consume_socket_output(frames, demux: t.Literal[False] = False) -> bytes: ... - - -@t.overload -def consume_socket_output(frames, demux: t.Literal[True]) -> tuple[bytes, bytes]: ... +def consume_socket_output( + frames: Sequence[bytes] | t.Generator[bytes], demux: t.Literal[False] = False +) -> bytes: ... @t.overload def consume_socket_output( - frames, demux: bool = False + frames: ( + Sequence[tuple[bytes | None, bytes | None]] + | t.Generator[tuple[bytes | None, bytes | None]] + ), + demux: t.Literal[True], +) -> tuple[bytes, bytes]: ... + + +@t.overload +def consume_socket_output( + frames: ( + Sequence[bytes] + | Sequence[tuple[bytes | None, bytes | None]] + | t.Generator[bytes] + | t.Generator[tuple[bytes | None, bytes | None]] + ), + demux: bool = False, ) -> bytes | tuple[bytes, bytes]: ... -def consume_socket_output(frames, demux: bool = False) -> bytes | tuple[bytes, bytes]: +def consume_socket_output( + frames: ( + Sequence[bytes] + | Sequence[tuple[bytes | None, bytes | None]] + | t.Generator[bytes] + | t.Generator[tuple[bytes | None, bytes | None]] + ), + demux: bool = False, +) -> bytes | tuple[bytes, bytes]: """ Iterate through frames read from the socket and return the result. @@ -183,12 +207,13 @@ def consume_socket_output(frames, demux: bool = False) -> bytes | tuple[bytes, b if demux is False: # If the streams are multiplexed, the generator returns strings, that # we just need to concatenate. - return b"".join(frames) + return b"".join(frames) # type: ignore # If the streams are demultiplexed, the generator yields tuples # (stdout, stderr) out: list[bytes | None] = [None, None] - for frame in frames: + frame: tuple[bytes | None, bytes | None] + for frame in frames: # type: ignore # It is guaranteed that for each frame, one and only one stream # is not None. if frame == (None, None): @@ -202,7 +227,7 @@ def consume_socket_output(frames, demux: bool = False) -> bytes | tuple[bytes, b if out[1] is None: out[1] = frame[1] else: - out[1] += frame[1] + out[1] += frame[1] # type: ignore[operator] return tuple(out) # type: ignore diff --git a/plugins/module_utils/_api/utils/utils.py b/plugins/module_utils/_api/utils/utils.py index 136b0f23..0ff758ae 100644 --- a/plugins/module_utils/_api/utils/utils.py +++ b/plugins/module_utils/_api/utils/utils.py @@ -46,7 +46,7 @@ URLComponents = collections.namedtuple( ) -def decode_json_header(header: str) -> dict[str, t.Any]: +def decode_json_header(header: str | bytes) -> dict[str, t.Any]: data = base64.b64decode(header).decode("utf-8") return json.loads(data) @@ -143,7 +143,12 @@ def convert_port_bindings( def convert_volume_binds( - binds: list[str] | Mapping[str | bytes, dict[str, str | bytes] | bytes | str | int], + binds: ( + list[str] + | Mapping[ + str | bytes, dict[str, str | bytes] | dict[str, str] | bytes | str | int + ] + ), ) -> list[str]: if isinstance(binds, list): return binds # type: ignore @@ -403,7 +408,9 @@ def kwargs_from_env( return params -def convert_filters(filters: Mapping[str, bool | str | list[str]]) -> str: +def convert_filters( + filters: Mapping[str, bool | str | int | list[int] | list[str] | list[str | int]], +) -> str: result = {} for k, v in filters.items(): if isinstance(v, bool): @@ -495,8 +502,8 @@ def split_command(command: str) -> list[str]: return shlex.split(command) -def format_environment(environment: Mapping[str, str | bytes]) -> list[str]: - def format_env(key, value): +def format_environment(environment: Mapping[str, str | bytes | None]) -> list[str]: + def format_env(key: str, value: str | bytes | None) -> str: if value is None: return key if isinstance(value, bytes): diff --git a/plugins/module_utils/_common.py b/plugins/module_utils/_common.py index 04588837..55a6cd3e 100644 --- a/plugins/module_utils/_common.py +++ b/plugins/module_utils/_common.py @@ -91,7 +91,7 @@ if not HAS_DOCKER_PY: # No Docker SDK for Python. Create a place holder client to allow # instantiation of AnsibleModule and proper error handing class Client: # type: ignore # noqa: F811, pylint: disable=function-redefined - def __init__(self, **kwargs): + def __init__(self, **kwargs: t.Any) -> None: pass class APIError(Exception): # type: ignore # noqa: F811, pylint: disable=function-redefined @@ -226,7 +226,7 @@ class AnsibleDockerClientBase(Client): f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}." ) - def log(self, msg: t.Any, pretty_print: bool = False): + def log(self, msg: t.Any, pretty_print: bool = False) -> None: pass # if self.debug: # from .util import log_debug @@ -609,7 +609,7 @@ class AnsibleDockerClientBase(Client): return new_tag, old_tag == new_tag - def inspect_distribution(self, image: str, **kwargs) -> dict[str, t.Any]: + def inspect_distribution(self, image: str, **kwargs: t.Any) -> dict[str, t.Any]: """ Get image digest by directly calling the Docker API when running Docker SDK < 4.0.0 since prior versions did not support accessing private repositories. @@ -629,7 +629,6 @@ class AnsibleDockerClientBase(Client): class AnsibleDockerClient(AnsibleDockerClientBase): - def __init__( self, argument_spec: dict[str, t.Any] | None = None, @@ -651,7 +650,6 @@ class AnsibleDockerClient(AnsibleDockerClientBase): option_minimal_versions_ignore_params: Sequence[str] | None = None, fail_results: dict[str, t.Any] | None = None, ): - # Modules can put information in here which will always be returned # in case client.fail() is called. self.fail_results = fail_results or {} diff --git a/plugins/module_utils/_common_api.py b/plugins/module_utils/_common_api.py index 7617d157..e2738a38 100644 --- a/plugins/module_utils/_common_api.py +++ b/plugins/module_utils/_common_api.py @@ -146,7 +146,7 @@ class AnsibleDockerClientBase(Client): f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}." ) - def log(self, msg: t.Any, pretty_print: bool = False): + def log(self, msg: t.Any, pretty_print: bool = False) -> None: pass # if self.debug: # from .util import log_debug @@ -295,7 +295,7 @@ class AnsibleDockerClientBase(Client): ), } - def depr(*args, **kwargs): + def depr(*args: t.Any, **kwargs: t.Any) -> None: self.deprecate(*args, **kwargs) update_tls_hostname( diff --git a/plugins/module_utils/_common_cli.py b/plugins/module_utils/_common_cli.py index 166d0e41..26bd3a70 100644 --- a/plugins/module_utils/_common_cli.py +++ b/plugins/module_utils/_common_cli.py @@ -82,7 +82,7 @@ class AnsibleDockerClientBase: def __init__( self, - common_args, + common_args: dict[str, t.Any], min_docker_api_version: str | None = None, needs_api_version: bool = True, ) -> None: @@ -91,15 +91,15 @@ class AnsibleDockerClientBase: self._environment["DOCKER_TLS_HOSTNAME"] = common_args["tls_hostname"] if common_args["api_version"] and common_args["api_version"] != "auto": self._environment["DOCKER_API_VERSION"] = common_args["api_version"] - self._cli = common_args.get("docker_cli") - if self._cli is None: + cli = common_args.get("docker_cli") + if cli is None: try: - self._cli = get_bin_path("docker") + cli = get_bin_path("docker") except ValueError: self.fail( "Cannot find docker CLI in path. Please provide it explicitly with the docker_cli parameter" ) - + self._cli = cli self._cli_base = [self._cli] docker_host = common_args["docker_host"] if not docker_host and not common_args["cli_context"]: @@ -149,7 +149,7 @@ class AnsibleDockerClientBase: "Internal error: cannot have needs_api_version=False with min_docker_api_version not None" ) - def log(self, msg: str, pretty_print: bool = False): + def log(self, msg: str, pretty_print: bool = False) -> None: pass # if self.debug: # from .util import log_debug @@ -227,7 +227,7 @@ class AnsibleDockerClientBase: return rc, result, stderr @abc.abstractmethod - def fail(self, msg: str, **kwargs) -> t.NoReturn: + def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn: pass @abc.abstractmethod @@ -395,7 +395,6 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase): fail_results: dict[str, t.Any] | None = None, needs_api_version: bool = True, ) -> None: - # Modules can put information in here which will always be returned # in case client.fail() is called. self.fail_results = fail_results or {} @@ -463,7 +462,7 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase): ) return rc, stdout, stderr - def fail(self, msg: str, **kwargs) -> t.NoReturn: + def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn: self.fail_results.update(kwargs) self.module.fail_json(msg=msg, **sanitize_result(self.fail_results)) diff --git a/plugins/module_utils/_compose_v2.py b/plugins/module_utils/_compose_v2.py index 253e9db9..4fdb1014 100644 --- a/plugins/module_utils/_compose_v2.py +++ b/plugins/module_utils/_compose_v2.py @@ -971,7 +971,7 @@ class BaseComposeManager(DockerBaseClass): stderr: str | bytes, ignore_service_pull_events: bool = False, ignore_build_events: bool = False, - ): + ) -> None: result["changed"] = result.get("changed", False) or has_changes( events, ignore_service_pull_events=ignore_service_pull_events, @@ -989,7 +989,7 @@ class BaseComposeManager(DockerBaseClass): stdout: str | bytes, stderr: bytes, rc: int, - ): + ) -> bool: return update_failed( result, events, diff --git a/plugins/module_utils/_copy.py b/plugins/module_utils/_copy.py index 11df3403..2e3ee6ab 100644 --- a/plugins/module_utils/_copy.py +++ b/plugins/module_utils/_copy.py @@ -330,6 +330,8 @@ def stat_file( client._raise_for_status(response) header = response.headers.get("x-docker-container-path-stat") try: + if header is None: + raise ValueError("x-docker-container-path-stat header not present") stat_data = json.loads(base64.b64decode(header)) except Exception as exc: raise DockerUnexpectedError( @@ -482,14 +484,14 @@ def fetch_file( shutil.copyfileobj(in_f, out_f) return in_path - def process_symlink(in_path, member) -> str: + def process_symlink(in_path: str, member: tarfile.TarInfo) -> str: if os.path.exists(b_out_path): os.unlink(b_out_path) os.symlink(member.linkname, b_out_path) return in_path - def process_other(in_path, member) -> str: + def process_other(in_path: str, member: tarfile.TarInfo) -> str: raise DockerFileCopyError( f'Remote file "{in_path}" is not a regular file or a symbolic link' ) diff --git a/plugins/module_utils/_module_container/base.py b/plugins/module_utils/_module_container/base.py index c0d1906b..82842789 100644 --- a/plugins/module_utils/_module_container/base.py +++ b/plugins/module_utils/_module_container/base.py @@ -193,7 +193,9 @@ class OptionGroup: ) -> None: if preprocess is None: - def preprocess(module, values): + def preprocess( + module: AnsibleModule, values: dict[str, t.Any] + ) -> dict[str, t.Any]: return values self.preprocess = preprocess @@ -207,8 +209,8 @@ class OptionGroup: self.ansible_required_by = ansible_required_by or {} self.argument_spec: dict[str, t.Any] = {} - def add_option(self, *args, **kwargs) -> OptionGroup: - option = Option(*args, owner=self, **kwargs) + def add_option(self, name: str, **kwargs: t.Any) -> OptionGroup: + option = Option(name, owner=self, **kwargs) if not option.not_a_container_option: self.options.append(option) self.all_options.append(option) @@ -788,7 +790,7 @@ def _preprocess_mounts( ) -> dict[str, t.Any]: last: dict[str, str] = {} - def check_collision(t, name): + def check_collision(t: str, name: str) -> None: if t in last: if name == last[t]: module.fail_json( @@ -1069,7 +1071,9 @@ def _preprocess_ports( return values -def _compare_platform(option: Option, param_value: t.Any, container_value: t.Any): +def _compare_platform( + option: Option, param_value: t.Any, container_value: t.Any +) -> bool: if option.comparison == "ignore": return True try: diff --git a/plugins/module_utils/_module_container/docker_api.py b/plugins/module_utils/_module_container/docker_api.py index 28776607..388e7f50 100644 --- a/plugins/module_utils/_module_container/docker_api.py +++ b/plugins/module_utils/_module_container/docker_api.py @@ -872,7 +872,7 @@ class DockerAPIEngine(Engine[AnsibleDockerClient]): image: dict[str, t.Any] | None, values: dict[str, t.Any], host_info: dict[str, t.Any] | None, - ): + ) -> dict[str, t.Any]: if len(options) != 1: raise AssertionError( "host_config_value can only be used for a single option" @@ -1961,7 +1961,14 @@ def _update_value_restart( } -def _get_values_ports(module, container, api_version, options, image, host_info): +def _get_values_ports( + module: AnsibleModule, + container: dict[str, t.Any], + api_version: LooseVersion, + options: list[Option], + image: dict[str, t.Any] | None, + host_info: dict[str, t.Any] | None, +) -> dict[str, t.Any]: host_config = container["HostConfig"] config = container["Config"] diff --git a/plugins/module_utils/_module_container/module.py b/plugins/module_utils/_module_container/module.py index c234ed55..86748795 100644 --- a/plugins/module_utils/_module_container/module.py +++ b/plugins/module_utils/_module_container/module.py @@ -292,7 +292,7 @@ class ContainerManager(DockerBaseClass, t.Generic[Client]): if self.module.params[param] is None: self.module.params[param] = value - def fail(self, *args, **kwargs) -> t.NoReturn: + def fail(self, *args: str, **kwargs: t.Any) -> t.NoReturn: # mypy doesn't know that Client has fail() method raise self.client.fail(*args, **kwargs) # type: ignore @@ -714,7 +714,7 @@ class ContainerManager(DockerBaseClass, t.Generic[Client]): container_image: dict[str, t.Any] | None, image: dict[str, t.Any] | None, host_info: dict[str, t.Any] | None, - ): + ) -> None: assert container.raw is not None container_values = engine.get_value( self.module, @@ -767,12 +767,12 @@ class ContainerManager(DockerBaseClass, t.Generic[Client]): # Since the order does not matter, sort so that the diff output is better. if option.name == "expected_mounts": # For selected values, use one entry as key - def sort_key_fn(x): + def sort_key_fn(x: dict[str, t.Any]) -> t.Any: return x["target"] else: # We sort the list of dictionaries by using the sorted items of a dict as its key. - def sort_key_fn(x): + def sort_key_fn(x: dict[str, t.Any]) -> t.Any: return sorted( (a, to_text(b, errors="surrogate_or_strict")) for a, b in x.items() diff --git a/plugins/module_utils/_socket_handler.py b/plugins/module_utils/_socket_handler.py index 67e0afd4..e0c3c1ef 100644 --- a/plugins/module_utils/_socket_handler.py +++ b/plugins/module_utils/_socket_handler.py @@ -26,6 +26,7 @@ from ansible_collections.community.docker.plugins.module_utils._socket_helper im if t.TYPE_CHECKING: from collections.abc import Callable + from types import TracebackType from ansible.module_utils.basic import AnsibleModule @@ -70,7 +71,12 @@ class DockerSocketHandlerBase: def __enter__(self) -> t.Self: return self - def __exit__(self, type_, value, tb) -> None: + def __exit__( + self, + type_: t.Type[BaseException] | None, + value: BaseException | None, + tb: TracebackType | None, + ) -> None: self._selector.close() def set_block_done_callback( @@ -210,7 +216,7 @@ class DockerSocketHandlerBase: stdout = [] stderr = [] - def append_block(stream_id, data): + def append_block(stream_id: int, data: bytes) -> None: if stream_id == docker_socket.STDOUT: stdout.append(data) elif stream_id == docker_socket.STDERR: diff --git a/plugins/module_utils/_swarm.py b/plugins/module_utils/_swarm.py index 699f5821..53fd2c7c 100644 --- a/plugins/module_utils/_swarm.py +++ b/plugins/module_utils/_swarm.py @@ -28,7 +28,6 @@ from ansible_collections.community.docker.plugins.module_utils._version import ( class AnsibleDockerSwarmClient(AnsibleDockerClient): - def get_swarm_node_id(self) -> str | None: """ Get the 'NodeID' of the Swarm node or 'None' if host is not in Swarm. It returns the NodeID @@ -281,7 +280,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient): def get_node_name_by_id(self, nodeid: str) -> str: return self.get_node_inspect(nodeid)["Description"]["Hostname"] - def get_unlock_key(self) -> str | None: + def get_unlock_key(self) -> dict[str, t.Any] | None: if self.docker_py_version < LooseVersion("2.7.0"): return None return super().get_unlock_key() diff --git a/plugins/module_utils/_util.py b/plugins/module_utils/_util.py index 171796ea..62f793c7 100644 --- a/plugins/module_utils/_util.py +++ b/plugins/module_utils/_util.py @@ -23,6 +23,12 @@ if t.TYPE_CHECKING: from ansible.module_utils.basic import AnsibleModule + from ._common import AnsibleDockerClientBase as CADCB + from ._common_api import AnsibleDockerClientBase as CAPIADCB + from ._common_cli import AnsibleDockerClientBase as CCLIADCB + + Client = t.Union[CADCB, CAPIADCB, CCLIADCB] + DEFAULT_DOCKER_HOST = "unix:///var/run/docker.sock" DEFAULT_TLS = False @@ -119,7 +125,7 @@ def sanitize_result(data: t.Any) -> t.Any: return data -def log_debug(msg: t.Any, pretty_print: bool = False): +def log_debug(msg: t.Any, pretty_print: bool = False) -> None: """Write a log message to docker.log. If ``pretty_print=True``, the message will be pretty-printed as JSON. @@ -325,7 +331,7 @@ class DifferenceTracker: def sanitize_labels( labels: dict[str, t.Any] | None, labels_field: str, - client=None, + client: Client | None = None, module: AnsibleModule | None = None, ) -> None: def fail(msg: str) -> t.NoReturn: @@ -371,7 +377,7 @@ def clean_dict_booleans_for_docker_api( which is the expected format of filters which accept lists such as labels. """ - def sanitize(value): + def sanitize(value: t.Any) -> str: if value is True: return "true" if value is False: diff --git a/plugins/modules/docker_compose_v2_pull.py b/plugins/modules/docker_compose_v2_pull.py index 1e00af40..4bb91148 100644 --- a/plugins/modules/docker_compose_v2_pull.py +++ b/plugins/modules/docker_compose_v2_pull.py @@ -147,7 +147,7 @@ class PullManager(BaseComposeManager): f"--ignore-buildable is only supported since Docker Compose 2.15.0. {self.client.get_cli()} has version {self.compose_version}" ) - def get_pull_cmd(self, dry_run: bool): + def get_pull_cmd(self, dry_run: bool) -> list[str]: args = self.get_base_args() + ["pull"] if self.policy != "always": args.extend(["--policy", self.policy]) diff --git a/plugins/modules/docker_config.py b/plugins/modules/docker_config.py index e33a47de..46e5ee6a 100644 --- a/plugins/modules/docker_config.py +++ b/plugins/modules/docker_config.py @@ -198,6 +198,7 @@ config_name: import base64 import hashlib import traceback +import typing as t try: @@ -220,9 +221,7 @@ from ansible_collections.community.docker.plugins.module_utils._util import ( class ConfigManager(DockerBaseClass): - - def __init__(self, client, results): - + def __init__(self, client: AnsibleDockerClient, results: dict[str, t.Any]) -> None: super().__init__() self.client = client @@ -253,10 +252,10 @@ class ConfigManager(DockerBaseClass): if self.rolling_versions: self.version = 0 - self.data_key = None - self.configs = [] + self.data_key: str | None = None + self.configs: list[dict[str, t.Any]] = [] - def __call__(self): + def __call__(self) -> None: self.get_config() if self.state == "present": self.data_key = hashlib.sha224(self.data).hexdigest() @@ -265,7 +264,7 @@ class ConfigManager(DockerBaseClass): elif self.state == "absent": self.absent() - def get_version(self, config): + def get_version(self, config: dict[str, t.Any]) -> int: try: return int( config.get("Spec", {}).get("Labels", {}).get("ansible_version", 0) @@ -273,14 +272,14 @@ class ConfigManager(DockerBaseClass): except ValueError: return 0 - def remove_old_versions(self): + def remove_old_versions(self) -> None: if not self.rolling_versions or self.versions_to_keep < 0: return if not self.check_mode: while len(self.configs) > max(self.versions_to_keep, 1): self.remove_config(self.configs.pop(0)) - def get_config(self): + def get_config(self) -> None: """Find an existing config.""" try: configs = self.client.configs(filters={"name": self.name}) @@ -299,9 +298,9 @@ class ConfigManager(DockerBaseClass): config for config in configs if config["Spec"]["Name"] == self.name ] - def create_config(self): + def create_config(self) -> str | None: """Create a new config""" - config_id = None + config_id: str | dict[str, t.Any] | None = None # We ca not see the data after creation, so adding a label we can use for idempotency check labels = {"ansible_key": self.data_key} if self.rolling_versions: @@ -325,18 +324,18 @@ class ConfigManager(DockerBaseClass): self.client.fail(f"Error creating config: {exc}") if isinstance(config_id, dict): - config_id = config_id["ID"] + return config_id["ID"] return config_id - def remove_config(self, config): + def remove_config(self, config: dict[str, t.Any]) -> None: try: if not self.check_mode: self.client.remove_config(config["ID"]) except APIError as exc: self.client.fail(f"Error removing config {config['Spec']['Name']}: {exc}") - def present(self): + def present(self) -> None: """Handles state == 'present', creating or updating the config""" if self.configs: config = self.configs[-1] @@ -378,7 +377,7 @@ class ConfigManager(DockerBaseClass): self.results["config_id"] = self.create_config() self.results["config_name"] = self.name - def absent(self): + def absent(self) -> None: """Handles state == 'absent', removing the config""" if self.configs: for config in self.configs: @@ -386,7 +385,7 @@ class ConfigManager(DockerBaseClass): self.results["changed"] = True -def main(): +def main() -> None: argument_spec = { "name": {"type": "str", "required": True}, "state": { diff --git a/plugins/modules/docker_container_copy_into.py b/plugins/modules/docker_container_copy_into.py index 9c7575ba..f52bf098 100644 --- a/plugins/modules/docker_container_copy_into.py +++ b/plugins/modules/docker_container_copy_into.py @@ -347,7 +347,7 @@ def retrieve_diff( max_file_size_for_diff: int, regular_stat: dict[str, t.Any] | None = None, link_target: str | None = None, -): +) -> None: if diff is None: return if regular_stat is not None: @@ -497,9 +497,9 @@ def is_file_idempotent( container_path: str, follow_links: bool, local_follow_links: bool, - owner_id, - group_id, - mode, + owner_id: int, + group_id: int, + mode: int | None, force: bool | None = False, diff: dict[str, t.Any] | None = None, max_file_size_for_diff: int = 1, @@ -744,9 +744,9 @@ def copy_file_into_container( container_path: str, follow_links: bool, local_follow_links: bool, - owner_id, - group_id, - mode, + owner_id: int, + group_id: int, + mode: int | None, force: bool | None = False, do_diff: bool = False, max_file_size_for_diff: int = 1, @@ -797,9 +797,9 @@ def is_content_idempotent( content: bytes, container_path: str, follow_links: bool, - owner_id, - group_id, - mode, + owner_id: int, + group_id: int, + mode: int, force: bool | None = False, diff: dict[str, t.Any] | None = None, max_file_size_for_diff: int = 1, @@ -989,9 +989,9 @@ def copy_content_into_container( content: bytes, container_path: str, follow_links: bool, - owner_id, - group_id, - mode, + owner_id: int, + group_id: int, + mode: int, force: bool | None = False, do_diff: bool = False, max_file_size_for_diff: int = 1, @@ -1133,6 +1133,7 @@ def main() -> None: owner_id, group_id = determine_user_group(client, container) if content is not None: + assert mode is not None # see required_by above copy_content_into_container( client, container, diff --git a/plugins/modules/docker_image.py b/plugins/modules/docker_image.py index 4605ba46..d2f42d6a 100644 --- a/plugins/modules/docker_image.py +++ b/plugins/modules/docker_image.py @@ -667,7 +667,7 @@ class ImageManager(DockerBaseClass): :rtype: str """ - def build_msg(reason): + def build_msg(reason: str) -> str: return f"Archived image {current_image_name} to {archive_path}, {reason}" try: @@ -877,7 +877,7 @@ class ImageManager(DockerBaseClass): self.push_image(repo, repo_tag) @staticmethod - def _extract_output_line(line: dict[str, t.Any], output: list[str]): + def _extract_output_line(line: dict[str, t.Any], output: list[str]) -> None: """ Extract text line from stream output and, if found, adds it to output. """ @@ -1165,18 +1165,18 @@ def main() -> None: ("source", "load", ["load_path"]), ] - def detect_etc_hosts(client): + def detect_etc_hosts(client: AnsibleDockerClient) -> bool: return client.module.params["build"] and bool( client.module.params["build"].get("etc_hosts") ) - def detect_build_platform(client): + def detect_build_platform(client: AnsibleDockerClient) -> bool: return ( client.module.params["build"] and client.module.params["build"].get("platform") is not None ) - def detect_pull_platform(client): + def detect_pull_platform(client: AnsibleDockerClient) -> bool: return ( client.module.params["pull"] and client.module.params["pull"].get("platform") is not None diff --git a/plugins/modules/docker_network.py b/plugins/modules/docker_network.py index 2386e0ef..4df3ef66 100644 --- a/plugins/modules/docker_network.py +++ b/plugins/modules/docker_network.py @@ -379,7 +379,7 @@ def normalize_ipam_config_key(key: str) -> str: return special_cases.get(key, key.lower()) -def dicts_are_essentially_equal(a: dict[str, t.Any], b: dict[str, t.Any]): +def dicts_are_essentially_equal(a: dict[str, t.Any], b: dict[str, t.Any]) -> bool: """Make sure that a is a subset of b, where None entries of a are ignored.""" for k, v in a.items(): if v is None: diff --git a/plugins/modules/docker_node.py b/plugins/modules/docker_node.py index b0cff03a..524c240e 100644 --- a/plugins/modules/docker_node.py +++ b/plugins/modules/docker_node.py @@ -134,6 +134,7 @@ node: """ import traceback +import typing as t try: @@ -157,18 +158,19 @@ from ansible_collections.community.docker.plugins.module_utils._util import ( class TaskParameters(DockerBaseClass): - def __init__(self, client): + hostname: str + + def __init__(self, client: AnsibleDockerSwarmClient) -> None: super().__init__() # Spec - self.name = None - self.labels = None - self.labels_state = None - self.labels_to_remove = None + self.labels: dict[str, t.Any] | None = None + self.labels_state: t.Literal["merge", "replace"] = "merge" + self.labels_to_remove: list[str] | None = None # Node - self.availability = None - self.role = None + self.availability: t.Literal["active", "pause", "drain"] | None = None + self.role: t.Literal["worker", "manager"] | None = None for key, value in client.module.params.items(): setattr(self, key, value) @@ -177,9 +179,9 @@ class TaskParameters(DockerBaseClass): class SwarmNodeManager(DockerBaseClass): - - def __init__(self, client, results): - + def __init__( + self, client: AnsibleDockerSwarmClient, results: dict[str, t.Any] + ) -> None: super().__init__() self.client = client @@ -192,10 +194,9 @@ class SwarmNodeManager(DockerBaseClass): self.node_update() - def node_update(self): + def node_update(self) -> None: if not (self.client.check_if_swarm_node(node_id=self.parameters.hostname)): self.client.fail("This node is not part of a swarm.") - return if self.client.check_if_swarm_node_is_down(): self.client.fail("Can not update the node. The node is down.") @@ -206,7 +207,7 @@ class SwarmNodeManager(DockerBaseClass): self.client.fail(f"Failed to get node information for {exc}") changed = False - node_spec = { + node_spec: dict[str, t.Any] = { "Availability": self.parameters.availability, "Role": self.parameters.role, "Labels": self.parameters.labels, @@ -277,7 +278,7 @@ class SwarmNodeManager(DockerBaseClass): self.results["changed"] = changed -def main(): +def main() -> None: argument_spec = { "hostname": {"type": "str", "required": True}, "labels": {"type": "dict"}, diff --git a/plugins/modules/docker_node_info.py b/plugins/modules/docker_node_info.py index 32be09e4..2ed33b76 100644 --- a/plugins/modules/docker_node_info.py +++ b/plugins/modules/docker_node_info.py @@ -87,6 +87,7 @@ nodes: """ import traceback +import typing as t from ansible_collections.community.docker.plugins.module_utils._common import ( RequestException, @@ -103,9 +104,8 @@ except ImportError: pass -def get_node_facts(client): - - results = [] +def get_node_facts(client: AnsibleDockerSwarmClient) -> list[dict[str, t.Any]]: + results: list[dict[str, t.Any]] = [] if client.module.params["self"] is True: self_node_id = client.get_swarm_node_id() @@ -114,8 +114,8 @@ def get_node_facts(client): return results if client.module.params["name"] is None: - node_info = client.get_all_nodes_inspect() - return node_info + node_info_list = client.get_all_nodes_inspect() + return node_info_list nodes = client.module.params["name"] if not isinstance(nodes, list): @@ -130,7 +130,7 @@ def get_node_facts(client): return results -def main(): +def main() -> None: argument_spec = { "name": {"type": "list", "elements": "str"}, "self": {"type": "bool", "default": False}, diff --git a/plugins/modules/docker_plugin.py b/plugins/modules/docker_plugin.py index 024cd282..e480845c 100644 --- a/plugins/modules/docker_plugin.py +++ b/plugins/modules/docker_plugin.py @@ -204,12 +204,10 @@ class DockerPluginManager: elif state == "disable": self.disable() - if self.diff or self.check_mode or self.parameters.debug: - if self.diff: - self.diff_result["before"], self.diff_result["after"] = ( - self.diff_tracker.get_before_after() - ) - self.diff = self.diff_result + if self.diff: + self.diff_result["before"], self.diff_result["after"] = ( + self.diff_tracker.get_before_after() + ) def get_existing_plugin(self) -> dict[str, t.Any] | None: try: @@ -409,7 +407,7 @@ class DockerPluginManager: result: dict[str, t.Any] = { "actions": self.actions, "changed": self.changed, - "diff": self.diff, + "diff": self.diff_result, "plugin": plugin_data, } if ( diff --git a/plugins/modules/docker_secret.py b/plugins/modules/docker_secret.py index 211e83d2..afe3d12b 100644 --- a/plugins/modules/docker_secret.py +++ b/plugins/modules/docker_secret.py @@ -190,6 +190,7 @@ secret_name: import base64 import hashlib import traceback +import typing as t try: @@ -212,9 +213,7 @@ from ansible_collections.community.docker.plugins.module_utils._util import ( class SecretManager(DockerBaseClass): - - def __init__(self, client, results): - + def __init__(self, client: AnsibleDockerClient, results: dict[str, t.Any]) -> None: super().__init__() self.client = client @@ -244,10 +243,10 @@ class SecretManager(DockerBaseClass): if self.rolling_versions: self.version = 0 - self.data_key = None - self.secrets = [] + self.data_key: str | None = None + self.secrets: list[dict[str, t.Any]] = [] - def __call__(self): + def __call__(self) -> None: self.get_secret() if self.state == "present": self.data_key = hashlib.sha224(self.data).hexdigest() @@ -256,7 +255,7 @@ class SecretManager(DockerBaseClass): elif self.state == "absent": self.absent() - def get_version(self, secret): + def get_version(self, secret: dict[str, t.Any]) -> int: try: return int( secret.get("Spec", {}).get("Labels", {}).get("ansible_version", 0) @@ -264,14 +263,14 @@ class SecretManager(DockerBaseClass): except ValueError: return 0 - def remove_old_versions(self): + def remove_old_versions(self) -> None: if not self.rolling_versions or self.versions_to_keep < 0: return if not self.check_mode: while len(self.secrets) > max(self.versions_to_keep, 1): self.remove_secret(self.secrets.pop(0)) - def get_secret(self): + def get_secret(self) -> None: """Find an existing secret.""" try: secrets = self.client.secrets(filters={"name": self.name}) @@ -290,9 +289,9 @@ class SecretManager(DockerBaseClass): secret for secret in secrets if secret["Spec"]["Name"] == self.name ] - def create_secret(self): + def create_secret(self) -> str | None: """Create a new secret""" - secret_id = None + secret_id: str | dict[str, t.Any] | None = None # We cannot see the data after creation, so adding a label we can use for idempotency check labels = {"ansible_key": self.data_key} if self.rolling_versions: @@ -312,18 +311,18 @@ class SecretManager(DockerBaseClass): self.client.fail(f"Error creating secret: {exc}") if isinstance(secret_id, dict): - secret_id = secret_id["ID"] + return secret_id["ID"] return secret_id - def remove_secret(self, secret): + def remove_secret(self, secret: dict[str, t.Any]) -> None: try: if not self.check_mode: self.client.remove_secret(secret["ID"]) except APIError as exc: self.client.fail(f"Error removing secret {secret['Spec']['Name']}: {exc}") - def present(self): + def present(self) -> None: """Handles state == 'present', creating or updating the secret""" if self.secrets: secret = self.secrets[-1] @@ -357,7 +356,7 @@ class SecretManager(DockerBaseClass): self.results["secret_id"] = self.create_secret() self.results["secret_name"] = self.name - def absent(self): + def absent(self) -> None: """Handles state == 'absent', removing the secret""" if self.secrets: for secret in self.secrets: @@ -365,7 +364,7 @@ class SecretManager(DockerBaseClass): self.results["changed"] = True -def main(): +def main() -> None: argument_spec = { "name": {"type": "str", "required": True}, "state": { diff --git a/plugins/modules/docker_stack.py b/plugins/modules/docker_stack.py index 8a9f12ef..413f155d 100644 --- a/plugins/modules/docker_stack.py +++ b/plugins/modules/docker_stack.py @@ -158,6 +158,7 @@ import json import os import tempfile import traceback +import typing as t from time import sleep from ansible.module_utils.common.text.converters import to_native @@ -183,7 +184,9 @@ except ImportError: HAS_YAML = False -def docker_stack_services(client, stack_name): +def docker_stack_services( + client: AnsibleModuleDockerClient, stack_name: str +) -> list[str]: dummy_rc, out, err = client.call_cli( "stack", "services", stack_name, "--format", "{{.Name}}" ) @@ -192,7 +195,9 @@ def docker_stack_services(client, stack_name): return to_native(out).strip().split("\n") -def docker_service_inspect(client, service_name): +def docker_service_inspect( + client: AnsibleModuleDockerClient, service_name: str +) -> dict[str, t.Any] | None: rc, out, dummy_err = client.call_cli("service", "inspect", service_name) if rc != 0: return None @@ -200,7 +205,9 @@ def docker_service_inspect(client, service_name): return ret -def docker_stack_deploy(client, stack_name, compose_files): +def docker_stack_deploy( + client: AnsibleModuleDockerClient, stack_name: str, compose_files: list[str] +) -> tuple[int, str, str]: command = ["stack", "deploy"] if client.module.params["prune"]: command += ["--prune"] @@ -217,14 +224,21 @@ def docker_stack_deploy(client, stack_name, compose_files): return rc, to_native(out), to_native(err) -def docker_stack_inspect(client, stack_name): - ret = {} +def docker_stack_inspect( + client: AnsibleModuleDockerClient, stack_name: str +) -> dict[str, dict[str, t.Any] | None]: + ret: dict[str, dict[str, t.Any] | None] = {} for service_name in docker_stack_services(client, stack_name): ret[service_name] = docker_service_inspect(client, service_name) return ret -def docker_stack_rm(client, stack_name, retries, interval): +def docker_stack_rm( + client: AnsibleModuleDockerClient, + stack_name: str, + retries: int, + interval: int | float, +) -> tuple[int, str, str]: command = ["stack", "rm", stack_name] if not client.module.params["detach"]: command += ["--detach=false"] @@ -237,7 +251,7 @@ def docker_stack_rm(client, stack_name, retries, interval): return rc, to_native(out), to_native(err) -def main(): +def main() -> None: client = AnsibleModuleDockerClient( argument_spec={ "name": {"type": "str", "required": True}, @@ -258,10 +272,10 @@ def main(): ) if not HAS_JSONDIFF: - return client.fail("jsondiff is not installed, try 'pip install jsondiff'") + client.fail("jsondiff is not installed, try 'pip install jsondiff'") if not HAS_YAML: - return client.fail("yaml is not installed, try 'pip install pyyaml'") + client.fail("yaml is not installed, try 'pip install pyyaml'") try: state = client.module.params["state"] diff --git a/plugins/modules/docker_stack_info.py b/plugins/modules/docker_stack_info.py index 8ae6f09f..c9117a3d 100644 --- a/plugins/modules/docker_stack_info.py +++ b/plugins/modules/docker_stack_info.py @@ -85,16 +85,7 @@ from ansible_collections.community.docker.plugins.module_utils._common_cli impor ) -def docker_stack_list(module): - docker_bin = module.get_bin_path("docker", required=True) - rc, out, err = module.run_command( - [docker_bin, "stack", "ls", "--format={{json .}}"] - ) - - return rc, out.strip(), err.strip() - - -def main(): +def main() -> None: client = AnsibleModuleDockerClient( argument_spec={}, supports_check_mode=True, diff --git a/plugins/modules/docker_stack_task_info.py b/plugins/modules/docker_stack_task_info.py index f67d26b1..2eba3305 100644 --- a/plugins/modules/docker_stack_task_info.py +++ b/plugins/modules/docker_stack_task_info.py @@ -93,16 +93,7 @@ from ansible_collections.community.docker.plugins.module_utils._common_cli impor ) -def docker_stack_task(module, stack_name): - docker_bin = module.get_bin_path("docker", required=True) - rc, out, err = module.run_command( - [docker_bin, "stack", "ps", stack_name, "--format={{json .}}"] - ) - - return rc, out.strip(), err.strip() - - -def main(): +def main() -> None: client = AnsibleModuleDockerClient( argument_spec={"name": {"type": "str", "required": True}}, supports_check_mode=True, diff --git a/plugins/modules/docker_swarm.py b/plugins/modules/docker_swarm.py index ee68b59e..46b58377 100644 --- a/plugins/modules/docker_swarm.py +++ b/plugins/modules/docker_swarm.py @@ -292,6 +292,7 @@ actions: import json import traceback +import typing as t try: @@ -314,40 +315,40 @@ from ansible_collections.community.docker.plugins.module_utils._util import ( class TaskParameters(DockerBaseClass): - def __init__(self): + def __init__(self) -> None: super().__init__() - self.advertise_addr = None - self.listen_addr = None - self.remote_addrs = None - self.join_token = None - self.data_path_addr = None - self.data_path_port = None + self.advertise_addr: str | None = None + self.listen_addr: str | None = None + self.remote_addrs: list[str] | None = None + self.join_token: str | None = None + self.data_path_addr: str | None = None + self.data_path_port: int | None = None self.spec = None # Spec - self.snapshot_interval = None - self.task_history_retention_limit = None - self.keep_old_snapshots = None - self.log_entries_for_slow_followers = None - self.heartbeat_tick = None - self.election_tick = None - self.dispatcher_heartbeat_period = None - self.node_cert_expiry = None - self.name = None - self.labels = None + self.snapshot_interval: int | None = None + self.task_history_retention_limit: int | None = None + self.keep_old_snapshots: int | None = None + self.log_entries_for_slow_followers: int | None = None + self.heartbeat_tick: int | None = None + self.election_tick: int | None = None + self.dispatcher_heartbeat_period: int | None = None + self.node_cert_expiry: int | None = None + self.name: str | None = None + self.labels: dict[str, t.Any] | None = None self.log_driver = None - self.signing_ca_cert = None - self.signing_ca_key = None - self.ca_force_rotate = None - self.autolock_managers = None - self.rotate_worker_token = None - self.rotate_manager_token = None - self.default_addr_pool = None - self.subnet_size = None + self.signing_ca_cert: str | None = None + self.signing_ca_key: str | None = None + self.ca_force_rotate: int | None = None + self.autolock_managers: bool | None = None + self.rotate_worker_token: bool | None = None + self.rotate_manager_token: bool | None = None + self.default_addr_pool: list[str] | None = None + self.subnet_size: int | None = None @staticmethod - def from_ansible_params(client): + def from_ansible_params(client: AnsibleDockerSwarmClient) -> TaskParameters: result = TaskParameters() for key, value in client.module.params.items(): if key in result.__dict__: @@ -356,7 +357,7 @@ class TaskParameters(DockerBaseClass): result.update_parameters(client) return result - def update_from_swarm_info(self, swarm_info): + def update_from_swarm_info(self, swarm_info: dict[str, t.Any]) -> None: spec = swarm_info["Spec"] ca_config = spec.get("CAConfig") or {} @@ -400,7 +401,7 @@ class TaskParameters(DockerBaseClass): if "LogDriver" in spec["TaskDefaults"]: self.log_driver = spec["TaskDefaults"]["LogDriver"] - def update_parameters(self, client): + def update_parameters(self, client: AnsibleDockerSwarmClient) -> None: assign = { "snapshot_interval": "snapshot_interval", "task_history_retention_limit": "task_history_retention_limit", @@ -427,7 +428,12 @@ class TaskParameters(DockerBaseClass): params[dest] = value self.spec = client.create_swarm_spec(**params) - def compare_to_active(self, other, client, differences): + def compare_to_active( + self, + other: TaskParameters, + client: AnsibleDockerSwarmClient, + differences: DifferenceTracker, + ) -> DifferenceTracker: for k in self.__dict__: if k in ( "advertise_addr", @@ -459,26 +465,28 @@ class TaskParameters(DockerBaseClass): class SwarmManager(DockerBaseClass): - - def __init__(self, client, results): - + def __init__( + self, client: AnsibleDockerSwarmClient, results: dict[str, t.Any] + ) -> None: super().__init__() self.client = client self.results = results self.check_mode = self.client.check_mode - self.swarm_info = {} + self.swarm_info: dict[str, t.Any] = {} - self.state = client.module.params["state"] - self.force = client.module.params["force"] - self.node_id = client.module.params["node_id"] + self.state: t.Literal["present", "join", "absent", "remove"] = ( + client.module.params["state"] + ) + self.force: bool = client.module.params["force"] + self.node_id: str | None = client.module.params["node_id"] self.differences = DifferenceTracker() self.parameters = TaskParameters.from_ansible_params(client) self.created = False - def __call__(self): + def __call__(self) -> None: choice_map = { "present": self.init_swarm, "join": self.join, @@ -486,14 +494,14 @@ class SwarmManager(DockerBaseClass): "remove": self.remove, } - choice_map.get(self.state)() + choice_map[self.state]() if self.client.module._diff or self.parameters.debug: diff = {} diff["before"], diff["after"] = self.differences.get_before_after() self.results["diff"] = diff - def inspect_swarm(self): + def inspect_swarm(self) -> None: try: data = self.client.inspect_swarm() json_str = json.dumps(data, ensure_ascii=False) @@ -507,7 +515,7 @@ class SwarmManager(DockerBaseClass): except APIError: pass - def get_unlock_key(self): + def get_unlock_key(self) -> dict[str, t.Any]: default = {"UnlockKey": None} if not self.has_swarm_lock_changed(): return default @@ -516,18 +524,18 @@ class SwarmManager(DockerBaseClass): except APIError: return default - def has_swarm_lock_changed(self): - return self.parameters.autolock_managers and ( + def has_swarm_lock_changed(self) -> bool: + return bool(self.parameters.autolock_managers) and ( self.created or self.differences.has_difference_for("autolock_managers") ) - def init_swarm(self): + def init_swarm(self) -> None: if not self.force and self.client.check_if_swarm_manager(): self.__update_swarm() return if not self.check_mode: - init_arguments = { + init_arguments: dict[str, t.Any] = { "advertise_addr": self.parameters.advertise_addr, "listen_addr": self.parameters.listen_addr, "force_new_cluster": self.force, @@ -562,7 +570,7 @@ class SwarmManager(DockerBaseClass): "UnlockKey": self.swarm_info.get("UnlockKey"), } - def __update_swarm(self): + def __update_swarm(self) -> None: try: self.inspect_swarm() version = self.swarm_info["Version"]["Index"] @@ -587,13 +595,12 @@ class SwarmManager(DockerBaseClass): ) except APIError as exc: self.client.fail(f"Can not update a Swarm Cluster: {exc}") - return self.inspect_swarm() self.results["actions"].append("Swarm cluster updated") self.results["changed"] = True - def join(self): + def join(self) -> None: if self.client.check_if_swarm_node(): self.results["actions"].append("This node is already part of a swarm.") return @@ -614,7 +621,7 @@ class SwarmManager(DockerBaseClass): self.differences.add("joined", parameter=True, active=False) self.results["changed"] = True - def leave(self): + def leave(self) -> None: if not self.client.check_if_swarm_node(): self.results["actions"].append("This node is not part of a swarm.") return @@ -627,7 +634,7 @@ class SwarmManager(DockerBaseClass): self.differences.add("joined", parameter="absent", active="present") self.results["changed"] = True - def remove(self): + def remove(self) -> None: if not self.client.check_if_swarm_manager(): self.client.fail("This node is not a manager.") @@ -655,11 +662,12 @@ class SwarmManager(DockerBaseClass): self.results["changed"] = True -def _detect_remove_operation(client): +def _detect_remove_operation(client: AnsibleDockerSwarmClient) -> bool: return client.module.params["state"] == "remove" -def main(): +def main() -> None: + # TODO: missing option log_driver? argument_spec = { "advertise_addr": {"type": "str"}, "data_path_addr": {"type": "str"}, diff --git a/plugins/modules/docker_swarm_info.py b/plugins/modules/docker_swarm_info.py index be157886..a3d106d5 100644 --- a/plugins/modules/docker_swarm_info.py +++ b/plugins/modules/docker_swarm_info.py @@ -186,6 +186,7 @@ tasks: """ import traceback +import typing as t try: @@ -207,16 +208,20 @@ from ansible_collections.community.docker.plugins.module_utils._util import ( class DockerSwarmManager(DockerBaseClass): - - def __init__(self, client, results): - + def __init__( + self, client: AnsibleDockerSwarmClient, results: dict[str, t.Any] + ) -> None: super().__init__() self.client = client self.results = results self.verbose_output = self.client.module.params["verbose_output"] - listed_objects = ["tasks", "services", "nodes"] + listed_objects: list[t.Literal["nodes", "tasks", "services"]] = [ + "tasks", + "services", + "nodes", + ] self.client.fail_task_if_not_swarm_manager() @@ -235,15 +240,18 @@ class DockerSwarmManager(DockerBaseClass): if self.client.module.params["unlock_key"]: self.results["swarm_unlock_key"] = self.get_docker_swarm_unlock_key() - def get_docker_swarm_facts(self): + def get_docker_swarm_facts(self) -> dict[str, t.Any]: try: return self.client.inspect_swarm() except APIError as exc: self.client.fail(f"Error inspecting docker swarm: {exc}") - def get_docker_items_list(self, docker_object=None, filters=None): - items = None - items_list = [] + def get_docker_items_list( + self, + docker_object: t.Literal["nodes", "tasks", "services"], + filters: dict[str, str], + ) -> list[dict[str, t.Any]]: + items_list: list[dict[str, t.Any]] = [] try: if docker_object == "nodes": @@ -252,6 +260,8 @@ class DockerSwarmManager(DockerBaseClass): items = self.client.tasks(filters=filters) elif docker_object == "services": items = self.client.services(filters=filters) + else: + raise ValueError(f"Invalid docker_object {docker_object}") except APIError as exc: self.client.fail( f"Error inspecting docker swarm for object '{docker_object}': {exc}" @@ -276,7 +286,7 @@ class DockerSwarmManager(DockerBaseClass): return items_list @staticmethod - def get_essential_facts_nodes(item): + def get_essential_facts_nodes(item: dict[str, t.Any]) -> dict[str, t.Any]: object_essentials = {} object_essentials["ID"] = item.get("ID") @@ -298,7 +308,7 @@ class DockerSwarmManager(DockerBaseClass): return object_essentials - def get_essential_facts_tasks(self, item): + def get_essential_facts_tasks(self, item: dict[str, t.Any]) -> dict[str, t.Any]: object_essentials = {} object_essentials["ID"] = item["ID"] @@ -319,7 +329,7 @@ class DockerSwarmManager(DockerBaseClass): return object_essentials @staticmethod - def get_essential_facts_services(item): + def get_essential_facts_services(item: dict[str, t.Any]) -> dict[str, t.Any]: object_essentials = {} object_essentials["ID"] = item["ID"] @@ -343,12 +353,12 @@ class DockerSwarmManager(DockerBaseClass): return object_essentials - def get_docker_swarm_unlock_key(self): + def get_docker_swarm_unlock_key(self) -> str | None: unlock_key = self.client.get_unlock_key() or {} return unlock_key.get("UnlockKey") or None -def main(): +def main() -> None: argument_spec = { "nodes": {"type": "bool", "default": False}, "nodes_filters": {"type": "dict"}, diff --git a/plugins/modules/docker_swarm_service.py b/plugins/modules/docker_swarm_service.py index eb6cc1cb..9183349a 100644 --- a/plugins/modules/docker_swarm_service.py +++ b/plugins/modules/docker_swarm_service.py @@ -853,6 +853,7 @@ EXAMPLES = r""" import shlex import time import traceback +import typing as t from ansible.module_utils.basic import human_to_bytes from ansible.module_utils.common.text.converters import to_text @@ -891,7 +892,9 @@ except ImportError: pass -def get_docker_environment(env, env_files): +def get_docker_environment( + env: str | dict[str, t.Any] | list[t.Any] | None, env_files: list[str] | None +) -> list[str] | None: """ Will return a list of "KEY=VALUE" items. Supplied env variable can be either a list or a dictionary. @@ -899,7 +902,7 @@ def get_docker_environment(env, env_files): If environment files are combined with explicit environment variables, the explicit environment variables take precedence. """ - env_dict = {} + env_dict: dict[str, str] = {} if env_files: for env_file in env_files: parsed_env_file = parse_env_file(env_file) @@ -936,7 +939,21 @@ def get_docker_environment(env, env_files): return sorted(env_list) -def get_docker_networks(networks, network_ids): +@t.overload +def get_docker_networks( + networks: list[str | dict[str, t.Any]], network_ids: dict[str, str] +) -> list[dict[str, t.Any]]: ... + + +@t.overload +def get_docker_networks( + networks: list[str | dict[str, t.Any]] | None, network_ids: dict[str, str] +) -> list[dict[str, t.Any]] | None: ... + + +def get_docker_networks( + networks: list[str | dict[str, t.Any]] | None, network_ids: dict[str, str] +) -> list[dict[str, t.Any]] | None: """ Validate a list of network names or a list of network dictionaries. Network names will be resolved to ids by using the network_ids mapping. @@ -945,6 +962,7 @@ def get_docker_networks(networks, network_ids): return None parsed_networks = [] for network in networks: + parsed_network: dict[str, t.Any] if isinstance(network, str): parsed_network = {"name": network} elif isinstance(network, dict): @@ -988,7 +1006,7 @@ def get_docker_networks(networks, network_ids): return parsed_networks or [] -def get_nanoseconds_from_raw_option(name, value): +def get_nanoseconds_from_raw_option(name: str, value: t.Any) -> int | None: if value is None: return None if isinstance(value, int): @@ -1003,12 +1021,14 @@ def get_nanoseconds_from_raw_option(name, value): ) -def get_value(key, values, default=None): +def get_value(key: str, values: dict[str, t.Any], default: t.Any = None) -> t.Any: value = values.get(key) return value if value is not None else default -def has_dict_changed(new_dict, old_dict): +def has_dict_changed( + new_dict: dict[str, t.Any] | None, old_dict: dict[str, t.Any] | None +) -> bool: """ Check if new_dict has differences compared to old_dict while ignoring keys in old_dict which are None in new_dict. @@ -1019,6 +1039,9 @@ def has_dict_changed(new_dict, old_dict): return True if not old_dict and new_dict: return True + if old_dict is None: + # in this case new_dict is empty, only the type checker didn't notice + return False defined_options = { option: value for option, value in new_dict.items() if value is not None } @@ -1031,12 +1054,17 @@ def has_dict_changed(new_dict, old_dict): return False -def has_list_changed(new_list, old_list, sort_lists=True, sort_key=None): +def has_list_changed( + new_list: list[t.Any] | None, + old_list: list[t.Any] | None, + sort_lists: bool = True, + sort_key: str | None = None, +) -> bool: """ Check two lists have differences. Sort lists by default. """ - def sort_list(unsorted_list): + def sort_list(unsorted_list: list[t.Any]) -> list[t.Any]: """ Sort a given list. The list may contain dictionaries, so use the sort key to handle them. @@ -1093,7 +1121,10 @@ def has_list_changed(new_list, old_list, sort_lists=True, sort_key=None): return False -def have_networks_changed(new_networks, old_networks): +def have_networks_changed( + new_networks: list[dict[str, t.Any]] | None, + old_networks: list[dict[str, t.Any]] | None, +) -> bool: """Special case list checking for networks to sort aliases""" if new_networks is None: @@ -1123,68 +1154,72 @@ def have_networks_changed(new_networks, old_networks): class DockerService(DockerBaseClass): - def __init__(self, docker_api_version, docker_py_version): + def __init__( + self, docker_api_version: LooseVersion, docker_py_version: LooseVersion + ) -> None: super().__init__() - self.image = "" - self.command = None - self.args = None - self.endpoint_mode = None - self.dns = None - self.healthcheck = None - self.healthcheck_disabled = None - self.hostname = None - self.hosts = None - self.tty = None - self.dns_search = None - self.dns_options = None - self.env = None - self.force_update = None - self.groups = None - self.log_driver = None - self.log_driver_options = None - self.labels = None - self.container_labels = None - self.sysctls = None - self.limit_cpu = None - self.limit_memory = None - self.reserve_cpu = None - self.reserve_memory = None - self.mode = "replicated" - self.user = None - self.mounts = None - self.configs = None - self.secrets = None - self.constraints = None - self.replicas_max_per_node = None - self.networks = None - self.stop_grace_period = None - self.stop_signal = None - self.publish = None - self.placement_preferences = None - self.replicas = -1 + self.image: str | None = "" + self.command: t.Any = None + self.args: list[str] | None = None + self.endpoint_mode: t.Literal["vip", "dnsrr"] | None = None + self.dns: list[str] | None = None + self.healthcheck: dict[str, t.Any] | None = None + self.healthcheck_disabled: bool | None = None + self.hostname: str | None = None + self.hosts: dict[str, t.Any] | None = None + self.tty: bool | None = None + self.dns_search: list[str] | None = None + self.dns_options: list[str] | None = None + self.env: t.Any = None + self.force_update: int | None = None + self.groups: list[str] | None = None + self.log_driver: str | None = None + self.log_driver_options: dict[str, t.Any] | None = None + self.labels: dict[str, t.Any] | None = None + self.container_labels: dict[str, t.Any] | None = None + self.sysctls: dict[str, t.Any] | None = None + self.limit_cpu: float | None = None + self.limit_memory: int | None = None + self.reserve_cpu: float | None = None + self.reserve_memory: int | None = None + self.mode: t.Literal["replicated", "global", "replicated-job"] = "replicated" + self.user: str | None = None + self.mounts: list[dict[str, t.Any]] | None = None + self.configs: list[dict[str, t.Any]] | None = None + self.secrets: list[dict[str, t.Any]] | None = None + self.constraints: list[str] | None = None + self.replicas_max_per_node: int | None = None + self.networks: list[t.Any] | None = None + self.stop_grace_period: int | None = None + self.stop_signal: str | None = None + self.publish: list[dict[str, t.Any]] | None = None + self.placement_preferences: list[dict[str, t.Any]] | None = None + self.replicas: int | None = -1 self.service_id = False self.service_version = False - self.read_only = None - self.restart_policy = None - self.restart_policy_attempts = None - self.restart_policy_delay = None - self.restart_policy_window = None - self.rollback_config = None - self.update_delay = None - self.update_parallelism = None - self.update_failure_action = None - self.update_monitor = None - self.update_max_failure_ratio = None - self.update_order = None - self.working_dir = None - self.init = None - self.cap_add = None - self.cap_drop = None + self.read_only: bool | None = None + self.restart_policy: t.Literal["none", "on-failure", "any"] | None = None + self.restart_policy_attempts: int | None = None + self.restart_policy_delay: str | None = None + self.restart_policy_window: str | None = None + self.rollback_config: dict[str, t.Any] | None = None + self.update_delay: str | None = None + self.update_parallelism: int | None = None + self.update_failure_action: ( + t.Literal["continue", "pause", "rollback"] | None + ) = None + self.update_monitor: str | None = None + self.update_max_failure_ratio: float | None = None + self.update_order: str | None = None + self.working_dir: str | None = None + self.init: bool | None = None + self.cap_add: list[str] | None = None + self.cap_drop: list[str] | None = None self.docker_api_version = docker_api_version self.docker_py_version = docker_py_version - def get_facts(self): + def get_facts(self) -> dict[str, t.Any]: return { "image": self.image, "mounts": self.mounts, @@ -1242,19 +1277,21 @@ class DockerService(DockerBaseClass): } @property - def can_update_networks(self): + def can_update_networks(self) -> bool: # Before Docker API 1.29 adding/removing networks was not supported return self.docker_api_version >= LooseVersion( "1.29" ) and self.docker_py_version >= LooseVersion("2.7") @property - def can_use_task_template_networks(self): + def can_use_task_template_networks(self) -> bool: # In Docker API 1.25 attaching networks to TaskTemplate is preferred over Spec return self.docker_py_version >= LooseVersion("2.7") @staticmethod - def get_restart_config_from_ansible_params(params): + def get_restart_config_from_ansible_params( + params: dict[str, t.Any], + ) -> dict[str, t.Any]: restart_config = params["restart_config"] or {} condition = get_value( "condition", @@ -1282,7 +1319,9 @@ class DockerService(DockerBaseClass): } @staticmethod - def get_update_config_from_ansible_params(params): + def get_update_config_from_ansible_params( + params: dict[str, t.Any], + ) -> dict[str, t.Any]: update_config = params["update_config"] or {} parallelism = get_value( "parallelism", @@ -1320,7 +1359,9 @@ class DockerService(DockerBaseClass): } @staticmethod - def get_rollback_config_from_ansible_params(params): + def get_rollback_config_from_ansible_params( + params: dict[str, t.Any], + ) -> dict[str, t.Any] | None: if params["rollback_config"] is None: return None rollback_config = params["rollback_config"] or {} @@ -1340,7 +1381,7 @@ class DockerService(DockerBaseClass): } @staticmethod - def get_logging_from_ansible_params(params): + def get_logging_from_ansible_params(params: dict[str, t.Any]) -> dict[str, t.Any]: logging_config = params["logging"] or {} driver = get_value( "driver", @@ -1356,7 +1397,7 @@ class DockerService(DockerBaseClass): } @staticmethod - def get_limits_from_ansible_params(params): + def get_limits_from_ansible_params(params: dict[str, t.Any]) -> dict[str, t.Any]: limits = params["limits"] or {} cpus = get_value( "cpus", @@ -1379,7 +1420,9 @@ class DockerService(DockerBaseClass): } @staticmethod - def get_reservations_from_ansible_params(params): + def get_reservations_from_ansible_params( + params: dict[str, t.Any], + ) -> dict[str, t.Any]: reservations = params["reservations"] or {} cpus = get_value( "cpus", @@ -1403,7 +1446,7 @@ class DockerService(DockerBaseClass): } @staticmethod - def get_placement_from_ansible_params(params): + def get_placement_from_ansible_params(params: dict[str, t.Any]) -> dict[str, t.Any]: placement = params["placement"] or {} constraints = get_value("constraints", placement) @@ -1419,14 +1462,14 @@ class DockerService(DockerBaseClass): @classmethod def from_ansible_params( cls, - ap, - old_service, - image_digest, - secret_ids, - config_ids, - network_ids, - client, - ): + ap: dict[str, t.Any], + old_service: DockerService | None, + image_digest: str, + secret_ids: dict[str, str], + config_ids: dict[str, str], + network_ids: dict[str, str], + client: AnsibleDockerClient, + ) -> DockerService: s = DockerService(client.docker_api_version, client.docker_py_version) s.image = image_digest s.args = ap["args"] @@ -1596,7 +1639,7 @@ class DockerService(DockerBaseClass): return s - def compare(self, os): + def compare(self, os: DockerService) -> tuple[bool, DifferenceTracker, bool, bool]: differences = DifferenceTracker() needs_rebuild = False force_update = False @@ -1784,7 +1827,7 @@ class DockerService(DockerBaseClass): differences.add( "update_order", parameter=self.update_order, active=os.update_order ) - has_image_changed, change = self.has_image_changed(os.image) + has_image_changed, change = self.has_image_changed(os.image or "") if has_image_changed: differences.add("image", parameter=self.image, active=change) if self.user and self.user != os.user: @@ -1828,7 +1871,7 @@ class DockerService(DockerBaseClass): force_update, ) - def has_healthcheck_changed(self, old_publish): + def has_healthcheck_changed(self, old_publish: DockerService) -> bool: if self.healthcheck_disabled is False and self.healthcheck is None: return False if self.healthcheck_disabled: @@ -1838,14 +1881,14 @@ class DockerService(DockerBaseClass): return False return self.healthcheck != old_publish.healthcheck - def has_publish_changed(self, old_publish): + def has_publish_changed(self, old_publish: list[dict[str, t.Any]] | None) -> bool: if self.publish is None: return False old_publish = old_publish or [] if len(self.publish) != len(old_publish): return True - def publish_sorter(item): + def publish_sorter(item: dict[str, t.Any]) -> tuple[int, int, str]: return ( item.get("published_port") or 0, item.get("target_port") or 0, @@ -1869,12 +1912,13 @@ class DockerService(DockerBaseClass): return True return False - def has_image_changed(self, old_image): + def has_image_changed(self, old_image: str) -> tuple[bool, str]: + assert self.image is not None if "@" not in self.image: old_image = old_image.split("@")[0] return self.image != old_image, old_image - def build_container_spec(self): + def build_container_spec(self) -> types.ContainerSpec: mounts = None if self.mounts is not None: mounts = [] @@ -1945,7 +1989,7 @@ class DockerService(DockerBaseClass): secrets.append(types.SecretReference(**secret_args)) - dns_config_args = {} + dns_config_args: dict[str, t.Any] = {} if self.dns is not None: dns_config_args["nameservers"] = self.dns if self.dns_search is not None: @@ -1954,7 +1998,7 @@ class DockerService(DockerBaseClass): dns_config_args["options"] = self.dns_options dns_config = types.DNSConfig(**dns_config_args) if dns_config_args else None - container_spec_args = {} + container_spec_args: dict[str, t.Any] = {} if self.command is not None: container_spec_args["command"] = self.command if self.args is not None: @@ -2004,8 +2048,8 @@ class DockerService(DockerBaseClass): return types.ContainerSpec(self.image, **container_spec_args) - def build_placement(self): - placement_args = {} + def build_placement(self) -> types.Placement | None: + placement_args: dict[str, t.Any] = {} if self.constraints is not None: placement_args["constraints"] = self.constraints if self.replicas_max_per_node is not None: @@ -2018,8 +2062,8 @@ class DockerService(DockerBaseClass): ] return types.Placement(**placement_args) if placement_args else None - def build_update_config(self): - update_config_args = {} + def build_update_config(self) -> types.UpdateConfig | None: + update_config_args: dict[str, t.Any] = {} if self.update_parallelism is not None: update_config_args["parallelism"] = self.update_parallelism if self.update_delay is not None: @@ -2034,16 +2078,16 @@ class DockerService(DockerBaseClass): update_config_args["order"] = self.update_order return types.UpdateConfig(**update_config_args) if update_config_args else None - def build_log_driver(self): - log_driver_args = {} + def build_log_driver(self) -> types.DriverConfig | None: + log_driver_args: dict[str, t.Any] = {} if self.log_driver is not None: log_driver_args["name"] = self.log_driver if self.log_driver_options is not None: log_driver_args["options"] = self.log_driver_options return types.DriverConfig(**log_driver_args) if log_driver_args else None - def build_restart_policy(self): - restart_policy_args = {} + def build_restart_policy(self) -> types.RestartPolicy | None: + restart_policy_args: dict[str, t.Any] = {} if self.restart_policy is not None: restart_policy_args["condition"] = self.restart_policy if self.restart_policy_delay is not None: @@ -2056,7 +2100,7 @@ class DockerService(DockerBaseClass): types.RestartPolicy(**restart_policy_args) if restart_policy_args else None ) - def build_rollback_config(self): + def build_rollback_config(self) -> types.RollbackConfig | None: if self.rollback_config is None: return None rollback_config_options = [ @@ -2078,8 +2122,8 @@ class DockerService(DockerBaseClass): else None ) - def build_resources(self): - resources_args = {} + def build_resources(self) -> types.Resources | None: + resources_args: dict[str, t.Any] = {} if self.limit_cpu is not None: resources_args["cpu_limit"] = int(self.limit_cpu * 1000000000.0) if self.limit_memory is not None: @@ -2090,12 +2134,16 @@ class DockerService(DockerBaseClass): resources_args["mem_reservation"] = self.reserve_memory return types.Resources(**resources_args) if resources_args else None - def build_task_template(self, container_spec, placement=None): + def build_task_template( + self, + container_spec: types.ContainerSpec, + placement: types.Placement | None = None, + ) -> types.TaskTemplate: log_driver = self.build_log_driver() restart_policy = self.build_restart_policy() resources = self.build_resources() - task_template_args = {} + task_template_args: dict[str, t.Any] = {} if placement is not None: task_template_args["placement"] = placement if log_driver is not None: @@ -2112,12 +2160,12 @@ class DockerService(DockerBaseClass): task_template_args["networks"] = networks return types.TaskTemplate(container_spec=container_spec, **task_template_args) - def build_service_mode(self): + def build_service_mode(self) -> types.ServiceMode: if self.mode == "global": self.replicas = None return types.ServiceMode(self.mode, replicas=self.replicas) - def build_networks(self): + def build_networks(self) -> list[dict[str, t.Any]] | None: networks = None if self.networks is not None: networks = [] @@ -2130,8 +2178,8 @@ class DockerService(DockerBaseClass): networks.append(docker_network) return networks - def build_endpoint_spec(self): - endpoint_spec_args = {} + def build_endpoint_spec(self) -> types.EndpointSpec | None: + endpoint_spec_args: dict[str, t.Any] = {} if self.publish is not None: ports = [] for port in self.publish: @@ -2149,7 +2197,7 @@ class DockerService(DockerBaseClass): endpoint_spec_args["mode"] = self.endpoint_mode return types.EndpointSpec(**endpoint_spec_args) if endpoint_spec_args else None - def build_docker_service(self): + def build_docker_service(self) -> dict[str, t.Any]: container_spec = self.build_container_spec() placement = self.build_placement() task_template = self.build_task_template(container_spec, placement) @@ -2159,7 +2207,10 @@ class DockerService(DockerBaseClass): service_mode = self.build_service_mode() endpoint_spec = self.build_endpoint_spec() - service = {"task_template": task_template, "mode": service_mode} + service: dict[str, t.Any] = { + "task_template": task_template, + "mode": service_mode, + } if update_config: service["update_config"] = update_config if rollback_config: @@ -2176,13 +2227,12 @@ class DockerService(DockerBaseClass): class DockerServiceManager: - - def __init__(self, client): + def __init__(self, client: AnsibleDockerClient): self.client = client self.retries = 2 - self.diff_tracker = None + self.diff_tracker: DifferenceTracker | None = None - def get_service(self, name): + def get_service(self, name: str) -> DockerService | None: try: raw_data = self.client.inspect_service(name) except NotFound: @@ -2415,7 +2465,9 @@ class DockerServiceManager: ds.init = task_template_data["ContainerSpec"].get("Init", False) return ds - def update_service(self, name, old_service, new_service): + def update_service( + self, name: str, old_service: DockerService, new_service: DockerService + ) -> None: service_data = new_service.build_docker_service() result = self.client.update_service( old_service.service_id, @@ -2427,15 +2479,15 @@ class DockerServiceManager: # (see https://github.com/docker/docker-py/pull/2272) self.client.report_warnings(result, ["Warning"]) - def create_service(self, name, service): + def create_service(self, name: str, service: DockerService) -> None: service_data = service.build_docker_service() result = self.client.create_service(name=name, **service_data) self.client.report_warnings(result, ["Warning"]) - def remove_service(self, name): + def remove_service(self, name: str) -> None: self.client.remove_service(name) - def get_image_digest(self, name, resolve=False): + def get_image_digest(self, name: str, resolve: bool = False) -> str: if not name or not resolve: return name repo, tag = parse_repository_tag(name) @@ -2446,10 +2498,10 @@ class DockerServiceManager: digest = distribution_data["Descriptor"]["digest"] return f"{name}@{digest}" - def get_networks_names_ids(self): + def get_networks_names_ids(self) -> dict[str, str]: return {network["Name"]: network["Id"] for network in self.client.networks()} - def get_missing_secret_ids(self): + def get_missing_secret_ids(self) -> dict[str, str]: """ Resolve missing secret ids by looking them up by name """ @@ -2471,7 +2523,7 @@ class DockerServiceManager: self.client.fail(f'Could not find a secret named "{secret_name}"') return secrets - def get_missing_config_ids(self): + def get_missing_config_ids(self) -> dict[str, str]: """ Resolve missing config ids by looking them up by name """ @@ -2493,7 +2545,7 @@ class DockerServiceManager: self.client.fail(f'Could not find a config named "{config_name}"') return configs - def run(self): + def run(self) -> tuple[str, bool, bool, list[str], dict[str, t.Any]]: self.diff_tracker = DifferenceTracker() module = self.client.module @@ -2582,7 +2634,7 @@ class DockerServiceManager: return msg, changed, rebuilt, differences.get_legacy_docker_diffs(), facts - def run_safe(self): + def run_safe(self) -> tuple[str, bool, bool, list[str], dict[str, t.Any]]: while True: try: return self.run() @@ -2596,20 +2648,20 @@ class DockerServiceManager: raise -def _detect_publish_mode_usage(client): +def _detect_publish_mode_usage(client: AnsibleDockerClient) -> bool: for publish_def in client.module.params["publish"] or []: if publish_def.get("mode"): return True return False -def _detect_healthcheck_start_period(client): +def _detect_healthcheck_start_period(client: AnsibleDockerClient) -> bool: if client.module.params["healthcheck"]: return client.module.params["healthcheck"]["start_period"] is not None return False -def _detect_mount_tmpfs_usage(client): +def _detect_mount_tmpfs_usage(client: AnsibleDockerClient) -> bool: for mount in client.module.params["mounts"] or []: if mount.get("type") == "tmpfs": return True @@ -2620,14 +2672,14 @@ def _detect_mount_tmpfs_usage(client): return False -def _detect_update_config_failure_action_rollback(client): +def _detect_update_config_failure_action_rollback(client: AnsibleDockerClient) -> bool: rollback_config_failure_action = (client.module.params["update_config"] or {}).get( "failure_action" ) return rollback_config_failure_action == "rollback" -def main(): +def main() -> None: argument_spec = { "name": {"type": "str", "required": True}, "image": {"type": "str"}, @@ -2948,6 +3000,7 @@ def main(): "swarm_service": facts, } if client.module._diff: + assert dsm.diff_tracker is not None before, after = dsm.diff_tracker.get_before_after() results["diff"] = {"before": before, "after": after} diff --git a/plugins/modules/docker_swarm_service_info.py b/plugins/modules/docker_swarm_service_info.py index 710a8ae4..d338c12e 100644 --- a/plugins/modules/docker_swarm_service_info.py +++ b/plugins/modules/docker_swarm_service_info.py @@ -63,6 +63,7 @@ service: """ import traceback +import typing as t try: @@ -79,12 +80,12 @@ from ansible_collections.community.docker.plugins.module_utils._swarm import ( ) -def get_service_info(client): +def get_service_info(client: AnsibleDockerSwarmClient) -> dict[str, t.Any] | None: service = client.module.params["name"] return client.get_service_inspect(service_id=service, skip_missing=True) -def main(): +def main() -> None: argument_spec = { "name": {"type": "str", "required": True}, } diff --git a/tests/sanity/ignore-2.17.txt b/tests/sanity/ignore-2.17.txt index b2c9b3ff..1dbff5a7 100644 --- a/tests/sanity/ignore-2.17.txt +++ b/tests/sanity/ignore-2.17.txt @@ -5,6 +5,7 @@ plugins/module_utils/_api/api/client.py pep8:E704 plugins/module_utils/_api/transport/sshconn.py no-assert plugins/module_utils/_api/utils/build.py no-assert plugins/module_utils/_api/utils/ports.py pep8:E704 +plugins/module_utils/_api/utils/proxy.py pep8:E704 plugins/module_utils/_api/utils/socket.py pep8:E704 plugins/module_utils/_common_cli.py pep8:E704 plugins/module_utils/_module_container/module.py no-assert @@ -12,6 +13,7 @@ plugins/module_utils/_platform.py no-assert plugins/module_utils/_socket_handler.py no-assert plugins/module_utils/_swarm.py pep8:E704 plugins/module_utils/_util.py pep8:E704 +plugins/modules/docker_container_copy_into.py no-assert plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin plugins/modules/docker_container_exec.py no-assert plugins/modules/docker_container_exec.py pylint:unpacking-non-sequence @@ -19,4 +21,6 @@ plugins/modules/docker_image.py no-assert plugins/modules/docker_image_tag.py no-assert plugins/modules/docker_login.py no-assert plugins/modules/docker_plugin.py no-assert +plugins/modules/docker_swarm_service.py no-assert +plugins/modules/docker_swarm_service.py pep8:E704 plugins/modules/docker_volume.py no-assert diff --git a/tests/sanity/ignore-2.18.txt b/tests/sanity/ignore-2.18.txt index 65be094d..88e39985 100644 --- a/tests/sanity/ignore-2.18.txt +++ b/tests/sanity/ignore-2.18.txt @@ -5,6 +5,7 @@ plugins/module_utils/_api/api/client.py pep8:E704 plugins/module_utils/_api/transport/sshconn.py no-assert plugins/module_utils/_api/utils/build.py no-assert plugins/module_utils/_api/utils/ports.py pep8:E704 +plugins/module_utils/_api/utils/proxy.py pep8:E704 plugins/module_utils/_api/utils/socket.py pep8:E704 plugins/module_utils/_common_cli.py pep8:E704 plugins/module_utils/_module_container/module.py no-assert @@ -12,10 +13,13 @@ plugins/module_utils/_platform.py no-assert plugins/module_utils/_socket_handler.py no-assert plugins/module_utils/_swarm.py pep8:E704 plugins/module_utils/_util.py pep8:E704 +plugins/modules/docker_container_copy_into.py no-assert plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin plugins/modules/docker_container_exec.py no-assert plugins/modules/docker_image.py no-assert plugins/modules/docker_image_tag.py no-assert plugins/modules/docker_login.py no-assert plugins/modules/docker_plugin.py no-assert +plugins/modules/docker_swarm_service.py no-assert +plugins/modules/docker_swarm_service.py pep8:E704 plugins/modules/docker_volume.py no-assert diff --git a/tests/sanity/ignore-2.19.txt b/tests/sanity/ignore-2.19.txt index b47a6747..dea97dcf 100644 --- a/tests/sanity/ignore-2.19.txt +++ b/tests/sanity/ignore-2.19.txt @@ -6,10 +6,12 @@ plugins/module_utils/_api/utils/build.py no-assert plugins/module_utils/_module_container/module.py no-assert plugins/module_utils/_platform.py no-assert plugins/module_utils/_socket_handler.py no-assert +plugins/modules/docker_container_copy_into.py no-assert plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin plugins/modules/docker_container_exec.py no-assert plugins/modules/docker_image.py no-assert plugins/modules/docker_image_tag.py no-assert plugins/modules/docker_login.py no-assert plugins/modules/docker_plugin.py no-assert +plugins/modules/docker_swarm_service.py no-assert plugins/modules/docker_volume.py no-assert diff --git a/tests/sanity/ignore-2.20.txt b/tests/sanity/ignore-2.20.txt index b47a6747..dea97dcf 100644 --- a/tests/sanity/ignore-2.20.txt +++ b/tests/sanity/ignore-2.20.txt @@ -6,10 +6,12 @@ plugins/module_utils/_api/utils/build.py no-assert plugins/module_utils/_module_container/module.py no-assert plugins/module_utils/_platform.py no-assert plugins/module_utils/_socket_handler.py no-assert +plugins/modules/docker_container_copy_into.py no-assert plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin plugins/modules/docker_container_exec.py no-assert plugins/modules/docker_image.py no-assert plugins/modules/docker_image_tag.py no-assert plugins/modules/docker_login.py no-assert plugins/modules/docker_plugin.py no-assert +plugins/modules/docker_swarm_service.py no-assert plugins/modules/docker_volume.py no-assert diff --git a/tests/sanity/ignore-2.21.txt b/tests/sanity/ignore-2.21.txt index b47a6747..dea97dcf 100644 --- a/tests/sanity/ignore-2.21.txt +++ b/tests/sanity/ignore-2.21.txt @@ -6,10 +6,12 @@ plugins/module_utils/_api/utils/build.py no-assert plugins/module_utils/_module_container/module.py no-assert plugins/module_utils/_platform.py no-assert plugins/module_utils/_socket_handler.py no-assert +plugins/modules/docker_container_copy_into.py no-assert plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin plugins/modules/docker_container_exec.py no-assert plugins/modules/docker_image.py no-assert plugins/modules/docker_image_tag.py no-assert plugins/modules/docker_login.py no-assert plugins/modules/docker_plugin.py no-assert +plugins/modules/docker_swarm_service.py no-assert plugins/modules/docker_volume.py no-assert diff --git a/tests/unit/plugins/connection/test_docker.py b/tests/unit/plugins/connection/test_docker.py index e9a804af..680aad1e 100644 --- a/tests/unit/plugins/connection/test_docker.py +++ b/tests/unit/plugins/connection/test_docker.py @@ -4,6 +4,7 @@ from __future__ import annotations +import typing as t import unittest from io import StringIO from unittest import mock @@ -14,8 +15,7 @@ from ansible.plugins.loader import connection_loader class TestDockerConnectionClass(unittest.TestCase): - - def setUp(self): + def setUp(self) -> None: self.play_context = PlayContext() self.play_context.prompt = ( "[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: " @@ -29,7 +29,7 @@ class TestDockerConnectionClass(unittest.TestCase): "community.docker.docker", self.play_context, self.in_stream ) - def tearDown(self): + def tearDown(self) -> None: pass @mock.patch( @@ -41,8 +41,8 @@ class TestDockerConnectionClass(unittest.TestCase): return_value=("docker version", "1.2.3", "", 0), ) def test_docker_connection_module_too_old( - self, mock_new_docker_version, mock_old_docker_version - ): + self, mock_new_docker_version: t.Any, mock_old_docker_version: t.Any + ) -> None: self.dc._version = None self.dc.remote_user = "foo" self.assertRaisesRegex( @@ -60,8 +60,8 @@ class TestDockerConnectionClass(unittest.TestCase): return_value=("docker version", "1.7.0", "", 0), ) def test_docker_connection_module( - self, mock_new_docker_version, mock_old_docker_version - ): + self, mock_new_docker_version: t.Any, mock_old_docker_version: t.Any + ) -> None: self.dc._version = None # old version and new version fail @@ -74,8 +74,8 @@ class TestDockerConnectionClass(unittest.TestCase): return_value=("false", "garbage", "", 1), ) def test_docker_connection_module_wrong_cmd( - self, mock_new_docker_version, mock_old_docker_version - ): + self, mock_new_docker_version: t.Any, mock_old_docker_version: t.Any + ) -> None: self.dc._version = None self.dc.remote_user = "foo" self.assertRaisesRegex( diff --git a/tests/unit/plugins/inventory/test_docker_containers.py b/tests/unit/plugins/inventory/test_docker_containers.py index ca96b81e..9f1d5ee6 100644 --- a/tests/unit/plugins/inventory/test_docker_containers.py +++ b/tests/unit/plugins/inventory/test_docker_containers.py @@ -4,6 +4,7 @@ from __future__ import annotations +import typing as t from unittest.mock import create_autospec import pytest @@ -19,14 +20,18 @@ from ansible_collections.community.docker.plugins.inventory.docker_containers im ) +if t.TYPE_CHECKING: + from collections.abc import Callable + + @pytest.fixture(scope="module") -def templar(): +def templar() -> Templar: dataloader = create_autospec(DataLoader, instance=True) return Templar(loader=dataloader) @pytest.fixture(scope="module") -def inventory(templar): +def inventory(templar: Templar) -> InventoryModule: r = InventoryModule() r.inventory = InventoryData() r.templar = templar @@ -83,8 +88,10 @@ LOVING_THARP_SERVICE = { } -def create_get_option(options, default=False): - def get_option(option): +def create_get_option( + options: dict[str, t.Any], default: t.Any = False +) -> Callable[[str], t.Any]: + def get_option(option: str) -> t.Any: if option in options: return options[option] return default @@ -93,9 +100,9 @@ def create_get_option(options, default=False): class FakeClient: - def __init__(self, *hosts): - self.get_results = {} - list_reply = [] + def __init__(self, *hosts: dict[str, t.Any]) -> None: + self.get_results: dict[str, t.Any] = {} + list_reply: list[dict[str, t.Any]] = [] for host in hosts: list_reply.append( { @@ -109,15 +116,16 @@ class FakeClient: self.get_results[f"/containers/{host['Id']}/json"] = host self.get_results["/containers/json"] = list_reply - def get_json(self, url, *param, **kwargs): + def get_json(self, url: str, *param: str, **kwargs: t.Any) -> t.Any: url = url.format(*param) return self.get_results[url] -def test_populate(inventory, mocker): +def test_populate(inventory: InventoryModule, mocker: t.Any) -> None: + assert inventory.inventory is not None client = FakeClient(LOVING_THARP) - inventory.get_option = mocker.MagicMock( + inventory.get_option = mocker.MagicMock( # type: ignore[method-assign] side_effect=create_get_option( { "verbose_output": True, @@ -130,9 +138,10 @@ def test_populate(inventory, mocker): } ) ) - inventory._populate(client) + inventory._populate(client) # type: ignore host_1 = inventory.inventory.get_host("loving_tharp") + assert host_1 is not None host_1_vars = host_1.get_vars() assert host_1_vars["ansible_host"] == "loving_tharp" @@ -149,10 +158,11 @@ def test_populate(inventory, mocker): assert len(inventory.inventory.hosts) == 1 -def test_populate_service(inventory, mocker): +def test_populate_service(inventory: InventoryModule, mocker: t.Any) -> None: + assert inventory.inventory is not None client = FakeClient(LOVING_THARP_SERVICE) - inventory.get_option = mocker.MagicMock( + inventory.get_option = mocker.MagicMock( # type: ignore[method-assign] side_effect=create_get_option( { "verbose_output": False, @@ -166,9 +176,10 @@ def test_populate_service(inventory, mocker): } ) ) - inventory._populate(client) + inventory._populate(client) # type: ignore host_1 = inventory.inventory.get_host("loving_tharp") + assert host_1 is not None host_1_vars = host_1.get_vars() assert host_1_vars["ansible_host"] == "loving_tharp" @@ -207,10 +218,11 @@ def test_populate_service(inventory, mocker): assert len(inventory.inventory.hosts) == 1 -def test_populate_stack(inventory, mocker): +def test_populate_stack(inventory: InventoryModule, mocker: t.Any) -> None: + assert inventory.inventory is not None client = FakeClient(LOVING_THARP_STACK) - inventory.get_option = mocker.MagicMock( + inventory.get_option = mocker.MagicMock( # type: ignore[method-assign] side_effect=create_get_option( { "verbose_output": False, @@ -226,9 +238,10 @@ def test_populate_stack(inventory, mocker): } ) ) - inventory._populate(client) + inventory._populate(client) # type: ignore host_1 = inventory.inventory.get_host("loving_tharp") + assert host_1 is not None host_1_vars = host_1.get_vars() assert host_1_vars["ansible_ssh_host"] == "127.0.0.1" @@ -267,10 +280,11 @@ def test_populate_stack(inventory, mocker): assert len(inventory.inventory.hosts) == 1 -def test_populate_filter_none(inventory, mocker): +def test_populate_filter_none(inventory: InventoryModule, mocker: t.Any) -> None: + assert inventory.inventory is not None client = FakeClient(LOVING_THARP) - inventory.get_option = mocker.MagicMock( + inventory.get_option = mocker.MagicMock( # type: ignore[method-assign] side_effect=create_get_option( { "verbose_output": True, @@ -285,15 +299,16 @@ def test_populate_filter_none(inventory, mocker): } ) ) - inventory._populate(client) + inventory._populate(client) # type: ignore assert len(inventory.inventory.hosts) == 0 -def test_populate_filter(inventory, mocker): +def test_populate_filter(inventory: InventoryModule, mocker: t.Any) -> None: + assert inventory.inventory is not None client = FakeClient(LOVING_THARP) - inventory.get_option = mocker.MagicMock( + inventory.get_option = mocker.MagicMock( # type: ignore[method-assign] side_effect=create_get_option( { "verbose_output": True, @@ -309,9 +324,10 @@ def test_populate_filter(inventory, mocker): } ) ) - inventory._populate(client) + inventory._populate(client) # type: ignore host_1 = inventory.inventory.get_host("loving_tharp") + assert host_1 is not None host_1_vars = host_1.get_vars() assert host_1_vars["ansible_host"] == "loving_tharp" diff --git a/tests/unit/plugins/module_utils/_api/api/test_client.py b/tests/unit/plugins/module_utils/_api/api/test_client.py index 20ce3cbe..ce75f2bd 100644 --- a/tests/unit/plugins/module_utils/_api/api/test_client.py +++ b/tests/unit/plugins/module_utils/_api/api/test_client.py @@ -19,6 +19,7 @@ import struct import tempfile import threading import time +import typing as t import unittest from http.server import BaseHTTPRequestHandler from socketserver import ThreadingTCPServer @@ -42,18 +43,24 @@ from ansible_collections.community.docker.tests.unit.plugins.module_utils._api.c from .. import fake_api +if t.TYPE_CHECKING: + from ansible_collections.community.docker.plugins.module_utils._api.auth import ( + AuthConfig, + ) + + DEFAULT_TIMEOUT_SECONDS = constants.DEFAULT_TIMEOUT_SECONDS def response( - status_code=200, - content="", - headers=None, - reason=None, - elapsed=0, - request=None, - raw=None, -): + status_code: int = 200, + content: bytes | dict[str, t.Any] | list[dict[str, t.Any]] = b"", + headers: dict[str, str] | None = None, + reason: str = "", + elapsed: int = 0, + request: requests.PreparedRequest | None = None, + raw: urllib3.HTTPResponse | None = None, +) -> requests.Response: res = requests.Response() res.status_code = status_code if not isinstance(content, bytes): @@ -62,23 +69,25 @@ def response( res.headers = requests.structures.CaseInsensitiveDict(headers or {}) res.reason = reason res.elapsed = datetime.timedelta(elapsed) - res.request = request + res.request = request # type: ignore res.raw = raw return res -def fake_resolve_authconfig( - authconfig, registry=None, *args, **kwargs -): # pylint: disable=keyword-arg-before-vararg +def fake_resolve_authconfig( # pylint: disable=keyword-arg-before-vararg + authconfig: AuthConfig, *args: t.Any, registry: str | None = None, **kwargs: t.Any +) -> None: return None -def fake_inspect_container(self, container, tty=False): +def fake_inspect_container(self: object, container: str, tty: bool = False) -> t.Any: return fake_api.get_fake_inspect_container(tty=tty)[1] -def fake_resp(method, url, *args, **kwargs): - key = None +def fake_resp( + method: str, url: str, *args: t.Any, **kwargs: t.Any +) -> requests.Response: + key: str | tuple[str, str] | None = None if url in fake_api.fake_responses: key = url elif (url, method) in fake_api.fake_responses: @@ -92,23 +101,37 @@ def fake_resp(method, url, *args, **kwargs): fake_request = mock.Mock(side_effect=fake_resp) -def fake_get(self, url, *args, **kwargs): +def fake_get( + self: APIClient, url: str, *args: str, **kwargs: t.Any +) -> requests.Response: return fake_request("GET", url, *args, **kwargs) -def fake_post(self, url, *args, **kwargs): +def fake_post( + self: APIClient, url: str, *args: str, **kwargs: t.Any +) -> requests.Response: return fake_request("POST", url, *args, **kwargs) -def fake_put(self, url, *args, **kwargs): +def fake_put( + self: APIClient, url: str, *args: str, **kwargs: t.Any +) -> requests.Response: return fake_request("PUT", url, *args, **kwargs) -def fake_delete(self, url, *args, **kwargs): +def fake_delete( + self: APIClient, url: str, *args: str, **kwargs: t.Any +) -> requests.Response: return fake_request("DELETE", url, *args, **kwargs) -def fake_read_from_socket(self, response, stream, tty=False, demux=False): +def fake_read_from_socket( + self: APIClient, + response: requests.Response, + stream: bool, + tty: bool = False, + demux: bool = False, +) -> bytes: return b"" @@ -117,7 +140,7 @@ url_prefix = f"{url_base}v{DEFAULT_DOCKER_API_VERSION}/" # pylint: disable=inva class BaseAPIClientTest(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.patcher = mock.patch.multiple( "ansible_collections.community.docker.plugins.module_utils._api.api.client.APIClient", get=fake_get, @@ -129,11 +152,13 @@ class BaseAPIClientTest(unittest.TestCase): self.patcher.start() self.client = APIClient(version=DEFAULT_DOCKER_API_VERSION) - def tearDown(self): + def tearDown(self) -> None: self.client.close() self.patcher.stop() - def base_create_payload(self, img="busybox", cmd=None): + def base_create_payload( + self, img: str = "busybox", cmd: list[str] | None = None + ) -> dict[str, t.Any]: if not cmd: cmd = ["true"] return { @@ -150,16 +175,16 @@ class BaseAPIClientTest(unittest.TestCase): class DockerApiTest(BaseAPIClientTest): - def test_ctor(self): + def test_ctor(self) -> None: with pytest.raises(errors.DockerException) as excinfo: - APIClient(version=1.12) + APIClient(version=1.12) # type: ignore assert ( str(excinfo.value) == "Version parameter must be a string or None. Found float" ) - def test_url_valid_resource(self): + def test_url_valid_resource(self) -> None: url = self.client._url("/hello/{0}/world", "somename") assert url == f"{url_prefix}hello/somename/world" @@ -172,50 +197,50 @@ class DockerApiTest(BaseAPIClientTest): url = self.client._url("/images/{0}/push", "localhost:5000/image") assert url == f"{url_prefix}images/localhost:5000/image/push" - def test_url_invalid_resource(self): + def test_url_invalid_resource(self) -> None: with pytest.raises(ValueError): - self.client._url("/hello/{0}/world", ["sakuya", "izayoi"]) + self.client._url("/hello/{0}/world", ["sakuya", "izayoi"]) # type: ignore - def test_url_no_resource(self): + def test_url_no_resource(self) -> None: url = self.client._url("/simple") assert url == f"{url_prefix}simple" - def test_url_unversioned_api(self): + def test_url_unversioned_api(self) -> None: url = self.client._url("/hello/{0}/world", "somename", versioned_api=False) assert url == f"{url_base}hello/somename/world" - def test_version(self): + def test_version(self) -> None: self.client.version() fake_request.assert_called_with( "GET", url_prefix + "version", timeout=DEFAULT_TIMEOUT_SECONDS ) - def test_version_no_api_version(self): + def test_version_no_api_version(self) -> None: self.client.version(False) fake_request.assert_called_with( "GET", url_base + "version", timeout=DEFAULT_TIMEOUT_SECONDS ) - def test_retrieve_server_version(self): + def test_retrieve_server_version(self) -> None: client = APIClient(version="auto") assert isinstance(client._version, str) assert not (client._version == "auto") client.close() - def test_auto_retrieve_server_version(self): + def test_auto_retrieve_server_version(self) -> None: version = self.client._retrieve_server_version() assert isinstance(version, str) - def test_info(self): + def test_info(self) -> None: self.client.info() fake_request.assert_called_with( "GET", url_prefix + "info", timeout=DEFAULT_TIMEOUT_SECONDS ) - def test_search(self): + def test_search(self) -> None: self.client.get_json("/images/search", params={"term": "busybox"}) fake_request.assert_called_with( @@ -225,7 +250,7 @@ class DockerApiTest(BaseAPIClientTest): timeout=DEFAULT_TIMEOUT_SECONDS, ) - def test_login(self): + def test_login(self) -> None: self.client.login("sakuya", "izayoi") args = fake_request.call_args assert args[0][0] == "POST" @@ -242,42 +267,42 @@ class DockerApiTest(BaseAPIClientTest): "serveraddress": None, } - def _socket_path_for_client_session(self, client): + def _socket_path_for_client_session(self, client: APIClient) -> str: socket_adapter = client.get_adapter("http+docker://") - return socket_adapter.socket_path + return socket_adapter.socket_path # type: ignore[attr-defined] - def test_url_compatibility_unix(self): + def test_url_compatibility_unix(self) -> None: c = APIClient(base_url="unix://socket", version=DEFAULT_DOCKER_API_VERSION) assert self._socket_path_for_client_session(c) == "/socket" - def test_url_compatibility_unix_triple_slash(self): + def test_url_compatibility_unix_triple_slash(self) -> None: c = APIClient(base_url="unix:///socket", version=DEFAULT_DOCKER_API_VERSION) assert self._socket_path_for_client_session(c) == "/socket" - def test_url_compatibility_http_unix_triple_slash(self): + def test_url_compatibility_http_unix_triple_slash(self) -> None: c = APIClient( base_url="http+unix:///socket", version=DEFAULT_DOCKER_API_VERSION ) assert self._socket_path_for_client_session(c) == "/socket" - def test_url_compatibility_http(self): + def test_url_compatibility_http(self) -> None: c = APIClient( base_url="http://hostname:1234", version=DEFAULT_DOCKER_API_VERSION ) assert c.base_url == "http://hostname:1234" - def test_url_compatibility_tcp(self): + def test_url_compatibility_tcp(self) -> None: c = APIClient( base_url="tcp://hostname:1234", version=DEFAULT_DOCKER_API_VERSION ) assert c.base_url == "http://hostname:1234" - def test_remove_link(self): + def test_remove_link(self) -> None: self.client.delete_call( "/containers/{0}", "3cc2351ab11b", @@ -291,7 +316,7 @@ class DockerApiTest(BaseAPIClientTest): timeout=DEFAULT_TIMEOUT_SECONDS, ) - def test_stream_helper_decoding(self): + def test_stream_helper_decoding(self) -> None: status_code, content = fake_api.fake_responses[url_prefix + "events"]() content_str = json.dumps(content).encode("utf-8") body = io.BytesIO(content_str) @@ -318,7 +343,7 @@ class DockerApiTest(BaseAPIClientTest): raw_resp._fp.seek(0) resp = response(status_code=status_code, content=content, raw=raw_resp) result = next(self.client._stream_helper(resp)) - assert result == content_str.decode("utf-8") + assert result == content_str.decode("utf-8") # type: ignore # non-chunked response, pass `decode=True` to the helper raw_resp._fp.seek(0) @@ -328,7 +353,7 @@ class DockerApiTest(BaseAPIClientTest): class UnixSocketStreamTest(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: socket_dir = tempfile.mkdtemp() self.build_context = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, socket_dir) @@ -339,23 +364,23 @@ class UnixSocketStreamTest(unittest.TestCase): server_thread = threading.Thread(target=self.run_server) server_thread.daemon = True server_thread.start() - self.response = None - self.request_handler = None + self.response: t.Any = None + self.request_handler: t.Any = None self.addCleanup(server_thread.join) self.addCleanup(self.stop) - def stop(self): + def stop(self) -> None: self.stop_server = True - def _setup_socket(self): + def _setup_socket(self) -> socket.socket: server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) server_sock.bind(self.socket_file) # Non-blocking mode so that we can shut the test down easily - server_sock.setblocking(0) + server_sock.setblocking(0) # type: ignore server_sock.listen(5) return server_sock - def run_server(self): + def run_server(self) -> None: try: while not self.stop_server: try: @@ -365,7 +390,7 @@ class UnixSocketStreamTest(unittest.TestCase): time.sleep(0.01) continue - connection.setblocking(1) + connection.setblocking(1) # type: ignore try: self.request_handler(connection) finally: @@ -373,7 +398,7 @@ class UnixSocketStreamTest(unittest.TestCase): finally: self.server_socket.close() - def early_response_sending_handler(self, connection): + def early_response_sending_handler(self, connection: socket.socket) -> None: data = b"" headers = None @@ -395,7 +420,7 @@ class UnixSocketStreamTest(unittest.TestCase): data += connection.recv(2048) @pytest.mark.skipif(constants.IS_WINDOWS_PLATFORM, reason="Unix only") - def test_early_stream_response(self): + def test_early_stream_response(self) -> None: self.request_handler = self.early_response_sending_handler lines = [] for i in range(0, 50): @@ -405,7 +430,7 @@ class UnixSocketStreamTest(unittest.TestCase): lines.append(b"") self.response = ( - b"HTTP/1.1 200 OK\r\n" b"Transfer-Encoding: chunked\r\n" b"\r\n" + b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n" ) + b"\r\n".join(lines) with APIClient( @@ -459,8 +484,12 @@ class TCPSocketStreamTest(unittest.TestCase): built on these islands for generations past? Now shall what of Him? """ + server: ThreadingTCPServer + thread: threading.Thread + address: str + @classmethod - def setup_class(cls): + def setup_class(cls) -> None: cls.server = ThreadingTCPServer(("", 0), cls.get_handler_class()) cls.thread = threading.Thread(target=cls.server.serve_forever) cls.thread.daemon = True @@ -468,18 +497,18 @@ class TCPSocketStreamTest(unittest.TestCase): cls.address = f"http://{socket.gethostname()}:{cls.server.server_address[1]}" @classmethod - def teardown_class(cls): + def teardown_class(cls) -> None: cls.server.shutdown() cls.server.server_close() cls.thread.join() @classmethod - def get_handler_class(cls): + def get_handler_class(cls) -> t.Type[BaseHTTPRequestHandler]: stdout_data = cls.stdout_data stderr_data = cls.stderr_data class Handler(BaseHTTPRequestHandler): - def do_POST(self): # pylint: disable=invalid-name + def do_POST(self) -> None: # pylint: disable=invalid-name resp_data = self.get_resp_data() self.send_response(101) self.send_header("Content-Type", "application/vnd.docker.raw-stream") @@ -491,7 +520,7 @@ class TCPSocketStreamTest(unittest.TestCase): self.wfile.write(resp_data) self.wfile.flush() - def get_resp_data(self): + def get_resp_data(self) -> bytes: path = self.path.split("/")[-1] if path == "tty": return stdout_data + stderr_data @@ -505,12 +534,17 @@ class TCPSocketStreamTest(unittest.TestCase): raise NotImplementedError(f"Unknown path {path}") @staticmethod - def frame_header(stream, data): + def frame_header(stream: int, data: bytes) -> bytes: return struct.pack(">BxxxL", stream, len(data)) return Handler - def request(self, stream=None, tty=None, demux=None): + def request( + self, + stream: bool | None = None, + tty: bool | None = None, + demux: bool | None = None, + ) -> t.Any: assert stream is not None and tty is not None and demux is not None with APIClient( base_url=self.address, @@ -523,51 +557,51 @@ class TCPSocketStreamTest(unittest.TestCase): resp = client._post(url, stream=True) return client._read_from_socket(resp, stream=stream, tty=tty, demux=demux) - def test_read_from_socket_tty(self): + def test_read_from_socket_tty(self) -> None: res = self.request(stream=True, tty=True, demux=False) assert next(res) == self.stdout_data + self.stderr_data with self.assertRaises(StopIteration): next(res) - def test_read_from_socket_tty_demux(self): + def test_read_from_socket_tty_demux(self) -> None: res = self.request(stream=True, tty=True, demux=True) assert next(res) == (self.stdout_data + self.stderr_data, None) with self.assertRaises(StopIteration): next(res) - def test_read_from_socket_no_tty(self): + def test_read_from_socket_no_tty(self) -> None: res = self.request(stream=True, tty=False, demux=False) assert next(res) == self.stdout_data assert next(res) == self.stderr_data with self.assertRaises(StopIteration): next(res) - def test_read_from_socket_no_tty_demux(self): + def test_read_from_socket_no_tty_demux(self) -> None: res = self.request(stream=True, tty=False, demux=True) assert (self.stdout_data, None) == next(res) assert (None, self.stderr_data) == next(res) with self.assertRaises(StopIteration): next(res) - def test_read_from_socket_no_stream_tty(self): + def test_read_from_socket_no_stream_tty(self) -> None: res = self.request(stream=False, tty=True, demux=False) assert res == self.stdout_data + self.stderr_data - def test_read_from_socket_no_stream_tty_demux(self): + def test_read_from_socket_no_stream_tty_demux(self) -> None: res = self.request(stream=False, tty=True, demux=True) assert res == (self.stdout_data + self.stderr_data, None) - def test_read_from_socket_no_stream_no_tty(self): + def test_read_from_socket_no_stream_no_tty(self) -> None: res = self.request(stream=False, tty=False, demux=False) assert res == self.stdout_data + self.stderr_data - def test_read_from_socket_no_stream_no_tty_demux(self): + def test_read_from_socket_no_stream_no_tty_demux(self) -> None: res = self.request(stream=False, tty=False, demux=True) assert res == (self.stdout_data, self.stderr_data) class UserAgentTest(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.patcher = mock.patch.object( APIClient, "send", @@ -575,10 +609,10 @@ class UserAgentTest(unittest.TestCase): ) self.mock_send = self.patcher.start() - def tearDown(self): + def tearDown(self) -> None: self.patcher.stop() - def test_default_user_agent(self): + def test_default_user_agent(self) -> None: client = APIClient(version=DEFAULT_DOCKER_API_VERSION) client.version() @@ -587,7 +621,7 @@ class UserAgentTest(unittest.TestCase): expected = "ansible-community.docker" assert headers["User-Agent"] == expected - def test_custom_user_agent(self): + def test_custom_user_agent(self) -> None: client = APIClient(user_agent="foo/bar", version=DEFAULT_DOCKER_API_VERSION) client.version() @@ -598,44 +632,44 @@ class UserAgentTest(unittest.TestCase): class DisableSocketTest(unittest.TestCase): class DummySocket: - def __init__(self, timeout=60): + def __init__(self, timeout: int | float | None = 60) -> None: self.timeout = timeout - self._sock = None + self._sock: t.Any = None - def settimeout(self, timeout): + def settimeout(self, timeout: int | float | None) -> None: self.timeout = timeout - def gettimeout(self): + def gettimeout(self) -> int | float | None: return self.timeout - def setUp(self): + def setUp(self) -> None: self.client = APIClient(version=DEFAULT_DOCKER_API_VERSION) - def test_disable_socket_timeout(self): + def test_disable_socket_timeout(self) -> None: """Test that the timeout is disabled on a generic socket object.""" the_socket = self.DummySocket() - self.client._disable_socket_timeout(the_socket) + self.client._disable_socket_timeout(the_socket) # type: ignore assert the_socket.timeout is None - def test_disable_socket_timeout2(self): + def test_disable_socket_timeout2(self) -> None: """Test that the timeouts are disabled on a generic socket object and it's _sock object if present.""" the_socket = self.DummySocket() - the_socket._sock = self.DummySocket() + the_socket._sock = self.DummySocket() # type: ignore - self.client._disable_socket_timeout(the_socket) + self.client._disable_socket_timeout(the_socket) # type: ignore assert the_socket.timeout is None assert the_socket._sock.timeout is None - def test_disable_socket_timout_non_blocking(self): + def test_disable_socket_timout_non_blocking(self) -> None: """Test that a non-blocking socket does not get set to blocking.""" the_socket = self.DummySocket() - the_socket._sock = self.DummySocket(0.0) + the_socket._sock = self.DummySocket(0.0) # type: ignore - self.client._disable_socket_timeout(the_socket) + self.client._disable_socket_timeout(the_socket) # type: ignore assert the_socket.timeout is None assert the_socket._sock.timeout == 0.0 diff --git a/tests/unit/plugins/module_utils/_api/fake_api.py b/tests/unit/plugins/module_utils/_api/fake_api.py index 41c5ed4c..006f97fc 100644 --- a/tests/unit/plugins/module_utils/_api/fake_api.py +++ b/tests/unit/plugins/module_utils/_api/fake_api.py @@ -8,6 +8,8 @@ from __future__ import annotations +import typing as t + from ansible_collections.community.docker.plugins.module_utils._api import constants from ansible_collections.community.docker.tests.unit.plugins.module_utils._api.constants import ( DEFAULT_DOCKER_API_VERSION, @@ -16,6 +18,10 @@ from ansible_collections.community.docker.tests.unit.plugins.module_utils._api.c from . import fake_stat +if t.TYPE_CHECKING: + from collections.abc import Callable + + CURRENT_VERSION = f"v{DEFAULT_DOCKER_API_VERSION}" FAKE_CONTAINER_ID = "3cc2351ab11b" @@ -38,7 +44,7 @@ FAKE_SECRET_NAME = "super_secret" # for clarity and readability -def get_fake_version(): +def get_fake_version() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = { "ApiVersion": "1.35", @@ -73,7 +79,7 @@ def get_fake_version(): return status_code, response -def get_fake_info(): +def get_fake_info() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = { "Containers": 1, @@ -86,23 +92,23 @@ def get_fake_info(): return status_code, response -def post_fake_auth(): +def post_fake_auth() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = {"Status": "Login Succeeded", "IdentityToken": "9cbaf023786cd7"} return status_code, response -def get_fake_ping(): +def get_fake_ping() -> tuple[int, str]: return 200, "OK" -def get_fake_search(): +def get_fake_search() -> tuple[int, list[dict[str, t.Any]]]: status_code = 200 response = [{"Name": "busybox", "Description": "Fake Description"}] return status_code, response -def get_fake_images(): +def get_fake_images() -> tuple[int, list[dict[str, t.Any]]]: status_code = 200 response = [ { @@ -115,7 +121,7 @@ def get_fake_images(): return status_code, response -def get_fake_image_history(): +def get_fake_image_history() -> tuple[int, list[dict[str, t.Any]]]: status_code = 200 response = [ {"Id": "b750fe79269d", "Created": 1364102658, "CreatedBy": "/bin/bash"}, @@ -125,14 +131,14 @@ def get_fake_image_history(): return status_code, response -def post_fake_import_image(): +def post_fake_import_image() -> tuple[int, str]: status_code = 200 response = "Import messages..." return status_code, response -def get_fake_containers(): +def get_fake_containers() -> tuple[int, list[dict[str, t.Any]]]: status_code = 200 response = [ { @@ -146,25 +152,25 @@ def get_fake_containers(): return status_code, response -def post_fake_start_container(): +def post_fake_start_container() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = {"Id": FAKE_CONTAINER_ID} return status_code, response -def post_fake_resize_container(): +def post_fake_resize_container() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = {"Id": FAKE_CONTAINER_ID} return status_code, response -def post_fake_create_container(): +def post_fake_create_container() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = {"Id": FAKE_CONTAINER_ID} return status_code, response -def get_fake_inspect_container(tty=False): +def get_fake_inspect_container(tty: bool = False) -> tuple[int, dict[str, t.Any]]: status_code = 200 response = { "Id": FAKE_CONTAINER_ID, @@ -188,7 +194,7 @@ def get_fake_inspect_container(tty=False): return status_code, response -def get_fake_inspect_image(): +def get_fake_inspect_image() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = { "Id": FAKE_IMAGE_ID, @@ -221,19 +227,19 @@ def get_fake_inspect_image(): return status_code, response -def get_fake_insert_image(): +def get_fake_insert_image() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = {"StatusCode": 0} return status_code, response -def get_fake_wait(): +def get_fake_wait() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = {"StatusCode": 0} return status_code, response -def get_fake_logs(): +def get_fake_logs() -> tuple[int, bytes]: status_code = 200 response = ( b"\x01\x00\x00\x00\x00\x00\x00\x00" @@ -244,13 +250,13 @@ def get_fake_logs(): return status_code, response -def get_fake_diff(): +def get_fake_diff() -> tuple[int, list[dict[str, t.Any]]]: status_code = 200 response = [{"Path": "/test", "Kind": 1}] return status_code, response -def get_fake_events(): +def get_fake_events() -> tuple[int, list[dict[str, t.Any]]]: status_code = 200 response = [ { @@ -263,19 +269,19 @@ def get_fake_events(): return status_code, response -def get_fake_export(): +def get_fake_export() -> tuple[int, str]: status_code = 200 response = "Byte Stream...." return status_code, response -def post_fake_exec_create(): +def post_fake_exec_create() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = {"Id": FAKE_EXEC_ID} return status_code, response -def post_fake_exec_start(): +def post_fake_exec_start() -> tuple[int, bytes]: status_code = 200 response = ( b"\x01\x00\x00\x00\x00\x00\x00\x11bin\nboot\ndev\netc\n" @@ -285,12 +291,12 @@ def post_fake_exec_start(): return status_code, response -def post_fake_exec_resize(): +def post_fake_exec_resize() -> tuple[int, str]: status_code = 201 return status_code, "" -def get_fake_exec_inspect(): +def get_fake_exec_inspect() -> tuple[int, dict[str, t.Any]]: return 200, { "OpenStderr": True, "OpenStdout": True, @@ -309,102 +315,102 @@ def get_fake_exec_inspect(): } -def post_fake_stop_container(): +def post_fake_stop_container() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = {"Id": FAKE_CONTAINER_ID} return status_code, response -def post_fake_kill_container(): +def post_fake_kill_container() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = {"Id": FAKE_CONTAINER_ID} return status_code, response -def post_fake_pause_container(): +def post_fake_pause_container() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = {"Id": FAKE_CONTAINER_ID} return status_code, response -def post_fake_unpause_container(): +def post_fake_unpause_container() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = {"Id": FAKE_CONTAINER_ID} return status_code, response -def post_fake_restart_container(): +def post_fake_restart_container() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = {"Id": FAKE_CONTAINER_ID} return status_code, response -def post_fake_rename_container(): +def post_fake_rename_container() -> tuple[int, None]: status_code = 204 return status_code, None -def delete_fake_remove_container(): +def delete_fake_remove_container() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = {"Id": FAKE_CONTAINER_ID} return status_code, response -def post_fake_image_create(): +def post_fake_image_create() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = {"Id": FAKE_IMAGE_ID} return status_code, response -def delete_fake_remove_image(): +def delete_fake_remove_image() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = {"Id": FAKE_IMAGE_ID} return status_code, response -def get_fake_get_image(): +def get_fake_get_image() -> tuple[int, str]: status_code = 200 response = "Byte Stream...." return status_code, response -def post_fake_load_image(): +def post_fake_load_image() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = {"Id": FAKE_IMAGE_ID} return status_code, response -def post_fake_commit(): +def post_fake_commit() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = {"Id": FAKE_CONTAINER_ID} return status_code, response -def post_fake_push(): +def post_fake_push() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = {"Id": FAKE_IMAGE_ID} return status_code, response -def post_fake_build_container(): +def post_fake_build_container() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = {"Id": FAKE_CONTAINER_ID} return status_code, response -def post_fake_tag_image(): +def post_fake_tag_image() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = {"Id": FAKE_IMAGE_ID} return status_code, response -def get_fake_stats(): +def get_fake_stats() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = fake_stat.OBJ return status_code, response -def get_fake_top(): +def get_fake_top() -> tuple[int, dict[str, t.Any]]: return 200, { "Processes": [ [ @@ -431,7 +437,7 @@ def get_fake_top(): } -def get_fake_volume_list(): +def get_fake_volume_list() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = { "Volumes": [ @@ -452,7 +458,7 @@ def get_fake_volume_list(): return status_code, response -def get_fake_volume(): +def get_fake_volume() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = { "Name": "perfectcherryblossom", @@ -464,23 +470,23 @@ def get_fake_volume(): return status_code, response -def fake_remove_volume(): +def fake_remove_volume() -> tuple[int, None]: return 204, None -def post_fake_update_container(): +def post_fake_update_container() -> tuple[int, dict[str, t.Any]]: return 200, {"Warnings": []} -def post_fake_update_node(): +def post_fake_update_node() -> tuple[int, None]: return 200, None -def post_fake_join_swarm(): +def post_fake_join_swarm() -> tuple[int, None]: return 200, None -def get_fake_network_list(): +def get_fake_network_list() -> tuple[int, list[dict[str, t.Any]]]: return 200, [ { "Name": "bridge", @@ -510,27 +516,27 @@ def get_fake_network_list(): ] -def get_fake_network(): +def get_fake_network() -> tuple[int, dict[str, t.Any]]: return 200, get_fake_network_list()[1][0] -def post_fake_network(): +def post_fake_network() -> tuple[int, dict[str, t.Any]]: return 201, {"Id": FAKE_NETWORK_ID, "Warnings": []} -def delete_fake_network(): +def delete_fake_network() -> tuple[int, None]: return 204, None -def post_fake_network_connect(): +def post_fake_network_connect() -> tuple[int, None]: return 200, None -def post_fake_network_disconnect(): +def post_fake_network_disconnect() -> tuple[int, None]: return 200, None -def post_fake_secret(): +def post_fake_secret() -> tuple[int, dict[str, t.Any]]: status_code = 200 response = {"ID": FAKE_SECRET_ID} return status_code, response @@ -541,7 +547,7 @@ prefix = "http+docker://localhost" # pylint: disable=invalid-name if constants.IS_WINDOWS_PLATFORM: prefix = "http+docker://localnpipe" # pylint: disable=invalid-name -fake_responses = { +fake_responses: dict[str | tuple[str, str], Callable] = { f"{prefix}/version": get_fake_version, f"{prefix}/{CURRENT_VERSION}/version": get_fake_version, f"{prefix}/{CURRENT_VERSION}/info": get_fake_info, @@ -574,6 +580,7 @@ fake_responses = { f"{prefix}/{CURRENT_VERSION}/containers/3cc2351ab11b/unpause": post_fake_unpause_container, f"{prefix}/{CURRENT_VERSION}/containers/3cc2351ab11b/restart": post_fake_restart_container, f"{prefix}/{CURRENT_VERSION}/containers/3cc2351ab11b": delete_fake_remove_container, + # TODO: the following is a duplicate of the import endpoint further above! f"{prefix}/{CURRENT_VERSION}/images/create": post_fake_image_create, f"{prefix}/{CURRENT_VERSION}/images/e9aa60c60128": delete_fake_remove_image, f"{prefix}/{CURRENT_VERSION}/images/e9aa60c60128/get": get_fake_get_image, diff --git a/tests/unit/plugins/module_utils/_api/test_auth.py b/tests/unit/plugins/module_utils/_api/test_auth.py index f2fbe4c2..38bba014 100644 --- a/tests/unit/plugins/module_utils/_api/test_auth.py +++ b/tests/unit/plugins/module_utils/_api/test_auth.py @@ -15,6 +15,7 @@ import os.path import random import shutil import tempfile +import typing as t import unittest from unittest import mock @@ -30,7 +31,7 @@ from ansible_collections.community.docker.plugins.module_utils._api.credentials. class RegressionTest(unittest.TestCase): - def test_803_urlsafe_encode(self): + def test_803_urlsafe_encode(self) -> None: auth_data = {"username": "root", "password": "GR?XGR?XGR?XGR?X"} encoded = auth.encode_header(auth_data) assert b"/" not in encoded @@ -38,75 +39,75 @@ class RegressionTest(unittest.TestCase): class ResolveRepositoryNameTest(unittest.TestCase): - def test_resolve_repository_name_hub_library_image(self): + def test_resolve_repository_name_hub_library_image(self) -> None: assert auth.resolve_repository_name("image") == ("docker.io", "image") - def test_resolve_repository_name_dotted_hub_library_image(self): + def test_resolve_repository_name_dotted_hub_library_image(self) -> None: assert auth.resolve_repository_name("image.valid") == ( "docker.io", "image.valid", ) - def test_resolve_repository_name_hub_image(self): + def test_resolve_repository_name_hub_image(self) -> None: assert auth.resolve_repository_name("username/image") == ( "docker.io", "username/image", ) - def test_explicit_hub_index_library_image(self): + def test_explicit_hub_index_library_image(self) -> None: assert auth.resolve_repository_name("docker.io/image") == ("docker.io", "image") - def test_explicit_legacy_hub_index_library_image(self): + def test_explicit_legacy_hub_index_library_image(self) -> None: assert auth.resolve_repository_name("index.docker.io/image") == ( "docker.io", "image", ) - def test_resolve_repository_name_private_registry(self): + def test_resolve_repository_name_private_registry(self) -> None: assert auth.resolve_repository_name("my.registry.net/image") == ( "my.registry.net", "image", ) - def test_resolve_repository_name_private_registry_with_port(self): + def test_resolve_repository_name_private_registry_with_port(self) -> None: assert auth.resolve_repository_name("my.registry.net:5000/image") == ( "my.registry.net:5000", "image", ) - def test_resolve_repository_name_private_registry_with_username(self): + def test_resolve_repository_name_private_registry_with_username(self) -> None: assert auth.resolve_repository_name("my.registry.net/username/image") == ( "my.registry.net", "username/image", ) - def test_resolve_repository_name_no_dots_but_port(self): + def test_resolve_repository_name_no_dots_but_port(self) -> None: assert auth.resolve_repository_name("hostname:5000/image") == ( "hostname:5000", "image", ) - def test_resolve_repository_name_no_dots_but_port_and_username(self): + def test_resolve_repository_name_no_dots_but_port_and_username(self) -> None: assert auth.resolve_repository_name("hostname:5000/username/image") == ( "hostname:5000", "username/image", ) - def test_resolve_repository_name_localhost(self): + def test_resolve_repository_name_localhost(self) -> None: assert auth.resolve_repository_name("localhost/image") == ("localhost", "image") - def test_resolve_repository_name_localhost_with_username(self): + def test_resolve_repository_name_localhost_with_username(self) -> None: assert auth.resolve_repository_name("localhost/username/image") == ( "localhost", "username/image", ) - def test_invalid_index_name(self): + def test_invalid_index_name(self) -> None: with pytest.raises(errors.InvalidRepository): auth.resolve_repository_name("-gecko.com/image") -def encode_auth(auth_info): +def encode_auth(auth_info: dict[str, t.Any]) -> bytes: return base64.b64encode( auth_info.get("username", "").encode("utf-8") + b":" @@ -131,129 +132,105 @@ class ResolveAuthTest(unittest.TestCase): } ) - def test_resolve_authconfig_hostname_only(self): - assert ( - auth.resolve_authconfig(self.auth_config, "my.registry.net")["username"] - == "privateuser" - ) + def test_resolve_authconfig_hostname_only(self) -> None: + ac = auth.resolve_authconfig(self.auth_config, "my.registry.net") + assert ac is not None + assert ac["username"] == "privateuser" - def test_resolve_authconfig_no_protocol(self): - assert ( - auth.resolve_authconfig(self.auth_config, "my.registry.net/v1/")["username"] - == "privateuser" - ) + def test_resolve_authconfig_no_protocol(self) -> None: + ac = auth.resolve_authconfig(self.auth_config, "my.registry.net/v1/") + assert ac is not None + assert ac["username"] == "privateuser" - def test_resolve_authconfig_no_path(self): - assert ( - auth.resolve_authconfig(self.auth_config, "http://my.registry.net")[ - "username" - ] - == "privateuser" - ) + def test_resolve_authconfig_no_path(self) -> None: + ac = auth.resolve_authconfig(self.auth_config, "http://my.registry.net") + assert ac is not None + assert ac["username"] == "privateuser" - def test_resolve_authconfig_no_path_trailing_slash(self): - assert ( - auth.resolve_authconfig(self.auth_config, "http://my.registry.net/")[ - "username" - ] - == "privateuser" - ) + def test_resolve_authconfig_no_path_trailing_slash(self) -> None: + ac = auth.resolve_authconfig(self.auth_config, "http://my.registry.net/") + assert ac is not None + assert ac["username"] == "privateuser" - def test_resolve_authconfig_no_path_wrong_secure_proto(self): - assert ( - auth.resolve_authconfig(self.auth_config, "https://my.registry.net")[ - "username" - ] - == "privateuser" - ) + def test_resolve_authconfig_no_path_wrong_secure_proto(self) -> None: + ac = auth.resolve_authconfig(self.auth_config, "https://my.registry.net") + assert ac is not None + assert ac["username"] == "privateuser" - def test_resolve_authconfig_no_path_wrong_insecure_proto(self): - assert ( - auth.resolve_authconfig(self.auth_config, "http://index.docker.io")[ - "username" - ] - == "indexuser" - ) + def test_resolve_authconfig_no_path_wrong_insecure_proto(self) -> None: + ac = auth.resolve_authconfig(self.auth_config, "http://index.docker.io") + assert ac is not None + assert ac["username"] == "indexuser" - def test_resolve_authconfig_path_wrong_proto(self): - assert ( - auth.resolve_authconfig(self.auth_config, "https://my.registry.net/v1/")[ - "username" - ] - == "privateuser" - ) + def test_resolve_authconfig_path_wrong_proto(self) -> None: + ac = auth.resolve_authconfig(self.auth_config, "https://my.registry.net/v1/") + assert ac is not None + assert ac["username"] == "privateuser" - def test_resolve_authconfig_default_registry(self): - assert auth.resolve_authconfig(self.auth_config)["username"] == "indexuser" + def test_resolve_authconfig_default_registry(self) -> None: + ac = auth.resolve_authconfig(self.auth_config) + assert ac is not None + assert ac["username"] == "indexuser" - def test_resolve_authconfig_default_explicit_none(self): - assert ( - auth.resolve_authconfig(self.auth_config, None)["username"] == "indexuser" - ) + def test_resolve_authconfig_default_explicit_none(self) -> None: + ac = auth.resolve_authconfig(self.auth_config, None) + assert ac is not None + assert ac["username"] == "indexuser" - def test_resolve_authconfig_fully_explicit(self): - assert ( - auth.resolve_authconfig(self.auth_config, "http://my.registry.net/v1/")[ - "username" - ] - == "privateuser" - ) + def test_resolve_authconfig_fully_explicit(self) -> None: + ac = auth.resolve_authconfig(self.auth_config, "http://my.registry.net/v1/") + assert ac is not None + assert ac["username"] == "privateuser" - def test_resolve_authconfig_legacy_config(self): - assert ( - auth.resolve_authconfig(self.auth_config, "legacy.registry.url")["username"] - == "legacyauth" - ) + def test_resolve_authconfig_legacy_config(self) -> None: + ac = auth.resolve_authconfig(self.auth_config, "legacy.registry.url") + assert ac is not None + assert ac["username"] == "legacyauth" - def test_resolve_authconfig_no_match(self): + def test_resolve_authconfig_no_match(self) -> None: assert auth.resolve_authconfig(self.auth_config, "does.not.exist") is None - def test_resolve_registry_and_auth_library_image(self): + def test_resolve_registry_and_auth_library_image(self) -> None: image = "image" - assert ( - auth.resolve_authconfig( - self.auth_config, auth.resolve_repository_name(image)[0] - )["username"] - == "indexuser" + ac = auth.resolve_authconfig( + self.auth_config, auth.resolve_repository_name(image)[0] ) + assert ac is not None + assert ac["username"] == "indexuser" - def test_resolve_registry_and_auth_hub_image(self): + def test_resolve_registry_and_auth_hub_image(self) -> None: image = "username/image" - assert ( - auth.resolve_authconfig( - self.auth_config, auth.resolve_repository_name(image)[0] - )["username"] - == "indexuser" + ac = auth.resolve_authconfig( + self.auth_config, auth.resolve_repository_name(image)[0] ) + assert ac is not None + assert ac["username"] == "indexuser" - def test_resolve_registry_and_auth_explicit_hub(self): + def test_resolve_registry_and_auth_explicit_hub(self) -> None: image = "docker.io/username/image" - assert ( - auth.resolve_authconfig( - self.auth_config, auth.resolve_repository_name(image)[0] - )["username"] - == "indexuser" + ac = auth.resolve_authconfig( + self.auth_config, auth.resolve_repository_name(image)[0] ) + assert ac is not None + assert ac["username"] == "indexuser" - def test_resolve_registry_and_auth_explicit_legacy_hub(self): + def test_resolve_registry_and_auth_explicit_legacy_hub(self) -> None: image = "index.docker.io/username/image" - assert ( - auth.resolve_authconfig( - self.auth_config, auth.resolve_repository_name(image)[0] - )["username"] - == "indexuser" + ac = auth.resolve_authconfig( + self.auth_config, auth.resolve_repository_name(image)[0] ) + assert ac is not None + assert ac["username"] == "indexuser" - def test_resolve_registry_and_auth_private_registry(self): + def test_resolve_registry_and_auth_private_registry(self) -> None: image = "my.registry.net/image" - assert ( - auth.resolve_authconfig( - self.auth_config, auth.resolve_repository_name(image)[0] - )["username"] - == "privateuser" + ac = auth.resolve_authconfig( + self.auth_config, auth.resolve_repository_name(image)[0] ) + assert ac is not None + assert ac["username"] == "privateuser" - def test_resolve_registry_and_auth_unauthenticated_registry(self): + def test_resolve_registry_and_auth_unauthenticated_registry(self) -> None: image = "other.registry.net/image" assert ( auth.resolve_authconfig( @@ -262,7 +239,7 @@ class ResolveAuthTest(unittest.TestCase): is None ) - def test_resolve_auth_with_empty_credstore_and_auth_dict(self): + def test_resolve_auth_with_empty_credstore_and_auth_dict(self) -> None: auth_config = auth.AuthConfig( { "auths": auth.parse_auth( @@ -277,17 +254,19 @@ class ResolveAuthTest(unittest.TestCase): "ansible_collections.community.docker.plugins.module_utils._api.auth.AuthConfig._resolve_authconfig_credstore" ) as m: m.return_value = None - assert "indexuser" == auth.resolve_authconfig(auth_config, None)["username"] + ac = auth.resolve_authconfig(auth_config, None) + assert ac is not None + assert "indexuser" == ac["username"] class LoadConfigTest(unittest.TestCase): - def test_load_config_no_file(self): + def test_load_config_no_file(self) -> None: folder = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, folder) cfg = auth.load_config(folder) assert cfg is not None - def test_load_legacy_config(self): + def test_load_legacy_config(self) -> None: folder = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, folder) cfg_path = os.path.join(folder, ".dockercfg") @@ -299,13 +278,13 @@ class LoadConfigTest(unittest.TestCase): cfg = auth.load_config(cfg_path) assert auth.resolve_authconfig(cfg) is not None assert cfg.auths[auth.INDEX_NAME] is not None - cfg = cfg.auths[auth.INDEX_NAME] - assert cfg["username"] == "sakuya" - assert cfg["password"] == "izayoi" - assert cfg["email"] == "sakuya@scarlet.net" - assert cfg.get("Auth") is None + cfg2 = cfg.auths[auth.INDEX_NAME] + assert cfg2["username"] == "sakuya" + assert cfg2["password"] == "izayoi" + assert cfg2["email"] == "sakuya@scarlet.net" + assert cfg2.get("Auth") is None - def test_load_json_config(self): + def test_load_json_config(self) -> None: folder = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, folder) cfg_path = os.path.join(folder, ".dockercfg") @@ -316,13 +295,13 @@ class LoadConfigTest(unittest.TestCase): cfg = auth.load_config(cfg_path) assert auth.resolve_authconfig(cfg) is not None assert cfg.auths[auth.INDEX_URL] is not None - cfg = cfg.auths[auth.INDEX_URL] - assert cfg["username"] == "sakuya" - assert cfg["password"] == "izayoi" - assert cfg["email"] == email - assert cfg.get("Auth") is None + cfg2 = cfg.auths[auth.INDEX_URL] + assert cfg2["username"] == "sakuya" + assert cfg2["password"] == "izayoi" + assert cfg2["email"] == email + assert cfg2.get("Auth") is None - def test_load_modern_json_config(self): + def test_load_modern_json_config(self) -> None: folder = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, folder) cfg_path = os.path.join(folder, "config.json") @@ -333,12 +312,12 @@ class LoadConfigTest(unittest.TestCase): cfg = auth.load_config(cfg_path) assert auth.resolve_authconfig(cfg) is not None assert cfg.auths[auth.INDEX_URL] is not None - cfg = cfg.auths[auth.INDEX_URL] - assert cfg["username"] == "sakuya" - assert cfg["password"] == "izayoi" - assert cfg["email"] == email + cfg2 = cfg.auths[auth.INDEX_URL] + assert cfg2["username"] == "sakuya" + assert cfg2["password"] == "izayoi" + assert cfg2["email"] == email - def test_load_config_with_random_name(self): + def test_load_config_with_random_name(self) -> None: folder = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, folder) @@ -353,13 +332,13 @@ class LoadConfigTest(unittest.TestCase): cfg = auth.load_config(dockercfg_path).auths assert registry in cfg assert cfg[registry] is not None - cfg = cfg[registry] - assert cfg["username"] == "sakuya" - assert cfg["password"] == "izayoi" - assert cfg["email"] == "sakuya@scarlet.net" - assert cfg.get("auth") is None + cfg2 = cfg[registry] + assert cfg2["username"] == "sakuya" + assert cfg2["password"] == "izayoi" + assert cfg2["email"] == "sakuya@scarlet.net" + assert cfg2.get("auth") is None - def test_load_config_custom_config_env(self): + def test_load_config_custom_config_env(self) -> None: folder = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, folder) @@ -375,13 +354,13 @@ class LoadConfigTest(unittest.TestCase): cfg = auth.load_config(None).auths assert registry in cfg assert cfg[registry] is not None - cfg = cfg[registry] - assert cfg["username"] == "sakuya" - assert cfg["password"] == "izayoi" - assert cfg["email"] == "sakuya@scarlet.net" - assert cfg.get("auth") is None + cfg2 = cfg[registry] + assert cfg2["username"] == "sakuya" + assert cfg2["password"] == "izayoi" + assert cfg2["email"] == "sakuya@scarlet.net" + assert cfg2.get("auth") is None - def test_load_config_custom_config_env_with_auths(self): + def test_load_config_custom_config_env_with_auths(self) -> None: folder = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, folder) @@ -398,13 +377,13 @@ class LoadConfigTest(unittest.TestCase): with mock.patch.dict(os.environ, {"DOCKER_CONFIG": folder}): cfg = auth.load_config(None) assert registry in cfg.auths - cfg = cfg.auths[registry] - assert cfg["username"] == "sakuya" - assert cfg["password"] == "izayoi" - assert cfg["email"] == "sakuya@scarlet.net" - assert cfg.get("auth") is None + cfg2 = cfg.auths[registry] + assert cfg2["username"] == "sakuya" + assert cfg2["password"] == "izayoi" + assert cfg2["email"] == "sakuya@scarlet.net" + assert cfg2.get("auth") is None - def test_load_config_custom_config_env_utf8(self): + def test_load_config_custom_config_env_utf8(self) -> None: folder = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, folder) @@ -421,13 +400,13 @@ class LoadConfigTest(unittest.TestCase): with mock.patch.dict(os.environ, {"DOCKER_CONFIG": folder}): cfg = auth.load_config(None) assert registry in cfg.auths - cfg = cfg.auths[registry] - assert cfg["username"] == b"sakuya\xc3\xa6".decode("utf8") - assert cfg["password"] == b"izayoi\xc3\xa6".decode("utf8") - assert cfg["email"] == "sakuya@scarlet.net" - assert cfg.get("auth") is None + cfg2 = cfg.auths[registry] + assert cfg2["username"] == b"sakuya\xc3\xa6".decode("utf8") + assert cfg2["password"] == b"izayoi\xc3\xa6".decode("utf8") + assert cfg2["email"] == "sakuya@scarlet.net" + assert cfg2.get("auth") is None - def test_load_config_unknown_keys(self): + def test_load_config_unknown_keys(self) -> None: folder = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, folder) dockercfg_path = os.path.join(folder, "config.json") @@ -438,7 +417,7 @@ class LoadConfigTest(unittest.TestCase): cfg = auth.load_config(dockercfg_path) assert dict(cfg) == {"auths": {}} - def test_load_config_invalid_auth_dict(self): + def test_load_config_invalid_auth_dict(self) -> None: folder = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, folder) dockercfg_path = os.path.join(folder, "config.json") @@ -449,7 +428,7 @@ class LoadConfigTest(unittest.TestCase): cfg = auth.load_config(dockercfg_path) assert dict(cfg) == {"auths": {"scarlet.net": {}}} - def test_load_config_identity_token(self): + def test_load_config_identity_token(self) -> None: folder = tempfile.mkdtemp() registry = "scarlet.net" token = "1ce1cebb-503e-7043-11aa-7feb8bd4a1ce" @@ -462,13 +441,13 @@ class LoadConfigTest(unittest.TestCase): cfg = auth.load_config(dockercfg_path) assert registry in cfg.auths - cfg = cfg.auths[registry] - assert "IdentityToken" in cfg - assert cfg["IdentityToken"] == token + cfg2 = cfg.auths[registry] + assert "IdentityToken" in cfg2 + assert cfg2["IdentityToken"] == token class CredstoreTest(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.authconfig = auth.AuthConfig({"credsStore": "default"}) self.default_store = InMemoryStore("default") self.authconfig._stores["default"] = self.default_store @@ -483,7 +462,7 @@ class CredstoreTest(unittest.TestCase): "hunter2", ) - def test_get_credential_store(self): + def test_get_credential_store(self) -> None: auth_config = auth.AuthConfig( { "credHelpers": { @@ -498,7 +477,7 @@ class CredstoreTest(unittest.TestCase): assert auth_config.get_credential_store("registry2.io") == "powerlock" assert auth_config.get_credential_store("registry3.io") == "blackbox" - def test_get_credential_store_no_default(self): + def test_get_credential_store_no_default(self) -> None: auth_config = auth.AuthConfig( { "credHelpers": { @@ -510,7 +489,7 @@ class CredstoreTest(unittest.TestCase): assert auth_config.get_credential_store("registry2.io") == "powerlock" assert auth_config.get_credential_store("registry3.io") is None - def test_get_credential_store_default_index(self): + def test_get_credential_store_default_index(self) -> None: auth_config = auth.AuthConfig( { "credHelpers": {"https://index.docker.io/v1/": "powerlock"}, @@ -522,7 +501,7 @@ class CredstoreTest(unittest.TestCase): assert auth_config.get_credential_store("docker.io") == "powerlock" assert auth_config.get_credential_store("images.io") == "truesecret" - def test_get_credential_store_with_plain_dict(self): + def test_get_credential_store_with_plain_dict(self) -> None: auth_config = { "credHelpers": {"registry1.io": "truesecret", "registry2.io": "powerlock"}, "credsStore": "blackbox", @@ -532,7 +511,7 @@ class CredstoreTest(unittest.TestCase): assert auth.get_credential_store(auth_config, "registry2.io") == "powerlock" assert auth.get_credential_store(auth_config, "registry3.io") == "blackbox" - def test_get_all_credentials_credstore_only(self): + def test_get_all_credentials_credstore_only(self) -> None: assert self.authconfig.get_all_credentials() == { "https://gensokyo.jp/v2": { "Username": "sakuya", @@ -556,7 +535,7 @@ class CredstoreTest(unittest.TestCase): }, } - def test_get_all_credentials_with_empty_credhelper(self): + def test_get_all_credentials_with_empty_credhelper(self) -> None: self.authconfig["credHelpers"] = { "registry1.io": "truesecret", } @@ -585,7 +564,7 @@ class CredstoreTest(unittest.TestCase): "registry1.io": None, } - def test_get_all_credentials_with_credhelpers_only(self): + def test_get_all_credentials_with_credhelpers_only(self) -> None: del self.authconfig["credsStore"] assert self.authconfig.get_all_credentials() == {} @@ -617,7 +596,7 @@ class CredstoreTest(unittest.TestCase): }, } - def test_get_all_credentials_with_auths_entries(self): + def test_get_all_credentials_with_auths_entries(self) -> None: self.authconfig.add_auth( "registry1.io", { @@ -655,7 +634,7 @@ class CredstoreTest(unittest.TestCase): }, } - def test_get_all_credentials_with_empty_auths_entry(self): + def test_get_all_credentials_with_empty_auths_entry(self) -> None: self.authconfig.add_auth("default.com", {}) assert self.authconfig.get_all_credentials() == { @@ -681,7 +660,7 @@ class CredstoreTest(unittest.TestCase): }, } - def test_get_all_credentials_credstore_overrides_auth_entry(self): + def test_get_all_credentials_credstore_overrides_auth_entry(self) -> None: self.authconfig.add_auth( "default.com", { @@ -714,7 +693,7 @@ class CredstoreTest(unittest.TestCase): }, } - def test_get_all_credentials_helpers_override_default(self): + def test_get_all_credentials_helpers_override_default(self) -> None: self.authconfig["credHelpers"] = { "https://default.com/v2": "truesecret", } @@ -744,7 +723,7 @@ class CredstoreTest(unittest.TestCase): }, } - def test_get_all_credentials_3_sources(self): + def test_get_all_credentials_3_sources(self) -> None: self.authconfig["credHelpers"] = { "registry1.io": "truesecret", } @@ -795,24 +774,27 @@ class CredstoreTest(unittest.TestCase): class InMemoryStore(Store): - def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called - self.__store = {} + def __init__( # pylint: disable=super-init-not-called + self, *args: t.Any, **kwargs: t.Any + ) -> None: + self.__store: dict[str | bytes, dict[str, t.Any]] = {} - def get(self, server): + def get(self, server: str | bytes) -> dict[str, t.Any]: try: return self.__store[server] except KeyError: raise CredentialsNotFound() from None - def store(self, server, username, secret): + def store(self, server: str, username: str, secret: str) -> bytes: self.__store[server] = { "ServerURL": server, "Username": username, "Secret": secret, } + return b"" - def list(self): + def list(self) -> dict[str | bytes, str]: return dict((k, v["Username"]) for k, v in self.__store.items()) - def erase(self, server): + def erase(self, server: str | bytes) -> None: del self.__store[server] diff --git a/tests/unit/plugins/module_utils/_api/test_context.py b/tests/unit/plugins/module_utils/_api/test_context.py index bc1cc13e..56262236 100644 --- a/tests/unit/plugins/module_utils/_api/test_context.py +++ b/tests/unit/plugins/module_utils/_api/test_context.py @@ -28,20 +28,20 @@ from ansible_collections.community.docker.plugins.module_utils._api.context.cont class BaseContextTest(unittest.TestCase): @pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="Linux specific path check") - def test_url_compatibility_on_linux(self): + def test_url_compatibility_on_linux(self) -> None: c = Context("test") assert c.Host == DEFAULT_UNIX_SOCKET[5:] @pytest.mark.skipif(not IS_WINDOWS_PLATFORM, reason="Windows specific path check") - def test_url_compatibility_on_windows(self): + def test_url_compatibility_on_windows(self) -> None: c = Context("test") assert c.Host == DEFAULT_NPIPE - def test_fail_on_default_context_create(self): + def test_fail_on_default_context_create(self) -> None: with pytest.raises(errors.ContextException): ContextAPI.create_context("default") - def test_default_in_context_list(self): + def test_default_in_context_list(self) -> None: found = False ctx = ContextAPI.contexts() for c in ctx: @@ -49,14 +49,16 @@ class BaseContextTest(unittest.TestCase): found = True assert found is True - def test_get_current_context(self): - assert ContextAPI.get_current_context().Name == "default" + def test_get_current_context(self) -> None: + context = ContextAPI.get_current_context() + assert context is not None + assert context.Name == "default" - def test_https_host(self): + def test_https_host(self) -> None: c = Context("test", host="tcp://testdomain:8080", tls=True) assert c.Host == "https://testdomain:8080" - def test_context_inspect_without_params(self): + def test_context_inspect_without_params(self) -> None: ctx = ContextAPI.inspect_context() assert ctx["Name"] == "default" assert ctx["Metadata"]["StackOrchestrator"] == "swarm" diff --git a/tests/unit/plugins/module_utils/_api/test_errors.py b/tests/unit/plugins/module_utils/_api/test_errors.py index cb9afd17..0b16a700 100644 --- a/tests/unit/plugins/module_utils/_api/test_errors.py +++ b/tests/unit/plugins/module_utils/_api/test_errors.py @@ -21,97 +21,97 @@ from ansible_collections.community.docker.plugins.module_utils._api.errors impor class APIErrorTest(unittest.TestCase): - def test_api_error_is_caught_by_dockerexception(self): + def test_api_error_is_caught_by_dockerexception(self) -> None: try: raise APIError("this should be caught by DockerException") except DockerException: pass - def test_status_code_200(self): + def test_status_code_200(self) -> None: """The status_code property is present with 200 response.""" resp = requests.Response() resp.status_code = 200 err = APIError("", response=resp) assert err.status_code == 200 - def test_status_code_400(self): + def test_status_code_400(self) -> None: """The status_code property is present with 400 response.""" resp = requests.Response() resp.status_code = 400 err = APIError("", response=resp) assert err.status_code == 400 - def test_status_code_500(self): + def test_status_code_500(self) -> None: """The status_code property is present with 500 response.""" resp = requests.Response() resp.status_code = 500 err = APIError("", response=resp) assert err.status_code == 500 - def test_is_server_error_200(self): + def test_is_server_error_200(self) -> None: """Report not server error on 200 response.""" resp = requests.Response() resp.status_code = 200 err = APIError("", response=resp) assert err.is_server_error() is False - def test_is_server_error_300(self): + def test_is_server_error_300(self) -> None: """Report not server error on 300 response.""" resp = requests.Response() resp.status_code = 300 err = APIError("", response=resp) assert err.is_server_error() is False - def test_is_server_error_400(self): + def test_is_server_error_400(self) -> None: """Report not server error on 400 response.""" resp = requests.Response() resp.status_code = 400 err = APIError("", response=resp) assert err.is_server_error() is False - def test_is_server_error_500(self): + def test_is_server_error_500(self) -> None: """Report server error on 500 response.""" resp = requests.Response() resp.status_code = 500 err = APIError("", response=resp) assert err.is_server_error() is True - def test_is_client_error_500(self): + def test_is_client_error_500(self) -> None: """Report not client error on 500 response.""" resp = requests.Response() resp.status_code = 500 err = APIError("", response=resp) assert err.is_client_error() is False - def test_is_client_error_400(self): + def test_is_client_error_400(self) -> None: """Report client error on 400 response.""" resp = requests.Response() resp.status_code = 400 err = APIError("", response=resp) assert err.is_client_error() is True - def test_is_error_300(self): + def test_is_error_300(self) -> None: """Report no error on 300 response.""" resp = requests.Response() resp.status_code = 300 err = APIError("", response=resp) assert err.is_error() is False - def test_is_error_400(self): + def test_is_error_400(self) -> None: """Report error on 400 response.""" resp = requests.Response() resp.status_code = 400 err = APIError("", response=resp) assert err.is_error() is True - def test_is_error_500(self): + def test_is_error_500(self) -> None: """Report error on 500 response.""" resp = requests.Response() resp.status_code = 500 err = APIError("", response=resp) assert err.is_error() is True - def test_create_error_from_exception(self): + def test_create_error_from_exception(self) -> None: resp = requests.Response() resp.status_code = 500 err = APIError("") @@ -126,10 +126,10 @@ class APIErrorTest(unittest.TestCase): class CreateUnexpectedKwargsErrorTest(unittest.TestCase): - def test_create_unexpected_kwargs_error_single(self): + def test_create_unexpected_kwargs_error_single(self) -> None: e = create_unexpected_kwargs_error("f", {"foo": "bar"}) assert str(e) == "f() got an unexpected keyword argument 'foo'" - def test_create_unexpected_kwargs_error_multiple(self): + def test_create_unexpected_kwargs_error_multiple(self) -> None: e = create_unexpected_kwargs_error("f", {"foo": "bar", "baz": "bosh"}) assert str(e) == "f() got unexpected keyword arguments 'baz', 'foo'" diff --git a/tests/unit/plugins/module_utils/_api/transport/test_sshconn.py b/tests/unit/plugins/module_utils/_api/transport/test_sshconn.py index 4fb18ceb..ce97478f 100644 --- a/tests/unit/plugins/module_utils/_api/transport/test_sshconn.py +++ b/tests/unit/plugins/module_utils/_api/transport/test_sshconn.py @@ -18,33 +18,33 @@ from ansible_collections.community.docker.plugins.module_utils._api.transport.ss class SSHAdapterTest(unittest.TestCase): @staticmethod - def test_ssh_hostname_prefix_trim(): + def test_ssh_hostname_prefix_trim() -> None: conn = SSHHTTPAdapter(base_url="ssh://user@hostname:1234", shell_out=True) assert conn.ssh_host == "user@hostname:1234" @staticmethod - def test_ssh_parse_url(): + def test_ssh_parse_url() -> None: c = SSHSocket(host="user@hostname:1234") assert c.host == "hostname" assert c.port == "1234" assert c.user == "user" @staticmethod - def test_ssh_parse_hostname_only(): + def test_ssh_parse_hostname_only() -> None: c = SSHSocket(host="hostname") assert c.host == "hostname" assert c.port is None assert c.user is None @staticmethod - def test_ssh_parse_user_and_hostname(): + def test_ssh_parse_user_and_hostname() -> None: c = SSHSocket(host="user@hostname") assert c.host == "hostname" assert c.port is None assert c.user == "user" @staticmethod - def test_ssh_parse_hostname_and_port(): + def test_ssh_parse_hostname_and_port() -> None: c = SSHSocket(host="hostname:22") assert c.host == "hostname" assert c.port == "22" diff --git a/tests/unit/plugins/module_utils/_api/transport/test_ssladapter.py b/tests/unit/plugins/module_utils/_api/transport/test_ssladapter.py index f1ea023f..78f4dc1b 100644 --- a/tests/unit/plugins/module_utils/_api/transport/test_ssladapter.py +++ b/tests/unit/plugins/module_utils/_api/transport/test_ssladapter.py @@ -27,7 +27,7 @@ else: class SSLAdapterTest(unittest.TestCase): - def test_only_uses_tls(self): + def test_only_uses_tls(self) -> None: ssl_context = ssladapter.urllib3.util.ssl_.create_urllib3_context() assert ssl_context.options & OP_NO_SSLv3 @@ -68,19 +68,19 @@ class MatchHostnameTest(unittest.TestCase): "version": 3, } - def test_match_ip_address_success(self): + def test_match_ip_address_success(self) -> None: assert match_hostname(self.cert, "127.0.0.1") is None - def test_match_localhost_success(self): + def test_match_localhost_success(self) -> None: assert match_hostname(self.cert, "localhost") is None - def test_match_dns_success(self): + def test_match_dns_success(self) -> None: assert match_hostname(self.cert, "touhou.gensokyo.jp") is None - def test_match_ip_address_failure(self): + def test_match_ip_address_failure(self) -> None: with pytest.raises(CertificateError): match_hostname(self.cert, "192.168.0.25") - def test_match_dns_failure(self): + def test_match_dns_failure(self) -> None: with pytest.raises(CertificateError): match_hostname(self.cert, "foobar.co.uk") diff --git a/tests/unit/plugins/module_utils/_api/utils/test_build.py b/tests/unit/plugins/module_utils/_api/utils/test_build.py index 043bd958..0cca04fa 100644 --- a/tests/unit/plugins/module_utils/_api/utils/test_build.py +++ b/tests/unit/plugins/module_utils/_api/utils/test_build.py @@ -14,6 +14,7 @@ import shutil import socket import tarfile import tempfile +import typing as t import unittest import pytest @@ -27,7 +28,11 @@ from ansible_collections.community.docker.plugins.module_utils._api.utils.build ) -def make_tree(dirs, files): +if t.TYPE_CHECKING: + from collections.abc import Collection + + +def make_tree(dirs: list[str], files: list[str]) -> str: base = tempfile.mkdtemp() for path in dirs: @@ -40,11 +45,11 @@ def make_tree(dirs, files): return base -def convert_paths(collection): +def convert_paths(collection: Collection[str]) -> set[str]: return set(map(convert_path, collection)) -def convert_path(path): +def convert_path(path: str) -> str: return path.replace("/", os.path.sep) @@ -88,26 +93,26 @@ class ExcludePathsTest(unittest.TestCase): all_paths = set(dirs + files) - def setUp(self): + def setUp(self) -> None: self.base = make_tree(self.dirs, self.files) - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree(self.base) - def exclude(self, patterns, dockerfile=None): + def exclude(self, patterns: list[str], dockerfile: str | None = None) -> set[str]: return set(exclude_paths(self.base, patterns, dockerfile=dockerfile)) - def test_no_excludes(self): + def test_no_excludes(self) -> None: assert self.exclude([""]) == convert_paths(self.all_paths) - def test_no_dupes(self): + def test_no_dupes(self) -> None: paths = exclude_paths(self.base, ["!a.py"]) assert sorted(paths) == sorted(set(paths)) - def test_wildcard_exclude(self): + def test_wildcard_exclude(self) -> None: assert self.exclude(["*"]) == set(["Dockerfile", ".dockerignore"]) - def test_exclude_dockerfile_dockerignore(self): + def test_exclude_dockerfile_dockerignore(self) -> None: """ Even if the .dockerignore file explicitly says to exclude Dockerfile and/or .dockerignore, don't exclude them from @@ -117,7 +122,7 @@ class ExcludePathsTest(unittest.TestCase): self.all_paths ) - def test_exclude_custom_dockerfile(self): + def test_exclude_custom_dockerfile(self) -> None: """ If we're using a custom Dockerfile, make sure that's not excluded. @@ -135,33 +140,33 @@ class ExcludePathsTest(unittest.TestCase): set(["foo/Dockerfile3", ".dockerignore"]) ) - def test_exclude_dockerfile_child(self): + def test_exclude_dockerfile_child(self) -> None: includes = self.exclude(["foo/"], dockerfile="foo/Dockerfile3") assert convert_path("foo/Dockerfile3") in includes assert convert_path("foo/a.py") not in includes - def test_single_filename(self): + def test_single_filename(self) -> None: assert self.exclude(["a.py"]) == convert_paths(self.all_paths - set(["a.py"])) - def test_single_filename_leading_dot_slash(self): + def test_single_filename_leading_dot_slash(self) -> None: assert self.exclude(["./a.py"]) == convert_paths(self.all_paths - set(["a.py"])) # As odd as it sounds, a filename pattern with a trailing slash on the # end *will* result in that file being excluded. - def test_single_filename_trailing_slash(self): + def test_single_filename_trailing_slash(self) -> None: assert self.exclude(["a.py/"]) == convert_paths(self.all_paths - set(["a.py"])) - def test_wildcard_filename_start(self): + def test_wildcard_filename_start(self) -> None: assert self.exclude(["*.py"]) == convert_paths( self.all_paths - set(["a.py", "b.py", "cde.py"]) ) - def test_wildcard_with_exception(self): + def test_wildcard_with_exception(self) -> None: assert self.exclude(["*.py", "!b.py"]) == convert_paths( self.all_paths - set(["a.py", "cde.py"]) ) - def test_wildcard_with_wildcard_exception(self): + def test_wildcard_with_wildcard_exception(self) -> None: assert self.exclude(["*.*", "!*.go"]) == convert_paths( self.all_paths - set( @@ -174,51 +179,51 @@ class ExcludePathsTest(unittest.TestCase): ) ) - def test_wildcard_filename_end(self): + def test_wildcard_filename_end(self) -> None: assert self.exclude(["a.*"]) == convert_paths( self.all_paths - set(["a.py", "a.go"]) ) - def test_question_mark(self): + def test_question_mark(self) -> None: assert self.exclude(["?.py"]) == convert_paths( self.all_paths - set(["a.py", "b.py"]) ) - def test_single_subdir_single_filename(self): + def test_single_subdir_single_filename(self) -> None: assert self.exclude(["foo/a.py"]) == convert_paths( self.all_paths - set(["foo/a.py"]) ) - def test_single_subdir_single_filename_leading_slash(self): + def test_single_subdir_single_filename_leading_slash(self) -> None: assert self.exclude(["/foo/a.py"]) == convert_paths( self.all_paths - set(["foo/a.py"]) ) - def test_exclude_include_absolute_path(self): + def test_exclude_include_absolute_path(self) -> None: base = make_tree([], ["a.py", "b.py"]) assert exclude_paths(base, ["/*", "!/*.py"]) == set(["a.py", "b.py"]) - def test_single_subdir_with_path_traversal(self): + def test_single_subdir_with_path_traversal(self) -> None: assert self.exclude(["foo/whoops/../a.py"]) == convert_paths( self.all_paths - set(["foo/a.py"]) ) - def test_single_subdir_wildcard_filename(self): + def test_single_subdir_wildcard_filename(self) -> None: assert self.exclude(["foo/*.py"]) == convert_paths( self.all_paths - set(["foo/a.py", "foo/b.py"]) ) - def test_wildcard_subdir_single_filename(self): + def test_wildcard_subdir_single_filename(self) -> None: assert self.exclude(["*/a.py"]) == convert_paths( self.all_paths - set(["foo/a.py", "bar/a.py"]) ) - def test_wildcard_subdir_wildcard_filename(self): + def test_wildcard_subdir_wildcard_filename(self) -> None: assert self.exclude(["*/*.py"]) == convert_paths( self.all_paths - set(["foo/a.py", "foo/b.py", "bar/a.py"]) ) - def test_directory(self): + def test_directory(self) -> None: assert self.exclude(["foo"]) == convert_paths( self.all_paths - set( @@ -233,7 +238,7 @@ class ExcludePathsTest(unittest.TestCase): ) ) - def test_directory_with_trailing_slash(self): + def test_directory_with_trailing_slash(self) -> None: assert self.exclude(["foo"]) == convert_paths( self.all_paths - set( @@ -248,13 +253,13 @@ class ExcludePathsTest(unittest.TestCase): ) ) - def test_directory_with_single_exception(self): + def test_directory_with_single_exception(self) -> None: assert self.exclude(["foo", "!foo/bar/a.py"]) == convert_paths( self.all_paths - set(["foo/a.py", "foo/b.py", "foo", "foo/bar", "foo/Dockerfile3"]) ) - def test_directory_with_subdir_exception(self): + def test_directory_with_subdir_exception(self) -> None: assert self.exclude(["foo", "!foo/bar"]) == convert_paths( self.all_paths - set(["foo/a.py", "foo/b.py", "foo", "foo/Dockerfile3"]) ) @@ -262,17 +267,17 @@ class ExcludePathsTest(unittest.TestCase): @pytest.mark.skipif( not IS_WINDOWS_PLATFORM, reason="Backslash patterns only on Windows" ) - def test_directory_with_subdir_exception_win32_pathsep(self): + def test_directory_with_subdir_exception_win32_pathsep(self) -> None: assert self.exclude(["foo", "!foo\\bar"]) == convert_paths( self.all_paths - set(["foo/a.py", "foo/b.py", "foo", "foo/Dockerfile3"]) ) - def test_directory_with_wildcard_exception(self): + def test_directory_with_wildcard_exception(self) -> None: assert self.exclude(["foo", "!foo/*.py"]) == convert_paths( self.all_paths - set(["foo/bar", "foo/bar/a.py", "foo", "foo/Dockerfile3"]) ) - def test_subdirectory(self): + def test_subdirectory(self) -> None: assert self.exclude(["foo/bar"]) == convert_paths( self.all_paths - set(["foo/bar", "foo/bar/a.py"]) ) @@ -280,12 +285,12 @@ class ExcludePathsTest(unittest.TestCase): @pytest.mark.skipif( not IS_WINDOWS_PLATFORM, reason="Backslash patterns only on Windows" ) - def test_subdirectory_win32_pathsep(self): + def test_subdirectory_win32_pathsep(self) -> None: assert self.exclude(["foo\\bar"]) == convert_paths( self.all_paths - set(["foo/bar", "foo/bar/a.py"]) ) - def test_double_wildcard(self): + def test_double_wildcard(self) -> None: assert self.exclude(["**/a.py"]) == convert_paths( self.all_paths - set(["a.py", "foo/a.py", "foo/bar/a.py", "bar/a.py"]) ) @@ -294,7 +299,7 @@ class ExcludePathsTest(unittest.TestCase): self.all_paths - set(["foo/bar", "foo/bar/a.py"]) ) - def test_single_and_double_wildcard(self): + def test_single_and_double_wildcard(self) -> None: assert self.exclude(["**/target/*/*"]) == convert_paths( self.all_paths - set( @@ -306,7 +311,7 @@ class ExcludePathsTest(unittest.TestCase): ) ) - def test_trailing_double_wildcard(self): + def test_trailing_double_wildcard(self) -> None: assert self.exclude(["subdir/**"]) == convert_paths( self.all_paths - set( @@ -326,7 +331,7 @@ class ExcludePathsTest(unittest.TestCase): ) ) - def test_double_wildcard_with_exception(self): + def test_double_wildcard_with_exception(self) -> None: assert self.exclude(["**", "!bar", "!foo/bar"]) == convert_paths( set( [ @@ -340,13 +345,13 @@ class ExcludePathsTest(unittest.TestCase): ) ) - def test_include_wildcard(self): + def test_include_wildcard(self) -> None: # This may be surprising but it matches the CLI's behavior # (tested with 18.05.0-ce on linux) base = make_tree(["a"], ["a/b.py"]) assert exclude_paths(base, ["*", "!*/b.py"]) == set() - def test_last_line_precedence(self): + def test_last_line_precedence(self) -> None: base = make_tree( [], [ @@ -361,7 +366,7 @@ class ExcludePathsTest(unittest.TestCase): ["README.md", "README-bis.md"] ) - def test_parent_directory(self): + def test_parent_directory(self) -> None: base = make_tree([], ["a.py", "b.py", "c.py"]) # Dockerignore reference stipulates that absolute paths are # equivalent to relative paths, hence /../foo should be @@ -372,7 +377,7 @@ class ExcludePathsTest(unittest.TestCase): class TarTest(unittest.TestCase): - def test_tar_with_excludes(self): + def test_tar_with_excludes(self) -> None: dirs = [ "foo", "foo/bar", @@ -420,7 +425,7 @@ class TarTest(unittest.TestCase): with tarfile.open(fileobj=archive) as tar_data: assert sorted(tar_data.getnames()) == sorted(expected_names) - def test_tar_with_empty_directory(self): + def test_tar_with_empty_directory(self) -> None: base = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, base) for d in ["foo", "bar"]: @@ -433,7 +438,7 @@ class TarTest(unittest.TestCase): IS_WINDOWS_PLATFORM or os.geteuid() == 0, reason="root user always has access ; no chmod on Windows", ) - def test_tar_with_inaccessible_file(self): + def test_tar_with_inaccessible_file(self) -> None: base = tempfile.mkdtemp() full_path = os.path.join(base, "foo") self.addCleanup(shutil.rmtree, base) @@ -446,7 +451,7 @@ class TarTest(unittest.TestCase): assert f"Can not read file in context: {full_path}" in ei.exconly() @pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="No symlinks on Windows") - def test_tar_with_file_symlinks(self): + def test_tar_with_file_symlinks(self) -> None: base = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, base) with open(os.path.join(base, "foo"), "wt", encoding="utf-8") as f: @@ -458,7 +463,7 @@ class TarTest(unittest.TestCase): assert sorted(tar_data.getnames()) == ["bar", "bar/foo", "foo"] @pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="No symlinks on Windows") - def test_tar_with_directory_symlinks(self): + def test_tar_with_directory_symlinks(self) -> None: base = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, base) for d in ["foo", "bar"]: @@ -469,7 +474,7 @@ class TarTest(unittest.TestCase): assert sorted(tar_data.getnames()) == ["bar", "bar/foo", "foo"] @pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="No symlinks on Windows") - def test_tar_with_broken_symlinks(self): + def test_tar_with_broken_symlinks(self) -> None: base = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, base) for d in ["foo", "bar"]: @@ -481,7 +486,7 @@ class TarTest(unittest.TestCase): assert sorted(tar_data.getnames()) == ["bar", "bar/foo", "foo"] @pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="No UNIX sockets on Win32") - def test_tar_socket_file(self): + def test_tar_socket_file(self) -> None: base = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, base) for d in ["foo", "bar"]: @@ -493,7 +498,7 @@ class TarTest(unittest.TestCase): with tarfile.open(fileobj=archive) as tar_data: assert sorted(tar_data.getnames()) == ["bar", "foo"] - def tar_test_negative_mtime_bug(self): + def tar_test_negative_mtime_bug(self) -> None: base = tempfile.mkdtemp() filename = os.path.join(base, "th.txt") self.addCleanup(shutil.rmtree, base) @@ -506,7 +511,7 @@ class TarTest(unittest.TestCase): assert tar_data.getmember("th.txt").mtime == -3600 @pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="No symlinks on Windows") - def test_tar_directory_link(self): + def test_tar_directory_link(self) -> None: dirs = ["a", "b", "a/c"] files = ["a/hello.py", "b/utils.py", "a/c/descend.py"] base = make_tree(dirs, files) diff --git a/tests/unit/plugins/module_utils/_api/utils/test_config.py b/tests/unit/plugins/module_utils/_api/utils/test_config.py index cbe8bf0d..48350e32 100644 --- a/tests/unit/plugins/module_utils/_api/utils/test_config.py +++ b/tests/unit/plugins/module_utils/_api/utils/test_config.py @@ -12,6 +12,7 @@ import json import os import shutil import tempfile +import typing as t import unittest from collections.abc import Callable from unittest import mock @@ -25,55 +26,55 @@ class FindConfigFileTest(unittest.TestCase): mkdir: Callable[[str], os.PathLike[str]] @fixture(autouse=True) - def tmpdir(self, tmpdir): + def tmpdir(self, tmpdir: t.Any) -> None: self.mkdir = tmpdir.mkdir - def test_find_config_fallback(self): + def test_find_config_fallback(self) -> None: tmpdir = self.mkdir("test_find_config_fallback") with mock.patch.dict(os.environ, {"HOME": str(tmpdir)}): assert config.find_config_file() is None - def test_find_config_from_explicit_path(self): + def test_find_config_from_explicit_path(self) -> None: tmpdir = self.mkdir("test_find_config_from_explicit_path") - config_path = tmpdir.ensure("my-config-file.json") + config_path = tmpdir.ensure("my-config-file.json") # type: ignore[attr-defined] assert config.find_config_file(str(config_path)) == str(config_path) - def test_find_config_from_environment(self): + def test_find_config_from_environment(self) -> None: tmpdir = self.mkdir("test_find_config_from_environment") - config_path = tmpdir.ensure("config.json") + config_path = tmpdir.ensure("config.json") # type: ignore[attr-defined] with mock.patch.dict(os.environ, {"DOCKER_CONFIG": str(tmpdir)}): assert config.find_config_file() == str(config_path) @mark.skipif("sys.platform == 'win32'") - def test_find_config_from_home_posix(self): + def test_find_config_from_home_posix(self) -> None: tmpdir = self.mkdir("test_find_config_from_home_posix") - config_path = tmpdir.ensure(".docker", "config.json") + config_path = tmpdir.ensure(".docker", "config.json") # type: ignore[attr-defined] with mock.patch.dict(os.environ, {"HOME": str(tmpdir)}): assert config.find_config_file() == str(config_path) @mark.skipif("sys.platform == 'win32'") - def test_find_config_from_home_legacy_name(self): + def test_find_config_from_home_legacy_name(self) -> None: tmpdir = self.mkdir("test_find_config_from_home_legacy_name") - config_path = tmpdir.ensure(".dockercfg") + config_path = tmpdir.ensure(".dockercfg") # type: ignore[attr-defined] with mock.patch.dict(os.environ, {"HOME": str(tmpdir)}): assert config.find_config_file() == str(config_path) @mark.skipif("sys.platform != 'win32'") - def test_find_config_from_home_windows(self): + def test_find_config_from_home_windows(self) -> None: tmpdir = self.mkdir("test_find_config_from_home_windows") - config_path = tmpdir.ensure(".docker", "config.json") + config_path = tmpdir.ensure(".docker", "config.json") # type: ignore[attr-defined] with mock.patch.dict(os.environ, {"USERPROFILE": str(tmpdir)}): assert config.find_config_file() == str(config_path) class LoadConfigTest(unittest.TestCase): - def test_load_config_no_file(self): + def test_load_config_no_file(self) -> None: folder = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, folder) cfg = config.load_general_config(folder) @@ -81,7 +82,7 @@ class LoadConfigTest(unittest.TestCase): assert isinstance(cfg, dict) assert not cfg - def test_load_config_custom_headers(self): + def test_load_config_custom_headers(self) -> None: folder = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, folder) @@ -97,7 +98,7 @@ class LoadConfigTest(unittest.TestCase): assert "HttpHeaders" in cfg assert cfg["HttpHeaders"] == {"Name": "Spike", "Surname": "Spiegel"} - def test_load_config_detach_keys(self): + def test_load_config_detach_keys(self) -> None: folder = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, folder) dockercfg_path = os.path.join(folder, "config.json") @@ -108,7 +109,7 @@ class LoadConfigTest(unittest.TestCase): cfg = config.load_general_config(dockercfg_path) assert cfg == config_data - def test_load_config_from_env(self): + def test_load_config_from_env(self) -> None: folder = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, folder) dockercfg_path = os.path.join(folder, "config.json") diff --git a/tests/unit/plugins/module_utils/_api/utils/test_decorators.py b/tests/unit/plugins/module_utils/_api/utils/test_decorators.py index 93d2019c..a8f70fea 100644 --- a/tests/unit/plugins/module_utils/_api/utils/test_decorators.py +++ b/tests/unit/plugins/module_utils/_api/utils/test_decorators.py @@ -8,6 +8,7 @@ from __future__ import annotations +import typing as t import unittest from ansible_collections.community.docker.plugins.module_utils._api.api.client import ( @@ -22,12 +23,12 @@ from ansible_collections.community.docker.tests.unit.plugins.module_utils._api.c class DecoratorsTest(unittest.TestCase): - def test_update_headers(self): + def test_update_headers(self) -> None: sample_headers = { "X-Docker-Locale": "en-US", } - def f(self, headers=None): + def f(self: t.Any, headers: t.Any = None) -> t.Any: return headers client = APIClient(version=DEFAULT_DOCKER_API_VERSION) diff --git a/tests/unit/plugins/module_utils/_api/utils/test_json_stream.py b/tests/unit/plugins/module_utils/_api/utils/test_json_stream.py index cc9693aa..ea97acc1 100644 --- a/tests/unit/plugins/module_utils/_api/utils/test_json_stream.py +++ b/tests/unit/plugins/module_utils/_api/utils/test_json_stream.py @@ -8,6 +8,8 @@ from __future__ import annotations +import typing as t + from ansible_collections.community.docker.plugins.module_utils._api.utils.json_stream import ( json_splitter, json_stream, @@ -15,41 +17,48 @@ from ansible_collections.community.docker.plugins.module_utils._api.utils.json_s ) -class TestJsonSplitter: +if t.TYPE_CHECKING: + T = t.TypeVar("T") - def test_json_splitter_no_object(self): + +def create_generator(input_sequence: list[T]) -> t.Generator[T]: + yield from input_sequence + + +class TestJsonSplitter: + def test_json_splitter_no_object(self) -> None: data = '{"foo": "bar' assert json_splitter(data) is None - def test_json_splitter_with_object(self): + def test_json_splitter_with_object(self) -> None: data = '{"foo": "bar"}\n \n{"next": "obj"}' assert json_splitter(data) == ({"foo": "bar"}, '{"next": "obj"}') - def test_json_splitter_leading_whitespace(self): + def test_json_splitter_leading_whitespace(self) -> None: data = '\n \r{"foo": "bar"}\n\n {"next": "obj"}' assert json_splitter(data) == ({"foo": "bar"}, '{"next": "obj"}') class TestStreamAsText: - - def test_stream_with_non_utf_unicode_character(self): - stream = [b"\xed\xf3\xf3"] + def test_stream_with_non_utf_unicode_character(self) -> None: + stream = create_generator([b"\xed\xf3\xf3"]) (output,) = stream_as_text(stream) assert output == "���" - def test_stream_with_utf_character(self): - stream = ["ěĝ".encode("utf-8")] + def test_stream_with_utf_character(self) -> None: + stream = create_generator(["ěĝ".encode("utf-8")]) (output,) = stream_as_text(stream) assert output == "ěĝ" class TestJsonStream: - - def test_with_falsy_entries(self): - stream = [ - '{"one": "two"}\n{}\n', - "[1, 2, 3]\n[]\n", - ] + def test_with_falsy_entries(self) -> None: + stream = create_generator( + [ + '{"one": "two"}\n{}\n', + "[1, 2, 3]\n[]\n", + ] + ) output = list(json_stream(stream)) assert output == [ {"one": "two"}, @@ -58,7 +67,9 @@ class TestJsonStream: [], ] - def test_with_leading_whitespace(self): - stream = ['\n \r\n {"one": "two"}{"x": 1}', ' {"three": "four"}\t\t{"x": 2}'] + def test_with_leading_whitespace(self) -> None: + stream = create_generator( + ['\n \r\n {"one": "two"}{"x": 1}', ' {"three": "four"}\t\t{"x": 2}'] + ) output = list(json_stream(stream)) assert output == [{"one": "two"}, {"x": 1}, {"three": "four"}, {"x": 2}] diff --git a/tests/unit/plugins/module_utils/_api/utils/test_ports.py b/tests/unit/plugins/module_utils/_api/utils/test_ports.py index 3bce6125..77b67d04 100644 --- a/tests/unit/plugins/module_utils/_api/utils/test_ports.py +++ b/tests/unit/plugins/module_utils/_api/utils/test_ports.py @@ -19,132 +19,132 @@ from ansible_collections.community.docker.plugins.module_utils._api.utils.ports class PortsTest(unittest.TestCase): - def test_split_port_with_host_ip(self): + def test_split_port_with_host_ip(self) -> None: internal_port, external_port = split_port("127.0.0.1:1000:2000") assert internal_port == ["2000"] assert external_port == [("127.0.0.1", "1000")] - def test_split_port_with_protocol(self): + def test_split_port_with_protocol(self) -> None: for protocol in ["tcp", "udp", "sctp"]: internal_port, external_port = split_port("127.0.0.1:1000:2000/" + protocol) assert internal_port == ["2000/" + protocol] assert external_port == [("127.0.0.1", "1000")] - def test_split_port_with_host_ip_no_port(self): + def test_split_port_with_host_ip_no_port(self) -> None: internal_port, external_port = split_port("127.0.0.1::2000") assert internal_port == ["2000"] assert external_port == [("127.0.0.1", None)] - def test_split_port_range_with_host_ip_no_port(self): + def test_split_port_range_with_host_ip_no_port(self) -> None: internal_port, external_port = split_port("127.0.0.1::2000-2001") assert internal_port == ["2000", "2001"] assert external_port == [("127.0.0.1", None), ("127.0.0.1", None)] - def test_split_port_with_host_port(self): + def test_split_port_with_host_port(self) -> None: internal_port, external_port = split_port("1000:2000") assert internal_port == ["2000"] assert external_port == ["1000"] - def test_split_port_range_with_host_port(self): + def test_split_port_range_with_host_port(self) -> None: internal_port, external_port = split_port("1000-1001:2000-2001") assert internal_port == ["2000", "2001"] assert external_port == ["1000", "1001"] - def test_split_port_random_port_range_with_host_port(self): + def test_split_port_random_port_range_with_host_port(self) -> None: internal_port, external_port = split_port("1000-1001:2000") assert internal_port == ["2000"] assert external_port == ["1000-1001"] - def test_split_port_no_host_port(self): + def test_split_port_no_host_port(self) -> None: internal_port, external_port = split_port("2000") assert internal_port == ["2000"] assert external_port is None - def test_split_port_range_no_host_port(self): + def test_split_port_range_no_host_port(self) -> None: internal_port, external_port = split_port("2000-2001") assert internal_port == ["2000", "2001"] assert external_port is None - def test_split_port_range_with_protocol(self): + def test_split_port_range_with_protocol(self) -> None: internal_port, external_port = split_port("127.0.0.1:1000-1001:2000-2001/udp") assert internal_port == ["2000/udp", "2001/udp"] assert external_port == [("127.0.0.1", "1000"), ("127.0.0.1", "1001")] - def test_split_port_with_ipv6_address(self): + def test_split_port_with_ipv6_address(self) -> None: internal_port, external_port = split_port("2001:abcd:ef00::2:1000:2000") assert internal_port == ["2000"] assert external_port == [("2001:abcd:ef00::2", "1000")] - def test_split_port_with_ipv6_square_brackets_address(self): + def test_split_port_with_ipv6_square_brackets_address(self) -> None: internal_port, external_port = split_port("[2001:abcd:ef00::2]:1000:2000") assert internal_port == ["2000"] assert external_port == [("2001:abcd:ef00::2", "1000")] - def test_split_port_invalid(self): + def test_split_port_invalid(self) -> None: with pytest.raises(ValueError): split_port("0.0.0.0:1000:2000:tcp") - def test_split_port_invalid_protocol(self): + def test_split_port_invalid_protocol(self) -> None: with pytest.raises(ValueError): split_port("0.0.0.0:1000:2000/ftp") - def test_non_matching_length_port_ranges(self): + def test_non_matching_length_port_ranges(self) -> None: with pytest.raises(ValueError): split_port("0.0.0.0:1000-1010:2000-2002/tcp") - def test_port_and_range_invalid(self): + def test_port_and_range_invalid(self) -> None: with pytest.raises(ValueError): split_port("0.0.0.0:1000:2000-2002/tcp") - def test_port_only_with_colon(self): + def test_port_only_with_colon(self) -> None: with pytest.raises(ValueError): split_port(":80") - def test_host_only_with_colon(self): + def test_host_only_with_colon(self) -> None: with pytest.raises(ValueError): split_port("localhost:") - def test_with_no_container_port(self): + def test_with_no_container_port(self) -> None: with pytest.raises(ValueError): split_port("localhost:80:") - def test_split_port_empty_string(self): + def test_split_port_empty_string(self) -> None: with pytest.raises(ValueError): split_port("") - def test_split_port_non_string(self): + def test_split_port_non_string(self) -> None: assert split_port(1243) == (["1243"], None) - def test_build_port_bindings_with_one_port(self): + def test_build_port_bindings_with_one_port(self) -> None: port_bindings = build_port_bindings(["127.0.0.1:1000:1000"]) assert port_bindings["1000"] == [("127.0.0.1", "1000")] - def test_build_port_bindings_with_matching_internal_ports(self): + def test_build_port_bindings_with_matching_internal_ports(self) -> None: port_bindings = build_port_bindings( ["127.0.0.1:1000:1000", "127.0.0.1:2000:1000"] ) assert port_bindings["1000"] == [("127.0.0.1", "1000"), ("127.0.0.1", "2000")] - def test_build_port_bindings_with_nonmatching_internal_ports(self): + def test_build_port_bindings_with_nonmatching_internal_ports(self) -> None: port_bindings = build_port_bindings( ["127.0.0.1:1000:1000", "127.0.0.1:2000:2000"] ) assert port_bindings["1000"] == [("127.0.0.1", "1000")] assert port_bindings["2000"] == [("127.0.0.1", "2000")] - def test_build_port_bindings_with_port_range(self): + def test_build_port_bindings_with_port_range(self) -> None: port_bindings = build_port_bindings(["127.0.0.1:1000-1001:1000-1001"]) assert port_bindings["1000"] == [("127.0.0.1", "1000")] assert port_bindings["1001"] == [("127.0.0.1", "1001")] - def test_build_port_bindings_with_matching_internal_port_ranges(self): + def test_build_port_bindings_with_matching_internal_port_ranges(self) -> None: port_bindings = build_port_bindings( ["127.0.0.1:1000-1001:1000-1001", "127.0.0.1:2000-2001:1000-1001"] ) assert port_bindings["1000"] == [("127.0.0.1", "1000"), ("127.0.0.1", "2000")] assert port_bindings["1001"] == [("127.0.0.1", "1001"), ("127.0.0.1", "2001")] - def test_build_port_bindings_with_nonmatching_internal_port_ranges(self): + def test_build_port_bindings_with_nonmatching_internal_port_ranges(self) -> None: port_bindings = build_port_bindings( ["127.0.0.1:1000:1000", "127.0.0.1:2000:2000"] ) diff --git a/tests/unit/plugins/module_utils/_api/utils/test_proxy.py b/tests/unit/plugins/module_utils/_api/utils/test_proxy.py index d5e12e50..9483b64e 100644 --- a/tests/unit/plugins/module_utils/_api/utils/test_proxy.py +++ b/tests/unit/plugins/module_utils/_api/utils/test_proxy.py @@ -33,8 +33,7 @@ ENV = { class ProxyConfigTest(unittest.TestCase): - - def test_from_dict(self): + def test_from_dict(self) -> None: config = ProxyConfig.from_dict( { "httpProxy": HTTP, @@ -48,7 +47,7 @@ class ProxyConfigTest(unittest.TestCase): self.assertEqual(CONFIG.ftp, config.ftp) self.assertEqual(CONFIG.no_proxy, config.no_proxy) - def test_new(self): + def test_new(self) -> None: config = ProxyConfig() self.assertIsNone(config.http) self.assertIsNone(config.https) @@ -61,22 +60,24 @@ class ProxyConfigTest(unittest.TestCase): self.assertEqual(config.ftp, "c") self.assertEqual(config.no_proxy, "d") - def test_truthiness(self): + def test_truthiness(self) -> None: assert not ProxyConfig() assert ProxyConfig(http="non-zero") assert ProxyConfig(https="non-zero") assert ProxyConfig(ftp="non-zero") assert ProxyConfig(no_proxy="non-zero") - def test_environment(self): + def test_environment(self) -> None: self.assertDictEqual(CONFIG.get_environment(), ENV) empty = ProxyConfig() self.assertDictEqual(empty.get_environment(), {}) - def test_inject_proxy_environment(self): + def test_inject_proxy_environment(self) -> None: # Proxy config is non null, env is None. + envlist = CONFIG.inject_proxy_environment(None) + assert envlist is not None self.assertSetEqual( - set(CONFIG.inject_proxy_environment(None)), + set(envlist), set(f"{k}={v}" for k, v in ENV.items()), ) diff --git a/tests/unit/plugins/module_utils/_api/utils/test_utils.py b/tests/unit/plugins/module_utils/_api/utils/test_utils.py index 85b4619d..2b412b50 100644 --- a/tests/unit/plugins/module_utils/_api/utils/test_utils.py +++ b/tests/unit/plugins/module_utils/_api/utils/test_utils.py @@ -52,13 +52,15 @@ TEST_CERT_DIR = os.path.join( class KwargsFromEnvTest(unittest.TestCase): - def setUp(self): + os_environ: dict[str, str] + + def setUp(self) -> None: self.os_environ = os.environ.copy() - def tearDown(self): - os.environ = self.os_environ + def tearDown(self) -> None: + os.environ = self.os_environ # type: ignore - def test_kwargs_from_env_empty(self): + def test_kwargs_from_env_empty(self) -> None: os.environ.update(DOCKER_HOST="", DOCKER_CERT_PATH="") os.environ.pop("DOCKER_TLS_VERIFY", None) @@ -66,7 +68,7 @@ class KwargsFromEnvTest(unittest.TestCase): assert kwargs.get("base_url") is None assert kwargs.get("tls") is None - def test_kwargs_from_env_tls(self): + def test_kwargs_from_env_tls(self) -> None: os.environ.update( DOCKER_HOST="tcp://192.168.59.103:2376", DOCKER_CERT_PATH=TEST_CERT_DIR, @@ -90,7 +92,7 @@ class KwargsFromEnvTest(unittest.TestCase): except TypeError as e: self.fail(e) - def test_kwargs_from_env_tls_verify_false(self): + def test_kwargs_from_env_tls_verify_false(self) -> None: os.environ.update( DOCKER_HOST="tcp://192.168.59.103:2376", DOCKER_CERT_PATH=TEST_CERT_DIR, @@ -113,7 +115,7 @@ class KwargsFromEnvTest(unittest.TestCase): except TypeError as e: self.fail(e) - def test_kwargs_from_env_tls_verify_false_no_cert(self): + def test_kwargs_from_env_tls_verify_false_no_cert(self) -> None: temp_dir = tempfile.mkdtemp() cert_dir = os.path.join(temp_dir, ".docker") shutil.copytree(TEST_CERT_DIR, cert_dir) @@ -125,7 +127,7 @@ class KwargsFromEnvTest(unittest.TestCase): kwargs = kwargs_from_env(assert_hostname=True) assert "tcp://192.168.59.103:2376" == kwargs["base_url"] - def test_kwargs_from_env_no_cert_path(self): + def test_kwargs_from_env_no_cert_path(self) -> None: try: temp_dir = tempfile.mkdtemp() cert_dir = os.path.join(temp_dir, ".docker") @@ -142,7 +144,7 @@ class KwargsFromEnvTest(unittest.TestCase): if temp_dir: shutil.rmtree(temp_dir) - def test_kwargs_from_env_alternate_env(self): + def test_kwargs_from_env_alternate_env(self) -> None: # Values in os.environ are entirely ignored if an alternate is # provided os.environ.update( @@ -160,30 +162,32 @@ class KwargsFromEnvTest(unittest.TestCase): class ConverVolumeBindsTest(unittest.TestCase): - def test_convert_volume_binds_empty(self): + def test_convert_volume_binds_empty(self) -> None: assert convert_volume_binds({}) == [] assert convert_volume_binds([]) == [] - def test_convert_volume_binds_list(self): + def test_convert_volume_binds_list(self) -> None: data = ["/a:/a:ro", "/b:/c:z"] assert convert_volume_binds(data) == data - def test_convert_volume_binds_complete(self): - data = {"/mnt/vol1": {"bind": "/data", "mode": "ro"}} + def test_convert_volume_binds_complete(self) -> None: + data: dict[str | bytes, dict[str, str]] = { + "/mnt/vol1": {"bind": "/data", "mode": "ro"} + } assert convert_volume_binds(data) == ["/mnt/vol1:/data:ro"] - def test_convert_volume_binds_compact(self): - data = {"/mnt/vol1": "/data"} + def test_convert_volume_binds_compact(self) -> None: + data: dict[str | bytes, str] = {"/mnt/vol1": "/data"} assert convert_volume_binds(data) == ["/mnt/vol1:/data:rw"] - def test_convert_volume_binds_no_mode(self): - data = {"/mnt/vol1": {"bind": "/data"}} + def test_convert_volume_binds_no_mode(self) -> None: + data: dict[str | bytes, dict[str, str]] = {"/mnt/vol1": {"bind": "/data"}} assert convert_volume_binds(data) == ["/mnt/vol1:/data:rw"] - def test_convert_volume_binds_unicode_bytes_input(self): + def test_convert_volume_binds_unicode_bytes_input(self) -> None: expected = ["/mnt/지연:/unicode/박:rw"] - data = { + data: dict[str | bytes, dict[str, str | bytes]] = { "/mnt/지연".encode("utf-8"): { "bind": "/unicode/박".encode("utf-8"), "mode": "rw", @@ -191,15 +195,17 @@ class ConverVolumeBindsTest(unittest.TestCase): } assert convert_volume_binds(data) == expected - def test_convert_volume_binds_unicode_unicode_input(self): + def test_convert_volume_binds_unicode_unicode_input(self) -> None: expected = ["/mnt/지연:/unicode/박:rw"] - data = {"/mnt/지연": {"bind": "/unicode/박", "mode": "rw"}} + data: dict[str | bytes, dict[str, str]] = { + "/mnt/지연": {"bind": "/unicode/박", "mode": "rw"} + } assert convert_volume_binds(data) == expected class ParseEnvFileTest(unittest.TestCase): - def generate_tempfile(self, file_content=None): + def generate_tempfile(self, file_content: str) -> str: """ Generates a temporary file for tests with the content of 'file_content' and returns the filename. @@ -209,31 +215,31 @@ class ParseEnvFileTest(unittest.TestCase): local_tempfile.write(file_content.encode("UTF-8")) return local_tempfile.name - def test_parse_env_file_proper(self): + def test_parse_env_file_proper(self) -> None: env_file = self.generate_tempfile(file_content="USER=jdoe\nPASS=secret") get_parse_env_file = parse_env_file(env_file) assert get_parse_env_file == {"USER": "jdoe", "PASS": "secret"} os.unlink(env_file) - def test_parse_env_file_with_equals_character(self): + def test_parse_env_file_with_equals_character(self) -> None: env_file = self.generate_tempfile(file_content="USER=jdoe\nPASS=sec==ret") get_parse_env_file = parse_env_file(env_file) assert get_parse_env_file == {"USER": "jdoe", "PASS": "sec==ret"} os.unlink(env_file) - def test_parse_env_file_commented_line(self): + def test_parse_env_file_commented_line(self) -> None: env_file = self.generate_tempfile(file_content="USER=jdoe\n#PASS=secret") get_parse_env_file = parse_env_file(env_file) assert get_parse_env_file == {"USER": "jdoe"} os.unlink(env_file) - def test_parse_env_file_newline(self): + def test_parse_env_file_newline(self) -> None: env_file = self.generate_tempfile(file_content="\nUSER=jdoe\n\n\nPASS=secret") get_parse_env_file = parse_env_file(env_file) assert get_parse_env_file == {"USER": "jdoe", "PASS": "secret"} os.unlink(env_file) - def test_parse_env_file_invalid_line(self): + def test_parse_env_file_invalid_line(self) -> None: env_file = self.generate_tempfile(file_content="USER jdoe") with pytest.raises(DockerException): parse_env_file(env_file) @@ -241,7 +247,7 @@ class ParseEnvFileTest(unittest.TestCase): class ParseHostTest(unittest.TestCase): - def test_parse_host(self): + def test_parse_host(self) -> None: invalid_hosts = [ "foo://0.0.0.0", "tcp://", @@ -282,16 +288,16 @@ class ParseHostTest(unittest.TestCase): for host in invalid_hosts: msg = f"Should have failed to parse invalid host: {host}" with self.assertRaises(DockerException, msg=msg): - parse_host(host, None) + parse_host(host) for host, expected in valid_hosts.items(): self.assertEqual( - parse_host(host, None), + parse_host(host), expected, msg=f"Failed to parse valid host: {host}", ) - def test_parse_host_empty_value(self): + def test_parse_host_empty_value(self) -> None: unix_socket = "http+unix:///var/run/docker.sock" npipe = "npipe:////./pipe/docker_engine" @@ -299,17 +305,17 @@ class ParseHostTest(unittest.TestCase): assert parse_host(val, is_win32=False) == unix_socket assert parse_host(val, is_win32=True) == npipe - def test_parse_host_tls(self): + def test_parse_host_tls(self) -> None: host_value = "myhost.docker.net:3348" expected_result = "https://myhost.docker.net:3348" assert parse_host(host_value, tls=True) == expected_result - def test_parse_host_tls_tcp_proto(self): + def test_parse_host_tls_tcp_proto(self) -> None: host_value = "tcp://myhost.docker.net:3348" expected_result = "https://myhost.docker.net:3348" assert parse_host(host_value, tls=True) == expected_result - def test_parse_host_trailing_slash(self): + def test_parse_host_trailing_slash(self) -> None: host_value = "tcp://myhost.docker.net:2376/" expected_result = "http://myhost.docker.net:2376" assert parse_host(host_value) == expected_result @@ -318,31 +324,31 @@ class ParseHostTest(unittest.TestCase): class ParseRepositoryTagTest(unittest.TestCase): sha = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" - def test_index_image_no_tag(self): + def test_index_image_no_tag(self) -> None: assert parse_repository_tag("root") == ("root", None) - def test_index_image_tag(self): + def test_index_image_tag(self) -> None: assert parse_repository_tag("root:tag") == ("root", "tag") - def test_index_user_image_no_tag(self): + def test_index_user_image_no_tag(self) -> None: assert parse_repository_tag("user/repo") == ("user/repo", None) - def test_index_user_image_tag(self): + def test_index_user_image_tag(self) -> None: assert parse_repository_tag("user/repo:tag") == ("user/repo", "tag") - def test_private_reg_image_no_tag(self): + def test_private_reg_image_no_tag(self) -> None: assert parse_repository_tag("url:5000/repo") == ("url:5000/repo", None) - def test_private_reg_image_tag(self): + def test_private_reg_image_tag(self) -> None: assert parse_repository_tag("url:5000/repo:tag") == ("url:5000/repo", "tag") - def test_index_image_sha(self): + def test_index_image_sha(self) -> None: assert parse_repository_tag(f"root@sha256:{self.sha}") == ( "root", f"sha256:{self.sha}", ) - def test_private_reg_image_sha(self): + def test_private_reg_image_sha(self) -> None: assert parse_repository_tag(f"url:5000/repo@sha256:{self.sha}") == ( "url:5000/repo", f"sha256:{self.sha}", @@ -350,7 +356,7 @@ class ParseRepositoryTagTest(unittest.TestCase): class ParseDeviceTest(unittest.TestCase): - def test_dict(self): + def test_dict(self) -> None: devices = parse_devices( [ { @@ -366,7 +372,7 @@ class ParseDeviceTest(unittest.TestCase): "CgroupPermissions": "r", } - def test_partial_string_definition(self): + def test_partial_string_definition(self) -> None: devices = parse_devices(["/dev/sda1"]) assert devices[0] == { "PathOnHost": "/dev/sda1", @@ -374,7 +380,7 @@ class ParseDeviceTest(unittest.TestCase): "CgroupPermissions": "rwm", } - def test_permissionless_string_definition(self): + def test_permissionless_string_definition(self) -> None: devices = parse_devices(["/dev/sda1:/dev/mnt1"]) assert devices[0] == { "PathOnHost": "/dev/sda1", @@ -382,7 +388,7 @@ class ParseDeviceTest(unittest.TestCase): "CgroupPermissions": "rwm", } - def test_full_string_definition(self): + def test_full_string_definition(self) -> None: devices = parse_devices(["/dev/sda1:/dev/mnt1:r"]) assert devices[0] == { "PathOnHost": "/dev/sda1", @@ -390,7 +396,7 @@ class ParseDeviceTest(unittest.TestCase): "CgroupPermissions": "r", } - def test_hybrid_list(self): + def test_hybrid_list(self) -> None: devices = parse_devices( [ "/dev/sda1:/dev/mnt1:rw", @@ -415,12 +421,12 @@ class ParseDeviceTest(unittest.TestCase): class ParseBytesTest(unittest.TestCase): - def test_parse_bytes_valid(self): + def test_parse_bytes_valid(self) -> None: assert parse_bytes("512MB") == 536870912 assert parse_bytes("512M") == 536870912 assert parse_bytes("512m") == 536870912 - def test_parse_bytes_invalid(self): + def test_parse_bytes_invalid(self) -> None: with pytest.raises(DockerException): parse_bytes("512MK") with pytest.raises(DockerException): @@ -428,15 +434,15 @@ class ParseBytesTest(unittest.TestCase): with pytest.raises(DockerException): parse_bytes("127.0.0.1K") - def test_parse_bytes_float(self): + def test_parse_bytes_float(self) -> None: assert parse_bytes("1.5k") == 1536 class UtilsTest(unittest.TestCase): longMessage = True - def test_convert_filters(self): - tests = [ + def test_convert_filters(self) -> None: + tests: list[tuple[dict[str, bool | str | int | list[str | int]], str]] = [ ({"dangling": True}, '{"dangling": ["true"]}'), ({"dangling": "true"}, '{"dangling": ["true"]}'), ({"exited": 0}, '{"exited": ["0"]}'), @@ -446,7 +452,7 @@ class UtilsTest(unittest.TestCase): for filters, expected in tests: assert convert_filters(filters) == expected - def test_decode_json_header(self): + def test_decode_json_header(self) -> None: obj = {"a": "b", "c": 1} data = base64.urlsafe_b64encode(bytes(json.dumps(obj), "utf-8")) decoded_data = decode_json_header(data) @@ -454,16 +460,16 @@ class UtilsTest(unittest.TestCase): class SplitCommandTest(unittest.TestCase): - def test_split_command_with_unicode(self): + def test_split_command_with_unicode(self) -> None: assert split_command("echo μμ") == ["echo", "μμ"] class FormatEnvironmentTest(unittest.TestCase): - def test_format_env_binary_unicode_value(self): + def test_format_env_binary_unicode_value(self) -> None: env_dict = {"ARTIST_NAME": b"\xec\x86\xa1\xec\xa7\x80\xec\x9d\x80"} assert format_environment(env_dict) == ["ARTIST_NAME=송지은"] - def test_format_env_no_value(self): + def test_format_env_no_value(self) -> None: env_dict = { "FOO": None, "BAR": "", diff --git a/tests/unit/plugins/module_utils/compose_v2_test_cases.py b/tests/unit/plugins/module_utils/compose_v2_test_cases.py index be41ebd9..2cbcb124 100644 --- a/tests/unit/plugins/module_utils/compose_v2_test_cases.py +++ b/tests/unit/plugins/module_utils/compose_v2_test_cases.py @@ -11,7 +11,7 @@ from ansible_collections.community.docker.plugins.module_utils._compose_v2 impor ) -EVENT_TEST_CASES: list[tuple[str, str, bool, bool, str, list[Event], list[Event]]] = [ +EVENT_TEST_CASES: list[tuple[str, str, bool, bool, str, list[Event], list[str]]] = [ # ####################################################################################################################### # ## Docker Compose 2.18.1 ############################################################################################## # ####################################################################################################################### diff --git a/tests/unit/plugins/module_utils/test__compose_v2.py b/tests/unit/plugins/module_utils/test__compose_v2.py index 4d855364..ebc6654f 100644 --- a/tests/unit/plugins/module_utils/test__compose_v2.py +++ b/tests/unit/plugins/module_utils/test__compose_v2.py @@ -14,7 +14,7 @@ from ansible_collections.community.docker.plugins.module_utils._compose_v2 impor from .compose_v2_test_cases import EVENT_TEST_CASES -EXTRA_TEST_CASES = [ +EXTRA_TEST_CASES: list[tuple[str, str, bool, bool, str, list[Event], list[str]]] = [ ( "2.24.2-manual-build-dry-run", "2.24.2", @@ -227,9 +227,7 @@ EXTRA_TEST_CASES = [ False, False, # fmt: off - " bash_1 Skipped \n" - " bash_2 Pulling \n" - " bash_2 Pulled \n", + " bash_1 Skipped \n bash_2 Pulling \n bash_2 Pulled \n", # fmt: on [ Event( @@ -361,15 +359,24 @@ _ALL_TEST_CASES = EVENT_TEST_CASES + EXTRA_TEST_CASES ids=[tc[0] for tc in _ALL_TEST_CASES], ) def test_parse_events( - test_id, compose_version, dry_run, nonzero_rc, stderr, events, warnings -): + test_id: str, + compose_version: str, + dry_run: bool, + nonzero_rc: bool, + stderr: str, + events: list[Event], + warnings: list[str], +) -> None: collected_warnings = [] - def collect_warning(msg): + def collect_warning(msg: str) -> None: collected_warnings.append(msg) collected_events = parse_events( - stderr, dry_run=dry_run, warn_function=collect_warning, nonzero_rc=nonzero_rc + stderr.encode("utf-8"), + dry_run=dry_run, + warn_function=collect_warning, + nonzero_rc=nonzero_rc, ) print(collected_events) diff --git a/tests/unit/plugins/module_utils/test__copy.py b/tests/unit/plugins/module_utils/test__copy.py index 8b9feb27..3cdee3ea 100644 --- a/tests/unit/plugins/module_utils/test__copy.py +++ b/tests/unit/plugins/module_utils/test__copy.py @@ -4,6 +4,8 @@ from __future__ import annotations +import typing as t + import pytest from ansible_collections.community.docker.plugins.module_utils._copy import ( @@ -11,7 +13,13 @@ from ansible_collections.community.docker.plugins.module_utils._copy import ( ) -def _simple_generator(sequence): +if t.TYPE_CHECKING: + from collections.abc import Sequence + + T = t.TypeVar("T") + + +def _simple_generator(sequence: Sequence[T]) -> t.Generator[T]: yield from sequence @@ -60,10 +68,12 @@ def _simple_generator(sequence): ), ], ) -def test__stream_generator_to_fileobj(chunks, read_sizes): - chunks = [count * data for count, data in chunks] - stream = _simple_generator(chunks) - expected = b"".join(chunks) +def test__stream_generator_to_fileobj( + chunks: list[tuple[int, bytes]], read_sizes: list[int] +) -> None: + data_chunks = [count * data for count, data in chunks] + stream = _simple_generator(data_chunks) + expected = b"".join(data_chunks) buffer = b"" totally_read = 0 diff --git a/tests/unit/plugins/module_utils/test__image_archive.py b/tests/unit/plugins/module_utils/test__image_archive.py index ce55153b..c467fb74 100644 --- a/tests/unit/plugins/module_utils/test__image_archive.py +++ b/tests/unit/plugins/module_utils/test__image_archive.py @@ -5,6 +5,7 @@ from __future__ import annotations import tarfile +import typing as t import pytest @@ -22,7 +23,7 @@ from ..test_support.docker_image_archive_stubbing import ( @pytest.fixture -def tar_file_name(tmpdir): +def tar_file_name(tmpdir: t.Any) -> str: """ Return the name of a non-existing tar file in an existing temporary directory. """ @@ -34,11 +35,11 @@ def tar_file_name(tmpdir): @pytest.mark.parametrize( "expected, value", [("sha256:foo", "foo"), ("sha256:bar", "bar")] ) -def test_api_image_id_from_archive_id(expected, value): +def test_api_image_id_from_archive_id(expected: str, value: str) -> None: assert api_image_id(value) == expected -def test_archived_image_manifest_extracts(tar_file_name): +def test_archived_image_manifest_extracts(tar_file_name: str) -> None: expected_id = "abcde12345" expected_tags = ["foo:latest", "bar:v1"] @@ -46,17 +47,20 @@ def test_archived_image_manifest_extracts(tar_file_name): actual = archived_image_manifest(tar_file_name) + assert actual is not None assert actual.image_id == expected_id assert actual.repo_tags == expected_tags -def test_archived_image_manifest_extracts_nothing_when_file_not_present(tar_file_name): +def test_archived_image_manifest_extracts_nothing_when_file_not_present( + tar_file_name: str, +) -> None: image_id = archived_image_manifest(tar_file_name) assert image_id is None -def test_archived_image_manifest_raises_when_file_not_a_tar(): +def test_archived_image_manifest_raises_when_file_not_a_tar() -> None: try: archived_image_manifest(__file__) raise AssertionError() @@ -65,7 +69,9 @@ def test_archived_image_manifest_raises_when_file_not_a_tar(): assert str(__file__) in str(e) -def test_archived_image_manifest_raises_when_tar_missing_manifest(tar_file_name): +def test_archived_image_manifest_raises_when_tar_missing_manifest( + tar_file_name: str, +) -> None: write_irrelevant_tar(tar_file_name) try: @@ -76,7 +82,9 @@ def test_archived_image_manifest_raises_when_tar_missing_manifest(tar_file_name) assert "manifest.json" in str(e.__cause__) -def test_archived_image_manifest_raises_when_manifest_missing_id(tar_file_name): +def test_archived_image_manifest_raises_when_manifest_missing_id( + tar_file_name: str, +) -> None: manifest = [{"foo": "bar"}] write_imitation_archive_with_manifest(tar_file_name, manifest) diff --git a/tests/unit/plugins/module_utils/test__logfmt.py b/tests/unit/plugins/module_utils/test__logfmt.py index 9e2a89a7..3e39c7de 100644 --- a/tests/unit/plugins/module_utils/test__logfmt.py +++ b/tests/unit/plugins/module_utils/test__logfmt.py @@ -4,6 +4,8 @@ from __future__ import annotations +import typing as t + import pytest from ansible_collections.community.docker.plugins.module_utils._logfmt import ( @@ -12,7 +14,7 @@ from ansible_collections.community.docker.plugins.module_utils._logfmt import ( ) -SUCCESS_TEST_CASES = [ +SUCCESS_TEST_CASES: list[tuple[str, dict[str, t.Any], dict[str, t.Any]]] = [ ( 'time="2024-02-02T08:14:10+01:00" level=warning msg="a network with name influxNetwork exists but was not' ' created for project \\"influxdb\\".\\nSet `external: true` to use an existing network"', @@ -59,7 +61,7 @@ SUCCESS_TEST_CASES = [ ] -FAILURE_TEST_CASES = [ +FAILURE_TEST_CASES: list[tuple[str, dict[str, t.Any], str]] = [ ( 'foo=bar a=14 baz="hello kitty" cool%story=bro f %^asdf', {"logrus_mode": True}, @@ -84,14 +86,16 @@ FAILURE_TEST_CASES = [ @pytest.mark.parametrize("line, kwargs, result", SUCCESS_TEST_CASES) -def test_parse_line_success(line, kwargs, result): +def test_parse_line_success( + line: str, kwargs: dict[str, t.Any], result: dict[str, t.Any] +) -> None: res = parse_line(line, **kwargs) print(repr(res)) assert res == result @pytest.mark.parametrize("line, kwargs, message", FAILURE_TEST_CASES) -def test_parse_line_failure(line, kwargs, message): +def test_parse_line_failure(line: str, kwargs: dict[str, t.Any], message: str) -> None: with pytest.raises(InvalidLogFmt) as exc: parse_line(line, **kwargs) diff --git a/tests/unit/plugins/module_utils/test__scramble.py b/tests/unit/plugins/module_utils/test__scramble.py index 29446a10..28683639 100644 --- a/tests/unit/plugins/module_utils/test__scramble.py +++ b/tests/unit/plugins/module_utils/test__scramble.py @@ -20,7 +20,7 @@ from ansible_collections.community.docker.plugins.module_utils._scramble import ("hello", b"\x01", "=S=aWRtbW4="), ], ) -def test_scramble_unscramble(plaintext, key, scrambled): +def test_scramble_unscramble(plaintext: str, key: bytes, scrambled: str) -> None: scrambled_ = scramble(plaintext, key) print(f"{scrambled_!r} == {scrambled!r}") assert scrambled_ == scrambled diff --git a/tests/unit/plugins/module_utils/test__util.py b/tests/unit/plugins/module_utils/test__util.py index 0390e411..88663677 100644 --- a/tests/unit/plugins/module_utils/test__util.py +++ b/tests/unit/plugins/module_utils/test__util.py @@ -4,6 +4,8 @@ from __future__ import annotations +import typing as t + import pytest from ansible_collections.community.docker.plugins.module_utils._util import ( @@ -14,15 +16,41 @@ from ansible_collections.community.docker.plugins.module_utils._util import ( ) -DICT_ALLOW_MORE_PRESENT = ( +if t.TYPE_CHECKING: + + class DAMSpec(t.TypedDict): + av: dict[str, t.Any] + bv: dict[str, t.Any] + result: bool + + class Spec(t.TypedDict): + a: t.Any + b: t.Any + method: t.Literal["strict", "ignore", "allow_more_present"] + type: t.Literal["value", "list", "set", "set(dict)", "dict"] + result: bool + + +DICT_ALLOW_MORE_PRESENT: list[DAMSpec] = [ {"av": {}, "bv": {"a": 1}, "result": True}, {"av": {"a": 1}, "bv": {"a": 1, "b": 2}, "result": True}, {"av": {"a": 1}, "bv": {"b": 2}, "result": False}, {"av": {"a": 1}, "bv": {"a": None, "b": 1}, "result": False}, {"av": {"a": None}, "bv": {"b": 1}, "result": False}, -) +] -COMPARE_GENERIC = [ +DICT_ALLOW_MORE_PRESENT_SPECS: list[Spec] = [ + { + "a": entry["av"], + "b": entry["bv"], + "method": "allow_more_present", + "type": "dict", + "result": entry["result"], + } + for entry in DICT_ALLOW_MORE_PRESENT +] + +COMPARE_GENERIC: list[Spec] = [ ######################################################################################## # value {"a": 1, "b": 2, "method": "strict", "type": "value", "result": False}, @@ -386,43 +414,34 @@ COMPARE_GENERIC = [ "type": "dict", "result": True, }, -] + [ - { - "a": entry["av"], - "b": entry["bv"], - "method": "allow_more_present", - "type": "dict", - "result": entry["result"], - } - for entry in DICT_ALLOW_MORE_PRESENT ] @pytest.mark.parametrize("entry", DICT_ALLOW_MORE_PRESENT) -def test_dict_allow_more_present(entry): +def test_dict_allow_more_present(entry: DAMSpec) -> None: assert compare_dict_allow_more_present(entry["av"], entry["bv"]) == entry["result"] -@pytest.mark.parametrize("entry", COMPARE_GENERIC) -def test_compare_generic(entry): +@pytest.mark.parametrize("entry", COMPARE_GENERIC + DICT_ALLOW_MORE_PRESENT_SPECS) +def test_compare_generic(entry: Spec) -> None: assert ( compare_generic(entry["a"], entry["b"], entry["method"], entry["type"]) == entry["result"] ) -def test_convert_duration_to_nanosecond(): +def test_convert_duration_to_nanosecond() -> None: nanoseconds = convert_duration_to_nanosecond("5s") assert nanoseconds == 5000000000 nanoseconds = convert_duration_to_nanosecond("1m5s") assert nanoseconds == 65000000000 with pytest.raises(ValueError): - convert_duration_to_nanosecond([1, 2, 3]) + convert_duration_to_nanosecond([1, 2, 3]) # type: ignore with pytest.raises(ValueError): convert_duration_to_nanosecond("10x") -def test_parse_healthcheck(): +def test_parse_healthcheck() -> None: result, disabled = parse_healthcheck( { "test": "sleep 1", diff --git a/tests/unit/plugins/modules/test_docker_container_copy_into.py b/tests/unit/plugins/modules/test_docker_container_copy_into.py index f722f4e9..131ed660 100644 --- a/tests/unit/plugins/modules/test_docker_container_copy_into.py +++ b/tests/unit/plugins/modules/test_docker_container_copy_into.py @@ -4,6 +4,8 @@ from __future__ import annotations +import typing as t + import pytest from ansible_collections.community.docker.plugins.modules.docker_container_copy_into import ( @@ -30,7 +32,7 @@ from ansible_collections.community.docker.plugins.modules.docker_container_copy_ ("-1", -1), ], ) -def test_parse_string(value, expected): +def test_parse_string(value: str, expected: int) -> None: assert parse_modern(value) == expected assert parse_octal_string_only(value) == expected @@ -45,10 +47,10 @@ def test_parse_string(value, expected): 123456789012345678901234567890123456789012345678901234567890, ], ) -def test_parse_int(value): +def test_parse_int(value: int) -> None: assert parse_modern(value) == value with pytest.raises(TypeError, match=f"^must be an octal string, got {value}L?$"): - parse_octal_string_only(value) + parse_octal_string_only(value) # type: ignore @pytest.mark.parametrize( @@ -60,7 +62,7 @@ def test_parse_int(value): {}, ], ) -def test_parse_bad_type(value): +def test_parse_bad_type(value: t.Any) -> None: with pytest.raises(TypeError, match="^must be an octal string or an integer, got "): parse_modern(value) with pytest.raises(TypeError, match="^must be an octal string, got "): @@ -75,7 +77,7 @@ def test_parse_bad_type(value): "9", ], ) -def test_parse_bad_value(value): +def test_parse_bad_value(value: str) -> None: with pytest.raises(ValueError): parse_modern(value) with pytest.raises(ValueError): diff --git a/tests/unit/plugins/modules/test_docker_image.py b/tests/unit/plugins/modules/test_docker_image.py index e4523f4c..b12a5591 100644 --- a/tests/unit/plugins/modules/test_docker_image.py +++ b/tests/unit/plugins/modules/test_docker_image.py @@ -4,6 +4,8 @@ from __future__ import annotations +import typing as t + import pytest from ansible_collections.community.docker.plugins.module_utils._image_archive import ( @@ -19,19 +21,24 @@ from ..test_support.docker_image_archive_stubbing import ( ) -def assert_no_logging(msg): +if t.TYPE_CHECKING: + from collections.abc import Callable + from pathlib import Path + + +def assert_no_logging(msg: str) -> t.NoReturn: raise AssertionError(f"Should not have logged anything but logged {msg}") -def capture_logging(messages): - def capture(msg): +def capture_logging(messages: list[str]) -> Callable[[str], None]: + def capture(msg: str) -> None: messages.append(msg) return capture @pytest.fixture -def tar_file_name(tmpdir): +def tar_file_name(tmpdir: t.Any) -> str: """ Return the name of a non-existing tar file in an existing temporary directory. """ @@ -39,7 +46,7 @@ def tar_file_name(tmpdir): return tmpdir.join("foo.tar") -def test_archived_image_action_when_missing(tar_file_name): +def test_archived_image_action_when_missing(tar_file_name: str) -> None: fake_name = "a:latest" fake_id = "a1" @@ -52,7 +59,7 @@ def test_archived_image_action_when_missing(tar_file_name): assert actual == expected -def test_archived_image_action_when_current(tar_file_name): +def test_archived_image_action_when_current(tar_file_name: str) -> None: fake_name = "b:latest" fake_id = "b2" @@ -65,7 +72,7 @@ def test_archived_image_action_when_current(tar_file_name): assert actual is None -def test_archived_image_action_when_invalid(tar_file_name): +def test_archived_image_action_when_invalid(tar_file_name: str) -> None: fake_name = "c:1.2.3" fake_id = "c3" @@ -73,7 +80,7 @@ def test_archived_image_action_when_invalid(tar_file_name): expected = f"Archived image {fake_name} to {tar_file_name}, overwriting an unreadable archive file" - actual_log = [] + actual_log: list[str] = [] actual = ImageManager.archived_image_action( capture_logging(actual_log), tar_file_name, fake_name, api_image_id(fake_id) ) @@ -84,7 +91,7 @@ def test_archived_image_action_when_invalid(tar_file_name): assert actual_log[0].startswith("Unable to extract manifest summary from archive") -def test_archived_image_action_when_obsolete_by_id(tar_file_name): +def test_archived_image_action_when_obsolete_by_id(tar_file_name: str) -> None: fake_name = "d:0.0.1" old_id = "e5" new_id = "d4" @@ -99,7 +106,7 @@ def test_archived_image_action_when_obsolete_by_id(tar_file_name): assert actual == expected -def test_archived_image_action_when_obsolete_by_name(tar_file_name): +def test_archived_image_action_when_obsolete_by_name(tar_file_name: str) -> None: old_name = "hi" new_name = "d:0.0.1" fake_id = "d4" diff --git a/tests/unit/plugins/modules/test_docker_image_build.py b/tests/unit/plugins/modules/test_docker_image_build.py index cddea0f6..9ca54380 100644 --- a/tests/unit/plugins/modules/test_docker_image_build.py +++ b/tests/unit/plugins/modules/test_docker_image_build.py @@ -21,5 +21,5 @@ from ansible_collections.community.docker.plugins.modules.docker_image_build imp ('\rhello, "hi" !\n', '"\rhello, ""hi"" !\n"'), ], ) -def test__quote_csv(value, expected): +def test__quote_csv(value: str, expected: str) -> None: assert _quote_csv(value) == expected diff --git a/tests/unit/plugins/modules/test_docker_network.py b/tests/unit/plugins/modules/test_docker_network.py index 3875a294..d548629d 100644 --- a/tests/unit/plugins/modules/test_docker_network.py +++ b/tests/unit/plugins/modules/test_docker_network.py @@ -6,6 +6,8 @@ from __future__ import annotations +import typing as t + import pytest from ansible_collections.community.docker.plugins.modules.docker_network import ( @@ -23,7 +25,9 @@ from ansible_collections.community.docker.plugins.modules.docker_network import ("fdd1:ac8c:0557:7ce2::/128", "ipv6"), ], ) -def test_validate_cidr_positives(cidr, expected): +def test_validate_cidr_positives( + cidr: str, expected: t.Literal["ipv4", "ipv6"] +) -> None: assert validate_cidr(cidr) == expected @@ -36,7 +40,7 @@ def test_validate_cidr_positives(cidr, expected): "fdd1:ac8c:0557:7ce2::", ], ) -def test_validate_cidr_negatives(cidr): +def test_validate_cidr_negatives(cidr: str) -> None: with pytest.raises(ValueError) as e: validate_cidr(cidr) assert f'"{cidr}" is not a valid CIDR' == str(e.value) diff --git a/tests/unit/plugins/modules/test_docker_swarm_service.py b/tests/unit/plugins/modules/test_docker_swarm_service.py index 557c29b6..725976fb 100644 --- a/tests/unit/plugins/modules/test_docker_swarm_service.py +++ b/tests/unit/plugins/modules/test_docker_swarm_service.py @@ -4,66 +4,47 @@ from __future__ import annotations +import typing as t + import pytest - -class APIErrorMock(Exception): - def __init__(self, message, response=None, explanation=None): - self.message = message - self.response = response - self.explanation = explanation +from ansible_collections.community.docker.plugins.modules import ( + docker_swarm_service, +) -@pytest.fixture(autouse=True) -def docker_module_mock(mocker): - docker_module_mock = mocker.MagicMock() - docker_utils_module_mock = mocker.MagicMock() - docker_errors_module_mock = mocker.MagicMock() - docker_errors_module_mock.APIError = APIErrorMock - mock_modules = { - "docker": docker_module_mock, - "docker.utils": docker_utils_module_mock, - "docker.errors": docker_errors_module_mock, - } - return mocker.patch.dict("sys.modules", **mock_modules) +APIError = pytest.importorskip("docker.errors.APIError") -@pytest.fixture(autouse=True) -def docker_swarm_service(): - from ansible_collections.community.docker.plugins.modules import ( - docker_swarm_service, - ) - - return docker_swarm_service - - -def test_retry_on_out_of_sequence_error(mocker, docker_swarm_service): +def test_retry_on_out_of_sequence_error(mocker: t.Any) -> None: run_mock = mocker.MagicMock( - side_effect=APIErrorMock( + side_effect=APIError( message="", response=None, explanation="rpc error: code = Unknown desc = update out of sequence", ) ) - manager = docker_swarm_service.DockerServiceManager(client=None) - manager.run = run_mock - with pytest.raises(APIErrorMock): + mocker.patch("time.sleep") + manager = docker_swarm_service.DockerServiceManager(client=None) # type: ignore + manager.run = run_mock # type: ignore + with pytest.raises(APIError): manager.run_safe() assert run_mock.call_count == 3 -def test_no_retry_on_general_api_error(mocker, docker_swarm_service): +def test_no_retry_on_general_api_error(mocker: t.Any) -> None: run_mock = mocker.MagicMock( - side_effect=APIErrorMock(message="", response=None, explanation="some error") + side_effect=APIError(message="", response=None, explanation="some error") ) - manager = docker_swarm_service.DockerServiceManager(client=None) - manager.run = run_mock - with pytest.raises(APIErrorMock): + mocker.patch("time.sleep") + manager = docker_swarm_service.DockerServiceManager(client=None) # type: ignore + manager.run = run_mock # type: ignore + with pytest.raises(APIError): manager.run_safe() assert run_mock.call_count == 1 -def test_get_docker_environment(mocker, docker_swarm_service): +def test_get_docker_environment(mocker: t.Any) -> None: env_file_result = {"TEST1": "A", "TEST2": "B", "TEST3": "C"} env_dict = {"TEST3": "CC", "TEST4": "D"} env_string = "TEST3=CC,TEST4=D" @@ -103,7 +84,7 @@ def test_get_docker_environment(mocker, docker_swarm_service): assert result == [] -def test_get_nanoseconds_from_raw_option(docker_swarm_service): +def test_get_nanoseconds_from_raw_option() -> None: value = docker_swarm_service.get_nanoseconds_from_raw_option("test", None) assert value is None @@ -117,7 +98,7 @@ def test_get_nanoseconds_from_raw_option(docker_swarm_service): docker_swarm_service.get_nanoseconds_from_raw_option("test", []) -def test_has_dict_changed(docker_swarm_service): +def test_has_dict_changed() -> None: assert not docker_swarm_service.has_dict_changed( {"a": 1}, {"a": 1}, @@ -135,8 +116,7 @@ def test_has_dict_changed(docker_swarm_service): assert not docker_swarm_service.has_dict_changed(None, {}) -def test_has_list_changed(docker_swarm_service): - +def test_has_list_changed() -> None: # List comparisons without dictionaries # I could improve the indenting, but pycodestyle wants this instead assert not docker_swarm_service.has_list_changed(None, None) @@ -161,7 +141,7 @@ def test_has_list_changed(docker_swarm_service): assert docker_swarm_service.has_list_changed([None, 1], [2, 1]) assert docker_swarm_service.has_list_changed([2, 1], [None, 1]) assert docker_swarm_service.has_list_changed( - "command --with args", ["command", "--with", "args"] + ["command --with args"], ["command", "--with", "args"] ) assert docker_swarm_service.has_list_changed( ["sleep", "3400"], ["sleep", "3600"], sort_lists=False @@ -259,7 +239,7 @@ def test_has_list_changed(docker_swarm_service): ) -def test_have_networks_changed(docker_swarm_service): +def test_have_networks_changed() -> None: assert not docker_swarm_service.have_networks_changed(None, None) assert not docker_swarm_service.have_networks_changed([], None) @@ -329,14 +309,14 @@ def test_have_networks_changed(docker_swarm_service): ) -def test_get_docker_networks(docker_swarm_service): +def test_get_docker_networks() -> None: network_names = [ "network_1", "network_2", "network_3", "network_4", ] - networks = [ + networks: list[str | dict[str, t.Any]] = [ network_names[0], {"name": network_names[1]}, {"name": network_names[2], "aliases": ["networkalias1"]}, @@ -367,28 +347,27 @@ def test_get_docker_networks(docker_swarm_service): assert "foo" in network["options"] # Test missing name with pytest.raises(TypeError): - docker_swarm_service.get_docker_networks([{"invalid": "err"}], {"err": 1}) + docker_swarm_service.get_docker_networks([{"invalid": "err"}], {"err": "x"}) # test for invalid aliases type with pytest.raises(TypeError): docker_swarm_service.get_docker_networks( - [{"name": "test", "aliases": 1}], {"test": 1} + [{"name": "test", "aliases": 1}], {"test": "x"} ) # Test invalid aliases elements with pytest.raises(TypeError): docker_swarm_service.get_docker_networks( - [{"name": "test", "aliases": [1]}], {"test": 1} + [{"name": "test", "aliases": [1]}], {"test": "x"} ) # Test for invalid options type with pytest.raises(TypeError): docker_swarm_service.get_docker_networks( - [{"name": "test", "options": 1}], {"test": 1} + [{"name": "test", "options": 1}], {"test": "x"} ) - # Test for invalid networks type - with pytest.raises(TypeError): - docker_swarm_service.get_docker_networks(1, {"test": 1}) # Test for non existing networks with pytest.raises(ValueError): - docker_swarm_service.get_docker_networks([{"name": "idontexist"}], {"test": 1}) + docker_swarm_service.get_docker_networks( + [{"name": "idontexist"}], {"test": "x"} + ) # Test empty values assert docker_swarm_service.get_docker_networks([], {}) == [] assert docker_swarm_service.get_docker_networks(None, {}) is None diff --git a/tests/unit/plugins/plugin_utils/test__unsafe.py b/tests/unit/plugins/plugin_utils/test__unsafe.py index e2de9f8a..99c93b7d 100644 --- a/tests/unit/plugins/plugin_utils/test__unsafe.py +++ b/tests/unit/plugins/plugin_utils/test__unsafe.py @@ -4,6 +4,8 @@ from __future__ import annotations +import typing as t + import pytest from ansible_collections.community.internal_test_tools.tests.unit.utils.trust import ( SUPPORTS_DATA_TAGGING, @@ -23,7 +25,9 @@ from ansible_collections.community.docker.plugins.plugin_utils._unsafe import ( ) -TEST_MAKE_UNSAFE = [ +TEST_MAKE_UNSAFE: list[ + tuple[t.Any, list[tuple[t.Any, ...]], list[tuple[t.Any, ...]]] +] = [ ( _make_trusted("text"), [], @@ -97,7 +101,11 @@ if not SUPPORTS_DATA_TAGGING: @pytest.mark.parametrize( "value, check_unsafe_paths, check_safe_paths", TEST_MAKE_UNSAFE ) -def test_make_unsafe(value, check_unsafe_paths, check_safe_paths): +def test_make_unsafe( + value: t.Any, + check_unsafe_paths: list[tuple[t.Any, ...]], + check_safe_paths: list[tuple[t.Any, ...]], +) -> None: unsafe_value = make_unsafe(value) assert unsafe_value == value for check_path in check_unsafe_paths: @@ -112,7 +120,7 @@ def test_make_unsafe(value, check_unsafe_paths, check_safe_paths): assert _is_trusted(obj) -def test_make_unsafe_idempotence(): +def test_make_unsafe_idempotence() -> None: assert make_unsafe(None) is None unsafe_str = _make_untrusted("{{test}}") @@ -122,8 +130,8 @@ def test_make_unsafe_idempotence(): assert id(make_unsafe(safe_str)) != id(safe_str) -def test_make_unsafe_dict_key(): - value = { +def test_make_unsafe_dict_key() -> None: + value: dict[t.Any, t.Any] = { _make_trusted("test"): 2, } if not SUPPORTS_DATA_TAGGING: @@ -144,8 +152,8 @@ def test_make_unsafe_dict_key(): assert not _is_trusted(obj) -def test_make_unsafe_set(): - value = set([_make_trusted("test")]) +def test_make_unsafe_set() -> None: + value: set[t.Any] = set([_make_trusted("test")]) if not SUPPORTS_DATA_TAGGING: value.add(_make_trusted(b"test")) unsafe_value = make_unsafe(value) diff --git a/tests/unit/plugins/test_support/docker_image_archive_stubbing.py b/tests/unit/plugins/test_support/docker_image_archive_stubbing.py index 06bd1a42..38736805 100644 --- a/tests/unit/plugins/test_support/docker_image_archive_stubbing.py +++ b/tests/unit/plugins/test_support/docker_image_archive_stubbing.py @@ -6,10 +6,13 @@ from __future__ import annotations import json import tarfile +import typing as t from tempfile import TemporaryFile -def write_imitation_archive(file_name, image_id, repo_tags): +def write_imitation_archive( + file_name: str, image_id: str, repo_tags: list[str] +) -> None: """ Write a tar file meeting these requirements: @@ -21,7 +24,7 @@ def write_imitation_archive(file_name, image_id, repo_tags): :type file_name: str :param image_id: Fake sha256 hash (without the sha256: prefix) :type image_id: str - :param repo_tags: list of fake image:tag's + :param repo_tags: list of fake image tags :type repo_tags: list """ @@ -30,7 +33,9 @@ def write_imitation_archive(file_name, image_id, repo_tags): write_imitation_archive_with_manifest(file_name, manifest) -def write_imitation_archive_with_manifest(file_name, manifest): +def write_imitation_archive_with_manifest( + file_name: str, manifest: list[dict[str, t.Any]] +) -> None: with tarfile.open(file_name, "w") as tf: with TemporaryFile() as f: f.write(json.dumps(manifest).encode("utf-8")) @@ -42,7 +47,7 @@ def write_imitation_archive_with_manifest(file_name, manifest): tf.addfile(ti, f) -def write_irrelevant_tar(file_name): +def write_irrelevant_tar(file_name: str) -> None: """ Create a tar file that does not match the spec for "docker image save" / "docker image load" commands. diff --git a/tests/unit/requirements.txt b/tests/unit/requirements.txt index 3b01b8aa..755590c7 100644 --- a/tests/unit/requirements.txt +++ b/tests/unit/requirements.txt @@ -2,4 +2,5 @@ # GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt) # SPDX-License-Identifier: GPL-3.0-or-later +docker requests