mirror of
https://github.com/ansible-collections/community.docker.git
synced 2026-03-16 04:04:31 +00:00
Add more type hints.
This commit is contained in:
parent
0ff66e8b24
commit
08960a9317
@ -13,8 +13,9 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import struct
|
||||
from functools import partial
|
||||
import typing as t
|
||||
from urllib.parse import quote
|
||||
|
||||
from .. import auth
|
||||
@ -50,13 +51,13 @@ from ..utils import config, json_stream, utils
|
||||
from ..utils.decorators import update_headers
|
||||
from ..utils.proxy import ProxyConfig
|
||||
from ..utils.socket import consume_socket_output, demux_adaptor, frames_iter
|
||||
from .daemon import DaemonApiMixin
|
||||
from ..utils.decorators import minimum_version
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class APIClient(_Session, DaemonApiMixin):
|
||||
class APIClient(_Session):
|
||||
"""
|
||||
A low-level client for the Docker Engine API.
|
||||
|
||||
@ -105,16 +106,16 @@ class APIClient(_Session, DaemonApiMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url=None,
|
||||
version=None,
|
||||
timeout=DEFAULT_TIMEOUT_SECONDS,
|
||||
tls=False,
|
||||
user_agent=DEFAULT_USER_AGENT,
|
||||
num_pools=None,
|
||||
credstore_env=None,
|
||||
use_ssh_client=False,
|
||||
max_pool_size=DEFAULT_MAX_POOL_SIZE,
|
||||
):
|
||||
base_url: str | None = None,
|
||||
version: str | None = None,
|
||||
timeout: int = DEFAULT_TIMEOUT_SECONDS,
|
||||
tls: bool | TLSConfig = False,
|
||||
user_agent: str = DEFAULT_USER_AGENT,
|
||||
num_pools: int | None = None,
|
||||
credstore_env: dict[str, str] | None = None,
|
||||
use_ssh_client: bool = False,
|
||||
max_pool_size: int = DEFAULT_MAX_POOL_SIZE,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
fail_on_missing_imports()
|
||||
@ -152,6 +153,9 @@ class APIClient(_Session, DaemonApiMixin):
|
||||
else DEFAULT_NUM_POOLS
|
||||
)
|
||||
|
||||
self._custom_adapter: (
|
||||
UnixHTTPAdapter | NpipeHTTPAdapter | SSHHTTPAdapter | SSLHTTPAdapter | None
|
||||
) = None
|
||||
if base_url.startswith("http+unix://"):
|
||||
self._custom_adapter = UnixHTTPAdapter(
|
||||
base_url,
|
||||
@ -223,7 +227,7 @@ class APIClient(_Session, DaemonApiMixin):
|
||||
f"API versions below {MINIMUM_DOCKER_API_VERSION} are no longer supported by this library."
|
||||
)
|
||||
|
||||
def _retrieve_server_version(self):
|
||||
def _retrieve_server_version(self) -> str:
|
||||
try:
|
||||
version_result = self.version(api_version=False)
|
||||
except Exception as e:
|
||||
@ -242,54 +246,87 @@ class APIClient(_Session, DaemonApiMixin):
|
||||
f"Error while fetching server API version: {e}. Response seems to be broken."
|
||||
) from e
|
||||
|
||||
def _set_request_timeout(self, kwargs):
|
||||
def _set_request_timeout(self, kwargs: dict[str, t.Any]) -> dict[str, t.Any]:
|
||||
"""Prepare the kwargs for an HTTP request by inserting the timeout
|
||||
parameter, if not already present."""
|
||||
kwargs.setdefault("timeout", self.timeout)
|
||||
return kwargs
|
||||
|
||||
@update_headers
|
||||
def _post(self, url, **kwargs):
|
||||
def _post(self, url: str, **kwargs):
|
||||
return self.post(url, **self._set_request_timeout(kwargs))
|
||||
|
||||
@update_headers
|
||||
def _get(self, url, **kwargs):
|
||||
def _get(self, url: str, **kwargs):
|
||||
return self.get(url, **self._set_request_timeout(kwargs))
|
||||
|
||||
@update_headers
|
||||
def _head(self, url, **kwargs):
|
||||
def _head(self, url: str, **kwargs):
|
||||
return self.head(url, **self._set_request_timeout(kwargs))
|
||||
|
||||
@update_headers
|
||||
def _put(self, url, **kwargs):
|
||||
def _put(self, url: str, **kwargs):
|
||||
return self.put(url, **self._set_request_timeout(kwargs))
|
||||
|
||||
@update_headers
|
||||
def _delete(self, url, **kwargs):
|
||||
def _delete(self, url: str, **kwargs):
|
||||
return self.delete(url, **self._set_request_timeout(kwargs))
|
||||
|
||||
def _url(self, pathfmt, *args, **kwargs):
|
||||
def _url(self, pathfmt: str, *args: str, versioned_api: bool = True) -> str:
|
||||
for arg in args:
|
||||
if not isinstance(arg, str):
|
||||
raise ValueError(
|
||||
f"Expected a string but found {arg} ({type(arg)}) instead"
|
||||
)
|
||||
|
||||
quote_f = partial(quote, safe="/:")
|
||||
args = map(quote_f, args)
|
||||
q_args = [quote(arg, safe="/:") for arg in args]
|
||||
|
||||
if kwargs.get("versioned_api", True):
|
||||
return f"{self.base_url}/v{self._version}{pathfmt.format(*args)}"
|
||||
return f"{self.base_url}{pathfmt.format(*args)}"
|
||||
if versioned_api:
|
||||
return f"{self.base_url}/v{self._version}{pathfmt.format(*q_args)}"
|
||||
return f"{self.base_url}{pathfmt.format(*q_args)}"
|
||||
|
||||
def _raise_for_status(self, response):
|
||||
def _raise_for_status(self, response) -> None:
|
||||
"""Raises stored :class:`APIError`, if one occurred."""
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except _HTTPError as e:
|
||||
create_api_error_from_http_exception(e)
|
||||
|
||||
def _result(self, response, get_json=False, get_binary=False):
|
||||
@t.overload
|
||||
def _result(
|
||||
self,
|
||||
response,
|
||||
*,
|
||||
get_json: t.Literal[False] = False,
|
||||
get_binary: t.Literal[False] = False,
|
||||
) -> str: ...
|
||||
|
||||
@t.overload
|
||||
def _result(
|
||||
self,
|
||||
response,
|
||||
*,
|
||||
get_json: t.Literal[True],
|
||||
get_binary: t.Literal[False] = False,
|
||||
) -> t.Any: ...
|
||||
|
||||
@t.overload
|
||||
def _result(
|
||||
self,
|
||||
response,
|
||||
*,
|
||||
get_json: t.Literal[False] = False,
|
||||
get_binary: t.Literal[True],
|
||||
) -> bytes: ...
|
||||
|
||||
@t.overload
|
||||
def _result(
|
||||
self, response, *, get_json: bool = False, get_binary: bool = False
|
||||
) -> t.Any | str | bytes: ...
|
||||
|
||||
def _result(
|
||||
self, response, *, get_json: bool = False, get_binary: bool = False
|
||||
) -> t.Any | str | bytes:
|
||||
if get_json and get_binary:
|
||||
raise AssertionError("json and binary must not be both True")
|
||||
self._raise_for_status(response)
|
||||
@ -300,10 +337,10 @@ class APIClient(_Session, DaemonApiMixin):
|
||||
return response.content
|
||||
return response.text
|
||||
|
||||
def _post_json(self, url, data, **kwargs):
|
||||
def _post_json(self, url: str, data: dict[str, str | None] | t.Any, **kwargs):
|
||||
# Go <1.1 cannot unserialize null to a string
|
||||
# so we do this disgusting thing here.
|
||||
data2 = {}
|
||||
data2: dict[str, t.Any] = {}
|
||||
if data is not None and isinstance(data, dict):
|
||||
for k, v in data.items():
|
||||
if v is not None:
|
||||
@ -316,7 +353,7 @@ class APIClient(_Session, DaemonApiMixin):
|
||||
kwargs["headers"]["Content-Type"] = "application/json"
|
||||
return self._post(url, data=json.dumps(data2), **kwargs)
|
||||
|
||||
def _attach_params(self, override=None):
|
||||
def _attach_params(self, override: dict[str, int] | None = None) -> dict[str, int]:
|
||||
return override or {"stdout": 1, "stderr": 1, "stream": 1}
|
||||
|
||||
def _get_raw_response_socket(self, response):
|
||||
@ -341,12 +378,24 @@ class APIClient(_Session, DaemonApiMixin):
|
||||
|
||||
return sock
|
||||
|
||||
def _stream_helper(self, response, decode=False):
|
||||
@t.overload
|
||||
def _stream_helper(
|
||||
self, response, *, decode: t.Literal[False] = False
|
||||
) -> t.Generator[bytes]: ...
|
||||
|
||||
@t.overload
|
||||
def _stream_helper(
|
||||
self, response, *, decode: t.Literal[True]
|
||||
) -> t.Generator[t.Any]: ...
|
||||
|
||||
def _stream_helper(self, response, *, decode: bool = False) -> t.Generator[t.Any]:
|
||||
"""Generator for data coming from a chunked-encoded HTTP response."""
|
||||
|
||||
if response.raw._fp.chunked:
|
||||
if decode:
|
||||
yield from json_stream.json_stream(self._stream_helper(response, False))
|
||||
yield from json_stream.json_stream(
|
||||
self._stream_helper(response, decode=False)
|
||||
)
|
||||
else:
|
||||
reader = response.raw
|
||||
while not reader.closed:
|
||||
@ -362,7 +411,7 @@ class APIClient(_Session, DaemonApiMixin):
|
||||
# encountered an error immediately
|
||||
yield self._result(response, get_json=decode)
|
||||
|
||||
def _multiplexed_buffer_helper(self, response):
|
||||
def _multiplexed_buffer_helper(self, response) -> t.Generator[bytes]:
|
||||
"""A generator of multiplexed data blocks read from a buffered
|
||||
response."""
|
||||
buf = self._result(response, get_binary=True)
|
||||
@ -378,7 +427,7 @@ class APIClient(_Session, DaemonApiMixin):
|
||||
walker = end
|
||||
yield buf[start:end]
|
||||
|
||||
def _multiplexed_response_stream_helper(self, response):
|
||||
def _multiplexed_response_stream_helper(self, response) -> t.Generator[bytes]:
|
||||
"""A generator of multiplexed data blocks coming from a response
|
||||
stream."""
|
||||
|
||||
@ -399,7 +448,19 @@ class APIClient(_Session, DaemonApiMixin):
|
||||
break
|
||||
yield data
|
||||
|
||||
def _stream_raw_result(self, response, chunk_size=1, decode=True):
|
||||
@t.overload
|
||||
def _stream_raw_result(
|
||||
self, response, *, chunk_size: int = 1, decode: t.Literal[True] = True
|
||||
) -> t.Generator[str]: ...
|
||||
|
||||
@t.overload
|
||||
def _stream_raw_result(
|
||||
self, response, *, chunk_size: int = 1, decode: t.Literal[False]
|
||||
) -> t.Generator[bytes]: ...
|
||||
|
||||
def _stream_raw_result(
|
||||
self, response, *, chunk_size: int = 1, decode: bool = True
|
||||
) -> t.Generator[str | bytes]:
|
||||
"""Stream result for TTY-enabled container and raw binary data"""
|
||||
self._raise_for_status(response)
|
||||
|
||||
@ -410,14 +471,81 @@ class APIClient(_Session, DaemonApiMixin):
|
||||
|
||||
yield from response.iter_content(chunk_size, decode)
|
||||
|
||||
def _read_from_socket(self, response, stream, tty=True, demux=False):
|
||||
@t.overload
|
||||
def _read_from_socket(
|
||||
self,
|
||||
response,
|
||||
*,
|
||||
stream: t.Literal[True],
|
||||
tty: bool = True,
|
||||
demux: t.Literal[False] = False,
|
||||
) -> t.Generator[bytes]: ...
|
||||
|
||||
@t.overload
|
||||
def _read_from_socket(
|
||||
self,
|
||||
response,
|
||||
*,
|
||||
stream: t.Literal[True],
|
||||
tty: t.Literal[True] = True,
|
||||
demux: t.Literal[True],
|
||||
) -> t.Generator[tuple[bytes, None]]: ...
|
||||
|
||||
@t.overload
|
||||
def _read_from_socket(
|
||||
self,
|
||||
response,
|
||||
*,
|
||||
stream: t.Literal[True],
|
||||
tty: t.Literal[False],
|
||||
demux: t.Literal[True],
|
||||
) -> t.Generator[tuple[bytes, None] | tuple[None, bytes]]: ...
|
||||
|
||||
@t.overload
|
||||
def _read_from_socket(
|
||||
self,
|
||||
response,
|
||||
*,
|
||||
stream: t.Literal[False],
|
||||
tty: bool = True,
|
||||
demux: t.Literal[False] = False,
|
||||
) -> bytes: ...
|
||||
|
||||
@t.overload
|
||||
def _read_from_socket(
|
||||
self,
|
||||
response,
|
||||
*,
|
||||
stream: t.Literal[False],
|
||||
tty: t.Literal[True] = True,
|
||||
demux: t.Literal[True],
|
||||
) -> tuple[bytes, None]: ...
|
||||
|
||||
@t.overload
|
||||
def _read_from_socket(
|
||||
self,
|
||||
response,
|
||||
*,
|
||||
stream: t.Literal[False],
|
||||
tty: t.Literal[False],
|
||||
demux: t.Literal[True],
|
||||
) -> tuple[bytes, bytes]: ...
|
||||
|
||||
@t.overload
|
||||
def _read_from_socket(
|
||||
self, response, *, stream: bool, tty: bool = True, demux: bool = False
|
||||
) -> t.Any: ...
|
||||
|
||||
def _read_from_socket(
|
||||
self, response, *, stream: bool, tty: bool = True, demux: bool = False
|
||||
) -> t.Any:
|
||||
"""Consume all data from the socket, close the response and return the
|
||||
data. If stream=True, then a generator is returned instead and the
|
||||
caller is responsible for closing the response.
|
||||
"""
|
||||
socket = self._get_raw_response_socket(response)
|
||||
|
||||
gen = frames_iter(socket, tty)
|
||||
gen: t.Generator = frames_iter(socket, tty)
|
||||
|
||||
if demux:
|
||||
# The generator will output tuples (stdout, stderr)
|
||||
@ -434,7 +562,7 @@ class APIClient(_Session, DaemonApiMixin):
|
||||
finally:
|
||||
response.close()
|
||||
|
||||
def _disable_socket_timeout(self, socket):
|
||||
def _disable_socket_timeout(self, socket) -> None:
|
||||
"""Depending on the combination of python version and whether we are
|
||||
connecting over http or https, we might need to access _sock, which
|
||||
may or may not exist; or we may need to just settimeout on socket
|
||||
@ -462,7 +590,27 @@ class APIClient(_Session, DaemonApiMixin):
|
||||
|
||||
s.settimeout(None)
|
||||
|
||||
def _get_result_tty(self, stream, res, is_tty):
|
||||
@t.overload
|
||||
def _get_result_tty(
|
||||
self, stream: t.Literal[True], res, is_tty: t.Literal[True]
|
||||
) -> t.Generator[str]: ...
|
||||
|
||||
@t.overload
|
||||
def _get_result_tty(
|
||||
self, stream: t.Literal[True], res, is_tty: t.Literal[False]
|
||||
) -> t.Generator[bytes]: ...
|
||||
|
||||
@t.overload
|
||||
def _get_result_tty(
|
||||
self, stream: t.Literal[False], res, is_tty: t.Literal[True]
|
||||
) -> bytes: ...
|
||||
|
||||
@t.overload
|
||||
def _get_result_tty(
|
||||
self, stream: t.Literal[False], res, is_tty: t.Literal[False]
|
||||
) -> bytes: ...
|
||||
|
||||
def _get_result_tty(self, stream: bool, res, is_tty: bool) -> t.Any:
|
||||
# We should also use raw streaming (without keep-alive)
|
||||
# if we are dealing with a tty-enabled container.
|
||||
if is_tty:
|
||||
@ -478,11 +626,11 @@ class APIClient(_Session, DaemonApiMixin):
|
||||
return self._multiplexed_response_stream_helper(res)
|
||||
return sep.join(list(self._multiplexed_buffer_helper(res)))
|
||||
|
||||
def _unmount(self, *args):
|
||||
def _unmount(self, *args) -> None:
|
||||
for proto in args:
|
||||
self.adapters.pop(proto)
|
||||
|
||||
def get_adapter(self, url):
|
||||
def get_adapter(self, url: str):
|
||||
try:
|
||||
return super().get_adapter(url)
|
||||
except _InvalidSchema as e:
|
||||
@ -491,10 +639,10 @@ class APIClient(_Session, DaemonApiMixin):
|
||||
raise e
|
||||
|
||||
@property
|
||||
def api_version(self):
|
||||
def api_version(self) -> str:
|
||||
return self._version
|
||||
|
||||
def reload_config(self, dockercfg_path=None):
|
||||
def reload_config(self, dockercfg_path: str | None = None) -> None:
|
||||
"""
|
||||
Force a reload of the auth configuration
|
||||
|
||||
@ -510,7 +658,7 @@ class APIClient(_Session, DaemonApiMixin):
|
||||
dockercfg_path, credstore_env=self.credstore_env
|
||||
)
|
||||
|
||||
def _set_auth_headers(self, headers):
|
||||
def _set_auth_headers(self, headers: dict[str, str | bytes]) -> None:
|
||||
log.debug("Looking for auth config")
|
||||
|
||||
# If we do not have any auth data so far, try reloading the config
|
||||
@ -537,57 +685,62 @@ class APIClient(_Session, DaemonApiMixin):
|
||||
else:
|
||||
log.debug("No auth config found")
|
||||
|
||||
def get_binary(self, pathfmt, *args, **kwargs):
|
||||
def get_binary(self, pathfmt: str, *args: str, **kwargs) -> bytes:
|
||||
return self._result(
|
||||
self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs),
|
||||
get_binary=True,
|
||||
)
|
||||
|
||||
def get_json(self, pathfmt, *args, **kwargs):
|
||||
def get_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any:
|
||||
return self._result(
|
||||
self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs),
|
||||
get_json=True,
|
||||
)
|
||||
|
||||
def get_text(self, pathfmt, *args, **kwargs):
|
||||
def get_text(self, pathfmt: str, *args: str, **kwargs) -> str:
|
||||
return self._result(
|
||||
self._get(self._url(pathfmt, *args, versioned_api=True), **kwargs)
|
||||
)
|
||||
|
||||
def get_raw_stream(self, pathfmt, *args, **kwargs):
|
||||
chunk_size = kwargs.pop("chunk_size", DEFAULT_DATA_CHUNK_SIZE)
|
||||
def get_raw_stream(
|
||||
self,
|
||||
pathfmt: str,
|
||||
*args: str,
|
||||
chunk_size: int = DEFAULT_DATA_CHUNK_SIZE,
|
||||
**kwargs,
|
||||
) -> t.Generator[bytes]:
|
||||
res = self._get(
|
||||
self._url(pathfmt, *args, versioned_api=True), stream=True, **kwargs
|
||||
)
|
||||
self._raise_for_status(res)
|
||||
return self._stream_raw_result(res, chunk_size, False)
|
||||
return self._stream_raw_result(res, chunk_size=chunk_size, decode=False)
|
||||
|
||||
def delete_call(self, pathfmt, *args, **kwargs):
|
||||
def delete_call(self, pathfmt: str, *args: str, **kwargs) -> None:
|
||||
self._raise_for_status(
|
||||
self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs)
|
||||
)
|
||||
|
||||
def delete_json(self, pathfmt, *args, **kwargs):
|
||||
def delete_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any:
|
||||
return self._result(
|
||||
self._delete(self._url(pathfmt, *args, versioned_api=True), **kwargs),
|
||||
get_json=True,
|
||||
)
|
||||
|
||||
def post_call(self, pathfmt, *args, **kwargs):
|
||||
def post_call(self, pathfmt: str, *args: str, **kwargs) -> None:
|
||||
self._raise_for_status(
|
||||
self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs)
|
||||
)
|
||||
|
||||
def post_json(self, pathfmt, *args, **kwargs):
|
||||
data = kwargs.pop("data", None)
|
||||
def post_json(self, pathfmt: str, *args: str, data: t.Any = None, **kwargs) -> None:
|
||||
self._raise_for_status(
|
||||
self._post_json(
|
||||
self._url(pathfmt, *args, versioned_api=True), data, **kwargs
|
||||
)
|
||||
)
|
||||
|
||||
def post_json_to_binary(self, pathfmt, *args, **kwargs):
|
||||
data = kwargs.pop("data", None)
|
||||
def post_json_to_binary(
|
||||
self, pathfmt: str, *args: str, data: t.Any = None, **kwargs
|
||||
) -> bytes:
|
||||
return self._result(
|
||||
self._post_json(
|
||||
self._url(pathfmt, *args, versioned_api=True), data, **kwargs
|
||||
@ -595,8 +748,9 @@ class APIClient(_Session, DaemonApiMixin):
|
||||
get_binary=True,
|
||||
)
|
||||
|
||||
def post_json_to_json(self, pathfmt, *args, **kwargs):
|
||||
data = kwargs.pop("data", None)
|
||||
def post_json_to_json(
|
||||
self, pathfmt: str, *args: str, data: t.Any = None, **kwargs
|
||||
) -> t.Any:
|
||||
return self._result(
|
||||
self._post_json(
|
||||
self._url(pathfmt, *args, versioned_api=True), data, **kwargs
|
||||
@ -604,17 +758,24 @@ class APIClient(_Session, DaemonApiMixin):
|
||||
get_json=True,
|
||||
)
|
||||
|
||||
def post_json_to_text(self, pathfmt, *args, **kwargs):
|
||||
data = kwargs.pop("data", None)
|
||||
def post_json_to_text(
|
||||
self, pathfmt: str, *args: str, data: t.Any = None, **kwargs
|
||||
) -> str:
|
||||
return self._result(
|
||||
self._post_json(
|
||||
self._url(pathfmt, *args, versioned_api=True), data, **kwargs
|
||||
),
|
||||
)
|
||||
|
||||
def post_json_to_stream_socket(self, pathfmt, *args, **kwargs):
|
||||
data = kwargs.pop("data", None)
|
||||
headers = (kwargs.pop("headers", None) or {}).copy()
|
||||
def post_json_to_stream_socket(
|
||||
self,
|
||||
pathfmt: str,
|
||||
*args: str,
|
||||
data: t.Any = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
headers = headers.copy() if headers else {}
|
||||
headers.update(
|
||||
{
|
||||
"Connection": "Upgrade",
|
||||
@ -631,18 +792,102 @@ class APIClient(_Session, DaemonApiMixin):
|
||||
)
|
||||
)
|
||||
|
||||
def post_json_to_stream(self, pathfmt, *args, **kwargs):
|
||||
data = kwargs.pop("data", None)
|
||||
headers = (kwargs.pop("headers", None) or {}).copy()
|
||||
@t.overload
|
||||
def post_json_to_stream(
|
||||
self,
|
||||
pathfmt: str,
|
||||
*args: str,
|
||||
data: t.Any = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
stream: t.Literal[True],
|
||||
tty: bool = True,
|
||||
demux: t.Literal[False] = False,
|
||||
**kwargs,
|
||||
) -> t.Generator[bytes]: ...
|
||||
|
||||
@t.overload
|
||||
def post_json_to_stream(
|
||||
self,
|
||||
pathfmt: str,
|
||||
*args: str,
|
||||
data: t.Any = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
stream: t.Literal[True],
|
||||
tty: t.Literal[True] = True,
|
||||
demux: t.Literal[True],
|
||||
**kwargs,
|
||||
) -> t.Generator[tuple[bytes, None]]: ...
|
||||
|
||||
@t.overload
|
||||
def post_json_to_stream(
|
||||
self,
|
||||
pathfmt: str,
|
||||
*args: str,
|
||||
data: t.Any = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
stream: t.Literal[True],
|
||||
tty: t.Literal[False],
|
||||
demux: t.Literal[True],
|
||||
**kwargs,
|
||||
) -> t.Generator[tuple[bytes, None] | tuple[None, bytes]]: ...
|
||||
|
||||
@t.overload
|
||||
def post_json_to_stream(
|
||||
self,
|
||||
pathfmt: str,
|
||||
*args: str,
|
||||
data: t.Any = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
stream: t.Literal[False],
|
||||
tty: bool = True,
|
||||
demux: t.Literal[False] = False,
|
||||
**kwargs,
|
||||
) -> bytes: ...
|
||||
|
||||
@t.overload
|
||||
def post_json_to_stream(
|
||||
self,
|
||||
pathfmt: str,
|
||||
*args: str,
|
||||
data: t.Any = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
stream: t.Literal[False],
|
||||
tty: t.Literal[True] = True,
|
||||
demux: t.Literal[True],
|
||||
**kwargs,
|
||||
) -> tuple[bytes, None]: ...
|
||||
|
||||
@t.overload
|
||||
def post_json_to_stream(
|
||||
self,
|
||||
pathfmt: str,
|
||||
*args: str,
|
||||
data: t.Any = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
stream: t.Literal[False],
|
||||
tty: t.Literal[False],
|
||||
demux: t.Literal[True],
|
||||
**kwargs,
|
||||
) -> tuple[bytes, bytes]: ...
|
||||
|
||||
def post_json_to_stream(
|
||||
self,
|
||||
pathfmt: str,
|
||||
*args: str,
|
||||
data: t.Any = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
stream: bool = False,
|
||||
demux: bool = False,
|
||||
tty: bool = False,
|
||||
**kwargs,
|
||||
) -> t.Any:
|
||||
headers = headers.copy() if headers else {}
|
||||
headers.update(
|
||||
{
|
||||
"Connection": "Upgrade",
|
||||
"Upgrade": "tcp",
|
||||
}
|
||||
)
|
||||
stream = kwargs.pop("stream", False)
|
||||
demux = kwargs.pop("demux", False)
|
||||
tty = kwargs.pop("tty", False)
|
||||
return self._read_from_socket(
|
||||
self._post_json(
|
||||
self._url(pathfmt, *args, versioned_api=True),
|
||||
@ -651,13 +896,133 @@ class APIClient(_Session, DaemonApiMixin):
|
||||
stream=True,
|
||||
**kwargs,
|
||||
),
|
||||
stream,
|
||||
stream=stream,
|
||||
tty=tty,
|
||||
demux=demux,
|
||||
)
|
||||
|
||||
def post_to_json(self, pathfmt, *args, **kwargs):
|
||||
def post_to_json(self, pathfmt: str, *args: str, **kwargs) -> t.Any:
|
||||
return self._result(
|
||||
self._post(self._url(pathfmt, *args, versioned_api=True), **kwargs),
|
||||
get_json=True,
|
||||
)
|
||||
|
||||
@minimum_version("1.25")
|
||||
def df(self) -> dict[str, t.Any]:
|
||||
"""
|
||||
Get data usage information.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary representing different resource categories
|
||||
and their respective data usage.
|
||||
|
||||
Raises:
|
||||
:py:class:`docker.errors.APIError`
|
||||
If the server returns an error.
|
||||
"""
|
||||
url = self._url("/system/df")
|
||||
return self._result(self._get(url), get_json=True)
|
||||
|
||||
def info(self) -> dict[str, t.Any]:
|
||||
"""
|
||||
Display system-wide information. Identical to the ``docker info``
|
||||
command.
|
||||
|
||||
Returns:
|
||||
(dict): The info as a dict
|
||||
|
||||
Raises:
|
||||
:py:class:`docker.errors.APIError`
|
||||
If the server returns an error.
|
||||
"""
|
||||
return self._result(self._get(self._url("/info")), get_json=True)
|
||||
|
||||
def login(
|
||||
self,
|
||||
username: str,
|
||||
password: str | None = None,
|
||||
email: str | None = None,
|
||||
registry: str | None = None,
|
||||
reauth: bool = False,
|
||||
dockercfg_path: str | None = None,
|
||||
) -> dict[str, t.Any]:
|
||||
"""
|
||||
Authenticate with a registry. Similar to the ``docker login`` command.
|
||||
|
||||
Args:
|
||||
username (str): The registry username
|
||||
password (str): The plaintext password
|
||||
email (str): The email for the registry account
|
||||
registry (str): URL to the registry. E.g.
|
||||
``https://index.docker.io/v1/``
|
||||
reauth (bool): Whether or not to refresh existing authentication on
|
||||
the Docker server.
|
||||
dockercfg_path (str): Use a custom path for the Docker config file
|
||||
(default ``$HOME/.docker/config.json`` if present,
|
||||
otherwise ``$HOME/.dockercfg``)
|
||||
|
||||
Returns:
|
||||
(dict): The response from the login request
|
||||
|
||||
Raises:
|
||||
:py:class:`docker.errors.APIError`
|
||||
If the server returns an error.
|
||||
"""
|
||||
|
||||
# If we do not have any auth data so far, try reloading the config file
|
||||
# one more time in case anything showed up in there.
|
||||
# If dockercfg_path is passed check to see if the config file exists,
|
||||
# if so load that config.
|
||||
if dockercfg_path and os.path.exists(dockercfg_path):
|
||||
self._auth_configs = auth.load_config(
|
||||
dockercfg_path, credstore_env=self.credstore_env
|
||||
)
|
||||
elif not self._auth_configs or self._auth_configs.is_empty:
|
||||
self._auth_configs = auth.load_config(credstore_env=self.credstore_env)
|
||||
|
||||
authcfg = self._auth_configs.resolve_authconfig(registry)
|
||||
# If we found an existing auth config for this registry and username
|
||||
# combination, we can return it immediately unless reauth is requested.
|
||||
if authcfg and authcfg.get("username", None) == username and not reauth:
|
||||
return authcfg
|
||||
|
||||
req_data = {
|
||||
"username": username,
|
||||
"password": password,
|
||||
"email": email,
|
||||
"serveraddress": registry,
|
||||
}
|
||||
|
||||
response = self._post_json(self._url("/auth"), data=req_data)
|
||||
if response.status_code == 200:
|
||||
self._auth_configs.add_auth(registry or auth.INDEX_NAME, req_data)
|
||||
return self._result(response, get_json=True)
|
||||
|
||||
def ping(self) -> bool:
|
||||
"""
|
||||
Checks the server is responsive. An exception will be raised if it
|
||||
is not responding.
|
||||
|
||||
Returns:
|
||||
(bool) The response from the server.
|
||||
|
||||
Raises:
|
||||
:py:class:`docker.errors.APIError`
|
||||
If the server returns an error.
|
||||
"""
|
||||
return self._result(self._get(self._url("/_ping"))) == "OK"
|
||||
|
||||
def version(self, api_version: bool = True) -> dict[str, t.Any]:
|
||||
"""
|
||||
Returns version information from the server. Similar to the ``docker
|
||||
version`` command.
|
||||
|
||||
Returns:
|
||||
(dict): The server version information
|
||||
|
||||
Raises:
|
||||
:py:class:`docker.errors.APIError`
|
||||
If the server returns an error.
|
||||
"""
|
||||
url = self._url("/version", versioned_api=api_version)
|
||||
return self._result(self._get(url), get_json=True)
|
||||
|
||||
@ -1,139 +0,0 @@
|
||||
# This code is part of the Ansible collection community.docker, but is an independent component.
|
||||
# This particular file, and this file only, is based on the Docker SDK for Python (https://github.com/docker/docker-py/)
|
||||
#
|
||||
# Copyright (c) 2016-2022 Docker, Inc.
|
||||
#
|
||||
# It is licensed under the Apache 2.0 license (see LICENSES/Apache-2.0.txt in this collection)
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
|
||||
# Do not use this from other collections or standalone plugins/modules!
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from .. import auth
|
||||
from ..utils.decorators import minimum_version
|
||||
|
||||
|
||||
class DaemonApiMixin:
|
||||
@minimum_version("1.25")
|
||||
def df(self):
|
||||
"""
|
||||
Get data usage information.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary representing different resource categories
|
||||
and their respective data usage.
|
||||
|
||||
Raises:
|
||||
:py:class:`docker.errors.APIError`
|
||||
If the server returns an error.
|
||||
"""
|
||||
url = self._url("/system/df")
|
||||
return self._result(self._get(url), get_json=True)
|
||||
|
||||
def info(self):
|
||||
"""
|
||||
Display system-wide information. Identical to the ``docker info``
|
||||
command.
|
||||
|
||||
Returns:
|
||||
(dict): The info as a dict
|
||||
|
||||
Raises:
|
||||
:py:class:`docker.errors.APIError`
|
||||
If the server returns an error.
|
||||
"""
|
||||
return self._result(self._get(self._url("/info")), get_json=True)
|
||||
|
||||
def login(
|
||||
self,
|
||||
username,
|
||||
password=None,
|
||||
email=None,
|
||||
registry=None,
|
||||
reauth=False,
|
||||
dockercfg_path=None,
|
||||
):
|
||||
"""
|
||||
Authenticate with a registry. Similar to the ``docker login`` command.
|
||||
|
||||
Args:
|
||||
username (str): The registry username
|
||||
password (str): The plaintext password
|
||||
email (str): The email for the registry account
|
||||
registry (str): URL to the registry. E.g.
|
||||
``https://index.docker.io/v1/``
|
||||
reauth (bool): Whether or not to refresh existing authentication on
|
||||
the Docker server.
|
||||
dockercfg_path (str): Use a custom path for the Docker config file
|
||||
(default ``$HOME/.docker/config.json`` if present,
|
||||
otherwise ``$HOME/.dockercfg``)
|
||||
|
||||
Returns:
|
||||
(dict): The response from the login request
|
||||
|
||||
Raises:
|
||||
:py:class:`docker.errors.APIError`
|
||||
If the server returns an error.
|
||||
"""
|
||||
|
||||
# If we do not have any auth data so far, try reloading the config file
|
||||
# one more time in case anything showed up in there.
|
||||
# If dockercfg_path is passed check to see if the config file exists,
|
||||
# if so load that config.
|
||||
if dockercfg_path and os.path.exists(dockercfg_path):
|
||||
self._auth_configs = auth.load_config(
|
||||
dockercfg_path, credstore_env=self.credstore_env
|
||||
)
|
||||
elif not self._auth_configs or self._auth_configs.is_empty:
|
||||
self._auth_configs = auth.load_config(credstore_env=self.credstore_env)
|
||||
|
||||
authcfg = self._auth_configs.resolve_authconfig(registry)
|
||||
# If we found an existing auth config for this registry and username
|
||||
# combination, we can return it immediately unless reauth is requested.
|
||||
if authcfg and authcfg.get("username", None) == username and not reauth:
|
||||
return authcfg
|
||||
|
||||
req_data = {
|
||||
"username": username,
|
||||
"password": password,
|
||||
"email": email,
|
||||
"serveraddress": registry,
|
||||
}
|
||||
|
||||
response = self._post_json(self._url("/auth"), data=req_data)
|
||||
if response.status_code == 200:
|
||||
self._auth_configs.add_auth(registry or auth.INDEX_NAME, req_data)
|
||||
return self._result(response, get_json=True)
|
||||
|
||||
def ping(self):
|
||||
"""
|
||||
Checks the server is responsive. An exception will be raised if it
|
||||
is not responding.
|
||||
|
||||
Returns:
|
||||
(bool) The response from the server.
|
||||
|
||||
Raises:
|
||||
:py:class:`docker.errors.APIError`
|
||||
If the server returns an error.
|
||||
"""
|
||||
return self._result(self._get(self._url("/_ping"))) == "OK"
|
||||
|
||||
def version(self, api_version=True):
|
||||
"""
|
||||
Returns version information from the server. Similar to the ``docker
|
||||
version`` command.
|
||||
|
||||
Returns:
|
||||
(dict): The server version information
|
||||
|
||||
Raises:
|
||||
:py:class:`docker.errors.APIError`
|
||||
If the server returns an error.
|
||||
"""
|
||||
url = self._url("/version", versioned_api=api_version)
|
||||
return self._result(self._get(url), get_json=True)
|
||||
@ -14,6 +14,7 @@ from __future__ import annotations
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
from . import errors
|
||||
from .credentials.errors import CredentialsNotFound, StoreError
|
||||
@ -21,6 +22,12 @@ from .credentials.store import Store
|
||||
from .utils import config
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible_collections.community.docker.plugins.module_utils._api.api.client import (
|
||||
APIClient,
|
||||
)
|
||||
|
||||
|
||||
INDEX_NAME = "docker.io"
|
||||
INDEX_URL = f"https://index.{INDEX_NAME}/v1/"
|
||||
TOKEN_USERNAME = "<token>"
|
||||
@ -28,7 +35,7 @@ TOKEN_USERNAME = "<token>"
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def resolve_repository_name(repo_name):
|
||||
def resolve_repository_name(repo_name: str) -> tuple[str, str]:
|
||||
if "://" in repo_name:
|
||||
raise errors.InvalidRepository(
|
||||
f"Repository name cannot contain a scheme ({repo_name})"
|
||||
@ -42,14 +49,14 @@ def resolve_repository_name(repo_name):
|
||||
return resolve_index_name(index_name), remote_name
|
||||
|
||||
|
||||
def resolve_index_name(index_name):
|
||||
def resolve_index_name(index_name: str) -> str:
|
||||
index_name = convert_to_hostname(index_name)
|
||||
if index_name == "index." + INDEX_NAME:
|
||||
index_name = INDEX_NAME
|
||||
return index_name
|
||||
|
||||
|
||||
def get_config_header(client, registry):
|
||||
def get_config_header(client: APIClient, registry: str) -> bytes | None:
|
||||
log.debug("Looking for auth config")
|
||||
if not client._auth_configs or client._auth_configs.is_empty:
|
||||
log.debug("No auth config in memory - loading from filesystem")
|
||||
@ -69,32 +76,38 @@ def get_config_header(client, registry):
|
||||
return None
|
||||
|
||||
|
||||
def split_repo_name(repo_name):
|
||||
def split_repo_name(repo_name: str) -> tuple[str, str]:
|
||||
parts = repo_name.split("/", 1)
|
||||
if len(parts) == 1 or (
|
||||
"." not in parts[0] and ":" not in parts[0] and parts[0] != "localhost"
|
||||
):
|
||||
# This is a docker index repo (ex: username/foobar or ubuntu)
|
||||
return INDEX_NAME, repo_name
|
||||
return tuple(parts)
|
||||
return tuple(parts) # type: ignore
|
||||
|
||||
|
||||
def get_credential_store(authconfig, registry):
|
||||
def get_credential_store(
|
||||
authconfig: dict[str, t.Any] | AuthConfig, registry: str
|
||||
) -> str | None:
|
||||
if not isinstance(authconfig, AuthConfig):
|
||||
authconfig = AuthConfig(authconfig)
|
||||
return authconfig.get_credential_store(registry)
|
||||
|
||||
|
||||
class AuthConfig(dict):
|
||||
def __init__(self, dct, credstore_env=None):
|
||||
def __init__(
|
||||
self, dct: dict[str, t.Any], credstore_env: dict[str, str] | None = None
|
||||
):
|
||||
if "auths" not in dct:
|
||||
dct["auths"] = {}
|
||||
self.update(dct)
|
||||
self._credstore_env = credstore_env
|
||||
self._stores = {}
|
||||
self._stores: dict[str, Store] = {}
|
||||
|
||||
@classmethod
|
||||
def parse_auth(cls, entries, raise_on_error=False):
|
||||
def parse_auth(
|
||||
cls, entries: dict[str, dict[str, t.Any]], raise_on_error=False
|
||||
) -> dict[str, dict[str, t.Any]]:
|
||||
"""
|
||||
Parses authentication entries
|
||||
|
||||
@ -107,10 +120,10 @@ class AuthConfig(dict):
|
||||
Authentication registry.
|
||||
"""
|
||||
|
||||
conf = {}
|
||||
conf: dict[str, dict[str, t.Any]] = {}
|
||||
for registry, entry in entries.items():
|
||||
if not isinstance(entry, dict):
|
||||
log.debug("Config entry for key %s is not auth config", registry)
|
||||
log.debug("Config entry for key %s is not auth config", registry) # type: ignore
|
||||
# We sometimes fall back to parsing the whole config as if it
|
||||
# was the auth config by itself, for legacy purposes. In that
|
||||
# case, we fail silently and return an empty conf if any of the
|
||||
@ -150,7 +163,12 @@ class AuthConfig(dict):
|
||||
return conf
|
||||
|
||||
@classmethod
|
||||
def load_config(cls, config_path, config_dict, credstore_env=None):
|
||||
def load_config(
|
||||
cls,
|
||||
config_path: str | None,
|
||||
config_dict: dict[str, t.Any] | None,
|
||||
credstore_env: dict[str, str] | None = None,
|
||||
) -> t.Self:
|
||||
"""
|
||||
Loads authentication data from a Docker configuration file in the given
|
||||
root directory or if config_path is passed use given path.
|
||||
@ -196,22 +214,24 @@ class AuthConfig(dict):
|
||||
return cls({"auths": cls.parse_auth(config_dict)}, credstore_env)
|
||||
|
||||
@property
|
||||
def auths(self):
|
||||
def auths(self) -> dict[str, dict[str, t.Any]]:
|
||||
return self.get("auths", {})
|
||||
|
||||
@property
|
||||
def creds_store(self):
|
||||
def creds_store(self) -> str | None:
|
||||
return self.get("credsStore", None)
|
||||
|
||||
@property
|
||||
def cred_helpers(self):
|
||||
def cred_helpers(self) -> dict[str, t.Any]:
|
||||
return self.get("credHelpers", {})
|
||||
|
||||
@property
|
||||
def is_empty(self):
|
||||
def is_empty(self) -> bool:
|
||||
return not self.auths and not self.creds_store and not self.cred_helpers
|
||||
|
||||
def resolve_authconfig(self, registry=None):
|
||||
def resolve_authconfig(
|
||||
self, registry: str | None = None
|
||||
) -> dict[str, t.Any] | None:
|
||||
"""
|
||||
Returns the authentication data from the given auth configuration for a
|
||||
specific registry. As with the Docker client, legacy entries in the
|
||||
@ -244,7 +264,9 @@ class AuthConfig(dict):
|
||||
log.debug("No entry found")
|
||||
return None
|
||||
|
||||
def _resolve_authconfig_credstore(self, registry, credstore_name):
|
||||
def _resolve_authconfig_credstore(
|
||||
self, registry: str | None, credstore_name: str
|
||||
) -> dict[str, t.Any] | None:
|
||||
if not registry or registry == INDEX_NAME:
|
||||
# The ecosystem is a little schizophrenic with index.docker.io VS
|
||||
# docker.io - in that case, it seems the full URL is necessary.
|
||||
@ -272,19 +294,19 @@ class AuthConfig(dict):
|
||||
except StoreError as e:
|
||||
raise errors.DockerException(f"Credentials store error: {e}")
|
||||
|
||||
def _get_store_instance(self, name):
|
||||
def _get_store_instance(self, name: str):
|
||||
if name not in self._stores:
|
||||
self._stores[name] = Store(name, environment=self._credstore_env)
|
||||
return self._stores[name]
|
||||
|
||||
def get_credential_store(self, registry):
|
||||
def get_credential_store(self, registry: str | None) -> str | None:
|
||||
if not registry or registry == INDEX_NAME:
|
||||
registry = INDEX_URL
|
||||
|
||||
return self.cred_helpers.get(registry) or self.creds_store
|
||||
|
||||
def get_all_credentials(self):
|
||||
auth_data = self.auths.copy()
|
||||
def get_all_credentials(self) -> dict[str, dict[str, t.Any] | None]:
|
||||
auth_data: dict[str, dict[str, t.Any] | None] = self.auths.copy() # type: ignore
|
||||
if self.creds_store:
|
||||
# Retrieve all credentials from the default store
|
||||
store = self._get_store_instance(self.creds_store)
|
||||
@ -299,21 +321,23 @@ class AuthConfig(dict):
|
||||
|
||||
return auth_data
|
||||
|
||||
def add_auth(self, reg, data):
|
||||
def add_auth(self, reg: str, data: dict[str, t.Any]) -> None:
|
||||
self["auths"][reg] = data
|
||||
|
||||
|
||||
def resolve_authconfig(authconfig, registry=None, credstore_env=None):
|
||||
def resolve_authconfig(
|
||||
authconfig, registry: str | None = None, credstore_env: dict[str, str] | None = None
|
||||
):
|
||||
if not isinstance(authconfig, AuthConfig):
|
||||
authconfig = AuthConfig(authconfig, credstore_env)
|
||||
return authconfig.resolve_authconfig(registry)
|
||||
|
||||
|
||||
def convert_to_hostname(url):
|
||||
def convert_to_hostname(url: str) -> str:
|
||||
return url.replace("http://", "").replace("https://", "").split("/", 1)[0]
|
||||
|
||||
|
||||
def decode_auth(auth):
|
||||
def decode_auth(auth: str | bytes) -> tuple[str, str]:
|
||||
if isinstance(auth, str):
|
||||
auth = auth.encode("ascii")
|
||||
s = base64.b64decode(auth)
|
||||
@ -321,12 +345,14 @@ def decode_auth(auth):
|
||||
return login.decode("utf8"), pwd.decode("utf8")
|
||||
|
||||
|
||||
def encode_header(auth):
|
||||
def encode_header(auth: dict[str, t.Any]) -> bytes:
|
||||
auth_json = json.dumps(auth).encode("ascii")
|
||||
return base64.urlsafe_b64encode(auth_json)
|
||||
|
||||
|
||||
def parse_auth(entries, raise_on_error=False):
|
||||
def parse_auth(
|
||||
entries: dict[str, dict[str, t.Any]], raise_on_error: bool = False
|
||||
) -> dict[str, dict[str, t.Any]]:
|
||||
"""
|
||||
Parses authentication entries
|
||||
|
||||
@ -342,11 +368,15 @@ def parse_auth(entries, raise_on_error=False):
|
||||
return AuthConfig.parse_auth(entries, raise_on_error)
|
||||
|
||||
|
||||
def load_config(config_path=None, config_dict=None, credstore_env=None):
|
||||
def load_config(
|
||||
config_path: str | None = None,
|
||||
config_dict: dict[str, t.Any] | None = None,
|
||||
credstore_env: dict[str, str] | None = None,
|
||||
) -> AuthConfig:
|
||||
return AuthConfig.load_config(config_path, config_dict, credstore_env)
|
||||
|
||||
|
||||
def _load_legacy_config(config_file):
|
||||
def _load_legacy_config(config_file: str) -> dict[str, dict[str, t.Any]]:
|
||||
log.debug("Attempting to parse legacy auth file format")
|
||||
try:
|
||||
data = []
|
||||
|
||||
@ -13,6 +13,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from .. import errors
|
||||
from .config import (
|
||||
@ -24,7 +25,11 @@ from .config import (
|
||||
from .context import Context
|
||||
|
||||
|
||||
def create_default_context():
|
||||
if t.TYPE_CHECKING:
|
||||
from ..tls import TLSConfig
|
||||
|
||||
|
||||
def create_default_context() -> Context:
|
||||
host = None
|
||||
if os.environ.get("DOCKER_HOST"):
|
||||
host = os.environ.get("DOCKER_HOST")
|
||||
@ -42,7 +47,7 @@ class ContextAPI:
|
||||
DEFAULT_CONTEXT = None
|
||||
|
||||
@classmethod
|
||||
def get_default_context(cls):
|
||||
def get_default_context(cls) -> Context:
|
||||
context = cls.DEFAULT_CONTEXT
|
||||
if context is None:
|
||||
context = create_default_context()
|
||||
@ -52,13 +57,13 @@ class ContextAPI:
|
||||
@classmethod
|
||||
def create_context(
|
||||
cls,
|
||||
name,
|
||||
orchestrator=None,
|
||||
host=None,
|
||||
tls_cfg=None,
|
||||
default_namespace=None,
|
||||
skip_tls_verify=False,
|
||||
):
|
||||
name: str,
|
||||
orchestrator: str | None = None,
|
||||
host: str | None = None,
|
||||
tls_cfg: TLSConfig | None = None,
|
||||
default_namespace: str | None = None,
|
||||
skip_tls_verify: bool = False,
|
||||
) -> Context:
|
||||
"""Creates a new context.
|
||||
Returns:
|
||||
(Context): a Context object.
|
||||
@ -108,7 +113,7 @@ class ContextAPI:
|
||||
return ctx
|
||||
|
||||
@classmethod
|
||||
def get_context(cls, name=None):
|
||||
def get_context(cls, name: str | None = None) -> Context | None:
|
||||
"""Retrieves a context object.
|
||||
Args:
|
||||
name (str): The name of the context
|
||||
@ -136,7 +141,7 @@ class ContextAPI:
|
||||
return Context.load_context(name)
|
||||
|
||||
@classmethod
|
||||
def contexts(cls):
|
||||
def contexts(cls) -> list[Context]:
|
||||
"""Context list.
|
||||
Returns:
|
||||
(Context): List of context objects.
|
||||
@ -170,7 +175,7 @@ class ContextAPI:
|
||||
return contexts
|
||||
|
||||
@classmethod
|
||||
def get_current_context(cls):
|
||||
def get_current_context(cls) -> Context | None:
|
||||
"""Get current context.
|
||||
Returns:
|
||||
(Context): current context object.
|
||||
@ -178,7 +183,7 @@ class ContextAPI:
|
||||
return cls.get_context()
|
||||
|
||||
@classmethod
|
||||
def set_current_context(cls, name="default"):
|
||||
def set_current_context(cls, name: str = "default") -> None:
|
||||
ctx = cls.get_context(name)
|
||||
if not ctx:
|
||||
raise errors.ContextNotFound(name)
|
||||
@ -188,7 +193,7 @@ class ContextAPI:
|
||||
raise errors.ContextException(f"Failed to set current context: {err}")
|
||||
|
||||
@classmethod
|
||||
def remove_context(cls, name):
|
||||
def remove_context(cls, name: str) -> None:
|
||||
"""Remove a context. Similar to the ``docker context rm`` command.
|
||||
|
||||
Args:
|
||||
@ -220,7 +225,7 @@ class ContextAPI:
|
||||
ctx.remove()
|
||||
|
||||
@classmethod
|
||||
def inspect_context(cls, name="default"):
|
||||
def inspect_context(cls, name: str = "default") -> dict[str, t.Any]:
|
||||
"""Inspect a context. Similar to the ``docker context inspect`` command.
|
||||
|
||||
Args:
|
||||
|
||||
@ -23,7 +23,7 @@ from ..utils.utils import parse_host
|
||||
METAFILE = "meta.json"
|
||||
|
||||
|
||||
def get_current_context_name_with_source():
|
||||
def get_current_context_name_with_source() -> tuple[str, str]:
|
||||
if os.environ.get("DOCKER_HOST"):
|
||||
return "default", "DOCKER_HOST environment variable set"
|
||||
if os.environ.get("DOCKER_CONTEXT"):
|
||||
@ -41,11 +41,11 @@ def get_current_context_name_with_source():
|
||||
return "default", "fallback value"
|
||||
|
||||
|
||||
def get_current_context_name():
|
||||
def get_current_context_name() -> str:
|
||||
return get_current_context_name_with_source()[0]
|
||||
|
||||
|
||||
def write_context_name_to_docker_config(name=None):
|
||||
def write_context_name_to_docker_config(name: str | None = None) -> Exception | None:
|
||||
if name == "default":
|
||||
name = None
|
||||
docker_cfg_path = find_config_file()
|
||||
@ -62,44 +62,45 @@ def write_context_name_to_docker_config(name=None):
|
||||
elif name:
|
||||
config["currentContext"] = name
|
||||
else:
|
||||
return
|
||||
return None
|
||||
if not docker_cfg_path:
|
||||
docker_cfg_path = get_default_config_file()
|
||||
try:
|
||||
with open(docker_cfg_path, "wt", encoding="utf-8") as f:
|
||||
json.dump(config, f, indent=4)
|
||||
return None
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
return e
|
||||
|
||||
|
||||
def get_context_id(name):
|
||||
def get_context_id(name: str) -> str:
|
||||
return hashlib.sha256(name.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def get_context_dir():
|
||||
def get_context_dir() -> str:
|
||||
docker_cfg_path = find_config_file() or get_default_config_file()
|
||||
return os.path.join(os.path.dirname(docker_cfg_path), "contexts")
|
||||
|
||||
|
||||
def get_meta_dir(name=None):
|
||||
def get_meta_dir(name: str | None = None) -> str:
|
||||
meta_dir = os.path.join(get_context_dir(), "meta")
|
||||
if name:
|
||||
return os.path.join(meta_dir, get_context_id(name))
|
||||
return meta_dir
|
||||
|
||||
|
||||
def get_meta_file(name):
|
||||
def get_meta_file(name) -> str:
|
||||
return os.path.join(get_meta_dir(name), METAFILE)
|
||||
|
||||
|
||||
def get_tls_dir(name=None, endpoint=""):
|
||||
def get_tls_dir(name: str | None = None, endpoint: str = "") -> str:
|
||||
context_dir = get_context_dir()
|
||||
if name:
|
||||
return os.path.join(context_dir, "tls", get_context_id(name), endpoint)
|
||||
return os.path.join(context_dir, "tls")
|
||||
|
||||
|
||||
def get_context_host(path=None, tls=False):
|
||||
def get_context_host(path: str | None = None, tls: bool = False) -> str:
|
||||
host = parse_host(path, IS_WINDOWS_PLATFORM, tls)
|
||||
if host == DEFAULT_UNIX_SOCKET:
|
||||
# remove http+ from default docker socket url
|
||||
|
||||
@ -13,6 +13,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import typing as t
|
||||
from shutil import copyfile, rmtree
|
||||
|
||||
from ..errors import ContextException
|
||||
@ -33,21 +34,21 @@ class Context:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
orchestrator=None,
|
||||
host=None,
|
||||
endpoints=None,
|
||||
skip_tls_verify=False,
|
||||
tls=False,
|
||||
description=None,
|
||||
):
|
||||
name: str,
|
||||
orchestrator: str | None = None,
|
||||
host: str | None = None,
|
||||
endpoints: dict[str, dict[str, t.Any]] | None = None,
|
||||
skip_tls_verify: bool = False,
|
||||
tls: bool = False,
|
||||
description: str | None = None,
|
||||
) -> None:
|
||||
if not name:
|
||||
raise ValueError("Name not provided")
|
||||
self.name = name
|
||||
self.context_type = None
|
||||
self.orchestrator = orchestrator
|
||||
self.endpoints = {}
|
||||
self.tls_cfg = {}
|
||||
self.tls_cfg: dict[str, TLSConfig] = {}
|
||||
self.meta_path = IN_MEMORY
|
||||
self.tls_path = IN_MEMORY
|
||||
self.description = description
|
||||
@ -89,12 +90,12 @@ class Context:
|
||||
|
||||
def set_endpoint(
|
||||
self,
|
||||
name="docker",
|
||||
host=None,
|
||||
tls_cfg=None,
|
||||
skip_tls_verify=False,
|
||||
def_namespace=None,
|
||||
):
|
||||
name: str = "docker",
|
||||
host: str | None = None,
|
||||
tls_cfg: TLSConfig | None = None,
|
||||
skip_tls_verify: bool = False,
|
||||
def_namespace: str | None = None,
|
||||
) -> None:
|
||||
self.endpoints[name] = {
|
||||
"Host": get_context_host(host, not skip_tls_verify or tls_cfg is not None),
|
||||
"SkipTLSVerify": skip_tls_verify,
|
||||
@ -105,11 +106,11 @@ class Context:
|
||||
if tls_cfg:
|
||||
self.tls_cfg[name] = tls_cfg
|
||||
|
||||
def inspect(self):
|
||||
def inspect(self) -> dict[str, t.Any]:
|
||||
return self()
|
||||
|
||||
@classmethod
|
||||
def load_context(cls, name):
|
||||
def load_context(cls, name: str) -> t.Self | None:
|
||||
meta = Context._load_meta(name)
|
||||
if meta:
|
||||
instance = cls(
|
||||
@ -125,12 +126,12 @@ class Context:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _load_meta(cls, name):
|
||||
def _load_meta(cls, name: str) -> dict[str, t.Any] | None:
|
||||
meta_file = get_meta_file(name)
|
||||
if not os.path.isfile(meta_file):
|
||||
return None
|
||||
|
||||
metadata = {}
|
||||
metadata: dict[str, t.Any] = {}
|
||||
try:
|
||||
with open(meta_file, "rt", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
@ -154,7 +155,7 @@ class Context:
|
||||
|
||||
return metadata
|
||||
|
||||
def _load_certs(self):
|
||||
def _load_certs(self) -> None:
|
||||
certs = {}
|
||||
tls_dir = get_tls_dir(self.name)
|
||||
for endpoint in self.endpoints:
|
||||
@ -184,7 +185,7 @@ class Context:
|
||||
self.tls_cfg = certs
|
||||
self.tls_path = tls_dir
|
||||
|
||||
def save(self):
|
||||
def save(self) -> None:
|
||||
meta_dir = get_meta_dir(self.name)
|
||||
if not os.path.isdir(meta_dir):
|
||||
os.makedirs(meta_dir)
|
||||
@ -216,54 +217,54 @@ class Context:
|
||||
self.meta_path = get_meta_dir(self.name)
|
||||
self.tls_path = get_tls_dir(self.name)
|
||||
|
||||
def remove(self):
|
||||
def remove(self) -> None:
|
||||
if os.path.isdir(self.meta_path):
|
||||
rmtree(self.meta_path)
|
||||
if os.path.isdir(self.tls_path):
|
||||
rmtree(self.tls_path)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}: '{self.name}'>"
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.__call__(), indent=2)
|
||||
|
||||
def __call__(self):
|
||||
def __call__(self) -> dict[str, t.Any]:
|
||||
result = self.Metadata
|
||||
result.update(self.TLSMaterial)
|
||||
result.update(self.Storage)
|
||||
return result
|
||||
|
||||
def is_docker_host(self):
|
||||
def is_docker_host(self) -> bool:
|
||||
return self.context_type is None
|
||||
|
||||
@property
|
||||
def Name(self): # pylint: disable=invalid-name
|
||||
def Name(self) -> str: # pylint: disable=invalid-name
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def Host(self): # pylint: disable=invalid-name
|
||||
def Host(self) -> str | None: # pylint: disable=invalid-name
|
||||
if not self.orchestrator or self.orchestrator == "swarm":
|
||||
endpoint = self.endpoints.get("docker", None)
|
||||
if endpoint:
|
||||
return endpoint.get("Host", None)
|
||||
return endpoint.get("Host", None) # type: ignore
|
||||
return None
|
||||
|
||||
return self.endpoints[self.orchestrator].get("Host", None)
|
||||
return self.endpoints[self.orchestrator].get("Host", None) # type: ignore
|
||||
|
||||
@property
|
||||
def Orchestrator(self): # pylint: disable=invalid-name
|
||||
def Orchestrator(self) -> str | None: # pylint: disable=invalid-name
|
||||
return self.orchestrator
|
||||
|
||||
@property
|
||||
def Metadata(self): # pylint: disable=invalid-name
|
||||
meta = {}
|
||||
def Metadata(self) -> dict[str, t.Any]: # pylint: disable=invalid-name
|
||||
meta: dict[str, t.Any] = {}
|
||||
if self.orchestrator:
|
||||
meta = {"StackOrchestrator": self.orchestrator}
|
||||
return {"Name": self.name, "Metadata": meta, "Endpoints": self.endpoints}
|
||||
|
||||
@property
|
||||
def TLSConfig(self): # pylint: disable=invalid-name
|
||||
def TLSConfig(self) -> TLSConfig | None: # pylint: disable=invalid-name
|
||||
key = self.orchestrator
|
||||
if not key or key == "swarm":
|
||||
key = "docker"
|
||||
@ -272,13 +273,15 @@ class Context:
|
||||
return None
|
||||
|
||||
@property
|
||||
def TLSMaterial(self): # pylint: disable=invalid-name
|
||||
certs = {}
|
||||
def TLSMaterial(self) -> dict[str, t.Any]: # pylint: disable=invalid-name
|
||||
certs: dict[str, t.Any] = {}
|
||||
for endpoint, tls in self.tls_cfg.items():
|
||||
cert, key = tls.cert
|
||||
certs[endpoint] = list(map(os.path.basename, [tls.ca_cert, cert, key]))
|
||||
paths = [tls.ca_cert, *tls.cert] if tls.cert else [tls.ca_cert]
|
||||
certs[endpoint] = [
|
||||
os.path.basename(path) if path else None for path in paths
|
||||
]
|
||||
return {"TLSMaterial": certs}
|
||||
|
||||
@property
|
||||
def Storage(self): # pylint: disable=invalid-name
|
||||
def Storage(self) -> dict[str, t.Any]: # pylint: disable=invalid-name
|
||||
return {"Storage": {"MetadataPath": self.meta_path, "TLSPath": self.tls_path}}
|
||||
|
||||
@ -11,6 +11,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from subprocess import CalledProcessError
|
||||
|
||||
|
||||
class StoreError(RuntimeError):
|
||||
pass
|
||||
@ -24,7 +30,7 @@ class InitializationError(StoreError):
|
||||
pass
|
||||
|
||||
|
||||
def process_store_error(cpe, program):
|
||||
def process_store_error(cpe: CalledProcessError, program: str) -> StoreError:
|
||||
message = cpe.output.decode("utf-8")
|
||||
if "credentials not found in native keychain" in message:
|
||||
return CredentialsNotFound(f"No matching credentials in {program}")
|
||||
|
||||
@ -14,13 +14,14 @@ from __future__ import annotations
|
||||
import errno
|
||||
import json
|
||||
import subprocess
|
||||
import typing as t
|
||||
|
||||
from . import constants, errors
|
||||
from .utils import create_environment_dict, find_executable
|
||||
|
||||
|
||||
class Store:
|
||||
def __init__(self, program, environment=None):
|
||||
def __init__(self, program: str, environment: dict[str, str] | None = None) -> None:
|
||||
"""Create a store object that acts as an interface to
|
||||
perform the basic operations for storing, retrieving
|
||||
and erasing credentials using `program`.
|
||||
@ -33,7 +34,7 @@ class Store:
|
||||
f"{self.program} not installed or not available in PATH"
|
||||
)
|
||||
|
||||
def get(self, server):
|
||||
def get(self, server: str | bytes) -> dict[str, t.Any]:
|
||||
"""Retrieve credentials for `server`. If no credentials are found,
|
||||
a `StoreError` will be raised.
|
||||
"""
|
||||
@ -53,7 +54,7 @@ class Store:
|
||||
|
||||
return result
|
||||
|
||||
def store(self, server, username, secret):
|
||||
def store(self, server: str, username: str, secret: str) -> bytes:
|
||||
"""Store credentials for `server`. Raises a `StoreError` if an error
|
||||
occurs.
|
||||
"""
|
||||
@ -62,7 +63,7 @@ class Store:
|
||||
).encode("utf-8")
|
||||
return self._execute("store", data_input)
|
||||
|
||||
def erase(self, server):
|
||||
def erase(self, server: str | bytes) -> None:
|
||||
"""Erase credentials for `server`. Raises a `StoreError` if an error
|
||||
occurs.
|
||||
"""
|
||||
@ -70,12 +71,16 @@ class Store:
|
||||
server = server.encode("utf-8")
|
||||
self._execute("erase", server)
|
||||
|
||||
def list(self):
|
||||
def list(self) -> t.Any:
|
||||
"""List stored credentials. Requires v0.4.0+ of the helper."""
|
||||
data = self._execute("list", None)
|
||||
return json.loads(data.decode("utf-8"))
|
||||
|
||||
def _execute(self, subcmd, data_input):
|
||||
def _execute(self, subcmd: str, data_input: bytes | None) -> bytes:
|
||||
if self.exe is None:
|
||||
raise errors.StoreError(
|
||||
f"{self.program} not installed or not available in PATH"
|
||||
)
|
||||
output = None
|
||||
env = create_environment_dict(self.environment)
|
||||
try:
|
||||
|
||||
@ -15,7 +15,7 @@ import os
|
||||
from shutil import which
|
||||
|
||||
|
||||
def find_executable(executable, path=None):
|
||||
def find_executable(executable: str, path: str | None = None) -> str | None:
|
||||
"""
|
||||
As distutils.spawn.find_executable, but on Windows, look up
|
||||
every extension declared in PATHEXT instead of just `.exe`
|
||||
@ -26,7 +26,7 @@ def find_executable(executable, path=None):
|
||||
return which(executable, path=path)
|
||||
|
||||
|
||||
def create_environment_dict(overrides):
|
||||
def create_environment_dict(overrides: dict[str, str] | None) -> dict[str, str]:
|
||||
"""
|
||||
Create and return a copy of os.environ with the specified overrides
|
||||
"""
|
||||
|
||||
@ -11,6 +11,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.module_utils.common.text.converters import to_native
|
||||
|
||||
from ._import_helper import HTTPError as _HTTPError
|
||||
@ -25,7 +27,7 @@ class DockerException(Exception):
|
||||
"""
|
||||
|
||||
|
||||
def create_api_error_from_http_exception(e):
|
||||
def create_api_error_from_http_exception(e: _HTTPError) -> t.NoReturn:
|
||||
"""
|
||||
Create a suitable APIError from requests.exceptions.HTTPError.
|
||||
"""
|
||||
@ -52,14 +54,16 @@ class APIError(_HTTPError, DockerException):
|
||||
An HTTP error from the API.
|
||||
"""
|
||||
|
||||
def __init__(self, message, response=None, explanation=None):
|
||||
def __init__(
|
||||
self, message: str | Exception, response=None, explanation: str | None = None
|
||||
) -> None:
|
||||
# requests 1.2 supports response as a keyword argument, but
|
||||
# requests 1.1 does not
|
||||
super().__init__(message)
|
||||
self.response = response
|
||||
self.explanation = explanation
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
message = super().__str__()
|
||||
|
||||
if self.is_client_error():
|
||||
@ -74,19 +78,20 @@ class APIError(_HTTPError, DockerException):
|
||||
return message
|
||||
|
||||
@property
|
||||
def status_code(self):
|
||||
def status_code(self) -> int | None:
|
||||
if self.response is not None:
|
||||
return self.response.status_code
|
||||
return None
|
||||
|
||||
def is_error(self):
|
||||
def is_error(self) -> bool:
|
||||
return self.is_client_error() or self.is_server_error()
|
||||
|
||||
def is_client_error(self):
|
||||
def is_client_error(self) -> bool:
|
||||
if self.status_code is None:
|
||||
return False
|
||||
return 400 <= self.status_code < 500
|
||||
|
||||
def is_server_error(self):
|
||||
def is_server_error(self) -> bool:
|
||||
if self.status_code is None:
|
||||
return False
|
||||
return 500 <= self.status_code < 600
|
||||
@ -121,10 +126,10 @@ class DeprecatedMethod(DockerException):
|
||||
|
||||
|
||||
class TLSParameterError(DockerException):
|
||||
def __init__(self, msg):
|
||||
def __init__(self, msg: str) -> None:
|
||||
self.msg = msg
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.msg + (
|
||||
". TLS configurations should map the Docker CLI "
|
||||
"client configurations. See "
|
||||
@ -142,7 +147,14 @@ class ContainerError(DockerException):
|
||||
Represents a container that has exited with a non-zero exit code.
|
||||
"""
|
||||
|
||||
def __init__(self, container, exit_status, command, image, stderr):
|
||||
def __init__(
|
||||
self,
|
||||
container: str,
|
||||
exit_status: int,
|
||||
command: list[str],
|
||||
image: str,
|
||||
stderr: str | None,
|
||||
):
|
||||
self.container = container
|
||||
self.exit_status = exit_status
|
||||
self.command = command
|
||||
@ -156,12 +168,12 @@ class ContainerError(DockerException):
|
||||
|
||||
|
||||
class StreamParseError(RuntimeError):
|
||||
def __init__(self, reason):
|
||||
def __init__(self, reason: Exception) -> None:
|
||||
self.msg = reason
|
||||
|
||||
|
||||
class BuildError(DockerException):
|
||||
def __init__(self, reason, build_log):
|
||||
def __init__(self, reason: str, build_log: str) -> None:
|
||||
super().__init__(reason)
|
||||
self.msg = reason
|
||||
self.build_log = build_log
|
||||
@ -171,7 +183,7 @@ class ImageLoadError(DockerException):
|
||||
pass
|
||||
|
||||
|
||||
def create_unexpected_kwargs_error(name, kwargs):
|
||||
def create_unexpected_kwargs_error(name: str, kwargs: dict[str, t.Any]) -> TypeError:
|
||||
quoted_kwargs = [f"'{k}'" for k in sorted(kwargs)]
|
||||
text = [f"{name}() "]
|
||||
if len(quoted_kwargs) == 1:
|
||||
@ -183,42 +195,44 @@ def create_unexpected_kwargs_error(name, kwargs):
|
||||
|
||||
|
||||
class MissingContextParameter(DockerException):
|
||||
def __init__(self, param):
|
||||
def __init__(self, param: str) -> None:
|
||||
self.param = param
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"missing parameter: {self.param}"
|
||||
|
||||
|
||||
class ContextAlreadyExists(DockerException):
|
||||
def __init__(self, name):
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"context {self.name} already exists"
|
||||
|
||||
|
||||
class ContextException(DockerException):
|
||||
def __init__(self, msg):
|
||||
def __init__(self, msg: str) -> None:
|
||||
self.msg = msg
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.msg
|
||||
|
||||
|
||||
class ContextNotFound(DockerException):
|
||||
def __init__(self, name):
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"context '{self.name}' not found"
|
||||
|
||||
|
||||
class MissingRequirementException(DockerException):
|
||||
def __init__(self, msg, requirement, import_exception):
|
||||
def __init__(
|
||||
self, msg: str, requirement: str, import_exception: ImportError | str
|
||||
) -> None:
|
||||
self.msg = msg
|
||||
self.requirement = requirement
|
||||
self.import_exception = import_exception
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.msg
|
||||
|
||||
@ -12,12 +12,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import ssl
|
||||
import typing as t
|
||||
|
||||
from . import errors
|
||||
from .transport.ssladapter import SSLHTTPAdapter
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ansible_collections.community.docker.plugins.module_utils._api.api.client import (
|
||||
APIClient,
|
||||
)
|
||||
|
||||
|
||||
class TLSConfig:
|
||||
"""
|
||||
TLS configuration.
|
||||
@ -27,25 +33,22 @@ class TLSConfig:
|
||||
ca_cert (str): Path to CA cert file.
|
||||
verify (bool or str): This can be ``False`` or a path to a CA cert
|
||||
file.
|
||||
ssl_version (int): A valid `SSL version`_.
|
||||
assert_hostname (bool): Verify the hostname of the server.
|
||||
|
||||
.. _`SSL version`:
|
||||
https://docs.python.org/3.5/library/ssl.html#ssl.PROTOCOL_TLSv1
|
||||
"""
|
||||
|
||||
cert = None
|
||||
ca_cert = None
|
||||
verify = None
|
||||
ssl_version = None
|
||||
cert: tuple[str, str] | None = None
|
||||
ca_cert: str | None = None
|
||||
verify: bool | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_cert=None,
|
||||
ca_cert=None,
|
||||
verify=None,
|
||||
ssl_version=None,
|
||||
assert_hostname=None,
|
||||
client_cert: tuple[str, str] | None = None,
|
||||
ca_cert: str | None = None,
|
||||
verify: bool | None = None,
|
||||
assert_hostname: bool | None = None,
|
||||
):
|
||||
# Argument compatibility/mapping with
|
||||
# https://docs.docker.com/engine/articles/https/
|
||||
@ -55,12 +58,6 @@ class TLSConfig:
|
||||
|
||||
self.assert_hostname = assert_hostname
|
||||
|
||||
# If the user provides an SSL version, we should use their preference
|
||||
if ssl_version:
|
||||
self.ssl_version = ssl_version
|
||||
else:
|
||||
self.ssl_version = ssl.PROTOCOL_TLS_CLIENT
|
||||
|
||||
# "client_cert" must have both or neither cert/key files. In
|
||||
# either case, Alert the user when both are expected, but any are
|
||||
# missing.
|
||||
@ -90,11 +87,10 @@ class TLSConfig:
|
||||
"Invalid CA certificate provided for `ca_cert`."
|
||||
)
|
||||
|
||||
def configure_client(self, client):
|
||||
def configure_client(self, client: APIClient) -> None:
|
||||
"""
|
||||
Configure a client with these TLS options.
|
||||
"""
|
||||
client.ssl_version = self.ssl_version
|
||||
|
||||
if self.verify and self.ca_cert:
|
||||
client.verify = self.ca_cert
|
||||
@ -107,7 +103,6 @@ class TLSConfig:
|
||||
client.mount(
|
||||
"https://",
|
||||
SSLHTTPAdapter(
|
||||
ssl_version=self.ssl_version,
|
||||
assert_hostname=self.assert_hostname,
|
||||
),
|
||||
)
|
||||
|
||||
@ -15,7 +15,7 @@ from .._import_helper import HTTPAdapter as _HTTPAdapter
|
||||
|
||||
|
||||
class BaseHTTPAdapter(_HTTPAdapter):
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
super().close()
|
||||
if hasattr(self, "pools"):
|
||||
self.pools.clear()
|
||||
@ -24,10 +24,10 @@ class BaseHTTPAdapter(_HTTPAdapter):
|
||||
# https://github.com/psf/requests/commit/c0813a2d910ea6b4f8438b91d315b8d181302356
|
||||
# changes requests.adapters.HTTPAdapter to no longer call get_connection() from
|
||||
# send(), but instead call _get_connection().
|
||||
def _get_connection(self, request, *args, **kwargs):
|
||||
def _get_connection(self, request, *args, **kwargs): # type: ignore
|
||||
return self.get_connection(request.url, kwargs.get("proxies"))
|
||||
|
||||
# Fix for requests 2.32.2+:
|
||||
# https://github.com/psf/requests/commit/c98e4d133ef29c46a9b68cd783087218a8075e05
|
||||
def get_connection_with_tls_context(self, request, verify, proxies=None, cert=None):
|
||||
def get_connection_with_tls_context(self, request, verify, proxies=None, cert=None): # type: ignore
|
||||
return self.get_connection(request.url, proxies)
|
||||
|
||||
@ -23,12 +23,12 @@ RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
|
||||
|
||||
|
||||
class NpipeHTTPConnection(urllib3_connection.HTTPConnection):
|
||||
def __init__(self, npipe_path, timeout=60):
|
||||
def __init__(self, npipe_path: str, timeout: int = 60) -> None:
|
||||
super().__init__("localhost", timeout=timeout)
|
||||
self.npipe_path = npipe_path
|
||||
self.timeout = timeout
|
||||
|
||||
def connect(self):
|
||||
def connect(self) -> None:
|
||||
sock = NpipeSocket()
|
||||
sock.settimeout(self.timeout)
|
||||
sock.connect(self.npipe_path)
|
||||
@ -36,18 +36,18 @@ class NpipeHTTPConnection(urllib3_connection.HTTPConnection):
|
||||
|
||||
|
||||
class NpipeHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
|
||||
def __init__(self, npipe_path, timeout=60, maxsize=10):
|
||||
def __init__(self, npipe_path: str, timeout: int = 60, maxsize: int = 10) -> None:
|
||||
super().__init__("localhost", timeout=timeout, maxsize=maxsize)
|
||||
self.npipe_path = npipe_path
|
||||
self.timeout = timeout
|
||||
|
||||
def _new_conn(self):
|
||||
def _new_conn(self) -> NpipeHTTPConnection:
|
||||
return NpipeHTTPConnection(self.npipe_path, self.timeout)
|
||||
|
||||
# When re-using connections, urllib3 tries to call select() on our
|
||||
# NpipeSocket instance, causing a crash. To circumvent this, we override
|
||||
# _get_conn, where that check happens.
|
||||
def _get_conn(self, timeout):
|
||||
def _get_conn(self, timeout: int) -> NpipeHTTPConnection:
|
||||
conn = None
|
||||
try:
|
||||
conn = self.pool.get(block=self.block, timeout=timeout)
|
||||
@ -67,7 +67,6 @@ class NpipeHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
|
||||
|
||||
|
||||
class NpipeHTTPAdapter(BaseHTTPAdapter):
|
||||
|
||||
__attrs__ = HTTPAdapter.__attrs__ + [
|
||||
"npipe_path",
|
||||
"pools",
|
||||
@ -77,11 +76,11 @@ class NpipeHTTPAdapter(BaseHTTPAdapter):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url,
|
||||
timeout=60,
|
||||
pool_connections=constants.DEFAULT_NUM_POOLS,
|
||||
max_pool_size=constants.DEFAULT_MAX_POOL_SIZE,
|
||||
):
|
||||
base_url: str,
|
||||
timeout: int = 60,
|
||||
pool_connections: int = constants.DEFAULT_NUM_POOLS,
|
||||
max_pool_size: int = constants.DEFAULT_MAX_POOL_SIZE,
|
||||
) -> None:
|
||||
self.npipe_path = base_url.replace("npipe://", "")
|
||||
self.timeout = timeout
|
||||
self.max_pool_size = max_pool_size
|
||||
@ -90,7 +89,7 @@ class NpipeHTTPAdapter(BaseHTTPAdapter):
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
def get_connection(self, url, proxies=None):
|
||||
def get_connection(self, url: str | bytes, proxies=None) -> NpipeHTTPConnectionPool:
|
||||
with self.pools.lock:
|
||||
pool = self.pools.get(url)
|
||||
if pool:
|
||||
@ -103,7 +102,7 @@ class NpipeHTTPAdapter(BaseHTTPAdapter):
|
||||
|
||||
return pool
|
||||
|
||||
def request_url(self, request, proxies):
|
||||
def request_url(self, request, proxies) -> str:
|
||||
# The select_proxy utility in requests errors out when the provided URL
|
||||
# does not have a hostname, like is the case when using a UNIX socket.
|
||||
# Since proxies are an irrelevant notion in the case of UNIX sockets
|
||||
|
||||
@ -15,6 +15,7 @@ import functools
|
||||
import io
|
||||
import time
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
|
||||
PYWIN32_IMPORT_ERROR: str | None # pylint: disable=invalid-name
|
||||
@ -29,6 +30,13 @@ except ImportError:
|
||||
else:
|
||||
PYWIN32_IMPORT_ERROR = None # pylint: disable=invalid-name
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from collections.abc import Buffer, Callable
|
||||
|
||||
_Self = t.TypeVar("_Self")
|
||||
_P = t.ParamSpec("_P")
|
||||
_R = t.TypeVar("_R")
|
||||
|
||||
|
||||
ERROR_PIPE_BUSY = 0xE7
|
||||
SECURITY_SQOS_PRESENT = 0x100000
|
||||
@ -37,10 +45,12 @@ SECURITY_ANONYMOUS = 0
|
||||
MAXIMUM_RETRY_COUNT = 10
|
||||
|
||||
|
||||
def check_closed(f):
|
||||
def check_closed(
|
||||
f: Callable[t.Concatenate[_Self, _P], _R],
|
||||
) -> Callable[t.Concatenate[_Self, _P], _R]:
|
||||
@functools.wraps(f)
|
||||
def wrapped(self, *args, **kwargs):
|
||||
if self._closed:
|
||||
def wrapped(self: _Self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
if self._closed: # type: ignore
|
||||
raise RuntimeError("Can not reuse socket after connection was closed.")
|
||||
return f(self, *args, **kwargs)
|
||||
|
||||
@ -54,25 +64,25 @@ class NpipeSocket:
|
||||
implemented.
|
||||
"""
|
||||
|
||||
def __init__(self, handle=None):
|
||||
def __init__(self, handle=None) -> None:
|
||||
self._timeout = win32pipe.NMPWAIT_USE_DEFAULT_WAIT
|
||||
self._handle = handle
|
||||
self._address = None
|
||||
self._address: str | None = None
|
||||
self._closed = False
|
||||
self.flags = None
|
||||
self.flags: int | None = None
|
||||
|
||||
def accept(self):
|
||||
def accept(self) -> t.NoReturn:
|
||||
raise NotImplementedError()
|
||||
|
||||
def bind(self, address):
|
||||
def bind(self, address) -> t.NoReturn:
|
||||
raise NotImplementedError()
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
self._handle.Close()
|
||||
self._closed = True
|
||||
|
||||
@check_closed
|
||||
def connect(self, address, retry_count=0):
|
||||
def connect(self, address, retry_count: int = 0) -> None:
|
||||
try:
|
||||
handle = win32file.CreateFile(
|
||||
address,
|
||||
@ -100,14 +110,14 @@ class NpipeSocket:
|
||||
return self.connect(address, retry_count)
|
||||
raise e
|
||||
|
||||
self.flags = win32pipe.GetNamedPipeInfo(handle)[0]
|
||||
self.flags = win32pipe.GetNamedPipeInfo(handle)[0] # type: ignore
|
||||
|
||||
self._handle = handle
|
||||
self._address = address
|
||||
|
||||
@check_closed
|
||||
def connect_ex(self, address):
|
||||
return self.connect(address)
|
||||
def connect_ex(self, address) -> None:
|
||||
self.connect(address)
|
||||
|
||||
@check_closed
|
||||
def detach(self):
|
||||
@ -115,25 +125,25 @@ class NpipeSocket:
|
||||
return self._handle
|
||||
|
||||
@check_closed
|
||||
def dup(self):
|
||||
def dup(self) -> NpipeSocket:
|
||||
return NpipeSocket(self._handle)
|
||||
|
||||
def getpeername(self):
|
||||
def getpeername(self) -> str | None:
|
||||
return self._address
|
||||
|
||||
def getsockname(self):
|
||||
def getsockname(self) -> str | None:
|
||||
return self._address
|
||||
|
||||
def getsockopt(self, level, optname, buflen=None):
|
||||
def getsockopt(self, level, optname, buflen=None) -> t.NoReturn:
|
||||
raise NotImplementedError()
|
||||
|
||||
def ioctl(self, control, option):
|
||||
def ioctl(self, control, option) -> t.NoReturn:
|
||||
raise NotImplementedError()
|
||||
|
||||
def listen(self, backlog):
|
||||
def listen(self, backlog) -> t.NoReturn:
|
||||
raise NotImplementedError()
|
||||
|
||||
def makefile(self, mode=None, bufsize=None):
|
||||
def makefile(self, mode: str, bufsize: int | None = None):
|
||||
if mode.strip("b") != "r":
|
||||
raise NotImplementedError()
|
||||
rawio = NpipeFileIOBase(self)
|
||||
@ -142,30 +152,30 @@ class NpipeSocket:
|
||||
return io.BufferedReader(rawio, buffer_size=bufsize)
|
||||
|
||||
@check_closed
|
||||
def recv(self, bufsize, flags=0):
|
||||
def recv(self, bufsize: int, flags: int = 0) -> str:
|
||||
dummy_err, data = win32file.ReadFile(self._handle, bufsize)
|
||||
return data
|
||||
|
||||
@check_closed
|
||||
def recvfrom(self, bufsize, flags=0):
|
||||
def recvfrom(self, bufsize: int, flags: int = 0) -> tuple[str, str | None]:
|
||||
data = self.recv(bufsize, flags)
|
||||
return (data, self._address)
|
||||
|
||||
@check_closed
|
||||
def recvfrom_into(self, buf, nbytes=0, flags=0):
|
||||
return self.recv_into(buf, nbytes, flags), self._address
|
||||
def recvfrom_into(
|
||||
self, buf: Buffer, nbytes: int = 0, flags: int = 0
|
||||
) -> tuple[int, str | None]:
|
||||
return self.recv_into(buf, nbytes), self._address
|
||||
|
||||
@check_closed
|
||||
def recv_into(self, buf, nbytes=0):
|
||||
readbuf = buf
|
||||
if not isinstance(buf, memoryview):
|
||||
readbuf = memoryview(buf)
|
||||
def recv_into(self, buf: Buffer, nbytes: int = 0) -> int:
|
||||
readbuf = buf if isinstance(buf, memoryview) else memoryview(buf)
|
||||
|
||||
event = win32event.CreateEvent(None, True, True, None)
|
||||
try:
|
||||
overlapped = pywintypes.OVERLAPPED()
|
||||
overlapped.hEvent = event
|
||||
dummy_err, dummy_data = win32file.ReadFile(
|
||||
dummy_err, dummy_data = win32file.ReadFile( # type: ignore
|
||||
self._handle, readbuf[:nbytes] if nbytes else readbuf, overlapped
|
||||
)
|
||||
wait_result = win32event.WaitForSingleObject(event, self._timeout)
|
||||
@ -177,12 +187,12 @@ class NpipeSocket:
|
||||
win32api.CloseHandle(event)
|
||||
|
||||
@check_closed
|
||||
def send(self, string, flags=0):
|
||||
def send(self, string: Buffer, flags: int = 0) -> int:
|
||||
event = win32event.CreateEvent(None, True, True, None)
|
||||
try:
|
||||
overlapped = pywintypes.OVERLAPPED()
|
||||
overlapped.hEvent = event
|
||||
win32file.WriteFile(self._handle, string, overlapped)
|
||||
win32file.WriteFile(self._handle, string, overlapped) # type: ignore
|
||||
wait_result = win32event.WaitForSingleObject(event, self._timeout)
|
||||
if wait_result == win32event.WAIT_TIMEOUT:
|
||||
win32file.CancelIo(self._handle)
|
||||
@ -192,20 +202,20 @@ class NpipeSocket:
|
||||
win32api.CloseHandle(event)
|
||||
|
||||
@check_closed
|
||||
def sendall(self, string, flags=0):
|
||||
def sendall(self, string: Buffer, flags: int = 0) -> int:
|
||||
return self.send(string, flags)
|
||||
|
||||
@check_closed
|
||||
def sendto(self, string, address):
|
||||
def sendto(self, string: Buffer, address: str) -> int:
|
||||
self.connect(address)
|
||||
return self.send(string)
|
||||
|
||||
def setblocking(self, flag):
|
||||
def setblocking(self, flag: bool):
|
||||
if flag:
|
||||
return self.settimeout(None)
|
||||
return self.settimeout(0)
|
||||
|
||||
def settimeout(self, value):
|
||||
def settimeout(self, value: int | float | None) -> None:
|
||||
if value is None:
|
||||
# Blocking mode
|
||||
self._timeout = win32event.INFINITE
|
||||
@ -215,39 +225,39 @@ class NpipeSocket:
|
||||
# Timeout mode - Value converted to milliseconds
|
||||
self._timeout = int(value * 1000)
|
||||
|
||||
def gettimeout(self):
|
||||
def gettimeout(self) -> int | float | None:
|
||||
return self._timeout
|
||||
|
||||
def setsockopt(self, level, optname, value):
|
||||
def setsockopt(self, level, optname, value) -> t.NoReturn:
|
||||
raise NotImplementedError()
|
||||
|
||||
@check_closed
|
||||
def shutdown(self, how):
|
||||
def shutdown(self, how) -> None:
|
||||
return self.close()
|
||||
|
||||
|
||||
class NpipeFileIOBase(io.RawIOBase):
|
||||
def __init__(self, npipe_socket):
|
||||
def __init__(self, npipe_socket) -> None:
|
||||
self.sock = npipe_socket
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
super().close()
|
||||
self.sock = None
|
||||
|
||||
def fileno(self):
|
||||
def fileno(self) -> int:
|
||||
return self.sock.fileno()
|
||||
|
||||
def isatty(self):
|
||||
def isatty(self) -> bool:
|
||||
return False
|
||||
|
||||
def readable(self):
|
||||
def readable(self) -> bool:
|
||||
return True
|
||||
|
||||
def readinto(self, buf):
|
||||
def readinto(self, buf: Buffer) -> int:
|
||||
return self.sock.recv_into(buf)
|
||||
|
||||
def seekable(self):
|
||||
def seekable(self) -> bool:
|
||||
return False
|
||||
|
||||
def writable(self):
|
||||
def writable(self) -> bool:
|
||||
return False
|
||||
|
||||
@ -17,6 +17,7 @@ import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import traceback
|
||||
import typing as t
|
||||
from queue import Empty
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@ -33,12 +34,15 @@ except ImportError:
|
||||
else:
|
||||
PARAMIKO_IMPORT_ERROR = None # pylint: disable=invalid-name
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from collections.abc import Buffer
|
||||
|
||||
|
||||
RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
|
||||
|
||||
|
||||
class SSHSocket(socket.socket):
|
||||
def __init__(self, host):
|
||||
def __init__(self, host: str) -> None:
|
||||
super().__init__(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.host = host
|
||||
self.port = None
|
||||
@ -48,9 +52,9 @@ class SSHSocket(socket.socket):
|
||||
if "@" in self.host:
|
||||
self.user, self.host = self.host.split("@")
|
||||
|
||||
self.proc = None
|
||||
self.proc: subprocess.Popen | None = None
|
||||
|
||||
def connect(self, **kwargs):
|
||||
def connect(self, *args_: t.Any, **kwargs: t.Any) -> None:
|
||||
args = ["ssh"]
|
||||
if self.user:
|
||||
args = args + ["-l", self.user]
|
||||
@ -82,37 +86,48 @@ class SSHSocket(socket.socket):
|
||||
preexec_fn=preexec_func,
|
||||
)
|
||||
|
||||
def _write(self, data):
|
||||
if not self.proc or self.proc.stdin.closed:
|
||||
def _write(self, data: Buffer) -> int:
|
||||
if not self.proc:
|
||||
raise RuntimeError(
|
||||
"SSH subprocess not initiated. connect() must be called first."
|
||||
)
|
||||
assert self.proc.stdin is not None
|
||||
if self.proc.stdin.closed:
|
||||
raise RuntimeError(
|
||||
"SSH subprocess not initiated. connect() must be called first after close()."
|
||||
)
|
||||
written = self.proc.stdin.write(data)
|
||||
self.proc.stdin.flush()
|
||||
return written
|
||||
|
||||
def sendall(self, data):
|
||||
def sendall(self, data: Buffer, *args, **kwargs) -> None:
|
||||
self._write(data)
|
||||
|
||||
def send(self, data):
|
||||
def send(self, data: Buffer, *args, **kwargs) -> int:
|
||||
return self._write(data)
|
||||
|
||||
def recv(self, n):
|
||||
def recv(self, n: int, *args, **kwargs) -> bytes:
|
||||
if not self.proc:
|
||||
raise RuntimeError(
|
||||
"SSH subprocess not initiated. connect() must be called first."
|
||||
)
|
||||
assert self.proc.stdout is not None
|
||||
return self.proc.stdout.read(n)
|
||||
|
||||
def makefile(self, mode):
|
||||
def makefile(self, mode: str, *args, **kwargs) -> t.IO: # type: ignore
|
||||
if not self.proc:
|
||||
self.connect()
|
||||
self.proc.stdout.channel = self
|
||||
assert self.proc is not None
|
||||
assert self.proc.stdout is not None
|
||||
self.proc.stdout.channel = self # type: ignore
|
||||
|
||||
return self.proc.stdout
|
||||
|
||||
def close(self):
|
||||
if not self.proc or self.proc.stdin.closed:
|
||||
def close(self) -> None:
|
||||
if not self.proc:
|
||||
return
|
||||
assert self.proc.stdin is not None
|
||||
if self.proc.stdin.closed:
|
||||
return
|
||||
self.proc.stdin.write(b"\n\n")
|
||||
self.proc.stdin.flush()
|
||||
@ -120,13 +135,19 @@ class SSHSocket(socket.socket):
|
||||
|
||||
|
||||
class SSHConnection(urllib3_connection.HTTPConnection):
|
||||
def __init__(self, ssh_transport=None, timeout=60, host=None):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ssh_transport=None,
|
||||
timeout: int = 60,
|
||||
host: str,
|
||||
) -> None:
|
||||
super().__init__("localhost", timeout=timeout)
|
||||
self.ssh_transport = ssh_transport
|
||||
self.timeout = timeout
|
||||
self.ssh_host = host
|
||||
|
||||
def connect(self):
|
||||
def connect(self) -> None:
|
||||
if self.ssh_transport:
|
||||
sock = self.ssh_transport.open_session()
|
||||
sock.settimeout(self.timeout)
|
||||
@ -142,7 +163,14 @@ class SSHConnection(urllib3_connection.HTTPConnection):
|
||||
class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
|
||||
scheme = "ssh"
|
||||
|
||||
def __init__(self, ssh_client=None, timeout=60, maxsize=10, host=None):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ssh_client: paramiko.SSHClient | None = None,
|
||||
timeout: int = 60,
|
||||
maxsize: int = 10,
|
||||
host: str,
|
||||
) -> None:
|
||||
super().__init__("localhost", timeout=timeout, maxsize=maxsize)
|
||||
self.ssh_transport = None
|
||||
self.timeout = timeout
|
||||
@ -150,13 +178,17 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
|
||||
self.ssh_transport = ssh_client.get_transport()
|
||||
self.ssh_host = host
|
||||
|
||||
def _new_conn(self):
|
||||
return SSHConnection(self.ssh_transport, self.timeout, self.ssh_host)
|
||||
def _new_conn(self) -> SSHConnection:
|
||||
return SSHConnection(
|
||||
ssh_transport=self.ssh_transport,
|
||||
timeout=self.timeout,
|
||||
host=self.ssh_host,
|
||||
)
|
||||
|
||||
# When re-using connections, urllib3 calls fileno() on our
|
||||
# SSH channel instance, quickly overloading our fd limit. To avoid this,
|
||||
# we override _get_conn
|
||||
def _get_conn(self, timeout):
|
||||
def _get_conn(self, timeout: int) -> SSHConnection:
|
||||
conn = None
|
||||
try:
|
||||
conn = self.pool.get(block=self.block, timeout=timeout)
|
||||
@ -176,7 +208,6 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
|
||||
|
||||
|
||||
class SSHHTTPAdapter(BaseHTTPAdapter):
|
||||
|
||||
__attrs__ = HTTPAdapter.__attrs__ + [
|
||||
"pools",
|
||||
"timeout",
|
||||
@ -187,13 +218,13 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url,
|
||||
timeout=60,
|
||||
pool_connections=constants.DEFAULT_NUM_POOLS,
|
||||
max_pool_size=constants.DEFAULT_MAX_POOL_SIZE,
|
||||
shell_out=False,
|
||||
):
|
||||
self.ssh_client = None
|
||||
base_url: str,
|
||||
timeout: int = 60,
|
||||
pool_connections: int = constants.DEFAULT_NUM_POOLS,
|
||||
max_pool_size: int = constants.DEFAULT_MAX_POOL_SIZE,
|
||||
shell_out: bool = False,
|
||||
) -> None:
|
||||
self.ssh_client: paramiko.SSHClient | None = None
|
||||
if not shell_out:
|
||||
self._create_paramiko_client(base_url)
|
||||
self._connect()
|
||||
@ -209,30 +240,31 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
def _create_paramiko_client(self, base_url):
|
||||
def _create_paramiko_client(self, base_url: str) -> None:
|
||||
logging.getLogger("paramiko").setLevel(logging.WARNING)
|
||||
self.ssh_client = paramiko.SSHClient()
|
||||
base_url = urlparse(base_url)
|
||||
self.ssh_params = {
|
||||
"hostname": base_url.hostname,
|
||||
"port": base_url.port,
|
||||
"username": base_url.username,
|
||||
base_url_p = urlparse(base_url)
|
||||
assert base_url_p.hostname is not None
|
||||
self.ssh_params: dict[str, t.Any] = {
|
||||
"hostname": base_url_p.hostname,
|
||||
"port": base_url_p.port,
|
||||
"username": base_url_p.username,
|
||||
}
|
||||
ssh_config_file = os.path.expanduser("~/.ssh/config")
|
||||
if os.path.exists(ssh_config_file):
|
||||
conf = paramiko.SSHConfig()
|
||||
with open(ssh_config_file, "rt", encoding="utf-8") as f:
|
||||
conf.parse(f)
|
||||
host_config = conf.lookup(base_url.hostname)
|
||||
host_config = conf.lookup(base_url_p.hostname)
|
||||
if "proxycommand" in host_config:
|
||||
self.ssh_params["sock"] = paramiko.ProxyCommand(
|
||||
host_config["proxycommand"]
|
||||
)
|
||||
if "hostname" in host_config:
|
||||
self.ssh_params["hostname"] = host_config["hostname"]
|
||||
if base_url.port is None and "port" in host_config:
|
||||
if base_url_p.port is None and "port" in host_config:
|
||||
self.ssh_params["port"] = host_config["port"]
|
||||
if base_url.username is None and "user" in host_config:
|
||||
if base_url_p.username is None and "user" in host_config:
|
||||
self.ssh_params["username"] = host_config["user"]
|
||||
if "identityfile" in host_config:
|
||||
self.ssh_params["key_filename"] = host_config["identityfile"]
|
||||
@ -240,11 +272,11 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
|
||||
self.ssh_client.load_system_host_keys()
|
||||
self.ssh_client.set_missing_host_key_policy(paramiko.RejectPolicy())
|
||||
|
||||
def _connect(self):
|
||||
def _connect(self) -> None:
|
||||
if self.ssh_client:
|
||||
self.ssh_client.connect(**self.ssh_params)
|
||||
|
||||
def get_connection(self, url, proxies=None):
|
||||
def get_connection(self, url: str | bytes, proxies=None) -> SSHConnectionPool:
|
||||
if not self.ssh_client:
|
||||
return SSHConnectionPool(
|
||||
ssh_client=self.ssh_client,
|
||||
@ -271,7 +303,7 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
|
||||
|
||||
return pool
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
super().close()
|
||||
if self.ssh_client:
|
||||
self.ssh_client.close()
|
||||
|
||||
@ -11,9 +11,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ansible_collections.community.docker.plugins.module_utils._version import (
|
||||
LooseVersion,
|
||||
)
|
||||
import typing as t
|
||||
|
||||
from .._import_helper import HTTPAdapter, urllib3
|
||||
from .basehttpadapter import BaseHTTPAdapter
|
||||
@ -30,14 +28,19 @@ PoolManager = urllib3.poolmanager.PoolManager
|
||||
class SSLHTTPAdapter(BaseHTTPAdapter):
|
||||
"""An HTTPS Transport Adapter that uses an arbitrary SSL version."""
|
||||
|
||||
__attrs__ = HTTPAdapter.__attrs__ + ["assert_hostname", "ssl_version"]
|
||||
__attrs__ = HTTPAdapter.__attrs__ + ["assert_hostname"]
|
||||
|
||||
def __init__(self, ssl_version=None, assert_hostname=None, **kwargs):
|
||||
self.ssl_version = ssl_version
|
||||
def __init__(
|
||||
self,
|
||||
assert_hostname: bool | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.assert_hostname = assert_hostname
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def init_poolmanager(self, connections, maxsize, block=False):
|
||||
def init_poolmanager(
|
||||
self, connections: int, maxsize: int, block: bool = False, **kwargs: t.Any
|
||||
) -> None:
|
||||
kwargs = {
|
||||
"num_pools": connections,
|
||||
"maxsize": maxsize,
|
||||
@ -45,12 +48,10 @@ class SSLHTTPAdapter(BaseHTTPAdapter):
|
||||
}
|
||||
if self.assert_hostname is not None:
|
||||
kwargs["assert_hostname"] = self.assert_hostname
|
||||
if self.ssl_version and self.can_override_ssl_version():
|
||||
kwargs["ssl_version"] = self.ssl_version
|
||||
|
||||
self.poolmanager = PoolManager(**kwargs)
|
||||
|
||||
def get_connection(self, *args, **kwargs):
|
||||
def get_connection(self, *args, **kwargs) -> urllib3.ConnectionPool:
|
||||
"""
|
||||
Ensure assert_hostname is set correctly on our pool
|
||||
|
||||
@ -61,15 +62,7 @@ class SSLHTTPAdapter(BaseHTTPAdapter):
|
||||
conn = super().get_connection(*args, **kwargs)
|
||||
if (
|
||||
self.assert_hostname is not None
|
||||
and conn.assert_hostname != self.assert_hostname
|
||||
and conn.assert_hostname != self.assert_hostname # type: ignore
|
||||
):
|
||||
conn.assert_hostname = self.assert_hostname
|
||||
conn.assert_hostname = self.assert_hostname # type: ignore
|
||||
return conn
|
||||
|
||||
def can_override_ssl_version(self):
|
||||
urllib_ver = urllib3.__version__.split("-")[0]
|
||||
if urllib_ver is None:
|
||||
return False
|
||||
if urllib_ver == "dev":
|
||||
return True
|
||||
return LooseVersion(urllib_ver) > LooseVersion("1.5")
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing as t
|
||||
|
||||
from .. import constants
|
||||
from .._import_helper import HTTPAdapter, urllib3, urllib3_connection
|
||||
@ -22,26 +23,25 @@ RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
|
||||
|
||||
|
||||
class UnixHTTPConnection(urllib3_connection.HTTPConnection):
|
||||
|
||||
def __init__(self, base_url, unix_socket, timeout=60):
|
||||
def __init__(self, base_url: str | bytes, unix_socket, timeout: int = 60) -> None:
|
||||
super().__init__("localhost", timeout=timeout)
|
||||
self.base_url = base_url
|
||||
self.unix_socket = unix_socket
|
||||
self.timeout = timeout
|
||||
self.disable_buffering = False
|
||||
|
||||
def connect(self):
|
||||
def connect(self) -> None:
|
||||
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
sock.settimeout(self.timeout)
|
||||
sock.connect(self.unix_socket)
|
||||
self.sock = sock
|
||||
|
||||
def putheader(self, header, *values):
|
||||
def putheader(self, header: str, *values: str) -> None:
|
||||
super().putheader(header, *values)
|
||||
if header == "Connection" and "Upgrade" in values:
|
||||
self.disable_buffering = True
|
||||
|
||||
def response_class(self, sock, *args, **kwargs):
|
||||
def response_class(self, sock, *args, **kwargs) -> t.Any:
|
||||
# FIXME: We may need to disable buffering on Py3,
|
||||
# but there's no clear way to do it at the moment. See:
|
||||
# https://github.com/docker/docker-py/issues/1799
|
||||
@ -49,18 +49,23 @@ class UnixHTTPConnection(urllib3_connection.HTTPConnection):
|
||||
|
||||
|
||||
class UnixHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
|
||||
def __init__(self, base_url, socket_path, timeout=60, maxsize=10):
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str | bytes,
|
||||
socket_path: str,
|
||||
timeout: int = 60,
|
||||
maxsize: int = 10,
|
||||
) -> None:
|
||||
super().__init__("localhost", timeout=timeout, maxsize=maxsize)
|
||||
self.base_url = base_url
|
||||
self.socket_path = socket_path
|
||||
self.timeout = timeout
|
||||
|
||||
def _new_conn(self):
|
||||
def _new_conn(self) -> UnixHTTPConnection:
|
||||
return UnixHTTPConnection(self.base_url, self.socket_path, self.timeout)
|
||||
|
||||
|
||||
class UnixHTTPAdapter(BaseHTTPAdapter):
|
||||
|
||||
__attrs__ = HTTPAdapter.__attrs__ + [
|
||||
"pools",
|
||||
"socket_path",
|
||||
@ -70,11 +75,11 @@ class UnixHTTPAdapter(BaseHTTPAdapter):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
socket_url,
|
||||
timeout=60,
|
||||
pool_connections=constants.DEFAULT_NUM_POOLS,
|
||||
max_pool_size=constants.DEFAULT_MAX_POOL_SIZE,
|
||||
):
|
||||
socket_url: str,
|
||||
timeout: int = 60,
|
||||
pool_connections: int = constants.DEFAULT_NUM_POOLS,
|
||||
max_pool_size: int = constants.DEFAULT_MAX_POOL_SIZE,
|
||||
) -> None:
|
||||
socket_path = socket_url.replace("http+unix://", "")
|
||||
if not socket_path.startswith("/"):
|
||||
socket_path = "/" + socket_path
|
||||
@ -86,7 +91,7 @@ class UnixHTTPAdapter(BaseHTTPAdapter):
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
def get_connection(self, url, proxies=None):
|
||||
def get_connection(self, url: str | bytes, proxies=None) -> UnixHTTPConnectionPool:
|
||||
with self.pools.lock:
|
||||
pool = self.pools.get(url)
|
||||
if pool:
|
||||
@ -99,7 +104,7 @@ class UnixHTTPAdapter(BaseHTTPAdapter):
|
||||
|
||||
return pool
|
||||
|
||||
def request_url(self, request, proxies):
|
||||
def request_url(self, request, proxies) -> str:
|
||||
# The select_proxy utility in requests errors out when the provided URL
|
||||
# does not have a hostname, like is the case when using a UNIX socket.
|
||||
# Since proxies are an irrelevant notion in the case of UNIX sockets
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing as t
|
||||
|
||||
from .._import_helper import urllib3
|
||||
from ..errors import DockerException
|
||||
@ -29,11 +30,11 @@ class CancellableStream:
|
||||
>>> events.close()
|
||||
"""
|
||||
|
||||
def __init__(self, stream, response):
|
||||
def __init__(self, stream, response) -> None:
|
||||
self._stream = stream
|
||||
self._response = response
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> t.Self:
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
@ -46,7 +47,7 @@ class CancellableStream:
|
||||
|
||||
next = __next__
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Closes the event streaming.
|
||||
"""
|
||||
|
||||
@ -17,15 +17,26 @@ import random
|
||||
import re
|
||||
import tarfile
|
||||
import tempfile
|
||||
import typing as t
|
||||
|
||||
from ..constants import IS_WINDOWS_PLATFORM, WINDOWS_LONGPATH_PREFIX
|
||||
from . import fnmatch
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
_SEP = re.compile("/|\\\\") if IS_WINDOWS_PLATFORM else re.compile("/")
|
||||
|
||||
|
||||
def tar(path, exclude=None, dockerfile=None, fileobj=None, gzip=False):
|
||||
def tar(
|
||||
path: str,
|
||||
exclude: list[str] | None = None,
|
||||
dockerfile: tuple[str, str] | tuple[None, None] | None = None,
|
||||
fileobj: t.IO[bytes] | None = None,
|
||||
gzip: bool = False,
|
||||
) -> t.IO[bytes]:
|
||||
root = os.path.abspath(path)
|
||||
exclude = exclude or []
|
||||
dockerfile = dockerfile or (None, None)
|
||||
@ -47,7 +58,9 @@ def tar(path, exclude=None, dockerfile=None, fileobj=None, gzip=False):
|
||||
)
|
||||
|
||||
|
||||
def exclude_paths(root, patterns, dockerfile=None):
|
||||
def exclude_paths(
|
||||
root: str, patterns: list[str], dockerfile: str | None = None
|
||||
) -> set[str]:
|
||||
"""
|
||||
Given a root directory path and a list of .dockerignore patterns, return
|
||||
an iterator of all paths (both regular files and directories) in the root
|
||||
@ -64,7 +77,7 @@ def exclude_paths(root, patterns, dockerfile=None):
|
||||
return set(pm.walk(root))
|
||||
|
||||
|
||||
def build_file_list(root):
|
||||
def build_file_list(root: str) -> list[str]:
|
||||
files = []
|
||||
for dirname, dirnames, fnames in os.walk(root):
|
||||
for filename in fnames + dirnames:
|
||||
@ -74,7 +87,13 @@ def build_file_list(root):
|
||||
return files
|
||||
|
||||
|
||||
def create_archive(root, files=None, fileobj=None, gzip=False, extra_files=None):
|
||||
def create_archive(
|
||||
root: str,
|
||||
files: Sequence[str] | None = None,
|
||||
fileobj: t.IO[bytes] | None = None,
|
||||
gzip: bool = False,
|
||||
extra_files: Sequence[tuple[str, str]] | None = None,
|
||||
) -> t.IO[bytes]:
|
||||
extra_files = extra_files or []
|
||||
if not fileobj:
|
||||
fileobj = tempfile.NamedTemporaryFile()
|
||||
@ -92,7 +111,7 @@ def create_archive(root, files=None, fileobj=None, gzip=False, extra_files=None)
|
||||
if i is None:
|
||||
# This happens when we encounter a socket file. We can safely
|
||||
# ignore it and proceed.
|
||||
continue
|
||||
continue # type: ignore
|
||||
|
||||
# Workaround https://bugs.python.org/issue32713
|
||||
if i.mtime < 0 or i.mtime > 8**11 - 1:
|
||||
@ -124,11 +143,11 @@ def create_archive(root, files=None, fileobj=None, gzip=False, extra_files=None)
|
||||
return fileobj
|
||||
|
||||
|
||||
def mkbuildcontext(dockerfile):
|
||||
def mkbuildcontext(dockerfile: io.BytesIO | t.IO[bytes]) -> t.IO[bytes]:
|
||||
f = tempfile.NamedTemporaryFile() # pylint: disable=consider-using-with
|
||||
try:
|
||||
with tarfile.open(mode="w", fileobj=f) as t:
|
||||
if isinstance(dockerfile, io.StringIO):
|
||||
if isinstance(dockerfile, io.StringIO): # type: ignore
|
||||
raise TypeError("Please use io.BytesIO to create in-memory Dockerfiles")
|
||||
if isinstance(dockerfile, io.BytesIO):
|
||||
dfinfo = tarfile.TarInfo("Dockerfile")
|
||||
@ -144,17 +163,17 @@ def mkbuildcontext(dockerfile):
|
||||
return f
|
||||
|
||||
|
||||
def split_path(p):
|
||||
def split_path(p: str) -> list[str]:
|
||||
return [pt for pt in re.split(_SEP, p) if pt and pt != "."]
|
||||
|
||||
|
||||
def normalize_slashes(p):
|
||||
def normalize_slashes(p: str) -> str:
|
||||
if IS_WINDOWS_PLATFORM:
|
||||
return "/".join(split_path(p))
|
||||
return p
|
||||
|
||||
|
||||
def walk(root, patterns, default=True):
|
||||
def walk(root: str, patterns: Sequence[str], default: bool = True) -> t.Generator[str]:
|
||||
pm = PatternMatcher(patterns)
|
||||
return pm.walk(root)
|
||||
|
||||
@ -162,11 +181,11 @@ def walk(root, patterns, default=True):
|
||||
# Heavily based on
|
||||
# https://github.com/moby/moby/blob/master/pkg/fileutils/fileutils.go
|
||||
class PatternMatcher:
|
||||
def __init__(self, patterns):
|
||||
def __init__(self, patterns: Sequence[str]) -> None:
|
||||
self.patterns = list(filter(lambda p: p.dirs, [Pattern(p) for p in patterns]))
|
||||
self.patterns.append(Pattern("!.dockerignore"))
|
||||
|
||||
def matches(self, filepath):
|
||||
def matches(self, filepath: str) -> bool:
|
||||
matched = False
|
||||
parent_path = os.path.dirname(filepath)
|
||||
parent_path_dirs = split_path(parent_path)
|
||||
@ -185,8 +204,8 @@ class PatternMatcher:
|
||||
|
||||
return matched
|
||||
|
||||
def walk(self, root):
|
||||
def rec_walk(current_dir):
|
||||
def walk(self, root: str) -> t.Generator[str]:
|
||||
def rec_walk(current_dir: str) -> t.Generator[str]:
|
||||
for f in os.listdir(current_dir):
|
||||
fpath = os.path.join(os.path.relpath(current_dir, root), f)
|
||||
if fpath.startswith("." + os.path.sep):
|
||||
@ -220,7 +239,7 @@ class PatternMatcher:
|
||||
|
||||
|
||||
class Pattern:
|
||||
def __init__(self, pattern_str):
|
||||
def __init__(self, pattern_str: str) -> None:
|
||||
self.exclusion = False
|
||||
if pattern_str.startswith("!"):
|
||||
self.exclusion = True
|
||||
@ -230,8 +249,7 @@ class Pattern:
|
||||
self.cleaned_pattern = "/".join(self.dirs)
|
||||
|
||||
@classmethod
|
||||
def normalize(cls, p):
|
||||
|
||||
def normalize(cls, p: str) -> list[str]:
|
||||
# Remove trailing spaces
|
||||
p = p.strip()
|
||||
|
||||
@ -256,11 +274,11 @@ class Pattern:
|
||||
i += 1
|
||||
return split
|
||||
|
||||
def match(self, filepath):
|
||||
def match(self, filepath: str) -> bool:
|
||||
return fnmatch.fnmatch(normalize_slashes(filepath), self.cleaned_pattern)
|
||||
|
||||
|
||||
def process_dockerfile(dockerfile, path):
|
||||
def process_dockerfile(dockerfile: str, path: str) -> tuple[str | None, str | None]:
|
||||
if not dockerfile:
|
||||
return (None, None)
|
||||
|
||||
@ -268,7 +286,7 @@ def process_dockerfile(dockerfile, path):
|
||||
if not os.path.isabs(dockerfile):
|
||||
abs_dockerfile = os.path.join(path, dockerfile)
|
||||
if IS_WINDOWS_PLATFORM and path.startswith(WINDOWS_LONGPATH_PREFIX):
|
||||
abs_dockerfile = f"{WINDOWS_LONGPATH_PREFIX}{os.path.normpath(abs_dockerfile[len(WINDOWS_LONGPATH_PREFIX):])}"
|
||||
abs_dockerfile = f"{WINDOWS_LONGPATH_PREFIX}{os.path.normpath(abs_dockerfile[len(WINDOWS_LONGPATH_PREFIX) :])}"
|
||||
if os.path.splitdrive(path)[0] != os.path.splitdrive(abs_dockerfile)[
|
||||
0
|
||||
] or os.path.relpath(abs_dockerfile, path).startswith(".."):
|
||||
|
||||
@ -14,6 +14,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ..constants import IS_WINDOWS_PLATFORM
|
||||
|
||||
@ -24,11 +25,11 @@ LEGACY_DOCKER_CONFIG_FILENAME = ".dockercfg"
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_default_config_file():
|
||||
def get_default_config_file() -> str:
|
||||
return os.path.join(home_dir(), DOCKER_CONFIG_FILENAME)
|
||||
|
||||
|
||||
def find_config_file(config_path=None):
|
||||
def find_config_file(config_path: str | None = None) -> str | None:
|
||||
homedir = home_dir()
|
||||
paths = list(
|
||||
filter(
|
||||
@ -54,14 +55,14 @@ def find_config_file(config_path=None):
|
||||
return None
|
||||
|
||||
|
||||
def config_path_from_environment():
|
||||
def config_path_from_environment() -> str | None:
|
||||
config_dir = os.environ.get("DOCKER_CONFIG")
|
||||
if not config_dir:
|
||||
return None
|
||||
return os.path.join(config_dir, os.path.basename(DOCKER_CONFIG_FILENAME))
|
||||
|
||||
|
||||
def home_dir():
|
||||
def home_dir() -> str:
|
||||
"""
|
||||
Get the user's home directory, using the same logic as the Docker Engine
|
||||
client - use %USERPROFILE% on Windows, $HOME/getuid on POSIX.
|
||||
@ -71,7 +72,7 @@ def home_dir():
|
||||
return os.path.expanduser("~")
|
||||
|
||||
|
||||
def load_general_config(config_path=None):
|
||||
def load_general_config(config_path: str | None = None) -> dict[str, t.Any]:
|
||||
config_file = find_config_file(config_path)
|
||||
|
||||
if not config_file:
|
||||
|
||||
@ -12,16 +12,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import typing as t
|
||||
|
||||
from .. import errors
|
||||
from . import utils
|
||||
|
||||
|
||||
def minimum_version(version):
|
||||
def decorator(f):
|
||||
if t.TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from ..api.client import APIClient
|
||||
|
||||
_Self = t.TypeVar("_Self")
|
||||
_P = t.ParamSpec("_P")
|
||||
_R = t.TypeVar("_R")
|
||||
|
||||
|
||||
def minimum_version(
|
||||
version: str,
|
||||
) -> Callable[
|
||||
[Callable[t.Concatenate[_Self, _P], _R]],
|
||||
Callable[t.Concatenate[_Self, _P], _R],
|
||||
]:
|
||||
def decorator(
|
||||
f: Callable[t.Concatenate[_Self, _P], _R],
|
||||
) -> Callable[t.Concatenate[_Self, _P], _R]:
|
||||
@functools.wraps(f)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if utils.version_lt(self._version, version):
|
||||
def wrapper(self: _Self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
# We use _Self instead of APIClient since this is used for mixins for APIClient.
|
||||
# This unfortunately means that self._version does not exist in the mixin,
|
||||
# it only exists after mixing in. This is why we ignore types here.
|
||||
if utils.version_lt(self._version, version): # type: ignore
|
||||
raise errors.InvalidVersion(
|
||||
f"{f.__name__} is not available for version < {version}"
|
||||
)
|
||||
@ -32,13 +53,16 @@ def minimum_version(version):
|
||||
return decorator
|
||||
|
||||
|
||||
def update_headers(f):
|
||||
def inner(self, *args, **kwargs):
|
||||
def update_headers(
|
||||
f: Callable[t.Concatenate[APIClient, _P], _R],
|
||||
) -> Callable[t.Concatenate[APIClient, _P], _R]:
|
||||
def inner(self: APIClient, *args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
if "HttpHeaders" in self._general_configs:
|
||||
if not kwargs.get("headers"):
|
||||
kwargs["headers"] = self._general_configs["HttpHeaders"]
|
||||
else:
|
||||
kwargs["headers"].update(self._general_configs["HttpHeaders"])
|
||||
# We cannot (yet) model that kwargs["headers"] should be a dictionary
|
||||
kwargs["headers"].update(self._general_configs["HttpHeaders"]) # type: ignore
|
||||
return f(self, *args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
@ -32,12 +32,12 @@ _cache: dict[str, re.Pattern] = {}
|
||||
_MAXCACHE = 100
|
||||
|
||||
|
||||
def _purge():
|
||||
def _purge() -> None:
|
||||
"""Clear the pattern cache"""
|
||||
_cache.clear()
|
||||
|
||||
|
||||
def fnmatch(name, pat):
|
||||
def fnmatch(name: str, pat: str):
|
||||
"""Test whether FILENAME matches PATTERN.
|
||||
|
||||
Patterns are Unix shell style:
|
||||
@ -58,7 +58,7 @@ def fnmatch(name, pat):
|
||||
return fnmatchcase(name, pat)
|
||||
|
||||
|
||||
def fnmatchcase(name, pat):
|
||||
def fnmatchcase(name: str, pat: str) -> bool:
|
||||
"""Test whether FILENAME matches PATTERN, including case.
|
||||
This is a version of fnmatch() which does not case-normalize
|
||||
its arguments.
|
||||
@ -74,7 +74,7 @@ def fnmatchcase(name, pat):
|
||||
return re_pat.match(name) is not None
|
||||
|
||||
|
||||
def translate(pat):
|
||||
def translate(pat: str) -> str:
|
||||
"""Translate a shell PATTERN to a regular expression.
|
||||
|
||||
There is no way to quote meta-characters.
|
||||
|
||||
@ -13,14 +13,22 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import json.decoder
|
||||
import typing as t
|
||||
|
||||
from ..errors import StreamParseError
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
|
||||
_T = t.TypeVar("_T")
|
||||
|
||||
|
||||
json_decoder = json.JSONDecoder()
|
||||
|
||||
|
||||
def stream_as_text(stream):
|
||||
def stream_as_text(stream: t.Generator[bytes | str]) -> t.Generator[str]:
|
||||
"""
|
||||
Given a stream of bytes or text, if any of the items in the stream
|
||||
are bytes convert them to text.
|
||||
@ -33,20 +41,22 @@ def stream_as_text(stream):
|
||||
yield data
|
||||
|
||||
|
||||
def json_splitter(buffer):
|
||||
def json_splitter(buffer: str) -> tuple[t.Any, str] | None:
|
||||
"""Attempt to parse a json object from a buffer. If there is at least one
|
||||
object, return it and the rest of the buffer, otherwise return None.
|
||||
"""
|
||||
buffer = buffer.strip()
|
||||
try:
|
||||
obj, index = json_decoder.raw_decode(buffer)
|
||||
rest = buffer[json.decoder.WHITESPACE.match(buffer, index).end() :]
|
||||
ws: re.Pattern = json.decoder.WHITESPACE # type: ignore[attr-defined]
|
||||
m = ws.match(buffer, index)
|
||||
rest = buffer[m.end() :] if m else buffer[index:]
|
||||
return obj, rest
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def json_stream(stream):
|
||||
def json_stream(stream: t.Generator[str | bytes]) -> t.Generator[t.Any]:
|
||||
"""Given a stream of text, return a stream of json objects.
|
||||
This handles streams which are inconsistently buffered (some entries may
|
||||
be newline delimited, and others are not).
|
||||
@ -54,21 +64,24 @@ def json_stream(stream):
|
||||
return split_buffer(stream, json_splitter, json_decoder.decode)
|
||||
|
||||
|
||||
def line_splitter(buffer, separator="\n"):
|
||||
def line_splitter(buffer: str, separator: str = "\n") -> tuple[str, str] | None:
|
||||
index = buffer.find(str(separator))
|
||||
if index == -1:
|
||||
return None
|
||||
return buffer[: index + 1], buffer[index + 1 :]
|
||||
|
||||
|
||||
def split_buffer(stream, splitter=None, decoder=lambda a: a):
|
||||
def split_buffer(
|
||||
stream: t.Generator[str | bytes],
|
||||
splitter: Callable[[str], tuple[_T, str] | None],
|
||||
decoder: Callable[[str], _T],
|
||||
) -> t.Generator[_T | str]:
|
||||
"""Given a generator which yields strings and a splitter function,
|
||||
joins all input, splits on the separator and yields each chunk.
|
||||
Unlike string.split(), each chunk includes the trailing
|
||||
separator, except for the last one if none was found on the end
|
||||
of the input.
|
||||
"""
|
||||
splitter = splitter or line_splitter
|
||||
buffered = ""
|
||||
|
||||
for data in stream_as_text(stream):
|
||||
|
||||
@ -12,6 +12,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import typing as t
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from collections.abc import Collection, Sequence
|
||||
|
||||
|
||||
PORT_SPEC = re.compile(
|
||||
@ -26,32 +31,42 @@ PORT_SPEC = re.compile(
|
||||
)
|
||||
|
||||
|
||||
def add_port_mapping(port_bindings, internal_port, external):
|
||||
def add_port_mapping(
|
||||
port_bindings: dict[str, list[str | tuple[str, str | None] | None]],
|
||||
internal_port: str,
|
||||
external: str | tuple[str, str | None] | None,
|
||||
) -> None:
|
||||
if internal_port in port_bindings:
|
||||
port_bindings[internal_port].append(external)
|
||||
else:
|
||||
port_bindings[internal_port] = [external]
|
||||
|
||||
|
||||
def add_port(port_bindings, internal_port_range, external_range):
|
||||
def add_port(
|
||||
port_bindings: dict[str, list[str | tuple[str, str | None] | None]],
|
||||
internal_port_range: list[str],
|
||||
external_range: list[str] | list[tuple[str, str | None]] | None,
|
||||
) -> None:
|
||||
if external_range is None:
|
||||
for internal_port in internal_port_range:
|
||||
add_port_mapping(port_bindings, internal_port, None)
|
||||
else:
|
||||
ports = zip(internal_port_range, external_range)
|
||||
for internal_port, external_port in ports:
|
||||
add_port_mapping(port_bindings, internal_port, external_port)
|
||||
for internal_port, external_port in zip(internal_port_range, external_range):
|
||||
# mypy loses the exact type of eternal_port elements for some reason...
|
||||
add_port_mapping(port_bindings, internal_port, external_port) # type: ignore
|
||||
|
||||
|
||||
def build_port_bindings(ports):
|
||||
port_bindings = {}
|
||||
def build_port_bindings(
|
||||
ports: Collection[str],
|
||||
) -> dict[str, list[str | tuple[str, str | None] | None]]:
|
||||
port_bindings: dict[str, list[str | tuple[str, str | None] | None]] = {}
|
||||
for port in ports:
|
||||
internal_port_range, external_range = split_port(port)
|
||||
add_port(port_bindings, internal_port_range, external_range)
|
||||
return port_bindings
|
||||
|
||||
|
||||
def _raise_invalid_port(port):
|
||||
def _raise_invalid_port(port: str) -> t.NoReturn:
|
||||
raise ValueError(
|
||||
f'Invalid port "{port}", should be '
|
||||
"[[remote_ip:]remote_port[-remote_port]:]"
|
||||
@ -59,39 +74,64 @@ def _raise_invalid_port(port):
|
||||
)
|
||||
|
||||
|
||||
def port_range(start, end, proto, randomly_available_port=False):
|
||||
if not start:
|
||||
@t.overload
|
||||
def port_range(
|
||||
start: str,
|
||||
end: str | None,
|
||||
proto: str,
|
||||
randomly_available_port: bool = False,
|
||||
) -> list[str]: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def port_range(
|
||||
start: str | None,
|
||||
end: str | None,
|
||||
proto: str,
|
||||
randomly_available_port: bool = False,
|
||||
) -> list[str] | None: ...
|
||||
|
||||
|
||||
def port_range(
|
||||
start: str | None,
|
||||
end: str | None,
|
||||
proto: str,
|
||||
randomly_available_port: bool = False,
|
||||
) -> list[str] | None:
|
||||
if start is None:
|
||||
return start
|
||||
if not end:
|
||||
if end is None:
|
||||
return [f"{start}{proto}"]
|
||||
if randomly_available_port:
|
||||
return [f"{start}-{end}{proto}"]
|
||||
return [f"{port}{proto}" for port in range(int(start), int(end) + 1)]
|
||||
|
||||
|
||||
def split_port(port):
|
||||
if hasattr(port, "legacy_repr"):
|
||||
# This is the worst hack, but it prevents a bug in Compose 1.14.0
|
||||
# https://github.com/docker/docker-py/issues/1668
|
||||
# TODO: remove once fixed in Compose stable
|
||||
port = port.legacy_repr()
|
||||
def split_port(
|
||||
port: str,
|
||||
) -> tuple[list[str], list[str] | list[tuple[str, str | None]] | None]:
|
||||
port = str(port)
|
||||
match = PORT_SPEC.match(port)
|
||||
if match is None:
|
||||
_raise_invalid_port(port)
|
||||
parts = match.groupdict()
|
||||
|
||||
host = parts["host"]
|
||||
proto = parts["proto"] or ""
|
||||
internal = port_range(parts["int"], parts["int_end"], proto)
|
||||
external = port_range(parts["ext"], parts["ext_end"], "", len(internal) == 1)
|
||||
host: str | None = parts["host"]
|
||||
proto: str = parts["proto"] or ""
|
||||
int_p: str = parts["int"]
|
||||
ext_p: str | None = parts["ext"] or None
|
||||
internal: list[str] = port_range(int_p, parts["int_end"], proto) # type: ignore
|
||||
external = port_range(ext_p, parts["ext_end"], "", len(internal) == 1)
|
||||
|
||||
if host is None:
|
||||
if external is not None and len(internal) != len(external):
|
||||
raise ValueError("Port ranges don't match in length")
|
||||
return internal, external
|
||||
external_or_none: Sequence[str | None]
|
||||
if not external:
|
||||
external = [None] * len(internal)
|
||||
elif len(internal) != len(external):
|
||||
raise ValueError("Port ranges don't match in length")
|
||||
return internal, [(host, ext_port) for ext_port in external]
|
||||
external_or_none = [None] * len(internal)
|
||||
else:
|
||||
external_or_none = external
|
||||
if len(internal) != len(external_or_none):
|
||||
raise ValueError("Port ranges don't match in length")
|
||||
return internal, [(host, ext_port) for ext_port in external_or_none]
|
||||
|
||||
@ -20,23 +20,23 @@ class ProxyConfig(dict):
|
||||
"""
|
||||
|
||||
@property
|
||||
def http(self):
|
||||
def http(self) -> str | None:
|
||||
return self.get("http")
|
||||
|
||||
@property
|
||||
def https(self):
|
||||
def https(self) -> str | None:
|
||||
return self.get("https")
|
||||
|
||||
@property
|
||||
def ftp(self):
|
||||
def ftp(self) -> str | None:
|
||||
return self.get("ftp")
|
||||
|
||||
@property
|
||||
def no_proxy(self):
|
||||
def no_proxy(self) -> str | None:
|
||||
return self.get("no_proxy")
|
||||
|
||||
@staticmethod
|
||||
def from_dict(config):
|
||||
def from_dict(config: dict[str, str]) -> ProxyConfig:
|
||||
"""
|
||||
Instantiate a new ProxyConfig from a dictionary that represents a
|
||||
client configuration, as described in `the documentation`_.
|
||||
@ -51,7 +51,7 @@ class ProxyConfig(dict):
|
||||
no_proxy=config.get("noProxy"),
|
||||
)
|
||||
|
||||
def get_environment(self):
|
||||
def get_environment(self) -> dict[str, str]:
|
||||
"""
|
||||
Return a dictionary representing the environment variables used to
|
||||
set the proxy settings.
|
||||
@ -67,7 +67,7 @@ class ProxyConfig(dict):
|
||||
env["no_proxy"] = env["NO_PROXY"] = self.no_proxy
|
||||
return env
|
||||
|
||||
def inject_proxy_environment(self, environment):
|
||||
def inject_proxy_environment(self, environment: list[str]) -> list[str]:
|
||||
"""
|
||||
Given a list of strings representing environment variables, prepend the
|
||||
environment variables corresponding to the proxy settings.
|
||||
@ -82,5 +82,5 @@ class ProxyConfig(dict):
|
||||
# variables defined in "environment" to take precedence.
|
||||
return proxy_env + environment
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"ProxyConfig(http={self.http}, https={self.https}, ftp={self.ftp}, no_proxy={self.no_proxy})"
|
||||
|
||||
@ -16,10 +16,15 @@ import os
|
||||
import select
|
||||
import socket as pysocket
|
||||
import struct
|
||||
import typing as t
|
||||
|
||||
from ..transport.npipesocket import NpipeSocket
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
|
||||
STDOUT = 1
|
||||
STDERR = 2
|
||||
|
||||
@ -33,7 +38,7 @@ class SocketError(Exception):
|
||||
NPIPE_ENDED = 109
|
||||
|
||||
|
||||
def read(socket, n=4096):
|
||||
def read(socket, n: int = 4096) -> bytes | None:
|
||||
"""
|
||||
Reads at most n bytes from socket
|
||||
"""
|
||||
@ -58,6 +63,7 @@ def read(socket, n=4096):
|
||||
except EnvironmentError as e:
|
||||
if e.errno not in recoverable_errors:
|
||||
raise
|
||||
return None # TODO ???
|
||||
except Exception as e:
|
||||
is_pipe_ended = (
|
||||
isinstance(socket, NpipeSocket)
|
||||
@ -67,11 +73,11 @@ def read(socket, n=4096):
|
||||
if is_pipe_ended:
|
||||
# npipes do not support duplex sockets, so we interpret
|
||||
# a PIPE_ENDED error as a close operation (0-length read).
|
||||
return ""
|
||||
return b""
|
||||
raise
|
||||
|
||||
|
||||
def read_exactly(socket, n):
|
||||
def read_exactly(socket, n: int) -> bytes:
|
||||
"""
|
||||
Reads exactly n bytes from socket
|
||||
Raises SocketError if there is not enough data
|
||||
@ -85,7 +91,7 @@ def read_exactly(socket, n):
|
||||
return data
|
||||
|
||||
|
||||
def next_frame_header(socket):
|
||||
def next_frame_header(socket) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the stream and size of the next frame of data waiting to be read
|
||||
from socket, according to the protocol defined here:
|
||||
@ -101,7 +107,7 @@ def next_frame_header(socket):
|
||||
return (stream, actual)
|
||||
|
||||
|
||||
def frames_iter(socket, tty):
|
||||
def frames_iter(socket, tty: bool) -> t.Generator[tuple[int, bytes]]:
|
||||
"""
|
||||
Return a generator of frames read from socket. A frame is a tuple where
|
||||
the first item is the stream number and the second item is a chunk of data.
|
||||
@ -114,7 +120,7 @@ def frames_iter(socket, tty):
|
||||
return frames_iter_no_tty(socket)
|
||||
|
||||
|
||||
def frames_iter_no_tty(socket):
|
||||
def frames_iter_no_tty(socket) -> t.Generator[tuple[int, bytes]]:
|
||||
"""
|
||||
Returns a generator of data read from the socket when the tty setting is
|
||||
not enabled.
|
||||
@ -135,20 +141,34 @@ def frames_iter_no_tty(socket):
|
||||
yield (stream, result)
|
||||
|
||||
|
||||
def frames_iter_tty(socket):
|
||||
def frames_iter_tty(socket) -> t.Generator[bytes]:
|
||||
"""
|
||||
Return a generator of data read from the socket when the tty setting is
|
||||
enabled.
|
||||
"""
|
||||
while True:
|
||||
result = read(socket)
|
||||
if len(result) == 0:
|
||||
if not result:
|
||||
# We have reached EOF
|
||||
return
|
||||
yield result
|
||||
|
||||
|
||||
def consume_socket_output(frames, demux=False):
|
||||
@t.overload
|
||||
def consume_socket_output(frames, demux: t.Literal[False] = False) -> bytes: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def consume_socket_output(frames, demux: t.Literal[True]) -> tuple[bytes, bytes]: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def consume_socket_output(
|
||||
frames, demux: bool = False
|
||||
) -> bytes | tuple[bytes, bytes]: ...
|
||||
|
||||
|
||||
def consume_socket_output(frames, demux: bool = False) -> bytes | tuple[bytes, bytes]:
|
||||
"""
|
||||
Iterate through frames read from the socket and return the result.
|
||||
|
||||
@ -167,7 +187,7 @@ def consume_socket_output(frames, demux=False):
|
||||
|
||||
# If the streams are demultiplexed, the generator yields tuples
|
||||
# (stdout, stderr)
|
||||
out = [None, None]
|
||||
out: list[bytes | None] = [None, None]
|
||||
for frame in frames:
|
||||
# It is guaranteed that for each frame, one and only one stream
|
||||
# is not None.
|
||||
@ -183,10 +203,10 @@ def consume_socket_output(frames, demux=False):
|
||||
out[1] = frame[1]
|
||||
else:
|
||||
out[1] += frame[1]
|
||||
return tuple(out)
|
||||
return tuple(out) # type: ignore
|
||||
|
||||
|
||||
def demux_adaptor(stream_id, data):
|
||||
def demux_adaptor(stream_id: int, data: bytes) -> tuple[bytes | None, bytes | None]:
|
||||
"""
|
||||
Utility to demultiplex stdout and stderr when reading frames from the
|
||||
socket.
|
||||
|
||||
@ -18,6 +18,7 @@ import os
|
||||
import os.path
|
||||
import shlex
|
||||
import string
|
||||
import typing as t
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
from ansible_collections.community.docker.plugins.module_utils._version import (
|
||||
@ -34,32 +35,23 @@ from ..constants import (
|
||||
from ..tls import TLSConfig
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import ssl
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
|
||||
URLComponents = collections.namedtuple(
|
||||
"URLComponents",
|
||||
"scheme netloc url params query fragment",
|
||||
)
|
||||
|
||||
|
||||
def create_ipam_pool(*args, **kwargs):
|
||||
raise errors.DeprecatedMethod(
|
||||
"utils.create_ipam_pool has been removed. Please use a "
|
||||
"docker.types.IPAMPool object instead."
|
||||
)
|
||||
|
||||
|
||||
def create_ipam_config(*args, **kwargs):
|
||||
raise errors.DeprecatedMethod(
|
||||
"utils.create_ipam_config has been removed. Please use a "
|
||||
"docker.types.IPAMConfig object instead."
|
||||
)
|
||||
|
||||
|
||||
def decode_json_header(header):
|
||||
def decode_json_header(header: str) -> dict[str, t.Any]:
|
||||
data = base64.b64decode(header).decode("utf-8")
|
||||
return json.loads(data)
|
||||
|
||||
|
||||
def compare_version(v1, v2):
|
||||
def compare_version(v1: str, v2: str) -> t.Literal[-1, 0, 1]:
|
||||
"""Compare docker versions
|
||||
|
||||
>>> v1 = '1.9'
|
||||
@ -80,43 +72,64 @@ def compare_version(v1, v2):
|
||||
return 1
|
||||
|
||||
|
||||
def version_lt(v1, v2):
|
||||
def version_lt(v1: str, v2: str) -> bool:
|
||||
return compare_version(v1, v2) > 0
|
||||
|
||||
|
||||
def version_gte(v1, v2):
|
||||
def version_gte(v1: str, v2: str) -> bool:
|
||||
return not version_lt(v1, v2)
|
||||
|
||||
|
||||
def _convert_port_binding(binding):
|
||||
def _convert_port_binding(
|
||||
binding: (
|
||||
tuple[str, str | int | None]
|
||||
| tuple[str | int | None]
|
||||
| dict[str, str]
|
||||
| str
|
||||
| int
|
||||
),
|
||||
) -> dict[str, str]:
|
||||
result = {"HostIp": "", "HostPort": ""}
|
||||
host_port: str | int | None = ""
|
||||
if isinstance(binding, tuple):
|
||||
if len(binding) == 2:
|
||||
result["HostPort"] = binding[1]
|
||||
host_port = binding[1] # type: ignore
|
||||
result["HostIp"] = binding[0]
|
||||
elif isinstance(binding[0], str):
|
||||
result["HostIp"] = binding[0]
|
||||
else:
|
||||
result["HostPort"] = binding[0]
|
||||
host_port = binding[0]
|
||||
elif isinstance(binding, dict):
|
||||
if "HostPort" in binding:
|
||||
result["HostPort"] = binding["HostPort"]
|
||||
host_port = binding["HostPort"]
|
||||
if "HostIp" in binding:
|
||||
result["HostIp"] = binding["HostIp"]
|
||||
else:
|
||||
raise ValueError(binding)
|
||||
else:
|
||||
result["HostPort"] = binding
|
||||
|
||||
if result["HostPort"] is None:
|
||||
result["HostPort"] = ""
|
||||
else:
|
||||
result["HostPort"] = str(result["HostPort"])
|
||||
host_port = binding
|
||||
|
||||
result["HostPort"] = str(host_port) if host_port is not None else ""
|
||||
return result
|
||||
|
||||
|
||||
def convert_port_bindings(port_bindings):
|
||||
def convert_port_bindings(
|
||||
port_bindings: dict[
|
||||
str | int,
|
||||
tuple[str, str | int | None]
|
||||
| tuple[str | int | None]
|
||||
| dict[str, str]
|
||||
| str
|
||||
| int
|
||||
| list[
|
||||
tuple[str, str | int | None]
|
||||
| tuple[str | int | None]
|
||||
| dict[str, str]
|
||||
| str
|
||||
| int
|
||||
],
|
||||
],
|
||||
) -> dict[str, list[dict[str, str]]]:
|
||||
result = {}
|
||||
for k, v in port_bindings.items():
|
||||
key = str(k)
|
||||
@ -129,9 +142,11 @@ def convert_port_bindings(port_bindings):
|
||||
return result
|
||||
|
||||
|
||||
def convert_volume_binds(binds):
|
||||
def convert_volume_binds(
|
||||
binds: list[str] | Mapping[str | bytes, dict[str, str | bytes] | bytes | str | int],
|
||||
) -> list[str]:
|
||||
if isinstance(binds, list):
|
||||
return binds
|
||||
return binds # type: ignore
|
||||
|
||||
result = []
|
||||
for k, v in binds.items():
|
||||
@ -149,7 +164,7 @@ def convert_volume_binds(binds):
|
||||
if "ro" in v:
|
||||
mode = "ro" if v["ro"] else "rw"
|
||||
elif "mode" in v:
|
||||
mode = v["mode"]
|
||||
mode = v["mode"] # type: ignore # TODO
|
||||
else:
|
||||
mode = "rw"
|
||||
|
||||
@ -165,9 +180,9 @@ def convert_volume_binds(binds):
|
||||
]
|
||||
if "propagation" in v and v["propagation"] in propagation_modes:
|
||||
if mode:
|
||||
mode = ",".join([mode, v["propagation"]])
|
||||
mode = ",".join([mode, v["propagation"]]) # type: ignore # TODO
|
||||
else:
|
||||
mode = v["propagation"]
|
||||
mode = v["propagation"] # type: ignore # TODO
|
||||
|
||||
result.append(f"{k}:{bind}:{mode}")
|
||||
else:
|
||||
@ -177,7 +192,7 @@ def convert_volume_binds(binds):
|
||||
return result
|
||||
|
||||
|
||||
def convert_tmpfs_mounts(tmpfs):
|
||||
def convert_tmpfs_mounts(tmpfs: dict[str, str] | list[str]) -> dict[str, str]:
|
||||
if isinstance(tmpfs, dict):
|
||||
return tmpfs
|
||||
|
||||
@ -204,9 +219,11 @@ def convert_tmpfs_mounts(tmpfs):
|
||||
return result
|
||||
|
||||
|
||||
def convert_service_networks(networks):
|
||||
def convert_service_networks(
|
||||
networks: list[str | dict[str, str]],
|
||||
) -> list[dict[str, str]]:
|
||||
if not networks:
|
||||
return networks
|
||||
return networks # type: ignore
|
||||
if not isinstance(networks, list):
|
||||
raise TypeError("networks parameter must be a list.")
|
||||
|
||||
@ -218,17 +235,17 @@ def convert_service_networks(networks):
|
||||
return result
|
||||
|
||||
|
||||
def parse_repository_tag(repo_name):
|
||||
def parse_repository_tag(repo_name: str) -> tuple[str, str | None]:
|
||||
parts = repo_name.rsplit("@", 1)
|
||||
if len(parts) == 2:
|
||||
return tuple(parts)
|
||||
return tuple(parts) # type: ignore
|
||||
parts = repo_name.rsplit(":", 1)
|
||||
if len(parts) == 2 and "/" not in parts[1]:
|
||||
return tuple(parts)
|
||||
return tuple(parts) # type: ignore
|
||||
return repo_name, None
|
||||
|
||||
|
||||
def parse_host(addr, is_win32=False, tls=False):
|
||||
def parse_host(addr: str | None, is_win32: bool = False, tls: bool = False) -> str:
|
||||
# Sensible defaults
|
||||
if not addr and is_win32:
|
||||
return DEFAULT_NPIPE
|
||||
@ -308,7 +325,7 @@ def parse_host(addr, is_win32=False, tls=False):
|
||||
).rstrip("/")
|
||||
|
||||
|
||||
def parse_devices(devices):
|
||||
def parse_devices(devices: Sequence[dict[str, str] | str]) -> list[dict[str, str]]:
|
||||
device_list = []
|
||||
for device in devices:
|
||||
if isinstance(device, dict):
|
||||
@ -337,7 +354,10 @@ def parse_devices(devices):
|
||||
return device_list
|
||||
|
||||
|
||||
def kwargs_from_env(ssl_version=None, assert_hostname=None, environment=None):
|
||||
def kwargs_from_env(
|
||||
assert_hostname: bool | None = None,
|
||||
environment: Mapping[str, str] | None = None,
|
||||
) -> dict[str, t.Any]:
|
||||
if not environment:
|
||||
environment = os.environ
|
||||
host = environment.get("DOCKER_HOST")
|
||||
@ -347,14 +367,14 @@ def kwargs_from_env(ssl_version=None, assert_hostname=None, environment=None):
|
||||
|
||||
# empty string for tls verify counts as "false".
|
||||
# Any value or 'unset' counts as true.
|
||||
tls_verify = environment.get("DOCKER_TLS_VERIFY")
|
||||
if tls_verify == "":
|
||||
tls_verify_str = environment.get("DOCKER_TLS_VERIFY")
|
||||
if tls_verify_str == "":
|
||||
tls_verify = False
|
||||
else:
|
||||
tls_verify = tls_verify is not None
|
||||
tls_verify = tls_verify_str is not None
|
||||
enable_tls = cert_path or tls_verify
|
||||
|
||||
params = {}
|
||||
params: dict[str, t.Any] = {}
|
||||
|
||||
if host:
|
||||
params["base_url"] = host
|
||||
@ -377,14 +397,13 @@ def kwargs_from_env(ssl_version=None, assert_hostname=None, environment=None):
|
||||
),
|
||||
ca_cert=os.path.join(cert_path, "ca.pem"),
|
||||
verify=tls_verify,
|
||||
ssl_version=ssl_version,
|
||||
assert_hostname=assert_hostname,
|
||||
)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def convert_filters(filters):
|
||||
def convert_filters(filters: dict[str, bool | str | list[str]]) -> str:
|
||||
result = {}
|
||||
for k, v in filters.items():
|
||||
if isinstance(v, bool):
|
||||
@ -397,7 +416,7 @@ def convert_filters(filters):
|
||||
return json.dumps(result)
|
||||
|
||||
|
||||
def parse_bytes(s):
|
||||
def parse_bytes(s: int | float | str) -> int | float:
|
||||
if isinstance(s, (int, float)):
|
||||
return s
|
||||
if len(s) == 0:
|
||||
@ -435,14 +454,16 @@ def parse_bytes(s):
|
||||
return s
|
||||
|
||||
|
||||
def normalize_links(links):
|
||||
def normalize_links(links: dict[str, str] | Sequence[tuple[str, str]]) -> list[str]:
|
||||
if isinstance(links, dict):
|
||||
links = links.items()
|
||||
sorted_links = sorted(links.items())
|
||||
else:
|
||||
sorted_links = sorted(links)
|
||||
|
||||
return [f"{k}:{v}" if v else k for k, v in sorted(links)]
|
||||
return [f"{k}:{v}" if v else k for k, v in sorted_links]
|
||||
|
||||
|
||||
def parse_env_file(env_file):
|
||||
def parse_env_file(env_file: str | os.PathLike) -> dict[str, str]:
|
||||
"""
|
||||
Reads a line-separated environment file.
|
||||
The format of each line should be "key=value".
|
||||
@ -451,7 +472,6 @@ def parse_env_file(env_file):
|
||||
|
||||
with open(env_file, "rt", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
|
||||
if line[0] == "#":
|
||||
continue
|
||||
|
||||
@ -471,11 +491,11 @@ def parse_env_file(env_file):
|
||||
return environment
|
||||
|
||||
|
||||
def split_command(command):
|
||||
def split_command(command: str) -> list[str]:
|
||||
return shlex.split(command)
|
||||
|
||||
|
||||
def format_environment(environment):
|
||||
def format_environment(environment: Mapping[str, str | bytes]) -> list[str]:
|
||||
def format_env(key, value):
|
||||
if value is None:
|
||||
return key
|
||||
@ -487,16 +507,9 @@ def format_environment(environment):
|
||||
return [format_env(*var) for var in environment.items()]
|
||||
|
||||
|
||||
def format_extra_hosts(extra_hosts, task=False):
|
||||
def format_extra_hosts(extra_hosts: Mapping[str, str], task: bool = False) -> list[str]:
|
||||
# Use format dictated by Swarm API if container is part of a task
|
||||
if task:
|
||||
return [f"{v} {k}" for k, v in sorted(extra_hosts.items())]
|
||||
|
||||
return [f"{k}:{v}" for k, v in sorted(extra_hosts.items())]
|
||||
|
||||
|
||||
def create_host_config(self, *args, **kwargs):
|
||||
raise errors.DeprecatedMethod(
|
||||
"utils.create_host_config has been removed. Please use a "
|
||||
"docker.types.HostConfig object instead."
|
||||
)
|
||||
|
||||
@ -552,7 +552,7 @@ def _execute_command(
|
||||
|
||||
result = client.get_json("/exec/{0}/json", exec_id)
|
||||
|
||||
rc = result.get("ExitCode") or 0
|
||||
rc: int = result.get("ExitCode") or 0
|
||||
stdout = stdout or b""
|
||||
stderr = stderr or b""
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user