Add more typing.

This commit is contained in:
Felix Fontein 2025-10-25 00:09:57 +02:00
parent a2deb384d4
commit 931ae7978c
48 changed files with 430 additions and 314 deletions

View File

@ -3,15 +3,15 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
[mypy] [mypy]
# check_untyped_defs = True -- for later check_untyped_defs = True
# disallow_untyped_defs = True -- for later disallow_untyped_defs = True
# strict = True -- only try to enable once everything (including dependencies!) is typed # strict = True -- only try to enable once everything (including dependencies!) is typed
strict_equality = True strict_equality = True
strict_bytes = True strict_bytes = True
warn_redundant_casts = True warn_redundant_casts = True
# warn_return_any = True -- for later # warn_return_any = True
warn_unreachable = True warn_unreachable = True
[mypy-ansible.*] [mypy-ansible.*]

View File

@ -141,7 +141,7 @@ class Connection(ConnectionBase):
transport = "community.docker.docker" transport = "community.docker.docker"
has_pipelining = True has_pipelining = True
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# Note: docker supports running as non-root in some configurations. # Note: docker supports running as non-root in some configurations.
@ -476,7 +476,7 @@ class Connection(ConnectionBase):
display.debug("done with docker.exec_command()") display.debug("done with docker.exec_command()")
return (p.returncode, stdout, stderr) return (p.returncode, stdout, stderr)
def _prefix_login_path(self, remote_path): def _prefix_login_path(self, remote_path: str) -> str:
"""Make sure that we put files into a standard path """Make sure that we put files into a standard path
If a path is relative, then we need to choose where to put it. If a path is relative, then we need to choose where to put it.

View File

@ -192,7 +192,7 @@ class Connection(ConnectionBase):
f'An unexpected requests error occurred for container "{remote_addr}" when trying to talk to the Docker daemon: {e}' f'An unexpected requests error occurred for container "{remote_addr}" when trying to talk to the Docker daemon: {e}'
) )
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.client: AnsibleDockerClient | None = None self.client: AnsibleDockerClient | None = None
@ -319,7 +319,7 @@ class Connection(ConnectionBase):
become_output = [b""] become_output = [b""]
def append_become_output(stream_id, data): def append_become_output(stream_id: int, data: bytes) -> None:
become_output[0] += data become_output[0] += data
exec_socket_handler.set_block_done_callback( exec_socket_handler.set_block_done_callback(

View File

@ -65,7 +65,7 @@ class Connection(ConnectionBase):
transport = "community.docker.nsenter" transport = "community.docker.nsenter"
has_pipelining = False has_pipelining = False
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.cwd = None self.cwd = None
self._nsenter_pid = None self._nsenter_pid = None

View File

@ -221,7 +221,10 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
return ip_addr return ip_addr
def _should_skip_host( def _should_skip_host(
self, machine_name: str, env_var_tuples, daemon_env: DaemonEnv self,
machine_name: str,
env_var_tuples: list[tuple[str, str]],
daemon_env: DaemonEnv,
) -> bool: ) -> bool:
if not env_var_tuples: if not env_var_tuples:
warning_prefix = f"Unable to fetch Docker daemon env vars from Docker Machine for host {machine_name}" warning_prefix = f"Unable to fetch Docker daemon env vars from Docker Machine for host {machine_name}"

View File

@ -67,7 +67,7 @@ except ImportError:
pass pass
class FakeURLLIB3: class FakeURLLIB3:
def __init__(self): def __init__(self) -> None:
self._collections = self self._collections = self
self.poolmanager = self self.poolmanager = self
self.connection = self self.connection = self
@ -81,14 +81,14 @@ except ImportError:
) )
class FakeURLLIB3Connection: class FakeURLLIB3Connection:
def __init__(self): def __init__(self) -> None:
self.HTTPConnection = _HTTPConnection # pylint: disable=invalid-name self.HTTPConnection = _HTTPConnection # pylint: disable=invalid-name
urllib3 = FakeURLLIB3() urllib3 = FakeURLLIB3()
urllib3_connection = FakeURLLIB3Connection() urllib3_connection = FakeURLLIB3Connection()
def fail_on_missing_imports(): def fail_on_missing_imports() -> None:
if REQUESTS_IMPORT_ERROR is not None: if REQUESTS_IMPORT_ERROR is not None:
from .errors import MissingRequirementException # pylint: disable=cyclic-import from .errors import MissingRequirementException # pylint: disable=cyclic-import

View File

@ -55,6 +55,7 @@ from ..utils.socket import consume_socket_output, demux_adaptor, frames_iter
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from requests import Response from requests import Response
from requests.adapters import BaseAdapter
from ..._socket_helper import SocketLike from ..._socket_helper import SocketLike
@ -258,23 +259,23 @@ class APIClient(_Session):
return kwargs return kwargs
@update_headers @update_headers
def _post(self, url: str, **kwargs): def _post(self, url: str, **kwargs: t.Any) -> Response:
return self.post(url, **self._set_request_timeout(kwargs)) return self.post(url, **self._set_request_timeout(kwargs))
@update_headers @update_headers
def _get(self, url: str, **kwargs): def _get(self, url: str, **kwargs: t.Any) -> Response:
return self.get(url, **self._set_request_timeout(kwargs)) return self.get(url, **self._set_request_timeout(kwargs))
@update_headers @update_headers
def _head(self, url: str, **kwargs): def _head(self, url: str, **kwargs: t.Any) -> Response:
return self.head(url, **self._set_request_timeout(kwargs)) return self.head(url, **self._set_request_timeout(kwargs))
@update_headers @update_headers
def _put(self, url: str, **kwargs): def _put(self, url: str, **kwargs: t.Any) -> Response:
return self.put(url, **self._set_request_timeout(kwargs)) return self.put(url, **self._set_request_timeout(kwargs))
@update_headers @update_headers
def _delete(self, url: str, **kwargs): def _delete(self, url: str, **kwargs: t.Any) -> Response:
return self.delete(url, **self._set_request_timeout(kwargs)) return self.delete(url, **self._set_request_timeout(kwargs))
def _url(self, pathfmt: str, *args: str, versioned_api: bool = True) -> str: def _url(self, pathfmt: str, *args: str, versioned_api: bool = True) -> str:
@ -343,7 +344,7 @@ class APIClient(_Session):
return response.text return response.text
def _post_json( def _post_json(
self, url: str, data: dict[str, str | None] | t.Any, **kwargs self, url: str, data: dict[str, str | None] | t.Any, **kwargs: t.Any
) -> Response: ) -> Response:
# Go <1.1 cannot unserialize null to a string # Go <1.1 cannot unserialize null to a string
# so we do this disgusting thing here. # so we do this disgusting thing here.
@ -556,22 +557,30 @@ class APIClient(_Session):
""" """
socket = self._get_raw_response_socket(response) socket = self._get_raw_response_socket(response)
gen: t.Generator = frames_iter(socket, tty) gen = frames_iter(socket, tty)
if demux: if demux:
# The generator will output tuples (stdout, stderr) # The generator will output tuples (stdout, stderr)
gen = (demux_adaptor(*frame) for frame in gen) demux_gen: t.Generator[tuple[bytes | None, bytes | None]] = (
demux_adaptor(*frame) for frame in gen
)
if stream:
return demux_gen
try:
# Wait for all the frames, concatenate them, and return the result
return consume_socket_output(demux_gen, demux=True)
finally:
response.close()
else: else:
# The generator will output strings # The generator will output strings
gen = (data for (dummy, data) in gen) mux_gen: t.Generator[bytes] = (data for (dummy, data) in gen)
if stream:
if stream: return mux_gen
return gen try:
try: # Wait for all the frames, concatenate them, and return the result
# Wait for all the frames, concatenate them, and return the result return consume_socket_output(mux_gen, demux=False)
return consume_socket_output(gen, demux=demux) finally:
finally: response.close()
response.close()
def _disable_socket_timeout(self, socket: SocketLike) -> None: def _disable_socket_timeout(self, socket: SocketLike) -> None:
"""Depending on the combination of python version and whether we are """Depending on the combination of python version and whether we are
@ -637,11 +646,11 @@ class APIClient(_Session):
return self._multiplexed_response_stream_helper(res) return self._multiplexed_response_stream_helper(res)
return sep.join(list(self._multiplexed_buffer_helper(res))) return sep.join(list(self._multiplexed_buffer_helper(res)))
def _unmount(self, *args) -> None: def _unmount(self, *args: t.Any) -> None:
for proto in args: for proto in args:
self.adapters.pop(proto) self.adapters.pop(proto)
def get_adapter(self, url: str): def get_adapter(self, url: str) -> BaseAdapter:
try: try:
return super().get_adapter(url) return super().get_adapter(url)
except _InvalidSchema as e: except _InvalidSchema as e:
@ -696,19 +705,19 @@ class APIClient(_Session):
else: else:
log.debug("No auth config found") log.debug("No auth config found")
def get_binary(self, pathfmt: str, *args: str, **kwargs) -> bytes: def get_binary(self, pathfmt: str, *args: str, **kwargs: t.Any) -> bytes:
return self._result( return self._result(
self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs), self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_binary=True, get_binary=True,
) )
def get_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any: def get_json(self, pathfmt: str, *args: str, **kwargs: t.Any) -> t.Any:
return self._result( return self._result(
self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs), self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_json=True, get_json=True,
) )
def get_text(self, pathfmt: str, *args: str, **kwargs) -> str: def get_text(self, pathfmt: str, *args: str, **kwargs: t.Any) -> str:
return self._result( return self._result(
self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs) self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs)
) )
@ -718,7 +727,7 @@ class APIClient(_Session):
pathfmt: str, pathfmt: str,
*args: str, *args: str,
chunk_size: int = DEFAULT_DATA_CHUNK_SIZE, chunk_size: int = DEFAULT_DATA_CHUNK_SIZE,
**kwargs, **kwargs: t.Any,
) -> t.Generator[bytes]: ) -> t.Generator[bytes]:
res = self._get( res = self._get(
self._url(pathfmt, *args, versioned_api=True), stream=True, **kwargs self._url(pathfmt, *args, versioned_api=True), stream=True, **kwargs
@ -726,23 +735,25 @@ class APIClient(_Session):
self._raise_for_status(res) self._raise_for_status(res)
return self._stream_raw_result(res, chunk_size=chunk_size, decode=False) return self._stream_raw_result(res, chunk_size=chunk_size, decode=False)
def delete_call(self, pathfmt: str, *args: str, **kwargs) -> None: def delete_call(self, pathfmt: str, *args: str, **kwargs: t.Any) -> None:
self._raise_for_status( self._raise_for_status(
self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs) self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs)
) )
def delete_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any: def delete_json(self, pathfmt: str, *args: str, **kwargs: t.Any) -> t.Any:
return self._result( return self._result(
self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs), self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_json=True, get_json=True,
) )
def post_call(self, pathfmt: str, *args: str, **kwargs) -> None: def post_call(self, pathfmt: str, *args: str, **kwargs: t.Any) -> None:
self._raise_for_status( self._raise_for_status(
self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs) self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs)
) )
def post_json(self, pathfmt: str, *args: str, data: t.Any = None, **kwargs) -> None: def post_json(
self, pathfmt: str, *args: str, data: t.Any = None, **kwargs: t.Any
) -> None:
self._raise_for_status( self._raise_for_status(
self._post_json( self._post_json(
self._url(pathfmt, *args, versioned_api=True), data, **kwargs self._url(pathfmt, *args, versioned_api=True), data, **kwargs
@ -750,7 +761,7 @@ class APIClient(_Session):
) )
def post_json_to_binary( def post_json_to_binary(
self, pathfmt: str, *args: str, data: t.Any = None, **kwargs self, pathfmt: str, *args: str, data: t.Any = None, **kwargs: t.Any
) -> bytes: ) -> bytes:
return self._result( return self._result(
self._post_json( self._post_json(
@ -760,7 +771,7 @@ class APIClient(_Session):
) )
def post_json_to_json( def post_json_to_json(
self, pathfmt: str, *args: str, data: t.Any = None, **kwargs self, pathfmt: str, *args: str, data: t.Any = None, **kwargs: t.Any
) -> t.Any: ) -> t.Any:
return self._result( return self._result(
self._post_json( self._post_json(
@ -770,7 +781,7 @@ class APIClient(_Session):
) )
def post_json_to_text( def post_json_to_text(
self, pathfmt: str, *args: str, data: t.Any = None, **kwargs self, pathfmt: str, *args: str, data: t.Any = None, **kwargs: t.Any
) -> str: ) -> str:
return self._result( return self._result(
self._post_json( self._post_json(
@ -784,7 +795,7 @@ class APIClient(_Session):
*args: str, *args: str,
data: t.Any = None, data: t.Any = None,
headers: dict[str, str] | None = None, headers: dict[str, str] | None = None,
**kwargs, **kwargs: t.Any,
) -> SocketLike: ) -> SocketLike:
headers = headers.copy() if headers else {} headers = headers.copy() if headers else {}
headers.update( headers.update(
@ -813,7 +824,7 @@ class APIClient(_Session):
stream: t.Literal[True], stream: t.Literal[True],
tty: bool = True, tty: bool = True,
demux: t.Literal[False] = False, demux: t.Literal[False] = False,
**kwargs, **kwargs: t.Any,
) -> t.Generator[bytes]: ... ) -> t.Generator[bytes]: ...
@t.overload @t.overload
@ -826,7 +837,7 @@ class APIClient(_Session):
stream: t.Literal[True], stream: t.Literal[True],
tty: t.Literal[True] = True, tty: t.Literal[True] = True,
demux: t.Literal[True], demux: t.Literal[True],
**kwargs, **kwargs: t.Any,
) -> t.Generator[tuple[bytes, None]]: ... ) -> t.Generator[tuple[bytes, None]]: ...
@t.overload @t.overload
@ -839,7 +850,7 @@ class APIClient(_Session):
stream: t.Literal[True], stream: t.Literal[True],
tty: t.Literal[False], tty: t.Literal[False],
demux: t.Literal[True], demux: t.Literal[True],
**kwargs, **kwargs: t.Any,
) -> t.Generator[tuple[bytes, None] | tuple[None, bytes]]: ... ) -> t.Generator[tuple[bytes, None] | tuple[None, bytes]]: ...
@t.overload @t.overload
@ -852,7 +863,7 @@ class APIClient(_Session):
stream: t.Literal[False], stream: t.Literal[False],
tty: bool = True, tty: bool = True,
demux: t.Literal[False] = False, demux: t.Literal[False] = False,
**kwargs, **kwargs: t.Any,
) -> bytes: ... ) -> bytes: ...
@t.overload @t.overload
@ -865,7 +876,7 @@ class APIClient(_Session):
stream: t.Literal[False], stream: t.Literal[False],
tty: t.Literal[True] = True, tty: t.Literal[True] = True,
demux: t.Literal[True], demux: t.Literal[True],
**kwargs, **kwargs: t.Any,
) -> tuple[bytes, None]: ... ) -> tuple[bytes, None]: ...
@t.overload @t.overload
@ -878,7 +889,7 @@ class APIClient(_Session):
stream: t.Literal[False], stream: t.Literal[False],
tty: t.Literal[False], tty: t.Literal[False],
demux: t.Literal[True], demux: t.Literal[True],
**kwargs, **kwargs: t.Any,
) -> tuple[bytes, bytes]: ... ) -> tuple[bytes, bytes]: ...
def post_json_to_stream( def post_json_to_stream(
@ -890,7 +901,7 @@ class APIClient(_Session):
stream: bool = False, stream: bool = False,
demux: bool = False, demux: bool = False,
tty: bool = False, tty: bool = False,
**kwargs, **kwargs: t.Any,
) -> t.Any: ) -> t.Any:
headers = headers.copy() if headers else {} headers = headers.copy() if headers else {}
headers.update( headers.update(
@ -912,7 +923,7 @@ class APIClient(_Session):
demux=demux, demux=demux,
) )
def post_to_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any: def post_to_json(self, pathfmt: str, *args: str, **kwargs: t.Any) -> t.Any:
return self._result( return self._result(
self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs), self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_json=True, get_json=True,

View File

@ -106,7 +106,7 @@ class AuthConfig(dict):
@classmethod @classmethod
def parse_auth( def parse_auth(
cls, entries: dict[str, dict[str, t.Any]], raise_on_error=False cls, entries: dict[str, dict[str, t.Any]], raise_on_error: bool = False
) -> dict[str, dict[str, t.Any]]: ) -> dict[str, dict[str, t.Any]]:
""" """
Parses authentication entries Parses authentication entries
@ -294,7 +294,7 @@ class AuthConfig(dict):
except StoreError as e: except StoreError as e:
raise errors.DockerException(f"Credentials store error: {e}") raise errors.DockerException(f"Credentials store error: {e}")
def _get_store_instance(self, name: str): def _get_store_instance(self, name: str) -> Store:
if name not in self._stores: if name not in self._stores:
self._stores[name] = Store(name, environment=self._credstore_env) self._stores[name] = Store(name, environment=self._credstore_env)
return self._stores[name] return self._stores[name]
@ -326,8 +326,10 @@ class AuthConfig(dict):
def resolve_authconfig( def resolve_authconfig(
authconfig, registry: str | None = None, credstore_env: dict[str, str] | None = None authconfig: AuthConfig | dict[str, t.Any],
): registry: str | None = None,
credstore_env: dict[str, str] | None = None,
) -> dict[str, t.Any] | None:
if not isinstance(authconfig, AuthConfig): if not isinstance(authconfig, AuthConfig):
authconfig = AuthConfig(authconfig, credstore_env) authconfig = AuthConfig(authconfig, credstore_env)
return authconfig.resolve_authconfig(registry) return authconfig.resolve_authconfig(registry)

View File

@ -89,7 +89,7 @@ def get_meta_dir(name: str | None = None) -> str:
return meta_dir return meta_dir
def get_meta_file(name) -> str: def get_meta_file(name: str) -> str:
return os.path.join(get_meta_dir(name), METAFILE) return os.path.join(get_meta_dir(name), METAFILE)

View File

@ -18,6 +18,10 @@ from ansible.module_utils.common.text.converters import to_native
from ._import_helper import HTTPError as _HTTPError from ._import_helper import HTTPError as _HTTPError
if t.TYPE_CHECKING:
from requests import Response
class DockerException(Exception): class DockerException(Exception):
""" """
A base class from which all other exceptions inherit. A base class from which all other exceptions inherit.
@ -55,7 +59,10 @@ class APIError(_HTTPError, DockerException):
""" """
def __init__( def __init__(
self, message: str | Exception, response=None, explanation: str | None = None self,
message: str | Exception,
response: Response | None = None,
explanation: str | None = None,
) -> None: ) -> None:
# requests 1.2 supports response as a keyword argument, but # requests 1.2 supports response as a keyword argument, but
# requests 1.1 does not # requests 1.1 does not

View File

@ -11,6 +11,7 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from queue import Empty from queue import Empty
from .. import constants from .. import constants
@ -19,6 +20,12 @@ from .basehttpadapter import BaseHTTPAdapter
from .npipesocket import NpipeSocket from .npipesocket import NpipeSocket
if t.TYPE_CHECKING:
from collections.abc import Mapping
from requests import PreparedRequest
RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
@ -91,7 +98,9 @@ class NpipeHTTPAdapter(BaseHTTPAdapter):
) )
super().__init__() super().__init__()
def get_connection(self, url: str | bytes, proxies=None) -> NpipeHTTPConnectionPool: def get_connection(
self, url: str | bytes, proxies: Mapping[str, str] | None = None
) -> NpipeHTTPConnectionPool:
with self.pools.lock: with self.pools.lock:
pool = self.pools.get(url) pool = self.pools.get(url)
if pool: if pool:
@ -104,7 +113,9 @@ class NpipeHTTPAdapter(BaseHTTPAdapter):
return pool return pool
def request_url(self, request, proxies) -> str: def request_url(
self, request: PreparedRequest, proxies: Mapping[str, str] | None
) -> str:
# The select_proxy utility in requests errors out when the provided URL # The select_proxy utility in requests errors out when the provided URL
# does not have a hostname, like is the case when using a UNIX socket. # does not have a hostname, like is the case when using a UNIX socket.
# Since proxies are an irrelevant notion in the case of UNIX sockets # Since proxies are an irrelevant notion in the case of UNIX sockets

View File

@ -64,7 +64,7 @@ class NpipeSocket:
implemented. implemented.
""" """
def __init__(self, handle=None) -> None: def __init__(self, handle: t.Any | None = None) -> None:
self._timeout = win32pipe.NMPWAIT_USE_DEFAULT_WAIT self._timeout = win32pipe.NMPWAIT_USE_DEFAULT_WAIT
self._handle = handle self._handle = handle
self._address: str | None = None self._address: str | None = None
@ -74,15 +74,17 @@ class NpipeSocket:
def accept(self) -> t.NoReturn: def accept(self) -> t.NoReturn:
raise NotImplementedError() raise NotImplementedError()
def bind(self, address) -> t.NoReturn: def bind(self, address: t.Any) -> t.NoReturn:
raise NotImplementedError() raise NotImplementedError()
def close(self) -> None: def close(self) -> None:
if self._handle is None:
raise ValueError("Handle not present")
self._handle.Close() self._handle.Close()
self._closed = True self._closed = True
@check_closed @check_closed
def connect(self, address, retry_count: int = 0) -> None: def connect(self, address: str, retry_count: int = 0) -> None:
try: try:
handle = win32file.CreateFile( handle = win32file.CreateFile(
address, address,
@ -116,11 +118,11 @@ class NpipeSocket:
self._address = address self._address = address
@check_closed @check_closed
def connect_ex(self, address) -> None: def connect_ex(self, address: str) -> None:
self.connect(address) self.connect(address)
@check_closed @check_closed
def detach(self): def detach(self) -> t.Any:
self._closed = True self._closed = True
return self._handle return self._handle
@ -134,16 +136,18 @@ class NpipeSocket:
def getsockname(self) -> str | None: def getsockname(self) -> str | None:
return self._address return self._address
def getsockopt(self, level, optname, buflen=None) -> t.NoReturn: def getsockopt(
self, level: t.Any, optname: t.Any, buflen: t.Any = None
) -> t.NoReturn:
raise NotImplementedError() raise NotImplementedError()
def ioctl(self, control, option) -> t.NoReturn: def ioctl(self, control: t.Any, option: t.Any) -> t.NoReturn:
raise NotImplementedError() raise NotImplementedError()
def listen(self, backlog) -> t.NoReturn: def listen(self, backlog: t.Any) -> t.NoReturn:
raise NotImplementedError() raise NotImplementedError()
def makefile(self, mode: str, bufsize: int | None = None): def makefile(self, mode: str, bufsize: int | None = None) -> t.IO[bytes]:
if mode.strip("b") != "r": if mode.strip("b") != "r":
raise NotImplementedError() raise NotImplementedError()
rawio = NpipeFileIOBase(self) rawio = NpipeFileIOBase(self)
@ -153,6 +157,8 @@ class NpipeSocket:
@check_closed @check_closed
def recv(self, bufsize: int, flags: int = 0) -> str: def recv(self, bufsize: int, flags: int = 0) -> str:
if self._handle is None:
raise ValueError("Handle not present")
dummy_err, data = win32file.ReadFile(self._handle, bufsize) dummy_err, data = win32file.ReadFile(self._handle, bufsize)
return data return data
@ -169,6 +175,8 @@ class NpipeSocket:
@check_closed @check_closed
def recv_into(self, buf: Buffer, nbytes: int = 0) -> int: def recv_into(self, buf: Buffer, nbytes: int = 0) -> int:
if self._handle is None:
raise ValueError("Handle not present")
readbuf = buf if isinstance(buf, memoryview) else memoryview(buf) readbuf = buf if isinstance(buf, memoryview) else memoryview(buf)
event = win32event.CreateEvent(None, True, True, None) event = win32event.CreateEvent(None, True, True, None)
@ -188,6 +196,8 @@ class NpipeSocket:
@check_closed @check_closed
def send(self, string: Buffer, flags: int = 0) -> int: def send(self, string: Buffer, flags: int = 0) -> int:
if self._handle is None:
raise ValueError("Handle not present")
event = win32event.CreateEvent(None, True, True, None) event = win32event.CreateEvent(None, True, True, None)
try: try:
overlapped = pywintypes.OVERLAPPED() overlapped = pywintypes.OVERLAPPED()
@ -210,7 +220,7 @@ class NpipeSocket:
self.connect(address) self.connect(address)
return self.send(string) return self.send(string)
def setblocking(self, flag: bool): def setblocking(self, flag: bool) -> None:
if flag: if flag:
return self.settimeout(None) return self.settimeout(None)
return self.settimeout(0) return self.settimeout(0)
@ -228,16 +238,16 @@ class NpipeSocket:
def gettimeout(self) -> int | float | None: def gettimeout(self) -> int | float | None:
return self._timeout return self._timeout
def setsockopt(self, level, optname, value) -> t.NoReturn: def setsockopt(self, level: t.Any, optname: t.Any, value: t.Any) -> t.NoReturn:
raise NotImplementedError() raise NotImplementedError()
@check_closed @check_closed
def shutdown(self, how) -> None: def shutdown(self, how: t.Any) -> None:
return self.close() return self.close()
class NpipeFileIOBase(io.RawIOBase): class NpipeFileIOBase(io.RawIOBase):
def __init__(self, npipe_socket) -> None: def __init__(self, npipe_socket: NpipeSocket | None) -> None:
self.sock = npipe_socket self.sock = npipe_socket
def close(self) -> None: def close(self) -> None:
@ -245,7 +255,10 @@ class NpipeFileIOBase(io.RawIOBase):
self.sock = None self.sock = None
def fileno(self) -> int: def fileno(self) -> int:
return self.sock.fileno() if self.sock is None:
raise RuntimeError("socket is closed")
# TODO: This is definitely a bug, NpipeSocket.fileno() does not exist!
return self.sock.fileno() # type: ignore
def isatty(self) -> bool: def isatty(self) -> bool:
return False return False
@ -254,6 +267,8 @@ class NpipeFileIOBase(io.RawIOBase):
return True return True
def readinto(self, buf: Buffer) -> int: def readinto(self, buf: Buffer) -> int:
if self.sock is None:
raise RuntimeError("socket is closed")
return self.sock.recv_into(buf) return self.sock.recv_into(buf)
def seekable(self) -> bool: def seekable(self) -> bool:

View File

@ -35,7 +35,7 @@ else:
PARAMIKO_IMPORT_ERROR = None # pylint: disable=invalid-name PARAMIKO_IMPORT_ERROR = None # pylint: disable=invalid-name
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from collections.abc import Buffer from collections.abc import Buffer, Mapping
RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
@ -67,7 +67,7 @@ class SSHSocket(socket.socket):
preexec_func = None preexec_func = None
if not constants.IS_WINDOWS_PLATFORM: if not constants.IS_WINDOWS_PLATFORM:
def f(): def f() -> None:
signal.signal(signal.SIGINT, signal.SIG_IGN) signal.signal(signal.SIGINT, signal.SIG_IGN)
preexec_func = f preexec_func = f
@ -100,13 +100,13 @@ class SSHSocket(socket.socket):
self.proc.stdin.flush() self.proc.stdin.flush()
return written return written
def sendall(self, data: Buffer, *args, **kwargs) -> None: def sendall(self, data: Buffer, *args: t.Any, **kwargs: t.Any) -> None:
self._write(data) self._write(data)
def send(self, data: Buffer, *args, **kwargs) -> int: def send(self, data: Buffer, *args: t.Any, **kwargs: t.Any) -> int:
return self._write(data) return self._write(data)
def recv(self, n: int, *args, **kwargs) -> bytes: def recv(self, n: int, *args: t.Any, **kwargs: t.Any) -> bytes:
if not self.proc: if not self.proc:
raise RuntimeError( raise RuntimeError(
"SSH subprocess not initiated. connect() must be called first." "SSH subprocess not initiated. connect() must be called first."
@ -114,7 +114,7 @@ class SSHSocket(socket.socket):
assert self.proc.stdout is not None assert self.proc.stdout is not None
return self.proc.stdout.read(n) return self.proc.stdout.read(n)
def makefile(self, mode: str, *args, **kwargs) -> t.IO: # type: ignore def makefile(self, mode: str, *args: t.Any, **kwargs: t.Any) -> t.IO: # type: ignore
if not self.proc: if not self.proc:
self.connect() self.connect()
assert self.proc is not None assert self.proc is not None
@ -138,7 +138,7 @@ class SSHConnection(urllib3_connection.HTTPConnection):
def __init__( def __init__(
self, self,
*, *,
ssh_transport=None, ssh_transport: paramiko.Transport | None = None,
timeout: int | float = 60, timeout: int | float = 60,
host: str, host: str,
) -> None: ) -> None:
@ -146,18 +146,19 @@ class SSHConnection(urllib3_connection.HTTPConnection):
self.ssh_transport = ssh_transport self.ssh_transport = ssh_transport
self.timeout = timeout self.timeout = timeout
self.ssh_host = host self.ssh_host = host
self.sock: paramiko.Channel | SSHSocket | None = None
def connect(self) -> None: def connect(self) -> None:
if self.ssh_transport: if self.ssh_transport:
sock = self.ssh_transport.open_session() channel = self.ssh_transport.open_session()
sock.settimeout(self.timeout) channel.settimeout(self.timeout)
sock.exec_command("docker system dial-stdio") channel.exec_command("docker system dial-stdio")
self.sock = channel
else: else:
sock = SSHSocket(self.ssh_host) sock = SSHSocket(self.ssh_host)
sock.settimeout(self.timeout) sock.settimeout(self.timeout)
sock.connect() sock.connect()
self.sock = sock
self.sock = sock
class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool): class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
@ -172,7 +173,7 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
host: str, host: str,
) -> None: ) -> None:
super().__init__("localhost", timeout=timeout, maxsize=maxsize) super().__init__("localhost", timeout=timeout, maxsize=maxsize)
self.ssh_transport = None self.ssh_transport: paramiko.Transport | None = None
self.timeout = timeout self.timeout = timeout
if ssh_client: if ssh_client:
self.ssh_transport = ssh_client.get_transport() self.ssh_transport = ssh_client.get_transport()
@ -276,7 +277,9 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
if self.ssh_client: if self.ssh_client:
self.ssh_client.connect(**self.ssh_params) self.ssh_client.connect(**self.ssh_params)
def get_connection(self, url: str | bytes, proxies=None) -> SSHConnectionPool: def get_connection(
self, url: str | bytes, proxies: Mapping[str, str] | None = None
) -> SSHConnectionPool:
if not self.ssh_client: if not self.ssh_client:
return SSHConnectionPool( return SSHConnectionPool(
ssh_client=self.ssh_client, ssh_client=self.ssh_client,

View File

@ -33,7 +33,7 @@ class SSLHTTPAdapter(BaseHTTPAdapter):
def __init__( def __init__(
self, self,
assert_hostname: bool | None = None, assert_hostname: bool | None = None,
**kwargs, **kwargs: t.Any,
) -> None: ) -> None:
self.assert_hostname = assert_hostname self.assert_hostname = assert_hostname
super().__init__(**kwargs) super().__init__(**kwargs)
@ -51,7 +51,7 @@ class SSLHTTPAdapter(BaseHTTPAdapter):
self.poolmanager = PoolManager(**kwargs) self.poolmanager = PoolManager(**kwargs)
def get_connection(self, *args, **kwargs) -> urllib3.ConnectionPool: def get_connection(self, *args: t.Any, **kwargs: t.Any) -> urllib3.ConnectionPool:
""" """
Ensure assert_hostname is set correctly on our pool Ensure assert_hostname is set correctly on our pool

View File

@ -19,12 +19,20 @@ from .._import_helper import HTTPAdapter, urllib3, urllib3_connection
from .basehttpadapter import BaseHTTPAdapter from .basehttpadapter import BaseHTTPAdapter
if t.TYPE_CHECKING:
from collections.abc import Mapping
from requests import PreparedRequest
from ..._socket_helper import SocketLike
RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
class UnixHTTPConnection(urllib3_connection.HTTPConnection): class UnixHTTPConnection(urllib3_connection.HTTPConnection):
def __init__( def __init__(
self, base_url: str | bytes, unix_socket, timeout: int | float = 60 self, base_url: str | bytes, unix_socket: str, timeout: int | float = 60
) -> None: ) -> None:
super().__init__("localhost", timeout=timeout) super().__init__("localhost", timeout=timeout)
self.base_url = base_url self.base_url = base_url
@ -43,7 +51,7 @@ class UnixHTTPConnection(urllib3_connection.HTTPConnection):
if header == "Connection" and "Upgrade" in values: if header == "Connection" and "Upgrade" in values:
self.disable_buffering = True self.disable_buffering = True
def response_class(self, sock, *args, **kwargs) -> t.Any: def response_class(self, sock: SocketLike, *args: t.Any, **kwargs: t.Any) -> t.Any:
# FIXME: We may need to disable buffering on Py3, # FIXME: We may need to disable buffering on Py3,
# but there's no clear way to do it at the moment. See: # but there's no clear way to do it at the moment. See:
# https://github.com/docker/docker-py/issues/1799 # https://github.com/docker/docker-py/issues/1799
@ -88,12 +96,16 @@ class UnixHTTPAdapter(BaseHTTPAdapter):
self.socket_path = socket_path self.socket_path = socket_path
self.timeout = timeout self.timeout = timeout
self.max_pool_size = max_pool_size self.max_pool_size = max_pool_size
self.pools = RecentlyUsedContainer(
pool_connections, dispose_func=lambda p: p.close() def f(p: t.Any) -> None:
) p.close()
self.pools = RecentlyUsedContainer(pool_connections, dispose_func=f)
super().__init__() super().__init__()
def get_connection(self, url: str | bytes, proxies=None) -> UnixHTTPConnectionPool: def get_connection(
self, url: str | bytes, proxies: Mapping[str, str] | None = None
) -> UnixHTTPConnectionPool:
with self.pools.lock: with self.pools.lock:
pool = self.pools.get(url) pool = self.pools.get(url)
if pool: if pool:
@ -106,7 +118,7 @@ class UnixHTTPAdapter(BaseHTTPAdapter):
return pool return pool
def request_url(self, request, proxies) -> str: def request_url(self, request: PreparedRequest, proxies: Mapping[str, str]) -> str:
# The select_proxy utility in requests errors out when the provided URL # The select_proxy utility in requests errors out when the provided URL
# does not have a hostname, like is the case when using a UNIX socket. # does not have a hostname, like is the case when using a UNIX socket.
# Since proxies are an irrelevant notion in the case of UNIX sockets # Since proxies are an irrelevant notion in the case of UNIX sockets

View File

@ -18,7 +18,13 @@ from .._import_helper import urllib3
from ..errors import DockerException from ..errors import DockerException
class CancellableStream: if t.TYPE_CHECKING:
from requests import Response
_T = t.TypeVar("_T")
class CancellableStream(t.Generic[_T]):
""" """
Stream wrapper for real-time events, logs, etc. from the server. Stream wrapper for real-time events, logs, etc. from the server.
@ -30,14 +36,14 @@ class CancellableStream:
>>> events.close() >>> events.close()
""" """
def __init__(self, stream, response) -> None: def __init__(self, stream: t.Generator[_T], response: Response) -> None:
self._stream = stream self._stream = stream
self._response = response self._response = response
def __iter__(self) -> t.Self: def __iter__(self) -> t.Self:
return self return self
def __next__(self): def __next__(self) -> _T:
try: try:
return next(self._stream) return next(self._stream)
except urllib3.exceptions.ProtocolError as exc: except urllib3.exceptions.ProtocolError as exc:
@ -56,7 +62,7 @@ class CancellableStream:
# find the underlying socket object # find the underlying socket object
# based on api.client._get_raw_response_socket # based on api.client._get_raw_response_socket
sock_fp = self._response.raw._fp.fp sock_fp = self._response.raw._fp.fp # type: ignore
if hasattr(sock_fp, "raw"): if hasattr(sock_fp, "raw"):
sock_raw = sock_fp.raw sock_raw = sock_fp.raw
@ -74,7 +80,7 @@ class CancellableStream:
"Cancellable streams not supported for the SSH protocol" "Cancellable streams not supported for the SSH protocol"
) )
else: else:
sock = sock_fp._sock sock = sock_fp._sock # type: ignore
if hasattr(urllib3.contrib, "pyopenssl") and isinstance( if hasattr(urllib3.contrib, "pyopenssl") and isinstance(
sock, urllib3.contrib.pyopenssl.WrappedSocket sock, urllib3.contrib.pyopenssl.WrappedSocket

View File

@ -37,7 +37,7 @@ def _purge() -> None:
_cache.clear() _cache.clear()
def fnmatch(name: str, pat: str): def fnmatch(name: str, pat: str) -> bool:
"""Test whether FILENAME matches PATTERN. """Test whether FILENAME matches PATTERN.
Patterns are Unix shell style: Patterns are Unix shell style:

View File

@ -22,7 +22,9 @@ from ..transport.npipesocket import NpipeSocket
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from collections.abc import Iterable from collections.abc import Iterable, Sequence
from ..._socket_helper import SocketLike
STDOUT = 1 STDOUT = 1
@ -38,14 +40,14 @@ class SocketError(Exception):
NPIPE_ENDED = 109 NPIPE_ENDED = 109
def read(socket, n: int = 4096) -> bytes | None: def read(socket: SocketLike, n: int = 4096) -> bytes | None:
""" """
Reads at most n bytes from socket Reads at most n bytes from socket
""" """
recoverable_errors = (errno.EINTR, errno.EDEADLK, errno.EWOULDBLOCK) recoverable_errors = (errno.EINTR, errno.EDEADLK, errno.EWOULDBLOCK)
if not isinstance(socket, NpipeSocket): if not isinstance(socket, NpipeSocket): # type: ignore[unreachable]
if not hasattr(select, "poll"): if not hasattr(select, "poll"):
# Limited to 1024 # Limited to 1024
select.select([socket], [], []) select.select([socket], [], [])
@ -66,7 +68,7 @@ def read(socket, n: int = 4096) -> bytes | None:
return None # TODO ??? return None # TODO ???
except Exception as e: except Exception as e:
is_pipe_ended = ( is_pipe_ended = (
isinstance(socket, NpipeSocket) isinstance(socket, NpipeSocket) # type: ignore[unreachable]
and len(e.args) > 0 and len(e.args) > 0
and e.args[0] == NPIPE_ENDED and e.args[0] == NPIPE_ENDED
) )
@ -77,7 +79,7 @@ def read(socket, n: int = 4096) -> bytes | None:
raise raise
def read_exactly(socket, n: int) -> bytes: def read_exactly(socket: SocketLike, n: int) -> bytes:
""" """
Reads exactly n bytes from socket Reads exactly n bytes from socket
Raises SocketError if there is not enough data Raises SocketError if there is not enough data
@ -91,7 +93,7 @@ def read_exactly(socket, n: int) -> bytes:
return data return data
def next_frame_header(socket) -> tuple[int, int]: def next_frame_header(socket: SocketLike) -> tuple[int, int]:
""" """
Returns the stream and size of the next frame of data waiting to be read Returns the stream and size of the next frame of data waiting to be read
from socket, according to the protocol defined here: from socket, according to the protocol defined here:
@ -107,7 +109,7 @@ def next_frame_header(socket) -> tuple[int, int]:
return (stream, actual) return (stream, actual)
def frames_iter(socket, tty: bool) -> t.Generator[tuple[int, bytes]]: def frames_iter(socket: SocketLike, tty: bool) -> t.Generator[tuple[int, bytes]]:
""" """
Return a generator of frames read from socket. A frame is a tuple where Return a generator of frames read from socket. A frame is a tuple where
the first item is the stream number and the second item is a chunk of data. the first item is the stream number and the second item is a chunk of data.
@ -120,7 +122,7 @@ def frames_iter(socket, tty: bool) -> t.Generator[tuple[int, bytes]]:
return frames_iter_no_tty(socket) return frames_iter_no_tty(socket)
def frames_iter_no_tty(socket) -> t.Generator[tuple[int, bytes]]: def frames_iter_no_tty(socket: SocketLike) -> t.Generator[tuple[int, bytes]]:
""" """
Returns a generator of data read from the socket when the tty setting is Returns a generator of data read from the socket when the tty setting is
not enabled. not enabled.
@ -141,7 +143,7 @@ def frames_iter_no_tty(socket) -> t.Generator[tuple[int, bytes]]:
yield (stream, result) yield (stream, result)
def frames_iter_tty(socket) -> t.Generator[bytes]: def frames_iter_tty(socket: SocketLike) -> t.Generator[bytes]:
""" """
Return a generator of data read from the socket when the tty setting is Return a generator of data read from the socket when the tty setting is
enabled. enabled.
@ -155,20 +157,42 @@ def frames_iter_tty(socket) -> t.Generator[bytes]:
@t.overload @t.overload
def consume_socket_output(frames, demux: t.Literal[False] = False) -> bytes: ... def consume_socket_output(
frames: Sequence[bytes] | t.Generator[bytes], demux: t.Literal[False] = False
) -> bytes: ...
@t.overload
def consume_socket_output(frames, demux: t.Literal[True]) -> tuple[bytes, bytes]: ...
@t.overload @t.overload
def consume_socket_output( def consume_socket_output(
frames, demux: bool = False frames: (
Sequence[tuple[bytes | None, bytes | None]]
| t.Generator[tuple[bytes | None, bytes | None]]
),
demux: t.Literal[True],
) -> tuple[bytes, bytes]: ...
@t.overload
def consume_socket_output(
frames: (
Sequence[bytes]
| Sequence[tuple[bytes | None, bytes | None]]
| t.Generator[bytes]
| t.Generator[tuple[bytes | None, bytes | None]]
),
demux: bool = False,
) -> bytes | tuple[bytes, bytes]: ... ) -> bytes | tuple[bytes, bytes]: ...
def consume_socket_output(frames, demux: bool = False) -> bytes | tuple[bytes, bytes]: def consume_socket_output(
frames: (
Sequence[bytes]
| Sequence[tuple[bytes | None, bytes | None]]
| t.Generator[bytes]
| t.Generator[tuple[bytes | None, bytes | None]]
),
demux: bool = False,
) -> bytes | tuple[bytes, bytes]:
""" """
Iterate through frames read from the socket and return the result. Iterate through frames read from the socket and return the result.
@ -183,12 +207,13 @@ def consume_socket_output(frames, demux: bool = False) -> bytes | tuple[bytes, b
if demux is False: if demux is False:
# If the streams are multiplexed, the generator returns strings, that # If the streams are multiplexed, the generator returns strings, that
# we just need to concatenate. # we just need to concatenate.
return b"".join(frames) return b"".join(frames) # type: ignore
# If the streams are demultiplexed, the generator yields tuples # If the streams are demultiplexed, the generator yields tuples
# (stdout, stderr) # (stdout, stderr)
out: list[bytes | None] = [None, None] out: list[bytes | None] = [None, None]
for frame in frames: frame: tuple[bytes | None, bytes | None]
for frame in frames: # type: ignore
# It is guaranteed that for each frame, one and only one stream # It is guaranteed that for each frame, one and only one stream
# is not None. # is not None.
if frame == (None, None): if frame == (None, None):
@ -202,7 +227,7 @@ def consume_socket_output(frames, demux: bool = False) -> bytes | tuple[bytes, b
if out[1] is None: if out[1] is None:
out[1] = frame[1] out[1] = frame[1]
else: else:
out[1] += frame[1] out[1] += frame[1] # type: ignore[operator]
return tuple(out) # type: ignore return tuple(out) # type: ignore

View File

@ -502,8 +502,8 @@ def split_command(command: str) -> list[str]:
return shlex.split(command) return shlex.split(command)
def format_environment(environment: Mapping[str, str | bytes]) -> list[str]: def format_environment(environment: Mapping[str, str | bytes | None]) -> list[str]:
def format_env(key, value): def format_env(key: str, value: str | bytes | None) -> str:
if value is None: if value is None:
return key return key
if isinstance(value, bytes): if isinstance(value, bytes):

View File

@ -91,7 +91,7 @@ if not HAS_DOCKER_PY:
# No Docker SDK for Python. Create a place holder client to allow # No Docker SDK for Python. Create a place holder client to allow
# instantiation of AnsibleModule and proper error handing # instantiation of AnsibleModule and proper error handing
class Client: # type: ignore # noqa: F811, pylint: disable=function-redefined class Client: # type: ignore # noqa: F811, pylint: disable=function-redefined
def __init__(self, **kwargs): def __init__(self, **kwargs: t.Any) -> None:
pass pass
class APIError(Exception): # type: ignore # noqa: F811, pylint: disable=function-redefined class APIError(Exception): # type: ignore # noqa: F811, pylint: disable=function-redefined
@ -226,7 +226,7 @@ class AnsibleDockerClientBase(Client):
f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}." f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}."
) )
def log(self, msg: t.Any, pretty_print: bool = False): def log(self, msg: t.Any, pretty_print: bool = False) -> None:
pass pass
# if self.debug: # if self.debug:
# from .util import log_debug # from .util import log_debug
@ -609,7 +609,7 @@ class AnsibleDockerClientBase(Client):
return new_tag, old_tag == new_tag return new_tag, old_tag == new_tag
def inspect_distribution(self, image: str, **kwargs) -> dict[str, t.Any]: def inspect_distribution(self, image: str, **kwargs: t.Any) -> dict[str, t.Any]:
""" """
Get image digest by directly calling the Docker API when running Docker SDK < 4.0.0 Get image digest by directly calling the Docker API when running Docker SDK < 4.0.0
since prior versions did not support accessing private repositories. since prior versions did not support accessing private repositories.
@ -629,7 +629,6 @@ class AnsibleDockerClientBase(Client):
class AnsibleDockerClient(AnsibleDockerClientBase): class AnsibleDockerClient(AnsibleDockerClientBase):
def __init__( def __init__(
self, self,
argument_spec: dict[str, t.Any] | None = None, argument_spec: dict[str, t.Any] | None = None,
@ -651,7 +650,6 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
option_minimal_versions_ignore_params: Sequence[str] | None = None, option_minimal_versions_ignore_params: Sequence[str] | None = None,
fail_results: dict[str, t.Any] | None = None, fail_results: dict[str, t.Any] | None = None,
): ):
# Modules can put information in here which will always be returned # Modules can put information in here which will always be returned
# in case client.fail() is called. # in case client.fail() is called.
self.fail_results = fail_results or {} self.fail_results = fail_results or {}

View File

@ -146,7 +146,7 @@ class AnsibleDockerClientBase(Client):
f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}." f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}."
) )
def log(self, msg: t.Any, pretty_print: bool = False): def log(self, msg: t.Any, pretty_print: bool = False) -> None:
pass pass
# if self.debug: # if self.debug:
# from .util import log_debug # from .util import log_debug
@ -295,7 +295,7 @@ class AnsibleDockerClientBase(Client):
), ),
} }
def depr(*args, **kwargs): def depr(*args: t.Any, **kwargs: t.Any) -> None:
self.deprecate(*args, **kwargs) self.deprecate(*args, **kwargs)
update_tls_hostname( update_tls_hostname(

View File

@ -82,7 +82,7 @@ class AnsibleDockerClientBase:
def __init__( def __init__(
self, self,
common_args, common_args: dict[str, t.Any],
min_docker_api_version: str | None = None, min_docker_api_version: str | None = None,
needs_api_version: bool = True, needs_api_version: bool = True,
) -> None: ) -> None:
@ -91,15 +91,15 @@ class AnsibleDockerClientBase:
self._environment["DOCKER_TLS_HOSTNAME"] = common_args["tls_hostname"] self._environment["DOCKER_TLS_HOSTNAME"] = common_args["tls_hostname"]
if common_args["api_version"] and common_args["api_version"] != "auto": if common_args["api_version"] and common_args["api_version"] != "auto":
self._environment["DOCKER_API_VERSION"] = common_args["api_version"] self._environment["DOCKER_API_VERSION"] = common_args["api_version"]
self._cli = common_args.get("docker_cli") cli = common_args.get("docker_cli")
if self._cli is None: if cli is None:
try: try:
self._cli = get_bin_path("docker") cli = get_bin_path("docker")
except ValueError: except ValueError:
self.fail( self.fail(
"Cannot find docker CLI in path. Please provide it explicitly with the docker_cli parameter" "Cannot find docker CLI in path. Please provide it explicitly with the docker_cli parameter"
) )
self._cli = cli
self._cli_base = [self._cli] self._cli_base = [self._cli]
docker_host = common_args["docker_host"] docker_host = common_args["docker_host"]
if not docker_host and not common_args["cli_context"]: if not docker_host and not common_args["cli_context"]:
@ -149,7 +149,7 @@ class AnsibleDockerClientBase:
"Internal error: cannot have needs_api_version=False with min_docker_api_version not None" "Internal error: cannot have needs_api_version=False with min_docker_api_version not None"
) )
def log(self, msg: str, pretty_print: bool = False): def log(self, msg: str, pretty_print: bool = False) -> None:
pass pass
# if self.debug: # if self.debug:
# from .util import log_debug # from .util import log_debug
@ -227,7 +227,7 @@ class AnsibleDockerClientBase:
return rc, result, stderr return rc, result, stderr
@abc.abstractmethod @abc.abstractmethod
def fail(self, msg: str, **kwargs) -> t.NoReturn: def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
pass pass
@abc.abstractmethod @abc.abstractmethod
@ -395,7 +395,6 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase):
fail_results: dict[str, t.Any] | None = None, fail_results: dict[str, t.Any] | None = None,
needs_api_version: bool = True, needs_api_version: bool = True,
) -> None: ) -> None:
# Modules can put information in here which will always be returned # Modules can put information in here which will always be returned
# in case client.fail() is called. # in case client.fail() is called.
self.fail_results = fail_results or {} self.fail_results = fail_results or {}
@ -463,7 +462,7 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase):
) )
return rc, stdout, stderr return rc, stdout, stderr
def fail(self, msg: str, **kwargs) -> t.NoReturn: def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
self.fail_results.update(kwargs) self.fail_results.update(kwargs)
self.module.fail_json(msg=msg, **sanitize_result(self.fail_results)) self.module.fail_json(msg=msg, **sanitize_result(self.fail_results))

View File

@ -971,7 +971,7 @@ class BaseComposeManager(DockerBaseClass):
stderr: str | bytes, stderr: str | bytes,
ignore_service_pull_events: bool = False, ignore_service_pull_events: bool = False,
ignore_build_events: bool = False, ignore_build_events: bool = False,
): ) -> None:
result["changed"] = result.get("changed", False) or has_changes( result["changed"] = result.get("changed", False) or has_changes(
events, events,
ignore_service_pull_events=ignore_service_pull_events, ignore_service_pull_events=ignore_service_pull_events,
@ -989,7 +989,7 @@ class BaseComposeManager(DockerBaseClass):
stdout: str | bytes, stdout: str | bytes,
stderr: bytes, stderr: bytes,
rc: int, rc: int,
): ) -> bool:
return update_failed( return update_failed(
result, result,
events, events,

View File

@ -330,6 +330,8 @@ def stat_file(
client._raise_for_status(response) client._raise_for_status(response)
header = response.headers.get("x-docker-container-path-stat") header = response.headers.get("x-docker-container-path-stat")
try: try:
if header is None:
raise ValueError("x-docker-container-path-stat header not present")
stat_data = json.loads(base64.b64decode(header)) stat_data = json.loads(base64.b64decode(header))
except Exception as exc: except Exception as exc:
raise DockerUnexpectedError( raise DockerUnexpectedError(
@ -482,14 +484,14 @@ def fetch_file(
shutil.copyfileobj(in_f, out_f) shutil.copyfileobj(in_f, out_f)
return in_path return in_path
def process_symlink(in_path, member) -> str: def process_symlink(in_path: str, member: tarfile.TarInfo) -> str:
if os.path.exists(b_out_path): if os.path.exists(b_out_path):
os.unlink(b_out_path) os.unlink(b_out_path)
os.symlink(member.linkname, b_out_path) os.symlink(member.linkname, b_out_path)
return in_path return in_path
def process_other(in_path, member) -> str: def process_other(in_path: str, member: tarfile.TarInfo) -> str:
raise DockerFileCopyError( raise DockerFileCopyError(
f'Remote file "{in_path}" is not a regular file or a symbolic link' f'Remote file "{in_path}" is not a regular file or a symbolic link'
) )

View File

@ -193,7 +193,9 @@ class OptionGroup:
) -> None: ) -> None:
if preprocess is None: if preprocess is None:
def preprocess(module, values): def preprocess(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
return values return values
self.preprocess = preprocess self.preprocess = preprocess
@ -207,8 +209,8 @@ class OptionGroup:
self.ansible_required_by = ansible_required_by or {} self.ansible_required_by = ansible_required_by or {}
self.argument_spec: dict[str, t.Any] = {} self.argument_spec: dict[str, t.Any] = {}
def add_option(self, *args, **kwargs) -> OptionGroup: def add_option(self, name: str, **kwargs: t.Any) -> OptionGroup:
option = Option(*args, owner=self, **kwargs) option = Option(name, owner=self, **kwargs)
if not option.not_a_container_option: if not option.not_a_container_option:
self.options.append(option) self.options.append(option)
self.all_options.append(option) self.all_options.append(option)
@ -788,7 +790,7 @@ def _preprocess_mounts(
) -> dict[str, t.Any]: ) -> dict[str, t.Any]:
last: dict[str, str] = {} last: dict[str, str] = {}
def check_collision(t, name): def check_collision(t: str, name: str) -> None:
if t in last: if t in last:
if name == last[t]: if name == last[t]:
module.fail_json( module.fail_json(
@ -1069,7 +1071,9 @@ def _preprocess_ports(
return values return values
def _compare_platform(option: Option, param_value: t.Any, container_value: t.Any): def _compare_platform(
option: Option, param_value: t.Any, container_value: t.Any
) -> bool:
if option.comparison == "ignore": if option.comparison == "ignore":
return True return True
try: try:

View File

@ -872,7 +872,7 @@ class DockerAPIEngine(Engine[AnsibleDockerClient]):
image: dict[str, t.Any] | None, image: dict[str, t.Any] | None,
values: dict[str, t.Any], values: dict[str, t.Any],
host_info: dict[str, t.Any] | None, host_info: dict[str, t.Any] | None,
): ) -> dict[str, t.Any]:
if len(options) != 1: if len(options) != 1:
raise AssertionError( raise AssertionError(
"host_config_value can only be used for a single option" "host_config_value can only be used for a single option"
@ -1961,7 +1961,14 @@ def _update_value_restart(
} }
def _get_values_ports(module, container, api_version, options, image, host_info): def _get_values_ports(
module: AnsibleModule,
container: dict[str, t.Any],
api_version: LooseVersion,
options: list[Option],
image: dict[str, t.Any] | None,
host_info: dict[str, t.Any] | None,
) -> dict[str, t.Any]:
host_config = container["HostConfig"] host_config = container["HostConfig"]
config = container["Config"] config = container["Config"]

View File

@ -292,7 +292,7 @@ class ContainerManager(DockerBaseClass, t.Generic[Client]):
if self.module.params[param] is None: if self.module.params[param] is None:
self.module.params[param] = value self.module.params[param] = value
def fail(self, *args, **kwargs) -> t.NoReturn: def fail(self, *args: str, **kwargs: t.Any) -> t.NoReturn:
# mypy doesn't know that Client has fail() method # mypy doesn't know that Client has fail() method
raise self.client.fail(*args, **kwargs) # type: ignore raise self.client.fail(*args, **kwargs) # type: ignore
@ -714,7 +714,7 @@ class ContainerManager(DockerBaseClass, t.Generic[Client]):
container_image: dict[str, t.Any] | None, container_image: dict[str, t.Any] | None,
image: dict[str, t.Any] | None, image: dict[str, t.Any] | None,
host_info: dict[str, t.Any] | None, host_info: dict[str, t.Any] | None,
): ) -> None:
assert container.raw is not None assert container.raw is not None
container_values = engine.get_value( container_values = engine.get_value(
self.module, self.module,
@ -767,12 +767,12 @@ class ContainerManager(DockerBaseClass, t.Generic[Client]):
# Since the order does not matter, sort so that the diff output is better. # Since the order does not matter, sort so that the diff output is better.
if option.name == "expected_mounts": if option.name == "expected_mounts":
# For selected values, use one entry as key # For selected values, use one entry as key
def sort_key_fn(x): def sort_key_fn(x: dict[str, t.Any]) -> t.Any:
return x["target"] return x["target"]
else: else:
# We sort the list of dictionaries by using the sorted items of a dict as its key. # We sort the list of dictionaries by using the sorted items of a dict as its key.
def sort_key_fn(x): def sort_key_fn(x: dict[str, t.Any]) -> t.Any:
return sorted( return sorted(
(a, to_text(b, errors="surrogate_or_strict")) (a, to_text(b, errors="surrogate_or_strict"))
for a, b in x.items() for a, b in x.items()

View File

@ -26,6 +26,7 @@ from ansible_collections.community.docker.plugins.module_utils._socket_helper im
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from collections.abc import Callable from collections.abc import Callable
from types import TracebackType
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
@ -70,7 +71,12 @@ class DockerSocketHandlerBase:
def __enter__(self) -> t.Self: def __enter__(self) -> t.Self:
return self return self
def __exit__(self, type_, value, tb) -> None: def __exit__(
self,
type_: t.Type[BaseException] | None,
value: BaseException | None,
tb: TracebackType | None,
) -> None:
self._selector.close() self._selector.close()
def set_block_done_callback( def set_block_done_callback(
@ -210,7 +216,7 @@ class DockerSocketHandlerBase:
stdout = [] stdout = []
stderr = [] stderr = []
def append_block(stream_id, data): def append_block(stream_id: int, data: bytes) -> None:
if stream_id == docker_socket.STDOUT: if stream_id == docker_socket.STDOUT:
stdout.append(data) stdout.append(data)
elif stream_id == docker_socket.STDERR: elif stream_id == docker_socket.STDERR:

View File

@ -23,6 +23,12 @@ if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
from ._common import AnsibleDockerClientBase as CADCB
from ._common_api import AnsibleDockerClientBase as CAPIADCB
from ._common_cli import AnsibleDockerClientBase as CCLIADCB
Client = t.Union[CADCB, CAPIADCB, CCLIADCB]
DEFAULT_DOCKER_HOST = "unix:///var/run/docker.sock" DEFAULT_DOCKER_HOST = "unix:///var/run/docker.sock"
DEFAULT_TLS = False DEFAULT_TLS = False
@ -119,7 +125,7 @@ def sanitize_result(data: t.Any) -> t.Any:
return data return data
def log_debug(msg: t.Any, pretty_print: bool = False): def log_debug(msg: t.Any, pretty_print: bool = False) -> None:
"""Write a log message to docker.log. """Write a log message to docker.log.
If ``pretty_print=True``, the message will be pretty-printed as JSON. If ``pretty_print=True``, the message will be pretty-printed as JSON.
@ -325,7 +331,7 @@ class DifferenceTracker:
def sanitize_labels( def sanitize_labels(
labels: dict[str, t.Any] | None, labels: dict[str, t.Any] | None,
labels_field: str, labels_field: str,
client=None, client: Client | None = None,
module: AnsibleModule | None = None, module: AnsibleModule | None = None,
) -> None: ) -> None:
def fail(msg: str) -> t.NoReturn: def fail(msg: str) -> t.NoReturn:
@ -371,7 +377,7 @@ def clean_dict_booleans_for_docker_api(
which is the expected format of filters which accept lists such as labels. which is the expected format of filters which accept lists such as labels.
""" """
def sanitize(value): def sanitize(value: t.Any) -> str:
if value is True: if value is True:
return "true" return "true"
if value is False: if value is False:

View File

@ -147,7 +147,7 @@ class PullManager(BaseComposeManager):
f"--ignore-buildable is only supported since Docker Compose 2.15.0. {self.client.get_cli()} has version {self.compose_version}" f"--ignore-buildable is only supported since Docker Compose 2.15.0. {self.client.get_cli()} has version {self.compose_version}"
) )
def get_pull_cmd(self, dry_run: bool): def get_pull_cmd(self, dry_run: bool) -> list[str]:
args = self.get_base_args() + ["pull"] args = self.get_base_args() + ["pull"]
if self.policy != "always": if self.policy != "always":
args.extend(["--policy", self.policy]) args.extend(["--policy", self.policy])

View File

@ -347,7 +347,7 @@ def retrieve_diff(
max_file_size_for_diff: int, max_file_size_for_diff: int,
regular_stat: dict[str, t.Any] | None = None, regular_stat: dict[str, t.Any] | None = None,
link_target: str | None = None, link_target: str | None = None,
): ) -> None:
if diff is None: if diff is None:
return return
if regular_stat is not None: if regular_stat is not None:
@ -497,9 +497,9 @@ def is_file_idempotent(
container_path: str, container_path: str,
follow_links: bool, follow_links: bool,
local_follow_links: bool, local_follow_links: bool,
owner_id, owner_id: int,
group_id, group_id: int,
mode, mode: int | None,
force: bool | None = False, force: bool | None = False,
diff: dict[str, t.Any] | None = None, diff: dict[str, t.Any] | None = None,
max_file_size_for_diff: int = 1, max_file_size_for_diff: int = 1,
@ -744,9 +744,9 @@ def copy_file_into_container(
container_path: str, container_path: str,
follow_links: bool, follow_links: bool,
local_follow_links: bool, local_follow_links: bool,
owner_id, owner_id: int,
group_id, group_id: int,
mode, mode: int | None,
force: bool | None = False, force: bool | None = False,
do_diff: bool = False, do_diff: bool = False,
max_file_size_for_diff: int = 1, max_file_size_for_diff: int = 1,
@ -797,9 +797,9 @@ def is_content_idempotent(
content: bytes, content: bytes,
container_path: str, container_path: str,
follow_links: bool, follow_links: bool,
owner_id, owner_id: int,
group_id, group_id: int,
mode, mode: int,
force: bool | None = False, force: bool | None = False,
diff: dict[str, t.Any] | None = None, diff: dict[str, t.Any] | None = None,
max_file_size_for_diff: int = 1, max_file_size_for_diff: int = 1,
@ -989,9 +989,9 @@ def copy_content_into_container(
content: bytes, content: bytes,
container_path: str, container_path: str,
follow_links: bool, follow_links: bool,
owner_id, owner_id: int,
group_id, group_id: int,
mode, mode: int,
force: bool | None = False, force: bool | None = False,
do_diff: bool = False, do_diff: bool = False,
max_file_size_for_diff: int = 1, max_file_size_for_diff: int = 1,
@ -1133,6 +1133,7 @@ def main() -> None:
owner_id, group_id = determine_user_group(client, container) owner_id, group_id = determine_user_group(client, container)
if content is not None: if content is not None:
assert mode is not None # see required_by above
copy_content_into_container( copy_content_into_container(
client, client,
container, container,

View File

@ -667,7 +667,7 @@ class ImageManager(DockerBaseClass):
:rtype: str :rtype: str
""" """
def build_msg(reason): def build_msg(reason: str) -> str:
return f"Archived image {current_image_name} to {archive_path}, {reason}" return f"Archived image {current_image_name} to {archive_path}, {reason}"
try: try:
@ -877,7 +877,7 @@ class ImageManager(DockerBaseClass):
self.push_image(repo, repo_tag) self.push_image(repo, repo_tag)
@staticmethod @staticmethod
def _extract_output_line(line: dict[str, t.Any], output: list[str]): def _extract_output_line(line: dict[str, t.Any], output: list[str]) -> None:
""" """
Extract text line from stream output and, if found, adds it to output. Extract text line from stream output and, if found, adds it to output.
""" """
@ -1165,18 +1165,18 @@ def main() -> None:
("source", "load", ["load_path"]), ("source", "load", ["load_path"]),
] ]
def detect_etc_hosts(client): def detect_etc_hosts(client: AnsibleDockerClient) -> bool:
return client.module.params["build"] and bool( return client.module.params["build"] and bool(
client.module.params["build"].get("etc_hosts") client.module.params["build"].get("etc_hosts")
) )
def detect_build_platform(client): def detect_build_platform(client: AnsibleDockerClient) -> bool:
return ( return (
client.module.params["build"] client.module.params["build"]
and client.module.params["build"].get("platform") is not None and client.module.params["build"].get("platform") is not None
) )
def detect_pull_platform(client): def detect_pull_platform(client: AnsibleDockerClient) -> bool:
return ( return (
client.module.params["pull"] client.module.params["pull"]
and client.module.params["pull"].get("platform") is not None and client.module.params["pull"].get("platform") is not None

View File

@ -379,7 +379,7 @@ def normalize_ipam_config_key(key: str) -> str:
return special_cases.get(key, key.lower()) return special_cases.get(key, key.lower())
def dicts_are_essentially_equal(a: dict[str, t.Any], b: dict[str, t.Any]): def dicts_are_essentially_equal(a: dict[str, t.Any], b: dict[str, t.Any]) -> bool:
"""Make sure that a is a subset of b, where None entries of a are ignored.""" """Make sure that a is a subset of b, where None entries of a are ignored."""
for k, v in a.items(): for k, v in a.items():
if v is None: if v is None:

View File

@ -204,12 +204,10 @@ class DockerPluginManager:
elif state == "disable": elif state == "disable":
self.disable() self.disable()
if self.diff or self.check_mode or self.parameters.debug: if self.diff:
if self.diff: self.diff_result["before"], self.diff_result["after"] = (
self.diff_result["before"], self.diff_result["after"] = ( self.diff_tracker.get_before_after()
self.diff_tracker.get_before_after() )
)
self.diff = self.diff_result
def get_existing_plugin(self) -> dict[str, t.Any] | None: def get_existing_plugin(self) -> dict[str, t.Any] | None:
try: try:
@ -409,7 +407,7 @@ class DockerPluginManager:
result: dict[str, t.Any] = { result: dict[str, t.Any] = {
"actions": self.actions, "actions": self.actions,
"changed": self.changed, "changed": self.changed,
"diff": self.diff, "diff": self.diff_result,
"plugin": plugin_data, "plugin": plugin_data,
} }
if ( if (

View File

@ -247,7 +247,9 @@ class DockerSwarmManager(DockerBaseClass):
self.client.fail(f"Error inspecting docker swarm: {exc}") self.client.fail(f"Error inspecting docker swarm: {exc}")
def get_docker_items_list( def get_docker_items_list(
self, docker_object: t.Literal["nodes", "tasks", "services"], filters=None self,
docker_object: t.Literal["nodes", "tasks", "services"],
filters: dict[str, str],
) -> list[dict[str, t.Any]]: ) -> list[dict[str, t.Any]]:
items_list: list[dict[str, t.Any]] = [] items_list: list[dict[str, t.Any]] = []

View File

@ -1463,8 +1463,8 @@ class DockerService(DockerBaseClass):
def from_ansible_params( def from_ansible_params(
cls, cls,
ap: dict[str, t.Any], ap: dict[str, t.Any],
old_service, old_service: DockerService | None,
image_digest, image_digest: str,
secret_ids: dict[str, str], secret_ids: dict[str, str],
config_ids: dict[str, str], config_ids: dict[str, str],
network_ids: dict[str, str], network_ids: dict[str, str],

View File

@ -4,6 +4,7 @@
from __future__ import annotations from __future__ import annotations
import typing as t
import unittest import unittest
from io import StringIO from io import StringIO
from unittest import mock from unittest import mock
@ -40,7 +41,7 @@ class TestDockerConnectionClass(unittest.TestCase):
return_value=("docker version", "1.2.3", "", 0), return_value=("docker version", "1.2.3", "", 0),
) )
def test_docker_connection_module_too_old( def test_docker_connection_module_too_old(
self, mock_new_docker_version, mock_old_docker_version self, mock_new_docker_version: t.Any, mock_old_docker_version: t.Any
) -> None: ) -> None:
self.dc._version = None self.dc._version = None
self.dc.remote_user = "foo" self.dc.remote_user = "foo"
@ -59,7 +60,7 @@ class TestDockerConnectionClass(unittest.TestCase):
return_value=("docker version", "1.7.0", "", 0), return_value=("docker version", "1.7.0", "", 0),
) )
def test_docker_connection_module( def test_docker_connection_module(
self, mock_new_docker_version, mock_old_docker_version self, mock_new_docker_version: t.Any, mock_old_docker_version: t.Any
) -> None: ) -> None:
self.dc._version = None self.dc._version = None
@ -73,7 +74,7 @@ class TestDockerConnectionClass(unittest.TestCase):
return_value=("false", "garbage", "", 1), return_value=("false", "garbage", "", 1),
) )
def test_docker_connection_module_wrong_cmd( def test_docker_connection_module_wrong_cmd(
self, mock_new_docker_version, mock_old_docker_version self, mock_new_docker_version: t.Any, mock_old_docker_version: t.Any
) -> None: ) -> None:
self.dc._version = None self.dc._version = None
self.dc.remote_user = "foo" self.dc.remote_user = "foo"

View File

@ -31,7 +31,7 @@ def templar() -> Templar:
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def inventory(templar) -> InventoryModule: def inventory(templar: Templar) -> InventoryModule:
r = InventoryModule() r = InventoryModule()
r.inventory = InventoryData() r.inventory = InventoryData()
r.templar = templar r.templar = templar
@ -91,7 +91,7 @@ LOVING_THARP_SERVICE = {
def create_get_option( def create_get_option(
options: dict[str, t.Any], default: t.Any = False options: dict[str, t.Any], default: t.Any = False
) -> Callable[[str], t.Any]: ) -> Callable[[str], t.Any]:
def get_option(option): def get_option(option: str) -> t.Any:
if option in options: if option in options:
return options[option] return options[option]
return default return default
@ -116,12 +116,12 @@ class FakeClient:
self.get_results[f"/containers/{host['Id']}/json"] = host self.get_results[f"/containers/{host['Id']}/json"] = host
self.get_results["/containers/json"] = list_reply self.get_results["/containers/json"] = list_reply
def get_json(self, url: str, *param: str, **kwargs) -> t.Any: def get_json(self, url: str, *param: str, **kwargs: t.Any) -> t.Any:
url = url.format(*param) url = url.format(*param)
return self.get_results[url] return self.get_results[url]
def test_populate(inventory: InventoryModule, mocker) -> None: def test_populate(inventory: InventoryModule, mocker: t.Any) -> None:
assert inventory.inventory is not None assert inventory.inventory is not None
client = FakeClient(LOVING_THARP) client = FakeClient(LOVING_THARP)
@ -158,7 +158,7 @@ def test_populate(inventory: InventoryModule, mocker) -> None:
assert len(inventory.inventory.hosts) == 1 assert len(inventory.inventory.hosts) == 1
def test_populate_service(inventory: InventoryModule, mocker) -> None: def test_populate_service(inventory: InventoryModule, mocker: t.Any) -> None:
assert inventory.inventory is not None assert inventory.inventory is not None
client = FakeClient(LOVING_THARP_SERVICE) client = FakeClient(LOVING_THARP_SERVICE)
@ -218,7 +218,7 @@ def test_populate_service(inventory: InventoryModule, mocker) -> None:
assert len(inventory.inventory.hosts) == 1 assert len(inventory.inventory.hosts) == 1
def test_populate_stack(inventory: InventoryModule, mocker) -> None: def test_populate_stack(inventory: InventoryModule, mocker: t.Any) -> None:
assert inventory.inventory is not None assert inventory.inventory is not None
client = FakeClient(LOVING_THARP_STACK) client = FakeClient(LOVING_THARP_STACK)
@ -280,7 +280,7 @@ def test_populate_stack(inventory: InventoryModule, mocker) -> None:
assert len(inventory.inventory.hosts) == 1 assert len(inventory.inventory.hosts) == 1
def test_populate_filter_none(inventory: InventoryModule, mocker) -> None: def test_populate_filter_none(inventory: InventoryModule, mocker: t.Any) -> None:
assert inventory.inventory is not None assert inventory.inventory is not None
client = FakeClient(LOVING_THARP) client = FakeClient(LOVING_THARP)
@ -304,7 +304,7 @@ def test_populate_filter_none(inventory: InventoryModule, mocker) -> None:
assert len(inventory.inventory.hosts) == 0 assert len(inventory.inventory.hosts) == 0
def test_populate_filter(inventory: InventoryModule, mocker) -> None: def test_populate_filter(inventory: InventoryModule, mocker: t.Any) -> None:
assert inventory.inventory is not None assert inventory.inventory is not None
client = FakeClient(LOVING_THARP) client = FakeClient(LOVING_THARP)

View File

@ -43,6 +43,12 @@ from ansible_collections.community.docker.tests.unit.plugins.module_utils._api.c
from .. import fake_api from .. import fake_api
if t.TYPE_CHECKING:
from ansible_collections.community.docker.plugins.module_utils._api.auth import (
AuthConfig,
)
DEFAULT_TIMEOUT_SECONDS = constants.DEFAULT_TIMEOUT_SECONDS DEFAULT_TIMEOUT_SECONDS = constants.DEFAULT_TIMEOUT_SECONDS
@ -52,8 +58,8 @@ def response(
headers: dict[str, str] | None = None, headers: dict[str, str] | None = None,
reason: str = "", reason: str = "",
elapsed: int = 0, elapsed: int = 0,
request=None, request: requests.PreparedRequest | None = None,
raw=None, raw: urllib3.HTTPResponse | None = None,
) -> requests.Response: ) -> requests.Response:
res = requests.Response() res = requests.Response()
res.status_code = status_code res.status_code = status_code
@ -63,18 +69,18 @@ def response(
res.headers = requests.structures.CaseInsensitiveDict(headers or {}) res.headers = requests.structures.CaseInsensitiveDict(headers or {})
res.reason = reason res.reason = reason
res.elapsed = datetime.timedelta(elapsed) res.elapsed = datetime.timedelta(elapsed)
res.request = request res.request = request # type: ignore
res.raw = raw res.raw = raw
return res return res
def fake_resolve_authconfig( # pylint: disable=keyword-arg-before-vararg def fake_resolve_authconfig( # pylint: disable=keyword-arg-before-vararg
authconfig, *args, registry=None, **kwargs authconfig: AuthConfig, *args: t.Any, registry: str | None = None, **kwargs: t.Any
) -> None: ) -> None:
return None return None
def fake_inspect_container(self, container: str, tty: bool = False): def fake_inspect_container(self: object, container: str, tty: bool = False) -> t.Any:
return fake_api.get_fake_inspect_container(tty=tty)[1] return fake_api.get_fake_inspect_container(tty=tty)[1]
@ -95,24 +101,32 @@ def fake_resp(
fake_request = mock.Mock(side_effect=fake_resp) fake_request = mock.Mock(side_effect=fake_resp)
def fake_get(self, url: str, *args, **kwargs) -> requests.Response: def fake_get(
self: APIClient, url: str, *args: str, **kwargs: t.Any
) -> requests.Response:
return fake_request("GET", url, *args, **kwargs) return fake_request("GET", url, *args, **kwargs)
def fake_post(self, url: str, *args, **kwargs) -> requests.Response: def fake_post(
self: APIClient, url: str, *args: str, **kwargs: t.Any
) -> requests.Response:
return fake_request("POST", url, *args, **kwargs) return fake_request("POST", url, *args, **kwargs)
def fake_put(self, url: str, *args, **kwargs) -> requests.Response: def fake_put(
self: APIClient, url: str, *args: str, **kwargs: t.Any
) -> requests.Response:
return fake_request("PUT", url, *args, **kwargs) return fake_request("PUT", url, *args, **kwargs)
def fake_delete(self, url: str, *args, **kwargs) -> requests.Response: def fake_delete(
self: APIClient, url: str, *args: str, **kwargs: t.Any
) -> requests.Response:
return fake_request("DELETE", url, *args, **kwargs) return fake_request("DELETE", url, *args, **kwargs)
def fake_read_from_socket( def fake_read_from_socket(
self, self: APIClient,
response: requests.Response, response: requests.Response,
stream: bool, stream: bool,
tty: bool = False, tty: bool = False,
@ -253,9 +267,9 @@ class DockerApiTest(BaseAPIClientTest):
"serveraddress": None, "serveraddress": None,
} }
def _socket_path_for_client_session(self, client) -> str: def _socket_path_for_client_session(self, client: APIClient) -> str:
socket_adapter = client.get_adapter("http+docker://") socket_adapter = client.get_adapter("http+docker://")
return socket_adapter.socket_path return socket_adapter.socket_path # type: ignore[attr-defined]
def test_url_compatibility_unix(self) -> None: def test_url_compatibility_unix(self) -> None:
c = APIClient(base_url="unix://socket", version=DEFAULT_DOCKER_API_VERSION) c = APIClient(base_url="unix://socket", version=DEFAULT_DOCKER_API_VERSION)
@ -384,7 +398,7 @@ class UnixSocketStreamTest(unittest.TestCase):
finally: finally:
self.server_socket.close() self.server_socket.close()
def early_response_sending_handler(self, connection) -> None: def early_response_sending_handler(self, connection: socket.socket) -> None:
data = b"" data = b""
headers = None headers = None
@ -494,7 +508,7 @@ class TCPSocketStreamTest(unittest.TestCase):
stderr_data = cls.stderr_data stderr_data = cls.stderr_data
class Handler(BaseHTTPRequestHandler): class Handler(BaseHTTPRequestHandler):
def do_POST(self): # pylint: disable=invalid-name def do_POST(self) -> None: # pylint: disable=invalid-name
resp_data = self.get_resp_data() resp_data = self.get_resp_data()
self.send_response(101) self.send_response(101)
self.send_header("Content-Type", "application/vnd.docker.raw-stream") self.send_header("Content-Type", "application/vnd.docker.raw-stream")
@ -506,7 +520,7 @@ class TCPSocketStreamTest(unittest.TestCase):
self.wfile.write(resp_data) self.wfile.write(resp_data)
self.wfile.flush() self.wfile.flush()
def get_resp_data(self): def get_resp_data(self) -> bytes:
path = self.path.split("/")[-1] path = self.path.split("/")[-1]
if path == "tty": if path == "tty":
return stdout_data + stderr_data return stdout_data + stderr_data
@ -520,7 +534,7 @@ class TCPSocketStreamTest(unittest.TestCase):
raise NotImplementedError(f"Unknown path {path}") raise NotImplementedError(f"Unknown path {path}")
@staticmethod @staticmethod
def frame_header(stream, data): def frame_header(stream: int, data: bytes) -> bytes:
return struct.pack(">BxxxL", stream, len(data)) return struct.pack(">BxxxL", stream, len(data))
return Handler return Handler

View File

@ -133,126 +133,102 @@ class ResolveAuthTest(unittest.TestCase):
) )
def test_resolve_authconfig_hostname_only(self) -> None: def test_resolve_authconfig_hostname_only(self) -> None:
assert ( ac = auth.resolve_authconfig(self.auth_config, "my.registry.net")
auth.resolve_authconfig(self.auth_config, "my.registry.net")["username"] assert ac is not None
== "privateuser" assert ac["username"] == "privateuser"
)
def test_resolve_authconfig_no_protocol(self) -> None: def test_resolve_authconfig_no_protocol(self) -> None:
assert ( ac = auth.resolve_authconfig(self.auth_config, "my.registry.net/v1/")
auth.resolve_authconfig(self.auth_config, "my.registry.net/v1/")["username"] assert ac is not None
== "privateuser" assert ac["username"] == "privateuser"
)
def test_resolve_authconfig_no_path(self) -> None: def test_resolve_authconfig_no_path(self) -> None:
assert ( ac = auth.resolve_authconfig(self.auth_config, "http://my.registry.net")
auth.resolve_authconfig(self.auth_config, "http://my.registry.net")[ assert ac is not None
"username" assert ac["username"] == "privateuser"
]
== "privateuser"
)
def test_resolve_authconfig_no_path_trailing_slash(self) -> None: def test_resolve_authconfig_no_path_trailing_slash(self) -> None:
assert ( ac = auth.resolve_authconfig(self.auth_config, "http://my.registry.net/")
auth.resolve_authconfig(self.auth_config, "http://my.registry.net/")[ assert ac is not None
"username" assert ac["username"] == "privateuser"
]
== "privateuser"
)
def test_resolve_authconfig_no_path_wrong_secure_proto(self) -> None: def test_resolve_authconfig_no_path_wrong_secure_proto(self) -> None:
assert ( ac = auth.resolve_authconfig(self.auth_config, "https://my.registry.net")
auth.resolve_authconfig(self.auth_config, "https://my.registry.net")[ assert ac is not None
"username" assert ac["username"] == "privateuser"
]
== "privateuser"
)
def test_resolve_authconfig_no_path_wrong_insecure_proto(self) -> None: def test_resolve_authconfig_no_path_wrong_insecure_proto(self) -> None:
assert ( ac = auth.resolve_authconfig(self.auth_config, "http://index.docker.io")
auth.resolve_authconfig(self.auth_config, "http://index.docker.io")[ assert ac is not None
"username" assert ac["username"] == "indexuser"
]
== "indexuser"
)
def test_resolve_authconfig_path_wrong_proto(self) -> None: def test_resolve_authconfig_path_wrong_proto(self) -> None:
assert ( ac = auth.resolve_authconfig(self.auth_config, "https://my.registry.net/v1/")
auth.resolve_authconfig(self.auth_config, "https://my.registry.net/v1/")[ assert ac is not None
"username" assert ac["username"] == "privateuser"
]
== "privateuser"
)
def test_resolve_authconfig_default_registry(self) -> None: def test_resolve_authconfig_default_registry(self) -> None:
assert auth.resolve_authconfig(self.auth_config)["username"] == "indexuser" ac = auth.resolve_authconfig(self.auth_config)
assert ac is not None
assert ac["username"] == "indexuser"
def test_resolve_authconfig_default_explicit_none(self) -> None: def test_resolve_authconfig_default_explicit_none(self) -> None:
assert ( ac = auth.resolve_authconfig(self.auth_config, None)
auth.resolve_authconfig(self.auth_config, None)["username"] == "indexuser" assert ac is not None
) assert ac["username"] == "indexuser"
def test_resolve_authconfig_fully_explicit(self) -> None: def test_resolve_authconfig_fully_explicit(self) -> None:
assert ( ac = auth.resolve_authconfig(self.auth_config, "http://my.registry.net/v1/")
auth.resolve_authconfig(self.auth_config, "http://my.registry.net/v1/")[ assert ac is not None
"username" assert ac["username"] == "privateuser"
]
== "privateuser"
)
def test_resolve_authconfig_legacy_config(self) -> None: def test_resolve_authconfig_legacy_config(self) -> None:
assert ( ac = auth.resolve_authconfig(self.auth_config, "legacy.registry.url")
auth.resolve_authconfig(self.auth_config, "legacy.registry.url")["username"] assert ac is not None
== "legacyauth" assert ac["username"] == "legacyauth"
)
def test_resolve_authconfig_no_match(self) -> None: def test_resolve_authconfig_no_match(self) -> None:
assert auth.resolve_authconfig(self.auth_config, "does.not.exist") is None assert auth.resolve_authconfig(self.auth_config, "does.not.exist") is None
def test_resolve_registry_and_auth_library_image(self) -> None: def test_resolve_registry_and_auth_library_image(self) -> None:
image = "image" image = "image"
assert ( ac = auth.resolve_authconfig(
auth.resolve_authconfig( self.auth_config, auth.resolve_repository_name(image)[0]
self.auth_config, auth.resolve_repository_name(image)[0]
)["username"]
== "indexuser"
) )
assert ac is not None
assert ac["username"] == "indexuser"
def test_resolve_registry_and_auth_hub_image(self) -> None: def test_resolve_registry_and_auth_hub_image(self) -> None:
image = "username/image" image = "username/image"
assert ( ac = auth.resolve_authconfig(
auth.resolve_authconfig( self.auth_config, auth.resolve_repository_name(image)[0]
self.auth_config, auth.resolve_repository_name(image)[0]
)["username"]
== "indexuser"
) )
assert ac is not None
assert ac["username"] == "indexuser"
def test_resolve_registry_and_auth_explicit_hub(self) -> None: def test_resolve_registry_and_auth_explicit_hub(self) -> None:
image = "docker.io/username/image" image = "docker.io/username/image"
assert ( ac = auth.resolve_authconfig(
auth.resolve_authconfig( self.auth_config, auth.resolve_repository_name(image)[0]
self.auth_config, auth.resolve_repository_name(image)[0]
)["username"]
== "indexuser"
) )
assert ac is not None
assert ac["username"] == "indexuser"
def test_resolve_registry_and_auth_explicit_legacy_hub(self) -> None: def test_resolve_registry_and_auth_explicit_legacy_hub(self) -> None:
image = "index.docker.io/username/image" image = "index.docker.io/username/image"
assert ( ac = auth.resolve_authconfig(
auth.resolve_authconfig( self.auth_config, auth.resolve_repository_name(image)[0]
self.auth_config, auth.resolve_repository_name(image)[0]
)["username"]
== "indexuser"
) )
assert ac is not None
assert ac["username"] == "indexuser"
def test_resolve_registry_and_auth_private_registry(self) -> None: def test_resolve_registry_and_auth_private_registry(self) -> None:
image = "my.registry.net/image" image = "my.registry.net/image"
assert ( ac = auth.resolve_authconfig(
auth.resolve_authconfig( self.auth_config, auth.resolve_repository_name(image)[0]
self.auth_config, auth.resolve_repository_name(image)[0]
)["username"]
== "privateuser"
) )
assert ac is not None
assert ac["username"] == "privateuser"
def test_resolve_registry_and_auth_unauthenticated_registry(self) -> None: def test_resolve_registry_and_auth_unauthenticated_registry(self) -> None:
image = "other.registry.net/image" image = "other.registry.net/image"
@ -278,7 +254,9 @@ class ResolveAuthTest(unittest.TestCase):
"ansible_collections.community.docker.plugins.module_utils._api.auth.AuthConfig._resolve_authconfig_credstore" "ansible_collections.community.docker.plugins.module_utils._api.auth.AuthConfig._resolve_authconfig_credstore"
) as m: ) as m:
m.return_value = None m.return_value = None
assert "indexuser" == auth.resolve_authconfig(auth_config, None)["username"] ac = auth.resolve_authconfig(auth_config, None)
assert ac is not None
assert "indexuser" == ac["username"]
class LoadConfigTest(unittest.TestCase): class LoadConfigTest(unittest.TestCase):
@ -797,7 +775,7 @@ class CredstoreTest(unittest.TestCase):
class InMemoryStore(Store): class InMemoryStore(Store):
def __init__( # pylint: disable=super-init-not-called def __init__( # pylint: disable=super-init-not-called
self, *args, **kwargs self, *args: t.Any, **kwargs: t.Any
) -> None: ) -> None:
self.__store: dict[str | bytes, dict[str, t.Any]] = {} self.__store: dict[str | bytes, dict[str, t.Any]] = {}

View File

@ -156,7 +156,7 @@ class ExcludePathsTest(unittest.TestCase):
def test_single_filename_trailing_slash(self) -> None: def test_single_filename_trailing_slash(self) -> None:
assert self.exclude(["a.py/"]) == convert_paths(self.all_paths - set(["a.py"])) assert self.exclude(["a.py/"]) == convert_paths(self.all_paths - set(["a.py"]))
def test_wildcard_filename_start(self): def test_wildcard_filename_start(self) -> None:
assert self.exclude(["*.py"]) == convert_paths( assert self.exclude(["*.py"]) == convert_paths(
self.all_paths - set(["a.py", "b.py", "cde.py"]) self.all_paths - set(["a.py", "b.py", "cde.py"])
) )

View File

@ -12,6 +12,7 @@ import json
import os import os
import shutil import shutil
import tempfile import tempfile
import typing as t
import unittest import unittest
from collections.abc import Callable from collections.abc import Callable
from unittest import mock from unittest import mock
@ -25,7 +26,7 @@ class FindConfigFileTest(unittest.TestCase):
mkdir: Callable[[str], os.PathLike[str]] mkdir: Callable[[str], os.PathLike[str]]
@fixture(autouse=True) @fixture(autouse=True)
def tmpdir(self, tmpdir) -> None: def tmpdir(self, tmpdir: t.Any) -> None:
self.mkdir = tmpdir.mkdir self.mkdir = tmpdir.mkdir
def test_find_config_fallback(self) -> None: def test_find_config_fallback(self) -> None:

View File

@ -8,6 +8,7 @@
from __future__ import annotations from __future__ import annotations
import typing as t
import unittest import unittest
from ansible_collections.community.docker.plugins.module_utils._api.api.client import ( from ansible_collections.community.docker.plugins.module_utils._api.api.client import (
@ -27,7 +28,7 @@ class DecoratorsTest(unittest.TestCase):
"X-Docker-Locale": "en-US", "X-Docker-Locale": "en-US",
} }
def f(self, headers=None): def f(self: t.Any, headers: t.Any = None) -> t.Any:
return headers return headers
client = APIClient(version=DEFAULT_DOCKER_API_VERSION) client = APIClient(version=DEFAULT_DOCKER_API_VERSION)

View File

@ -469,7 +469,7 @@ class FormatEnvironmentTest(unittest.TestCase):
env_dict = {"ARTIST_NAME": b"\xec\x86\xa1\xec\xa7\x80\xec\x9d\x80"} env_dict = {"ARTIST_NAME": b"\xec\x86\xa1\xec\xa7\x80\xec\x9d\x80"}
assert format_environment(env_dict) == ["ARTIST_NAME=송지은"] assert format_environment(env_dict) == ["ARTIST_NAME=송지은"]
def test_format_env_no_value(self): def test_format_env_no_value(self) -> None:
env_dict = { env_dict = {
"FOO": None, "FOO": None,
"BAR": "", "BAR": "",

View File

@ -369,7 +369,7 @@ def test_parse_events(
) -> None: ) -> None:
collected_warnings = [] collected_warnings = []
def collect_warning(msg): def collect_warning(msg: str) -> None:
collected_warnings.append(msg) collected_warnings.append(msg)
collected_events = parse_events( collected_events = parse_events(

View File

@ -5,6 +5,7 @@
from __future__ import annotations from __future__ import annotations
import tarfile import tarfile
import typing as t
import pytest import pytest
@ -22,7 +23,7 @@ from ..test_support.docker_image_archive_stubbing import (
@pytest.fixture @pytest.fixture
def tar_file_name(tmpdir) -> str: def tar_file_name(tmpdir: t.Any) -> str:
""" """
Return the name of a non-existing tar file in an existing temporary directory. Return the name of a non-existing tar file in an existing temporary directory.
""" """
@ -38,7 +39,7 @@ def test_api_image_id_from_archive_id(expected: str, value: str) -> None:
assert api_image_id(value) == expected assert api_image_id(value) == expected
def test_archived_image_manifest_extracts(tar_file_name) -> None: def test_archived_image_manifest_extracts(tar_file_name: str) -> None:
expected_id = "abcde12345" expected_id = "abcde12345"
expected_tags = ["foo:latest", "bar:v1"] expected_tags = ["foo:latest", "bar:v1"]
@ -52,7 +53,7 @@ def test_archived_image_manifest_extracts(tar_file_name) -> None:
def test_archived_image_manifest_extracts_nothing_when_file_not_present( def test_archived_image_manifest_extracts_nothing_when_file_not_present(
tar_file_name, tar_file_name: str,
) -> None: ) -> None:
image_id = archived_image_manifest(tar_file_name) image_id = archived_image_manifest(tar_file_name)
@ -69,7 +70,7 @@ def test_archived_image_manifest_raises_when_file_not_a_tar() -> None:
def test_archived_image_manifest_raises_when_tar_missing_manifest( def test_archived_image_manifest_raises_when_tar_missing_manifest(
tar_file_name, tar_file_name: str,
) -> None: ) -> None:
write_irrelevant_tar(tar_file_name) write_irrelevant_tar(tar_file_name)
@ -81,7 +82,9 @@ def test_archived_image_manifest_raises_when_tar_missing_manifest(
assert "manifest.json" in str(e.__cause__) assert "manifest.json" in str(e.__cause__)
def test_archived_image_manifest_raises_when_manifest_missing_id(tar_file_name) -> None: def test_archived_image_manifest_raises_when_manifest_missing_id(
tar_file_name: str,
) -> None:
manifest = [{"foo": "bar"}] manifest = [{"foo": "bar"}]
write_imitation_archive_with_manifest(tar_file_name, manifest) write_imitation_archive_with_manifest(tar_file_name, manifest)

View File

@ -38,7 +38,7 @@ def capture_logging(messages: list[str]) -> Callable[[str], None]:
@pytest.fixture @pytest.fixture
def tar_file_name(tmpdir): def tar_file_name(tmpdir: t.Any) -> str:
""" """
Return the name of a non-existing tar file in an existing temporary directory. Return the name of a non-existing tar file in an existing temporary directory.
""" """
@ -46,7 +46,7 @@ def tar_file_name(tmpdir):
return tmpdir.join("foo.tar") return tmpdir.join("foo.tar")
def test_archived_image_action_when_missing(tar_file_name) -> None: def test_archived_image_action_when_missing(tar_file_name: str) -> None:
fake_name = "a:latest" fake_name = "a:latest"
fake_id = "a1" fake_id = "a1"
@ -59,7 +59,7 @@ def test_archived_image_action_when_missing(tar_file_name) -> None:
assert actual == expected assert actual == expected
def test_archived_image_action_when_current(tar_file_name) -> None: def test_archived_image_action_when_current(tar_file_name: str) -> None:
fake_name = "b:latest" fake_name = "b:latest"
fake_id = "b2" fake_id = "b2"
@ -72,7 +72,7 @@ def test_archived_image_action_when_current(tar_file_name) -> None:
assert actual is None assert actual is None
def test_archived_image_action_when_invalid(tar_file_name) -> None: def test_archived_image_action_when_invalid(tar_file_name: str) -> None:
fake_name = "c:1.2.3" fake_name = "c:1.2.3"
fake_id = "c3" fake_id = "c3"
@ -91,7 +91,7 @@ def test_archived_image_action_when_invalid(tar_file_name) -> None:
assert actual_log[0].startswith("Unable to extract manifest summary from archive") assert actual_log[0].startswith("Unable to extract manifest summary from archive")
def test_archived_image_action_when_obsolete_by_id(tar_file_name) -> None: def test_archived_image_action_when_obsolete_by_id(tar_file_name: str) -> None:
fake_name = "d:0.0.1" fake_name = "d:0.0.1"
old_id = "e5" old_id = "e5"
new_id = "d4" new_id = "d4"
@ -106,7 +106,7 @@ def test_archived_image_action_when_obsolete_by_id(tar_file_name) -> None:
assert actual == expected assert actual == expected
def test_archived_image_action_when_obsolete_by_name(tar_file_name) -> None: def test_archived_image_action_when_obsolete_by_name(tar_file_name: str) -> None:
old_name = "hi" old_name = "hi"
new_name = "d:0.0.1" new_name = "d:0.0.1"
fake_id = "d4" fake_id = "d4"

View File

@ -16,7 +16,7 @@ from ansible_collections.community.docker.plugins.modules import (
APIError = pytest.importorskip("docker.errors.APIError") APIError = pytest.importorskip("docker.errors.APIError")
def test_retry_on_out_of_sequence_error(mocker) -> None: def test_retry_on_out_of_sequence_error(mocker: t.Any) -> None:
run_mock = mocker.MagicMock( run_mock = mocker.MagicMock(
side_effect=APIError( side_effect=APIError(
message="", message="",
@ -32,7 +32,7 @@ def test_retry_on_out_of_sequence_error(mocker) -> None:
assert run_mock.call_count == 3 assert run_mock.call_count == 3
def test_no_retry_on_general_api_error(mocker) -> None: def test_no_retry_on_general_api_error(mocker: t.Any) -> None:
run_mock = mocker.MagicMock( run_mock = mocker.MagicMock(
side_effect=APIError(message="", response=None, explanation="some error") side_effect=APIError(message="", response=None, explanation="some error")
) )
@ -44,7 +44,7 @@ def test_no_retry_on_general_api_error(mocker) -> None:
assert run_mock.call_count == 1 assert run_mock.call_count == 1
def test_get_docker_environment(mocker) -> None: def test_get_docker_environment(mocker: t.Any) -> None:
env_file_result = {"TEST1": "A", "TEST2": "B", "TEST3": "C"} env_file_result = {"TEST1": "A", "TEST2": "B", "TEST3": "C"}
env_dict = {"TEST3": "CC", "TEST4": "D"} env_dict = {"TEST3": "CC", "TEST4": "D"}
env_string = "TEST3=CC,TEST4=D" env_string = "TEST3=CC,TEST4=D"