# This code is part of the Ansible collection community.docker, but is an independent component. # This particular file, and this file only, is based on the Docker SDK for Python (https://github.com/docker/docker-py/) # # Copyright (c) 2016-2022 Docker, Inc. # # It is licensed under the Apache 2.0 license (see LICENSES/Apache-2.0.txt in this collection) # SPDX-License-Identifier: Apache-2.0 # Note that this module util is **PRIVATE** to the collection. It can have breaking changes at any time. # Do not use this from other collections or standalone plugins/modules! from __future__ import annotations import logging import os import signal import socket import subprocess import traceback import typing as t from queue import Empty from urllib.parse import urlparse from .. import constants from .._import_helper import HTTPAdapter, urllib3, urllib3_connection from .basehttpadapter import BaseHTTPAdapter PARAMIKO_IMPORT_ERROR: str | None # pylint: disable=invalid-name try: import paramiko except ImportError: PARAMIKO_IMPORT_ERROR = traceback.format_exc() # pylint: disable=invalid-name else: PARAMIKO_IMPORT_ERROR = None # pylint: disable=invalid-name if t.TYPE_CHECKING: from collections.abc import Buffer, Mapping RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer class SSHSocket(socket.socket): def __init__(self, host: str) -> None: super().__init__(socket.AF_INET, socket.SOCK_STREAM) self.host = host self.port = None self.user = None if ":" in self.host: self.host, self.port = self.host.split(":") if "@" in self.host: self.user, self.host = self.host.split("@") self.proc: subprocess.Popen | None = None def connect(self, *args_: t.Any, **kwargs: t.Any) -> None: args = ["ssh"] if self.user: args = args + ["-l", self.user] if self.port: args = args + ["-p", self.port] args = args + ["--", self.host, "docker system dial-stdio"] preexec_func = None if not constants.IS_WINDOWS_PLATFORM: def f() -> None: signal.signal(signal.SIGINT, signal.SIG_IGN) preexec_func = f env = dict(os.environ) # drop LD_LIBRARY_PATH and SSL_CERT_FILE env.pop("LD_LIBRARY_PATH", None) env.pop("SSL_CERT_FILE", None) self.proc = subprocess.Popen( # pylint: disable=consider-using-with args, env=env, stdout=subprocess.PIPE, stdin=subprocess.PIPE, preexec_fn=preexec_func, ) def _write(self, data: Buffer) -> int: if not self.proc: raise RuntimeError( "SSH subprocess not initiated. connect() must be called first." ) assert self.proc.stdin is not None if self.proc.stdin.closed: raise RuntimeError( "SSH subprocess not initiated. connect() must be called first after close()." ) written = self.proc.stdin.write(data) self.proc.stdin.flush() return written def sendall(self, data: Buffer, *args: t.Any, **kwargs: t.Any) -> None: self._write(data) def send(self, data: Buffer, *args: t.Any, **kwargs: t.Any) -> int: return self._write(data) def recv(self, n: int, *args: t.Any, **kwargs: t.Any) -> bytes: if not self.proc: raise RuntimeError( "SSH subprocess not initiated. connect() must be called first." ) assert self.proc.stdout is not None return self.proc.stdout.read(n) def makefile(self, mode: str, *args: t.Any, **kwargs: t.Any) -> t.IO: # type: ignore if not self.proc: self.connect() assert self.proc is not None assert self.proc.stdout is not None self.proc.stdout.channel = self # type: ignore return self.proc.stdout def close(self) -> None: if not self.proc: return assert self.proc.stdin is not None if self.proc.stdin.closed: return self.proc.stdin.write(b"\n\n") self.proc.stdin.flush() self.proc.terminate() class SSHConnection(urllib3_connection.HTTPConnection): def __init__( self, *, ssh_transport: paramiko.Transport | None = None, timeout: int | float = 60, host: str, ) -> None: super().__init__("localhost", timeout=timeout) self.ssh_transport = ssh_transport self.timeout = timeout self.ssh_host = host self.sock: paramiko.Channel | SSHSocket | None = None def connect(self) -> None: if self.ssh_transport: channel = self.ssh_transport.open_session() channel.settimeout(self.timeout) channel.exec_command("docker system dial-stdio") self.sock = channel else: sock = SSHSocket(self.ssh_host) sock.settimeout(self.timeout) sock.connect() self.sock = sock class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool): scheme = "ssh" def __init__( self, *, ssh_client: paramiko.SSHClient | None = None, timeout: int | float = 60, maxsize: int = 10, host: str, ) -> None: super().__init__("localhost", timeout=timeout, maxsize=maxsize) self.ssh_transport: paramiko.Transport | None = None self.timeout = timeout if ssh_client: self.ssh_transport = ssh_client.get_transport() self.ssh_host = host def _new_conn(self) -> SSHConnection: return SSHConnection( ssh_transport=self.ssh_transport, timeout=self.timeout, host=self.ssh_host, ) # When re-using connections, urllib3 calls fileno() on our # SSH channel instance, quickly overloading our fd limit. To avoid this, # we override _get_conn def _get_conn(self, timeout: int | float) -> SSHConnection: conn = None try: conn = self.pool.get(block=self.block, timeout=timeout) except AttributeError as exc: # self.pool is None raise urllib3.exceptions.ClosedPoolError(self, "Pool is closed.") from exc except Empty as exc: if self.block: raise urllib3.exceptions.EmptyPoolError( self, "Pool reached maximum size and no more connections are allowed.", ) from exc # Oh well, we'll create a new connection then return conn or self._new_conn() class SSHHTTPAdapter(BaseHTTPAdapter): __attrs__ = HTTPAdapter.__attrs__ + [ "pools", "timeout", "ssh_client", "ssh_params", "max_pool_size", ] def __init__( self, base_url: str, timeout: int | float = 60, pool_connections: int = constants.DEFAULT_NUM_POOLS, max_pool_size: int = constants.DEFAULT_MAX_POOL_SIZE, shell_out: bool = False, ) -> None: self.ssh_client: paramiko.SSHClient | None = None if not shell_out: self._create_paramiko_client(base_url) self._connect() self.ssh_host = base_url if base_url.startswith("ssh://"): self.ssh_host = base_url[len("ssh://") :] self.timeout = timeout self.max_pool_size = max_pool_size self.pools = RecentlyUsedContainer( pool_connections, dispose_func=lambda p: p.close() ) super().__init__() def _create_paramiko_client(self, base_url: str) -> None: logging.getLogger("paramiko").setLevel(logging.WARNING) self.ssh_client = paramiko.SSHClient() base_url_p = urlparse(base_url) assert base_url_p.hostname is not None self.ssh_params: dict[str, t.Any] = { "hostname": base_url_p.hostname, "port": base_url_p.port, "username": base_url_p.username, } ssh_config_file = os.path.expanduser("~/.ssh/config") if os.path.exists(ssh_config_file): conf = paramiko.SSHConfig() with open(ssh_config_file, "rt", encoding="utf-8") as f: conf.parse(f) host_config = conf.lookup(base_url_p.hostname) if "proxycommand" in host_config: self.ssh_params["sock"] = paramiko.ProxyCommand( host_config["proxycommand"] ) if "hostname" in host_config: self.ssh_params["hostname"] = host_config["hostname"] if base_url_p.port is None and "port" in host_config: self.ssh_params["port"] = host_config["port"] if base_url_p.username is None and "user" in host_config: self.ssh_params["username"] = host_config["user"] if "identityfile" in host_config: self.ssh_params["key_filename"] = host_config["identityfile"] self.ssh_client.load_system_host_keys() self.ssh_client.set_missing_host_key_policy(paramiko.RejectPolicy()) def _connect(self) -> None: if self.ssh_client: self.ssh_client.connect(**self.ssh_params) def get_connection( self, url: str | bytes, proxies: Mapping[str, str] | None = None ) -> SSHConnectionPool: if not self.ssh_client: return SSHConnectionPool( ssh_client=self.ssh_client, timeout=self.timeout, maxsize=self.max_pool_size, host=self.ssh_host, ) with self.pools.lock: pool = self.pools.get(url) if pool: return pool # Connection is closed try a reconnect if self.ssh_client and not self.ssh_client.get_transport(): self._connect() pool = SSHConnectionPool( ssh_client=self.ssh_client, timeout=self.timeout, maxsize=self.max_pool_size, host=self.ssh_host, ) self.pools[url] = pool return pool def close(self) -> None: super().close() if self.ssh_client: self.ssh_client.close()