Add more type hints.

This commit is contained in:
Felix Fontein 2025-10-21 06:25:30 +02:00
parent 0ff66e8b24
commit 08960a9317
28 changed files with 1104 additions and 650 deletions

View File

@ -13,8 +13,9 @@ from __future__ import annotations
import json import json
import logging import logging
import os
import struct import struct
from functools import partial import typing as t
from urllib.parse import quote from urllib.parse import quote
from .. import auth from .. import auth
@ -50,13 +51,13 @@ from ..utils import config, json_stream, utils
from ..utils.decorators import update_headers from ..utils.decorators import update_headers
from ..utils.proxy import ProxyConfig from ..utils.proxy import ProxyConfig
from ..utils.socket import consume_socket_output, demux_adaptor, frames_iter from ..utils.socket import consume_socket_output, demux_adaptor, frames_iter
from .daemon import DaemonApiMixin from ..utils.decorators import minimum_version
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class APIClient(_Session, DaemonApiMixin): class APIClient(_Session):
""" """
A low-level client for the Docker Engine API. A low-level client for the Docker Engine API.
@ -105,16 +106,16 @@ class APIClient(_Session, DaemonApiMixin):
def __init__( def __init__(
self, self,
base_url=None, base_url: str | None = None,
version=None, version: str | None = None,
timeout=DEFAULT_TIMEOUT_SECONDS, timeout: int = DEFAULT_TIMEOUT_SECONDS,
tls=False, tls: bool | TLSConfig = False,
user_agent=DEFAULT_USER_AGENT, user_agent: str = DEFAULT_USER_AGENT,
num_pools=None, num_pools: int | None = None,
credstore_env=None, credstore_env: dict[str, str] | None = None,
use_ssh_client=False, use_ssh_client: bool = False,
max_pool_size=DEFAULT_MAX_POOL_SIZE, max_pool_size: int = DEFAULT_MAX_POOL_SIZE,
): ) -> None:
super().__init__() super().__init__()
fail_on_missing_imports() fail_on_missing_imports()
@ -152,6 +153,9 @@ class APIClient(_Session, DaemonApiMixin):
else DEFAULT_NUM_POOLS else DEFAULT_NUM_POOLS
) )
self._custom_adapter: (
UnixHTTPAdapter | NpipeHTTPAdapter | SSHHTTPAdapter | SSLHTTPAdapter | None
) = None
if base_url.startswith("http+unix://"): if base_url.startswith("http+unix://"):
self._custom_adapter = UnixHTTPAdapter( self._custom_adapter = UnixHTTPAdapter(
base_url, base_url,
@ -223,7 +227,7 @@ class APIClient(_Session, DaemonApiMixin):
f"API versions below {MINIMUM_DOCKER_API_VERSION} are no longer supported by this library." f"API versions below {MINIMUM_DOCKER_API_VERSION} are no longer supported by this library."
) )
def _retrieve_server_version(self): def _retrieve_server_version(self) -> str:
try: try:
version_result = self.version(api_version=False) version_result = self.version(api_version=False)
except Exception as e: except Exception as e:
@ -242,54 +246,87 @@ class APIClient(_Session, DaemonApiMixin):
f"Error while fetching server API version: {e}. Response seems to be broken." f"Error while fetching server API version: {e}. Response seems to be broken."
) from e ) from e
def _set_request_timeout(self, kwargs): def _set_request_timeout(self, kwargs: dict[str, t.Any]) -> dict[str, t.Any]:
"""Prepare the kwargs for an HTTP request by inserting the timeout """Prepare the kwargs for an HTTP request by inserting the timeout
parameter, if not already present.""" parameter, if not already present."""
kwargs.setdefault("timeout", self.timeout) kwargs.setdefault("timeout", self.timeout)
return kwargs return kwargs
@update_headers @update_headers
def _post(self, url, **kwargs): def _post(self, url: str, **kwargs):
return self.post(url, **self._set_request_timeout(kwargs)) return self.post(url, **self._set_request_timeout(kwargs))
@update_headers @update_headers
def _get(self, url, **kwargs): def _get(self, url: str, **kwargs):
return self.get(url, **self._set_request_timeout(kwargs)) return self.get(url, **self._set_request_timeout(kwargs))
@update_headers @update_headers
def _head(self, url, **kwargs): def _head(self, url: str, **kwargs):
return self.head(url, **self._set_request_timeout(kwargs)) return self.head(url, **self._set_request_timeout(kwargs))
@update_headers @update_headers
def _put(self, url, **kwargs): def _put(self, url: str, **kwargs):
return self.put(url, **self._set_request_timeout(kwargs)) return self.put(url, **self._set_request_timeout(kwargs))
@update_headers @update_headers
def _delete(self, url, **kwargs): def _delete(self, url: str, **kwargs):
return self.delete(url, **self._set_request_timeout(kwargs)) return self.delete(url, **self._set_request_timeout(kwargs))
def _url(self, pathfmt, *args, **kwargs): def _url(self, pathfmt: str, *args: str, versioned_api: bool = True) -> str:
for arg in args: for arg in args:
if not isinstance(arg, str): if not isinstance(arg, str):
raise ValueError( raise ValueError(
f"Expected a string but found {arg} ({type(arg)}) instead" f"Expected a string but found {arg} ({type(arg)}) instead"
) )
quote_f = partial(quote, safe="/:") q_args = [quote(arg, safe="/:") for arg in args]
args = map(quote_f, args)
if kwargs.get("versioned_api", True): if versioned_api:
return f"{self.base_url}/v{self._version}{pathfmt.format(*args)}" return f"{self.base_url}/v{self._version}{pathfmt.format(*q_args)}"
return f"{self.base_url}{pathfmt.format(*args)}" return f"{self.base_url}{pathfmt.format(*q_args)}"
def _raise_for_status(self, response): def _raise_for_status(self, response) -> None:
"""Raises stored :class:`APIError`, if one occurred.""" """Raises stored :class:`APIError`, if one occurred."""
try: try:
response.raise_for_status() response.raise_for_status()
except _HTTPError as e: except _HTTPError as e:
create_api_error_from_http_exception(e) create_api_error_from_http_exception(e)
def _result(self, response, get_json=False, get_binary=False): @t.overload
def _result(
self,
response,
*,
get_json: t.Literal[False] = False,
get_binary: t.Literal[False] = False,
) -> str: ...
@t.overload
def _result(
self,
response,
*,
get_json: t.Literal[True],
get_binary: t.Literal[False] = False,
) -> t.Any: ...
@t.overload
def _result(
self,
response,
*,
get_json: t.Literal[False] = False,
get_binary: t.Literal[True],
) -> bytes: ...
@t.overload
def _result(
self, response, *, get_json: bool = False, get_binary: bool = False
) -> t.Any | str | bytes: ...
def _result(
self, response, *, get_json: bool = False, get_binary: bool = False
) -> t.Any | str | bytes:
if get_json and get_binary: if get_json and get_binary:
raise AssertionError("json and binary must not be both True") raise AssertionError("json and binary must not be both True")
self._raise_for_status(response) self._raise_for_status(response)
@ -300,10 +337,10 @@ class APIClient(_Session, DaemonApiMixin):
return response.content return response.content
return response.text return response.text
def _post_json(self, url, data, **kwargs): def _post_json(self, url: str, data: dict[str, str | None] | t.Any, **kwargs):
# Go <1.1 cannot unserialize null to a string # Go <1.1 cannot unserialize null to a string
# so we do this disgusting thing here. # so we do this disgusting thing here.
data2 = {} data2: dict[str, t.Any] = {}
if data is not None and isinstance(data, dict): if data is not None and isinstance(data, dict):
for k, v in data.items(): for k, v in data.items():
if v is not None: if v is not None:
@ -316,7 +353,7 @@ class APIClient(_Session, DaemonApiMixin):
kwargs["headers"]["Content-Type"] = "application/json" kwargs["headers"]["Content-Type"] = "application/json"
return self._post(url, data=json.dumps(data2), **kwargs) return self._post(url, data=json.dumps(data2), **kwargs)
def _attach_params(self, override=None): def _attach_params(self, override: dict[str, int] | None = None) -> dict[str, int]:
return override or {"stdout": 1, "stderr": 1, "stream": 1} return override or {"stdout": 1, "stderr": 1, "stream": 1}
def _get_raw_response_socket(self, response): def _get_raw_response_socket(self, response):
@ -341,12 +378,24 @@ class APIClient(_Session, DaemonApiMixin):
return sock return sock
def _stream_helper(self, response, decode=False): @t.overload
def _stream_helper(
self, response, *, decode: t.Literal[False] = False
) -> t.Generator[bytes]: ...
@t.overload
def _stream_helper(
self, response, *, decode: t.Literal[True]
) -> t.Generator[t.Any]: ...
def _stream_helper(self, response, *, decode: bool = False) -> t.Generator[t.Any]:
"""Generator for data coming from a chunked-encoded HTTP response.""" """Generator for data coming from a chunked-encoded HTTP response."""
if response.raw._fp.chunked: if response.raw._fp.chunked:
if decode: if decode:
yield from json_stream.json_stream(self._stream_helper(response, False)) yield from json_stream.json_stream(
self._stream_helper(response, decode=False)
)
else: else:
reader = response.raw reader = response.raw
while not reader.closed: while not reader.closed:
@ -362,7 +411,7 @@ class APIClient(_Session, DaemonApiMixin):
# encountered an error immediately # encountered an error immediately
yield self._result(response, get_json=decode) yield self._result(response, get_json=decode)
def _multiplexed_buffer_helper(self, response): def _multiplexed_buffer_helper(self, response) -> t.Generator[bytes]:
"""A generator of multiplexed data blocks read from a buffered """A generator of multiplexed data blocks read from a buffered
response.""" response."""
buf = self._result(response, get_binary=True) buf = self._result(response, get_binary=True)
@ -378,7 +427,7 @@ class APIClient(_Session, DaemonApiMixin):
walker = end walker = end
yield buf[start:end] yield buf[start:end]
def _multiplexed_response_stream_helper(self, response): def _multiplexed_response_stream_helper(self, response) -> t.Generator[bytes]:
"""A generator of multiplexed data blocks coming from a response """A generator of multiplexed data blocks coming from a response
stream.""" stream."""
@ -399,7 +448,19 @@ class APIClient(_Session, DaemonApiMixin):
break break
yield data yield data
def _stream_raw_result(self, response, chunk_size=1, decode=True): @t.overload
def _stream_raw_result(
self, response, *, chunk_size: int = 1, decode: t.Literal[True] = True
) -> t.Generator[str]: ...
@t.overload
def _stream_raw_result(
self, response, *, chunk_size: int = 1, decode: t.Literal[False]
) -> t.Generator[bytes]: ...
def _stream_raw_result(
self, response, *, chunk_size: int = 1, decode: bool = True
) -> t.Generator[str | bytes]:
"""Stream result for TTY-enabled container and raw binary data""" """Stream result for TTY-enabled container and raw binary data"""
self._raise_for_status(response) self._raise_for_status(response)
@ -410,14 +471,81 @@ class APIClient(_Session, DaemonApiMixin):
yield from response.iter_content(chunk_size, decode) yield from response.iter_content(chunk_size, decode)
def _read_from_socket(self, response, stream, tty=True, demux=False): @t.overload
def _read_from_socket(
self,
response,
*,
stream: t.Literal[True],
tty: bool = True,
demux: t.Literal[False] = False,
) -> t.Generator[bytes]: ...
@t.overload
def _read_from_socket(
self,
response,
*,
stream: t.Literal[True],
tty: t.Literal[True] = True,
demux: t.Literal[True],
) -> t.Generator[tuple[bytes, None]]: ...
@t.overload
def _read_from_socket(
self,
response,
*,
stream: t.Literal[True],
tty: t.Literal[False],
demux: t.Literal[True],
) -> t.Generator[tuple[bytes, None] | tuple[None, bytes]]: ...
@t.overload
def _read_from_socket(
self,
response,
*,
stream: t.Literal[False],
tty: bool = True,
demux: t.Literal[False] = False,
) -> bytes: ...
@t.overload
def _read_from_socket(
self,
response,
*,
stream: t.Literal[False],
tty: t.Literal[True] = True,
demux: t.Literal[True],
) -> tuple[bytes, None]: ...
@t.overload
def _read_from_socket(
self,
response,
*,
stream: t.Literal[False],
tty: t.Literal[False],
demux: t.Literal[True],
) -> tuple[bytes, bytes]: ...
@t.overload
def _read_from_socket(
self, response, *, stream: bool, tty: bool = True, demux: bool = False
) -> t.Any: ...
def _read_from_socket(
self, response, *, stream: bool, tty: bool = True, demux: bool = False
) -> t.Any:
"""Consume all data from the socket, close the response and return the """Consume all data from the socket, close the response and return the
data. If stream=True, then a generator is returned instead and the data. If stream=True, then a generator is returned instead and the
caller is responsible for closing the response. caller is responsible for closing the response.
""" """
socket = self._get_raw_response_socket(response) socket = self._get_raw_response_socket(response)
gen = frames_iter(socket, tty) gen: t.Generator = frames_iter(socket, tty)
if demux: if demux:
# The generator will output tuples (stdout, stderr) # The generator will output tuples (stdout, stderr)
@ -434,7 +562,7 @@ class APIClient(_Session, DaemonApiMixin):
finally: finally:
response.close() response.close()
def _disable_socket_timeout(self, socket): def _disable_socket_timeout(self, socket) -> None:
"""Depending on the combination of python version and whether we are """Depending on the combination of python version and whether we are
connecting over http or https, we might need to access _sock, which connecting over http or https, we might need to access _sock, which
may or may not exist; or we may need to just settimeout on socket may or may not exist; or we may need to just settimeout on socket
@ -462,7 +590,27 @@ class APIClient(_Session, DaemonApiMixin):
s.settimeout(None) s.settimeout(None)
def _get_result_tty(self, stream, res, is_tty): @t.overload
def _get_result_tty(
self, stream: t.Literal[True], res, is_tty: t.Literal[True]
) -> t.Generator[str]: ...
@t.overload
def _get_result_tty(
self, stream: t.Literal[True], res, is_tty: t.Literal[False]
) -> t.Generator[bytes]: ...
@t.overload
def _get_result_tty(
self, stream: t.Literal[False], res, is_tty: t.Literal[True]
) -> bytes: ...
@t.overload
def _get_result_tty(
self, stream: t.Literal[False], res, is_tty: t.Literal[False]
) -> bytes: ...
def _get_result_tty(self, stream: bool, res, is_tty: bool) -> t.Any:
# We should also use raw streaming (without keep-alive) # We should also use raw streaming (without keep-alive)
# if we are dealing with a tty-enabled container. # if we are dealing with a tty-enabled container.
if is_tty: if is_tty:
@ -478,11 +626,11 @@ class APIClient(_Session, DaemonApiMixin):
return self._multiplexed_response_stream_helper(res) return self._multiplexed_response_stream_helper(res)
return sep.join(list(self._multiplexed_buffer_helper(res))) return sep.join(list(self._multiplexed_buffer_helper(res)))
def _unmount(self, *args): def _unmount(self, *args) -> None:
for proto in args: for proto in args:
self.adapters.pop(proto) self.adapters.pop(proto)
def get_adapter(self, url): def get_adapter(self, url: str):
try: try:
return super().get_adapter(url) return super().get_adapter(url)
except _InvalidSchema as e: except _InvalidSchema as e:
@ -491,10 +639,10 @@ class APIClient(_Session, DaemonApiMixin):
raise e raise e
@property @property
def api_version(self): def api_version(self) -> str:
return self._version return self._version
def reload_config(self, dockercfg_path=None): def reload_config(self, dockercfg_path: str | None = None) -> None:
""" """
Force a reload of the auth configuration Force a reload of the auth configuration
@ -510,7 +658,7 @@ class APIClient(_Session, DaemonApiMixin):
dockercfg_path, credstore_env=self.credstore_env dockercfg_path, credstore_env=self.credstore_env
) )
def _set_auth_headers(self, headers): def _set_auth_headers(self, headers: dict[str, str | bytes]) -> None:
log.debug("Looking for auth config") log.debug("Looking for auth config")
# If we do not have any auth data so far, try reloading the config # If we do not have any auth data so far, try reloading the config
@ -537,57 +685,62 @@ class APIClient(_Session, DaemonApiMixin):
else: else:
log.debug("No auth config found") log.debug("No auth config found")
def get_binary(self, pathfmt, *args, **kwargs): def get_binary(self, pathfmt: str, *args: str, **kwargs) -> bytes:
return self._result( return self._result(
self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs), self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_binary=True, get_binary=True,
) )
def get_json(self, pathfmt, *args, **kwargs): def get_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any:
return self._result( return self._result(
self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs), self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_json=True, get_json=True,
) )
def get_text(self, pathfmt, *args, **kwargs): def get_text(self, pathfmt: str, *args: str, **kwargs) -> str:
return self._result( return self._result(
self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs) self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs)
) )
def get_raw_stream(self, pathfmt, *args, **kwargs): def get_raw_stream(
chunk_size = kwargs.pop("chunk_size", DEFAULT_DATA_CHUNK_SIZE) self,
pathfmt: str,
*args: str,
chunk_size: int = DEFAULT_DATA_CHUNK_SIZE,
**kwargs,
) -> t.Generator[bytes]:
res = self._get( res = self._get(
self._url(pathfmt, *args, versioned_api=True), stream=True, **kwargs self._url(pathfmt, *args, versioned_api=True), stream=True, **kwargs
) )
self._raise_for_status(res) self._raise_for_status(res)
return self._stream_raw_result(res, chunk_size, False) return self._stream_raw_result(res, chunk_size=chunk_size, decode=False)
def delete_call(self, pathfmt, *args, **kwargs): def delete_call(self, pathfmt: str, *args: str, **kwargs) -> None:
self._raise_for_status( self._raise_for_status(
self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs) self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs)
) )
def delete_json(self, pathfmt, *args, **kwargs): def delete_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any:
return self._result( return self._result(
self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs), self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_json=True, get_json=True,
) )
def post_call(self, pathfmt, *args, **kwargs): def post_call(self, pathfmt: str, *args: str, **kwargs) -> None:
self._raise_for_status( self._raise_for_status(
self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs) self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs)
) )
def post_json(self, pathfmt, *args, **kwargs): def post_json(self, pathfmt: str, *args: str, data: t.Any = None, **kwargs) -> None:
data = kwargs.pop("data", None)
self._raise_for_status( self._raise_for_status(
self._post_json( self._post_json(
self._url(pathfmt, *args, versioned_api=True), data, **kwargs self._url(pathfmt, *args, versioned_api=True), data, **kwargs
) )
) )
def post_json_to_binary(self, pathfmt, *args, **kwargs): def post_json_to_binary(
data = kwargs.pop("data", None) self, pathfmt: str, *args: str, data: t.Any = None, **kwargs
) -> bytes:
return self._result( return self._result(
self._post_json( self._post_json(
self._url(pathfmt, *args, versioned_api=True), data, **kwargs self._url(pathfmt, *args, versioned_api=True), data, **kwargs
@ -595,8 +748,9 @@ class APIClient(_Session, DaemonApiMixin):
get_binary=True, get_binary=True,
) )
def post_json_to_json(self, pathfmt, *args, **kwargs): def post_json_to_json(
data = kwargs.pop("data", None) self, pathfmt: str, *args: str, data: t.Any = None, **kwargs
) -> t.Any:
return self._result( return self._result(
self._post_json( self._post_json(
self._url(pathfmt, *args, versioned_api=True), data, **kwargs self._url(pathfmt, *args, versioned_api=True), data, **kwargs
@ -604,17 +758,24 @@ class APIClient(_Session, DaemonApiMixin):
get_json=True, get_json=True,
) )
def post_json_to_text(self, pathfmt, *args, **kwargs): def post_json_to_text(
data = kwargs.pop("data", None) self, pathfmt: str, *args: str, data: t.Any = None, **kwargs
) -> str:
return self._result( return self._result(
self._post_json( self._post_json(
self._url(pathfmt, *args, versioned_api=True), data, **kwargs self._url(pathfmt, *args, versioned_api=True), data, **kwargs
), ),
) )
def post_json_to_stream_socket(self, pathfmt, *args, **kwargs): def post_json_to_stream_socket(
data = kwargs.pop("data", None) self,
headers = (kwargs.pop("headers", None) or {}).copy() pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
**kwargs,
):
headers = headers.copy() if headers else {}
headers.update( headers.update(
{ {
"Connection": "Upgrade", "Connection": "Upgrade",
@ -631,18 +792,102 @@ class APIClient(_Session, DaemonApiMixin):
) )
) )
def post_json_to_stream(self, pathfmt, *args, **kwargs): @t.overload
data = kwargs.pop("data", None) def post_json_to_stream(
headers = (kwargs.pop("headers", None) or {}).copy() self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: t.Literal[True],
tty: bool = True,
demux: t.Literal[False] = False,
**kwargs,
) -> t.Generator[bytes]: ...
@t.overload
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: t.Literal[True],
tty: t.Literal[True] = True,
demux: t.Literal[True],
**kwargs,
) -> t.Generator[tuple[bytes, None]]: ...
@t.overload
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: t.Literal[True],
tty: t.Literal[False],
demux: t.Literal[True],
**kwargs,
) -> t.Generator[tuple[bytes, None] | tuple[None, bytes]]: ...
@t.overload
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: t.Literal[False],
tty: bool = True,
demux: t.Literal[False] = False,
**kwargs,
) -> bytes: ...
@t.overload
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: t.Literal[False],
tty: t.Literal[True] = True,
demux: t.Literal[True],
**kwargs,
) -> tuple[bytes, None]: ...
@t.overload
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: t.Literal[False],
tty: t.Literal[False],
demux: t.Literal[True],
**kwargs,
) -> tuple[bytes, bytes]: ...
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: bool = False,
demux: bool = False,
tty: bool = False,
**kwargs,
) -> t.Any:
headers = headers.copy() if headers else {}
headers.update( headers.update(
{ {
"Connection": "Upgrade", "Connection": "Upgrade",
"Upgrade": "tcp", "Upgrade": "tcp",
} }
) )
stream = kwargs.pop("stream", False)
demux = kwargs.pop("demux", False)
tty = kwargs.pop("tty", False)
return self._read_from_socket( return self._read_from_socket(
self._post_json( self._post_json(
self._url(pathfmt, *args, versioned_api=True), self._url(pathfmt, *args, versioned_api=True),
@ -651,13 +896,133 @@ class APIClient(_Session, DaemonApiMixin):
stream=True, stream=True,
**kwargs, **kwargs,
), ),
stream, stream=stream,
tty=tty, tty=tty,
demux=demux, demux=demux,
) )
def post_to_json(self, pathfmt, *args, **kwargs): def post_to_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any:
return self._result( return self._result(
self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs), self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_json=True, get_json=True,
) )
@minimum_version("1.25")
def df(self) -> dict[str, t.Any]:
"""
Get data usage information.
Returns:
(dict): A dictionary representing different resource categories
and their respective data usage.
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
url = self._url("/system/df")
return self._result(self._get(url), get_json=True)
def info(self) -> dict[str, t.Any]:
"""
Display system-wide information. Identical to the ``docker info``
command.
Returns:
(dict): The info as a dict
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
return self._result(self._get(self._url("/info")), get_json=True)
def login(
self,
username: str,
password: str | None = None,
email: str | None = None,
registry: str | None = None,
reauth: bool = False,
dockercfg_path: str | None = None,
) -> dict[str, t.Any]:
"""
Authenticate with a registry. Similar to the ``docker login`` command.
Args:
username (str): The registry username
password (str): The plaintext password
email (str): The email for the registry account
registry (str): URL to the registry. E.g.
``https://index.docker.io/v1/``
reauth (bool): Whether or not to refresh existing authentication on
the Docker server.
dockercfg_path (str): Use a custom path for the Docker config file
(default ``$HOME/.docker/config.json`` if present,
otherwise ``$HOME/.dockercfg``)
Returns:
(dict): The response from the login request
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
# If we do not have any auth data so far, try reloading the config file
# one more time in case anything showed up in there.
# If dockercfg_path is passed check to see if the config file exists,
# if so load that config.
if dockercfg_path and os.path.exists(dockercfg_path):
self._auth_configs = auth.load_config(
dockercfg_path, credstore_env=self.credstore_env
)
elif not self._auth_configs or self._auth_configs.is_empty:
self._auth_configs = auth.load_config(credstore_env=self.credstore_env)
authcfg = self._auth_configs.resolve_authconfig(registry)
# If we found an existing auth config for this registry and username
# combination, we can return it immediately unless reauth is requested.
if authcfg and authcfg.get("username", None) == username and not reauth:
return authcfg
req_data = {
"username": username,
"password": password,
"email": email,
"serveraddress": registry,
}
response = self._post_json(self._url("/auth"), data=req_data)
if response.status_code == 200:
self._auth_configs.add_auth(registry or auth.INDEX_NAME, req_data)
return self._result(response, get_json=True)
def ping(self) -> bool:
"""
Checks the server is responsive. An exception will be raised if it
is not responding.
Returns:
(bool) The response from the server.
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
return self._result(self._get(self._url("/_ping"))) == "OK"
def version(self, api_version: bool = True) -> dict[str, t.Any]:
"""
Returns version information from the server. Similar to the ``docker
version`` command.
Returns:
(dict): The server version information
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
url = self._url("/version", versioned_api=api_version)
return self._result(self._get(url), get_json=True)

View File

@ -1,139 +0,0 @@
# This code is part of the Ansible collection community.docker, but is an independent component.
# This particular file, and this file only, is based on the Docker SDK for Python (https://github.com/docker/docker-py/)
#
# Copyright (c) 2016-2022 Docker, Inc.
#
# It is licensed under the Apache 2.0 license (see LICENSES/Apache-2.0.txt in this collection)
# SPDX-License-Identifier: Apache-2.0
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import os
from .. import auth
from ..utils.decorators import minimum_version
class DaemonApiMixin:
@minimum_version("1.25")
def df(self):
"""
Get data usage information.
Returns:
(dict): A dictionary representing different resource categories
and their respective data usage.
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
url = self._url("/system/df")
return self._result(self._get(url), get_json=True)
def info(self):
"""
Display system-wide information. Identical to the ``docker info``
command.
Returns:
(dict): The info as a dict
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
return self._result(self._get(self._url("/info")), get_json=True)
def login(
self,
username,
password=None,
email=None,
registry=None,
reauth=False,
dockercfg_path=None,
):
"""
Authenticate with a registry. Similar to the ``docker login`` command.
Args:
username (str): The registry username
password (str): The plaintext password
email (str): The email for the registry account
registry (str): URL to the registry. E.g.
``https://index.docker.io/v1/``
reauth (bool): Whether or not to refresh existing authentication on
the Docker server.
dockercfg_path (str): Use a custom path for the Docker config file
(default ``$HOME/.docker/config.json`` if present,
otherwise ``$HOME/.dockercfg``)
Returns:
(dict): The response from the login request
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
# If we do not have any auth data so far, try reloading the config file
# one more time in case anything showed up in there.
# If dockercfg_path is passed check to see if the config file exists,
# if so load that config.
if dockercfg_path and os.path.exists(dockercfg_path):
self._auth_configs = auth.load_config(
dockercfg_path, credstore_env=self.credstore_env
)
elif not self._auth_configs or self._auth_configs.is_empty:
self._auth_configs = auth.load_config(credstore_env=self.credstore_env)
authcfg = self._auth_configs.resolve_authconfig(registry)
# If we found an existing auth config for this registry and username
# combination, we can return it immediately unless reauth is requested.
if authcfg and authcfg.get("username", None) == username and not reauth:
return authcfg
req_data = {
"username": username,
"password": password,
"email": email,
"serveraddress": registry,
}
response = self._post_json(self._url("/auth"), data=req_data)
if response.status_code == 200:
self._auth_configs.add_auth(registry or auth.INDEX_NAME, req_data)
return self._result(response, get_json=True)
def ping(self):
"""
Checks the server is responsive. An exception will be raised if it
is not responding.
Returns:
(bool) The response from the server.
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
return self._result(self._get(self._url("/_ping"))) == "OK"
def version(self, api_version=True):
"""
Returns version information from the server. Similar to the ``docker
version`` command.
Returns:
(dict): The server version information
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
url = self._url("/version", versioned_api=api_version)
return self._result(self._get(url), get_json=True)

View File

@ -14,6 +14,7 @@ from __future__ import annotations
import base64 import base64
import json import json
import logging import logging
import typing as t
from . import errors from . import errors
from .credentials.errors import CredentialsNotFound, StoreError from .credentials.errors import CredentialsNotFound, StoreError
@ -21,6 +22,12 @@ from .credentials.store import Store
from .utils import config from .utils import config
if t.TYPE_CHECKING:
from ansible_collections.community.docker.plugins.module_utils._api.api.client import (
APIClient,
)
INDEX_NAME = "docker.io" INDEX_NAME = "docker.io"
INDEX_URL = f"https://index.{INDEX_NAME}/v1/" INDEX_URL = f"https://index.{INDEX_NAME}/v1/"
TOKEN_USERNAME = "<token>" TOKEN_USERNAME = "<token>"
@ -28,7 +35,7 @@ TOKEN_USERNAME = "<token>"
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def resolve_repository_name(repo_name): def resolve_repository_name(repo_name: str) -> tuple[str, str]:
if "://" in repo_name: if "://" in repo_name:
raise errors.InvalidRepository( raise errors.InvalidRepository(
f"Repository name cannot contain a scheme ({repo_name})" f"Repository name cannot contain a scheme ({repo_name})"
@ -42,14 +49,14 @@ def resolve_repository_name(repo_name):
return resolve_index_name(index_name), remote_name return resolve_index_name(index_name), remote_name
def resolve_index_name(index_name): def resolve_index_name(index_name: str) -> str:
index_name = convert_to_hostname(index_name) index_name = convert_to_hostname(index_name)
if index_name == "index." + INDEX_NAME: if index_name == "index." + INDEX_NAME:
index_name = INDEX_NAME index_name = INDEX_NAME
return index_name return index_name
def get_config_header(client, registry): def get_config_header(client: APIClient, registry: str) -> bytes | None:
log.debug("Looking for auth config") log.debug("Looking for auth config")
if not client._auth_configs or client._auth_configs.is_empty: if not client._auth_configs or client._auth_configs.is_empty:
log.debug("No auth config in memory - loading from filesystem") log.debug("No auth config in memory - loading from filesystem")
@ -69,32 +76,38 @@ def get_config_header(client, registry):
return None return None
def split_repo_name(repo_name): def split_repo_name(repo_name: str) -> tuple[str, str]:
parts = repo_name.split("/", 1) parts = repo_name.split("/", 1)
if len(parts) == 1 or ( if len(parts) == 1 or (
"." not in parts[0] and ":" not in parts[0] and parts[0] != "localhost" "." not in parts[0] and ":" not in parts[0] and parts[0] != "localhost"
): ):
# This is a docker index repo (ex: username/foobar or ubuntu) # This is a docker index repo (ex: username/foobar or ubuntu)
return INDEX_NAME, repo_name return INDEX_NAME, repo_name
return tuple(parts) return tuple(parts) # type: ignore
def get_credential_store(authconfig, registry): def get_credential_store(
authconfig: dict[str, t.Any] | AuthConfig, registry: str
) -> str | None:
if not isinstance(authconfig, AuthConfig): if not isinstance(authconfig, AuthConfig):
authconfig = AuthConfig(authconfig) authconfig = AuthConfig(authconfig)
return authconfig.get_credential_store(registry) return authconfig.get_credential_store(registry)
class AuthConfig(dict): class AuthConfig(dict):
def __init__(self, dct, credstore_env=None): def __init__(
self, dct: dict[str, t.Any], credstore_env: dict[str, str] | None = None
):
if "auths" not in dct: if "auths" not in dct:
dct["auths"] = {} dct["auths"] = {}
self.update(dct) self.update(dct)
self._credstore_env = credstore_env self._credstore_env = credstore_env
self._stores = {} self._stores: dict[str, Store] = {}
@classmethod @classmethod
def parse_auth(cls, entries, raise_on_error=False): def parse_auth(
cls, entries: dict[str, dict[str, t.Any]], raise_on_error=False
) -> dict[str, dict[str, t.Any]]:
""" """
Parses authentication entries Parses authentication entries
@ -107,10 +120,10 @@ class AuthConfig(dict):
Authentication registry. Authentication registry.
""" """
conf = {} conf: dict[str, dict[str, t.Any]] = {}
for registry, entry in entries.items(): for registry, entry in entries.items():
if not isinstance(entry, dict): if not isinstance(entry, dict):
log.debug("Config entry for key %s is not auth config", registry) log.debug("Config entry for key %s is not auth config", registry) # type: ignore
# We sometimes fall back to parsing the whole config as if it # We sometimes fall back to parsing the whole config as if it
# was the auth config by itself, for legacy purposes. In that # was the auth config by itself, for legacy purposes. In that
# case, we fail silently and return an empty conf if any of the # case, we fail silently and return an empty conf if any of the
@ -150,7 +163,12 @@ class AuthConfig(dict):
return conf return conf
@classmethod @classmethod
def load_config(cls, config_path, config_dict, credstore_env=None): def load_config(
cls,
config_path: str | None,
config_dict: dict[str, t.Any] | None,
credstore_env: dict[str, str] | None = None,
) -> t.Self:
""" """
Loads authentication data from a Docker configuration file in the given Loads authentication data from a Docker configuration file in the given
root directory or if config_path is passed use given path. root directory or if config_path is passed use given path.
@ -196,22 +214,24 @@ class AuthConfig(dict):
return cls({"auths": cls.parse_auth(config_dict)}, credstore_env) return cls({"auths": cls.parse_auth(config_dict)}, credstore_env)
@property @property
def auths(self): def auths(self) -> dict[str, dict[str, t.Any]]:
return self.get("auths", {}) return self.get("auths", {})
@property @property
def creds_store(self): def creds_store(self) -> str | None:
return self.get("credsStore", None) return self.get("credsStore", None)
@property @property
def cred_helpers(self): def cred_helpers(self) -> dict[str, t.Any]:
return self.get("credHelpers", {}) return self.get("credHelpers", {})
@property @property
def is_empty(self): def is_empty(self) -> bool:
return not self.auths and not self.creds_store and not self.cred_helpers return not self.auths and not self.creds_store and not self.cred_helpers
def resolve_authconfig(self, registry=None): def resolve_authconfig(
self, registry: str | None = None
) -> dict[str, t.Any] | None:
""" """
Returns the authentication data from the given auth configuration for a Returns the authentication data from the given auth configuration for a
specific registry. As with the Docker client, legacy entries in the specific registry. As with the Docker client, legacy entries in the
@ -244,7 +264,9 @@ class AuthConfig(dict):
log.debug("No entry found") log.debug("No entry found")
return None return None
def _resolve_authconfig_credstore(self, registry, credstore_name): def _resolve_authconfig_credstore(
self, registry: str | None, credstore_name: str
) -> dict[str, t.Any] | None:
if not registry or registry == INDEX_NAME: if not registry or registry == INDEX_NAME:
# The ecosystem is a little schizophrenic with index.docker.io VS # The ecosystem is a little schizophrenic with index.docker.io VS
# docker.io - in that case, it seems the full URL is necessary. # docker.io - in that case, it seems the full URL is necessary.
@ -272,19 +294,19 @@ class AuthConfig(dict):
except StoreError as e: except StoreError as e:
raise errors.DockerException(f"Credentials store error: {e}") raise errors.DockerException(f"Credentials store error: {e}")
def _get_store_instance(self, name): def _get_store_instance(self, name: str):
if name not in self._stores: if name not in self._stores:
self._stores[name] = Store(name, environment=self._credstore_env) self._stores[name] = Store(name, environment=self._credstore_env)
return self._stores[name] return self._stores[name]
def get_credential_store(self, registry): def get_credential_store(self, registry: str | None) -> str | None:
if not registry or registry == INDEX_NAME: if not registry or registry == INDEX_NAME:
registry = INDEX_URL registry = INDEX_URL
return self.cred_helpers.get(registry) or self.creds_store return self.cred_helpers.get(registry) or self.creds_store
def get_all_credentials(self): def get_all_credentials(self) -> dict[str, dict[str, t.Any] | None]:
auth_data = self.auths.copy() auth_data: dict[str, dict[str, t.Any] | None] = self.auths.copy() # type: ignore
if self.creds_store: if self.creds_store:
# Retrieve all credentials from the default store # Retrieve all credentials from the default store
store = self._get_store_instance(self.creds_store) store = self._get_store_instance(self.creds_store)
@ -299,21 +321,23 @@ class AuthConfig(dict):
return auth_data return auth_data
def add_auth(self, reg, data): def add_auth(self, reg: str, data: dict[str, t.Any]) -> None:
self["auths"][reg] = data self["auths"][reg] = data
def resolve_authconfig(authconfig, registry=None, credstore_env=None): def resolve_authconfig(
authconfig, registry: str | None = None, credstore_env: dict[str, str] | None = None
):
if not isinstance(authconfig, AuthConfig): if not isinstance(authconfig, AuthConfig):
authconfig = AuthConfig(authconfig, credstore_env) authconfig = AuthConfig(authconfig, credstore_env)
return authconfig.resolve_authconfig(registry) return authconfig.resolve_authconfig(registry)
def convert_to_hostname(url): def convert_to_hostname(url: str) -> str:
return url.replace("http://", "").replace("https://", "").split("/", 1)[0] return url.replace("http://", "").replace("https://", "").split("/", 1)[0]
def decode_auth(auth): def decode_auth(auth: str | bytes) -> tuple[str, str]:
if isinstance(auth, str): if isinstance(auth, str):
auth = auth.encode("ascii") auth = auth.encode("ascii")
s = base64.b64decode(auth) s = base64.b64decode(auth)
@ -321,12 +345,14 @@ def decode_auth(auth):
return login.decode("utf8"), pwd.decode("utf8") return login.decode("utf8"), pwd.decode("utf8")
def encode_header(auth): def encode_header(auth: dict[str, t.Any]) -> bytes:
auth_json = json.dumps(auth).encode("ascii") auth_json = json.dumps(auth).encode("ascii")
return base64.urlsafe_b64encode(auth_json) return base64.urlsafe_b64encode(auth_json)
def parse_auth(entries, raise_on_error=False): def parse_auth(
entries: dict[str, dict[str, t.Any]], raise_on_error: bool = False
) -> dict[str, dict[str, t.Any]]:
""" """
Parses authentication entries Parses authentication entries
@ -342,11 +368,15 @@ def parse_auth(entries, raise_on_error=False):
return AuthConfig.parse_auth(entries, raise_on_error) return AuthConfig.parse_auth(entries, raise_on_error)
def load_config(config_path=None, config_dict=None, credstore_env=None): def load_config(
config_path: str | None = None,
config_dict: dict[str, t.Any] | None = None,
credstore_env: dict[str, str] | None = None,
) -> AuthConfig:
return AuthConfig.load_config(config_path, config_dict, credstore_env) return AuthConfig.load_config(config_path, config_dict, credstore_env)
def _load_legacy_config(config_file): def _load_legacy_config(config_file: str) -> dict[str, dict[str, t.Any]]:
log.debug("Attempting to parse legacy auth file format") log.debug("Attempting to parse legacy auth file format")
try: try:
data = [] data = []

View File

@ -13,6 +13,7 @@ from __future__ import annotations
import json import json
import os import os
import typing as t
from .. import errors from .. import errors
from .config import ( from .config import (
@ -24,7 +25,11 @@ from .config import (
from .context import Context from .context import Context
def create_default_context(): if t.TYPE_CHECKING:
from ..tls import TLSConfig
def create_default_context() -> Context:
host = None host = None
if os.environ.get("DOCKER_HOST"): if os.environ.get("DOCKER_HOST"):
host = os.environ.get("DOCKER_HOST") host = os.environ.get("DOCKER_HOST")
@ -42,7 +47,7 @@ class ContextAPI:
DEFAULT_CONTEXT = None DEFAULT_CONTEXT = None
@classmethod @classmethod
def get_default_context(cls): def get_default_context(cls) -> Context:
context = cls.DEFAULT_CONTEXT context = cls.DEFAULT_CONTEXT
if context is None: if context is None:
context = create_default_context() context = create_default_context()
@ -52,13 +57,13 @@ class ContextAPI:
@classmethod @classmethod
def create_context( def create_context(
cls, cls,
name, name: str,
orchestrator=None, orchestrator: str | None = None,
host=None, host: str | None = None,
tls_cfg=None, tls_cfg: TLSConfig | None = None,
default_namespace=None, default_namespace: str | None = None,
skip_tls_verify=False, skip_tls_verify: bool = False,
): ) -> Context:
"""Creates a new context. """Creates a new context.
Returns: Returns:
(Context): a Context object. (Context): a Context object.
@ -108,7 +113,7 @@ class ContextAPI:
return ctx return ctx
@classmethod @classmethod
def get_context(cls, name=None): def get_context(cls, name: str | None = None) -> Context | None:
"""Retrieves a context object. """Retrieves a context object.
Args: Args:
name (str): The name of the context name (str): The name of the context
@ -136,7 +141,7 @@ class ContextAPI:
return Context.load_context(name) return Context.load_context(name)
@classmethod @classmethod
def contexts(cls): def contexts(cls) -> list[Context]:
"""Context list. """Context list.
Returns: Returns:
(Context): List of context objects. (Context): List of context objects.
@ -170,7 +175,7 @@ class ContextAPI:
return contexts return contexts
@classmethod @classmethod
def get_current_context(cls): def get_current_context(cls) -> Context | None:
"""Get current context. """Get current context.
Returns: Returns:
(Context): current context object. (Context): current context object.
@ -178,7 +183,7 @@ class ContextAPI:
return cls.get_context() return cls.get_context()
@classmethod @classmethod
def set_current_context(cls, name="default"): def set_current_context(cls, name: str = "default") -> None:
ctx = cls.get_context(name) ctx = cls.get_context(name)
if not ctx: if not ctx:
raise errors.ContextNotFound(name) raise errors.ContextNotFound(name)
@ -188,7 +193,7 @@ class ContextAPI:
raise errors.ContextException(f"Failed to set current context: {err}") raise errors.ContextException(f"Failed to set current context: {err}")
@classmethod @classmethod
def remove_context(cls, name): def remove_context(cls, name: str) -> None:
"""Remove a context. Similar to the ``docker context rm`` command. """Remove a context. Similar to the ``docker context rm`` command.
Args: Args:
@ -220,7 +225,7 @@ class ContextAPI:
ctx.remove() ctx.remove()
@classmethod @classmethod
def inspect_context(cls, name="default"): def inspect_context(cls, name: str = "default") -> dict[str, t.Any]:
"""Inspect a context. Similar to the ``docker context inspect`` command. """Inspect a context. Similar to the ``docker context inspect`` command.
Args: Args:

View File

@ -23,7 +23,7 @@ from ..utils.utils import parse_host
METAFILE = "meta.json" METAFILE = "meta.json"
def get_current_context_name_with_source(): def get_current_context_name_with_source() -> tuple[str, str]:
if os.environ.get("DOCKER_HOST"): if os.environ.get("DOCKER_HOST"):
return "default", "DOCKER_HOST environment variable set" return "default", "DOCKER_HOST environment variable set"
if os.environ.get("DOCKER_CONTEXT"): if os.environ.get("DOCKER_CONTEXT"):
@ -41,11 +41,11 @@ def get_current_context_name_with_source():
return "default", "fallback value" return "default", "fallback value"
def get_current_context_name(): def get_current_context_name() -> str:
return get_current_context_name_with_source()[0] return get_current_context_name_with_source()[0]
def write_context_name_to_docker_config(name=None): def write_context_name_to_docker_config(name: str | None = None) -> Exception | None:
if name == "default": if name == "default":
name = None name = None
docker_cfg_path = find_config_file() docker_cfg_path = find_config_file()
@ -62,44 +62,45 @@ def write_context_name_to_docker_config(name=None):
elif name: elif name:
config["currentContext"] = name config["currentContext"] = name
else: else:
return return None
if not docker_cfg_path: if not docker_cfg_path:
docker_cfg_path = get_default_config_file() docker_cfg_path = get_default_config_file()
try: try:
with open(docker_cfg_path, "wt", encoding="utf-8") as f: with open(docker_cfg_path, "wt", encoding="utf-8") as f:
json.dump(config, f, indent=4) json.dump(config, f, indent=4)
return None
except Exception as e: # pylint: disable=broad-exception-caught except Exception as e: # pylint: disable=broad-exception-caught
return e return e
def get_context_id(name): def get_context_id(name: str) -> str:
return hashlib.sha256(name.encode("utf-8")).hexdigest() return hashlib.sha256(name.encode("utf-8")).hexdigest()
def get_context_dir(): def get_context_dir() -> str:
docker_cfg_path = find_config_file() or get_default_config_file() docker_cfg_path = find_config_file() or get_default_config_file()
return os.path.join(os.path.dirname(docker_cfg_path), "contexts") return os.path.join(os.path.dirname(docker_cfg_path), "contexts")
def get_meta_dir(name=None): def get_meta_dir(name: str | None = None) -> str:
meta_dir = os.path.join(get_context_dir(), "meta") meta_dir = os.path.join(get_context_dir(), "meta")
if name: if name:
return os.path.join(meta_dir, get_context_id(name)) return os.path.join(meta_dir, get_context_id(name))
return meta_dir return meta_dir
def get_meta_file(name): def get_meta_file(name) -> str:
return os.path.join(get_meta_dir(name), METAFILE) return os.path.join(get_meta_dir(name), METAFILE)
def get_tls_dir(name=None, endpoint=""): def get_tls_dir(name: str | None = None, endpoint: str = "") -> str:
context_dir = get_context_dir() context_dir = get_context_dir()
if name: if name:
return os.path.join(context_dir, "tls", get_context_id(name), endpoint) return os.path.join(context_dir, "tls", get_context_id(name), endpoint)
return os.path.join(context_dir, "tls") return os.path.join(context_dir, "tls")
def get_context_host(path=None, tls=False): def get_context_host(path: str | None = None, tls: bool = False) -> str:
host = parse_host(path, IS_WINDOWS_PLATFORM, tls) host = parse_host(path, IS_WINDOWS_PLATFORM, tls)
if host == DEFAULT_UNIX_SOCKET: if host == DEFAULT_UNIX_SOCKET:
# remove http+ from default docker socket url # remove http+ from default docker socket url

View File

@ -13,6 +13,7 @@ from __future__ import annotations
import json import json
import os import os
import typing as t
from shutil import copyfile, rmtree from shutil import copyfile, rmtree
from ..errors import ContextException from ..errors import ContextException
@ -33,21 +34,21 @@ class Context:
def __init__( def __init__(
self, self,
name, name: str,
orchestrator=None, orchestrator: str | None = None,
host=None, host: str | None = None,
endpoints=None, endpoints: dict[str, dict[str, t.Any]] | None = None,
skip_tls_verify=False, skip_tls_verify: bool = False,
tls=False, tls: bool = False,
description=None, description: str | None = None,
): ) -> None:
if not name: if not name:
raise ValueError("Name not provided") raise ValueError("Name not provided")
self.name = name self.name = name
self.context_type = None self.context_type = None
self.orchestrator = orchestrator self.orchestrator = orchestrator
self.endpoints = {} self.endpoints = {}
self.tls_cfg = {} self.tls_cfg: dict[str, TLSConfig] = {}
self.meta_path = IN_MEMORY self.meta_path = IN_MEMORY
self.tls_path = IN_MEMORY self.tls_path = IN_MEMORY
self.description = description self.description = description
@ -89,12 +90,12 @@ class Context:
def set_endpoint( def set_endpoint(
self, self,
name="docker", name: str = "docker",
host=None, host: str | None = None,
tls_cfg=None, tls_cfg: TLSConfig | None = None,
skip_tls_verify=False, skip_tls_verify: bool = False,
def_namespace=None, def_namespace: str | None = None,
): ) -> None:
self.endpoints[name] = { self.endpoints[name] = {
"Host": get_context_host(host, not skip_tls_verify or tls_cfg is not None), "Host": get_context_host(host, not skip_tls_verify or tls_cfg is not None),
"SkipTLSVerify": skip_tls_verify, "SkipTLSVerify": skip_tls_verify,
@ -105,11 +106,11 @@ class Context:
if tls_cfg: if tls_cfg:
self.tls_cfg[name] = tls_cfg self.tls_cfg[name] = tls_cfg
def inspect(self): def inspect(self) -> dict[str, t.Any]:
return self() return self()
@classmethod @classmethod
def load_context(cls, name): def load_context(cls, name: str) -> t.Self | None:
meta = Context._load_meta(name) meta = Context._load_meta(name)
if meta: if meta:
instance = cls( instance = cls(
@ -125,12 +126,12 @@ class Context:
return None return None
@classmethod @classmethod
def _load_meta(cls, name): def _load_meta(cls, name: str) -> dict[str, t.Any] | None:
meta_file = get_meta_file(name) meta_file = get_meta_file(name)
if not os.path.isfile(meta_file): if not os.path.isfile(meta_file):
return None return None
metadata = {} metadata: dict[str, t.Any] = {}
try: try:
with open(meta_file, "rt", encoding="utf-8") as f: with open(meta_file, "rt", encoding="utf-8") as f:
metadata = json.load(f) metadata = json.load(f)
@ -154,7 +155,7 @@ class Context:
return metadata return metadata
def _load_certs(self): def _load_certs(self) -> None:
certs = {} certs = {}
tls_dir = get_tls_dir(self.name) tls_dir = get_tls_dir(self.name)
for endpoint in self.endpoints: for endpoint in self.endpoints:
@ -184,7 +185,7 @@ class Context:
self.tls_cfg = certs self.tls_cfg = certs
self.tls_path = tls_dir self.tls_path = tls_dir
def save(self): def save(self) -> None:
meta_dir = get_meta_dir(self.name) meta_dir = get_meta_dir(self.name)
if not os.path.isdir(meta_dir): if not os.path.isdir(meta_dir):
os.makedirs(meta_dir) os.makedirs(meta_dir)
@ -216,54 +217,54 @@ class Context:
self.meta_path = get_meta_dir(self.name) self.meta_path = get_meta_dir(self.name)
self.tls_path = get_tls_dir(self.name) self.tls_path = get_tls_dir(self.name)
def remove(self): def remove(self) -> None:
if os.path.isdir(self.meta_path): if os.path.isdir(self.meta_path):
rmtree(self.meta_path) rmtree(self.meta_path)
if os.path.isdir(self.tls_path): if os.path.isdir(self.tls_path):
rmtree(self.tls_path) rmtree(self.tls_path)
def __repr__(self): def __repr__(self) -> str:
return f"<{self.__class__.__name__}: '{self.name}'>" return f"<{self.__class__.__name__}: '{self.name}'>"
def __str__(self): def __str__(self) -> str:
return json.dumps(self.__call__(), indent=2) return json.dumps(self.__call__(), indent=2)
def __call__(self): def __call__(self) -> dict[str, t.Any]:
result = self.Metadata result = self.Metadata
result.update(self.TLSMaterial) result.update(self.TLSMaterial)
result.update(self.Storage) result.update(self.Storage)
return result return result
def is_docker_host(self): def is_docker_host(self) -> bool:
return self.context_type is None return self.context_type is None
@property @property
def Name(self): # pylint: disable=invalid-name def Name(self) -> str: # pylint: disable=invalid-name
return self.name return self.name
@property @property
def Host(self): # pylint: disable=invalid-name def Host(self) -> str | None: # pylint: disable=invalid-name
if not self.orchestrator or self.orchestrator == "swarm": if not self.orchestrator or self.orchestrator == "swarm":
endpoint = self.endpoints.get("docker", None) endpoint = self.endpoints.get("docker", None)
if endpoint: if endpoint:
return endpoint.get("Host", None) return endpoint.get("Host", None) # type: ignore
return None return None
return self.endpoints[self.orchestrator].get("Host", None) return self.endpoints[self.orchestrator].get("Host", None) # type: ignore
@property @property
def Orchestrator(self): # pylint: disable=invalid-name def Orchestrator(self) -> str | None: # pylint: disable=invalid-name
return self.orchestrator return self.orchestrator
@property @property
def Metadata(self): # pylint: disable=invalid-name def Metadata(self) -> dict[str, t.Any]: # pylint: disable=invalid-name
meta = {} meta: dict[str, t.Any] = {}
if self.orchestrator: if self.orchestrator:
meta = {"StackOrchestrator": self.orchestrator} meta = {"StackOrchestrator": self.orchestrator}
return {"Name": self.name, "Metadata": meta, "Endpoints": self.endpoints} return {"Name": self.name, "Metadata": meta, "Endpoints": self.endpoints}
@property @property
def TLSConfig(self): # pylint: disable=invalid-name def TLSConfig(self) -> TLSConfig | None: # pylint: disable=invalid-name
key = self.orchestrator key = self.orchestrator
if not key or key == "swarm": if not key or key == "swarm":
key = "docker" key = "docker"
@ -272,13 +273,15 @@ class Context:
return None return None
@property @property
def TLSMaterial(self): # pylint: disable=invalid-name def TLSMaterial(self) -> dict[str, t.Any]: # pylint: disable=invalid-name
certs = {} certs: dict[str, t.Any] = {}
for endpoint, tls in self.tls_cfg.items(): for endpoint, tls in self.tls_cfg.items():
cert, key = tls.cert paths = [tls.ca_cert, *tls.cert] if tls.cert else [tls.ca_cert]
certs[endpoint] = list(map(os.path.basename, [tls.ca_cert, cert, key])) certs[endpoint] = [
os.path.basename(path) if path else None for path in paths
]
return {"TLSMaterial": certs} return {"TLSMaterial": certs}
@property @property
def Storage(self): # pylint: disable=invalid-name def Storage(self) -> dict[str, t.Any]: # pylint: disable=invalid-name
return {"Storage": {"MetadataPath": self.meta_path, "TLSPath": self.tls_path}} return {"Storage": {"MetadataPath": self.meta_path, "TLSPath": self.tls_path}}

View File

@ -11,6 +11,12 @@
from __future__ import annotations from __future__ import annotations
import typing as t
if t.TYPE_CHECKING:
from subprocess import CalledProcessError
class StoreError(RuntimeError): class StoreError(RuntimeError):
pass pass
@ -24,7 +30,7 @@ class InitializationError(StoreError):
pass pass
def process_store_error(cpe, program): def process_store_error(cpe: CalledProcessError, program: str) -> StoreError:
message = cpe.output.decode("utf-8") message = cpe.output.decode("utf-8")
if "credentials not found in native keychain" in message: if "credentials not found in native keychain" in message:
return CredentialsNotFound(f"No matching credentials in {program}") return CredentialsNotFound(f"No matching credentials in {program}")

View File

@ -14,13 +14,14 @@ from __future__ import annotations
import errno import errno
import json import json
import subprocess import subprocess
import typing as t
from . import constants, errors from . import constants, errors
from .utils import create_environment_dict, find_executable from .utils import create_environment_dict, find_executable
class Store: class Store:
def __init__(self, program, environment=None): def __init__(self, program: str, environment: dict[str, str] | None = None) -> None:
"""Create a store object that acts as an interface to """Create a store object that acts as an interface to
perform the basic operations for storing, retrieving perform the basic operations for storing, retrieving
and erasing credentials using `program`. and erasing credentials using `program`.
@ -33,7 +34,7 @@ class Store:
f"{self.program} not installed or not available in PATH" f"{self.program} not installed or not available in PATH"
) )
def get(self, server): def get(self, server: str | bytes) -> dict[str, t.Any]:
"""Retrieve credentials for `server`. If no credentials are found, """Retrieve credentials for `server`. If no credentials are found,
a `StoreError` will be raised. a `StoreError` will be raised.
""" """
@ -53,7 +54,7 @@ class Store:
return result return result
def store(self, server, username, secret): def store(self, server: str, username: str, secret: str) -> bytes:
"""Store credentials for `server`. Raises a `StoreError` if an error """Store credentials for `server`. Raises a `StoreError` if an error
occurs. occurs.
""" """
@ -62,7 +63,7 @@ class Store:
).encode("utf-8") ).encode("utf-8")
return self._execute("store", data_input) return self._execute("store", data_input)
def erase(self, server): def erase(self, server: str | bytes) -> None:
"""Erase credentials for `server`. Raises a `StoreError` if an error """Erase credentials for `server`. Raises a `StoreError` if an error
occurs. occurs.
""" """
@ -70,12 +71,16 @@ class Store:
server = server.encode("utf-8") server = server.encode("utf-8")
self._execute("erase", server) self._execute("erase", server)
def list(self): def list(self) -> t.Any:
"""List stored credentials. Requires v0.4.0+ of the helper.""" """List stored credentials. Requires v0.4.0+ of the helper."""
data = self._execute("list", None) data = self._execute("list", None)
return json.loads(data.decode("utf-8")) return json.loads(data.decode("utf-8"))
def _execute(self, subcmd, data_input): def _execute(self, subcmd: str, data_input: bytes | None) -> bytes:
if self.exe is None:
raise errors.StoreError(
f"{self.program} not installed or not available in PATH"
)
output = None output = None
env = create_environment_dict(self.environment) env = create_environment_dict(self.environment)
try: try:

View File

@ -15,7 +15,7 @@ import os
from shutil import which from shutil import which
def find_executable(executable, path=None): def find_executable(executable: str, path: str | None = None) -> str | None:
""" """
As distutils.spawn.find_executable, but on Windows, look up As distutils.spawn.find_executable, but on Windows, look up
every extension declared in PATHEXT instead of just `.exe` every extension declared in PATHEXT instead of just `.exe`
@ -26,7 +26,7 @@ def find_executable(executable, path=None):
return which(executable, path=path) return which(executable, path=path)
def create_environment_dict(overrides): def create_environment_dict(overrides: dict[str, str] | None) -> dict[str, str]:
""" """
Create and return a copy of os.environ with the specified overrides Create and return a copy of os.environ with the specified overrides
""" """

View File

@ -11,6 +11,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from ansible.module_utils.common.text.converters import to_native from ansible.module_utils.common.text.converters import to_native
from ._import_helper import HTTPError as _HTTPError from ._import_helper import HTTPError as _HTTPError
@ -25,7 +27,7 @@ class DockerException(Exception):
""" """
def create_api_error_from_http_exception(e): def create_api_error_from_http_exception(e: _HTTPError) -> t.NoReturn:
""" """
Create a suitable APIError from requests.exceptions.HTTPError. Create a suitable APIError from requests.exceptions.HTTPError.
""" """
@ -52,14 +54,16 @@ class APIError(_HTTPError, DockerException):
An HTTP error from the API. An HTTP error from the API.
""" """
def __init__(self, message, response=None, explanation=None): def __init__(
self, message: str | Exception, response=None, explanation: str | None = None
) -> None:
# requests 1.2 supports response as a keyword argument, but # requests 1.2 supports response as a keyword argument, but
# requests 1.1 does not # requests 1.1 does not
super().__init__(message) super().__init__(message)
self.response = response self.response = response
self.explanation = explanation self.explanation = explanation
def __str__(self): def __str__(self) -> str:
message = super().__str__() message = super().__str__()
if self.is_client_error(): if self.is_client_error():
@ -74,19 +78,20 @@ class APIError(_HTTPError, DockerException):
return message return message
@property @property
def status_code(self): def status_code(self) -> int | None:
if self.response is not None: if self.response is not None:
return self.response.status_code return self.response.status_code
return None
def is_error(self): def is_error(self) -> bool:
return self.is_client_error() or self.is_server_error() return self.is_client_error() or self.is_server_error()
def is_client_error(self): def is_client_error(self) -> bool:
if self.status_code is None: if self.status_code is None:
return False return False
return 400 <= self.status_code < 500 return 400 <= self.status_code < 500
def is_server_error(self): def is_server_error(self) -> bool:
if self.status_code is None: if self.status_code is None:
return False return False
return 500 <= self.status_code < 600 return 500 <= self.status_code < 600
@ -121,10 +126,10 @@ class DeprecatedMethod(DockerException):
class TLSParameterError(DockerException): class TLSParameterError(DockerException):
def __init__(self, msg): def __init__(self, msg: str) -> None:
self.msg = msg self.msg = msg
def __str__(self): def __str__(self) -> str:
return self.msg + ( return self.msg + (
". TLS configurations should map the Docker CLI " ". TLS configurations should map the Docker CLI "
"client configurations. See " "client configurations. See "
@ -142,7 +147,14 @@ class ContainerError(DockerException):
Represents a container that has exited with a non-zero exit code. Represents a container that has exited with a non-zero exit code.
""" """
def __init__(self, container, exit_status, command, image, stderr): def __init__(
self,
container: str,
exit_status: int,
command: list[str],
image: str,
stderr: str | None,
):
self.container = container self.container = container
self.exit_status = exit_status self.exit_status = exit_status
self.command = command self.command = command
@ -156,12 +168,12 @@ class ContainerError(DockerException):
class StreamParseError(RuntimeError): class StreamParseError(RuntimeError):
def __init__(self, reason): def __init__(self, reason: Exception) -> None:
self.msg = reason self.msg = reason
class BuildError(DockerException): class BuildError(DockerException):
def __init__(self, reason, build_log): def __init__(self, reason: str, build_log: str) -> None:
super().__init__(reason) super().__init__(reason)
self.msg = reason self.msg = reason
self.build_log = build_log self.build_log = build_log
@ -171,7 +183,7 @@ class ImageLoadError(DockerException):
pass pass
def create_unexpected_kwargs_error(name, kwargs): def create_unexpected_kwargs_error(name: str, kwargs: dict[str, t.Any]) -> TypeError:
quoted_kwargs = [f"'{k}'" for k in sorted(kwargs)] quoted_kwargs = [f"'{k}'" for k in sorted(kwargs)]
text = [f"{name}() "] text = [f"{name}() "]
if len(quoted_kwargs) == 1: if len(quoted_kwargs) == 1:
@ -183,42 +195,44 @@ def create_unexpected_kwargs_error(name, kwargs):
class MissingContextParameter(DockerException): class MissingContextParameter(DockerException):
def __init__(self, param): def __init__(self, param: str) -> None:
self.param = param self.param = param
def __str__(self): def __str__(self) -> str:
return f"missing parameter: {self.param}" return f"missing parameter: {self.param}"
class ContextAlreadyExists(DockerException): class ContextAlreadyExists(DockerException):
def __init__(self, name): def __init__(self, name: str) -> None:
self.name = name self.name = name
def __str__(self): def __str__(self) -> str:
return f"context {self.name} already exists" return f"context {self.name} already exists"
class ContextException(DockerException): class ContextException(DockerException):
def __init__(self, msg): def __init__(self, msg: str) -> None:
self.msg = msg self.msg = msg
def __str__(self): def __str__(self) -> str:
return self.msg return self.msg
class ContextNotFound(DockerException): class ContextNotFound(DockerException):
def __init__(self, name): def __init__(self, name: str) -> None:
self.name = name self.name = name
def __str__(self): def __str__(self) -> str:
return f"context '{self.name}' not found" return f"context '{self.name}' not found"
class MissingRequirementException(DockerException): class MissingRequirementException(DockerException):
def __init__(self, msg, requirement, import_exception): def __init__(
self, msg: str, requirement: str, import_exception: ImportError | str
) -> None:
self.msg = msg self.msg = msg
self.requirement = requirement self.requirement = requirement
self.import_exception = import_exception self.import_exception = import_exception
def __str__(self): def __str__(self) -> str:
return self.msg return self.msg

View File

@ -12,12 +12,18 @@
from __future__ import annotations from __future__ import annotations
import os import os
import ssl import typing as t
from . import errors from . import errors
from .transport.ssladapter import SSLHTTPAdapter from .transport.ssladapter import SSLHTTPAdapter
if t.TYPE_CHECKING:
from ansible_collections.community.docker.plugins.module_utils._api.api.client import (
APIClient,
)
class TLSConfig: class TLSConfig:
""" """
TLS configuration. TLS configuration.
@ -27,25 +33,22 @@ class TLSConfig:
ca_cert (str): Path to CA cert file. ca_cert (str): Path to CA cert file.
verify (bool or str): This can be ``False`` or a path to a CA cert verify (bool or str): This can be ``False`` or a path to a CA cert
file. file.
ssl_version (int): A valid `SSL version`_.
assert_hostname (bool): Verify the hostname of the server. assert_hostname (bool): Verify the hostname of the server.
.. _`SSL version`: .. _`SSL version`:
https://docs.python.org/3.5/library/ssl.html#ssl.PROTOCOL_TLSv1 https://docs.python.org/3.5/library/ssl.html#ssl.PROTOCOL_TLSv1
""" """
cert = None cert: tuple[str, str] | None = None
ca_cert = None ca_cert: str | None = None
verify = None verify: bool | None = None
ssl_version = None
def __init__( def __init__(
self, self,
client_cert=None, client_cert: tuple[str, str] | None = None,
ca_cert=None, ca_cert: str | None = None,
verify=None, verify: bool | None = None,
ssl_version=None, assert_hostname: bool | None = None,
assert_hostname=None,
): ):
# Argument compatibility/mapping with # Argument compatibility/mapping with
# https://docs.docker.com/engine/articles/https/ # https://docs.docker.com/engine/articles/https/
@ -55,12 +58,6 @@ class TLSConfig:
self.assert_hostname = assert_hostname self.assert_hostname = assert_hostname
# If the user provides an SSL version, we should use their preference
if ssl_version:
self.ssl_version = ssl_version
else:
self.ssl_version = ssl.PROTOCOL_TLS_CLIENT
# "client_cert" must have both or neither cert/key files. In # "client_cert" must have both or neither cert/key files. In
# either case, Alert the user when both are expected, but any are # either case, Alert the user when both are expected, but any are
# missing. # missing.
@ -90,11 +87,10 @@ class TLSConfig:
"Invalid CA certificate provided for `ca_cert`." "Invalid CA certificate provided for `ca_cert`."
) )
def configure_client(self, client): def configure_client(self, client: APIClient) -> None:
""" """
Configure a client with these TLS options. Configure a client with these TLS options.
""" """
client.ssl_version = self.ssl_version
if self.verify and self.ca_cert: if self.verify and self.ca_cert:
client.verify = self.ca_cert client.verify = self.ca_cert
@ -107,7 +103,6 @@ class TLSConfig:
client.mount( client.mount(
"https://", "https://",
SSLHTTPAdapter( SSLHTTPAdapter(
ssl_version=self.ssl_version,
assert_hostname=self.assert_hostname, assert_hostname=self.assert_hostname,
), ),
) )

View File

@ -15,7 +15,7 @@ from .._import_helper import HTTPAdapter as _HTTPAdapter
class BaseHTTPAdapter(_HTTPAdapter): class BaseHTTPAdapter(_HTTPAdapter):
def close(self): def close(self) -> None:
super().close() super().close()
if hasattr(self, "pools"): if hasattr(self, "pools"):
self.pools.clear() self.pools.clear()
@ -24,10 +24,10 @@ class BaseHTTPAdapter(_HTTPAdapter):
# https://github.com/psf/requests/commit/c0813a2d910ea6b4f8438b91d315b8d181302356 # https://github.com/psf/requests/commit/c0813a2d910ea6b4f8438b91d315b8d181302356
# changes requests.adapters.HTTPAdapter to no longer call get_connection() from # changes requests.adapters.HTTPAdapter to no longer call get_connection() from
# send(), but instead call _get_connection(). # send(), but instead call _get_connection().
def _get_connection(self, request, *args, **kwargs): def _get_connection(self, request, *args, **kwargs): # type: ignore
return self.get_connection(request.url, kwargs.get("proxies")) return self.get_connection(request.url, kwargs.get("proxies"))
# Fix for requests 2.32.2+: # Fix for requests 2.32.2+:
# https://github.com/psf/requests/commit/c98e4d133ef29c46a9b68cd783087218a8075e05 # https://github.com/psf/requests/commit/c98e4d133ef29c46a9b68cd783087218a8075e05
def get_connection_with_tls_context(self, request, verify, proxies=None, cert=None): def get_connection_with_tls_context(self, request, verify, proxies=None, cert=None): # type: ignore
return self.get_connection(request.url, proxies) return self.get_connection(request.url, proxies)

View File

@ -23,12 +23,12 @@ RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
class NpipeHTTPConnection(urllib3_connection.HTTPConnection): class NpipeHTTPConnection(urllib3_connection.HTTPConnection):
def __init__(self, npipe_path, timeout=60): def __init__(self, npipe_path: str, timeout: int = 60) -> None:
super().__init__("localhost", timeout=timeout) super().__init__("localhost", timeout=timeout)
self.npipe_path = npipe_path self.npipe_path = npipe_path
self.timeout = timeout self.timeout = timeout
def connect(self): def connect(self) -> None:
sock = NpipeSocket() sock = NpipeSocket()
sock.settimeout(self.timeout) sock.settimeout(self.timeout)
sock.connect(self.npipe_path) sock.connect(self.npipe_path)
@ -36,18 +36,18 @@ class NpipeHTTPConnection(urllib3_connection.HTTPConnection):
class NpipeHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool): class NpipeHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
def __init__(self, npipe_path, timeout=60, maxsize=10): def __init__(self, npipe_path: str, timeout: int = 60, maxsize: int = 10) -> None:
super().__init__("localhost", timeout=timeout, maxsize=maxsize) super().__init__("localhost", timeout=timeout, maxsize=maxsize)
self.npipe_path = npipe_path self.npipe_path = npipe_path
self.timeout = timeout self.timeout = timeout
def _new_conn(self): def _new_conn(self) -> NpipeHTTPConnection:
return NpipeHTTPConnection(self.npipe_path, self.timeout) return NpipeHTTPConnection(self.npipe_path, self.timeout)
# When re-using connections, urllib3 tries to call select() on our # When re-using connections, urllib3 tries to call select() on our
# NpipeSocket instance, causing a crash. To circumvent this, we override # NpipeSocket instance, causing a crash. To circumvent this, we override
# _get_conn, where that check happens. # _get_conn, where that check happens.
def _get_conn(self, timeout): def _get_conn(self, timeout: int) -> NpipeHTTPConnection:
conn = None conn = None
try: try:
conn = self.pool.get(block=self.block, timeout=timeout) conn = self.pool.get(block=self.block, timeout=timeout)
@ -67,7 +67,6 @@ class NpipeHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
class NpipeHTTPAdapter(BaseHTTPAdapter): class NpipeHTTPAdapter(BaseHTTPAdapter):
__attrs__ = HTTPAdapter.__attrs__ + [ __attrs__ = HTTPAdapter.__attrs__ + [
"npipe_path", "npipe_path",
"pools", "pools",
@ -77,11 +76,11 @@ class NpipeHTTPAdapter(BaseHTTPAdapter):
def __init__( def __init__(
self, self,
base_url, base_url: str,
timeout=60, timeout: int = 60,
pool_connections=constants.DEFAULT_NUM_POOLS, pool_connections: int = constants.DEFAULT_NUM_POOLS,
max_pool_size=constants.DEFAULT_MAX_POOL_SIZE, max_pool_size: int = constants.DEFAULT_MAX_POOL_SIZE,
): ) -> None:
self.npipe_path = base_url.replace("npipe://", "") self.npipe_path = base_url.replace("npipe://", "")
self.timeout = timeout self.timeout = timeout
self.max_pool_size = max_pool_size self.max_pool_size = max_pool_size
@ -90,7 +89,7 @@ class NpipeHTTPAdapter(BaseHTTPAdapter):
) )
super().__init__() super().__init__()
def get_connection(self, url, proxies=None): def get_connection(self, url: str | bytes, proxies=None) -> NpipeHTTPConnectionPool:
with self.pools.lock: with self.pools.lock:
pool = self.pools.get(url) pool = self.pools.get(url)
if pool: if pool:
@ -103,7 +102,7 @@ class NpipeHTTPAdapter(BaseHTTPAdapter):
return pool return pool
def request_url(self, request, proxies): def request_url(self, request, proxies) -> str:
# The select_proxy utility in requests errors out when the provided URL # The select_proxy utility in requests errors out when the provided URL
# does not have a hostname, like is the case when using a UNIX socket. # does not have a hostname, like is the case when using a UNIX socket.
# Since proxies are an irrelevant notion in the case of UNIX sockets # Since proxies are an irrelevant notion in the case of UNIX sockets

View File

@ -15,6 +15,7 @@ import functools
import io import io
import time import time
import traceback import traceback
import typing as t
PYWIN32_IMPORT_ERROR: str | None # pylint: disable=invalid-name PYWIN32_IMPORT_ERROR: str | None # pylint: disable=invalid-name
@ -29,6 +30,13 @@ except ImportError:
else: else:
PYWIN32_IMPORT_ERROR = None # pylint: disable=invalid-name PYWIN32_IMPORT_ERROR = None # pylint: disable=invalid-name
if t.TYPE_CHECKING:
from collections.abc import Buffer, Callable
_Self = t.TypeVar("_Self")
_P = t.ParamSpec("_P")
_R = t.TypeVar("_R")
ERROR_PIPE_BUSY = 0xE7 ERROR_PIPE_BUSY = 0xE7
SECURITY_SQOS_PRESENT = 0x100000 SECURITY_SQOS_PRESENT = 0x100000
@ -37,10 +45,12 @@ SECURITY_ANONYMOUS = 0
MAXIMUM_RETRY_COUNT = 10 MAXIMUM_RETRY_COUNT = 10
def check_closed(f): def check_closed(
f: Callable[t.Concatenate[_Self, _P], _R],
) -> Callable[t.Concatenate[_Self, _P], _R]:
@functools.wraps(f) @functools.wraps(f)
def wrapped(self, *args, **kwargs): def wrapped(self: _Self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
if self._closed: if self._closed: # type: ignore
raise RuntimeError("Can not reuse socket after connection was closed.") raise RuntimeError("Can not reuse socket after connection was closed.")
return f(self, *args, **kwargs) return f(self, *args, **kwargs)
@ -54,25 +64,25 @@ class NpipeSocket:
implemented. implemented.
""" """
def __init__(self, handle=None): def __init__(self, handle=None) -> None:
self._timeout = win32pipe.NMPWAIT_USE_DEFAULT_WAIT self._timeout = win32pipe.NMPWAIT_USE_DEFAULT_WAIT
self._handle = handle self._handle = handle
self._address = None self._address: str | None = None
self._closed = False self._closed = False
self.flags = None self.flags: int | None = None
def accept(self): def accept(self) -> t.NoReturn:
raise NotImplementedError() raise NotImplementedError()
def bind(self, address): def bind(self, address) -> t.NoReturn:
raise NotImplementedError() raise NotImplementedError()
def close(self): def close(self) -> None:
self._handle.Close() self._handle.Close()
self._closed = True self._closed = True
@check_closed @check_closed
def connect(self, address, retry_count=0): def connect(self, address, retry_count: int = 0) -> None:
try: try:
handle = win32file.CreateFile( handle = win32file.CreateFile(
address, address,
@ -100,14 +110,14 @@ class NpipeSocket:
return self.connect(address, retry_count) return self.connect(address, retry_count)
raise e raise e
self.flags = win32pipe.GetNamedPipeInfo(handle)[0] self.flags = win32pipe.GetNamedPipeInfo(handle)[0] # type: ignore
self._handle = handle self._handle = handle
self._address = address self._address = address
@check_closed @check_closed
def connect_ex(self, address): def connect_ex(self, address) -> None:
return self.connect(address) self.connect(address)
@check_closed @check_closed
def detach(self): def detach(self):
@ -115,25 +125,25 @@ class NpipeSocket:
return self._handle return self._handle
@check_closed @check_closed
def dup(self): def dup(self) -> NpipeSocket:
return NpipeSocket(self._handle) return NpipeSocket(self._handle)
def getpeername(self): def getpeername(self) -> str | None:
return self._address return self._address
def getsockname(self): def getsockname(self) -> str | None:
return self._address return self._address
def getsockopt(self, level, optname, buflen=None): def getsockopt(self, level, optname, buflen=None) -> t.NoReturn:
raise NotImplementedError() raise NotImplementedError()
def ioctl(self, control, option): def ioctl(self, control, option) -> t.NoReturn:
raise NotImplementedError() raise NotImplementedError()
def listen(self, backlog): def listen(self, backlog) -> t.NoReturn:
raise NotImplementedError() raise NotImplementedError()
def makefile(self, mode=None, bufsize=None): def makefile(self, mode: str, bufsize: int | None = None):
if mode.strip("b") != "r": if mode.strip("b") != "r":
raise NotImplementedError() raise NotImplementedError()
rawio = NpipeFileIOBase(self) rawio = NpipeFileIOBase(self)
@ -142,30 +152,30 @@ class NpipeSocket:
return io.BufferedReader(rawio, buffer_size=bufsize) return io.BufferedReader(rawio, buffer_size=bufsize)
@check_closed @check_closed
def recv(self, bufsize, flags=0): def recv(self, bufsize: int, flags: int = 0) -> str:
dummy_err, data = win32file.ReadFile(self._handle, bufsize) dummy_err, data = win32file.ReadFile(self._handle, bufsize)
return data return data
@check_closed @check_closed
def recvfrom(self, bufsize, flags=0): def recvfrom(self, bufsize: int, flags: int = 0) -> tuple[str, str | None]:
data = self.recv(bufsize, flags) data = self.recv(bufsize, flags)
return (data, self._address) return (data, self._address)
@check_closed @check_closed
def recvfrom_into(self, buf, nbytes=0, flags=0): def recvfrom_into(
return self.recv_into(buf, nbytes, flags), self._address self, buf: Buffer, nbytes: int = 0, flags: int = 0
) -> tuple[int, str | None]:
return self.recv_into(buf, nbytes), self._address
@check_closed @check_closed
def recv_into(self, buf, nbytes=0): def recv_into(self, buf: Buffer, nbytes: int = 0) -> int:
readbuf = buf readbuf = buf if isinstance(buf, memoryview) else memoryview(buf)
if not isinstance(buf, memoryview):
readbuf = memoryview(buf)
event = win32event.CreateEvent(None, True, True, None) event = win32event.CreateEvent(None, True, True, None)
try: try:
overlapped = pywintypes.OVERLAPPED() overlapped = pywintypes.OVERLAPPED()
overlapped.hEvent = event overlapped.hEvent = event
dummy_err, dummy_data = win32file.ReadFile( dummy_err, dummy_data = win32file.ReadFile( # type: ignore
self._handle, readbuf[:nbytes] if nbytes else readbuf, overlapped self._handle, readbuf[:nbytes] if nbytes else readbuf, overlapped
) )
wait_result = win32event.WaitForSingleObject(event, self._timeout) wait_result = win32event.WaitForSingleObject(event, self._timeout)
@ -177,12 +187,12 @@ class NpipeSocket:
win32api.CloseHandle(event) win32api.CloseHandle(event)
@check_closed @check_closed
def send(self, string, flags=0): def send(self, string: Buffer, flags: int = 0) -> int:
event = win32event.CreateEvent(None, True, True, None) event = win32event.CreateEvent(None, True, True, None)
try: try:
overlapped = pywintypes.OVERLAPPED() overlapped = pywintypes.OVERLAPPED()
overlapped.hEvent = event overlapped.hEvent = event
win32file.WriteFile(self._handle, string, overlapped) win32file.WriteFile(self._handle, string, overlapped) # type: ignore
wait_result = win32event.WaitForSingleObject(event, self._timeout) wait_result = win32event.WaitForSingleObject(event, self._timeout)
if wait_result == win32event.WAIT_TIMEOUT: if wait_result == win32event.WAIT_TIMEOUT:
win32file.CancelIo(self._handle) win32file.CancelIo(self._handle)
@ -192,20 +202,20 @@ class NpipeSocket:
win32api.CloseHandle(event) win32api.CloseHandle(event)
@check_closed @check_closed
def sendall(self, string, flags=0): def sendall(self, string: Buffer, flags: int = 0) -> int:
return self.send(string, flags) return self.send(string, flags)
@check_closed @check_closed
def sendto(self, string, address): def sendto(self, string: Buffer, address: str) -> int:
self.connect(address) self.connect(address)
return self.send(string) return self.send(string)
def setblocking(self, flag): def setblocking(self, flag: bool):
if flag: if flag:
return self.settimeout(None) return self.settimeout(None)
return self.settimeout(0) return self.settimeout(0)
def settimeout(self, value): def settimeout(self, value: int | float | None) -> None:
if value is None: if value is None:
# Blocking mode # Blocking mode
self._timeout = win32event.INFINITE self._timeout = win32event.INFINITE
@ -215,39 +225,39 @@ class NpipeSocket:
# Timeout mode - Value converted to milliseconds # Timeout mode - Value converted to milliseconds
self._timeout = int(value * 1000) self._timeout = int(value * 1000)
def gettimeout(self): def gettimeout(self) -> int | float | None:
return self._timeout return self._timeout
def setsockopt(self, level, optname, value): def setsockopt(self, level, optname, value) -> t.NoReturn:
raise NotImplementedError() raise NotImplementedError()
@check_closed @check_closed
def shutdown(self, how): def shutdown(self, how) -> None:
return self.close() return self.close()
class NpipeFileIOBase(io.RawIOBase): class NpipeFileIOBase(io.RawIOBase):
def __init__(self, npipe_socket): def __init__(self, npipe_socket) -> None:
self.sock = npipe_socket self.sock = npipe_socket
def close(self): def close(self) -> None:
super().close() super().close()
self.sock = None self.sock = None
def fileno(self): def fileno(self) -> int:
return self.sock.fileno() return self.sock.fileno()
def isatty(self): def isatty(self) -> bool:
return False return False
def readable(self): def readable(self) -> bool:
return True return True
def readinto(self, buf): def readinto(self, buf: Buffer) -> int:
return self.sock.recv_into(buf) return self.sock.recv_into(buf)
def seekable(self): def seekable(self) -> bool:
return False return False
def writable(self): def writable(self) -> bool:
return False return False

View File

@ -17,6 +17,7 @@ import signal
import socket import socket
import subprocess import subprocess
import traceback import traceback
import typing as t
from queue import Empty from queue import Empty
from urllib.parse import urlparse from urllib.parse import urlparse
@ -33,12 +34,15 @@ except ImportError:
else: else:
PARAMIKO_IMPORT_ERROR = None # pylint: disable=invalid-name PARAMIKO_IMPORT_ERROR = None # pylint: disable=invalid-name
if t.TYPE_CHECKING:
from collections.abc import Buffer
RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
class SSHSocket(socket.socket): class SSHSocket(socket.socket):
def __init__(self, host): def __init__(self, host: str) -> None:
super().__init__(socket.AF_INET, socket.SOCK_STREAM) super().__init__(socket.AF_INET, socket.SOCK_STREAM)
self.host = host self.host = host
self.port = None self.port = None
@ -48,9 +52,9 @@ class SSHSocket(socket.socket):
if "@" in self.host: if "@" in self.host:
self.user, self.host = self.host.split("@") self.user, self.host = self.host.split("@")
self.proc = None self.proc: subprocess.Popen | None = None
def connect(self, **kwargs): def connect(self, *args_: t.Any, **kwargs: t.Any) -> None:
args = ["ssh"] args = ["ssh"]
if self.user: if self.user:
args = args + ["-l", self.user] args = args + ["-l", self.user]
@ -82,37 +86,48 @@ class SSHSocket(socket.socket):
preexec_fn=preexec_func, preexec_fn=preexec_func,
) )
def _write(self, data): def _write(self, data: Buffer) -> int:
if not self.proc or self.proc.stdin.closed: if not self.proc:
raise RuntimeError( raise RuntimeError(
"SSH subprocess not initiated. connect() must be called first." "SSH subprocess not initiated. connect() must be called first."
) )
assert self.proc.stdin is not None
if self.proc.stdin.closed:
raise RuntimeError(
"SSH subprocess not initiated. connect() must be called first after close()."
)
written = self.proc.stdin.write(data) written = self.proc.stdin.write(data)
self.proc.stdin.flush() self.proc.stdin.flush()
return written return written
def sendall(self, data): def sendall(self, data: Buffer, *args, **kwargs) -> None:
self._write(data) self._write(data)
def send(self, data): def send(self, data: Buffer, *args, **kwargs) -> int:
return self._write(data) return self._write(data)
def recv(self, n): def recv(self, n: int, *args, **kwargs) -> bytes:
if not self.proc: if not self.proc:
raise RuntimeError( raise RuntimeError(
"SSH subprocess not initiated. connect() must be called first." "SSH subprocess not initiated. connect() must be called first."
) )
assert self.proc.stdout is not None
return self.proc.stdout.read(n) return self.proc.stdout.read(n)
def makefile(self, mode): def makefile(self, mode: str, *args, **kwargs) -> t.IO: # type: ignore
if not self.proc: if not self.proc:
self.connect() self.connect()
self.proc.stdout.channel = self assert self.proc is not None
assert self.proc.stdout is not None
self.proc.stdout.channel = self # type: ignore
return self.proc.stdout return self.proc.stdout
def close(self): def close(self) -> None:
if not self.proc or self.proc.stdin.closed: if not self.proc:
return
assert self.proc.stdin is not None
if self.proc.stdin.closed:
return return
self.proc.stdin.write(b"\n\n") self.proc.stdin.write(b"\n\n")
self.proc.stdin.flush() self.proc.stdin.flush()
@ -120,13 +135,19 @@ class SSHSocket(socket.socket):
class SSHConnection(urllib3_connection.HTTPConnection): class SSHConnection(urllib3_connection.HTTPConnection):
def __init__(self, ssh_transport=None, timeout=60, host=None): def __init__(
self,
*,
ssh_transport=None,
timeout: int = 60,
host: str,
) -> None:
super().__init__("localhost", timeout=timeout) super().__init__("localhost", timeout=timeout)
self.ssh_transport = ssh_transport self.ssh_transport = ssh_transport
self.timeout = timeout self.timeout = timeout
self.ssh_host = host self.ssh_host = host
def connect(self): def connect(self) -> None:
if self.ssh_transport: if self.ssh_transport:
sock = self.ssh_transport.open_session() sock = self.ssh_transport.open_session()
sock.settimeout(self.timeout) sock.settimeout(self.timeout)
@ -142,7 +163,14 @@ class SSHConnection(urllib3_connection.HTTPConnection):
class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool): class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
scheme = "ssh" scheme = "ssh"
def __init__(self, ssh_client=None, timeout=60, maxsize=10, host=None): def __init__(
self,
*,
ssh_client: paramiko.SSHClient | None = None,
timeout: int = 60,
maxsize: int = 10,
host: str,
) -> None:
super().__init__("localhost", timeout=timeout, maxsize=maxsize) super().__init__("localhost", timeout=timeout, maxsize=maxsize)
self.ssh_transport = None self.ssh_transport = None
self.timeout = timeout self.timeout = timeout
@ -150,13 +178,17 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
self.ssh_transport = ssh_client.get_transport() self.ssh_transport = ssh_client.get_transport()
self.ssh_host = host self.ssh_host = host
def _new_conn(self): def _new_conn(self) -> SSHConnection:
return SSHConnection(self.ssh_transport, self.timeout, self.ssh_host) return SSHConnection(
ssh_transport=self.ssh_transport,
timeout=self.timeout,
host=self.ssh_host,
)
# When re-using connections, urllib3 calls fileno() on our # When re-using connections, urllib3 calls fileno() on our
# SSH channel instance, quickly overloading our fd limit. To avoid this, # SSH channel instance, quickly overloading our fd limit. To avoid this,
# we override _get_conn # we override _get_conn
def _get_conn(self, timeout): def _get_conn(self, timeout: int) -> SSHConnection:
conn = None conn = None
try: try:
conn = self.pool.get(block=self.block, timeout=timeout) conn = self.pool.get(block=self.block, timeout=timeout)
@ -176,7 +208,6 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
class SSHHTTPAdapter(BaseHTTPAdapter): class SSHHTTPAdapter(BaseHTTPAdapter):
__attrs__ = HTTPAdapter.__attrs__ + [ __attrs__ = HTTPAdapter.__attrs__ + [
"pools", "pools",
"timeout", "timeout",
@ -187,13 +218,13 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
def __init__( def __init__(
self, self,
base_url, base_url: str,
timeout=60, timeout: int = 60,
pool_connections=constants.DEFAULT_NUM_POOLS, pool_connections: int = constants.DEFAULT_NUM_POOLS,
max_pool_size=constants.DEFAULT_MAX_POOL_SIZE, max_pool_size: int = constants.DEFAULT_MAX_POOL_SIZE,
shell_out=False, shell_out: bool = False,
): ) -> None:
self.ssh_client = None self.ssh_client: paramiko.SSHClient | None = None
if not shell_out: if not shell_out:
self._create_paramiko_client(base_url) self._create_paramiko_client(base_url)
self._connect() self._connect()
@ -209,30 +240,31 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
) )
super().__init__() super().__init__()
def _create_paramiko_client(self, base_url): def _create_paramiko_client(self, base_url: str) -> None:
logging.getLogger("paramiko").setLevel(logging.WARNING) logging.getLogger("paramiko").setLevel(logging.WARNING)
self.ssh_client = paramiko.SSHClient() self.ssh_client = paramiko.SSHClient()
base_url = urlparse(base_url) base_url_p = urlparse(base_url)
self.ssh_params = { assert base_url_p.hostname is not None
"hostname": base_url.hostname, self.ssh_params: dict[str, t.Any] = {
"port": base_url.port, "hostname": base_url_p.hostname,
"username": base_url.username, "port": base_url_p.port,
"username": base_url_p.username,
} }
ssh_config_file = os.path.expanduser("~/.ssh/config") ssh_config_file = os.path.expanduser("~/.ssh/config")
if os.path.exists(ssh_config_file): if os.path.exists(ssh_config_file):
conf = paramiko.SSHConfig() conf = paramiko.SSHConfig()
with open(ssh_config_file, "rt", encoding="utf-8") as f: with open(ssh_config_file, "rt", encoding="utf-8") as f:
conf.parse(f) conf.parse(f)
host_config = conf.lookup(base_url.hostname) host_config = conf.lookup(base_url_p.hostname)
if "proxycommand" in host_config: if "proxycommand" in host_config:
self.ssh_params["sock"] = paramiko.ProxyCommand( self.ssh_params["sock"] = paramiko.ProxyCommand(
host_config["proxycommand"] host_config["proxycommand"]
) )
if "hostname" in host_config: if "hostname" in host_config:
self.ssh_params["hostname"] = host_config["hostname"] self.ssh_params["hostname"] = host_config["hostname"]
if base_url.port is None and "port" in host_config: if base_url_p.port is None and "port" in host_config:
self.ssh_params["port"] = host_config["port"] self.ssh_params["port"] = host_config["port"]
if base_url.username is None and "user" in host_config: if base_url_p.username is None and "user" in host_config:
self.ssh_params["username"] = host_config["user"] self.ssh_params["username"] = host_config["user"]
if "identityfile" in host_config: if "identityfile" in host_config:
self.ssh_params["key_filename"] = host_config["identityfile"] self.ssh_params["key_filename"] = host_config["identityfile"]
@ -240,11 +272,11 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
self.ssh_client.load_system_host_keys() self.ssh_client.load_system_host_keys()
self.ssh_client.set_missing_host_key_policy(paramiko.RejectPolicy()) self.ssh_client.set_missing_host_key_policy(paramiko.RejectPolicy())
def _connect(self): def _connect(self) -> None:
if self.ssh_client: if self.ssh_client:
self.ssh_client.connect(**self.ssh_params) self.ssh_client.connect(**self.ssh_params)
def get_connection(self, url, proxies=None): def get_connection(self, url: str | bytes, proxies=None) -> SSHConnectionPool:
if not self.ssh_client: if not self.ssh_client:
return SSHConnectionPool( return SSHConnectionPool(
ssh_client=self.ssh_client, ssh_client=self.ssh_client,
@ -271,7 +303,7 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
return pool return pool
def close(self): def close(self) -> None:
super().close() super().close()
if self.ssh_client: if self.ssh_client:
self.ssh_client.close() self.ssh_client.close()

View File

@ -11,9 +11,7 @@
from __future__ import annotations from __future__ import annotations
from ansible_collections.community.docker.plugins.module_utils._version import ( import typing as t
LooseVersion,
)
from .._import_helper import HTTPAdapter, urllib3 from .._import_helper import HTTPAdapter, urllib3
from .basehttpadapter import BaseHTTPAdapter from .basehttpadapter import BaseHTTPAdapter
@ -30,14 +28,19 @@ PoolManager = urllib3.poolmanager.PoolManager
class SSLHTTPAdapter(BaseHTTPAdapter): class SSLHTTPAdapter(BaseHTTPAdapter):
"""An HTTPS Transport Adapter that uses an arbitrary SSL version.""" """An HTTPS Transport Adapter that uses an arbitrary SSL version."""
__attrs__ = HTTPAdapter.__attrs__ + ["assert_hostname", "ssl_version"] __attrs__ = HTTPAdapter.__attrs__ + ["assert_hostname"]
def __init__(self, ssl_version=None, assert_hostname=None, **kwargs): def __init__(
self.ssl_version = ssl_version self,
assert_hostname: bool | None = None,
**kwargs,
) -> None:
self.assert_hostname = assert_hostname self.assert_hostname = assert_hostname
super().__init__(**kwargs) super().__init__(**kwargs)
def init_poolmanager(self, connections, maxsize, block=False): def init_poolmanager(
self, connections: int, maxsize: int, block: bool = False, **kwargs: t.Any
) -> None:
kwargs = { kwargs = {
"num_pools": connections, "num_pools": connections,
"maxsize": maxsize, "maxsize": maxsize,
@ -45,12 +48,10 @@ class SSLHTTPAdapter(BaseHTTPAdapter):
} }
if self.assert_hostname is not None: if self.assert_hostname is not None:
kwargs["assert_hostname"] = self.assert_hostname kwargs["assert_hostname"] = self.assert_hostname
if self.ssl_version and self.can_override_ssl_version():
kwargs["ssl_version"] = self.ssl_version
self.poolmanager = PoolManager(**kwargs) self.poolmanager = PoolManager(**kwargs)
def get_connection(self, *args, **kwargs): def get_connection(self, *args, **kwargs) -> urllib3.ConnectionPool:
""" """
Ensure assert_hostname is set correctly on our pool Ensure assert_hostname is set correctly on our pool
@ -61,15 +62,7 @@ class SSLHTTPAdapter(BaseHTTPAdapter):
conn = super().get_connection(*args, **kwargs) conn = super().get_connection(*args, **kwargs)
if ( if (
self.assert_hostname is not None self.assert_hostname is not None
and conn.assert_hostname != self.assert_hostname and conn.assert_hostname != self.assert_hostname # type: ignore
): ):
conn.assert_hostname = self.assert_hostname conn.assert_hostname = self.assert_hostname # type: ignore
return conn return conn
def can_override_ssl_version(self):
urllib_ver = urllib3.__version__.split("-")[0]
if urllib_ver is None:
return False
if urllib_ver == "dev":
return True
return LooseVersion(urllib_ver) > LooseVersion("1.5")

View File

@ -12,6 +12,7 @@
from __future__ import annotations from __future__ import annotations
import socket import socket
import typing as t
from .. import constants from .. import constants
from .._import_helper import HTTPAdapter, urllib3, urllib3_connection from .._import_helper import HTTPAdapter, urllib3, urllib3_connection
@ -22,26 +23,25 @@ RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
class UnixHTTPConnection(urllib3_connection.HTTPConnection): class UnixHTTPConnection(urllib3_connection.HTTPConnection):
def __init__(self, base_url: str | bytes, unix_socket, timeout: int = 60) -> None:
def __init__(self, base_url, unix_socket, timeout=60):
super().__init__("localhost", timeout=timeout) super().__init__("localhost", timeout=timeout)
self.base_url = base_url self.base_url = base_url
self.unix_socket = unix_socket self.unix_socket = unix_socket
self.timeout = timeout self.timeout = timeout
self.disable_buffering = False self.disable_buffering = False
def connect(self): def connect(self) -> None:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.settimeout(self.timeout) sock.settimeout(self.timeout)
sock.connect(self.unix_socket) sock.connect(self.unix_socket)
self.sock = sock self.sock = sock
def putheader(self, header, *values): def putheader(self, header: str, *values: str) -> None:
super().putheader(header, *values) super().putheader(header, *values)
if header == "Connection" and "Upgrade" in values: if header == "Connection" and "Upgrade" in values:
self.disable_buffering = True self.disable_buffering = True
def response_class(self, sock, *args, **kwargs): def response_class(self, sock, *args, **kwargs) -> t.Any:
# FIXME: We may need to disable buffering on Py3, # FIXME: We may need to disable buffering on Py3,
# but there's no clear way to do it at the moment. See: # but there's no clear way to do it at the moment. See:
# https://github.com/docker/docker-py/issues/1799 # https://github.com/docker/docker-py/issues/1799
@ -49,18 +49,23 @@ class UnixHTTPConnection(urllib3_connection.HTTPConnection):
class UnixHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool): class UnixHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
def __init__(self, base_url, socket_path, timeout=60, maxsize=10): def __init__(
self,
base_url: str | bytes,
socket_path: str,
timeout: int = 60,
maxsize: int = 10,
) -> None:
super().__init__("localhost", timeout=timeout, maxsize=maxsize) super().__init__("localhost", timeout=timeout, maxsize=maxsize)
self.base_url = base_url self.base_url = base_url
self.socket_path = socket_path self.socket_path = socket_path
self.timeout = timeout self.timeout = timeout
def _new_conn(self): def _new_conn(self) -> UnixHTTPConnection:
return UnixHTTPConnection(self.base_url, self.socket_path, self.timeout) return UnixHTTPConnection(self.base_url, self.socket_path, self.timeout)
class UnixHTTPAdapter(BaseHTTPAdapter): class UnixHTTPAdapter(BaseHTTPAdapter):
__attrs__ = HTTPAdapter.__attrs__ + [ __attrs__ = HTTPAdapter.__attrs__ + [
"pools", "pools",
"socket_path", "socket_path",
@ -70,11 +75,11 @@ class UnixHTTPAdapter(BaseHTTPAdapter):
def __init__( def __init__(
self, self,
socket_url, socket_url: str,
timeout=60, timeout: int = 60,
pool_connections=constants.DEFAULT_NUM_POOLS, pool_connections: int = constants.DEFAULT_NUM_POOLS,
max_pool_size=constants.DEFAULT_MAX_POOL_SIZE, max_pool_size: int = constants.DEFAULT_MAX_POOL_SIZE,
): ) -> None:
socket_path = socket_url.replace("http+unix://", "") socket_path = socket_url.replace("http+unix://", "")
if not socket_path.startswith("/"): if not socket_path.startswith("/"):
socket_path = "/" + socket_path socket_path = "/" + socket_path
@ -86,7 +91,7 @@ class UnixHTTPAdapter(BaseHTTPAdapter):
) )
super().__init__() super().__init__()
def get_connection(self, url, proxies=None): def get_connection(self, url: str | bytes, proxies=None) -> UnixHTTPConnectionPool:
with self.pools.lock: with self.pools.lock:
pool = self.pools.get(url) pool = self.pools.get(url)
if pool: if pool:
@ -99,7 +104,7 @@ class UnixHTTPAdapter(BaseHTTPAdapter):
return pool return pool
def request_url(self, request, proxies): def request_url(self, request, proxies) -> str:
# The select_proxy utility in requests errors out when the provided URL # The select_proxy utility in requests errors out when the provided URL
# does not have a hostname, like is the case when using a UNIX socket. # does not have a hostname, like is the case when using a UNIX socket.
# Since proxies are an irrelevant notion in the case of UNIX sockets # Since proxies are an irrelevant notion in the case of UNIX sockets

View File

@ -12,6 +12,7 @@
from __future__ import annotations from __future__ import annotations
import socket import socket
import typing as t
from .._import_helper import urllib3 from .._import_helper import urllib3
from ..errors import DockerException from ..errors import DockerException
@ -29,11 +30,11 @@ class CancellableStream:
>>> events.close() >>> events.close()
""" """
def __init__(self, stream, response): def __init__(self, stream, response) -> None:
self._stream = stream self._stream = stream
self._response = response self._response = response
def __iter__(self): def __iter__(self) -> t.Self:
return self return self
def __next__(self): def __next__(self):
@ -46,7 +47,7 @@ class CancellableStream:
next = __next__ next = __next__
def close(self): def close(self) -> None:
""" """
Closes the event streaming. Closes the event streaming.
""" """

View File

@ -17,15 +17,26 @@ import random
import re import re
import tarfile import tarfile
import tempfile import tempfile
import typing as t
from ..constants import IS_WINDOWS_PLATFORM, WINDOWS_LONGPATH_PREFIX from ..constants import IS_WINDOWS_PLATFORM, WINDOWS_LONGPATH_PREFIX
from . import fnmatch from . import fnmatch
if t.TYPE_CHECKING:
from collections.abc import Sequence
_SEP = re.compile("/|\\\\") if IS_WINDOWS_PLATFORM else re.compile("/") _SEP = re.compile("/|\\\\") if IS_WINDOWS_PLATFORM else re.compile("/")
def tar(path, exclude=None, dockerfile=None, fileobj=None, gzip=False): def tar(
path: str,
exclude: list[str] | None = None,
dockerfile: tuple[str, str] | tuple[None, None] | None = None,
fileobj: t.IO[bytes] | None = None,
gzip: bool = False,
) -> t.IO[bytes]:
root = os.path.abspath(path) root = os.path.abspath(path)
exclude = exclude or [] exclude = exclude or []
dockerfile = dockerfile or (None, None) dockerfile = dockerfile or (None, None)
@ -47,7 +58,9 @@ def tar(path, exclude=None, dockerfile=None, fileobj=None, gzip=False):
) )
def exclude_paths(root, patterns, dockerfile=None): def exclude_paths(
root: str, patterns: list[str], dockerfile: str | None = None
) -> set[str]:
""" """
Given a root directory path and a list of .dockerignore patterns, return Given a root directory path and a list of .dockerignore patterns, return
an iterator of all paths (both regular files and directories) in the root an iterator of all paths (both regular files and directories) in the root
@ -64,7 +77,7 @@ def exclude_paths(root, patterns, dockerfile=None):
return set(pm.walk(root)) return set(pm.walk(root))
def build_file_list(root): def build_file_list(root: str) -> list[str]:
files = [] files = []
for dirname, dirnames, fnames in os.walk(root): for dirname, dirnames, fnames in os.walk(root):
for filename in fnames + dirnames: for filename in fnames + dirnames:
@ -74,7 +87,13 @@ def build_file_list(root):
return files return files
def create_archive(root, files=None, fileobj=None, gzip=False, extra_files=None): def create_archive(
root: str,
files: Sequence[str] | None = None,
fileobj: t.IO[bytes] | None = None,
gzip: bool = False,
extra_files: Sequence[tuple[str, str]] | None = None,
) -> t.IO[bytes]:
extra_files = extra_files or [] extra_files = extra_files or []
if not fileobj: if not fileobj:
fileobj = tempfile.NamedTemporaryFile() fileobj = tempfile.NamedTemporaryFile()
@ -92,7 +111,7 @@ def create_archive(root, files=None, fileobj=None, gzip=False, extra_files=None)
if i is None: if i is None:
# This happens when we encounter a socket file. We can safely # This happens when we encounter a socket file. We can safely
# ignore it and proceed. # ignore it and proceed.
continue continue # type: ignore
# Workaround https://bugs.python.org/issue32713 # Workaround https://bugs.python.org/issue32713
if i.mtime < 0 or i.mtime > 8**11 - 1: if i.mtime < 0 or i.mtime > 8**11 - 1:
@ -124,11 +143,11 @@ def create_archive(root, files=None, fileobj=None, gzip=False, extra_files=None)
return fileobj return fileobj
def mkbuildcontext(dockerfile): def mkbuildcontext(dockerfile: io.BytesIO | t.IO[bytes]) -> t.IO[bytes]:
f = tempfile.NamedTemporaryFile() # pylint: disable=consider-using-with f = tempfile.NamedTemporaryFile() # pylint: disable=consider-using-with
try: try:
with tarfile.open(mode="w", fileobj=f) as t: with tarfile.open(mode="w", fileobj=f) as t:
if isinstance(dockerfile, io.StringIO): if isinstance(dockerfile, io.StringIO): # type: ignore
raise TypeError("Please use io.BytesIO to create in-memory Dockerfiles") raise TypeError("Please use io.BytesIO to create in-memory Dockerfiles")
if isinstance(dockerfile, io.BytesIO): if isinstance(dockerfile, io.BytesIO):
dfinfo = tarfile.TarInfo("Dockerfile") dfinfo = tarfile.TarInfo("Dockerfile")
@ -144,17 +163,17 @@ def mkbuildcontext(dockerfile):
return f return f
def split_path(p): def split_path(p: str) -> list[str]:
return [pt for pt in re.split(_SEP, p) if pt and pt != "."] return [pt for pt in re.split(_SEP, p) if pt and pt != "."]
def normalize_slashes(p): def normalize_slashes(p: str) -> str:
if IS_WINDOWS_PLATFORM: if IS_WINDOWS_PLATFORM:
return "/".join(split_path(p)) return "/".join(split_path(p))
return p return p
def walk(root, patterns, default=True): def walk(root: str, patterns: Sequence[str], default: bool = True) -> t.Generator[str]:
pm = PatternMatcher(patterns) pm = PatternMatcher(patterns)
return pm.walk(root) return pm.walk(root)
@ -162,11 +181,11 @@ def walk(root, patterns, default=True):
# Heavily based on # Heavily based on
# https://github.com/moby/moby/blob/master/pkg/fileutils/fileutils.go # https://github.com/moby/moby/blob/master/pkg/fileutils/fileutils.go
class PatternMatcher: class PatternMatcher:
def __init__(self, patterns): def __init__(self, patterns: Sequence[str]) -> None:
self.patterns = list(filter(lambda p: p.dirs, [Pattern(p) for p in patterns])) self.patterns = list(filter(lambda p: p.dirs, [Pattern(p) for p in patterns]))
self.patterns.append(Pattern("!.dockerignore")) self.patterns.append(Pattern("!.dockerignore"))
def matches(self, filepath): def matches(self, filepath: str) -> bool:
matched = False matched = False
parent_path = os.path.dirname(filepath) parent_path = os.path.dirname(filepath)
parent_path_dirs = split_path(parent_path) parent_path_dirs = split_path(parent_path)
@ -185,8 +204,8 @@ class PatternMatcher:
return matched return matched
def walk(self, root): def walk(self, root: str) -> t.Generator[str]:
def rec_walk(current_dir): def rec_walk(current_dir: str) -> t.Generator[str]:
for f in os.listdir(current_dir): for f in os.listdir(current_dir):
fpath = os.path.join(os.path.relpath(current_dir, root), f) fpath = os.path.join(os.path.relpath(current_dir, root), f)
if fpath.startswith("." + os.path.sep): if fpath.startswith("." + os.path.sep):
@ -220,7 +239,7 @@ class PatternMatcher:
class Pattern: class Pattern:
def __init__(self, pattern_str): def __init__(self, pattern_str: str) -> None:
self.exclusion = False self.exclusion = False
if pattern_str.startswith("!"): if pattern_str.startswith("!"):
self.exclusion = True self.exclusion = True
@ -230,8 +249,7 @@ class Pattern:
self.cleaned_pattern = "/".join(self.dirs) self.cleaned_pattern = "/".join(self.dirs)
@classmethod @classmethod
def normalize(cls, p): def normalize(cls, p: str) -> list[str]:
# Remove trailing spaces # Remove trailing spaces
p = p.strip() p = p.strip()
@ -256,11 +274,11 @@ class Pattern:
i += 1 i += 1
return split return split
def match(self, filepath): def match(self, filepath: str) -> bool:
return fnmatch.fnmatch(normalize_slashes(filepath), self.cleaned_pattern) return fnmatch.fnmatch(normalize_slashes(filepath), self.cleaned_pattern)
def process_dockerfile(dockerfile, path): def process_dockerfile(dockerfile: str, path: str) -> tuple[str | None, str | None]:
if not dockerfile: if not dockerfile:
return (None, None) return (None, None)
@ -268,7 +286,7 @@ def process_dockerfile(dockerfile, path):
if not os.path.isabs(dockerfile): if not os.path.isabs(dockerfile):
abs_dockerfile = os.path.join(path, dockerfile) abs_dockerfile = os.path.join(path, dockerfile)
if IS_WINDOWS_PLATFORM and path.startswith(WINDOWS_LONGPATH_PREFIX): if IS_WINDOWS_PLATFORM and path.startswith(WINDOWS_LONGPATH_PREFIX):
abs_dockerfile = f"{WINDOWS_LONGPATH_PREFIX}{os.path.normpath(abs_dockerfile[len(WINDOWS_LONGPATH_PREFIX):])}" abs_dockerfile = f"{WINDOWS_LONGPATH_PREFIX}{os.path.normpath(abs_dockerfile[len(WINDOWS_LONGPATH_PREFIX) :])}"
if os.path.splitdrive(path)[0] != os.path.splitdrive(abs_dockerfile)[ if os.path.splitdrive(path)[0] != os.path.splitdrive(abs_dockerfile)[
0 0
] or os.path.relpath(abs_dockerfile, path).startswith(".."): ] or os.path.relpath(abs_dockerfile, path).startswith(".."):

View File

@ -14,6 +14,7 @@ from __future__ import annotations
import json import json
import logging import logging
import os import os
import typing as t
from ..constants import IS_WINDOWS_PLATFORM from ..constants import IS_WINDOWS_PLATFORM
@ -24,11 +25,11 @@ LEGACY_DOCKER_CONFIG_FILENAME = ".dockercfg"
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def get_default_config_file(): def get_default_config_file() -> str:
return os.path.join(home_dir(), DOCKER_CONFIG_FILENAME) return os.path.join(home_dir(), DOCKER_CONFIG_FILENAME)
def find_config_file(config_path=None): def find_config_file(config_path: str | None = None) -> str | None:
homedir = home_dir() homedir = home_dir()
paths = list( paths = list(
filter( filter(
@ -54,14 +55,14 @@ def find_config_file(config_path=None):
return None return None
def config_path_from_environment(): def config_path_from_environment() -> str | None:
config_dir = os.environ.get("DOCKER_CONFIG") config_dir = os.environ.get("DOCKER_CONFIG")
if not config_dir: if not config_dir:
return None return None
return os.path.join(config_dir, os.path.basename(DOCKER_CONFIG_FILENAME)) return os.path.join(config_dir, os.path.basename(DOCKER_CONFIG_FILENAME))
def home_dir(): def home_dir() -> str:
""" """
Get the user's home directory, using the same logic as the Docker Engine Get the user's home directory, using the same logic as the Docker Engine
client - use %USERPROFILE% on Windows, $HOME/getuid on POSIX. client - use %USERPROFILE% on Windows, $HOME/getuid on POSIX.
@ -71,7 +72,7 @@ def home_dir():
return os.path.expanduser("~") return os.path.expanduser("~")
def load_general_config(config_path=None): def load_general_config(config_path: str | None = None) -> dict[str, t.Any]:
config_file = find_config_file(config_path) config_file = find_config_file(config_path)
if not config_file: if not config_file:

View File

@ -12,16 +12,37 @@
from __future__ import annotations from __future__ import annotations
import functools import functools
import typing as t
from .. import errors from .. import errors
from . import utils from . import utils
def minimum_version(version): if t.TYPE_CHECKING:
def decorator(f): from collections.abc import Callable
from ..api.client import APIClient
_Self = t.TypeVar("_Self")
_P = t.ParamSpec("_P")
_R = t.TypeVar("_R")
def minimum_version(
version: str,
) -> Callable[
[Callable[t.Concatenate[_Self, _P], _R]],
Callable[t.Concatenate[_Self, _P], _R],
]:
def decorator(
f: Callable[t.Concatenate[_Self, _P], _R],
) -> Callable[t.Concatenate[_Self, _P], _R]:
@functools.wraps(f) @functools.wraps(f)
def wrapper(self, *args, **kwargs): def wrapper(self: _Self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
if utils.version_lt(self._version, version): # We use _Self instead of APIClient since this is used for mixins for APIClient.
# This unfortunately means that self._version does not exist in the mixin,
# it only exists after mixing in. This is why we ignore types here.
if utils.version_lt(self._version, version): # type: ignore
raise errors.InvalidVersion( raise errors.InvalidVersion(
f"{f.__name__} is not available for version < {version}" f"{f.__name__} is not available for version < {version}"
) )
@ -32,13 +53,16 @@ def minimum_version(version):
return decorator return decorator
def update_headers(f): def update_headers(
def inner(self, *args, **kwargs): f: Callable[t.Concatenate[APIClient, _P], _R],
) -> Callable[t.Concatenate[APIClient, _P], _R]:
def inner(self: APIClient, *args: _P.args, **kwargs: _P.kwargs) -> _R:
if "HttpHeaders" in self._general_configs: if "HttpHeaders" in self._general_configs:
if not kwargs.get("headers"): if not kwargs.get("headers"):
kwargs["headers"] = self._general_configs["HttpHeaders"] kwargs["headers"] = self._general_configs["HttpHeaders"]
else: else:
kwargs["headers"].update(self._general_configs["HttpHeaders"]) # We cannot (yet) model that kwargs["headers"] should be a dictionary
kwargs["headers"].update(self._general_configs["HttpHeaders"]) # type: ignore
return f(self, *args, **kwargs) return f(self, *args, **kwargs)
return inner return inner

View File

@ -32,12 +32,12 @@ _cache: dict[str, re.Pattern] = {}
_MAXCACHE = 100 _MAXCACHE = 100
def _purge(): def _purge() -> None:
"""Clear the pattern cache""" """Clear the pattern cache"""
_cache.clear() _cache.clear()
def fnmatch(name, pat): def fnmatch(name: str, pat: str):
"""Test whether FILENAME matches PATTERN. """Test whether FILENAME matches PATTERN.
Patterns are Unix shell style: Patterns are Unix shell style:
@ -58,7 +58,7 @@ def fnmatch(name, pat):
return fnmatchcase(name, pat) return fnmatchcase(name, pat)
def fnmatchcase(name, pat): def fnmatchcase(name: str, pat: str) -> bool:
"""Test whether FILENAME matches PATTERN, including case. """Test whether FILENAME matches PATTERN, including case.
This is a version of fnmatch() which does not case-normalize This is a version of fnmatch() which does not case-normalize
its arguments. its arguments.
@ -74,7 +74,7 @@ def fnmatchcase(name, pat):
return re_pat.match(name) is not None return re_pat.match(name) is not None
def translate(pat): def translate(pat: str) -> str:
"""Translate a shell PATTERN to a regular expression. """Translate a shell PATTERN to a regular expression.
There is no way to quote meta-characters. There is no way to quote meta-characters.

View File

@ -13,14 +13,22 @@ from __future__ import annotations
import json import json
import json.decoder import json.decoder
import typing as t
from ..errors import StreamParseError from ..errors import StreamParseError
if t.TYPE_CHECKING:
import re
from collections.abc import Callable
_T = t.TypeVar("_T")
json_decoder = json.JSONDecoder() json_decoder = json.JSONDecoder()
def stream_as_text(stream): def stream_as_text(stream: t.Generator[bytes | str]) -> t.Generator[str]:
""" """
Given a stream of bytes or text, if any of the items in the stream Given a stream of bytes or text, if any of the items in the stream
are bytes convert them to text. are bytes convert them to text.
@ -33,20 +41,22 @@ def stream_as_text(stream):
yield data yield data
def json_splitter(buffer): def json_splitter(buffer: str) -> tuple[t.Any, str] | None:
"""Attempt to parse a json object from a buffer. If there is at least one """Attempt to parse a json object from a buffer. If there is at least one
object, return it and the rest of the buffer, otherwise return None. object, return it and the rest of the buffer, otherwise return None.
""" """
buffer = buffer.strip() buffer = buffer.strip()
try: try:
obj, index = json_decoder.raw_decode(buffer) obj, index = json_decoder.raw_decode(buffer)
rest = buffer[json.decoder.WHITESPACE.match(buffer, index).end() :] ws: re.Pattern = json.decoder.WHITESPACE # type: ignore[attr-defined]
m = ws.match(buffer, index)
rest = buffer[m.end() :] if m else buffer[index:]
return obj, rest return obj, rest
except ValueError: except ValueError:
return None return None
def json_stream(stream): def json_stream(stream: t.Generator[str | bytes]) -> t.Generator[t.Any]:
"""Given a stream of text, return a stream of json objects. """Given a stream of text, return a stream of json objects.
This handles streams which are inconsistently buffered (some entries may This handles streams which are inconsistently buffered (some entries may
be newline delimited, and others are not). be newline delimited, and others are not).
@ -54,21 +64,24 @@ def json_stream(stream):
return split_buffer(stream, json_splitter, json_decoder.decode) return split_buffer(stream, json_splitter, json_decoder.decode)
def line_splitter(buffer, separator="\n"): def line_splitter(buffer: str, separator: str = "\n") -> tuple[str, str] | None:
index = buffer.find(str(separator)) index = buffer.find(str(separator))
if index == -1: if index == -1:
return None return None
return buffer[: index + 1], buffer[index + 1 :] return buffer[: index + 1], buffer[index + 1 :]
def split_buffer(stream, splitter=None, decoder=lambda a: a): def split_buffer(
stream: t.Generator[str | bytes],
splitter: Callable[[str], tuple[_T, str] | None],
decoder: Callable[[str], _T],
) -> t.Generator[_T | str]:
"""Given a generator which yields strings and a splitter function, """Given a generator which yields strings and a splitter function,
joins all input, splits on the separator and yields each chunk. joins all input, splits on the separator and yields each chunk.
Unlike string.split(), each chunk includes the trailing Unlike string.split(), each chunk includes the trailing
separator, except for the last one if none was found on the end separator, except for the last one if none was found on the end
of the input. of the input.
""" """
splitter = splitter or line_splitter
buffered = "" buffered = ""
for data in stream_as_text(stream): for data in stream_as_text(stream):

View File

@ -12,6 +12,11 @@
from __future__ import annotations from __future__ import annotations
import re import re
import typing as t
if t.TYPE_CHECKING:
from collections.abc import Collection, Sequence
PORT_SPEC = re.compile( PORT_SPEC = re.compile(
@ -26,32 +31,42 @@ PORT_SPEC = re.compile(
) )
def add_port_mapping(port_bindings, internal_port, external): def add_port_mapping(
port_bindings: dict[str, list[str | tuple[str, str | None] | None]],
internal_port: str,
external: str | tuple[str, str | None] | None,
) -> None:
if internal_port in port_bindings: if internal_port in port_bindings:
port_bindings[internal_port].append(external) port_bindings[internal_port].append(external)
else: else:
port_bindings[internal_port] = [external] port_bindings[internal_port] = [external]
def add_port(port_bindings, internal_port_range, external_range): def add_port(
port_bindings: dict[str, list[str | tuple[str, str | None] | None]],
internal_port_range: list[str],
external_range: list[str] | list[tuple[str, str | None]] | None,
) -> None:
if external_range is None: if external_range is None:
for internal_port in internal_port_range: for internal_port in internal_port_range:
add_port_mapping(port_bindings, internal_port, None) add_port_mapping(port_bindings, internal_port, None)
else: else:
ports = zip(internal_port_range, external_range) for internal_port, external_port in zip(internal_port_range, external_range):
for internal_port, external_port in ports: # mypy loses the exact type of eternal_port elements for some reason...
add_port_mapping(port_bindings, internal_port, external_port) add_port_mapping(port_bindings, internal_port, external_port) # type: ignore
def build_port_bindings(ports): def build_port_bindings(
port_bindings = {} ports: Collection[str],
) -> dict[str, list[str | tuple[str, str | None] | None]]:
port_bindings: dict[str, list[str | tuple[str, str | None] | None]] = {}
for port in ports: for port in ports:
internal_port_range, external_range = split_port(port) internal_port_range, external_range = split_port(port)
add_port(port_bindings, internal_port_range, external_range) add_port(port_bindings, internal_port_range, external_range)
return port_bindings return port_bindings
def _raise_invalid_port(port): def _raise_invalid_port(port: str) -> t.NoReturn:
raise ValueError( raise ValueError(
f'Invalid port "{port}", should be ' f'Invalid port "{port}", should be '
"[[remote_ip:]remote_port[-remote_port]:]" "[[remote_ip:]remote_port[-remote_port]:]"
@ -59,39 +74,64 @@ def _raise_invalid_port(port):
) )
def port_range(start, end, proto, randomly_available_port=False): @t.overload
if not start: def port_range(
start: str,
end: str | None,
proto: str,
randomly_available_port: bool = False,
) -> list[str]: ...
@t.overload
def port_range(
start: str | None,
end: str | None,
proto: str,
randomly_available_port: bool = False,
) -> list[str] | None: ...
def port_range(
start: str | None,
end: str | None,
proto: str,
randomly_available_port: bool = False,
) -> list[str] | None:
if start is None:
return start return start
if not end: if end is None:
return [f"{start}{proto}"] return [f"{start}{proto}"]
if randomly_available_port: if randomly_available_port:
return [f"{start}-{end}{proto}"] return [f"{start}-{end}{proto}"]
return [f"{port}{proto}" for port in range(int(start), int(end) + 1)] return [f"{port}{proto}" for port in range(int(start), int(end) + 1)]
def split_port(port): def split_port(
if hasattr(port, "legacy_repr"): port: str,
# This is the worst hack, but it prevents a bug in Compose 1.14.0 ) -> tuple[list[str], list[str] | list[tuple[str, str | None]] | None]:
# https://github.com/docker/docker-py/issues/1668
# TODO: remove once fixed in Compose stable
port = port.legacy_repr()
port = str(port) port = str(port)
match = PORT_SPEC.match(port) match = PORT_SPEC.match(port)
if match is None: if match is None:
_raise_invalid_port(port) _raise_invalid_port(port)
parts = match.groupdict() parts = match.groupdict()
host = parts["host"] host: str | None = parts["host"]
proto = parts["proto"] or "" proto: str = parts["proto"] or ""
internal = port_range(parts["int"], parts["int_end"], proto) int_p: str = parts["int"]
external = port_range(parts["ext"], parts["ext_end"], "", len(internal) == 1) ext_p: str | None = parts["ext"] or None
internal: list[str] = port_range(int_p, parts["int_end"], proto) # type: ignore
external = port_range(ext_p, parts["ext_end"], "", len(internal) == 1)
if host is None: if host is None:
if external is not None and len(internal) != len(external): if external is not None and len(internal) != len(external):
raise ValueError("Port ranges don't match in length") raise ValueError("Port ranges don't match in length")
return internal, external return internal, external
external_or_none: Sequence[str | None]
if not external: if not external:
external = [None] * len(internal) external_or_none = [None] * len(internal)
elif len(internal) != len(external): else:
raise ValueError("Port ranges don't match in length") external_or_none = external
return internal, [(host, ext_port) for ext_port in external] if len(internal) != len(external_or_none):
raise ValueError("Port ranges don't match in length")
return internal, [(host, ext_port) for ext_port in external_or_none]

View File

@ -20,23 +20,23 @@ class ProxyConfig(dict):
""" """
@property @property
def http(self): def http(self) -> str | None:
return self.get("http") return self.get("http")
@property @property
def https(self): def https(self) -> str | None:
return self.get("https") return self.get("https")
@property @property
def ftp(self): def ftp(self) -> str | None:
return self.get("ftp") return self.get("ftp")
@property @property
def no_proxy(self): def no_proxy(self) -> str | None:
return self.get("no_proxy") return self.get("no_proxy")
@staticmethod @staticmethod
def from_dict(config): def from_dict(config: dict[str, str]) -> ProxyConfig:
""" """
Instantiate a new ProxyConfig from a dictionary that represents a Instantiate a new ProxyConfig from a dictionary that represents a
client configuration, as described in `the documentation`_. client configuration, as described in `the documentation`_.
@ -51,7 +51,7 @@ class ProxyConfig(dict):
no_proxy=config.get("noProxy"), no_proxy=config.get("noProxy"),
) )
def get_environment(self): def get_environment(self) -> dict[str, str]:
""" """
Return a dictionary representing the environment variables used to Return a dictionary representing the environment variables used to
set the proxy settings. set the proxy settings.
@ -67,7 +67,7 @@ class ProxyConfig(dict):
env["no_proxy"] = env["NO_PROXY"] = self.no_proxy env["no_proxy"] = env["NO_PROXY"] = self.no_proxy
return env return env
def inject_proxy_environment(self, environment): def inject_proxy_environment(self, environment: list[str]) -> list[str]:
""" """
Given a list of strings representing environment variables, prepend the Given a list of strings representing environment variables, prepend the
environment variables corresponding to the proxy settings. environment variables corresponding to the proxy settings.
@ -82,5 +82,5 @@ class ProxyConfig(dict):
# variables defined in "environment" to take precedence. # variables defined in "environment" to take precedence.
return proxy_env + environment return proxy_env + environment
def __str__(self): def __str__(self) -> str:
return f"ProxyConfig(http={self.http}, https={self.https}, ftp={self.ftp}, no_proxy={self.no_proxy})" return f"ProxyConfig(http={self.http}, https={self.https}, ftp={self.ftp}, no_proxy={self.no_proxy})"

View File

@ -16,10 +16,15 @@ import os
import select import select
import socket as pysocket import socket as pysocket
import struct import struct
import typing as t
from ..transport.npipesocket import NpipeSocket from ..transport.npipesocket import NpipeSocket
if t.TYPE_CHECKING:
from collections.abc import Iterable
STDOUT = 1 STDOUT = 1
STDERR = 2 STDERR = 2
@ -33,7 +38,7 @@ class SocketError(Exception):
NPIPE_ENDED = 109 NPIPE_ENDED = 109
def read(socket, n=4096): def read(socket, n: int = 4096) -> bytes | None:
""" """
Reads at most n bytes from socket Reads at most n bytes from socket
""" """
@ -58,6 +63,7 @@ def read(socket, n=4096):
except EnvironmentError as e: except EnvironmentError as e:
if e.errno not in recoverable_errors: if e.errno not in recoverable_errors:
raise raise
return None # TODO ???
except Exception as e: except Exception as e:
is_pipe_ended = ( is_pipe_ended = (
isinstance(socket, NpipeSocket) isinstance(socket, NpipeSocket)
@ -67,11 +73,11 @@ def read(socket, n=4096):
if is_pipe_ended: if is_pipe_ended:
# npipes do not support duplex sockets, so we interpret # npipes do not support duplex sockets, so we interpret
# a PIPE_ENDED error as a close operation (0-length read). # a PIPE_ENDED error as a close operation (0-length read).
return "" return b""
raise raise
def read_exactly(socket, n): def read_exactly(socket, n: int) -> bytes:
""" """
Reads exactly n bytes from socket Reads exactly n bytes from socket
Raises SocketError if there is not enough data Raises SocketError if there is not enough data
@ -85,7 +91,7 @@ def read_exactly(socket, n):
return data return data
def next_frame_header(socket): def next_frame_header(socket) -> tuple[int, int]:
""" """
Returns the stream and size of the next frame of data waiting to be read Returns the stream and size of the next frame of data waiting to be read
from socket, according to the protocol defined here: from socket, according to the protocol defined here:
@ -101,7 +107,7 @@ def next_frame_header(socket):
return (stream, actual) return (stream, actual)
def frames_iter(socket, tty): def frames_iter(socket, tty: bool) -> t.Generator[tuple[int, bytes]]:
""" """
Return a generator of frames read from socket. A frame is a tuple where Return a generator of frames read from socket. A frame is a tuple where
the first item is the stream number and the second item is a chunk of data. the first item is the stream number and the second item is a chunk of data.
@ -114,7 +120,7 @@ def frames_iter(socket, tty):
return frames_iter_no_tty(socket) return frames_iter_no_tty(socket)
def frames_iter_no_tty(socket): def frames_iter_no_tty(socket) -> t.Generator[tuple[int, bytes]]:
""" """
Returns a generator of data read from the socket when the tty setting is Returns a generator of data read from the socket when the tty setting is
not enabled. not enabled.
@ -135,20 +141,34 @@ def frames_iter_no_tty(socket):
yield (stream, result) yield (stream, result)
def frames_iter_tty(socket): def frames_iter_tty(socket) -> t.Generator[bytes]:
""" """
Return a generator of data read from the socket when the tty setting is Return a generator of data read from the socket when the tty setting is
enabled. enabled.
""" """
while True: while True:
result = read(socket) result = read(socket)
if len(result) == 0: if not result:
# We have reached EOF # We have reached EOF
return return
yield result yield result
def consume_socket_output(frames, demux=False): @t.overload
def consume_socket_output(frames, demux: t.Literal[False] = False) -> bytes: ...
@t.overload
def consume_socket_output(frames, demux: t.Literal[True]) -> tuple[bytes, bytes]: ...
@t.overload
def consume_socket_output(
frames, demux: bool = False
) -> bytes | tuple[bytes, bytes]: ...
def consume_socket_output(frames, demux: bool = False) -> bytes | tuple[bytes, bytes]:
""" """
Iterate through frames read from the socket and return the result. Iterate through frames read from the socket and return the result.
@ -167,7 +187,7 @@ def consume_socket_output(frames, demux=False):
# If the streams are demultiplexed, the generator yields tuples # If the streams are demultiplexed, the generator yields tuples
# (stdout, stderr) # (stdout, stderr)
out = [None, None] out: list[bytes | None] = [None, None]
for frame in frames: for frame in frames:
# It is guaranteed that for each frame, one and only one stream # It is guaranteed that for each frame, one and only one stream
# is not None. # is not None.
@ -183,10 +203,10 @@ def consume_socket_output(frames, demux=False):
out[1] = frame[1] out[1] = frame[1]
else: else:
out[1] += frame[1] out[1] += frame[1]
return tuple(out) return tuple(out) # type: ignore
def demux_adaptor(stream_id, data): def demux_adaptor(stream_id: int, data: bytes) -> tuple[bytes | None, bytes | None]:
""" """
Utility to demultiplex stdout and stderr when reading frames from the Utility to demultiplex stdout and stderr when reading frames from the
socket. socket.

View File

@ -18,6 +18,7 @@ import os
import os.path import os.path
import shlex import shlex
import string import string
import typing as t
from urllib.parse import urlparse, urlunparse from urllib.parse import urlparse, urlunparse
from ansible_collections.community.docker.plugins.module_utils._version import ( from ansible_collections.community.docker.plugins.module_utils._version import (
@ -34,32 +35,23 @@ from ..constants import (
from ..tls import TLSConfig from ..tls import TLSConfig
if t.TYPE_CHECKING:
import ssl
from collections.abc import Mapping, Sequence
URLComponents = collections.namedtuple( URLComponents = collections.namedtuple(
"URLComponents", "URLComponents",
"scheme netloc url params query fragment", "scheme netloc url params query fragment",
) )
def create_ipam_pool(*args, **kwargs): def decode_json_header(header: str) -> dict[str, t.Any]:
raise errors.DeprecatedMethod(
"utils.create_ipam_pool has been removed. Please use a "
"docker.types.IPAMPool object instead."
)
def create_ipam_config(*args, **kwargs):
raise errors.DeprecatedMethod(
"utils.create_ipam_config has been removed. Please use a "
"docker.types.IPAMConfig object instead."
)
def decode_json_header(header):
data = base64.b64decode(header).decode("utf-8") data = base64.b64decode(header).decode("utf-8")
return json.loads(data) return json.loads(data)
def compare_version(v1, v2): def compare_version(v1: str, v2: str) -> t.Literal[-1, 0, 1]:
"""Compare docker versions """Compare docker versions
>>> v1 = '1.9' >>> v1 = '1.9'
@ -80,43 +72,64 @@ def compare_version(v1, v2):
return 1 return 1
def version_lt(v1, v2): def version_lt(v1: str, v2: str) -> bool:
return compare_version(v1, v2) > 0 return compare_version(v1, v2) > 0
def version_gte(v1, v2): def version_gte(v1: str, v2: str) -> bool:
return not version_lt(v1, v2) return not version_lt(v1, v2)
def _convert_port_binding(binding): def _convert_port_binding(
binding: (
tuple[str, str | int | None]
| tuple[str | int | None]
| dict[str, str]
| str
| int
),
) -> dict[str, str]:
result = {"HostIp": "", "HostPort": ""} result = {"HostIp": "", "HostPort": ""}
host_port: str | int | None = ""
if isinstance(binding, tuple): if isinstance(binding, tuple):
if len(binding) == 2: if len(binding) == 2:
result["HostPort"] = binding[1] host_port = binding[1] # type: ignore
result["HostIp"] = binding[0] result["HostIp"] = binding[0]
elif isinstance(binding[0], str): elif isinstance(binding[0], str):
result["HostIp"] = binding[0] result["HostIp"] = binding[0]
else: else:
result["HostPort"] = binding[0] host_port = binding[0]
elif isinstance(binding, dict): elif isinstance(binding, dict):
if "HostPort" in binding: if "HostPort" in binding:
result["HostPort"] = binding["HostPort"] host_port = binding["HostPort"]
if "HostIp" in binding: if "HostIp" in binding:
result["HostIp"] = binding["HostIp"] result["HostIp"] = binding["HostIp"]
else: else:
raise ValueError(binding) raise ValueError(binding)
else: else:
result["HostPort"] = binding host_port = binding
if result["HostPort"] is None:
result["HostPort"] = ""
else:
result["HostPort"] = str(result["HostPort"])
result["HostPort"] = str(host_port) if host_port is not None else ""
return result return result
def convert_port_bindings(port_bindings): def convert_port_bindings(
port_bindings: dict[
str | int,
tuple[str, str | int | None]
| tuple[str | int | None]
| dict[str, str]
| str
| int
| list[
tuple[str, str | int | None]
| tuple[str | int | None]
| dict[str, str]
| str
| int
],
],
) -> dict[str, list[dict[str, str]]]:
result = {} result = {}
for k, v in port_bindings.items(): for k, v in port_bindings.items():
key = str(k) key = str(k)
@ -129,9 +142,11 @@ def convert_port_bindings(port_bindings):
return result return result
def convert_volume_binds(binds): def convert_volume_binds(
binds: list[str] | Mapping[str | bytes, dict[str, str | bytes] | bytes | str | int],
) -> list[str]:
if isinstance(binds, list): if isinstance(binds, list):
return binds return binds # type: ignore
result = [] result = []
for k, v in binds.items(): for k, v in binds.items():
@ -149,7 +164,7 @@ def convert_volume_binds(binds):
if "ro" in v: if "ro" in v:
mode = "ro" if v["ro"] else "rw" mode = "ro" if v["ro"] else "rw"
elif "mode" in v: elif "mode" in v:
mode = v["mode"] mode = v["mode"] # type: ignore # TODO
else: else:
mode = "rw" mode = "rw"
@ -165,9 +180,9 @@ def convert_volume_binds(binds):
] ]
if "propagation" in v and v["propagation"] in propagation_modes: if "propagation" in v and v["propagation"] in propagation_modes:
if mode: if mode:
mode = ",".join([mode, v["propagation"]]) mode = ",".join([mode, v["propagation"]]) # type: ignore # TODO
else: else:
mode = v["propagation"] mode = v["propagation"] # type: ignore # TODO
result.append(f"{k}:{bind}:{mode}") result.append(f"{k}:{bind}:{mode}")
else: else:
@ -177,7 +192,7 @@ def convert_volume_binds(binds):
return result return result
def convert_tmpfs_mounts(tmpfs): def convert_tmpfs_mounts(tmpfs: dict[str, str] | list[str]) -> dict[str, str]:
if isinstance(tmpfs, dict): if isinstance(tmpfs, dict):
return tmpfs return tmpfs
@ -204,9 +219,11 @@ def convert_tmpfs_mounts(tmpfs):
return result return result
def convert_service_networks(networks): def convert_service_networks(
networks: list[str | dict[str, str]],
) -> list[dict[str, str]]:
if not networks: if not networks:
return networks return networks # type: ignore
if not isinstance(networks, list): if not isinstance(networks, list):
raise TypeError("networks parameter must be a list.") raise TypeError("networks parameter must be a list.")
@ -218,17 +235,17 @@ def convert_service_networks(networks):
return result return result
def parse_repository_tag(repo_name): def parse_repository_tag(repo_name: str) -> tuple[str, str | None]:
parts = repo_name.rsplit("@", 1) parts = repo_name.rsplit("@", 1)
if len(parts) == 2: if len(parts) == 2:
return tuple(parts) return tuple(parts) # type: ignore
parts = repo_name.rsplit(":", 1) parts = repo_name.rsplit(":", 1)
if len(parts) == 2 and "/" not in parts[1]: if len(parts) == 2 and "/" not in parts[1]:
return tuple(parts) return tuple(parts) # type: ignore
return repo_name, None return repo_name, None
def parse_host(addr, is_win32=False, tls=False): def parse_host(addr: str | None, is_win32: bool = False, tls: bool = False) -> str:
# Sensible defaults # Sensible defaults
if not addr and is_win32: if not addr and is_win32:
return DEFAULT_NPIPE return DEFAULT_NPIPE
@ -308,7 +325,7 @@ def parse_host(addr, is_win32=False, tls=False):
).rstrip("/") ).rstrip("/")
def parse_devices(devices): def parse_devices(devices: Sequence[dict[str, str] | str]) -> list[dict[str, str]]:
device_list = [] device_list = []
for device in devices: for device in devices:
if isinstance(device, dict): if isinstance(device, dict):
@ -337,7 +354,10 @@ def parse_devices(devices):
return device_list return device_list
def kwargs_from_env(ssl_version=None, assert_hostname=None, environment=None): def kwargs_from_env(
assert_hostname: bool | None = None,
environment: Mapping[str, str] | None = None,
) -> dict[str, t.Any]:
if not environment: if not environment:
environment = os.environ environment = os.environ
host = environment.get("DOCKER_HOST") host = environment.get("DOCKER_HOST")
@ -347,14 +367,14 @@ def kwargs_from_env(ssl_version=None, assert_hostname=None, environment=None):
# empty string for tls verify counts as "false". # empty string for tls verify counts as "false".
# Any value or 'unset' counts as true. # Any value or 'unset' counts as true.
tls_verify = environment.get("DOCKER_TLS_VERIFY") tls_verify_str = environment.get("DOCKER_TLS_VERIFY")
if tls_verify == "": if tls_verify_str == "":
tls_verify = False tls_verify = False
else: else:
tls_verify = tls_verify is not None tls_verify = tls_verify_str is not None
enable_tls = cert_path or tls_verify enable_tls = cert_path or tls_verify
params = {} params: dict[str, t.Any] = {}
if host: if host:
params["base_url"] = host params["base_url"] = host
@ -377,14 +397,13 @@ def kwargs_from_env(ssl_version=None, assert_hostname=None, environment=None):
), ),
ca_cert=os.path.join(cert_path, "ca.pem"), ca_cert=os.path.join(cert_path, "ca.pem"),
verify=tls_verify, verify=tls_verify,
ssl_version=ssl_version,
assert_hostname=assert_hostname, assert_hostname=assert_hostname,
) )
return params return params
def convert_filters(filters): def convert_filters(filters: dict[str, bool | str | list[str]]) -> str:
result = {} result = {}
for k, v in filters.items(): for k, v in filters.items():
if isinstance(v, bool): if isinstance(v, bool):
@ -397,7 +416,7 @@ def convert_filters(filters):
return json.dumps(result) return json.dumps(result)
def parse_bytes(s): def parse_bytes(s: int | float | str) -> int | float:
if isinstance(s, (int, float)): if isinstance(s, (int, float)):
return s return s
if len(s) == 0: if len(s) == 0:
@ -435,14 +454,16 @@ def parse_bytes(s):
return s return s
def normalize_links(links): def normalize_links(links: dict[str, str] | Sequence[tuple[str, str]]) -> list[str]:
if isinstance(links, dict): if isinstance(links, dict):
links = links.items() sorted_links = sorted(links.items())
else:
sorted_links = sorted(links)
return [f"{k}:{v}" if v else k for k, v in sorted(links)] return [f"{k}:{v}" if v else k for k, v in sorted_links]
def parse_env_file(env_file): def parse_env_file(env_file: str | os.PathLike) -> dict[str, str]:
""" """
Reads a line-separated environment file. Reads a line-separated environment file.
The format of each line should be "key=value". The format of each line should be "key=value".
@ -451,7 +472,6 @@ def parse_env_file(env_file):
with open(env_file, "rt", encoding="utf-8") as f: with open(env_file, "rt", encoding="utf-8") as f:
for line in f: for line in f:
if line[0] == "#": if line[0] == "#":
continue continue
@ -471,11 +491,11 @@ def parse_env_file(env_file):
return environment return environment
def split_command(command): def split_command(command: str) -> list[str]:
return shlex.split(command) return shlex.split(command)
def format_environment(environment): def format_environment(environment: Mapping[str, str | bytes]) -> list[str]:
def format_env(key, value): def format_env(key, value):
if value is None: if value is None:
return key return key
@ -487,16 +507,9 @@ def format_environment(environment):
return [format_env(*var) for var in environment.items()] return [format_env(*var) for var in environment.items()]
def format_extra_hosts(extra_hosts, task=False): def format_extra_hosts(extra_hosts: Mapping[str, str], task: bool = False) -> list[str]:
# Use format dictated by Swarm API if container is part of a task # Use format dictated by Swarm API if container is part of a task
if task: if task:
return [f"{v} {k}" for k, v in sorted(extra_hosts.items())] return [f"{v} {k}" for k, v in sorted(extra_hosts.items())]
return [f"{k}:{v}" for k, v in sorted(extra_hosts.items())] return [f"{k}:{v}" for k, v in sorted(extra_hosts.items())]
def create_host_config(self, *args, **kwargs):
raise errors.DeprecatedMethod(
"utils.create_host_config has been removed. Please use a "
"docker.types.HostConfig object instead."
)

View File

@ -552,7 +552,7 @@ def _execute_command(
result = client.get_json("/exec/{0}/json", exec_id) result = client.get_json("/exec/{0}/json", exec_id)
rc = result.get("ExitCode") or 0 rc: int = result.get("ExitCode") or 0
stdout = stdout or b"" stdout = stdout or b""
stderr = stderr or b"" stderr = stderr or b""