Add typing information, 2/n (#1178)

* Add typing to Docker Stack modules. Clean modules up.

* Add typing to Docker Swarm modules.

* Add typing to unit tests.

* Add more typing.

* Add ignore.txt entries.
This commit is contained in:
Felix Fontein 2025-10-25 01:16:04 +02:00 committed by GitHub
parent 3350283bcc
commit 6ad4bfcd40
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
84 changed files with 1496 additions and 1161 deletions

View File

@ -3,15 +3,15 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
[mypy] [mypy]
# check_untyped_defs = True -- for later check_untyped_defs = True
# disallow_untyped_defs = True -- for later disallow_untyped_defs = True
# strict = True -- only try to enable once everything (including dependencies!) is typed # strict = True -- only try to enable once everything (including dependencies!) is typed
strict_equality = True strict_equality = True
strict_bytes = True strict_bytes = True
warn_redundant_casts = True warn_redundant_casts = True
# warn_return_any = True -- for later # warn_return_any = True
warn_unreachable = True warn_unreachable = True
[mypy-ansible.*] [mypy-ansible.*]

View File

@ -141,7 +141,7 @@ class Connection(ConnectionBase):
transport = "community.docker.docker" transport = "community.docker.docker"
has_pipelining = True has_pipelining = True
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# Note: docker supports running as non-root in some configurations. # Note: docker supports running as non-root in some configurations.
@ -476,7 +476,7 @@ class Connection(ConnectionBase):
display.debug("done with docker.exec_command()") display.debug("done with docker.exec_command()")
return (p.returncode, stdout, stderr) return (p.returncode, stdout, stderr)
def _prefix_login_path(self, remote_path): def _prefix_login_path(self, remote_path: str) -> str:
"""Make sure that we put files into a standard path """Make sure that we put files into a standard path
If a path is relative, then we need to choose where to put it. If a path is relative, then we need to choose where to put it.

View File

@ -192,7 +192,7 @@ class Connection(ConnectionBase):
f'An unexpected requests error occurred for container "{remote_addr}" when trying to talk to the Docker daemon: {e}' f'An unexpected requests error occurred for container "{remote_addr}" when trying to talk to the Docker daemon: {e}'
) )
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.client: AnsibleDockerClient | None = None self.client: AnsibleDockerClient | None = None
@ -319,7 +319,7 @@ class Connection(ConnectionBase):
become_output = [b""] become_output = [b""]
def append_become_output(stream_id, data): def append_become_output(stream_id: int, data: bytes) -> None:
become_output[0] += data become_output[0] += data
exec_socket_handler.set_block_done_callback( exec_socket_handler.set_block_done_callback(

View File

@ -65,7 +65,7 @@ class Connection(ConnectionBase):
transport = "community.docker.nsenter" transport = "community.docker.nsenter"
has_pipelining = False has_pipelining = False
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.cwd = None self.cwd = None
self._nsenter_pid = None self._nsenter_pid = None

View File

@ -221,7 +221,10 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
return ip_addr return ip_addr
def _should_skip_host( def _should_skip_host(
self, machine_name: str, env_var_tuples, daemon_env: DaemonEnv self,
machine_name: str,
env_var_tuples: list[tuple[str, str]],
daemon_env: DaemonEnv,
) -> bool: ) -> bool:
if not env_var_tuples: if not env_var_tuples:
warning_prefix = f"Unable to fetch Docker daemon env vars from Docker Machine for host {machine_name}" warning_prefix = f"Unable to fetch Docker daemon env vars from Docker Machine for host {machine_name}"

View File

@ -67,7 +67,7 @@ except ImportError:
pass pass
class FakeURLLIB3: class FakeURLLIB3:
def __init__(self): def __init__(self) -> None:
self._collections = self self._collections = self
self.poolmanager = self self.poolmanager = self
self.connection = self self.connection = self
@ -81,14 +81,14 @@ except ImportError:
) )
class FakeURLLIB3Connection: class FakeURLLIB3Connection:
def __init__(self): def __init__(self) -> None:
self.HTTPConnection = _HTTPConnection # pylint: disable=invalid-name self.HTTPConnection = _HTTPConnection # pylint: disable=invalid-name
urllib3 = FakeURLLIB3() urllib3 = FakeURLLIB3()
urllib3_connection = FakeURLLIB3Connection() urllib3_connection = FakeURLLIB3Connection()
def fail_on_missing_imports(): def fail_on_missing_imports() -> None:
if REQUESTS_IMPORT_ERROR is not None: if REQUESTS_IMPORT_ERROR is not None:
from .errors import MissingRequirementException # pylint: disable=cyclic-import from .errors import MissingRequirementException # pylint: disable=cyclic-import

View File

@ -55,6 +55,7 @@ from ..utils.socket import consume_socket_output, demux_adaptor, frames_iter
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from requests import Response from requests import Response
from requests.adapters import BaseAdapter
from ..._socket_helper import SocketLike from ..._socket_helper import SocketLike
@ -258,23 +259,23 @@ class APIClient(_Session):
return kwargs return kwargs
@update_headers @update_headers
def _post(self, url: str, **kwargs): def _post(self, url: str, **kwargs: t.Any) -> Response:
return self.post(url, **self._set_request_timeout(kwargs)) return self.post(url, **self._set_request_timeout(kwargs))
@update_headers @update_headers
def _get(self, url: str, **kwargs): def _get(self, url: str, **kwargs: t.Any) -> Response:
return self.get(url, **self._set_request_timeout(kwargs)) return self.get(url, **self._set_request_timeout(kwargs))
@update_headers @update_headers
def _head(self, url: str, **kwargs): def _head(self, url: str, **kwargs: t.Any) -> Response:
return self.head(url, **self._set_request_timeout(kwargs)) return self.head(url, **self._set_request_timeout(kwargs))
@update_headers @update_headers
def _put(self, url: str, **kwargs): def _put(self, url: str, **kwargs: t.Any) -> Response:
return self.put(url, **self._set_request_timeout(kwargs)) return self.put(url, **self._set_request_timeout(kwargs))
@update_headers @update_headers
def _delete(self, url: str, **kwargs): def _delete(self, url: str, **kwargs: t.Any) -> Response:
return self.delete(url, **self._set_request_timeout(kwargs)) return self.delete(url, **self._set_request_timeout(kwargs))
def _url(self, pathfmt: str, *args: str, versioned_api: bool = True) -> str: def _url(self, pathfmt: str, *args: str, versioned_api: bool = True) -> str:
@ -343,7 +344,7 @@ class APIClient(_Session):
return response.text return response.text
def _post_json( def _post_json(
self, url: str, data: dict[str, str | None] | t.Any, **kwargs self, url: str, data: dict[str, str | None] | t.Any, **kwargs: t.Any
) -> Response: ) -> Response:
# Go <1.1 cannot unserialize null to a string # Go <1.1 cannot unserialize null to a string
# so we do this disgusting thing here. # so we do this disgusting thing here.
@ -556,22 +557,30 @@ class APIClient(_Session):
""" """
socket = self._get_raw_response_socket(response) socket = self._get_raw_response_socket(response)
gen: t.Generator = frames_iter(socket, tty) gen = frames_iter(socket, tty)
if demux: if demux:
# The generator will output tuples (stdout, stderr) # The generator will output tuples (stdout, stderr)
gen = (demux_adaptor(*frame) for frame in gen) demux_gen: t.Generator[tuple[bytes | None, bytes | None]] = (
demux_adaptor(*frame) for frame in gen
)
if stream:
return demux_gen
try:
# Wait for all the frames, concatenate them, and return the result
return consume_socket_output(demux_gen, demux=True)
finally:
response.close()
else: else:
# The generator will output strings # The generator will output strings
gen = (data for (dummy, data) in gen) mux_gen: t.Generator[bytes] = (data for (dummy, data) in gen)
if stream:
if stream: return mux_gen
return gen try:
try: # Wait for all the frames, concatenate them, and return the result
# Wait for all the frames, concatenate them, and return the result return consume_socket_output(mux_gen, demux=False)
return consume_socket_output(gen, demux=demux) finally:
finally: response.close()
response.close()
def _disable_socket_timeout(self, socket: SocketLike) -> None: def _disable_socket_timeout(self, socket: SocketLike) -> None:
"""Depending on the combination of python version and whether we are """Depending on the combination of python version and whether we are
@ -637,11 +646,11 @@ class APIClient(_Session):
return self._multiplexed_response_stream_helper(res) return self._multiplexed_response_stream_helper(res)
return sep.join(list(self._multiplexed_buffer_helper(res))) return sep.join(list(self._multiplexed_buffer_helper(res)))
def _unmount(self, *args) -> None: def _unmount(self, *args: t.Any) -> None:
for proto in args: for proto in args:
self.adapters.pop(proto) self.adapters.pop(proto)
def get_adapter(self, url: str): def get_adapter(self, url: str) -> BaseAdapter:
try: try:
return super().get_adapter(url) return super().get_adapter(url)
except _InvalidSchema as e: except _InvalidSchema as e:
@ -696,19 +705,19 @@ class APIClient(_Session):
else: else:
log.debug("No auth config found") log.debug("No auth config found")
def get_binary(self, pathfmt: str, *args: str, **kwargs) -> bytes: def get_binary(self, pathfmt: str, *args: str, **kwargs: t.Any) -> bytes:
return self._result( return self._result(
self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs), self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_binary=True, get_binary=True,
) )
def get_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any: def get_json(self, pathfmt: str, *args: str, **kwargs: t.Any) -> t.Any:
return self._result( return self._result(
self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs), self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_json=True, get_json=True,
) )
def get_text(self, pathfmt: str, *args: str, **kwargs) -> str: def get_text(self, pathfmt: str, *args: str, **kwargs: t.Any) -> str:
return self._result( return self._result(
self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs) self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs)
) )
@ -718,7 +727,7 @@ class APIClient(_Session):
pathfmt: str, pathfmt: str,
*args: str, *args: str,
chunk_size: int = DEFAULT_DATA_CHUNK_SIZE, chunk_size: int = DEFAULT_DATA_CHUNK_SIZE,
**kwargs, **kwargs: t.Any,
) -> t.Generator[bytes]: ) -> t.Generator[bytes]:
res = self._get( res = self._get(
self._url(pathfmt, *args, versioned_api=True), stream=True, **kwargs self._url(pathfmt, *args, versioned_api=True), stream=True, **kwargs
@ -726,23 +735,25 @@ class APIClient(_Session):
self._raise_for_status(res) self._raise_for_status(res)
return self._stream_raw_result(res, chunk_size=chunk_size, decode=False) return self._stream_raw_result(res, chunk_size=chunk_size, decode=False)
def delete_call(self, pathfmt: str, *args: str, **kwargs) -> None: def delete_call(self, pathfmt: str, *args: str, **kwargs: t.Any) -> None:
self._raise_for_status( self._raise_for_status(
self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs) self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs)
) )
def delete_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any: def delete_json(self, pathfmt: str, *args: str, **kwargs: t.Any) -> t.Any:
return self._result( return self._result(
self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs), self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_json=True, get_json=True,
) )
def post_call(self, pathfmt: str, *args: str, **kwargs) -> None: def post_call(self, pathfmt: str, *args: str, **kwargs: t.Any) -> None:
self._raise_for_status( self._raise_for_status(
self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs) self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs)
) )
def post_json(self, pathfmt: str, *args: str, data: t.Any = None, **kwargs) -> None: def post_json(
self, pathfmt: str, *args: str, data: t.Any = None, **kwargs: t.Any
) -> None:
self._raise_for_status( self._raise_for_status(
self._post_json( self._post_json(
self._url(pathfmt, *args, versioned_api=True), data, **kwargs self._url(pathfmt, *args, versioned_api=True), data, **kwargs
@ -750,7 +761,7 @@ class APIClient(_Session):
) )
def post_json_to_binary( def post_json_to_binary(
self, pathfmt: str, *args: str, data: t.Any = None, **kwargs self, pathfmt: str, *args: str, data: t.Any = None, **kwargs: t.Any
) -> bytes: ) -> bytes:
return self._result( return self._result(
self._post_json( self._post_json(
@ -760,7 +771,7 @@ class APIClient(_Session):
) )
def post_json_to_json( def post_json_to_json(
self, pathfmt: str, *args: str, data: t.Any = None, **kwargs self, pathfmt: str, *args: str, data: t.Any = None, **kwargs: t.Any
) -> t.Any: ) -> t.Any:
return self._result( return self._result(
self._post_json( self._post_json(
@ -770,7 +781,7 @@ class APIClient(_Session):
) )
def post_json_to_text( def post_json_to_text(
self, pathfmt: str, *args: str, data: t.Any = None, **kwargs self, pathfmt: str, *args: str, data: t.Any = None, **kwargs: t.Any
) -> str: ) -> str:
return self._result( return self._result(
self._post_json( self._post_json(
@ -784,7 +795,7 @@ class APIClient(_Session):
*args: str, *args: str,
data: t.Any = None, data: t.Any = None,
headers: dict[str, str] | None = None, headers: dict[str, str] | None = None,
**kwargs, **kwargs: t.Any,
) -> SocketLike: ) -> SocketLike:
headers = headers.copy() if headers else {} headers = headers.copy() if headers else {}
headers.update( headers.update(
@ -813,7 +824,7 @@ class APIClient(_Session):
stream: t.Literal[True], stream: t.Literal[True],
tty: bool = True, tty: bool = True,
demux: t.Literal[False] = False, demux: t.Literal[False] = False,
**kwargs, **kwargs: t.Any,
) -> t.Generator[bytes]: ... ) -> t.Generator[bytes]: ...
@t.overload @t.overload
@ -826,7 +837,7 @@ class APIClient(_Session):
stream: t.Literal[True], stream: t.Literal[True],
tty: t.Literal[True] = True, tty: t.Literal[True] = True,
demux: t.Literal[True], demux: t.Literal[True],
**kwargs, **kwargs: t.Any,
) -> t.Generator[tuple[bytes, None]]: ... ) -> t.Generator[tuple[bytes, None]]: ...
@t.overload @t.overload
@ -839,7 +850,7 @@ class APIClient(_Session):
stream: t.Literal[True], stream: t.Literal[True],
tty: t.Literal[False], tty: t.Literal[False],
demux: t.Literal[True], demux: t.Literal[True],
**kwargs, **kwargs: t.Any,
) -> t.Generator[tuple[bytes, None] | tuple[None, bytes]]: ... ) -> t.Generator[tuple[bytes, None] | tuple[None, bytes]]: ...
@t.overload @t.overload
@ -852,7 +863,7 @@ class APIClient(_Session):
stream: t.Literal[False], stream: t.Literal[False],
tty: bool = True, tty: bool = True,
demux: t.Literal[False] = False, demux: t.Literal[False] = False,
**kwargs, **kwargs: t.Any,
) -> bytes: ... ) -> bytes: ...
@t.overload @t.overload
@ -865,7 +876,7 @@ class APIClient(_Session):
stream: t.Literal[False], stream: t.Literal[False],
tty: t.Literal[True] = True, tty: t.Literal[True] = True,
demux: t.Literal[True], demux: t.Literal[True],
**kwargs, **kwargs: t.Any,
) -> tuple[bytes, None]: ... ) -> tuple[bytes, None]: ...
@t.overload @t.overload
@ -878,7 +889,7 @@ class APIClient(_Session):
stream: t.Literal[False], stream: t.Literal[False],
tty: t.Literal[False], tty: t.Literal[False],
demux: t.Literal[True], demux: t.Literal[True],
**kwargs, **kwargs: t.Any,
) -> tuple[bytes, bytes]: ... ) -> tuple[bytes, bytes]: ...
def post_json_to_stream( def post_json_to_stream(
@ -890,7 +901,7 @@ class APIClient(_Session):
stream: bool = False, stream: bool = False,
demux: bool = False, demux: bool = False,
tty: bool = False, tty: bool = False,
**kwargs, **kwargs: t.Any,
) -> t.Any: ) -> t.Any:
headers = headers.copy() if headers else {} headers = headers.copy() if headers else {}
headers.update( headers.update(
@ -912,7 +923,7 @@ class APIClient(_Session):
demux=demux, demux=demux,
) )
def post_to_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any: def post_to_json(self, pathfmt: str, *args: str, **kwargs: t.Any) -> t.Any:
return self._result( return self._result(
self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs), self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_json=True, get_json=True,

View File

@ -106,7 +106,7 @@ class AuthConfig(dict):
@classmethod @classmethod
def parse_auth( def parse_auth(
cls, entries: dict[str, dict[str, t.Any]], raise_on_error=False cls, entries: dict[str, dict[str, t.Any]], raise_on_error: bool = False
) -> dict[str, dict[str, t.Any]]: ) -> dict[str, dict[str, t.Any]]:
""" """
Parses authentication entries Parses authentication entries
@ -294,7 +294,7 @@ class AuthConfig(dict):
except StoreError as e: except StoreError as e:
raise errors.DockerException(f"Credentials store error: {e}") raise errors.DockerException(f"Credentials store error: {e}")
def _get_store_instance(self, name: str): def _get_store_instance(self, name: str) -> Store:
if name not in self._stores: if name not in self._stores:
self._stores[name] = Store(name, environment=self._credstore_env) self._stores[name] = Store(name, environment=self._credstore_env)
return self._stores[name] return self._stores[name]
@ -326,8 +326,10 @@ class AuthConfig(dict):
def resolve_authconfig( def resolve_authconfig(
authconfig, registry: str | None = None, credstore_env: dict[str, str] | None = None authconfig: AuthConfig | dict[str, t.Any],
): registry: str | None = None,
credstore_env: dict[str, str] | None = None,
) -> dict[str, t.Any] | None:
if not isinstance(authconfig, AuthConfig): if not isinstance(authconfig, AuthConfig):
authconfig = AuthConfig(authconfig, credstore_env) authconfig = AuthConfig(authconfig, credstore_env)
return authconfig.resolve_authconfig(registry) return authconfig.resolve_authconfig(registry)

View File

@ -89,7 +89,7 @@ def get_meta_dir(name: str | None = None) -> str:
return meta_dir return meta_dir
def get_meta_file(name) -> str: def get_meta_file(name: str) -> str:
return os.path.join(get_meta_dir(name), METAFILE) return os.path.join(get_meta_dir(name), METAFILE)

View File

@ -18,6 +18,10 @@ from ansible.module_utils.common.text.converters import to_native
from ._import_helper import HTTPError as _HTTPError from ._import_helper import HTTPError as _HTTPError
if t.TYPE_CHECKING:
from requests import Response
class DockerException(Exception): class DockerException(Exception):
""" """
A base class from which all other exceptions inherit. A base class from which all other exceptions inherit.
@ -55,7 +59,10 @@ class APIError(_HTTPError, DockerException):
""" """
def __init__( def __init__(
self, message: str | Exception, response=None, explanation: str | None = None self,
message: str | Exception,
response: Response | None = None,
explanation: str | None = None,
) -> None: ) -> None:
# requests 1.2 supports response as a keyword argument, but # requests 1.2 supports response as a keyword argument, but
# requests 1.1 does not # requests 1.1 does not

View File

@ -11,6 +11,7 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from queue import Empty from queue import Empty
from .. import constants from .. import constants
@ -19,6 +20,12 @@ from .basehttpadapter import BaseHTTPAdapter
from .npipesocket import NpipeSocket from .npipesocket import NpipeSocket
if t.TYPE_CHECKING:
from collections.abc import Mapping
from requests import PreparedRequest
RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
@ -91,7 +98,9 @@ class NpipeHTTPAdapter(BaseHTTPAdapter):
) )
super().__init__() super().__init__()
def get_connection(self, url: str | bytes, proxies=None) -> NpipeHTTPConnectionPool: def get_connection(
self, url: str | bytes, proxies: Mapping[str, str] | None = None
) -> NpipeHTTPConnectionPool:
with self.pools.lock: with self.pools.lock:
pool = self.pools.get(url) pool = self.pools.get(url)
if pool: if pool:
@ -104,7 +113,9 @@ class NpipeHTTPAdapter(BaseHTTPAdapter):
return pool return pool
def request_url(self, request, proxies) -> str: def request_url(
self, request: PreparedRequest, proxies: Mapping[str, str] | None
) -> str:
# The select_proxy utility in requests errors out when the provided URL # The select_proxy utility in requests errors out when the provided URL
# does not have a hostname, like is the case when using a UNIX socket. # does not have a hostname, like is the case when using a UNIX socket.
# Since proxies are an irrelevant notion in the case of UNIX sockets # Since proxies are an irrelevant notion in the case of UNIX sockets

View File

@ -64,7 +64,7 @@ class NpipeSocket:
implemented. implemented.
""" """
def __init__(self, handle=None) -> None: def __init__(self, handle: t.Any | None = None) -> None:
self._timeout = win32pipe.NMPWAIT_USE_DEFAULT_WAIT self._timeout = win32pipe.NMPWAIT_USE_DEFAULT_WAIT
self._handle = handle self._handle = handle
self._address: str | None = None self._address: str | None = None
@ -74,15 +74,17 @@ class NpipeSocket:
def accept(self) -> t.NoReturn: def accept(self) -> t.NoReturn:
raise NotImplementedError() raise NotImplementedError()
def bind(self, address) -> t.NoReturn: def bind(self, address: t.Any) -> t.NoReturn:
raise NotImplementedError() raise NotImplementedError()
def close(self) -> None: def close(self) -> None:
if self._handle is None:
raise ValueError("Handle not present")
self._handle.Close() self._handle.Close()
self._closed = True self._closed = True
@check_closed @check_closed
def connect(self, address, retry_count: int = 0) -> None: def connect(self, address: str, retry_count: int = 0) -> None:
try: try:
handle = win32file.CreateFile( handle = win32file.CreateFile(
address, address,
@ -116,11 +118,11 @@ class NpipeSocket:
self._address = address self._address = address
@check_closed @check_closed
def connect_ex(self, address) -> None: def connect_ex(self, address: str) -> None:
self.connect(address) self.connect(address)
@check_closed @check_closed
def detach(self): def detach(self) -> t.Any:
self._closed = True self._closed = True
return self._handle return self._handle
@ -134,16 +136,18 @@ class NpipeSocket:
def getsockname(self) -> str | None: def getsockname(self) -> str | None:
return self._address return self._address
def getsockopt(self, level, optname, buflen=None) -> t.NoReturn: def getsockopt(
self, level: t.Any, optname: t.Any, buflen: t.Any = None
) -> t.NoReturn:
raise NotImplementedError() raise NotImplementedError()
def ioctl(self, control, option) -> t.NoReturn: def ioctl(self, control: t.Any, option: t.Any) -> t.NoReturn:
raise NotImplementedError() raise NotImplementedError()
def listen(self, backlog) -> t.NoReturn: def listen(self, backlog: t.Any) -> t.NoReturn:
raise NotImplementedError() raise NotImplementedError()
def makefile(self, mode: str, bufsize: int | None = None): def makefile(self, mode: str, bufsize: int | None = None) -> t.IO[bytes]:
if mode.strip("b") != "r": if mode.strip("b") != "r":
raise NotImplementedError() raise NotImplementedError()
rawio = NpipeFileIOBase(self) rawio = NpipeFileIOBase(self)
@ -153,6 +157,8 @@ class NpipeSocket:
@check_closed @check_closed
def recv(self, bufsize: int, flags: int = 0) -> str: def recv(self, bufsize: int, flags: int = 0) -> str:
if self._handle is None:
raise ValueError("Handle not present")
dummy_err, data = win32file.ReadFile(self._handle, bufsize) dummy_err, data = win32file.ReadFile(self._handle, bufsize)
return data return data
@ -169,6 +175,8 @@ class NpipeSocket:
@check_closed @check_closed
def recv_into(self, buf: Buffer, nbytes: int = 0) -> int: def recv_into(self, buf: Buffer, nbytes: int = 0) -> int:
if self._handle is None:
raise ValueError("Handle not present")
readbuf = buf if isinstance(buf, memoryview) else memoryview(buf) readbuf = buf if isinstance(buf, memoryview) else memoryview(buf)
event = win32event.CreateEvent(None, True, True, None) event = win32event.CreateEvent(None, True, True, None)
@ -188,6 +196,8 @@ class NpipeSocket:
@check_closed @check_closed
def send(self, string: Buffer, flags: int = 0) -> int: def send(self, string: Buffer, flags: int = 0) -> int:
if self._handle is None:
raise ValueError("Handle not present")
event = win32event.CreateEvent(None, True, True, None) event = win32event.CreateEvent(None, True, True, None)
try: try:
overlapped = pywintypes.OVERLAPPED() overlapped = pywintypes.OVERLAPPED()
@ -210,7 +220,7 @@ class NpipeSocket:
self.connect(address) self.connect(address)
return self.send(string) return self.send(string)
def setblocking(self, flag: bool): def setblocking(self, flag: bool) -> None:
if flag: if flag:
return self.settimeout(None) return self.settimeout(None)
return self.settimeout(0) return self.settimeout(0)
@ -228,16 +238,16 @@ class NpipeSocket:
def gettimeout(self) -> int | float | None: def gettimeout(self) -> int | float | None:
return self._timeout return self._timeout
def setsockopt(self, level, optname, value) -> t.NoReturn: def setsockopt(self, level: t.Any, optname: t.Any, value: t.Any) -> t.NoReturn:
raise NotImplementedError() raise NotImplementedError()
@check_closed @check_closed
def shutdown(self, how) -> None: def shutdown(self, how: t.Any) -> None:
return self.close() return self.close()
class NpipeFileIOBase(io.RawIOBase): class NpipeFileIOBase(io.RawIOBase):
def __init__(self, npipe_socket) -> None: def __init__(self, npipe_socket: NpipeSocket | None) -> None:
self.sock = npipe_socket self.sock = npipe_socket
def close(self) -> None: def close(self) -> None:
@ -245,7 +255,10 @@ class NpipeFileIOBase(io.RawIOBase):
self.sock = None self.sock = None
def fileno(self) -> int: def fileno(self) -> int:
return self.sock.fileno() if self.sock is None:
raise RuntimeError("socket is closed")
# TODO: This is definitely a bug, NpipeSocket.fileno() does not exist!
return self.sock.fileno() # type: ignore
def isatty(self) -> bool: def isatty(self) -> bool:
return False return False
@ -254,6 +267,8 @@ class NpipeFileIOBase(io.RawIOBase):
return True return True
def readinto(self, buf: Buffer) -> int: def readinto(self, buf: Buffer) -> int:
if self.sock is None:
raise RuntimeError("socket is closed")
return self.sock.recv_into(buf) return self.sock.recv_into(buf)
def seekable(self) -> bool: def seekable(self) -> bool:

View File

@ -35,7 +35,7 @@ else:
PARAMIKO_IMPORT_ERROR = None # pylint: disable=invalid-name PARAMIKO_IMPORT_ERROR = None # pylint: disable=invalid-name
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from collections.abc import Buffer from collections.abc import Buffer, Mapping
RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
@ -67,7 +67,7 @@ class SSHSocket(socket.socket):
preexec_func = None preexec_func = None
if not constants.IS_WINDOWS_PLATFORM: if not constants.IS_WINDOWS_PLATFORM:
def f(): def f() -> None:
signal.signal(signal.SIGINT, signal.SIG_IGN) signal.signal(signal.SIGINT, signal.SIG_IGN)
preexec_func = f preexec_func = f
@ -100,13 +100,13 @@ class SSHSocket(socket.socket):
self.proc.stdin.flush() self.proc.stdin.flush()
return written return written
def sendall(self, data: Buffer, *args, **kwargs) -> None: def sendall(self, data: Buffer, *args: t.Any, **kwargs: t.Any) -> None:
self._write(data) self._write(data)
def send(self, data: Buffer, *args, **kwargs) -> int: def send(self, data: Buffer, *args: t.Any, **kwargs: t.Any) -> int:
return self._write(data) return self._write(data)
def recv(self, n: int, *args, **kwargs) -> bytes: def recv(self, n: int, *args: t.Any, **kwargs: t.Any) -> bytes:
if not self.proc: if not self.proc:
raise RuntimeError( raise RuntimeError(
"SSH subprocess not initiated. connect() must be called first." "SSH subprocess not initiated. connect() must be called first."
@ -114,7 +114,7 @@ class SSHSocket(socket.socket):
assert self.proc.stdout is not None assert self.proc.stdout is not None
return self.proc.stdout.read(n) return self.proc.stdout.read(n)
def makefile(self, mode: str, *args, **kwargs) -> t.IO: # type: ignore def makefile(self, mode: str, *args: t.Any, **kwargs: t.Any) -> t.IO: # type: ignore
if not self.proc: if not self.proc:
self.connect() self.connect()
assert self.proc is not None assert self.proc is not None
@ -138,7 +138,7 @@ class SSHConnection(urllib3_connection.HTTPConnection):
def __init__( def __init__(
self, self,
*, *,
ssh_transport=None, ssh_transport: paramiko.Transport | None = None,
timeout: int | float = 60, timeout: int | float = 60,
host: str, host: str,
) -> None: ) -> None:
@ -146,18 +146,19 @@ class SSHConnection(urllib3_connection.HTTPConnection):
self.ssh_transport = ssh_transport self.ssh_transport = ssh_transport
self.timeout = timeout self.timeout = timeout
self.ssh_host = host self.ssh_host = host
self.sock: paramiko.Channel | SSHSocket | None = None
def connect(self) -> None: def connect(self) -> None:
if self.ssh_transport: if self.ssh_transport:
sock = self.ssh_transport.open_session() channel = self.ssh_transport.open_session()
sock.settimeout(self.timeout) channel.settimeout(self.timeout)
sock.exec_command("docker system dial-stdio") channel.exec_command("docker system dial-stdio")
self.sock = channel
else: else:
sock = SSHSocket(self.ssh_host) sock = SSHSocket(self.ssh_host)
sock.settimeout(self.timeout) sock.settimeout(self.timeout)
sock.connect() sock.connect()
self.sock = sock
self.sock = sock
class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool): class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
@ -172,7 +173,7 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
host: str, host: str,
) -> None: ) -> None:
super().__init__("localhost", timeout=timeout, maxsize=maxsize) super().__init__("localhost", timeout=timeout, maxsize=maxsize)
self.ssh_transport = None self.ssh_transport: paramiko.Transport | None = None
self.timeout = timeout self.timeout = timeout
if ssh_client: if ssh_client:
self.ssh_transport = ssh_client.get_transport() self.ssh_transport = ssh_client.get_transport()
@ -276,7 +277,9 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
if self.ssh_client: if self.ssh_client:
self.ssh_client.connect(**self.ssh_params) self.ssh_client.connect(**self.ssh_params)
def get_connection(self, url: str | bytes, proxies=None) -> SSHConnectionPool: def get_connection(
self, url: str | bytes, proxies: Mapping[str, str] | None = None
) -> SSHConnectionPool:
if not self.ssh_client: if not self.ssh_client:
return SSHConnectionPool( return SSHConnectionPool(
ssh_client=self.ssh_client, ssh_client=self.ssh_client,

View File

@ -33,7 +33,7 @@ class SSLHTTPAdapter(BaseHTTPAdapter):
def __init__( def __init__(
self, self,
assert_hostname: bool | None = None, assert_hostname: bool | None = None,
**kwargs, **kwargs: t.Any,
) -> None: ) -> None:
self.assert_hostname = assert_hostname self.assert_hostname = assert_hostname
super().__init__(**kwargs) super().__init__(**kwargs)
@ -51,7 +51,7 @@ class SSLHTTPAdapter(BaseHTTPAdapter):
self.poolmanager = PoolManager(**kwargs) self.poolmanager = PoolManager(**kwargs)
def get_connection(self, *args, **kwargs) -> urllib3.ConnectionPool: def get_connection(self, *args: t.Any, **kwargs: t.Any) -> urllib3.ConnectionPool:
""" """
Ensure assert_hostname is set correctly on our pool Ensure assert_hostname is set correctly on our pool

View File

@ -19,12 +19,20 @@ from .._import_helper import HTTPAdapter, urllib3, urllib3_connection
from .basehttpadapter import BaseHTTPAdapter from .basehttpadapter import BaseHTTPAdapter
if t.TYPE_CHECKING:
from collections.abc import Mapping
from requests import PreparedRequest
from ..._socket_helper import SocketLike
RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
class UnixHTTPConnection(urllib3_connection.HTTPConnection): class UnixHTTPConnection(urllib3_connection.HTTPConnection):
def __init__( def __init__(
self, base_url: str | bytes, unix_socket, timeout: int | float = 60 self, base_url: str | bytes, unix_socket: str, timeout: int | float = 60
) -> None: ) -> None:
super().__init__("localhost", timeout=timeout) super().__init__("localhost", timeout=timeout)
self.base_url = base_url self.base_url = base_url
@ -43,7 +51,7 @@ class UnixHTTPConnection(urllib3_connection.HTTPConnection):
if header == "Connection" and "Upgrade" in values: if header == "Connection" and "Upgrade" in values:
self.disable_buffering = True self.disable_buffering = True
def response_class(self, sock, *args, **kwargs) -> t.Any: def response_class(self, sock: SocketLike, *args: t.Any, **kwargs: t.Any) -> t.Any:
# FIXME: We may need to disable buffering on Py3, # FIXME: We may need to disable buffering on Py3,
# but there's no clear way to do it at the moment. See: # but there's no clear way to do it at the moment. See:
# https://github.com/docker/docker-py/issues/1799 # https://github.com/docker/docker-py/issues/1799
@ -88,12 +96,16 @@ class UnixHTTPAdapter(BaseHTTPAdapter):
self.socket_path = socket_path self.socket_path = socket_path
self.timeout = timeout self.timeout = timeout
self.max_pool_size = max_pool_size self.max_pool_size = max_pool_size
self.pools = RecentlyUsedContainer(
pool_connections, dispose_func=lambda p: p.close() def f(p: t.Any) -> None:
) p.close()
self.pools = RecentlyUsedContainer(pool_connections, dispose_func=f)
super().__init__() super().__init__()
def get_connection(self, url: str | bytes, proxies=None) -> UnixHTTPConnectionPool: def get_connection(
self, url: str | bytes, proxies: Mapping[str, str] | None = None
) -> UnixHTTPConnectionPool:
with self.pools.lock: with self.pools.lock:
pool = self.pools.get(url) pool = self.pools.get(url)
if pool: if pool:
@ -106,7 +118,7 @@ class UnixHTTPAdapter(BaseHTTPAdapter):
return pool return pool
def request_url(self, request, proxies) -> str: def request_url(self, request: PreparedRequest, proxies: Mapping[str, str]) -> str:
# The select_proxy utility in requests errors out when the provided URL # The select_proxy utility in requests errors out when the provided URL
# does not have a hostname, like is the case when using a UNIX socket. # does not have a hostname, like is the case when using a UNIX socket.
# Since proxies are an irrelevant notion in the case of UNIX sockets # Since proxies are an irrelevant notion in the case of UNIX sockets

View File

@ -18,7 +18,13 @@ from .._import_helper import urllib3
from ..errors import DockerException from ..errors import DockerException
class CancellableStream: if t.TYPE_CHECKING:
from requests import Response
_T = t.TypeVar("_T")
class CancellableStream(t.Generic[_T]):
""" """
Stream wrapper for real-time events, logs, etc. from the server. Stream wrapper for real-time events, logs, etc. from the server.
@ -30,14 +36,14 @@ class CancellableStream:
>>> events.close() >>> events.close()
""" """
def __init__(self, stream, response) -> None: def __init__(self, stream: t.Generator[_T], response: Response) -> None:
self._stream = stream self._stream = stream
self._response = response self._response = response
def __iter__(self) -> t.Self: def __iter__(self) -> t.Self:
return self return self
def __next__(self): def __next__(self) -> _T:
try: try:
return next(self._stream) return next(self._stream)
except urllib3.exceptions.ProtocolError as exc: except urllib3.exceptions.ProtocolError as exc:
@ -56,7 +62,7 @@ class CancellableStream:
# find the underlying socket object # find the underlying socket object
# based on api.client._get_raw_response_socket # based on api.client._get_raw_response_socket
sock_fp = self._response.raw._fp.fp sock_fp = self._response.raw._fp.fp # type: ignore
if hasattr(sock_fp, "raw"): if hasattr(sock_fp, "raw"):
sock_raw = sock_fp.raw sock_raw = sock_fp.raw
@ -74,7 +80,7 @@ class CancellableStream:
"Cancellable streams not supported for the SSH protocol" "Cancellable streams not supported for the SSH protocol"
) )
else: else:
sock = sock_fp._sock sock = sock_fp._sock # type: ignore
if hasattr(urllib3.contrib, "pyopenssl") and isinstance( if hasattr(urllib3.contrib, "pyopenssl") and isinstance(
sock, urllib3.contrib.pyopenssl.WrappedSocket sock, urllib3.contrib.pyopenssl.WrappedSocket

View File

@ -37,7 +37,7 @@ def _purge() -> None:
_cache.clear() _cache.clear()
def fnmatch(name: str, pat: str): def fnmatch(name: str, pat: str) -> bool:
"""Test whether FILENAME matches PATTERN. """Test whether FILENAME matches PATTERN.
Patterns are Unix shell style: Patterns are Unix shell style:

View File

@ -108,7 +108,7 @@ def port_range(
def split_port( def split_port(
port: str, port: str | int,
) -> tuple[list[str], list[str] | list[tuple[str, str | None]] | None]: ) -> tuple[list[str], list[str] | list[tuple[str, str | None]] | None]:
port = str(port) port = str(port)
match = PORT_SPEC.match(port) match = PORT_SPEC.match(port)

View File

@ -11,6 +11,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from .utils import format_environment from .utils import format_environment
@ -67,7 +69,17 @@ class ProxyConfig(dict):
env["no_proxy"] = env["NO_PROXY"] = self.no_proxy env["no_proxy"] = env["NO_PROXY"] = self.no_proxy
return env return env
def inject_proxy_environment(self, environment: list[str]) -> list[str]: @t.overload
def inject_proxy_environment(self, environment: list[str]) -> list[str]: ...
@t.overload
def inject_proxy_environment(
self, environment: list[str] | None
) -> list[str] | None: ...
def inject_proxy_environment(
self, environment: list[str] | None
) -> list[str] | None:
""" """
Given a list of strings representing environment variables, prepend the Given a list of strings representing environment variables, prepend the
environment variables corresponding to the proxy settings. environment variables corresponding to the proxy settings.

View File

@ -22,7 +22,9 @@ from ..transport.npipesocket import NpipeSocket
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from collections.abc import Iterable from collections.abc import Iterable, Sequence
from ..._socket_helper import SocketLike
STDOUT = 1 STDOUT = 1
@ -38,14 +40,14 @@ class SocketError(Exception):
NPIPE_ENDED = 109 NPIPE_ENDED = 109
def read(socket, n: int = 4096) -> bytes | None: def read(socket: SocketLike, n: int = 4096) -> bytes | None:
""" """
Reads at most n bytes from socket Reads at most n bytes from socket
""" """
recoverable_errors = (errno.EINTR, errno.EDEADLK, errno.EWOULDBLOCK) recoverable_errors = (errno.EINTR, errno.EDEADLK, errno.EWOULDBLOCK)
if not isinstance(socket, NpipeSocket): if not isinstance(socket, NpipeSocket): # type: ignore[unreachable]
if not hasattr(select, "poll"): if not hasattr(select, "poll"):
# Limited to 1024 # Limited to 1024
select.select([socket], [], []) select.select([socket], [], [])
@ -66,7 +68,7 @@ def read(socket, n: int = 4096) -> bytes | None:
return None # TODO ??? return None # TODO ???
except Exception as e: except Exception as e:
is_pipe_ended = ( is_pipe_ended = (
isinstance(socket, NpipeSocket) isinstance(socket, NpipeSocket) # type: ignore[unreachable]
and len(e.args) > 0 and len(e.args) > 0
and e.args[0] == NPIPE_ENDED and e.args[0] == NPIPE_ENDED
) )
@ -77,7 +79,7 @@ def read(socket, n: int = 4096) -> bytes | None:
raise raise
def read_exactly(socket, n: int) -> bytes: def read_exactly(socket: SocketLike, n: int) -> bytes:
""" """
Reads exactly n bytes from socket Reads exactly n bytes from socket
Raises SocketError if there is not enough data Raises SocketError if there is not enough data
@ -91,7 +93,7 @@ def read_exactly(socket, n: int) -> bytes:
return data return data
def next_frame_header(socket) -> tuple[int, int]: def next_frame_header(socket: SocketLike) -> tuple[int, int]:
""" """
Returns the stream and size of the next frame of data waiting to be read Returns the stream and size of the next frame of data waiting to be read
from socket, according to the protocol defined here: from socket, according to the protocol defined here:
@ -107,7 +109,7 @@ def next_frame_header(socket) -> tuple[int, int]:
return (stream, actual) return (stream, actual)
def frames_iter(socket, tty: bool) -> t.Generator[tuple[int, bytes]]: def frames_iter(socket: SocketLike, tty: bool) -> t.Generator[tuple[int, bytes]]:
""" """
Return a generator of frames read from socket. A frame is a tuple where Return a generator of frames read from socket. A frame is a tuple where
the first item is the stream number and the second item is a chunk of data. the first item is the stream number and the second item is a chunk of data.
@ -120,7 +122,7 @@ def frames_iter(socket, tty: bool) -> t.Generator[tuple[int, bytes]]:
return frames_iter_no_tty(socket) return frames_iter_no_tty(socket)
def frames_iter_no_tty(socket) -> t.Generator[tuple[int, bytes]]: def frames_iter_no_tty(socket: SocketLike) -> t.Generator[tuple[int, bytes]]:
""" """
Returns a generator of data read from the socket when the tty setting is Returns a generator of data read from the socket when the tty setting is
not enabled. not enabled.
@ -141,7 +143,7 @@ def frames_iter_no_tty(socket) -> t.Generator[tuple[int, bytes]]:
yield (stream, result) yield (stream, result)
def frames_iter_tty(socket) -> t.Generator[bytes]: def frames_iter_tty(socket: SocketLike) -> t.Generator[bytes]:
""" """
Return a generator of data read from the socket when the tty setting is Return a generator of data read from the socket when the tty setting is
enabled. enabled.
@ -155,20 +157,42 @@ def frames_iter_tty(socket) -> t.Generator[bytes]:
@t.overload @t.overload
def consume_socket_output(frames, demux: t.Literal[False] = False) -> bytes: ... def consume_socket_output(
frames: Sequence[bytes] | t.Generator[bytes], demux: t.Literal[False] = False
) -> bytes: ...
@t.overload
def consume_socket_output(frames, demux: t.Literal[True]) -> tuple[bytes, bytes]: ...
@t.overload @t.overload
def consume_socket_output( def consume_socket_output(
frames, demux: bool = False frames: (
Sequence[tuple[bytes | None, bytes | None]]
| t.Generator[tuple[bytes | None, bytes | None]]
),
demux: t.Literal[True],
) -> tuple[bytes, bytes]: ...
@t.overload
def consume_socket_output(
frames: (
Sequence[bytes]
| Sequence[tuple[bytes | None, bytes | None]]
| t.Generator[bytes]
| t.Generator[tuple[bytes | None, bytes | None]]
),
demux: bool = False,
) -> bytes | tuple[bytes, bytes]: ... ) -> bytes | tuple[bytes, bytes]: ...
def consume_socket_output(frames, demux: bool = False) -> bytes | tuple[bytes, bytes]: def consume_socket_output(
frames: (
Sequence[bytes]
| Sequence[tuple[bytes | None, bytes | None]]
| t.Generator[bytes]
| t.Generator[tuple[bytes | None, bytes | None]]
),
demux: bool = False,
) -> bytes | tuple[bytes, bytes]:
""" """
Iterate through frames read from the socket and return the result. Iterate through frames read from the socket and return the result.
@ -183,12 +207,13 @@ def consume_socket_output(frames, demux: bool = False) -> bytes | tuple[bytes, b
if demux is False: if demux is False:
# If the streams are multiplexed, the generator returns strings, that # If the streams are multiplexed, the generator returns strings, that
# we just need to concatenate. # we just need to concatenate.
return b"".join(frames) return b"".join(frames) # type: ignore
# If the streams are demultiplexed, the generator yields tuples # If the streams are demultiplexed, the generator yields tuples
# (stdout, stderr) # (stdout, stderr)
out: list[bytes | None] = [None, None] out: list[bytes | None] = [None, None]
for frame in frames: frame: tuple[bytes | None, bytes | None]
for frame in frames: # type: ignore
# It is guaranteed that for each frame, one and only one stream # It is guaranteed that for each frame, one and only one stream
# is not None. # is not None.
if frame == (None, None): if frame == (None, None):
@ -202,7 +227,7 @@ def consume_socket_output(frames, demux: bool = False) -> bytes | tuple[bytes, b
if out[1] is None: if out[1] is None:
out[1] = frame[1] out[1] = frame[1]
else: else:
out[1] += frame[1] out[1] += frame[1] # type: ignore[operator]
return tuple(out) # type: ignore return tuple(out) # type: ignore

View File

@ -46,7 +46,7 @@ URLComponents = collections.namedtuple(
) )
def decode_json_header(header: str) -> dict[str, t.Any]: def decode_json_header(header: str | bytes) -> dict[str, t.Any]:
data = base64.b64decode(header).decode("utf-8") data = base64.b64decode(header).decode("utf-8")
return json.loads(data) return json.loads(data)
@ -143,7 +143,12 @@ def convert_port_bindings(
def convert_volume_binds( def convert_volume_binds(
binds: list[str] | Mapping[str | bytes, dict[str, str | bytes] | bytes | str | int], binds: (
list[str]
| Mapping[
str | bytes, dict[str, str | bytes] | dict[str, str] | bytes | str | int
]
),
) -> list[str]: ) -> list[str]:
if isinstance(binds, list): if isinstance(binds, list):
return binds # type: ignore return binds # type: ignore
@ -403,7 +408,9 @@ def kwargs_from_env(
return params return params
def convert_filters(filters: Mapping[str, bool | str | list[str]]) -> str: def convert_filters(
filters: Mapping[str, bool | str | int | list[int] | list[str] | list[str | int]],
) -> str:
result = {} result = {}
for k, v in filters.items(): for k, v in filters.items():
if isinstance(v, bool): if isinstance(v, bool):
@ -495,8 +502,8 @@ def split_command(command: str) -> list[str]:
return shlex.split(command) return shlex.split(command)
def format_environment(environment: Mapping[str, str | bytes]) -> list[str]: def format_environment(environment: Mapping[str, str | bytes | None]) -> list[str]:
def format_env(key, value): def format_env(key: str, value: str | bytes | None) -> str:
if value is None: if value is None:
return key return key
if isinstance(value, bytes): if isinstance(value, bytes):

View File

@ -91,7 +91,7 @@ if not HAS_DOCKER_PY:
# No Docker SDK for Python. Create a place holder client to allow # No Docker SDK for Python. Create a place holder client to allow
# instantiation of AnsibleModule and proper error handing # instantiation of AnsibleModule and proper error handing
class Client: # type: ignore # noqa: F811, pylint: disable=function-redefined class Client: # type: ignore # noqa: F811, pylint: disable=function-redefined
def __init__(self, **kwargs): def __init__(self, **kwargs: t.Any) -> None:
pass pass
class APIError(Exception): # type: ignore # noqa: F811, pylint: disable=function-redefined class APIError(Exception): # type: ignore # noqa: F811, pylint: disable=function-redefined
@ -226,7 +226,7 @@ class AnsibleDockerClientBase(Client):
f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}." f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}."
) )
def log(self, msg: t.Any, pretty_print: bool = False): def log(self, msg: t.Any, pretty_print: bool = False) -> None:
pass pass
# if self.debug: # if self.debug:
# from .util import log_debug # from .util import log_debug
@ -609,7 +609,7 @@ class AnsibleDockerClientBase(Client):
return new_tag, old_tag == new_tag return new_tag, old_tag == new_tag
def inspect_distribution(self, image: str, **kwargs) -> dict[str, t.Any]: def inspect_distribution(self, image: str, **kwargs: t.Any) -> dict[str, t.Any]:
""" """
Get image digest by directly calling the Docker API when running Docker SDK < 4.0.0 Get image digest by directly calling the Docker API when running Docker SDK < 4.0.0
since prior versions did not support accessing private repositories. since prior versions did not support accessing private repositories.
@ -629,7 +629,6 @@ class AnsibleDockerClientBase(Client):
class AnsibleDockerClient(AnsibleDockerClientBase): class AnsibleDockerClient(AnsibleDockerClientBase):
def __init__( def __init__(
self, self,
argument_spec: dict[str, t.Any] | None = None, argument_spec: dict[str, t.Any] | None = None,
@ -651,7 +650,6 @@ class AnsibleDockerClient(AnsibleDockerClientBase):
option_minimal_versions_ignore_params: Sequence[str] | None = None, option_minimal_versions_ignore_params: Sequence[str] | None = None,
fail_results: dict[str, t.Any] | None = None, fail_results: dict[str, t.Any] | None = None,
): ):
# Modules can put information in here which will always be returned # Modules can put information in here which will always be returned
# in case client.fail() is called. # in case client.fail() is called.
self.fail_results = fail_results or {} self.fail_results = fail_results or {}

View File

@ -146,7 +146,7 @@ class AnsibleDockerClientBase(Client):
f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}." f"Docker API version is {self.docker_api_version_str}. Minimum version required is {min_docker_api_version}."
) )
def log(self, msg: t.Any, pretty_print: bool = False): def log(self, msg: t.Any, pretty_print: bool = False) -> None:
pass pass
# if self.debug: # if self.debug:
# from .util import log_debug # from .util import log_debug
@ -295,7 +295,7 @@ class AnsibleDockerClientBase(Client):
), ),
} }
def depr(*args, **kwargs): def depr(*args: t.Any, **kwargs: t.Any) -> None:
self.deprecate(*args, **kwargs) self.deprecate(*args, **kwargs)
update_tls_hostname( update_tls_hostname(

View File

@ -82,7 +82,7 @@ class AnsibleDockerClientBase:
def __init__( def __init__(
self, self,
common_args, common_args: dict[str, t.Any],
min_docker_api_version: str | None = None, min_docker_api_version: str | None = None,
needs_api_version: bool = True, needs_api_version: bool = True,
) -> None: ) -> None:
@ -91,15 +91,15 @@ class AnsibleDockerClientBase:
self._environment["DOCKER_TLS_HOSTNAME"] = common_args["tls_hostname"] self._environment["DOCKER_TLS_HOSTNAME"] = common_args["tls_hostname"]
if common_args["api_version"] and common_args["api_version"] != "auto": if common_args["api_version"] and common_args["api_version"] != "auto":
self._environment["DOCKER_API_VERSION"] = common_args["api_version"] self._environment["DOCKER_API_VERSION"] = common_args["api_version"]
self._cli = common_args.get("docker_cli") cli = common_args.get("docker_cli")
if self._cli is None: if cli is None:
try: try:
self._cli = get_bin_path("docker") cli = get_bin_path("docker")
except ValueError: except ValueError:
self.fail( self.fail(
"Cannot find docker CLI in path. Please provide it explicitly with the docker_cli parameter" "Cannot find docker CLI in path. Please provide it explicitly with the docker_cli parameter"
) )
self._cli = cli
self._cli_base = [self._cli] self._cli_base = [self._cli]
docker_host = common_args["docker_host"] docker_host = common_args["docker_host"]
if not docker_host and not common_args["cli_context"]: if not docker_host and not common_args["cli_context"]:
@ -149,7 +149,7 @@ class AnsibleDockerClientBase:
"Internal error: cannot have needs_api_version=False with min_docker_api_version not None" "Internal error: cannot have needs_api_version=False with min_docker_api_version not None"
) )
def log(self, msg: str, pretty_print: bool = False): def log(self, msg: str, pretty_print: bool = False) -> None:
pass pass
# if self.debug: # if self.debug:
# from .util import log_debug # from .util import log_debug
@ -227,7 +227,7 @@ class AnsibleDockerClientBase:
return rc, result, stderr return rc, result, stderr
@abc.abstractmethod @abc.abstractmethod
def fail(self, msg: str, **kwargs) -> t.NoReturn: def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
pass pass
@abc.abstractmethod @abc.abstractmethod
@ -395,7 +395,6 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase):
fail_results: dict[str, t.Any] | None = None, fail_results: dict[str, t.Any] | None = None,
needs_api_version: bool = True, needs_api_version: bool = True,
) -> None: ) -> None:
# Modules can put information in here which will always be returned # Modules can put information in here which will always be returned
# in case client.fail() is called. # in case client.fail() is called.
self.fail_results = fail_results or {} self.fail_results = fail_results or {}
@ -463,7 +462,7 @@ class AnsibleModuleDockerClient(AnsibleDockerClientBase):
) )
return rc, stdout, stderr return rc, stdout, stderr
def fail(self, msg: str, **kwargs) -> t.NoReturn: def fail(self, msg: str, **kwargs: t.Any) -> t.NoReturn:
self.fail_results.update(kwargs) self.fail_results.update(kwargs)
self.module.fail_json(msg=msg, **sanitize_result(self.fail_results)) self.module.fail_json(msg=msg, **sanitize_result(self.fail_results))

View File

@ -971,7 +971,7 @@ class BaseComposeManager(DockerBaseClass):
stderr: str | bytes, stderr: str | bytes,
ignore_service_pull_events: bool = False, ignore_service_pull_events: bool = False,
ignore_build_events: bool = False, ignore_build_events: bool = False,
): ) -> None:
result["changed"] = result.get("changed", False) or has_changes( result["changed"] = result.get("changed", False) or has_changes(
events, events,
ignore_service_pull_events=ignore_service_pull_events, ignore_service_pull_events=ignore_service_pull_events,
@ -989,7 +989,7 @@ class BaseComposeManager(DockerBaseClass):
stdout: str | bytes, stdout: str | bytes,
stderr: bytes, stderr: bytes,
rc: int, rc: int,
): ) -> bool:
return update_failed( return update_failed(
result, result,
events, events,

View File

@ -330,6 +330,8 @@ def stat_file(
client._raise_for_status(response) client._raise_for_status(response)
header = response.headers.get("x-docker-container-path-stat") header = response.headers.get("x-docker-container-path-stat")
try: try:
if header is None:
raise ValueError("x-docker-container-path-stat header not present")
stat_data = json.loads(base64.b64decode(header)) stat_data = json.loads(base64.b64decode(header))
except Exception as exc: except Exception as exc:
raise DockerUnexpectedError( raise DockerUnexpectedError(
@ -482,14 +484,14 @@ def fetch_file(
shutil.copyfileobj(in_f, out_f) shutil.copyfileobj(in_f, out_f)
return in_path return in_path
def process_symlink(in_path, member) -> str: def process_symlink(in_path: str, member: tarfile.TarInfo) -> str:
if os.path.exists(b_out_path): if os.path.exists(b_out_path):
os.unlink(b_out_path) os.unlink(b_out_path)
os.symlink(member.linkname, b_out_path) os.symlink(member.linkname, b_out_path)
return in_path return in_path
def process_other(in_path, member) -> str: def process_other(in_path: str, member: tarfile.TarInfo) -> str:
raise DockerFileCopyError( raise DockerFileCopyError(
f'Remote file "{in_path}" is not a regular file or a symbolic link' f'Remote file "{in_path}" is not a regular file or a symbolic link'
) )

View File

@ -193,7 +193,9 @@ class OptionGroup:
) -> None: ) -> None:
if preprocess is None: if preprocess is None:
def preprocess(module, values): def preprocess(
module: AnsibleModule, values: dict[str, t.Any]
) -> dict[str, t.Any]:
return values return values
self.preprocess = preprocess self.preprocess = preprocess
@ -207,8 +209,8 @@ class OptionGroup:
self.ansible_required_by = ansible_required_by or {} self.ansible_required_by = ansible_required_by or {}
self.argument_spec: dict[str, t.Any] = {} self.argument_spec: dict[str, t.Any] = {}
def add_option(self, *args, **kwargs) -> OptionGroup: def add_option(self, name: str, **kwargs: t.Any) -> OptionGroup:
option = Option(*args, owner=self, **kwargs) option = Option(name, owner=self, **kwargs)
if not option.not_a_container_option: if not option.not_a_container_option:
self.options.append(option) self.options.append(option)
self.all_options.append(option) self.all_options.append(option)
@ -788,7 +790,7 @@ def _preprocess_mounts(
) -> dict[str, t.Any]: ) -> dict[str, t.Any]:
last: dict[str, str] = {} last: dict[str, str] = {}
def check_collision(t, name): def check_collision(t: str, name: str) -> None:
if t in last: if t in last:
if name == last[t]: if name == last[t]:
module.fail_json( module.fail_json(
@ -1069,7 +1071,9 @@ def _preprocess_ports(
return values return values
def _compare_platform(option: Option, param_value: t.Any, container_value: t.Any): def _compare_platform(
option: Option, param_value: t.Any, container_value: t.Any
) -> bool:
if option.comparison == "ignore": if option.comparison == "ignore":
return True return True
try: try:

View File

@ -872,7 +872,7 @@ class DockerAPIEngine(Engine[AnsibleDockerClient]):
image: dict[str, t.Any] | None, image: dict[str, t.Any] | None,
values: dict[str, t.Any], values: dict[str, t.Any],
host_info: dict[str, t.Any] | None, host_info: dict[str, t.Any] | None,
): ) -> dict[str, t.Any]:
if len(options) != 1: if len(options) != 1:
raise AssertionError( raise AssertionError(
"host_config_value can only be used for a single option" "host_config_value can only be used for a single option"
@ -1961,7 +1961,14 @@ def _update_value_restart(
} }
def _get_values_ports(module, container, api_version, options, image, host_info): def _get_values_ports(
module: AnsibleModule,
container: dict[str, t.Any],
api_version: LooseVersion,
options: list[Option],
image: dict[str, t.Any] | None,
host_info: dict[str, t.Any] | None,
) -> dict[str, t.Any]:
host_config = container["HostConfig"] host_config = container["HostConfig"]
config = container["Config"] config = container["Config"]

View File

@ -292,7 +292,7 @@ class ContainerManager(DockerBaseClass, t.Generic[Client]):
if self.module.params[param] is None: if self.module.params[param] is None:
self.module.params[param] = value self.module.params[param] = value
def fail(self, *args, **kwargs) -> t.NoReturn: def fail(self, *args: str, **kwargs: t.Any) -> t.NoReturn:
# mypy doesn't know that Client has fail() method # mypy doesn't know that Client has fail() method
raise self.client.fail(*args, **kwargs) # type: ignore raise self.client.fail(*args, **kwargs) # type: ignore
@ -714,7 +714,7 @@ class ContainerManager(DockerBaseClass, t.Generic[Client]):
container_image: dict[str, t.Any] | None, container_image: dict[str, t.Any] | None,
image: dict[str, t.Any] | None, image: dict[str, t.Any] | None,
host_info: dict[str, t.Any] | None, host_info: dict[str, t.Any] | None,
): ) -> None:
assert container.raw is not None assert container.raw is not None
container_values = engine.get_value( container_values = engine.get_value(
self.module, self.module,
@ -767,12 +767,12 @@ class ContainerManager(DockerBaseClass, t.Generic[Client]):
# Since the order does not matter, sort so that the diff output is better. # Since the order does not matter, sort so that the diff output is better.
if option.name == "expected_mounts": if option.name == "expected_mounts":
# For selected values, use one entry as key # For selected values, use one entry as key
def sort_key_fn(x): def sort_key_fn(x: dict[str, t.Any]) -> t.Any:
return x["target"] return x["target"]
else: else:
# We sort the list of dictionaries by using the sorted items of a dict as its key. # We sort the list of dictionaries by using the sorted items of a dict as its key.
def sort_key_fn(x): def sort_key_fn(x: dict[str, t.Any]) -> t.Any:
return sorted( return sorted(
(a, to_text(b, errors="surrogate_or_strict")) (a, to_text(b, errors="surrogate_or_strict"))
for a, b in x.items() for a, b in x.items()

View File

@ -26,6 +26,7 @@ from ansible_collections.community.docker.plugins.module_utils._socket_helper im
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from collections.abc import Callable from collections.abc import Callable
from types import TracebackType
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
@ -70,7 +71,12 @@ class DockerSocketHandlerBase:
def __enter__(self) -> t.Self: def __enter__(self) -> t.Self:
return self return self
def __exit__(self, type_, value, tb) -> None: def __exit__(
self,
type_: t.Type[BaseException] | None,
value: BaseException | None,
tb: TracebackType | None,
) -> None:
self._selector.close() self._selector.close()
def set_block_done_callback( def set_block_done_callback(
@ -210,7 +216,7 @@ class DockerSocketHandlerBase:
stdout = [] stdout = []
stderr = [] stderr = []
def append_block(stream_id, data): def append_block(stream_id: int, data: bytes) -> None:
if stream_id == docker_socket.STDOUT: if stream_id == docker_socket.STDOUT:
stdout.append(data) stdout.append(data)
elif stream_id == docker_socket.STDERR: elif stream_id == docker_socket.STDERR:

View File

@ -28,7 +28,6 @@ from ansible_collections.community.docker.plugins.module_utils._version import (
class AnsibleDockerSwarmClient(AnsibleDockerClient): class AnsibleDockerSwarmClient(AnsibleDockerClient):
def get_swarm_node_id(self) -> str | None: def get_swarm_node_id(self) -> str | None:
""" """
Get the 'NodeID' of the Swarm node or 'None' if host is not in Swarm. It returns the NodeID Get the 'NodeID' of the Swarm node or 'None' if host is not in Swarm. It returns the NodeID
@ -281,7 +280,7 @@ class AnsibleDockerSwarmClient(AnsibleDockerClient):
def get_node_name_by_id(self, nodeid: str) -> str: def get_node_name_by_id(self, nodeid: str) -> str:
return self.get_node_inspect(nodeid)["Description"]["Hostname"] return self.get_node_inspect(nodeid)["Description"]["Hostname"]
def get_unlock_key(self) -> str | None: def get_unlock_key(self) -> dict[str, t.Any] | None:
if self.docker_py_version < LooseVersion("2.7.0"): if self.docker_py_version < LooseVersion("2.7.0"):
return None return None
return super().get_unlock_key() return super().get_unlock_key()

View File

@ -23,6 +23,12 @@ if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
from ._common import AnsibleDockerClientBase as CADCB
from ._common_api import AnsibleDockerClientBase as CAPIADCB
from ._common_cli import AnsibleDockerClientBase as CCLIADCB
Client = t.Union[CADCB, CAPIADCB, CCLIADCB]
DEFAULT_DOCKER_HOST = "unix:///var/run/docker.sock" DEFAULT_DOCKER_HOST = "unix:///var/run/docker.sock"
DEFAULT_TLS = False DEFAULT_TLS = False
@ -119,7 +125,7 @@ def sanitize_result(data: t.Any) -> t.Any:
return data return data
def log_debug(msg: t.Any, pretty_print: bool = False): def log_debug(msg: t.Any, pretty_print: bool = False) -> None:
"""Write a log message to docker.log. """Write a log message to docker.log.
If ``pretty_print=True``, the message will be pretty-printed as JSON. If ``pretty_print=True``, the message will be pretty-printed as JSON.
@ -325,7 +331,7 @@ class DifferenceTracker:
def sanitize_labels( def sanitize_labels(
labels: dict[str, t.Any] | None, labels: dict[str, t.Any] | None,
labels_field: str, labels_field: str,
client=None, client: Client | None = None,
module: AnsibleModule | None = None, module: AnsibleModule | None = None,
) -> None: ) -> None:
def fail(msg: str) -> t.NoReturn: def fail(msg: str) -> t.NoReturn:
@ -371,7 +377,7 @@ def clean_dict_booleans_for_docker_api(
which is the expected format of filters which accept lists such as labels. which is the expected format of filters which accept lists such as labels.
""" """
def sanitize(value): def sanitize(value: t.Any) -> str:
if value is True: if value is True:
return "true" return "true"
if value is False: if value is False:

View File

@ -147,7 +147,7 @@ class PullManager(BaseComposeManager):
f"--ignore-buildable is only supported since Docker Compose 2.15.0. {self.client.get_cli()} has version {self.compose_version}" f"--ignore-buildable is only supported since Docker Compose 2.15.0. {self.client.get_cli()} has version {self.compose_version}"
) )
def get_pull_cmd(self, dry_run: bool): def get_pull_cmd(self, dry_run: bool) -> list[str]:
args = self.get_base_args() + ["pull"] args = self.get_base_args() + ["pull"]
if self.policy != "always": if self.policy != "always":
args.extend(["--policy", self.policy]) args.extend(["--policy", self.policy])

View File

@ -198,6 +198,7 @@ config_name:
import base64 import base64
import hashlib import hashlib
import traceback import traceback
import typing as t
try: try:
@ -220,9 +221,7 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
class ConfigManager(DockerBaseClass): class ConfigManager(DockerBaseClass):
def __init__(self, client: AnsibleDockerClient, results: dict[str, t.Any]) -> None:
def __init__(self, client, results):
super().__init__() super().__init__()
self.client = client self.client = client
@ -253,10 +252,10 @@ class ConfigManager(DockerBaseClass):
if self.rolling_versions: if self.rolling_versions:
self.version = 0 self.version = 0
self.data_key = None self.data_key: str | None = None
self.configs = [] self.configs: list[dict[str, t.Any]] = []
def __call__(self): def __call__(self) -> None:
self.get_config() self.get_config()
if self.state == "present": if self.state == "present":
self.data_key = hashlib.sha224(self.data).hexdigest() self.data_key = hashlib.sha224(self.data).hexdigest()
@ -265,7 +264,7 @@ class ConfigManager(DockerBaseClass):
elif self.state == "absent": elif self.state == "absent":
self.absent() self.absent()
def get_version(self, config): def get_version(self, config: dict[str, t.Any]) -> int:
try: try:
return int( return int(
config.get("Spec", {}).get("Labels", {}).get("ansible_version", 0) config.get("Spec", {}).get("Labels", {}).get("ansible_version", 0)
@ -273,14 +272,14 @@ class ConfigManager(DockerBaseClass):
except ValueError: except ValueError:
return 0 return 0
def remove_old_versions(self): def remove_old_versions(self) -> None:
if not self.rolling_versions or self.versions_to_keep < 0: if not self.rolling_versions or self.versions_to_keep < 0:
return return
if not self.check_mode: if not self.check_mode:
while len(self.configs) > max(self.versions_to_keep, 1): while len(self.configs) > max(self.versions_to_keep, 1):
self.remove_config(self.configs.pop(0)) self.remove_config(self.configs.pop(0))
def get_config(self): def get_config(self) -> None:
"""Find an existing config.""" """Find an existing config."""
try: try:
configs = self.client.configs(filters={"name": self.name}) configs = self.client.configs(filters={"name": self.name})
@ -299,9 +298,9 @@ class ConfigManager(DockerBaseClass):
config for config in configs if config["Spec"]["Name"] == self.name config for config in configs if config["Spec"]["Name"] == self.name
] ]
def create_config(self): def create_config(self) -> str | None:
"""Create a new config""" """Create a new config"""
config_id = None config_id: str | dict[str, t.Any] | None = None
# We ca not see the data after creation, so adding a label we can use for idempotency check # We ca not see the data after creation, so adding a label we can use for idempotency check
labels = {"ansible_key": self.data_key} labels = {"ansible_key": self.data_key}
if self.rolling_versions: if self.rolling_versions:
@ -325,18 +324,18 @@ class ConfigManager(DockerBaseClass):
self.client.fail(f"Error creating config: {exc}") self.client.fail(f"Error creating config: {exc}")
if isinstance(config_id, dict): if isinstance(config_id, dict):
config_id = config_id["ID"] return config_id["ID"]
return config_id return config_id
def remove_config(self, config): def remove_config(self, config: dict[str, t.Any]) -> None:
try: try:
if not self.check_mode: if not self.check_mode:
self.client.remove_config(config["ID"]) self.client.remove_config(config["ID"])
except APIError as exc: except APIError as exc:
self.client.fail(f"Error removing config {config['Spec']['Name']}: {exc}") self.client.fail(f"Error removing config {config['Spec']['Name']}: {exc}")
def present(self): def present(self) -> None:
"""Handles state == 'present', creating or updating the config""" """Handles state == 'present', creating or updating the config"""
if self.configs: if self.configs:
config = self.configs[-1] config = self.configs[-1]
@ -378,7 +377,7 @@ class ConfigManager(DockerBaseClass):
self.results["config_id"] = self.create_config() self.results["config_id"] = self.create_config()
self.results["config_name"] = self.name self.results["config_name"] = self.name
def absent(self): def absent(self) -> None:
"""Handles state == 'absent', removing the config""" """Handles state == 'absent', removing the config"""
if self.configs: if self.configs:
for config in self.configs: for config in self.configs:
@ -386,7 +385,7 @@ class ConfigManager(DockerBaseClass):
self.results["changed"] = True self.results["changed"] = True
def main(): def main() -> None:
argument_spec = { argument_spec = {
"name": {"type": "str", "required": True}, "name": {"type": "str", "required": True},
"state": { "state": {

View File

@ -347,7 +347,7 @@ def retrieve_diff(
max_file_size_for_diff: int, max_file_size_for_diff: int,
regular_stat: dict[str, t.Any] | None = None, regular_stat: dict[str, t.Any] | None = None,
link_target: str | None = None, link_target: str | None = None,
): ) -> None:
if diff is None: if diff is None:
return return
if regular_stat is not None: if regular_stat is not None:
@ -497,9 +497,9 @@ def is_file_idempotent(
container_path: str, container_path: str,
follow_links: bool, follow_links: bool,
local_follow_links: bool, local_follow_links: bool,
owner_id, owner_id: int,
group_id, group_id: int,
mode, mode: int | None,
force: bool | None = False, force: bool | None = False,
diff: dict[str, t.Any] | None = None, diff: dict[str, t.Any] | None = None,
max_file_size_for_diff: int = 1, max_file_size_for_diff: int = 1,
@ -744,9 +744,9 @@ def copy_file_into_container(
container_path: str, container_path: str,
follow_links: bool, follow_links: bool,
local_follow_links: bool, local_follow_links: bool,
owner_id, owner_id: int,
group_id, group_id: int,
mode, mode: int | None,
force: bool | None = False, force: bool | None = False,
do_diff: bool = False, do_diff: bool = False,
max_file_size_for_diff: int = 1, max_file_size_for_diff: int = 1,
@ -797,9 +797,9 @@ def is_content_idempotent(
content: bytes, content: bytes,
container_path: str, container_path: str,
follow_links: bool, follow_links: bool,
owner_id, owner_id: int,
group_id, group_id: int,
mode, mode: int,
force: bool | None = False, force: bool | None = False,
diff: dict[str, t.Any] | None = None, diff: dict[str, t.Any] | None = None,
max_file_size_for_diff: int = 1, max_file_size_for_diff: int = 1,
@ -989,9 +989,9 @@ def copy_content_into_container(
content: bytes, content: bytes,
container_path: str, container_path: str,
follow_links: bool, follow_links: bool,
owner_id, owner_id: int,
group_id, group_id: int,
mode, mode: int,
force: bool | None = False, force: bool | None = False,
do_diff: bool = False, do_diff: bool = False,
max_file_size_for_diff: int = 1, max_file_size_for_diff: int = 1,
@ -1133,6 +1133,7 @@ def main() -> None:
owner_id, group_id = determine_user_group(client, container) owner_id, group_id = determine_user_group(client, container)
if content is not None: if content is not None:
assert mode is not None # see required_by above
copy_content_into_container( copy_content_into_container(
client, client,
container, container,

View File

@ -667,7 +667,7 @@ class ImageManager(DockerBaseClass):
:rtype: str :rtype: str
""" """
def build_msg(reason): def build_msg(reason: str) -> str:
return f"Archived image {current_image_name} to {archive_path}, {reason}" return f"Archived image {current_image_name} to {archive_path}, {reason}"
try: try:
@ -877,7 +877,7 @@ class ImageManager(DockerBaseClass):
self.push_image(repo, repo_tag) self.push_image(repo, repo_tag)
@staticmethod @staticmethod
def _extract_output_line(line: dict[str, t.Any], output: list[str]): def _extract_output_line(line: dict[str, t.Any], output: list[str]) -> None:
""" """
Extract text line from stream output and, if found, adds it to output. Extract text line from stream output and, if found, adds it to output.
""" """
@ -1165,18 +1165,18 @@ def main() -> None:
("source", "load", ["load_path"]), ("source", "load", ["load_path"]),
] ]
def detect_etc_hosts(client): def detect_etc_hosts(client: AnsibleDockerClient) -> bool:
return client.module.params["build"] and bool( return client.module.params["build"] and bool(
client.module.params["build"].get("etc_hosts") client.module.params["build"].get("etc_hosts")
) )
def detect_build_platform(client): def detect_build_platform(client: AnsibleDockerClient) -> bool:
return ( return (
client.module.params["build"] client.module.params["build"]
and client.module.params["build"].get("platform") is not None and client.module.params["build"].get("platform") is not None
) )
def detect_pull_platform(client): def detect_pull_platform(client: AnsibleDockerClient) -> bool:
return ( return (
client.module.params["pull"] client.module.params["pull"]
and client.module.params["pull"].get("platform") is not None and client.module.params["pull"].get("platform") is not None

View File

@ -379,7 +379,7 @@ def normalize_ipam_config_key(key: str) -> str:
return special_cases.get(key, key.lower()) return special_cases.get(key, key.lower())
def dicts_are_essentially_equal(a: dict[str, t.Any], b: dict[str, t.Any]): def dicts_are_essentially_equal(a: dict[str, t.Any], b: dict[str, t.Any]) -> bool:
"""Make sure that a is a subset of b, where None entries of a are ignored.""" """Make sure that a is a subset of b, where None entries of a are ignored."""
for k, v in a.items(): for k, v in a.items():
if v is None: if v is None:

View File

@ -134,6 +134,7 @@ node:
""" """
import traceback import traceback
import typing as t
try: try:
@ -157,18 +158,19 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
class TaskParameters(DockerBaseClass): class TaskParameters(DockerBaseClass):
def __init__(self, client): hostname: str
def __init__(self, client: AnsibleDockerSwarmClient) -> None:
super().__init__() super().__init__()
# Spec # Spec
self.name = None self.labels: dict[str, t.Any] | None = None
self.labels = None self.labels_state: t.Literal["merge", "replace"] = "merge"
self.labels_state = None self.labels_to_remove: list[str] | None = None
self.labels_to_remove = None
# Node # Node
self.availability = None self.availability: t.Literal["active", "pause", "drain"] | None = None
self.role = None self.role: t.Literal["worker", "manager"] | None = None
for key, value in client.module.params.items(): for key, value in client.module.params.items():
setattr(self, key, value) setattr(self, key, value)
@ -177,9 +179,9 @@ class TaskParameters(DockerBaseClass):
class SwarmNodeManager(DockerBaseClass): class SwarmNodeManager(DockerBaseClass):
def __init__(
def __init__(self, client, results): self, client: AnsibleDockerSwarmClient, results: dict[str, t.Any]
) -> None:
super().__init__() super().__init__()
self.client = client self.client = client
@ -192,10 +194,9 @@ class SwarmNodeManager(DockerBaseClass):
self.node_update() self.node_update()
def node_update(self): def node_update(self) -> None:
if not (self.client.check_if_swarm_node(node_id=self.parameters.hostname)): if not (self.client.check_if_swarm_node(node_id=self.parameters.hostname)):
self.client.fail("This node is not part of a swarm.") self.client.fail("This node is not part of a swarm.")
return
if self.client.check_if_swarm_node_is_down(): if self.client.check_if_swarm_node_is_down():
self.client.fail("Can not update the node. The node is down.") self.client.fail("Can not update the node. The node is down.")
@ -206,7 +207,7 @@ class SwarmNodeManager(DockerBaseClass):
self.client.fail(f"Failed to get node information for {exc}") self.client.fail(f"Failed to get node information for {exc}")
changed = False changed = False
node_spec = { node_spec: dict[str, t.Any] = {
"Availability": self.parameters.availability, "Availability": self.parameters.availability,
"Role": self.parameters.role, "Role": self.parameters.role,
"Labels": self.parameters.labels, "Labels": self.parameters.labels,
@ -277,7 +278,7 @@ class SwarmNodeManager(DockerBaseClass):
self.results["changed"] = changed self.results["changed"] = changed
def main(): def main() -> None:
argument_spec = { argument_spec = {
"hostname": {"type": "str", "required": True}, "hostname": {"type": "str", "required": True},
"labels": {"type": "dict"}, "labels": {"type": "dict"},

View File

@ -87,6 +87,7 @@ nodes:
""" """
import traceback import traceback
import typing as t
from ansible_collections.community.docker.plugins.module_utils._common import ( from ansible_collections.community.docker.plugins.module_utils._common import (
RequestException, RequestException,
@ -103,9 +104,8 @@ except ImportError:
pass pass
def get_node_facts(client): def get_node_facts(client: AnsibleDockerSwarmClient) -> list[dict[str, t.Any]]:
results: list[dict[str, t.Any]] = []
results = []
if client.module.params["self"] is True: if client.module.params["self"] is True:
self_node_id = client.get_swarm_node_id() self_node_id = client.get_swarm_node_id()
@ -114,8 +114,8 @@ def get_node_facts(client):
return results return results
if client.module.params["name"] is None: if client.module.params["name"] is None:
node_info = client.get_all_nodes_inspect() node_info_list = client.get_all_nodes_inspect()
return node_info return node_info_list
nodes = client.module.params["name"] nodes = client.module.params["name"]
if not isinstance(nodes, list): if not isinstance(nodes, list):
@ -130,7 +130,7 @@ def get_node_facts(client):
return results return results
def main(): def main() -> None:
argument_spec = { argument_spec = {
"name": {"type": "list", "elements": "str"}, "name": {"type": "list", "elements": "str"},
"self": {"type": "bool", "default": False}, "self": {"type": "bool", "default": False},

View File

@ -204,12 +204,10 @@ class DockerPluginManager:
elif state == "disable": elif state == "disable":
self.disable() self.disable()
if self.diff or self.check_mode or self.parameters.debug: if self.diff:
if self.diff: self.diff_result["before"], self.diff_result["after"] = (
self.diff_result["before"], self.diff_result["after"] = ( self.diff_tracker.get_before_after()
self.diff_tracker.get_before_after() )
)
self.diff = self.diff_result
def get_existing_plugin(self) -> dict[str, t.Any] | None: def get_existing_plugin(self) -> dict[str, t.Any] | None:
try: try:
@ -409,7 +407,7 @@ class DockerPluginManager:
result: dict[str, t.Any] = { result: dict[str, t.Any] = {
"actions": self.actions, "actions": self.actions,
"changed": self.changed, "changed": self.changed,
"diff": self.diff, "diff": self.diff_result,
"plugin": plugin_data, "plugin": plugin_data,
} }
if ( if (

View File

@ -190,6 +190,7 @@ secret_name:
import base64 import base64
import hashlib import hashlib
import traceback import traceback
import typing as t
try: try:
@ -212,9 +213,7 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
class SecretManager(DockerBaseClass): class SecretManager(DockerBaseClass):
def __init__(self, client: AnsibleDockerClient, results: dict[str, t.Any]) -> None:
def __init__(self, client, results):
super().__init__() super().__init__()
self.client = client self.client = client
@ -244,10 +243,10 @@ class SecretManager(DockerBaseClass):
if self.rolling_versions: if self.rolling_versions:
self.version = 0 self.version = 0
self.data_key = None self.data_key: str | None = None
self.secrets = [] self.secrets: list[dict[str, t.Any]] = []
def __call__(self): def __call__(self) -> None:
self.get_secret() self.get_secret()
if self.state == "present": if self.state == "present":
self.data_key = hashlib.sha224(self.data).hexdigest() self.data_key = hashlib.sha224(self.data).hexdigest()
@ -256,7 +255,7 @@ class SecretManager(DockerBaseClass):
elif self.state == "absent": elif self.state == "absent":
self.absent() self.absent()
def get_version(self, secret): def get_version(self, secret: dict[str, t.Any]) -> int:
try: try:
return int( return int(
secret.get("Spec", {}).get("Labels", {}).get("ansible_version", 0) secret.get("Spec", {}).get("Labels", {}).get("ansible_version", 0)
@ -264,14 +263,14 @@ class SecretManager(DockerBaseClass):
except ValueError: except ValueError:
return 0 return 0
def remove_old_versions(self): def remove_old_versions(self) -> None:
if not self.rolling_versions or self.versions_to_keep < 0: if not self.rolling_versions or self.versions_to_keep < 0:
return return
if not self.check_mode: if not self.check_mode:
while len(self.secrets) > max(self.versions_to_keep, 1): while len(self.secrets) > max(self.versions_to_keep, 1):
self.remove_secret(self.secrets.pop(0)) self.remove_secret(self.secrets.pop(0))
def get_secret(self): def get_secret(self) -> None:
"""Find an existing secret.""" """Find an existing secret."""
try: try:
secrets = self.client.secrets(filters={"name": self.name}) secrets = self.client.secrets(filters={"name": self.name})
@ -290,9 +289,9 @@ class SecretManager(DockerBaseClass):
secret for secret in secrets if secret["Spec"]["Name"] == self.name secret for secret in secrets if secret["Spec"]["Name"] == self.name
] ]
def create_secret(self): def create_secret(self) -> str | None:
"""Create a new secret""" """Create a new secret"""
secret_id = None secret_id: str | dict[str, t.Any] | None = None
# We cannot see the data after creation, so adding a label we can use for idempotency check # We cannot see the data after creation, so adding a label we can use for idempotency check
labels = {"ansible_key": self.data_key} labels = {"ansible_key": self.data_key}
if self.rolling_versions: if self.rolling_versions:
@ -312,18 +311,18 @@ class SecretManager(DockerBaseClass):
self.client.fail(f"Error creating secret: {exc}") self.client.fail(f"Error creating secret: {exc}")
if isinstance(secret_id, dict): if isinstance(secret_id, dict):
secret_id = secret_id["ID"] return secret_id["ID"]
return secret_id return secret_id
def remove_secret(self, secret): def remove_secret(self, secret: dict[str, t.Any]) -> None:
try: try:
if not self.check_mode: if not self.check_mode:
self.client.remove_secret(secret["ID"]) self.client.remove_secret(secret["ID"])
except APIError as exc: except APIError as exc:
self.client.fail(f"Error removing secret {secret['Spec']['Name']}: {exc}") self.client.fail(f"Error removing secret {secret['Spec']['Name']}: {exc}")
def present(self): def present(self) -> None:
"""Handles state == 'present', creating or updating the secret""" """Handles state == 'present', creating or updating the secret"""
if self.secrets: if self.secrets:
secret = self.secrets[-1] secret = self.secrets[-1]
@ -357,7 +356,7 @@ class SecretManager(DockerBaseClass):
self.results["secret_id"] = self.create_secret() self.results["secret_id"] = self.create_secret()
self.results["secret_name"] = self.name self.results["secret_name"] = self.name
def absent(self): def absent(self) -> None:
"""Handles state == 'absent', removing the secret""" """Handles state == 'absent', removing the secret"""
if self.secrets: if self.secrets:
for secret in self.secrets: for secret in self.secrets:
@ -365,7 +364,7 @@ class SecretManager(DockerBaseClass):
self.results["changed"] = True self.results["changed"] = True
def main(): def main() -> None:
argument_spec = { argument_spec = {
"name": {"type": "str", "required": True}, "name": {"type": "str", "required": True},
"state": { "state": {

View File

@ -158,6 +158,7 @@ import json
import os import os
import tempfile import tempfile
import traceback import traceback
import typing as t
from time import sleep from time import sleep
from ansible.module_utils.common.text.converters import to_native from ansible.module_utils.common.text.converters import to_native
@ -183,7 +184,9 @@ except ImportError:
HAS_YAML = False HAS_YAML = False
def docker_stack_services(client, stack_name): def docker_stack_services(
client: AnsibleModuleDockerClient, stack_name: str
) -> list[str]:
dummy_rc, out, err = client.call_cli( dummy_rc, out, err = client.call_cli(
"stack", "services", stack_name, "--format", "{{.Name}}" "stack", "services", stack_name, "--format", "{{.Name}}"
) )
@ -192,7 +195,9 @@ def docker_stack_services(client, stack_name):
return to_native(out).strip().split("\n") return to_native(out).strip().split("\n")
def docker_service_inspect(client, service_name): def docker_service_inspect(
client: AnsibleModuleDockerClient, service_name: str
) -> dict[str, t.Any] | None:
rc, out, dummy_err = client.call_cli("service", "inspect", service_name) rc, out, dummy_err = client.call_cli("service", "inspect", service_name)
if rc != 0: if rc != 0:
return None return None
@ -200,7 +205,9 @@ def docker_service_inspect(client, service_name):
return ret return ret
def docker_stack_deploy(client, stack_name, compose_files): def docker_stack_deploy(
client: AnsibleModuleDockerClient, stack_name: str, compose_files: list[str]
) -> tuple[int, str, str]:
command = ["stack", "deploy"] command = ["stack", "deploy"]
if client.module.params["prune"]: if client.module.params["prune"]:
command += ["--prune"] command += ["--prune"]
@ -217,14 +224,21 @@ def docker_stack_deploy(client, stack_name, compose_files):
return rc, to_native(out), to_native(err) return rc, to_native(out), to_native(err)
def docker_stack_inspect(client, stack_name): def docker_stack_inspect(
ret = {} client: AnsibleModuleDockerClient, stack_name: str
) -> dict[str, dict[str, t.Any] | None]:
ret: dict[str, dict[str, t.Any] | None] = {}
for service_name in docker_stack_services(client, stack_name): for service_name in docker_stack_services(client, stack_name):
ret[service_name] = docker_service_inspect(client, service_name) ret[service_name] = docker_service_inspect(client, service_name)
return ret return ret
def docker_stack_rm(client, stack_name, retries, interval): def docker_stack_rm(
client: AnsibleModuleDockerClient,
stack_name: str,
retries: int,
interval: int | float,
) -> tuple[int, str, str]:
command = ["stack", "rm", stack_name] command = ["stack", "rm", stack_name]
if not client.module.params["detach"]: if not client.module.params["detach"]:
command += ["--detach=false"] command += ["--detach=false"]
@ -237,7 +251,7 @@ def docker_stack_rm(client, stack_name, retries, interval):
return rc, to_native(out), to_native(err) return rc, to_native(out), to_native(err)
def main(): def main() -> None:
client = AnsibleModuleDockerClient( client = AnsibleModuleDockerClient(
argument_spec={ argument_spec={
"name": {"type": "str", "required": True}, "name": {"type": "str", "required": True},
@ -258,10 +272,10 @@ def main():
) )
if not HAS_JSONDIFF: if not HAS_JSONDIFF:
return client.fail("jsondiff is not installed, try 'pip install jsondiff'") client.fail("jsondiff is not installed, try 'pip install jsondiff'")
if not HAS_YAML: if not HAS_YAML:
return client.fail("yaml is not installed, try 'pip install pyyaml'") client.fail("yaml is not installed, try 'pip install pyyaml'")
try: try:
state = client.module.params["state"] state = client.module.params["state"]

View File

@ -85,16 +85,7 @@ from ansible_collections.community.docker.plugins.module_utils._common_cli impor
) )
def docker_stack_list(module): def main() -> None:
docker_bin = module.get_bin_path("docker", required=True)
rc, out, err = module.run_command(
[docker_bin, "stack", "ls", "--format={{json .}}"]
)
return rc, out.strip(), err.strip()
def main():
client = AnsibleModuleDockerClient( client = AnsibleModuleDockerClient(
argument_spec={}, argument_spec={},
supports_check_mode=True, supports_check_mode=True,

View File

@ -93,16 +93,7 @@ from ansible_collections.community.docker.plugins.module_utils._common_cli impor
) )
def docker_stack_task(module, stack_name): def main() -> None:
docker_bin = module.get_bin_path("docker", required=True)
rc, out, err = module.run_command(
[docker_bin, "stack", "ps", stack_name, "--format={{json .}}"]
)
return rc, out.strip(), err.strip()
def main():
client = AnsibleModuleDockerClient( client = AnsibleModuleDockerClient(
argument_spec={"name": {"type": "str", "required": True}}, argument_spec={"name": {"type": "str", "required": True}},
supports_check_mode=True, supports_check_mode=True,

View File

@ -292,6 +292,7 @@ actions:
import json import json
import traceback import traceback
import typing as t
try: try:
@ -314,40 +315,40 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
class TaskParameters(DockerBaseClass): class TaskParameters(DockerBaseClass):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.advertise_addr = None self.advertise_addr: str | None = None
self.listen_addr = None self.listen_addr: str | None = None
self.remote_addrs = None self.remote_addrs: list[str] | None = None
self.join_token = None self.join_token: str | None = None
self.data_path_addr = None self.data_path_addr: str | None = None
self.data_path_port = None self.data_path_port: int | None = None
self.spec = None self.spec = None
# Spec # Spec
self.snapshot_interval = None self.snapshot_interval: int | None = None
self.task_history_retention_limit = None self.task_history_retention_limit: int | None = None
self.keep_old_snapshots = None self.keep_old_snapshots: int | None = None
self.log_entries_for_slow_followers = None self.log_entries_for_slow_followers: int | None = None
self.heartbeat_tick = None self.heartbeat_tick: int | None = None
self.election_tick = None self.election_tick: int | None = None
self.dispatcher_heartbeat_period = None self.dispatcher_heartbeat_period: int | None = None
self.node_cert_expiry = None self.node_cert_expiry: int | None = None
self.name = None self.name: str | None = None
self.labels = None self.labels: dict[str, t.Any] | None = None
self.log_driver = None self.log_driver = None
self.signing_ca_cert = None self.signing_ca_cert: str | None = None
self.signing_ca_key = None self.signing_ca_key: str | None = None
self.ca_force_rotate = None self.ca_force_rotate: int | None = None
self.autolock_managers = None self.autolock_managers: bool | None = None
self.rotate_worker_token = None self.rotate_worker_token: bool | None = None
self.rotate_manager_token = None self.rotate_manager_token: bool | None = None
self.default_addr_pool = None self.default_addr_pool: list[str] | None = None
self.subnet_size = None self.subnet_size: int | None = None
@staticmethod @staticmethod
def from_ansible_params(client): def from_ansible_params(client: AnsibleDockerSwarmClient) -> TaskParameters:
result = TaskParameters() result = TaskParameters()
for key, value in client.module.params.items(): for key, value in client.module.params.items():
if key in result.__dict__: if key in result.__dict__:
@ -356,7 +357,7 @@ class TaskParameters(DockerBaseClass):
result.update_parameters(client) result.update_parameters(client)
return result return result
def update_from_swarm_info(self, swarm_info): def update_from_swarm_info(self, swarm_info: dict[str, t.Any]) -> None:
spec = swarm_info["Spec"] spec = swarm_info["Spec"]
ca_config = spec.get("CAConfig") or {} ca_config = spec.get("CAConfig") or {}
@ -400,7 +401,7 @@ class TaskParameters(DockerBaseClass):
if "LogDriver" in spec["TaskDefaults"]: if "LogDriver" in spec["TaskDefaults"]:
self.log_driver = spec["TaskDefaults"]["LogDriver"] self.log_driver = spec["TaskDefaults"]["LogDriver"]
def update_parameters(self, client): def update_parameters(self, client: AnsibleDockerSwarmClient) -> None:
assign = { assign = {
"snapshot_interval": "snapshot_interval", "snapshot_interval": "snapshot_interval",
"task_history_retention_limit": "task_history_retention_limit", "task_history_retention_limit": "task_history_retention_limit",
@ -427,7 +428,12 @@ class TaskParameters(DockerBaseClass):
params[dest] = value params[dest] = value
self.spec = client.create_swarm_spec(**params) self.spec = client.create_swarm_spec(**params)
def compare_to_active(self, other, client, differences): def compare_to_active(
self,
other: TaskParameters,
client: AnsibleDockerSwarmClient,
differences: DifferenceTracker,
) -> DifferenceTracker:
for k in self.__dict__: for k in self.__dict__:
if k in ( if k in (
"advertise_addr", "advertise_addr",
@ -459,26 +465,28 @@ class TaskParameters(DockerBaseClass):
class SwarmManager(DockerBaseClass): class SwarmManager(DockerBaseClass):
def __init__(
def __init__(self, client, results): self, client: AnsibleDockerSwarmClient, results: dict[str, t.Any]
) -> None:
super().__init__() super().__init__()
self.client = client self.client = client
self.results = results self.results = results
self.check_mode = self.client.check_mode self.check_mode = self.client.check_mode
self.swarm_info = {} self.swarm_info: dict[str, t.Any] = {}
self.state = client.module.params["state"] self.state: t.Literal["present", "join", "absent", "remove"] = (
self.force = client.module.params["force"] client.module.params["state"]
self.node_id = client.module.params["node_id"] )
self.force: bool = client.module.params["force"]
self.node_id: str | None = client.module.params["node_id"]
self.differences = DifferenceTracker() self.differences = DifferenceTracker()
self.parameters = TaskParameters.from_ansible_params(client) self.parameters = TaskParameters.from_ansible_params(client)
self.created = False self.created = False
def __call__(self): def __call__(self) -> None:
choice_map = { choice_map = {
"present": self.init_swarm, "present": self.init_swarm,
"join": self.join, "join": self.join,
@ -486,14 +494,14 @@ class SwarmManager(DockerBaseClass):
"remove": self.remove, "remove": self.remove,
} }
choice_map.get(self.state)() choice_map[self.state]()
if self.client.module._diff or self.parameters.debug: if self.client.module._diff or self.parameters.debug:
diff = {} diff = {}
diff["before"], diff["after"] = self.differences.get_before_after() diff["before"], diff["after"] = self.differences.get_before_after()
self.results["diff"] = diff self.results["diff"] = diff
def inspect_swarm(self): def inspect_swarm(self) -> None:
try: try:
data = self.client.inspect_swarm() data = self.client.inspect_swarm()
json_str = json.dumps(data, ensure_ascii=False) json_str = json.dumps(data, ensure_ascii=False)
@ -507,7 +515,7 @@ class SwarmManager(DockerBaseClass):
except APIError: except APIError:
pass pass
def get_unlock_key(self): def get_unlock_key(self) -> dict[str, t.Any]:
default = {"UnlockKey": None} default = {"UnlockKey": None}
if not self.has_swarm_lock_changed(): if not self.has_swarm_lock_changed():
return default return default
@ -516,18 +524,18 @@ class SwarmManager(DockerBaseClass):
except APIError: except APIError:
return default return default
def has_swarm_lock_changed(self): def has_swarm_lock_changed(self) -> bool:
return self.parameters.autolock_managers and ( return bool(self.parameters.autolock_managers) and (
self.created or self.differences.has_difference_for("autolock_managers") self.created or self.differences.has_difference_for("autolock_managers")
) )
def init_swarm(self): def init_swarm(self) -> None:
if not self.force and self.client.check_if_swarm_manager(): if not self.force and self.client.check_if_swarm_manager():
self.__update_swarm() self.__update_swarm()
return return
if not self.check_mode: if not self.check_mode:
init_arguments = { init_arguments: dict[str, t.Any] = {
"advertise_addr": self.parameters.advertise_addr, "advertise_addr": self.parameters.advertise_addr,
"listen_addr": self.parameters.listen_addr, "listen_addr": self.parameters.listen_addr,
"force_new_cluster": self.force, "force_new_cluster": self.force,
@ -562,7 +570,7 @@ class SwarmManager(DockerBaseClass):
"UnlockKey": self.swarm_info.get("UnlockKey"), "UnlockKey": self.swarm_info.get("UnlockKey"),
} }
def __update_swarm(self): def __update_swarm(self) -> None:
try: try:
self.inspect_swarm() self.inspect_swarm()
version = self.swarm_info["Version"]["Index"] version = self.swarm_info["Version"]["Index"]
@ -587,13 +595,12 @@ class SwarmManager(DockerBaseClass):
) )
except APIError as exc: except APIError as exc:
self.client.fail(f"Can not update a Swarm Cluster: {exc}") self.client.fail(f"Can not update a Swarm Cluster: {exc}")
return
self.inspect_swarm() self.inspect_swarm()
self.results["actions"].append("Swarm cluster updated") self.results["actions"].append("Swarm cluster updated")
self.results["changed"] = True self.results["changed"] = True
def join(self): def join(self) -> None:
if self.client.check_if_swarm_node(): if self.client.check_if_swarm_node():
self.results["actions"].append("This node is already part of a swarm.") self.results["actions"].append("This node is already part of a swarm.")
return return
@ -614,7 +621,7 @@ class SwarmManager(DockerBaseClass):
self.differences.add("joined", parameter=True, active=False) self.differences.add("joined", parameter=True, active=False)
self.results["changed"] = True self.results["changed"] = True
def leave(self): def leave(self) -> None:
if not self.client.check_if_swarm_node(): if not self.client.check_if_swarm_node():
self.results["actions"].append("This node is not part of a swarm.") self.results["actions"].append("This node is not part of a swarm.")
return return
@ -627,7 +634,7 @@ class SwarmManager(DockerBaseClass):
self.differences.add("joined", parameter="absent", active="present") self.differences.add("joined", parameter="absent", active="present")
self.results["changed"] = True self.results["changed"] = True
def remove(self): def remove(self) -> None:
if not self.client.check_if_swarm_manager(): if not self.client.check_if_swarm_manager():
self.client.fail("This node is not a manager.") self.client.fail("This node is not a manager.")
@ -655,11 +662,12 @@ class SwarmManager(DockerBaseClass):
self.results["changed"] = True self.results["changed"] = True
def _detect_remove_operation(client): def _detect_remove_operation(client: AnsibleDockerSwarmClient) -> bool:
return client.module.params["state"] == "remove" return client.module.params["state"] == "remove"
def main(): def main() -> None:
# TODO: missing option log_driver?
argument_spec = { argument_spec = {
"advertise_addr": {"type": "str"}, "advertise_addr": {"type": "str"},
"data_path_addr": {"type": "str"}, "data_path_addr": {"type": "str"},

View File

@ -186,6 +186,7 @@ tasks:
""" """
import traceback import traceback
import typing as t
try: try:
@ -207,16 +208,20 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
class DockerSwarmManager(DockerBaseClass): class DockerSwarmManager(DockerBaseClass):
def __init__(
def __init__(self, client, results): self, client: AnsibleDockerSwarmClient, results: dict[str, t.Any]
) -> None:
super().__init__() super().__init__()
self.client = client self.client = client
self.results = results self.results = results
self.verbose_output = self.client.module.params["verbose_output"] self.verbose_output = self.client.module.params["verbose_output"]
listed_objects = ["tasks", "services", "nodes"] listed_objects: list[t.Literal["nodes", "tasks", "services"]] = [
"tasks",
"services",
"nodes",
]
self.client.fail_task_if_not_swarm_manager() self.client.fail_task_if_not_swarm_manager()
@ -235,15 +240,18 @@ class DockerSwarmManager(DockerBaseClass):
if self.client.module.params["unlock_key"]: if self.client.module.params["unlock_key"]:
self.results["swarm_unlock_key"] = self.get_docker_swarm_unlock_key() self.results["swarm_unlock_key"] = self.get_docker_swarm_unlock_key()
def get_docker_swarm_facts(self): def get_docker_swarm_facts(self) -> dict[str, t.Any]:
try: try:
return self.client.inspect_swarm() return self.client.inspect_swarm()
except APIError as exc: except APIError as exc:
self.client.fail(f"Error inspecting docker swarm: {exc}") self.client.fail(f"Error inspecting docker swarm: {exc}")
def get_docker_items_list(self, docker_object=None, filters=None): def get_docker_items_list(
items = None self,
items_list = [] docker_object: t.Literal["nodes", "tasks", "services"],
filters: dict[str, str],
) -> list[dict[str, t.Any]]:
items_list: list[dict[str, t.Any]] = []
try: try:
if docker_object == "nodes": if docker_object == "nodes":
@ -252,6 +260,8 @@ class DockerSwarmManager(DockerBaseClass):
items = self.client.tasks(filters=filters) items = self.client.tasks(filters=filters)
elif docker_object == "services": elif docker_object == "services":
items = self.client.services(filters=filters) items = self.client.services(filters=filters)
else:
raise ValueError(f"Invalid docker_object {docker_object}")
except APIError as exc: except APIError as exc:
self.client.fail( self.client.fail(
f"Error inspecting docker swarm for object '{docker_object}': {exc}" f"Error inspecting docker swarm for object '{docker_object}': {exc}"
@ -276,7 +286,7 @@ class DockerSwarmManager(DockerBaseClass):
return items_list return items_list
@staticmethod @staticmethod
def get_essential_facts_nodes(item): def get_essential_facts_nodes(item: dict[str, t.Any]) -> dict[str, t.Any]:
object_essentials = {} object_essentials = {}
object_essentials["ID"] = item.get("ID") object_essentials["ID"] = item.get("ID")
@ -298,7 +308,7 @@ class DockerSwarmManager(DockerBaseClass):
return object_essentials return object_essentials
def get_essential_facts_tasks(self, item): def get_essential_facts_tasks(self, item: dict[str, t.Any]) -> dict[str, t.Any]:
object_essentials = {} object_essentials = {}
object_essentials["ID"] = item["ID"] object_essentials["ID"] = item["ID"]
@ -319,7 +329,7 @@ class DockerSwarmManager(DockerBaseClass):
return object_essentials return object_essentials
@staticmethod @staticmethod
def get_essential_facts_services(item): def get_essential_facts_services(item: dict[str, t.Any]) -> dict[str, t.Any]:
object_essentials = {} object_essentials = {}
object_essentials["ID"] = item["ID"] object_essentials["ID"] = item["ID"]
@ -343,12 +353,12 @@ class DockerSwarmManager(DockerBaseClass):
return object_essentials return object_essentials
def get_docker_swarm_unlock_key(self): def get_docker_swarm_unlock_key(self) -> str | None:
unlock_key = self.client.get_unlock_key() or {} unlock_key = self.client.get_unlock_key() or {}
return unlock_key.get("UnlockKey") or None return unlock_key.get("UnlockKey") or None
def main(): def main() -> None:
argument_spec = { argument_spec = {
"nodes": {"type": "bool", "default": False}, "nodes": {"type": "bool", "default": False},
"nodes_filters": {"type": "dict"}, "nodes_filters": {"type": "dict"},

View File

@ -853,6 +853,7 @@ EXAMPLES = r"""
import shlex import shlex
import time import time
import traceback import traceback
import typing as t
from ansible.module_utils.basic import human_to_bytes from ansible.module_utils.basic import human_to_bytes
from ansible.module_utils.common.text.converters import to_text from ansible.module_utils.common.text.converters import to_text
@ -891,7 +892,9 @@ except ImportError:
pass pass
def get_docker_environment(env, env_files): def get_docker_environment(
env: str | dict[str, t.Any] | list[t.Any] | None, env_files: list[str] | None
) -> list[str] | None:
""" """
Will return a list of "KEY=VALUE" items. Supplied env variable can Will return a list of "KEY=VALUE" items. Supplied env variable can
be either a list or a dictionary. be either a list or a dictionary.
@ -899,7 +902,7 @@ def get_docker_environment(env, env_files):
If environment files are combined with explicit environment variables, If environment files are combined with explicit environment variables,
the explicit environment variables take precedence. the explicit environment variables take precedence.
""" """
env_dict = {} env_dict: dict[str, str] = {}
if env_files: if env_files:
for env_file in env_files: for env_file in env_files:
parsed_env_file = parse_env_file(env_file) parsed_env_file = parse_env_file(env_file)
@ -936,7 +939,21 @@ def get_docker_environment(env, env_files):
return sorted(env_list) return sorted(env_list)
def get_docker_networks(networks, network_ids): @t.overload
def get_docker_networks(
networks: list[str | dict[str, t.Any]], network_ids: dict[str, str]
) -> list[dict[str, t.Any]]: ...
@t.overload
def get_docker_networks(
networks: list[str | dict[str, t.Any]] | None, network_ids: dict[str, str]
) -> list[dict[str, t.Any]] | None: ...
def get_docker_networks(
networks: list[str | dict[str, t.Any]] | None, network_ids: dict[str, str]
) -> list[dict[str, t.Any]] | None:
""" """
Validate a list of network names or a list of network dictionaries. Validate a list of network names or a list of network dictionaries.
Network names will be resolved to ids by using the network_ids mapping. Network names will be resolved to ids by using the network_ids mapping.
@ -945,6 +962,7 @@ def get_docker_networks(networks, network_ids):
return None return None
parsed_networks = [] parsed_networks = []
for network in networks: for network in networks:
parsed_network: dict[str, t.Any]
if isinstance(network, str): if isinstance(network, str):
parsed_network = {"name": network} parsed_network = {"name": network}
elif isinstance(network, dict): elif isinstance(network, dict):
@ -988,7 +1006,7 @@ def get_docker_networks(networks, network_ids):
return parsed_networks or [] return parsed_networks or []
def get_nanoseconds_from_raw_option(name, value): def get_nanoseconds_from_raw_option(name: str, value: t.Any) -> int | None:
if value is None: if value is None:
return None return None
if isinstance(value, int): if isinstance(value, int):
@ -1003,12 +1021,14 @@ def get_nanoseconds_from_raw_option(name, value):
) )
def get_value(key, values, default=None): def get_value(key: str, values: dict[str, t.Any], default: t.Any = None) -> t.Any:
value = values.get(key) value = values.get(key)
return value if value is not None else default return value if value is not None else default
def has_dict_changed(new_dict, old_dict): def has_dict_changed(
new_dict: dict[str, t.Any] | None, old_dict: dict[str, t.Any] | None
) -> bool:
""" """
Check if new_dict has differences compared to old_dict while Check if new_dict has differences compared to old_dict while
ignoring keys in old_dict which are None in new_dict. ignoring keys in old_dict which are None in new_dict.
@ -1019,6 +1039,9 @@ def has_dict_changed(new_dict, old_dict):
return True return True
if not old_dict and new_dict: if not old_dict and new_dict:
return True return True
if old_dict is None:
# in this case new_dict is empty, only the type checker didn't notice
return False
defined_options = { defined_options = {
option: value for option, value in new_dict.items() if value is not None option: value for option, value in new_dict.items() if value is not None
} }
@ -1031,12 +1054,17 @@ def has_dict_changed(new_dict, old_dict):
return False return False
def has_list_changed(new_list, old_list, sort_lists=True, sort_key=None): def has_list_changed(
new_list: list[t.Any] | None,
old_list: list[t.Any] | None,
sort_lists: bool = True,
sort_key: str | None = None,
) -> bool:
""" """
Check two lists have differences. Sort lists by default. Check two lists have differences. Sort lists by default.
""" """
def sort_list(unsorted_list): def sort_list(unsorted_list: list[t.Any]) -> list[t.Any]:
""" """
Sort a given list. Sort a given list.
The list may contain dictionaries, so use the sort key to handle them. The list may contain dictionaries, so use the sort key to handle them.
@ -1093,7 +1121,10 @@ def has_list_changed(new_list, old_list, sort_lists=True, sort_key=None):
return False return False
def have_networks_changed(new_networks, old_networks): def have_networks_changed(
new_networks: list[dict[str, t.Any]] | None,
old_networks: list[dict[str, t.Any]] | None,
) -> bool:
"""Special case list checking for networks to sort aliases""" """Special case list checking for networks to sort aliases"""
if new_networks is None: if new_networks is None:
@ -1123,68 +1154,72 @@ def have_networks_changed(new_networks, old_networks):
class DockerService(DockerBaseClass): class DockerService(DockerBaseClass):
def __init__(self, docker_api_version, docker_py_version): def __init__(
self, docker_api_version: LooseVersion, docker_py_version: LooseVersion
) -> None:
super().__init__() super().__init__()
self.image = "" self.image: str | None = ""
self.command = None self.command: t.Any = None
self.args = None self.args: list[str] | None = None
self.endpoint_mode = None self.endpoint_mode: t.Literal["vip", "dnsrr"] | None = None
self.dns = None self.dns: list[str] | None = None
self.healthcheck = None self.healthcheck: dict[str, t.Any] | None = None
self.healthcheck_disabled = None self.healthcheck_disabled: bool | None = None
self.hostname = None self.hostname: str | None = None
self.hosts = None self.hosts: dict[str, t.Any] | None = None
self.tty = None self.tty: bool | None = None
self.dns_search = None self.dns_search: list[str] | None = None
self.dns_options = None self.dns_options: list[str] | None = None
self.env = None self.env: t.Any = None
self.force_update = None self.force_update: int | None = None
self.groups = None self.groups: list[str] | None = None
self.log_driver = None self.log_driver: str | None = None
self.log_driver_options = None self.log_driver_options: dict[str, t.Any] | None = None
self.labels = None self.labels: dict[str, t.Any] | None = None
self.container_labels = None self.container_labels: dict[str, t.Any] | None = None
self.sysctls = None self.sysctls: dict[str, t.Any] | None = None
self.limit_cpu = None self.limit_cpu: float | None = None
self.limit_memory = None self.limit_memory: int | None = None
self.reserve_cpu = None self.reserve_cpu: float | None = None
self.reserve_memory = None self.reserve_memory: int | None = None
self.mode = "replicated" self.mode: t.Literal["replicated", "global", "replicated-job"] = "replicated"
self.user = None self.user: str | None = None
self.mounts = None self.mounts: list[dict[str, t.Any]] | None = None
self.configs = None self.configs: list[dict[str, t.Any]] | None = None
self.secrets = None self.secrets: list[dict[str, t.Any]] | None = None
self.constraints = None self.constraints: list[str] | None = None
self.replicas_max_per_node = None self.replicas_max_per_node: int | None = None
self.networks = None self.networks: list[t.Any] | None = None
self.stop_grace_period = None self.stop_grace_period: int | None = None
self.stop_signal = None self.stop_signal: str | None = None
self.publish = None self.publish: list[dict[str, t.Any]] | None = None
self.placement_preferences = None self.placement_preferences: list[dict[str, t.Any]] | None = None
self.replicas = -1 self.replicas: int | None = -1
self.service_id = False self.service_id = False
self.service_version = False self.service_version = False
self.read_only = None self.read_only: bool | None = None
self.restart_policy = None self.restart_policy: t.Literal["none", "on-failure", "any"] | None = None
self.restart_policy_attempts = None self.restart_policy_attempts: int | None = None
self.restart_policy_delay = None self.restart_policy_delay: str | None = None
self.restart_policy_window = None self.restart_policy_window: str | None = None
self.rollback_config = None self.rollback_config: dict[str, t.Any] | None = None
self.update_delay = None self.update_delay: str | None = None
self.update_parallelism = None self.update_parallelism: int | None = None
self.update_failure_action = None self.update_failure_action: (
self.update_monitor = None t.Literal["continue", "pause", "rollback"] | None
self.update_max_failure_ratio = None ) = None
self.update_order = None self.update_monitor: str | None = None
self.working_dir = None self.update_max_failure_ratio: float | None = None
self.init = None self.update_order: str | None = None
self.cap_add = None self.working_dir: str | None = None
self.cap_drop = None self.init: bool | None = None
self.cap_add: list[str] | None = None
self.cap_drop: list[str] | None = None
self.docker_api_version = docker_api_version self.docker_api_version = docker_api_version
self.docker_py_version = docker_py_version self.docker_py_version = docker_py_version
def get_facts(self): def get_facts(self) -> dict[str, t.Any]:
return { return {
"image": self.image, "image": self.image,
"mounts": self.mounts, "mounts": self.mounts,
@ -1242,19 +1277,21 @@ class DockerService(DockerBaseClass):
} }
@property @property
def can_update_networks(self): def can_update_networks(self) -> bool:
# Before Docker API 1.29 adding/removing networks was not supported # Before Docker API 1.29 adding/removing networks was not supported
return self.docker_api_version >= LooseVersion( return self.docker_api_version >= LooseVersion(
"1.29" "1.29"
) and self.docker_py_version >= LooseVersion("2.7") ) and self.docker_py_version >= LooseVersion("2.7")
@property @property
def can_use_task_template_networks(self): def can_use_task_template_networks(self) -> bool:
# In Docker API 1.25 attaching networks to TaskTemplate is preferred over Spec # In Docker API 1.25 attaching networks to TaskTemplate is preferred over Spec
return self.docker_py_version >= LooseVersion("2.7") return self.docker_py_version >= LooseVersion("2.7")
@staticmethod @staticmethod
def get_restart_config_from_ansible_params(params): def get_restart_config_from_ansible_params(
params: dict[str, t.Any],
) -> dict[str, t.Any]:
restart_config = params["restart_config"] or {} restart_config = params["restart_config"] or {}
condition = get_value( condition = get_value(
"condition", "condition",
@ -1282,7 +1319,9 @@ class DockerService(DockerBaseClass):
} }
@staticmethod @staticmethod
def get_update_config_from_ansible_params(params): def get_update_config_from_ansible_params(
params: dict[str, t.Any],
) -> dict[str, t.Any]:
update_config = params["update_config"] or {} update_config = params["update_config"] or {}
parallelism = get_value( parallelism = get_value(
"parallelism", "parallelism",
@ -1320,7 +1359,9 @@ class DockerService(DockerBaseClass):
} }
@staticmethod @staticmethod
def get_rollback_config_from_ansible_params(params): def get_rollback_config_from_ansible_params(
params: dict[str, t.Any],
) -> dict[str, t.Any] | None:
if params["rollback_config"] is None: if params["rollback_config"] is None:
return None return None
rollback_config = params["rollback_config"] or {} rollback_config = params["rollback_config"] or {}
@ -1340,7 +1381,7 @@ class DockerService(DockerBaseClass):
} }
@staticmethod @staticmethod
def get_logging_from_ansible_params(params): def get_logging_from_ansible_params(params: dict[str, t.Any]) -> dict[str, t.Any]:
logging_config = params["logging"] or {} logging_config = params["logging"] or {}
driver = get_value( driver = get_value(
"driver", "driver",
@ -1356,7 +1397,7 @@ class DockerService(DockerBaseClass):
} }
@staticmethod @staticmethod
def get_limits_from_ansible_params(params): def get_limits_from_ansible_params(params: dict[str, t.Any]) -> dict[str, t.Any]:
limits = params["limits"] or {} limits = params["limits"] or {}
cpus = get_value( cpus = get_value(
"cpus", "cpus",
@ -1379,7 +1420,9 @@ class DockerService(DockerBaseClass):
} }
@staticmethod @staticmethod
def get_reservations_from_ansible_params(params): def get_reservations_from_ansible_params(
params: dict[str, t.Any],
) -> dict[str, t.Any]:
reservations = params["reservations"] or {} reservations = params["reservations"] or {}
cpus = get_value( cpus = get_value(
"cpus", "cpus",
@ -1403,7 +1446,7 @@ class DockerService(DockerBaseClass):
} }
@staticmethod @staticmethod
def get_placement_from_ansible_params(params): def get_placement_from_ansible_params(params: dict[str, t.Any]) -> dict[str, t.Any]:
placement = params["placement"] or {} placement = params["placement"] or {}
constraints = get_value("constraints", placement) constraints = get_value("constraints", placement)
@ -1419,14 +1462,14 @@ class DockerService(DockerBaseClass):
@classmethod @classmethod
def from_ansible_params( def from_ansible_params(
cls, cls,
ap, ap: dict[str, t.Any],
old_service, old_service: DockerService | None,
image_digest, image_digest: str,
secret_ids, secret_ids: dict[str, str],
config_ids, config_ids: dict[str, str],
network_ids, network_ids: dict[str, str],
client, client: AnsibleDockerClient,
): ) -> DockerService:
s = DockerService(client.docker_api_version, client.docker_py_version) s = DockerService(client.docker_api_version, client.docker_py_version)
s.image = image_digest s.image = image_digest
s.args = ap["args"] s.args = ap["args"]
@ -1596,7 +1639,7 @@ class DockerService(DockerBaseClass):
return s return s
def compare(self, os): def compare(self, os: DockerService) -> tuple[bool, DifferenceTracker, bool, bool]:
differences = DifferenceTracker() differences = DifferenceTracker()
needs_rebuild = False needs_rebuild = False
force_update = False force_update = False
@ -1784,7 +1827,7 @@ class DockerService(DockerBaseClass):
differences.add( differences.add(
"update_order", parameter=self.update_order, active=os.update_order "update_order", parameter=self.update_order, active=os.update_order
) )
has_image_changed, change = self.has_image_changed(os.image) has_image_changed, change = self.has_image_changed(os.image or "")
if has_image_changed: if has_image_changed:
differences.add("image", parameter=self.image, active=change) differences.add("image", parameter=self.image, active=change)
if self.user and self.user != os.user: if self.user and self.user != os.user:
@ -1828,7 +1871,7 @@ class DockerService(DockerBaseClass):
force_update, force_update,
) )
def has_healthcheck_changed(self, old_publish): def has_healthcheck_changed(self, old_publish: DockerService) -> bool:
if self.healthcheck_disabled is False and self.healthcheck is None: if self.healthcheck_disabled is False and self.healthcheck is None:
return False return False
if self.healthcheck_disabled: if self.healthcheck_disabled:
@ -1838,14 +1881,14 @@ class DockerService(DockerBaseClass):
return False return False
return self.healthcheck != old_publish.healthcheck return self.healthcheck != old_publish.healthcheck
def has_publish_changed(self, old_publish): def has_publish_changed(self, old_publish: list[dict[str, t.Any]] | None) -> bool:
if self.publish is None: if self.publish is None:
return False return False
old_publish = old_publish or [] old_publish = old_publish or []
if len(self.publish) != len(old_publish): if len(self.publish) != len(old_publish):
return True return True
def publish_sorter(item): def publish_sorter(item: dict[str, t.Any]) -> tuple[int, int, str]:
return ( return (
item.get("published_port") or 0, item.get("published_port") or 0,
item.get("target_port") or 0, item.get("target_port") or 0,
@ -1869,12 +1912,13 @@ class DockerService(DockerBaseClass):
return True return True
return False return False
def has_image_changed(self, old_image): def has_image_changed(self, old_image: str) -> tuple[bool, str]:
assert self.image is not None
if "@" not in self.image: if "@" not in self.image:
old_image = old_image.split("@")[0] old_image = old_image.split("@")[0]
return self.image != old_image, old_image return self.image != old_image, old_image
def build_container_spec(self): def build_container_spec(self) -> types.ContainerSpec:
mounts = None mounts = None
if self.mounts is not None: if self.mounts is not None:
mounts = [] mounts = []
@ -1945,7 +1989,7 @@ class DockerService(DockerBaseClass):
secrets.append(types.SecretReference(**secret_args)) secrets.append(types.SecretReference(**secret_args))
dns_config_args = {} dns_config_args: dict[str, t.Any] = {}
if self.dns is not None: if self.dns is not None:
dns_config_args["nameservers"] = self.dns dns_config_args["nameservers"] = self.dns
if self.dns_search is not None: if self.dns_search is not None:
@ -1954,7 +1998,7 @@ class DockerService(DockerBaseClass):
dns_config_args["options"] = self.dns_options dns_config_args["options"] = self.dns_options
dns_config = types.DNSConfig(**dns_config_args) if dns_config_args else None dns_config = types.DNSConfig(**dns_config_args) if dns_config_args else None
container_spec_args = {} container_spec_args: dict[str, t.Any] = {}
if self.command is not None: if self.command is not None:
container_spec_args["command"] = self.command container_spec_args["command"] = self.command
if self.args is not None: if self.args is not None:
@ -2004,8 +2048,8 @@ class DockerService(DockerBaseClass):
return types.ContainerSpec(self.image, **container_spec_args) return types.ContainerSpec(self.image, **container_spec_args)
def build_placement(self): def build_placement(self) -> types.Placement | None:
placement_args = {} placement_args: dict[str, t.Any] = {}
if self.constraints is not None: if self.constraints is not None:
placement_args["constraints"] = self.constraints placement_args["constraints"] = self.constraints
if self.replicas_max_per_node is not None: if self.replicas_max_per_node is not None:
@ -2018,8 +2062,8 @@ class DockerService(DockerBaseClass):
] ]
return types.Placement(**placement_args) if placement_args else None return types.Placement(**placement_args) if placement_args else None
def build_update_config(self): def build_update_config(self) -> types.UpdateConfig | None:
update_config_args = {} update_config_args: dict[str, t.Any] = {}
if self.update_parallelism is not None: if self.update_parallelism is not None:
update_config_args["parallelism"] = self.update_parallelism update_config_args["parallelism"] = self.update_parallelism
if self.update_delay is not None: if self.update_delay is not None:
@ -2034,16 +2078,16 @@ class DockerService(DockerBaseClass):
update_config_args["order"] = self.update_order update_config_args["order"] = self.update_order
return types.UpdateConfig(**update_config_args) if update_config_args else None return types.UpdateConfig(**update_config_args) if update_config_args else None
def build_log_driver(self): def build_log_driver(self) -> types.DriverConfig | None:
log_driver_args = {} log_driver_args: dict[str, t.Any] = {}
if self.log_driver is not None: if self.log_driver is not None:
log_driver_args["name"] = self.log_driver log_driver_args["name"] = self.log_driver
if self.log_driver_options is not None: if self.log_driver_options is not None:
log_driver_args["options"] = self.log_driver_options log_driver_args["options"] = self.log_driver_options
return types.DriverConfig(**log_driver_args) if log_driver_args else None return types.DriverConfig(**log_driver_args) if log_driver_args else None
def build_restart_policy(self): def build_restart_policy(self) -> types.RestartPolicy | None:
restart_policy_args = {} restart_policy_args: dict[str, t.Any] = {}
if self.restart_policy is not None: if self.restart_policy is not None:
restart_policy_args["condition"] = self.restart_policy restart_policy_args["condition"] = self.restart_policy
if self.restart_policy_delay is not None: if self.restart_policy_delay is not None:
@ -2056,7 +2100,7 @@ class DockerService(DockerBaseClass):
types.RestartPolicy(**restart_policy_args) if restart_policy_args else None types.RestartPolicy(**restart_policy_args) if restart_policy_args else None
) )
def build_rollback_config(self): def build_rollback_config(self) -> types.RollbackConfig | None:
if self.rollback_config is None: if self.rollback_config is None:
return None return None
rollback_config_options = [ rollback_config_options = [
@ -2078,8 +2122,8 @@ class DockerService(DockerBaseClass):
else None else None
) )
def build_resources(self): def build_resources(self) -> types.Resources | None:
resources_args = {} resources_args: dict[str, t.Any] = {}
if self.limit_cpu is not None: if self.limit_cpu is not None:
resources_args["cpu_limit"] = int(self.limit_cpu * 1000000000.0) resources_args["cpu_limit"] = int(self.limit_cpu * 1000000000.0)
if self.limit_memory is not None: if self.limit_memory is not None:
@ -2090,12 +2134,16 @@ class DockerService(DockerBaseClass):
resources_args["mem_reservation"] = self.reserve_memory resources_args["mem_reservation"] = self.reserve_memory
return types.Resources(**resources_args) if resources_args else None return types.Resources(**resources_args) if resources_args else None
def build_task_template(self, container_spec, placement=None): def build_task_template(
self,
container_spec: types.ContainerSpec,
placement: types.Placement | None = None,
) -> types.TaskTemplate:
log_driver = self.build_log_driver() log_driver = self.build_log_driver()
restart_policy = self.build_restart_policy() restart_policy = self.build_restart_policy()
resources = self.build_resources() resources = self.build_resources()
task_template_args = {} task_template_args: dict[str, t.Any] = {}
if placement is not None: if placement is not None:
task_template_args["placement"] = placement task_template_args["placement"] = placement
if log_driver is not None: if log_driver is not None:
@ -2112,12 +2160,12 @@ class DockerService(DockerBaseClass):
task_template_args["networks"] = networks task_template_args["networks"] = networks
return types.TaskTemplate(container_spec=container_spec, **task_template_args) return types.TaskTemplate(container_spec=container_spec, **task_template_args)
def build_service_mode(self): def build_service_mode(self) -> types.ServiceMode:
if self.mode == "global": if self.mode == "global":
self.replicas = None self.replicas = None
return types.ServiceMode(self.mode, replicas=self.replicas) return types.ServiceMode(self.mode, replicas=self.replicas)
def build_networks(self): def build_networks(self) -> list[dict[str, t.Any]] | None:
networks = None networks = None
if self.networks is not None: if self.networks is not None:
networks = [] networks = []
@ -2130,8 +2178,8 @@ class DockerService(DockerBaseClass):
networks.append(docker_network) networks.append(docker_network)
return networks return networks
def build_endpoint_spec(self): def build_endpoint_spec(self) -> types.EndpointSpec | None:
endpoint_spec_args = {} endpoint_spec_args: dict[str, t.Any] = {}
if self.publish is not None: if self.publish is not None:
ports = [] ports = []
for port in self.publish: for port in self.publish:
@ -2149,7 +2197,7 @@ class DockerService(DockerBaseClass):
endpoint_spec_args["mode"] = self.endpoint_mode endpoint_spec_args["mode"] = self.endpoint_mode
return types.EndpointSpec(**endpoint_spec_args) if endpoint_spec_args else None return types.EndpointSpec(**endpoint_spec_args) if endpoint_spec_args else None
def build_docker_service(self): def build_docker_service(self) -> dict[str, t.Any]:
container_spec = self.build_container_spec() container_spec = self.build_container_spec()
placement = self.build_placement() placement = self.build_placement()
task_template = self.build_task_template(container_spec, placement) task_template = self.build_task_template(container_spec, placement)
@ -2159,7 +2207,10 @@ class DockerService(DockerBaseClass):
service_mode = self.build_service_mode() service_mode = self.build_service_mode()
endpoint_spec = self.build_endpoint_spec() endpoint_spec = self.build_endpoint_spec()
service = {"task_template": task_template, "mode": service_mode} service: dict[str, t.Any] = {
"task_template": task_template,
"mode": service_mode,
}
if update_config: if update_config:
service["update_config"] = update_config service["update_config"] = update_config
if rollback_config: if rollback_config:
@ -2176,13 +2227,12 @@ class DockerService(DockerBaseClass):
class DockerServiceManager: class DockerServiceManager:
def __init__(self, client: AnsibleDockerClient):
def __init__(self, client):
self.client = client self.client = client
self.retries = 2 self.retries = 2
self.diff_tracker = None self.diff_tracker: DifferenceTracker | None = None
def get_service(self, name): def get_service(self, name: str) -> DockerService | None:
try: try:
raw_data = self.client.inspect_service(name) raw_data = self.client.inspect_service(name)
except NotFound: except NotFound:
@ -2415,7 +2465,9 @@ class DockerServiceManager:
ds.init = task_template_data["ContainerSpec"].get("Init", False) ds.init = task_template_data["ContainerSpec"].get("Init", False)
return ds return ds
def update_service(self, name, old_service, new_service): def update_service(
self, name: str, old_service: DockerService, new_service: DockerService
) -> None:
service_data = new_service.build_docker_service() service_data = new_service.build_docker_service()
result = self.client.update_service( result = self.client.update_service(
old_service.service_id, old_service.service_id,
@ -2427,15 +2479,15 @@ class DockerServiceManager:
# (see https://github.com/docker/docker-py/pull/2272) # (see https://github.com/docker/docker-py/pull/2272)
self.client.report_warnings(result, ["Warning"]) self.client.report_warnings(result, ["Warning"])
def create_service(self, name, service): def create_service(self, name: str, service: DockerService) -> None:
service_data = service.build_docker_service() service_data = service.build_docker_service()
result = self.client.create_service(name=name, **service_data) result = self.client.create_service(name=name, **service_data)
self.client.report_warnings(result, ["Warning"]) self.client.report_warnings(result, ["Warning"])
def remove_service(self, name): def remove_service(self, name: str) -> None:
self.client.remove_service(name) self.client.remove_service(name)
def get_image_digest(self, name, resolve=False): def get_image_digest(self, name: str, resolve: bool = False) -> str:
if not name or not resolve: if not name or not resolve:
return name return name
repo, tag = parse_repository_tag(name) repo, tag = parse_repository_tag(name)
@ -2446,10 +2498,10 @@ class DockerServiceManager:
digest = distribution_data["Descriptor"]["digest"] digest = distribution_data["Descriptor"]["digest"]
return f"{name}@{digest}" return f"{name}@{digest}"
def get_networks_names_ids(self): def get_networks_names_ids(self) -> dict[str, str]:
return {network["Name"]: network["Id"] for network in self.client.networks()} return {network["Name"]: network["Id"] for network in self.client.networks()}
def get_missing_secret_ids(self): def get_missing_secret_ids(self) -> dict[str, str]:
""" """
Resolve missing secret ids by looking them up by name Resolve missing secret ids by looking them up by name
""" """
@ -2471,7 +2523,7 @@ class DockerServiceManager:
self.client.fail(f'Could not find a secret named "{secret_name}"') self.client.fail(f'Could not find a secret named "{secret_name}"')
return secrets return secrets
def get_missing_config_ids(self): def get_missing_config_ids(self) -> dict[str, str]:
""" """
Resolve missing config ids by looking them up by name Resolve missing config ids by looking them up by name
""" """
@ -2493,7 +2545,7 @@ class DockerServiceManager:
self.client.fail(f'Could not find a config named "{config_name}"') self.client.fail(f'Could not find a config named "{config_name}"')
return configs return configs
def run(self): def run(self) -> tuple[str, bool, bool, list[str], dict[str, t.Any]]:
self.diff_tracker = DifferenceTracker() self.diff_tracker = DifferenceTracker()
module = self.client.module module = self.client.module
@ -2582,7 +2634,7 @@ class DockerServiceManager:
return msg, changed, rebuilt, differences.get_legacy_docker_diffs(), facts return msg, changed, rebuilt, differences.get_legacy_docker_diffs(), facts
def run_safe(self): def run_safe(self) -> tuple[str, bool, bool, list[str], dict[str, t.Any]]:
while True: while True:
try: try:
return self.run() return self.run()
@ -2596,20 +2648,20 @@ class DockerServiceManager:
raise raise
def _detect_publish_mode_usage(client): def _detect_publish_mode_usage(client: AnsibleDockerClient) -> bool:
for publish_def in client.module.params["publish"] or []: for publish_def in client.module.params["publish"] or []:
if publish_def.get("mode"): if publish_def.get("mode"):
return True return True
return False return False
def _detect_healthcheck_start_period(client): def _detect_healthcheck_start_period(client: AnsibleDockerClient) -> bool:
if client.module.params["healthcheck"]: if client.module.params["healthcheck"]:
return client.module.params["healthcheck"]["start_period"] is not None return client.module.params["healthcheck"]["start_period"] is not None
return False return False
def _detect_mount_tmpfs_usage(client): def _detect_mount_tmpfs_usage(client: AnsibleDockerClient) -> bool:
for mount in client.module.params["mounts"] or []: for mount in client.module.params["mounts"] or []:
if mount.get("type") == "tmpfs": if mount.get("type") == "tmpfs":
return True return True
@ -2620,14 +2672,14 @@ def _detect_mount_tmpfs_usage(client):
return False return False
def _detect_update_config_failure_action_rollback(client): def _detect_update_config_failure_action_rollback(client: AnsibleDockerClient) -> bool:
rollback_config_failure_action = (client.module.params["update_config"] or {}).get( rollback_config_failure_action = (client.module.params["update_config"] or {}).get(
"failure_action" "failure_action"
) )
return rollback_config_failure_action == "rollback" return rollback_config_failure_action == "rollback"
def main(): def main() -> None:
argument_spec = { argument_spec = {
"name": {"type": "str", "required": True}, "name": {"type": "str", "required": True},
"image": {"type": "str"}, "image": {"type": "str"},
@ -2948,6 +3000,7 @@ def main():
"swarm_service": facts, "swarm_service": facts,
} }
if client.module._diff: if client.module._diff:
assert dsm.diff_tracker is not None
before, after = dsm.diff_tracker.get_before_after() before, after = dsm.diff_tracker.get_before_after()
results["diff"] = {"before": before, "after": after} results["diff"] = {"before": before, "after": after}

View File

@ -63,6 +63,7 @@ service:
""" """
import traceback import traceback
import typing as t
try: try:
@ -79,12 +80,12 @@ from ansible_collections.community.docker.plugins.module_utils._swarm import (
) )
def get_service_info(client): def get_service_info(client: AnsibleDockerSwarmClient) -> dict[str, t.Any] | None:
service = client.module.params["name"] service = client.module.params["name"]
return client.get_service_inspect(service_id=service, skip_missing=True) return client.get_service_inspect(service_id=service, skip_missing=True)
def main(): def main() -> None:
argument_spec = { argument_spec = {
"name": {"type": "str", "required": True}, "name": {"type": "str", "required": True},
} }

View File

@ -5,6 +5,7 @@ plugins/module_utils/_api/api/client.py pep8:E704
plugins/module_utils/_api/transport/sshconn.py no-assert plugins/module_utils/_api/transport/sshconn.py no-assert
plugins/module_utils/_api/utils/build.py no-assert plugins/module_utils/_api/utils/build.py no-assert
plugins/module_utils/_api/utils/ports.py pep8:E704 plugins/module_utils/_api/utils/ports.py pep8:E704
plugins/module_utils/_api/utils/proxy.py pep8:E704
plugins/module_utils/_api/utils/socket.py pep8:E704 plugins/module_utils/_api/utils/socket.py pep8:E704
plugins/module_utils/_common_cli.py pep8:E704 plugins/module_utils/_common_cli.py pep8:E704
plugins/module_utils/_module_container/module.py no-assert plugins/module_utils/_module_container/module.py no-assert
@ -12,6 +13,7 @@ plugins/module_utils/_platform.py no-assert
plugins/module_utils/_socket_handler.py no-assert plugins/module_utils/_socket_handler.py no-assert
plugins/module_utils/_swarm.py pep8:E704 plugins/module_utils/_swarm.py pep8:E704
plugins/module_utils/_util.py pep8:E704 plugins/module_utils/_util.py pep8:E704
plugins/modules/docker_container_copy_into.py no-assert
plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin
plugins/modules/docker_container_exec.py no-assert plugins/modules/docker_container_exec.py no-assert
plugins/modules/docker_container_exec.py pylint:unpacking-non-sequence plugins/modules/docker_container_exec.py pylint:unpacking-non-sequence
@ -19,4 +21,6 @@ plugins/modules/docker_image.py no-assert
plugins/modules/docker_image_tag.py no-assert plugins/modules/docker_image_tag.py no-assert
plugins/modules/docker_login.py no-assert plugins/modules/docker_login.py no-assert
plugins/modules/docker_plugin.py no-assert plugins/modules/docker_plugin.py no-assert
plugins/modules/docker_swarm_service.py no-assert
plugins/modules/docker_swarm_service.py pep8:E704
plugins/modules/docker_volume.py no-assert plugins/modules/docker_volume.py no-assert

View File

@ -5,6 +5,7 @@ plugins/module_utils/_api/api/client.py pep8:E704
plugins/module_utils/_api/transport/sshconn.py no-assert plugins/module_utils/_api/transport/sshconn.py no-assert
plugins/module_utils/_api/utils/build.py no-assert plugins/module_utils/_api/utils/build.py no-assert
plugins/module_utils/_api/utils/ports.py pep8:E704 plugins/module_utils/_api/utils/ports.py pep8:E704
plugins/module_utils/_api/utils/proxy.py pep8:E704
plugins/module_utils/_api/utils/socket.py pep8:E704 plugins/module_utils/_api/utils/socket.py pep8:E704
plugins/module_utils/_common_cli.py pep8:E704 plugins/module_utils/_common_cli.py pep8:E704
plugins/module_utils/_module_container/module.py no-assert plugins/module_utils/_module_container/module.py no-assert
@ -12,10 +13,13 @@ plugins/module_utils/_platform.py no-assert
plugins/module_utils/_socket_handler.py no-assert plugins/module_utils/_socket_handler.py no-assert
plugins/module_utils/_swarm.py pep8:E704 plugins/module_utils/_swarm.py pep8:E704
plugins/module_utils/_util.py pep8:E704 plugins/module_utils/_util.py pep8:E704
plugins/modules/docker_container_copy_into.py no-assert
plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin
plugins/modules/docker_container_exec.py no-assert plugins/modules/docker_container_exec.py no-assert
plugins/modules/docker_image.py no-assert plugins/modules/docker_image.py no-assert
plugins/modules/docker_image_tag.py no-assert plugins/modules/docker_image_tag.py no-assert
plugins/modules/docker_login.py no-assert plugins/modules/docker_login.py no-assert
plugins/modules/docker_plugin.py no-assert plugins/modules/docker_plugin.py no-assert
plugins/modules/docker_swarm_service.py no-assert
plugins/modules/docker_swarm_service.py pep8:E704
plugins/modules/docker_volume.py no-assert plugins/modules/docker_volume.py no-assert

View File

@ -6,10 +6,12 @@ plugins/module_utils/_api/utils/build.py no-assert
plugins/module_utils/_module_container/module.py no-assert plugins/module_utils/_module_container/module.py no-assert
plugins/module_utils/_platform.py no-assert plugins/module_utils/_platform.py no-assert
plugins/module_utils/_socket_handler.py no-assert plugins/module_utils/_socket_handler.py no-assert
plugins/modules/docker_container_copy_into.py no-assert
plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin
plugins/modules/docker_container_exec.py no-assert plugins/modules/docker_container_exec.py no-assert
plugins/modules/docker_image.py no-assert plugins/modules/docker_image.py no-assert
plugins/modules/docker_image_tag.py no-assert plugins/modules/docker_image_tag.py no-assert
plugins/modules/docker_login.py no-assert plugins/modules/docker_login.py no-assert
plugins/modules/docker_plugin.py no-assert plugins/modules/docker_plugin.py no-assert
plugins/modules/docker_swarm_service.py no-assert
plugins/modules/docker_volume.py no-assert plugins/modules/docker_volume.py no-assert

View File

@ -6,10 +6,12 @@ plugins/module_utils/_api/utils/build.py no-assert
plugins/module_utils/_module_container/module.py no-assert plugins/module_utils/_module_container/module.py no-assert
plugins/module_utils/_platform.py no-assert plugins/module_utils/_platform.py no-assert
plugins/module_utils/_socket_handler.py no-assert plugins/module_utils/_socket_handler.py no-assert
plugins/modules/docker_container_copy_into.py no-assert
plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin
plugins/modules/docker_container_exec.py no-assert plugins/modules/docker_container_exec.py no-assert
plugins/modules/docker_image.py no-assert plugins/modules/docker_image.py no-assert
plugins/modules/docker_image_tag.py no-assert plugins/modules/docker_image_tag.py no-assert
plugins/modules/docker_login.py no-assert plugins/modules/docker_login.py no-assert
plugins/modules/docker_plugin.py no-assert plugins/modules/docker_plugin.py no-assert
plugins/modules/docker_swarm_service.py no-assert
plugins/modules/docker_volume.py no-assert plugins/modules/docker_volume.py no-assert

View File

@ -6,10 +6,12 @@ plugins/module_utils/_api/utils/build.py no-assert
plugins/module_utils/_module_container/module.py no-assert plugins/module_utils/_module_container/module.py no-assert
plugins/module_utils/_platform.py no-assert plugins/module_utils/_platform.py no-assert
plugins/module_utils/_socket_handler.py no-assert plugins/module_utils/_socket_handler.py no-assert
plugins/modules/docker_container_copy_into.py no-assert
plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin plugins/modules/docker_container_copy_into.py validate-modules:undocumented-parameter # _max_file_size_for_diff is used by the action plugin
plugins/modules/docker_container_exec.py no-assert plugins/modules/docker_container_exec.py no-assert
plugins/modules/docker_image.py no-assert plugins/modules/docker_image.py no-assert
plugins/modules/docker_image_tag.py no-assert plugins/modules/docker_image_tag.py no-assert
plugins/modules/docker_login.py no-assert plugins/modules/docker_login.py no-assert
plugins/modules/docker_plugin.py no-assert plugins/modules/docker_plugin.py no-assert
plugins/modules/docker_swarm_service.py no-assert
plugins/modules/docker_volume.py no-assert plugins/modules/docker_volume.py no-assert

View File

@ -4,6 +4,7 @@
from __future__ import annotations from __future__ import annotations
import typing as t
import unittest import unittest
from io import StringIO from io import StringIO
from unittest import mock from unittest import mock
@ -14,8 +15,7 @@ from ansible.plugins.loader import connection_loader
class TestDockerConnectionClass(unittest.TestCase): class TestDockerConnectionClass(unittest.TestCase):
def setUp(self) -> None:
def setUp(self):
self.play_context = PlayContext() self.play_context = PlayContext()
self.play_context.prompt = ( self.play_context.prompt = (
"[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: " "[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: "
@ -29,7 +29,7 @@ class TestDockerConnectionClass(unittest.TestCase):
"community.docker.docker", self.play_context, self.in_stream "community.docker.docker", self.play_context, self.in_stream
) )
def tearDown(self): def tearDown(self) -> None:
pass pass
@mock.patch( @mock.patch(
@ -41,8 +41,8 @@ class TestDockerConnectionClass(unittest.TestCase):
return_value=("docker version", "1.2.3", "", 0), return_value=("docker version", "1.2.3", "", 0),
) )
def test_docker_connection_module_too_old( def test_docker_connection_module_too_old(
self, mock_new_docker_version, mock_old_docker_version self, mock_new_docker_version: t.Any, mock_old_docker_version: t.Any
): ) -> None:
self.dc._version = None self.dc._version = None
self.dc.remote_user = "foo" self.dc.remote_user = "foo"
self.assertRaisesRegex( self.assertRaisesRegex(
@ -60,8 +60,8 @@ class TestDockerConnectionClass(unittest.TestCase):
return_value=("docker version", "1.7.0", "", 0), return_value=("docker version", "1.7.0", "", 0),
) )
def test_docker_connection_module( def test_docker_connection_module(
self, mock_new_docker_version, mock_old_docker_version self, mock_new_docker_version: t.Any, mock_old_docker_version: t.Any
): ) -> None:
self.dc._version = None self.dc._version = None
# old version and new version fail # old version and new version fail
@ -74,8 +74,8 @@ class TestDockerConnectionClass(unittest.TestCase):
return_value=("false", "garbage", "", 1), return_value=("false", "garbage", "", 1),
) )
def test_docker_connection_module_wrong_cmd( def test_docker_connection_module_wrong_cmd(
self, mock_new_docker_version, mock_old_docker_version self, mock_new_docker_version: t.Any, mock_old_docker_version: t.Any
): ) -> None:
self.dc._version = None self.dc._version = None
self.dc.remote_user = "foo" self.dc.remote_user = "foo"
self.assertRaisesRegex( self.assertRaisesRegex(

View File

@ -4,6 +4,7 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from unittest.mock import create_autospec from unittest.mock import create_autospec
import pytest import pytest
@ -19,14 +20,18 @@ from ansible_collections.community.docker.plugins.inventory.docker_containers im
) )
if t.TYPE_CHECKING:
from collections.abc import Callable
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def templar(): def templar() -> Templar:
dataloader = create_autospec(DataLoader, instance=True) dataloader = create_autospec(DataLoader, instance=True)
return Templar(loader=dataloader) return Templar(loader=dataloader)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def inventory(templar): def inventory(templar: Templar) -> InventoryModule:
r = InventoryModule() r = InventoryModule()
r.inventory = InventoryData() r.inventory = InventoryData()
r.templar = templar r.templar = templar
@ -83,8 +88,10 @@ LOVING_THARP_SERVICE = {
} }
def create_get_option(options, default=False): def create_get_option(
def get_option(option): options: dict[str, t.Any], default: t.Any = False
) -> Callable[[str], t.Any]:
def get_option(option: str) -> t.Any:
if option in options: if option in options:
return options[option] return options[option]
return default return default
@ -93,9 +100,9 @@ def create_get_option(options, default=False):
class FakeClient: class FakeClient:
def __init__(self, *hosts): def __init__(self, *hosts: dict[str, t.Any]) -> None:
self.get_results = {} self.get_results: dict[str, t.Any] = {}
list_reply = [] list_reply: list[dict[str, t.Any]] = []
for host in hosts: for host in hosts:
list_reply.append( list_reply.append(
{ {
@ -109,15 +116,16 @@ class FakeClient:
self.get_results[f"/containers/{host['Id']}/json"] = host self.get_results[f"/containers/{host['Id']}/json"] = host
self.get_results["/containers/json"] = list_reply self.get_results["/containers/json"] = list_reply
def get_json(self, url, *param, **kwargs): def get_json(self, url: str, *param: str, **kwargs: t.Any) -> t.Any:
url = url.format(*param) url = url.format(*param)
return self.get_results[url] return self.get_results[url]
def test_populate(inventory, mocker): def test_populate(inventory: InventoryModule, mocker: t.Any) -> None:
assert inventory.inventory is not None
client = FakeClient(LOVING_THARP) client = FakeClient(LOVING_THARP)
inventory.get_option = mocker.MagicMock( inventory.get_option = mocker.MagicMock( # type: ignore[method-assign]
side_effect=create_get_option( side_effect=create_get_option(
{ {
"verbose_output": True, "verbose_output": True,
@ -130,9 +138,10 @@ def test_populate(inventory, mocker):
} }
) )
) )
inventory._populate(client) inventory._populate(client) # type: ignore
host_1 = inventory.inventory.get_host("loving_tharp") host_1 = inventory.inventory.get_host("loving_tharp")
assert host_1 is not None
host_1_vars = host_1.get_vars() host_1_vars = host_1.get_vars()
assert host_1_vars["ansible_host"] == "loving_tharp" assert host_1_vars["ansible_host"] == "loving_tharp"
@ -149,10 +158,11 @@ def test_populate(inventory, mocker):
assert len(inventory.inventory.hosts) == 1 assert len(inventory.inventory.hosts) == 1
def test_populate_service(inventory, mocker): def test_populate_service(inventory: InventoryModule, mocker: t.Any) -> None:
assert inventory.inventory is not None
client = FakeClient(LOVING_THARP_SERVICE) client = FakeClient(LOVING_THARP_SERVICE)
inventory.get_option = mocker.MagicMock( inventory.get_option = mocker.MagicMock( # type: ignore[method-assign]
side_effect=create_get_option( side_effect=create_get_option(
{ {
"verbose_output": False, "verbose_output": False,
@ -166,9 +176,10 @@ def test_populate_service(inventory, mocker):
} }
) )
) )
inventory._populate(client) inventory._populate(client) # type: ignore
host_1 = inventory.inventory.get_host("loving_tharp") host_1 = inventory.inventory.get_host("loving_tharp")
assert host_1 is not None
host_1_vars = host_1.get_vars() host_1_vars = host_1.get_vars()
assert host_1_vars["ansible_host"] == "loving_tharp" assert host_1_vars["ansible_host"] == "loving_tharp"
@ -207,10 +218,11 @@ def test_populate_service(inventory, mocker):
assert len(inventory.inventory.hosts) == 1 assert len(inventory.inventory.hosts) == 1
def test_populate_stack(inventory, mocker): def test_populate_stack(inventory: InventoryModule, mocker: t.Any) -> None:
assert inventory.inventory is not None
client = FakeClient(LOVING_THARP_STACK) client = FakeClient(LOVING_THARP_STACK)
inventory.get_option = mocker.MagicMock( inventory.get_option = mocker.MagicMock( # type: ignore[method-assign]
side_effect=create_get_option( side_effect=create_get_option(
{ {
"verbose_output": False, "verbose_output": False,
@ -226,9 +238,10 @@ def test_populate_stack(inventory, mocker):
} }
) )
) )
inventory._populate(client) inventory._populate(client) # type: ignore
host_1 = inventory.inventory.get_host("loving_tharp") host_1 = inventory.inventory.get_host("loving_tharp")
assert host_1 is not None
host_1_vars = host_1.get_vars() host_1_vars = host_1.get_vars()
assert host_1_vars["ansible_ssh_host"] == "127.0.0.1" assert host_1_vars["ansible_ssh_host"] == "127.0.0.1"
@ -267,10 +280,11 @@ def test_populate_stack(inventory, mocker):
assert len(inventory.inventory.hosts) == 1 assert len(inventory.inventory.hosts) == 1
def test_populate_filter_none(inventory, mocker): def test_populate_filter_none(inventory: InventoryModule, mocker: t.Any) -> None:
assert inventory.inventory is not None
client = FakeClient(LOVING_THARP) client = FakeClient(LOVING_THARP)
inventory.get_option = mocker.MagicMock( inventory.get_option = mocker.MagicMock( # type: ignore[method-assign]
side_effect=create_get_option( side_effect=create_get_option(
{ {
"verbose_output": True, "verbose_output": True,
@ -285,15 +299,16 @@ def test_populate_filter_none(inventory, mocker):
} }
) )
) )
inventory._populate(client) inventory._populate(client) # type: ignore
assert len(inventory.inventory.hosts) == 0 assert len(inventory.inventory.hosts) == 0
def test_populate_filter(inventory, mocker): def test_populate_filter(inventory: InventoryModule, mocker: t.Any) -> None:
assert inventory.inventory is not None
client = FakeClient(LOVING_THARP) client = FakeClient(LOVING_THARP)
inventory.get_option = mocker.MagicMock( inventory.get_option = mocker.MagicMock( # type: ignore[method-assign]
side_effect=create_get_option( side_effect=create_get_option(
{ {
"verbose_output": True, "verbose_output": True,
@ -309,9 +324,10 @@ def test_populate_filter(inventory, mocker):
} }
) )
) )
inventory._populate(client) inventory._populate(client) # type: ignore
host_1 = inventory.inventory.get_host("loving_tharp") host_1 = inventory.inventory.get_host("loving_tharp")
assert host_1 is not None
host_1_vars = host_1.get_vars() host_1_vars = host_1.get_vars()
assert host_1_vars["ansible_host"] == "loving_tharp" assert host_1_vars["ansible_host"] == "loving_tharp"

View File

@ -19,6 +19,7 @@ import struct
import tempfile import tempfile
import threading import threading
import time import time
import typing as t
import unittest import unittest
from http.server import BaseHTTPRequestHandler from http.server import BaseHTTPRequestHandler
from socketserver import ThreadingTCPServer from socketserver import ThreadingTCPServer
@ -42,18 +43,24 @@ from ansible_collections.community.docker.tests.unit.plugins.module_utils._api.c
from .. import fake_api from .. import fake_api
if t.TYPE_CHECKING:
from ansible_collections.community.docker.plugins.module_utils._api.auth import (
AuthConfig,
)
DEFAULT_TIMEOUT_SECONDS = constants.DEFAULT_TIMEOUT_SECONDS DEFAULT_TIMEOUT_SECONDS = constants.DEFAULT_TIMEOUT_SECONDS
def response( def response(
status_code=200, status_code: int = 200,
content="", content: bytes | dict[str, t.Any] | list[dict[str, t.Any]] = b"",
headers=None, headers: dict[str, str] | None = None,
reason=None, reason: str = "",
elapsed=0, elapsed: int = 0,
request=None, request: requests.PreparedRequest | None = None,
raw=None, raw: urllib3.HTTPResponse | None = None,
): ) -> requests.Response:
res = requests.Response() res = requests.Response()
res.status_code = status_code res.status_code = status_code
if not isinstance(content, bytes): if not isinstance(content, bytes):
@ -62,23 +69,25 @@ def response(
res.headers = requests.structures.CaseInsensitiveDict(headers or {}) res.headers = requests.structures.CaseInsensitiveDict(headers or {})
res.reason = reason res.reason = reason
res.elapsed = datetime.timedelta(elapsed) res.elapsed = datetime.timedelta(elapsed)
res.request = request res.request = request # type: ignore
res.raw = raw res.raw = raw
return res return res
def fake_resolve_authconfig( def fake_resolve_authconfig( # pylint: disable=keyword-arg-before-vararg
authconfig, registry=None, *args, **kwargs authconfig: AuthConfig, *args: t.Any, registry: str | None = None, **kwargs: t.Any
): # pylint: disable=keyword-arg-before-vararg ) -> None:
return None return None
def fake_inspect_container(self, container, tty=False): def fake_inspect_container(self: object, container: str, tty: bool = False) -> t.Any:
return fake_api.get_fake_inspect_container(tty=tty)[1] return fake_api.get_fake_inspect_container(tty=tty)[1]
def fake_resp(method, url, *args, **kwargs): def fake_resp(
key = None method: str, url: str, *args: t.Any, **kwargs: t.Any
) -> requests.Response:
key: str | tuple[str, str] | None = None
if url in fake_api.fake_responses: if url in fake_api.fake_responses:
key = url key = url
elif (url, method) in fake_api.fake_responses: elif (url, method) in fake_api.fake_responses:
@ -92,23 +101,37 @@ def fake_resp(method, url, *args, **kwargs):
fake_request = mock.Mock(side_effect=fake_resp) fake_request = mock.Mock(side_effect=fake_resp)
def fake_get(self, url, *args, **kwargs): def fake_get(
self: APIClient, url: str, *args: str, **kwargs: t.Any
) -> requests.Response:
return fake_request("GET", url, *args, **kwargs) return fake_request("GET", url, *args, **kwargs)
def fake_post(self, url, *args, **kwargs): def fake_post(
self: APIClient, url: str, *args: str, **kwargs: t.Any
) -> requests.Response:
return fake_request("POST", url, *args, **kwargs) return fake_request("POST", url, *args, **kwargs)
def fake_put(self, url, *args, **kwargs): def fake_put(
self: APIClient, url: str, *args: str, **kwargs: t.Any
) -> requests.Response:
return fake_request("PUT", url, *args, **kwargs) return fake_request("PUT", url, *args, **kwargs)
def fake_delete(self, url, *args, **kwargs): def fake_delete(
self: APIClient, url: str, *args: str, **kwargs: t.Any
) -> requests.Response:
return fake_request("DELETE", url, *args, **kwargs) return fake_request("DELETE", url, *args, **kwargs)
def fake_read_from_socket(self, response, stream, tty=False, demux=False): def fake_read_from_socket(
self: APIClient,
response: requests.Response,
stream: bool,
tty: bool = False,
demux: bool = False,
) -> bytes:
return b"" return b""
@ -117,7 +140,7 @@ url_prefix = f"{url_base}v{DEFAULT_DOCKER_API_VERSION}/" # pylint: disable=inva
class BaseAPIClientTest(unittest.TestCase): class BaseAPIClientTest(unittest.TestCase):
def setUp(self): def setUp(self) -> None:
self.patcher = mock.patch.multiple( self.patcher = mock.patch.multiple(
"ansible_collections.community.docker.plugins.module_utils._api.api.client.APIClient", "ansible_collections.community.docker.plugins.module_utils._api.api.client.APIClient",
get=fake_get, get=fake_get,
@ -129,11 +152,13 @@ class BaseAPIClientTest(unittest.TestCase):
self.patcher.start() self.patcher.start()
self.client = APIClient(version=DEFAULT_DOCKER_API_VERSION) self.client = APIClient(version=DEFAULT_DOCKER_API_VERSION)
def tearDown(self): def tearDown(self) -> None:
self.client.close() self.client.close()
self.patcher.stop() self.patcher.stop()
def base_create_payload(self, img="busybox", cmd=None): def base_create_payload(
self, img: str = "busybox", cmd: list[str] | None = None
) -> dict[str, t.Any]:
if not cmd: if not cmd:
cmd = ["true"] cmd = ["true"]
return { return {
@ -150,16 +175,16 @@ class BaseAPIClientTest(unittest.TestCase):
class DockerApiTest(BaseAPIClientTest): class DockerApiTest(BaseAPIClientTest):
def test_ctor(self): def test_ctor(self) -> None:
with pytest.raises(errors.DockerException) as excinfo: with pytest.raises(errors.DockerException) as excinfo:
APIClient(version=1.12) APIClient(version=1.12) # type: ignore
assert ( assert (
str(excinfo.value) str(excinfo.value)
== "Version parameter must be a string or None. Found float" == "Version parameter must be a string or None. Found float"
) )
def test_url_valid_resource(self): def test_url_valid_resource(self) -> None:
url = self.client._url("/hello/{0}/world", "somename") url = self.client._url("/hello/{0}/world", "somename")
assert url == f"{url_prefix}hello/somename/world" assert url == f"{url_prefix}hello/somename/world"
@ -172,50 +197,50 @@ class DockerApiTest(BaseAPIClientTest):
url = self.client._url("/images/{0}/push", "localhost:5000/image") url = self.client._url("/images/{0}/push", "localhost:5000/image")
assert url == f"{url_prefix}images/localhost:5000/image/push" assert url == f"{url_prefix}images/localhost:5000/image/push"
def test_url_invalid_resource(self): def test_url_invalid_resource(self) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
self.client._url("/hello/{0}/world", ["sakuya", "izayoi"]) self.client._url("/hello/{0}/world", ["sakuya", "izayoi"]) # type: ignore
def test_url_no_resource(self): def test_url_no_resource(self) -> None:
url = self.client._url("/simple") url = self.client._url("/simple")
assert url == f"{url_prefix}simple" assert url == f"{url_prefix}simple"
def test_url_unversioned_api(self): def test_url_unversioned_api(self) -> None:
url = self.client._url("/hello/{0}/world", "somename", versioned_api=False) url = self.client._url("/hello/{0}/world", "somename", versioned_api=False)
assert url == f"{url_base}hello/somename/world" assert url == f"{url_base}hello/somename/world"
def test_version(self): def test_version(self) -> None:
self.client.version() self.client.version()
fake_request.assert_called_with( fake_request.assert_called_with(
"GET", url_prefix + "version", timeout=DEFAULT_TIMEOUT_SECONDS "GET", url_prefix + "version", timeout=DEFAULT_TIMEOUT_SECONDS
) )
def test_version_no_api_version(self): def test_version_no_api_version(self) -> None:
self.client.version(False) self.client.version(False)
fake_request.assert_called_with( fake_request.assert_called_with(
"GET", url_base + "version", timeout=DEFAULT_TIMEOUT_SECONDS "GET", url_base + "version", timeout=DEFAULT_TIMEOUT_SECONDS
) )
def test_retrieve_server_version(self): def test_retrieve_server_version(self) -> None:
client = APIClient(version="auto") client = APIClient(version="auto")
assert isinstance(client._version, str) assert isinstance(client._version, str)
assert not (client._version == "auto") assert not (client._version == "auto")
client.close() client.close()
def test_auto_retrieve_server_version(self): def test_auto_retrieve_server_version(self) -> None:
version = self.client._retrieve_server_version() version = self.client._retrieve_server_version()
assert isinstance(version, str) assert isinstance(version, str)
def test_info(self): def test_info(self) -> None:
self.client.info() self.client.info()
fake_request.assert_called_with( fake_request.assert_called_with(
"GET", url_prefix + "info", timeout=DEFAULT_TIMEOUT_SECONDS "GET", url_prefix + "info", timeout=DEFAULT_TIMEOUT_SECONDS
) )
def test_search(self): def test_search(self) -> None:
self.client.get_json("/images/search", params={"term": "busybox"}) self.client.get_json("/images/search", params={"term": "busybox"})
fake_request.assert_called_with( fake_request.assert_called_with(
@ -225,7 +250,7 @@ class DockerApiTest(BaseAPIClientTest):
timeout=DEFAULT_TIMEOUT_SECONDS, timeout=DEFAULT_TIMEOUT_SECONDS,
) )
def test_login(self): def test_login(self) -> None:
self.client.login("sakuya", "izayoi") self.client.login("sakuya", "izayoi")
args = fake_request.call_args args = fake_request.call_args
assert args[0][0] == "POST" assert args[0][0] == "POST"
@ -242,42 +267,42 @@ class DockerApiTest(BaseAPIClientTest):
"serveraddress": None, "serveraddress": None,
} }
def _socket_path_for_client_session(self, client): def _socket_path_for_client_session(self, client: APIClient) -> str:
socket_adapter = client.get_adapter("http+docker://") socket_adapter = client.get_adapter("http+docker://")
return socket_adapter.socket_path return socket_adapter.socket_path # type: ignore[attr-defined]
def test_url_compatibility_unix(self): def test_url_compatibility_unix(self) -> None:
c = APIClient(base_url="unix://socket", version=DEFAULT_DOCKER_API_VERSION) c = APIClient(base_url="unix://socket", version=DEFAULT_DOCKER_API_VERSION)
assert self._socket_path_for_client_session(c) == "/socket" assert self._socket_path_for_client_session(c) == "/socket"
def test_url_compatibility_unix_triple_slash(self): def test_url_compatibility_unix_triple_slash(self) -> None:
c = APIClient(base_url="unix:///socket", version=DEFAULT_DOCKER_API_VERSION) c = APIClient(base_url="unix:///socket", version=DEFAULT_DOCKER_API_VERSION)
assert self._socket_path_for_client_session(c) == "/socket" assert self._socket_path_for_client_session(c) == "/socket"
def test_url_compatibility_http_unix_triple_slash(self): def test_url_compatibility_http_unix_triple_slash(self) -> None:
c = APIClient( c = APIClient(
base_url="http+unix:///socket", version=DEFAULT_DOCKER_API_VERSION base_url="http+unix:///socket", version=DEFAULT_DOCKER_API_VERSION
) )
assert self._socket_path_for_client_session(c) == "/socket" assert self._socket_path_for_client_session(c) == "/socket"
def test_url_compatibility_http(self): def test_url_compatibility_http(self) -> None:
c = APIClient( c = APIClient(
base_url="http://hostname:1234", version=DEFAULT_DOCKER_API_VERSION base_url="http://hostname:1234", version=DEFAULT_DOCKER_API_VERSION
) )
assert c.base_url == "http://hostname:1234" assert c.base_url == "http://hostname:1234"
def test_url_compatibility_tcp(self): def test_url_compatibility_tcp(self) -> None:
c = APIClient( c = APIClient(
base_url="tcp://hostname:1234", version=DEFAULT_DOCKER_API_VERSION base_url="tcp://hostname:1234", version=DEFAULT_DOCKER_API_VERSION
) )
assert c.base_url == "http://hostname:1234" assert c.base_url == "http://hostname:1234"
def test_remove_link(self): def test_remove_link(self) -> None:
self.client.delete_call( self.client.delete_call(
"/containers/{0}", "/containers/{0}",
"3cc2351ab11b", "3cc2351ab11b",
@ -291,7 +316,7 @@ class DockerApiTest(BaseAPIClientTest):
timeout=DEFAULT_TIMEOUT_SECONDS, timeout=DEFAULT_TIMEOUT_SECONDS,
) )
def test_stream_helper_decoding(self): def test_stream_helper_decoding(self) -> None:
status_code, content = fake_api.fake_responses[url_prefix + "events"]() status_code, content = fake_api.fake_responses[url_prefix + "events"]()
content_str = json.dumps(content).encode("utf-8") content_str = json.dumps(content).encode("utf-8")
body = io.BytesIO(content_str) body = io.BytesIO(content_str)
@ -318,7 +343,7 @@ class DockerApiTest(BaseAPIClientTest):
raw_resp._fp.seek(0) raw_resp._fp.seek(0)
resp = response(status_code=status_code, content=content, raw=raw_resp) resp = response(status_code=status_code, content=content, raw=raw_resp)
result = next(self.client._stream_helper(resp)) result = next(self.client._stream_helper(resp))
assert result == content_str.decode("utf-8") assert result == content_str.decode("utf-8") # type: ignore
# non-chunked response, pass `decode=True` to the helper # non-chunked response, pass `decode=True` to the helper
raw_resp._fp.seek(0) raw_resp._fp.seek(0)
@ -328,7 +353,7 @@ class DockerApiTest(BaseAPIClientTest):
class UnixSocketStreamTest(unittest.TestCase): class UnixSocketStreamTest(unittest.TestCase):
def setUp(self): def setUp(self) -> None:
socket_dir = tempfile.mkdtemp() socket_dir = tempfile.mkdtemp()
self.build_context = tempfile.mkdtemp() self.build_context = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, socket_dir) self.addCleanup(shutil.rmtree, socket_dir)
@ -339,23 +364,23 @@ class UnixSocketStreamTest(unittest.TestCase):
server_thread = threading.Thread(target=self.run_server) server_thread = threading.Thread(target=self.run_server)
server_thread.daemon = True server_thread.daemon = True
server_thread.start() server_thread.start()
self.response = None self.response: t.Any = None
self.request_handler = None self.request_handler: t.Any = None
self.addCleanup(server_thread.join) self.addCleanup(server_thread.join)
self.addCleanup(self.stop) self.addCleanup(self.stop)
def stop(self): def stop(self) -> None:
self.stop_server = True self.stop_server = True
def _setup_socket(self): def _setup_socket(self) -> socket.socket:
server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server_sock.bind(self.socket_file) server_sock.bind(self.socket_file)
# Non-blocking mode so that we can shut the test down easily # Non-blocking mode so that we can shut the test down easily
server_sock.setblocking(0) server_sock.setblocking(0) # type: ignore
server_sock.listen(5) server_sock.listen(5)
return server_sock return server_sock
def run_server(self): def run_server(self) -> None:
try: try:
while not self.stop_server: while not self.stop_server:
try: try:
@ -365,7 +390,7 @@ class UnixSocketStreamTest(unittest.TestCase):
time.sleep(0.01) time.sleep(0.01)
continue continue
connection.setblocking(1) connection.setblocking(1) # type: ignore
try: try:
self.request_handler(connection) self.request_handler(connection)
finally: finally:
@ -373,7 +398,7 @@ class UnixSocketStreamTest(unittest.TestCase):
finally: finally:
self.server_socket.close() self.server_socket.close()
def early_response_sending_handler(self, connection): def early_response_sending_handler(self, connection: socket.socket) -> None:
data = b"" data = b""
headers = None headers = None
@ -395,7 +420,7 @@ class UnixSocketStreamTest(unittest.TestCase):
data += connection.recv(2048) data += connection.recv(2048)
@pytest.mark.skipif(constants.IS_WINDOWS_PLATFORM, reason="Unix only") @pytest.mark.skipif(constants.IS_WINDOWS_PLATFORM, reason="Unix only")
def test_early_stream_response(self): def test_early_stream_response(self) -> None:
self.request_handler = self.early_response_sending_handler self.request_handler = self.early_response_sending_handler
lines = [] lines = []
for i in range(0, 50): for i in range(0, 50):
@ -405,7 +430,7 @@ class UnixSocketStreamTest(unittest.TestCase):
lines.append(b"") lines.append(b"")
self.response = ( self.response = (
b"HTTP/1.1 200 OK\r\n" b"Transfer-Encoding: chunked\r\n" b"\r\n" b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n"
) + b"\r\n".join(lines) ) + b"\r\n".join(lines)
with APIClient( with APIClient(
@ -459,8 +484,12 @@ class TCPSocketStreamTest(unittest.TestCase):
built on these islands for generations past? Now shall what of Him? built on these islands for generations past? Now shall what of Him?
""" """
server: ThreadingTCPServer
thread: threading.Thread
address: str
@classmethod @classmethod
def setup_class(cls): def setup_class(cls) -> None:
cls.server = ThreadingTCPServer(("", 0), cls.get_handler_class()) cls.server = ThreadingTCPServer(("", 0), cls.get_handler_class())
cls.thread = threading.Thread(target=cls.server.serve_forever) cls.thread = threading.Thread(target=cls.server.serve_forever)
cls.thread.daemon = True cls.thread.daemon = True
@ -468,18 +497,18 @@ class TCPSocketStreamTest(unittest.TestCase):
cls.address = f"http://{socket.gethostname()}:{cls.server.server_address[1]}" cls.address = f"http://{socket.gethostname()}:{cls.server.server_address[1]}"
@classmethod @classmethod
def teardown_class(cls): def teardown_class(cls) -> None:
cls.server.shutdown() cls.server.shutdown()
cls.server.server_close() cls.server.server_close()
cls.thread.join() cls.thread.join()
@classmethod @classmethod
def get_handler_class(cls): def get_handler_class(cls) -> t.Type[BaseHTTPRequestHandler]:
stdout_data = cls.stdout_data stdout_data = cls.stdout_data
stderr_data = cls.stderr_data stderr_data = cls.stderr_data
class Handler(BaseHTTPRequestHandler): class Handler(BaseHTTPRequestHandler):
def do_POST(self): # pylint: disable=invalid-name def do_POST(self) -> None: # pylint: disable=invalid-name
resp_data = self.get_resp_data() resp_data = self.get_resp_data()
self.send_response(101) self.send_response(101)
self.send_header("Content-Type", "application/vnd.docker.raw-stream") self.send_header("Content-Type", "application/vnd.docker.raw-stream")
@ -491,7 +520,7 @@ class TCPSocketStreamTest(unittest.TestCase):
self.wfile.write(resp_data) self.wfile.write(resp_data)
self.wfile.flush() self.wfile.flush()
def get_resp_data(self): def get_resp_data(self) -> bytes:
path = self.path.split("/")[-1] path = self.path.split("/")[-1]
if path == "tty": if path == "tty":
return stdout_data + stderr_data return stdout_data + stderr_data
@ -505,12 +534,17 @@ class TCPSocketStreamTest(unittest.TestCase):
raise NotImplementedError(f"Unknown path {path}") raise NotImplementedError(f"Unknown path {path}")
@staticmethod @staticmethod
def frame_header(stream, data): def frame_header(stream: int, data: bytes) -> bytes:
return struct.pack(">BxxxL", stream, len(data)) return struct.pack(">BxxxL", stream, len(data))
return Handler return Handler
def request(self, stream=None, tty=None, demux=None): def request(
self,
stream: bool | None = None,
tty: bool | None = None,
demux: bool | None = None,
) -> t.Any:
assert stream is not None and tty is not None and demux is not None assert stream is not None and tty is not None and demux is not None
with APIClient( with APIClient(
base_url=self.address, base_url=self.address,
@ -523,51 +557,51 @@ class TCPSocketStreamTest(unittest.TestCase):
resp = client._post(url, stream=True) resp = client._post(url, stream=True)
return client._read_from_socket(resp, stream=stream, tty=tty, demux=demux) return client._read_from_socket(resp, stream=stream, tty=tty, demux=demux)
def test_read_from_socket_tty(self): def test_read_from_socket_tty(self) -> None:
res = self.request(stream=True, tty=True, demux=False) res = self.request(stream=True, tty=True, demux=False)
assert next(res) == self.stdout_data + self.stderr_data assert next(res) == self.stdout_data + self.stderr_data
with self.assertRaises(StopIteration): with self.assertRaises(StopIteration):
next(res) next(res)
def test_read_from_socket_tty_demux(self): def test_read_from_socket_tty_demux(self) -> None:
res = self.request(stream=True, tty=True, demux=True) res = self.request(stream=True, tty=True, demux=True)
assert next(res) == (self.stdout_data + self.stderr_data, None) assert next(res) == (self.stdout_data + self.stderr_data, None)
with self.assertRaises(StopIteration): with self.assertRaises(StopIteration):
next(res) next(res)
def test_read_from_socket_no_tty(self): def test_read_from_socket_no_tty(self) -> None:
res = self.request(stream=True, tty=False, demux=False) res = self.request(stream=True, tty=False, demux=False)
assert next(res) == self.stdout_data assert next(res) == self.stdout_data
assert next(res) == self.stderr_data assert next(res) == self.stderr_data
with self.assertRaises(StopIteration): with self.assertRaises(StopIteration):
next(res) next(res)
def test_read_from_socket_no_tty_demux(self): def test_read_from_socket_no_tty_demux(self) -> None:
res = self.request(stream=True, tty=False, demux=True) res = self.request(stream=True, tty=False, demux=True)
assert (self.stdout_data, None) == next(res) assert (self.stdout_data, None) == next(res)
assert (None, self.stderr_data) == next(res) assert (None, self.stderr_data) == next(res)
with self.assertRaises(StopIteration): with self.assertRaises(StopIteration):
next(res) next(res)
def test_read_from_socket_no_stream_tty(self): def test_read_from_socket_no_stream_tty(self) -> None:
res = self.request(stream=False, tty=True, demux=False) res = self.request(stream=False, tty=True, demux=False)
assert res == self.stdout_data + self.stderr_data assert res == self.stdout_data + self.stderr_data
def test_read_from_socket_no_stream_tty_demux(self): def test_read_from_socket_no_stream_tty_demux(self) -> None:
res = self.request(stream=False, tty=True, demux=True) res = self.request(stream=False, tty=True, demux=True)
assert res == (self.stdout_data + self.stderr_data, None) assert res == (self.stdout_data + self.stderr_data, None)
def test_read_from_socket_no_stream_no_tty(self): def test_read_from_socket_no_stream_no_tty(self) -> None:
res = self.request(stream=False, tty=False, demux=False) res = self.request(stream=False, tty=False, demux=False)
assert res == self.stdout_data + self.stderr_data assert res == self.stdout_data + self.stderr_data
def test_read_from_socket_no_stream_no_tty_demux(self): def test_read_from_socket_no_stream_no_tty_demux(self) -> None:
res = self.request(stream=False, tty=False, demux=True) res = self.request(stream=False, tty=False, demux=True)
assert res == (self.stdout_data, self.stderr_data) assert res == (self.stdout_data, self.stderr_data)
class UserAgentTest(unittest.TestCase): class UserAgentTest(unittest.TestCase):
def setUp(self): def setUp(self) -> None:
self.patcher = mock.patch.object( self.patcher = mock.patch.object(
APIClient, APIClient,
"send", "send",
@ -575,10 +609,10 @@ class UserAgentTest(unittest.TestCase):
) )
self.mock_send = self.patcher.start() self.mock_send = self.patcher.start()
def tearDown(self): def tearDown(self) -> None:
self.patcher.stop() self.patcher.stop()
def test_default_user_agent(self): def test_default_user_agent(self) -> None:
client = APIClient(version=DEFAULT_DOCKER_API_VERSION) client = APIClient(version=DEFAULT_DOCKER_API_VERSION)
client.version() client.version()
@ -587,7 +621,7 @@ class UserAgentTest(unittest.TestCase):
expected = "ansible-community.docker" expected = "ansible-community.docker"
assert headers["User-Agent"] == expected assert headers["User-Agent"] == expected
def test_custom_user_agent(self): def test_custom_user_agent(self) -> None:
client = APIClient(user_agent="foo/bar", version=DEFAULT_DOCKER_API_VERSION) client = APIClient(user_agent="foo/bar", version=DEFAULT_DOCKER_API_VERSION)
client.version() client.version()
@ -598,44 +632,44 @@ class UserAgentTest(unittest.TestCase):
class DisableSocketTest(unittest.TestCase): class DisableSocketTest(unittest.TestCase):
class DummySocket: class DummySocket:
def __init__(self, timeout=60): def __init__(self, timeout: int | float | None = 60) -> None:
self.timeout = timeout self.timeout = timeout
self._sock = None self._sock: t.Any = None
def settimeout(self, timeout): def settimeout(self, timeout: int | float | None) -> None:
self.timeout = timeout self.timeout = timeout
def gettimeout(self): def gettimeout(self) -> int | float | None:
return self.timeout return self.timeout
def setUp(self): def setUp(self) -> None:
self.client = APIClient(version=DEFAULT_DOCKER_API_VERSION) self.client = APIClient(version=DEFAULT_DOCKER_API_VERSION)
def test_disable_socket_timeout(self): def test_disable_socket_timeout(self) -> None:
"""Test that the timeout is disabled on a generic socket object.""" """Test that the timeout is disabled on a generic socket object."""
the_socket = self.DummySocket() the_socket = self.DummySocket()
self.client._disable_socket_timeout(the_socket) self.client._disable_socket_timeout(the_socket) # type: ignore
assert the_socket.timeout is None assert the_socket.timeout is None
def test_disable_socket_timeout2(self): def test_disable_socket_timeout2(self) -> None:
"""Test that the timeouts are disabled on a generic socket object """Test that the timeouts are disabled on a generic socket object
and it's _sock object if present.""" and it's _sock object if present."""
the_socket = self.DummySocket() the_socket = self.DummySocket()
the_socket._sock = self.DummySocket() the_socket._sock = self.DummySocket() # type: ignore
self.client._disable_socket_timeout(the_socket) self.client._disable_socket_timeout(the_socket) # type: ignore
assert the_socket.timeout is None assert the_socket.timeout is None
assert the_socket._sock.timeout is None assert the_socket._sock.timeout is None
def test_disable_socket_timout_non_blocking(self): def test_disable_socket_timout_non_blocking(self) -> None:
"""Test that a non-blocking socket does not get set to blocking.""" """Test that a non-blocking socket does not get set to blocking."""
the_socket = self.DummySocket() the_socket = self.DummySocket()
the_socket._sock = self.DummySocket(0.0) the_socket._sock = self.DummySocket(0.0) # type: ignore
self.client._disable_socket_timeout(the_socket) self.client._disable_socket_timeout(the_socket) # type: ignore
assert the_socket.timeout is None assert the_socket.timeout is None
assert the_socket._sock.timeout == 0.0 assert the_socket._sock.timeout == 0.0

View File

@ -8,6 +8,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from ansible_collections.community.docker.plugins.module_utils._api import constants from ansible_collections.community.docker.plugins.module_utils._api import constants
from ansible_collections.community.docker.tests.unit.plugins.module_utils._api.constants import ( from ansible_collections.community.docker.tests.unit.plugins.module_utils._api.constants import (
DEFAULT_DOCKER_API_VERSION, DEFAULT_DOCKER_API_VERSION,
@ -16,6 +18,10 @@ from ansible_collections.community.docker.tests.unit.plugins.module_utils._api.c
from . import fake_stat from . import fake_stat
if t.TYPE_CHECKING:
from collections.abc import Callable
CURRENT_VERSION = f"v{DEFAULT_DOCKER_API_VERSION}" CURRENT_VERSION = f"v{DEFAULT_DOCKER_API_VERSION}"
FAKE_CONTAINER_ID = "3cc2351ab11b" FAKE_CONTAINER_ID = "3cc2351ab11b"
@ -38,7 +44,7 @@ FAKE_SECRET_NAME = "super_secret"
# for clarity and readability # for clarity and readability
def get_fake_version(): def get_fake_version() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = { response = {
"ApiVersion": "1.35", "ApiVersion": "1.35",
@ -73,7 +79,7 @@ def get_fake_version():
return status_code, response return status_code, response
def get_fake_info(): def get_fake_info() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = { response = {
"Containers": 1, "Containers": 1,
@ -86,23 +92,23 @@ def get_fake_info():
return status_code, response return status_code, response
def post_fake_auth(): def post_fake_auth() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Status": "Login Succeeded", "IdentityToken": "9cbaf023786cd7"} response = {"Status": "Login Succeeded", "IdentityToken": "9cbaf023786cd7"}
return status_code, response return status_code, response
def get_fake_ping(): def get_fake_ping() -> tuple[int, str]:
return 200, "OK" return 200, "OK"
def get_fake_search(): def get_fake_search() -> tuple[int, list[dict[str, t.Any]]]:
status_code = 200 status_code = 200
response = [{"Name": "busybox", "Description": "Fake Description"}] response = [{"Name": "busybox", "Description": "Fake Description"}]
return status_code, response return status_code, response
def get_fake_images(): def get_fake_images() -> tuple[int, list[dict[str, t.Any]]]:
status_code = 200 status_code = 200
response = [ response = [
{ {
@ -115,7 +121,7 @@ def get_fake_images():
return status_code, response return status_code, response
def get_fake_image_history(): def get_fake_image_history() -> tuple[int, list[dict[str, t.Any]]]:
status_code = 200 status_code = 200
response = [ response = [
{"Id": "b750fe79269d", "Created": 1364102658, "CreatedBy": "/bin/bash"}, {"Id": "b750fe79269d", "Created": 1364102658, "CreatedBy": "/bin/bash"},
@ -125,14 +131,14 @@ def get_fake_image_history():
return status_code, response return status_code, response
def post_fake_import_image(): def post_fake_import_image() -> tuple[int, str]:
status_code = 200 status_code = 200
response = "Import messages..." response = "Import messages..."
return status_code, response return status_code, response
def get_fake_containers(): def get_fake_containers() -> tuple[int, list[dict[str, t.Any]]]:
status_code = 200 status_code = 200
response = [ response = [
{ {
@ -146,25 +152,25 @@ def get_fake_containers():
return status_code, response return status_code, response
def post_fake_start_container(): def post_fake_start_container() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_CONTAINER_ID} response = {"Id": FAKE_CONTAINER_ID}
return status_code, response return status_code, response
def post_fake_resize_container(): def post_fake_resize_container() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_CONTAINER_ID} response = {"Id": FAKE_CONTAINER_ID}
return status_code, response return status_code, response
def post_fake_create_container(): def post_fake_create_container() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_CONTAINER_ID} response = {"Id": FAKE_CONTAINER_ID}
return status_code, response return status_code, response
def get_fake_inspect_container(tty=False): def get_fake_inspect_container(tty: bool = False) -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = { response = {
"Id": FAKE_CONTAINER_ID, "Id": FAKE_CONTAINER_ID,
@ -188,7 +194,7 @@ def get_fake_inspect_container(tty=False):
return status_code, response return status_code, response
def get_fake_inspect_image(): def get_fake_inspect_image() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = { response = {
"Id": FAKE_IMAGE_ID, "Id": FAKE_IMAGE_ID,
@ -221,19 +227,19 @@ def get_fake_inspect_image():
return status_code, response return status_code, response
def get_fake_insert_image(): def get_fake_insert_image() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"StatusCode": 0} response = {"StatusCode": 0}
return status_code, response return status_code, response
def get_fake_wait(): def get_fake_wait() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"StatusCode": 0} response = {"StatusCode": 0}
return status_code, response return status_code, response
def get_fake_logs(): def get_fake_logs() -> tuple[int, bytes]:
status_code = 200 status_code = 200
response = ( response = (
b"\x01\x00\x00\x00\x00\x00\x00\x00" b"\x01\x00\x00\x00\x00\x00\x00\x00"
@ -244,13 +250,13 @@ def get_fake_logs():
return status_code, response return status_code, response
def get_fake_diff(): def get_fake_diff() -> tuple[int, list[dict[str, t.Any]]]:
status_code = 200 status_code = 200
response = [{"Path": "/test", "Kind": 1}] response = [{"Path": "/test", "Kind": 1}]
return status_code, response return status_code, response
def get_fake_events(): def get_fake_events() -> tuple[int, list[dict[str, t.Any]]]:
status_code = 200 status_code = 200
response = [ response = [
{ {
@ -263,19 +269,19 @@ def get_fake_events():
return status_code, response return status_code, response
def get_fake_export(): def get_fake_export() -> tuple[int, str]:
status_code = 200 status_code = 200
response = "Byte Stream...." response = "Byte Stream...."
return status_code, response return status_code, response
def post_fake_exec_create(): def post_fake_exec_create() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_EXEC_ID} response = {"Id": FAKE_EXEC_ID}
return status_code, response return status_code, response
def post_fake_exec_start(): def post_fake_exec_start() -> tuple[int, bytes]:
status_code = 200 status_code = 200
response = ( response = (
b"\x01\x00\x00\x00\x00\x00\x00\x11bin\nboot\ndev\netc\n" b"\x01\x00\x00\x00\x00\x00\x00\x11bin\nboot\ndev\netc\n"
@ -285,12 +291,12 @@ def post_fake_exec_start():
return status_code, response return status_code, response
def post_fake_exec_resize(): def post_fake_exec_resize() -> tuple[int, str]:
status_code = 201 status_code = 201
return status_code, "" return status_code, ""
def get_fake_exec_inspect(): def get_fake_exec_inspect() -> tuple[int, dict[str, t.Any]]:
return 200, { return 200, {
"OpenStderr": True, "OpenStderr": True,
"OpenStdout": True, "OpenStdout": True,
@ -309,102 +315,102 @@ def get_fake_exec_inspect():
} }
def post_fake_stop_container(): def post_fake_stop_container() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_CONTAINER_ID} response = {"Id": FAKE_CONTAINER_ID}
return status_code, response return status_code, response
def post_fake_kill_container(): def post_fake_kill_container() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_CONTAINER_ID} response = {"Id": FAKE_CONTAINER_ID}
return status_code, response return status_code, response
def post_fake_pause_container(): def post_fake_pause_container() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_CONTAINER_ID} response = {"Id": FAKE_CONTAINER_ID}
return status_code, response return status_code, response
def post_fake_unpause_container(): def post_fake_unpause_container() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_CONTAINER_ID} response = {"Id": FAKE_CONTAINER_ID}
return status_code, response return status_code, response
def post_fake_restart_container(): def post_fake_restart_container() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_CONTAINER_ID} response = {"Id": FAKE_CONTAINER_ID}
return status_code, response return status_code, response
def post_fake_rename_container(): def post_fake_rename_container() -> tuple[int, None]:
status_code = 204 status_code = 204
return status_code, None return status_code, None
def delete_fake_remove_container(): def delete_fake_remove_container() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_CONTAINER_ID} response = {"Id": FAKE_CONTAINER_ID}
return status_code, response return status_code, response
def post_fake_image_create(): def post_fake_image_create() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_IMAGE_ID} response = {"Id": FAKE_IMAGE_ID}
return status_code, response return status_code, response
def delete_fake_remove_image(): def delete_fake_remove_image() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_IMAGE_ID} response = {"Id": FAKE_IMAGE_ID}
return status_code, response return status_code, response
def get_fake_get_image(): def get_fake_get_image() -> tuple[int, str]:
status_code = 200 status_code = 200
response = "Byte Stream...." response = "Byte Stream...."
return status_code, response return status_code, response
def post_fake_load_image(): def post_fake_load_image() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_IMAGE_ID} response = {"Id": FAKE_IMAGE_ID}
return status_code, response return status_code, response
def post_fake_commit(): def post_fake_commit() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_CONTAINER_ID} response = {"Id": FAKE_CONTAINER_ID}
return status_code, response return status_code, response
def post_fake_push(): def post_fake_push() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_IMAGE_ID} response = {"Id": FAKE_IMAGE_ID}
return status_code, response return status_code, response
def post_fake_build_container(): def post_fake_build_container() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_CONTAINER_ID} response = {"Id": FAKE_CONTAINER_ID}
return status_code, response return status_code, response
def post_fake_tag_image(): def post_fake_tag_image() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_IMAGE_ID} response = {"Id": FAKE_IMAGE_ID}
return status_code, response return status_code, response
def get_fake_stats(): def get_fake_stats() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = fake_stat.OBJ response = fake_stat.OBJ
return status_code, response return status_code, response
def get_fake_top(): def get_fake_top() -> tuple[int, dict[str, t.Any]]:
return 200, { return 200, {
"Processes": [ "Processes": [
[ [
@ -431,7 +437,7 @@ def get_fake_top():
} }
def get_fake_volume_list(): def get_fake_volume_list() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = { response = {
"Volumes": [ "Volumes": [
@ -452,7 +458,7 @@ def get_fake_volume_list():
return status_code, response return status_code, response
def get_fake_volume(): def get_fake_volume() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = { response = {
"Name": "perfectcherryblossom", "Name": "perfectcherryblossom",
@ -464,23 +470,23 @@ def get_fake_volume():
return status_code, response return status_code, response
def fake_remove_volume(): def fake_remove_volume() -> tuple[int, None]:
return 204, None return 204, None
def post_fake_update_container(): def post_fake_update_container() -> tuple[int, dict[str, t.Any]]:
return 200, {"Warnings": []} return 200, {"Warnings": []}
def post_fake_update_node(): def post_fake_update_node() -> tuple[int, None]:
return 200, None return 200, None
def post_fake_join_swarm(): def post_fake_join_swarm() -> tuple[int, None]:
return 200, None return 200, None
def get_fake_network_list(): def get_fake_network_list() -> tuple[int, list[dict[str, t.Any]]]:
return 200, [ return 200, [
{ {
"Name": "bridge", "Name": "bridge",
@ -510,27 +516,27 @@ def get_fake_network_list():
] ]
def get_fake_network(): def get_fake_network() -> tuple[int, dict[str, t.Any]]:
return 200, get_fake_network_list()[1][0] return 200, get_fake_network_list()[1][0]
def post_fake_network(): def post_fake_network() -> tuple[int, dict[str, t.Any]]:
return 201, {"Id": FAKE_NETWORK_ID, "Warnings": []} return 201, {"Id": FAKE_NETWORK_ID, "Warnings": []}
def delete_fake_network(): def delete_fake_network() -> tuple[int, None]:
return 204, None return 204, None
def post_fake_network_connect(): def post_fake_network_connect() -> tuple[int, None]:
return 200, None return 200, None
def post_fake_network_disconnect(): def post_fake_network_disconnect() -> tuple[int, None]:
return 200, None return 200, None
def post_fake_secret(): def post_fake_secret() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"ID": FAKE_SECRET_ID} response = {"ID": FAKE_SECRET_ID}
return status_code, response return status_code, response
@ -541,7 +547,7 @@ prefix = "http+docker://localhost" # pylint: disable=invalid-name
if constants.IS_WINDOWS_PLATFORM: if constants.IS_WINDOWS_PLATFORM:
prefix = "http+docker://localnpipe" # pylint: disable=invalid-name prefix = "http+docker://localnpipe" # pylint: disable=invalid-name
fake_responses = { fake_responses: dict[str | tuple[str, str], Callable] = {
f"{prefix}/version": get_fake_version, f"{prefix}/version": get_fake_version,
f"{prefix}/{CURRENT_VERSION}/version": get_fake_version, f"{prefix}/{CURRENT_VERSION}/version": get_fake_version,
f"{prefix}/{CURRENT_VERSION}/info": get_fake_info, f"{prefix}/{CURRENT_VERSION}/info": get_fake_info,
@ -574,6 +580,7 @@ fake_responses = {
f"{prefix}/{CURRENT_VERSION}/containers/3cc2351ab11b/unpause": post_fake_unpause_container, f"{prefix}/{CURRENT_VERSION}/containers/3cc2351ab11b/unpause": post_fake_unpause_container,
f"{prefix}/{CURRENT_VERSION}/containers/3cc2351ab11b/restart": post_fake_restart_container, f"{prefix}/{CURRENT_VERSION}/containers/3cc2351ab11b/restart": post_fake_restart_container,
f"{prefix}/{CURRENT_VERSION}/containers/3cc2351ab11b": delete_fake_remove_container, f"{prefix}/{CURRENT_VERSION}/containers/3cc2351ab11b": delete_fake_remove_container,
# TODO: the following is a duplicate of the import endpoint further above!
f"{prefix}/{CURRENT_VERSION}/images/create": post_fake_image_create, f"{prefix}/{CURRENT_VERSION}/images/create": post_fake_image_create,
f"{prefix}/{CURRENT_VERSION}/images/e9aa60c60128": delete_fake_remove_image, f"{prefix}/{CURRENT_VERSION}/images/e9aa60c60128": delete_fake_remove_image,
f"{prefix}/{CURRENT_VERSION}/images/e9aa60c60128/get": get_fake_get_image, f"{prefix}/{CURRENT_VERSION}/images/e9aa60c60128/get": get_fake_get_image,

View File

@ -15,6 +15,7 @@ import os.path
import random import random
import shutil import shutil
import tempfile import tempfile
import typing as t
import unittest import unittest
from unittest import mock from unittest import mock
@ -30,7 +31,7 @@ from ansible_collections.community.docker.plugins.module_utils._api.credentials.
class RegressionTest(unittest.TestCase): class RegressionTest(unittest.TestCase):
def test_803_urlsafe_encode(self): def test_803_urlsafe_encode(self) -> None:
auth_data = {"username": "root", "password": "GR?XGR?XGR?XGR?X"} auth_data = {"username": "root", "password": "GR?XGR?XGR?XGR?X"}
encoded = auth.encode_header(auth_data) encoded = auth.encode_header(auth_data)
assert b"/" not in encoded assert b"/" not in encoded
@ -38,75 +39,75 @@ class RegressionTest(unittest.TestCase):
class ResolveRepositoryNameTest(unittest.TestCase): class ResolveRepositoryNameTest(unittest.TestCase):
def test_resolve_repository_name_hub_library_image(self): def test_resolve_repository_name_hub_library_image(self) -> None:
assert auth.resolve_repository_name("image") == ("docker.io", "image") assert auth.resolve_repository_name("image") == ("docker.io", "image")
def test_resolve_repository_name_dotted_hub_library_image(self): def test_resolve_repository_name_dotted_hub_library_image(self) -> None:
assert auth.resolve_repository_name("image.valid") == ( assert auth.resolve_repository_name("image.valid") == (
"docker.io", "docker.io",
"image.valid", "image.valid",
) )
def test_resolve_repository_name_hub_image(self): def test_resolve_repository_name_hub_image(self) -> None:
assert auth.resolve_repository_name("username/image") == ( assert auth.resolve_repository_name("username/image") == (
"docker.io", "docker.io",
"username/image", "username/image",
) )
def test_explicit_hub_index_library_image(self): def test_explicit_hub_index_library_image(self) -> None:
assert auth.resolve_repository_name("docker.io/image") == ("docker.io", "image") assert auth.resolve_repository_name("docker.io/image") == ("docker.io", "image")
def test_explicit_legacy_hub_index_library_image(self): def test_explicit_legacy_hub_index_library_image(self) -> None:
assert auth.resolve_repository_name("index.docker.io/image") == ( assert auth.resolve_repository_name("index.docker.io/image") == (
"docker.io", "docker.io",
"image", "image",
) )
def test_resolve_repository_name_private_registry(self): def test_resolve_repository_name_private_registry(self) -> None:
assert auth.resolve_repository_name("my.registry.net/image") == ( assert auth.resolve_repository_name("my.registry.net/image") == (
"my.registry.net", "my.registry.net",
"image", "image",
) )
def test_resolve_repository_name_private_registry_with_port(self): def test_resolve_repository_name_private_registry_with_port(self) -> None:
assert auth.resolve_repository_name("my.registry.net:5000/image") == ( assert auth.resolve_repository_name("my.registry.net:5000/image") == (
"my.registry.net:5000", "my.registry.net:5000",
"image", "image",
) )
def test_resolve_repository_name_private_registry_with_username(self): def test_resolve_repository_name_private_registry_with_username(self) -> None:
assert auth.resolve_repository_name("my.registry.net/username/image") == ( assert auth.resolve_repository_name("my.registry.net/username/image") == (
"my.registry.net", "my.registry.net",
"username/image", "username/image",
) )
def test_resolve_repository_name_no_dots_but_port(self): def test_resolve_repository_name_no_dots_but_port(self) -> None:
assert auth.resolve_repository_name("hostname:5000/image") == ( assert auth.resolve_repository_name("hostname:5000/image") == (
"hostname:5000", "hostname:5000",
"image", "image",
) )
def test_resolve_repository_name_no_dots_but_port_and_username(self): def test_resolve_repository_name_no_dots_but_port_and_username(self) -> None:
assert auth.resolve_repository_name("hostname:5000/username/image") == ( assert auth.resolve_repository_name("hostname:5000/username/image") == (
"hostname:5000", "hostname:5000",
"username/image", "username/image",
) )
def test_resolve_repository_name_localhost(self): def test_resolve_repository_name_localhost(self) -> None:
assert auth.resolve_repository_name("localhost/image") == ("localhost", "image") assert auth.resolve_repository_name("localhost/image") == ("localhost", "image")
def test_resolve_repository_name_localhost_with_username(self): def test_resolve_repository_name_localhost_with_username(self) -> None:
assert auth.resolve_repository_name("localhost/username/image") == ( assert auth.resolve_repository_name("localhost/username/image") == (
"localhost", "localhost",
"username/image", "username/image",
) )
def test_invalid_index_name(self): def test_invalid_index_name(self) -> None:
with pytest.raises(errors.InvalidRepository): with pytest.raises(errors.InvalidRepository):
auth.resolve_repository_name("-gecko.com/image") auth.resolve_repository_name("-gecko.com/image")
def encode_auth(auth_info): def encode_auth(auth_info: dict[str, t.Any]) -> bytes:
return base64.b64encode( return base64.b64encode(
auth_info.get("username", "").encode("utf-8") auth_info.get("username", "").encode("utf-8")
+ b":" + b":"
@ -131,129 +132,105 @@ class ResolveAuthTest(unittest.TestCase):
} }
) )
def test_resolve_authconfig_hostname_only(self): def test_resolve_authconfig_hostname_only(self) -> None:
assert ( ac = auth.resolve_authconfig(self.auth_config, "my.registry.net")
auth.resolve_authconfig(self.auth_config, "my.registry.net")["username"] assert ac is not None
== "privateuser" assert ac["username"] == "privateuser"
)
def test_resolve_authconfig_no_protocol(self): def test_resolve_authconfig_no_protocol(self) -> None:
assert ( ac = auth.resolve_authconfig(self.auth_config, "my.registry.net/v1/")
auth.resolve_authconfig(self.auth_config, "my.registry.net/v1/")["username"] assert ac is not None
== "privateuser" assert ac["username"] == "privateuser"
)
def test_resolve_authconfig_no_path(self): def test_resolve_authconfig_no_path(self) -> None:
assert ( ac = auth.resolve_authconfig(self.auth_config, "http://my.registry.net")
auth.resolve_authconfig(self.auth_config, "http://my.registry.net")[ assert ac is not None
"username" assert ac["username"] == "privateuser"
]
== "privateuser"
)
def test_resolve_authconfig_no_path_trailing_slash(self): def test_resolve_authconfig_no_path_trailing_slash(self) -> None:
assert ( ac = auth.resolve_authconfig(self.auth_config, "http://my.registry.net/")
auth.resolve_authconfig(self.auth_config, "http://my.registry.net/")[ assert ac is not None
"username" assert ac["username"] == "privateuser"
]
== "privateuser"
)
def test_resolve_authconfig_no_path_wrong_secure_proto(self): def test_resolve_authconfig_no_path_wrong_secure_proto(self) -> None:
assert ( ac = auth.resolve_authconfig(self.auth_config, "https://my.registry.net")
auth.resolve_authconfig(self.auth_config, "https://my.registry.net")[ assert ac is not None
"username" assert ac["username"] == "privateuser"
]
== "privateuser"
)
def test_resolve_authconfig_no_path_wrong_insecure_proto(self): def test_resolve_authconfig_no_path_wrong_insecure_proto(self) -> None:
assert ( ac = auth.resolve_authconfig(self.auth_config, "http://index.docker.io")
auth.resolve_authconfig(self.auth_config, "http://index.docker.io")[ assert ac is not None
"username" assert ac["username"] == "indexuser"
]
== "indexuser"
)
def test_resolve_authconfig_path_wrong_proto(self): def test_resolve_authconfig_path_wrong_proto(self) -> None:
assert ( ac = auth.resolve_authconfig(self.auth_config, "https://my.registry.net/v1/")
auth.resolve_authconfig(self.auth_config, "https://my.registry.net/v1/")[ assert ac is not None
"username" assert ac["username"] == "privateuser"
]
== "privateuser"
)
def test_resolve_authconfig_default_registry(self): def test_resolve_authconfig_default_registry(self) -> None:
assert auth.resolve_authconfig(self.auth_config)["username"] == "indexuser" ac = auth.resolve_authconfig(self.auth_config)
assert ac is not None
assert ac["username"] == "indexuser"
def test_resolve_authconfig_default_explicit_none(self): def test_resolve_authconfig_default_explicit_none(self) -> None:
assert ( ac = auth.resolve_authconfig(self.auth_config, None)
auth.resolve_authconfig(self.auth_config, None)["username"] == "indexuser" assert ac is not None
) assert ac["username"] == "indexuser"
def test_resolve_authconfig_fully_explicit(self): def test_resolve_authconfig_fully_explicit(self) -> None:
assert ( ac = auth.resolve_authconfig(self.auth_config, "http://my.registry.net/v1/")
auth.resolve_authconfig(self.auth_config, "http://my.registry.net/v1/")[ assert ac is not None
"username" assert ac["username"] == "privateuser"
]
== "privateuser"
)
def test_resolve_authconfig_legacy_config(self): def test_resolve_authconfig_legacy_config(self) -> None:
assert ( ac = auth.resolve_authconfig(self.auth_config, "legacy.registry.url")
auth.resolve_authconfig(self.auth_config, "legacy.registry.url")["username"] assert ac is not None
== "legacyauth" assert ac["username"] == "legacyauth"
)
def test_resolve_authconfig_no_match(self): def test_resolve_authconfig_no_match(self) -> None:
assert auth.resolve_authconfig(self.auth_config, "does.not.exist") is None assert auth.resolve_authconfig(self.auth_config, "does.not.exist") is None
def test_resolve_registry_and_auth_library_image(self): def test_resolve_registry_and_auth_library_image(self) -> None:
image = "image" image = "image"
assert ( ac = auth.resolve_authconfig(
auth.resolve_authconfig( self.auth_config, auth.resolve_repository_name(image)[0]
self.auth_config, auth.resolve_repository_name(image)[0]
)["username"]
== "indexuser"
) )
assert ac is not None
assert ac["username"] == "indexuser"
def test_resolve_registry_and_auth_hub_image(self): def test_resolve_registry_and_auth_hub_image(self) -> None:
image = "username/image" image = "username/image"
assert ( ac = auth.resolve_authconfig(
auth.resolve_authconfig( self.auth_config, auth.resolve_repository_name(image)[0]
self.auth_config, auth.resolve_repository_name(image)[0]
)["username"]
== "indexuser"
) )
assert ac is not None
assert ac["username"] == "indexuser"
def test_resolve_registry_and_auth_explicit_hub(self): def test_resolve_registry_and_auth_explicit_hub(self) -> None:
image = "docker.io/username/image" image = "docker.io/username/image"
assert ( ac = auth.resolve_authconfig(
auth.resolve_authconfig( self.auth_config, auth.resolve_repository_name(image)[0]
self.auth_config, auth.resolve_repository_name(image)[0]
)["username"]
== "indexuser"
) )
assert ac is not None
assert ac["username"] == "indexuser"
def test_resolve_registry_and_auth_explicit_legacy_hub(self): def test_resolve_registry_and_auth_explicit_legacy_hub(self) -> None:
image = "index.docker.io/username/image" image = "index.docker.io/username/image"
assert ( ac = auth.resolve_authconfig(
auth.resolve_authconfig( self.auth_config, auth.resolve_repository_name(image)[0]
self.auth_config, auth.resolve_repository_name(image)[0]
)["username"]
== "indexuser"
) )
assert ac is not None
assert ac["username"] == "indexuser"
def test_resolve_registry_and_auth_private_registry(self): def test_resolve_registry_and_auth_private_registry(self) -> None:
image = "my.registry.net/image" image = "my.registry.net/image"
assert ( ac = auth.resolve_authconfig(
auth.resolve_authconfig( self.auth_config, auth.resolve_repository_name(image)[0]
self.auth_config, auth.resolve_repository_name(image)[0]
)["username"]
== "privateuser"
) )
assert ac is not None
assert ac["username"] == "privateuser"
def test_resolve_registry_and_auth_unauthenticated_registry(self): def test_resolve_registry_and_auth_unauthenticated_registry(self) -> None:
image = "other.registry.net/image" image = "other.registry.net/image"
assert ( assert (
auth.resolve_authconfig( auth.resolve_authconfig(
@ -262,7 +239,7 @@ class ResolveAuthTest(unittest.TestCase):
is None is None
) )
def test_resolve_auth_with_empty_credstore_and_auth_dict(self): def test_resolve_auth_with_empty_credstore_and_auth_dict(self) -> None:
auth_config = auth.AuthConfig( auth_config = auth.AuthConfig(
{ {
"auths": auth.parse_auth( "auths": auth.parse_auth(
@ -277,17 +254,19 @@ class ResolveAuthTest(unittest.TestCase):
"ansible_collections.community.docker.plugins.module_utils._api.auth.AuthConfig._resolve_authconfig_credstore" "ansible_collections.community.docker.plugins.module_utils._api.auth.AuthConfig._resolve_authconfig_credstore"
) as m: ) as m:
m.return_value = None m.return_value = None
assert "indexuser" == auth.resolve_authconfig(auth_config, None)["username"] ac = auth.resolve_authconfig(auth_config, None)
assert ac is not None
assert "indexuser" == ac["username"]
class LoadConfigTest(unittest.TestCase): class LoadConfigTest(unittest.TestCase):
def test_load_config_no_file(self): def test_load_config_no_file(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
cfg = auth.load_config(folder) cfg = auth.load_config(folder)
assert cfg is not None assert cfg is not None
def test_load_legacy_config(self): def test_load_legacy_config(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
cfg_path = os.path.join(folder, ".dockercfg") cfg_path = os.path.join(folder, ".dockercfg")
@ -299,13 +278,13 @@ class LoadConfigTest(unittest.TestCase):
cfg = auth.load_config(cfg_path) cfg = auth.load_config(cfg_path)
assert auth.resolve_authconfig(cfg) is not None assert auth.resolve_authconfig(cfg) is not None
assert cfg.auths[auth.INDEX_NAME] is not None assert cfg.auths[auth.INDEX_NAME] is not None
cfg = cfg.auths[auth.INDEX_NAME] cfg2 = cfg.auths[auth.INDEX_NAME]
assert cfg["username"] == "sakuya" assert cfg2["username"] == "sakuya"
assert cfg["password"] == "izayoi" assert cfg2["password"] == "izayoi"
assert cfg["email"] == "sakuya@scarlet.net" assert cfg2["email"] == "sakuya@scarlet.net"
assert cfg.get("Auth") is None assert cfg2.get("Auth") is None
def test_load_json_config(self): def test_load_json_config(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
cfg_path = os.path.join(folder, ".dockercfg") cfg_path = os.path.join(folder, ".dockercfg")
@ -316,13 +295,13 @@ class LoadConfigTest(unittest.TestCase):
cfg = auth.load_config(cfg_path) cfg = auth.load_config(cfg_path)
assert auth.resolve_authconfig(cfg) is not None assert auth.resolve_authconfig(cfg) is not None
assert cfg.auths[auth.INDEX_URL] is not None assert cfg.auths[auth.INDEX_URL] is not None
cfg = cfg.auths[auth.INDEX_URL] cfg2 = cfg.auths[auth.INDEX_URL]
assert cfg["username"] == "sakuya" assert cfg2["username"] == "sakuya"
assert cfg["password"] == "izayoi" assert cfg2["password"] == "izayoi"
assert cfg["email"] == email assert cfg2["email"] == email
assert cfg.get("Auth") is None assert cfg2.get("Auth") is None
def test_load_modern_json_config(self): def test_load_modern_json_config(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
cfg_path = os.path.join(folder, "config.json") cfg_path = os.path.join(folder, "config.json")
@ -333,12 +312,12 @@ class LoadConfigTest(unittest.TestCase):
cfg = auth.load_config(cfg_path) cfg = auth.load_config(cfg_path)
assert auth.resolve_authconfig(cfg) is not None assert auth.resolve_authconfig(cfg) is not None
assert cfg.auths[auth.INDEX_URL] is not None assert cfg.auths[auth.INDEX_URL] is not None
cfg = cfg.auths[auth.INDEX_URL] cfg2 = cfg.auths[auth.INDEX_URL]
assert cfg["username"] == "sakuya" assert cfg2["username"] == "sakuya"
assert cfg["password"] == "izayoi" assert cfg2["password"] == "izayoi"
assert cfg["email"] == email assert cfg2["email"] == email
def test_load_config_with_random_name(self): def test_load_config_with_random_name(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
@ -353,13 +332,13 @@ class LoadConfigTest(unittest.TestCase):
cfg = auth.load_config(dockercfg_path).auths cfg = auth.load_config(dockercfg_path).auths
assert registry in cfg assert registry in cfg
assert cfg[registry] is not None assert cfg[registry] is not None
cfg = cfg[registry] cfg2 = cfg[registry]
assert cfg["username"] == "sakuya" assert cfg2["username"] == "sakuya"
assert cfg["password"] == "izayoi" assert cfg2["password"] == "izayoi"
assert cfg["email"] == "sakuya@scarlet.net" assert cfg2["email"] == "sakuya@scarlet.net"
assert cfg.get("auth") is None assert cfg2.get("auth") is None
def test_load_config_custom_config_env(self): def test_load_config_custom_config_env(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
@ -375,13 +354,13 @@ class LoadConfigTest(unittest.TestCase):
cfg = auth.load_config(None).auths cfg = auth.load_config(None).auths
assert registry in cfg assert registry in cfg
assert cfg[registry] is not None assert cfg[registry] is not None
cfg = cfg[registry] cfg2 = cfg[registry]
assert cfg["username"] == "sakuya" assert cfg2["username"] == "sakuya"
assert cfg["password"] == "izayoi" assert cfg2["password"] == "izayoi"
assert cfg["email"] == "sakuya@scarlet.net" assert cfg2["email"] == "sakuya@scarlet.net"
assert cfg.get("auth") is None assert cfg2.get("auth") is None
def test_load_config_custom_config_env_with_auths(self): def test_load_config_custom_config_env_with_auths(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
@ -398,13 +377,13 @@ class LoadConfigTest(unittest.TestCase):
with mock.patch.dict(os.environ, {"DOCKER_CONFIG": folder}): with mock.patch.dict(os.environ, {"DOCKER_CONFIG": folder}):
cfg = auth.load_config(None) cfg = auth.load_config(None)
assert registry in cfg.auths assert registry in cfg.auths
cfg = cfg.auths[registry] cfg2 = cfg.auths[registry]
assert cfg["username"] == "sakuya" assert cfg2["username"] == "sakuya"
assert cfg["password"] == "izayoi" assert cfg2["password"] == "izayoi"
assert cfg["email"] == "sakuya@scarlet.net" assert cfg2["email"] == "sakuya@scarlet.net"
assert cfg.get("auth") is None assert cfg2.get("auth") is None
def test_load_config_custom_config_env_utf8(self): def test_load_config_custom_config_env_utf8(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
@ -421,13 +400,13 @@ class LoadConfigTest(unittest.TestCase):
with mock.patch.dict(os.environ, {"DOCKER_CONFIG": folder}): with mock.patch.dict(os.environ, {"DOCKER_CONFIG": folder}):
cfg = auth.load_config(None) cfg = auth.load_config(None)
assert registry in cfg.auths assert registry in cfg.auths
cfg = cfg.auths[registry] cfg2 = cfg.auths[registry]
assert cfg["username"] == b"sakuya\xc3\xa6".decode("utf8") assert cfg2["username"] == b"sakuya\xc3\xa6".decode("utf8")
assert cfg["password"] == b"izayoi\xc3\xa6".decode("utf8") assert cfg2["password"] == b"izayoi\xc3\xa6".decode("utf8")
assert cfg["email"] == "sakuya@scarlet.net" assert cfg2["email"] == "sakuya@scarlet.net"
assert cfg.get("auth") is None assert cfg2.get("auth") is None
def test_load_config_unknown_keys(self): def test_load_config_unknown_keys(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
dockercfg_path = os.path.join(folder, "config.json") dockercfg_path = os.path.join(folder, "config.json")
@ -438,7 +417,7 @@ class LoadConfigTest(unittest.TestCase):
cfg = auth.load_config(dockercfg_path) cfg = auth.load_config(dockercfg_path)
assert dict(cfg) == {"auths": {}} assert dict(cfg) == {"auths": {}}
def test_load_config_invalid_auth_dict(self): def test_load_config_invalid_auth_dict(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
dockercfg_path = os.path.join(folder, "config.json") dockercfg_path = os.path.join(folder, "config.json")
@ -449,7 +428,7 @@ class LoadConfigTest(unittest.TestCase):
cfg = auth.load_config(dockercfg_path) cfg = auth.load_config(dockercfg_path)
assert dict(cfg) == {"auths": {"scarlet.net": {}}} assert dict(cfg) == {"auths": {"scarlet.net": {}}}
def test_load_config_identity_token(self): def test_load_config_identity_token(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
registry = "scarlet.net" registry = "scarlet.net"
token = "1ce1cebb-503e-7043-11aa-7feb8bd4a1ce" token = "1ce1cebb-503e-7043-11aa-7feb8bd4a1ce"
@ -462,13 +441,13 @@ class LoadConfigTest(unittest.TestCase):
cfg = auth.load_config(dockercfg_path) cfg = auth.load_config(dockercfg_path)
assert registry in cfg.auths assert registry in cfg.auths
cfg = cfg.auths[registry] cfg2 = cfg.auths[registry]
assert "IdentityToken" in cfg assert "IdentityToken" in cfg2
assert cfg["IdentityToken"] == token assert cfg2["IdentityToken"] == token
class CredstoreTest(unittest.TestCase): class CredstoreTest(unittest.TestCase):
def setUp(self): def setUp(self) -> None:
self.authconfig = auth.AuthConfig({"credsStore": "default"}) self.authconfig = auth.AuthConfig({"credsStore": "default"})
self.default_store = InMemoryStore("default") self.default_store = InMemoryStore("default")
self.authconfig._stores["default"] = self.default_store self.authconfig._stores["default"] = self.default_store
@ -483,7 +462,7 @@ class CredstoreTest(unittest.TestCase):
"hunter2", "hunter2",
) )
def test_get_credential_store(self): def test_get_credential_store(self) -> None:
auth_config = auth.AuthConfig( auth_config = auth.AuthConfig(
{ {
"credHelpers": { "credHelpers": {
@ -498,7 +477,7 @@ class CredstoreTest(unittest.TestCase):
assert auth_config.get_credential_store("registry2.io") == "powerlock" assert auth_config.get_credential_store("registry2.io") == "powerlock"
assert auth_config.get_credential_store("registry3.io") == "blackbox" assert auth_config.get_credential_store("registry3.io") == "blackbox"
def test_get_credential_store_no_default(self): def test_get_credential_store_no_default(self) -> None:
auth_config = auth.AuthConfig( auth_config = auth.AuthConfig(
{ {
"credHelpers": { "credHelpers": {
@ -510,7 +489,7 @@ class CredstoreTest(unittest.TestCase):
assert auth_config.get_credential_store("registry2.io") == "powerlock" assert auth_config.get_credential_store("registry2.io") == "powerlock"
assert auth_config.get_credential_store("registry3.io") is None assert auth_config.get_credential_store("registry3.io") is None
def test_get_credential_store_default_index(self): def test_get_credential_store_default_index(self) -> None:
auth_config = auth.AuthConfig( auth_config = auth.AuthConfig(
{ {
"credHelpers": {"https://index.docker.io/v1/": "powerlock"}, "credHelpers": {"https://index.docker.io/v1/": "powerlock"},
@ -522,7 +501,7 @@ class CredstoreTest(unittest.TestCase):
assert auth_config.get_credential_store("docker.io") == "powerlock" assert auth_config.get_credential_store("docker.io") == "powerlock"
assert auth_config.get_credential_store("images.io") == "truesecret" assert auth_config.get_credential_store("images.io") == "truesecret"
def test_get_credential_store_with_plain_dict(self): def test_get_credential_store_with_plain_dict(self) -> None:
auth_config = { auth_config = {
"credHelpers": {"registry1.io": "truesecret", "registry2.io": "powerlock"}, "credHelpers": {"registry1.io": "truesecret", "registry2.io": "powerlock"},
"credsStore": "blackbox", "credsStore": "blackbox",
@ -532,7 +511,7 @@ class CredstoreTest(unittest.TestCase):
assert auth.get_credential_store(auth_config, "registry2.io") == "powerlock" assert auth.get_credential_store(auth_config, "registry2.io") == "powerlock"
assert auth.get_credential_store(auth_config, "registry3.io") == "blackbox" assert auth.get_credential_store(auth_config, "registry3.io") == "blackbox"
def test_get_all_credentials_credstore_only(self): def test_get_all_credentials_credstore_only(self) -> None:
assert self.authconfig.get_all_credentials() == { assert self.authconfig.get_all_credentials() == {
"https://gensokyo.jp/v2": { "https://gensokyo.jp/v2": {
"Username": "sakuya", "Username": "sakuya",
@ -556,7 +535,7 @@ class CredstoreTest(unittest.TestCase):
}, },
} }
def test_get_all_credentials_with_empty_credhelper(self): def test_get_all_credentials_with_empty_credhelper(self) -> None:
self.authconfig["credHelpers"] = { self.authconfig["credHelpers"] = {
"registry1.io": "truesecret", "registry1.io": "truesecret",
} }
@ -585,7 +564,7 @@ class CredstoreTest(unittest.TestCase):
"registry1.io": None, "registry1.io": None,
} }
def test_get_all_credentials_with_credhelpers_only(self): def test_get_all_credentials_with_credhelpers_only(self) -> None:
del self.authconfig["credsStore"] del self.authconfig["credsStore"]
assert self.authconfig.get_all_credentials() == {} assert self.authconfig.get_all_credentials() == {}
@ -617,7 +596,7 @@ class CredstoreTest(unittest.TestCase):
}, },
} }
def test_get_all_credentials_with_auths_entries(self): def test_get_all_credentials_with_auths_entries(self) -> None:
self.authconfig.add_auth( self.authconfig.add_auth(
"registry1.io", "registry1.io",
{ {
@ -655,7 +634,7 @@ class CredstoreTest(unittest.TestCase):
}, },
} }
def test_get_all_credentials_with_empty_auths_entry(self): def test_get_all_credentials_with_empty_auths_entry(self) -> None:
self.authconfig.add_auth("default.com", {}) self.authconfig.add_auth("default.com", {})
assert self.authconfig.get_all_credentials() == { assert self.authconfig.get_all_credentials() == {
@ -681,7 +660,7 @@ class CredstoreTest(unittest.TestCase):
}, },
} }
def test_get_all_credentials_credstore_overrides_auth_entry(self): def test_get_all_credentials_credstore_overrides_auth_entry(self) -> None:
self.authconfig.add_auth( self.authconfig.add_auth(
"default.com", "default.com",
{ {
@ -714,7 +693,7 @@ class CredstoreTest(unittest.TestCase):
}, },
} }
def test_get_all_credentials_helpers_override_default(self): def test_get_all_credentials_helpers_override_default(self) -> None:
self.authconfig["credHelpers"] = { self.authconfig["credHelpers"] = {
"https://default.com/v2": "truesecret", "https://default.com/v2": "truesecret",
} }
@ -744,7 +723,7 @@ class CredstoreTest(unittest.TestCase):
}, },
} }
def test_get_all_credentials_3_sources(self): def test_get_all_credentials_3_sources(self) -> None:
self.authconfig["credHelpers"] = { self.authconfig["credHelpers"] = {
"registry1.io": "truesecret", "registry1.io": "truesecret",
} }
@ -795,24 +774,27 @@ class CredstoreTest(unittest.TestCase):
class InMemoryStore(Store): class InMemoryStore(Store):
def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called def __init__( # pylint: disable=super-init-not-called
self.__store = {} self, *args: t.Any, **kwargs: t.Any
) -> None:
self.__store: dict[str | bytes, dict[str, t.Any]] = {}
def get(self, server): def get(self, server: str | bytes) -> dict[str, t.Any]:
try: try:
return self.__store[server] return self.__store[server]
except KeyError: except KeyError:
raise CredentialsNotFound() from None raise CredentialsNotFound() from None
def store(self, server, username, secret): def store(self, server: str, username: str, secret: str) -> bytes:
self.__store[server] = { self.__store[server] = {
"ServerURL": server, "ServerURL": server,
"Username": username, "Username": username,
"Secret": secret, "Secret": secret,
} }
return b""
def list(self): def list(self) -> dict[str | bytes, str]:
return dict((k, v["Username"]) for k, v in self.__store.items()) return dict((k, v["Username"]) for k, v in self.__store.items())
def erase(self, server): def erase(self, server: str | bytes) -> None:
del self.__store[server] del self.__store[server]

View File

@ -28,20 +28,20 @@ from ansible_collections.community.docker.plugins.module_utils._api.context.cont
class BaseContextTest(unittest.TestCase): class BaseContextTest(unittest.TestCase):
@pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="Linux specific path check") @pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="Linux specific path check")
def test_url_compatibility_on_linux(self): def test_url_compatibility_on_linux(self) -> None:
c = Context("test") c = Context("test")
assert c.Host == DEFAULT_UNIX_SOCKET[5:] assert c.Host == DEFAULT_UNIX_SOCKET[5:]
@pytest.mark.skipif(not IS_WINDOWS_PLATFORM, reason="Windows specific path check") @pytest.mark.skipif(not IS_WINDOWS_PLATFORM, reason="Windows specific path check")
def test_url_compatibility_on_windows(self): def test_url_compatibility_on_windows(self) -> None:
c = Context("test") c = Context("test")
assert c.Host == DEFAULT_NPIPE assert c.Host == DEFAULT_NPIPE
def test_fail_on_default_context_create(self): def test_fail_on_default_context_create(self) -> None:
with pytest.raises(errors.ContextException): with pytest.raises(errors.ContextException):
ContextAPI.create_context("default") ContextAPI.create_context("default")
def test_default_in_context_list(self): def test_default_in_context_list(self) -> None:
found = False found = False
ctx = ContextAPI.contexts() ctx = ContextAPI.contexts()
for c in ctx: for c in ctx:
@ -49,14 +49,16 @@ class BaseContextTest(unittest.TestCase):
found = True found = True
assert found is True assert found is True
def test_get_current_context(self): def test_get_current_context(self) -> None:
assert ContextAPI.get_current_context().Name == "default" context = ContextAPI.get_current_context()
assert context is not None
assert context.Name == "default"
def test_https_host(self): def test_https_host(self) -> None:
c = Context("test", host="tcp://testdomain:8080", tls=True) c = Context("test", host="tcp://testdomain:8080", tls=True)
assert c.Host == "https://testdomain:8080" assert c.Host == "https://testdomain:8080"
def test_context_inspect_without_params(self): def test_context_inspect_without_params(self) -> None:
ctx = ContextAPI.inspect_context() ctx = ContextAPI.inspect_context()
assert ctx["Name"] == "default" assert ctx["Name"] == "default"
assert ctx["Metadata"]["StackOrchestrator"] == "swarm" assert ctx["Metadata"]["StackOrchestrator"] == "swarm"

View File

@ -21,97 +21,97 @@ from ansible_collections.community.docker.plugins.module_utils._api.errors impor
class APIErrorTest(unittest.TestCase): class APIErrorTest(unittest.TestCase):
def test_api_error_is_caught_by_dockerexception(self): def test_api_error_is_caught_by_dockerexception(self) -> None:
try: try:
raise APIError("this should be caught by DockerException") raise APIError("this should be caught by DockerException")
except DockerException: except DockerException:
pass pass
def test_status_code_200(self): def test_status_code_200(self) -> None:
"""The status_code property is present with 200 response.""" """The status_code property is present with 200 response."""
resp = requests.Response() resp = requests.Response()
resp.status_code = 200 resp.status_code = 200
err = APIError("", response=resp) err = APIError("", response=resp)
assert err.status_code == 200 assert err.status_code == 200
def test_status_code_400(self): def test_status_code_400(self) -> None:
"""The status_code property is present with 400 response.""" """The status_code property is present with 400 response."""
resp = requests.Response() resp = requests.Response()
resp.status_code = 400 resp.status_code = 400
err = APIError("", response=resp) err = APIError("", response=resp)
assert err.status_code == 400 assert err.status_code == 400
def test_status_code_500(self): def test_status_code_500(self) -> None:
"""The status_code property is present with 500 response.""" """The status_code property is present with 500 response."""
resp = requests.Response() resp = requests.Response()
resp.status_code = 500 resp.status_code = 500
err = APIError("", response=resp) err = APIError("", response=resp)
assert err.status_code == 500 assert err.status_code == 500
def test_is_server_error_200(self): def test_is_server_error_200(self) -> None:
"""Report not server error on 200 response.""" """Report not server error on 200 response."""
resp = requests.Response() resp = requests.Response()
resp.status_code = 200 resp.status_code = 200
err = APIError("", response=resp) err = APIError("", response=resp)
assert err.is_server_error() is False assert err.is_server_error() is False
def test_is_server_error_300(self): def test_is_server_error_300(self) -> None:
"""Report not server error on 300 response.""" """Report not server error on 300 response."""
resp = requests.Response() resp = requests.Response()
resp.status_code = 300 resp.status_code = 300
err = APIError("", response=resp) err = APIError("", response=resp)
assert err.is_server_error() is False assert err.is_server_error() is False
def test_is_server_error_400(self): def test_is_server_error_400(self) -> None:
"""Report not server error on 400 response.""" """Report not server error on 400 response."""
resp = requests.Response() resp = requests.Response()
resp.status_code = 400 resp.status_code = 400
err = APIError("", response=resp) err = APIError("", response=resp)
assert err.is_server_error() is False assert err.is_server_error() is False
def test_is_server_error_500(self): def test_is_server_error_500(self) -> None:
"""Report server error on 500 response.""" """Report server error on 500 response."""
resp = requests.Response() resp = requests.Response()
resp.status_code = 500 resp.status_code = 500
err = APIError("", response=resp) err = APIError("", response=resp)
assert err.is_server_error() is True assert err.is_server_error() is True
def test_is_client_error_500(self): def test_is_client_error_500(self) -> None:
"""Report not client error on 500 response.""" """Report not client error on 500 response."""
resp = requests.Response() resp = requests.Response()
resp.status_code = 500 resp.status_code = 500
err = APIError("", response=resp) err = APIError("", response=resp)
assert err.is_client_error() is False assert err.is_client_error() is False
def test_is_client_error_400(self): def test_is_client_error_400(self) -> None:
"""Report client error on 400 response.""" """Report client error on 400 response."""
resp = requests.Response() resp = requests.Response()
resp.status_code = 400 resp.status_code = 400
err = APIError("", response=resp) err = APIError("", response=resp)
assert err.is_client_error() is True assert err.is_client_error() is True
def test_is_error_300(self): def test_is_error_300(self) -> None:
"""Report no error on 300 response.""" """Report no error on 300 response."""
resp = requests.Response() resp = requests.Response()
resp.status_code = 300 resp.status_code = 300
err = APIError("", response=resp) err = APIError("", response=resp)
assert err.is_error() is False assert err.is_error() is False
def test_is_error_400(self): def test_is_error_400(self) -> None:
"""Report error on 400 response.""" """Report error on 400 response."""
resp = requests.Response() resp = requests.Response()
resp.status_code = 400 resp.status_code = 400
err = APIError("", response=resp) err = APIError("", response=resp)
assert err.is_error() is True assert err.is_error() is True
def test_is_error_500(self): def test_is_error_500(self) -> None:
"""Report error on 500 response.""" """Report error on 500 response."""
resp = requests.Response() resp = requests.Response()
resp.status_code = 500 resp.status_code = 500
err = APIError("", response=resp) err = APIError("", response=resp)
assert err.is_error() is True assert err.is_error() is True
def test_create_error_from_exception(self): def test_create_error_from_exception(self) -> None:
resp = requests.Response() resp = requests.Response()
resp.status_code = 500 resp.status_code = 500
err = APIError("") err = APIError("")
@ -126,10 +126,10 @@ class APIErrorTest(unittest.TestCase):
class CreateUnexpectedKwargsErrorTest(unittest.TestCase): class CreateUnexpectedKwargsErrorTest(unittest.TestCase):
def test_create_unexpected_kwargs_error_single(self): def test_create_unexpected_kwargs_error_single(self) -> None:
e = create_unexpected_kwargs_error("f", {"foo": "bar"}) e = create_unexpected_kwargs_error("f", {"foo": "bar"})
assert str(e) == "f() got an unexpected keyword argument 'foo'" assert str(e) == "f() got an unexpected keyword argument 'foo'"
def test_create_unexpected_kwargs_error_multiple(self): def test_create_unexpected_kwargs_error_multiple(self) -> None:
e = create_unexpected_kwargs_error("f", {"foo": "bar", "baz": "bosh"}) e = create_unexpected_kwargs_error("f", {"foo": "bar", "baz": "bosh"})
assert str(e) == "f() got unexpected keyword arguments 'baz', 'foo'" assert str(e) == "f() got unexpected keyword arguments 'baz', 'foo'"

View File

@ -18,33 +18,33 @@ from ansible_collections.community.docker.plugins.module_utils._api.transport.ss
class SSHAdapterTest(unittest.TestCase): class SSHAdapterTest(unittest.TestCase):
@staticmethod @staticmethod
def test_ssh_hostname_prefix_trim(): def test_ssh_hostname_prefix_trim() -> None:
conn = SSHHTTPAdapter(base_url="ssh://user@hostname:1234", shell_out=True) conn = SSHHTTPAdapter(base_url="ssh://user@hostname:1234", shell_out=True)
assert conn.ssh_host == "user@hostname:1234" assert conn.ssh_host == "user@hostname:1234"
@staticmethod @staticmethod
def test_ssh_parse_url(): def test_ssh_parse_url() -> None:
c = SSHSocket(host="user@hostname:1234") c = SSHSocket(host="user@hostname:1234")
assert c.host == "hostname" assert c.host == "hostname"
assert c.port == "1234" assert c.port == "1234"
assert c.user == "user" assert c.user == "user"
@staticmethod @staticmethod
def test_ssh_parse_hostname_only(): def test_ssh_parse_hostname_only() -> None:
c = SSHSocket(host="hostname") c = SSHSocket(host="hostname")
assert c.host == "hostname" assert c.host == "hostname"
assert c.port is None assert c.port is None
assert c.user is None assert c.user is None
@staticmethod @staticmethod
def test_ssh_parse_user_and_hostname(): def test_ssh_parse_user_and_hostname() -> None:
c = SSHSocket(host="user@hostname") c = SSHSocket(host="user@hostname")
assert c.host == "hostname" assert c.host == "hostname"
assert c.port is None assert c.port is None
assert c.user == "user" assert c.user == "user"
@staticmethod @staticmethod
def test_ssh_parse_hostname_and_port(): def test_ssh_parse_hostname_and_port() -> None:
c = SSHSocket(host="hostname:22") c = SSHSocket(host="hostname:22")
assert c.host == "hostname" assert c.host == "hostname"
assert c.port == "22" assert c.port == "22"

View File

@ -27,7 +27,7 @@ else:
class SSLAdapterTest(unittest.TestCase): class SSLAdapterTest(unittest.TestCase):
def test_only_uses_tls(self): def test_only_uses_tls(self) -> None:
ssl_context = ssladapter.urllib3.util.ssl_.create_urllib3_context() ssl_context = ssladapter.urllib3.util.ssl_.create_urllib3_context()
assert ssl_context.options & OP_NO_SSLv3 assert ssl_context.options & OP_NO_SSLv3
@ -68,19 +68,19 @@ class MatchHostnameTest(unittest.TestCase):
"version": 3, "version": 3,
} }
def test_match_ip_address_success(self): def test_match_ip_address_success(self) -> None:
assert match_hostname(self.cert, "127.0.0.1") is None assert match_hostname(self.cert, "127.0.0.1") is None
def test_match_localhost_success(self): def test_match_localhost_success(self) -> None:
assert match_hostname(self.cert, "localhost") is None assert match_hostname(self.cert, "localhost") is None
def test_match_dns_success(self): def test_match_dns_success(self) -> None:
assert match_hostname(self.cert, "touhou.gensokyo.jp") is None assert match_hostname(self.cert, "touhou.gensokyo.jp") is None
def test_match_ip_address_failure(self): def test_match_ip_address_failure(self) -> None:
with pytest.raises(CertificateError): with pytest.raises(CertificateError):
match_hostname(self.cert, "192.168.0.25") match_hostname(self.cert, "192.168.0.25")
def test_match_dns_failure(self): def test_match_dns_failure(self) -> None:
with pytest.raises(CertificateError): with pytest.raises(CertificateError):
match_hostname(self.cert, "foobar.co.uk") match_hostname(self.cert, "foobar.co.uk")

View File

@ -14,6 +14,7 @@ import shutil
import socket import socket
import tarfile import tarfile
import tempfile import tempfile
import typing as t
import unittest import unittest
import pytest import pytest
@ -27,7 +28,11 @@ from ansible_collections.community.docker.plugins.module_utils._api.utils.build
) )
def make_tree(dirs, files): if t.TYPE_CHECKING:
from collections.abc import Collection
def make_tree(dirs: list[str], files: list[str]) -> str:
base = tempfile.mkdtemp() base = tempfile.mkdtemp()
for path in dirs: for path in dirs:
@ -40,11 +45,11 @@ def make_tree(dirs, files):
return base return base
def convert_paths(collection): def convert_paths(collection: Collection[str]) -> set[str]:
return set(map(convert_path, collection)) return set(map(convert_path, collection))
def convert_path(path): def convert_path(path: str) -> str:
return path.replace("/", os.path.sep) return path.replace("/", os.path.sep)
@ -88,26 +93,26 @@ class ExcludePathsTest(unittest.TestCase):
all_paths = set(dirs + files) all_paths = set(dirs + files)
def setUp(self): def setUp(self) -> None:
self.base = make_tree(self.dirs, self.files) self.base = make_tree(self.dirs, self.files)
def tearDown(self): def tearDown(self) -> None:
shutil.rmtree(self.base) shutil.rmtree(self.base)
def exclude(self, patterns, dockerfile=None): def exclude(self, patterns: list[str], dockerfile: str | None = None) -> set[str]:
return set(exclude_paths(self.base, patterns, dockerfile=dockerfile)) return set(exclude_paths(self.base, patterns, dockerfile=dockerfile))
def test_no_excludes(self): def test_no_excludes(self) -> None:
assert self.exclude([""]) == convert_paths(self.all_paths) assert self.exclude([""]) == convert_paths(self.all_paths)
def test_no_dupes(self): def test_no_dupes(self) -> None:
paths = exclude_paths(self.base, ["!a.py"]) paths = exclude_paths(self.base, ["!a.py"])
assert sorted(paths) == sorted(set(paths)) assert sorted(paths) == sorted(set(paths))
def test_wildcard_exclude(self): def test_wildcard_exclude(self) -> None:
assert self.exclude(["*"]) == set(["Dockerfile", ".dockerignore"]) assert self.exclude(["*"]) == set(["Dockerfile", ".dockerignore"])
def test_exclude_dockerfile_dockerignore(self): def test_exclude_dockerfile_dockerignore(self) -> None:
""" """
Even if the .dockerignore file explicitly says to exclude Even if the .dockerignore file explicitly says to exclude
Dockerfile and/or .dockerignore, don't exclude them from Dockerfile and/or .dockerignore, don't exclude them from
@ -117,7 +122,7 @@ class ExcludePathsTest(unittest.TestCase):
self.all_paths self.all_paths
) )
def test_exclude_custom_dockerfile(self): def test_exclude_custom_dockerfile(self) -> None:
""" """
If we're using a custom Dockerfile, make sure that's not If we're using a custom Dockerfile, make sure that's not
excluded. excluded.
@ -135,33 +140,33 @@ class ExcludePathsTest(unittest.TestCase):
set(["foo/Dockerfile3", ".dockerignore"]) set(["foo/Dockerfile3", ".dockerignore"])
) )
def test_exclude_dockerfile_child(self): def test_exclude_dockerfile_child(self) -> None:
includes = self.exclude(["foo/"], dockerfile="foo/Dockerfile3") includes = self.exclude(["foo/"], dockerfile="foo/Dockerfile3")
assert convert_path("foo/Dockerfile3") in includes assert convert_path("foo/Dockerfile3") in includes
assert convert_path("foo/a.py") not in includes assert convert_path("foo/a.py") not in includes
def test_single_filename(self): def test_single_filename(self) -> None:
assert self.exclude(["a.py"]) == convert_paths(self.all_paths - set(["a.py"])) assert self.exclude(["a.py"]) == convert_paths(self.all_paths - set(["a.py"]))
def test_single_filename_leading_dot_slash(self): def test_single_filename_leading_dot_slash(self) -> None:
assert self.exclude(["./a.py"]) == convert_paths(self.all_paths - set(["a.py"])) assert self.exclude(["./a.py"]) == convert_paths(self.all_paths - set(["a.py"]))
# As odd as it sounds, a filename pattern with a trailing slash on the # As odd as it sounds, a filename pattern with a trailing slash on the
# end *will* result in that file being excluded. # end *will* result in that file being excluded.
def test_single_filename_trailing_slash(self): def test_single_filename_trailing_slash(self) -> None:
assert self.exclude(["a.py/"]) == convert_paths(self.all_paths - set(["a.py"])) assert self.exclude(["a.py/"]) == convert_paths(self.all_paths - set(["a.py"]))
def test_wildcard_filename_start(self): def test_wildcard_filename_start(self) -> None:
assert self.exclude(["*.py"]) == convert_paths( assert self.exclude(["*.py"]) == convert_paths(
self.all_paths - set(["a.py", "b.py", "cde.py"]) self.all_paths - set(["a.py", "b.py", "cde.py"])
) )
def test_wildcard_with_exception(self): def test_wildcard_with_exception(self) -> None:
assert self.exclude(["*.py", "!b.py"]) == convert_paths( assert self.exclude(["*.py", "!b.py"]) == convert_paths(
self.all_paths - set(["a.py", "cde.py"]) self.all_paths - set(["a.py", "cde.py"])
) )
def test_wildcard_with_wildcard_exception(self): def test_wildcard_with_wildcard_exception(self) -> None:
assert self.exclude(["*.*", "!*.go"]) == convert_paths( assert self.exclude(["*.*", "!*.go"]) == convert_paths(
self.all_paths self.all_paths
- set( - set(
@ -174,51 +179,51 @@ class ExcludePathsTest(unittest.TestCase):
) )
) )
def test_wildcard_filename_end(self): def test_wildcard_filename_end(self) -> None:
assert self.exclude(["a.*"]) == convert_paths( assert self.exclude(["a.*"]) == convert_paths(
self.all_paths - set(["a.py", "a.go"]) self.all_paths - set(["a.py", "a.go"])
) )
def test_question_mark(self): def test_question_mark(self) -> None:
assert self.exclude(["?.py"]) == convert_paths( assert self.exclude(["?.py"]) == convert_paths(
self.all_paths - set(["a.py", "b.py"]) self.all_paths - set(["a.py", "b.py"])
) )
def test_single_subdir_single_filename(self): def test_single_subdir_single_filename(self) -> None:
assert self.exclude(["foo/a.py"]) == convert_paths( assert self.exclude(["foo/a.py"]) == convert_paths(
self.all_paths - set(["foo/a.py"]) self.all_paths - set(["foo/a.py"])
) )
def test_single_subdir_single_filename_leading_slash(self): def test_single_subdir_single_filename_leading_slash(self) -> None:
assert self.exclude(["/foo/a.py"]) == convert_paths( assert self.exclude(["/foo/a.py"]) == convert_paths(
self.all_paths - set(["foo/a.py"]) self.all_paths - set(["foo/a.py"])
) )
def test_exclude_include_absolute_path(self): def test_exclude_include_absolute_path(self) -> None:
base = make_tree([], ["a.py", "b.py"]) base = make_tree([], ["a.py", "b.py"])
assert exclude_paths(base, ["/*", "!/*.py"]) == set(["a.py", "b.py"]) assert exclude_paths(base, ["/*", "!/*.py"]) == set(["a.py", "b.py"])
def test_single_subdir_with_path_traversal(self): def test_single_subdir_with_path_traversal(self) -> None:
assert self.exclude(["foo/whoops/../a.py"]) == convert_paths( assert self.exclude(["foo/whoops/../a.py"]) == convert_paths(
self.all_paths - set(["foo/a.py"]) self.all_paths - set(["foo/a.py"])
) )
def test_single_subdir_wildcard_filename(self): def test_single_subdir_wildcard_filename(self) -> None:
assert self.exclude(["foo/*.py"]) == convert_paths( assert self.exclude(["foo/*.py"]) == convert_paths(
self.all_paths - set(["foo/a.py", "foo/b.py"]) self.all_paths - set(["foo/a.py", "foo/b.py"])
) )
def test_wildcard_subdir_single_filename(self): def test_wildcard_subdir_single_filename(self) -> None:
assert self.exclude(["*/a.py"]) == convert_paths( assert self.exclude(["*/a.py"]) == convert_paths(
self.all_paths - set(["foo/a.py", "bar/a.py"]) self.all_paths - set(["foo/a.py", "bar/a.py"])
) )
def test_wildcard_subdir_wildcard_filename(self): def test_wildcard_subdir_wildcard_filename(self) -> None:
assert self.exclude(["*/*.py"]) == convert_paths( assert self.exclude(["*/*.py"]) == convert_paths(
self.all_paths - set(["foo/a.py", "foo/b.py", "bar/a.py"]) self.all_paths - set(["foo/a.py", "foo/b.py", "bar/a.py"])
) )
def test_directory(self): def test_directory(self) -> None:
assert self.exclude(["foo"]) == convert_paths( assert self.exclude(["foo"]) == convert_paths(
self.all_paths self.all_paths
- set( - set(
@ -233,7 +238,7 @@ class ExcludePathsTest(unittest.TestCase):
) )
) )
def test_directory_with_trailing_slash(self): def test_directory_with_trailing_slash(self) -> None:
assert self.exclude(["foo"]) == convert_paths( assert self.exclude(["foo"]) == convert_paths(
self.all_paths self.all_paths
- set( - set(
@ -248,13 +253,13 @@ class ExcludePathsTest(unittest.TestCase):
) )
) )
def test_directory_with_single_exception(self): def test_directory_with_single_exception(self) -> None:
assert self.exclude(["foo", "!foo/bar/a.py"]) == convert_paths( assert self.exclude(["foo", "!foo/bar/a.py"]) == convert_paths(
self.all_paths self.all_paths
- set(["foo/a.py", "foo/b.py", "foo", "foo/bar", "foo/Dockerfile3"]) - set(["foo/a.py", "foo/b.py", "foo", "foo/bar", "foo/Dockerfile3"])
) )
def test_directory_with_subdir_exception(self): def test_directory_with_subdir_exception(self) -> None:
assert self.exclude(["foo", "!foo/bar"]) == convert_paths( assert self.exclude(["foo", "!foo/bar"]) == convert_paths(
self.all_paths - set(["foo/a.py", "foo/b.py", "foo", "foo/Dockerfile3"]) self.all_paths - set(["foo/a.py", "foo/b.py", "foo", "foo/Dockerfile3"])
) )
@ -262,17 +267,17 @@ class ExcludePathsTest(unittest.TestCase):
@pytest.mark.skipif( @pytest.mark.skipif(
not IS_WINDOWS_PLATFORM, reason="Backslash patterns only on Windows" not IS_WINDOWS_PLATFORM, reason="Backslash patterns only on Windows"
) )
def test_directory_with_subdir_exception_win32_pathsep(self): def test_directory_with_subdir_exception_win32_pathsep(self) -> None:
assert self.exclude(["foo", "!foo\\bar"]) == convert_paths( assert self.exclude(["foo", "!foo\\bar"]) == convert_paths(
self.all_paths - set(["foo/a.py", "foo/b.py", "foo", "foo/Dockerfile3"]) self.all_paths - set(["foo/a.py", "foo/b.py", "foo", "foo/Dockerfile3"])
) )
def test_directory_with_wildcard_exception(self): def test_directory_with_wildcard_exception(self) -> None:
assert self.exclude(["foo", "!foo/*.py"]) == convert_paths( assert self.exclude(["foo", "!foo/*.py"]) == convert_paths(
self.all_paths - set(["foo/bar", "foo/bar/a.py", "foo", "foo/Dockerfile3"]) self.all_paths - set(["foo/bar", "foo/bar/a.py", "foo", "foo/Dockerfile3"])
) )
def test_subdirectory(self): def test_subdirectory(self) -> None:
assert self.exclude(["foo/bar"]) == convert_paths( assert self.exclude(["foo/bar"]) == convert_paths(
self.all_paths - set(["foo/bar", "foo/bar/a.py"]) self.all_paths - set(["foo/bar", "foo/bar/a.py"])
) )
@ -280,12 +285,12 @@ class ExcludePathsTest(unittest.TestCase):
@pytest.mark.skipif( @pytest.mark.skipif(
not IS_WINDOWS_PLATFORM, reason="Backslash patterns only on Windows" not IS_WINDOWS_PLATFORM, reason="Backslash patterns only on Windows"
) )
def test_subdirectory_win32_pathsep(self): def test_subdirectory_win32_pathsep(self) -> None:
assert self.exclude(["foo\\bar"]) == convert_paths( assert self.exclude(["foo\\bar"]) == convert_paths(
self.all_paths - set(["foo/bar", "foo/bar/a.py"]) self.all_paths - set(["foo/bar", "foo/bar/a.py"])
) )
def test_double_wildcard(self): def test_double_wildcard(self) -> None:
assert self.exclude(["**/a.py"]) == convert_paths( assert self.exclude(["**/a.py"]) == convert_paths(
self.all_paths - set(["a.py", "foo/a.py", "foo/bar/a.py", "bar/a.py"]) self.all_paths - set(["a.py", "foo/a.py", "foo/bar/a.py", "bar/a.py"])
) )
@ -294,7 +299,7 @@ class ExcludePathsTest(unittest.TestCase):
self.all_paths - set(["foo/bar", "foo/bar/a.py"]) self.all_paths - set(["foo/bar", "foo/bar/a.py"])
) )
def test_single_and_double_wildcard(self): def test_single_and_double_wildcard(self) -> None:
assert self.exclude(["**/target/*/*"]) == convert_paths( assert self.exclude(["**/target/*/*"]) == convert_paths(
self.all_paths self.all_paths
- set( - set(
@ -306,7 +311,7 @@ class ExcludePathsTest(unittest.TestCase):
) )
) )
def test_trailing_double_wildcard(self): def test_trailing_double_wildcard(self) -> None:
assert self.exclude(["subdir/**"]) == convert_paths( assert self.exclude(["subdir/**"]) == convert_paths(
self.all_paths self.all_paths
- set( - set(
@ -326,7 +331,7 @@ class ExcludePathsTest(unittest.TestCase):
) )
) )
def test_double_wildcard_with_exception(self): def test_double_wildcard_with_exception(self) -> None:
assert self.exclude(["**", "!bar", "!foo/bar"]) == convert_paths( assert self.exclude(["**", "!bar", "!foo/bar"]) == convert_paths(
set( set(
[ [
@ -340,13 +345,13 @@ class ExcludePathsTest(unittest.TestCase):
) )
) )
def test_include_wildcard(self): def test_include_wildcard(self) -> None:
# This may be surprising but it matches the CLI's behavior # This may be surprising but it matches the CLI's behavior
# (tested with 18.05.0-ce on linux) # (tested with 18.05.0-ce on linux)
base = make_tree(["a"], ["a/b.py"]) base = make_tree(["a"], ["a/b.py"])
assert exclude_paths(base, ["*", "!*/b.py"]) == set() assert exclude_paths(base, ["*", "!*/b.py"]) == set()
def test_last_line_precedence(self): def test_last_line_precedence(self) -> None:
base = make_tree( base = make_tree(
[], [],
[ [
@ -361,7 +366,7 @@ class ExcludePathsTest(unittest.TestCase):
["README.md", "README-bis.md"] ["README.md", "README-bis.md"]
) )
def test_parent_directory(self): def test_parent_directory(self) -> None:
base = make_tree([], ["a.py", "b.py", "c.py"]) base = make_tree([], ["a.py", "b.py", "c.py"])
# Dockerignore reference stipulates that absolute paths are # Dockerignore reference stipulates that absolute paths are
# equivalent to relative paths, hence /../foo should be # equivalent to relative paths, hence /../foo should be
@ -372,7 +377,7 @@ class ExcludePathsTest(unittest.TestCase):
class TarTest(unittest.TestCase): class TarTest(unittest.TestCase):
def test_tar_with_excludes(self): def test_tar_with_excludes(self) -> None:
dirs = [ dirs = [
"foo", "foo",
"foo/bar", "foo/bar",
@ -420,7 +425,7 @@ class TarTest(unittest.TestCase):
with tarfile.open(fileobj=archive) as tar_data: with tarfile.open(fileobj=archive) as tar_data:
assert sorted(tar_data.getnames()) == sorted(expected_names) assert sorted(tar_data.getnames()) == sorted(expected_names)
def test_tar_with_empty_directory(self): def test_tar_with_empty_directory(self) -> None:
base = tempfile.mkdtemp() base = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, base) self.addCleanup(shutil.rmtree, base)
for d in ["foo", "bar"]: for d in ["foo", "bar"]:
@ -433,7 +438,7 @@ class TarTest(unittest.TestCase):
IS_WINDOWS_PLATFORM or os.geteuid() == 0, IS_WINDOWS_PLATFORM or os.geteuid() == 0,
reason="root user always has access ; no chmod on Windows", reason="root user always has access ; no chmod on Windows",
) )
def test_tar_with_inaccessible_file(self): def test_tar_with_inaccessible_file(self) -> None:
base = tempfile.mkdtemp() base = tempfile.mkdtemp()
full_path = os.path.join(base, "foo") full_path = os.path.join(base, "foo")
self.addCleanup(shutil.rmtree, base) self.addCleanup(shutil.rmtree, base)
@ -446,7 +451,7 @@ class TarTest(unittest.TestCase):
assert f"Can not read file in context: {full_path}" in ei.exconly() assert f"Can not read file in context: {full_path}" in ei.exconly()
@pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="No symlinks on Windows") @pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="No symlinks on Windows")
def test_tar_with_file_symlinks(self): def test_tar_with_file_symlinks(self) -> None:
base = tempfile.mkdtemp() base = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, base) self.addCleanup(shutil.rmtree, base)
with open(os.path.join(base, "foo"), "wt", encoding="utf-8") as f: with open(os.path.join(base, "foo"), "wt", encoding="utf-8") as f:
@ -458,7 +463,7 @@ class TarTest(unittest.TestCase):
assert sorted(tar_data.getnames()) == ["bar", "bar/foo", "foo"] assert sorted(tar_data.getnames()) == ["bar", "bar/foo", "foo"]
@pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="No symlinks on Windows") @pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="No symlinks on Windows")
def test_tar_with_directory_symlinks(self): def test_tar_with_directory_symlinks(self) -> None:
base = tempfile.mkdtemp() base = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, base) self.addCleanup(shutil.rmtree, base)
for d in ["foo", "bar"]: for d in ["foo", "bar"]:
@ -469,7 +474,7 @@ class TarTest(unittest.TestCase):
assert sorted(tar_data.getnames()) == ["bar", "bar/foo", "foo"] assert sorted(tar_data.getnames()) == ["bar", "bar/foo", "foo"]
@pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="No symlinks on Windows") @pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="No symlinks on Windows")
def test_tar_with_broken_symlinks(self): def test_tar_with_broken_symlinks(self) -> None:
base = tempfile.mkdtemp() base = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, base) self.addCleanup(shutil.rmtree, base)
for d in ["foo", "bar"]: for d in ["foo", "bar"]:
@ -481,7 +486,7 @@ class TarTest(unittest.TestCase):
assert sorted(tar_data.getnames()) == ["bar", "bar/foo", "foo"] assert sorted(tar_data.getnames()) == ["bar", "bar/foo", "foo"]
@pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="No UNIX sockets on Win32") @pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="No UNIX sockets on Win32")
def test_tar_socket_file(self): def test_tar_socket_file(self) -> None:
base = tempfile.mkdtemp() base = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, base) self.addCleanup(shutil.rmtree, base)
for d in ["foo", "bar"]: for d in ["foo", "bar"]:
@ -493,7 +498,7 @@ class TarTest(unittest.TestCase):
with tarfile.open(fileobj=archive) as tar_data: with tarfile.open(fileobj=archive) as tar_data:
assert sorted(tar_data.getnames()) == ["bar", "foo"] assert sorted(tar_data.getnames()) == ["bar", "foo"]
def tar_test_negative_mtime_bug(self): def tar_test_negative_mtime_bug(self) -> None:
base = tempfile.mkdtemp() base = tempfile.mkdtemp()
filename = os.path.join(base, "th.txt") filename = os.path.join(base, "th.txt")
self.addCleanup(shutil.rmtree, base) self.addCleanup(shutil.rmtree, base)
@ -506,7 +511,7 @@ class TarTest(unittest.TestCase):
assert tar_data.getmember("th.txt").mtime == -3600 assert tar_data.getmember("th.txt").mtime == -3600
@pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="No symlinks on Windows") @pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="No symlinks on Windows")
def test_tar_directory_link(self): def test_tar_directory_link(self) -> None:
dirs = ["a", "b", "a/c"] dirs = ["a", "b", "a/c"]
files = ["a/hello.py", "b/utils.py", "a/c/descend.py"] files = ["a/hello.py", "b/utils.py", "a/c/descend.py"]
base = make_tree(dirs, files) base = make_tree(dirs, files)

View File

@ -12,6 +12,7 @@ import json
import os import os
import shutil import shutil
import tempfile import tempfile
import typing as t
import unittest import unittest
from collections.abc import Callable from collections.abc import Callable
from unittest import mock from unittest import mock
@ -25,55 +26,55 @@ class FindConfigFileTest(unittest.TestCase):
mkdir: Callable[[str], os.PathLike[str]] mkdir: Callable[[str], os.PathLike[str]]
@fixture(autouse=True) @fixture(autouse=True)
def tmpdir(self, tmpdir): def tmpdir(self, tmpdir: t.Any) -> None:
self.mkdir = tmpdir.mkdir self.mkdir = tmpdir.mkdir
def test_find_config_fallback(self): def test_find_config_fallback(self) -> None:
tmpdir = self.mkdir("test_find_config_fallback") tmpdir = self.mkdir("test_find_config_fallback")
with mock.patch.dict(os.environ, {"HOME": str(tmpdir)}): with mock.patch.dict(os.environ, {"HOME": str(tmpdir)}):
assert config.find_config_file() is None assert config.find_config_file() is None
def test_find_config_from_explicit_path(self): def test_find_config_from_explicit_path(self) -> None:
tmpdir = self.mkdir("test_find_config_from_explicit_path") tmpdir = self.mkdir("test_find_config_from_explicit_path")
config_path = tmpdir.ensure("my-config-file.json") config_path = tmpdir.ensure("my-config-file.json") # type: ignore[attr-defined]
assert config.find_config_file(str(config_path)) == str(config_path) assert config.find_config_file(str(config_path)) == str(config_path)
def test_find_config_from_environment(self): def test_find_config_from_environment(self) -> None:
tmpdir = self.mkdir("test_find_config_from_environment") tmpdir = self.mkdir("test_find_config_from_environment")
config_path = tmpdir.ensure("config.json") config_path = tmpdir.ensure("config.json") # type: ignore[attr-defined]
with mock.patch.dict(os.environ, {"DOCKER_CONFIG": str(tmpdir)}): with mock.patch.dict(os.environ, {"DOCKER_CONFIG": str(tmpdir)}):
assert config.find_config_file() == str(config_path) assert config.find_config_file() == str(config_path)
@mark.skipif("sys.platform == 'win32'") @mark.skipif("sys.platform == 'win32'")
def test_find_config_from_home_posix(self): def test_find_config_from_home_posix(self) -> None:
tmpdir = self.mkdir("test_find_config_from_home_posix") tmpdir = self.mkdir("test_find_config_from_home_posix")
config_path = tmpdir.ensure(".docker", "config.json") config_path = tmpdir.ensure(".docker", "config.json") # type: ignore[attr-defined]
with mock.patch.dict(os.environ, {"HOME": str(tmpdir)}): with mock.patch.dict(os.environ, {"HOME": str(tmpdir)}):
assert config.find_config_file() == str(config_path) assert config.find_config_file() == str(config_path)
@mark.skipif("sys.platform == 'win32'") @mark.skipif("sys.platform == 'win32'")
def test_find_config_from_home_legacy_name(self): def test_find_config_from_home_legacy_name(self) -> None:
tmpdir = self.mkdir("test_find_config_from_home_legacy_name") tmpdir = self.mkdir("test_find_config_from_home_legacy_name")
config_path = tmpdir.ensure(".dockercfg") config_path = tmpdir.ensure(".dockercfg") # type: ignore[attr-defined]
with mock.patch.dict(os.environ, {"HOME": str(tmpdir)}): with mock.patch.dict(os.environ, {"HOME": str(tmpdir)}):
assert config.find_config_file() == str(config_path) assert config.find_config_file() == str(config_path)
@mark.skipif("sys.platform != 'win32'") @mark.skipif("sys.platform != 'win32'")
def test_find_config_from_home_windows(self): def test_find_config_from_home_windows(self) -> None:
tmpdir = self.mkdir("test_find_config_from_home_windows") tmpdir = self.mkdir("test_find_config_from_home_windows")
config_path = tmpdir.ensure(".docker", "config.json") config_path = tmpdir.ensure(".docker", "config.json") # type: ignore[attr-defined]
with mock.patch.dict(os.environ, {"USERPROFILE": str(tmpdir)}): with mock.patch.dict(os.environ, {"USERPROFILE": str(tmpdir)}):
assert config.find_config_file() == str(config_path) assert config.find_config_file() == str(config_path)
class LoadConfigTest(unittest.TestCase): class LoadConfigTest(unittest.TestCase):
def test_load_config_no_file(self): def test_load_config_no_file(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
cfg = config.load_general_config(folder) cfg = config.load_general_config(folder)
@ -81,7 +82,7 @@ class LoadConfigTest(unittest.TestCase):
assert isinstance(cfg, dict) assert isinstance(cfg, dict)
assert not cfg assert not cfg
def test_load_config_custom_headers(self): def test_load_config_custom_headers(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
@ -97,7 +98,7 @@ class LoadConfigTest(unittest.TestCase):
assert "HttpHeaders" in cfg assert "HttpHeaders" in cfg
assert cfg["HttpHeaders"] == {"Name": "Spike", "Surname": "Spiegel"} assert cfg["HttpHeaders"] == {"Name": "Spike", "Surname": "Spiegel"}
def test_load_config_detach_keys(self): def test_load_config_detach_keys(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
dockercfg_path = os.path.join(folder, "config.json") dockercfg_path = os.path.join(folder, "config.json")
@ -108,7 +109,7 @@ class LoadConfigTest(unittest.TestCase):
cfg = config.load_general_config(dockercfg_path) cfg = config.load_general_config(dockercfg_path)
assert cfg == config_data assert cfg == config_data
def test_load_config_from_env(self): def test_load_config_from_env(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
dockercfg_path = os.path.join(folder, "config.json") dockercfg_path = os.path.join(folder, "config.json")

View File

@ -8,6 +8,7 @@
from __future__ import annotations from __future__ import annotations
import typing as t
import unittest import unittest
from ansible_collections.community.docker.plugins.module_utils._api.api.client import ( from ansible_collections.community.docker.plugins.module_utils._api.api.client import (
@ -22,12 +23,12 @@ from ansible_collections.community.docker.tests.unit.plugins.module_utils._api.c
class DecoratorsTest(unittest.TestCase): class DecoratorsTest(unittest.TestCase):
def test_update_headers(self): def test_update_headers(self) -> None:
sample_headers = { sample_headers = {
"X-Docker-Locale": "en-US", "X-Docker-Locale": "en-US",
} }
def f(self, headers=None): def f(self: t.Any, headers: t.Any = None) -> t.Any:
return headers return headers
client = APIClient(version=DEFAULT_DOCKER_API_VERSION) client = APIClient(version=DEFAULT_DOCKER_API_VERSION)

View File

@ -8,6 +8,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from ansible_collections.community.docker.plugins.module_utils._api.utils.json_stream import ( from ansible_collections.community.docker.plugins.module_utils._api.utils.json_stream import (
json_splitter, json_splitter,
json_stream, json_stream,
@ -15,41 +17,48 @@ from ansible_collections.community.docker.plugins.module_utils._api.utils.json_s
) )
class TestJsonSplitter: if t.TYPE_CHECKING:
T = t.TypeVar("T")
def test_json_splitter_no_object(self):
def create_generator(input_sequence: list[T]) -> t.Generator[T]:
yield from input_sequence
class TestJsonSplitter:
def test_json_splitter_no_object(self) -> None:
data = '{"foo": "bar' data = '{"foo": "bar'
assert json_splitter(data) is None assert json_splitter(data) is None
def test_json_splitter_with_object(self): def test_json_splitter_with_object(self) -> None:
data = '{"foo": "bar"}\n \n{"next": "obj"}' data = '{"foo": "bar"}\n \n{"next": "obj"}'
assert json_splitter(data) == ({"foo": "bar"}, '{"next": "obj"}') assert json_splitter(data) == ({"foo": "bar"}, '{"next": "obj"}')
def test_json_splitter_leading_whitespace(self): def test_json_splitter_leading_whitespace(self) -> None:
data = '\n \r{"foo": "bar"}\n\n {"next": "obj"}' data = '\n \r{"foo": "bar"}\n\n {"next": "obj"}'
assert json_splitter(data) == ({"foo": "bar"}, '{"next": "obj"}') assert json_splitter(data) == ({"foo": "bar"}, '{"next": "obj"}')
class TestStreamAsText: class TestStreamAsText:
def test_stream_with_non_utf_unicode_character(self) -> None:
def test_stream_with_non_utf_unicode_character(self): stream = create_generator([b"\xed\xf3\xf3"])
stream = [b"\xed\xf3\xf3"]
(output,) = stream_as_text(stream) (output,) = stream_as_text(stream)
assert output == "<EFBFBD><EFBFBD><EFBFBD>" assert output == "<EFBFBD><EFBFBD><EFBFBD>"
def test_stream_with_utf_character(self): def test_stream_with_utf_character(self) -> None:
stream = ["ěĝ".encode("utf-8")] stream = create_generator(["ěĝ".encode("utf-8")])
(output,) = stream_as_text(stream) (output,) = stream_as_text(stream)
assert output == "ěĝ" assert output == "ěĝ"
class TestJsonStream: class TestJsonStream:
def test_with_falsy_entries(self) -> None:
def test_with_falsy_entries(self): stream = create_generator(
stream = [ [
'{"one": "two"}\n{}\n', '{"one": "two"}\n{}\n',
"[1, 2, 3]\n[]\n", "[1, 2, 3]\n[]\n",
] ]
)
output = list(json_stream(stream)) output = list(json_stream(stream))
assert output == [ assert output == [
{"one": "two"}, {"one": "two"},
@ -58,7 +67,9 @@ class TestJsonStream:
[], [],
] ]
def test_with_leading_whitespace(self): def test_with_leading_whitespace(self) -> None:
stream = ['\n \r\n {"one": "two"}{"x": 1}', ' {"three": "four"}\t\t{"x": 2}'] stream = create_generator(
['\n \r\n {"one": "two"}{"x": 1}', ' {"three": "four"}\t\t{"x": 2}']
)
output = list(json_stream(stream)) output = list(json_stream(stream))
assert output == [{"one": "two"}, {"x": 1}, {"three": "four"}, {"x": 2}] assert output == [{"one": "two"}, {"x": 1}, {"three": "four"}, {"x": 2}]

View File

@ -19,132 +19,132 @@ from ansible_collections.community.docker.plugins.module_utils._api.utils.ports
class PortsTest(unittest.TestCase): class PortsTest(unittest.TestCase):
def test_split_port_with_host_ip(self): def test_split_port_with_host_ip(self) -> None:
internal_port, external_port = split_port("127.0.0.1:1000:2000") internal_port, external_port = split_port("127.0.0.1:1000:2000")
assert internal_port == ["2000"] assert internal_port == ["2000"]
assert external_port == [("127.0.0.1", "1000")] assert external_port == [("127.0.0.1", "1000")]
def test_split_port_with_protocol(self): def test_split_port_with_protocol(self) -> None:
for protocol in ["tcp", "udp", "sctp"]: for protocol in ["tcp", "udp", "sctp"]:
internal_port, external_port = split_port("127.0.0.1:1000:2000/" + protocol) internal_port, external_port = split_port("127.0.0.1:1000:2000/" + protocol)
assert internal_port == ["2000/" + protocol] assert internal_port == ["2000/" + protocol]
assert external_port == [("127.0.0.1", "1000")] assert external_port == [("127.0.0.1", "1000")]
def test_split_port_with_host_ip_no_port(self): def test_split_port_with_host_ip_no_port(self) -> None:
internal_port, external_port = split_port("127.0.0.1::2000") internal_port, external_port = split_port("127.0.0.1::2000")
assert internal_port == ["2000"] assert internal_port == ["2000"]
assert external_port == [("127.0.0.1", None)] assert external_port == [("127.0.0.1", None)]
def test_split_port_range_with_host_ip_no_port(self): def test_split_port_range_with_host_ip_no_port(self) -> None:
internal_port, external_port = split_port("127.0.0.1::2000-2001") internal_port, external_port = split_port("127.0.0.1::2000-2001")
assert internal_port == ["2000", "2001"] assert internal_port == ["2000", "2001"]
assert external_port == [("127.0.0.1", None), ("127.0.0.1", None)] assert external_port == [("127.0.0.1", None), ("127.0.0.1", None)]
def test_split_port_with_host_port(self): def test_split_port_with_host_port(self) -> None:
internal_port, external_port = split_port("1000:2000") internal_port, external_port = split_port("1000:2000")
assert internal_port == ["2000"] assert internal_port == ["2000"]
assert external_port == ["1000"] assert external_port == ["1000"]
def test_split_port_range_with_host_port(self): def test_split_port_range_with_host_port(self) -> None:
internal_port, external_port = split_port("1000-1001:2000-2001") internal_port, external_port = split_port("1000-1001:2000-2001")
assert internal_port == ["2000", "2001"] assert internal_port == ["2000", "2001"]
assert external_port == ["1000", "1001"] assert external_port == ["1000", "1001"]
def test_split_port_random_port_range_with_host_port(self): def test_split_port_random_port_range_with_host_port(self) -> None:
internal_port, external_port = split_port("1000-1001:2000") internal_port, external_port = split_port("1000-1001:2000")
assert internal_port == ["2000"] assert internal_port == ["2000"]
assert external_port == ["1000-1001"] assert external_port == ["1000-1001"]
def test_split_port_no_host_port(self): def test_split_port_no_host_port(self) -> None:
internal_port, external_port = split_port("2000") internal_port, external_port = split_port("2000")
assert internal_port == ["2000"] assert internal_port == ["2000"]
assert external_port is None assert external_port is None
def test_split_port_range_no_host_port(self): def test_split_port_range_no_host_port(self) -> None:
internal_port, external_port = split_port("2000-2001") internal_port, external_port = split_port("2000-2001")
assert internal_port == ["2000", "2001"] assert internal_port == ["2000", "2001"]
assert external_port is None assert external_port is None
def test_split_port_range_with_protocol(self): def test_split_port_range_with_protocol(self) -> None:
internal_port, external_port = split_port("127.0.0.1:1000-1001:2000-2001/udp") internal_port, external_port = split_port("127.0.0.1:1000-1001:2000-2001/udp")
assert internal_port == ["2000/udp", "2001/udp"] assert internal_port == ["2000/udp", "2001/udp"]
assert external_port == [("127.0.0.1", "1000"), ("127.0.0.1", "1001")] assert external_port == [("127.0.0.1", "1000"), ("127.0.0.1", "1001")]
def test_split_port_with_ipv6_address(self): def test_split_port_with_ipv6_address(self) -> None:
internal_port, external_port = split_port("2001:abcd:ef00::2:1000:2000") internal_port, external_port = split_port("2001:abcd:ef00::2:1000:2000")
assert internal_port == ["2000"] assert internal_port == ["2000"]
assert external_port == [("2001:abcd:ef00::2", "1000")] assert external_port == [("2001:abcd:ef00::2", "1000")]
def test_split_port_with_ipv6_square_brackets_address(self): def test_split_port_with_ipv6_square_brackets_address(self) -> None:
internal_port, external_port = split_port("[2001:abcd:ef00::2]:1000:2000") internal_port, external_port = split_port("[2001:abcd:ef00::2]:1000:2000")
assert internal_port == ["2000"] assert internal_port == ["2000"]
assert external_port == [("2001:abcd:ef00::2", "1000")] assert external_port == [("2001:abcd:ef00::2", "1000")]
def test_split_port_invalid(self): def test_split_port_invalid(self) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
split_port("0.0.0.0:1000:2000:tcp") split_port("0.0.0.0:1000:2000:tcp")
def test_split_port_invalid_protocol(self): def test_split_port_invalid_protocol(self) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
split_port("0.0.0.0:1000:2000/ftp") split_port("0.0.0.0:1000:2000/ftp")
def test_non_matching_length_port_ranges(self): def test_non_matching_length_port_ranges(self) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
split_port("0.0.0.0:1000-1010:2000-2002/tcp") split_port("0.0.0.0:1000-1010:2000-2002/tcp")
def test_port_and_range_invalid(self): def test_port_and_range_invalid(self) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
split_port("0.0.0.0:1000:2000-2002/tcp") split_port("0.0.0.0:1000:2000-2002/tcp")
def test_port_only_with_colon(self): def test_port_only_with_colon(self) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
split_port(":80") split_port(":80")
def test_host_only_with_colon(self): def test_host_only_with_colon(self) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
split_port("localhost:") split_port("localhost:")
def test_with_no_container_port(self): def test_with_no_container_port(self) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
split_port("localhost:80:") split_port("localhost:80:")
def test_split_port_empty_string(self): def test_split_port_empty_string(self) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
split_port("") split_port("")
def test_split_port_non_string(self): def test_split_port_non_string(self) -> None:
assert split_port(1243) == (["1243"], None) assert split_port(1243) == (["1243"], None)
def test_build_port_bindings_with_one_port(self): def test_build_port_bindings_with_one_port(self) -> None:
port_bindings = build_port_bindings(["127.0.0.1:1000:1000"]) port_bindings = build_port_bindings(["127.0.0.1:1000:1000"])
assert port_bindings["1000"] == [("127.0.0.1", "1000")] assert port_bindings["1000"] == [("127.0.0.1", "1000")]
def test_build_port_bindings_with_matching_internal_ports(self): def test_build_port_bindings_with_matching_internal_ports(self) -> None:
port_bindings = build_port_bindings( port_bindings = build_port_bindings(
["127.0.0.1:1000:1000", "127.0.0.1:2000:1000"] ["127.0.0.1:1000:1000", "127.0.0.1:2000:1000"]
) )
assert port_bindings["1000"] == [("127.0.0.1", "1000"), ("127.0.0.1", "2000")] assert port_bindings["1000"] == [("127.0.0.1", "1000"), ("127.0.0.1", "2000")]
def test_build_port_bindings_with_nonmatching_internal_ports(self): def test_build_port_bindings_with_nonmatching_internal_ports(self) -> None:
port_bindings = build_port_bindings( port_bindings = build_port_bindings(
["127.0.0.1:1000:1000", "127.0.0.1:2000:2000"] ["127.0.0.1:1000:1000", "127.0.0.1:2000:2000"]
) )
assert port_bindings["1000"] == [("127.0.0.1", "1000")] assert port_bindings["1000"] == [("127.0.0.1", "1000")]
assert port_bindings["2000"] == [("127.0.0.1", "2000")] assert port_bindings["2000"] == [("127.0.0.1", "2000")]
def test_build_port_bindings_with_port_range(self): def test_build_port_bindings_with_port_range(self) -> None:
port_bindings = build_port_bindings(["127.0.0.1:1000-1001:1000-1001"]) port_bindings = build_port_bindings(["127.0.0.1:1000-1001:1000-1001"])
assert port_bindings["1000"] == [("127.0.0.1", "1000")] assert port_bindings["1000"] == [("127.0.0.1", "1000")]
assert port_bindings["1001"] == [("127.0.0.1", "1001")] assert port_bindings["1001"] == [("127.0.0.1", "1001")]
def test_build_port_bindings_with_matching_internal_port_ranges(self): def test_build_port_bindings_with_matching_internal_port_ranges(self) -> None:
port_bindings = build_port_bindings( port_bindings = build_port_bindings(
["127.0.0.1:1000-1001:1000-1001", "127.0.0.1:2000-2001:1000-1001"] ["127.0.0.1:1000-1001:1000-1001", "127.0.0.1:2000-2001:1000-1001"]
) )
assert port_bindings["1000"] == [("127.0.0.1", "1000"), ("127.0.0.1", "2000")] assert port_bindings["1000"] == [("127.0.0.1", "1000"), ("127.0.0.1", "2000")]
assert port_bindings["1001"] == [("127.0.0.1", "1001"), ("127.0.0.1", "2001")] assert port_bindings["1001"] == [("127.0.0.1", "1001"), ("127.0.0.1", "2001")]
def test_build_port_bindings_with_nonmatching_internal_port_ranges(self): def test_build_port_bindings_with_nonmatching_internal_port_ranges(self) -> None:
port_bindings = build_port_bindings( port_bindings = build_port_bindings(
["127.0.0.1:1000:1000", "127.0.0.1:2000:2000"] ["127.0.0.1:1000:1000", "127.0.0.1:2000:2000"]
) )

View File

@ -33,8 +33,7 @@ ENV = {
class ProxyConfigTest(unittest.TestCase): class ProxyConfigTest(unittest.TestCase):
def test_from_dict(self) -> None:
def test_from_dict(self):
config = ProxyConfig.from_dict( config = ProxyConfig.from_dict(
{ {
"httpProxy": HTTP, "httpProxy": HTTP,
@ -48,7 +47,7 @@ class ProxyConfigTest(unittest.TestCase):
self.assertEqual(CONFIG.ftp, config.ftp) self.assertEqual(CONFIG.ftp, config.ftp)
self.assertEqual(CONFIG.no_proxy, config.no_proxy) self.assertEqual(CONFIG.no_proxy, config.no_proxy)
def test_new(self): def test_new(self) -> None:
config = ProxyConfig() config = ProxyConfig()
self.assertIsNone(config.http) self.assertIsNone(config.http)
self.assertIsNone(config.https) self.assertIsNone(config.https)
@ -61,22 +60,24 @@ class ProxyConfigTest(unittest.TestCase):
self.assertEqual(config.ftp, "c") self.assertEqual(config.ftp, "c")
self.assertEqual(config.no_proxy, "d") self.assertEqual(config.no_proxy, "d")
def test_truthiness(self): def test_truthiness(self) -> None:
assert not ProxyConfig() assert not ProxyConfig()
assert ProxyConfig(http="non-zero") assert ProxyConfig(http="non-zero")
assert ProxyConfig(https="non-zero") assert ProxyConfig(https="non-zero")
assert ProxyConfig(ftp="non-zero") assert ProxyConfig(ftp="non-zero")
assert ProxyConfig(no_proxy="non-zero") assert ProxyConfig(no_proxy="non-zero")
def test_environment(self): def test_environment(self) -> None:
self.assertDictEqual(CONFIG.get_environment(), ENV) self.assertDictEqual(CONFIG.get_environment(), ENV)
empty = ProxyConfig() empty = ProxyConfig()
self.assertDictEqual(empty.get_environment(), {}) self.assertDictEqual(empty.get_environment(), {})
def test_inject_proxy_environment(self): def test_inject_proxy_environment(self) -> None:
# Proxy config is non null, env is None. # Proxy config is non null, env is None.
envlist = CONFIG.inject_proxy_environment(None)
assert envlist is not None
self.assertSetEqual( self.assertSetEqual(
set(CONFIG.inject_proxy_environment(None)), set(envlist),
set(f"{k}={v}" for k, v in ENV.items()), set(f"{k}={v}" for k, v in ENV.items()),
) )

View File

@ -52,13 +52,15 @@ TEST_CERT_DIR = os.path.join(
class KwargsFromEnvTest(unittest.TestCase): class KwargsFromEnvTest(unittest.TestCase):
def setUp(self): os_environ: dict[str, str]
def setUp(self) -> None:
self.os_environ = os.environ.copy() self.os_environ = os.environ.copy()
def tearDown(self): def tearDown(self) -> None:
os.environ = self.os_environ os.environ = self.os_environ # type: ignore
def test_kwargs_from_env_empty(self): def test_kwargs_from_env_empty(self) -> None:
os.environ.update(DOCKER_HOST="", DOCKER_CERT_PATH="") os.environ.update(DOCKER_HOST="", DOCKER_CERT_PATH="")
os.environ.pop("DOCKER_TLS_VERIFY", None) os.environ.pop("DOCKER_TLS_VERIFY", None)
@ -66,7 +68,7 @@ class KwargsFromEnvTest(unittest.TestCase):
assert kwargs.get("base_url") is None assert kwargs.get("base_url") is None
assert kwargs.get("tls") is None assert kwargs.get("tls") is None
def test_kwargs_from_env_tls(self): def test_kwargs_from_env_tls(self) -> None:
os.environ.update( os.environ.update(
DOCKER_HOST="tcp://192.168.59.103:2376", DOCKER_HOST="tcp://192.168.59.103:2376",
DOCKER_CERT_PATH=TEST_CERT_DIR, DOCKER_CERT_PATH=TEST_CERT_DIR,
@ -90,7 +92,7 @@ class KwargsFromEnvTest(unittest.TestCase):
except TypeError as e: except TypeError as e:
self.fail(e) self.fail(e)
def test_kwargs_from_env_tls_verify_false(self): def test_kwargs_from_env_tls_verify_false(self) -> None:
os.environ.update( os.environ.update(
DOCKER_HOST="tcp://192.168.59.103:2376", DOCKER_HOST="tcp://192.168.59.103:2376",
DOCKER_CERT_PATH=TEST_CERT_DIR, DOCKER_CERT_PATH=TEST_CERT_DIR,
@ -113,7 +115,7 @@ class KwargsFromEnvTest(unittest.TestCase):
except TypeError as e: except TypeError as e:
self.fail(e) self.fail(e)
def test_kwargs_from_env_tls_verify_false_no_cert(self): def test_kwargs_from_env_tls_verify_false_no_cert(self) -> None:
temp_dir = tempfile.mkdtemp() temp_dir = tempfile.mkdtemp()
cert_dir = os.path.join(temp_dir, ".docker") cert_dir = os.path.join(temp_dir, ".docker")
shutil.copytree(TEST_CERT_DIR, cert_dir) shutil.copytree(TEST_CERT_DIR, cert_dir)
@ -125,7 +127,7 @@ class KwargsFromEnvTest(unittest.TestCase):
kwargs = kwargs_from_env(assert_hostname=True) kwargs = kwargs_from_env(assert_hostname=True)
assert "tcp://192.168.59.103:2376" == kwargs["base_url"] assert "tcp://192.168.59.103:2376" == kwargs["base_url"]
def test_kwargs_from_env_no_cert_path(self): def test_kwargs_from_env_no_cert_path(self) -> None:
try: try:
temp_dir = tempfile.mkdtemp() temp_dir = tempfile.mkdtemp()
cert_dir = os.path.join(temp_dir, ".docker") cert_dir = os.path.join(temp_dir, ".docker")
@ -142,7 +144,7 @@ class KwargsFromEnvTest(unittest.TestCase):
if temp_dir: if temp_dir:
shutil.rmtree(temp_dir) shutil.rmtree(temp_dir)
def test_kwargs_from_env_alternate_env(self): def test_kwargs_from_env_alternate_env(self) -> None:
# Values in os.environ are entirely ignored if an alternate is # Values in os.environ are entirely ignored if an alternate is
# provided # provided
os.environ.update( os.environ.update(
@ -160,30 +162,32 @@ class KwargsFromEnvTest(unittest.TestCase):
class ConverVolumeBindsTest(unittest.TestCase): class ConverVolumeBindsTest(unittest.TestCase):
def test_convert_volume_binds_empty(self): def test_convert_volume_binds_empty(self) -> None:
assert convert_volume_binds({}) == [] assert convert_volume_binds({}) == []
assert convert_volume_binds([]) == [] assert convert_volume_binds([]) == []
def test_convert_volume_binds_list(self): def test_convert_volume_binds_list(self) -> None:
data = ["/a:/a:ro", "/b:/c:z"] data = ["/a:/a:ro", "/b:/c:z"]
assert convert_volume_binds(data) == data assert convert_volume_binds(data) == data
def test_convert_volume_binds_complete(self): def test_convert_volume_binds_complete(self) -> None:
data = {"/mnt/vol1": {"bind": "/data", "mode": "ro"}} data: dict[str | bytes, dict[str, str]] = {
"/mnt/vol1": {"bind": "/data", "mode": "ro"}
}
assert convert_volume_binds(data) == ["/mnt/vol1:/data:ro"] assert convert_volume_binds(data) == ["/mnt/vol1:/data:ro"]
def test_convert_volume_binds_compact(self): def test_convert_volume_binds_compact(self) -> None:
data = {"/mnt/vol1": "/data"} data: dict[str | bytes, str] = {"/mnt/vol1": "/data"}
assert convert_volume_binds(data) == ["/mnt/vol1:/data:rw"] assert convert_volume_binds(data) == ["/mnt/vol1:/data:rw"]
def test_convert_volume_binds_no_mode(self): def test_convert_volume_binds_no_mode(self) -> None:
data = {"/mnt/vol1": {"bind": "/data"}} data: dict[str | bytes, dict[str, str]] = {"/mnt/vol1": {"bind": "/data"}}
assert convert_volume_binds(data) == ["/mnt/vol1:/data:rw"] assert convert_volume_binds(data) == ["/mnt/vol1:/data:rw"]
def test_convert_volume_binds_unicode_bytes_input(self): def test_convert_volume_binds_unicode_bytes_input(self) -> None:
expected = ["/mnt/지연:/unicode/박:rw"] expected = ["/mnt/지연:/unicode/박:rw"]
data = { data: dict[str | bytes, dict[str, str | bytes]] = {
"/mnt/지연".encode("utf-8"): { "/mnt/지연".encode("utf-8"): {
"bind": "/unicode/박".encode("utf-8"), "bind": "/unicode/박".encode("utf-8"),
"mode": "rw", "mode": "rw",
@ -191,15 +195,17 @@ class ConverVolumeBindsTest(unittest.TestCase):
} }
assert convert_volume_binds(data) == expected assert convert_volume_binds(data) == expected
def test_convert_volume_binds_unicode_unicode_input(self): def test_convert_volume_binds_unicode_unicode_input(self) -> None:
expected = ["/mnt/지연:/unicode/박:rw"] expected = ["/mnt/지연:/unicode/박:rw"]
data = {"/mnt/지연": {"bind": "/unicode/박", "mode": "rw"}} data: dict[str | bytes, dict[str, str]] = {
"/mnt/지연": {"bind": "/unicode/박", "mode": "rw"}
}
assert convert_volume_binds(data) == expected assert convert_volume_binds(data) == expected
class ParseEnvFileTest(unittest.TestCase): class ParseEnvFileTest(unittest.TestCase):
def generate_tempfile(self, file_content=None): def generate_tempfile(self, file_content: str) -> str:
""" """
Generates a temporary file for tests with the content Generates a temporary file for tests with the content
of 'file_content' and returns the filename. of 'file_content' and returns the filename.
@ -209,31 +215,31 @@ class ParseEnvFileTest(unittest.TestCase):
local_tempfile.write(file_content.encode("UTF-8")) local_tempfile.write(file_content.encode("UTF-8"))
return local_tempfile.name return local_tempfile.name
def test_parse_env_file_proper(self): def test_parse_env_file_proper(self) -> None:
env_file = self.generate_tempfile(file_content="USER=jdoe\nPASS=secret") env_file = self.generate_tempfile(file_content="USER=jdoe\nPASS=secret")
get_parse_env_file = parse_env_file(env_file) get_parse_env_file = parse_env_file(env_file)
assert get_parse_env_file == {"USER": "jdoe", "PASS": "secret"} assert get_parse_env_file == {"USER": "jdoe", "PASS": "secret"}
os.unlink(env_file) os.unlink(env_file)
def test_parse_env_file_with_equals_character(self): def test_parse_env_file_with_equals_character(self) -> None:
env_file = self.generate_tempfile(file_content="USER=jdoe\nPASS=sec==ret") env_file = self.generate_tempfile(file_content="USER=jdoe\nPASS=sec==ret")
get_parse_env_file = parse_env_file(env_file) get_parse_env_file = parse_env_file(env_file)
assert get_parse_env_file == {"USER": "jdoe", "PASS": "sec==ret"} assert get_parse_env_file == {"USER": "jdoe", "PASS": "sec==ret"}
os.unlink(env_file) os.unlink(env_file)
def test_parse_env_file_commented_line(self): def test_parse_env_file_commented_line(self) -> None:
env_file = self.generate_tempfile(file_content="USER=jdoe\n#PASS=secret") env_file = self.generate_tempfile(file_content="USER=jdoe\n#PASS=secret")
get_parse_env_file = parse_env_file(env_file) get_parse_env_file = parse_env_file(env_file)
assert get_parse_env_file == {"USER": "jdoe"} assert get_parse_env_file == {"USER": "jdoe"}
os.unlink(env_file) os.unlink(env_file)
def test_parse_env_file_newline(self): def test_parse_env_file_newline(self) -> None:
env_file = self.generate_tempfile(file_content="\nUSER=jdoe\n\n\nPASS=secret") env_file = self.generate_tempfile(file_content="\nUSER=jdoe\n\n\nPASS=secret")
get_parse_env_file = parse_env_file(env_file) get_parse_env_file = parse_env_file(env_file)
assert get_parse_env_file == {"USER": "jdoe", "PASS": "secret"} assert get_parse_env_file == {"USER": "jdoe", "PASS": "secret"}
os.unlink(env_file) os.unlink(env_file)
def test_parse_env_file_invalid_line(self): def test_parse_env_file_invalid_line(self) -> None:
env_file = self.generate_tempfile(file_content="USER jdoe") env_file = self.generate_tempfile(file_content="USER jdoe")
with pytest.raises(DockerException): with pytest.raises(DockerException):
parse_env_file(env_file) parse_env_file(env_file)
@ -241,7 +247,7 @@ class ParseEnvFileTest(unittest.TestCase):
class ParseHostTest(unittest.TestCase): class ParseHostTest(unittest.TestCase):
def test_parse_host(self): def test_parse_host(self) -> None:
invalid_hosts = [ invalid_hosts = [
"foo://0.0.0.0", "foo://0.0.0.0",
"tcp://", "tcp://",
@ -282,16 +288,16 @@ class ParseHostTest(unittest.TestCase):
for host in invalid_hosts: for host in invalid_hosts:
msg = f"Should have failed to parse invalid host: {host}" msg = f"Should have failed to parse invalid host: {host}"
with self.assertRaises(DockerException, msg=msg): with self.assertRaises(DockerException, msg=msg):
parse_host(host, None) parse_host(host)
for host, expected in valid_hosts.items(): for host, expected in valid_hosts.items():
self.assertEqual( self.assertEqual(
parse_host(host, None), parse_host(host),
expected, expected,
msg=f"Failed to parse valid host: {host}", msg=f"Failed to parse valid host: {host}",
) )
def test_parse_host_empty_value(self): def test_parse_host_empty_value(self) -> None:
unix_socket = "http+unix:///var/run/docker.sock" unix_socket = "http+unix:///var/run/docker.sock"
npipe = "npipe:////./pipe/docker_engine" npipe = "npipe:////./pipe/docker_engine"
@ -299,17 +305,17 @@ class ParseHostTest(unittest.TestCase):
assert parse_host(val, is_win32=False) == unix_socket assert parse_host(val, is_win32=False) == unix_socket
assert parse_host(val, is_win32=True) == npipe assert parse_host(val, is_win32=True) == npipe
def test_parse_host_tls(self): def test_parse_host_tls(self) -> None:
host_value = "myhost.docker.net:3348" host_value = "myhost.docker.net:3348"
expected_result = "https://myhost.docker.net:3348" expected_result = "https://myhost.docker.net:3348"
assert parse_host(host_value, tls=True) == expected_result assert parse_host(host_value, tls=True) == expected_result
def test_parse_host_tls_tcp_proto(self): def test_parse_host_tls_tcp_proto(self) -> None:
host_value = "tcp://myhost.docker.net:3348" host_value = "tcp://myhost.docker.net:3348"
expected_result = "https://myhost.docker.net:3348" expected_result = "https://myhost.docker.net:3348"
assert parse_host(host_value, tls=True) == expected_result assert parse_host(host_value, tls=True) == expected_result
def test_parse_host_trailing_slash(self): def test_parse_host_trailing_slash(self) -> None:
host_value = "tcp://myhost.docker.net:2376/" host_value = "tcp://myhost.docker.net:2376/"
expected_result = "http://myhost.docker.net:2376" expected_result = "http://myhost.docker.net:2376"
assert parse_host(host_value) == expected_result assert parse_host(host_value) == expected_result
@ -318,31 +324,31 @@ class ParseHostTest(unittest.TestCase):
class ParseRepositoryTagTest(unittest.TestCase): class ParseRepositoryTagTest(unittest.TestCase):
sha = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" sha = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
def test_index_image_no_tag(self): def test_index_image_no_tag(self) -> None:
assert parse_repository_tag("root") == ("root", None) assert parse_repository_tag("root") == ("root", None)
def test_index_image_tag(self): def test_index_image_tag(self) -> None:
assert parse_repository_tag("root:tag") == ("root", "tag") assert parse_repository_tag("root:tag") == ("root", "tag")
def test_index_user_image_no_tag(self): def test_index_user_image_no_tag(self) -> None:
assert parse_repository_tag("user/repo") == ("user/repo", None) assert parse_repository_tag("user/repo") == ("user/repo", None)
def test_index_user_image_tag(self): def test_index_user_image_tag(self) -> None:
assert parse_repository_tag("user/repo:tag") == ("user/repo", "tag") assert parse_repository_tag("user/repo:tag") == ("user/repo", "tag")
def test_private_reg_image_no_tag(self): def test_private_reg_image_no_tag(self) -> None:
assert parse_repository_tag("url:5000/repo") == ("url:5000/repo", None) assert parse_repository_tag("url:5000/repo") == ("url:5000/repo", None)
def test_private_reg_image_tag(self): def test_private_reg_image_tag(self) -> None:
assert parse_repository_tag("url:5000/repo:tag") == ("url:5000/repo", "tag") assert parse_repository_tag("url:5000/repo:tag") == ("url:5000/repo", "tag")
def test_index_image_sha(self): def test_index_image_sha(self) -> None:
assert parse_repository_tag(f"root@sha256:{self.sha}") == ( assert parse_repository_tag(f"root@sha256:{self.sha}") == (
"root", "root",
f"sha256:{self.sha}", f"sha256:{self.sha}",
) )
def test_private_reg_image_sha(self): def test_private_reg_image_sha(self) -> None:
assert parse_repository_tag(f"url:5000/repo@sha256:{self.sha}") == ( assert parse_repository_tag(f"url:5000/repo@sha256:{self.sha}") == (
"url:5000/repo", "url:5000/repo",
f"sha256:{self.sha}", f"sha256:{self.sha}",
@ -350,7 +356,7 @@ class ParseRepositoryTagTest(unittest.TestCase):
class ParseDeviceTest(unittest.TestCase): class ParseDeviceTest(unittest.TestCase):
def test_dict(self): def test_dict(self) -> None:
devices = parse_devices( devices = parse_devices(
[ [
{ {
@ -366,7 +372,7 @@ class ParseDeviceTest(unittest.TestCase):
"CgroupPermissions": "r", "CgroupPermissions": "r",
} }
def test_partial_string_definition(self): def test_partial_string_definition(self) -> None:
devices = parse_devices(["/dev/sda1"]) devices = parse_devices(["/dev/sda1"])
assert devices[0] == { assert devices[0] == {
"PathOnHost": "/dev/sda1", "PathOnHost": "/dev/sda1",
@ -374,7 +380,7 @@ class ParseDeviceTest(unittest.TestCase):
"CgroupPermissions": "rwm", "CgroupPermissions": "rwm",
} }
def test_permissionless_string_definition(self): def test_permissionless_string_definition(self) -> None:
devices = parse_devices(["/dev/sda1:/dev/mnt1"]) devices = parse_devices(["/dev/sda1:/dev/mnt1"])
assert devices[0] == { assert devices[0] == {
"PathOnHost": "/dev/sda1", "PathOnHost": "/dev/sda1",
@ -382,7 +388,7 @@ class ParseDeviceTest(unittest.TestCase):
"CgroupPermissions": "rwm", "CgroupPermissions": "rwm",
} }
def test_full_string_definition(self): def test_full_string_definition(self) -> None:
devices = parse_devices(["/dev/sda1:/dev/mnt1:r"]) devices = parse_devices(["/dev/sda1:/dev/mnt1:r"])
assert devices[0] == { assert devices[0] == {
"PathOnHost": "/dev/sda1", "PathOnHost": "/dev/sda1",
@ -390,7 +396,7 @@ class ParseDeviceTest(unittest.TestCase):
"CgroupPermissions": "r", "CgroupPermissions": "r",
} }
def test_hybrid_list(self): def test_hybrid_list(self) -> None:
devices = parse_devices( devices = parse_devices(
[ [
"/dev/sda1:/dev/mnt1:rw", "/dev/sda1:/dev/mnt1:rw",
@ -415,12 +421,12 @@ class ParseDeviceTest(unittest.TestCase):
class ParseBytesTest(unittest.TestCase): class ParseBytesTest(unittest.TestCase):
def test_parse_bytes_valid(self): def test_parse_bytes_valid(self) -> None:
assert parse_bytes("512MB") == 536870912 assert parse_bytes("512MB") == 536870912
assert parse_bytes("512M") == 536870912 assert parse_bytes("512M") == 536870912
assert parse_bytes("512m") == 536870912 assert parse_bytes("512m") == 536870912
def test_parse_bytes_invalid(self): def test_parse_bytes_invalid(self) -> None:
with pytest.raises(DockerException): with pytest.raises(DockerException):
parse_bytes("512MK") parse_bytes("512MK")
with pytest.raises(DockerException): with pytest.raises(DockerException):
@ -428,15 +434,15 @@ class ParseBytesTest(unittest.TestCase):
with pytest.raises(DockerException): with pytest.raises(DockerException):
parse_bytes("127.0.0.1K") parse_bytes("127.0.0.1K")
def test_parse_bytes_float(self): def test_parse_bytes_float(self) -> None:
assert parse_bytes("1.5k") == 1536 assert parse_bytes("1.5k") == 1536
class UtilsTest(unittest.TestCase): class UtilsTest(unittest.TestCase):
longMessage = True longMessage = True
def test_convert_filters(self): def test_convert_filters(self) -> None:
tests = [ tests: list[tuple[dict[str, bool | str | int | list[str | int]], str]] = [
({"dangling": True}, '{"dangling": ["true"]}'), ({"dangling": True}, '{"dangling": ["true"]}'),
({"dangling": "true"}, '{"dangling": ["true"]}'), ({"dangling": "true"}, '{"dangling": ["true"]}'),
({"exited": 0}, '{"exited": ["0"]}'), ({"exited": 0}, '{"exited": ["0"]}'),
@ -446,7 +452,7 @@ class UtilsTest(unittest.TestCase):
for filters, expected in tests: for filters, expected in tests:
assert convert_filters(filters) == expected assert convert_filters(filters) == expected
def test_decode_json_header(self): def test_decode_json_header(self) -> None:
obj = {"a": "b", "c": 1} obj = {"a": "b", "c": 1}
data = base64.urlsafe_b64encode(bytes(json.dumps(obj), "utf-8")) data = base64.urlsafe_b64encode(bytes(json.dumps(obj), "utf-8"))
decoded_data = decode_json_header(data) decoded_data = decode_json_header(data)
@ -454,16 +460,16 @@ class UtilsTest(unittest.TestCase):
class SplitCommandTest(unittest.TestCase): class SplitCommandTest(unittest.TestCase):
def test_split_command_with_unicode(self): def test_split_command_with_unicode(self) -> None:
assert split_command("echo μμ") == ["echo", "μμ"] assert split_command("echo μμ") == ["echo", "μμ"]
class FormatEnvironmentTest(unittest.TestCase): class FormatEnvironmentTest(unittest.TestCase):
def test_format_env_binary_unicode_value(self): def test_format_env_binary_unicode_value(self) -> None:
env_dict = {"ARTIST_NAME": b"\xec\x86\xa1\xec\xa7\x80\xec\x9d\x80"} env_dict = {"ARTIST_NAME": b"\xec\x86\xa1\xec\xa7\x80\xec\x9d\x80"}
assert format_environment(env_dict) == ["ARTIST_NAME=송지은"] assert format_environment(env_dict) == ["ARTIST_NAME=송지은"]
def test_format_env_no_value(self): def test_format_env_no_value(self) -> None:
env_dict = { env_dict = {
"FOO": None, "FOO": None,
"BAR": "", "BAR": "",

View File

@ -11,7 +11,7 @@ from ansible_collections.community.docker.plugins.module_utils._compose_v2 impor
) )
EVENT_TEST_CASES: list[tuple[str, str, bool, bool, str, list[Event], list[Event]]] = [ EVENT_TEST_CASES: list[tuple[str, str, bool, bool, str, list[Event], list[str]]] = [
# ####################################################################################################################### # #######################################################################################################################
# ## Docker Compose 2.18.1 ############################################################################################## # ## Docker Compose 2.18.1 ##############################################################################################
# ####################################################################################################################### # #######################################################################################################################

View File

@ -14,7 +14,7 @@ from ansible_collections.community.docker.plugins.module_utils._compose_v2 impor
from .compose_v2_test_cases import EVENT_TEST_CASES from .compose_v2_test_cases import EVENT_TEST_CASES
EXTRA_TEST_CASES = [ EXTRA_TEST_CASES: list[tuple[str, str, bool, bool, str, list[Event], list[str]]] = [
( (
"2.24.2-manual-build-dry-run", "2.24.2-manual-build-dry-run",
"2.24.2", "2.24.2",
@ -227,9 +227,7 @@ EXTRA_TEST_CASES = [
False, False,
False, False,
# fmt: off # fmt: off
" bash_1 Skipped \n" " bash_1 Skipped \n bash_2 Pulling \n bash_2 Pulled \n",
" bash_2 Pulling \n"
" bash_2 Pulled \n",
# fmt: on # fmt: on
[ [
Event( Event(
@ -361,15 +359,24 @@ _ALL_TEST_CASES = EVENT_TEST_CASES + EXTRA_TEST_CASES
ids=[tc[0] for tc in _ALL_TEST_CASES], ids=[tc[0] for tc in _ALL_TEST_CASES],
) )
def test_parse_events( def test_parse_events(
test_id, compose_version, dry_run, nonzero_rc, stderr, events, warnings test_id: str,
): compose_version: str,
dry_run: bool,
nonzero_rc: bool,
stderr: str,
events: list[Event],
warnings: list[str],
) -> None:
collected_warnings = [] collected_warnings = []
def collect_warning(msg): def collect_warning(msg: str) -> None:
collected_warnings.append(msg) collected_warnings.append(msg)
collected_events = parse_events( collected_events = parse_events(
stderr, dry_run=dry_run, warn_function=collect_warning, nonzero_rc=nonzero_rc stderr.encode("utf-8"),
dry_run=dry_run,
warn_function=collect_warning,
nonzero_rc=nonzero_rc,
) )
print(collected_events) print(collected_events)

View File

@ -4,6 +4,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
import pytest import pytest
from ansible_collections.community.docker.plugins.module_utils._copy import ( from ansible_collections.community.docker.plugins.module_utils._copy import (
@ -11,7 +13,13 @@ from ansible_collections.community.docker.plugins.module_utils._copy import (
) )
def _simple_generator(sequence): if t.TYPE_CHECKING:
from collections.abc import Sequence
T = t.TypeVar("T")
def _simple_generator(sequence: Sequence[T]) -> t.Generator[T]:
yield from sequence yield from sequence
@ -60,10 +68,12 @@ def _simple_generator(sequence):
), ),
], ],
) )
def test__stream_generator_to_fileobj(chunks, read_sizes): def test__stream_generator_to_fileobj(
chunks = [count * data for count, data in chunks] chunks: list[tuple[int, bytes]], read_sizes: list[int]
stream = _simple_generator(chunks) ) -> None:
expected = b"".join(chunks) data_chunks = [count * data for count, data in chunks]
stream = _simple_generator(data_chunks)
expected = b"".join(data_chunks)
buffer = b"" buffer = b""
totally_read = 0 totally_read = 0

View File

@ -5,6 +5,7 @@
from __future__ import annotations from __future__ import annotations
import tarfile import tarfile
import typing as t
import pytest import pytest
@ -22,7 +23,7 @@ from ..test_support.docker_image_archive_stubbing import (
@pytest.fixture @pytest.fixture
def tar_file_name(tmpdir): def tar_file_name(tmpdir: t.Any) -> str:
""" """
Return the name of a non-existing tar file in an existing temporary directory. Return the name of a non-existing tar file in an existing temporary directory.
""" """
@ -34,11 +35,11 @@ def tar_file_name(tmpdir):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"expected, value", [("sha256:foo", "foo"), ("sha256:bar", "bar")] "expected, value", [("sha256:foo", "foo"), ("sha256:bar", "bar")]
) )
def test_api_image_id_from_archive_id(expected, value): def test_api_image_id_from_archive_id(expected: str, value: str) -> None:
assert api_image_id(value) == expected assert api_image_id(value) == expected
def test_archived_image_manifest_extracts(tar_file_name): def test_archived_image_manifest_extracts(tar_file_name: str) -> None:
expected_id = "abcde12345" expected_id = "abcde12345"
expected_tags = ["foo:latest", "bar:v1"] expected_tags = ["foo:latest", "bar:v1"]
@ -46,17 +47,20 @@ def test_archived_image_manifest_extracts(tar_file_name):
actual = archived_image_manifest(tar_file_name) actual = archived_image_manifest(tar_file_name)
assert actual is not None
assert actual.image_id == expected_id assert actual.image_id == expected_id
assert actual.repo_tags == expected_tags assert actual.repo_tags == expected_tags
def test_archived_image_manifest_extracts_nothing_when_file_not_present(tar_file_name): def test_archived_image_manifest_extracts_nothing_when_file_not_present(
tar_file_name: str,
) -> None:
image_id = archived_image_manifest(tar_file_name) image_id = archived_image_manifest(tar_file_name)
assert image_id is None assert image_id is None
def test_archived_image_manifest_raises_when_file_not_a_tar(): def test_archived_image_manifest_raises_when_file_not_a_tar() -> None:
try: try:
archived_image_manifest(__file__) archived_image_manifest(__file__)
raise AssertionError() raise AssertionError()
@ -65,7 +69,9 @@ def test_archived_image_manifest_raises_when_file_not_a_tar():
assert str(__file__) in str(e) assert str(__file__) in str(e)
def test_archived_image_manifest_raises_when_tar_missing_manifest(tar_file_name): def test_archived_image_manifest_raises_when_tar_missing_manifest(
tar_file_name: str,
) -> None:
write_irrelevant_tar(tar_file_name) write_irrelevant_tar(tar_file_name)
try: try:
@ -76,7 +82,9 @@ def test_archived_image_manifest_raises_when_tar_missing_manifest(tar_file_name)
assert "manifest.json" in str(e.__cause__) assert "manifest.json" in str(e.__cause__)
def test_archived_image_manifest_raises_when_manifest_missing_id(tar_file_name): def test_archived_image_manifest_raises_when_manifest_missing_id(
tar_file_name: str,
) -> None:
manifest = [{"foo": "bar"}] manifest = [{"foo": "bar"}]
write_imitation_archive_with_manifest(tar_file_name, manifest) write_imitation_archive_with_manifest(tar_file_name, manifest)

View File

@ -4,6 +4,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
import pytest import pytest
from ansible_collections.community.docker.plugins.module_utils._logfmt import ( from ansible_collections.community.docker.plugins.module_utils._logfmt import (
@ -12,7 +14,7 @@ from ansible_collections.community.docker.plugins.module_utils._logfmt import (
) )
SUCCESS_TEST_CASES = [ SUCCESS_TEST_CASES: list[tuple[str, dict[str, t.Any], dict[str, t.Any]]] = [
( (
'time="2024-02-02T08:14:10+01:00" level=warning msg="a network with name influxNetwork exists but was not' 'time="2024-02-02T08:14:10+01:00" level=warning msg="a network with name influxNetwork exists but was not'
' created for project \\"influxdb\\".\\nSet `external: true` to use an existing network"', ' created for project \\"influxdb\\".\\nSet `external: true` to use an existing network"',
@ -59,7 +61,7 @@ SUCCESS_TEST_CASES = [
] ]
FAILURE_TEST_CASES = [ FAILURE_TEST_CASES: list[tuple[str, dict[str, t.Any], str]] = [
( (
'foo=bar a=14 baz="hello kitty" cool%story=bro f %^asdf', 'foo=bar a=14 baz="hello kitty" cool%story=bro f %^asdf',
{"logrus_mode": True}, {"logrus_mode": True},
@ -84,14 +86,16 @@ FAILURE_TEST_CASES = [
@pytest.mark.parametrize("line, kwargs, result", SUCCESS_TEST_CASES) @pytest.mark.parametrize("line, kwargs, result", SUCCESS_TEST_CASES)
def test_parse_line_success(line, kwargs, result): def test_parse_line_success(
line: str, kwargs: dict[str, t.Any], result: dict[str, t.Any]
) -> None:
res = parse_line(line, **kwargs) res = parse_line(line, **kwargs)
print(repr(res)) print(repr(res))
assert res == result assert res == result
@pytest.mark.parametrize("line, kwargs, message", FAILURE_TEST_CASES) @pytest.mark.parametrize("line, kwargs, message", FAILURE_TEST_CASES)
def test_parse_line_failure(line, kwargs, message): def test_parse_line_failure(line: str, kwargs: dict[str, t.Any], message: str) -> None:
with pytest.raises(InvalidLogFmt) as exc: with pytest.raises(InvalidLogFmt) as exc:
parse_line(line, **kwargs) parse_line(line, **kwargs)

View File

@ -20,7 +20,7 @@ from ansible_collections.community.docker.plugins.module_utils._scramble import
("hello", b"\x01", "=S=aWRtbW4="), ("hello", b"\x01", "=S=aWRtbW4="),
], ],
) )
def test_scramble_unscramble(plaintext, key, scrambled): def test_scramble_unscramble(plaintext: str, key: bytes, scrambled: str) -> None:
scrambled_ = scramble(plaintext, key) scrambled_ = scramble(plaintext, key)
print(f"{scrambled_!r} == {scrambled!r}") print(f"{scrambled_!r} == {scrambled!r}")
assert scrambled_ == scrambled assert scrambled_ == scrambled

View File

@ -4,6 +4,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
import pytest import pytest
from ansible_collections.community.docker.plugins.module_utils._util import ( from ansible_collections.community.docker.plugins.module_utils._util import (
@ -14,15 +16,41 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
) )
DICT_ALLOW_MORE_PRESENT = ( if t.TYPE_CHECKING:
class DAMSpec(t.TypedDict):
av: dict[str, t.Any]
bv: dict[str, t.Any]
result: bool
class Spec(t.TypedDict):
a: t.Any
b: t.Any
method: t.Literal["strict", "ignore", "allow_more_present"]
type: t.Literal["value", "list", "set", "set(dict)", "dict"]
result: bool
DICT_ALLOW_MORE_PRESENT: list[DAMSpec] = [
{"av": {}, "bv": {"a": 1}, "result": True}, {"av": {}, "bv": {"a": 1}, "result": True},
{"av": {"a": 1}, "bv": {"a": 1, "b": 2}, "result": True}, {"av": {"a": 1}, "bv": {"a": 1, "b": 2}, "result": True},
{"av": {"a": 1}, "bv": {"b": 2}, "result": False}, {"av": {"a": 1}, "bv": {"b": 2}, "result": False},
{"av": {"a": 1}, "bv": {"a": None, "b": 1}, "result": False}, {"av": {"a": 1}, "bv": {"a": None, "b": 1}, "result": False},
{"av": {"a": None}, "bv": {"b": 1}, "result": False}, {"av": {"a": None}, "bv": {"b": 1}, "result": False},
) ]
COMPARE_GENERIC = [ DICT_ALLOW_MORE_PRESENT_SPECS: list[Spec] = [
{
"a": entry["av"],
"b": entry["bv"],
"method": "allow_more_present",
"type": "dict",
"result": entry["result"],
}
for entry in DICT_ALLOW_MORE_PRESENT
]
COMPARE_GENERIC: list[Spec] = [
######################################################################################## ########################################################################################
# value # value
{"a": 1, "b": 2, "method": "strict", "type": "value", "result": False}, {"a": 1, "b": 2, "method": "strict", "type": "value", "result": False},
@ -386,43 +414,34 @@ COMPARE_GENERIC = [
"type": "dict", "type": "dict",
"result": True, "result": True,
}, },
] + [
{
"a": entry["av"],
"b": entry["bv"],
"method": "allow_more_present",
"type": "dict",
"result": entry["result"],
}
for entry in DICT_ALLOW_MORE_PRESENT
] ]
@pytest.mark.parametrize("entry", DICT_ALLOW_MORE_PRESENT) @pytest.mark.parametrize("entry", DICT_ALLOW_MORE_PRESENT)
def test_dict_allow_more_present(entry): def test_dict_allow_more_present(entry: DAMSpec) -> None:
assert compare_dict_allow_more_present(entry["av"], entry["bv"]) == entry["result"] assert compare_dict_allow_more_present(entry["av"], entry["bv"]) == entry["result"]
@pytest.mark.parametrize("entry", COMPARE_GENERIC) @pytest.mark.parametrize("entry", COMPARE_GENERIC + DICT_ALLOW_MORE_PRESENT_SPECS)
def test_compare_generic(entry): def test_compare_generic(entry: Spec) -> None:
assert ( assert (
compare_generic(entry["a"], entry["b"], entry["method"], entry["type"]) compare_generic(entry["a"], entry["b"], entry["method"], entry["type"])
== entry["result"] == entry["result"]
) )
def test_convert_duration_to_nanosecond(): def test_convert_duration_to_nanosecond() -> None:
nanoseconds = convert_duration_to_nanosecond("5s") nanoseconds = convert_duration_to_nanosecond("5s")
assert nanoseconds == 5000000000 assert nanoseconds == 5000000000
nanoseconds = convert_duration_to_nanosecond("1m5s") nanoseconds = convert_duration_to_nanosecond("1m5s")
assert nanoseconds == 65000000000 assert nanoseconds == 65000000000
with pytest.raises(ValueError): with pytest.raises(ValueError):
convert_duration_to_nanosecond([1, 2, 3]) convert_duration_to_nanosecond([1, 2, 3]) # type: ignore
with pytest.raises(ValueError): with pytest.raises(ValueError):
convert_duration_to_nanosecond("10x") convert_duration_to_nanosecond("10x")
def test_parse_healthcheck(): def test_parse_healthcheck() -> None:
result, disabled = parse_healthcheck( result, disabled = parse_healthcheck(
{ {
"test": "sleep 1", "test": "sleep 1",

View File

@ -4,6 +4,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
import pytest import pytest
from ansible_collections.community.docker.plugins.modules.docker_container_copy_into import ( from ansible_collections.community.docker.plugins.modules.docker_container_copy_into import (
@ -30,7 +32,7 @@ from ansible_collections.community.docker.plugins.modules.docker_container_copy_
("-1", -1), ("-1", -1),
], ],
) )
def test_parse_string(value, expected): def test_parse_string(value: str, expected: int) -> None:
assert parse_modern(value) == expected assert parse_modern(value) == expected
assert parse_octal_string_only(value) == expected assert parse_octal_string_only(value) == expected
@ -45,10 +47,10 @@ def test_parse_string(value, expected):
123456789012345678901234567890123456789012345678901234567890, 123456789012345678901234567890123456789012345678901234567890,
], ],
) )
def test_parse_int(value): def test_parse_int(value: int) -> None:
assert parse_modern(value) == value assert parse_modern(value) == value
with pytest.raises(TypeError, match=f"^must be an octal string, got {value}L?$"): with pytest.raises(TypeError, match=f"^must be an octal string, got {value}L?$"):
parse_octal_string_only(value) parse_octal_string_only(value) # type: ignore
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -60,7 +62,7 @@ def test_parse_int(value):
{}, {},
], ],
) )
def test_parse_bad_type(value): def test_parse_bad_type(value: t.Any) -> None:
with pytest.raises(TypeError, match="^must be an octal string or an integer, got "): with pytest.raises(TypeError, match="^must be an octal string or an integer, got "):
parse_modern(value) parse_modern(value)
with pytest.raises(TypeError, match="^must be an octal string, got "): with pytest.raises(TypeError, match="^must be an octal string, got "):
@ -75,7 +77,7 @@ def test_parse_bad_type(value):
"9", "9",
], ],
) )
def test_parse_bad_value(value): def test_parse_bad_value(value: str) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
parse_modern(value) parse_modern(value)
with pytest.raises(ValueError): with pytest.raises(ValueError):

View File

@ -4,6 +4,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
import pytest import pytest
from ansible_collections.community.docker.plugins.module_utils._image_archive import ( from ansible_collections.community.docker.plugins.module_utils._image_archive import (
@ -19,19 +21,24 @@ from ..test_support.docker_image_archive_stubbing import (
) )
def assert_no_logging(msg): if t.TYPE_CHECKING:
from collections.abc import Callable
from pathlib import Path
def assert_no_logging(msg: str) -> t.NoReturn:
raise AssertionError(f"Should not have logged anything but logged {msg}") raise AssertionError(f"Should not have logged anything but logged {msg}")
def capture_logging(messages): def capture_logging(messages: list[str]) -> Callable[[str], None]:
def capture(msg): def capture(msg: str) -> None:
messages.append(msg) messages.append(msg)
return capture return capture
@pytest.fixture @pytest.fixture
def tar_file_name(tmpdir): def tar_file_name(tmpdir: t.Any) -> str:
""" """
Return the name of a non-existing tar file in an existing temporary directory. Return the name of a non-existing tar file in an existing temporary directory.
""" """
@ -39,7 +46,7 @@ def tar_file_name(tmpdir):
return tmpdir.join("foo.tar") return tmpdir.join("foo.tar")
def test_archived_image_action_when_missing(tar_file_name): def test_archived_image_action_when_missing(tar_file_name: str) -> None:
fake_name = "a:latest" fake_name = "a:latest"
fake_id = "a1" fake_id = "a1"
@ -52,7 +59,7 @@ def test_archived_image_action_when_missing(tar_file_name):
assert actual == expected assert actual == expected
def test_archived_image_action_when_current(tar_file_name): def test_archived_image_action_when_current(tar_file_name: str) -> None:
fake_name = "b:latest" fake_name = "b:latest"
fake_id = "b2" fake_id = "b2"
@ -65,7 +72,7 @@ def test_archived_image_action_when_current(tar_file_name):
assert actual is None assert actual is None
def test_archived_image_action_when_invalid(tar_file_name): def test_archived_image_action_when_invalid(tar_file_name: str) -> None:
fake_name = "c:1.2.3" fake_name = "c:1.2.3"
fake_id = "c3" fake_id = "c3"
@ -73,7 +80,7 @@ def test_archived_image_action_when_invalid(tar_file_name):
expected = f"Archived image {fake_name} to {tar_file_name}, overwriting an unreadable archive file" expected = f"Archived image {fake_name} to {tar_file_name}, overwriting an unreadable archive file"
actual_log = [] actual_log: list[str] = []
actual = ImageManager.archived_image_action( actual = ImageManager.archived_image_action(
capture_logging(actual_log), tar_file_name, fake_name, api_image_id(fake_id) capture_logging(actual_log), tar_file_name, fake_name, api_image_id(fake_id)
) )
@ -84,7 +91,7 @@ def test_archived_image_action_when_invalid(tar_file_name):
assert actual_log[0].startswith("Unable to extract manifest summary from archive") assert actual_log[0].startswith("Unable to extract manifest summary from archive")
def test_archived_image_action_when_obsolete_by_id(tar_file_name): def test_archived_image_action_when_obsolete_by_id(tar_file_name: str) -> None:
fake_name = "d:0.0.1" fake_name = "d:0.0.1"
old_id = "e5" old_id = "e5"
new_id = "d4" new_id = "d4"
@ -99,7 +106,7 @@ def test_archived_image_action_when_obsolete_by_id(tar_file_name):
assert actual == expected assert actual == expected
def test_archived_image_action_when_obsolete_by_name(tar_file_name): def test_archived_image_action_when_obsolete_by_name(tar_file_name: str) -> None:
old_name = "hi" old_name = "hi"
new_name = "d:0.0.1" new_name = "d:0.0.1"
fake_id = "d4" fake_id = "d4"

View File

@ -21,5 +21,5 @@ from ansible_collections.community.docker.plugins.modules.docker_image_build imp
('\rhello, "hi" !\n', '"\rhello, ""hi"" !\n"'), ('\rhello, "hi" !\n', '"\rhello, ""hi"" !\n"'),
], ],
) )
def test__quote_csv(value, expected): def test__quote_csv(value: str, expected: str) -> None:
assert _quote_csv(value) == expected assert _quote_csv(value) == expected

View File

@ -6,6 +6,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
import pytest import pytest
from ansible_collections.community.docker.plugins.modules.docker_network import ( from ansible_collections.community.docker.plugins.modules.docker_network import (
@ -23,7 +25,9 @@ from ansible_collections.community.docker.plugins.modules.docker_network import
("fdd1:ac8c:0557:7ce2::/128", "ipv6"), ("fdd1:ac8c:0557:7ce2::/128", "ipv6"),
], ],
) )
def test_validate_cidr_positives(cidr, expected): def test_validate_cidr_positives(
cidr: str, expected: t.Literal["ipv4", "ipv6"]
) -> None:
assert validate_cidr(cidr) == expected assert validate_cidr(cidr) == expected
@ -36,7 +40,7 @@ def test_validate_cidr_positives(cidr, expected):
"fdd1:ac8c:0557:7ce2::", "fdd1:ac8c:0557:7ce2::",
], ],
) )
def test_validate_cidr_negatives(cidr): def test_validate_cidr_negatives(cidr: str) -> None:
with pytest.raises(ValueError) as e: with pytest.raises(ValueError) as e:
validate_cidr(cidr) validate_cidr(cidr)
assert f'"{cidr}" is not a valid CIDR' == str(e.value) assert f'"{cidr}" is not a valid CIDR' == str(e.value)

View File

@ -4,66 +4,47 @@
from __future__ import annotations from __future__ import annotations
import typing as t
import pytest import pytest
from ansible_collections.community.docker.plugins.modules import (
class APIErrorMock(Exception): docker_swarm_service,
def __init__(self, message, response=None, explanation=None): )
self.message = message
self.response = response
self.explanation = explanation
@pytest.fixture(autouse=True) APIError = pytest.importorskip("docker.errors.APIError")
def docker_module_mock(mocker):
docker_module_mock = mocker.MagicMock()
docker_utils_module_mock = mocker.MagicMock()
docker_errors_module_mock = mocker.MagicMock()
docker_errors_module_mock.APIError = APIErrorMock
mock_modules = {
"docker": docker_module_mock,
"docker.utils": docker_utils_module_mock,
"docker.errors": docker_errors_module_mock,
}
return mocker.patch.dict("sys.modules", **mock_modules)
@pytest.fixture(autouse=True) def test_retry_on_out_of_sequence_error(mocker: t.Any) -> None:
def docker_swarm_service():
from ansible_collections.community.docker.plugins.modules import (
docker_swarm_service,
)
return docker_swarm_service
def test_retry_on_out_of_sequence_error(mocker, docker_swarm_service):
run_mock = mocker.MagicMock( run_mock = mocker.MagicMock(
side_effect=APIErrorMock( side_effect=APIError(
message="", message="",
response=None, response=None,
explanation="rpc error: code = Unknown desc = update out of sequence", explanation="rpc error: code = Unknown desc = update out of sequence",
) )
) )
manager = docker_swarm_service.DockerServiceManager(client=None) mocker.patch("time.sleep")
manager.run = run_mock manager = docker_swarm_service.DockerServiceManager(client=None) # type: ignore
with pytest.raises(APIErrorMock): manager.run = run_mock # type: ignore
with pytest.raises(APIError):
manager.run_safe() manager.run_safe()
assert run_mock.call_count == 3 assert run_mock.call_count == 3
def test_no_retry_on_general_api_error(mocker, docker_swarm_service): def test_no_retry_on_general_api_error(mocker: t.Any) -> None:
run_mock = mocker.MagicMock( run_mock = mocker.MagicMock(
side_effect=APIErrorMock(message="", response=None, explanation="some error") side_effect=APIError(message="", response=None, explanation="some error")
) )
manager = docker_swarm_service.DockerServiceManager(client=None) mocker.patch("time.sleep")
manager.run = run_mock manager = docker_swarm_service.DockerServiceManager(client=None) # type: ignore
with pytest.raises(APIErrorMock): manager.run = run_mock # type: ignore
with pytest.raises(APIError):
manager.run_safe() manager.run_safe()
assert run_mock.call_count == 1 assert run_mock.call_count == 1
def test_get_docker_environment(mocker, docker_swarm_service): def test_get_docker_environment(mocker: t.Any) -> None:
env_file_result = {"TEST1": "A", "TEST2": "B", "TEST3": "C"} env_file_result = {"TEST1": "A", "TEST2": "B", "TEST3": "C"}
env_dict = {"TEST3": "CC", "TEST4": "D"} env_dict = {"TEST3": "CC", "TEST4": "D"}
env_string = "TEST3=CC,TEST4=D" env_string = "TEST3=CC,TEST4=D"
@ -103,7 +84,7 @@ def test_get_docker_environment(mocker, docker_swarm_service):
assert result == [] assert result == []
def test_get_nanoseconds_from_raw_option(docker_swarm_service): def test_get_nanoseconds_from_raw_option() -> None:
value = docker_swarm_service.get_nanoseconds_from_raw_option("test", None) value = docker_swarm_service.get_nanoseconds_from_raw_option("test", None)
assert value is None assert value is None
@ -117,7 +98,7 @@ def test_get_nanoseconds_from_raw_option(docker_swarm_service):
docker_swarm_service.get_nanoseconds_from_raw_option("test", []) docker_swarm_service.get_nanoseconds_from_raw_option("test", [])
def test_has_dict_changed(docker_swarm_service): def test_has_dict_changed() -> None:
assert not docker_swarm_service.has_dict_changed( assert not docker_swarm_service.has_dict_changed(
{"a": 1}, {"a": 1},
{"a": 1}, {"a": 1},
@ -135,8 +116,7 @@ def test_has_dict_changed(docker_swarm_service):
assert not docker_swarm_service.has_dict_changed(None, {}) assert not docker_swarm_service.has_dict_changed(None, {})
def test_has_list_changed(docker_swarm_service): def test_has_list_changed() -> None:
# List comparisons without dictionaries # List comparisons without dictionaries
# I could improve the indenting, but pycodestyle wants this instead # I could improve the indenting, but pycodestyle wants this instead
assert not docker_swarm_service.has_list_changed(None, None) assert not docker_swarm_service.has_list_changed(None, None)
@ -161,7 +141,7 @@ def test_has_list_changed(docker_swarm_service):
assert docker_swarm_service.has_list_changed([None, 1], [2, 1]) assert docker_swarm_service.has_list_changed([None, 1], [2, 1])
assert docker_swarm_service.has_list_changed([2, 1], [None, 1]) assert docker_swarm_service.has_list_changed([2, 1], [None, 1])
assert docker_swarm_service.has_list_changed( assert docker_swarm_service.has_list_changed(
"command --with args", ["command", "--with", "args"] ["command --with args"], ["command", "--with", "args"]
) )
assert docker_swarm_service.has_list_changed( assert docker_swarm_service.has_list_changed(
["sleep", "3400"], ["sleep", "3600"], sort_lists=False ["sleep", "3400"], ["sleep", "3600"], sort_lists=False
@ -259,7 +239,7 @@ def test_has_list_changed(docker_swarm_service):
) )
def test_have_networks_changed(docker_swarm_service): def test_have_networks_changed() -> None:
assert not docker_swarm_service.have_networks_changed(None, None) assert not docker_swarm_service.have_networks_changed(None, None)
assert not docker_swarm_service.have_networks_changed([], None) assert not docker_swarm_service.have_networks_changed([], None)
@ -329,14 +309,14 @@ def test_have_networks_changed(docker_swarm_service):
) )
def test_get_docker_networks(docker_swarm_service): def test_get_docker_networks() -> None:
network_names = [ network_names = [
"network_1", "network_1",
"network_2", "network_2",
"network_3", "network_3",
"network_4", "network_4",
] ]
networks = [ networks: list[str | dict[str, t.Any]] = [
network_names[0], network_names[0],
{"name": network_names[1]}, {"name": network_names[1]},
{"name": network_names[2], "aliases": ["networkalias1"]}, {"name": network_names[2], "aliases": ["networkalias1"]},
@ -367,28 +347,27 @@ def test_get_docker_networks(docker_swarm_service):
assert "foo" in network["options"] assert "foo" in network["options"]
# Test missing name # Test missing name
with pytest.raises(TypeError): with pytest.raises(TypeError):
docker_swarm_service.get_docker_networks([{"invalid": "err"}], {"err": 1}) docker_swarm_service.get_docker_networks([{"invalid": "err"}], {"err": "x"})
# test for invalid aliases type # test for invalid aliases type
with pytest.raises(TypeError): with pytest.raises(TypeError):
docker_swarm_service.get_docker_networks( docker_swarm_service.get_docker_networks(
[{"name": "test", "aliases": 1}], {"test": 1} [{"name": "test", "aliases": 1}], {"test": "x"}
) )
# Test invalid aliases elements # Test invalid aliases elements
with pytest.raises(TypeError): with pytest.raises(TypeError):
docker_swarm_service.get_docker_networks( docker_swarm_service.get_docker_networks(
[{"name": "test", "aliases": [1]}], {"test": 1} [{"name": "test", "aliases": [1]}], {"test": "x"}
) )
# Test for invalid options type # Test for invalid options type
with pytest.raises(TypeError): with pytest.raises(TypeError):
docker_swarm_service.get_docker_networks( docker_swarm_service.get_docker_networks(
[{"name": "test", "options": 1}], {"test": 1} [{"name": "test", "options": 1}], {"test": "x"}
) )
# Test for invalid networks type
with pytest.raises(TypeError):
docker_swarm_service.get_docker_networks(1, {"test": 1})
# Test for non existing networks # Test for non existing networks
with pytest.raises(ValueError): with pytest.raises(ValueError):
docker_swarm_service.get_docker_networks([{"name": "idontexist"}], {"test": 1}) docker_swarm_service.get_docker_networks(
[{"name": "idontexist"}], {"test": "x"}
)
# Test empty values # Test empty values
assert docker_swarm_service.get_docker_networks([], {}) == [] assert docker_swarm_service.get_docker_networks([], {}) == []
assert docker_swarm_service.get_docker_networks(None, {}) is None assert docker_swarm_service.get_docker_networks(None, {}) is None

View File

@ -4,6 +4,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
import pytest import pytest
from ansible_collections.community.internal_test_tools.tests.unit.utils.trust import ( from ansible_collections.community.internal_test_tools.tests.unit.utils.trust import (
SUPPORTS_DATA_TAGGING, SUPPORTS_DATA_TAGGING,
@ -23,7 +25,9 @@ from ansible_collections.community.docker.plugins.plugin_utils._unsafe import (
) )
TEST_MAKE_UNSAFE = [ TEST_MAKE_UNSAFE: list[
tuple[t.Any, list[tuple[t.Any, ...]], list[tuple[t.Any, ...]]]
] = [
( (
_make_trusted("text"), _make_trusted("text"),
[], [],
@ -97,7 +101,11 @@ if not SUPPORTS_DATA_TAGGING:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"value, check_unsafe_paths, check_safe_paths", TEST_MAKE_UNSAFE "value, check_unsafe_paths, check_safe_paths", TEST_MAKE_UNSAFE
) )
def test_make_unsafe(value, check_unsafe_paths, check_safe_paths): def test_make_unsafe(
value: t.Any,
check_unsafe_paths: list[tuple[t.Any, ...]],
check_safe_paths: list[tuple[t.Any, ...]],
) -> None:
unsafe_value = make_unsafe(value) unsafe_value = make_unsafe(value)
assert unsafe_value == value assert unsafe_value == value
for check_path in check_unsafe_paths: for check_path in check_unsafe_paths:
@ -112,7 +120,7 @@ def test_make_unsafe(value, check_unsafe_paths, check_safe_paths):
assert _is_trusted(obj) assert _is_trusted(obj)
def test_make_unsafe_idempotence(): def test_make_unsafe_idempotence() -> None:
assert make_unsafe(None) is None assert make_unsafe(None) is None
unsafe_str = _make_untrusted("{{test}}") unsafe_str = _make_untrusted("{{test}}")
@ -122,8 +130,8 @@ def test_make_unsafe_idempotence():
assert id(make_unsafe(safe_str)) != id(safe_str) assert id(make_unsafe(safe_str)) != id(safe_str)
def test_make_unsafe_dict_key(): def test_make_unsafe_dict_key() -> None:
value = { value: dict[t.Any, t.Any] = {
_make_trusted("test"): 2, _make_trusted("test"): 2,
} }
if not SUPPORTS_DATA_TAGGING: if not SUPPORTS_DATA_TAGGING:
@ -144,8 +152,8 @@ def test_make_unsafe_dict_key():
assert not _is_trusted(obj) assert not _is_trusted(obj)
def test_make_unsafe_set(): def test_make_unsafe_set() -> None:
value = set([_make_trusted("test")]) value: set[t.Any] = set([_make_trusted("test")])
if not SUPPORTS_DATA_TAGGING: if not SUPPORTS_DATA_TAGGING:
value.add(_make_trusted(b"test")) value.add(_make_trusted(b"test"))
unsafe_value = make_unsafe(value) unsafe_value = make_unsafe(value)

View File

@ -6,10 +6,13 @@ from __future__ import annotations
import json import json
import tarfile import tarfile
import typing as t
from tempfile import TemporaryFile from tempfile import TemporaryFile
def write_imitation_archive(file_name, image_id, repo_tags): def write_imitation_archive(
file_name: str, image_id: str, repo_tags: list[str]
) -> None:
""" """
Write a tar file meeting these requirements: Write a tar file meeting these requirements:
@ -21,7 +24,7 @@ def write_imitation_archive(file_name, image_id, repo_tags):
:type file_name: str :type file_name: str
:param image_id: Fake sha256 hash (without the sha256: prefix) :param image_id: Fake sha256 hash (without the sha256: prefix)
:type image_id: str :type image_id: str
:param repo_tags: list of fake image:tag's :param repo_tags: list of fake image tags
:type repo_tags: list :type repo_tags: list
""" """
@ -30,7 +33,9 @@ def write_imitation_archive(file_name, image_id, repo_tags):
write_imitation_archive_with_manifest(file_name, manifest) write_imitation_archive_with_manifest(file_name, manifest)
def write_imitation_archive_with_manifest(file_name, manifest): def write_imitation_archive_with_manifest(
file_name: str, manifest: list[dict[str, t.Any]]
) -> None:
with tarfile.open(file_name, "w") as tf: with tarfile.open(file_name, "w") as tf:
with TemporaryFile() as f: with TemporaryFile() as f:
f.write(json.dumps(manifest).encode("utf-8")) f.write(json.dumps(manifest).encode("utf-8"))
@ -42,7 +47,7 @@ def write_imitation_archive_with_manifest(file_name, manifest):
tf.addfile(ti, f) tf.addfile(ti, f)
def write_irrelevant_tar(file_name): def write_irrelevant_tar(file_name: str) -> None:
""" """
Create a tar file that does not match the spec for "docker image save" / "docker image load" commands. Create a tar file that does not match the spec for "docker image save" / "docker image load" commands.

View File

@ -2,4 +2,5 @@
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt) # GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
docker
requests requests