Add more type hints.

This commit is contained in:
Felix Fontein 2025-10-21 06:25:30 +02:00
parent 0ff66e8b24
commit 08960a9317
28 changed files with 1104 additions and 650 deletions

View File

@ -13,8 +13,9 @@ from __future__ import annotations
import json
import logging
import os
import struct
from functools import partial
import typing as t
from urllib.parse import quote
from .. import auth
@ -50,13 +51,13 @@ from ..utils import config, json_stream, utils
from ..utils.decorators import update_headers
from ..utils.proxy import ProxyConfig
from ..utils.socket import consume_socket_output, demux_adaptor, frames_iter
from .daemon import DaemonApiMixin
from ..utils.decorators import minimum_version
log = logging.getLogger(__name__)
class APIClient(_Session, DaemonApiMixin):
class APIClient(_Session):
"""
A low-level client for the Docker Engine API.
@ -105,16 +106,16 @@ class APIClient(_Session, DaemonApiMixin):
def __init__(
self,
base_url=None,
version=None,
timeout=DEFAULT_TIMEOUT_SECONDS,
tls=False,
user_agent=DEFAULT_USER_AGENT,
num_pools=None,
credstore_env=None,
use_ssh_client=False,
max_pool_size=DEFAULT_MAX_POOL_SIZE,
):
base_url: str | None = None,
version: str | None = None,
timeout: int = DEFAULT_TIMEOUT_SECONDS,
tls: bool | TLSConfig = False,
user_agent: str = DEFAULT_USER_AGENT,
num_pools: int | None = None,
credstore_env: dict[str, str] | None = None,
use_ssh_client: bool = False,
max_pool_size: int = DEFAULT_MAX_POOL_SIZE,
) -> None:
super().__init__()
fail_on_missing_imports()
@ -152,6 +153,9 @@ class APIClient(_Session, DaemonApiMixin):
else DEFAULT_NUM_POOLS
)
self._custom_adapter: (
UnixHTTPAdapter | NpipeHTTPAdapter | SSHHTTPAdapter | SSLHTTPAdapter | None
) = None
if base_url.startswith("http+unix://"):
self._custom_adapter = UnixHTTPAdapter(
base_url,
@ -223,7 +227,7 @@ class APIClient(_Session, DaemonApiMixin):
f"API versions below {MINIMUM_DOCKER_API_VERSION} are no longer supported by this library."
)
def _retrieve_server_version(self):
def _retrieve_server_version(self) -> str:
try:
version_result = self.version(api_version=False)
except Exception as e:
@ -242,54 +246,87 @@ class APIClient(_Session, DaemonApiMixin):
f"Error while fetching server API version: {e}. Response seems to be broken."
) from e
def _set_request_timeout(self, kwargs):
def _set_request_timeout(self, kwargs: dict[str, t.Any]) -> dict[str, t.Any]:
"""Prepare the kwargs for an HTTP request by inserting the timeout
parameter, if not already present."""
kwargs.setdefault("timeout", self.timeout)
return kwargs
@update_headers
def _post(self, url, **kwargs):
def _post(self, url: str, **kwargs):
return self.post(url, **self._set_request_timeout(kwargs))
@update_headers
def _get(self, url, **kwargs):
def _get(self, url: str, **kwargs):
return self.get(url, **self._set_request_timeout(kwargs))
@update_headers
def _head(self, url, **kwargs):
def _head(self, url: str, **kwargs):
return self.head(url, **self._set_request_timeout(kwargs))
@update_headers
def _put(self, url, **kwargs):
def _put(self, url: str, **kwargs):
return self.put(url, **self._set_request_timeout(kwargs))
@update_headers
def _delete(self, url, **kwargs):
def _delete(self, url: str, **kwargs):
return self.delete(url, **self._set_request_timeout(kwargs))
def _url(self, pathfmt, *args, **kwargs):
def _url(self, pathfmt: str, *args: str, versioned_api: bool = True) -> str:
for arg in args:
if not isinstance(arg, str):
raise ValueError(
f"Expected a string but found {arg} ({type(arg)}) instead"
)
quote_f = partial(quote, safe="/:")
args = map(quote_f, args)
q_args = [quote(arg, safe="/:") for arg in args]
if kwargs.get("versioned_api", True):
return f"{self.base_url}/v{self._version}{pathfmt.format(*args)}"
return f"{self.base_url}{pathfmt.format(*args)}"
if versioned_api:
return f"{self.base_url}/v{self._version}{pathfmt.format(*q_args)}"
return f"{self.base_url}{pathfmt.format(*q_args)}"
def _raise_for_status(self, response):
def _raise_for_status(self, response) -> None:
"""Raises stored :class:`APIError`, if one occurred."""
try:
response.raise_for_status()
except _HTTPError as e:
create_api_error_from_http_exception(e)
def _result(self, response, get_json=False, get_binary=False):
@t.overload
def _result(
self,
response,
*,
get_json: t.Literal[False] = False,
get_binary: t.Literal[False] = False,
) -> str: ...
@t.overload
def _result(
self,
response,
*,
get_json: t.Literal[True],
get_binary: t.Literal[False] = False,
) -> t.Any: ...
@t.overload
def _result(
self,
response,
*,
get_json: t.Literal[False] = False,
get_binary: t.Literal[True],
) -> bytes: ...
@t.overload
def _result(
self, response, *, get_json: bool = False, get_binary: bool = False
) -> t.Any | str | bytes: ...
def _result(
self, response, *, get_json: bool = False, get_binary: bool = False
) -> t.Any | str | bytes:
if get_json and get_binary:
raise AssertionError("json and binary must not be both True")
self._raise_for_status(response)
@ -300,10 +337,10 @@ class APIClient(_Session, DaemonApiMixin):
return response.content
return response.text
def _post_json(self, url, data, **kwargs):
def _post_json(self, url: str, data: dict[str, str | None] | t.Any, **kwargs):
# Go <1.1 cannot unserialize null to a string
# so we do this disgusting thing here.
data2 = {}
data2: dict[str, t.Any] = {}
if data is not None and isinstance(data, dict):
for k, v in data.items():
if v is not None:
@ -316,7 +353,7 @@ class APIClient(_Session, DaemonApiMixin):
kwargs["headers"]["Content-Type"] = "application/json"
return self._post(url, data=json.dumps(data2), **kwargs)
def _attach_params(self, override=None):
def _attach_params(self, override: dict[str, int] | None = None) -> dict[str, int]:
return override or {"stdout": 1, "stderr": 1, "stream": 1}
def _get_raw_response_socket(self, response):
@ -341,12 +378,24 @@ class APIClient(_Session, DaemonApiMixin):
return sock
def _stream_helper(self, response, decode=False):
@t.overload
def _stream_helper(
self, response, *, decode: t.Literal[False] = False
) -> t.Generator[bytes]: ...
@t.overload
def _stream_helper(
self, response, *, decode: t.Literal[True]
) -> t.Generator[t.Any]: ...
def _stream_helper(self, response, *, decode: bool = False) -> t.Generator[t.Any]:
"""Generator for data coming from a chunked-encoded HTTP response."""
if response.raw._fp.chunked:
if decode:
yield from json_stream.json_stream(self._stream_helper(response, False))
yield from json_stream.json_stream(
self._stream_helper(response, decode=False)
)
else:
reader = response.raw
while not reader.closed:
@ -362,7 +411,7 @@ class APIClient(_Session, DaemonApiMixin):
# encountered an error immediately
yield self._result(response, get_json=decode)
def _multiplexed_buffer_helper(self, response):
def _multiplexed_buffer_helper(self, response) -> t.Generator[bytes]:
"""A generator of multiplexed data blocks read from a buffered
response."""
buf = self._result(response, get_binary=True)
@ -378,7 +427,7 @@ class APIClient(_Session, DaemonApiMixin):
walker = end
yield buf[start:end]
def _multiplexed_response_stream_helper(self, response):
def _multiplexed_response_stream_helper(self, response) -> t.Generator[bytes]:
"""A generator of multiplexed data blocks coming from a response
stream."""
@ -399,7 +448,19 @@ class APIClient(_Session, DaemonApiMixin):
break
yield data
def _stream_raw_result(self, response, chunk_size=1, decode=True):
@t.overload
def _stream_raw_result(
self, response, *, chunk_size: int = 1, decode: t.Literal[True] = True
) -> t.Generator[str]: ...
@t.overload
def _stream_raw_result(
self, response, *, chunk_size: int = 1, decode: t.Literal[False]
) -> t.Generator[bytes]: ...
def _stream_raw_result(
self, response, *, chunk_size: int = 1, decode: bool = True
) -> t.Generator[str | bytes]:
"""Stream result for TTY-enabled container and raw binary data"""
self._raise_for_status(response)
@ -410,14 +471,81 @@ class APIClient(_Session, DaemonApiMixin):
yield from response.iter_content(chunk_size, decode)
def _read_from_socket(self, response, stream, tty=True, demux=False):
@t.overload
def _read_from_socket(
self,
response,
*,
stream: t.Literal[True],
tty: bool = True,
demux: t.Literal[False] = False,
) -> t.Generator[bytes]: ...
@t.overload
def _read_from_socket(
self,
response,
*,
stream: t.Literal[True],
tty: t.Literal[True] = True,
demux: t.Literal[True],
) -> t.Generator[tuple[bytes, None]]: ...
@t.overload
def _read_from_socket(
self,
response,
*,
stream: t.Literal[True],
tty: t.Literal[False],
demux: t.Literal[True],
) -> t.Generator[tuple[bytes, None] | tuple[None, bytes]]: ...
@t.overload
def _read_from_socket(
self,
response,
*,
stream: t.Literal[False],
tty: bool = True,
demux: t.Literal[False] = False,
) -> bytes: ...
@t.overload
def _read_from_socket(
self,
response,
*,
stream: t.Literal[False],
tty: t.Literal[True] = True,
demux: t.Literal[True],
) -> tuple[bytes, None]: ...
@t.overload
def _read_from_socket(
self,
response,
*,
stream: t.Literal[False],
tty: t.Literal[False],
demux: t.Literal[True],
) -> tuple[bytes, bytes]: ...
@t.overload
def _read_from_socket(
self, response, *, stream: bool, tty: bool = True, demux: bool = False
) -> t.Any: ...
def _read_from_socket(
self, response, *, stream: bool, tty: bool = True, demux: bool = False
) -> t.Any:
"""Consume all data from the socket, close the response and return the
data. If stream=True, then a generator is returned instead and the
caller is responsible for closing the response.
"""
socket = self._get_raw_response_socket(response)
gen = frames_iter(socket, tty)
gen: t.Generator = frames_iter(socket, tty)
if demux:
# The generator will output tuples (stdout, stderr)
@ -434,7 +562,7 @@ class APIClient(_Session, DaemonApiMixin):
finally:
response.close()
def _disable_socket_timeout(self, socket):
def _disable_socket_timeout(self, socket) -> None:
"""Depending on the combination of python version and whether we are
connecting over http or https, we might need to access _sock, which
may or may not exist; or we may need to just settimeout on socket
@ -462,7 +590,27 @@ class APIClient(_Session, DaemonApiMixin):
s.settimeout(None)
def _get_result_tty(self, stream, res, is_tty):
@t.overload
def _get_result_tty(
self, stream: t.Literal[True], res, is_tty: t.Literal[True]
) -> t.Generator[str]: ...
@t.overload
def _get_result_tty(
self, stream: t.Literal[True], res, is_tty: t.Literal[False]
) -> t.Generator[bytes]: ...
@t.overload
def _get_result_tty(
self, stream: t.Literal[False], res, is_tty: t.Literal[True]
) -> bytes: ...
@t.overload
def _get_result_tty(
self, stream: t.Literal[False], res, is_tty: t.Literal[False]
) -> bytes: ...
def _get_result_tty(self, stream: bool, res, is_tty: bool) -> t.Any:
# We should also use raw streaming (without keep-alive)
# if we are dealing with a tty-enabled container.
if is_tty:
@ -478,11 +626,11 @@ class APIClient(_Session, DaemonApiMixin):
return self._multiplexed_response_stream_helper(res)
return sep.join(list(self._multiplexed_buffer_helper(res)))
def _unmount(self, *args):
def _unmount(self, *args) -> None:
for proto in args:
self.adapters.pop(proto)
def get_adapter(self, url):
def get_adapter(self, url: str):
try:
return super().get_adapter(url)
except _InvalidSchema as e:
@ -491,10 +639,10 @@ class APIClient(_Session, DaemonApiMixin):
raise e
@property
def api_version(self):
def api_version(self) -> str:
return self._version
def reload_config(self, dockercfg_path=None):
def reload_config(self, dockercfg_path: str | None = None) -> None:
"""
Force a reload of the auth configuration
@ -510,7 +658,7 @@ class APIClient(_Session, DaemonApiMixin):
dockercfg_path, credstore_env=self.credstore_env
)
def _set_auth_headers(self, headers):
def _set_auth_headers(self, headers: dict[str, str | bytes]) -> None:
log.debug("Looking for auth config")
# If we do not have any auth data so far, try reloading the config
@ -537,57 +685,62 @@ class APIClient(_Session, DaemonApiMixin):
else:
log.debug("No auth config found")
def get_binary(self, pathfmt, *args, **kwargs):
def get_binary(self, pathfmt: str, *args: str, **kwargs) -> bytes:
return self._result(
self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_binary=True,
)
def get_json(self, pathfmt, *args, **kwargs):
def get_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any:
return self._result(
self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_json=True,
)
def get_text(self, pathfmt, *args, **kwargs):
def get_text(self, pathfmt: str, *args: str, **kwargs) -> str:
return self._result(
self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs)
)
def get_raw_stream(self, pathfmt, *args, **kwargs):
chunk_size = kwargs.pop("chunk_size", DEFAULT_DATA_CHUNK_SIZE)
def get_raw_stream(
self,
pathfmt: str,
*args: str,
chunk_size: int = DEFAULT_DATA_CHUNK_SIZE,
**kwargs,
) -> t.Generator[bytes]:
res = self._get(
self._url(pathfmt, *args, versioned_api=True), stream=True, **kwargs
)
self._raise_for_status(res)
return self._stream_raw_result(res, chunk_size, False)
return self._stream_raw_result(res, chunk_size=chunk_size, decode=False)
def delete_call(self, pathfmt, *args, **kwargs):
def delete_call(self, pathfmt: str, *args: str, **kwargs) -> None:
self._raise_for_status(
self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs)
)
def delete_json(self, pathfmt, *args, **kwargs):
def delete_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any:
return self._result(
self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_json=True,
)
def post_call(self, pathfmt, *args, **kwargs):
def post_call(self, pathfmt: str, *args: str, **kwargs) -> None:
self._raise_for_status(
self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs)
)
def post_json(self, pathfmt, *args, **kwargs):
data = kwargs.pop("data", None)
def post_json(self, pathfmt: str, *args: str, data: t.Any = None, **kwargs) -> None:
self._raise_for_status(
self._post_json(
self._url(pathfmt, *args, versioned_api=True), data, **kwargs
)
)
def post_json_to_binary(self, pathfmt, *args, **kwargs):
data = kwargs.pop("data", None)
def post_json_to_binary(
self, pathfmt: str, *args: str, data: t.Any = None, **kwargs
) -> bytes:
return self._result(
self._post_json(
self._url(pathfmt, *args, versioned_api=True), data, **kwargs
@ -595,8 +748,9 @@ class APIClient(_Session, DaemonApiMixin):
get_binary=True,
)
def post_json_to_json(self, pathfmt, *args, **kwargs):
data = kwargs.pop("data", None)
def post_json_to_json(
self, pathfmt: str, *args: str, data: t.Any = None, **kwargs
) -> t.Any:
return self._result(
self._post_json(
self._url(pathfmt, *args, versioned_api=True), data, **kwargs
@ -604,17 +758,24 @@ class APIClient(_Session, DaemonApiMixin):
get_json=True,
)
def post_json_to_text(self, pathfmt, *args, **kwargs):
data = kwargs.pop("data", None)
def post_json_to_text(
self, pathfmt: str, *args: str, data: t.Any = None, **kwargs
) -> str:
return self._result(
self._post_json(
self._url(pathfmt, *args, versioned_api=True), data, **kwargs
),
)
def post_json_to_stream_socket(self, pathfmt, *args, **kwargs):
data = kwargs.pop("data", None)
headers = (kwargs.pop("headers", None) or {}).copy()
def post_json_to_stream_socket(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
**kwargs,
):
headers = headers.copy() if headers else {}
headers.update(
{
"Connection": "Upgrade",
@ -631,18 +792,102 @@ class APIClient(_Session, DaemonApiMixin):
)
)
def post_json_to_stream(self, pathfmt, *args, **kwargs):
data = kwargs.pop("data", None)
headers = (kwargs.pop("headers", None) or {}).copy()
@t.overload
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: t.Literal[True],
tty: bool = True,
demux: t.Literal[False] = False,
**kwargs,
) -> t.Generator[bytes]: ...
@t.overload
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: t.Literal[True],
tty: t.Literal[True] = True,
demux: t.Literal[True],
**kwargs,
) -> t.Generator[tuple[bytes, None]]: ...
@t.overload
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: t.Literal[True],
tty: t.Literal[False],
demux: t.Literal[True],
**kwargs,
) -> t.Generator[tuple[bytes, None] | tuple[None, bytes]]: ...
@t.overload
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: t.Literal[False],
tty: bool = True,
demux: t.Literal[False] = False,
**kwargs,
) -> bytes: ...
@t.overload
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: t.Literal[False],
tty: t.Literal[True] = True,
demux: t.Literal[True],
**kwargs,
) -> tuple[bytes, None]: ...
@t.overload
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: t.Literal[False],
tty: t.Literal[False],
demux: t.Literal[True],
**kwargs,
) -> tuple[bytes, bytes]: ...
def post_json_to_stream(
self,
pathfmt: str,
*args: str,
data: t.Any = None,
headers: dict[str, str] | None = None,
stream: bool = False,
demux: bool = False,
tty: bool = False,
**kwargs,
) -> t.Any:
headers = headers.copy() if headers else {}
headers.update(
{
"Connection": "Upgrade",
"Upgrade": "tcp",
}
)
stream = kwargs.pop("stream", False)
demux = kwargs.pop("demux", False)
tty = kwargs.pop("tty", False)
return self._read_from_socket(
self._post_json(
self._url(pathfmt, *args, versioned_api=True),
@ -651,13 +896,133 @@ class APIClient(_Session, DaemonApiMixin):
stream=True,
**kwargs,
),
stream,
stream=stream,
tty=tty,
demux=demux,
)
def post_to_json(self, pathfmt, *args, **kwargs):
def post_to_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any:
return self._result(
self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs),
get_json=True,
)
@minimum_version("1.25")
def df(self) -> dict[str, t.Any]:
"""
Get data usage information.
Returns:
(dict): A dictionary representing different resource categories
and their respective data usage.
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
url = self._url("/system/df")
return self._result(self._get(url), get_json=True)
def info(self) -> dict[str, t.Any]:
"""
Display system-wide information. Identical to the ``docker info``
command.
Returns:
(dict): The info as a dict
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
return self._result(self._get(self._url("/info")), get_json=True)
def login(
self,
username: str,
password: str | None = None,
email: str | None = None,
registry: str | None = None,
reauth: bool = False,
dockercfg_path: str | None = None,
) -> dict[str, t.Any]:
"""
Authenticate with a registry. Similar to the ``docker login`` command.
Args:
username (str): The registry username
password (str): The plaintext password
email (str): The email for the registry account
registry (str): URL to the registry. E.g.
``https://index.docker.io/v1/``
reauth (bool): Whether or not to refresh existing authentication on
the Docker server.
dockercfg_path (str): Use a custom path for the Docker config file
(default ``$HOME/.docker/config.json`` if present,
otherwise ``$HOME/.dockercfg``)
Returns:
(dict): The response from the login request
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
# If we do not have any auth data so far, try reloading the config file
# one more time in case anything showed up in there.
# If dockercfg_path is passed check to see if the config file exists,
# if so load that config.
if dockercfg_path and os.path.exists(dockercfg_path):
self._auth_configs = auth.load_config(
dockercfg_path, credstore_env=self.credstore_env
)
elif not self._auth_configs or self._auth_configs.is_empty:
self._auth_configs = auth.load_config(credstore_env=self.credstore_env)
authcfg = self._auth_configs.resolve_authconfig(registry)
# If we found an existing auth config for this registry and username
# combination, we can return it immediately unless reauth is requested.
if authcfg and authcfg.get("username", None) == username and not reauth:
return authcfg
req_data = {
"username": username,
"password": password,
"email": email,
"serveraddress": registry,
}
response = self._post_json(self._url("/auth"), data=req_data)
if response.status_code == 200:
self._auth_configs.add_auth(registry or auth.INDEX_NAME, req_data)
return self._result(response, get_json=True)
def ping(self) -> bool:
"""
Checks the server is responsive. An exception will be raised if it
is not responding.
Returns:
(bool) The response from the server.
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
return self._result(self._get(self._url("/_ping"))) == "OK"
def version(self, api_version: bool = True) -> dict[str, t.Any]:
"""
Returns version information from the server. Similar to the ``docker
version`` command.
Returns:
(dict): The server version information
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
url = self._url("/version", versioned_api=api_version)
return self._result(self._get(url), get_json=True)

View File

@ -1,139 +0,0 @@
# This code is part of the Ansible collection community.docker, but is an independent component.
# This particular file, and this file only, is based on the Docker SDK for Python (https://github.com/docker/docker-py/)
#
# Copyright (c) 2016-2022 Docker, Inc.
#
# It is licensed under the Apache 2.0 license (see LICENSES/Apache-2.0.txt in this collection)
# SPDX-License-Identifier: Apache-2.0
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import os
from .. import auth
from ..utils.decorators import minimum_version
class DaemonApiMixin:
@minimum_version("1.25")
def df(self):
"""
Get data usage information.
Returns:
(dict): A dictionary representing different resource categories
and their respective data usage.
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
url = self._url("/system/df")
return self._result(self._get(url), get_json=True)
def info(self):
"""
Display system-wide information. Identical to the ``docker info``
command.
Returns:
(dict): The info as a dict
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
return self._result(self._get(self._url("/info")), get_json=True)
def login(
self,
username,
password=None,
email=None,
registry=None,
reauth=False,
dockercfg_path=None,
):
"""
Authenticate with a registry. Similar to the ``docker login`` command.
Args:
username (str): The registry username
password (str): The plaintext password
email (str): The email for the registry account
registry (str): URL to the registry. E.g.
``https://index.docker.io/v1/``
reauth (bool): Whether or not to refresh existing authentication on
the Docker server.
dockercfg_path (str): Use a custom path for the Docker config file
(default ``$HOME/.docker/config.json`` if present,
otherwise ``$HOME/.dockercfg``)
Returns:
(dict): The response from the login request
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
# If we do not have any auth data so far, try reloading the config file
# one more time in case anything showed up in there.
# If dockercfg_path is passed check to see if the config file exists,
# if so load that config.
if dockercfg_path and os.path.exists(dockercfg_path):
self._auth_configs = auth.load_config(
dockercfg_path, credstore_env=self.credstore_env
)
elif not self._auth_configs or self._auth_configs.is_empty:
self._auth_configs = auth.load_config(credstore_env=self.credstore_env)
authcfg = self._auth_configs.resolve_authconfig(registry)
# If we found an existing auth config for this registry and username
# combination, we can return it immediately unless reauth is requested.
if authcfg and authcfg.get("username", None) == username and not reauth:
return authcfg
req_data = {
"username": username,
"password": password,
"email": email,
"serveraddress": registry,
}
response = self._post_json(self._url("/auth"), data=req_data)
if response.status_code == 200:
self._auth_configs.add_auth(registry or auth.INDEX_NAME, req_data)
return self._result(response, get_json=True)
def ping(self):
"""
Checks the server is responsive. An exception will be raised if it
is not responding.
Returns:
(bool) The response from the server.
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
return self._result(self._get(self._url("/_ping"))) == "OK"
def version(self, api_version=True):
"""
Returns version information from the server. Similar to the ``docker
version`` command.
Returns:
(dict): The server version information
Raises:
:py:class:`docker.errors.APIError`
If the server returns an error.
"""
url = self._url("/version", versioned_api=api_version)
return self._result(self._get(url), get_json=True)

View File

@ -14,6 +14,7 @@ from __future__ import annotations
import base64
import json
import logging
import typing as t
from . import errors
from .credentials.errors import CredentialsNotFound, StoreError
@ -21,6 +22,12 @@ from .credentials.store import Store
from .utils import config
if t.TYPE_CHECKING:
from ansible_collections.community.docker.plugins.module_utils._api.api.client import (
APIClient,
)
INDEX_NAME = "docker.io"
INDEX_URL = f"https://index.{INDEX_NAME}/v1/"
TOKEN_USERNAME = "<token>"
@ -28,7 +35,7 @@ TOKEN_USERNAME = "<token>"
log = logging.getLogger(__name__)
def resolve_repository_name(repo_name):
def resolve_repository_name(repo_name: str) -> tuple[str, str]:
if "://" in repo_name:
raise errors.InvalidRepository(
f"Repository name cannot contain a scheme ({repo_name})"
@ -42,14 +49,14 @@ def resolve_repository_name(repo_name):
return resolve_index_name(index_name), remote_name
def resolve_index_name(index_name):
def resolve_index_name(index_name: str) -> str:
index_name = convert_to_hostname(index_name)
if index_name == "index." + INDEX_NAME:
index_name = INDEX_NAME
return index_name
def get_config_header(client, registry):
def get_config_header(client: APIClient, registry: str) -> bytes | None:
log.debug("Looking for auth config")
if not client._auth_configs or client._auth_configs.is_empty:
log.debug("No auth config in memory - loading from filesystem")
@ -69,32 +76,38 @@ def get_config_header(client, registry):
return None
def split_repo_name(repo_name):
def split_repo_name(repo_name: str) -> tuple[str, str]:
parts = repo_name.split("/", 1)
if len(parts) == 1 or (
"." not in parts[0] and ":" not in parts[0] and parts[0] != "localhost"
):
# This is a docker index repo (ex: username/foobar or ubuntu)
return INDEX_NAME, repo_name
return tuple(parts)
return tuple(parts) # type: ignore
def get_credential_store(authconfig, registry):
def get_credential_store(
authconfig: dict[str, t.Any] | AuthConfig, registry: str
) -> str | None:
if not isinstance(authconfig, AuthConfig):
authconfig = AuthConfig(authconfig)
return authconfig.get_credential_store(registry)
class AuthConfig(dict):
def __init__(self, dct, credstore_env=None):
def __init__(
self, dct: dict[str, t.Any], credstore_env: dict[str, str] | None = None
):
if "auths" not in dct:
dct["auths"] = {}
self.update(dct)
self._credstore_env = credstore_env
self._stores = {}
self._stores: dict[str, Store] = {}
@classmethod
def parse_auth(cls, entries, raise_on_error=False):
def parse_auth(
cls, entries: dict[str, dict[str, t.Any]], raise_on_error=False
) -> dict[str, dict[str, t.Any]]:
"""
Parses authentication entries
@ -107,10 +120,10 @@ class AuthConfig(dict):
Authentication registry.
"""
conf = {}
conf: dict[str, dict[str, t.Any]] = {}
for registry, entry in entries.items():
if not isinstance(entry, dict):
log.debug("Config entry for key %s is not auth config", registry)
log.debug("Config entry for key %s is not auth config", registry) # type: ignore
# We sometimes fall back to parsing the whole config as if it
# was the auth config by itself, for legacy purposes. In that
# case, we fail silently and return an empty conf if any of the
@ -150,7 +163,12 @@ class AuthConfig(dict):
return conf
@classmethod
def load_config(cls, config_path, config_dict, credstore_env=None):
def load_config(
cls,
config_path: str | None,
config_dict: dict[str, t.Any] | None,
credstore_env: dict[str, str] | None = None,
) -> t.Self:
"""
Loads authentication data from a Docker configuration file in the given
root directory or if config_path is passed use given path.
@ -196,22 +214,24 @@ class AuthConfig(dict):
return cls({"auths": cls.parse_auth(config_dict)}, credstore_env)
@property
def auths(self):
def auths(self) -> dict[str, dict[str, t.Any]]:
return self.get("auths", {})
@property
def creds_store(self):
def creds_store(self) -> str | None:
return self.get("credsStore", None)
@property
def cred_helpers(self):
def cred_helpers(self) -> dict[str, t.Any]:
return self.get("credHelpers", {})
@property
def is_empty(self):
def is_empty(self) -> bool:
return not self.auths and not self.creds_store and not self.cred_helpers
def resolve_authconfig(self, registry=None):
def resolve_authconfig(
self, registry: str | None = None
) -> dict[str, t.Any] | None:
"""
Returns the authentication data from the given auth configuration for a
specific registry. As with the Docker client, legacy entries in the
@ -244,7 +264,9 @@ class AuthConfig(dict):
log.debug("No entry found")
return None
def _resolve_authconfig_credstore(self, registry, credstore_name):
def _resolve_authconfig_credstore(
self, registry: str | None, credstore_name: str
) -> dict[str, t.Any] | None:
if not registry or registry == INDEX_NAME:
# The ecosystem is a little schizophrenic with index.docker.io VS
# docker.io - in that case, it seems the full URL is necessary.
@ -272,19 +294,19 @@ class AuthConfig(dict):
except StoreError as e:
raise errors.DockerException(f"Credentials store error: {e}")
def _get_store_instance(self, name):
def _get_store_instance(self, name: str):
if name not in self._stores:
self._stores[name] = Store(name, environment=self._credstore_env)
return self._stores[name]
def get_credential_store(self, registry):
def get_credential_store(self, registry: str | None) -> str | None:
if not registry or registry == INDEX_NAME:
registry = INDEX_URL
return self.cred_helpers.get(registry) or self.creds_store
def get_all_credentials(self):
auth_data = self.auths.copy()
def get_all_credentials(self) -> dict[str, dict[str, t.Any] | None]:
auth_data: dict[str, dict[str, t.Any] | None] = self.auths.copy() # type: ignore
if self.creds_store:
# Retrieve all credentials from the default store
store = self._get_store_instance(self.creds_store)
@ -299,21 +321,23 @@ class AuthConfig(dict):
return auth_data
def add_auth(self, reg, data):
def add_auth(self, reg: str, data: dict[str, t.Any]) -> None:
self["auths"][reg] = data
def resolve_authconfig(authconfig, registry=None, credstore_env=None):
def resolve_authconfig(
authconfig, registry: str | None = None, credstore_env: dict[str, str] | None = None
):
if not isinstance(authconfig, AuthConfig):
authconfig = AuthConfig(authconfig, credstore_env)
return authconfig.resolve_authconfig(registry)
def convert_to_hostname(url):
def convert_to_hostname(url: str) -> str:
return url.replace("http://", "").replace("https://", "").split("/", 1)[0]
def decode_auth(auth):
def decode_auth(auth: str | bytes) -> tuple[str, str]:
if isinstance(auth, str):
auth = auth.encode("ascii")
s = base64.b64decode(auth)
@ -321,12 +345,14 @@ def decode_auth(auth):
return login.decode("utf8"), pwd.decode("utf8")
def encode_header(auth):
def encode_header(auth: dict[str, t.Any]) -> bytes:
auth_json = json.dumps(auth).encode("ascii")
return base64.urlsafe_b64encode(auth_json)
def parse_auth(entries, raise_on_error=False):
def parse_auth(
entries: dict[str, dict[str, t.Any]], raise_on_error: bool = False
) -> dict[str, dict[str, t.Any]]:
"""
Parses authentication entries
@ -342,11 +368,15 @@ def parse_auth(entries, raise_on_error=False):
return AuthConfig.parse_auth(entries, raise_on_error)
def load_config(config_path=None, config_dict=None, credstore_env=None):
def load_config(
config_path: str | None = None,
config_dict: dict[str, t.Any] | None = None,
credstore_env: dict[str, str] | None = None,
) -> AuthConfig:
return AuthConfig.load_config(config_path, config_dict, credstore_env)
def _load_legacy_config(config_file):
def _load_legacy_config(config_file: str) -> dict[str, dict[str, t.Any]]:
log.debug("Attempting to parse legacy auth file format")
try:
data = []

View File

@ -13,6 +13,7 @@ from __future__ import annotations
import json
import os
import typing as t
from .. import errors
from .config import (
@ -24,7 +25,11 @@ from .config import (
from .context import Context
def create_default_context():
if t.TYPE_CHECKING:
from ..tls import TLSConfig
def create_default_context() -> Context:
host = None
if os.environ.get("DOCKER_HOST"):
host = os.environ.get("DOCKER_HOST")
@ -42,7 +47,7 @@ class ContextAPI:
DEFAULT_CONTEXT = None
@classmethod
def get_default_context(cls):
def get_default_context(cls) -> Context:
context = cls.DEFAULT_CONTEXT
if context is None:
context = create_default_context()
@ -52,13 +57,13 @@ class ContextAPI:
@classmethod
def create_context(
cls,
name,
orchestrator=None,
host=None,
tls_cfg=None,
default_namespace=None,
skip_tls_verify=False,
):
name: str,
orchestrator: str | None = None,
host: str | None = None,
tls_cfg: TLSConfig | None = None,
default_namespace: str | None = None,
skip_tls_verify: bool = False,
) -> Context:
"""Creates a new context.
Returns:
(Context): a Context object.
@ -108,7 +113,7 @@ class ContextAPI:
return ctx
@classmethod
def get_context(cls, name=None):
def get_context(cls, name: str | None = None) -> Context | None:
"""Retrieves a context object.
Args:
name (str): The name of the context
@ -136,7 +141,7 @@ class ContextAPI:
return Context.load_context(name)
@classmethod
def contexts(cls):
def contexts(cls) -> list[Context]:
"""Context list.
Returns:
(Context): List of context objects.
@ -170,7 +175,7 @@ class ContextAPI:
return contexts
@classmethod
def get_current_context(cls):
def get_current_context(cls) -> Context | None:
"""Get current context.
Returns:
(Context): current context object.
@ -178,7 +183,7 @@ class ContextAPI:
return cls.get_context()
@classmethod
def set_current_context(cls, name="default"):
def set_current_context(cls, name: str = "default") -> None:
ctx = cls.get_context(name)
if not ctx:
raise errors.ContextNotFound(name)
@ -188,7 +193,7 @@ class ContextAPI:
raise errors.ContextException(f"Failed to set current context: {err}")
@classmethod
def remove_context(cls, name):
def remove_context(cls, name: str) -> None:
"""Remove a context. Similar to the ``docker context rm`` command.
Args:
@ -220,7 +225,7 @@ class ContextAPI:
ctx.remove()
@classmethod
def inspect_context(cls, name="default"):
def inspect_context(cls, name: str = "default") -> dict[str, t.Any]:
"""Inspect a context. Similar to the ``docker context inspect`` command.
Args:

View File

@ -23,7 +23,7 @@ from ..utils.utils import parse_host
METAFILE = "meta.json"
def get_current_context_name_with_source():
def get_current_context_name_with_source() -> tuple[str, str]:
if os.environ.get("DOCKER_HOST"):
return "default", "DOCKER_HOST environment variable set"
if os.environ.get("DOCKER_CONTEXT"):
@ -41,11 +41,11 @@ def get_current_context_name_with_source():
return "default", "fallback value"
def get_current_context_name():
def get_current_context_name() -> str:
return get_current_context_name_with_source()[0]
def write_context_name_to_docker_config(name=None):
def write_context_name_to_docker_config(name: str | None = None) -> Exception | None:
if name == "default":
name = None
docker_cfg_path = find_config_file()
@ -62,44 +62,45 @@ def write_context_name_to_docker_config(name=None):
elif name:
config["currentContext"] = name
else:
return
return None
if not docker_cfg_path:
docker_cfg_path = get_default_config_file()
try:
with open(docker_cfg_path, "wt", encoding="utf-8") as f:
json.dump(config, f, indent=4)
return None
except Exception as e: # pylint: disable=broad-exception-caught
return e
def get_context_id(name):
def get_context_id(name: str) -> str:
return hashlib.sha256(name.encode("utf-8")).hexdigest()
def get_context_dir():
def get_context_dir() -> str:
docker_cfg_path = find_config_file() or get_default_config_file()
return os.path.join(os.path.dirname(docker_cfg_path), "contexts")
def get_meta_dir(name=None):
def get_meta_dir(name: str | None = None) -> str:
meta_dir = os.path.join(get_context_dir(), "meta")
if name:
return os.path.join(meta_dir, get_context_id(name))
return meta_dir
def get_meta_file(name):
def get_meta_file(name) -> str:
return os.path.join(get_meta_dir(name), METAFILE)
def get_tls_dir(name=None, endpoint=""):
def get_tls_dir(name: str | None = None, endpoint: str = "") -> str:
context_dir = get_context_dir()
if name:
return os.path.join(context_dir, "tls", get_context_id(name), endpoint)
return os.path.join(context_dir, "tls")
def get_context_host(path=None, tls=False):
def get_context_host(path: str | None = None, tls: bool = False) -> str:
host = parse_host(path, IS_WINDOWS_PLATFORM, tls)
if host == DEFAULT_UNIX_SOCKET:
# remove http+ from default docker socket url

View File

@ -13,6 +13,7 @@ from __future__ import annotations
import json
import os
import typing as t
from shutil import copyfile, rmtree
from ..errors import ContextException
@ -33,21 +34,21 @@ class Context:
def __init__(
self,
name,
orchestrator=None,
host=None,
endpoints=None,
skip_tls_verify=False,
tls=False,
description=None,
):
name: str,
orchestrator: str | None = None,
host: str | None = None,
endpoints: dict[str, dict[str, t.Any]] | None = None,
skip_tls_verify: bool = False,
tls: bool = False,
description: str | None = None,
) -> None:
if not name:
raise ValueError("Name not provided")
self.name = name
self.context_type = None
self.orchestrator = orchestrator
self.endpoints = {}
self.tls_cfg = {}
self.tls_cfg: dict[str, TLSConfig] = {}
self.meta_path = IN_MEMORY
self.tls_path = IN_MEMORY
self.description = description
@ -89,12 +90,12 @@ class Context:
def set_endpoint(
self,
name="docker",
host=None,
tls_cfg=None,
skip_tls_verify=False,
def_namespace=None,
):
name: str = "docker",
host: str | None = None,
tls_cfg: TLSConfig | None = None,
skip_tls_verify: bool = False,
def_namespace: str | None = None,
) -> None:
self.endpoints[name] = {
"Host": get_context_host(host, not skip_tls_verify or tls_cfg is not None),
"SkipTLSVerify": skip_tls_verify,
@ -105,11 +106,11 @@ class Context:
if tls_cfg:
self.tls_cfg[name] = tls_cfg
def inspect(self):
def inspect(self) -> dict[str, t.Any]:
return self()
@classmethod
def load_context(cls, name):
def load_context(cls, name: str) -> t.Self | None:
meta = Context._load_meta(name)
if meta:
instance = cls(
@ -125,12 +126,12 @@ class Context:
return None
@classmethod
def _load_meta(cls, name):
def _load_meta(cls, name: str) -> dict[str, t.Any] | None:
meta_file = get_meta_file(name)
if not os.path.isfile(meta_file):
return None
metadata = {}
metadata: dict[str, t.Any] = {}
try:
with open(meta_file, "rt", encoding="utf-8") as f:
metadata = json.load(f)
@ -154,7 +155,7 @@ class Context:
return metadata
def _load_certs(self):
def _load_certs(self) -> None:
certs = {}
tls_dir = get_tls_dir(self.name)
for endpoint in self.endpoints:
@ -184,7 +185,7 @@ class Context:
self.tls_cfg = certs
self.tls_path = tls_dir
def save(self):
def save(self) -> None:
meta_dir = get_meta_dir(self.name)
if not os.path.isdir(meta_dir):
os.makedirs(meta_dir)
@ -216,54 +217,54 @@ class Context:
self.meta_path = get_meta_dir(self.name)
self.tls_path = get_tls_dir(self.name)
def remove(self):
def remove(self) -> None:
if os.path.isdir(self.meta_path):
rmtree(self.meta_path)
if os.path.isdir(self.tls_path):
rmtree(self.tls_path)
def __repr__(self):
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: '{self.name}'>"
def __str__(self):
def __str__(self) -> str:
return json.dumps(self.__call__(), indent=2)
def __call__(self):
def __call__(self) -> dict[str, t.Any]:
result = self.Metadata
result.update(self.TLSMaterial)
result.update(self.Storage)
return result
def is_docker_host(self):
def is_docker_host(self) -> bool:
return self.context_type is None
@property
def Name(self): # pylint: disable=invalid-name
def Name(self) -> str: # pylint: disable=invalid-name
return self.name
@property
def Host(self): # pylint: disable=invalid-name
def Host(self) -> str | None: # pylint: disable=invalid-name
if not self.orchestrator or self.orchestrator == "swarm":
endpoint = self.endpoints.get("docker", None)
if endpoint:
return endpoint.get("Host", None)
return endpoint.get("Host", None) # type: ignore
return None
return self.endpoints[self.orchestrator].get("Host", None)
return self.endpoints[self.orchestrator].get("Host", None) # type: ignore
@property
def Orchestrator(self): # pylint: disable=invalid-name
def Orchestrator(self) -> str | None: # pylint: disable=invalid-name
return self.orchestrator
@property
def Metadata(self): # pylint: disable=invalid-name
meta = {}
def Metadata(self) -> dict[str, t.Any]: # pylint: disable=invalid-name
meta: dict[str, t.Any] = {}
if self.orchestrator:
meta = {"StackOrchestrator": self.orchestrator}
return {"Name": self.name, "Metadata": meta, "Endpoints": self.endpoints}
@property
def TLSConfig(self): # pylint: disable=invalid-name
def TLSConfig(self) -> TLSConfig | None: # pylint: disable=invalid-name
key = self.orchestrator
if not key or key == "swarm":
key = "docker"
@ -272,13 +273,15 @@ class Context:
return None
@property
def TLSMaterial(self): # pylint: disable=invalid-name
certs = {}
def TLSMaterial(self) -> dict[str, t.Any]: # pylint: disable=invalid-name
certs: dict[str, t.Any] = {}
for endpoint, tls in self.tls_cfg.items():
cert, key = tls.cert
certs[endpoint] = list(map(os.path.basename, [tls.ca_cert, cert, key]))
paths = [tls.ca_cert, *tls.cert] if tls.cert else [tls.ca_cert]
certs[endpoint] = [
os.path.basename(path) if path else None for path in paths
]
return {"TLSMaterial": certs}
@property
def Storage(self): # pylint: disable=invalid-name
def Storage(self) -> dict[str, t.Any]: # pylint: disable=invalid-name
return {"Storage": {"MetadataPath": self.meta_path, "TLSPath": self.tls_path}}

View File

@ -11,6 +11,12 @@
from __future__ import annotations
import typing as t
if t.TYPE_CHECKING:
from subprocess import CalledProcessError
class StoreError(RuntimeError):
pass
@ -24,7 +30,7 @@ class InitializationError(StoreError):
pass
def process_store_error(cpe, program):
def process_store_error(cpe: CalledProcessError, program: str) -> StoreError:
message = cpe.output.decode("utf-8")
if "credentials not found in native keychain" in message:
return CredentialsNotFound(f"No matching credentials in {program}")

View File

@ -14,13 +14,14 @@ from __future__ import annotations
import errno
import json
import subprocess
import typing as t
from . import constants, errors
from .utils import create_environment_dict, find_executable
class Store:
def __init__(self, program, environment=None):
def __init__(self, program: str, environment: dict[str, str] | None = None) -> None:
"""Create a store object that acts as an interface to
perform the basic operations for storing, retrieving
and erasing credentials using `program`.
@ -33,7 +34,7 @@ class Store:
f"{self.program} not installed or not available in PATH"
)
def get(self, server):
def get(self, server: str | bytes) -> dict[str, t.Any]:
"""Retrieve credentials for `server`. If no credentials are found,
a `StoreError` will be raised.
"""
@ -53,7 +54,7 @@ class Store:
return result
def store(self, server, username, secret):
def store(self, server: str, username: str, secret: str) -> bytes:
"""Store credentials for `server`. Raises a `StoreError` if an error
occurs.
"""
@ -62,7 +63,7 @@ class Store:
).encode("utf-8")
return self._execute("store", data_input)
def erase(self, server):
def erase(self, server: str | bytes) -> None:
"""Erase credentials for `server`. Raises a `StoreError` if an error
occurs.
"""
@ -70,12 +71,16 @@ class Store:
server = server.encode("utf-8")
self._execute("erase", server)
def list(self):
def list(self) -> t.Any:
"""List stored credentials. Requires v0.4.0+ of the helper."""
data = self._execute("list", None)
return json.loads(data.decode("utf-8"))
def _execute(self, subcmd, data_input):
def _execute(self, subcmd: str, data_input: bytes | None) -> bytes:
if self.exe is None:
raise errors.StoreError(
f"{self.program} not installed or not available in PATH"
)
output = None
env = create_environment_dict(self.environment)
try:

View File

@ -15,7 +15,7 @@ import os
from shutil import which
def find_executable(executable, path=None):
def find_executable(executable: str, path: str | None = None) -> str | None:
"""
As distutils.spawn.find_executable, but on Windows, look up
every extension declared in PATHEXT instead of just `.exe`
@ -26,7 +26,7 @@ def find_executable(executable, path=None):
return which(executable, path=path)
def create_environment_dict(overrides):
def create_environment_dict(overrides: dict[str, str] | None) -> dict[str, str]:
"""
Create and return a copy of os.environ with the specified overrides
"""

View File

@ -11,6 +11,8 @@
from __future__ import annotations
import typing as t
from ansible.module_utils.common.text.converters import to_native
from ._import_helper import HTTPError as _HTTPError
@ -25,7 +27,7 @@ class DockerException(Exception):
"""
def create_api_error_from_http_exception(e):
def create_api_error_from_http_exception(e: _HTTPError) -> t.NoReturn:
"""
Create a suitable APIError from requests.exceptions.HTTPError.
"""
@ -52,14 +54,16 @@ class APIError(_HTTPError, DockerException):
An HTTP error from the API.
"""
def __init__(self, message, response=None, explanation=None):
def __init__(
self, message: str | Exception, response=None, explanation: str | None = None
) -> None:
# requests 1.2 supports response as a keyword argument, but
# requests 1.1 does not
super().__init__(message)
self.response = response
self.explanation = explanation
def __str__(self):
def __str__(self) -> str:
message = super().__str__()
if self.is_client_error():
@ -74,19 +78,20 @@ class APIError(_HTTPError, DockerException):
return message
@property
def status_code(self):
def status_code(self) -> int | None:
if self.response is not None:
return self.response.status_code
return None
def is_error(self):
def is_error(self) -> bool:
return self.is_client_error() or self.is_server_error()
def is_client_error(self):
def is_client_error(self) -> bool:
if self.status_code is None:
return False
return 400 <= self.status_code < 500
def is_server_error(self):
def is_server_error(self) -> bool:
if self.status_code is None:
return False
return 500 <= self.status_code < 600
@ -121,10 +126,10 @@ class DeprecatedMethod(DockerException):
class TLSParameterError(DockerException):
def __init__(self, msg):
def __init__(self, msg: str) -> None:
self.msg = msg
def __str__(self):
def __str__(self) -> str:
return self.msg + (
". TLS configurations should map the Docker CLI "
"client configurations. See "
@ -142,7 +147,14 @@ class ContainerError(DockerException):
Represents a container that has exited with a non-zero exit code.
"""
def __init__(self, container, exit_status, command, image, stderr):
def __init__(
self,
container: str,
exit_status: int,
command: list[str],
image: str,
stderr: str | None,
):
self.container = container
self.exit_status = exit_status
self.command = command
@ -156,12 +168,12 @@ class ContainerError(DockerException):
class StreamParseError(RuntimeError):
def __init__(self, reason):
def __init__(self, reason: Exception) -> None:
self.msg = reason
class BuildError(DockerException):
def __init__(self, reason, build_log):
def __init__(self, reason: str, build_log: str) -> None:
super().__init__(reason)
self.msg = reason
self.build_log = build_log
@ -171,7 +183,7 @@ class ImageLoadError(DockerException):
pass
def create_unexpected_kwargs_error(name, kwargs):
def create_unexpected_kwargs_error(name: str, kwargs: dict[str, t.Any]) -> TypeError:
quoted_kwargs = [f"'{k}'" for k in sorted(kwargs)]
text = [f"{name}() "]
if len(quoted_kwargs) == 1:
@ -183,42 +195,44 @@ def create_unexpected_kwargs_error(name, kwargs):
class MissingContextParameter(DockerException):
def __init__(self, param):
def __init__(self, param: str) -> None:
self.param = param
def __str__(self):
def __str__(self) -> str:
return f"missing parameter: {self.param}"
class ContextAlreadyExists(DockerException):
def __init__(self, name):
def __init__(self, name: str) -> None:
self.name = name
def __str__(self):
def __str__(self) -> str:
return f"context {self.name} already exists"
class ContextException(DockerException):
def __init__(self, msg):
def __init__(self, msg: str) -> None:
self.msg = msg
def __str__(self):
def __str__(self) -> str:
return self.msg
class ContextNotFound(DockerException):
def __init__(self, name):
def __init__(self, name: str) -> None:
self.name = name
def __str__(self):
def __str__(self) -> str:
return f"context '{self.name}' not found"
class MissingRequirementException(DockerException):
def __init__(self, msg, requirement, import_exception):
def __init__(
self, msg: str, requirement: str, import_exception: ImportError | str
) -> None:
self.msg = msg
self.requirement = requirement
self.import_exception = import_exception
def __str__(self):
def __str__(self) -> str:
return self.msg

View File

@ -12,12 +12,18 @@
from __future__ import annotations
import os
import ssl
import typing as t
from . import errors
from .transport.ssladapter import SSLHTTPAdapter
if t.TYPE_CHECKING:
from ansible_collections.community.docker.plugins.module_utils._api.api.client import (
APIClient,
)
class TLSConfig:
"""
TLS configuration.
@ -27,25 +33,22 @@ class TLSConfig:
ca_cert (str): Path to CA cert file.
verify (bool or str): This can be ``False`` or a path to a CA cert
file.
ssl_version (int): A valid `SSL version`_.
assert_hostname (bool): Verify the hostname of the server.
.. _`SSL version`:
https://docs.python.org/3.5/library/ssl.html#ssl.PROTOCOL_TLSv1
"""
cert = None
ca_cert = None
verify = None
ssl_version = None
cert: tuple[str, str] | None = None
ca_cert: str | None = None
verify: bool | None = None
def __init__(
self,
client_cert=None,
ca_cert=None,
verify=None,
ssl_version=None,
assert_hostname=None,
client_cert: tuple[str, str] | None = None,
ca_cert: str | None = None,
verify: bool | None = None,
assert_hostname: bool | None = None,
):
# Argument compatibility/mapping with
# https://docs.docker.com/engine/articles/https/
@ -55,12 +58,6 @@ class TLSConfig:
self.assert_hostname = assert_hostname
# If the user provides an SSL version, we should use their preference
if ssl_version:
self.ssl_version = ssl_version
else:
self.ssl_version = ssl.PROTOCOL_TLS_CLIENT
# "client_cert" must have both or neither cert/key files. In
# either case, Alert the user when both are expected, but any are
# missing.
@ -90,11 +87,10 @@ class TLSConfig:
"Invalid CA certificate provided for `ca_cert`."
)
def configure_client(self, client):
def configure_client(self, client: APIClient) -> None:
"""
Configure a client with these TLS options.
"""
client.ssl_version = self.ssl_version
if self.verify and self.ca_cert:
client.verify = self.ca_cert
@ -107,7 +103,6 @@ class TLSConfig:
client.mount(
"https://",
SSLHTTPAdapter(
ssl_version=self.ssl_version,
assert_hostname=self.assert_hostname,
),
)

View File

@ -15,7 +15,7 @@ from .._import_helper import HTTPAdapter as _HTTPAdapter
class BaseHTTPAdapter(_HTTPAdapter):
def close(self):
def close(self) -> None:
super().close()
if hasattr(self, "pools"):
self.pools.clear()
@ -24,10 +24,10 @@ class BaseHTTPAdapter(_HTTPAdapter):
# https://github.com/psf/requests/commit/c0813a2d910ea6b4f8438b91d315b8d181302356
# changes requests.adapters.HTTPAdapter to no longer call get_connection() from
# send(), but instead call _get_connection().
def _get_connection(self, request, *args, **kwargs):
def _get_connection(self, request, *args, **kwargs): # type: ignore
return self.get_connection(request.url, kwargs.get("proxies"))
# Fix for requests 2.32.2+:
# https://github.com/psf/requests/commit/c98e4d133ef29c46a9b68cd783087218a8075e05
def get_connection_with_tls_context(self, request, verify, proxies=None, cert=None):
def get_connection_with_tls_context(self, request, verify, proxies=None, cert=None): # type: ignore
return self.get_connection(request.url, proxies)

View File

@ -23,12 +23,12 @@ RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
class NpipeHTTPConnection(urllib3_connection.HTTPConnection):
def __init__(self, npipe_path, timeout=60):
def __init__(self, npipe_path: str, timeout: int = 60) -> None:
super().__init__("localhost", timeout=timeout)
self.npipe_path = npipe_path
self.timeout = timeout
def connect(self):
def connect(self) -> None:
sock = NpipeSocket()
sock.settimeout(self.timeout)
sock.connect(self.npipe_path)
@ -36,18 +36,18 @@ class NpipeHTTPConnection(urllib3_connection.HTTPConnection):
class NpipeHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
def __init__(self, npipe_path, timeout=60, maxsize=10):
def __init__(self, npipe_path: str, timeout: int = 60, maxsize: int = 10) -> None:
super().__init__("localhost", timeout=timeout, maxsize=maxsize)
self.npipe_path = npipe_path
self.timeout = timeout
def _new_conn(self):
def _new_conn(self) -> NpipeHTTPConnection:
return NpipeHTTPConnection(self.npipe_path, self.timeout)
# When re-using connections, urllib3 tries to call select() on our
# NpipeSocket instance, causing a crash. To circumvent this, we override
# _get_conn, where that check happens.
def _get_conn(self, timeout):
def _get_conn(self, timeout: int) -> NpipeHTTPConnection:
conn = None
try:
conn = self.pool.get(block=self.block, timeout=timeout)
@ -67,7 +67,6 @@ class NpipeHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
class NpipeHTTPAdapter(BaseHTTPAdapter):
__attrs__ = HTTPAdapter.__attrs__ + [
"npipe_path",
"pools",
@ -77,11 +76,11 @@ class NpipeHTTPAdapter(BaseHTTPAdapter):
def __init__(
self,
base_url,
timeout=60,
pool_connections=constants.DEFAULT_NUM_POOLS,
max_pool_size=constants.DEFAULT_MAX_POOL_SIZE,
):
base_url: str,
timeout: int = 60,
pool_connections: int = constants.DEFAULT_NUM_POOLS,
max_pool_size: int = constants.DEFAULT_MAX_POOL_SIZE,
) -> None:
self.npipe_path = base_url.replace("npipe://", "")
self.timeout = timeout
self.max_pool_size = max_pool_size
@ -90,7 +89,7 @@ class NpipeHTTPAdapter(BaseHTTPAdapter):
)
super().__init__()
def get_connection(self, url, proxies=None):
def get_connection(self, url: str | bytes, proxies=None) -> NpipeHTTPConnectionPool:
with self.pools.lock:
pool = self.pools.get(url)
if pool:
@ -103,7 +102,7 @@ class NpipeHTTPAdapter(BaseHTTPAdapter):
return pool
def request_url(self, request, proxies):
def request_url(self, request, proxies) -> str:
# The select_proxy utility in requests errors out when the provided URL
# does not have a hostname, like is the case when using a UNIX socket.
# Since proxies are an irrelevant notion in the case of UNIX sockets

View File

@ -15,6 +15,7 @@ import functools
import io
import time
import traceback
import typing as t
PYWIN32_IMPORT_ERROR: str | None # pylint: disable=invalid-name
@ -29,6 +30,13 @@ except ImportError:
else:
PYWIN32_IMPORT_ERROR = None # pylint: disable=invalid-name
if t.TYPE_CHECKING:
from collections.abc import Buffer, Callable
_Self = t.TypeVar("_Self")
_P = t.ParamSpec("_P")
_R = t.TypeVar("_R")
ERROR_PIPE_BUSY = 0xE7
SECURITY_SQOS_PRESENT = 0x100000
@ -37,10 +45,12 @@ SECURITY_ANONYMOUS = 0
MAXIMUM_RETRY_COUNT = 10
def check_closed(f):
def check_closed(
f: Callable[t.Concatenate[_Self, _P], _R],
) -> Callable[t.Concatenate[_Self, _P], _R]:
@functools.wraps(f)
def wrapped(self, *args, **kwargs):
if self._closed:
def wrapped(self: _Self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
if self._closed: # type: ignore
raise RuntimeError("Can not reuse socket after connection was closed.")
return f(self, *args, **kwargs)
@ -54,25 +64,25 @@ class NpipeSocket:
implemented.
"""
def __init__(self, handle=None):
def __init__(self, handle=None) -> None:
self._timeout = win32pipe.NMPWAIT_USE_DEFAULT_WAIT
self._handle = handle
self._address = None
self._address: str | None = None
self._closed = False
self.flags = None
self.flags: int | None = None
def accept(self):
def accept(self) -> t.NoReturn:
raise NotImplementedError()
def bind(self, address):
def bind(self, address) -> t.NoReturn:
raise NotImplementedError()
def close(self):
def close(self) -> None:
self._handle.Close()
self._closed = True
@check_closed
def connect(self, address, retry_count=0):
def connect(self, address, retry_count: int = 0) -> None:
try:
handle = win32file.CreateFile(
address,
@ -100,14 +110,14 @@ class NpipeSocket:
return self.connect(address, retry_count)
raise e
self.flags = win32pipe.GetNamedPipeInfo(handle)[0]
self.flags = win32pipe.GetNamedPipeInfo(handle)[0] # type: ignore
self._handle = handle
self._address = address
@check_closed
def connect_ex(self, address):
return self.connect(address)
def connect_ex(self, address) -> None:
self.connect(address)
@check_closed
def detach(self):
@ -115,25 +125,25 @@ class NpipeSocket:
return self._handle
@check_closed
def dup(self):
def dup(self) -> NpipeSocket:
return NpipeSocket(self._handle)
def getpeername(self):
def getpeername(self) -> str | None:
return self._address
def getsockname(self):
def getsockname(self) -> str | None:
return self._address
def getsockopt(self, level, optname, buflen=None):
def getsockopt(self, level, optname, buflen=None) -> t.NoReturn:
raise NotImplementedError()
def ioctl(self, control, option):
def ioctl(self, control, option) -> t.NoReturn:
raise NotImplementedError()
def listen(self, backlog):
def listen(self, backlog) -> t.NoReturn:
raise NotImplementedError()
def makefile(self, mode=None, bufsize=None):
def makefile(self, mode: str, bufsize: int | None = None):
if mode.strip("b") != "r":
raise NotImplementedError()
rawio = NpipeFileIOBase(self)
@ -142,30 +152,30 @@ class NpipeSocket:
return io.BufferedReader(rawio, buffer_size=bufsize)
@check_closed
def recv(self, bufsize, flags=0):
def recv(self, bufsize: int, flags: int = 0) -> str:
dummy_err, data = win32file.ReadFile(self._handle, bufsize)
return data
@check_closed
def recvfrom(self, bufsize, flags=0):
def recvfrom(self, bufsize: int, flags: int = 0) -> tuple[str, str | None]:
data = self.recv(bufsize, flags)
return (data, self._address)
@check_closed
def recvfrom_into(self, buf, nbytes=0, flags=0):
return self.recv_into(buf, nbytes, flags), self._address
def recvfrom_into(
self, buf: Buffer, nbytes: int = 0, flags: int = 0
) -> tuple[int, str | None]:
return self.recv_into(buf, nbytes), self._address
@check_closed
def recv_into(self, buf, nbytes=0):
readbuf = buf
if not isinstance(buf, memoryview):
readbuf = memoryview(buf)
def recv_into(self, buf: Buffer, nbytes: int = 0) -> int:
readbuf = buf if isinstance(buf, memoryview) else memoryview(buf)
event = win32event.CreateEvent(None, True, True, None)
try:
overlapped = pywintypes.OVERLAPPED()
overlapped.hEvent = event
dummy_err, dummy_data = win32file.ReadFile(
dummy_err, dummy_data = win32file.ReadFile( # type: ignore
self._handle, readbuf[:nbytes] if nbytes else readbuf, overlapped
)
wait_result = win32event.WaitForSingleObject(event, self._timeout)
@ -177,12 +187,12 @@ class NpipeSocket:
win32api.CloseHandle(event)
@check_closed
def send(self, string, flags=0):
def send(self, string: Buffer, flags: int = 0) -> int:
event = win32event.CreateEvent(None, True, True, None)
try:
overlapped = pywintypes.OVERLAPPED()
overlapped.hEvent = event
win32file.WriteFile(self._handle, string, overlapped)
win32file.WriteFile(self._handle, string, overlapped) # type: ignore
wait_result = win32event.WaitForSingleObject(event, self._timeout)
if wait_result == win32event.WAIT_TIMEOUT:
win32file.CancelIo(self._handle)
@ -192,20 +202,20 @@ class NpipeSocket:
win32api.CloseHandle(event)
@check_closed
def sendall(self, string, flags=0):
def sendall(self, string: Buffer, flags: int = 0) -> int:
return self.send(string, flags)
@check_closed
def sendto(self, string, address):
def sendto(self, string: Buffer, address: str) -> int:
self.connect(address)
return self.send(string)
def setblocking(self, flag):
def setblocking(self, flag: bool):
if flag:
return self.settimeout(None)
return self.settimeout(0)
def settimeout(self, value):
def settimeout(self, value: int | float | None) -> None:
if value is None:
# Blocking mode
self._timeout = win32event.INFINITE
@ -215,39 +225,39 @@ class NpipeSocket:
# Timeout mode - Value converted to milliseconds
self._timeout = int(value * 1000)
def gettimeout(self):
def gettimeout(self) -> int | float | None:
return self._timeout
def setsockopt(self, level, optname, value):
def setsockopt(self, level, optname, value) -> t.NoReturn:
raise NotImplementedError()
@check_closed
def shutdown(self, how):
def shutdown(self, how) -> None:
return self.close()
class NpipeFileIOBase(io.RawIOBase):
def __init__(self, npipe_socket):
def __init__(self, npipe_socket) -> None:
self.sock = npipe_socket
def close(self):
def close(self) -> None:
super().close()
self.sock = None
def fileno(self):
def fileno(self) -> int:
return self.sock.fileno()
def isatty(self):
def isatty(self) -> bool:
return False
def readable(self):
def readable(self) -> bool:
return True
def readinto(self, buf):
def readinto(self, buf: Buffer) -> int:
return self.sock.recv_into(buf)
def seekable(self):
def seekable(self) -> bool:
return False
def writable(self):
def writable(self) -> bool:
return False

View File

@ -17,6 +17,7 @@ import signal
import socket
import subprocess
import traceback
import typing as t
from queue import Empty
from urllib.parse import urlparse
@ -33,12 +34,15 @@ except ImportError:
else:
PARAMIKO_IMPORT_ERROR = None # pylint: disable=invalid-name
if t.TYPE_CHECKING:
from collections.abc import Buffer
RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
class SSHSocket(socket.socket):
def __init__(self, host):
def __init__(self, host: str) -> None:
super().__init__(socket.AF_INET, socket.SOCK_STREAM)
self.host = host
self.port = None
@ -48,9 +52,9 @@ class SSHSocket(socket.socket):
if "@" in self.host:
self.user, self.host = self.host.split("@")
self.proc = None
self.proc: subprocess.Popen | None = None
def connect(self, **kwargs):
def connect(self, *args_: t.Any, **kwargs: t.Any) -> None:
args = ["ssh"]
if self.user:
args = args + ["-l", self.user]
@ -82,37 +86,48 @@ class SSHSocket(socket.socket):
preexec_fn=preexec_func,
)
def _write(self, data):
if not self.proc or self.proc.stdin.closed:
def _write(self, data: Buffer) -> int:
if not self.proc:
raise RuntimeError(
"SSH subprocess not initiated. connect() must be called first."
)
assert self.proc.stdin is not None
if self.proc.stdin.closed:
raise RuntimeError(
"SSH subprocess not initiated. connect() must be called first after close()."
)
written = self.proc.stdin.write(data)
self.proc.stdin.flush()
return written
def sendall(self, data):
def sendall(self, data: Buffer, *args, **kwargs) -> None:
self._write(data)
def send(self, data):
def send(self, data: Buffer, *args, **kwargs) -> int:
return self._write(data)
def recv(self, n):
def recv(self, n: int, *args, **kwargs) -> bytes:
if not self.proc:
raise RuntimeError(
"SSH subprocess not initiated. connect() must be called first."
)
assert self.proc.stdout is not None
return self.proc.stdout.read(n)
def makefile(self, mode):
def makefile(self, mode: str, *args, **kwargs) -> t.IO: # type: ignore
if not self.proc:
self.connect()
self.proc.stdout.channel = self
assert self.proc is not None
assert self.proc.stdout is not None
self.proc.stdout.channel = self # type: ignore
return self.proc.stdout
def close(self):
if not self.proc or self.proc.stdin.closed:
def close(self) -> None:
if not self.proc:
return
assert self.proc.stdin is not None
if self.proc.stdin.closed:
return
self.proc.stdin.write(b"\n\n")
self.proc.stdin.flush()
@ -120,13 +135,19 @@ class SSHSocket(socket.socket):
class SSHConnection(urllib3_connection.HTTPConnection):
def __init__(self, ssh_transport=None, timeout=60, host=None):
def __init__(
self,
*,
ssh_transport=None,
timeout: int = 60,
host: str,
) -> None:
super().__init__("localhost", timeout=timeout)
self.ssh_transport = ssh_transport
self.timeout = timeout
self.ssh_host = host
def connect(self):
def connect(self) -> None:
if self.ssh_transport:
sock = self.ssh_transport.open_session()
sock.settimeout(self.timeout)
@ -142,7 +163,14 @@ class SSHConnection(urllib3_connection.HTTPConnection):
class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
scheme = "ssh"
def __init__(self, ssh_client=None, timeout=60, maxsize=10, host=None):
def __init__(
self,
*,
ssh_client: paramiko.SSHClient | None = None,
timeout: int = 60,
maxsize: int = 10,
host: str,
) -> None:
super().__init__("localhost", timeout=timeout, maxsize=maxsize)
self.ssh_transport = None
self.timeout = timeout
@ -150,13 +178,17 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
self.ssh_transport = ssh_client.get_transport()
self.ssh_host = host
def _new_conn(self):
return SSHConnection(self.ssh_transport, self.timeout, self.ssh_host)
def _new_conn(self) -> SSHConnection:
return SSHConnection(
ssh_transport=self.ssh_transport,
timeout=self.timeout,
host=self.ssh_host,
)
# When re-using connections, urllib3 calls fileno() on our
# SSH channel instance, quickly overloading our fd limit. To avoid this,
# we override _get_conn
def _get_conn(self, timeout):
def _get_conn(self, timeout: int) -> SSHConnection:
conn = None
try:
conn = self.pool.get(block=self.block, timeout=timeout)
@ -176,7 +208,6 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
class SSHHTTPAdapter(BaseHTTPAdapter):
__attrs__ = HTTPAdapter.__attrs__ + [
"pools",
"timeout",
@ -187,13 +218,13 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
def __init__(
self,
base_url,
timeout=60,
pool_connections=constants.DEFAULT_NUM_POOLS,
max_pool_size=constants.DEFAULT_MAX_POOL_SIZE,
shell_out=False,
):
self.ssh_client = None
base_url: str,
timeout: int = 60,
pool_connections: int = constants.DEFAULT_NUM_POOLS,
max_pool_size: int = constants.DEFAULT_MAX_POOL_SIZE,
shell_out: bool = False,
) -> None:
self.ssh_client: paramiko.SSHClient | None = None
if not shell_out:
self._create_paramiko_client(base_url)
self._connect()
@ -209,30 +240,31 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
)
super().__init__()
def _create_paramiko_client(self, base_url):
def _create_paramiko_client(self, base_url: str) -> None:
logging.getLogger("paramiko").setLevel(logging.WARNING)
self.ssh_client = paramiko.SSHClient()
base_url = urlparse(base_url)
self.ssh_params = {
"hostname": base_url.hostname,
"port": base_url.port,
"username": base_url.username,
base_url_p = urlparse(base_url)
assert base_url_p.hostname is not None
self.ssh_params: dict[str, t.Any] = {
"hostname": base_url_p.hostname,
"port": base_url_p.port,
"username": base_url_p.username,
}
ssh_config_file = os.path.expanduser("~/.ssh/config")
if os.path.exists(ssh_config_file):
conf = paramiko.SSHConfig()
with open(ssh_config_file, "rt", encoding="utf-8") as f:
conf.parse(f)
host_config = conf.lookup(base_url.hostname)
host_config = conf.lookup(base_url_p.hostname)
if "proxycommand" in host_config:
self.ssh_params["sock"] = paramiko.ProxyCommand(
host_config["proxycommand"]
)
if "hostname" in host_config:
self.ssh_params["hostname"] = host_config["hostname"]
if base_url.port is None and "port" in host_config:
if base_url_p.port is None and "port" in host_config:
self.ssh_params["port"] = host_config["port"]
if base_url.username is None and "user" in host_config:
if base_url_p.username is None and "user" in host_config:
self.ssh_params["username"] = host_config["user"]
if "identityfile" in host_config:
self.ssh_params["key_filename"] = host_config["identityfile"]
@ -240,11 +272,11 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
self.ssh_client.load_system_host_keys()
self.ssh_client.set_missing_host_key_policy(paramiko.RejectPolicy())
def _connect(self):
def _connect(self) -> None:
if self.ssh_client:
self.ssh_client.connect(**self.ssh_params)
def get_connection(self, url, proxies=None):
def get_connection(self, url: str | bytes, proxies=None) -> SSHConnectionPool:
if not self.ssh_client:
return SSHConnectionPool(
ssh_client=self.ssh_client,
@ -271,7 +303,7 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
return pool
def close(self):
def close(self) -> None:
super().close()
if self.ssh_client:
self.ssh_client.close()

View File

@ -11,9 +11,7 @@
from __future__ import annotations
from ansible_collections.community.docker.plugins.module_utils._version import (
LooseVersion,
)
import typing as t
from .._import_helper import HTTPAdapter, urllib3
from .basehttpadapter import BaseHTTPAdapter
@ -30,14 +28,19 @@ PoolManager = urllib3.poolmanager.PoolManager
class SSLHTTPAdapter(BaseHTTPAdapter):
"""An HTTPS Transport Adapter that uses an arbitrary SSL version."""
__attrs__ = HTTPAdapter.__attrs__ + ["assert_hostname", "ssl_version"]
__attrs__ = HTTPAdapter.__attrs__ + ["assert_hostname"]
def __init__(self, ssl_version=None, assert_hostname=None, **kwargs):
self.ssl_version = ssl_version
def __init__(
self,
assert_hostname: bool | None = None,
**kwargs,
) -> None:
self.assert_hostname = assert_hostname
super().__init__(**kwargs)
def init_poolmanager(self, connections, maxsize, block=False):
def init_poolmanager(
self, connections: int, maxsize: int, block: bool = False, **kwargs: t.Any
) -> None:
kwargs = {
"num_pools": connections,
"maxsize": maxsize,
@ -45,12 +48,10 @@ class SSLHTTPAdapter(BaseHTTPAdapter):
}
if self.assert_hostname is not None:
kwargs["assert_hostname"] = self.assert_hostname
if self.ssl_version and self.can_override_ssl_version():
kwargs["ssl_version"] = self.ssl_version
self.poolmanager = PoolManager(**kwargs)
def get_connection(self, *args, **kwargs):
def get_connection(self, *args, **kwargs) -> urllib3.ConnectionPool:
"""
Ensure assert_hostname is set correctly on our pool
@ -61,15 +62,7 @@ class SSLHTTPAdapter(BaseHTTPAdapter):
conn = super().get_connection(*args, **kwargs)
if (
self.assert_hostname is not None
and conn.assert_hostname != self.assert_hostname
and conn.assert_hostname != self.assert_hostname # type: ignore
):
conn.assert_hostname = self.assert_hostname
conn.assert_hostname = self.assert_hostname # type: ignore
return conn
def can_override_ssl_version(self):
urllib_ver = urllib3.__version__.split("-")[0]
if urllib_ver is None:
return False
if urllib_ver == "dev":
return True
return LooseVersion(urllib_ver) > LooseVersion("1.5")

View File

@ -12,6 +12,7 @@
from __future__ import annotations
import socket
import typing as t
from .. import constants
from .._import_helper import HTTPAdapter, urllib3, urllib3_connection
@ -22,26 +23,25 @@ RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
class UnixHTTPConnection(urllib3_connection.HTTPConnection):
def __init__(self, base_url, unix_socket, timeout=60):
def __init__(self, base_url: str | bytes, unix_socket, timeout: int = 60) -> None:
super().__init__("localhost", timeout=timeout)
self.base_url = base_url
self.unix_socket = unix_socket
self.timeout = timeout
self.disable_buffering = False
def connect(self):
def connect(self) -> None:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.settimeout(self.timeout)
sock.connect(self.unix_socket)
self.sock = sock
def putheader(self, header, *values):
def putheader(self, header: str, *values: str) -> None:
super().putheader(header, *values)
if header == "Connection" and "Upgrade" in values:
self.disable_buffering = True
def response_class(self, sock, *args, **kwargs):
def response_class(self, sock, *args, **kwargs) -> t.Any:
# FIXME: We may need to disable buffering on Py3,
# but there's no clear way to do it at the moment. See:
# https://github.com/docker/docker-py/issues/1799
@ -49,18 +49,23 @@ class UnixHTTPConnection(urllib3_connection.HTTPConnection):
class UnixHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
def __init__(self, base_url, socket_path, timeout=60, maxsize=10):
def __init__(
self,
base_url: str | bytes,
socket_path: str,
timeout: int = 60,
maxsize: int = 10,
) -> None:
super().__init__("localhost", timeout=timeout, maxsize=maxsize)
self.base_url = base_url
self.socket_path = socket_path
self.timeout = timeout
def _new_conn(self):
def _new_conn(self) -> UnixHTTPConnection:
return UnixHTTPConnection(self.base_url, self.socket_path, self.timeout)
class UnixHTTPAdapter(BaseHTTPAdapter):
__attrs__ = HTTPAdapter.__attrs__ + [
"pools",
"socket_path",
@ -70,11 +75,11 @@ class UnixHTTPAdapter(BaseHTTPAdapter):
def __init__(
self,
socket_url,
timeout=60,
pool_connections=constants.DEFAULT_NUM_POOLS,
max_pool_size=constants.DEFAULT_MAX_POOL_SIZE,
):
socket_url: str,
timeout: int = 60,
pool_connections: int = constants.DEFAULT_NUM_POOLS,
max_pool_size: int = constants.DEFAULT_MAX_POOL_SIZE,
) -> None:
socket_path = socket_url.replace("http+unix://", "")
if not socket_path.startswith("/"):
socket_path = "/" + socket_path
@ -86,7 +91,7 @@ class UnixHTTPAdapter(BaseHTTPAdapter):
)
super().__init__()
def get_connection(self, url, proxies=None):
def get_connection(self, url: str | bytes, proxies=None) -> UnixHTTPConnectionPool:
with self.pools.lock:
pool = self.pools.get(url)
if pool:
@ -99,7 +104,7 @@ class UnixHTTPAdapter(BaseHTTPAdapter):
return pool
def request_url(self, request, proxies):
def request_url(self, request, proxies) -> str:
# The select_proxy utility in requests errors out when the provided URL
# does not have a hostname, like is the case when using a UNIX socket.
# Since proxies are an irrelevant notion in the case of UNIX sockets

View File

@ -12,6 +12,7 @@
from __future__ import annotations
import socket
import typing as t
from .._import_helper import urllib3
from ..errors import DockerException
@ -29,11 +30,11 @@ class CancellableStream:
>>> events.close()
"""
def __init__(self, stream, response):
def __init__(self, stream, response) -> None:
self._stream = stream
self._response = response
def __iter__(self):
def __iter__(self) -> t.Self:
return self
def __next__(self):
@ -46,7 +47,7 @@ class CancellableStream:
next = __next__
def close(self):
def close(self) -> None:
"""
Closes the event streaming.
"""

View File

@ -17,15 +17,26 @@ import random
import re
import tarfile
import tempfile
import typing as t
from ..constants import IS_WINDOWS_PLATFORM, WINDOWS_LONGPATH_PREFIX
from . import fnmatch
if t.TYPE_CHECKING:
from collections.abc import Sequence
_SEP = re.compile("/|\\\\") if IS_WINDOWS_PLATFORM else re.compile("/")
def tar(path, exclude=None, dockerfile=None, fileobj=None, gzip=False):
def tar(
path: str,
exclude: list[str] | None = None,
dockerfile: tuple[str, str] | tuple[None, None] | None = None,
fileobj: t.IO[bytes] | None = None,
gzip: bool = False,
) -> t.IO[bytes]:
root = os.path.abspath(path)
exclude = exclude or []
dockerfile = dockerfile or (None, None)
@ -47,7 +58,9 @@ def tar(path, exclude=None, dockerfile=None, fileobj=None, gzip=False):
)
def exclude_paths(root, patterns, dockerfile=None):
def exclude_paths(
root: str, patterns: list[str], dockerfile: str | None = None
) -> set[str]:
"""
Given a root directory path and a list of .dockerignore patterns, return
an iterator of all paths (both regular files and directories) in the root
@ -64,7 +77,7 @@ def exclude_paths(root, patterns, dockerfile=None):
return set(pm.walk(root))
def build_file_list(root):
def build_file_list(root: str) -> list[str]:
files = []
for dirname, dirnames, fnames in os.walk(root):
for filename in fnames + dirnames:
@ -74,7 +87,13 @@ def build_file_list(root):
return files
def create_archive(root, files=None, fileobj=None, gzip=False, extra_files=None):
def create_archive(
root: str,
files: Sequence[str] | None = None,
fileobj: t.IO[bytes] | None = None,
gzip: bool = False,
extra_files: Sequence[tuple[str, str]] | None = None,
) -> t.IO[bytes]:
extra_files = extra_files or []
if not fileobj:
fileobj = tempfile.NamedTemporaryFile()
@ -92,7 +111,7 @@ def create_archive(root, files=None, fileobj=None, gzip=False, extra_files=None)
if i is None:
# This happens when we encounter a socket file. We can safely
# ignore it and proceed.
continue
continue # type: ignore
# Workaround https://bugs.python.org/issue32713
if i.mtime < 0 or i.mtime > 8**11 - 1:
@ -124,11 +143,11 @@ def create_archive(root, files=None, fileobj=None, gzip=False, extra_files=None)
return fileobj
def mkbuildcontext(dockerfile):
def mkbuildcontext(dockerfile: io.BytesIO | t.IO[bytes]) -> t.IO[bytes]:
f = tempfile.NamedTemporaryFile() # pylint: disable=consider-using-with
try:
with tarfile.open(mode="w", fileobj=f) as t:
if isinstance(dockerfile, io.StringIO):
if isinstance(dockerfile, io.StringIO): # type: ignore
raise TypeError("Please use io.BytesIO to create in-memory Dockerfiles")
if isinstance(dockerfile, io.BytesIO):
dfinfo = tarfile.TarInfo("Dockerfile")
@ -144,17 +163,17 @@ def mkbuildcontext(dockerfile):
return f
def split_path(p):
def split_path(p: str) -> list[str]:
return [pt for pt in re.split(_SEP, p) if pt and pt != "."]
def normalize_slashes(p):
def normalize_slashes(p: str) -> str:
if IS_WINDOWS_PLATFORM:
return "/".join(split_path(p))
return p
def walk(root, patterns, default=True):
def walk(root: str, patterns: Sequence[str], default: bool = True) -> t.Generator[str]:
pm = PatternMatcher(patterns)
return pm.walk(root)
@ -162,11 +181,11 @@ def walk(root, patterns, default=True):
# Heavily based on
# https://github.com/moby/moby/blob/master/pkg/fileutils/fileutils.go
class PatternMatcher:
def __init__(self, patterns):
def __init__(self, patterns: Sequence[str]) -> None:
self.patterns = list(filter(lambda p: p.dirs, [Pattern(p) for p in patterns]))
self.patterns.append(Pattern("!.dockerignore"))
def matches(self, filepath):
def matches(self, filepath: str) -> bool:
matched = False
parent_path = os.path.dirname(filepath)
parent_path_dirs = split_path(parent_path)
@ -185,8 +204,8 @@ class PatternMatcher:
return matched
def walk(self, root):
def rec_walk(current_dir):
def walk(self, root: str) -> t.Generator[str]:
def rec_walk(current_dir: str) -> t.Generator[str]:
for f in os.listdir(current_dir):
fpath = os.path.join(os.path.relpath(current_dir, root), f)
if fpath.startswith("." + os.path.sep):
@ -220,7 +239,7 @@ class PatternMatcher:
class Pattern:
def __init__(self, pattern_str):
def __init__(self, pattern_str: str) -> None:
self.exclusion = False
if pattern_str.startswith("!"):
self.exclusion = True
@ -230,8 +249,7 @@ class Pattern:
self.cleaned_pattern = "/".join(self.dirs)
@classmethod
def normalize(cls, p):
def normalize(cls, p: str) -> list[str]:
# Remove trailing spaces
p = p.strip()
@ -256,11 +274,11 @@ class Pattern:
i += 1
return split
def match(self, filepath):
def match(self, filepath: str) -> bool:
return fnmatch.fnmatch(normalize_slashes(filepath), self.cleaned_pattern)
def process_dockerfile(dockerfile, path):
def process_dockerfile(dockerfile: str, path: str) -> tuple[str | None, str | None]:
if not dockerfile:
return (None, None)
@ -268,7 +286,7 @@ def process_dockerfile(dockerfile, path):
if not os.path.isabs(dockerfile):
abs_dockerfile = os.path.join(path, dockerfile)
if IS_WINDOWS_PLATFORM and path.startswith(WINDOWS_LONGPATH_PREFIX):
abs_dockerfile = f"{WINDOWS_LONGPATH_PREFIX}{os.path.normpath(abs_dockerfile[len(WINDOWS_LONGPATH_PREFIX):])}"
abs_dockerfile = f"{WINDOWS_LONGPATH_PREFIX}{os.path.normpath(abs_dockerfile[len(WINDOWS_LONGPATH_PREFIX) :])}"
if os.path.splitdrive(path)[0] != os.path.splitdrive(abs_dockerfile)[
0
] or os.path.relpath(abs_dockerfile, path).startswith(".."):

View File

@ -14,6 +14,7 @@ from __future__ import annotations
import json
import logging
import os
import typing as t
from ..constants import IS_WINDOWS_PLATFORM
@ -24,11 +25,11 @@ LEGACY_DOCKER_CONFIG_FILENAME = ".dockercfg"
log = logging.getLogger(__name__)
def get_default_config_file():
def get_default_config_file() -> str:
return os.path.join(home_dir(), DOCKER_CONFIG_FILENAME)
def find_config_file(config_path=None):
def find_config_file(config_path: str | None = None) -> str | None:
homedir = home_dir()
paths = list(
filter(
@ -54,14 +55,14 @@ def find_config_file(config_path=None):
return None
def config_path_from_environment():
def config_path_from_environment() -> str | None:
config_dir = os.environ.get("DOCKER_CONFIG")
if not config_dir:
return None
return os.path.join(config_dir, os.path.basename(DOCKER_CONFIG_FILENAME))
def home_dir():
def home_dir() -> str:
"""
Get the user's home directory, using the same logic as the Docker Engine
client - use %USERPROFILE% on Windows, $HOME/getuid on POSIX.
@ -71,7 +72,7 @@ def home_dir():
return os.path.expanduser("~")
def load_general_config(config_path=None):
def load_general_config(config_path: str | None = None) -> dict[str, t.Any]:
config_file = find_config_file(config_path)
if not config_file:

View File

@ -12,16 +12,37 @@
from __future__ import annotations
import functools
import typing as t
from .. import errors
from . import utils
def minimum_version(version):
def decorator(f):
if t.TYPE_CHECKING:
from collections.abc import Callable
from ..api.client import APIClient
_Self = t.TypeVar("_Self")
_P = t.ParamSpec("_P")
_R = t.TypeVar("_R")
def minimum_version(
version: str,
) -> Callable[
[Callable[t.Concatenate[_Self, _P], _R]],
Callable[t.Concatenate[_Self, _P], _R],
]:
def decorator(
f: Callable[t.Concatenate[_Self, _P], _R],
) -> Callable[t.Concatenate[_Self, _P], _R]:
@functools.wraps(f)
def wrapper(self, *args, **kwargs):
if utils.version_lt(self._version, version):
def wrapper(self: _Self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
# We use _Self instead of APIClient since this is used for mixins for APIClient.
# This unfortunately means that self._version does not exist in the mixin,
# it only exists after mixing in. This is why we ignore types here.
if utils.version_lt(self._version, version): # type: ignore
raise errors.InvalidVersion(
f"{f.__name__} is not available for version < {version}"
)
@ -32,13 +53,16 @@ def minimum_version(version):
return decorator
def update_headers(f):
def inner(self, *args, **kwargs):
def update_headers(
f: Callable[t.Concatenate[APIClient, _P], _R],
) -> Callable[t.Concatenate[APIClient, _P], _R]:
def inner(self: APIClient, *args: _P.args, **kwargs: _P.kwargs) -> _R:
if "HttpHeaders" in self._general_configs:
if not kwargs.get("headers"):
kwargs["headers"] = self._general_configs["HttpHeaders"]
else:
kwargs["headers"].update(self._general_configs["HttpHeaders"])
# We cannot (yet) model that kwargs["headers"] should be a dictionary
kwargs["headers"].update(self._general_configs["HttpHeaders"]) # type: ignore
return f(self, *args, **kwargs)
return inner

View File

@ -32,12 +32,12 @@ _cache: dict[str, re.Pattern] = {}
_MAXCACHE = 100
def _purge():
def _purge() -> None:
"""Clear the pattern cache"""
_cache.clear()
def fnmatch(name, pat):
def fnmatch(name: str, pat: str):
"""Test whether FILENAME matches PATTERN.
Patterns are Unix shell style:
@ -58,7 +58,7 @@ def fnmatch(name, pat):
return fnmatchcase(name, pat)
def fnmatchcase(name, pat):
def fnmatchcase(name: str, pat: str) -> bool:
"""Test whether FILENAME matches PATTERN, including case.
This is a version of fnmatch() which does not case-normalize
its arguments.
@ -74,7 +74,7 @@ def fnmatchcase(name, pat):
return re_pat.match(name) is not None
def translate(pat):
def translate(pat: str) -> str:
"""Translate a shell PATTERN to a regular expression.
There is no way to quote meta-characters.

View File

@ -13,14 +13,22 @@ from __future__ import annotations
import json
import json.decoder
import typing as t
from ..errors import StreamParseError
if t.TYPE_CHECKING:
import re
from collections.abc import Callable
_T = t.TypeVar("_T")
json_decoder = json.JSONDecoder()
def stream_as_text(stream):
def stream_as_text(stream: t.Generator[bytes | str]) -> t.Generator[str]:
"""
Given a stream of bytes or text, if any of the items in the stream
are bytes convert them to text.
@ -33,20 +41,22 @@ def stream_as_text(stream):
yield data
def json_splitter(buffer):
def json_splitter(buffer: str) -> tuple[t.Any, str] | None:
"""Attempt to parse a json object from a buffer. If there is at least one
object, return it and the rest of the buffer, otherwise return None.
"""
buffer = buffer.strip()
try:
obj, index = json_decoder.raw_decode(buffer)
rest = buffer[json.decoder.WHITESPACE.match(buffer, index).end() :]
ws: re.Pattern = json.decoder.WHITESPACE # type: ignore[attr-defined]
m = ws.match(buffer, index)
rest = buffer[m.end() :] if m else buffer[index:]
return obj, rest
except ValueError:
return None
def json_stream(stream):
def json_stream(stream: t.Generator[str | bytes]) -> t.Generator[t.Any]:
"""Given a stream of text, return a stream of json objects.
This handles streams which are inconsistently buffered (some entries may
be newline delimited, and others are not).
@ -54,21 +64,24 @@ def json_stream(stream):
return split_buffer(stream, json_splitter, json_decoder.decode)
def line_splitter(buffer, separator="\n"):
def line_splitter(buffer: str, separator: str = "\n") -> tuple[str, str] | None:
index = buffer.find(str(separator))
if index == -1:
return None
return buffer[: index + 1], buffer[index + 1 :]
def split_buffer(stream, splitter=None, decoder=lambda a: a):
def split_buffer(
stream: t.Generator[str | bytes],
splitter: Callable[[str], tuple[_T, str] | None],
decoder: Callable[[str], _T],
) -> t.Generator[_T | str]:
"""Given a generator which yields strings and a splitter function,
joins all input, splits on the separator and yields each chunk.
Unlike string.split(), each chunk includes the trailing
separator, except for the last one if none was found on the end
of the input.
"""
splitter = splitter or line_splitter
buffered = ""
for data in stream_as_text(stream):

View File

@ -12,6 +12,11 @@
from __future__ import annotations
import re
import typing as t
if t.TYPE_CHECKING:
from collections.abc import Collection, Sequence
PORT_SPEC = re.compile(
@ -26,32 +31,42 @@ PORT_SPEC = re.compile(
)
def add_port_mapping(port_bindings, internal_port, external):
def add_port_mapping(
port_bindings: dict[str, list[str | tuple[str, str | None] | None]],
internal_port: str,
external: str | tuple[str, str | None] | None,
) -> None:
if internal_port in port_bindings:
port_bindings[internal_port].append(external)
else:
port_bindings[internal_port] = [external]
def add_port(port_bindings, internal_port_range, external_range):
def add_port(
port_bindings: dict[str, list[str | tuple[str, str | None] | None]],
internal_port_range: list[str],
external_range: list[str] | list[tuple[str, str | None]] | None,
) -> None:
if external_range is None:
for internal_port in internal_port_range:
add_port_mapping(port_bindings, internal_port, None)
else:
ports = zip(internal_port_range, external_range)
for internal_port, external_port in ports:
add_port_mapping(port_bindings, internal_port, external_port)
for internal_port, external_port in zip(internal_port_range, external_range):
# mypy loses the exact type of eternal_port elements for some reason...
add_port_mapping(port_bindings, internal_port, external_port) # type: ignore
def build_port_bindings(ports):
port_bindings = {}
def build_port_bindings(
ports: Collection[str],
) -> dict[str, list[str | tuple[str, str | None] | None]]:
port_bindings: dict[str, list[str | tuple[str, str | None] | None]] = {}
for port in ports:
internal_port_range, external_range = split_port(port)
add_port(port_bindings, internal_port_range, external_range)
return port_bindings
def _raise_invalid_port(port):
def _raise_invalid_port(port: str) -> t.NoReturn:
raise ValueError(
f'Invalid port "{port}", should be '
"[[remote_ip:]remote_port[-remote_port]:]"
@ -59,39 +74,64 @@ def _raise_invalid_port(port):
)
def port_range(start, end, proto, randomly_available_port=False):
if not start:
@t.overload
def port_range(
start: str,
end: str | None,
proto: str,
randomly_available_port: bool = False,
) -> list[str]: ...
@t.overload
def port_range(
start: str | None,
end: str | None,
proto: str,
randomly_available_port: bool = False,
) -> list[str] | None: ...
def port_range(
start: str | None,
end: str | None,
proto: str,
randomly_available_port: bool = False,
) -> list[str] | None:
if start is None:
return start
if not end:
if end is None:
return [f"{start}{proto}"]
if randomly_available_port:
return [f"{start}-{end}{proto}"]
return [f"{port}{proto}" for port in range(int(start), int(end) + 1)]
def split_port(port):
if hasattr(port, "legacy_repr"):
# This is the worst hack, but it prevents a bug in Compose 1.14.0
# https://github.com/docker/docker-py/issues/1668
# TODO: remove once fixed in Compose stable
port = port.legacy_repr()
def split_port(
port: str,
) -> tuple[list[str], list[str] | list[tuple[str, str | None]] | None]:
port = str(port)
match = PORT_SPEC.match(port)
if match is None:
_raise_invalid_port(port)
parts = match.groupdict()
host = parts["host"]
proto = parts["proto"] or ""
internal = port_range(parts["int"], parts["int_end"], proto)
external = port_range(parts["ext"], parts["ext_end"], "", len(internal) == 1)
host: str | None = parts["host"]
proto: str = parts["proto"] or ""
int_p: str = parts["int"]
ext_p: str | None = parts["ext"] or None
internal: list[str] = port_range(int_p, parts["int_end"], proto) # type: ignore
external = port_range(ext_p, parts["ext_end"], "", len(internal) == 1)
if host is None:
if external is not None and len(internal) != len(external):
raise ValueError("Port ranges don't match in length")
return internal, external
external_or_none: Sequence[str | None]
if not external:
external = [None] * len(internal)
elif len(internal) != len(external):
raise ValueError("Port ranges don't match in length")
return internal, [(host, ext_port) for ext_port in external]
external_or_none = [None] * len(internal)
else:
external_or_none = external
if len(internal) != len(external_or_none):
raise ValueError("Port ranges don't match in length")
return internal, [(host, ext_port) for ext_port in external_or_none]

View File

@ -20,23 +20,23 @@ class ProxyConfig(dict):
"""
@property
def http(self):
def http(self) -> str | None:
return self.get("http")
@property
def https(self):
def https(self) -> str | None:
return self.get("https")
@property
def ftp(self):
def ftp(self) -> str | None:
return self.get("ftp")
@property
def no_proxy(self):
def no_proxy(self) -> str | None:
return self.get("no_proxy")
@staticmethod
def from_dict(config):
def from_dict(config: dict[str, str]) -> ProxyConfig:
"""
Instantiate a new ProxyConfig from a dictionary that represents a
client configuration, as described in `the documentation`_.
@ -51,7 +51,7 @@ class ProxyConfig(dict):
no_proxy=config.get("noProxy"),
)
def get_environment(self):
def get_environment(self) -> dict[str, str]:
"""
Return a dictionary representing the environment variables used to
set the proxy settings.
@ -67,7 +67,7 @@ class ProxyConfig(dict):
env["no_proxy"] = env["NO_PROXY"] = self.no_proxy
return env
def inject_proxy_environment(self, environment):
def inject_proxy_environment(self, environment: list[str]) -> list[str]:
"""
Given a list of strings representing environment variables, prepend the
environment variables corresponding to the proxy settings.
@ -82,5 +82,5 @@ class ProxyConfig(dict):
# variables defined in "environment" to take precedence.
return proxy_env + environment
def __str__(self):
def __str__(self) -> str:
return f"ProxyConfig(http={self.http}, https={self.https}, ftp={self.ftp}, no_proxy={self.no_proxy})"

View File

@ -16,10 +16,15 @@ import os
import select
import socket as pysocket
import struct
import typing as t
from ..transport.npipesocket import NpipeSocket
if t.TYPE_CHECKING:
from collections.abc import Iterable
STDOUT = 1
STDERR = 2
@ -33,7 +38,7 @@ class SocketError(Exception):
NPIPE_ENDED = 109
def read(socket, n=4096):
def read(socket, n: int = 4096) -> bytes | None:
"""
Reads at most n bytes from socket
"""
@ -58,6 +63,7 @@ def read(socket, n=4096):
except EnvironmentError as e:
if e.errno not in recoverable_errors:
raise
return None # TODO ???
except Exception as e:
is_pipe_ended = (
isinstance(socket, NpipeSocket)
@ -67,11 +73,11 @@ def read(socket, n=4096):
if is_pipe_ended:
# npipes do not support duplex sockets, so we interpret
# a PIPE_ENDED error as a close operation (0-length read).
return ""
return b""
raise
def read_exactly(socket, n):
def read_exactly(socket, n: int) -> bytes:
"""
Reads exactly n bytes from socket
Raises SocketError if there is not enough data
@ -85,7 +91,7 @@ def read_exactly(socket, n):
return data
def next_frame_header(socket):
def next_frame_header(socket) -> tuple[int, int]:
"""
Returns the stream and size of the next frame of data waiting to be read
from socket, according to the protocol defined here:
@ -101,7 +107,7 @@ def next_frame_header(socket):
return (stream, actual)
def frames_iter(socket, tty):
def frames_iter(socket, tty: bool) -> t.Generator[tuple[int, bytes]]:
"""
Return a generator of frames read from socket. A frame is a tuple where
the first item is the stream number and the second item is a chunk of data.
@ -114,7 +120,7 @@ def frames_iter(socket, tty):
return frames_iter_no_tty(socket)
def frames_iter_no_tty(socket):
def frames_iter_no_tty(socket) -> t.Generator[tuple[int, bytes]]:
"""
Returns a generator of data read from the socket when the tty setting is
not enabled.
@ -135,20 +141,34 @@ def frames_iter_no_tty(socket):
yield (stream, result)
def frames_iter_tty(socket):
def frames_iter_tty(socket) -> t.Generator[bytes]:
"""
Return a generator of data read from the socket when the tty setting is
enabled.
"""
while True:
result = read(socket)
if len(result) == 0:
if not result:
# We have reached EOF
return
yield result
def consume_socket_output(frames, demux=False):
@t.overload
def consume_socket_output(frames, demux: t.Literal[False] = False) -> bytes: ...
@t.overload
def consume_socket_output(frames, demux: t.Literal[True]) -> tuple[bytes, bytes]: ...
@t.overload
def consume_socket_output(
frames, demux: bool = False
) -> bytes | tuple[bytes, bytes]: ...
def consume_socket_output(frames, demux: bool = False) -> bytes | tuple[bytes, bytes]:
"""
Iterate through frames read from the socket and return the result.
@ -167,7 +187,7 @@ def consume_socket_output(frames, demux=False):
# If the streams are demultiplexed, the generator yields tuples
# (stdout, stderr)
out = [None, None]
out: list[bytes | None] = [None, None]
for frame in frames:
# It is guaranteed that for each frame, one and only one stream
# is not None.
@ -183,10 +203,10 @@ def consume_socket_output(frames, demux=False):
out[1] = frame[1]
else:
out[1] += frame[1]
return tuple(out)
return tuple(out) # type: ignore
def demux_adaptor(stream_id, data):
def demux_adaptor(stream_id: int, data: bytes) -> tuple[bytes | None, bytes | None]:
"""
Utility to demultiplex stdout and stderr when reading frames from the
socket.

View File

@ -18,6 +18,7 @@ import os
import os.path
import shlex
import string
import typing as t
from urllib.parse import urlparse, urlunparse
from ansible_collections.community.docker.plugins.module_utils._version import (
@ -34,32 +35,23 @@ from ..constants import (
from ..tls import TLSConfig
if t.TYPE_CHECKING:
import ssl
from collections.abc import Mapping, Sequence
URLComponents = collections.namedtuple(
"URLComponents",
"scheme netloc url params query fragment",
)
def create_ipam_pool(*args, **kwargs):
raise errors.DeprecatedMethod(
"utils.create_ipam_pool has been removed. Please use a "
"docker.types.IPAMPool object instead."
)
def create_ipam_config(*args, **kwargs):
raise errors.DeprecatedMethod(
"utils.create_ipam_config has been removed. Please use a "
"docker.types.IPAMConfig object instead."
)
def decode_json_header(header):
def decode_json_header(header: str) -> dict[str, t.Any]:
data = base64.b64decode(header).decode("utf-8")
return json.loads(data)
def compare_version(v1, v2):
def compare_version(v1: str, v2: str) -> t.Literal[-1, 0, 1]:
"""Compare docker versions
>>> v1 = '1.9'
@ -80,43 +72,64 @@ def compare_version(v1, v2):
return 1
def version_lt(v1, v2):
def version_lt(v1: str, v2: str) -> bool:
return compare_version(v1, v2) > 0
def version_gte(v1, v2):
def version_gte(v1: str, v2: str) -> bool:
return not version_lt(v1, v2)
def _convert_port_binding(binding):
def _convert_port_binding(
binding: (
tuple[str, str | int | None]
| tuple[str | int | None]
| dict[str, str]
| str
| int
),
) -> dict[str, str]:
result = {"HostIp": "", "HostPort": ""}
host_port: str | int | None = ""
if isinstance(binding, tuple):
if len(binding) == 2:
result["HostPort"] = binding[1]
host_port = binding[1] # type: ignore
result["HostIp"] = binding[0]
elif isinstance(binding[0], str):
result["HostIp"] = binding[0]
else:
result["HostPort"] = binding[0]
host_port = binding[0]
elif isinstance(binding, dict):
if "HostPort" in binding:
result["HostPort"] = binding["HostPort"]
host_port = binding["HostPort"]
if "HostIp" in binding:
result["HostIp"] = binding["HostIp"]
else:
raise ValueError(binding)
else:
result["HostPort"] = binding
if result["HostPort"] is None:
result["HostPort"] = ""
else:
result["HostPort"] = str(result["HostPort"])
host_port = binding
result["HostPort"] = str(host_port) if host_port is not None else ""
return result
def convert_port_bindings(port_bindings):
def convert_port_bindings(
port_bindings: dict[
str | int,
tuple[str, str | int | None]
| tuple[str | int | None]
| dict[str, str]
| str
| int
| list[
tuple[str, str | int | None]
| tuple[str | int | None]
| dict[str, str]
| str
| int
],
],
) -> dict[str, list[dict[str, str]]]:
result = {}
for k, v in port_bindings.items():
key = str(k)
@ -129,9 +142,11 @@ def convert_port_bindings(port_bindings):
return result
def convert_volume_binds(binds):
def convert_volume_binds(
binds: list[str] | Mapping[str | bytes, dict[str, str | bytes] | bytes | str | int],
) -> list[str]:
if isinstance(binds, list):
return binds
return binds # type: ignore
result = []
for k, v in binds.items():
@ -149,7 +164,7 @@ def convert_volume_binds(binds):
if "ro" in v:
mode = "ro" if v["ro"] else "rw"
elif "mode" in v:
mode = v["mode"]
mode = v["mode"] # type: ignore # TODO
else:
mode = "rw"
@ -165,9 +180,9 @@ def convert_volume_binds(binds):
]
if "propagation" in v and v["propagation"] in propagation_modes:
if mode:
mode = ",".join([mode, v["propagation"]])
mode = ",".join([mode, v["propagation"]]) # type: ignore # TODO
else:
mode = v["propagation"]
mode = v["propagation"] # type: ignore # TODO
result.append(f"{k}:{bind}:{mode}")
else:
@ -177,7 +192,7 @@ def convert_volume_binds(binds):
return result
def convert_tmpfs_mounts(tmpfs):
def convert_tmpfs_mounts(tmpfs: dict[str, str] | list[str]) -> dict[str, str]:
if isinstance(tmpfs, dict):
return tmpfs
@ -204,9 +219,11 @@ def convert_tmpfs_mounts(tmpfs):
return result
def convert_service_networks(networks):
def convert_service_networks(
networks: list[str | dict[str, str]],
) -> list[dict[str, str]]:
if not networks:
return networks
return networks # type: ignore
if not isinstance(networks, list):
raise TypeError("networks parameter must be a list.")
@ -218,17 +235,17 @@ def convert_service_networks(networks):
return result
def parse_repository_tag(repo_name):
def parse_repository_tag(repo_name: str) -> tuple[str, str | None]:
parts = repo_name.rsplit("@", 1)
if len(parts) == 2:
return tuple(parts)
return tuple(parts) # type: ignore
parts = repo_name.rsplit(":", 1)
if len(parts) == 2 and "/" not in parts[1]:
return tuple(parts)
return tuple(parts) # type: ignore
return repo_name, None
def parse_host(addr, is_win32=False, tls=False):
def parse_host(addr: str | None, is_win32: bool = False, tls: bool = False) -> str:
# Sensible defaults
if not addr and is_win32:
return DEFAULT_NPIPE
@ -308,7 +325,7 @@ def parse_host(addr, is_win32=False, tls=False):
).rstrip("/")
def parse_devices(devices):
def parse_devices(devices: Sequence[dict[str, str] | str]) -> list[dict[str, str]]:
device_list = []
for device in devices:
if isinstance(device, dict):
@ -337,7 +354,10 @@ def parse_devices(devices):
return device_list
def kwargs_from_env(ssl_version=None, assert_hostname=None, environment=None):
def kwargs_from_env(
assert_hostname: bool | None = None,
environment: Mapping[str, str] | None = None,
) -> dict[str, t.Any]:
if not environment:
environment = os.environ
host = environment.get("DOCKER_HOST")
@ -347,14 +367,14 @@ def kwargs_from_env(ssl_version=None, assert_hostname=None, environment=None):
# empty string for tls verify counts as "false".
# Any value or 'unset' counts as true.
tls_verify = environment.get("DOCKER_TLS_VERIFY")
if tls_verify == "":
tls_verify_str = environment.get("DOCKER_TLS_VERIFY")
if tls_verify_str == "":
tls_verify = False
else:
tls_verify = tls_verify is not None
tls_verify = tls_verify_str is not None
enable_tls = cert_path or tls_verify
params = {}
params: dict[str, t.Any] = {}
if host:
params["base_url"] = host
@ -377,14 +397,13 @@ def kwargs_from_env(ssl_version=None, assert_hostname=None, environment=None):
),
ca_cert=os.path.join(cert_path, "ca.pem"),
verify=tls_verify,
ssl_version=ssl_version,
assert_hostname=assert_hostname,
)
return params
def convert_filters(filters):
def convert_filters(filters: dict[str, bool | str | list[str]]) -> str:
result = {}
for k, v in filters.items():
if isinstance(v, bool):
@ -397,7 +416,7 @@ def convert_filters(filters):
return json.dumps(result)
def parse_bytes(s):
def parse_bytes(s: int | float | str) -> int | float:
if isinstance(s, (int, float)):
return s
if len(s) == 0:
@ -435,14 +454,16 @@ def parse_bytes(s):
return s
def normalize_links(links):
def normalize_links(links: dict[str, str] | Sequence[tuple[str, str]]) -> list[str]:
if isinstance(links, dict):
links = links.items()
sorted_links = sorted(links.items())
else:
sorted_links = sorted(links)
return [f"{k}:{v}" if v else k for k, v in sorted(links)]
return [f"{k}:{v}" if v else k for k, v in sorted_links]
def parse_env_file(env_file):
def parse_env_file(env_file: str | os.PathLike) -> dict[str, str]:
"""
Reads a line-separated environment file.
The format of each line should be "key=value".
@ -451,7 +472,6 @@ def parse_env_file(env_file):
with open(env_file, "rt", encoding="utf-8") as f:
for line in f:
if line[0] == "#":
continue
@ -471,11 +491,11 @@ def parse_env_file(env_file):
return environment
def split_command(command):
def split_command(command: str) -> list[str]:
return shlex.split(command)
def format_environment(environment):
def format_environment(environment: Mapping[str, str | bytes]) -> list[str]:
def format_env(key, value):
if value is None:
return key
@ -487,16 +507,9 @@ def format_environment(environment):
return [format_env(*var) for var in environment.items()]
def format_extra_hosts(extra_hosts, task=False):
def format_extra_hosts(extra_hosts: Mapping[str, str], task: bool = False) -> list[str]:
# Use format dictated by Swarm API if container is part of a task
if task:
return [f"{v} {k}" for k, v in sorted(extra_hosts.items())]
return [f"{k}:{v}" for k, v in sorted(extra_hosts.items())]
def create_host_config(self, *args, **kwargs):
raise errors.DeprecatedMethod(
"utils.create_host_config has been removed. Please use a "
"docker.types.HostConfig object instead."
)

View File

@ -552,7 +552,7 @@ def _execute_command(
result = client.get_json("/exec/{0}/json", exec_id)
rc = result.get("ExitCode") or 0
rc: int = result.get("ExitCode") or 0
stdout = stdout or b""
stderr = stderr or b""