community.docker/plugins/module_utils/_api/utils/utils.py
Felix Fontein dbc7b0ec18
Cleanup with ruff check (#1182)
* Implement improvements suggested by ruff check.

* Add ruff check to CI.
2025-10-28 06:58:15 +01:00

521 lines
15 KiB
Python

# This code is part of the Ansible collection community.docker, but is an independent component.
# This particular file, and this file only, is based on the Docker SDK for Python (https://github.com/docker/docker-py/)
#
# Copyright (c) 2016-2022 Docker, Inc.
#
# It is licensed under the Apache 2.0 license (see LICENSES/Apache-2.0.txt in this collection)
# SPDX-License-Identifier: Apache-2.0
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
# Do not use this from other collections or standalone plugins/modules!
from __future__ import annotations
import base64
import collections
import json
import os
import os.path
import shlex
import string
import typing as t
from urllib.parse import urlparse, urlunparse
from ansible_collections.community.docker.plugins.module_utils._version import (
StrictVersion,
)
from .. import errors
from ..constants import (
BYTE_UNITS,
DEFAULT_HTTP_HOST,
DEFAULT_NPIPE,
DEFAULT_UNIX_SOCKET,
)
from ..tls import TLSConfig
if t.TYPE_CHECKING:
from collections.abc import Mapping, Sequence
URLComponents = collections.namedtuple(
"URLComponents",
"scheme netloc url params query fragment",
)
def decode_json_header(header: str | bytes) -> dict[str, t.Any]:
data = base64.b64decode(header).decode("utf-8")
return json.loads(data)
def compare_version(v1: str, v2: str) -> t.Literal[-1, 0, 1]:
"""Compare docker versions
>>> v1 = '1.9'
>>> v2 = '1.10'
>>> compare_version(v1, v2)
1
>>> compare_version(v2, v1)
-1
>>> compare_version(v2, v2)
0
"""
s1 = StrictVersion(v1)
s2 = StrictVersion(v2)
if s1 == s2:
return 0
if s1 > s2:
return -1
return 1
def version_lt(v1: str, v2: str) -> bool:
return compare_version(v1, v2) > 0
def version_gte(v1: str, v2: str) -> bool:
return not version_lt(v1, v2)
def _convert_port_binding(
binding: (
tuple[str, str | int | None]
| tuple[str | int | None]
| dict[str, str]
| str
| int
),
) -> dict[str, str]:
result = {"HostIp": "", "HostPort": ""}
host_port: str | int | None = ""
if isinstance(binding, tuple):
if len(binding) == 2:
host_port = binding[1] # type: ignore
result["HostIp"] = binding[0]
elif isinstance(binding[0], str):
result["HostIp"] = binding[0]
else:
host_port = binding[0]
elif isinstance(binding, dict):
if "HostPort" in binding:
host_port = binding["HostPort"]
if "HostIp" in binding:
result["HostIp"] = binding["HostIp"]
else:
raise ValueError(binding)
else:
host_port = binding
result["HostPort"] = str(host_port) if host_port is not None else ""
return result
def convert_port_bindings(
port_bindings: dict[
str | int,
tuple[str, str | int | None]
| tuple[str | int | None]
| dict[str, str]
| str
| int
| list[
tuple[str, str | int | None]
| tuple[str | int | None]
| dict[str, str]
| str
| int
],
],
) -> dict[str, list[dict[str, str]]]:
result = {}
for k, v in port_bindings.items():
key = str(k)
if "/" not in key:
key += "/tcp"
if isinstance(v, list):
result[key] = [_convert_port_binding(binding) for binding in v]
else:
result[key] = [_convert_port_binding(v)]
return result
def convert_volume_binds(
binds: (
list[str]
| Mapping[
str | bytes, dict[str, str | bytes] | dict[str, str] | bytes | str | int
]
),
) -> list[str]:
if isinstance(binds, list):
return binds # type: ignore
result = []
for k, v in binds.items():
if isinstance(k, bytes):
k = k.decode("utf-8")
if isinstance(v, dict):
if "ro" in v and "mode" in v:
raise ValueError(f'Binding cannot contain both "ro" and "mode": {v!r}')
bind = v["bind"]
if isinstance(bind, bytes):
bind = bind.decode("utf-8")
if "ro" in v:
mode = "ro" if v["ro"] else "rw"
elif "mode" in v:
mode = v["mode"] # type: ignore # TODO
else:
mode = "rw"
# NOTE: this is only relevant for Linux hosts
# (does not apply in Docker Desktop)
propagation_modes = [
"rshared",
"shared",
"rslave",
"slave",
"rprivate",
"private",
]
if "propagation" in v and v["propagation"] in propagation_modes:
if mode:
mode = ",".join([mode, v["propagation"]]) # type: ignore # TODO
else:
mode = v["propagation"] # type: ignore # TODO
result.append(f"{k}:{bind}:{mode}")
else:
if isinstance(v, bytes):
v = v.decode("utf-8")
result.append(f"{k}:{v}:rw")
return result
def convert_tmpfs_mounts(tmpfs: dict[str, str] | list[str]) -> dict[str, str]:
if isinstance(tmpfs, dict):
return tmpfs
if not isinstance(tmpfs, list):
raise ValueError(
f"Expected tmpfs value to be either a list or a dict, found: {type(tmpfs).__name__}"
)
result = {}
for mount in tmpfs:
if isinstance(mount, str):
if ":" in mount:
name, options = mount.split(":", 1)
else:
name = mount
options = ""
else:
raise ValueError(
f"Expected item in tmpfs list to be a string, found: {type(mount).__name__}"
)
result[name] = options
return result
def convert_service_networks(
networks: list[str | dict[str, str]],
) -> list[dict[str, str]]:
if not networks:
return networks # type: ignore
if not isinstance(networks, list):
raise TypeError("networks parameter must be a list.")
result = []
for n in networks:
if isinstance(n, str):
n = {"Target": n}
result.append(n)
return result
def parse_repository_tag(repo_name: str) -> tuple[str, str | None]:
parts = repo_name.rsplit("@", 1)
if len(parts) == 2:
return tuple(parts) # type: ignore
parts = repo_name.rsplit(":", 1)
if len(parts) == 2 and "/" not in parts[1]:
return tuple(parts) # type: ignore
return repo_name, None
def parse_host(addr: str | None, is_win32: bool = False, tls: bool = False) -> str:
# Sensible defaults
if not addr and is_win32:
return DEFAULT_NPIPE
if not addr or addr.strip() == "unix://":
return DEFAULT_UNIX_SOCKET
addr = addr.strip()
parsed_url = urlparse(addr)
proto = parsed_url.scheme
if not proto or any(x not in string.ascii_letters + "+" for x in proto):
# https://bugs.python.org/issue754016
parsed_url = urlparse("//" + addr, "tcp")
proto = "tcp"
if proto == "fd":
raise errors.DockerException("fd protocol is not implemented")
# These protos are valid aliases for our library but not for the
# official spec
if proto in ("http", "https"):
tls = proto == "https"
proto = "tcp"
elif proto == "http+unix":
proto = "unix"
if proto not in ("tcp", "unix", "npipe", "ssh"):
raise errors.DockerException(f"Invalid bind address protocol: {addr}")
if proto == "tcp" and not parsed_url.netloc:
# "tcp://" is exceptionally disallowed by convention;
# omitting a hostname for other protocols is fine
raise errors.DockerException(f"Invalid bind address format: {addr}")
if any(
[parsed_url.params, parsed_url.query, parsed_url.fragment, parsed_url.password]
):
raise errors.DockerException(f"Invalid bind address format: {addr}")
if parsed_url.path and proto == "ssh":
raise errors.DockerException(
f"Invalid bind address format: no path allowed for this protocol: {addr}"
)
path = parsed_url.path
if proto == "unix" and parsed_url.hostname is not None:
# For legacy reasons, we consider unix://path
# to be valid and equivalent to unix:///path
path = f"{parsed_url.hostname}/{path}"
netloc = parsed_url.netloc
if proto in ("tcp", "ssh"):
port = parsed_url.port or 0
if port <= 0:
port = 22 if proto == "ssh" else (2375 if tls else 2376)
netloc = f"{parsed_url.netloc}:{port}"
if not parsed_url.hostname:
netloc = f"{DEFAULT_HTTP_HOST}:{port}"
# Rewrite schemes to fit library internals (requests adapters)
if proto == "tcp":
proto = f"http{'s' if tls else ''}"
elif proto == "unix":
proto = "http+unix"
if proto in ("http+unix", "npipe"):
return f"{proto}://{path}".rstrip("/")
return urlunparse(
URLComponents(
scheme=proto,
netloc=netloc,
url=path,
params="",
query="",
fragment="",
)
).rstrip("/")
def parse_devices(devices: Sequence[dict[str, str] | str]) -> list[dict[str, str]]:
device_list = []
for device in devices:
if isinstance(device, dict):
device_list.append(device)
continue
if not isinstance(device, str):
raise errors.DockerException(f"Invalid device type {type(device)}")
device_mapping = device.split(":")
if device_mapping:
path_on_host = device_mapping[0]
if len(device_mapping) > 1:
path_in_container = device_mapping[1]
else:
path_in_container = path_on_host
if len(device_mapping) > 2:
permissions = device_mapping[2]
else:
permissions = "rwm"
device_list.append(
{
"PathOnHost": path_on_host,
"PathInContainer": path_in_container,
"CgroupPermissions": permissions,
}
)
return device_list
def kwargs_from_env(
assert_hostname: bool | None = None,
environment: Mapping[str, str] | None = None,
) -> dict[str, t.Any]:
if not environment:
environment = os.environ
host = environment.get("DOCKER_HOST")
# empty string for cert path is the same as unset.
cert_path = environment.get("DOCKER_CERT_PATH") or None
# empty string for tls verify counts as "false".
# Any value or 'unset' counts as true.
tls_verify_str = environment.get("DOCKER_TLS_VERIFY")
if tls_verify_str == "":
tls_verify = False
else:
tls_verify = tls_verify_str is not None
enable_tls = cert_path or tls_verify
params: dict[str, t.Any] = {}
if host:
params["base_url"] = host
if not enable_tls:
return params
if not cert_path:
cert_path = os.path.join(os.path.expanduser("~"), ".docker")
if not tls_verify and assert_hostname is None:
# assert_hostname is a subset of TLS verification,
# so if it is not set already then set it to false.
assert_hostname = False
params["tls"] = TLSConfig(
client_cert=(
os.path.join(cert_path, "cert.pem"),
os.path.join(cert_path, "key.pem"),
),
ca_cert=os.path.join(cert_path, "ca.pem"),
verify=tls_verify,
assert_hostname=assert_hostname,
)
return params
def convert_filters(
filters: Mapping[str, bool | str | int | list[int] | list[str] | list[str | int]],
) -> str:
result = {}
for k, v in filters.items():
if isinstance(v, bool):
v = "true" if v else "false"
if not isinstance(v, list):
v = [
v,
]
result[k] = [str(item) if not isinstance(item, str) else item for item in v]
return json.dumps(result)
def parse_bytes(s: int | float | str) -> int | float:
if isinstance(s, (int, float)):
return s
if len(s) == 0:
return 0
if s[-2:-1].isalpha() and s[-1].isalpha() and (s[-1] == "b" or s[-1] == "B"):
s = s[:-1]
units = BYTE_UNITS
suffix = s[-1].lower()
# Check if the variable is a string representation of an int
# without a units part. Assuming that the units are bytes.
if suffix.isdigit():
digits_part = s
suffix = "b"
else:
digits_part = s[:-1]
if suffix in units or suffix.isdigit():
try:
digits = float(digits_part)
except ValueError as exc:
raise errors.DockerException(
f"Failed converting the string value for memory ({digits_part}) to an integer."
) from exc
# Reconvert to long for the final result
s = int(digits * units[suffix])
else:
raise errors.DockerException(
f"The specified value for memory ({s}) should specify the units. The postfix should be one of the `b` `k` `m` `g` characters"
)
return s
def normalize_links(links: dict[str, str] | Sequence[tuple[str, str]]) -> list[str]:
if isinstance(links, dict):
sorted_links = sorted(links.items())
else:
sorted_links = sorted(links)
return [f"{k}:{v}" if v else k for k, v in sorted_links]
def parse_env_file(env_file: str | os.PathLike) -> dict[str, str]:
"""
Reads a line-separated environment file.
The format of each line should be "key=value".
"""
environment = {}
with open(env_file, "rt", encoding="utf-8") as f:
for line in f:
if line[0] == "#":
continue
line = line.strip()
if not line:
continue
parse_line = line.split("=", 1)
if len(parse_line) == 2:
k, v = parse_line
environment[k] = v
else:
raise errors.DockerException(
f"Invalid line in environment file {env_file}:\n{line}"
)
return environment
def split_command(command: str) -> list[str]:
return shlex.split(command)
def format_environment(environment: Mapping[str, str | bytes | None]) -> list[str]:
def format_env(key: str, value: str | bytes | None) -> str:
if value is None:
return key
if isinstance(value, bytes):
value = value.decode("utf-8")
return f"{key}={value}"
return [format_env(*var) for var in environment.items()]
def format_extra_hosts(extra_hosts: Mapping[str, str], task: bool = False) -> list[str]:
# Use format dictated by Swarm API if container is part of a task
if task:
return [f"{v} {k}" for k, v in sorted(extra_hosts.items())]
return [f"{k}:{v}" for k, v in sorted(extra_hosts.items())]