Add more typing.

This commit is contained in:
Felix Fontein 2025-10-25 00:09:57 +02:00
parent a2deb384d4
commit 931ae7978c
48 changed files with 430 additions and 314 deletions

View File

@ -3,15 +3,15 @@
# SPDX-License-Identifier: GPL-3.0-or-later
[mypy]
# check_untyped_defs = True -- for later
# disallow_untyped_defs = True -- for later
check_untyped_defs = True
disallow_untyped_defs = True
# strict = True -- only try to enable once everything (including dependencies!) is typed
strict_equality = True
strict_bytes = True
warn_redundant_casts = True
# warn_return_any = True -- for later
# warn_return_any = True
warn_unreachable = True
[mypy-ansible.*]

View File

@ -141,7 +141,7 @@ class Connection(ConnectionBase):
transport = "community.docker.docker"
has_pipelining = True
def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs)
# Note: docker supports running as non-root in some configurations.
@ -476,7 +476,7 @@ class Connection(ConnectionBase):
display.debug("done with docker.exec_command()")
return (p.returncode, stdout, stderr)
def _prefix_login_path(self, remote_path):
def _prefix_login_path(self, remote_path: str) -> str:
"""Make sure that we put files into a standard path
If a path is relative, then we need to choose where to put it.

View File

@ -192,7 +192,7 @@ class Connection(ConnectionBase):
f'An unexpected requests error occurred for container "{remote_addr}" when trying to talk to the Docker daemon: {e}'
)
def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs)
self.client: AnsibleDockerClient | None = None
@ -319,7 +319,7 @@ class Connection(ConnectionBase):
become_output = [b""]
def append_become_output(stream_id, data):
def append_become_output(stream_id: int, data: bytes) -> None:
become_output[0] += data
exec_socket_handler.set_block_done_callback(

View File

@ -65,7 +65,7 @@ class Connection(ConnectionBase):
transport = "community.docker.nsenter"
has_pipelining = False
def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs)
self.cwd = None
self._nsenter_pid = None

View File

@ -221,7 +221,10 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
return ip_addr
def _should_skip_host(
self, machine_name: str, env_var_tuples, daemon_env: DaemonEnv
self,
machine_name: str,
env_var_tuples: list[tuple[str, str]],
daemon_env: DaemonEnv,
) -> bool:
if not env_var_tuples:
warning_prefix = f"Unable to fetch Docker daemon env vars from Docker Machine for host {machine_name}"

View File

@ -67,7 +67,7 @@ except ImportError:
pass
class FakeURLLIB3:
def __init__(self):
def __init__(self) -> None:
self._collections = self
self.poolmanager = self
self.connection = self
@ -81,14 +81,14 @@ except ImportError:
)
class FakeURLLIB3Connection:
def __init__(self):
def __init__(self) -> None:
self.HTTPConnection = _HTTPConnection # pylint: disable=invalid-name
urllib3 = FakeURLLIB3()
urllib3_connection = FakeURLLIB3Connection()
def fail_on_missing_imports():
def fail_on_missing_imports() -> None:
if REQUESTS_IMPORT_ERROR is not None:
from .errors import MissingRequirementException # pylint: disable=cyclic-import

View File

@ -55,6 +55,7 @@ from ..utils.socket import consume_socket_output, demux_adaptor, frames_iter
if t.TYPE_CHECKING:
from requests import Response
from requests.adapters import BaseAdapter
from ..._socket_helper import SocketLike
@ -258,23 +259,23 @@ class APIClient(_Session):
return kwargs
@update_headers
def _post(self, url: str, **kwargs):
def _post(self, url: str, **kwargs: t.Any) -> Response:
return self.post(url, **self._set_request_timeout(kwargs))
@update_headers
def _get(self, url: str, **kwargs):
def _get(self, url: str, **kwargs: t.Any) -> Response:
return self.get(url, **self._set_request_timeout(kwargs))
@update_headers
def _head(self, url: str, **kwargs):
def _head(self, url: str, **kwargs: t.Any) -> Response:
return self.head(url, **self._set_request_timeout(kwargs))
@update_headers
def _put(self, url: str, **kwargs):
def _put(self, url: str, **kwargs: t.Any) -> Response:
return self.put(url, **self._set_request_timeout(kwargs))
@update_headers
def _delete(self, url: str, **kwargs):
def _delete(self, url: str, **kwargs: t.Any) -> Response:
return self.delete(url, **self._set_request_timeout(kwargs))
def _url(self, pathfmt: str, *args: str, versioned_api: bool = True) -> str:
@ -343,7 +344,7 @@ class APIClient(_Session):
return response.text
def _post_json(
self, url: str, data: dict[str, str | None] | t.Any, **kwargs
self, url: str, data: dict[str, str | None] | t.Any, **kwargs: t.Any
) -> Response:
# Go <1.1 cannot unserialize null to a string
# so we do this disgusting thing here.
@ -556,22 +557,30 @@ class APIClient(_Session):
"""
socket = self._get_raw_response_socket(response)
gen: t.Generator = frames_iter(socket, tty)
gen = frames_iter(socket, tty)
if demux:
# The generator will output tuples (stdout, stderr)
gen = (demux_adaptor(*frame) for frame in gen)
demux_gen: t.Generator[tuple[bytes | None, bytes | None]] = (
demux_adaptor(*frame) for frame in gen
)
if stream:
return demux_gen
try:
# Wait for all the frames, concatenate them, and return the result
return consume_socket_output(demux_gen, demux=True)
finally:
response.close()
else:
# The generator will output strings
gen = (data for (dummy, data) in gen)
if stream:
return gen
try:
# Wait for all the frames, concatenate them, and return the result
return consume_socket_output(gen, demux=demux)
finally:
response.close()
mux_gen: t.Generator[bytes] = (data for (dummy, data) in gen)
if stream:
return mux_gen
try:
# Wait for all the frames, concatenate them, and return the result
return consume_socket_output(mux_gen, demux=False)
finally:
response.close()
def _disable_socket_timeout(self, socket: SocketLike) -> None:
"""Depending on the combination of python version and whether we are
@ -637,11 +646,11 @@ class APIClient(_Session):
return self._multiplexed_response_stream_helper(res)
return sep.join(list(self._multiplexed_buffer_helper(res)))
def _unmount(self, *args) -> None:
def _unmount(self, *args: t.Any) -> None:
for proto in args:
self.adapters.pop(proto)
def get_adapter(self, url: str):
def get_adapter(self, url: str) -> BaseAdapter:
try:
return super().get_adapter(url)
except _InvalidSchema as e:
@ -696,19 +705,19 @@ class APIClient(_Session):
else:
log.debug("No auth config found")
def get_binary(self, pathfmt: str, *args: str, **kwargs) -> bytes:
def get_binary(self, pathfmt: str, *args: str, **kwargs: t.Any) -> bytes:
return self._result(
self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_binary=True,
)
def get_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any:
def get_json(self, pathfmt: str, *args: str, **kwargs: t.Any) -> t.Any:
return self._result(
self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_json=True,
)
def get_text(self, pathfmt: str, *args: str, **kwargs) -> str:
def get_text(self, pathfmt: str, *args: str, **kwargs: t.Any) -> str:
return self._result(
self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs)
)
@ -718,7 +727,7 @@ class APIClient(_Session):
pathfmt: str,
*args: str,
chunk_size: int = DEFAULT_DATA_CHUNK_SIZE,
**kwargs,
**kwargs: t.Any,
) -> t.Generator[bytes]:
res = self._get(
self._url(pathfmt, *args, versioned_api=True), stream=True, **kwargs
@ -726,23 +735,25 @@ class APIClient(_Session):
self._raise_for_status(res)
return self._stream_raw_result(res, chunk_size=chunk_size, decode=False)
def delete_call(self, pathfmt: str, *args: str, **kwargs) -> None:
def delete_call(self, pathfmt: str, *args: str, **kwargs: t.Any) -> None:
self._raise_for_status(
self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs)
)
def delete_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any:
def delete_json(self, pathfmt: str, *args: str, **kwargs: t.Any) -> t.Any:
return self._result(
self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_json=True,
)
def post_call(self, pathfmt: str, *args: str, **kwargs) -> None:
def post_call(self, pathfmt: str, *args: str, **kwargs: t.Any) -> None:
self._raise_for_status(
self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs)
)
def post_json(self, pathfmt: str, *args: str, data: t.Any = None, **kwargs) -> None:
def post_json(
self, pathfmt: str, *args: str, data: t.Any = None, **kwargs: t.Any
) -> None:
self._raise_for_status(
self._post_json(
self._url(pathfmt, *args, versioned_api=True), data, **kwargs
@ -750,7 +761,7 @@ class APIClient(_Session):
)
def post_json_to_binary(
self, pathfmt: str, *args: str, data: t.Any = None, **kwargs
self, pathfmt: str, *args: str, data: t.Any = None, **kwargs: t.Any
) -> bytes:
return self._result(
self._post_json(
@ -760,7 +771,7 @@ class APIClient(_Session):
)
def post_json_to_json(
self, pathfmt: str, *args: str, data: t.Any = None, **kwargs
self, pathfmt: str, *args: str, data: t.Any = None, **kwargs: t.Any
) -> t.Any:
return self._result(
self._post_json(
@ -770,7 +781,7 @@ class APIClient(_Session):
)
def post_json_to_text(
self, pathfmt: str, *args: str, data: t.Any = None, **kwargs
self, pathfmt: str, *args: str, data: t.Any = None, **kwargs: t.Any
) -> str:
return self._result(
self._post_json(
@ -784,7 +795,7 @@ class APIClient(_Session):
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
**kwargs,
**kwargs: t.Any,
) -> SocketLike:
headers = headers.copy() if headers else {}
headers.update(
@ -813,7 +824,7 @@ class APIClient(_Session):
stream: t.Literal[True],
tty: bool = True,
demux: t.Literal[False] = False,
**kwargs,
**kwargs: t.Any,
) -> t.Generator[bytes]: ...
@t.overload
@ -826,7 +837,7 @@ class APIClient(_Session):
stream: t.Literal[True],
tty: t.Literal[True] = True,
demux: t.Literal[True],
**kwargs,
**kwargs: t.Any,
) -> t.Generator[tuple[bytes, None]]: ...
@t.overload
@ -839,7 +850,7 @@ class APIClient(_Session):
stream: t.Literal[True],
tty: t.Literal[False],
demux: t.Literal[True],
**kwargs,
**kwargs: t.Any,
) -> t.Generator[tuple[bytes, None] | tuple[None, bytes]]: ...
@t.overload
@ -852,7 +863,7 @@ class APIClient(_Session):
stream: t.Literal[False],
tty: bool = True,
demux: t.Literal[False] = False,
**kwargs,
**kwargs: t.Any,
) -> bytes: ...
@t.overload
@ -865,7 +876,7 @@ class APIClient(_Session):
stream: t.Literal[False],
tty: t.Literal[True] = True,
demux: t.Literal[True],
**kwargs,
**kwargs: t.Any,
) -> tuple[bytes, None]: ...
@t.overload
@ -878,7 +889,7 @@ class APIClient(_Session):
stream: t.Literal[False],
tty: t.Literal[False],
demux: t.Literal[True],
**kwargs,
**kwargs: t.Any,
) -> tuple[bytes, bytes]: ...
def post_json_to_stream(
@ -890,7 +901,7 @@ class APIClient(_Session):
stream: bool = False,
demux: bool = False,
tty: bool = False,
**kwargs,
**kwargs: t.Any,
) -> t.Any:
headers = headers.copy() if headers else {}
headers.update(
@ -912,7 +923,7 @@ class APIClient(_Session):
demux=demux,
)
def post_to_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any:
def post_to_json(self, pathfmt: str, *args: str, **kwargs: t.Any) -> t.Any:
return self._result(
self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_json=True,

View File

@ -106,7 +106,7 @@ class AuthConfig(dict):
@classmethod
def parse_auth(
cls, entries: dict[str, dict[str, t.Any]], raise_on_error=False
cls, entries: dict[str, dict[str, t.Any]], raise_on_error: bool = False
) -> dict[str, dict[str, t.Any]]:
"""
Parses authentication entries
@ -294,7 +294,7 @@ class AuthConfig(dict):
except StoreError as e:
raise errors.DockerException(f"Credentials store error: {e}")
def _get_store_instance(self, name: str):
def _get_store_instance(self, name: str) -> Store:
if name not in self._stores:
self._stores[name] = Store(name, environment=self._credstore_env)
return self._stores[name]
@ -326,8 +326,10 @@ class AuthConfig(dict):
def resolve_authconfig(
authconfig, registry: str | None = None, credstore_env: dict[str, str] | None = None
):
authconfig: AuthConfig | dict[str, t.Any],
registry: str | None = None,
credstore_env: dict[str, str] | None = None,
) -> dict[str, t.Any] | None:
if not isinstance(authconfig, AuthConfig):
authconfig = AuthConfig(authconfig, credstore_env)
return authconfig.resolve_authconfig(registry)

View File

@ -89,7 +89,7 @@ def get_meta_dir(name: str | None = None) -> str:
return meta_dir
def get_meta_file(name) -> str:
def get_meta_file(name: str) -> str:
return os.path.join(get_meta_dir(name), METAFILE)

View File

@ -18,6 +18,10 @@ from ansible.module_utils.common.text.converters import to_native
from ._import_helper import HTTPError as _HTTPError
if t.TYPE_CHECKING:
from requests import Response
class DockerException(Exception):
"""
A base class from which all other exceptions inherit.
@ -55,7 +59,10 @@ class APIError(_HTTPError, DockerException):
"""
def __init__(
self, message: str | Exception, response=None, explanation: str | None = None
self,
message: str | Exception,
response: Response | None = None,
explanation: str | None = None,
) -> None:
# requests 1.2 supports response as a keyword argument, but
# requests 1.1 does not

View File

@ -11,6 +11,7 @@
from __future__ import annotations
import typing as t
from queue import Empty
from .. import constants
@ -19,6 +20,12 @@ from .basehttpadapter import BaseHTTPAdapter
from .npipesocket import NpipeSocket
if t.TYPE_CHECKING:
from collections.abc import Mapping
from requests import PreparedRequest
RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
@ -91,7 +98,9 @@ class NpipeHTTPAdapter(BaseHTTPAdapter):
)
super().__init__()
def get_connection(self, url: str | bytes, proxies=None) -> NpipeHTTPConnectionPool:
def get_connection(
self, url: str | bytes, proxies: Mapping[str, str] | None = None
) -> NpipeHTTPConnectionPool:
with self.pools.lock:
pool = self.pools.get(url)
if pool:
@ -104,7 +113,9 @@ class NpipeHTTPAdapter(BaseHTTPAdapter):
return pool
def request_url(self, request, proxies) -> str:
def request_url(
self, request: PreparedRequest, proxies: Mapping[str, str] | None
) -> str:
# The select_proxy utility in requests errors out when the provided URL
# does not have a hostname, like is the case when using a UNIX socket.
# Since proxies are an irrelevant notion in the case of UNIX sockets

View File

@ -64,7 +64,7 @@ class NpipeSocket:
implemented.
"""
def __init__(self, handle=None) -> None:
def __init__(self, handle: t.Any | None = None) -> None:
self._timeout = win32pipe.NMPWAIT_USE_DEFAULT_WAIT
self._handle = handle
self._address: str | None = None
@ -74,15 +74,17 @@ class NpipeSocket:
def accept(self) -> t.NoReturn:
raise NotImplementedError()
def bind(self, address) -> t.NoReturn:
def bind(self, address: t.Any) -> t.NoReturn:
raise NotImplementedError()
def close(self) -> None:
if self._handle is None:
raise ValueError("Handle not present")
self._handle.Close()
self._closed = True
@check_closed
def connect(self, address, retry_count: int = 0) -> None:
def connect(self, address: str, retry_count: int = 0) -> None:
try:
handle = win32file.CreateFile(
address,
@ -116,11 +118,11 @@ class NpipeSocket:
self._address = address
@check_closed
def connect_ex(self, address) -> None:
def connect_ex(self, address: str) -> None:
self.connect(address)
@check_closed
def detach(self):
def detach(self) -> t.Any:
self._closed = True
return self._handle
@ -134,16 +136,18 @@ class NpipeSocket:
def getsockname(self) -> str | None:
return self._address
def getsockopt(self, level, optname, buflen=None) -> t.NoReturn:
def getsockopt(
self, level: t.Any, optname: t.Any, buflen: t.Any = None
) -> t.NoReturn:
raise NotImplementedError()
def ioctl(self, control, option) -> t.NoReturn:
def ioctl(self, control: t.Any, option: t.Any) -> t.NoReturn:
raise NotImplementedError()
def listen(self, backlog) -> t.NoReturn:
def listen(self, backlog: t.Any) -> t.NoReturn:
raise NotImplementedError()
def makefile(self, mode: str, bufsize: int | None = None):
def makefile(self, mode: str, bufsize: int | None = None) -> t.IO[bytes]:
if mode.strip("b") != "r":
raise NotImplementedError()
rawio = NpipeFileIOBase(self)
@ -153,6 +157,8 @@ class NpipeSocket:
@check_closed
def recv(self, bufsize: int, flags: int = 0) -> str:
if self._handle is None:
raise ValueError("Handle not present")
dummy_err, data = win32file.ReadFile(self._handle, bufsize)
return data
@ -169,6 +175,8 @@ class NpipeSocket:
@check_closed
def recv_into(self, buf: Buffer, nbytes: int = 0) -> int:
if self._handle is None:
raise ValueError("Handle not present")
readbuf = buf if isinstance(buf, memoryview) else memoryview(buf)
event = win32event.CreateEvent(None, True, True, None)
@ -188,6 +196,8 @@ class NpipeSocket:
@check_closed
def send(self, string: Buffer, flags: int = 0) -> int:
if self._handle is None:
raise ValueError("Handle not present")
event = win32event.CreateEvent(None, True, True, None)
try:
overlapped = pywintypes.OVERLAPPED()
@ -210,7 +220,7 @@ class NpipeSocket:
self.connect(address)
return self.send(string)
def setblocking(self, flag: bool):
def setblocking(self, flag: bool) -> None:
if flag:
return self.settimeout(None)
return self.settimeout(0)
@ -228,16 +238,16 @@ class NpipeSocket:
def gettimeout(self) -> int | float | None:
return self._timeout
def setsockopt(self, level, optname, value) -> t.NoReturn:
def setsockopt(self, level: t.Any, optname: t.Any, value: t.Any) -> t.NoReturn:
raise NotImplementedError()
@check_closed
def shutdown(self, how) -> None:
def shutdown(self, how: t.Any) -> None:
return self.close()
class NpipeFileIOBase(io.RawIOBase):
def __init__(self, npipe_socket) -> None:
def __init__(self, npipe_socket: NpipeSocket | None) -> None:
self.sock = npipe_socket
def close(self) -> None:
@ -245,7 +255,10 @@ class NpipeFileIOBase(io.RawIOBase):
self.sock = None
def fileno(self) -> int:
return self.sock.fileno()
if self.sock is None:
raise RuntimeError("socket is closed")
# TODO: This is definitely a bug, NpipeSocket.fileno() does not exist!
return self.sock.fileno() # type: ignore
def isatty(self) -> bool:
return False
@ -254,6 +267,8 @@ class NpipeFileIOBase(io.RawIOBase):
return True
def readinto(self, buf: Buffer) -> int:
if self.sock is None:
raise RuntimeError("socket is closed")
return self.sock.recv_into(buf)
def seekable(self) -> bool:

View File

@ -35,7 +35,7 @@ else:
PARAMIKO_IMPORT_ERROR = None # pylint: disable=invalid-name
if t.TYPE_CHECKING:
from collections.abc import Buffer
from collections.abc import Buffer, Mapping
RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
@ -67,7 +67,7 @@ class SSHSocket(socket.socket):
preexec_func = None
if not constants.IS_WINDOWS_PLATFORM:
def f():
def f() -> None:
signal.signal(signal.SIGINT, signal.SIG_IGN)
preexec_func = f
@ -100,13 +100,13 @@ class SSHSocket(socket.socket):
self.proc.stdin.flush()
return written
def sendall(self, data: Buffer, *args, **kwargs) -> None:
def sendall(self, data: Buffer, *args: t.Any, **kwargs: t.Any) -> None:
self._write(data)
def send(self, data: Buffer, *args, **kwargs) -> int:
def send(self, data: Buffer, *args: t.Any, **kwargs: t.Any) -> int:
return self._write(data)
def recv(self, n: int, *args, **kwargs) -> bytes:
def recv(self, n: int, *args: t.Any, **kwargs: t.Any) -> bytes:
if not self.proc:
raise RuntimeError(
"SSH subprocess not initiated. connect() must be called first."
@ -114,7 +114,7 @@ class SSHSocket(socket.socket):
assert self.proc.stdout is not None
return self.proc.stdout.read(n)
def makefile(self, mode: str, *args, **kwargs) -> t.IO: # type: ignore
def makefile(self, mode: str, *args: t.Any, **kwargs: t.Any) -> t.IO: # type: ignore
if not self.proc:
self.connect()
assert self.proc is not None
@ -138,7 +138,7 @@ class SSHConnection(urllib3_connection.HTTPConnection):
def __init__(
self,
*,
ssh_transport=None,
ssh_transport: paramiko.Transport | None = None,
timeout: int | float = 60,
host: str,
) -> None:
@ -146,18 +146,19 @@ class SSHConnection(urllib3_connection.HTTPConnection):
self.ssh_transport = ssh_transport
self.timeout = timeout
self.ssh_host = host
self.sock: paramiko.Channel | SSHSocket | None = None
def connect(self) -> None:
if self.ssh_transport:
sock = self.ssh_transport.open_session()
sock.settimeout(self.timeout)
sock.exec_command("docker system dial-stdio")
channel = self.ssh_transport.open_session()
channel.settimeout(self.timeout)
channel.exec_command("docker system dial-stdio")
self.sock = channel
else:
sock = SSHSocket(self.ssh_host)
sock.settimeout(self.timeout)
sock.connect()
self.sock = sock
self.sock = sock
class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
@ -172,7 +173,7 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
host: str,
) -> None:
super().__init__("localhost", timeout=timeout, maxsize=maxsize)
self.ssh_transport = None
self.ssh_transport: paramiko.Transport | None = None
self.timeout = timeout
if ssh_client:
self.ssh_transport = ssh_client.get_transport()
@ -276,7 +277,9 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
if self.ssh_client:
self.ssh_client.connect(**self.ssh_params)
def get_connection(self, url: str | bytes, proxies=None) -> SSHConnectionPool:
def get_connection(
self, url: str | bytes, proxies: Mapping[str, str] | None = None
) -> SSHConnectionPool:
if not self.ssh_client:
return SSHConnectionPool(
ssh_client=self.ssh_client,

View File

@ -33,7 +33,7 @@ class SSLHTTPAdapter(BaseHTTPAdapter):
def __init__(
self,
assert_hostname: bool | None = None,
**kwargs,
**kwargs: t.Any,
) -> None:
self.assert_hostname = assert_hostname
super().__init__(**kwargs)
@ -51,7 +51,7 @@ class SSLHTTPAdapter(BaseHTTPAdapter):
self.poolmanager = PoolManager(**kwargs)
def get_connection(self, *args, **kwargs) -> urllib3.ConnectionPool:
def get_connection(self, *args: t.Any, **kwargs: t.Any) -> urllib3.ConnectionPool:
"""
Ensure assert_hostname is set correctly on our pool

View File

@ -19,12 +19,20 @@ from .._import_helper import HTTPAdapter, urllib3, urllib3_connection
from .basehttpadapter import BaseHTTPAdapter
if t.TYPE_CHECKING:
from collections.abc import Mapping
from requests import PreparedRequest
from ..._socket_helper import SocketLike
RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
class UnixHTTPConnection(urllib3_connection.HTTPConnection):
def __init__(
self, base_url: str | bytes, unix_socket, timeout: int | float = 60
self, base_url: str | bytes, unix_socket: str, timeout: int | float = 60
) -> None:
super().__init__("localhost", timeout=timeout)
self.base_url = base_url
@ -43,7 +51,7 @@ class UnixHTTPConnection(urllib3_connection.HTTPConnection):
if header == "Connection" and "Upgrade" in values:
self.disable_buffering = True
def response_class(self, sock, *args, **kwargs) -> t.Any:
def response_class(self, sock: SocketLike, *args: t.Any, **kwargs: t.Any) -> t.Any:
# FIXME: We may need to disable buffering on Py3,
# but there's no clear way to do it at the moment. See:
# https://github.com/docker/docker-py/issues/1799
@ -88,12 +96,16 @@ class UnixHTTPAdapter(BaseHTTPAdapter):
self.socket_path = socket_path
self.timeout = timeout
self.max_pool_size = max_pool_size
self.pools = RecentlyUsedContainer(
pool_connections, dispose_func=lambda p: p.close()
)
def f(p: t.Any) -> None:
p.close()
self.pools = RecentlyUsedContainer(pool_connections, dispose_func=f)
super().__init__()
def get_connection(self, url: str | bytes, proxies=None) -> UnixHTTPConnectionPool:
def get_connection(
self, url: str | bytes, proxies: Mapping[str, str] | None = None
) -> UnixHTTPConnectionPool:
with self.pools.lock:
pool = self.pools.get(url)
if pool:
@ -106,7 +118,7 @@ class UnixHTTPAdapter(BaseHTTPAdapter):
return pool
def request_url(self, request, proxies) -> str:
def request_url(self, request: PreparedRequest, proxies: Mapping[str, str]) -> str:
# The select_proxy utility in requests errors out when the provided URL
# does not have a hostname, like is the case when using a UNIX socket.
# Since proxies are an irrelevant notion in the case of UNIX sockets

View File

@ -18,7 +18,13 @@ from .._import_helper import urllib3
from ..errors import DockerException
class CancellableStream:
if t.TYPE_CHECKING:
from requests import Response
_T = t.TypeVar("_T")
class CancellableStream(t.Generic[_T]):
"""
Stream wrapper for real-time events, logs, etc. from the server.
@ -30,14 +36,14 @@ class CancellableStream:
>>> events.close()
"""
def __init__(self, stream, response) -> None:
def __init__(self, stream: t.Generator[_T], response: Response) -> None:
self._stream = stream
self._response = response
def __iter__(self) -> t.Self:
return self
def __next__(self):
def __next__(self) -> _T:
try:
return next(self._stream)
except urllib3.exceptions.ProtocolError as exc:
@ -56,7 +62,7 @@ class CancellableStream:
# find the underlying socket object
# based on api.client._get_raw_response_socket
sock_fp = self._response.raw._fp.fp
sock_fp = self._response.raw._fp.fp # type: ignore
if hasattr(sock_fp, "raw"):
sock_raw = sock_fp.raw
@ -74,7 +80,7 @@ class CancellableStream:
"Cancellable streams not supported for the SSH protocol"
)
else:
sock = sock_fp._sock
sock = sock_fp._sock # type: ignore
if hasattr(urllib3.contrib, "pyopenssl") and isinstance(
sock, urllib3.contrib.pyopenssl.WrappedSocket

View File

@ -37,7 +37,7 @@ def _purge() -> None:
_cache.clear()
def fnmatch(name: str, pat: str):
def fnmatch(name: str, pat: str) -> bool:
"""Test whether FILENAME matches PATTERN.
Patterns are Unix shell style:

View File

@ -22,7 +22,9 @@ from ..transport.npipesocket import NpipeSocket
if t.TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import Iterable, Sequence
from ..._socket_helper import SocketLike
STDOUT = 1
@ -38,14 +40,14 @@ class SocketError(Exception):
NPIPE_ENDED = 109
def read(socket, n: int = 4096) -> bytes | None:
def read(socket: SocketLike, n: int = 4096) -> bytes | None:
"""
Reads at most n bytes from socket
"""
recoverable_errors = (errno.EINTR, errno.EDEADLK, errno.EWOULDBLOCK)
if not isinstance(socket, NpipeSocket):
if not isinstance(socket, NpipeSocket): # type: ignore[unreachable]
if not hasattr(select, "poll"):
# Limited to 1024
select.select([socket], [], [])
@ -66,7 +68,7 @@ def read(socket, n: int = 4096) -> bytes | None:
return None # TODO ???
except Exception as e:
is_pipe_ended = (
isinstance(socket, NpipeSocket)
isinstance(socket, NpipeSocket) # type: ignore[unreachable]
and len(e.args) > 0
and e.args[0] == NPIPE_ENDED
)
@ -77,7 +79,7 @@ def read(socket, n: int = 4096) -> bytes | None:
raise
def read_exactly(socket, n: int) -> bytes:
def read_exactly(socket: SocketLike, n: int) -> bytes:
"""
Reads exactly n bytes from socket
Raises SocketError if there is not enough data
@ -91,7 +93,7 @@ def read_exactly(socket, n: int) -> bytes:
return data
def next_frame_header(socket) -> tuple[int, int]:
def next_frame_header(socket: SocketLike) -> tuple[int, int]:
"""
Returns the stream and size of the next frame of data waiting to be read
from socket, according to the protocol defined here:
@ -107,7 +109,7 @@ def next_frame_header(socket) -> tuple[int, int]:
return (stream, actual)
def frames_iter(socket, tty: bool) -> t.Generator[tuple[int, bytes]]:
def frames_iter(socket: SocketLike, tty: bool) -> t.Generator[tuple[int, bytes]]:
"""
Return a generator of frames read from socket. A frame is a tuple where
the first item is the stream number and the second item is a chunk of data.
@ -120,7 +122,7 @@ def frames_iter(socket, tty: bool) -> t.Generator[tuple[int, bytes]]:
return frames_iter_no_tty(socket)
def frames_iter_no_tty(socket) -> t.Generator[tuple[int, bytes]]:
def frames_iter_no_tty(socket: SocketLike) -> t.Generator[tuple[int, bytes]]:
"""
Returns a generator of data read from the socket when the tty setting is
not enabled.
@ -141,7 +143,7 @@ def frames_iter_no_tty(socket) -> t.Generator[tuple[int, bytes]]:
yield (stream, result)
def frames_iter_tty(socket) -> t.Generator[bytes]:
def frames_iter_tty(socket: SocketLike) -> t.Generator[bytes]:
"""
Return a generator of data read from the socket when the tty setting is
enabled.
@ -155,20 +157,42 @@ def frames_iter_tty(socket) -> t.Generator[bytes]:
@t.overload
def consume_socket_output(frames, demux: t.Literal[False] = False) -> bytes: ...
@t.overload
def consume_socket_output(frames, demux: t.Literal[True]) -> tuple[bytes, bytes]: ...
def consume_socket_output(
frames: Sequence[bytes] | t.Generator[bytes], demux: t.Literal[False] = False
) -> bytes: ...
@t.overload
def consume_socket_output(
frames, demux: bool = False
frames: (
Sequence[tuple[bytes | None, bytes | None]]
| t.Generator[tuple[bytes | None, bytes | None]]
),
demux: t.Literal[True],
) -> tuple[bytes, bytes]: ...
@t.overload
def consume_socket_output(
frames: (
Sequence[bytes]
| Sequence[tuple[bytes | None, bytes | None]]
| t.Generator[bytes]
| t.Generator[tuple[bytes | None, bytes | None]]
),
demux: bool = False,
) -> bytes | tuple[bytes, bytes]: ...
def consume_socket_output(frames, demux: bool = False) -> bytes | tuple[bytes, bytes]:
def consume_socket_output(
frames: (
Sequence[bytes]
| Sequence[tuple[bytes | None, bytes | None]]
| t.Generator[bytes]
| t.Generator[tuple[bytes | None, bytes | None]]
),
demux: bool = False,
) -> bytes | tuple[bytes, bytes]:
"""
Iterate through frames read from the socket and return the result.
@ -183,12 +207,13 @@ def consume_socket_output(frames, demux: bool = False) -> bytes | tuple[bytes, b
if demux is False:
# If the streams are multiplexed, the generator returns strings, that
# we just need to concatenate.
return b"".join(frames)
return b"".join(frames) # type: ignore
# If the streams are demultiplexed, the generator yields tuples
# (stdout, stderr)
out: list[bytes | None] = [None, None]
for frame in frames:
frame: tuple[bytes | None, bytes | None]
for frame in frames: # type: ignore
# It is guaranteed that for each frame, one and only one stream
# is not None.
if frame == (None, None):
@ -202,7 +227,7 @@ def consume_socket_output(frames, demux: bool = False) -> bytes | tuple[bytes, b
if out[1] is None:
out[1] = frame[1]
else:
out[1] += frame[1]
out[1] += frame[1] # type: ignore[operator]
return tuple(out) # type: ignore

View File

@ -502,8 +502,8 @@ def split_command(command: str) -> list[str]:
return shlex.split(command)
def format_environment(environment: Mapping[str, str | bytes]) -> list[str]:
def format_env(key, value):
def format_environment(environment: Mapping[str, str | bytes | None]) -> list[str]:
def format_env(key: str, value: str | bytes | None) -> str:
if value is None:
return key
if isinstance(value, bytes):

View File

@ -91,7 +91,7 @@ if not HAS_DOCKER_PY:
# No Docker SDK for Python. Create a place holder client to allow
# instantiation of AnsibleModule and proper error handing
class Client: # type: ignore # noqa: F811, pylint: disable=function-redefined
def __init__(self, **kwargs):
def __init__(self, **kwargs: t.Any) -> None:
pass
class APIError(Exception): # type: ignore # noqa: F811, pylint: disable=function-redefined
@ -226,7 +226,7 @@ class AnsibleDockerClientBase(Client):
f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}."
)
def log(self, msg: t.Any, pretty_print: bool = False):
def log(self, msg: t.Any, pretty_print: bool = False) -> None:
pass
# if self.debug:
# from .util import log_debug
@ -609,7 +609,7 @@ class AnsibleDockerClientBase(Client):
return new_tag, old_tag == new_tag
def inspect_distribution(self, image: str, **kwargs) -> dict[str, t.Any]:
def inspect_distribution(self, image: str, **kwargs: t.Any) -> dict[str, t.Any]:
"""
Get image digest by directly calling the Docker API when running Docker SDK < 4.0.0
since prior versions did not support accessing private repositories.
@ -629,7 +629,6 @@ class AnsibleDockerClientBase(Client):
class AnsibleDockerClient(AnsibleDockerClientBase):
def __init__(
self,
argument_spec: dict[str, t.Any] | None = None,
@ -651,7 +650,6 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
option_minimal_versions_ignore_params: Sequence[str] | None = None,
fail_results: dict[str, t.Any] | None = None,
):
# Modules can put information in here which will always be returned
# in case client.fail() is called.
self.fail_results = fail_results or {}

View File

@ -146,7 +146,7 @@ class AnsibleDockerClientBase(Client):
f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}."
)
def log(self, msg: t.Any, pretty_print: bool = False):
def log(self, msg: t.Any, pretty_print: bool = False) -> None:
pass
# if self.debug:
# from .util import log_debug
@ -295,7 +295,7 @@ class AnsibleDockerClientBase(Client):
),
}
def depr(*args, **kwargs):
def depr(*args: t.Any, **kwargs: t.Any) -> None:
self.deprecate(*args, **kwargs)
update_tls_hostname(

View File

@ -82,7 +82,7 @@ class AnsibleDockerClientBase:
def __init__(
self,
common_args,
common_args: dict[str, t.Any],
min_docker_api_version: str | None = None,
needs_api_version: bool = True,
) -> None:
@ -91,15 +91,15 @@ class AnsibleDockerClientBase:
self._environment["DOCKER_TLS_HOSTNAME"] = common_args["tls_hostname"]
if common_args["api_version"] and common_args["api_version"] != "auto":
self._environment["DOCKER_API_VERSION"] = common_args["api_version"]
self._cli = common_args.get("docker_cli")
if self._cli is None:
cli = common_args.get("docker_cli")
if cli is None:
try:
self._cli = get_bin_path("docker")
cli = get_bin_path("docker")
except ValueError:
self.fail(
"Cannot find docker CLI in path. Please provide it explicitly with the docker_cli parameter"
)
self._cli = cli
self._cli_base = [self._cli]
docker_host = common_args["docker_host"]
if not docker_host and not common_args["cli_context"]:
@ -149,7 +149,7 @@ class AnsibleDockerClientBase:
"Internal error: cannot have needs_api_version=False with min_docker_api_version not None"
)
def log(self, msg: str, pretty_print: bool = False):
def log(self, msg: str, pretty_print: bool = False) -> None:
pass
# if self.debug:
# from .util import log_debug
@ -227,7 +227,7 @@ class AnsibleDockerClientBase:
return rc, result, stderr
@abc.abstractmethod
def fail(self, msg: str, **kwargs) -> t.NoReturn:
def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
pass
@abc.abstractmethod
@ -395,7 +395,6 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase):
fail_results: dict[str, t.Any] | None = None,
needs_api_version: bool = True,
) -> None:
# Modules can put information in here which will always be returned
# in case client.fail() is called.
self.fail_results = fail_results or {}
@ -463,7 +462,7 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase):
)
return rc, stdout, stderr
def fail(self, msg: str, **kwargs) -> t.NoReturn:
def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
self.fail_results.update(kwargs)
self.module.fail_json(msg=msg, **sanitize_result(self.fail_results))

View File

@ -971,7 +971,7 @@ class BaseComposeManager(DockerBaseClass):
stderr: str | bytes,
ignore_service_pull_events: bool = False,
ignore_build_events: bool = False,
):
) -> None:
result["changed"] = result.get("changed", False) or has_changes(
events,
ignore_service_pull_events=ignore_service_pull_events,
@ -989,7 +989,7 @@ class BaseComposeManager(DockerBaseClass):
stdout: str | bytes,
stderr: bytes,
rc: int,
):
) -> bool:
return update_failed(
result,
events,

View File

@ -330,6 +330,8 @@ def stat_file(
client._raise_for_status(response)
header = response.headers.get("x-docker-container-path-stat")
try:
if header is None:
raise ValueError("x-docker-container-path-stat header not present")
stat_data = json.loads(base64.b64decode(header))
except Exception as exc:
raise DockerUnexpectedError(
@ -482,14 +484,14 @@ def fetch_file(
shutil.copyfileobj(in_f, out_f)
return in_path
def process_symlink(in_path, member) -> str:
def process_symlink(in_path: str, member: tarfile.TarInfo) -> str:
if os.path.exists(b_out_path):
os.unlink(b_out_path)
os.symlink(member.linkname, b_out_path)
return in_path
def process_other(in_path, member) -> str:
def process_other(in_path: str, member: tarfile.TarInfo) -> str:
raise DockerFileCopyError(
f'Remote file "{in_path}" is not a regular file or a symbolic link'
)

View File

@ -193,7 +193,9 @@ class OptionGroup:
) -> None:
if preprocess is None:
def preprocess(module, values):
def preprocess(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
return values
self.preprocess = preprocess
@ -207,8 +209,8 @@ class OptionGroup:
self.ansible_required_by = ansible_required_by or {}
self.argument_spec: dict[str, t.Any] = {}
def add_option(self, *args, **kwargs) -> OptionGroup:
option = Option(*args, owner=self, **kwargs)
def add_option(self, name: str, **kwargs: t.Any) -> OptionGroup:
option = Option(name, owner=self, **kwargs)
if not option.not_a_container_option:
self.options.append(option)
self.all_options.append(option)
@ -788,7 +790,7 @@ def _preprocess_mounts(
) -> dict[str, t.Any]:
last: dict[str, str] = {}
def check_collision(t, name):
def check_collision(t: str, name: str) -> None:
if t in last:
if name == last[t]:
module.fail_json(
@ -1069,7 +1071,9 @@ def _preprocess_ports(
return values
def _compare_platform(option: Option, param_value: t.Any, container_value: t.Any):
def _compare_platform(
option: Option, param_value: t.Any, container_value: t.Any
) -> bool:
if option.comparison == "ignore":
return True
try:

View File

@ -872,7 +872,7 @@ class DockerAPIEngine(Engine[AnsibleDockerClient]):
image: dict[str, t.Any] | None,
values: dict[str, t.Any],
host_info: dict[str, t.Any] | None,
):
) -> dict[str, t.Any]:
if len(options) != 1:
raise AssertionError(
"host_config_value can only be used for a single option"
@ -1961,7 +1961,14 @@ def _update_value_restart(
}
def _get_values_ports(module, container, api_version, options, image, host_info):
def _get_values_ports(
module: AnsibleModule,
container: dict[str, t.Any],
api_version: LooseVersion,
options: list[Option],
image: dict[str, t.Any] | None,
host_info: dict[str, t.Any] | None,
) -> dict[str, t.Any]:
host_config = container["HostConfig"]
config = container["Config"]

View File

@ -292,7 +292,7 @@ class ContainerManager(DockerBaseClass, t.Generic[Client]):
if self.module.params[param] is None:
self.module.params[param] = value
def fail(self, *args, **kwargs) -> t.NoReturn:
def fail(self, *args: str, **kwargs: t.Any) -> t.NoReturn:
# mypy doesn't know that Client has fail() method
raise self.client.fail(*args, **kwargs) # type: ignore
@ -714,7 +714,7 @@ class ContainerManager(DockerBaseClass, t.Generic[Client]):
container_image: dict[str, t.Any] | None,
image: dict[str, t.Any] | None,
host_info: dict[str, t.Any] | None,
):
) -> None:
assert container.raw is not None
container_values = engine.get_value(
self.module,
@ -767,12 +767,12 @@ class ContainerManager(DockerBaseClass, t.Generic[Client]):
# Since the order does not matter, sort so that the diff output is better.
if option.name == "expected_mounts":
# For selected values, use one entry as key
def sort_key_fn(x):
def sort_key_fn(x: dict[str, t.Any]) -> t.Any:
return x["target"]
else:
# We sort the list of dictionaries by using the sorted items of a dict as its key.
def sort_key_fn(x):
def sort_key_fn(x: dict[str, t.Any]) -> t.Any:
return sorted(
(a, to_text(b, errors="surrogate_or_strict"))
for a, b in x.items()

View File

@ -26,6 +26,7 @@ from ansible_collections.community.docker.plugins.module_utils._socket_helper im
if t.TYPE_CHECKING:
from collections.abc import Callable
from types import TracebackType
from ansible.module_utils.basic import AnsibleModule
@ -70,7 +71,12 @@ class DockerSocketHandlerBase:
def __enter__(self) -> t.Self:
return self
def __exit__(self, type_, value, tb) -> None:
def __exit__(
self,
type_: t.Type[BaseException] | None,
value: BaseException | None,
tb: TracebackType | None,
) -> None:
self._selector.close()
def set_block_done_callback(
@ -210,7 +216,7 @@ class DockerSocketHandlerBase:
stdout = []
stderr = []
def append_block(stream_id, data):
def append_block(stream_id: int, data: bytes) -> None:
if stream_id == docker_socket.STDOUT:
stdout.append(data)
elif stream_id == docker_socket.STDERR:

View File

@ -23,6 +23,12 @@ if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
from ._common import AnsibleDockerClientBase as CADCB
from ._common_api import AnsibleDockerClientBase as CAPIADCB
from ._common_cli import AnsibleDockerClientBase as CCLIADCB
Client = t.Union[CADCB, CAPIADCB, CCLIADCB]
DEFAULT_DOCKER_HOST = "unix:///var/run/docker.sock"
DEFAULT_TLS = False
@ -119,7 +125,7 @@ def sanitize_result(data: t.Any) -> t.Any:
return data
def log_debug(msg: t.Any, pretty_print: bool = False):
def log_debug(msg: t.Any, pretty_print: bool = False) -> None:
"""Write a log message to docker.log.
If ``pretty_print=True``, the message will be pretty-printed as JSON.
@ -325,7 +331,7 @@ class DifferenceTracker:
def sanitize_labels(
labels: dict[str, t.Any] | None,
labels_field: str,
client=None,
client: Client | None = None,
module: AnsibleModule | None = None,
) -> None:
def fail(msg: str) -> t.NoReturn:
@ -371,7 +377,7 @@ def clean_dict_booleans_for_docker_api(
which is the expected format of filters which accept lists such as labels.
"""
def sanitize(value):
def sanitize(value: t.Any) -> str:
if value is True:
return "true"
if value is False:

View File

@ -147,7 +147,7 @@ class PullManager(BaseComposeManager):
f"--ignore-buildable is only supported since Docker Compose 2.15.0. {self.client.get_cli()} has version {self.compose_version}"
)
def get_pull_cmd(self, dry_run: bool):
def get_pull_cmd(self, dry_run: bool) -> list[str]:
args = self.get_base_args() + ["pull"]
if self.policy != "always":
args.extend(["--policy", self.policy])

View File

@ -347,7 +347,7 @@ def retrieve_diff(
max_file_size_for_diff: int,
regular_stat: dict[str, t.Any] | None = None,
link_target: str | None = None,
):
) -> None:
if diff is None:
return
if regular_stat is not None:
@ -497,9 +497,9 @@ def is_file_idempotent(
container_path: str,
follow_links: bool,
local_follow_links: bool,
owner_id,
group_id,
mode,
owner_id: int,
group_id: int,
mode: int | None,
force: bool | None = False,
diff: dict[str, t.Any] | None = None,
max_file_size_for_diff: int = 1,
@ -744,9 +744,9 @@ def copy_file_into_container(
container_path: str,
follow_links: bool,
local_follow_links: bool,
owner_id,
group_id,
mode,
owner_id: int,
group_id: int,
mode: int | None,
force: bool | None = False,
do_diff: bool = False,
max_file_size_for_diff: int = 1,
@ -797,9 +797,9 @@ def is_content_idempotent(
content: bytes,
container_path: str,
follow_links: bool,
owner_id,
group_id,
mode,
owner_id: int,
group_id: int,
mode: int,
force: bool | None = False,
diff: dict[str, t.Any] | None = None,
max_file_size_for_diff: int = 1,
@ -989,9 +989,9 @@ def copy_content_into_container(
content: bytes,
container_path: str,
follow_links: bool,
owner_id,
group_id,
mode,
owner_id: int,
group_id: int,
mode: int,
force: bool | None = False,
do_diff: bool = False,
max_file_size_for_diff: int = 1,
@ -1133,6 +1133,7 @@ def main() -> None:
owner_id, group_id = determine_user_group(client, container)
if content is not None:
assert mode is not None # see required_by above
copy_content_into_container(
client,
container,

View File

@ -667,7 +667,7 @@ class ImageManager(DockerBaseClass):
:rtype: str
"""
def build_msg(reason):
def build_msg(reason: str) -> str:
return f"Archived image {current_image_name} to {archive_path}, {reason}"
try:
@ -877,7 +877,7 @@ class ImageManager(DockerBaseClass):
self.push_image(repo, repo_tag)
@staticmethod
def _extract_output_line(line: dict[str, t.Any], output: list[str]):
def _extract_output_line(line: dict[str, t.Any], output: list[str]) -> None:
"""
Extract text line from stream output and, if found, adds it to output.
"""
@ -1165,18 +1165,18 @@ def main() -> None:
("source", "load", ["load_path"]),
]
def detect_etc_hosts(client):
def detect_etc_hosts(client: AnsibleDockerClient) -> bool:
return client.module.params["build"] and bool(
client.module.params["build"].get("etc_hosts")
)
def detect_build_platform(client):
def detect_build_platform(client: AnsibleDockerClient) -> bool:
return (
client.module.params["build"]
and client.module.params["build"].get("platform") is not None
)
def detect_pull_platform(client):
def detect_pull_platform(client: AnsibleDockerClient) -> bool:
return (
client.module.params["pull"]
and client.module.params["pull"].get("platform") is not None

View File

@ -379,7 +379,7 @@ def normalize_ipam_config_key(key: str) -> str:
return special_cases.get(key, key.lower())
def dicts_are_essentially_equal(a: dict[str, t.Any], b: dict[str, t.Any]):
def dicts_are_essentially_equal(a: dict[str, t.Any], b: dict[str, t.Any]) -> bool:
"""Make sure that a is a subset of b, where None entries of a are ignored."""
for k, v in a.items():
if v is None:

View File

@ -204,12 +204,10 @@ class DockerPluginManager:
elif state == "disable":
self.disable()
if self.diff or self.check_mode or self.parameters.debug:
if self.diff:
self.diff_result["before"], self.diff_result["after"] = (
self.diff_tracker.get_before_after()
)
self.diff = self.diff_result
if self.diff:
self.diff_result["before"], self.diff_result["after"] = (
self.diff_tracker.get_before_after()
)
def get_existing_plugin(self) -> dict[str, t.Any] | None:
try:
@ -409,7 +407,7 @@ class DockerPluginManager:
result: dict[str, t.Any] = {
"actions": self.actions,
"changed": self.changed,
"diff": self.diff,
"diff": self.diff_result,
"plugin": plugin_data,
}
if (

View File

@ -247,7 +247,9 @@ class DockerSwarmManager(DockerBaseClass):
self.client.fail(f"Error inspecting docker swarm: {exc}")
def get_docker_items_list(
self, docker_object: t.Literal["nodes", "tasks", "services"], filters=None
self,
docker_object: t.Literal["nodes", "tasks", "services"],
filters: dict[str, str],
) -> list[dict[str, t.Any]]:
items_list: list[dict[str, t.Any]] = []

View File

@ -1463,8 +1463,8 @@ class DockerService(DockerBaseClass):
def from_ansible_params(
cls,
ap: dict[str, t.Any],
old_service,
image_digest,
old_service: DockerService | None,
image_digest: str,
secret_ids: dict[str, str],
config_ids: dict[str, str],
network_ids: dict[str, str],

View File

@ -4,6 +4,7 @@
from __future__ import annotations
import typing as t
import unittest
from io import StringIO
from unittest import mock
@ -40,7 +41,7 @@ class TestDockerConnectionClass(unittest.TestCase):
return_value=("docker version", "1.2.3", "", 0),
)
def test_docker_connection_module_too_old(
self, mock_new_docker_version, mock_old_docker_version
self, mock_new_docker_version: t.Any, mock_old_docker_version: t.Any
) -> None:
self.dc._version = None
self.dc.remote_user = "foo"
@ -59,7 +60,7 @@ class TestDockerConnectionClass(unittest.TestCase):
return_value=("docker version", "1.7.0", "", 0),
)
def test_docker_connection_module(
self, mock_new_docker_version, mock_old_docker_version
self, mock_new_docker_version: t.Any, mock_old_docker_version: t.Any
) -> None:
self.dc._version = None
@ -73,7 +74,7 @@ class TestDockerConnectionClass(unittest.TestCase):
return_value=("false", "garbage", "", 1),
)
def test_docker_connection_module_wrong_cmd(
self, mock_new_docker_version, mock_old_docker_version
self, mock_new_docker_version: t.Any, mock_old_docker_version: t.Any
) -> None:
self.dc._version = None
self.dc.remote_user = "foo"

View File

@ -31,7 +31,7 @@ def templar() -> Templar:
@pytest.fixture(scope="module")
def inventory(templar) -> InventoryModule:
def inventory(templar: Templar) -> InventoryModule:
r = InventoryModule()
r.inventory = InventoryData()
r.templar = templar
@ -91,7 +91,7 @@ LOVING_THARP_SERVICE = {
def create_get_option(
options: dict[str, t.Any], default: t.Any = False
) -> Callable[[str], t.Any]:
def get_option(option):
def get_option(option: str) -> t.Any:
if option in options:
return options[option]
return default
@ -116,12 +116,12 @@ class FakeClient:
self.get_results[f"/containers/{host['Id']}/json"] = host
self.get_results["/containers/json"] = list_reply
def get_json(self, url: str, *param: str, **kwargs) -> t.Any:
def get_json(self, url: str, *param: str, **kwargs: t.Any) -> t.Any:
url = url.format(*param)
return self.get_results[url]
def test_populate(inventory: InventoryModule, mocker) -> None:
def test_populate(inventory: InventoryModule, mocker: t.Any) -> None:
assert inventory.inventory is not None
client = FakeClient(LOVING_THARP)
@ -158,7 +158,7 @@ def test_populate(inventory: InventoryModule, mocker) -> None:
assert len(inventory.inventory.hosts) == 1
def test_populate_service(inventory: InventoryModule, mocker) -> None:
def test_populate_service(inventory: InventoryModule, mocker: t.Any) -> None:
assert inventory.inventory is not None
client = FakeClient(LOVING_THARP_SERVICE)
@ -218,7 +218,7 @@ def test_populate_service(inventory: InventoryModule, mocker) -> None:
assert len(inventory.inventory.hosts) == 1
def test_populate_stack(inventory: InventoryModule, mocker) -> None:
def test_populate_stack(inventory: InventoryModule, mocker: t.Any) -> None:
assert inventory.inventory is not None
client = FakeClient(LOVING_THARP_STACK)
@ -280,7 +280,7 @@ def test_populate_stack(inventory: InventoryModule, mocker) -> None:
assert len(inventory.inventory.hosts) == 1
def test_populate_filter_none(inventory: InventoryModule, mocker) -> None:
def test_populate_filter_none(inventory: InventoryModule, mocker: t.Any) -> None:
assert inventory.inventory is not None
client = FakeClient(LOVING_THARP)
@ -304,7 +304,7 @@ def test_populate_filter_none(inventory: InventoryModule, mocker) -> None:
assert len(inventory.inventory.hosts) == 0
def test_populate_filter(inventory: InventoryModule, mocker) -> None:
def test_populate_filter(inventory: InventoryModule, mocker: t.Any) -> None:
assert inventory.inventory is not None
client = FakeClient(LOVING_THARP)

View File

@ -43,6 +43,12 @@ from ansible_collections.community.docker.tests.unit.plugins.module_utils._api.c
from .. import fake_api
if t.TYPE_CHECKING:
from ansible_collections.community.docker.plugins.module_utils._api.auth import (
AuthConfig,
)
DEFAULT_TIMEOUT_SECONDS = constants.DEFAULT_TIMEOUT_SECONDS
@ -52,8 +58,8 @@ def response(
headers: dict[str, str] | None = None,
reason: str = "",
elapsed: int = 0,
request=None,
raw=None,
request: requests.PreparedRequest | None = None,
raw: urllib3.HTTPResponse | None = None,
) -> requests.Response:
res = requests.Response()
res.status_code = status_code
@ -63,18 +69,18 @@ def response(
res.headers = requests.structures.CaseInsensitiveDict(headers or {})
res.reason = reason
res.elapsed = datetime.timedelta(elapsed)
res.request = request
res.request = request # type: ignore
res.raw = raw
return res
def fake_resolve_authconfig( # pylint: disable=keyword-arg-before-vararg
authconfig, *args, registry=None, **kwargs
authconfig: AuthConfig, *args: t.Any, registry: str | None = None, **kwargs: t.Any
) -> None:
return None
def fake_inspect_container(self, container: str, tty: bool = False):
def fake_inspect_container(self: object, container: str, tty: bool = False) -> t.Any:
return fake_api.get_fake_inspect_container(tty=tty)[1]
@ -95,24 +101,32 @@ def fake_resp(
fake_request = mock.Mock(side_effect=fake_resp)
def fake_get(self, url: str, *args, **kwargs) -> requests.Response:
def fake_get(
self: APIClient, url: str, *args: str, **kwargs: t.Any
) -> requests.Response:
return fake_request("GET", url, *args, **kwargs)
def fake_post(self, url: str, *args, **kwargs) -> requests.Response:
def fake_post(
self: APIClient, url: str, *args: str, **kwargs: t.Any
) -> requests.Response:
return fake_request("POST", url, *args, **kwargs)
def fake_put(self, url: str, *args, **kwargs) -> requests.Response:
def fake_put(
self: APIClient, url: str, *args: str, **kwargs: t.Any
) -> requests.Response:
return fake_request("PUT", url, *args, **kwargs)
def fake_delete(self, url: str, *args, **kwargs) -> requests.Response:
def fake_delete(
self: APIClient, url: str, *args: str, **kwargs: t.Any
) -> requests.Response:
return fake_request("DELETE", url, *args, **kwargs)
def fake_read_from_socket(
self,
self: APIClient,
response: requests.Response,
stream: bool,
tty: bool = False,
@ -253,9 +267,9 @@ class DockerApiTest(BaseAPIClientTest):
"serveraddress": None,
}
def _socket_path_for_client_session(self, client) -> str:
def _socket_path_for_client_session(self, client: APIClient) -> str:
socket_adapter = client.get_adapter("http+docker://")
return socket_adapter.socket_path
return socket_adapter.socket_path # type: ignore[attr-defined]
def test_url_compatibility_unix(self) -> None:
c = APIClient(base_url="unix://socket", version=DEFAULT_DOCKER_API_VERSION)
@ -384,7 +398,7 @@ class UnixSocketStreamTest(unittest.TestCase):
finally:
self.server_socket.close()
def early_response_sending_handler(self, connection) -> None:
def early_response_sending_handler(self, connection: socket.socket) -> None:
data = b""
headers = None
@ -494,7 +508,7 @@ class TCPSocketStreamTest(unittest.TestCase):
stderr_data = cls.stderr_data
class Handler(BaseHTTPRequestHandler):
def do_POST(self): # pylint: disable=invalid-name
def do_POST(self) -> None: # pylint: disable=invalid-name
resp_data = self.get_resp_data()
self.send_response(101)
self.send_header("Content-Type", "application/vnd.docker.raw-stream")
@ -506,7 +520,7 @@ class TCPSocketStreamTest(unittest.TestCase):
self.wfile.write(resp_data)
self.wfile.flush()
def get_resp_data(self):
def get_resp_data(self) -> bytes:
path = self.path.split("/")[-1]
if path == "tty":
return stdout_data + stderr_data
@ -520,7 +534,7 @@ class TCPSocketStreamTest(unittest.TestCase):
raise NotImplementedError(f"Unknown path {path}")
@staticmethod
def frame_header(stream, data):
def frame_header(stream: int, data: bytes) -> bytes:
return struct.pack(">BxxxL", stream, len(data))
return Handler

View File

@ -133,126 +133,102 @@ class ResolveAuthTest(unittest.TestCase):
)
def test_resolve_authconfig_hostname_only(self) -> None:
assert (
auth.resolve_authconfig(self.auth_config, "my.registry.net")["username"]
== "privateuser"
)
ac = auth.resolve_authconfig(self.auth_config, "my.registry.net")
assert ac is not None
assert ac["username"] == "privateuser"
def test_resolve_authconfig_no_protocol(self) -> None:
assert (
auth.resolve_authconfig(self.auth_config, "my.registry.net/v1/")["username"]
== "privateuser"
)
ac = auth.resolve_authconfig(self.auth_config, "my.registry.net/v1/")
assert ac is not None
assert ac["username"] == "privateuser"
def test_resolve_authconfig_no_path(self) -> None:
assert (
auth.resolve_authconfig(self.auth_config, "http://my.registry.net")[
"username"
]
== "privateuser"
)
ac = auth.resolve_authconfig(self.auth_config, "http://my.registry.net")
assert ac is not None
assert ac["username"] == "privateuser"
def test_resolve_authconfig_no_path_trailing_slash(self) -> None:
assert (
auth.resolve_authconfig(self.auth_config, "http://my.registry.net/")[
"username"
]
== "privateuser"
)
ac = auth.resolve_authconfig(self.auth_config, "http://my.registry.net/")
assert ac is not None
assert ac["username"] == "privateuser"
def test_resolve_authconfig_no_path_wrong_secure_proto(self) -> None:
assert (
auth.resolve_authconfig(self.auth_config, "https://my.registry.net")[
"username"
]
== "privateuser"
)
ac = auth.resolve_authconfig(self.auth_config, "https://my.registry.net")
assert ac is not None
assert ac["username"] == "privateuser"
def test_resolve_authconfig_no_path_wrong_insecure_proto(self) -> None:
assert (
auth.resolve_authconfig(self.auth_config, "http://index.docker.io")[
"username"
]
== "indexuser"
)
ac = auth.resolve_authconfig(self.auth_config, "http://index.docker.io")
assert ac is not None
assert ac["username"] == "indexuser"
def test_resolve_authconfig_path_wrong_proto(self) -> None:
assert (
auth.resolve_authconfig(self.auth_config, "https://my.registry.net/v1/")[
"username"
]
== "privateuser"
)
ac = auth.resolve_authconfig(self.auth_config, "https://my.registry.net/v1/")
assert ac is not None
assert ac["username"] == "privateuser"
def test_resolve_authconfig_default_registry(self) -> None:
assert auth.resolve_authconfig(self.auth_config)["username"] == "indexuser"
ac = auth.resolve_authconfig(self.auth_config)
assert ac is not None
assert ac["username"] == "indexuser"
def test_resolve_authconfig_default_explicit_none(self) -> None:
assert (
auth.resolve_authconfig(self.auth_config, None)["username"] == "indexuser"
)
ac = auth.resolve_authconfig(self.auth_config, None)
assert ac is not None
assert ac["username"] == "indexuser"
def test_resolve_authconfig_fully_explicit(self) -> None:
assert (
auth.resolve_authconfig(self.auth_config, "http://my.registry.net/v1/")[
"username"
]
== "privateuser"
)
ac = auth.resolve_authconfig(self.auth_config, "http://my.registry.net/v1/")
assert ac is not None
assert ac["username"] == "privateuser"
def test_resolve_authconfig_legacy_config(self) -> None:
assert (
auth.resolve_authconfig(self.auth_config, "legacy.registry.url")["username"]
== "legacyauth"
)
ac = auth.resolve_authconfig(self.auth_config, "legacy.registry.url")
assert ac is not None
assert ac["username"] == "legacyauth"
def test_resolve_authconfig_no_match(self) -> None:
assert auth.resolve_authconfig(self.auth_config, "does.not.exist") is None
def test_resolve_registry_and_auth_library_image(self) -> None:
image = "image"
assert (
auth.resolve_authconfig(
self.auth_config, auth.resolve_repository_name(image)[0]
)["username"]
== "indexuser"
ac = auth.resolve_authconfig(
self.auth_config, auth.resolve_repository_name(image)[0]
)
assert ac is not None
assert ac["username"] == "indexuser"
def test_resolve_registry_and_auth_hub_image(self) -> None:
image = "username/image"
assert (
auth.resolve_authconfig(
self.auth_config, auth.resolve_repository_name(image)[0]
)["username"]
== "indexuser"
ac = auth.resolve_authconfig(
self.auth_config, auth.resolve_repository_name(image)[0]
)
assert ac is not None
assert ac["username"] == "indexuser"
def test_resolve_registry_and_auth_explicit_hub(self) -> None:
image = "docker.io/username/image"
assert (
auth.resolve_authconfig(
self.auth_config, auth.resolve_repository_name(image)[0]
)["username"]
== "indexuser"
ac = auth.resolve_authconfig(
self.auth_config, auth.resolve_repository_name(image)[0]
)
assert ac is not None
assert ac["username"] == "indexuser"
def test_resolve_registry_and_auth_explicit_legacy_hub(self) -> None:
image = "index.docker.io/username/image"
assert (
auth.resolve_authconfig(
self.auth_config, auth.resolve_repository_name(image)[0]
)["username"]
== "indexuser"
ac = auth.resolve_authconfig(
self.auth_config, auth.resolve_repository_name(image)[0]
)
assert ac is not None
assert ac["username"] == "indexuser"
def test_resolve_registry_and_auth_private_registry(self) -> None:
image = "my.registry.net/image"
assert (
auth.resolve_authconfig(
self.auth_config, auth.resolve_repository_name(image)[0]
)["username"]
== "privateuser"
ac = auth.resolve_authconfig(
self.auth_config, auth.resolve_repository_name(image)[0]
)
assert ac is not None
assert ac["username"] == "privateuser"
def test_resolve_registry_and_auth_unauthenticated_registry(self) -> None:
image = "other.registry.net/image"
@ -278,7 +254,9 @@ class ResolveAuthTest(unittest.TestCase):
"ansible_collections.community.docker.plugins.module_utils._api.auth.AuthConfig._resolve_authconfig_credstore"
) as m:
m.return_value = None
assert "indexuser" == auth.resolve_authconfig(auth_config, None)["username"]
ac = auth.resolve_authconfig(auth_config, None)
assert ac is not None
assert "indexuser" == ac["username"]
class LoadConfigTest(unittest.TestCase):
@ -797,7 +775,7 @@ class CredstoreTest(unittest.TestCase):
class InMemoryStore(Store):
def __init__( # pylint: disable=super-init-not-called
self, *args, **kwargs
self, *args: t.Any, **kwargs: t.Any
) -> None:
self.__store: dict[str | bytes, dict[str, t.Any]] = {}

View File

@ -156,7 +156,7 @@ class ExcludePathsTest(unittest.TestCase):
def test_single_filename_trailing_slash(self) -> None:
assert self.exclude(["a.py/"]) == convert_paths(self.all_paths - set(["a.py"]))
def test_wildcard_filename_start(self):
def test_wildcard_filename_start(self) -> None:
assert self.exclude(["*.py"]) == convert_paths(
self.all_paths - set(["a.py", "b.py", "cde.py"])
)

View File

@ -12,6 +12,7 @@ import json
import os
import shutil
import tempfile
import typing as t
import unittest
from collections.abc import Callable
from unittest import mock
@ -25,7 +26,7 @@ class FindConfigFileTest(unittest.TestCase):
mkdir: Callable[[str], os.PathLike[str]]
@fixture(autouse=True)
def tmpdir(self, tmpdir) -> None:
def tmpdir(self, tmpdir: t.Any) -> None:
self.mkdir = tmpdir.mkdir
def test_find_config_fallback(self) -> None:

View File

@ -8,6 +8,7 @@
from __future__ import annotations
import typing as t
import unittest
from ansible_collections.community.docker.plugins.module_utils._api.api.client import (
@ -27,7 +28,7 @@ class DecoratorsTest(unittest.TestCase):
"X-Docker-Locale": "en-US",
}
def f(self, headers=None):
def f(self: t.Any, headers: t.Any = None) -> t.Any:
return headers
client = APIClient(version=DEFAULT_DOCKER_API_VERSION)

View File

@ -469,7 +469,7 @@ class FormatEnvironmentTest(unittest.TestCase):
env_dict = {"ARTIST_NAME": b"\xec\x86\xa1\xec\xa7\x80\xec\x9d\x80"}
assert format_environment(env_dict) == ["ARTIST_NAME=송지은"]
def test_format_env_no_value(self):
def test_format_env_no_value(self) -> None:
env_dict = {
"FOO": None,
"BAR": "",

View File

@ -369,7 +369,7 @@ def test_parse_events(
) -> None:
collected_warnings = []
def collect_warning(msg):
def collect_warning(msg: str) -> None:
collected_warnings.append(msg)
collected_events = parse_events(

View File

@ -5,6 +5,7 @@
from __future__ import annotations
import tarfile
import typing as t
import pytest
@ -22,7 +23,7 @@ from ..test_support.docker_image_archive_stubbing import (
@pytest.fixture
def tar_file_name(tmpdir) -> str:
def tar_file_name(tmpdir: t.Any) -> str:
"""
Return the name of a non-existing tar file in an existing temporary directory.
"""
@ -38,7 +39,7 @@ def test_api_image_id_from_archive_id(expected: str, value: str) -> None:
assert api_image_id(value) == expected
def test_archived_image_manifest_extracts(tar_file_name) -> None:
def test_archived_image_manifest_extracts(tar_file_name: str) -> None:
expected_id = "abcde12345"
expected_tags = ["foo:latest", "bar:v1"]
@ -52,7 +53,7 @@ def test_archived_image_manifest_extracts(tar_file_name) -> None:
def test_archived_image_manifest_extracts_nothing_when_file_not_present(
tar_file_name,
tar_file_name: str,
) -> None:
image_id = archived_image_manifest(tar_file_name)
@ -69,7 +70,7 @@ def test_archived_image_manifest_raises_when_file_not_a_tar() -> None:
def test_archived_image_manifest_raises_when_tar_missing_manifest(
tar_file_name,
tar_file_name: str,
) -> None:
write_irrelevant_tar(tar_file_name)
@ -81,7 +82,9 @@ def test_archived_image_manifest_raises_when_tar_missing_manifest(
assert "manifest.json" in str(e.__cause__)
def test_archived_image_manifest_raises_when_manifest_missing_id(tar_file_name) -> None:
def test_archived_image_manifest_raises_when_manifest_missing_id(
tar_file_name: str,
) -> None:
manifest = [{"foo": "bar"}]
write_imitation_archive_with_manifest(tar_file_name, manifest)

View File

@ -38,7 +38,7 @@ def capture_logging(messages: list[str]) -> Callable[[str], None]:
@pytest.fixture
def tar_file_name(tmpdir):
def tar_file_name(tmpdir: t.Any) -> str:
"""
Return the name of a non-existing tar file in an existing temporary directory.
"""
@ -46,7 +46,7 @@ def tar_file_name(tmpdir):
return tmpdir.join("foo.tar")
def test_archived_image_action_when_missing(tar_file_name) -> None:
def test_archived_image_action_when_missing(tar_file_name: str) -> None:
fake_name = "a:latest"
fake_id = "a1"
@ -59,7 +59,7 @@ def test_archived_image_action_when_missing(tar_file_name) -> None:
assert actual == expected
def test_archived_image_action_when_current(tar_file_name) -> None:
def test_archived_image_action_when_current(tar_file_name: str) -> None:
fake_name = "b:latest"
fake_id = "b2"
@ -72,7 +72,7 @@ def test_archived_image_action_when_current(tar_file_name) -> None:
assert actual is None
def test_archived_image_action_when_invalid(tar_file_name) -> None:
def test_archived_image_action_when_invalid(tar_file_name: str) -> None:
fake_name = "c:1.2.3"
fake_id = "c3"
@ -91,7 +91,7 @@ def test_archived_image_action_when_invalid(tar_file_name) -> None:
assert actual_log[0].startswith("Unable to extract manifest summary from archive")
def test_archived_image_action_when_obsolete_by_id(tar_file_name) -> None:
def test_archived_image_action_when_obsolete_by_id(tar_file_name: str) -> None:
fake_name = "d:0.0.1"
old_id = "e5"
new_id = "d4"
@ -106,7 +106,7 @@ def test_archived_image_action_when_obsolete_by_id(tar_file_name) -> None:
assert actual == expected
def test_archived_image_action_when_obsolete_by_name(tar_file_name) -> None:
def test_archived_image_action_when_obsolete_by_name(tar_file_name: str) -> None:
old_name = "hi"
new_name = "d:0.0.1"
fake_id = "d4"

View File

@ -16,7 +16,7 @@ from ansible_collections.community.docker.plugins.modules import (
APIError = pytest.importorskip("docker.errors.APIError")
def test_retry_on_out_of_sequence_error(mocker) -> None:
def test_retry_on_out_of_sequence_error(mocker: t.Any) -> None:
run_mock = mocker.MagicMock(
side_effect=APIError(
message="",
@ -32,7 +32,7 @@ def test_retry_on_out_of_sequence_error(mocker) -> None:
assert run_mock.call_count == 3
def test_no_retry_on_general_api_error(mocker) -> None:
def test_no_retry_on_general_api_error(mocker: t.Any) -> None:
run_mock = mocker.MagicMock(
side_effect=APIError(message="", response=None, explanation="some error")
)
@ -44,7 +44,7 @@ def test_no_retry_on_general_api_error(mocker) -> None:
assert run_mock.call_count == 1
def test_get_docker_environment(mocker) -> None:
def test_get_docker_environment(mocker: t.Any) -> None:
env_file_result = {"TEST1": "A", "TEST2": "B", "TEST3": "C"}
env_dict = {"TEST3": "CC", "TEST4": "D"}
env_string = "TEST3=CC,TEST4=D"