From 08960a93178cd955f2434d379206eb3032b77ada Mon Sep 17 00:00:00 2001 From: Felix Fontein Date: Tue, 21 Oct 2025 06:25:30 +0200 Subject: [PATCH] Add more type hints. --- plugins/module_utils/_api/api/client.py | 511 +++++++++++++++--- plugins/module_utils/_api/api/daemon.py | 139 ----- plugins/module_utils/_api/auth.py | 90 ++- plugins/module_utils/_api/context/api.py | 35 +- plugins/module_utils/_api/context/config.py | 21 +- plugins/module_utils/_api/context/context.py | 81 +-- .../module_utils/_api/credentials/errors.py | 8 +- .../module_utils/_api/credentials/store.py | 17 +- .../module_utils/_api/credentials/utils.py | 4 +- plugins/module_utils/_api/errors.py | 60 +- plugins/module_utils/_api/tls.py | 35 +- .../_api/transport/basehttpadapter.py | 6 +- .../module_utils/_api/transport/npipeconn.py | 25 +- .../_api/transport/npipesocket.py | 102 ++-- .../module_utils/_api/transport/sshconn.py | 108 ++-- .../module_utils/_api/transport/ssladapter.py | 33 +- .../module_utils/_api/transport/unixconn.py | 35 +- plugins/module_utils/_api/types/daemon.py | 7 +- plugins/module_utils/_api/utils/build.py | 58 +- plugins/module_utils/_api/utils/config.py | 11 +- plugins/module_utils/_api/utils/decorators.py | 38 +- plugins/module_utils/_api/utils/fnmatch.py | 8 +- .../module_utils/_api/utils/json_stream.py | 27 +- plugins/module_utils/_api/utils/ports.py | 90 ++- plugins/module_utils/_api/utils/proxy.py | 16 +- plugins/module_utils/_api/utils/socket.py | 44 +- plugins/module_utils/_api/utils/utils.py | 143 ++--- plugins/module_utils/_copy.py | 2 +- 28 files changed, 1104 insertions(+), 650 deletions(-) delete mode 100644 plugins/module_utils/_api/api/daemon.py diff --git a/plugins/module_utils/_api/api/client.py b/plugins/module_utils/_api/api/client.py index af5d1db0..d1bad664 100644 --- a/plugins/module_utils/_api/api/client.py +++ b/plugins/module_utils/_api/api/client.py @@ -13,8 +13,9 @@ from __future__ import annotations import json import logging +import os import struct -from functools import partial +import typing as t from urllib.parse import quote from .. import auth @@ -50,13 +51,13 @@ from ..utils import config, json_stream, utils from ..utils.decorators import update_headers from ..utils.proxy import ProxyConfig from ..utils.socket import consume_socket_output, demux_adaptor, frames_iter -from .daemon import DaemonApiMixin +from ..utils.decorators import minimum_version log = logging.getLogger(__name__) -class APIClient(_Session, DaemonApiMixin): +class APIClient(_Session): """ A low-level client for the Docker Engine API. @@ -105,16 +106,16 @@ class APIClient(_Session, DaemonApiMixin): def __init__( self, - base_url=None, - version=None, - timeout=DEFAULT_TIMEOUT_SECONDS, - tls=False, - user_agent=DEFAULT_USER_AGENT, - num_pools=None, - credstore_env=None, - use_ssh_client=False, - max_pool_size=DEFAULT_MAX_POOL_SIZE, - ): + base_url: str | None = None, + version: str | None = None, + timeout: int = DEFAULT_TIMEOUT_SECONDS, + tls: bool | TLSConfig = False, + user_agent: str = DEFAULT_USER_AGENT, + num_pools: int | None = None, + credstore_env: dict[str, str] | None = None, + use_ssh_client: bool = False, + max_pool_size: int = DEFAULT_MAX_POOL_SIZE, + ) -> None: super().__init__() fail_on_missing_imports() @@ -152,6 +153,9 @@ class APIClient(_Session, DaemonApiMixin): else DEFAULT_NUM_POOLS ) + self._custom_adapter: ( + UnixHTTPAdapter | NpipeHTTPAdapter | SSHHTTPAdapter | SSLHTTPAdapter | None + ) = None if base_url.startswith("http+unix://"): self._custom_adapter = UnixHTTPAdapter( base_url, @@ -223,7 +227,7 @@ class APIClient(_Session, DaemonApiMixin): f"API versions below {MINIMUM_DOCKER_API_VERSION} are no longer supported by this library." ) - def _retrieve_server_version(self): + def _retrieve_server_version(self) -> str: try: version_result = self.version(api_version=False) except Exception as e: @@ -242,54 +246,87 @@ class APIClient(_Session, DaemonApiMixin): f"Error while fetching server API version: {e}. Response seems to be broken." ) from e - def _set_request_timeout(self, kwargs): + def _set_request_timeout(self, kwargs: dict[str, t.Any]) -> dict[str, t.Any]: """Prepare the kwargs for an HTTP request by inserting the timeout parameter, if not already present.""" kwargs.setdefault("timeout", self.timeout) return kwargs @update_headers - def _post(self, url, **kwargs): + def _post(self, url: str, **kwargs): return self.post(url, **self._set_request_timeout(kwargs)) @update_headers - def _get(self, url, **kwargs): + def _get(self, url: str, **kwargs): return self.get(url, **self._set_request_timeout(kwargs)) @update_headers - def _head(self, url, **kwargs): + def _head(self, url: str, **kwargs): return self.head(url, **self._set_request_timeout(kwargs)) @update_headers - def _put(self, url, **kwargs): + def _put(self, url: str, **kwargs): return self.put(url, **self._set_request_timeout(kwargs)) @update_headers - def _delete(self, url, **kwargs): + def _delete(self, url: str, **kwargs): return self.delete(url, **self._set_request_timeout(kwargs)) - def _url(self, pathfmt, *args, **kwargs): + def _url(self, pathfmt: str, *args: str, versioned_api: bool = True) -> str: for arg in args: if not isinstance(arg, str): raise ValueError( f"Expected a string but found {arg} ({type(arg)}) instead" ) - quote_f = partial(quote, safe="/:") - args = map(quote_f, args) + q_args = [quote(arg, safe="/:") for arg in args] - if kwargs.get("versioned_api", True): - return f"{self.base_url}/v{self._version}{pathfmt.format(*args)}" - return f"{self.base_url}{pathfmt.format(*args)}" + if versioned_api: + return f"{self.base_url}/v{self._version}{pathfmt.format(*q_args)}" + return f"{self.base_url}{pathfmt.format(*q_args)}" - def _raise_for_status(self, response): + def _raise_for_status(self, response) -> None: """Raises stored :class:`APIError`, if one occurred.""" try: response.raise_for_status() except _HTTPError as e: create_api_error_from_http_exception(e) - def _result(self, response, get_json=False, get_binary=False): + @t.overload + def _result( + self, + response, + *, + get_json: t.Literal[False] = False, + get_binary: t.Literal[False] = False, + ) -> str: ... + + @t.overload + def _result( + self, + response, + *, + get_json: t.Literal[True], + get_binary: t.Literal[False] = False, + ) -> t.Any: ... + + @t.overload + def _result( + self, + response, + *, + get_json: t.Literal[False] = False, + get_binary: t.Literal[True], + ) -> bytes: ... + + @t.overload + def _result( + self, response, *, get_json: bool = False, get_binary: bool = False + ) -> t.Any | str | bytes: ... + + def _result( + self, response, *, get_json: bool = False, get_binary: bool = False + ) -> t.Any | str | bytes: if get_json and get_binary: raise AssertionError("json and binary must not be both True") self._raise_for_status(response) @@ -300,10 +337,10 @@ class APIClient(_Session, DaemonApiMixin): return response.content return response.text - def _post_json(self, url, data, **kwargs): + def _post_json(self, url: str, data: dict[str, str | None] | t.Any, **kwargs): # Go <1.1 cannot unserialize null to a string # so we do this disgusting thing here. - data2 = {} + data2: dict[str, t.Any] = {} if data is not None and isinstance(data, dict): for k, v in data.items(): if v is not None: @@ -316,7 +353,7 @@ class APIClient(_Session, DaemonApiMixin): kwargs["headers"]["Content-Type"] = "application/json" return self._post(url, data=json.dumps(data2), **kwargs) - def _attach_params(self, override=None): + def _attach_params(self, override: dict[str, int] | None = None) -> dict[str, int]: return override or {"stdout": 1, "stderr": 1, "stream": 1} def _get_raw_response_socket(self, response): @@ -341,12 +378,24 @@ class APIClient(_Session, DaemonApiMixin): return sock - def _stream_helper(self, response, decode=False): + @t.overload + def _stream_helper( + self, response, *, decode: t.Literal[False] = False + ) -> t.Generator[bytes]: ... + + @t.overload + def _stream_helper( + self, response, *, decode: t.Literal[True] + ) -> t.Generator[t.Any]: ... + + def _stream_helper(self, response, *, decode: bool = False) -> t.Generator[t.Any]: """Generator for data coming from a chunked-encoded HTTP response.""" if response.raw._fp.chunked: if decode: - yield from json_stream.json_stream(self._stream_helper(response, False)) + yield from json_stream.json_stream( + self._stream_helper(response, decode=False) + ) else: reader = response.raw while not reader.closed: @@ -362,7 +411,7 @@ class APIClient(_Session, DaemonApiMixin): # encountered an error immediately yield self._result(response, get_json=decode) - def _multiplexed_buffer_helper(self, response): + def _multiplexed_buffer_helper(self, response) -> t.Generator[bytes]: """A generator of multiplexed data blocks read from a buffered response.""" buf = self._result(response, get_binary=True) @@ -378,7 +427,7 @@ class APIClient(_Session, DaemonApiMixin): walker = end yield buf[start:end] - def _multiplexed_response_stream_helper(self, response): + def _multiplexed_response_stream_helper(self, response) -> t.Generator[bytes]: """A generator of multiplexed data blocks coming from a response stream.""" @@ -399,7 +448,19 @@ class APIClient(_Session, DaemonApiMixin): break yield data - def _stream_raw_result(self, response, chunk_size=1, decode=True): + @t.overload + def _stream_raw_result( + self, response, *, chunk_size: int = 1, decode: t.Literal[True] = True + ) -> t.Generator[str]: ... + + @t.overload + def _stream_raw_result( + self, response, *, chunk_size: int = 1, decode: t.Literal[False] + ) -> t.Generator[bytes]: ... + + def _stream_raw_result( + self, response, *, chunk_size: int = 1, decode: bool = True + ) -> t.Generator[str | bytes]: """Stream result for TTY-enabled container and raw binary data""" self._raise_for_status(response) @@ -410,14 +471,81 @@ class APIClient(_Session, DaemonApiMixin): yield from response.iter_content(chunk_size, decode) - def _read_from_socket(self, response, stream, tty=True, demux=False): + @t.overload + def _read_from_socket( + self, + response, + *, + stream: t.Literal[True], + tty: bool = True, + demux: t.Literal[False] = False, + ) -> t.Generator[bytes]: ... + + @t.overload + def _read_from_socket( + self, + response, + *, + stream: t.Literal[True], + tty: t.Literal[True] = True, + demux: t.Literal[True], + ) -> t.Generator[tuple[bytes, None]]: ... + + @t.overload + def _read_from_socket( + self, + response, + *, + stream: t.Literal[True], + tty: t.Literal[False], + demux: t.Literal[True], + ) -> t.Generator[tuple[bytes, None] | tuple[None, bytes]]: ... + + @t.overload + def _read_from_socket( + self, + response, + *, + stream: t.Literal[False], + tty: bool = True, + demux: t.Literal[False] = False, + ) -> bytes: ... + + @t.overload + def _read_from_socket( + self, + response, + *, + stream: t.Literal[False], + tty: t.Literal[True] = True, + demux: t.Literal[True], + ) -> tuple[bytes, None]: ... + + @t.overload + def _read_from_socket( + self, + response, + *, + stream: t.Literal[False], + tty: t.Literal[False], + demux: t.Literal[True], + ) -> tuple[bytes, bytes]: ... + + @t.overload + def _read_from_socket( + self, response, *, stream: bool, tty: bool = True, demux: bool = False + ) -> t.Any: ... + + def _read_from_socket( + self, response, *, stream: bool, tty: bool = True, demux: bool = False + ) -> t.Any: """Consume all data from the socket, close the response and return the data. If stream=True, then a generator is returned instead and the caller is responsible for closing the response. """ socket = self._get_raw_response_socket(response) - gen = frames_iter(socket, tty) + gen: t.Generator = frames_iter(socket, tty) if demux: # The generator will output tuples (stdout, stderr) @@ -434,7 +562,7 @@ class APIClient(_Session, DaemonApiMixin): finally: response.close() - def _disable_socket_timeout(self, socket): + def _disable_socket_timeout(self, socket) -> None: """Depending on the combination of python version and whether we are connecting over http or https, we might need to access _sock, which may or may not exist; or we may need to just settimeout on socket @@ -462,7 +590,27 @@ class APIClient(_Session, DaemonApiMixin): s.settimeout(None) - def _get_result_tty(self, stream, res, is_tty): + @t.overload + def _get_result_tty( + self, stream: t.Literal[True], res, is_tty: t.Literal[True] + ) -> t.Generator[str]: ... + + @t.overload + def _get_result_tty( + self, stream: t.Literal[True], res, is_tty: t.Literal[False] + ) -> t.Generator[bytes]: ... + + @t.overload + def _get_result_tty( + self, stream: t.Literal[False], res, is_tty: t.Literal[True] + ) -> bytes: ... + + @t.overload + def _get_result_tty( + self, stream: t.Literal[False], res, is_tty: t.Literal[False] + ) -> bytes: ... + + def _get_result_tty(self, stream: bool, res, is_tty: bool) -> t.Any: # We should also use raw streaming (without keep-alive) # if we are dealing with a tty-enabled container. if is_tty: @@ -478,11 +626,11 @@ class APIClient(_Session, DaemonApiMixin): return self._multiplexed_response_stream_helper(res) return sep.join(list(self._multiplexed_buffer_helper(res))) - def _unmount(self, *args): + def _unmount(self, *args) -> None: for proto in args: self.adapters.pop(proto) - def get_adapter(self, url): + def get_adapter(self, url: str): try: return super().get_adapter(url) except _InvalidSchema as e: @@ -491,10 +639,10 @@ class APIClient(_Session, DaemonApiMixin): raise e @property - def api_version(self): + def api_version(self) -> str: return self._version - def reload_config(self, dockercfg_path=None): + def reload_config(self, dockercfg_path: str | None = None) -> None: """ Force a reload of the auth configuration @@ -510,7 +658,7 @@ class APIClient(_Session, DaemonApiMixin): dockercfg_path, credstore_env=self.credstore_env ) - def _set_auth_headers(self, headers): + def _set_auth_headers(self, headers: dict[str, str | bytes]) -> None: log.debug("Looking for auth config") # If we do not have any auth data so far, try reloading the config @@ -537,57 +685,62 @@ class APIClient(_Session, DaemonApiMixin): else: log.debug("No auth config found") - def get_binary(self, pathfmt, *args, **kwargs): + def get_binary(self, pathfmt: str, *args: str, **kwargs) -> bytes: return self._result( self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs), get_binary=True, ) - def get_json(self, pathfmt, *args, **kwargs): + def get_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any: return self._result( self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs), get_json=True, ) - def get_text(self, pathfmt, *args, **kwargs): + def get_text(self, pathfmt: str, *args: str, **kwargs) -> str: return self._result( self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs) ) - def get_raw_stream(self, pathfmt, *args, **kwargs): - chunk_size = kwargs.pop("chunk_size", DEFAULT_DATA_CHUNK_SIZE) + def get_raw_stream( + self, + pathfmt: str, + *args: str, + chunk_size: int = DEFAULT_DATA_CHUNK_SIZE, + **kwargs, + ) -> t.Generator[bytes]: res = self._get( self._url(pathfmt, *args, versioned_api=True), stream=True, **kwargs ) self._raise_for_status(res) - return self._stream_raw_result(res, chunk_size, False) + return self._stream_raw_result(res, chunk_size=chunk_size, decode=False) - def delete_call(self, pathfmt, *args, **kwargs): + def delete_call(self, pathfmt: str, *args: str, **kwargs) -> None: self._raise_for_status( self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs) ) - def delete_json(self, pathfmt, *args, **kwargs): + def delete_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any: return self._result( self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs), get_json=True, ) - def post_call(self, pathfmt, *args, **kwargs): + def post_call(self, pathfmt: str, *args: str, **kwargs) -> None: self._raise_for_status( self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs) ) - def post_json(self, pathfmt, *args, **kwargs): - data = kwargs.pop("data", None) + def post_json(self, pathfmt: str, *args: str, data: t.Any = None, **kwargs) -> None: self._raise_for_status( self._post_json( self._url(pathfmt, *args, versioned_api=True), data, **kwargs ) ) - def post_json_to_binary(self, pathfmt, *args, **kwargs): - data = kwargs.pop("data", None) + def post_json_to_binary( + self, pathfmt: str, *args: str, data: t.Any = None, **kwargs + ) -> bytes: return self._result( self._post_json( self._url(pathfmt, *args, versioned_api=True), data, **kwargs @@ -595,8 +748,9 @@ class APIClient(_Session, DaemonApiMixin): get_binary=True, ) - def post_json_to_json(self, pathfmt, *args, **kwargs): - data = kwargs.pop("data", None) + def post_json_to_json( + self, pathfmt: str, *args: str, data: t.Any = None, **kwargs + ) -> t.Any: return self._result( self._post_json( self._url(pathfmt, *args, versioned_api=True), data, **kwargs @@ -604,17 +758,24 @@ class APIClient(_Session, DaemonApiMixin): get_json=True, ) - def post_json_to_text(self, pathfmt, *args, **kwargs): - data = kwargs.pop("data", None) + def post_json_to_text( + self, pathfmt: str, *args: str, data: t.Any = None, **kwargs + ) -> str: return self._result( self._post_json( self._url(pathfmt, *args, versioned_api=True), data, **kwargs ), ) - def post_json_to_stream_socket(self, pathfmt, *args, **kwargs): - data = kwargs.pop("data", None) - headers = (kwargs.pop("headers", None) or {}).copy() + def post_json_to_stream_socket( + self, + pathfmt: str, + *args: str, + data: t.Any = None, + headers: dict[str, str] | None = None, + **kwargs, + ): + headers = headers.copy() if headers else {} headers.update( { "Connection": "Upgrade", @@ -631,18 +792,102 @@ class APIClient(_Session, DaemonApiMixin): ) ) - def post_json_to_stream(self, pathfmt, *args, **kwargs): - data = kwargs.pop("data", None) - headers = (kwargs.pop("headers", None) or {}).copy() + @t.overload + def post_json_to_stream( + self, + pathfmt: str, + *args: str, + data: t.Any = None, + headers: dict[str, str] | None = None, + stream: t.Literal[True], + tty: bool = True, + demux: t.Literal[False] = False, + **kwargs, + ) -> t.Generator[bytes]: ... + + @t.overload + def post_json_to_stream( + self, + pathfmt: str, + *args: str, + data: t.Any = None, + headers: dict[str, str] | None = None, + stream: t.Literal[True], + tty: t.Literal[True] = True, + demux: t.Literal[True], + **kwargs, + ) -> t.Generator[tuple[bytes, None]]: ... + + @t.overload + def post_json_to_stream( + self, + pathfmt: str, + *args: str, + data: t.Any = None, + headers: dict[str, str] | None = None, + stream: t.Literal[True], + tty: t.Literal[False], + demux: t.Literal[True], + **kwargs, + ) -> t.Generator[tuple[bytes, None] | tuple[None, bytes]]: ... + + @t.overload + def post_json_to_stream( + self, + pathfmt: str, + *args: str, + data: t.Any = None, + headers: dict[str, str] | None = None, + stream: t.Literal[False], + tty: bool = True, + demux: t.Literal[False] = False, + **kwargs, + ) -> bytes: ... + + @t.overload + def post_json_to_stream( + self, + pathfmt: str, + *args: str, + data: t.Any = None, + headers: dict[str, str] | None = None, + stream: t.Literal[False], + tty: t.Literal[True] = True, + demux: t.Literal[True], + **kwargs, + ) -> tuple[bytes, None]: ... + + @t.overload + def post_json_to_stream( + self, + pathfmt: str, + *args: str, + data: t.Any = None, + headers: dict[str, str] | None = None, + stream: t.Literal[False], + tty: t.Literal[False], + demux: t.Literal[True], + **kwargs, + ) -> tuple[bytes, bytes]: ... + + def post_json_to_stream( + self, + pathfmt: str, + *args: str, + data: t.Any = None, + headers: dict[str, str] | None = None, + stream: bool = False, + demux: bool = False, + tty: bool = False, + **kwargs, + ) -> t.Any: + headers = headers.copy() if headers else {} headers.update( { "Connection": "Upgrade", "Upgrade": "tcp", } ) - stream = kwargs.pop("stream", False) - demux = kwargs.pop("demux", False) - tty = kwargs.pop("tty", False) return self._read_from_socket( self._post_json( self._url(pathfmt, *args, versioned_api=True), @@ -651,13 +896,133 @@ class APIClient(_Session, DaemonApiMixin): stream=True, **kwargs, ), - stream, + stream=stream, tty=tty, demux=demux, ) - def post_to_json(self, pathfmt, *args, **kwargs): + def post_to_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any: return self._result( self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs), get_json=True, ) + + @minimum_version("1.25") + def df(self) -> dict[str, t.Any]: + """ + Get data usage information. + + Returns: + (dict): A dictionary representing different resource categories + and their respective data usage. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + url = self._url("/system/df") + return self._result(self._get(url), get_json=True) + + def info(self) -> dict[str, t.Any]: + """ + Display system-wide information. Identical to the ``docker info`` + command. + + Returns: + (dict): The info as a dict + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + return self._result(self._get(self._url("/info")), get_json=True) + + def login( + self, + username: str, + password: str | None = None, + email: str | None = None, + registry: str | None = None, + reauth: bool = False, + dockercfg_path: str | None = None, + ) -> dict[str, t.Any]: + """ + Authenticate with a registry. Similar to the ``docker login`` command. + + Args: + username (str): The registry username + password (str): The plaintext password + email (str): The email for the registry account + registry (str): URL to the registry. E.g. + ``https://index.docker.io/v1/`` + reauth (bool): Whether or not to refresh existing authentication on + the Docker server. + dockercfg_path (str): Use a custom path for the Docker config file + (default ``$HOME/.docker/config.json`` if present, + otherwise ``$HOME/.dockercfg``) + + Returns: + (dict): The response from the login request + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + + # If we do not have any auth data so far, try reloading the config file + # one more time in case anything showed up in there. + # If dockercfg_path is passed check to see if the config file exists, + # if so load that config. + if dockercfg_path and os.path.exists(dockercfg_path): + self._auth_configs = auth.load_config( + dockercfg_path, credstore_env=self.credstore_env + ) + elif not self._auth_configs or self._auth_configs.is_empty: + self._auth_configs = auth.load_config(credstore_env=self.credstore_env) + + authcfg = self._auth_configs.resolve_authconfig(registry) + # If we found an existing auth config for this registry and username + # combination, we can return it immediately unless reauth is requested. + if authcfg and authcfg.get("username", None) == username and not reauth: + return authcfg + + req_data = { + "username": username, + "password": password, + "email": email, + "serveraddress": registry, + } + + response = self._post_json(self._url("/auth"), data=req_data) + if response.status_code == 200: + self._auth_configs.add_auth(registry or auth.INDEX_NAME, req_data) + return self._result(response, get_json=True) + + def ping(self) -> bool: + """ + Checks the server is responsive. An exception will be raised if it + is not responding. + + Returns: + (bool) The response from the server. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + return self._result(self._get(self._url("/_ping"))) == "OK" + + def version(self, api_version: bool = True) -> dict[str, t.Any]: + """ + Returns version information from the server. Similar to the ``docker + version`` command. + + Returns: + (dict): The server version information + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + url = self._url("/version", versioned_api=api_version) + return self._result(self._get(url), get_json=True) diff --git a/plugins/module_utils/_api/api/daemon.py b/plugins/module_utils/_api/api/daemon.py deleted file mode 100644 index d16118d7..00000000 --- a/plugins/module_utils/_api/api/daemon.py +++ /dev/null @@ -1,139 +0,0 @@ -# This code is part of the Ansible collection community.docker, but is an independent component. -# This particular file, and this file only, is based on the Docker SDK for Python (https://github.com/docker/docker-py/) -# -# Copyright (c) 2016-2022 Docker, Inc. -# -# It is licensed under the Apache 2.0 license (see LICENSES/Apache-2.0.txt in this collection) -# SPDX-License-Identifier: Apache-2.0 - -# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time. -# Do not use this from other collections or standalone plugins/modules! - -from __future__ import annotations - -import os - -from .. import auth -from ..utils.decorators import minimum_version - - -class DaemonApiMixin: - @minimum_version("1.25") - def df(self): - """ - Get data usage information. - - Returns: - (dict): A dictionary representing different resource categories - and their respective data usage. - - Raises: - :py:class:`docker.errors.APIError` - If the server returns an error. - """ - url = self._url("/system/df") - return self._result(self._get(url), get_json=True) - - def info(self): - """ - Display system-wide information. Identical to the ``docker info`` - command. - - Returns: - (dict): The info as a dict - - Raises: - :py:class:`docker.errors.APIError` - If the server returns an error. - """ - return self._result(self._get(self._url("/info")), get_json=True) - - def login( - self, - username, - password=None, - email=None, - registry=None, - reauth=False, - dockercfg_path=None, - ): - """ - Authenticate with a registry. Similar to the ``docker login`` command. - - Args: - username (str): The registry username - password (str): The plaintext password - email (str): The email for the registry account - registry (str): URL to the registry. E.g. - ``https://index.docker.io/v1/`` - reauth (bool): Whether or not to refresh existing authentication on - the Docker server. - dockercfg_path (str): Use a custom path for the Docker config file - (default ``$HOME/.docker/config.json`` if present, - otherwise ``$HOME/.dockercfg``) - - Returns: - (dict): The response from the login request - - Raises: - :py:class:`docker.errors.APIError` - If the server returns an error. - """ - - # If we do not have any auth data so far, try reloading the config file - # one more time in case anything showed up in there. - # If dockercfg_path is passed check to see if the config file exists, - # if so load that config. - if dockercfg_path and os.path.exists(dockercfg_path): - self._auth_configs = auth.load_config( - dockercfg_path, credstore_env=self.credstore_env - ) - elif not self._auth_configs or self._auth_configs.is_empty: - self._auth_configs = auth.load_config(credstore_env=self.credstore_env) - - authcfg = self._auth_configs.resolve_authconfig(registry) - # If we found an existing auth config for this registry and username - # combination, we can return it immediately unless reauth is requested. - if authcfg and authcfg.get("username", None) == username and not reauth: - return authcfg - - req_data = { - "username": username, - "password": password, - "email": email, - "serveraddress": registry, - } - - response = self._post_json(self._url("/auth"), data=req_data) - if response.status_code == 200: - self._auth_configs.add_auth(registry or auth.INDEX_NAME, req_data) - return self._result(response, get_json=True) - - def ping(self): - """ - Checks the server is responsive. An exception will be raised if it - is not responding. - - Returns: - (bool) The response from the server. - - Raises: - :py:class:`docker.errors.APIError` - If the server returns an error. - """ - return self._result(self._get(self._url("/_ping"))) == "OK" - - def version(self, api_version=True): - """ - Returns version information from the server. Similar to the ``docker - version`` command. - - Returns: - (dict): The server version information - - Raises: - :py:class:`docker.errors.APIError` - If the server returns an error. - """ - url = self._url("/version", versioned_api=api_version) - return self._result(self._get(url), get_json=True) diff --git a/plugins/module_utils/_api/auth.py b/plugins/module_utils/_api/auth.py index 317e6c77..0c6cff00 100644 --- a/plugins/module_utils/_api/auth.py +++ b/plugins/module_utils/_api/auth.py @@ -14,6 +14,7 @@ from __future__ import annotations import base64 import json import logging +import typing as t from . import errors from .credentials.errors import CredentialsNotFound, StoreError @@ -21,6 +22,12 @@ from .credentials.store import Store from .utils import config +if t.TYPE_CHECKING: + from ansible_collections.community.docker.plugins.module_utils._api.api.client import ( + APIClient, + ) + + INDEX_NAME = "docker.io" INDEX_URL = f"https://index.{INDEX_NAME}/v1/" TOKEN_USERNAME = "" @@ -28,7 +35,7 @@ TOKEN_USERNAME = "" log = logging.getLogger(__name__) -def resolve_repository_name(repo_name): +def resolve_repository_name(repo_name: str) -> tuple[str, str]: if "://" in repo_name: raise errors.InvalidRepository( f"Repository name cannot contain a scheme ({repo_name})" @@ -42,14 +49,14 @@ def resolve_repository_name(repo_name): return resolve_index_name(index_name), remote_name -def resolve_index_name(index_name): +def resolve_index_name(index_name: str) -> str: index_name = convert_to_hostname(index_name) if index_name == "index." + INDEX_NAME: index_name = INDEX_NAME return index_name -def get_config_header(client, registry): +def get_config_header(client: APIClient, registry: str) -> bytes | None: log.debug("Looking for auth config") if not client._auth_configs or client._auth_configs.is_empty: log.debug("No auth config in memory - loading from filesystem") @@ -69,32 +76,38 @@ def get_config_header(client, registry): return None -def split_repo_name(repo_name): +def split_repo_name(repo_name: str) -> tuple[str, str]: parts = repo_name.split("/", 1) if len(parts) == 1 or ( "." not in parts[0] and ":" not in parts[0] and parts[0] != "localhost" ): # This is a docker index repo (ex: username/foobar or ubuntu) return INDEX_NAME, repo_name - return tuple(parts) + return tuple(parts) # type: ignore -def get_credential_store(authconfig, registry): +def get_credential_store( + authconfig: dict[str, t.Any] | AuthConfig, registry: str +) -> str | None: if not isinstance(authconfig, AuthConfig): authconfig = AuthConfig(authconfig) return authconfig.get_credential_store(registry) class AuthConfig(dict): - def __init__(self, dct, credstore_env=None): + def __init__( + self, dct: dict[str, t.Any], credstore_env: dict[str, str] | None = None + ): if "auths" not in dct: dct["auths"] = {} self.update(dct) self._credstore_env = credstore_env - self._stores = {} + self._stores: dict[str, Store] = {} @classmethod - def parse_auth(cls, entries, raise_on_error=False): + def parse_auth( + cls, entries: dict[str, dict[str, t.Any]], raise_on_error=False + ) -> dict[str, dict[str, t.Any]]: """ Parses authentication entries @@ -107,10 +120,10 @@ class AuthConfig(dict): Authentication registry. """ - conf = {} + conf: dict[str, dict[str, t.Any]] = {} for registry, entry in entries.items(): if not isinstance(entry, dict): - log.debug("Config entry for key %s is not auth config", registry) + log.debug("Config entry for key %s is not auth config", registry) # type: ignore # We sometimes fall back to parsing the whole config as if it # was the auth config by itself, for legacy purposes. In that # case, we fail silently and return an empty conf if any of the @@ -150,7 +163,12 @@ class AuthConfig(dict): return conf @classmethod - def load_config(cls, config_path, config_dict, credstore_env=None): + def load_config( + cls, + config_path: str | None, + config_dict: dict[str, t.Any] | None, + credstore_env: dict[str, str] | None = None, + ) -> t.Self: """ Loads authentication data from a Docker configuration file in the given root directory or if config_path is passed use given path. @@ -196,22 +214,24 @@ class AuthConfig(dict): return cls({"auths": cls.parse_auth(config_dict)}, credstore_env) @property - def auths(self): + def auths(self) -> dict[str, dict[str, t.Any]]: return self.get("auths", {}) @property - def creds_store(self): + def creds_store(self) -> str | None: return self.get("credsStore", None) @property - def cred_helpers(self): + def cred_helpers(self) -> dict[str, t.Any]: return self.get("credHelpers", {}) @property - def is_empty(self): + def is_empty(self) -> bool: return not self.auths and not self.creds_store and not self.cred_helpers - def resolve_authconfig(self, registry=None): + def resolve_authconfig( + self, registry: str | None = None + ) -> dict[str, t.Any] | None: """ Returns the authentication data from the given auth configuration for a specific registry. As with the Docker client, legacy entries in the @@ -244,7 +264,9 @@ class AuthConfig(dict): log.debug("No entry found") return None - def _resolve_authconfig_credstore(self, registry, credstore_name): + def _resolve_authconfig_credstore( + self, registry: str | None, credstore_name: str + ) -> dict[str, t.Any] | None: if not registry or registry == INDEX_NAME: # The ecosystem is a little schizophrenic with index.docker.io VS # docker.io - in that case, it seems the full URL is necessary. @@ -272,19 +294,19 @@ class AuthConfig(dict): except StoreError as e: raise errors.DockerException(f"Credentials store error: {e}") - def _get_store_instance(self, name): + def _get_store_instance(self, name: str): if name not in self._stores: self._stores[name] = Store(name, environment=self._credstore_env) return self._stores[name] - def get_credential_store(self, registry): + def get_credential_store(self, registry: str | None) -> str | None: if not registry or registry == INDEX_NAME: registry = INDEX_URL return self.cred_helpers.get(registry) or self.creds_store - def get_all_credentials(self): - auth_data = self.auths.copy() + def get_all_credentials(self) -> dict[str, dict[str, t.Any] | None]: + auth_data: dict[str, dict[str, t.Any] | None] = self.auths.copy() # type: ignore if self.creds_store: # Retrieve all credentials from the default store store = self._get_store_instance(self.creds_store) @@ -299,21 +321,23 @@ class AuthConfig(dict): return auth_data - def add_auth(self, reg, data): + def add_auth(self, reg: str, data: dict[str, t.Any]) -> None: self["auths"][reg] = data -def resolve_authconfig(authconfig, registry=None, credstore_env=None): +def resolve_authconfig( + authconfig, registry: str | None = None, credstore_env: dict[str, str] | None = None +): if not isinstance(authconfig, AuthConfig): authconfig = AuthConfig(authconfig, credstore_env) return authconfig.resolve_authconfig(registry) -def convert_to_hostname(url): +def convert_to_hostname(url: str) -> str: return url.replace("http://", "").replace("https://", "").split("/", 1)[0] -def decode_auth(auth): +def decode_auth(auth: str | bytes) -> tuple[str, str]: if isinstance(auth, str): auth = auth.encode("ascii") s = base64.b64decode(auth) @@ -321,12 +345,14 @@ def decode_auth(auth): return login.decode("utf8"), pwd.decode("utf8") -def encode_header(auth): +def encode_header(auth: dict[str, t.Any]) -> bytes: auth_json = json.dumps(auth).encode("ascii") return base64.urlsafe_b64encode(auth_json) -def parse_auth(entries, raise_on_error=False): +def parse_auth( + entries: dict[str, dict[str, t.Any]], raise_on_error: bool = False +) -> dict[str, dict[str, t.Any]]: """ Parses authentication entries @@ -342,11 +368,15 @@ def parse_auth(entries, raise_on_error=False): return AuthConfig.parse_auth(entries, raise_on_error) -def load_config(config_path=None, config_dict=None, credstore_env=None): +def load_config( + config_path: str | None = None, + config_dict: dict[str, t.Any] | None = None, + credstore_env: dict[str, str] | None = None, +) -> AuthConfig: return AuthConfig.load_config(config_path, config_dict, credstore_env) -def _load_legacy_config(config_file): +def _load_legacy_config(config_file: str) -> dict[str, dict[str, t.Any]]: log.debug("Attempting to parse legacy auth file format") try: data = [] diff --git a/plugins/module_utils/_api/context/api.py b/plugins/module_utils/_api/context/api.py index 2b026ab7..133357d2 100644 --- a/plugins/module_utils/_api/context/api.py +++ b/plugins/module_utils/_api/context/api.py @@ -13,6 +13,7 @@ from __future__ import annotations import json import os +import typing as t from .. import errors from .config import ( @@ -24,7 +25,11 @@ from .config import ( from .context import Context -def create_default_context(): +if t.TYPE_CHECKING: + from ..tls import TLSConfig + + +def create_default_context() -> Context: host = None if os.environ.get("DOCKER_HOST"): host = os.environ.get("DOCKER_HOST") @@ -42,7 +47,7 @@ class ContextAPI: DEFAULT_CONTEXT = None @classmethod - def get_default_context(cls): + def get_default_context(cls) -> Context: context = cls.DEFAULT_CONTEXT if context is None: context = create_default_context() @@ -52,13 +57,13 @@ class ContextAPI: @classmethod def create_context( cls, - name, - orchestrator=None, - host=None, - tls_cfg=None, - default_namespace=None, - skip_tls_verify=False, - ): + name: str, + orchestrator: str | None = None, + host: str | None = None, + tls_cfg: TLSConfig | None = None, + default_namespace: str | None = None, + skip_tls_verify: bool = False, + ) -> Context: """Creates a new context. Returns: (Context): a Context object. @@ -108,7 +113,7 @@ class ContextAPI: return ctx @classmethod - def get_context(cls, name=None): + def get_context(cls, name: str | None = None) -> Context | None: """Retrieves a context object. Args: name (str): The name of the context @@ -136,7 +141,7 @@ class ContextAPI: return Context.load_context(name) @classmethod - def contexts(cls): + def contexts(cls) -> list[Context]: """Context list. Returns: (Context): List of context objects. @@ -170,7 +175,7 @@ class ContextAPI: return contexts @classmethod - def get_current_context(cls): + def get_current_context(cls) -> Context | None: """Get current context. Returns: (Context): current context object. @@ -178,7 +183,7 @@ class ContextAPI: return cls.get_context() @classmethod - def set_current_context(cls, name="default"): + def set_current_context(cls, name: str = "default") -> None: ctx = cls.get_context(name) if not ctx: raise errors.ContextNotFound(name) @@ -188,7 +193,7 @@ class ContextAPI: raise errors.ContextException(f"Failed to set current context: {err}") @classmethod - def remove_context(cls, name): + def remove_context(cls, name: str) -> None: """Remove a context. Similar to the ``docker context rm`` command. Args: @@ -220,7 +225,7 @@ class ContextAPI: ctx.remove() @classmethod - def inspect_context(cls, name="default"): + def inspect_context(cls, name: str = "default") -> dict[str, t.Any]: """Inspect a context. Similar to the ``docker context inspect`` command. Args: diff --git a/plugins/module_utils/_api/context/config.py b/plugins/module_utils/_api/context/config.py index b3ff0aa0..6ab07b0d 100644 --- a/plugins/module_utils/_api/context/config.py +++ b/plugins/module_utils/_api/context/config.py @@ -23,7 +23,7 @@ from ..utils.utils import parse_host METAFILE = "meta.json" -def get_current_context_name_with_source(): +def get_current_context_name_with_source() -> tuple[str, str]: if os.environ.get("DOCKER_HOST"): return "default", "DOCKER_HOST environment variable set" if os.environ.get("DOCKER_CONTEXT"): @@ -41,11 +41,11 @@ def get_current_context_name_with_source(): return "default", "fallback value" -def get_current_context_name(): +def get_current_context_name() -> str: return get_current_context_name_with_source()[0] -def write_context_name_to_docker_config(name=None): +def write_context_name_to_docker_config(name: str | None = None) -> Exception | None: if name == "default": name = None docker_cfg_path = find_config_file() @@ -62,44 +62,45 @@ def write_context_name_to_docker_config(name=None): elif name: config["currentContext"] = name else: - return + return None if not docker_cfg_path: docker_cfg_path = get_default_config_file() try: with open(docker_cfg_path, "wt", encoding="utf-8") as f: json.dump(config, f, indent=4) + return None except Exception as e: # pylint: disable=broad-exception-caught return e -def get_context_id(name): +def get_context_id(name: str) -> str: return hashlib.sha256(name.encode("utf-8")).hexdigest() -def get_context_dir(): +def get_context_dir() -> str: docker_cfg_path = find_config_file() or get_default_config_file() return os.path.join(os.path.dirname(docker_cfg_path), "contexts") -def get_meta_dir(name=None): +def get_meta_dir(name: str | None = None) -> str: meta_dir = os.path.join(get_context_dir(), "meta") if name: return os.path.join(meta_dir, get_context_id(name)) return meta_dir -def get_meta_file(name): +def get_meta_file(name) -> str: return os.path.join(get_meta_dir(name), METAFILE) -def get_tls_dir(name=None, endpoint=""): +def get_tls_dir(name: str | None = None, endpoint: str = "") -> str: context_dir = get_context_dir() if name: return os.path.join(context_dir, "tls", get_context_id(name), endpoint) return os.path.join(context_dir, "tls") -def get_context_host(path=None, tls=False): +def get_context_host(path: str | None = None, tls: bool = False) -> str: host = parse_host(path, IS_WINDOWS_PLATFORM, tls) if host == DEFAULT_UNIX_SOCKET: # remove http+ from default docker socket url diff --git a/plugins/module_utils/_api/context/context.py b/plugins/module_utils/_api/context/context.py index b8a43fb0..aaa3c280 100644 --- a/plugins/module_utils/_api/context/context.py +++ b/plugins/module_utils/_api/context/context.py @@ -13,6 +13,7 @@ from __future__ import annotations import json import os +import typing as t from shutil import copyfile, rmtree from ..errors import ContextException @@ -33,21 +34,21 @@ class Context: def __init__( self, - name, - orchestrator=None, - host=None, - endpoints=None, - skip_tls_verify=False, - tls=False, - description=None, - ): + name: str, + orchestrator: str | None = None, + host: str | None = None, + endpoints: dict[str, dict[str, t.Any]] | None = None, + skip_tls_verify: bool = False, + tls: bool = False, + description: str | None = None, + ) -> None: if not name: raise ValueError("Name not provided") self.name = name self.context_type = None self.orchestrator = orchestrator self.endpoints = {} - self.tls_cfg = {} + self.tls_cfg: dict[str, TLSConfig] = {} self.meta_path = IN_MEMORY self.tls_path = IN_MEMORY self.description = description @@ -89,12 +90,12 @@ class Context: def set_endpoint( self, - name="docker", - host=None, - tls_cfg=None, - skip_tls_verify=False, - def_namespace=None, - ): + name: str = "docker", + host: str | None = None, + tls_cfg: TLSConfig | None = None, + skip_tls_verify: bool = False, + def_namespace: str | None = None, + ) -> None: self.endpoints[name] = { "Host": get_context_host(host, not skip_tls_verify or tls_cfg is not None), "SkipTLSVerify": skip_tls_verify, @@ -105,11 +106,11 @@ class Context: if tls_cfg: self.tls_cfg[name] = tls_cfg - def inspect(self): + def inspect(self) -> dict[str, t.Any]: return self() @classmethod - def load_context(cls, name): + def load_context(cls, name: str) -> t.Self | None: meta = Context._load_meta(name) if meta: instance = cls( @@ -125,12 +126,12 @@ class Context: return None @classmethod - def _load_meta(cls, name): + def _load_meta(cls, name: str) -> dict[str, t.Any] | None: meta_file = get_meta_file(name) if not os.path.isfile(meta_file): return None - metadata = {} + metadata: dict[str, t.Any] = {} try: with open(meta_file, "rt", encoding="utf-8") as f: metadata = json.load(f) @@ -154,7 +155,7 @@ class Context: return metadata - def _load_certs(self): + def _load_certs(self) -> None: certs = {} tls_dir = get_tls_dir(self.name) for endpoint in self.endpoints: @@ -184,7 +185,7 @@ class Context: self.tls_cfg = certs self.tls_path = tls_dir - def save(self): + def save(self) -> None: meta_dir = get_meta_dir(self.name) if not os.path.isdir(meta_dir): os.makedirs(meta_dir) @@ -216,54 +217,54 @@ class Context: self.meta_path = get_meta_dir(self.name) self.tls_path = get_tls_dir(self.name) - def remove(self): + def remove(self) -> None: if os.path.isdir(self.meta_path): rmtree(self.meta_path) if os.path.isdir(self.tls_path): rmtree(self.tls_path) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__}: '{self.name}'>" - def __str__(self): + def __str__(self) -> str: return json.dumps(self.__call__(), indent=2) - def __call__(self): + def __call__(self) -> dict[str, t.Any]: result = self.Metadata result.update(self.TLSMaterial) result.update(self.Storage) return result - def is_docker_host(self): + def is_docker_host(self) -> bool: return self.context_type is None @property - def Name(self): # pylint: disable=invalid-name + def Name(self) -> str: # pylint: disable=invalid-name return self.name @property - def Host(self): # pylint: disable=invalid-name + def Host(self) -> str | None: # pylint: disable=invalid-name if not self.orchestrator or self.orchestrator == "swarm": endpoint = self.endpoints.get("docker", None) if endpoint: - return endpoint.get("Host", None) + return endpoint.get("Host", None) # type: ignore return None - return self.endpoints[self.orchestrator].get("Host", None) + return self.endpoints[self.orchestrator].get("Host", None) # type: ignore @property - def Orchestrator(self): # pylint: disable=invalid-name + def Orchestrator(self) -> str | None: # pylint: disable=invalid-name return self.orchestrator @property - def Metadata(self): # pylint: disable=invalid-name - meta = {} + def Metadata(self) -> dict[str, t.Any]: # pylint: disable=invalid-name + meta: dict[str, t.Any] = {} if self.orchestrator: meta = {"StackOrchestrator": self.orchestrator} return {"Name": self.name, "Metadata": meta, "Endpoints": self.endpoints} @property - def TLSConfig(self): # pylint: disable=invalid-name + def TLSConfig(self) -> TLSConfig | None: # pylint: disable=invalid-name key = self.orchestrator if not key or key == "swarm": key = "docker" @@ -272,13 +273,15 @@ class Context: return None @property - def TLSMaterial(self): # pylint: disable=invalid-name - certs = {} + def TLSMaterial(self) -> dict[str, t.Any]: # pylint: disable=invalid-name + certs: dict[str, t.Any] = {} for endpoint, tls in self.tls_cfg.items(): - cert, key = tls.cert - certs[endpoint] = list(map(os.path.basename, [tls.ca_cert, cert, key])) + paths = [tls.ca_cert, *tls.cert] if tls.cert else [tls.ca_cert] + certs[endpoint] = [ + os.path.basename(path) if path else None for path in paths + ] return {"TLSMaterial": certs} @property - def Storage(self): # pylint: disable=invalid-name + def Storage(self) -> dict[str, t.Any]: # pylint: disable=invalid-name return {"Storage": {"MetadataPath": self.meta_path, "TLSPath": self.tls_path}} diff --git a/plugins/module_utils/_api/credentials/errors.py b/plugins/module_utils/_api/credentials/errors.py index 323f8f67..6faed91c 100644 --- a/plugins/module_utils/_api/credentials/errors.py +++ b/plugins/module_utils/_api/credentials/errors.py @@ -11,6 +11,12 @@ from __future__ import annotations +import typing as t + + +if t.TYPE_CHECKING: + from subprocess import CalledProcessError + class StoreError(RuntimeError): pass @@ -24,7 +30,7 @@ class InitializationError(StoreError): pass -def process_store_error(cpe, program): +def process_store_error(cpe: CalledProcessError, program: str) -> StoreError: message = cpe.output.decode("utf-8") if "credentials not found in native keychain" in message: return CredentialsNotFound(f"No matching credentials in {program}") diff --git a/plugins/module_utils/_api/credentials/store.py b/plugins/module_utils/_api/credentials/store.py index 5bf5fd28..1d560e91 100644 --- a/plugins/module_utils/_api/credentials/store.py +++ b/plugins/module_utils/_api/credentials/store.py @@ -14,13 +14,14 @@ from __future__ import annotations import errno import json import subprocess +import typing as t from . import constants, errors from .utils import create_environment_dict, find_executable class Store: - def __init__(self, program, environment=None): + def __init__(self, program: str, environment: dict[str, str] | None = None) -> None: """Create a store object that acts as an interface to perform the basic operations for storing, retrieving and erasing credentials using `program`. @@ -33,7 +34,7 @@ class Store: f"{self.program} not installed or not available in PATH" ) - def get(self, server): + def get(self, server: str | bytes) -> dict[str, t.Any]: """Retrieve credentials for `server`. If no credentials are found, a `StoreError` will be raised. """ @@ -53,7 +54,7 @@ class Store: return result - def store(self, server, username, secret): + def store(self, server: str, username: str, secret: str) -> bytes: """Store credentials for `server`. Raises a `StoreError` if an error occurs. """ @@ -62,7 +63,7 @@ class Store: ).encode("utf-8") return self._execute("store", data_input) - def erase(self, server): + def erase(self, server: str | bytes) -> None: """Erase credentials for `server`. Raises a `StoreError` if an error occurs. """ @@ -70,12 +71,16 @@ class Store: server = server.encode("utf-8") self._execute("erase", server) - def list(self): + def list(self) -> t.Any: """List stored credentials. Requires v0.4.0+ of the helper.""" data = self._execute("list", None) return json.loads(data.decode("utf-8")) - def _execute(self, subcmd, data_input): + def _execute(self, subcmd: str, data_input: bytes | None) -> bytes: + if self.exe is None: + raise errors.StoreError( + f"{self.program} not installed or not available in PATH" + ) output = None env = create_environment_dict(self.environment) try: diff --git a/plugins/module_utils/_api/credentials/utils.py b/plugins/module_utils/_api/credentials/utils.py index 7a82c34b..ff63d5df 100644 --- a/plugins/module_utils/_api/credentials/utils.py +++ b/plugins/module_utils/_api/credentials/utils.py @@ -15,7 +15,7 @@ import os from shutil import which -def find_executable(executable, path=None): +def find_executable(executable: str, path: str | None = None) -> str | None: """ As distutils.spawn.find_executable, but on Windows, look up every extension declared in PATHEXT instead of just `.exe` @@ -26,7 +26,7 @@ def find_executable(executable, path=None): return which(executable, path=path) -def create_environment_dict(overrides): +def create_environment_dict(overrides: dict[str, str] | None) -> dict[str, str]: """ Create and return a copy of os.environ with the specified overrides """ diff --git a/plugins/module_utils/_api/errors.py b/plugins/module_utils/_api/errors.py index 548b1d71..ca5494e8 100644 --- a/plugins/module_utils/_api/errors.py +++ b/plugins/module_utils/_api/errors.py @@ -11,6 +11,8 @@ from __future__ import annotations +import typing as t + from ansible.module_utils.common.text.converters import to_native from ._import_helper import HTTPError as _HTTPError @@ -25,7 +27,7 @@ class DockerException(Exception): """ -def create_api_error_from_http_exception(e): +def create_api_error_from_http_exception(e: _HTTPError) -> t.NoReturn: """ Create a suitable APIError from requests.exceptions.HTTPError. """ @@ -52,14 +54,16 @@ class APIError(_HTTPError, DockerException): An HTTP error from the API. """ - def __init__(self, message, response=None, explanation=None): + def __init__( + self, message: str | Exception, response=None, explanation: str | None = None + ) -> None: # requests 1.2 supports response as a keyword argument, but # requests 1.1 does not super().__init__(message) self.response = response self.explanation = explanation - def __str__(self): + def __str__(self) -> str: message = super().__str__() if self.is_client_error(): @@ -74,19 +78,20 @@ class APIError(_HTTPError, DockerException): return message @property - def status_code(self): + def status_code(self) -> int | None: if self.response is not None: return self.response.status_code + return None - def is_error(self): + def is_error(self) -> bool: return self.is_client_error() or self.is_server_error() - def is_client_error(self): + def is_client_error(self) -> bool: if self.status_code is None: return False return 400 <= self.status_code < 500 - def is_server_error(self): + def is_server_error(self) -> bool: if self.status_code is None: return False return 500 <= self.status_code < 600 @@ -121,10 +126,10 @@ class DeprecatedMethod(DockerException): class TLSParameterError(DockerException): - def __init__(self, msg): + def __init__(self, msg: str) -> None: self.msg = msg - def __str__(self): + def __str__(self) -> str: return self.msg + ( ". TLS configurations should map the Docker CLI " "client configurations. See " @@ -142,7 +147,14 @@ class ContainerError(DockerException): Represents a container that has exited with a non-zero exit code. """ - def __init__(self, container, exit_status, command, image, stderr): + def __init__( + self, + container: str, + exit_status: int, + command: list[str], + image: str, + stderr: str | None, + ): self.container = container self.exit_status = exit_status self.command = command @@ -156,12 +168,12 @@ class ContainerError(DockerException): class StreamParseError(RuntimeError): - def __init__(self, reason): + def __init__(self, reason: Exception) -> None: self.msg = reason class BuildError(DockerException): - def __init__(self, reason, build_log): + def __init__(self, reason: str, build_log: str) -> None: super().__init__(reason) self.msg = reason self.build_log = build_log @@ -171,7 +183,7 @@ class ImageLoadError(DockerException): pass -def create_unexpected_kwargs_error(name, kwargs): +def create_unexpected_kwargs_error(name: str, kwargs: dict[str, t.Any]) -> TypeError: quoted_kwargs = [f"'{k}'" for k in sorted(kwargs)] text = [f"{name}() "] if len(quoted_kwargs) == 1: @@ -183,42 +195,44 @@ def create_unexpected_kwargs_error(name, kwargs): class MissingContextParameter(DockerException): - def __init__(self, param): + def __init__(self, param: str) -> None: self.param = param - def __str__(self): + def __str__(self) -> str: return f"missing parameter: {self.param}" class ContextAlreadyExists(DockerException): - def __init__(self, name): + def __init__(self, name: str) -> None: self.name = name - def __str__(self): + def __str__(self) -> str: return f"context {self.name} already exists" class ContextException(DockerException): - def __init__(self, msg): + def __init__(self, msg: str) -> None: self.msg = msg - def __str__(self): + def __str__(self) -> str: return self.msg class ContextNotFound(DockerException): - def __init__(self, name): + def __init__(self, name: str) -> None: self.name = name - def __str__(self): + def __str__(self) -> str: return f"context '{self.name}' not found" class MissingRequirementException(DockerException): - def __init__(self, msg, requirement, import_exception): + def __init__( + self, msg: str, requirement: str, import_exception: ImportError | str + ) -> None: self.msg = msg self.requirement = requirement self.import_exception = import_exception - def __str__(self): + def __str__(self) -> str: return self.msg diff --git a/plugins/module_utils/_api/tls.py b/plugins/module_utils/_api/tls.py index 1b81c193..f2918200 100644 --- a/plugins/module_utils/_api/tls.py +++ b/plugins/module_utils/_api/tls.py @@ -12,12 +12,18 @@ from __future__ import annotations import os -import ssl +import typing as t from . import errors from .transport.ssladapter import SSLHTTPAdapter +if t.TYPE_CHECKING: + from ansible_collections.community.docker.plugins.module_utils._api.api.client import ( + APIClient, + ) + + class TLSConfig: """ TLS configuration. @@ -27,25 +33,22 @@ class TLSConfig: ca_cert (str): Path to CA cert file. verify (bool or str): This can be ``False`` or a path to a CA cert file. - ssl_version (int): A valid `SSL version`_. assert_hostname (bool): Verify the hostname of the server. .. _`SSL version`: https://docs.python.org/3.5/library/ssl.html#ssl.PROTOCOL_TLSv1 """ - cert = None - ca_cert = None - verify = None - ssl_version = None + cert: tuple[str, str] | None = None + ca_cert: str | None = None + verify: bool | None = None def __init__( self, - client_cert=None, - ca_cert=None, - verify=None, - ssl_version=None, - assert_hostname=None, + client_cert: tuple[str, str] | None = None, + ca_cert: str | None = None, + verify: bool | None = None, + assert_hostname: bool | None = None, ): # Argument compatibility/mapping with # https://docs.docker.com/engine/articles/https/ @@ -55,12 +58,6 @@ class TLSConfig: self.assert_hostname = assert_hostname - # If the user provides an SSL version, we should use their preference - if ssl_version: - self.ssl_version = ssl_version - else: - self.ssl_version = ssl.PROTOCOL_TLS_CLIENT - # "client_cert" must have both or neither cert/key files. In # either case, Alert the user when both are expected, but any are # missing. @@ -90,11 +87,10 @@ class TLSConfig: "Invalid CA certificate provided for `ca_cert`." ) - def configure_client(self, client): + def configure_client(self, client: APIClient) -> None: """ Configure a client with these TLS options. """ - client.ssl_version = self.ssl_version if self.verify and self.ca_cert: client.verify = self.ca_cert @@ -107,7 +103,6 @@ class TLSConfig: client.mount( "https://", SSLHTTPAdapter( - ssl_version=self.ssl_version, assert_hostname=self.assert_hostname, ), ) diff --git a/plugins/module_utils/_api/transport/basehttpadapter.py b/plugins/module_utils/_api/transport/basehttpadapter.py index 603ba3eb..90239199 100644 --- a/plugins/module_utils/_api/transport/basehttpadapter.py +++ b/plugins/module_utils/_api/transport/basehttpadapter.py @@ -15,7 +15,7 @@ from .._import_helper import HTTPAdapter as _HTTPAdapter class BaseHTTPAdapter(_HTTPAdapter): - def close(self): + def close(self) -> None: super().close() if hasattr(self, "pools"): self.pools.clear() @@ -24,10 +24,10 @@ class BaseHTTPAdapter(_HTTPAdapter): # https://github.com/psf/requests/commit/c0813a2d910ea6b4f8438b91d315b8d181302356 # changes requests.adapters.HTTPAdapter to no longer call get_connection() from # send(), but instead call _get_connection(). - def _get_connection(self, request, *args, **kwargs): + def _get_connection(self, request, *args, **kwargs): # type: ignore return self.get_connection(request.url, kwargs.get("proxies")) # Fix for requests 2.32.2+: # https://github.com/psf/requests/commit/c98e4d133ef29c46a9b68cd783087218a8075e05 - def get_connection_with_tls_context(self, request, verify, proxies=None, cert=None): + def get_connection_with_tls_context(self, request, verify, proxies=None, cert=None): # type: ignore return self.get_connection(request.url, proxies) diff --git a/plugins/module_utils/_api/transport/npipeconn.py b/plugins/module_utils/_api/transport/npipeconn.py index 3f618b38..b49cb907 100644 --- a/plugins/module_utils/_api/transport/npipeconn.py +++ b/plugins/module_utils/_api/transport/npipeconn.py @@ -23,12 +23,12 @@ RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer class NpipeHTTPConnection(urllib3_connection.HTTPConnection): - def __init__(self, npipe_path, timeout=60): + def __init__(self, npipe_path: str, timeout: int = 60) -> None: super().__init__("localhost", timeout=timeout) self.npipe_path = npipe_path self.timeout = timeout - def connect(self): + def connect(self) -> None: sock = NpipeSocket() sock.settimeout(self.timeout) sock.connect(self.npipe_path) @@ -36,18 +36,18 @@ class NpipeHTTPConnection(urllib3_connection.HTTPConnection): class NpipeHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool): - def __init__(self, npipe_path, timeout=60, maxsize=10): + def __init__(self, npipe_path: str, timeout: int = 60, maxsize: int = 10) -> None: super().__init__("localhost", timeout=timeout, maxsize=maxsize) self.npipe_path = npipe_path self.timeout = timeout - def _new_conn(self): + def _new_conn(self) -> NpipeHTTPConnection: return NpipeHTTPConnection(self.npipe_path, self.timeout) # When re-using connections, urllib3 tries to call select() on our # NpipeSocket instance, causing a crash. To circumvent this, we override # _get_conn, where that check happens. - def _get_conn(self, timeout): + def _get_conn(self, timeout: int) -> NpipeHTTPConnection: conn = None try: conn = self.pool.get(block=self.block, timeout=timeout) @@ -67,7 +67,6 @@ class NpipeHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool): class NpipeHTTPAdapter(BaseHTTPAdapter): - __attrs__ = HTTPAdapter.__attrs__ + [ "npipe_path", "pools", @@ -77,11 +76,11 @@ class NpipeHTTPAdapter(BaseHTTPAdapter): def __init__( self, - base_url, - timeout=60, - pool_connections=constants.DEFAULT_NUM_POOLS, - max_pool_size=constants.DEFAULT_MAX_POOL_SIZE, - ): + base_url: str, + timeout: int = 60, + pool_connections: int = constants.DEFAULT_NUM_POOLS, + max_pool_size: int = constants.DEFAULT_MAX_POOL_SIZE, + ) -> None: self.npipe_path = base_url.replace("npipe://", "") self.timeout = timeout self.max_pool_size = max_pool_size @@ -90,7 +89,7 @@ class NpipeHTTPAdapter(BaseHTTPAdapter): ) super().__init__() - def get_connection(self, url, proxies=None): + def get_connection(self, url: str | bytes, proxies=None) -> NpipeHTTPConnectionPool: with self.pools.lock: pool = self.pools.get(url) if pool: @@ -103,7 +102,7 @@ class NpipeHTTPAdapter(BaseHTTPAdapter): return pool - def request_url(self, request, proxies): + def request_url(self, request, proxies) -> str: # The select_proxy utility in requests errors out when the provided URL # does not have a hostname, like is the case when using a UNIX socket. # Since proxies are an irrelevant notion in the case of UNIX sockets diff --git a/plugins/module_utils/_api/transport/npipesocket.py b/plugins/module_utils/_api/transport/npipesocket.py index 23183c39..e4473f49 100644 --- a/plugins/module_utils/_api/transport/npipesocket.py +++ b/plugins/module_utils/_api/transport/npipesocket.py @@ -15,6 +15,7 @@ import functools import io import time import traceback +import typing as t PYWIN32_IMPORT_ERROR: str | None # pylint: disable=invalid-name @@ -29,6 +30,13 @@ except ImportError: else: PYWIN32_IMPORT_ERROR = None # pylint: disable=invalid-name +if t.TYPE_CHECKING: + from collections.abc import Buffer, Callable + + _Self = t.TypeVar("_Self") + _P = t.ParamSpec("_P") + _R = t.TypeVar("_R") + ERROR_PIPE_BUSY = 0xE7 SECURITY_SQOS_PRESENT = 0x100000 @@ -37,10 +45,12 @@ SECURITY_ANONYMOUS = 0 MAXIMUM_RETRY_COUNT = 10 -def check_closed(f): +def check_closed( + f: Callable[t.Concatenate[_Self, _P], _R], +) -> Callable[t.Concatenate[_Self, _P], _R]: @functools.wraps(f) - def wrapped(self, *args, **kwargs): - if self._closed: + def wrapped(self: _Self, *args: _P.args, **kwargs: _P.kwargs) -> _R: + if self._closed: # type: ignore raise RuntimeError("Can not reuse socket after connection was closed.") return f(self, *args, **kwargs) @@ -54,25 +64,25 @@ class NpipeSocket: implemented. """ - def __init__(self, handle=None): + def __init__(self, handle=None) -> None: self._timeout = win32pipe.NMPWAIT_USE_DEFAULT_WAIT self._handle = handle - self._address = None + self._address: str | None = None self._closed = False - self.flags = None + self.flags: int | None = None - def accept(self): + def accept(self) -> t.NoReturn: raise NotImplementedError() - def bind(self, address): + def bind(self, address) -> t.NoReturn: raise NotImplementedError() - def close(self): + def close(self) -> None: self._handle.Close() self._closed = True @check_closed - def connect(self, address, retry_count=0): + def connect(self, address, retry_count: int = 0) -> None: try: handle = win32file.CreateFile( address, @@ -100,14 +110,14 @@ class NpipeSocket: return self.connect(address, retry_count) raise e - self.flags = win32pipe.GetNamedPipeInfo(handle)[0] + self.flags = win32pipe.GetNamedPipeInfo(handle)[0] # type: ignore self._handle = handle self._address = address @check_closed - def connect_ex(self, address): - return self.connect(address) + def connect_ex(self, address) -> None: + self.connect(address) @check_closed def detach(self): @@ -115,25 +125,25 @@ class NpipeSocket: return self._handle @check_closed - def dup(self): + def dup(self) -> NpipeSocket: return NpipeSocket(self._handle) - def getpeername(self): + def getpeername(self) -> str | None: return self._address - def getsockname(self): + def getsockname(self) -> str | None: return self._address - def getsockopt(self, level, optname, buflen=None): + def getsockopt(self, level, optname, buflen=None) -> t.NoReturn: raise NotImplementedError() - def ioctl(self, control, option): + def ioctl(self, control, option) -> t.NoReturn: raise NotImplementedError() - def listen(self, backlog): + def listen(self, backlog) -> t.NoReturn: raise NotImplementedError() - def makefile(self, mode=None, bufsize=None): + def makefile(self, mode: str, bufsize: int | None = None): if mode.strip("b") != "r": raise NotImplementedError() rawio = NpipeFileIOBase(self) @@ -142,30 +152,30 @@ class NpipeSocket: return io.BufferedReader(rawio, buffer_size=bufsize) @check_closed - def recv(self, bufsize, flags=0): + def recv(self, bufsize: int, flags: int = 0) -> str: dummy_err, data = win32file.ReadFile(self._handle, bufsize) return data @check_closed - def recvfrom(self, bufsize, flags=0): + def recvfrom(self, bufsize: int, flags: int = 0) -> tuple[str, str | None]: data = self.recv(bufsize, flags) return (data, self._address) @check_closed - def recvfrom_into(self, buf, nbytes=0, flags=0): - return self.recv_into(buf, nbytes, flags), self._address + def recvfrom_into( + self, buf: Buffer, nbytes: int = 0, flags: int = 0 + ) -> tuple[int, str | None]: + return self.recv_into(buf, nbytes), self._address @check_closed - def recv_into(self, buf, nbytes=0): - readbuf = buf - if not isinstance(buf, memoryview): - readbuf = memoryview(buf) + def recv_into(self, buf: Buffer, nbytes: int = 0) -> int: + readbuf = buf if isinstance(buf, memoryview) else memoryview(buf) event = win32event.CreateEvent(None, True, True, None) try: overlapped = pywintypes.OVERLAPPED() overlapped.hEvent = event - dummy_err, dummy_data = win32file.ReadFile( + dummy_err, dummy_data = win32file.ReadFile( # type: ignore self._handle, readbuf[:nbytes] if nbytes else readbuf, overlapped ) wait_result = win32event.WaitForSingleObject(event, self._timeout) @@ -177,12 +187,12 @@ class NpipeSocket: win32api.CloseHandle(event) @check_closed - def send(self, string, flags=0): + def send(self, string: Buffer, flags: int = 0) -> int: event = win32event.CreateEvent(None, True, True, None) try: overlapped = pywintypes.OVERLAPPED() overlapped.hEvent = event - win32file.WriteFile(self._handle, string, overlapped) + win32file.WriteFile(self._handle, string, overlapped) # type: ignore wait_result = win32event.WaitForSingleObject(event, self._timeout) if wait_result == win32event.WAIT_TIMEOUT: win32file.CancelIo(self._handle) @@ -192,20 +202,20 @@ class NpipeSocket: win32api.CloseHandle(event) @check_closed - def sendall(self, string, flags=0): + def sendall(self, string: Buffer, flags: int = 0) -> int: return self.send(string, flags) @check_closed - def sendto(self, string, address): + def sendto(self, string: Buffer, address: str) -> int: self.connect(address) return self.send(string) - def setblocking(self, flag): + def setblocking(self, flag: bool): if flag: return self.settimeout(None) return self.settimeout(0) - def settimeout(self, value): + def settimeout(self, value: int | float | None) -> None: if value is None: # Blocking mode self._timeout = win32event.INFINITE @@ -215,39 +225,39 @@ class NpipeSocket: # Timeout mode - Value converted to milliseconds self._timeout = int(value * 1000) - def gettimeout(self): + def gettimeout(self) -> int | float | None: return self._timeout - def setsockopt(self, level, optname, value): + def setsockopt(self, level, optname, value) -> t.NoReturn: raise NotImplementedError() @check_closed - def shutdown(self, how): + def shutdown(self, how) -> None: return self.close() class NpipeFileIOBase(io.RawIOBase): - def __init__(self, npipe_socket): + def __init__(self, npipe_socket) -> None: self.sock = npipe_socket - def close(self): + def close(self) -> None: super().close() self.sock = None - def fileno(self): + def fileno(self) -> int: return self.sock.fileno() - def isatty(self): + def isatty(self) -> bool: return False - def readable(self): + def readable(self) -> bool: return True - def readinto(self, buf): + def readinto(self, buf: Buffer) -> int: return self.sock.recv_into(buf) - def seekable(self): + def seekable(self) -> bool: return False - def writable(self): + def writable(self) -> bool: return False diff --git a/plugins/module_utils/_api/transport/sshconn.py b/plugins/module_utils/_api/transport/sshconn.py index 40fbce4c..9173a55a 100644 --- a/plugins/module_utils/_api/transport/sshconn.py +++ b/plugins/module_utils/_api/transport/sshconn.py @@ -17,6 +17,7 @@ import signal import socket import subprocess import traceback +import typing as t from queue import Empty from urllib.parse import urlparse @@ -33,12 +34,15 @@ except ImportError: else: PARAMIKO_IMPORT_ERROR = None # pylint: disable=invalid-name +if t.TYPE_CHECKING: + from collections.abc import Buffer + RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer class SSHSocket(socket.socket): - def __init__(self, host): + def __init__(self, host: str) -> None: super().__init__(socket.AF_INET, socket.SOCK_STREAM) self.host = host self.port = None @@ -48,9 +52,9 @@ class SSHSocket(socket.socket): if "@" in self.host: self.user, self.host = self.host.split("@") - self.proc = None + self.proc: subprocess.Popen | None = None - def connect(self, **kwargs): + def connect(self, *args_: t.Any, **kwargs: t.Any) -> None: args = ["ssh"] if self.user: args = args + ["-l", self.user] @@ -82,37 +86,48 @@ class SSHSocket(socket.socket): preexec_fn=preexec_func, ) - def _write(self, data): - if not self.proc or self.proc.stdin.closed: + def _write(self, data: Buffer) -> int: + if not self.proc: raise RuntimeError( "SSH subprocess not initiated. connect() must be called first." ) + assert self.proc.stdin is not None + if self.proc.stdin.closed: + raise RuntimeError( + "SSH subprocess not initiated. connect() must be called first after close()." + ) written = self.proc.stdin.write(data) self.proc.stdin.flush() return written - def sendall(self, data): + def sendall(self, data: Buffer, *args, **kwargs) -> None: self._write(data) - def send(self, data): + def send(self, data: Buffer, *args, **kwargs) -> int: return self._write(data) - def recv(self, n): + def recv(self, n: int, *args, **kwargs) -> bytes: if not self.proc: raise RuntimeError( "SSH subprocess not initiated. connect() must be called first." ) + assert self.proc.stdout is not None return self.proc.stdout.read(n) - def makefile(self, mode): + def makefile(self, mode: str, *args, **kwargs) -> t.IO: # type: ignore if not self.proc: self.connect() - self.proc.stdout.channel = self + assert self.proc is not None + assert self.proc.stdout is not None + self.proc.stdout.channel = self # type: ignore return self.proc.stdout - def close(self): - if not self.proc or self.proc.stdin.closed: + def close(self) -> None: + if not self.proc: + return + assert self.proc.stdin is not None + if self.proc.stdin.closed: return self.proc.stdin.write(b"\n\n") self.proc.stdin.flush() @@ -120,13 +135,19 @@ class SSHSocket(socket.socket): class SSHConnection(urllib3_connection.HTTPConnection): - def __init__(self, ssh_transport=None, timeout=60, host=None): + def __init__( + self, + *, + ssh_transport=None, + timeout: int = 60, + host: str, + ) -> None: super().__init__("localhost", timeout=timeout) self.ssh_transport = ssh_transport self.timeout = timeout self.ssh_host = host - def connect(self): + def connect(self) -> None: if self.ssh_transport: sock = self.ssh_transport.open_session() sock.settimeout(self.timeout) @@ -142,7 +163,14 @@ class SSHConnection(urllib3_connection.HTTPConnection): class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool): scheme = "ssh" - def __init__(self, ssh_client=None, timeout=60, maxsize=10, host=None): + def __init__( + self, + *, + ssh_client: paramiko.SSHClient | None = None, + timeout: int = 60, + maxsize: int = 10, + host: str, + ) -> None: super().__init__("localhost", timeout=timeout, maxsize=maxsize) self.ssh_transport = None self.timeout = timeout @@ -150,13 +178,17 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool): self.ssh_transport = ssh_client.get_transport() self.ssh_host = host - def _new_conn(self): - return SSHConnection(self.ssh_transport, self.timeout, self.ssh_host) + def _new_conn(self) -> SSHConnection: + return SSHConnection( + ssh_transport=self.ssh_transport, + timeout=self.timeout, + host=self.ssh_host, + ) # When re-using connections, urllib3 calls fileno() on our # SSH channel instance, quickly overloading our fd limit. To avoid this, # we override _get_conn - def _get_conn(self, timeout): + def _get_conn(self, timeout: int) -> SSHConnection: conn = None try: conn = self.pool.get(block=self.block, timeout=timeout) @@ -176,7 +208,6 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool): class SSHHTTPAdapter(BaseHTTPAdapter): - __attrs__ = HTTPAdapter.__attrs__ + [ "pools", "timeout", @@ -187,13 +218,13 @@ class SSHHTTPAdapter(BaseHTTPAdapter): def __init__( self, - base_url, - timeout=60, - pool_connections=constants.DEFAULT_NUM_POOLS, - max_pool_size=constants.DEFAULT_MAX_POOL_SIZE, - shell_out=False, - ): - self.ssh_client = None + base_url: str, + timeout: int = 60, + pool_connections: int = constants.DEFAULT_NUM_POOLS, + max_pool_size: int = constants.DEFAULT_MAX_POOL_SIZE, + shell_out: bool = False, + ) -> None: + self.ssh_client: paramiko.SSHClient | None = None if not shell_out: self._create_paramiko_client(base_url) self._connect() @@ -209,30 +240,31 @@ class SSHHTTPAdapter(BaseHTTPAdapter): ) super().__init__() - def _create_paramiko_client(self, base_url): + def _create_paramiko_client(self, base_url: str) -> None: logging.getLogger("paramiko").setLevel(logging.WARNING) self.ssh_client = paramiko.SSHClient() - base_url = urlparse(base_url) - self.ssh_params = { - "hostname": base_url.hostname, - "port": base_url.port, - "username": base_url.username, + base_url_p = urlparse(base_url) + assert base_url_p.hostname is not None + self.ssh_params: dict[str, t.Any] = { + "hostname": base_url_p.hostname, + "port": base_url_p.port, + "username": base_url_p.username, } ssh_config_file = os.path.expanduser("~/.ssh/config") if os.path.exists(ssh_config_file): conf = paramiko.SSHConfig() with open(ssh_config_file, "rt", encoding="utf-8") as f: conf.parse(f) - host_config = conf.lookup(base_url.hostname) + host_config = conf.lookup(base_url_p.hostname) if "proxycommand" in host_config: self.ssh_params["sock"] = paramiko.ProxyCommand( host_config["proxycommand"] ) if "hostname" in host_config: self.ssh_params["hostname"] = host_config["hostname"] - if base_url.port is None and "port" in host_config: + if base_url_p.port is None and "port" in host_config: self.ssh_params["port"] = host_config["port"] - if base_url.username is None and "user" in host_config: + if base_url_p.username is None and "user" in host_config: self.ssh_params["username"] = host_config["user"] if "identityfile" in host_config: self.ssh_params["key_filename"] = host_config["identityfile"] @@ -240,11 +272,11 @@ class SSHHTTPAdapter(BaseHTTPAdapter): self.ssh_client.load_system_host_keys() self.ssh_client.set_missing_host_key_policy(paramiko.RejectPolicy()) - def _connect(self): + def _connect(self) -> None: if self.ssh_client: self.ssh_client.connect(**self.ssh_params) - def get_connection(self, url, proxies=None): + def get_connection(self, url: str | bytes, proxies=None) -> SSHConnectionPool: if not self.ssh_client: return SSHConnectionPool( ssh_client=self.ssh_client, @@ -271,7 +303,7 @@ class SSHHTTPAdapter(BaseHTTPAdapter): return pool - def close(self): + def close(self) -> None: super().close() if self.ssh_client: self.ssh_client.close() diff --git a/plugins/module_utils/_api/transport/ssladapter.py b/plugins/module_utils/_api/transport/ssladapter.py index d0cb8f79..2cad6cea 100644 --- a/plugins/module_utils/_api/transport/ssladapter.py +++ b/plugins/module_utils/_api/transport/ssladapter.py @@ -11,9 +11,7 @@ from __future__ import annotations -from ansible_collections.community.docker.plugins.module_utils._version import ( - LooseVersion, -) +import typing as t from .._import_helper import HTTPAdapter, urllib3 from .basehttpadapter import BaseHTTPAdapter @@ -30,14 +28,19 @@ PoolManager = urllib3.poolmanager.PoolManager class SSLHTTPAdapter(BaseHTTPAdapter): """An HTTPS Transport Adapter that uses an arbitrary SSL version.""" - __attrs__ = HTTPAdapter.__attrs__ + ["assert_hostname", "ssl_version"] + __attrs__ = HTTPAdapter.__attrs__ + ["assert_hostname"] - def __init__(self, ssl_version=None, assert_hostname=None, **kwargs): - self.ssl_version = ssl_version + def __init__( + self, + assert_hostname: bool | None = None, + **kwargs, + ) -> None: self.assert_hostname = assert_hostname super().__init__(**kwargs) - def init_poolmanager(self, connections, maxsize, block=False): + def init_poolmanager( + self, connections: int, maxsize: int, block: bool = False, **kwargs: t.Any + ) -> None: kwargs = { "num_pools": connections, "maxsize": maxsize, @@ -45,12 +48,10 @@ class SSLHTTPAdapter(BaseHTTPAdapter): } if self.assert_hostname is not None: kwargs["assert_hostname"] = self.assert_hostname - if self.ssl_version and self.can_override_ssl_version(): - kwargs["ssl_version"] = self.ssl_version self.poolmanager = PoolManager(**kwargs) - def get_connection(self, *args, **kwargs): + def get_connection(self, *args, **kwargs) -> urllib3.ConnectionPool: """ Ensure assert_hostname is set correctly on our pool @@ -61,15 +62,7 @@ class SSLHTTPAdapter(BaseHTTPAdapter): conn = super().get_connection(*args, **kwargs) if ( self.assert_hostname is not None - and conn.assert_hostname != self.assert_hostname + and conn.assert_hostname != self.assert_hostname # type: ignore ): - conn.assert_hostname = self.assert_hostname + conn.assert_hostname = self.assert_hostname # type: ignore return conn - - def can_override_ssl_version(self): - urllib_ver = urllib3.__version__.split("-")[0] - if urllib_ver is None: - return False - if urllib_ver == "dev": - return True - return LooseVersion(urllib_ver) > LooseVersion("1.5") diff --git a/plugins/module_utils/_api/transport/unixconn.py b/plugins/module_utils/_api/transport/unixconn.py index 2c615986..1b7894c1 100644 --- a/plugins/module_utils/_api/transport/unixconn.py +++ b/plugins/module_utils/_api/transport/unixconn.py @@ -12,6 +12,7 @@ from __future__ import annotations import socket +import typing as t from .. import constants from .._import_helper import HTTPAdapter, urllib3, urllib3_connection @@ -22,26 +23,25 @@ RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer class UnixHTTPConnection(urllib3_connection.HTTPConnection): - - def __init__(self, base_url, unix_socket, timeout=60): + def __init__(self, base_url: str | bytes, unix_socket, timeout: int = 60) -> None: super().__init__("localhost", timeout=timeout) self.base_url = base_url self.unix_socket = unix_socket self.timeout = timeout self.disable_buffering = False - def connect(self): + def connect(self) -> None: sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock.settimeout(self.timeout) sock.connect(self.unix_socket) self.sock = sock - def putheader(self, header, *values): + def putheader(self, header: str, *values: str) -> None: super().putheader(header, *values) if header == "Connection" and "Upgrade" in values: self.disable_buffering = True - def response_class(self, sock, *args, **kwargs): + def response_class(self, sock, *args, **kwargs) -> t.Any: # FIXME: We may need to disable buffering on Py3, # but there's no clear way to do it at the moment. See: # https://github.com/docker/docker-py/issues/1799 @@ -49,18 +49,23 @@ class UnixHTTPConnection(urllib3_connection.HTTPConnection): class UnixHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool): - def __init__(self, base_url, socket_path, timeout=60, maxsize=10): + def __init__( + self, + base_url: str | bytes, + socket_path: str, + timeout: int = 60, + maxsize: int = 10, + ) -> None: super().__init__("localhost", timeout=timeout, maxsize=maxsize) self.base_url = base_url self.socket_path = socket_path self.timeout = timeout - def _new_conn(self): + def _new_conn(self) -> UnixHTTPConnection: return UnixHTTPConnection(self.base_url, self.socket_path, self.timeout) class UnixHTTPAdapter(BaseHTTPAdapter): - __attrs__ = HTTPAdapter.__attrs__ + [ "pools", "socket_path", @@ -70,11 +75,11 @@ class UnixHTTPAdapter(BaseHTTPAdapter): def __init__( self, - socket_url, - timeout=60, - pool_connections=constants.DEFAULT_NUM_POOLS, - max_pool_size=constants.DEFAULT_MAX_POOL_SIZE, - ): + socket_url: str, + timeout: int = 60, + pool_connections: int = constants.DEFAULT_NUM_POOLS, + max_pool_size: int = constants.DEFAULT_MAX_POOL_SIZE, + ) -> None: socket_path = socket_url.replace("http+unix://", "") if not socket_path.startswith("/"): socket_path = "/" + socket_path @@ -86,7 +91,7 @@ class UnixHTTPAdapter(BaseHTTPAdapter): ) super().__init__() - def get_connection(self, url, proxies=None): + def get_connection(self, url: str | bytes, proxies=None) -> UnixHTTPConnectionPool: with self.pools.lock: pool = self.pools.get(url) if pool: @@ -99,7 +104,7 @@ class UnixHTTPAdapter(BaseHTTPAdapter): return pool - def request_url(self, request, proxies): + def request_url(self, request, proxies) -> str: # The select_proxy utility in requests errors out when the provided URL # does not have a hostname, like is the case when using a UNIX socket. # Since proxies are an irrelevant notion in the case of UNIX sockets diff --git a/plugins/module_utils/_api/types/daemon.py b/plugins/module_utils/_api/types/daemon.py index 6defe9b2..eb386169 100644 --- a/plugins/module_utils/_api/types/daemon.py +++ b/plugins/module_utils/_api/types/daemon.py @@ -12,6 +12,7 @@ from __future__ import annotations import socket +import typing as t from .._import_helper import urllib3 from ..errors import DockerException @@ -29,11 +30,11 @@ class CancellableStream: >>> events.close() """ - def __init__(self, stream, response): + def __init__(self, stream, response) -> None: self._stream = stream self._response = response - def __iter__(self): + def __iter__(self) -> t.Self: return self def __next__(self): @@ -46,7 +47,7 @@ class CancellableStream: next = __next__ - def close(self): + def close(self) -> None: """ Closes the event streaming. """ diff --git a/plugins/module_utils/_api/utils/build.py b/plugins/module_utils/_api/utils/build.py index d15774be..5c2a90c8 100644 --- a/plugins/module_utils/_api/utils/build.py +++ b/plugins/module_utils/_api/utils/build.py @@ -17,15 +17,26 @@ import random import re import tarfile import tempfile +import typing as t from ..constants import IS_WINDOWS_PLATFORM, WINDOWS_LONGPATH_PREFIX from . import fnmatch +if t.TYPE_CHECKING: + from collections.abc import Sequence + + _SEP = re.compile("/|\\\\") if IS_WINDOWS_PLATFORM else re.compile("/") -def tar(path, exclude=None, dockerfile=None, fileobj=None, gzip=False): +def tar( + path: str, + exclude: list[str] | None = None, + dockerfile: tuple[str, str] | tuple[None, None] | None = None, + fileobj: t.IO[bytes] | None = None, + gzip: bool = False, +) -> t.IO[bytes]: root = os.path.abspath(path) exclude = exclude or [] dockerfile = dockerfile or (None, None) @@ -47,7 +58,9 @@ def tar(path, exclude=None, dockerfile=None, fileobj=None, gzip=False): ) -def exclude_paths(root, patterns, dockerfile=None): +def exclude_paths( + root: str, patterns: list[str], dockerfile: str | None = None +) -> set[str]: """ Given a root directory path and a list of .dockerignore patterns, return an iterator of all paths (both regular files and directories) in the root @@ -64,7 +77,7 @@ def exclude_paths(root, patterns, dockerfile=None): return set(pm.walk(root)) -def build_file_list(root): +def build_file_list(root: str) -> list[str]: files = [] for dirname, dirnames, fnames in os.walk(root): for filename in fnames + dirnames: @@ -74,7 +87,13 @@ def build_file_list(root): return files -def create_archive(root, files=None, fileobj=None, gzip=False, extra_files=None): +def create_archive( + root: str, + files: Sequence[str] | None = None, + fileobj: t.IO[bytes] | None = None, + gzip: bool = False, + extra_files: Sequence[tuple[str, str]] | None = None, +) -> t.IO[bytes]: extra_files = extra_files or [] if not fileobj: fileobj = tempfile.NamedTemporaryFile() @@ -92,7 +111,7 @@ def create_archive(root, files=None, fileobj=None, gzip=False, extra_files=None) if i is None: # This happens when we encounter a socket file. We can safely # ignore it and proceed. - continue + continue # type: ignore # Workaround https://bugs.python.org/issue32713 if i.mtime < 0 or i.mtime > 8**11 - 1: @@ -124,11 +143,11 @@ def create_archive(root, files=None, fileobj=None, gzip=False, extra_files=None) return fileobj -def mkbuildcontext(dockerfile): +def mkbuildcontext(dockerfile: io.BytesIO | t.IO[bytes]) -> t.IO[bytes]: f = tempfile.NamedTemporaryFile() # pylint: disable=consider-using-with try: with tarfile.open(mode="w", fileobj=f) as t: - if isinstance(dockerfile, io.StringIO): + if isinstance(dockerfile, io.StringIO): # type: ignore raise TypeError("Please use io.BytesIO to create in-memory Dockerfiles") if isinstance(dockerfile, io.BytesIO): dfinfo = tarfile.TarInfo("Dockerfile") @@ -144,17 +163,17 @@ def mkbuildcontext(dockerfile): return f -def split_path(p): +def split_path(p: str) -> list[str]: return [pt for pt in re.split(_SEP, p) if pt and pt != "."] -def normalize_slashes(p): +def normalize_slashes(p: str) -> str: if IS_WINDOWS_PLATFORM: return "/".join(split_path(p)) return p -def walk(root, patterns, default=True): +def walk(root: str, patterns: Sequence[str], default: bool = True) -> t.Generator[str]: pm = PatternMatcher(patterns) return pm.walk(root) @@ -162,11 +181,11 @@ def walk(root, patterns, default=True): # Heavily based on # https://github.com/moby/moby/blob/master/pkg/fileutils/fileutils.go class PatternMatcher: - def __init__(self, patterns): + def __init__(self, patterns: Sequence[str]) -> None: self.patterns = list(filter(lambda p: p.dirs, [Pattern(p) for p in patterns])) self.patterns.append(Pattern("!.dockerignore")) - def matches(self, filepath): + def matches(self, filepath: str) -> bool: matched = False parent_path = os.path.dirname(filepath) parent_path_dirs = split_path(parent_path) @@ -185,8 +204,8 @@ class PatternMatcher: return matched - def walk(self, root): - def rec_walk(current_dir): + def walk(self, root: str) -> t.Generator[str]: + def rec_walk(current_dir: str) -> t.Generator[str]: for f in os.listdir(current_dir): fpath = os.path.join(os.path.relpath(current_dir, root), f) if fpath.startswith("." + os.path.sep): @@ -220,7 +239,7 @@ class PatternMatcher: class Pattern: - def __init__(self, pattern_str): + def __init__(self, pattern_str: str) -> None: self.exclusion = False if pattern_str.startswith("!"): self.exclusion = True @@ -230,8 +249,7 @@ class Pattern: self.cleaned_pattern = "/".join(self.dirs) @classmethod - def normalize(cls, p): - + def normalize(cls, p: str) -> list[str]: # Remove trailing spaces p = p.strip() @@ -256,11 +274,11 @@ class Pattern: i += 1 return split - def match(self, filepath): + def match(self, filepath: str) -> bool: return fnmatch.fnmatch(normalize_slashes(filepath), self.cleaned_pattern) -def process_dockerfile(dockerfile, path): +def process_dockerfile(dockerfile: str, path: str) -> tuple[str | None, str | None]: if not dockerfile: return (None, None) @@ -268,7 +286,7 @@ def process_dockerfile(dockerfile, path): if not os.path.isabs(dockerfile): abs_dockerfile = os.path.join(path, dockerfile) if IS_WINDOWS_PLATFORM and path.startswith(WINDOWS_LONGPATH_PREFIX): - abs_dockerfile = f"{WINDOWS_LONGPATH_PREFIX}{os.path.normpath(abs_dockerfile[len(WINDOWS_LONGPATH_PREFIX):])}" + abs_dockerfile = f"{WINDOWS_LONGPATH_PREFIX}{os.path.normpath(abs_dockerfile[len(WINDOWS_LONGPATH_PREFIX) :])}" if os.path.splitdrive(path)[0] != os.path.splitdrive(abs_dockerfile)[ 0 ] or os.path.relpath(abs_dockerfile, path).startswith(".."): diff --git a/plugins/module_utils/_api/utils/config.py b/plugins/module_utils/_api/utils/config.py index 934f2dfc..eaa9542a 100644 --- a/plugins/module_utils/_api/utils/config.py +++ b/plugins/module_utils/_api/utils/config.py @@ -14,6 +14,7 @@ from __future__ import annotations import json import logging import os +import typing as t from ..constants import IS_WINDOWS_PLATFORM @@ -24,11 +25,11 @@ LEGACY_DOCKER_CONFIG_FILENAME = ".dockercfg" log = logging.getLogger(__name__) -def get_default_config_file(): +def get_default_config_file() -> str: return os.path.join(home_dir(), DOCKER_CONFIG_FILENAME) -def find_config_file(config_path=None): +def find_config_file(config_path: str | None = None) -> str | None: homedir = home_dir() paths = list( filter( @@ -54,14 +55,14 @@ def find_config_file(config_path=None): return None -def config_path_from_environment(): +def config_path_from_environment() -> str | None: config_dir = os.environ.get("DOCKER_CONFIG") if not config_dir: return None return os.path.join(config_dir, os.path.basename(DOCKER_CONFIG_FILENAME)) -def home_dir(): +def home_dir() -> str: """ Get the user's home directory, using the same logic as the Docker Engine client - use %USERPROFILE% on Windows, $HOME/getuid on POSIX. @@ -71,7 +72,7 @@ def home_dir(): return os.path.expanduser("~") -def load_general_config(config_path=None): +def load_general_config(config_path: str | None = None) -> dict[str, t.Any]: config_file = find_config_file(config_path) if not config_file: diff --git a/plugins/module_utils/_api/utils/decorators.py b/plugins/module_utils/_api/utils/decorators.py index f046ebd3..59821aca 100644 --- a/plugins/module_utils/_api/utils/decorators.py +++ b/plugins/module_utils/_api/utils/decorators.py @@ -12,16 +12,37 @@ from __future__ import annotations import functools +import typing as t from .. import errors from . import utils -def minimum_version(version): - def decorator(f): +if t.TYPE_CHECKING: + from collections.abc import Callable + + from ..api.client import APIClient + + _Self = t.TypeVar("_Self") + _P = t.ParamSpec("_P") + _R = t.TypeVar("_R") + + +def minimum_version( + version: str, +) -> Callable[ + [Callable[t.Concatenate[_Self, _P], _R]], + Callable[t.Concatenate[_Self, _P], _R], +]: + def decorator( + f: Callable[t.Concatenate[_Self, _P], _R], + ) -> Callable[t.Concatenate[_Self, _P], _R]: @functools.wraps(f) - def wrapper(self, *args, **kwargs): - if utils.version_lt(self._version, version): + def wrapper(self: _Self, *args: _P.args, **kwargs: _P.kwargs) -> _R: + # We use _Self instead of APIClient since this is used for mixins for APIClient. + # This unfortunately means that self._version does not exist in the mixin, + # it only exists after mixing in. This is why we ignore types here. + if utils.version_lt(self._version, version): # type: ignore raise errors.InvalidVersion( f"{f.__name__} is not available for version < {version}" ) @@ -32,13 +53,16 @@ def minimum_version(version): return decorator -def update_headers(f): - def inner(self, *args, **kwargs): +def update_headers( + f: Callable[t.Concatenate[APIClient, _P], _R], +) -> Callable[t.Concatenate[APIClient, _P], _R]: + def inner(self: APIClient, *args: _P.args, **kwargs: _P.kwargs) -> _R: if "HttpHeaders" in self._general_configs: if not kwargs.get("headers"): kwargs["headers"] = self._general_configs["HttpHeaders"] else: - kwargs["headers"].update(self._general_configs["HttpHeaders"]) + # We cannot (yet) model that kwargs["headers"] should be a dictionary + kwargs["headers"].update(self._general_configs["HttpHeaders"]) # type: ignore return f(self, *args, **kwargs) return inner diff --git a/plugins/module_utils/_api/utils/fnmatch.py b/plugins/module_utils/_api/utils/fnmatch.py index 234226c2..525cf84a 100644 --- a/plugins/module_utils/_api/utils/fnmatch.py +++ b/plugins/module_utils/_api/utils/fnmatch.py @@ -32,12 +32,12 @@ _cache: dict[str, re.Pattern] = {} _MAXCACHE = 100 -def _purge(): +def _purge() -> None: """Clear the pattern cache""" _cache.clear() -def fnmatch(name, pat): +def fnmatch(name: str, pat: str): """Test whether FILENAME matches PATTERN. Patterns are Unix shell style: @@ -58,7 +58,7 @@ def fnmatch(name, pat): return fnmatchcase(name, pat) -def fnmatchcase(name, pat): +def fnmatchcase(name: str, pat: str) -> bool: """Test whether FILENAME matches PATTERN, including case. This is a version of fnmatch() which does not case-normalize its arguments. @@ -74,7 +74,7 @@ def fnmatchcase(name, pat): return re_pat.match(name) is not None -def translate(pat): +def translate(pat: str) -> str: """Translate a shell PATTERN to a regular expression. There is no way to quote meta-characters. diff --git a/plugins/module_utils/_api/utils/json_stream.py b/plugins/module_utils/_api/utils/json_stream.py index dac3d0ca..ada8905e 100644 --- a/plugins/module_utils/_api/utils/json_stream.py +++ b/plugins/module_utils/_api/utils/json_stream.py @@ -13,14 +13,22 @@ from __future__ import annotations import json import json.decoder +import typing as t from ..errors import StreamParseError +if t.TYPE_CHECKING: + import re + from collections.abc import Callable + + _T = t.TypeVar("_T") + + json_decoder = json.JSONDecoder() -def stream_as_text(stream): +def stream_as_text(stream: t.Generator[bytes | str]) -> t.Generator[str]: """ Given a stream of bytes or text, if any of the items in the stream are bytes convert them to text. @@ -33,20 +41,22 @@ def stream_as_text(stream): yield data -def json_splitter(buffer): +def json_splitter(buffer: str) -> tuple[t.Any, str] | None: """Attempt to parse a json object from a buffer. If there is at least one object, return it and the rest of the buffer, otherwise return None. """ buffer = buffer.strip() try: obj, index = json_decoder.raw_decode(buffer) - rest = buffer[json.decoder.WHITESPACE.match(buffer, index).end() :] + ws: re.Pattern = json.decoder.WHITESPACE # type: ignore[attr-defined] + m = ws.match(buffer, index) + rest = buffer[m.end() :] if m else buffer[index:] return obj, rest except ValueError: return None -def json_stream(stream): +def json_stream(stream: t.Generator[str | bytes]) -> t.Generator[t.Any]: """Given a stream of text, return a stream of json objects. This handles streams which are inconsistently buffered (some entries may be newline delimited, and others are not). @@ -54,21 +64,24 @@ def json_stream(stream): return split_buffer(stream, json_splitter, json_decoder.decode) -def line_splitter(buffer, separator="\n"): +def line_splitter(buffer: str, separator: str = "\n") -> tuple[str, str] | None: index = buffer.find(str(separator)) if index == -1: return None return buffer[: index + 1], buffer[index + 1 :] -def split_buffer(stream, splitter=None, decoder=lambda a: a): +def split_buffer( + stream: t.Generator[str | bytes], + splitter: Callable[[str], tuple[_T, str] | None], + decoder: Callable[[str], _T], +) -> t.Generator[_T | str]: """Given a generator which yields strings and a splitter function, joins all input, splits on the separator and yields each chunk. Unlike string.split(), each chunk includes the trailing separator, except for the last one if none was found on the end of the input. """ - splitter = splitter or line_splitter buffered = "" for data in stream_as_text(stream): diff --git a/plugins/module_utils/_api/utils/ports.py b/plugins/module_utils/_api/utils/ports.py index 11a350e6..a033c450 100644 --- a/plugins/module_utils/_api/utils/ports.py +++ b/plugins/module_utils/_api/utils/ports.py @@ -12,6 +12,11 @@ from __future__ import annotations import re +import typing as t + + +if t.TYPE_CHECKING: + from collections.abc import Collection, Sequence PORT_SPEC = re.compile( @@ -26,32 +31,42 @@ PORT_SPEC = re.compile( ) -def add_port_mapping(port_bindings, internal_port, external): +def add_port_mapping( + port_bindings: dict[str, list[str | tuple[str, str | None] | None]], + internal_port: str, + external: str | tuple[str, str | None] | None, +) -> None: if internal_port in port_bindings: port_bindings[internal_port].append(external) else: port_bindings[internal_port] = [external] -def add_port(port_bindings, internal_port_range, external_range): +def add_port( + port_bindings: dict[str, list[str | tuple[str, str | None] | None]], + internal_port_range: list[str], + external_range: list[str] | list[tuple[str, str | None]] | None, +) -> None: if external_range is None: for internal_port in internal_port_range: add_port_mapping(port_bindings, internal_port, None) else: - ports = zip(internal_port_range, external_range) - for internal_port, external_port in ports: - add_port_mapping(port_bindings, internal_port, external_port) + for internal_port, external_port in zip(internal_port_range, external_range): + # mypy loses the exact type of eternal_port elements for some reason... + add_port_mapping(port_bindings, internal_port, external_port) # type: ignore -def build_port_bindings(ports): - port_bindings = {} +def build_port_bindings( + ports: Collection[str], +) -> dict[str, list[str | tuple[str, str | None] | None]]: + port_bindings: dict[str, list[str | tuple[str, str | None] | None]] = {} for port in ports: internal_port_range, external_range = split_port(port) add_port(port_bindings, internal_port_range, external_range) return port_bindings -def _raise_invalid_port(port): +def _raise_invalid_port(port: str) -> t.NoReturn: raise ValueError( f'Invalid port "{port}", should be ' "[[remote_ip:]remote_port[-remote_port]:]" @@ -59,39 +74,64 @@ def _raise_invalid_port(port): ) -def port_range(start, end, proto, randomly_available_port=False): - if not start: +@t.overload +def port_range( + start: str, + end: str | None, + proto: str, + randomly_available_port: bool = False, +) -> list[str]: ... + + +@t.overload +def port_range( + start: str | None, + end: str | None, + proto: str, + randomly_available_port: bool = False, +) -> list[str] | None: ... + + +def port_range( + start: str | None, + end: str | None, + proto: str, + randomly_available_port: bool = False, +) -> list[str] | None: + if start is None: return start - if not end: + if end is None: return [f"{start}{proto}"] if randomly_available_port: return [f"{start}-{end}{proto}"] return [f"{port}{proto}" for port in range(int(start), int(end) + 1)] -def split_port(port): - if hasattr(port, "legacy_repr"): - # This is the worst hack, but it prevents a bug in Compose 1.14.0 - # https://github.com/docker/docker-py/issues/1668 - # TODO: remove once fixed in Compose stable - port = port.legacy_repr() +def split_port( + port: str, +) -> tuple[list[str], list[str] | list[tuple[str, str | None]] | None]: port = str(port) match = PORT_SPEC.match(port) if match is None: _raise_invalid_port(port) parts = match.groupdict() - host = parts["host"] - proto = parts["proto"] or "" - internal = port_range(parts["int"], parts["int_end"], proto) - external = port_range(parts["ext"], parts["ext_end"], "", len(internal) == 1) + host: str | None = parts["host"] + proto: str = parts["proto"] or "" + int_p: str = parts["int"] + ext_p: str | None = parts["ext"] or None + internal: list[str] = port_range(int_p, parts["int_end"], proto) # type: ignore + external = port_range(ext_p, parts["ext_end"], "", len(internal) == 1) if host is None: if external is not None and len(internal) != len(external): raise ValueError("Port ranges don't match in length") return internal, external + external_or_none: Sequence[str | None] if not external: - external = [None] * len(internal) - elif len(internal) != len(external): - raise ValueError("Port ranges don't match in length") - return internal, [(host, ext_port) for ext_port in external] + external_or_none = [None] * len(internal) + else: + external_or_none = external + if len(internal) != len(external_or_none): + raise ValueError("Port ranges don't match in length") + return internal, [(host, ext_port) for ext_port in external_or_none] diff --git a/plugins/module_utils/_api/utils/proxy.py b/plugins/module_utils/_api/utils/proxy.py index af1fc064..0f5fa9f3 100644 --- a/plugins/module_utils/_api/utils/proxy.py +++ b/plugins/module_utils/_api/utils/proxy.py @@ -20,23 +20,23 @@ class ProxyConfig(dict): """ @property - def http(self): + def http(self) -> str | None: return self.get("http") @property - def https(self): + def https(self) -> str | None: return self.get("https") @property - def ftp(self): + def ftp(self) -> str | None: return self.get("ftp") @property - def no_proxy(self): + def no_proxy(self) -> str | None: return self.get("no_proxy") @staticmethod - def from_dict(config): + def from_dict(config: dict[str, str]) -> ProxyConfig: """ Instantiate a new ProxyConfig from a dictionary that represents a client configuration, as described in `the documentation`_. @@ -51,7 +51,7 @@ class ProxyConfig(dict): no_proxy=config.get("noProxy"), ) - def get_environment(self): + def get_environment(self) -> dict[str, str]: """ Return a dictionary representing the environment variables used to set the proxy settings. @@ -67,7 +67,7 @@ class ProxyConfig(dict): env["no_proxy"] = env["NO_PROXY"] = self.no_proxy return env - def inject_proxy_environment(self, environment): + def inject_proxy_environment(self, environment: list[str]) -> list[str]: """ Given a list of strings representing environment variables, prepend the environment variables corresponding to the proxy settings. @@ -82,5 +82,5 @@ class ProxyConfig(dict): # variables defined in "environment" to take precedence. return proxy_env + environment - def __str__(self): + def __str__(self) -> str: return f"ProxyConfig(http={self.http}, https={self.https}, ftp={self.ftp}, no_proxy={self.no_proxy})" diff --git a/plugins/module_utils/_api/utils/socket.py b/plugins/module_utils/_api/utils/socket.py index 615018ad..6619e0ff 100644 --- a/plugins/module_utils/_api/utils/socket.py +++ b/plugins/module_utils/_api/utils/socket.py @@ -16,10 +16,15 @@ import os import select import socket as pysocket import struct +import typing as t from ..transport.npipesocket import NpipeSocket +if t.TYPE_CHECKING: + from collections.abc import Iterable + + STDOUT = 1 STDERR = 2 @@ -33,7 +38,7 @@ class SocketError(Exception): NPIPE_ENDED = 109 -def read(socket, n=4096): +def read(socket, n: int = 4096) -> bytes | None: """ Reads at most n bytes from socket """ @@ -58,6 +63,7 @@ def read(socket, n=4096): except EnvironmentError as e: if e.errno not in recoverable_errors: raise + return None # TODO ??? except Exception as e: is_pipe_ended = ( isinstance(socket, NpipeSocket) @@ -67,11 +73,11 @@ def read(socket, n=4096): if is_pipe_ended: # npipes do not support duplex sockets, so we interpret # a PIPE_ENDED error as a close operation (0-length read). - return "" + return b"" raise -def read_exactly(socket, n): +def read_exactly(socket, n: int) -> bytes: """ Reads exactly n bytes from socket Raises SocketError if there is not enough data @@ -85,7 +91,7 @@ def read_exactly(socket, n): return data -def next_frame_header(socket): +def next_frame_header(socket) -> tuple[int, int]: """ Returns the stream and size of the next frame of data waiting to be read from socket, according to the protocol defined here: @@ -101,7 +107,7 @@ def next_frame_header(socket): return (stream, actual) -def frames_iter(socket, tty): +def frames_iter(socket, tty: bool) -> t.Generator[tuple[int, bytes]]: """ Return a generator of frames read from socket. A frame is a tuple where the first item is the stream number and the second item is a chunk of data. @@ -114,7 +120,7 @@ def frames_iter(socket, tty): return frames_iter_no_tty(socket) -def frames_iter_no_tty(socket): +def frames_iter_no_tty(socket) -> t.Generator[tuple[int, bytes]]: """ Returns a generator of data read from the socket when the tty setting is not enabled. @@ -135,20 +141,34 @@ def frames_iter_no_tty(socket): yield (stream, result) -def frames_iter_tty(socket): +def frames_iter_tty(socket) -> t.Generator[bytes]: """ Return a generator of data read from the socket when the tty setting is enabled. """ while True: result = read(socket) - if len(result) == 0: + if not result: # We have reached EOF return yield result -def consume_socket_output(frames, demux=False): +@t.overload +def consume_socket_output(frames, demux: t.Literal[False] = False) -> bytes: ... + + +@t.overload +def consume_socket_output(frames, demux: t.Literal[True]) -> tuple[bytes, bytes]: ... + + +@t.overload +def consume_socket_output( + frames, demux: bool = False +) -> bytes | tuple[bytes, bytes]: ... + + +def consume_socket_output(frames, demux: bool = False) -> bytes | tuple[bytes, bytes]: """ Iterate through frames read from the socket and return the result. @@ -167,7 +187,7 @@ def consume_socket_output(frames, demux=False): # If the streams are demultiplexed, the generator yields tuples # (stdout, stderr) - out = [None, None] + out: list[bytes | None] = [None, None] for frame in frames: # It is guaranteed that for each frame, one and only one stream # is not None. @@ -183,10 +203,10 @@ def consume_socket_output(frames, demux=False): out[1] = frame[1] else: out[1] += frame[1] - return tuple(out) + return tuple(out) # type: ignore -def demux_adaptor(stream_id, data): +def demux_adaptor(stream_id: int, data: bytes) -> tuple[bytes | None, bytes | None]: """ Utility to demultiplex stdout and stderr when reading frames from the socket. diff --git a/plugins/module_utils/_api/utils/utils.py b/plugins/module_utils/_api/utils/utils.py index 31c9b39f..b0a5ed53 100644 --- a/plugins/module_utils/_api/utils/utils.py +++ b/plugins/module_utils/_api/utils/utils.py @@ -18,6 +18,7 @@ import os import os.path import shlex import string +import typing as t from urllib.parse import urlparse, urlunparse from ansible_collections.community.docker.plugins.module_utils._version import ( @@ -34,32 +35,23 @@ from ..constants import ( from ..tls import TLSConfig +if t.TYPE_CHECKING: + import ssl + from collections.abc import Mapping, Sequence + + URLComponents = collections.namedtuple( "URLComponents", "scheme netloc url params query fragment", ) -def create_ipam_pool(*args, **kwargs): - raise errors.DeprecatedMethod( - "utils.create_ipam_pool has been removed. Please use a " - "docker.types.IPAMPool object instead." - ) - - -def create_ipam_config(*args, **kwargs): - raise errors.DeprecatedMethod( - "utils.create_ipam_config has been removed. Please use a " - "docker.types.IPAMConfig object instead." - ) - - -def decode_json_header(header): +def decode_json_header(header: str) -> dict[str, t.Any]: data = base64.b64decode(header).decode("utf-8") return json.loads(data) -def compare_version(v1, v2): +def compare_version(v1: str, v2: str) -> t.Literal[-1, 0, 1]: """Compare docker versions >>> v1 = '1.9' @@ -80,43 +72,64 @@ def compare_version(v1, v2): return 1 -def version_lt(v1, v2): +def version_lt(v1: str, v2: str) -> bool: return compare_version(v1, v2) > 0 -def version_gte(v1, v2): +def version_gte(v1: str, v2: str) -> bool: return not version_lt(v1, v2) -def _convert_port_binding(binding): +def _convert_port_binding( + binding: ( + tuple[str, str | int | None] + | tuple[str | int | None] + | dict[str, str] + | str + | int + ), +) -> dict[str, str]: result = {"HostIp": "", "HostPort": ""} + host_port: str | int | None = "" if isinstance(binding, tuple): if len(binding) == 2: - result["HostPort"] = binding[1] + host_port = binding[1] # type: ignore result["HostIp"] = binding[0] elif isinstance(binding[0], str): result["HostIp"] = binding[0] else: - result["HostPort"] = binding[0] + host_port = binding[0] elif isinstance(binding, dict): if "HostPort" in binding: - result["HostPort"] = binding["HostPort"] + host_port = binding["HostPort"] if "HostIp" in binding: result["HostIp"] = binding["HostIp"] else: raise ValueError(binding) else: - result["HostPort"] = binding - - if result["HostPort"] is None: - result["HostPort"] = "" - else: - result["HostPort"] = str(result["HostPort"]) + host_port = binding + result["HostPort"] = str(host_port) if host_port is not None else "" return result -def convert_port_bindings(port_bindings): +def convert_port_bindings( + port_bindings: dict[ + str | int, + tuple[str, str | int | None] + | tuple[str | int | None] + | dict[str, str] + | str + | int + | list[ + tuple[str, str | int | None] + | tuple[str | int | None] + | dict[str, str] + | str + | int + ], + ], +) -> dict[str, list[dict[str, str]]]: result = {} for k, v in port_bindings.items(): key = str(k) @@ -129,9 +142,11 @@ def convert_port_bindings(port_bindings): return result -def convert_volume_binds(binds): +def convert_volume_binds( + binds: list[str] | Mapping[str | bytes, dict[str, str | bytes] | bytes | str | int], +) -> list[str]: if isinstance(binds, list): - return binds + return binds # type: ignore result = [] for k, v in binds.items(): @@ -149,7 +164,7 @@ def convert_volume_binds(binds): if "ro" in v: mode = "ro" if v["ro"] else "rw" elif "mode" in v: - mode = v["mode"] + mode = v["mode"] # type: ignore # TODO else: mode = "rw" @@ -165,9 +180,9 @@ def convert_volume_binds(binds): ] if "propagation" in v and v["propagation"] in propagation_modes: if mode: - mode = ",".join([mode, v["propagation"]]) + mode = ",".join([mode, v["propagation"]]) # type: ignore # TODO else: - mode = v["propagation"] + mode = v["propagation"] # type: ignore # TODO result.append(f"{k}:{bind}:{mode}") else: @@ -177,7 +192,7 @@ def convert_volume_binds(binds): return result -def convert_tmpfs_mounts(tmpfs): +def convert_tmpfs_mounts(tmpfs: dict[str, str] | list[str]) -> dict[str, str]: if isinstance(tmpfs, dict): return tmpfs @@ -204,9 +219,11 @@ def convert_tmpfs_mounts(tmpfs): return result -def convert_service_networks(networks): +def convert_service_networks( + networks: list[str | dict[str, str]], +) -> list[dict[str, str]]: if not networks: - return networks + return networks # type: ignore if not isinstance(networks, list): raise TypeError("networks parameter must be a list.") @@ -218,17 +235,17 @@ def convert_service_networks(networks): return result -def parse_repository_tag(repo_name): +def parse_repository_tag(repo_name: str) -> tuple[str, str | None]: parts = repo_name.rsplit("@", 1) if len(parts) == 2: - return tuple(parts) + return tuple(parts) # type: ignore parts = repo_name.rsplit(":", 1) if len(parts) == 2 and "/" not in parts[1]: - return tuple(parts) + return tuple(parts) # type: ignore return repo_name, None -def parse_host(addr, is_win32=False, tls=False): +def parse_host(addr: str | None, is_win32: bool = False, tls: bool = False) -> str: # Sensible defaults if not addr and is_win32: return DEFAULT_NPIPE @@ -308,7 +325,7 @@ def parse_host(addr, is_win32=False, tls=False): ).rstrip("/") -def parse_devices(devices): +def parse_devices(devices: Sequence[dict[str, str] | str]) -> list[dict[str, str]]: device_list = [] for device in devices: if isinstance(device, dict): @@ -337,7 +354,10 @@ def parse_devices(devices): return device_list -def kwargs_from_env(ssl_version=None, assert_hostname=None, environment=None): +def kwargs_from_env( + assert_hostname: bool | None = None, + environment: Mapping[str, str] | None = None, +) -> dict[str, t.Any]: if not environment: environment = os.environ host = environment.get("DOCKER_HOST") @@ -347,14 +367,14 @@ def kwargs_from_env(ssl_version=None, assert_hostname=None, environment=None): # empty string for tls verify counts as "false". # Any value or 'unset' counts as true. - tls_verify = environment.get("DOCKER_TLS_VERIFY") - if tls_verify == "": + tls_verify_str = environment.get("DOCKER_TLS_VERIFY") + if tls_verify_str == "": tls_verify = False else: - tls_verify = tls_verify is not None + tls_verify = tls_verify_str is not None enable_tls = cert_path or tls_verify - params = {} + params: dict[str, t.Any] = {} if host: params["base_url"] = host @@ -377,14 +397,13 @@ def kwargs_from_env(ssl_version=None, assert_hostname=None, environment=None): ), ca_cert=os.path.join(cert_path, "ca.pem"), verify=tls_verify, - ssl_version=ssl_version, assert_hostname=assert_hostname, ) return params -def convert_filters(filters): +def convert_filters(filters: dict[str, bool | str | list[str]]) -> str: result = {} for k, v in filters.items(): if isinstance(v, bool): @@ -397,7 +416,7 @@ def convert_filters(filters): return json.dumps(result) -def parse_bytes(s): +def parse_bytes(s: int | float | str) -> int | float: if isinstance(s, (int, float)): return s if len(s) == 0: @@ -435,14 +454,16 @@ def parse_bytes(s): return s -def normalize_links(links): +def normalize_links(links: dict[str, str] | Sequence[tuple[str, str]]) -> list[str]: if isinstance(links, dict): - links = links.items() + sorted_links = sorted(links.items()) + else: + sorted_links = sorted(links) - return [f"{k}:{v}" if v else k for k, v in sorted(links)] + return [f"{k}:{v}" if v else k for k, v in sorted_links] -def parse_env_file(env_file): +def parse_env_file(env_file: str | os.PathLike) -> dict[str, str]: """ Reads a line-separated environment file. The format of each line should be "key=value". @@ -451,7 +472,6 @@ def parse_env_file(env_file): with open(env_file, "rt", encoding="utf-8") as f: for line in f: - if line[0] == "#": continue @@ -471,11 +491,11 @@ def parse_env_file(env_file): return environment -def split_command(command): +def split_command(command: str) -> list[str]: return shlex.split(command) -def format_environment(environment): +def format_environment(environment: Mapping[str, str | bytes]) -> list[str]: def format_env(key, value): if value is None: return key @@ -487,16 +507,9 @@ def format_environment(environment): return [format_env(*var) for var in environment.items()] -def format_extra_hosts(extra_hosts, task=False): +def format_extra_hosts(extra_hosts: Mapping[str, str], task: bool = False) -> list[str]: # Use format dictated by Swarm API if container is part of a task if task: return [f"{v} {k}" for k, v in sorted(extra_hosts.items())] return [f"{k}:{v}" for k, v in sorted(extra_hosts.items())] - - -def create_host_config(self, *args, **kwargs): - raise errors.DeprecatedMethod( - "utils.create_host_config has been removed. Please use a " - "docker.types.HostConfig object instead." - ) diff --git a/plugins/module_utils/_copy.py b/plugins/module_utils/_copy.py index db1808d6..413ba079 100644 --- a/plugins/module_utils/_copy.py +++ b/plugins/module_utils/_copy.py @@ -552,7 +552,7 @@ def _execute_command( result = client.get_json("/exec/{0}/json", exec_id) - rc = result.get("ExitCode") or 0 + rc: int = result.get("ExitCode") or 0 stdout = stdout or b"" stderr = stderr or b""