mirror of
https://github.com/ansible-collections/community.docker.git
synced 2025-12-16 11:58:43 +00:00
* Add typing to Docker Stack modules. Clean modules up. * Add typing to Docker Swarm modules. * Add typing to unit tests. * Add more typing. * Add ignore.txt entries.
313 lines
10 KiB
Python
313 lines
10 KiB
Python
# This code is part of the Ansible collection community.docker, but is an independent component.
|
|
# This particular file, and this file only, is based on the Docker SDK for Python (https://github.com/docker/docker-py/)
|
|
#
|
|
# Copyright (c) 2016-2022 Docker, Inc.
|
|
#
|
|
# It is licensed under the Apache 2.0 license (see LICENSES/Apache-2.0.txt in this collection)
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time.
|
|
# Do not use this from other collections or standalone plugins/modules!
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
import signal
|
|
import socket
|
|
import subprocess
|
|
import traceback
|
|
import typing as t
|
|
from queue import Empty
|
|
from urllib.parse import urlparse
|
|
|
|
from .. import constants
|
|
from .._import_helper import HTTPAdapter, urllib3, urllib3_connection
|
|
from .basehttpadapter import BaseHTTPAdapter
|
|
|
|
|
|
PARAMIKO_IMPORT_ERROR: str | None # pylint: disable=invalid-name
|
|
try:
|
|
import paramiko
|
|
except ImportError:
|
|
PARAMIKO_IMPORT_ERROR = traceback.format_exc() # pylint: disable=invalid-name
|
|
else:
|
|
PARAMIKO_IMPORT_ERROR = None # pylint: disable=invalid-name
|
|
|
|
if t.TYPE_CHECKING:
|
|
from collections.abc import Buffer, Mapping
|
|
|
|
|
|
RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
|
|
|
|
|
|
class SSHSocket(socket.socket):
|
|
def __init__(self, host: str) -> None:
|
|
super().__init__(socket.AF_INET, socket.SOCK_STREAM)
|
|
self.host = host
|
|
self.port = None
|
|
self.user = None
|
|
if ":" in self.host:
|
|
self.host, self.port = self.host.split(":")
|
|
if "@" in self.host:
|
|
self.user, self.host = self.host.split("@")
|
|
|
|
self.proc: subprocess.Popen | None = None
|
|
|
|
def connect(self, *args_: t.Any, **kwargs: t.Any) -> None:
|
|
args = ["ssh"]
|
|
if self.user:
|
|
args = args + ["-l", self.user]
|
|
|
|
if self.port:
|
|
args = args + ["-p", self.port]
|
|
|
|
args = args + ["--", self.host, "docker system dial-stdio"]
|
|
|
|
preexec_func = None
|
|
if not constants.IS_WINDOWS_PLATFORM:
|
|
|
|
def f() -> None:
|
|
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
|
|
|
preexec_func = f
|
|
|
|
env = dict(os.environ)
|
|
|
|
# drop LD_LIBRARY_PATH and SSL_CERT_FILE
|
|
env.pop("LD_LIBRARY_PATH", None)
|
|
env.pop("SSL_CERT_FILE", None)
|
|
|
|
self.proc = subprocess.Popen( # pylint: disable=consider-using-with
|
|
args,
|
|
env=env,
|
|
stdout=subprocess.PIPE,
|
|
stdin=subprocess.PIPE,
|
|
preexec_fn=preexec_func,
|
|
)
|
|
|
|
def _write(self, data: Buffer) -> int:
|
|
if not self.proc:
|
|
raise RuntimeError(
|
|
"SSH subprocess not initiated. connect() must be called first."
|
|
)
|
|
assert self.proc.stdin is not None
|
|
if self.proc.stdin.closed:
|
|
raise RuntimeError(
|
|
"SSH subprocess not initiated. connect() must be called first after close()."
|
|
)
|
|
written = self.proc.stdin.write(data)
|
|
self.proc.stdin.flush()
|
|
return written
|
|
|
|
def sendall(self, data: Buffer, *args: t.Any, **kwargs: t.Any) -> None:
|
|
self._write(data)
|
|
|
|
def send(self, data: Buffer, *args: t.Any, **kwargs: t.Any) -> int:
|
|
return self._write(data)
|
|
|
|
def recv(self, n: int, *args: t.Any, **kwargs: t.Any) -> bytes:
|
|
if not self.proc:
|
|
raise RuntimeError(
|
|
"SSH subprocess not initiated. connect() must be called first."
|
|
)
|
|
assert self.proc.stdout is not None
|
|
return self.proc.stdout.read(n)
|
|
|
|
def makefile(self, mode: str, *args: t.Any, **kwargs: t.Any) -> t.IO: # type: ignore
|
|
if not self.proc:
|
|
self.connect()
|
|
assert self.proc is not None
|
|
assert self.proc.stdout is not None
|
|
self.proc.stdout.channel = self # type: ignore
|
|
|
|
return self.proc.stdout
|
|
|
|
def close(self) -> None:
|
|
if not self.proc:
|
|
return
|
|
assert self.proc.stdin is not None
|
|
if self.proc.stdin.closed:
|
|
return
|
|
self.proc.stdin.write(b"\n\n")
|
|
self.proc.stdin.flush()
|
|
self.proc.terminate()
|
|
|
|
|
|
class SSHConnection(urllib3_connection.HTTPConnection):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
ssh_transport: paramiko.Transport | None = None,
|
|
timeout: int | float = 60,
|
|
host: str,
|
|
) -> None:
|
|
super().__init__("localhost", timeout=timeout)
|
|
self.ssh_transport = ssh_transport
|
|
self.timeout = timeout
|
|
self.ssh_host = host
|
|
self.sock: paramiko.Channel | SSHSocket | None = None
|
|
|
|
def connect(self) -> None:
|
|
if self.ssh_transport:
|
|
channel = self.ssh_transport.open_session()
|
|
channel.settimeout(self.timeout)
|
|
channel.exec_command("docker system dial-stdio")
|
|
self.sock = channel
|
|
else:
|
|
sock = SSHSocket(self.ssh_host)
|
|
sock.settimeout(self.timeout)
|
|
sock.connect()
|
|
self.sock = sock
|
|
|
|
|
|
class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
|
|
scheme = "ssh"
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
ssh_client: paramiko.SSHClient | None = None,
|
|
timeout: int | float = 60,
|
|
maxsize: int = 10,
|
|
host: str,
|
|
) -> None:
|
|
super().__init__("localhost", timeout=timeout, maxsize=maxsize)
|
|
self.ssh_transport: paramiko.Transport | None = None
|
|
self.timeout = timeout
|
|
if ssh_client:
|
|
self.ssh_transport = ssh_client.get_transport()
|
|
self.ssh_host = host
|
|
|
|
def _new_conn(self) -> SSHConnection:
|
|
return SSHConnection(
|
|
ssh_transport=self.ssh_transport,
|
|
timeout=self.timeout,
|
|
host=self.ssh_host,
|
|
)
|
|
|
|
# When re-using connections, urllib3 calls fileno() on our
|
|
# SSH channel instance, quickly overloading our fd limit. To avoid this,
|
|
# we override _get_conn
|
|
def _get_conn(self, timeout: int | float) -> SSHConnection:
|
|
conn = None
|
|
try:
|
|
conn = self.pool.get(block=self.block, timeout=timeout)
|
|
|
|
except AttributeError as exc: # self.pool is None
|
|
raise urllib3.exceptions.ClosedPoolError(self, "Pool is closed.") from exc
|
|
|
|
except Empty as exc:
|
|
if self.block:
|
|
raise urllib3.exceptions.EmptyPoolError(
|
|
self,
|
|
"Pool reached maximum size and no more connections are allowed.",
|
|
) from exc
|
|
# Oh well, we'll create a new connection then
|
|
|
|
return conn or self._new_conn()
|
|
|
|
|
|
class SSHHTTPAdapter(BaseHTTPAdapter):
|
|
__attrs__ = HTTPAdapter.__attrs__ + [
|
|
"pools",
|
|
"timeout",
|
|
"ssh_client",
|
|
"ssh_params",
|
|
"max_pool_size",
|
|
]
|
|
|
|
def __init__(
|
|
self,
|
|
base_url: str,
|
|
timeout: int | float = 60,
|
|
pool_connections: int = constants.DEFAULT_NUM_POOLS,
|
|
max_pool_size: int = constants.DEFAULT_MAX_POOL_SIZE,
|
|
shell_out: bool = False,
|
|
) -> None:
|
|
self.ssh_client: paramiko.SSHClient | None = None
|
|
if not shell_out:
|
|
self._create_paramiko_client(base_url)
|
|
self._connect()
|
|
|
|
self.ssh_host = base_url
|
|
if base_url.startswith("ssh://"):
|
|
self.ssh_host = base_url[len("ssh://") :]
|
|
|
|
self.timeout = timeout
|
|
self.max_pool_size = max_pool_size
|
|
self.pools = RecentlyUsedContainer(
|
|
pool_connections, dispose_func=lambda p: p.close()
|
|
)
|
|
super().__init__()
|
|
|
|
def _create_paramiko_client(self, base_url: str) -> None:
|
|
logging.getLogger("paramiko").setLevel(logging.WARNING)
|
|
self.ssh_client = paramiko.SSHClient()
|
|
base_url_p = urlparse(base_url)
|
|
assert base_url_p.hostname is not None
|
|
self.ssh_params: dict[str, t.Any] = {
|
|
"hostname": base_url_p.hostname,
|
|
"port": base_url_p.port,
|
|
"username": base_url_p.username,
|
|
}
|
|
ssh_config_file = os.path.expanduser("~/.ssh/config")
|
|
if os.path.exists(ssh_config_file):
|
|
conf = paramiko.SSHConfig()
|
|
with open(ssh_config_file, "rt", encoding="utf-8") as f:
|
|
conf.parse(f)
|
|
host_config = conf.lookup(base_url_p.hostname)
|
|
if "proxycommand" in host_config:
|
|
self.ssh_params["sock"] = paramiko.ProxyCommand(
|
|
host_config["proxycommand"]
|
|
)
|
|
if "hostname" in host_config:
|
|
self.ssh_params["hostname"] = host_config["hostname"]
|
|
if base_url_p.port is None and "port" in host_config:
|
|
self.ssh_params["port"] = host_config["port"]
|
|
if base_url_p.username is None and "user" in host_config:
|
|
self.ssh_params["username"] = host_config["user"]
|
|
if "identityfile" in host_config:
|
|
self.ssh_params["key_filename"] = host_config["identityfile"]
|
|
|
|
self.ssh_client.load_system_host_keys()
|
|
self.ssh_client.set_missing_host_key_policy(paramiko.RejectPolicy())
|
|
|
|
def _connect(self) -> None:
|
|
if self.ssh_client:
|
|
self.ssh_client.connect(**self.ssh_params)
|
|
|
|
def get_connection(
|
|
self, url: str | bytes, proxies: Mapping[str, str] | None = None
|
|
) -> SSHConnectionPool:
|
|
if not self.ssh_client:
|
|
return SSHConnectionPool(
|
|
ssh_client=self.ssh_client,
|
|
timeout=self.timeout,
|
|
maxsize=self.max_pool_size,
|
|
host=self.ssh_host,
|
|
)
|
|
with self.pools.lock:
|
|
pool = self.pools.get(url)
|
|
if pool:
|
|
return pool
|
|
|
|
# Connection is closed try a reconnect
|
|
if self.ssh_client and not self.ssh_client.get_transport():
|
|
self._connect()
|
|
|
|
pool = SSHConnectionPool(
|
|
ssh_client=self.ssh_client,
|
|
timeout=self.timeout,
|
|
maxsize=self.max_pool_size,
|
|
host=self.ssh_host,
|
|
)
|
|
self.pools[url] = pool
|
|
|
|
return pool
|
|
|
|
def close(self) -> None:
|
|
super().close()
|
|
if self.ssh_client:
|
|
self.ssh_client.close()
|