Add typing to unit tests.

This commit is contained in:
Felix Fontein 2025-10-24 22:12:35 +02:00
parent 13d39a36eb
commit a2deb384d4
35 changed files with 758 additions and 605 deletions

View File

@ -108,7 +108,7 @@ def port_range(
def split_port( def split_port(
port: str, port: str | int,
) -> tuple[list[str], list[str] | list[tuple[str, str | None]] | None]: ) -> tuple[list[str], list[str] | list[tuple[str, str | None]] | None]:
port = str(port) port = str(port)
match = PORT_SPEC.match(port) match = PORT_SPEC.match(port)

View File

@ -11,6 +11,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from .utils import format_environment from .utils import format_environment
@ -67,7 +69,17 @@ class ProxyConfig(dict):
env["no_proxy"] = env["NO_PROXY"] = self.no_proxy env["no_proxy"] = env["NO_PROXY"] = self.no_proxy
return env return env
def inject_proxy_environment(self, environment: list[str]) -> list[str]: @t.overload
def inject_proxy_environment(self, environment: list[str]) -> list[str]: ...
@t.overload
def inject_proxy_environment(
self, environment: list[str] | None
) -> list[str] | None: ...
def inject_proxy_environment(
self, environment: list[str] | None
) -> list[str] | None:
""" """
Given a list of strings representing environment variables, prepend the Given a list of strings representing environment variables, prepend the
environment variables corresponding to the proxy settings. environment variables corresponding to the proxy settings.

View File

@ -46,7 +46,7 @@ URLComponents = collections.namedtuple(
) )
def decode_json_header(header: str) -> dict[str, t.Any]: def decode_json_header(header: str | bytes) -> dict[str, t.Any]:
data = base64.b64decode(header).decode("utf-8") data = base64.b64decode(header).decode("utf-8")
return json.loads(data) return json.loads(data)
@ -143,7 +143,12 @@ def convert_port_bindings(
def convert_volume_binds( def convert_volume_binds(
binds: list[str] | Mapping[str | bytes, dict[str, str | bytes] | bytes | str | int], binds: (
list[str]
| Mapping[
str | bytes, dict[str, str | bytes] | dict[str, str] | bytes | str | int
]
),
) -> list[str]: ) -> list[str]:
if isinstance(binds, list): if isinstance(binds, list):
return binds # type: ignore return binds # type: ignore
@ -403,7 +408,9 @@ def kwargs_from_env(
return params return params
def convert_filters(filters: Mapping[str, bool | str | list[str]]) -> str: def convert_filters(
filters: Mapping[str, bool | str | int | list[int] | list[str] | list[str | int]],
) -> str:
result = {} result = {}
for k, v in filters.items(): for k, v in filters.items():
if isinstance(v, bool): if isinstance(v, bool):

View File

@ -939,6 +939,18 @@ def get_docker_environment(
return sorted(env_list) return sorted(env_list)
@t.overload
def get_docker_networks(
networks: list[str | dict[str, t.Any]], network_ids: dict[str, str]
) -> list[dict[str, t.Any]]: ...
@t.overload
def get_docker_networks(
networks: list[str | dict[str, t.Any]] | None, network_ids: dict[str, str]
) -> list[dict[str, t.Any]] | None: ...
def get_docker_networks( def get_docker_networks(
networks: list[str | dict[str, t.Any]] | None, network_ids: dict[str, str] networks: list[str | dict[str, t.Any]] | None, network_ids: dict[str, str]
) -> list[dict[str, t.Any]] | None: ) -> list[dict[str, t.Any]] | None:

View File

@ -14,8 +14,7 @@ from ansible.plugins.loader import connection_loader
class TestDockerConnectionClass(unittest.TestCase): class TestDockerConnectionClass(unittest.TestCase):
def setUp(self) -> None:
def setUp(self):
self.play_context = PlayContext() self.play_context = PlayContext()
self.play_context.prompt = ( self.play_context.prompt = (
"[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: " "[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: "
@ -29,7 +28,7 @@ class TestDockerConnectionClass(unittest.TestCase):
"community.docker.docker", self.play_context, self.in_stream "community.docker.docker", self.play_context, self.in_stream
) )
def tearDown(self): def tearDown(self) -> None:
pass pass
@mock.patch( @mock.patch(
@ -42,7 +41,7 @@ class TestDockerConnectionClass(unittest.TestCase):
) )
def test_docker_connection_module_too_old( def test_docker_connection_module_too_old(
self, mock_new_docker_version, mock_old_docker_version self, mock_new_docker_version, mock_old_docker_version
): ) -> None:
self.dc._version = None self.dc._version = None
self.dc.remote_user = "foo" self.dc.remote_user = "foo"
self.assertRaisesRegex( self.assertRaisesRegex(
@ -61,7 +60,7 @@ class TestDockerConnectionClass(unittest.TestCase):
) )
def test_docker_connection_module( def test_docker_connection_module(
self, mock_new_docker_version, mock_old_docker_version self, mock_new_docker_version, mock_old_docker_version
): ) -> None:
self.dc._version = None self.dc._version = None
# old version and new version fail # old version and new version fail
@ -75,7 +74,7 @@ class TestDockerConnectionClass(unittest.TestCase):
) )
def test_docker_connection_module_wrong_cmd( def test_docker_connection_module_wrong_cmd(
self, mock_new_docker_version, mock_old_docker_version self, mock_new_docker_version, mock_old_docker_version
): ) -> None:
self.dc._version = None self.dc._version = None
self.dc.remote_user = "foo" self.dc.remote_user = "foo"
self.assertRaisesRegex( self.assertRaisesRegex(

View File

@ -4,6 +4,7 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from unittest.mock import create_autospec from unittest.mock import create_autospec
import pytest import pytest
@ -19,14 +20,18 @@ from ansible_collections.community.docker.plugins.inventory.docker_containers im
) )
if t.TYPE_CHECKING:
from collections.abc import Callable
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def templar(): def templar() -> Templar:
dataloader = create_autospec(DataLoader, instance=True) dataloader = create_autospec(DataLoader, instance=True)
return Templar(loader=dataloader) return Templar(loader=dataloader)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def inventory(templar): def inventory(templar) -> InventoryModule:
r = InventoryModule() r = InventoryModule()
r.inventory = InventoryData() r.inventory = InventoryData()
r.templar = templar r.templar = templar
@ -83,7 +88,9 @@ LOVING_THARP_SERVICE = {
} }
def create_get_option(options, default=False): def create_get_option(
options: dict[str, t.Any], default: t.Any = False
) -> Callable[[str], t.Any]:
def get_option(option): def get_option(option):
if option in options: if option in options:
return options[option] return options[option]
@ -93,9 +100,9 @@ def create_get_option(options, default=False):
class FakeClient: class FakeClient:
def __init__(self, *hosts): def __init__(self, *hosts: dict[str, t.Any]) -> None:
self.get_results = {} self.get_results: dict[str, t.Any] = {}
list_reply = [] list_reply: list[dict[str, t.Any]] = []
for host in hosts: for host in hosts:
list_reply.append( list_reply.append(
{ {
@ -109,15 +116,16 @@ class FakeClient:
self.get_results[f"/containers/{host['Id']}/json"] = host self.get_results[f"/containers/{host['Id']}/json"] = host
self.get_results["/containers/json"] = list_reply self.get_results["/containers/json"] = list_reply
def get_json(self, url, *param, **kwargs): def get_json(self, url: str, *param: str, **kwargs) -> t.Any:
url = url.format(*param) url = url.format(*param)
return self.get_results[url] return self.get_results[url]
def test_populate(inventory, mocker): def test_populate(inventory: InventoryModule, mocker) -> None:
assert inventory.inventory is not None
client = FakeClient(LOVING_THARP) client = FakeClient(LOVING_THARP)
inventory.get_option = mocker.MagicMock( inventory.get_option = mocker.MagicMock( # type: ignore[method-assign]
side_effect=create_get_option( side_effect=create_get_option(
{ {
"verbose_output": True, "verbose_output": True,
@ -130,9 +138,10 @@ def test_populate(inventory, mocker):
} }
) )
) )
inventory._populate(client) inventory._populate(client) # type: ignore
host_1 = inventory.inventory.get_host("loving_tharp") host_1 = inventory.inventory.get_host("loving_tharp")
assert host_1 is not None
host_1_vars = host_1.get_vars() host_1_vars = host_1.get_vars()
assert host_1_vars["ansible_host"] == "loving_tharp" assert host_1_vars["ansible_host"] == "loving_tharp"
@ -149,10 +158,11 @@ def test_populate(inventory, mocker):
assert len(inventory.inventory.hosts) == 1 assert len(inventory.inventory.hosts) == 1
def test_populate_service(inventory, mocker): def test_populate_service(inventory: InventoryModule, mocker) -> None:
assert inventory.inventory is not None
client = FakeClient(LOVING_THARP_SERVICE) client = FakeClient(LOVING_THARP_SERVICE)
inventory.get_option = mocker.MagicMock( inventory.get_option = mocker.MagicMock( # type: ignore[method-assign]
side_effect=create_get_option( side_effect=create_get_option(
{ {
"verbose_output": False, "verbose_output": False,
@ -166,9 +176,10 @@ def test_populate_service(inventory, mocker):
} }
) )
) )
inventory._populate(client) inventory._populate(client) # type: ignore
host_1 = inventory.inventory.get_host("loving_tharp") host_1 = inventory.inventory.get_host("loving_tharp")
assert host_1 is not None
host_1_vars = host_1.get_vars() host_1_vars = host_1.get_vars()
assert host_1_vars["ansible_host"] == "loving_tharp" assert host_1_vars["ansible_host"] == "loving_tharp"
@ -207,10 +218,11 @@ def test_populate_service(inventory, mocker):
assert len(inventory.inventory.hosts) == 1 assert len(inventory.inventory.hosts) == 1
def test_populate_stack(inventory, mocker): def test_populate_stack(inventory: InventoryModule, mocker) -> None:
assert inventory.inventory is not None
client = FakeClient(LOVING_THARP_STACK) client = FakeClient(LOVING_THARP_STACK)
inventory.get_option = mocker.MagicMock( inventory.get_option = mocker.MagicMock( # type: ignore[method-assign]
side_effect=create_get_option( side_effect=create_get_option(
{ {
"verbose_output": False, "verbose_output": False,
@ -226,9 +238,10 @@ def test_populate_stack(inventory, mocker):
} }
) )
) )
inventory._populate(client) inventory._populate(client) # type: ignore
host_1 = inventory.inventory.get_host("loving_tharp") host_1 = inventory.inventory.get_host("loving_tharp")
assert host_1 is not None
host_1_vars = host_1.get_vars() host_1_vars = host_1.get_vars()
assert host_1_vars["ansible_ssh_host"] == "127.0.0.1" assert host_1_vars["ansible_ssh_host"] == "127.0.0.1"
@ -267,10 +280,11 @@ def test_populate_stack(inventory, mocker):
assert len(inventory.inventory.hosts) == 1 assert len(inventory.inventory.hosts) == 1
def test_populate_filter_none(inventory, mocker): def test_populate_filter_none(inventory: InventoryModule, mocker) -> None:
assert inventory.inventory is not None
client = FakeClient(LOVING_THARP) client = FakeClient(LOVING_THARP)
inventory.get_option = mocker.MagicMock( inventory.get_option = mocker.MagicMock( # type: ignore[method-assign]
side_effect=create_get_option( side_effect=create_get_option(
{ {
"verbose_output": True, "verbose_output": True,
@ -285,15 +299,16 @@ def test_populate_filter_none(inventory, mocker):
} }
) )
) )
inventory._populate(client) inventory._populate(client) # type: ignore
assert len(inventory.inventory.hosts) == 0 assert len(inventory.inventory.hosts) == 0
def test_populate_filter(inventory, mocker): def test_populate_filter(inventory: InventoryModule, mocker) -> None:
assert inventory.inventory is not None
client = FakeClient(LOVING_THARP) client = FakeClient(LOVING_THARP)
inventory.get_option = mocker.MagicMock( inventory.get_option = mocker.MagicMock( # type: ignore[method-assign]
side_effect=create_get_option( side_effect=create_get_option(
{ {
"verbose_output": True, "verbose_output": True,
@ -309,9 +324,10 @@ def test_populate_filter(inventory, mocker):
} }
) )
) )
inventory._populate(client) inventory._populate(client) # type: ignore
host_1 = inventory.inventory.get_host("loving_tharp") host_1 = inventory.inventory.get_host("loving_tharp")
assert host_1 is not None
host_1_vars = host_1.get_vars() host_1_vars = host_1.get_vars()
assert host_1_vars["ansible_host"] == "loving_tharp" assert host_1_vars["ansible_host"] == "loving_tharp"

View File

@ -19,6 +19,7 @@ import struct
import tempfile import tempfile
import threading import threading
import time import time
import typing as t
import unittest import unittest
from http.server import BaseHTTPRequestHandler from http.server import BaseHTTPRequestHandler
from socketserver import ThreadingTCPServer from socketserver import ThreadingTCPServer
@ -46,14 +47,14 @@ DEFAULT_TIMEOUT_SECONDS = constants.DEFAULT_TIMEOUT_SECONDS
def response( def response(
status_code=200, status_code: int = 200,
content="", content: bytes | dict[str, t.Any] | list[dict[str, t.Any]] = b"",
headers=None, headers: dict[str, str] | None = None,
reason=None, reason: str = "",
elapsed=0, elapsed: int = 0,
request=None, request=None,
raw=None, raw=None,
): ) -> requests.Response:
res = requests.Response() res = requests.Response()
res.status_code = status_code res.status_code = status_code
if not isinstance(content, bytes): if not isinstance(content, bytes):
@ -67,18 +68,20 @@ def response(
return res return res
def fake_resolve_authconfig( def fake_resolve_authconfig( # pylint: disable=keyword-arg-before-vararg
authconfig, registry=None, *args, **kwargs authconfig, *args, registry=None, **kwargs
): # pylint: disable=keyword-arg-before-vararg ) -> None:
return None return None
def fake_inspect_container(self, container, tty=False): def fake_inspect_container(self, container: str, tty: bool = False):
return fake_api.get_fake_inspect_container(tty=tty)[1] return fake_api.get_fake_inspect_container(tty=tty)[1]
def fake_resp(method, url, *args, **kwargs): def fake_resp(
key = None method: str, url: str, *args: t.Any, **kwargs: t.Any
) -> requests.Response:
key: str | tuple[str, str] | None = None
if url in fake_api.fake_responses: if url in fake_api.fake_responses:
key = url key = url
elif (url, method) in fake_api.fake_responses: elif (url, method) in fake_api.fake_responses:
@ -92,23 +95,29 @@ def fake_resp(method, url, *args, **kwargs):
fake_request = mock.Mock(side_effect=fake_resp) fake_request = mock.Mock(side_effect=fake_resp)
def fake_get(self, url, *args, **kwargs): def fake_get(self, url: str, *args, **kwargs) -> requests.Response:
return fake_request("GET", url, *args, **kwargs) return fake_request("GET", url, *args, **kwargs)
def fake_post(self, url, *args, **kwargs): def fake_post(self, url: str, *args, **kwargs) -> requests.Response:
return fake_request("POST", url, *args, **kwargs) return fake_request("POST", url, *args, **kwargs)
def fake_put(self, url, *args, **kwargs): def fake_put(self, url: str, *args, **kwargs) -> requests.Response:
return fake_request("PUT", url, *args, **kwargs) return fake_request("PUT", url, *args, **kwargs)
def fake_delete(self, url, *args, **kwargs): def fake_delete(self, url: str, *args, **kwargs) -> requests.Response:
return fake_request("DELETE", url, *args, **kwargs) return fake_request("DELETE", url, *args, **kwargs)
def fake_read_from_socket(self, response, stream, tty=False, demux=False): def fake_read_from_socket(
self,
response: requests.Response,
stream: bool,
tty: bool = False,
demux: bool = False,
) -> bytes:
return b"" return b""
@ -117,7 +126,7 @@ url_prefix = f"{url_base}v{DEFAULT_DOCKER_API_VERSION}/" # pylint: disable=inva
class BaseAPIClientTest(unittest.TestCase): class BaseAPIClientTest(unittest.TestCase):
def setUp(self): def setUp(self) -> None:
self.patcher = mock.patch.multiple( self.patcher = mock.patch.multiple(
"ansible_collections.community.docker.plugins.module_utils._api.api.client.APIClient", "ansible_collections.community.docker.plugins.module_utils._api.api.client.APIClient",
get=fake_get, get=fake_get,
@ -129,11 +138,13 @@ class BaseAPIClientTest(unittest.TestCase):
self.patcher.start() self.patcher.start()
self.client = APIClient(version=DEFAULT_DOCKER_API_VERSION) self.client = APIClient(version=DEFAULT_DOCKER_API_VERSION)
def tearDown(self): def tearDown(self) -> None:
self.client.close() self.client.close()
self.patcher.stop() self.patcher.stop()
def base_create_payload(self, img="busybox", cmd=None): def base_create_payload(
self, img: str = "busybox", cmd: list[str] | None = None
) -> dict[str, t.Any]:
if not cmd: if not cmd:
cmd = ["true"] cmd = ["true"]
return { return {
@ -150,16 +161,16 @@ class BaseAPIClientTest(unittest.TestCase):
class DockerApiTest(BaseAPIClientTest): class DockerApiTest(BaseAPIClientTest):
def test_ctor(self): def test_ctor(self) -> None:
with pytest.raises(errors.DockerException) as excinfo: with pytest.raises(errors.DockerException) as excinfo:
APIClient(version=1.12) APIClient(version=1.12) # type: ignore
assert ( assert (
str(excinfo.value) str(excinfo.value)
== "Version parameter must be a string or None. Found float" == "Version parameter must be a string or None. Found float"
) )
def test_url_valid_resource(self): def test_url_valid_resource(self) -> None:
url = self.client._url("/hello/{0}/world", "somename") url = self.client._url("/hello/{0}/world", "somename")
assert url == f"{url_prefix}hello/somename/world" assert url == f"{url_prefix}hello/somename/world"
@ -172,50 +183,50 @@ class DockerApiTest(BaseAPIClientTest):
url = self.client._url("/images/{0}/push", "localhost:5000/image") url = self.client._url("/images/{0}/push", "localhost:5000/image")
assert url == f"{url_prefix}images/localhost:5000/image/push" assert url == f"{url_prefix}images/localhost:5000/image/push"
def test_url_invalid_resource(self): def test_url_invalid_resource(self) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
self.client._url("/hello/{0}/world", ["sakuya", "izayoi"]) self.client._url("/hello/{0}/world", ["sakuya", "izayoi"]) # type: ignore
def test_url_no_resource(self): def test_url_no_resource(self) -> None:
url = self.client._url("/simple") url = self.client._url("/simple")
assert url == f"{url_prefix}simple" assert url == f"{url_prefix}simple"
def test_url_unversioned_api(self): def test_url_unversioned_api(self) -> None:
url = self.client._url("/hello/{0}/world", "somename", versioned_api=False) url = self.client._url("/hello/{0}/world", "somename", versioned_api=False)
assert url == f"{url_base}hello/somename/world" assert url == f"{url_base}hello/somename/world"
def test_version(self): def test_version(self) -> None:
self.client.version() self.client.version()
fake_request.assert_called_with( fake_request.assert_called_with(
"GET", url_prefix + "version", timeout=DEFAULT_TIMEOUT_SECONDS "GET", url_prefix + "version", timeout=DEFAULT_TIMEOUT_SECONDS
) )
def test_version_no_api_version(self): def test_version_no_api_version(self) -> None:
self.client.version(False) self.client.version(False)
fake_request.assert_called_with( fake_request.assert_called_with(
"GET", url_base + "version", timeout=DEFAULT_TIMEOUT_SECONDS "GET", url_base + "version", timeout=DEFAULT_TIMEOUT_SECONDS
) )
def test_retrieve_server_version(self): def test_retrieve_server_version(self) -> None:
client = APIClient(version="auto") client = APIClient(version="auto")
assert isinstance(client._version, str) assert isinstance(client._version, str)
assert not (client._version == "auto") assert not (client._version == "auto")
client.close() client.close()
def test_auto_retrieve_server_version(self): def test_auto_retrieve_server_version(self) -> None:
version = self.client._retrieve_server_version() version = self.client._retrieve_server_version()
assert isinstance(version, str) assert isinstance(version, str)
def test_info(self): def test_info(self) -> None:
self.client.info() self.client.info()
fake_request.assert_called_with( fake_request.assert_called_with(
"GET", url_prefix + "info", timeout=DEFAULT_TIMEOUT_SECONDS "GET", url_prefix + "info", timeout=DEFAULT_TIMEOUT_SECONDS
) )
def test_search(self): def test_search(self) -> None:
self.client.get_json("/images/search", params={"term": "busybox"}) self.client.get_json("/images/search", params={"term": "busybox"})
fake_request.assert_called_with( fake_request.assert_called_with(
@ -225,7 +236,7 @@ class DockerApiTest(BaseAPIClientTest):
timeout=DEFAULT_TIMEOUT_SECONDS, timeout=DEFAULT_TIMEOUT_SECONDS,
) )
def test_login(self): def test_login(self) -> None:
self.client.login("sakuya", "izayoi") self.client.login("sakuya", "izayoi")
args = fake_request.call_args args = fake_request.call_args
assert args[0][0] == "POST" assert args[0][0] == "POST"
@ -242,42 +253,42 @@ class DockerApiTest(BaseAPIClientTest):
"serveraddress": None, "serveraddress": None,
} }
def _socket_path_for_client_session(self, client): def _socket_path_for_client_session(self, client) -> str:
socket_adapter = client.get_adapter("http+docker://") socket_adapter = client.get_adapter("http+docker://")
return socket_adapter.socket_path return socket_adapter.socket_path
def test_url_compatibility_unix(self): def test_url_compatibility_unix(self) -> None:
c = APIClient(base_url="unix://socket", version=DEFAULT_DOCKER_API_VERSION) c = APIClient(base_url="unix://socket", version=DEFAULT_DOCKER_API_VERSION)
assert self._socket_path_for_client_session(c) == "/socket" assert self._socket_path_for_client_session(c) == "/socket"
def test_url_compatibility_unix_triple_slash(self): def test_url_compatibility_unix_triple_slash(self) -> None:
c = APIClient(base_url="unix:///socket", version=DEFAULT_DOCKER_API_VERSION) c = APIClient(base_url="unix:///socket", version=DEFAULT_DOCKER_API_VERSION)
assert self._socket_path_for_client_session(c) == "/socket" assert self._socket_path_for_client_session(c) == "/socket"
def test_url_compatibility_http_unix_triple_slash(self): def test_url_compatibility_http_unix_triple_slash(self) -> None:
c = APIClient( c = APIClient(
base_url="http+unix:///socket", version=DEFAULT_DOCKER_API_VERSION base_url="http+unix:///socket", version=DEFAULT_DOCKER_API_VERSION
) )
assert self._socket_path_for_client_session(c) == "/socket" assert self._socket_path_for_client_session(c) == "/socket"
def test_url_compatibility_http(self): def test_url_compatibility_http(self) -> None:
c = APIClient( c = APIClient(
base_url="http://hostname:1234", version=DEFAULT_DOCKER_API_VERSION base_url="http://hostname:1234", version=DEFAULT_DOCKER_API_VERSION
) )
assert c.base_url == "http://hostname:1234" assert c.base_url == "http://hostname:1234"
def test_url_compatibility_tcp(self): def test_url_compatibility_tcp(self) -> None:
c = APIClient( c = APIClient(
base_url="tcp://hostname:1234", version=DEFAULT_DOCKER_API_VERSION base_url="tcp://hostname:1234", version=DEFAULT_DOCKER_API_VERSION
) )
assert c.base_url == "http://hostname:1234" assert c.base_url == "http://hostname:1234"
def test_remove_link(self): def test_remove_link(self) -> None:
self.client.delete_call( self.client.delete_call(
"/containers/{0}", "/containers/{0}",
"3cc2351ab11b", "3cc2351ab11b",
@ -291,7 +302,7 @@ class DockerApiTest(BaseAPIClientTest):
timeout=DEFAULT_TIMEOUT_SECONDS, timeout=DEFAULT_TIMEOUT_SECONDS,
) )
def test_stream_helper_decoding(self): def test_stream_helper_decoding(self) -> None:
status_code, content = fake_api.fake_responses[url_prefix + "events"]() status_code, content = fake_api.fake_responses[url_prefix + "events"]()
content_str = json.dumps(content).encode("utf-8") content_str = json.dumps(content).encode("utf-8")
body = io.BytesIO(content_str) body = io.BytesIO(content_str)
@ -318,7 +329,7 @@ class DockerApiTest(BaseAPIClientTest):
raw_resp._fp.seek(0) raw_resp._fp.seek(0)
resp = response(status_code=status_code, content=content, raw=raw_resp) resp = response(status_code=status_code, content=content, raw=raw_resp)
result = next(self.client._stream_helper(resp)) result = next(self.client._stream_helper(resp))
assert result == content_str.decode("utf-8") assert result == content_str.decode("utf-8") # type: ignore
# non-chunked response, pass `decode=True` to the helper # non-chunked response, pass `decode=True` to the helper
raw_resp._fp.seek(0) raw_resp._fp.seek(0)
@ -328,7 +339,7 @@ class DockerApiTest(BaseAPIClientTest):
class UnixSocketStreamTest(unittest.TestCase): class UnixSocketStreamTest(unittest.TestCase):
def setUp(self): def setUp(self) -> None:
socket_dir = tempfile.mkdtemp() socket_dir = tempfile.mkdtemp()
self.build_context = tempfile.mkdtemp() self.build_context = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, socket_dir) self.addCleanup(shutil.rmtree, socket_dir)
@ -339,23 +350,23 @@ class UnixSocketStreamTest(unittest.TestCase):
server_thread = threading.Thread(target=self.run_server) server_thread = threading.Thread(target=self.run_server)
server_thread.daemon = True server_thread.daemon = True
server_thread.start() server_thread.start()
self.response = None self.response: t.Any = None
self.request_handler = None self.request_handler: t.Any = None
self.addCleanup(server_thread.join) self.addCleanup(server_thread.join)
self.addCleanup(self.stop) self.addCleanup(self.stop)
def stop(self): def stop(self) -> None:
self.stop_server = True self.stop_server = True
def _setup_socket(self): def _setup_socket(self) -> socket.socket:
server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server_sock.bind(self.socket_file) server_sock.bind(self.socket_file)
# Non-blocking mode so that we can shut the test down easily # Non-blocking mode so that we can shut the test down easily
server_sock.setblocking(0) server_sock.setblocking(0) # type: ignore
server_sock.listen(5) server_sock.listen(5)
return server_sock return server_sock
def run_server(self): def run_server(self) -> None:
try: try:
while not self.stop_server: while not self.stop_server:
try: try:
@ -365,7 +376,7 @@ class UnixSocketStreamTest(unittest.TestCase):
time.sleep(0.01) time.sleep(0.01)
continue continue
connection.setblocking(1) connection.setblocking(1) # type: ignore
try: try:
self.request_handler(connection) self.request_handler(connection)
finally: finally:
@ -373,7 +384,7 @@ class UnixSocketStreamTest(unittest.TestCase):
finally: finally:
self.server_socket.close() self.server_socket.close()
def early_response_sending_handler(self, connection): def early_response_sending_handler(self, connection) -> None:
data = b"" data = b""
headers = None headers = None
@ -395,7 +406,7 @@ class UnixSocketStreamTest(unittest.TestCase):
data += connection.recv(2048) data += connection.recv(2048)
@pytest.mark.skipif(constants.IS_WINDOWS_PLATFORM, reason="Unix only") @pytest.mark.skipif(constants.IS_WINDOWS_PLATFORM, reason="Unix only")
def test_early_stream_response(self): def test_early_stream_response(self) -> None:
self.request_handler = self.early_response_sending_handler self.request_handler = self.early_response_sending_handler
lines = [] lines = []
for i in range(0, 50): for i in range(0, 50):
@ -405,7 +416,7 @@ class UnixSocketStreamTest(unittest.TestCase):
lines.append(b"") lines.append(b"")
self.response = ( self.response = (
b"HTTP/1.1 200 OK\r\n" b"Transfer-Encoding: chunked\r\n" b"\r\n" b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n"
) + b"\r\n".join(lines) ) + b"\r\n".join(lines)
with APIClient( with APIClient(
@ -459,8 +470,12 @@ class TCPSocketStreamTest(unittest.TestCase):
built on these islands for generations past? Now shall what of Him? built on these islands for generations past? Now shall what of Him?
""" """
server: ThreadingTCPServer
thread: threading.Thread
address: str
@classmethod @classmethod
def setup_class(cls): def setup_class(cls) -> None:
cls.server = ThreadingTCPServer(("", 0), cls.get_handler_class()) cls.server = ThreadingTCPServer(("", 0), cls.get_handler_class())
cls.thread = threading.Thread(target=cls.server.serve_forever) cls.thread = threading.Thread(target=cls.server.serve_forever)
cls.thread.daemon = True cls.thread.daemon = True
@ -468,13 +483,13 @@ class TCPSocketStreamTest(unittest.TestCase):
cls.address = f"http://{socket.gethostname()}:{cls.server.server_address[1]}" cls.address = f"http://{socket.gethostname()}:{cls.server.server_address[1]}"
@classmethod @classmethod
def teardown_class(cls): def teardown_class(cls) -> None:
cls.server.shutdown() cls.server.shutdown()
cls.server.server_close() cls.server.server_close()
cls.thread.join() cls.thread.join()
@classmethod @classmethod
def get_handler_class(cls): def get_handler_class(cls) -> t.Type[BaseHTTPRequestHandler]:
stdout_data = cls.stdout_data stdout_data = cls.stdout_data
stderr_data = cls.stderr_data stderr_data = cls.stderr_data
@ -510,7 +525,12 @@ class TCPSocketStreamTest(unittest.TestCase):
return Handler return Handler
def request(self, stream=None, tty=None, demux=None): def request(
self,
stream: bool | None = None,
tty: bool | None = None,
demux: bool | None = None,
) -> t.Any:
assert stream is not None and tty is not None and demux is not None assert stream is not None and tty is not None and demux is not None
with APIClient( with APIClient(
base_url=self.address, base_url=self.address,
@ -523,51 +543,51 @@ class TCPSocketStreamTest(unittest.TestCase):
resp = client._post(url, stream=True) resp = client._post(url, stream=True)
return client._read_from_socket(resp, stream=stream, tty=tty, demux=demux) return client._read_from_socket(resp, stream=stream, tty=tty, demux=demux)
def test_read_from_socket_tty(self): def test_read_from_socket_tty(self) -> None:
res = self.request(stream=True, tty=True, demux=False) res = self.request(stream=True, tty=True, demux=False)
assert next(res) == self.stdout_data + self.stderr_data assert next(res) == self.stdout_data + self.stderr_data
with self.assertRaises(StopIteration): with self.assertRaises(StopIteration):
next(res) next(res)
def test_read_from_socket_tty_demux(self): def test_read_from_socket_tty_demux(self) -> None:
res = self.request(stream=True, tty=True, demux=True) res = self.request(stream=True, tty=True, demux=True)
assert next(res) == (self.stdout_data + self.stderr_data, None) assert next(res) == (self.stdout_data + self.stderr_data, None)
with self.assertRaises(StopIteration): with self.assertRaises(StopIteration):
next(res) next(res)
def test_read_from_socket_no_tty(self): def test_read_from_socket_no_tty(self) -> None:
res = self.request(stream=True, tty=False, demux=False) res = self.request(stream=True, tty=False, demux=False)
assert next(res) == self.stdout_data assert next(res) == self.stdout_data
assert next(res) == self.stderr_data assert next(res) == self.stderr_data
with self.assertRaises(StopIteration): with self.assertRaises(StopIteration):
next(res) next(res)
def test_read_from_socket_no_tty_demux(self): def test_read_from_socket_no_tty_demux(self) -> None:
res = self.request(stream=True, tty=False, demux=True) res = self.request(stream=True, tty=False, demux=True)
assert (self.stdout_data, None) == next(res) assert (self.stdout_data, None) == next(res)
assert (None, self.stderr_data) == next(res) assert (None, self.stderr_data) == next(res)
with self.assertRaises(StopIteration): with self.assertRaises(StopIteration):
next(res) next(res)
def test_read_from_socket_no_stream_tty(self): def test_read_from_socket_no_stream_tty(self) -> None:
res = self.request(stream=False, tty=True, demux=False) res = self.request(stream=False, tty=True, demux=False)
assert res == self.stdout_data + self.stderr_data assert res == self.stdout_data + self.stderr_data
def test_read_from_socket_no_stream_tty_demux(self): def test_read_from_socket_no_stream_tty_demux(self) -> None:
res = self.request(stream=False, tty=True, demux=True) res = self.request(stream=False, tty=True, demux=True)
assert res == (self.stdout_data + self.stderr_data, None) assert res == (self.stdout_data + self.stderr_data, None)
def test_read_from_socket_no_stream_no_tty(self): def test_read_from_socket_no_stream_no_tty(self) -> None:
res = self.request(stream=False, tty=False, demux=False) res = self.request(stream=False, tty=False, demux=False)
assert res == self.stdout_data + self.stderr_data assert res == self.stdout_data + self.stderr_data
def test_read_from_socket_no_stream_no_tty_demux(self): def test_read_from_socket_no_stream_no_tty_demux(self) -> None:
res = self.request(stream=False, tty=False, demux=True) res = self.request(stream=False, tty=False, demux=True)
assert res == (self.stdout_data, self.stderr_data) assert res == (self.stdout_data, self.stderr_data)
class UserAgentTest(unittest.TestCase): class UserAgentTest(unittest.TestCase):
def setUp(self): def setUp(self) -> None:
self.patcher = mock.patch.object( self.patcher = mock.patch.object(
APIClient, APIClient,
"send", "send",
@ -575,10 +595,10 @@ class UserAgentTest(unittest.TestCase):
) )
self.mock_send = self.patcher.start() self.mock_send = self.patcher.start()
def tearDown(self): def tearDown(self) -> None:
self.patcher.stop() self.patcher.stop()
def test_default_user_agent(self): def test_default_user_agent(self) -> None:
client = APIClient(version=DEFAULT_DOCKER_API_VERSION) client = APIClient(version=DEFAULT_DOCKER_API_VERSION)
client.version() client.version()
@ -587,7 +607,7 @@ class UserAgentTest(unittest.TestCase):
expected = "ansible-community.docker" expected = "ansible-community.docker"
assert headers["User-Agent"] == expected assert headers["User-Agent"] == expected
def test_custom_user_agent(self): def test_custom_user_agent(self) -> None:
client = APIClient(user_agent="foo/bar", version=DEFAULT_DOCKER_API_VERSION) client = APIClient(user_agent="foo/bar", version=DEFAULT_DOCKER_API_VERSION)
client.version() client.version()
@ -598,44 +618,44 @@ class UserAgentTest(unittest.TestCase):
class DisableSocketTest(unittest.TestCase): class DisableSocketTest(unittest.TestCase):
class DummySocket: class DummySocket:
def __init__(self, timeout=60): def __init__(self, timeout: int | float | None = 60) -> None:
self.timeout = timeout self.timeout = timeout
self._sock = None self._sock: t.Any = None
def settimeout(self, timeout): def settimeout(self, timeout: int | float | None) -> None:
self.timeout = timeout self.timeout = timeout
def gettimeout(self): def gettimeout(self) -> int | float | None:
return self.timeout return self.timeout
def setUp(self): def setUp(self) -> None:
self.client = APIClient(version=DEFAULT_DOCKER_API_VERSION) self.client = APIClient(version=DEFAULT_DOCKER_API_VERSION)
def test_disable_socket_timeout(self): def test_disable_socket_timeout(self) -> None:
"""Test that the timeout is disabled on a generic socket object.""" """Test that the timeout is disabled on a generic socket object."""
the_socket = self.DummySocket() the_socket = self.DummySocket()
self.client._disable_socket_timeout(the_socket) self.client._disable_socket_timeout(the_socket) # type: ignore
assert the_socket.timeout is None assert the_socket.timeout is None
def test_disable_socket_timeout2(self): def test_disable_socket_timeout2(self) -> None:
"""Test that the timeouts are disabled on a generic socket object """Test that the timeouts are disabled on a generic socket object
and it's _sock object if present.""" and it's _sock object if present."""
the_socket = self.DummySocket() the_socket = self.DummySocket()
the_socket._sock = self.DummySocket() the_socket._sock = self.DummySocket() # type: ignore
self.client._disable_socket_timeout(the_socket) self.client._disable_socket_timeout(the_socket) # type: ignore
assert the_socket.timeout is None assert the_socket.timeout is None
assert the_socket._sock.timeout is None assert the_socket._sock.timeout is None
def test_disable_socket_timout_non_blocking(self): def test_disable_socket_timout_non_blocking(self) -> None:
"""Test that a non-blocking socket does not get set to blocking.""" """Test that a non-blocking socket does not get set to blocking."""
the_socket = self.DummySocket() the_socket = self.DummySocket()
the_socket._sock = self.DummySocket(0.0) the_socket._sock = self.DummySocket(0.0) # type: ignore
self.client._disable_socket_timeout(the_socket) self.client._disable_socket_timeout(the_socket) # type: ignore
assert the_socket.timeout is None assert the_socket.timeout is None
assert the_socket._sock.timeout == 0.0 assert the_socket._sock.timeout == 0.0

View File

@ -8,6 +8,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from ansible_collections.community.docker.plugins.module_utils._api import constants from ansible_collections.community.docker.plugins.module_utils._api import constants
from ansible_collections.community.docker.tests.unit.plugins.module_utils._api.constants import ( from ansible_collections.community.docker.tests.unit.plugins.module_utils._api.constants import (
DEFAULT_DOCKER_API_VERSION, DEFAULT_DOCKER_API_VERSION,
@ -16,6 +18,10 @@ from ansible_collections.community.docker.tests.unit.plugins.module_utils._api.c
from . import fake_stat from . import fake_stat
if t.TYPE_CHECKING:
from collections.abc import Callable
CURRENT_VERSION = f"v{DEFAULT_DOCKER_API_VERSION}" CURRENT_VERSION = f"v{DEFAULT_DOCKER_API_VERSION}"
FAKE_CONTAINER_ID = "3cc2351ab11b" FAKE_CONTAINER_ID = "3cc2351ab11b"
@ -38,7 +44,7 @@ FAKE_SECRET_NAME = "super_secret"
# for clarity and readability # for clarity and readability
def get_fake_version(): def get_fake_version() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = { response = {
"ApiVersion": "1.35", "ApiVersion": "1.35",
@ -73,7 +79,7 @@ def get_fake_version():
return status_code, response return status_code, response
def get_fake_info(): def get_fake_info() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = { response = {
"Containers": 1, "Containers": 1,
@ -86,23 +92,23 @@ def get_fake_info():
return status_code, response return status_code, response
def post_fake_auth(): def post_fake_auth() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Status": "Login Succeeded", "IdentityToken": "9cbaf023786cd7"} response = {"Status": "Login Succeeded", "IdentityToken": "9cbaf023786cd7"}
return status_code, response return status_code, response
def get_fake_ping(): def get_fake_ping() -> tuple[int, str]:
return 200, "OK" return 200, "OK"
def get_fake_search(): def get_fake_search() -> tuple[int, list[dict[str, t.Any]]]:
status_code = 200 status_code = 200
response = [{"Name": "busybox", "Description": "Fake Description"}] response = [{"Name": "busybox", "Description": "Fake Description"}]
return status_code, response return status_code, response
def get_fake_images(): def get_fake_images() -> tuple[int, list[dict[str, t.Any]]]:
status_code = 200 status_code = 200
response = [ response = [
{ {
@ -115,7 +121,7 @@ def get_fake_images():
return status_code, response return status_code, response
def get_fake_image_history(): def get_fake_image_history() -> tuple[int, list[dict[str, t.Any]]]:
status_code = 200 status_code = 200
response = [ response = [
{"Id": "b750fe79269d", "Created": 1364102658, "CreatedBy": "/bin/bash"}, {"Id": "b750fe79269d", "Created": 1364102658, "CreatedBy": "/bin/bash"},
@ -125,14 +131,14 @@ def get_fake_image_history():
return status_code, response return status_code, response
def post_fake_import_image(): def post_fake_import_image() -> tuple[int, str]:
status_code = 200 status_code = 200
response = "Import messages..." response = "Import messages..."
return status_code, response return status_code, response
def get_fake_containers(): def get_fake_containers() -> tuple[int, list[dict[str, t.Any]]]:
status_code = 200 status_code = 200
response = [ response = [
{ {
@ -146,25 +152,25 @@ def get_fake_containers():
return status_code, response return status_code, response
def post_fake_start_container(): def post_fake_start_container() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_CONTAINER_ID} response = {"Id": FAKE_CONTAINER_ID}
return status_code, response return status_code, response
def post_fake_resize_container(): def post_fake_resize_container() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_CONTAINER_ID} response = {"Id": FAKE_CONTAINER_ID}
return status_code, response return status_code, response
def post_fake_create_container(): def post_fake_create_container() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_CONTAINER_ID} response = {"Id": FAKE_CONTAINER_ID}
return status_code, response return status_code, response
def get_fake_inspect_container(tty=False): def get_fake_inspect_container(tty: bool = False) -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = { response = {
"Id": FAKE_CONTAINER_ID, "Id": FAKE_CONTAINER_ID,
@ -188,7 +194,7 @@ def get_fake_inspect_container(tty=False):
return status_code, response return status_code, response
def get_fake_inspect_image(): def get_fake_inspect_image() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = { response = {
"Id": FAKE_IMAGE_ID, "Id": FAKE_IMAGE_ID,
@ -221,19 +227,19 @@ def get_fake_inspect_image():
return status_code, response return status_code, response
def get_fake_insert_image(): def get_fake_insert_image() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"StatusCode": 0} response = {"StatusCode": 0}
return status_code, response return status_code, response
def get_fake_wait(): def get_fake_wait() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"StatusCode": 0} response = {"StatusCode": 0}
return status_code, response return status_code, response
def get_fake_logs(): def get_fake_logs() -> tuple[int, bytes]:
status_code = 200 status_code = 200
response = ( response = (
b"\x01\x00\x00\x00\x00\x00\x00\x00" b"\x01\x00\x00\x00\x00\x00\x00\x00"
@ -244,13 +250,13 @@ def get_fake_logs():
return status_code, response return status_code, response
def get_fake_diff(): def get_fake_diff() -> tuple[int, list[dict[str, t.Any]]]:
status_code = 200 status_code = 200
response = [{"Path": "/test", "Kind": 1}] response = [{"Path": "/test", "Kind": 1}]
return status_code, response return status_code, response
def get_fake_events(): def get_fake_events() -> tuple[int, list[dict[str, t.Any]]]:
status_code = 200 status_code = 200
response = [ response = [
{ {
@ -263,19 +269,19 @@ def get_fake_events():
return status_code, response return status_code, response
def get_fake_export(): def get_fake_export() -> tuple[int, str]:
status_code = 200 status_code = 200
response = "Byte Stream...." response = "Byte Stream...."
return status_code, response return status_code, response
def post_fake_exec_create(): def post_fake_exec_create() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_EXEC_ID} response = {"Id": FAKE_EXEC_ID}
return status_code, response return status_code, response
def post_fake_exec_start(): def post_fake_exec_start() -> tuple[int, bytes]:
status_code = 200 status_code = 200
response = ( response = (
b"\x01\x00\x00\x00\x00\x00\x00\x11bin\nboot\ndev\netc\n" b"\x01\x00\x00\x00\x00\x00\x00\x11bin\nboot\ndev\netc\n"
@ -285,12 +291,12 @@ def post_fake_exec_start():
return status_code, response return status_code, response
def post_fake_exec_resize(): def post_fake_exec_resize() -> tuple[int, str]:
status_code = 201 status_code = 201
return status_code, "" return status_code, ""
def get_fake_exec_inspect(): def get_fake_exec_inspect() -> tuple[int, dict[str, t.Any]]:
return 200, { return 200, {
"OpenStderr": True, "OpenStderr": True,
"OpenStdout": True, "OpenStdout": True,
@ -309,102 +315,102 @@ def get_fake_exec_inspect():
} }
def post_fake_stop_container(): def post_fake_stop_container() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_CONTAINER_ID} response = {"Id": FAKE_CONTAINER_ID}
return status_code, response return status_code, response
def post_fake_kill_container(): def post_fake_kill_container() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_CONTAINER_ID} response = {"Id": FAKE_CONTAINER_ID}
return status_code, response return status_code, response
def post_fake_pause_container(): def post_fake_pause_container() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_CONTAINER_ID} response = {"Id": FAKE_CONTAINER_ID}
return status_code, response return status_code, response
def post_fake_unpause_container(): def post_fake_unpause_container() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_CONTAINER_ID} response = {"Id": FAKE_CONTAINER_ID}
return status_code, response return status_code, response
def post_fake_restart_container(): def post_fake_restart_container() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_CONTAINER_ID} response = {"Id": FAKE_CONTAINER_ID}
return status_code, response return status_code, response
def post_fake_rename_container(): def post_fake_rename_container() -> tuple[int, None]:
status_code = 204 status_code = 204
return status_code, None return status_code, None
def delete_fake_remove_container(): def delete_fake_remove_container() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_CONTAINER_ID} response = {"Id": FAKE_CONTAINER_ID}
return status_code, response return status_code, response
def post_fake_image_create(): def post_fake_image_create() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_IMAGE_ID} response = {"Id": FAKE_IMAGE_ID}
return status_code, response return status_code, response
def delete_fake_remove_image(): def delete_fake_remove_image() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_IMAGE_ID} response = {"Id": FAKE_IMAGE_ID}
return status_code, response return status_code, response
def get_fake_get_image(): def get_fake_get_image() -> tuple[int, str]:
status_code = 200 status_code = 200
response = "Byte Stream...." response = "Byte Stream...."
return status_code, response return status_code, response
def post_fake_load_image(): def post_fake_load_image() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_IMAGE_ID} response = {"Id": FAKE_IMAGE_ID}
return status_code, response return status_code, response
def post_fake_commit(): def post_fake_commit() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_CONTAINER_ID} response = {"Id": FAKE_CONTAINER_ID}
return status_code, response return status_code, response
def post_fake_push(): def post_fake_push() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_IMAGE_ID} response = {"Id": FAKE_IMAGE_ID}
return status_code, response return status_code, response
def post_fake_build_container(): def post_fake_build_container() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_CONTAINER_ID} response = {"Id": FAKE_CONTAINER_ID}
return status_code, response return status_code, response
def post_fake_tag_image(): def post_fake_tag_image() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"Id": FAKE_IMAGE_ID} response = {"Id": FAKE_IMAGE_ID}
return status_code, response return status_code, response
def get_fake_stats(): def get_fake_stats() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = fake_stat.OBJ response = fake_stat.OBJ
return status_code, response return status_code, response
def get_fake_top(): def get_fake_top() -> tuple[int, dict[str, t.Any]]:
return 200, { return 200, {
"Processes": [ "Processes": [
[ [
@ -431,7 +437,7 @@ def get_fake_top():
} }
def get_fake_volume_list(): def get_fake_volume_list() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = { response = {
"Volumes": [ "Volumes": [
@ -452,7 +458,7 @@ def get_fake_volume_list():
return status_code, response return status_code, response
def get_fake_volume(): def get_fake_volume() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = { response = {
"Name": "perfectcherryblossom", "Name": "perfectcherryblossom",
@ -464,23 +470,23 @@ def get_fake_volume():
return status_code, response return status_code, response
def fake_remove_volume(): def fake_remove_volume() -> tuple[int, None]:
return 204, None return 204, None
def post_fake_update_container(): def post_fake_update_container() -> tuple[int, dict[str, t.Any]]:
return 200, {"Warnings": []} return 200, {"Warnings": []}
def post_fake_update_node(): def post_fake_update_node() -> tuple[int, None]:
return 200, None return 200, None
def post_fake_join_swarm(): def post_fake_join_swarm() -> tuple[int, None]:
return 200, None return 200, None
def get_fake_network_list(): def get_fake_network_list() -> tuple[int, list[dict[str, t.Any]]]:
return 200, [ return 200, [
{ {
"Name": "bridge", "Name": "bridge",
@ -510,27 +516,27 @@ def get_fake_network_list():
] ]
def get_fake_network(): def get_fake_network() -> tuple[int, dict[str, t.Any]]:
return 200, get_fake_network_list()[1][0] return 200, get_fake_network_list()[1][0]
def post_fake_network(): def post_fake_network() -> tuple[int, dict[str, t.Any]]:
return 201, {"Id": FAKE_NETWORK_ID, "Warnings": []} return 201, {"Id": FAKE_NETWORK_ID, "Warnings": []}
def delete_fake_network(): def delete_fake_network() -> tuple[int, None]:
return 204, None return 204, None
def post_fake_network_connect(): def post_fake_network_connect() -> tuple[int, None]:
return 200, None return 200, None
def post_fake_network_disconnect(): def post_fake_network_disconnect() -> tuple[int, None]:
return 200, None return 200, None
def post_fake_secret(): def post_fake_secret() -> tuple[int, dict[str, t.Any]]:
status_code = 200 status_code = 200
response = {"ID": FAKE_SECRET_ID} response = {"ID": FAKE_SECRET_ID}
return status_code, response return status_code, response
@ -541,7 +547,7 @@ prefix = "http+docker://localhost" # pylint: disable=invalid-name
if constants.IS_WINDOWS_PLATFORM: if constants.IS_WINDOWS_PLATFORM:
prefix = "http+docker://localnpipe" # pylint: disable=invalid-name prefix = "http+docker://localnpipe" # pylint: disable=invalid-name
fake_responses = { fake_responses: dict[str | tuple[str, str], Callable] = {
f"{prefix}/version": get_fake_version, f"{prefix}/version": get_fake_version,
f"{prefix}/{CURRENT_VERSION}/version": get_fake_version, f"{prefix}/{CURRENT_VERSION}/version": get_fake_version,
f"{prefix}/{CURRENT_VERSION}/info": get_fake_info, f"{prefix}/{CURRENT_VERSION}/info": get_fake_info,
@ -574,6 +580,7 @@ fake_responses = {
f"{prefix}/{CURRENT_VERSION}/containers/3cc2351ab11b/unpause": post_fake_unpause_container, f"{prefix}/{CURRENT_VERSION}/containers/3cc2351ab11b/unpause": post_fake_unpause_container,
f"{prefix}/{CURRENT_VERSION}/containers/3cc2351ab11b/restart": post_fake_restart_container, f"{prefix}/{CURRENT_VERSION}/containers/3cc2351ab11b/restart": post_fake_restart_container,
f"{prefix}/{CURRENT_VERSION}/containers/3cc2351ab11b": delete_fake_remove_container, f"{prefix}/{CURRENT_VERSION}/containers/3cc2351ab11b": delete_fake_remove_container,
# TODO: the following is a duplicate of the import endpoint further above!
f"{prefix}/{CURRENT_VERSION}/images/create": post_fake_image_create, f"{prefix}/{CURRENT_VERSION}/images/create": post_fake_image_create,
f"{prefix}/{CURRENT_VERSION}/images/e9aa60c60128": delete_fake_remove_image, f"{prefix}/{CURRENT_VERSION}/images/e9aa60c60128": delete_fake_remove_image,
f"{prefix}/{CURRENT_VERSION}/images/e9aa60c60128/get": get_fake_get_image, f"{prefix}/{CURRENT_VERSION}/images/e9aa60c60128/get": get_fake_get_image,

View File

@ -15,6 +15,7 @@ import os.path
import random import random
import shutil import shutil
import tempfile import tempfile
import typing as t
import unittest import unittest
from unittest import mock from unittest import mock
@ -30,7 +31,7 @@ from ansible_collections.community.docker.plugins.module_utils._api.credentials.
class RegressionTest(unittest.TestCase): class RegressionTest(unittest.TestCase):
def test_803_urlsafe_encode(self): def test_803_urlsafe_encode(self) -> None:
auth_data = {"username": "root", "password": "GR?XGR?XGR?XGR?X"} auth_data = {"username": "root", "password": "GR?XGR?XGR?XGR?X"}
encoded = auth.encode_header(auth_data) encoded = auth.encode_header(auth_data)
assert b"/" not in encoded assert b"/" not in encoded
@ -38,75 +39,75 @@ class RegressionTest(unittest.TestCase):
class ResolveRepositoryNameTest(unittest.TestCase): class ResolveRepositoryNameTest(unittest.TestCase):
def test_resolve_repository_name_hub_library_image(self): def test_resolve_repository_name_hub_library_image(self) -> None:
assert auth.resolve_repository_name("image") == ("docker.io", "image") assert auth.resolve_repository_name("image") == ("docker.io", "image")
def test_resolve_repository_name_dotted_hub_library_image(self): def test_resolve_repository_name_dotted_hub_library_image(self) -> None:
assert auth.resolve_repository_name("image.valid") == ( assert auth.resolve_repository_name("image.valid") == (
"docker.io", "docker.io",
"image.valid", "image.valid",
) )
def test_resolve_repository_name_hub_image(self): def test_resolve_repository_name_hub_image(self) -> None:
assert auth.resolve_repository_name("username/image") == ( assert auth.resolve_repository_name("username/image") == (
"docker.io", "docker.io",
"username/image", "username/image",
) )
def test_explicit_hub_index_library_image(self): def test_explicit_hub_index_library_image(self) -> None:
assert auth.resolve_repository_name("docker.io/image") == ("docker.io", "image") assert auth.resolve_repository_name("docker.io/image") == ("docker.io", "image")
def test_explicit_legacy_hub_index_library_image(self): def test_explicit_legacy_hub_index_library_image(self) -> None:
assert auth.resolve_repository_name("index.docker.io/image") == ( assert auth.resolve_repository_name("index.docker.io/image") == (
"docker.io", "docker.io",
"image", "image",
) )
def test_resolve_repository_name_private_registry(self): def test_resolve_repository_name_private_registry(self) -> None:
assert auth.resolve_repository_name("my.registry.net/image") == ( assert auth.resolve_repository_name("my.registry.net/image") == (
"my.registry.net", "my.registry.net",
"image", "image",
) )
def test_resolve_repository_name_private_registry_with_port(self): def test_resolve_repository_name_private_registry_with_port(self) -> None:
assert auth.resolve_repository_name("my.registry.net:5000/image") == ( assert auth.resolve_repository_name("my.registry.net:5000/image") == (
"my.registry.net:5000", "my.registry.net:5000",
"image", "image",
) )
def test_resolve_repository_name_private_registry_with_username(self): def test_resolve_repository_name_private_registry_with_username(self) -> None:
assert auth.resolve_repository_name("my.registry.net/username/image") == ( assert auth.resolve_repository_name("my.registry.net/username/image") == (
"my.registry.net", "my.registry.net",
"username/image", "username/image",
) )
def test_resolve_repository_name_no_dots_but_port(self): def test_resolve_repository_name_no_dots_but_port(self) -> None:
assert auth.resolve_repository_name("hostname:5000/image") == ( assert auth.resolve_repository_name("hostname:5000/image") == (
"hostname:5000", "hostname:5000",
"image", "image",
) )
def test_resolve_repository_name_no_dots_but_port_and_username(self): def test_resolve_repository_name_no_dots_but_port_and_username(self) -> None:
assert auth.resolve_repository_name("hostname:5000/username/image") == ( assert auth.resolve_repository_name("hostname:5000/username/image") == (
"hostname:5000", "hostname:5000",
"username/image", "username/image",
) )
def test_resolve_repository_name_localhost(self): def test_resolve_repository_name_localhost(self) -> None:
assert auth.resolve_repository_name("localhost/image") == ("localhost", "image") assert auth.resolve_repository_name("localhost/image") == ("localhost", "image")
def test_resolve_repository_name_localhost_with_username(self): def test_resolve_repository_name_localhost_with_username(self) -> None:
assert auth.resolve_repository_name("localhost/username/image") == ( assert auth.resolve_repository_name("localhost/username/image") == (
"localhost", "localhost",
"username/image", "username/image",
) )
def test_invalid_index_name(self): def test_invalid_index_name(self) -> None:
with pytest.raises(errors.InvalidRepository): with pytest.raises(errors.InvalidRepository):
auth.resolve_repository_name("-gecko.com/image") auth.resolve_repository_name("-gecko.com/image")
def encode_auth(auth_info): def encode_auth(auth_info: dict[str, t.Any]) -> bytes:
return base64.b64encode( return base64.b64encode(
auth_info.get("username", "").encode("utf-8") auth_info.get("username", "").encode("utf-8")
+ b":" + b":"
@ -131,19 +132,19 @@ class ResolveAuthTest(unittest.TestCase):
} }
) )
def test_resolve_authconfig_hostname_only(self): def test_resolve_authconfig_hostname_only(self) -> None:
assert ( assert (
auth.resolve_authconfig(self.auth_config, "my.registry.net")["username"] auth.resolve_authconfig(self.auth_config, "my.registry.net")["username"]
== "privateuser" == "privateuser"
) )
def test_resolve_authconfig_no_protocol(self): def test_resolve_authconfig_no_protocol(self) -> None:
assert ( assert (
auth.resolve_authconfig(self.auth_config, "my.registry.net/v1/")["username"] auth.resolve_authconfig(self.auth_config, "my.registry.net/v1/")["username"]
== "privateuser" == "privateuser"
) )
def test_resolve_authconfig_no_path(self): def test_resolve_authconfig_no_path(self) -> None:
assert ( assert (
auth.resolve_authconfig(self.auth_config, "http://my.registry.net")[ auth.resolve_authconfig(self.auth_config, "http://my.registry.net")[
"username" "username"
@ -151,7 +152,7 @@ class ResolveAuthTest(unittest.TestCase):
== "privateuser" == "privateuser"
) )
def test_resolve_authconfig_no_path_trailing_slash(self): def test_resolve_authconfig_no_path_trailing_slash(self) -> None:
assert ( assert (
auth.resolve_authconfig(self.auth_config, "http://my.registry.net/")[ auth.resolve_authconfig(self.auth_config, "http://my.registry.net/")[
"username" "username"
@ -159,7 +160,7 @@ class ResolveAuthTest(unittest.TestCase):
== "privateuser" == "privateuser"
) )
def test_resolve_authconfig_no_path_wrong_secure_proto(self): def test_resolve_authconfig_no_path_wrong_secure_proto(self) -> None:
assert ( assert (
auth.resolve_authconfig(self.auth_config, "https://my.registry.net")[ auth.resolve_authconfig(self.auth_config, "https://my.registry.net")[
"username" "username"
@ -167,7 +168,7 @@ class ResolveAuthTest(unittest.TestCase):
== "privateuser" == "privateuser"
) )
def test_resolve_authconfig_no_path_wrong_insecure_proto(self): def test_resolve_authconfig_no_path_wrong_insecure_proto(self) -> None:
assert ( assert (
auth.resolve_authconfig(self.auth_config, "http://index.docker.io")[ auth.resolve_authconfig(self.auth_config, "http://index.docker.io")[
"username" "username"
@ -175,7 +176,7 @@ class ResolveAuthTest(unittest.TestCase):
== "indexuser" == "indexuser"
) )
def test_resolve_authconfig_path_wrong_proto(self): def test_resolve_authconfig_path_wrong_proto(self) -> None:
assert ( assert (
auth.resolve_authconfig(self.auth_config, "https://my.registry.net/v1/")[ auth.resolve_authconfig(self.auth_config, "https://my.registry.net/v1/")[
"username" "username"
@ -183,15 +184,15 @@ class ResolveAuthTest(unittest.TestCase):
== "privateuser" == "privateuser"
) )
def test_resolve_authconfig_default_registry(self): def test_resolve_authconfig_default_registry(self) -> None:
assert auth.resolve_authconfig(self.auth_config)["username"] == "indexuser" assert auth.resolve_authconfig(self.auth_config)["username"] == "indexuser"
def test_resolve_authconfig_default_explicit_none(self): def test_resolve_authconfig_default_explicit_none(self) -> None:
assert ( assert (
auth.resolve_authconfig(self.auth_config, None)["username"] == "indexuser" auth.resolve_authconfig(self.auth_config, None)["username"] == "indexuser"
) )
def test_resolve_authconfig_fully_explicit(self): def test_resolve_authconfig_fully_explicit(self) -> None:
assert ( assert (
auth.resolve_authconfig(self.auth_config, "http://my.registry.net/v1/")[ auth.resolve_authconfig(self.auth_config, "http://my.registry.net/v1/")[
"username" "username"
@ -199,16 +200,16 @@ class ResolveAuthTest(unittest.TestCase):
== "privateuser" == "privateuser"
) )
def test_resolve_authconfig_legacy_config(self): def test_resolve_authconfig_legacy_config(self) -> None:
assert ( assert (
auth.resolve_authconfig(self.auth_config, "legacy.registry.url")["username"] auth.resolve_authconfig(self.auth_config, "legacy.registry.url")["username"]
== "legacyauth" == "legacyauth"
) )
def test_resolve_authconfig_no_match(self): def test_resolve_authconfig_no_match(self) -> None:
assert auth.resolve_authconfig(self.auth_config, "does.not.exist") is None assert auth.resolve_authconfig(self.auth_config, "does.not.exist") is None
def test_resolve_registry_and_auth_library_image(self): def test_resolve_registry_and_auth_library_image(self) -> None:
image = "image" image = "image"
assert ( assert (
auth.resolve_authconfig( auth.resolve_authconfig(
@ -217,7 +218,7 @@ class ResolveAuthTest(unittest.TestCase):
== "indexuser" == "indexuser"
) )
def test_resolve_registry_and_auth_hub_image(self): def test_resolve_registry_and_auth_hub_image(self) -> None:
image = "username/image" image = "username/image"
assert ( assert (
auth.resolve_authconfig( auth.resolve_authconfig(
@ -226,7 +227,7 @@ class ResolveAuthTest(unittest.TestCase):
== "indexuser" == "indexuser"
) )
def test_resolve_registry_and_auth_explicit_hub(self): def test_resolve_registry_and_auth_explicit_hub(self) -> None:
image = "docker.io/username/image" image = "docker.io/username/image"
assert ( assert (
auth.resolve_authconfig( auth.resolve_authconfig(
@ -235,7 +236,7 @@ class ResolveAuthTest(unittest.TestCase):
== "indexuser" == "indexuser"
) )
def test_resolve_registry_and_auth_explicit_legacy_hub(self): def test_resolve_registry_and_auth_explicit_legacy_hub(self) -> None:
image = "index.docker.io/username/image" image = "index.docker.io/username/image"
assert ( assert (
auth.resolve_authconfig( auth.resolve_authconfig(
@ -244,7 +245,7 @@ class ResolveAuthTest(unittest.TestCase):
== "indexuser" == "indexuser"
) )
def test_resolve_registry_and_auth_private_registry(self): def test_resolve_registry_and_auth_private_registry(self) -> None:
image = "my.registry.net/image" image = "my.registry.net/image"
assert ( assert (
auth.resolve_authconfig( auth.resolve_authconfig(
@ -253,7 +254,7 @@ class ResolveAuthTest(unittest.TestCase):
== "privateuser" == "privateuser"
) )
def test_resolve_registry_and_auth_unauthenticated_registry(self): def test_resolve_registry_and_auth_unauthenticated_registry(self) -> None:
image = "other.registry.net/image" image = "other.registry.net/image"
assert ( assert (
auth.resolve_authconfig( auth.resolve_authconfig(
@ -262,7 +263,7 @@ class ResolveAuthTest(unittest.TestCase):
is None is None
) )
def test_resolve_auth_with_empty_credstore_and_auth_dict(self): def test_resolve_auth_with_empty_credstore_and_auth_dict(self) -> None:
auth_config = auth.AuthConfig( auth_config = auth.AuthConfig(
{ {
"auths": auth.parse_auth( "auths": auth.parse_auth(
@ -281,13 +282,13 @@ class ResolveAuthTest(unittest.TestCase):
class LoadConfigTest(unittest.TestCase): class LoadConfigTest(unittest.TestCase):
def test_load_config_no_file(self): def test_load_config_no_file(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
cfg = auth.load_config(folder) cfg = auth.load_config(folder)
assert cfg is not None assert cfg is not None
def test_load_legacy_config(self): def test_load_legacy_config(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
cfg_path = os.path.join(folder, ".dockercfg") cfg_path = os.path.join(folder, ".dockercfg")
@ -299,13 +300,13 @@ class LoadConfigTest(unittest.TestCase):
cfg = auth.load_config(cfg_path) cfg = auth.load_config(cfg_path)
assert auth.resolve_authconfig(cfg) is not None assert auth.resolve_authconfig(cfg) is not None
assert cfg.auths[auth.INDEX_NAME] is not None assert cfg.auths[auth.INDEX_NAME] is not None
cfg = cfg.auths[auth.INDEX_NAME] cfg2 = cfg.auths[auth.INDEX_NAME]
assert cfg["username"] == "sakuya" assert cfg2["username"] == "sakuya"
assert cfg["password"] == "izayoi" assert cfg2["password"] == "izayoi"
assert cfg["email"] == "sakuya@scarlet.net" assert cfg2["email"] == "sakuya@scarlet.net"
assert cfg.get("Auth") is None assert cfg2.get("Auth") is None
def test_load_json_config(self): def test_load_json_config(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
cfg_path = os.path.join(folder, ".dockercfg") cfg_path = os.path.join(folder, ".dockercfg")
@ -316,13 +317,13 @@ class LoadConfigTest(unittest.TestCase):
cfg = auth.load_config(cfg_path) cfg = auth.load_config(cfg_path)
assert auth.resolve_authconfig(cfg) is not None assert auth.resolve_authconfig(cfg) is not None
assert cfg.auths[auth.INDEX_URL] is not None assert cfg.auths[auth.INDEX_URL] is not None
cfg = cfg.auths[auth.INDEX_URL] cfg2 = cfg.auths[auth.INDEX_URL]
assert cfg["username"] == "sakuya" assert cfg2["username"] == "sakuya"
assert cfg["password"] == "izayoi" assert cfg2["password"] == "izayoi"
assert cfg["email"] == email assert cfg2["email"] == email
assert cfg.get("Auth") is None assert cfg2.get("Auth") is None
def test_load_modern_json_config(self): def test_load_modern_json_config(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
cfg_path = os.path.join(folder, "config.json") cfg_path = os.path.join(folder, "config.json")
@ -333,12 +334,12 @@ class LoadConfigTest(unittest.TestCase):
cfg = auth.load_config(cfg_path) cfg = auth.load_config(cfg_path)
assert auth.resolve_authconfig(cfg) is not None assert auth.resolve_authconfig(cfg) is not None
assert cfg.auths[auth.INDEX_URL] is not None assert cfg.auths[auth.INDEX_URL] is not None
cfg = cfg.auths[auth.INDEX_URL] cfg2 = cfg.auths[auth.INDEX_URL]
assert cfg["username"] == "sakuya" assert cfg2["username"] == "sakuya"
assert cfg["password"] == "izayoi" assert cfg2["password"] == "izayoi"
assert cfg["email"] == email assert cfg2["email"] == email
def test_load_config_with_random_name(self): def test_load_config_with_random_name(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
@ -353,13 +354,13 @@ class LoadConfigTest(unittest.TestCase):
cfg = auth.load_config(dockercfg_path).auths cfg = auth.load_config(dockercfg_path).auths
assert registry in cfg assert registry in cfg
assert cfg[registry] is not None assert cfg[registry] is not None
cfg = cfg[registry] cfg2 = cfg[registry]
assert cfg["username"] == "sakuya" assert cfg2["username"] == "sakuya"
assert cfg["password"] == "izayoi" assert cfg2["password"] == "izayoi"
assert cfg["email"] == "sakuya@scarlet.net" assert cfg2["email"] == "sakuya@scarlet.net"
assert cfg.get("auth") is None assert cfg2.get("auth") is None
def test_load_config_custom_config_env(self): def test_load_config_custom_config_env(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
@ -375,13 +376,13 @@ class LoadConfigTest(unittest.TestCase):
cfg = auth.load_config(None).auths cfg = auth.load_config(None).auths
assert registry in cfg assert registry in cfg
assert cfg[registry] is not None assert cfg[registry] is not None
cfg = cfg[registry] cfg2 = cfg[registry]
assert cfg["username"] == "sakuya" assert cfg2["username"] == "sakuya"
assert cfg["password"] == "izayoi" assert cfg2["password"] == "izayoi"
assert cfg["email"] == "sakuya@scarlet.net" assert cfg2["email"] == "sakuya@scarlet.net"
assert cfg.get("auth") is None assert cfg2.get("auth") is None
def test_load_config_custom_config_env_with_auths(self): def test_load_config_custom_config_env_with_auths(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
@ -398,13 +399,13 @@ class LoadConfigTest(unittest.TestCase):
with mock.patch.dict(os.environ, {"DOCKER_CONFIG": folder}): with mock.patch.dict(os.environ, {"DOCKER_CONFIG": folder}):
cfg = auth.load_config(None) cfg = auth.load_config(None)
assert registry in cfg.auths assert registry in cfg.auths
cfg = cfg.auths[registry] cfg2 = cfg.auths[registry]
assert cfg["username"] == "sakuya" assert cfg2["username"] == "sakuya"
assert cfg["password"] == "izayoi" assert cfg2["password"] == "izayoi"
assert cfg["email"] == "sakuya@scarlet.net" assert cfg2["email"] == "sakuya@scarlet.net"
assert cfg.get("auth") is None assert cfg2.get("auth") is None
def test_load_config_custom_config_env_utf8(self): def test_load_config_custom_config_env_utf8(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
@ -421,13 +422,13 @@ class LoadConfigTest(unittest.TestCase):
with mock.patch.dict(os.environ, {"DOCKER_CONFIG": folder}): with mock.patch.dict(os.environ, {"DOCKER_CONFIG": folder}):
cfg = auth.load_config(None) cfg = auth.load_config(None)
assert registry in cfg.auths assert registry in cfg.auths
cfg = cfg.auths[registry] cfg2 = cfg.auths[registry]
assert cfg["username"] == b"sakuya\xc3\xa6".decode("utf8") assert cfg2["username"] == b"sakuya\xc3\xa6".decode("utf8")
assert cfg["password"] == b"izayoi\xc3\xa6".decode("utf8") assert cfg2["password"] == b"izayoi\xc3\xa6".decode("utf8")
assert cfg["email"] == "sakuya@scarlet.net" assert cfg2["email"] == "sakuya@scarlet.net"
assert cfg.get("auth") is None assert cfg2.get("auth") is None
def test_load_config_unknown_keys(self): def test_load_config_unknown_keys(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
dockercfg_path = os.path.join(folder, "config.json") dockercfg_path = os.path.join(folder, "config.json")
@ -438,7 +439,7 @@ class LoadConfigTest(unittest.TestCase):
cfg = auth.load_config(dockercfg_path) cfg = auth.load_config(dockercfg_path)
assert dict(cfg) == {"auths": {}} assert dict(cfg) == {"auths": {}}
def test_load_config_invalid_auth_dict(self): def test_load_config_invalid_auth_dict(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
dockercfg_path = os.path.join(folder, "config.json") dockercfg_path = os.path.join(folder, "config.json")
@ -449,7 +450,7 @@ class LoadConfigTest(unittest.TestCase):
cfg = auth.load_config(dockercfg_path) cfg = auth.load_config(dockercfg_path)
assert dict(cfg) == {"auths": {"scarlet.net": {}}} assert dict(cfg) == {"auths": {"scarlet.net": {}}}
def test_load_config_identity_token(self): def test_load_config_identity_token(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
registry = "scarlet.net" registry = "scarlet.net"
token = "1ce1cebb-503e-7043-11aa-7feb8bd4a1ce" token = "1ce1cebb-503e-7043-11aa-7feb8bd4a1ce"
@ -462,13 +463,13 @@ class LoadConfigTest(unittest.TestCase):
cfg = auth.load_config(dockercfg_path) cfg = auth.load_config(dockercfg_path)
assert registry in cfg.auths assert registry in cfg.auths
cfg = cfg.auths[registry] cfg2 = cfg.auths[registry]
assert "IdentityToken" in cfg assert "IdentityToken" in cfg2
assert cfg["IdentityToken"] == token assert cfg2["IdentityToken"] == token
class CredstoreTest(unittest.TestCase): class CredstoreTest(unittest.TestCase):
def setUp(self): def setUp(self) -> None:
self.authconfig = auth.AuthConfig({"credsStore": "default"}) self.authconfig = auth.AuthConfig({"credsStore": "default"})
self.default_store = InMemoryStore("default") self.default_store = InMemoryStore("default")
self.authconfig._stores["default"] = self.default_store self.authconfig._stores["default"] = self.default_store
@ -483,7 +484,7 @@ class CredstoreTest(unittest.TestCase):
"hunter2", "hunter2",
) )
def test_get_credential_store(self): def test_get_credential_store(self) -> None:
auth_config = auth.AuthConfig( auth_config = auth.AuthConfig(
{ {
"credHelpers": { "credHelpers": {
@ -498,7 +499,7 @@ class CredstoreTest(unittest.TestCase):
assert auth_config.get_credential_store("registry2.io") == "powerlock" assert auth_config.get_credential_store("registry2.io") == "powerlock"
assert auth_config.get_credential_store("registry3.io") == "blackbox" assert auth_config.get_credential_store("registry3.io") == "blackbox"
def test_get_credential_store_no_default(self): def test_get_credential_store_no_default(self) -> None:
auth_config = auth.AuthConfig( auth_config = auth.AuthConfig(
{ {
"credHelpers": { "credHelpers": {
@ -510,7 +511,7 @@ class CredstoreTest(unittest.TestCase):
assert auth_config.get_credential_store("registry2.io") == "powerlock" assert auth_config.get_credential_store("registry2.io") == "powerlock"
assert auth_config.get_credential_store("registry3.io") is None assert auth_config.get_credential_store("registry3.io") is None
def test_get_credential_store_default_index(self): def test_get_credential_store_default_index(self) -> None:
auth_config = auth.AuthConfig( auth_config = auth.AuthConfig(
{ {
"credHelpers": {"https://index.docker.io/v1/": "powerlock"}, "credHelpers": {"https://index.docker.io/v1/": "powerlock"},
@ -522,7 +523,7 @@ class CredstoreTest(unittest.TestCase):
assert auth_config.get_credential_store("docker.io") == "powerlock" assert auth_config.get_credential_store("docker.io") == "powerlock"
assert auth_config.get_credential_store("images.io") == "truesecret" assert auth_config.get_credential_store("images.io") == "truesecret"
def test_get_credential_store_with_plain_dict(self): def test_get_credential_store_with_plain_dict(self) -> None:
auth_config = { auth_config = {
"credHelpers": {"registry1.io": "truesecret", "registry2.io": "powerlock"}, "credHelpers": {"registry1.io": "truesecret", "registry2.io": "powerlock"},
"credsStore": "blackbox", "credsStore": "blackbox",
@ -532,7 +533,7 @@ class CredstoreTest(unittest.TestCase):
assert auth.get_credential_store(auth_config, "registry2.io") == "powerlock" assert auth.get_credential_store(auth_config, "registry2.io") == "powerlock"
assert auth.get_credential_store(auth_config, "registry3.io") == "blackbox" assert auth.get_credential_store(auth_config, "registry3.io") == "blackbox"
def test_get_all_credentials_credstore_only(self): def test_get_all_credentials_credstore_only(self) -> None:
assert self.authconfig.get_all_credentials() == { assert self.authconfig.get_all_credentials() == {
"https://gensokyo.jp/v2": { "https://gensokyo.jp/v2": {
"Username": "sakuya", "Username": "sakuya",
@ -556,7 +557,7 @@ class CredstoreTest(unittest.TestCase):
}, },
} }
def test_get_all_credentials_with_empty_credhelper(self): def test_get_all_credentials_with_empty_credhelper(self) -> None:
self.authconfig["credHelpers"] = { self.authconfig["credHelpers"] = {
"registry1.io": "truesecret", "registry1.io": "truesecret",
} }
@ -585,7 +586,7 @@ class CredstoreTest(unittest.TestCase):
"registry1.io": None, "registry1.io": None,
} }
def test_get_all_credentials_with_credhelpers_only(self): def test_get_all_credentials_with_credhelpers_only(self) -> None:
del self.authconfig["credsStore"] del self.authconfig["credsStore"]
assert self.authconfig.get_all_credentials() == {} assert self.authconfig.get_all_credentials() == {}
@ -617,7 +618,7 @@ class CredstoreTest(unittest.TestCase):
}, },
} }
def test_get_all_credentials_with_auths_entries(self): def test_get_all_credentials_with_auths_entries(self) -> None:
self.authconfig.add_auth( self.authconfig.add_auth(
"registry1.io", "registry1.io",
{ {
@ -655,7 +656,7 @@ class CredstoreTest(unittest.TestCase):
}, },
} }
def test_get_all_credentials_with_empty_auths_entry(self): def test_get_all_credentials_with_empty_auths_entry(self) -> None:
self.authconfig.add_auth("default.com", {}) self.authconfig.add_auth("default.com", {})
assert self.authconfig.get_all_credentials() == { assert self.authconfig.get_all_credentials() == {
@ -681,7 +682,7 @@ class CredstoreTest(unittest.TestCase):
}, },
} }
def test_get_all_credentials_credstore_overrides_auth_entry(self): def test_get_all_credentials_credstore_overrides_auth_entry(self) -> None:
self.authconfig.add_auth( self.authconfig.add_auth(
"default.com", "default.com",
{ {
@ -714,7 +715,7 @@ class CredstoreTest(unittest.TestCase):
}, },
} }
def test_get_all_credentials_helpers_override_default(self): def test_get_all_credentials_helpers_override_default(self) -> None:
self.authconfig["credHelpers"] = { self.authconfig["credHelpers"] = {
"https://default.com/v2": "truesecret", "https://default.com/v2": "truesecret",
} }
@ -744,7 +745,7 @@ class CredstoreTest(unittest.TestCase):
}, },
} }
def test_get_all_credentials_3_sources(self): def test_get_all_credentials_3_sources(self) -> None:
self.authconfig["credHelpers"] = { self.authconfig["credHelpers"] = {
"registry1.io": "truesecret", "registry1.io": "truesecret",
} }
@ -795,24 +796,27 @@ class CredstoreTest(unittest.TestCase):
class InMemoryStore(Store): class InMemoryStore(Store):
def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called def __init__( # pylint: disable=super-init-not-called
self.__store = {} self, *args, **kwargs
) -> None:
self.__store: dict[str | bytes, dict[str, t.Any]] = {}
def get(self, server): def get(self, server: str | bytes) -> dict[str, t.Any]:
try: try:
return self.__store[server] return self.__store[server]
except KeyError: except KeyError:
raise CredentialsNotFound() from None raise CredentialsNotFound() from None
def store(self, server, username, secret): def store(self, server: str, username: str, secret: str) -> bytes:
self.__store[server] = { self.__store[server] = {
"ServerURL": server, "ServerURL": server,
"Username": username, "Username": username,
"Secret": secret, "Secret": secret,
} }
return b""
def list(self): def list(self) -> dict[str | bytes, str]:
return dict((k, v["Username"]) for k, v in self.__store.items()) return dict((k, v["Username"]) for k, v in self.__store.items())
def erase(self, server): def erase(self, server: str | bytes) -> None:
del self.__store[server] del self.__store[server]

View File

@ -28,20 +28,20 @@ from ansible_collections.community.docker.plugins.module_utils._api.context.cont
class BaseContextTest(unittest.TestCase): class BaseContextTest(unittest.TestCase):
@pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="Linux specific path check") @pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="Linux specific path check")
def test_url_compatibility_on_linux(self): def test_url_compatibility_on_linux(self) -> None:
c = Context("test") c = Context("test")
assert c.Host == DEFAULT_UNIX_SOCKET[5:] assert c.Host == DEFAULT_UNIX_SOCKET[5:]
@pytest.mark.skipif(not IS_WINDOWS_PLATFORM, reason="Windows specific path check") @pytest.mark.skipif(not IS_WINDOWS_PLATFORM, reason="Windows specific path check")
def test_url_compatibility_on_windows(self): def test_url_compatibility_on_windows(self) -> None:
c = Context("test") c = Context("test")
assert c.Host == DEFAULT_NPIPE assert c.Host == DEFAULT_NPIPE
def test_fail_on_default_context_create(self): def test_fail_on_default_context_create(self) -> None:
with pytest.raises(errors.ContextException): with pytest.raises(errors.ContextException):
ContextAPI.create_context("default") ContextAPI.create_context("default")
def test_default_in_context_list(self): def test_default_in_context_list(self) -> None:
found = False found = False
ctx = ContextAPI.contexts() ctx = ContextAPI.contexts()
for c in ctx: for c in ctx:
@ -49,14 +49,16 @@ class BaseContextTest(unittest.TestCase):
found = True found = True
assert found is True assert found is True
def test_get_current_context(self): def test_get_current_context(self) -> None:
assert ContextAPI.get_current_context().Name == "default" context = ContextAPI.get_current_context()
assert context is not None
assert context.Name == "default"
def test_https_host(self): def test_https_host(self) -> None:
c = Context("test", host="tcp://testdomain:8080", tls=True) c = Context("test", host="tcp://testdomain:8080", tls=True)
assert c.Host == "https://testdomain:8080" assert c.Host == "https://testdomain:8080"
def test_context_inspect_without_params(self): def test_context_inspect_without_params(self) -> None:
ctx = ContextAPI.inspect_context() ctx = ContextAPI.inspect_context()
assert ctx["Name"] == "default" assert ctx["Name"] == "default"
assert ctx["Metadata"]["StackOrchestrator"] == "swarm" assert ctx["Metadata"]["StackOrchestrator"] == "swarm"

View File

@ -21,97 +21,97 @@ from ansible_collections.community.docker.plugins.module_utils._api.errors impor
class APIErrorTest(unittest.TestCase): class APIErrorTest(unittest.TestCase):
def test_api_error_is_caught_by_dockerexception(self): def test_api_error_is_caught_by_dockerexception(self) -> None:
try: try:
raise APIError("this should be caught by DockerException") raise APIError("this should be caught by DockerException")
except DockerException: except DockerException:
pass pass
def test_status_code_200(self): def test_status_code_200(self) -> None:
"""The status_code property is present with 200 response.""" """The status_code property is present with 200 response."""
resp = requests.Response() resp = requests.Response()
resp.status_code = 200 resp.status_code = 200
err = APIError("", response=resp) err = APIError("", response=resp)
assert err.status_code == 200 assert err.status_code == 200
def test_status_code_400(self): def test_status_code_400(self) -> None:
"""The status_code property is present with 400 response.""" """The status_code property is present with 400 response."""
resp = requests.Response() resp = requests.Response()
resp.status_code = 400 resp.status_code = 400
err = APIError("", response=resp) err = APIError("", response=resp)
assert err.status_code == 400 assert err.status_code == 400
def test_status_code_500(self): def test_status_code_500(self) -> None:
"""The status_code property is present with 500 response.""" """The status_code property is present with 500 response."""
resp = requests.Response() resp = requests.Response()
resp.status_code = 500 resp.status_code = 500
err = APIError("", response=resp) err = APIError("", response=resp)
assert err.status_code == 500 assert err.status_code == 500
def test_is_server_error_200(self): def test_is_server_error_200(self) -> None:
"""Report not server error on 200 response.""" """Report not server error on 200 response."""
resp = requests.Response() resp = requests.Response()
resp.status_code = 200 resp.status_code = 200
err = APIError("", response=resp) err = APIError("", response=resp)
assert err.is_server_error() is False assert err.is_server_error() is False
def test_is_server_error_300(self): def test_is_server_error_300(self) -> None:
"""Report not server error on 300 response.""" """Report not server error on 300 response."""
resp = requests.Response() resp = requests.Response()
resp.status_code = 300 resp.status_code = 300
err = APIError("", response=resp) err = APIError("", response=resp)
assert err.is_server_error() is False assert err.is_server_error() is False
def test_is_server_error_400(self): def test_is_server_error_400(self) -> None:
"""Report not server error on 400 response.""" """Report not server error on 400 response."""
resp = requests.Response() resp = requests.Response()
resp.status_code = 400 resp.status_code = 400
err = APIError("", response=resp) err = APIError("", response=resp)
assert err.is_server_error() is False assert err.is_server_error() is False
def test_is_server_error_500(self): def test_is_server_error_500(self) -> None:
"""Report server error on 500 response.""" """Report server error on 500 response."""
resp = requests.Response() resp = requests.Response()
resp.status_code = 500 resp.status_code = 500
err = APIError("", response=resp) err = APIError("", response=resp)
assert err.is_server_error() is True assert err.is_server_error() is True
def test_is_client_error_500(self): def test_is_client_error_500(self) -> None:
"""Report not client error on 500 response.""" """Report not client error on 500 response."""
resp = requests.Response() resp = requests.Response()
resp.status_code = 500 resp.status_code = 500
err = APIError("", response=resp) err = APIError("", response=resp)
assert err.is_client_error() is False assert err.is_client_error() is False
def test_is_client_error_400(self): def test_is_client_error_400(self) -> None:
"""Report client error on 400 response.""" """Report client error on 400 response."""
resp = requests.Response() resp = requests.Response()
resp.status_code = 400 resp.status_code = 400
err = APIError("", response=resp) err = APIError("", response=resp)
assert err.is_client_error() is True assert err.is_client_error() is True
def test_is_error_300(self): def test_is_error_300(self) -> None:
"""Report no error on 300 response.""" """Report no error on 300 response."""
resp = requests.Response() resp = requests.Response()
resp.status_code = 300 resp.status_code = 300
err = APIError("", response=resp) err = APIError("", response=resp)
assert err.is_error() is False assert err.is_error() is False
def test_is_error_400(self): def test_is_error_400(self) -> None:
"""Report error on 400 response.""" """Report error on 400 response."""
resp = requests.Response() resp = requests.Response()
resp.status_code = 400 resp.status_code = 400
err = APIError("", response=resp) err = APIError("", response=resp)
assert err.is_error() is True assert err.is_error() is True
def test_is_error_500(self): def test_is_error_500(self) -> None:
"""Report error on 500 response.""" """Report error on 500 response."""
resp = requests.Response() resp = requests.Response()
resp.status_code = 500 resp.status_code = 500
err = APIError("", response=resp) err = APIError("", response=resp)
assert err.is_error() is True assert err.is_error() is True
def test_create_error_from_exception(self): def test_create_error_from_exception(self) -> None:
resp = requests.Response() resp = requests.Response()
resp.status_code = 500 resp.status_code = 500
err = APIError("") err = APIError("")
@ -126,10 +126,10 @@ class APIErrorTest(unittest.TestCase):
class CreateUnexpectedKwargsErrorTest(unittest.TestCase): class CreateUnexpectedKwargsErrorTest(unittest.TestCase):
def test_create_unexpected_kwargs_error_single(self): def test_create_unexpected_kwargs_error_single(self) -> None:
e = create_unexpected_kwargs_error("f", {"foo": "bar"}) e = create_unexpected_kwargs_error("f", {"foo": "bar"})
assert str(e) == "f() got an unexpected keyword argument 'foo'" assert str(e) == "f() got an unexpected keyword argument 'foo'"
def test_create_unexpected_kwargs_error_multiple(self): def test_create_unexpected_kwargs_error_multiple(self) -> None:
e = create_unexpected_kwargs_error("f", {"foo": "bar", "baz": "bosh"}) e = create_unexpected_kwargs_error("f", {"foo": "bar", "baz": "bosh"})
assert str(e) == "f() got unexpected keyword arguments 'baz', 'foo'" assert str(e) == "f() got unexpected keyword arguments 'baz', 'foo'"

View File

@ -18,33 +18,33 @@ from ansible_collections.community.docker.plugins.module_utils._api.transport.ss
class SSHAdapterTest(unittest.TestCase): class SSHAdapterTest(unittest.TestCase):
@staticmethod @staticmethod
def test_ssh_hostname_prefix_trim(): def test_ssh_hostname_prefix_trim() -> None:
conn = SSHHTTPAdapter(base_url="ssh://user@hostname:1234", shell_out=True) conn = SSHHTTPAdapter(base_url="ssh://user@hostname:1234", shell_out=True)
assert conn.ssh_host == "user@hostname:1234" assert conn.ssh_host == "user@hostname:1234"
@staticmethod @staticmethod
def test_ssh_parse_url(): def test_ssh_parse_url() -> None:
c = SSHSocket(host="user@hostname:1234") c = SSHSocket(host="user@hostname:1234")
assert c.host == "hostname" assert c.host == "hostname"
assert c.port == "1234" assert c.port == "1234"
assert c.user == "user" assert c.user == "user"
@staticmethod @staticmethod
def test_ssh_parse_hostname_only(): def test_ssh_parse_hostname_only() -> None:
c = SSHSocket(host="hostname") c = SSHSocket(host="hostname")
assert c.host == "hostname" assert c.host == "hostname"
assert c.port is None assert c.port is None
assert c.user is None assert c.user is None
@staticmethod @staticmethod
def test_ssh_parse_user_and_hostname(): def test_ssh_parse_user_and_hostname() -> None:
c = SSHSocket(host="user@hostname") c = SSHSocket(host="user@hostname")
assert c.host == "hostname" assert c.host == "hostname"
assert c.port is None assert c.port is None
assert c.user == "user" assert c.user == "user"
@staticmethod @staticmethod
def test_ssh_parse_hostname_and_port(): def test_ssh_parse_hostname_and_port() -> None:
c = SSHSocket(host="hostname:22") c = SSHSocket(host="hostname:22")
assert c.host == "hostname" assert c.host == "hostname"
assert c.port == "22" assert c.port == "22"

View File

@ -27,7 +27,7 @@ else:
class SSLAdapterTest(unittest.TestCase): class SSLAdapterTest(unittest.TestCase):
def test_only_uses_tls(self): def test_only_uses_tls(self) -> None:
ssl_context = ssladapter.urllib3.util.ssl_.create_urllib3_context() ssl_context = ssladapter.urllib3.util.ssl_.create_urllib3_context()
assert ssl_context.options & OP_NO_SSLv3 assert ssl_context.options & OP_NO_SSLv3
@ -68,19 +68,19 @@ class MatchHostnameTest(unittest.TestCase):
"version": 3, "version": 3,
} }
def test_match_ip_address_success(self): def test_match_ip_address_success(self) -> None:
assert match_hostname(self.cert, "127.0.0.1") is None assert match_hostname(self.cert, "127.0.0.1") is None
def test_match_localhost_success(self): def test_match_localhost_success(self) -> None:
assert match_hostname(self.cert, "localhost") is None assert match_hostname(self.cert, "localhost") is None
def test_match_dns_success(self): def test_match_dns_success(self) -> None:
assert match_hostname(self.cert, "touhou.gensokyo.jp") is None assert match_hostname(self.cert, "touhou.gensokyo.jp") is None
def test_match_ip_address_failure(self): def test_match_ip_address_failure(self) -> None:
with pytest.raises(CertificateError): with pytest.raises(CertificateError):
match_hostname(self.cert, "192.168.0.25") match_hostname(self.cert, "192.168.0.25")
def test_match_dns_failure(self): def test_match_dns_failure(self) -> None:
with pytest.raises(CertificateError): with pytest.raises(CertificateError):
match_hostname(self.cert, "foobar.co.uk") match_hostname(self.cert, "foobar.co.uk")

View File

@ -14,6 +14,7 @@ import shutil
import socket import socket
import tarfile import tarfile
import tempfile import tempfile
import typing as t
import unittest import unittest
import pytest import pytest
@ -27,7 +28,11 @@ from ansible_collections.community.docker.plugins.module_utils._api.utils.build
) )
def make_tree(dirs, files): if t.TYPE_CHECKING:
from collections.abc import Collection
def make_tree(dirs: list[str], files: list[str]) -> str:
base = tempfile.mkdtemp() base = tempfile.mkdtemp()
for path in dirs: for path in dirs:
@ -40,11 +45,11 @@ def make_tree(dirs, files):
return base return base
def convert_paths(collection): def convert_paths(collection: Collection[str]) -> set[str]:
return set(map(convert_path, collection)) return set(map(convert_path, collection))
def convert_path(path): def convert_path(path: str) -> str:
return path.replace("/", os.path.sep) return path.replace("/", os.path.sep)
@ -88,26 +93,26 @@ class ExcludePathsTest(unittest.TestCase):
all_paths = set(dirs + files) all_paths = set(dirs + files)
def setUp(self): def setUp(self) -> None:
self.base = make_tree(self.dirs, self.files) self.base = make_tree(self.dirs, self.files)
def tearDown(self): def tearDown(self) -> None:
shutil.rmtree(self.base) shutil.rmtree(self.base)
def exclude(self, patterns, dockerfile=None): def exclude(self, patterns: list[str], dockerfile: str | None = None) -> set[str]:
return set(exclude_paths(self.base, patterns, dockerfile=dockerfile)) return set(exclude_paths(self.base, patterns, dockerfile=dockerfile))
def test_no_excludes(self): def test_no_excludes(self) -> None:
assert self.exclude([""]) == convert_paths(self.all_paths) assert self.exclude([""]) == convert_paths(self.all_paths)
def test_no_dupes(self): def test_no_dupes(self) -> None:
paths = exclude_paths(self.base, ["!a.py"]) paths = exclude_paths(self.base, ["!a.py"])
assert sorted(paths) == sorted(set(paths)) assert sorted(paths) == sorted(set(paths))
def test_wildcard_exclude(self): def test_wildcard_exclude(self) -> None:
assert self.exclude(["*"]) == set(["Dockerfile", ".dockerignore"]) assert self.exclude(["*"]) == set(["Dockerfile", ".dockerignore"])
def test_exclude_dockerfile_dockerignore(self): def test_exclude_dockerfile_dockerignore(self) -> None:
""" """
Even if the .dockerignore file explicitly says to exclude Even if the .dockerignore file explicitly says to exclude
Dockerfile and/or .dockerignore, don't exclude them from Dockerfile and/or .dockerignore, don't exclude them from
@ -117,7 +122,7 @@ class ExcludePathsTest(unittest.TestCase):
self.all_paths self.all_paths
) )
def test_exclude_custom_dockerfile(self): def test_exclude_custom_dockerfile(self) -> None:
""" """
If we're using a custom Dockerfile, make sure that's not If we're using a custom Dockerfile, make sure that's not
excluded. excluded.
@ -135,20 +140,20 @@ class ExcludePathsTest(unittest.TestCase):
set(["foo/Dockerfile3", ".dockerignore"]) set(["foo/Dockerfile3", ".dockerignore"])
) )
def test_exclude_dockerfile_child(self): def test_exclude_dockerfile_child(self) -> None:
includes = self.exclude(["foo/"], dockerfile="foo/Dockerfile3") includes = self.exclude(["foo/"], dockerfile="foo/Dockerfile3")
assert convert_path("foo/Dockerfile3") in includes assert convert_path("foo/Dockerfile3") in includes
assert convert_path("foo/a.py") not in includes assert convert_path("foo/a.py") not in includes
def test_single_filename(self): def test_single_filename(self) -> None:
assert self.exclude(["a.py"]) == convert_paths(self.all_paths - set(["a.py"])) assert self.exclude(["a.py"]) == convert_paths(self.all_paths - set(["a.py"]))
def test_single_filename_leading_dot_slash(self): def test_single_filename_leading_dot_slash(self) -> None:
assert self.exclude(["./a.py"]) == convert_paths(self.all_paths - set(["a.py"])) assert self.exclude(["./a.py"]) == convert_paths(self.all_paths - set(["a.py"]))
# As odd as it sounds, a filename pattern with a trailing slash on the # As odd as it sounds, a filename pattern with a trailing slash on the
# end *will* result in that file being excluded. # end *will* result in that file being excluded.
def test_single_filename_trailing_slash(self): def test_single_filename_trailing_slash(self) -> None:
assert self.exclude(["a.py/"]) == convert_paths(self.all_paths - set(["a.py"])) assert self.exclude(["a.py/"]) == convert_paths(self.all_paths - set(["a.py"]))
def test_wildcard_filename_start(self): def test_wildcard_filename_start(self):
@ -156,12 +161,12 @@ class ExcludePathsTest(unittest.TestCase):
self.all_paths - set(["a.py", "b.py", "cde.py"]) self.all_paths - set(["a.py", "b.py", "cde.py"])
) )
def test_wildcard_with_exception(self): def test_wildcard_with_exception(self) -> None:
assert self.exclude(["*.py", "!b.py"]) == convert_paths( assert self.exclude(["*.py", "!b.py"]) == convert_paths(
self.all_paths - set(["a.py", "cde.py"]) self.all_paths - set(["a.py", "cde.py"])
) )
def test_wildcard_with_wildcard_exception(self): def test_wildcard_with_wildcard_exception(self) -> None:
assert self.exclude(["*.*", "!*.go"]) == convert_paths( assert self.exclude(["*.*", "!*.go"]) == convert_paths(
self.all_paths self.all_paths
- set( - set(
@ -174,51 +179,51 @@ class ExcludePathsTest(unittest.TestCase):
) )
) )
def test_wildcard_filename_end(self): def test_wildcard_filename_end(self) -> None:
assert self.exclude(["a.*"]) == convert_paths( assert self.exclude(["a.*"]) == convert_paths(
self.all_paths - set(["a.py", "a.go"]) self.all_paths - set(["a.py", "a.go"])
) )
def test_question_mark(self): def test_question_mark(self) -> None:
assert self.exclude(["?.py"]) == convert_paths( assert self.exclude(["?.py"]) == convert_paths(
self.all_paths - set(["a.py", "b.py"]) self.all_paths - set(["a.py", "b.py"])
) )
def test_single_subdir_single_filename(self): def test_single_subdir_single_filename(self) -> None:
assert self.exclude(["foo/a.py"]) == convert_paths( assert self.exclude(["foo/a.py"]) == convert_paths(
self.all_paths - set(["foo/a.py"]) self.all_paths - set(["foo/a.py"])
) )
def test_single_subdir_single_filename_leading_slash(self): def test_single_subdir_single_filename_leading_slash(self) -> None:
assert self.exclude(["/foo/a.py"]) == convert_paths( assert self.exclude(["/foo/a.py"]) == convert_paths(
self.all_paths - set(["foo/a.py"]) self.all_paths - set(["foo/a.py"])
) )
def test_exclude_include_absolute_path(self): def test_exclude_include_absolute_path(self) -> None:
base = make_tree([], ["a.py", "b.py"]) base = make_tree([], ["a.py", "b.py"])
assert exclude_paths(base, ["/*", "!/*.py"]) == set(["a.py", "b.py"]) assert exclude_paths(base, ["/*", "!/*.py"]) == set(["a.py", "b.py"])
def test_single_subdir_with_path_traversal(self): def test_single_subdir_with_path_traversal(self) -> None:
assert self.exclude(["foo/whoops/../a.py"]) == convert_paths( assert self.exclude(["foo/whoops/../a.py"]) == convert_paths(
self.all_paths - set(["foo/a.py"]) self.all_paths - set(["foo/a.py"])
) )
def test_single_subdir_wildcard_filename(self): def test_single_subdir_wildcard_filename(self) -> None:
assert self.exclude(["foo/*.py"]) == convert_paths( assert self.exclude(["foo/*.py"]) == convert_paths(
self.all_paths - set(["foo/a.py", "foo/b.py"]) self.all_paths - set(["foo/a.py", "foo/b.py"])
) )
def test_wildcard_subdir_single_filename(self): def test_wildcard_subdir_single_filename(self) -> None:
assert self.exclude(["*/a.py"]) == convert_paths( assert self.exclude(["*/a.py"]) == convert_paths(
self.all_paths - set(["foo/a.py", "bar/a.py"]) self.all_paths - set(["foo/a.py", "bar/a.py"])
) )
def test_wildcard_subdir_wildcard_filename(self): def test_wildcard_subdir_wildcard_filename(self) -> None:
assert self.exclude(["*/*.py"]) == convert_paths( assert self.exclude(["*/*.py"]) == convert_paths(
self.all_paths - set(["foo/a.py", "foo/b.py", "bar/a.py"]) self.all_paths - set(["foo/a.py", "foo/b.py", "bar/a.py"])
) )
def test_directory(self): def test_directory(self) -> None:
assert self.exclude(["foo"]) == convert_paths( assert self.exclude(["foo"]) == convert_paths(
self.all_paths self.all_paths
- set( - set(
@ -233,7 +238,7 @@ class ExcludePathsTest(unittest.TestCase):
) )
) )
def test_directory_with_trailing_slash(self): def test_directory_with_trailing_slash(self) -> None:
assert self.exclude(["foo"]) == convert_paths( assert self.exclude(["foo"]) == convert_paths(
self.all_paths self.all_paths
- set( - set(
@ -248,13 +253,13 @@ class ExcludePathsTest(unittest.TestCase):
) )
) )
def test_directory_with_single_exception(self): def test_directory_with_single_exception(self) -> None:
assert self.exclude(["foo", "!foo/bar/a.py"]) == convert_paths( assert self.exclude(["foo", "!foo/bar/a.py"]) == convert_paths(
self.all_paths self.all_paths
- set(["foo/a.py", "foo/b.py", "foo", "foo/bar", "foo/Dockerfile3"]) - set(["foo/a.py", "foo/b.py", "foo", "foo/bar", "foo/Dockerfile3"])
) )
def test_directory_with_subdir_exception(self): def test_directory_with_subdir_exception(self) -> None:
assert self.exclude(["foo", "!foo/bar"]) == convert_paths( assert self.exclude(["foo", "!foo/bar"]) == convert_paths(
self.all_paths - set(["foo/a.py", "foo/b.py", "foo", "foo/Dockerfile3"]) self.all_paths - set(["foo/a.py", "foo/b.py", "foo", "foo/Dockerfile3"])
) )
@ -262,17 +267,17 @@ class ExcludePathsTest(unittest.TestCase):
@pytest.mark.skipif( @pytest.mark.skipif(
not IS_WINDOWS_PLATFORM, reason="Backslash patterns only on Windows" not IS_WINDOWS_PLATFORM, reason="Backslash patterns only on Windows"
) )
def test_directory_with_subdir_exception_win32_pathsep(self): def test_directory_with_subdir_exception_win32_pathsep(self) -> None:
assert self.exclude(["foo", "!foo\\bar"]) == convert_paths( assert self.exclude(["foo", "!foo\\bar"]) == convert_paths(
self.all_paths - set(["foo/a.py", "foo/b.py", "foo", "foo/Dockerfile3"]) self.all_paths - set(["foo/a.py", "foo/b.py", "foo", "foo/Dockerfile3"])
) )
def test_directory_with_wildcard_exception(self): def test_directory_with_wildcard_exception(self) -> None:
assert self.exclude(["foo", "!foo/*.py"]) == convert_paths( assert self.exclude(["foo", "!foo/*.py"]) == convert_paths(
self.all_paths - set(["foo/bar", "foo/bar/a.py", "foo", "foo/Dockerfile3"]) self.all_paths - set(["foo/bar", "foo/bar/a.py", "foo", "foo/Dockerfile3"])
) )
def test_subdirectory(self): def test_subdirectory(self) -> None:
assert self.exclude(["foo/bar"]) == convert_paths( assert self.exclude(["foo/bar"]) == convert_paths(
self.all_paths - set(["foo/bar", "foo/bar/a.py"]) self.all_paths - set(["foo/bar", "foo/bar/a.py"])
) )
@ -280,12 +285,12 @@ class ExcludePathsTest(unittest.TestCase):
@pytest.mark.skipif( @pytest.mark.skipif(
not IS_WINDOWS_PLATFORM, reason="Backslash patterns only on Windows" not IS_WINDOWS_PLATFORM, reason="Backslash patterns only on Windows"
) )
def test_subdirectory_win32_pathsep(self): def test_subdirectory_win32_pathsep(self) -> None:
assert self.exclude(["foo\\bar"]) == convert_paths( assert self.exclude(["foo\\bar"]) == convert_paths(
self.all_paths - set(["foo/bar", "foo/bar/a.py"]) self.all_paths - set(["foo/bar", "foo/bar/a.py"])
) )
def test_double_wildcard(self): def test_double_wildcard(self) -> None:
assert self.exclude(["**/a.py"]) == convert_paths( assert self.exclude(["**/a.py"]) == convert_paths(
self.all_paths - set(["a.py", "foo/a.py", "foo/bar/a.py", "bar/a.py"]) self.all_paths - set(["a.py", "foo/a.py", "foo/bar/a.py", "bar/a.py"])
) )
@ -294,7 +299,7 @@ class ExcludePathsTest(unittest.TestCase):
self.all_paths - set(["foo/bar", "foo/bar/a.py"]) self.all_paths - set(["foo/bar", "foo/bar/a.py"])
) )
def test_single_and_double_wildcard(self): def test_single_and_double_wildcard(self) -> None:
assert self.exclude(["**/target/*/*"]) == convert_paths( assert self.exclude(["**/target/*/*"]) == convert_paths(
self.all_paths self.all_paths
- set( - set(
@ -306,7 +311,7 @@ class ExcludePathsTest(unittest.TestCase):
) )
) )
def test_trailing_double_wildcard(self): def test_trailing_double_wildcard(self) -> None:
assert self.exclude(["subdir/**"]) == convert_paths( assert self.exclude(["subdir/**"]) == convert_paths(
self.all_paths self.all_paths
- set( - set(
@ -326,7 +331,7 @@ class ExcludePathsTest(unittest.TestCase):
) )
) )
def test_double_wildcard_with_exception(self): def test_double_wildcard_with_exception(self) -> None:
assert self.exclude(["**", "!bar", "!foo/bar"]) == convert_paths( assert self.exclude(["**", "!bar", "!foo/bar"]) == convert_paths(
set( set(
[ [
@ -340,13 +345,13 @@ class ExcludePathsTest(unittest.TestCase):
) )
) )
def test_include_wildcard(self): def test_include_wildcard(self) -> None:
# This may be surprising but it matches the CLI's behavior # This may be surprising but it matches the CLI's behavior
# (tested with 18.05.0-ce on linux) # (tested with 18.05.0-ce on linux)
base = make_tree(["a"], ["a/b.py"]) base = make_tree(["a"], ["a/b.py"])
assert exclude_paths(base, ["*", "!*/b.py"]) == set() assert exclude_paths(base, ["*", "!*/b.py"]) == set()
def test_last_line_precedence(self): def test_last_line_precedence(self) -> None:
base = make_tree( base = make_tree(
[], [],
[ [
@ -361,7 +366,7 @@ class ExcludePathsTest(unittest.TestCase):
["README.md", "README-bis.md"] ["README.md", "README-bis.md"]
) )
def test_parent_directory(self): def test_parent_directory(self) -> None:
base = make_tree([], ["a.py", "b.py", "c.py"]) base = make_tree([], ["a.py", "b.py", "c.py"])
# Dockerignore reference stipulates that absolute paths are # Dockerignore reference stipulates that absolute paths are
# equivalent to relative paths, hence /../foo should be # equivalent to relative paths, hence /../foo should be
@ -372,7 +377,7 @@ class ExcludePathsTest(unittest.TestCase):
class TarTest(unittest.TestCase): class TarTest(unittest.TestCase):
def test_tar_with_excludes(self): def test_tar_with_excludes(self) -> None:
dirs = [ dirs = [
"foo", "foo",
"foo/bar", "foo/bar",
@ -420,7 +425,7 @@ class TarTest(unittest.TestCase):
with tarfile.open(fileobj=archive) as tar_data: with tarfile.open(fileobj=archive) as tar_data:
assert sorted(tar_data.getnames()) == sorted(expected_names) assert sorted(tar_data.getnames()) == sorted(expected_names)
def test_tar_with_empty_directory(self): def test_tar_with_empty_directory(self) -> None:
base = tempfile.mkdtemp() base = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, base) self.addCleanup(shutil.rmtree, base)
for d in ["foo", "bar"]: for d in ["foo", "bar"]:
@ -433,7 +438,7 @@ class TarTest(unittest.TestCase):
IS_WINDOWS_PLATFORM or os.geteuid() == 0, IS_WINDOWS_PLATFORM or os.geteuid() == 0,
reason="root user always has access ; no chmod on Windows", reason="root user always has access ; no chmod on Windows",
) )
def test_tar_with_inaccessible_file(self): def test_tar_with_inaccessible_file(self) -> None:
base = tempfile.mkdtemp() base = tempfile.mkdtemp()
full_path = os.path.join(base, "foo") full_path = os.path.join(base, "foo")
self.addCleanup(shutil.rmtree, base) self.addCleanup(shutil.rmtree, base)
@ -446,7 +451,7 @@ class TarTest(unittest.TestCase):
assert f"Can not read file in context: {full_path}" in ei.exconly() assert f"Can not read file in context: {full_path}" in ei.exconly()
@pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="No symlinks on Windows") @pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="No symlinks on Windows")
def test_tar_with_file_symlinks(self): def test_tar_with_file_symlinks(self) -> None:
base = tempfile.mkdtemp() base = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, base) self.addCleanup(shutil.rmtree, base)
with open(os.path.join(base, "foo"), "wt", encoding="utf-8") as f: with open(os.path.join(base, "foo"), "wt", encoding="utf-8") as f:
@ -458,7 +463,7 @@ class TarTest(unittest.TestCase):
assert sorted(tar_data.getnames()) == ["bar", "bar/foo", "foo"] assert sorted(tar_data.getnames()) == ["bar", "bar/foo", "foo"]
@pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="No symlinks on Windows") @pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="No symlinks on Windows")
def test_tar_with_directory_symlinks(self): def test_tar_with_directory_symlinks(self) -> None:
base = tempfile.mkdtemp() base = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, base) self.addCleanup(shutil.rmtree, base)
for d in ["foo", "bar"]: for d in ["foo", "bar"]:
@ -469,7 +474,7 @@ class TarTest(unittest.TestCase):
assert sorted(tar_data.getnames()) == ["bar", "bar/foo", "foo"] assert sorted(tar_data.getnames()) == ["bar", "bar/foo", "foo"]
@pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="No symlinks on Windows") @pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="No symlinks on Windows")
def test_tar_with_broken_symlinks(self): def test_tar_with_broken_symlinks(self) -> None:
base = tempfile.mkdtemp() base = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, base) self.addCleanup(shutil.rmtree, base)
for d in ["foo", "bar"]: for d in ["foo", "bar"]:
@ -481,7 +486,7 @@ class TarTest(unittest.TestCase):
assert sorted(tar_data.getnames()) == ["bar", "bar/foo", "foo"] assert sorted(tar_data.getnames()) == ["bar", "bar/foo", "foo"]
@pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="No UNIX sockets on Win32") @pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="No UNIX sockets on Win32")
def test_tar_socket_file(self): def test_tar_socket_file(self) -> None:
base = tempfile.mkdtemp() base = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, base) self.addCleanup(shutil.rmtree, base)
for d in ["foo", "bar"]: for d in ["foo", "bar"]:
@ -493,7 +498,7 @@ class TarTest(unittest.TestCase):
with tarfile.open(fileobj=archive) as tar_data: with tarfile.open(fileobj=archive) as tar_data:
assert sorted(tar_data.getnames()) == ["bar", "foo"] assert sorted(tar_data.getnames()) == ["bar", "foo"]
def tar_test_negative_mtime_bug(self): def tar_test_negative_mtime_bug(self) -> None:
base = tempfile.mkdtemp() base = tempfile.mkdtemp()
filename = os.path.join(base, "th.txt") filename = os.path.join(base, "th.txt")
self.addCleanup(shutil.rmtree, base) self.addCleanup(shutil.rmtree, base)
@ -506,7 +511,7 @@ class TarTest(unittest.TestCase):
assert tar_data.getmember("th.txt").mtime == -3600 assert tar_data.getmember("th.txt").mtime == -3600
@pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="No symlinks on Windows") @pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason="No symlinks on Windows")
def test_tar_directory_link(self): def test_tar_directory_link(self) -> None:
dirs = ["a", "b", "a/c"] dirs = ["a", "b", "a/c"]
files = ["a/hello.py", "b/utils.py", "a/c/descend.py"] files = ["a/hello.py", "b/utils.py", "a/c/descend.py"]
base = make_tree(dirs, files) base = make_tree(dirs, files)

View File

@ -25,55 +25,55 @@ class FindConfigFileTest(unittest.TestCase):
mkdir: Callable[[str], os.PathLike[str]] mkdir: Callable[[str], os.PathLike[str]]
@fixture(autouse=True) @fixture(autouse=True)
def tmpdir(self, tmpdir): def tmpdir(self, tmpdir) -> None:
self.mkdir = tmpdir.mkdir self.mkdir = tmpdir.mkdir
def test_find_config_fallback(self): def test_find_config_fallback(self) -> None:
tmpdir = self.mkdir("test_find_config_fallback") tmpdir = self.mkdir("test_find_config_fallback")
with mock.patch.dict(os.environ, {"HOME": str(tmpdir)}): with mock.patch.dict(os.environ, {"HOME": str(tmpdir)}):
assert config.find_config_file() is None assert config.find_config_file() is None
def test_find_config_from_explicit_path(self): def test_find_config_from_explicit_path(self) -> None:
tmpdir = self.mkdir("test_find_config_from_explicit_path") tmpdir = self.mkdir("test_find_config_from_explicit_path")
config_path = tmpdir.ensure("my-config-file.json") config_path = tmpdir.ensure("my-config-file.json") # type: ignore[attr-defined]
assert config.find_config_file(str(config_path)) == str(config_path) assert config.find_config_file(str(config_path)) == str(config_path)
def test_find_config_from_environment(self): def test_find_config_from_environment(self) -> None:
tmpdir = self.mkdir("test_find_config_from_environment") tmpdir = self.mkdir("test_find_config_from_environment")
config_path = tmpdir.ensure("config.json") config_path = tmpdir.ensure("config.json") # type: ignore[attr-defined]
with mock.patch.dict(os.environ, {"DOCKER_CONFIG": str(tmpdir)}): with mock.patch.dict(os.environ, {"DOCKER_CONFIG": str(tmpdir)}):
assert config.find_config_file() == str(config_path) assert config.find_config_file() == str(config_path)
@mark.skipif("sys.platform == 'win32'") @mark.skipif("sys.platform == 'win32'")
def test_find_config_from_home_posix(self): def test_find_config_from_home_posix(self) -> None:
tmpdir = self.mkdir("test_find_config_from_home_posix") tmpdir = self.mkdir("test_find_config_from_home_posix")
config_path = tmpdir.ensure(".docker", "config.json") config_path = tmpdir.ensure(".docker", "config.json") # type: ignore[attr-defined]
with mock.patch.dict(os.environ, {"HOME": str(tmpdir)}): with mock.patch.dict(os.environ, {"HOME": str(tmpdir)}):
assert config.find_config_file() == str(config_path) assert config.find_config_file() == str(config_path)
@mark.skipif("sys.platform == 'win32'") @mark.skipif("sys.platform == 'win32'")
def test_find_config_from_home_legacy_name(self): def test_find_config_from_home_legacy_name(self) -> None:
tmpdir = self.mkdir("test_find_config_from_home_legacy_name") tmpdir = self.mkdir("test_find_config_from_home_legacy_name")
config_path = tmpdir.ensure(".dockercfg") config_path = tmpdir.ensure(".dockercfg") # type: ignore[attr-defined]
with mock.patch.dict(os.environ, {"HOME": str(tmpdir)}): with mock.patch.dict(os.environ, {"HOME": str(tmpdir)}):
assert config.find_config_file() == str(config_path) assert config.find_config_file() == str(config_path)
@mark.skipif("sys.platform != 'win32'") @mark.skipif("sys.platform != 'win32'")
def test_find_config_from_home_windows(self): def test_find_config_from_home_windows(self) -> None:
tmpdir = self.mkdir("test_find_config_from_home_windows") tmpdir = self.mkdir("test_find_config_from_home_windows")
config_path = tmpdir.ensure(".docker", "config.json") config_path = tmpdir.ensure(".docker", "config.json") # type: ignore[attr-defined]
with mock.patch.dict(os.environ, {"USERPROFILE": str(tmpdir)}): with mock.patch.dict(os.environ, {"USERPROFILE": str(tmpdir)}):
assert config.find_config_file() == str(config_path) assert config.find_config_file() == str(config_path)
class LoadConfigTest(unittest.TestCase): class LoadConfigTest(unittest.TestCase):
def test_load_config_no_file(self): def test_load_config_no_file(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
cfg = config.load_general_config(folder) cfg = config.load_general_config(folder)
@ -81,7 +81,7 @@ class LoadConfigTest(unittest.TestCase):
assert isinstance(cfg, dict) assert isinstance(cfg, dict)
assert not cfg assert not cfg
def test_load_config_custom_headers(self): def test_load_config_custom_headers(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
@ -97,7 +97,7 @@ class LoadConfigTest(unittest.TestCase):
assert "HttpHeaders" in cfg assert "HttpHeaders" in cfg
assert cfg["HttpHeaders"] == {"Name": "Spike", "Surname": "Spiegel"} assert cfg["HttpHeaders"] == {"Name": "Spike", "Surname": "Spiegel"}
def test_load_config_detach_keys(self): def test_load_config_detach_keys(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
dockercfg_path = os.path.join(folder, "config.json") dockercfg_path = os.path.join(folder, "config.json")
@ -108,7 +108,7 @@ class LoadConfigTest(unittest.TestCase):
cfg = config.load_general_config(dockercfg_path) cfg = config.load_general_config(dockercfg_path)
assert cfg == config_data assert cfg == config_data
def test_load_config_from_env(self): def test_load_config_from_env(self) -> None:
folder = tempfile.mkdtemp() folder = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, folder) self.addCleanup(shutil.rmtree, folder)
dockercfg_path = os.path.join(folder, "config.json") dockercfg_path = os.path.join(folder, "config.json")

View File

@ -22,7 +22,7 @@ from ansible_collections.community.docker.tests.unit.plugins.module_utils._api.c
class DecoratorsTest(unittest.TestCase): class DecoratorsTest(unittest.TestCase):
def test_update_headers(self): def test_update_headers(self) -> None:
sample_headers = { sample_headers = {
"X-Docker-Locale": "en-US", "X-Docker-Locale": "en-US",
} }

View File

@ -8,6 +8,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from ansible_collections.community.docker.plugins.module_utils._api.utils.json_stream import ( from ansible_collections.community.docker.plugins.module_utils._api.utils.json_stream import (
json_splitter, json_splitter,
json_stream, json_stream,
@ -15,41 +17,48 @@ from ansible_collections.community.docker.plugins.module_utils._api.utils.json_s
) )
class TestJsonSplitter: if t.TYPE_CHECKING:
T = t.TypeVar("T")
def test_json_splitter_no_object(self):
def create_generator(input_sequence: list[T]) -> t.Generator[T]:
yield from input_sequence
class TestJsonSplitter:
def test_json_splitter_no_object(self) -> None:
data = '{"foo": "bar' data = '{"foo": "bar'
assert json_splitter(data) is None assert json_splitter(data) is None
def test_json_splitter_with_object(self): def test_json_splitter_with_object(self) -> None:
data = '{"foo": "bar"}\n \n{"next": "obj"}' data = '{"foo": "bar"}\n \n{"next": "obj"}'
assert json_splitter(data) == ({"foo": "bar"}, '{"next": "obj"}') assert json_splitter(data) == ({"foo": "bar"}, '{"next": "obj"}')
def test_json_splitter_leading_whitespace(self): def test_json_splitter_leading_whitespace(self) -> None:
data = '\n \r{"foo": "bar"}\n\n {"next": "obj"}' data = '\n \r{"foo": "bar"}\n\n {"next": "obj"}'
assert json_splitter(data) == ({"foo": "bar"}, '{"next": "obj"}') assert json_splitter(data) == ({"foo": "bar"}, '{"next": "obj"}')
class TestStreamAsText: class TestStreamAsText:
def test_stream_with_non_utf_unicode_character(self) -> None:
def test_stream_with_non_utf_unicode_character(self): stream = create_generator([b"\xed\xf3\xf3"])
stream = [b"\xed\xf3\xf3"]
(output,) = stream_as_text(stream) (output,) = stream_as_text(stream)
assert output == "<EFBFBD><EFBFBD><EFBFBD>" assert output == "<EFBFBD><EFBFBD><EFBFBD>"
def test_stream_with_utf_character(self): def test_stream_with_utf_character(self) -> None:
stream = ["ěĝ".encode("utf-8")] stream = create_generator(["ěĝ".encode("utf-8")])
(output,) = stream_as_text(stream) (output,) = stream_as_text(stream)
assert output == "ěĝ" assert output == "ěĝ"
class TestJsonStream: class TestJsonStream:
def test_with_falsy_entries(self) -> None:
def test_with_falsy_entries(self): stream = create_generator(
stream = [ [
'{"one": "two"}\n{}\n', '{"one": "two"}\n{}\n',
"[1, 2, 3]\n[]\n", "[1, 2, 3]\n[]\n",
] ]
)
output = list(json_stream(stream)) output = list(json_stream(stream))
assert output == [ assert output == [
{"one": "two"}, {"one": "two"},
@ -58,7 +67,9 @@ class TestJsonStream:
[], [],
] ]
def test_with_leading_whitespace(self): def test_with_leading_whitespace(self) -> None:
stream = ['\n \r\n {"one": "two"}{"x": 1}', ' {"three": "four"}\t\t{"x": 2}'] stream = create_generator(
['\n \r\n {"one": "two"}{"x": 1}', ' {"three": "four"}\t\t{"x": 2}']
)
output = list(json_stream(stream)) output = list(json_stream(stream))
assert output == [{"one": "two"}, {"x": 1}, {"three": "four"}, {"x": 2}] assert output == [{"one": "two"}, {"x": 1}, {"three": "four"}, {"x": 2}]

View File

@ -19,132 +19,132 @@ from ansible_collections.community.docker.plugins.module_utils._api.utils.ports
class PortsTest(unittest.TestCase): class PortsTest(unittest.TestCase):
def test_split_port_with_host_ip(self): def test_split_port_with_host_ip(self) -> None:
internal_port, external_port = split_port("127.0.0.1:1000:2000") internal_port, external_port = split_port("127.0.0.1:1000:2000")
assert internal_port == ["2000"] assert internal_port == ["2000"]
assert external_port == [("127.0.0.1", "1000")] assert external_port == [("127.0.0.1", "1000")]
def test_split_port_with_protocol(self): def test_split_port_with_protocol(self) -> None:
for protocol in ["tcp", "udp", "sctp"]: for protocol in ["tcp", "udp", "sctp"]:
internal_port, external_port = split_port("127.0.0.1:1000:2000/" + protocol) internal_port, external_port = split_port("127.0.0.1:1000:2000/" + protocol)
assert internal_port == ["2000/" + protocol] assert internal_port == ["2000/" + protocol]
assert external_port == [("127.0.0.1", "1000")] assert external_port == [("127.0.0.1", "1000")]
def test_split_port_with_host_ip_no_port(self): def test_split_port_with_host_ip_no_port(self) -> None:
internal_port, external_port = split_port("127.0.0.1::2000") internal_port, external_port = split_port("127.0.0.1::2000")
assert internal_port == ["2000"] assert internal_port == ["2000"]
assert external_port == [("127.0.0.1", None)] assert external_port == [("127.0.0.1", None)]
def test_split_port_range_with_host_ip_no_port(self): def test_split_port_range_with_host_ip_no_port(self) -> None:
internal_port, external_port = split_port("127.0.0.1::2000-2001") internal_port, external_port = split_port("127.0.0.1::2000-2001")
assert internal_port == ["2000", "2001"] assert internal_port == ["2000", "2001"]
assert external_port == [("127.0.0.1", None), ("127.0.0.1", None)] assert external_port == [("127.0.0.1", None), ("127.0.0.1", None)]
def test_split_port_with_host_port(self): def test_split_port_with_host_port(self) -> None:
internal_port, external_port = split_port("1000:2000") internal_port, external_port = split_port("1000:2000")
assert internal_port == ["2000"] assert internal_port == ["2000"]
assert external_port == ["1000"] assert external_port == ["1000"]
def test_split_port_range_with_host_port(self): def test_split_port_range_with_host_port(self) -> None:
internal_port, external_port = split_port("1000-1001:2000-2001") internal_port, external_port = split_port("1000-1001:2000-2001")
assert internal_port == ["2000", "2001"] assert internal_port == ["2000", "2001"]
assert external_port == ["1000", "1001"] assert external_port == ["1000", "1001"]
def test_split_port_random_port_range_with_host_port(self): def test_split_port_random_port_range_with_host_port(self) -> None:
internal_port, external_port = split_port("1000-1001:2000") internal_port, external_port = split_port("1000-1001:2000")
assert internal_port == ["2000"] assert internal_port == ["2000"]
assert external_port == ["1000-1001"] assert external_port == ["1000-1001"]
def test_split_port_no_host_port(self): def test_split_port_no_host_port(self) -> None:
internal_port, external_port = split_port("2000") internal_port, external_port = split_port("2000")
assert internal_port == ["2000"] assert internal_port == ["2000"]
assert external_port is None assert external_port is None
def test_split_port_range_no_host_port(self): def test_split_port_range_no_host_port(self) -> None:
internal_port, external_port = split_port("2000-2001") internal_port, external_port = split_port("2000-2001")
assert internal_port == ["2000", "2001"] assert internal_port == ["2000", "2001"]
assert external_port is None assert external_port is None
def test_split_port_range_with_protocol(self): def test_split_port_range_with_protocol(self) -> None:
internal_port, external_port = split_port("127.0.0.1:1000-1001:2000-2001/udp") internal_port, external_port = split_port("127.0.0.1:1000-1001:2000-2001/udp")
assert internal_port == ["2000/udp", "2001/udp"] assert internal_port == ["2000/udp", "2001/udp"]
assert external_port == [("127.0.0.1", "1000"), ("127.0.0.1", "1001")] assert external_port == [("127.0.0.1", "1000"), ("127.0.0.1", "1001")]
def test_split_port_with_ipv6_address(self): def test_split_port_with_ipv6_address(self) -> None:
internal_port, external_port = split_port("2001:abcd:ef00::2:1000:2000") internal_port, external_port = split_port("2001:abcd:ef00::2:1000:2000")
assert internal_port == ["2000"] assert internal_port == ["2000"]
assert external_port == [("2001:abcd:ef00::2", "1000")] assert external_port == [("2001:abcd:ef00::2", "1000")]
def test_split_port_with_ipv6_square_brackets_address(self): def test_split_port_with_ipv6_square_brackets_address(self) -> None:
internal_port, external_port = split_port("[2001:abcd:ef00::2]:1000:2000") internal_port, external_port = split_port("[2001:abcd:ef00::2]:1000:2000")
assert internal_port == ["2000"] assert internal_port == ["2000"]
assert external_port == [("2001:abcd:ef00::2", "1000")] assert external_port == [("2001:abcd:ef00::2", "1000")]
def test_split_port_invalid(self): def test_split_port_invalid(self) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
split_port("0.0.0.0:1000:2000:tcp") split_port("0.0.0.0:1000:2000:tcp")
def test_split_port_invalid_protocol(self): def test_split_port_invalid_protocol(self) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
split_port("0.0.0.0:1000:2000/ftp") split_port("0.0.0.0:1000:2000/ftp")
def test_non_matching_length_port_ranges(self): def test_non_matching_length_port_ranges(self) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
split_port("0.0.0.0:1000-1010:2000-2002/tcp") split_port("0.0.0.0:1000-1010:2000-2002/tcp")
def test_port_and_range_invalid(self): def test_port_and_range_invalid(self) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
split_port("0.0.0.0:1000:2000-2002/tcp") split_port("0.0.0.0:1000:2000-2002/tcp")
def test_port_only_with_colon(self): def test_port_only_with_colon(self) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
split_port(":80") split_port(":80")
def test_host_only_with_colon(self): def test_host_only_with_colon(self) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
split_port("localhost:") split_port("localhost:")
def test_with_no_container_port(self): def test_with_no_container_port(self) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
split_port("localhost:80:") split_port("localhost:80:")
def test_split_port_empty_string(self): def test_split_port_empty_string(self) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
split_port("") split_port("")
def test_split_port_non_string(self): def test_split_port_non_string(self) -> None:
assert split_port(1243) == (["1243"], None) assert split_port(1243) == (["1243"], None)
def test_build_port_bindings_with_one_port(self): def test_build_port_bindings_with_one_port(self) -> None:
port_bindings = build_port_bindings(["127.0.0.1:1000:1000"]) port_bindings = build_port_bindings(["127.0.0.1:1000:1000"])
assert port_bindings["1000"] == [("127.0.0.1", "1000")] assert port_bindings["1000"] == [("127.0.0.1", "1000")]
def test_build_port_bindings_with_matching_internal_ports(self): def test_build_port_bindings_with_matching_internal_ports(self) -> None:
port_bindings = build_port_bindings( port_bindings = build_port_bindings(
["127.0.0.1:1000:1000", "127.0.0.1:2000:1000"] ["127.0.0.1:1000:1000", "127.0.0.1:2000:1000"]
) )
assert port_bindings["1000"] == [("127.0.0.1", "1000"), ("127.0.0.1", "2000")] assert port_bindings["1000"] == [("127.0.0.1", "1000"), ("127.0.0.1", "2000")]
def test_build_port_bindings_with_nonmatching_internal_ports(self): def test_build_port_bindings_with_nonmatching_internal_ports(self) -> None:
port_bindings = build_port_bindings( port_bindings = build_port_bindings(
["127.0.0.1:1000:1000", "127.0.0.1:2000:2000"] ["127.0.0.1:1000:1000", "127.0.0.1:2000:2000"]
) )
assert port_bindings["1000"] == [("127.0.0.1", "1000")] assert port_bindings["1000"] == [("127.0.0.1", "1000")]
assert port_bindings["2000"] == [("127.0.0.1", "2000")] assert port_bindings["2000"] == [("127.0.0.1", "2000")]
def test_build_port_bindings_with_port_range(self): def test_build_port_bindings_with_port_range(self) -> None:
port_bindings = build_port_bindings(["127.0.0.1:1000-1001:1000-1001"]) port_bindings = build_port_bindings(["127.0.0.1:1000-1001:1000-1001"])
assert port_bindings["1000"] == [("127.0.0.1", "1000")] assert port_bindings["1000"] == [("127.0.0.1", "1000")]
assert port_bindings["1001"] == [("127.0.0.1", "1001")] assert port_bindings["1001"] == [("127.0.0.1", "1001")]
def test_build_port_bindings_with_matching_internal_port_ranges(self): def test_build_port_bindings_with_matching_internal_port_ranges(self) -> None:
port_bindings = build_port_bindings( port_bindings = build_port_bindings(
["127.0.0.1:1000-1001:1000-1001", "127.0.0.1:2000-2001:1000-1001"] ["127.0.0.1:1000-1001:1000-1001", "127.0.0.1:2000-2001:1000-1001"]
) )
assert port_bindings["1000"] == [("127.0.0.1", "1000"), ("127.0.0.1", "2000")] assert port_bindings["1000"] == [("127.0.0.1", "1000"), ("127.0.0.1", "2000")]
assert port_bindings["1001"] == [("127.0.0.1", "1001"), ("127.0.0.1", "2001")] assert port_bindings["1001"] == [("127.0.0.1", "1001"), ("127.0.0.1", "2001")]
def test_build_port_bindings_with_nonmatching_internal_port_ranges(self): def test_build_port_bindings_with_nonmatching_internal_port_ranges(self) -> None:
port_bindings = build_port_bindings( port_bindings = build_port_bindings(
["127.0.0.1:1000:1000", "127.0.0.1:2000:2000"] ["127.0.0.1:1000:1000", "127.0.0.1:2000:2000"]
) )

View File

@ -33,8 +33,7 @@ ENV = {
class ProxyConfigTest(unittest.TestCase): class ProxyConfigTest(unittest.TestCase):
def test_from_dict(self) -> None:
def test_from_dict(self):
config = ProxyConfig.from_dict( config = ProxyConfig.from_dict(
{ {
"httpProxy": HTTP, "httpProxy": HTTP,
@ -48,7 +47,7 @@ class ProxyConfigTest(unittest.TestCase):
self.assertEqual(CONFIG.ftp, config.ftp) self.assertEqual(CONFIG.ftp, config.ftp)
self.assertEqual(CONFIG.no_proxy, config.no_proxy) self.assertEqual(CONFIG.no_proxy, config.no_proxy)
def test_new(self): def test_new(self) -> None:
config = ProxyConfig() config = ProxyConfig()
self.assertIsNone(config.http) self.assertIsNone(config.http)
self.assertIsNone(config.https) self.assertIsNone(config.https)
@ -61,22 +60,24 @@ class ProxyConfigTest(unittest.TestCase):
self.assertEqual(config.ftp, "c") self.assertEqual(config.ftp, "c")
self.assertEqual(config.no_proxy, "d") self.assertEqual(config.no_proxy, "d")
def test_truthiness(self): def test_truthiness(self) -> None:
assert not ProxyConfig() assert not ProxyConfig()
assert ProxyConfig(http="non-zero") assert ProxyConfig(http="non-zero")
assert ProxyConfig(https="non-zero") assert ProxyConfig(https="non-zero")
assert ProxyConfig(ftp="non-zero") assert ProxyConfig(ftp="non-zero")
assert ProxyConfig(no_proxy="non-zero") assert ProxyConfig(no_proxy="non-zero")
def test_environment(self): def test_environment(self) -> None:
self.assertDictEqual(CONFIG.get_environment(), ENV) self.assertDictEqual(CONFIG.get_environment(), ENV)
empty = ProxyConfig() empty = ProxyConfig()
self.assertDictEqual(empty.get_environment(), {}) self.assertDictEqual(empty.get_environment(), {})
def test_inject_proxy_environment(self): def test_inject_proxy_environment(self) -> None:
# Proxy config is non null, env is None. # Proxy config is non null, env is None.
envlist = CONFIG.inject_proxy_environment(None)
assert envlist is not None
self.assertSetEqual( self.assertSetEqual(
set(CONFIG.inject_proxy_environment(None)), set(envlist),
set(f"{k}={v}" for k, v in ENV.items()), set(f"{k}={v}" for k, v in ENV.items()),
) )

View File

@ -52,13 +52,15 @@ TEST_CERT_DIR = os.path.join(
class KwargsFromEnvTest(unittest.TestCase): class KwargsFromEnvTest(unittest.TestCase):
def setUp(self): os_environ: dict[str, str]
def setUp(self) -> None:
self.os_environ = os.environ.copy() self.os_environ = os.environ.copy()
def tearDown(self): def tearDown(self) -> None:
os.environ = self.os_environ os.environ = self.os_environ # type: ignore
def test_kwargs_from_env_empty(self): def test_kwargs_from_env_empty(self) -> None:
os.environ.update(DOCKER_HOST="", DOCKER_CERT_PATH="") os.environ.update(DOCKER_HOST="", DOCKER_CERT_PATH="")
os.environ.pop("DOCKER_TLS_VERIFY", None) os.environ.pop("DOCKER_TLS_VERIFY", None)
@ -66,7 +68,7 @@ class KwargsFromEnvTest(unittest.TestCase):
assert kwargs.get("base_url") is None assert kwargs.get("base_url") is None
assert kwargs.get("tls") is None assert kwargs.get("tls") is None
def test_kwargs_from_env_tls(self): def test_kwargs_from_env_tls(self) -> None:
os.environ.update( os.environ.update(
DOCKER_HOST="tcp://192.168.59.103:2376", DOCKER_HOST="tcp://192.168.59.103:2376",
DOCKER_CERT_PATH=TEST_CERT_DIR, DOCKER_CERT_PATH=TEST_CERT_DIR,
@ -90,7 +92,7 @@ class KwargsFromEnvTest(unittest.TestCase):
except TypeError as e: except TypeError as e:
self.fail(e) self.fail(e)
def test_kwargs_from_env_tls_verify_false(self): def test_kwargs_from_env_tls_verify_false(self) -> None:
os.environ.update( os.environ.update(
DOCKER_HOST="tcp://192.168.59.103:2376", DOCKER_HOST="tcp://192.168.59.103:2376",
DOCKER_CERT_PATH=TEST_CERT_DIR, DOCKER_CERT_PATH=TEST_CERT_DIR,
@ -113,7 +115,7 @@ class KwargsFromEnvTest(unittest.TestCase):
except TypeError as e: except TypeError as e:
self.fail(e) self.fail(e)
def test_kwargs_from_env_tls_verify_false_no_cert(self): def test_kwargs_from_env_tls_verify_false_no_cert(self) -> None:
temp_dir = tempfile.mkdtemp() temp_dir = tempfile.mkdtemp()
cert_dir = os.path.join(temp_dir, ".docker") cert_dir = os.path.join(temp_dir, ".docker")
shutil.copytree(TEST_CERT_DIR, cert_dir) shutil.copytree(TEST_CERT_DIR, cert_dir)
@ -125,7 +127,7 @@ class KwargsFromEnvTest(unittest.TestCase):
kwargs = kwargs_from_env(assert_hostname=True) kwargs = kwargs_from_env(assert_hostname=True)
assert "tcp://192.168.59.103:2376" == kwargs["base_url"] assert "tcp://192.168.59.103:2376" == kwargs["base_url"]
def test_kwargs_from_env_no_cert_path(self): def test_kwargs_from_env_no_cert_path(self) -> None:
try: try:
temp_dir = tempfile.mkdtemp() temp_dir = tempfile.mkdtemp()
cert_dir = os.path.join(temp_dir, ".docker") cert_dir = os.path.join(temp_dir, ".docker")
@ -142,7 +144,7 @@ class KwargsFromEnvTest(unittest.TestCase):
if temp_dir: if temp_dir:
shutil.rmtree(temp_dir) shutil.rmtree(temp_dir)
def test_kwargs_from_env_alternate_env(self): def test_kwargs_from_env_alternate_env(self) -> None:
# Values in os.environ are entirely ignored if an alternate is # Values in os.environ are entirely ignored if an alternate is
# provided # provided
os.environ.update( os.environ.update(
@ -160,30 +162,32 @@ class KwargsFromEnvTest(unittest.TestCase):
class ConverVolumeBindsTest(unittest.TestCase): class ConverVolumeBindsTest(unittest.TestCase):
def test_convert_volume_binds_empty(self): def test_convert_volume_binds_empty(self) -> None:
assert convert_volume_binds({}) == [] assert convert_volume_binds({}) == []
assert convert_volume_binds([]) == [] assert convert_volume_binds([]) == []
def test_convert_volume_binds_list(self): def test_convert_volume_binds_list(self) -> None:
data = ["/a:/a:ro", "/b:/c:z"] data = ["/a:/a:ro", "/b:/c:z"]
assert convert_volume_binds(data) == data assert convert_volume_binds(data) == data
def test_convert_volume_binds_complete(self): def test_convert_volume_binds_complete(self) -> None:
data = {"/mnt/vol1": {"bind": "/data", "mode": "ro"}} data: dict[str | bytes, dict[str, str]] = {
"/mnt/vol1": {"bind": "/data", "mode": "ro"}
}
assert convert_volume_binds(data) == ["/mnt/vol1:/data:ro"] assert convert_volume_binds(data) == ["/mnt/vol1:/data:ro"]
def test_convert_volume_binds_compact(self): def test_convert_volume_binds_compact(self) -> None:
data = {"/mnt/vol1": "/data"} data: dict[str | bytes, str] = {"/mnt/vol1": "/data"}
assert convert_volume_binds(data) == ["/mnt/vol1:/data:rw"] assert convert_volume_binds(data) == ["/mnt/vol1:/data:rw"]
def test_convert_volume_binds_no_mode(self): def test_convert_volume_binds_no_mode(self) -> None:
data = {"/mnt/vol1": {"bind": "/data"}} data: dict[str | bytes, dict[str, str]] = {"/mnt/vol1": {"bind": "/data"}}
assert convert_volume_binds(data) == ["/mnt/vol1:/data:rw"] assert convert_volume_binds(data) == ["/mnt/vol1:/data:rw"]
def test_convert_volume_binds_unicode_bytes_input(self): def test_convert_volume_binds_unicode_bytes_input(self) -> None:
expected = ["/mnt/지연:/unicode/박:rw"] expected = ["/mnt/지연:/unicode/박:rw"]
data = { data: dict[str | bytes, dict[str, str | bytes]] = {
"/mnt/지연".encode("utf-8"): { "/mnt/지연".encode("utf-8"): {
"bind": "/unicode/박".encode("utf-8"), "bind": "/unicode/박".encode("utf-8"),
"mode": "rw", "mode": "rw",
@ -191,15 +195,17 @@ class ConverVolumeBindsTest(unittest.TestCase):
} }
assert convert_volume_binds(data) == expected assert convert_volume_binds(data) == expected
def test_convert_volume_binds_unicode_unicode_input(self): def test_convert_volume_binds_unicode_unicode_input(self) -> None:
expected = ["/mnt/지연:/unicode/박:rw"] expected = ["/mnt/지연:/unicode/박:rw"]
data = {"/mnt/지연": {"bind": "/unicode/박", "mode": "rw"}} data: dict[str | bytes, dict[str, str]] = {
"/mnt/지연": {"bind": "/unicode/박", "mode": "rw"}
}
assert convert_volume_binds(data) == expected assert convert_volume_binds(data) == expected
class ParseEnvFileTest(unittest.TestCase): class ParseEnvFileTest(unittest.TestCase):
def generate_tempfile(self, file_content=None): def generate_tempfile(self, file_content: str) -> str:
""" """
Generates a temporary file for tests with the content Generates a temporary file for tests with the content
of 'file_content' and returns the filename. of 'file_content' and returns the filename.
@ -209,31 +215,31 @@ class ParseEnvFileTest(unittest.TestCase):
local_tempfile.write(file_content.encode("UTF-8")) local_tempfile.write(file_content.encode("UTF-8"))
return local_tempfile.name return local_tempfile.name
def test_parse_env_file_proper(self): def test_parse_env_file_proper(self) -> None:
env_file = self.generate_tempfile(file_content="USER=jdoe\nPASS=secret") env_file = self.generate_tempfile(file_content="USER=jdoe\nPASS=secret")
get_parse_env_file = parse_env_file(env_file) get_parse_env_file = parse_env_file(env_file)
assert get_parse_env_file == {"USER": "jdoe", "PASS": "secret"} assert get_parse_env_file == {"USER": "jdoe", "PASS": "secret"}
os.unlink(env_file) os.unlink(env_file)
def test_parse_env_file_with_equals_character(self): def test_parse_env_file_with_equals_character(self) -> None:
env_file = self.generate_tempfile(file_content="USER=jdoe\nPASS=sec==ret") env_file = self.generate_tempfile(file_content="USER=jdoe\nPASS=sec==ret")
get_parse_env_file = parse_env_file(env_file) get_parse_env_file = parse_env_file(env_file)
assert get_parse_env_file == {"USER": "jdoe", "PASS": "sec==ret"} assert get_parse_env_file == {"USER": "jdoe", "PASS": "sec==ret"}
os.unlink(env_file) os.unlink(env_file)
def test_parse_env_file_commented_line(self): def test_parse_env_file_commented_line(self) -> None:
env_file = self.generate_tempfile(file_content="USER=jdoe\n#PASS=secret") env_file = self.generate_tempfile(file_content="USER=jdoe\n#PASS=secret")
get_parse_env_file = parse_env_file(env_file) get_parse_env_file = parse_env_file(env_file)
assert get_parse_env_file == {"USER": "jdoe"} assert get_parse_env_file == {"USER": "jdoe"}
os.unlink(env_file) os.unlink(env_file)
def test_parse_env_file_newline(self): def test_parse_env_file_newline(self) -> None:
env_file = self.generate_tempfile(file_content="\nUSER=jdoe\n\n\nPASS=secret") env_file = self.generate_tempfile(file_content="\nUSER=jdoe\n\n\nPASS=secret")
get_parse_env_file = parse_env_file(env_file) get_parse_env_file = parse_env_file(env_file)
assert get_parse_env_file == {"USER": "jdoe", "PASS": "secret"} assert get_parse_env_file == {"USER": "jdoe", "PASS": "secret"}
os.unlink(env_file) os.unlink(env_file)
def test_parse_env_file_invalid_line(self): def test_parse_env_file_invalid_line(self) -> None:
env_file = self.generate_tempfile(file_content="USER jdoe") env_file = self.generate_tempfile(file_content="USER jdoe")
with pytest.raises(DockerException): with pytest.raises(DockerException):
parse_env_file(env_file) parse_env_file(env_file)
@ -241,7 +247,7 @@ class ParseEnvFileTest(unittest.TestCase):
class ParseHostTest(unittest.TestCase): class ParseHostTest(unittest.TestCase):
def test_parse_host(self): def test_parse_host(self) -> None:
invalid_hosts = [ invalid_hosts = [
"foo://0.0.0.0", "foo://0.0.0.0",
"tcp://", "tcp://",
@ -282,16 +288,16 @@ class ParseHostTest(unittest.TestCase):
for host in invalid_hosts: for host in invalid_hosts:
msg = f"Should have failed to parse invalid host: {host}" msg = f"Should have failed to parse invalid host: {host}"
with self.assertRaises(DockerException, msg=msg): with self.assertRaises(DockerException, msg=msg):
parse_host(host, None) parse_host(host)
for host, expected in valid_hosts.items(): for host, expected in valid_hosts.items():
self.assertEqual( self.assertEqual(
parse_host(host, None), parse_host(host),
expected, expected,
msg=f"Failed to parse valid host: {host}", msg=f"Failed to parse valid host: {host}",
) )
def test_parse_host_empty_value(self): def test_parse_host_empty_value(self) -> None:
unix_socket = "http+unix:///var/run/docker.sock" unix_socket = "http+unix:///var/run/docker.sock"
npipe = "npipe:////./pipe/docker_engine" npipe = "npipe:////./pipe/docker_engine"
@ -299,17 +305,17 @@ class ParseHostTest(unittest.TestCase):
assert parse_host(val, is_win32=False) == unix_socket assert parse_host(val, is_win32=False) == unix_socket
assert parse_host(val, is_win32=True) == npipe assert parse_host(val, is_win32=True) == npipe
def test_parse_host_tls(self): def test_parse_host_tls(self) -> None:
host_value = "myhost.docker.net:3348" host_value = "myhost.docker.net:3348"
expected_result = "https://myhost.docker.net:3348" expected_result = "https://myhost.docker.net:3348"
assert parse_host(host_value, tls=True) == expected_result assert parse_host(host_value, tls=True) == expected_result
def test_parse_host_tls_tcp_proto(self): def test_parse_host_tls_tcp_proto(self) -> None:
host_value = "tcp://myhost.docker.net:3348" host_value = "tcp://myhost.docker.net:3348"
expected_result = "https://myhost.docker.net:3348" expected_result = "https://myhost.docker.net:3348"
assert parse_host(host_value, tls=True) == expected_result assert parse_host(host_value, tls=True) == expected_result
def test_parse_host_trailing_slash(self): def test_parse_host_trailing_slash(self) -> None:
host_value = "tcp://myhost.docker.net:2376/" host_value = "tcp://myhost.docker.net:2376/"
expected_result = "http://myhost.docker.net:2376" expected_result = "http://myhost.docker.net:2376"
assert parse_host(host_value) == expected_result assert parse_host(host_value) == expected_result
@ -318,31 +324,31 @@ class ParseHostTest(unittest.TestCase):
class ParseRepositoryTagTest(unittest.TestCase): class ParseRepositoryTagTest(unittest.TestCase):
sha = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" sha = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
def test_index_image_no_tag(self): def test_index_image_no_tag(self) -> None:
assert parse_repository_tag("root") == ("root", None) assert parse_repository_tag("root") == ("root", None)
def test_index_image_tag(self): def test_index_image_tag(self) -> None:
assert parse_repository_tag("root:tag") == ("root", "tag") assert parse_repository_tag("root:tag") == ("root", "tag")
def test_index_user_image_no_tag(self): def test_index_user_image_no_tag(self) -> None:
assert parse_repository_tag("user/repo") == ("user/repo", None) assert parse_repository_tag("user/repo") == ("user/repo", None)
def test_index_user_image_tag(self): def test_index_user_image_tag(self) -> None:
assert parse_repository_tag("user/repo:tag") == ("user/repo", "tag") assert parse_repository_tag("user/repo:tag") == ("user/repo", "tag")
def test_private_reg_image_no_tag(self): def test_private_reg_image_no_tag(self) -> None:
assert parse_repository_tag("url:5000/repo") == ("url:5000/repo", None) assert parse_repository_tag("url:5000/repo") == ("url:5000/repo", None)
def test_private_reg_image_tag(self): def test_private_reg_image_tag(self) -> None:
assert parse_repository_tag("url:5000/repo:tag") == ("url:5000/repo", "tag") assert parse_repository_tag("url:5000/repo:tag") == ("url:5000/repo", "tag")
def test_index_image_sha(self): def test_index_image_sha(self) -> None:
assert parse_repository_tag(f"root@sha256:{self.sha}") == ( assert parse_repository_tag(f"root@sha256:{self.sha}") == (
"root", "root",
f"sha256:{self.sha}", f"sha256:{self.sha}",
) )
def test_private_reg_image_sha(self): def test_private_reg_image_sha(self) -> None:
assert parse_repository_tag(f"url:5000/repo@sha256:{self.sha}") == ( assert parse_repository_tag(f"url:5000/repo@sha256:{self.sha}") == (
"url:5000/repo", "url:5000/repo",
f"sha256:{self.sha}", f"sha256:{self.sha}",
@ -350,7 +356,7 @@ class ParseRepositoryTagTest(unittest.TestCase):
class ParseDeviceTest(unittest.TestCase): class ParseDeviceTest(unittest.TestCase):
def test_dict(self): def test_dict(self) -> None:
devices = parse_devices( devices = parse_devices(
[ [
{ {
@ -366,7 +372,7 @@ class ParseDeviceTest(unittest.TestCase):
"CgroupPermissions": "r", "CgroupPermissions": "r",
} }
def test_partial_string_definition(self): def test_partial_string_definition(self) -> None:
devices = parse_devices(["/dev/sda1"]) devices = parse_devices(["/dev/sda1"])
assert devices[0] == { assert devices[0] == {
"PathOnHost": "/dev/sda1", "PathOnHost": "/dev/sda1",
@ -374,7 +380,7 @@ class ParseDeviceTest(unittest.TestCase):
"CgroupPermissions": "rwm", "CgroupPermissions": "rwm",
} }
def test_permissionless_string_definition(self): def test_permissionless_string_definition(self) -> None:
devices = parse_devices(["/dev/sda1:/dev/mnt1"]) devices = parse_devices(["/dev/sda1:/dev/mnt1"])
assert devices[0] == { assert devices[0] == {
"PathOnHost": "/dev/sda1", "PathOnHost": "/dev/sda1",
@ -382,7 +388,7 @@ class ParseDeviceTest(unittest.TestCase):
"CgroupPermissions": "rwm", "CgroupPermissions": "rwm",
} }
def test_full_string_definition(self): def test_full_string_definition(self) -> None:
devices = parse_devices(["/dev/sda1:/dev/mnt1:r"]) devices = parse_devices(["/dev/sda1:/dev/mnt1:r"])
assert devices[0] == { assert devices[0] == {
"PathOnHost": "/dev/sda1", "PathOnHost": "/dev/sda1",
@ -390,7 +396,7 @@ class ParseDeviceTest(unittest.TestCase):
"CgroupPermissions": "r", "CgroupPermissions": "r",
} }
def test_hybrid_list(self): def test_hybrid_list(self) -> None:
devices = parse_devices( devices = parse_devices(
[ [
"/dev/sda1:/dev/mnt1:rw", "/dev/sda1:/dev/mnt1:rw",
@ -415,12 +421,12 @@ class ParseDeviceTest(unittest.TestCase):
class ParseBytesTest(unittest.TestCase): class ParseBytesTest(unittest.TestCase):
def test_parse_bytes_valid(self): def test_parse_bytes_valid(self) -> None:
assert parse_bytes("512MB") == 536870912 assert parse_bytes("512MB") == 536870912
assert parse_bytes("512M") == 536870912 assert parse_bytes("512M") == 536870912
assert parse_bytes("512m") == 536870912 assert parse_bytes("512m") == 536870912
def test_parse_bytes_invalid(self): def test_parse_bytes_invalid(self) -> None:
with pytest.raises(DockerException): with pytest.raises(DockerException):
parse_bytes("512MK") parse_bytes("512MK")
with pytest.raises(DockerException): with pytest.raises(DockerException):
@ -428,15 +434,15 @@ class ParseBytesTest(unittest.TestCase):
with pytest.raises(DockerException): with pytest.raises(DockerException):
parse_bytes("127.0.0.1K") parse_bytes("127.0.0.1K")
def test_parse_bytes_float(self): def test_parse_bytes_float(self) -> None:
assert parse_bytes("1.5k") == 1536 assert parse_bytes("1.5k") == 1536
class UtilsTest(unittest.TestCase): class UtilsTest(unittest.TestCase):
longMessage = True longMessage = True
def test_convert_filters(self): def test_convert_filters(self) -> None:
tests = [ tests: list[tuple[dict[str, bool | str | int | list[str | int]], str]] = [
({"dangling": True}, '{"dangling": ["true"]}'), ({"dangling": True}, '{"dangling": ["true"]}'),
({"dangling": "true"}, '{"dangling": ["true"]}'), ({"dangling": "true"}, '{"dangling": ["true"]}'),
({"exited": 0}, '{"exited": ["0"]}'), ({"exited": 0}, '{"exited": ["0"]}'),
@ -446,7 +452,7 @@ class UtilsTest(unittest.TestCase):
for filters, expected in tests: for filters, expected in tests:
assert convert_filters(filters) == expected assert convert_filters(filters) == expected
def test_decode_json_header(self): def test_decode_json_header(self) -> None:
obj = {"a": "b", "c": 1} obj = {"a": "b", "c": 1}
data = base64.urlsafe_b64encode(bytes(json.dumps(obj), "utf-8")) data = base64.urlsafe_b64encode(bytes(json.dumps(obj), "utf-8"))
decoded_data = decode_json_header(data) decoded_data = decode_json_header(data)
@ -454,12 +460,12 @@ class UtilsTest(unittest.TestCase):
class SplitCommandTest(unittest.TestCase): class SplitCommandTest(unittest.TestCase):
def test_split_command_with_unicode(self): def test_split_command_with_unicode(self) -> None:
assert split_command("echo μμ") == ["echo", "μμ"] assert split_command("echo μμ") == ["echo", "μμ"]
class FormatEnvironmentTest(unittest.TestCase): class FormatEnvironmentTest(unittest.TestCase):
def test_format_env_binary_unicode_value(self): def test_format_env_binary_unicode_value(self) -> None:
env_dict = {"ARTIST_NAME": b"\xec\x86\xa1\xec\xa7\x80\xec\x9d\x80"} env_dict = {"ARTIST_NAME": b"\xec\x86\xa1\xec\xa7\x80\xec\x9d\x80"}
assert format_environment(env_dict) == ["ARTIST_NAME=송지은"] assert format_environment(env_dict) == ["ARTIST_NAME=송지은"]

View File

@ -11,7 +11,7 @@ from ansible_collections.community.docker.plugins.module_utils._compose_v2 impor
) )
EVENT_TEST_CASES: list[tuple[str, str, bool, bool, str, list[Event], list[Event]]] = [ EVENT_TEST_CASES: list[tuple[str, str, bool, bool, str, list[Event], list[str]]] = [
# ####################################################################################################################### # #######################################################################################################################
# ## Docker Compose 2.18.1 ############################################################################################## # ## Docker Compose 2.18.1 ##############################################################################################
# ####################################################################################################################### # #######################################################################################################################

View File

@ -14,7 +14,7 @@ from ansible_collections.community.docker.plugins.module_utils._compose_v2 impor
from .compose_v2_test_cases import EVENT_TEST_CASES from .compose_v2_test_cases import EVENT_TEST_CASES
EXTRA_TEST_CASES = [ EXTRA_TEST_CASES: list[tuple[str, str, bool, bool, str, list[Event], list[str]]] = [
( (
"2.24.2-manual-build-dry-run", "2.24.2-manual-build-dry-run",
"2.24.2", "2.24.2",
@ -227,9 +227,7 @@ EXTRA_TEST_CASES = [
False, False,
False, False,
# fmt: off # fmt: off
" bash_1 Skipped \n" " bash_1 Skipped \n bash_2 Pulling \n bash_2 Pulled \n",
" bash_2 Pulling \n"
" bash_2 Pulled \n",
# fmt: on # fmt: on
[ [
Event( Event(
@ -361,15 +359,24 @@ _ALL_TEST_CASES = EVENT_TEST_CASES + EXTRA_TEST_CASES
ids=[tc[0] for tc in _ALL_TEST_CASES], ids=[tc[0] for tc in _ALL_TEST_CASES],
) )
def test_parse_events( def test_parse_events(
test_id, compose_version, dry_run, nonzero_rc, stderr, events, warnings test_id: str,
): compose_version: str,
dry_run: bool,
nonzero_rc: bool,
stderr: str,
events: list[Event],
warnings: list[str],
) -> None:
collected_warnings = [] collected_warnings = []
def collect_warning(msg): def collect_warning(msg):
collected_warnings.append(msg) collected_warnings.append(msg)
collected_events = parse_events( collected_events = parse_events(
stderr, dry_run=dry_run, warn_function=collect_warning, nonzero_rc=nonzero_rc stderr.encode("utf-8"),
dry_run=dry_run,
warn_function=collect_warning,
nonzero_rc=nonzero_rc,
) )
print(collected_events) print(collected_events)

View File

@ -4,6 +4,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
import pytest import pytest
from ansible_collections.community.docker.plugins.module_utils._copy import ( from ansible_collections.community.docker.plugins.module_utils._copy import (
@ -11,7 +13,13 @@ from ansible_collections.community.docker.plugins.module_utils._copy import (
) )
def _simple_generator(sequence): if t.TYPE_CHECKING:
from collections.abc import Sequence
T = t.TypeVar("T")
def _simple_generator(sequence: Sequence[T]) -> t.Generator[T]:
yield from sequence yield from sequence
@ -60,10 +68,12 @@ def _simple_generator(sequence):
), ),
], ],
) )
def test__stream_generator_to_fileobj(chunks, read_sizes): def test__stream_generator_to_fileobj(
chunks = [count * data for count, data in chunks] chunks: list[tuple[int, bytes]], read_sizes: list[int]
stream = _simple_generator(chunks) ) -> None:
expected = b"".join(chunks) data_chunks = [count * data for count, data in chunks]
stream = _simple_generator(data_chunks)
expected = b"".join(data_chunks)
buffer = b"" buffer = b""
totally_read = 0 totally_read = 0

View File

@ -22,7 +22,7 @@ from ..test_support.docker_image_archive_stubbing import (
@pytest.fixture @pytest.fixture
def tar_file_name(tmpdir): def tar_file_name(tmpdir) -> str:
""" """
Return the name of a non-existing tar file in an existing temporary directory. Return the name of a non-existing tar file in an existing temporary directory.
""" """
@ -34,11 +34,11 @@ def tar_file_name(tmpdir):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"expected, value", [("sha256:foo", "foo"), ("sha256:bar", "bar")] "expected, value", [("sha256:foo", "foo"), ("sha256:bar", "bar")]
) )
def test_api_image_id_from_archive_id(expected, value): def test_api_image_id_from_archive_id(expected: str, value: str) -> None:
assert api_image_id(value) == expected assert api_image_id(value) == expected
def test_archived_image_manifest_extracts(tar_file_name): def test_archived_image_manifest_extracts(tar_file_name) -> None:
expected_id = "abcde12345" expected_id = "abcde12345"
expected_tags = ["foo:latest", "bar:v1"] expected_tags = ["foo:latest", "bar:v1"]
@ -46,17 +46,20 @@ def test_archived_image_manifest_extracts(tar_file_name):
actual = archived_image_manifest(tar_file_name) actual = archived_image_manifest(tar_file_name)
assert actual is not None
assert actual.image_id == expected_id assert actual.image_id == expected_id
assert actual.repo_tags == expected_tags assert actual.repo_tags == expected_tags
def test_archived_image_manifest_extracts_nothing_when_file_not_present(tar_file_name): def test_archived_image_manifest_extracts_nothing_when_file_not_present(
tar_file_name,
) -> None:
image_id = archived_image_manifest(tar_file_name) image_id = archived_image_manifest(tar_file_name)
assert image_id is None assert image_id is None
def test_archived_image_manifest_raises_when_file_not_a_tar(): def test_archived_image_manifest_raises_when_file_not_a_tar() -> None:
try: try:
archived_image_manifest(__file__) archived_image_manifest(__file__)
raise AssertionError() raise AssertionError()
@ -65,7 +68,9 @@ def test_archived_image_manifest_raises_when_file_not_a_tar():
assert str(__file__) in str(e) assert str(__file__) in str(e)
def test_archived_image_manifest_raises_when_tar_missing_manifest(tar_file_name): def test_archived_image_manifest_raises_when_tar_missing_manifest(
tar_file_name,
) -> None:
write_irrelevant_tar(tar_file_name) write_irrelevant_tar(tar_file_name)
try: try:
@ -76,7 +81,7 @@ def test_archived_image_manifest_raises_when_tar_missing_manifest(tar_file_name)
assert "manifest.json" in str(e.__cause__) assert "manifest.json" in str(e.__cause__)
def test_archived_image_manifest_raises_when_manifest_missing_id(tar_file_name): def test_archived_image_manifest_raises_when_manifest_missing_id(tar_file_name) -> None:
manifest = [{"foo": "bar"}] manifest = [{"foo": "bar"}]
write_imitation_archive_with_manifest(tar_file_name, manifest) write_imitation_archive_with_manifest(tar_file_name, manifest)

View File

@ -4,6 +4,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
import pytest import pytest
from ansible_collections.community.docker.plugins.module_utils._logfmt import ( from ansible_collections.community.docker.plugins.module_utils._logfmt import (
@ -12,7 +14,7 @@ from ansible_collections.community.docker.plugins.module_utils._logfmt import (
) )
SUCCESS_TEST_CASES = [ SUCCESS_TEST_CASES: list[tuple[str, dict[str, t.Any], dict[str, t.Any]]] = [
( (
'time="2024-02-02T08:14:10+01:00" level=warning msg="a network with name influxNetwork exists but was not' 'time="2024-02-02T08:14:10+01:00" level=warning msg="a network with name influxNetwork exists but was not'
' created for project \\"influxdb\\".\\nSet `external: true` to use an existing network"', ' created for project \\"influxdb\\".\\nSet `external: true` to use an existing network"',
@ -59,7 +61,7 @@ SUCCESS_TEST_CASES = [
] ]
FAILURE_TEST_CASES = [ FAILURE_TEST_CASES: list[tuple[str, dict[str, t.Any], str]] = [
( (
'foo=bar a=14 baz="hello kitty" cool%story=bro f %^asdf', 'foo=bar a=14 baz="hello kitty" cool%story=bro f %^asdf',
{"logrus_mode": True}, {"logrus_mode": True},
@ -84,14 +86,16 @@ FAILURE_TEST_CASES = [
@pytest.mark.parametrize("line, kwargs, result", SUCCESS_TEST_CASES) @pytest.mark.parametrize("line, kwargs, result", SUCCESS_TEST_CASES)
def test_parse_line_success(line, kwargs, result): def test_parse_line_success(
line: str, kwargs: dict[str, t.Any], result: dict[str, t.Any]
) -> None:
res = parse_line(line, **kwargs) res = parse_line(line, **kwargs)
print(repr(res)) print(repr(res))
assert res == result assert res == result
@pytest.mark.parametrize("line, kwargs, message", FAILURE_TEST_CASES) @pytest.mark.parametrize("line, kwargs, message", FAILURE_TEST_CASES)
def test_parse_line_failure(line, kwargs, message): def test_parse_line_failure(line: str, kwargs: dict[str, t.Any], message: str) -> None:
with pytest.raises(InvalidLogFmt) as exc: with pytest.raises(InvalidLogFmt) as exc:
parse_line(line, **kwargs) parse_line(line, **kwargs)

View File

@ -20,7 +20,7 @@ from ansible_collections.community.docker.plugins.module_utils._scramble import
("hello", b"\x01", "=S=aWRtbW4="), ("hello", b"\x01", "=S=aWRtbW4="),
], ],
) )
def test_scramble_unscramble(plaintext, key, scrambled): def test_scramble_unscramble(plaintext: str, key: bytes, scrambled: str) -> None:
scrambled_ = scramble(plaintext, key) scrambled_ = scramble(plaintext, key)
print(f"{scrambled_!r} == {scrambled!r}") print(f"{scrambled_!r} == {scrambled!r}")
assert scrambled_ == scrambled assert scrambled_ == scrambled

View File

@ -4,6 +4,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
import pytest import pytest
from ansible_collections.community.docker.plugins.module_utils._util import ( from ansible_collections.community.docker.plugins.module_utils._util import (
@ -14,15 +16,41 @@ from ansible_collections.community.docker.plugins.module_utils._util import (
) )
DICT_ALLOW_MORE_PRESENT = ( if t.TYPE_CHECKING:
class DAMSpec(t.TypedDict):
av: dict[str, t.Any]
bv: dict[str, t.Any]
result: bool
class Spec(t.TypedDict):
a: t.Any
b: t.Any
method: t.Literal["strict", "ignore", "allow_more_present"]
type: t.Literal["value", "list", "set", "set(dict)", "dict"]
result: bool
DICT_ALLOW_MORE_PRESENT: list[DAMSpec] = [
{"av": {}, "bv": {"a": 1}, "result": True}, {"av": {}, "bv": {"a": 1}, "result": True},
{"av": {"a": 1}, "bv": {"a": 1, "b": 2}, "result": True}, {"av": {"a": 1}, "bv": {"a": 1, "b": 2}, "result": True},
{"av": {"a": 1}, "bv": {"b": 2}, "result": False}, {"av": {"a": 1}, "bv": {"b": 2}, "result": False},
{"av": {"a": 1}, "bv": {"a": None, "b": 1}, "result": False}, {"av": {"a": 1}, "bv": {"a": None, "b": 1}, "result": False},
{"av": {"a": None}, "bv": {"b": 1}, "result": False}, {"av": {"a": None}, "bv": {"b": 1}, "result": False},
) ]
COMPARE_GENERIC = [ DICT_ALLOW_MORE_PRESENT_SPECS: list[Spec] = [
{
"a": entry["av"],
"b": entry["bv"],
"method": "allow_more_present",
"type": "dict",
"result": entry["result"],
}
for entry in DICT_ALLOW_MORE_PRESENT
]
COMPARE_GENERIC: list[Spec] = [
######################################################################################## ########################################################################################
# value # value
{"a": 1, "b": 2, "method": "strict", "type": "value", "result": False}, {"a": 1, "b": 2, "method": "strict", "type": "value", "result": False},
@ -386,43 +414,34 @@ COMPARE_GENERIC = [
"type": "dict", "type": "dict",
"result": True, "result": True,
}, },
] + [
{
"a": entry["av"],
"b": entry["bv"],
"method": "allow_more_present",
"type": "dict",
"result": entry["result"],
}
for entry in DICT_ALLOW_MORE_PRESENT
] ]
@pytest.mark.parametrize("entry", DICT_ALLOW_MORE_PRESENT) @pytest.mark.parametrize("entry", DICT_ALLOW_MORE_PRESENT)
def test_dict_allow_more_present(entry): def test_dict_allow_more_present(entry: DAMSpec) -> None:
assert compare_dict_allow_more_present(entry["av"], entry["bv"]) == entry["result"] assert compare_dict_allow_more_present(entry["av"], entry["bv"]) == entry["result"]
@pytest.mark.parametrize("entry", COMPARE_GENERIC) @pytest.mark.parametrize("entry", COMPARE_GENERIC + DICT_ALLOW_MORE_PRESENT_SPECS)
def test_compare_generic(entry): def test_compare_generic(entry: Spec) -> None:
assert ( assert (
compare_generic(entry["a"], entry["b"], entry["method"], entry["type"]) compare_generic(entry["a"], entry["b"], entry["method"], entry["type"])
== entry["result"] == entry["result"]
) )
def test_convert_duration_to_nanosecond(): def test_convert_duration_to_nanosecond() -> None:
nanoseconds = convert_duration_to_nanosecond("5s") nanoseconds = convert_duration_to_nanosecond("5s")
assert nanoseconds == 5000000000 assert nanoseconds == 5000000000
nanoseconds = convert_duration_to_nanosecond("1m5s") nanoseconds = convert_duration_to_nanosecond("1m5s")
assert nanoseconds == 65000000000 assert nanoseconds == 65000000000
with pytest.raises(ValueError): with pytest.raises(ValueError):
convert_duration_to_nanosecond([1, 2, 3]) convert_duration_to_nanosecond([1, 2, 3]) # type: ignore
with pytest.raises(ValueError): with pytest.raises(ValueError):
convert_duration_to_nanosecond("10x") convert_duration_to_nanosecond("10x")
def test_parse_healthcheck(): def test_parse_healthcheck() -> None:
result, disabled = parse_healthcheck( result, disabled = parse_healthcheck(
{ {
"test": "sleep 1", "test": "sleep 1",

View File

@ -4,6 +4,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
import pytest import pytest
from ansible_collections.community.docker.plugins.modules.docker_container_copy_into import ( from ansible_collections.community.docker.plugins.modules.docker_container_copy_into import (
@ -30,7 +32,7 @@ from ansible_collections.community.docker.plugins.modules.docker_container_copy_
("-1", -1), ("-1", -1),
], ],
) )
def test_parse_string(value, expected): def test_parse_string(value: str, expected: int) -> None:
assert parse_modern(value) == expected assert parse_modern(value) == expected
assert parse_octal_string_only(value) == expected assert parse_octal_string_only(value) == expected
@ -45,10 +47,10 @@ def test_parse_string(value, expected):
123456789012345678901234567890123456789012345678901234567890, 123456789012345678901234567890123456789012345678901234567890,
], ],
) )
def test_parse_int(value): def test_parse_int(value: int) -> None:
assert parse_modern(value) == value assert parse_modern(value) == value
with pytest.raises(TypeError, match=f"^must be an octal string, got {value}L?$"): with pytest.raises(TypeError, match=f"^must be an octal string, got {value}L?$"):
parse_octal_string_only(value) parse_octal_string_only(value) # type: ignore
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -60,7 +62,7 @@ def test_parse_int(value):
{}, {},
], ],
) )
def test_parse_bad_type(value): def test_parse_bad_type(value: t.Any) -> None:
with pytest.raises(TypeError, match="^must be an octal string or an integer, got "): with pytest.raises(TypeError, match="^must be an octal string or an integer, got "):
parse_modern(value) parse_modern(value)
with pytest.raises(TypeError, match="^must be an octal string, got "): with pytest.raises(TypeError, match="^must be an octal string, got "):
@ -75,7 +77,7 @@ def test_parse_bad_type(value):
"9", "9",
], ],
) )
def test_parse_bad_value(value): def test_parse_bad_value(value: str) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
parse_modern(value) parse_modern(value)
with pytest.raises(ValueError): with pytest.raises(ValueError):

View File

@ -4,6 +4,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
import pytest import pytest
from ansible_collections.community.docker.plugins.module_utils._image_archive import ( from ansible_collections.community.docker.plugins.module_utils._image_archive import (
@ -19,12 +21,17 @@ from ..test_support.docker_image_archive_stubbing import (
) )
def assert_no_logging(msg): if t.TYPE_CHECKING:
from collections.abc import Callable
from pathlib import Path
def assert_no_logging(msg: str) -> t.NoReturn:
raise AssertionError(f"Should not have logged anything but logged {msg}") raise AssertionError(f"Should not have logged anything but logged {msg}")
def capture_logging(messages): def capture_logging(messages: list[str]) -> Callable[[str], None]:
def capture(msg): def capture(msg: str) -> None:
messages.append(msg) messages.append(msg)
return capture return capture
@ -39,7 +46,7 @@ def tar_file_name(tmpdir):
return tmpdir.join("foo.tar") return tmpdir.join("foo.tar")
def test_archived_image_action_when_missing(tar_file_name): def test_archived_image_action_when_missing(tar_file_name) -> None:
fake_name = "a:latest" fake_name = "a:latest"
fake_id = "a1" fake_id = "a1"
@ -52,7 +59,7 @@ def test_archived_image_action_when_missing(tar_file_name):
assert actual == expected assert actual == expected
def test_archived_image_action_when_current(tar_file_name): def test_archived_image_action_when_current(tar_file_name) -> None:
fake_name = "b:latest" fake_name = "b:latest"
fake_id = "b2" fake_id = "b2"
@ -65,7 +72,7 @@ def test_archived_image_action_when_current(tar_file_name):
assert actual is None assert actual is None
def test_archived_image_action_when_invalid(tar_file_name): def test_archived_image_action_when_invalid(tar_file_name) -> None:
fake_name = "c:1.2.3" fake_name = "c:1.2.3"
fake_id = "c3" fake_id = "c3"
@ -73,7 +80,7 @@ def test_archived_image_action_when_invalid(tar_file_name):
expected = f"Archived image {fake_name} to {tar_file_name}, overwriting an unreadable archive file" expected = f"Archived image {fake_name} to {tar_file_name}, overwriting an unreadable archive file"
actual_log = [] actual_log: list[str] = []
actual = ImageManager.archived_image_action( actual = ImageManager.archived_image_action(
capture_logging(actual_log), tar_file_name, fake_name, api_image_id(fake_id) capture_logging(actual_log), tar_file_name, fake_name, api_image_id(fake_id)
) )
@ -84,7 +91,7 @@ def test_archived_image_action_when_invalid(tar_file_name):
assert actual_log[0].startswith("Unable to extract manifest summary from archive") assert actual_log[0].startswith("Unable to extract manifest summary from archive")
def test_archived_image_action_when_obsolete_by_id(tar_file_name): def test_archived_image_action_when_obsolete_by_id(tar_file_name) -> None:
fake_name = "d:0.0.1" fake_name = "d:0.0.1"
old_id = "e5" old_id = "e5"
new_id = "d4" new_id = "d4"
@ -99,7 +106,7 @@ def test_archived_image_action_when_obsolete_by_id(tar_file_name):
assert actual == expected assert actual == expected
def test_archived_image_action_when_obsolete_by_name(tar_file_name): def test_archived_image_action_when_obsolete_by_name(tar_file_name) -> None:
old_name = "hi" old_name = "hi"
new_name = "d:0.0.1" new_name = "d:0.0.1"
fake_id = "d4" fake_id = "d4"

View File

@ -21,5 +21,5 @@ from ansible_collections.community.docker.plugins.modules.docker_image_build imp
('\rhello, "hi" !\n', '"\rhello, ""hi"" !\n"'), ('\rhello, "hi" !\n', '"\rhello, ""hi"" !\n"'),
], ],
) )
def test__quote_csv(value, expected): def test__quote_csv(value: str, expected: str) -> None:
assert _quote_csv(value) == expected assert _quote_csv(value) == expected

View File

@ -6,6 +6,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
import pytest import pytest
from ansible_collections.community.docker.plugins.modules.docker_network import ( from ansible_collections.community.docker.plugins.modules.docker_network import (
@ -23,7 +25,9 @@ from ansible_collections.community.docker.plugins.modules.docker_network import
("fdd1:ac8c:0557:7ce2::/128", "ipv6"), ("fdd1:ac8c:0557:7ce2::/128", "ipv6"),
], ],
) )
def test_validate_cidr_positives(cidr, expected): def test_validate_cidr_positives(
cidr: str, expected: t.Literal["ipv4", "ipv6"]
) -> None:
assert validate_cidr(cidr) == expected assert validate_cidr(cidr) == expected
@ -36,7 +40,7 @@ def test_validate_cidr_positives(cidr, expected):
"fdd1:ac8c:0557:7ce2::", "fdd1:ac8c:0557:7ce2::",
], ],
) )
def test_validate_cidr_negatives(cidr): def test_validate_cidr_negatives(cidr: str) -> None:
with pytest.raises(ValueError) as e: with pytest.raises(ValueError) as e:
validate_cidr(cidr) validate_cidr(cidr)
assert f'"{cidr}" is not a valid CIDR' == str(e.value) assert f'"{cidr}" is not a valid CIDR' == str(e.value)

View File

@ -4,66 +4,47 @@
from __future__ import annotations from __future__ import annotations
import typing as t
import pytest import pytest
from ansible_collections.community.docker.plugins.modules import (
class APIErrorMock(Exception): docker_swarm_service,
def __init__(self, message, response=None, explanation=None): )
self.message = message
self.response = response
self.explanation = explanation
@pytest.fixture(autouse=True) APIError = pytest.importorskip("docker.errors.APIError")
def docker_module_mock(mocker):
docker_module_mock = mocker.MagicMock()
docker_utils_module_mock = mocker.MagicMock()
docker_errors_module_mock = mocker.MagicMock()
docker_errors_module_mock.APIError = APIErrorMock
mock_modules = {
"docker": docker_module_mock,
"docker.utils": docker_utils_module_mock,
"docker.errors": docker_errors_module_mock,
}
return mocker.patch.dict("sys.modules", **mock_modules)
@pytest.fixture(autouse=True) def test_retry_on_out_of_sequence_error(mocker) -> None:
def docker_swarm_service():
from ansible_collections.community.docker.plugins.modules import (
docker_swarm_service,
)
return docker_swarm_service
def test_retry_on_out_of_sequence_error(mocker, docker_swarm_service):
run_mock = mocker.MagicMock( run_mock = mocker.MagicMock(
side_effect=APIErrorMock( side_effect=APIError(
message="", message="",
response=None, response=None,
explanation="rpc error: code = Unknown desc = update out of sequence", explanation="rpc error: code = Unknown desc = update out of sequence",
) )
) )
manager = docker_swarm_service.DockerServiceManager(client=None) mocker.patch("time.sleep")
manager.run = run_mock manager = docker_swarm_service.DockerServiceManager(client=None) # type: ignore
with pytest.raises(APIErrorMock): manager.run = run_mock # type: ignore
with pytest.raises(APIError):
manager.run_safe() manager.run_safe()
assert run_mock.call_count == 3 assert run_mock.call_count == 3
def test_no_retry_on_general_api_error(mocker, docker_swarm_service): def test_no_retry_on_general_api_error(mocker) -> None:
run_mock = mocker.MagicMock( run_mock = mocker.MagicMock(
side_effect=APIErrorMock(message="", response=None, explanation="some error") side_effect=APIError(message="", response=None, explanation="some error")
) )
manager = docker_swarm_service.DockerServiceManager(client=None) mocker.patch("time.sleep")
manager.run = run_mock manager = docker_swarm_service.DockerServiceManager(client=None) # type: ignore
with pytest.raises(APIErrorMock): manager.run = run_mock # type: ignore
with pytest.raises(APIError):
manager.run_safe() manager.run_safe()
assert run_mock.call_count == 1 assert run_mock.call_count == 1
def test_get_docker_environment(mocker, docker_swarm_service): def test_get_docker_environment(mocker) -> None:
env_file_result = {"TEST1": "A", "TEST2": "B", "TEST3": "C"} env_file_result = {"TEST1": "A", "TEST2": "B", "TEST3": "C"}
env_dict = {"TEST3": "CC", "TEST4": "D"} env_dict = {"TEST3": "CC", "TEST4": "D"}
env_string = "TEST3=CC,TEST4=D" env_string = "TEST3=CC,TEST4=D"
@ -103,7 +84,7 @@ def test_get_docker_environment(mocker, docker_swarm_service):
assert result == [] assert result == []
def test_get_nanoseconds_from_raw_option(docker_swarm_service): def test_get_nanoseconds_from_raw_option() -> None:
value = docker_swarm_service.get_nanoseconds_from_raw_option("test", None) value = docker_swarm_service.get_nanoseconds_from_raw_option("test", None)
assert value is None assert value is None
@ -117,7 +98,7 @@ def test_get_nanoseconds_from_raw_option(docker_swarm_service):
docker_swarm_service.get_nanoseconds_from_raw_option("test", []) docker_swarm_service.get_nanoseconds_from_raw_option("test", [])
def test_has_dict_changed(docker_swarm_service): def test_has_dict_changed() -> None:
assert not docker_swarm_service.has_dict_changed( assert not docker_swarm_service.has_dict_changed(
{"a": 1}, {"a": 1},
{"a": 1}, {"a": 1},
@ -135,8 +116,7 @@ def test_has_dict_changed(docker_swarm_service):
assert not docker_swarm_service.has_dict_changed(None, {}) assert not docker_swarm_service.has_dict_changed(None, {})
def test_has_list_changed(docker_swarm_service): def test_has_list_changed() -> None:
# List comparisons without dictionaries # List comparisons without dictionaries
# I could improve the indenting, but pycodestyle wants this instead # I could improve the indenting, but pycodestyle wants this instead
assert not docker_swarm_service.has_list_changed(None, None) assert not docker_swarm_service.has_list_changed(None, None)
@ -161,7 +141,7 @@ def test_has_list_changed(docker_swarm_service):
assert docker_swarm_service.has_list_changed([None, 1], [2, 1]) assert docker_swarm_service.has_list_changed([None, 1], [2, 1])
assert docker_swarm_service.has_list_changed([2, 1], [None, 1]) assert docker_swarm_service.has_list_changed([2, 1], [None, 1])
assert docker_swarm_service.has_list_changed( assert docker_swarm_service.has_list_changed(
"command --with args", ["command", "--with", "args"] ["command --with args"], ["command", "--with", "args"]
) )
assert docker_swarm_service.has_list_changed( assert docker_swarm_service.has_list_changed(
["sleep", "3400"], ["sleep", "3600"], sort_lists=False ["sleep", "3400"], ["sleep", "3600"], sort_lists=False
@ -259,7 +239,7 @@ def test_has_list_changed(docker_swarm_service):
) )
def test_have_networks_changed(docker_swarm_service): def test_have_networks_changed() -> None:
assert not docker_swarm_service.have_networks_changed(None, None) assert not docker_swarm_service.have_networks_changed(None, None)
assert not docker_swarm_service.have_networks_changed([], None) assert not docker_swarm_service.have_networks_changed([], None)
@ -329,14 +309,14 @@ def test_have_networks_changed(docker_swarm_service):
) )
def test_get_docker_networks(docker_swarm_service): def test_get_docker_networks() -> None:
network_names = [ network_names = [
"network_1", "network_1",
"network_2", "network_2",
"network_3", "network_3",
"network_4", "network_4",
] ]
networks = [ networks: list[str | dict[str, t.Any]] = [
network_names[0], network_names[0],
{"name": network_names[1]}, {"name": network_names[1]},
{"name": network_names[2], "aliases": ["networkalias1"]}, {"name": network_names[2], "aliases": ["networkalias1"]},
@ -367,28 +347,27 @@ def test_get_docker_networks(docker_swarm_service):
assert "foo" in network["options"] assert "foo" in network["options"]
# Test missing name # Test missing name
with pytest.raises(TypeError): with pytest.raises(TypeError):
docker_swarm_service.get_docker_networks([{"invalid": "err"}], {"err": 1}) docker_swarm_service.get_docker_networks([{"invalid": "err"}], {"err": "x"})
# test for invalid aliases type # test for invalid aliases type
with pytest.raises(TypeError): with pytest.raises(TypeError):
docker_swarm_service.get_docker_networks( docker_swarm_service.get_docker_networks(
[{"name": "test", "aliases": 1}], {"test": 1} [{"name": "test", "aliases": 1}], {"test": "x"}
) )
# Test invalid aliases elements # Test invalid aliases elements
with pytest.raises(TypeError): with pytest.raises(TypeError):
docker_swarm_service.get_docker_networks( docker_swarm_service.get_docker_networks(
[{"name": "test", "aliases": [1]}], {"test": 1} [{"name": "test", "aliases": [1]}], {"test": "x"}
) )
# Test for invalid options type # Test for invalid options type
with pytest.raises(TypeError): with pytest.raises(TypeError):
docker_swarm_service.get_docker_networks( docker_swarm_service.get_docker_networks(
[{"name": "test", "options": 1}], {"test": 1} [{"name": "test", "options": 1}], {"test": "x"}
) )
# Test for invalid networks type
with pytest.raises(TypeError):
docker_swarm_service.get_docker_networks(1, {"test": 1})
# Test for non existing networks # Test for non existing networks
with pytest.raises(ValueError): with pytest.raises(ValueError):
docker_swarm_service.get_docker_networks([{"name": "idontexist"}], {"test": 1}) docker_swarm_service.get_docker_networks(
[{"name": "idontexist"}], {"test": "x"}
)
# Test empty values # Test empty values
assert docker_swarm_service.get_docker_networks([], {}) == [] assert docker_swarm_service.get_docker_networks([], {}) == []
assert docker_swarm_service.get_docker_networks(None, {}) is None assert docker_swarm_service.get_docker_networks(None, {}) is None

View File

@ -4,6 +4,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
import pytest import pytest
from ansible_collections.community.internal_test_tools.tests.unit.utils.trust import ( from ansible_collections.community.internal_test_tools.tests.unit.utils.trust import (
SUPPORTS_DATA_TAGGING, SUPPORTS_DATA_TAGGING,
@ -23,7 +25,9 @@ from ansible_collections.community.docker.plugins.plugin_utils._unsafe import (
) )
TEST_MAKE_UNSAFE = [ TEST_MAKE_UNSAFE: list[
tuple[t.Any, list[tuple[t.Any, ...]], list[tuple[t.Any, ...]]]
] = [
( (
_make_trusted("text"), _make_trusted("text"),
[], [],
@ -97,7 +101,11 @@ if not SUPPORTS_DATA_TAGGING:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"value, check_unsafe_paths, check_safe_paths", TEST_MAKE_UNSAFE "value, check_unsafe_paths, check_safe_paths", TEST_MAKE_UNSAFE
) )
def test_make_unsafe(value, check_unsafe_paths, check_safe_paths): def test_make_unsafe(
value: t.Any,
check_unsafe_paths: list[tuple[t.Any, ...]],
check_safe_paths: list[tuple[t.Any, ...]],
) -> None:
unsafe_value = make_unsafe(value) unsafe_value = make_unsafe(value)
assert unsafe_value == value assert unsafe_value == value
for check_path in check_unsafe_paths: for check_path in check_unsafe_paths:
@ -112,7 +120,7 @@ def test_make_unsafe(value, check_unsafe_paths, check_safe_paths):
assert _is_trusted(obj) assert _is_trusted(obj)
def test_make_unsafe_idempotence(): def test_make_unsafe_idempotence() -> None:
assert make_unsafe(None) is None assert make_unsafe(None) is None
unsafe_str = _make_untrusted("{{test}}") unsafe_str = _make_untrusted("{{test}}")
@ -122,8 +130,8 @@ def test_make_unsafe_idempotence():
assert id(make_unsafe(safe_str)) != id(safe_str) assert id(make_unsafe(safe_str)) != id(safe_str)
def test_make_unsafe_dict_key(): def test_make_unsafe_dict_key() -> None:
value = { value: dict[t.Any, t.Any] = {
_make_trusted("test"): 2, _make_trusted("test"): 2,
} }
if not SUPPORTS_DATA_TAGGING: if not SUPPORTS_DATA_TAGGING:
@ -144,8 +152,8 @@ def test_make_unsafe_dict_key():
assert not _is_trusted(obj) assert not _is_trusted(obj)
def test_make_unsafe_set(): def test_make_unsafe_set() -> None:
value = set([_make_trusted("test")]) value: set[t.Any] = set([_make_trusted("test")])
if not SUPPORTS_DATA_TAGGING: if not SUPPORTS_DATA_TAGGING:
value.add(_make_trusted(b"test")) value.add(_make_trusted(b"test"))
unsafe_value = make_unsafe(value) unsafe_value = make_unsafe(value)

View File

@ -6,10 +6,13 @@ from __future__ import annotations
import json import json
import tarfile import tarfile
import typing as t
from tempfile import TemporaryFile from tempfile import TemporaryFile
def write_imitation_archive(file_name, image_id, repo_tags): def write_imitation_archive(
file_name: str, image_id: str, repo_tags: list[str]
) -> None:
""" """
Write a tar file meeting these requirements: Write a tar file meeting these requirements:
@ -21,7 +24,7 @@ def write_imitation_archive(file_name, image_id, repo_tags):
:type file_name: str :type file_name: str
:param image_id: Fake sha256 hash (without the sha256: prefix) :param image_id: Fake sha256 hash (without the sha256: prefix)
:type image_id: str :type image_id: str
:param repo_tags: list of fake image:tag's :param repo_tags: list of fake image tags
:type repo_tags: list :type repo_tags: list
""" """
@ -30,7 +33,9 @@ def write_imitation_archive(file_name, image_id, repo_tags):
write_imitation_archive_with_manifest(file_name, manifest) write_imitation_archive_with_manifest(file_name, manifest)
def write_imitation_archive_with_manifest(file_name, manifest): def write_imitation_archive_with_manifest(
file_name: str, manifest: list[dict[str, t.Any]]
) -> None:
with tarfile.open(file_name, "w") as tf: with tarfile.open(file_name, "w") as tf:
with TemporaryFile() as f: with TemporaryFile() as f:
f.write(json.dumps(manifest).encode("utf-8")) f.write(json.dumps(manifest).encode("utf-8"))
@ -42,7 +47,7 @@ def write_imitation_archive_with_manifest(file_name, manifest):
tf.addfile(ti, f) tf.addfile(ti, f)
def write_irrelevant_tar(file_name): def write_irrelevant_tar(file_name: str) -> None:
""" """
Create a tar file that does not match the spec for "docker image save" / "docker image load" commands. Create a tar file that does not match the spec for "docker image save" / "docker image load" commands.

View File

@ -2,4 +2,5 @@
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt) # GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
docker
requests requests