mirror of
https://github.com/ansible/ansible.git
synced 2025-11-30 23:16:08 +07:00
Add ssh-agent launching, and ssh-agent python client (#84754)
* Add ssh-agent launching, and ssh-agent python client * Move things around, is this better?? * docs * postpone creating dir after bin lookup * fix method name * changelog ssh agent * address reviews * fix typing * do not redefine public_key * typing * more typing * Catch OSError when starting ssh agent * likely copy pasted old code * var type fix * why is this needed? ci_complete * ignoring the change for now * write out pub key file atomically * defensive timeout for the socket * _populate_agent docstring * do not allow setting these in config * check expected length before slicing blobs * test all key types * remove lock/unlock functionality * docstring * private _ssh_agent * . * launch agent in cli and ansible_ssh_* * additional info for ssh-agent comment * Add tests for remove and remove_all * comment on os.rename * hopefully mitigate agent startup/delays problems * exceptions * unused import * fix sanity * perf --------- Co-authored-by: Matt Martz <matt@sivel.net>
This commit is contained in:
6
changelogs/fragments/ssh-agent.yml
Normal file
6
changelogs/fragments/ssh-agent.yml
Normal file
@@ -0,0 +1,6 @@
|
||||
minor_changes:
|
||||
- ssh-agent - ``ansible``, ``ansible-playbook`` and ``ansible-console`` are capable of spawning or reusing an ssh-agent,
|
||||
allowing plugins to interact with the ssh-agent.
|
||||
Additionally a pure python ssh-agent client has been added, enabling easy interaction with the agent. The ssh connection plugin contains
|
||||
new functionality via ``ansible_ssh_private_key`` and ``ansible_ssh_private_key_passphrase``, for loading an SSH private key into
|
||||
the agent from a variable.
|
||||
@@ -7,6 +7,7 @@ from __future__ import annotations
|
||||
|
||||
import locale
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
|
||||
|
||||
@@ -88,6 +89,7 @@ if jinja2_version < LooseVersion('3.1'):
|
||||
'Current version: %s' % jinja2_version
|
||||
)
|
||||
|
||||
import atexit
|
||||
import errno
|
||||
import getpass
|
||||
import subprocess
|
||||
@@ -111,10 +113,12 @@ from ansible.module_utils.six import string_types
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_text
|
||||
from ansible.module_utils.common.collections import is_sequence
|
||||
from ansible.module_utils.common.file import is_executable
|
||||
from ansible.module_utils.common.process import get_bin_path
|
||||
from ansible.parsing.dataloader import DataLoader
|
||||
from ansible.parsing.vault import PromptVaultSecret, get_file_vault_secret
|
||||
from ansible.plugins.loader import add_all_plugin_dirs, init_plugin_loader
|
||||
from ansible.release import __version__
|
||||
from ansible.utils._ssh_agent import SshAgentClient
|
||||
from ansible.utils.collection_loader import AnsibleCollectionConfig
|
||||
from ansible.utils.collection_loader._collection_finder import _get_collection_name_from_path
|
||||
from ansible.utils.path import unfrackpath
|
||||
@@ -128,6 +132,77 @@ except ImportError:
|
||||
HAS_ARGCOMPLETE = False
|
||||
|
||||
|
||||
_SSH_AGENT_STDOUT_READ_TIMEOUT = 5 # seconds
|
||||
|
||||
|
||||
def _ssh_agent_timeout_handler(signum, frame):
|
||||
raise TimeoutError
|
||||
|
||||
|
||||
def _launch_ssh_agent() -> None:
|
||||
ssh_agent_cfg = C.config.get_config_value('SSH_AGENT')
|
||||
match ssh_agent_cfg:
|
||||
case 'none':
|
||||
display.debug('SSH_AGENT set to none')
|
||||
return
|
||||
case 'auto':
|
||||
try:
|
||||
ssh_agent_bin = get_bin_path('ssh-agent', required=True)
|
||||
except ValueError as e:
|
||||
raise AnsibleError('SSH_AGENT set to auto, but cannot find ssh-agent binary') from e
|
||||
ssh_agent_dir = os.path.join(C.DEFAULT_LOCAL_TMP, 'ssh_agent')
|
||||
os.mkdir(ssh_agent_dir, 0o700)
|
||||
sock = os.path.join(ssh_agent_dir, 'agent.sock')
|
||||
display.vvv('SSH_AGENT: starting...')
|
||||
try:
|
||||
p = subprocess.Popen(
|
||||
[ssh_agent_bin, '-D', '-s', '-a', sock],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
except OSError as e:
|
||||
raise AnsibleError(
|
||||
f'Could not start ssh-agent: {e}'
|
||||
) from e
|
||||
|
||||
if p.poll() is not None:
|
||||
raise AnsibleError(
|
||||
f'Could not start ssh-agent: rc={p.returncode} stderr="{p.stderr.read().decode()}"'
|
||||
)
|
||||
|
||||
old_sigalrm_handler = signal.signal(signal.SIGALRM, _ssh_agent_timeout_handler)
|
||||
signal.alarm(_SSH_AGENT_STDOUT_READ_TIMEOUT)
|
||||
try:
|
||||
stdout = p.stdout.read(13)
|
||||
except TimeoutError:
|
||||
stdout = b''
|
||||
finally:
|
||||
signal.alarm(0)
|
||||
signal.signal(signal.SIGALRM, old_sigalrm_handler)
|
||||
|
||||
if stdout != b'SSH_AUTH_SOCK':
|
||||
display.warning(
|
||||
f'The first 13 characters of stdout did not match the '
|
||||
f'expected SSH_AUTH_SOCK. This may not be the right binary, '
|
||||
f'or an incompatible agent: {stdout.decode()}'
|
||||
)
|
||||
display.vvv(f'SSH_AGENT: ssh-agent[{p.pid}] started and bound to {sock}')
|
||||
atexit.register(p.terminate)
|
||||
case _:
|
||||
sock = ssh_agent_cfg
|
||||
|
||||
try:
|
||||
with SshAgentClient(sock) as client:
|
||||
client.list()
|
||||
except Exception as e:
|
||||
raise AnsibleError(
|
||||
f'Could not communicate with ssh-agent using auth sock {sock}: {e}'
|
||||
) from e
|
||||
|
||||
os.environ['SSH_AUTH_SOCK'] = os.environ['ANSIBLE_SSH_AGENT'] = sock
|
||||
|
||||
|
||||
class CLI(ABC):
|
||||
""" code behind bin/ansible* programs """
|
||||
|
||||
@@ -137,6 +212,7 @@ class CLI(ABC):
|
||||
# -S (chop long lines) -X (disable termcap init and de-init)
|
||||
LESS_OPTS = 'FRSX'
|
||||
SKIP_INVENTORY_DEFAULTS = False
|
||||
USES_CONNECTION = False
|
||||
|
||||
def __init__(self, args, callback=None):
|
||||
"""
|
||||
@@ -528,8 +604,7 @@ class CLI(ABC):
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _play_prereqs():
|
||||
def _play_prereqs(self):
|
||||
# TODO: evaluate moving all of the code that touches ``AnsibleCollectionConfig``
|
||||
# into ``init_plugin_loader`` so that we can specifically remove
|
||||
# ``AnsibleCollectionConfig.playbook_paths`` to make it immutable after instantiation
|
||||
@@ -560,6 +635,12 @@ class CLI(ABC):
|
||||
auto_prompt=False)
|
||||
loader.set_vault_secrets(vault_secrets)
|
||||
|
||||
if self.USES_CONNECTION:
|
||||
try:
|
||||
_launch_ssh_agent()
|
||||
except Exception as e:
|
||||
raise AnsibleError('Failed to launch ssh agent', orig_exc=e)
|
||||
|
||||
# create the inventory, and filter it based on the subset specified (if any)
|
||||
inventory = InventoryManager(loader=loader, sources=options['inventory'], cache=(not options.get('flush_cache')))
|
||||
|
||||
|
||||
@@ -30,6 +30,8 @@ class AdHocCLI(CLI):
|
||||
|
||||
name = 'ansible'
|
||||
|
||||
USES_CONNECTION = True
|
||||
|
||||
def init_parser(self):
|
||||
""" create an options parser for bin/ansible """
|
||||
super(AdHocCLI, self).init_parser(usage='%prog <host-pattern> [options]',
|
||||
|
||||
@@ -72,6 +72,8 @@ class ConsoleCLI(CLI, cmd.Cmd):
|
||||
# use specific to console, but fallback to highlight for backwards compatibility
|
||||
NORMAL_PROMPT = C.COLOR_CONSOLE_PROMPT or C.COLOR_HIGHLIGHT
|
||||
|
||||
USES_CONNECTION = True
|
||||
|
||||
def __init__(self, args):
|
||||
|
||||
super(ConsoleCLI, self).__init__(args)
|
||||
|
||||
@@ -34,6 +34,8 @@ class PlaybookCLI(CLI):
|
||||
|
||||
name = 'ansible-playbook'
|
||||
|
||||
USES_CONNECTION = True
|
||||
|
||||
def init_parser(self):
|
||||
|
||||
# create parser for CLI options
|
||||
|
||||
@@ -1888,6 +1888,24 @@ SHOW_CUSTOM_STATS:
|
||||
ini:
|
||||
- {key: show_custom_stats, section: defaults}
|
||||
type: bool
|
||||
SSH_AGENT:
|
||||
name: Manage an SSH Agent
|
||||
description: Manage an SSH Agent via Ansible. A configuration of ``none`` will not interact with an agent,
|
||||
``auto`` will start and destroy an agent via ``ssh-agent`` binary during the run, and a path
|
||||
to an SSH_AUTH_SOCK will allow interaction with a pre-existing agent.
|
||||
default: none
|
||||
type: string
|
||||
env: [{name: ANSIBLE_SSH_AGENT}]
|
||||
ini: [{key: ssh_agent, section: connection}]
|
||||
version_added: '2.19'
|
||||
SSH_AGENT_KEY_LIFETIME:
|
||||
name: Set a maximum lifetime when adding identities to an agent
|
||||
description: For keys inserted into an agent defined by ``SSH_AGENT``, define a lifetime, in seconds, that the key may remain
|
||||
in the agent.
|
||||
type: int
|
||||
env: [{name: ANSIBLE_SSH_AGENT_KEY_LIFETIME}]
|
||||
ini: [{key: ssh_agent_key_lifetime, section: connection}]
|
||||
version_added: '2.19'
|
||||
STRING_TYPE_FILTERS:
|
||||
name: Filters to preserve strings
|
||||
default: [string, to_json, to_nice_json, to_yaml, to_nice_yaml, ppretty, json]
|
||||
|
||||
@@ -265,7 +265,6 @@ DOCUMENTATION = """
|
||||
vars:
|
||||
- name: ansible_pipelining
|
||||
- name: ansible_ssh_pipelining
|
||||
|
||||
private_key_file:
|
||||
description:
|
||||
- Path to private key file to use for authentication.
|
||||
@@ -281,7 +280,27 @@ DOCUMENTATION = """
|
||||
cli:
|
||||
- name: private_key_file
|
||||
option: '--private-key'
|
||||
|
||||
private_key:
|
||||
description:
|
||||
- Private key contents in PEM format. Requires the C(SSH_AGENT) configuration to be enabled.
|
||||
type: string
|
||||
env:
|
||||
- name: ANSIBLE_PRIVATE_KEY
|
||||
vars:
|
||||
- name: ansible_private_key
|
||||
- name: ansible_ssh_private_key
|
||||
version_added: '2.19'
|
||||
private_key_passphrase:
|
||||
description:
|
||||
- Private key passphrase, dependent on O(private_key).
|
||||
- This does NOT have any effect when used with O(private_key_file).
|
||||
type: string
|
||||
env:
|
||||
- name: ANSIBLE_PRIVATE_KEY_PASSPHRASE
|
||||
vars:
|
||||
- name: ansible_private_key_passphrase
|
||||
- name: ansible_ssh_private_key_passphrase
|
||||
version_added: '2.19'
|
||||
control_path:
|
||||
description:
|
||||
- This is the location to save SSH's ControlPath sockets, it uses SSH's variable substitution.
|
||||
@@ -398,11 +417,13 @@ import shlex
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import typing as t
|
||||
from functools import wraps
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
|
||||
from ansible import constants as C
|
||||
from ansible.errors import (
|
||||
AnsibleAuthenticationFailure,
|
||||
AnsibleConnectionFailure,
|
||||
@@ -415,6 +436,15 @@ from ansible.plugins.connection import ConnectionBase, BUFSIZE
|
||||
from ansible.plugins.shell.powershell import _replace_stderr_clixml
|
||||
from ansible.utils.display import Display
|
||||
from ansible.utils.path import unfrackpath, makedirs_safe
|
||||
from ansible.utils._ssh_agent import SshAgentClient, _key_data_into_crypto_objects
|
||||
|
||||
try:
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
except ImportError:
|
||||
HAS_CRYPTOGRAPHY = False
|
||||
else:
|
||||
HAS_CRYPTOGRAPHY = True
|
||||
|
||||
|
||||
display = Display()
|
||||
|
||||
@@ -638,6 +668,8 @@ class Connection(ConnectionBase):
|
||||
self._tty_parser.add_argument('-t', action='count')
|
||||
self._tty_parser.add_argument('-o', action='append')
|
||||
|
||||
self._populated_agent: pathlib.Path | None = None
|
||||
|
||||
# The connection is created by running ssh/scp/sftp from the exec_command,
|
||||
# put_file, and fetch_file methods, so we don't need to do any connection
|
||||
# management here.
|
||||
@@ -712,6 +744,52 @@ class Connection(ConnectionBase):
|
||||
display.vvvvv(u'SSH: %s: (%s)' % (explanation, ')('.join(to_text(a) for a in b_args)), host=self.host)
|
||||
b_command += b_args
|
||||
|
||||
def _populate_agent(self) -> pathlib.Path:
|
||||
"""Adds configured private key identity to the SSH agent. Returns a path to a file containing the public key."""
|
||||
if self._populated_agent:
|
||||
return self._populated_agent
|
||||
|
||||
if (auth_sock := C.config.get_config_value('SSH_AGENT')) == 'none':
|
||||
raise AnsibleError('Cannot utilize private_key with SSH_AGENT disabled')
|
||||
|
||||
key_data = self.get_option('private_key')
|
||||
passphrase = self.get_option('private_key_passphrase')
|
||||
|
||||
private_key, public_key, fingerprint = _key_data_into_crypto_objects(
|
||||
to_bytes(key_data),
|
||||
to_bytes(passphrase) if passphrase else None,
|
||||
)
|
||||
|
||||
with SshAgentClient(auth_sock) as client:
|
||||
if public_key not in client:
|
||||
display.vvv(f'SSH: SSH_AGENT adding {fingerprint} to agent', host=self.host)
|
||||
client.add(
|
||||
private_key,
|
||||
f'[added by ansible: PID={os.getpid()}, UID={os.getuid()}, EUID={os.geteuid()}, TIME={time.time()}]',
|
||||
C.config.get_config_value('SSH_AGENT_KEY_LIFETIME'),
|
||||
)
|
||||
else:
|
||||
display.vvv(f'SSH: SSH_AGENT {fingerprint} exists in agent', host=self.host)
|
||||
# Write the public key to disk, to be provided as IdentityFile.
|
||||
# This allows ssh to pick an explicit key in the agent to use,
|
||||
# preventing ssh from attempting all keys in the agent.
|
||||
pubkey_path = self._populated_agent = pathlib.Path(C.DEFAULT_LOCAL_TMP).joinpath(
|
||||
fingerprint.replace('/', '-') + '.pub'
|
||||
)
|
||||
if os.path.exists(pubkey_path):
|
||||
return pubkey_path
|
||||
|
||||
with tempfile.NamedTemporaryFile(dir=C.DEFAULT_LOCAL_TMP, delete=False) as f:
|
||||
f.write(public_key.public_bytes(
|
||||
encoding=serialization.Encoding.OpenSSH,
|
||||
format=serialization.PublicFormat.OpenSSH
|
||||
))
|
||||
# move atomically to prevent race conditions, silently succeeds if the target exists
|
||||
os.rename(f.name, pubkey_path)
|
||||
os.chmod(pubkey_path, mode=0o400)
|
||||
|
||||
return self._populated_agent
|
||||
|
||||
def _build_command(self, binary: str, subsystem: str, *other_args: bytes | str) -> list[bytes]:
|
||||
"""
|
||||
Takes a executable (ssh, scp, sftp or wrapper) and optional extra arguments and returns the remote command
|
||||
@@ -797,8 +875,17 @@ class Connection(ConnectionBase):
|
||||
b_args = (b"-o", b"Port=" + to_bytes(self.port, nonstring='simplerepr', errors='surrogate_or_strict'))
|
||||
self._add_args(b_command, b_args, u"ANSIBLE_REMOTE_PORT/remote_port/ansible_port set")
|
||||
|
||||
key = self.get_option('private_key_file')
|
||||
if key:
|
||||
if self.get_option('private_key'):
|
||||
try:
|
||||
key = self._populate_agent()
|
||||
except Exception as e:
|
||||
raise AnsibleAuthenticationFailure(
|
||||
'Failed to add configured private key into ssh-agent',
|
||||
orig_exc=e,
|
||||
)
|
||||
b_args = (b'-o', b'IdentitiesOnly=yes', b'-o', to_bytes(f'IdentityFile="{key}"', errors='surrogate_or_strict'))
|
||||
self._add_args(b_command, b_args, "ANSIBLE_PRIVATE_KEY/private_key set")
|
||||
elif key := self.get_option('private_key_file'):
|
||||
b_args = (b"-o", b'IdentityFile="' + to_bytes(os.path.expanduser(key), errors='surrogate_or_strict') + b'"')
|
||||
self._add_args(b_command, b_args, u"ANSIBLE_PRIVATE_KEY_FILE/private_key_file/ansible_ssh_private_key_file set")
|
||||
|
||||
|
||||
657
lib/ansible/utils/_ssh_agent.py
Normal file
657
lib/ansible/utils/_ssh_agent.py
Normal file
@@ -0,0 +1,657 @@
|
||||
# Copyright: Contributors to the Ansible project
|
||||
# BSD 3 Clause License (see licenses/BSD-3-Clause.txt or https://opensource.org/license/bsd-3-clause/)
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import binascii
|
||||
import copy
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import hashlib
|
||||
import socket
|
||||
import types
|
||||
import typing as t
|
||||
|
||||
try:
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric.dsa import (
|
||||
DSAParameterNumbers,
|
||||
DSAPrivateKey,
|
||||
DSAPublicKey,
|
||||
DSAPublicNumbers,
|
||||
)
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import (
|
||||
EllipticCurve,
|
||||
EllipticCurvePrivateKey,
|
||||
EllipticCurvePublicKey,
|
||||
SECP256R1,
|
||||
SECP384R1,
|
||||
SECP521R1,
|
||||
)
|
||||
from cryptography.hazmat.primitives.asymmetric.ed25519 import (
|
||||
Ed25519PrivateKey,
|
||||
Ed25519PublicKey,
|
||||
)
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import (
|
||||
RSAPrivateKey,
|
||||
RSAPublicKey,
|
||||
RSAPublicNumbers,
|
||||
)
|
||||
except ImportError:
|
||||
HAS_CRYPTOGRAPHY = False
|
||||
else:
|
||||
HAS_CRYPTOGRAPHY = True
|
||||
|
||||
CryptoPublicKey = t.Union[
|
||||
DSAPublicKey,
|
||||
EllipticCurvePublicKey,
|
||||
Ed25519PublicKey,
|
||||
RSAPublicKey,
|
||||
]
|
||||
|
||||
CryptoPrivateKey = t.Union[
|
||||
DSAPrivateKey,
|
||||
EllipticCurvePrivateKey,
|
||||
Ed25519PrivateKey,
|
||||
RSAPrivateKey,
|
||||
]
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from cryptography.hazmat.primitives.asymmetric.dsa import DSAPrivateNumbers
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateNumbers
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateNumbers
|
||||
|
||||
|
||||
_SSH_AGENT_CLIENT_SOCKET_TIMEOUT = 10
|
||||
|
||||
|
||||
class ProtocolMsgNumbers(enum.IntEnum):
|
||||
# Responses
|
||||
SSH_AGENT_FAILURE = 5
|
||||
SSH_AGENT_SUCCESS = 6
|
||||
SSH_AGENT_IDENTITIES_ANSWER = 12
|
||||
SSH_AGENT_SIGN_RESPONSE = 14
|
||||
SSH_AGENT_EXTENSION_FAILURE = 28
|
||||
SSH_AGENT_EXTENSION_RESPONSE = 29
|
||||
|
||||
# Constraints
|
||||
SSH_AGENT_CONSTRAIN_LIFETIME = 1
|
||||
SSH_AGENT_CONSTRAIN_CONFIRM = 2
|
||||
SSH_AGENT_CONSTRAIN_EXTENSION = 255
|
||||
|
||||
# Requests
|
||||
SSH_AGENTC_REQUEST_IDENTITIES = 11
|
||||
SSH_AGENTC_SIGN_REQUEST = 13
|
||||
SSH_AGENTC_ADD_IDENTITY = 17
|
||||
SSH_AGENTC_REMOVE_IDENTITY = 18
|
||||
SSH_AGENTC_REMOVE_ALL_IDENTITIES = 19
|
||||
SSH_AGENTC_ADD_SMARTCARD_KEY = 20
|
||||
SSH_AGENTC_REMOVE_SMARTCARD_KEY = 21
|
||||
SSH_AGENTC_LOCK = 22
|
||||
SSH_AGENTC_UNLOCK = 23
|
||||
SSH_AGENTC_ADD_ID_CONSTRAINED = 25
|
||||
SSH_AGENTC_ADD_SMARTCARD_KEY_CONSTRAINED = 26
|
||||
SSH_AGENTC_EXTENSION = 27
|
||||
|
||||
def to_blob(self) -> bytes:
|
||||
return bytes([self])
|
||||
|
||||
|
||||
class SshAgentFailure(RuntimeError):
|
||||
"""Server failure or unexpected response."""
|
||||
|
||||
|
||||
# NOTE: Classes below somewhat represent "Data Type Representations Used in the SSH Protocols"
|
||||
# as specified by RFC4251
|
||||
|
||||
@t.runtime_checkable
|
||||
class SupportsToBlob(t.Protocol):
|
||||
def to_blob(self) -> bytes:
|
||||
...
|
||||
|
||||
|
||||
@t.runtime_checkable
|
||||
class SupportsFromBlob(t.Protocol):
|
||||
@classmethod
|
||||
def from_blob(cls, blob: memoryview | bytes) -> t.Self:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def consume_from_blob(cls, blob: memoryview | bytes) -> tuple[t.Self, memoryview | bytes]:
|
||||
...
|
||||
|
||||
|
||||
def _split_blob(blob: memoryview | bytes, length: int) -> tuple[memoryview | bytes, memoryview | bytes]:
|
||||
if len(blob) < length:
|
||||
raise ValueError("_split_blob: unexpected data length")
|
||||
return blob[:length], blob[length:]
|
||||
|
||||
|
||||
class VariableSized:
|
||||
@classmethod
|
||||
def from_blob(cls, blob: memoryview | bytes) -> t.Self:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def consume_from_blob(cls, blob: memoryview | bytes) -> tuple[t.Self, memoryview | bytes]:
|
||||
length = uint32.from_blob(blob[:4])
|
||||
blob = blob[4:]
|
||||
data, rest = _split_blob(blob, length)
|
||||
return cls.from_blob(data), rest
|
||||
|
||||
|
||||
class uint32(int):
|
||||
def to_blob(self) -> bytes:
|
||||
return self.to_bytes(length=4, byteorder='big')
|
||||
|
||||
@classmethod
|
||||
def from_blob(cls, blob: memoryview | bytes) -> t.Self:
|
||||
return cls.from_bytes(blob, byteorder='big')
|
||||
|
||||
@classmethod
|
||||
def consume_from_blob(cls, blob: memoryview | bytes) -> tuple[t.Self, memoryview | bytes]:
|
||||
length = uint32(4)
|
||||
data, rest = _split_blob(blob, length)
|
||||
return cls.from_blob(data), rest
|
||||
|
||||
|
||||
class mpint(int, VariableSized):
|
||||
def to_blob(self) -> bytes:
|
||||
if self < 0:
|
||||
raise ValueError("negative mpint not allowed")
|
||||
if not self:
|
||||
return b""
|
||||
nbytes = (self.bit_length() + 8) // 8
|
||||
ret = bytearray(self.to_bytes(length=nbytes, byteorder='big'))
|
||||
ret[:0] = uint32(len(ret)).to_blob()
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def from_blob(cls, blob: memoryview | bytes) -> t.Self:
|
||||
if blob and blob[0] > 127:
|
||||
raise ValueError("Invalid data")
|
||||
return cls.from_bytes(blob, byteorder='big')
|
||||
|
||||
|
||||
class constraints(bytes):
|
||||
def to_blob(self) -> bytes:
|
||||
return self
|
||||
|
||||
|
||||
class binary_string(bytes, VariableSized):
|
||||
def to_blob(self) -> bytes:
|
||||
if length := len(self):
|
||||
return uint32(length).to_blob() + self
|
||||
else:
|
||||
return b""
|
||||
|
||||
@classmethod
|
||||
def from_blob(cls, blob: memoryview | bytes) -> t.Self:
|
||||
return cls(blob)
|
||||
|
||||
|
||||
class unicode_string(str, VariableSized):
|
||||
def to_blob(self) -> bytes:
|
||||
val = self.encode('utf-8')
|
||||
if length := len(val):
|
||||
return uint32(length).to_blob() + val
|
||||
else:
|
||||
return b""
|
||||
|
||||
@classmethod
|
||||
def from_blob(cls, blob: memoryview | bytes) -> t.Self:
|
||||
return cls(bytes(blob).decode('utf-8'))
|
||||
|
||||
|
||||
class KeyAlgo(str, VariableSized, enum.Enum):
|
||||
RSA = "ssh-rsa"
|
||||
DSA = "ssh-dss"
|
||||
ECDSA256 = "ecdsa-sha2-nistp256"
|
||||
SKECDSA256 = "sk-ecdsa-sha2-nistp256@openssh.com"
|
||||
ECDSA384 = "ecdsa-sha2-nistp384"
|
||||
ECDSA521 = "ecdsa-sha2-nistp521"
|
||||
ED25519 = "ssh-ed25519"
|
||||
SKED25519 = "sk-ssh-ed25519@openssh.com"
|
||||
RSASHA256 = "rsa-sha2-256"
|
||||
RSASHA512 = "rsa-sha2-512"
|
||||
|
||||
@property
|
||||
def main_type(self) -> str:
|
||||
match self:
|
||||
case self.RSA:
|
||||
return 'RSA'
|
||||
case self.DSA:
|
||||
return 'DSA'
|
||||
case self.ECDSA256 | self.ECDSA384 | self.ECDSA521:
|
||||
return 'ECDSA'
|
||||
case self.ED25519:
|
||||
return 'ED25519'
|
||||
case _:
|
||||
raise NotImplementedError(self.name)
|
||||
|
||||
def to_blob(self) -> bytes:
|
||||
b_self = self.encode('utf-8')
|
||||
return uint32(len(b_self)).to_blob() + b_self
|
||||
|
||||
@classmethod
|
||||
def from_blob(cls, blob: memoryview | bytes) -> t.Self:
|
||||
return cls(bytes(blob).decode('utf-8'))
|
||||
|
||||
|
||||
if HAS_CRYPTOGRAPHY:
|
||||
_ECDSA_KEY_TYPE: dict[KeyAlgo, type[EllipticCurve]] = {
|
||||
KeyAlgo.ECDSA256: SECP256R1,
|
||||
KeyAlgo.ECDSA384: SECP384R1,
|
||||
KeyAlgo.ECDSA521: SECP521R1,
|
||||
}
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Msg:
|
||||
def to_blob(self) -> bytes:
|
||||
rv = bytearray()
|
||||
for field in dataclasses.fields(self):
|
||||
fv = getattr(self, field.name)
|
||||
if isinstance(fv, SupportsToBlob):
|
||||
rv.extend(fv.to_blob())
|
||||
else:
|
||||
raise NotImplementedError(field.type)
|
||||
return rv
|
||||
|
||||
@classmethod
|
||||
def from_blob(cls, blob: memoryview | bytes) -> t.Self:
|
||||
args: list[t.Any] = []
|
||||
for _field_name, field_type in t.get_type_hints(cls).items():
|
||||
if isinstance(field_type, SupportsFromBlob):
|
||||
fv, blob = field_type.consume_from_blob(blob)
|
||||
args.append(fv)
|
||||
else:
|
||||
raise NotImplementedError(str(field_type))
|
||||
return cls(*args)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PrivateKeyMsg(Msg):
|
||||
@staticmethod
|
||||
def from_private_key(private_key: CryptoPrivateKey) -> PrivateKeyMsg:
|
||||
match private_key:
|
||||
case RSAPrivateKey():
|
||||
rsa_pn: RSAPrivateNumbers = private_key.private_numbers()
|
||||
return RSAPrivateKeyMsg(
|
||||
KeyAlgo.RSA,
|
||||
mpint(rsa_pn.public_numbers.n),
|
||||
mpint(rsa_pn.public_numbers.e),
|
||||
mpint(rsa_pn.d),
|
||||
mpint(rsa_pn.iqmp),
|
||||
mpint(rsa_pn.p),
|
||||
mpint(rsa_pn.q),
|
||||
)
|
||||
case DSAPrivateKey():
|
||||
dsa_pn: DSAPrivateNumbers = private_key.private_numbers()
|
||||
return DSAPrivateKeyMsg(
|
||||
KeyAlgo.DSA,
|
||||
mpint(dsa_pn.public_numbers.parameter_numbers.p),
|
||||
mpint(dsa_pn.public_numbers.parameter_numbers.q),
|
||||
mpint(dsa_pn.public_numbers.parameter_numbers.g),
|
||||
mpint(dsa_pn.public_numbers.y),
|
||||
mpint(dsa_pn.x),
|
||||
)
|
||||
case EllipticCurvePrivateKey():
|
||||
ecdsa_pn: EllipticCurvePrivateNumbers = private_key.private_numbers()
|
||||
key_size = private_key.key_size
|
||||
return EcdsaPrivateKeyMsg(
|
||||
getattr(KeyAlgo, f'ECDSA{key_size}'),
|
||||
unicode_string(f'nistp{key_size}'),
|
||||
binary_string(private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.X962,
|
||||
format=serialization.PublicFormat.UncompressedPoint
|
||||
)),
|
||||
mpint(ecdsa_pn.private_value),
|
||||
)
|
||||
case Ed25519PrivateKey():
|
||||
public_bytes = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.Raw,
|
||||
format=serialization.PublicFormat.Raw,
|
||||
)
|
||||
private_bytes = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.Raw,
|
||||
format=serialization.PrivateFormat.Raw,
|
||||
encryption_algorithm=serialization.NoEncryption()
|
||||
)
|
||||
return Ed25519PrivateKeyMsg(
|
||||
KeyAlgo.ED25519,
|
||||
binary_string(public_bytes),
|
||||
binary_string(private_bytes + public_bytes),
|
||||
)
|
||||
case _:
|
||||
raise NotImplementedError(private_key)
|
||||
|
||||
|
||||
@dataclasses.dataclass(order=True, slots=True)
|
||||
class RSAPrivateKeyMsg(PrivateKeyMsg):
|
||||
type: KeyAlgo
|
||||
n: mpint
|
||||
e: mpint
|
||||
d: mpint
|
||||
iqmp: mpint
|
||||
p: mpint
|
||||
q: mpint
|
||||
comments: unicode_string = dataclasses.field(default=unicode_string(''), compare=False)
|
||||
constraints: constraints = dataclasses.field(default=constraints(b''))
|
||||
|
||||
|
||||
@dataclasses.dataclass(order=True, slots=True)
|
||||
class DSAPrivateKeyMsg(PrivateKeyMsg):
|
||||
type: KeyAlgo
|
||||
p: mpint
|
||||
q: mpint
|
||||
g: mpint
|
||||
y: mpint
|
||||
x: mpint
|
||||
comments: unicode_string = dataclasses.field(default=unicode_string(''), compare=False)
|
||||
constraints: constraints = dataclasses.field(default=constraints(b''))
|
||||
|
||||
|
||||
@dataclasses.dataclass(order=True, slots=True)
|
||||
class EcdsaPrivateKeyMsg(PrivateKeyMsg):
|
||||
type: KeyAlgo
|
||||
ecdsa_curve_name: unicode_string
|
||||
Q: binary_string
|
||||
d: mpint
|
||||
comments: unicode_string = dataclasses.field(default=unicode_string(''), compare=False)
|
||||
constraints: constraints = dataclasses.field(default=constraints(b''))
|
||||
|
||||
|
||||
@dataclasses.dataclass(order=True, slots=True)
|
||||
class Ed25519PrivateKeyMsg(PrivateKeyMsg):
|
||||
type: KeyAlgo
|
||||
enc_a: binary_string
|
||||
k_env_a: binary_string
|
||||
comments: unicode_string = dataclasses.field(default=unicode_string(''), compare=False)
|
||||
constraints: constraints = dataclasses.field(default=constraints(b''))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PublicKeyMsg(Msg):
|
||||
@staticmethod
|
||||
def get_dataclass(
|
||||
type: KeyAlgo
|
||||
) -> type[t.Union[
|
||||
RSAPublicKeyMsg,
|
||||
EcdsaPublicKeyMsg,
|
||||
Ed25519PublicKeyMsg,
|
||||
DSAPublicKeyMsg
|
||||
]]:
|
||||
match type:
|
||||
case KeyAlgo.RSA:
|
||||
return RSAPublicKeyMsg
|
||||
case KeyAlgo.ECDSA256 | KeyAlgo.ECDSA384 | KeyAlgo.ECDSA521:
|
||||
return EcdsaPublicKeyMsg
|
||||
case KeyAlgo.ED25519:
|
||||
return Ed25519PublicKeyMsg
|
||||
case KeyAlgo.DSA:
|
||||
return DSAPublicKeyMsg
|
||||
case _:
|
||||
raise NotImplementedError(type)
|
||||
|
||||
@functools.cached_property
|
||||
def public_key(self) -> CryptoPublicKey:
|
||||
type: KeyAlgo = self.type
|
||||
match type:
|
||||
case KeyAlgo.RSA:
|
||||
return RSAPublicNumbers(
|
||||
self.e,
|
||||
self.n
|
||||
).public_key()
|
||||
case KeyAlgo.ECDSA256 | KeyAlgo.ECDSA384 | KeyAlgo.ECDSA521:
|
||||
curve = _ECDSA_KEY_TYPE[KeyAlgo(type)]
|
||||
return EllipticCurvePublicKey.from_encoded_point(
|
||||
curve(),
|
||||
self.Q
|
||||
)
|
||||
case KeyAlgo.ED25519:
|
||||
return Ed25519PublicKey.from_public_bytes(
|
||||
self.enc_a
|
||||
)
|
||||
case KeyAlgo.DSA:
|
||||
return DSAPublicNumbers(
|
||||
self.y,
|
||||
DSAParameterNumbers(
|
||||
self.p,
|
||||
self.q,
|
||||
self.g
|
||||
)
|
||||
).public_key()
|
||||
case _:
|
||||
raise NotImplementedError(type)
|
||||
|
||||
@staticmethod
|
||||
def from_public_key(public_key: CryptoPublicKey) -> PublicKeyMsg:
|
||||
match public_key:
|
||||
case DSAPublicKey():
|
||||
dsa_pn: DSAPublicNumbers = public_key.public_numbers()
|
||||
return DSAPublicKeyMsg(
|
||||
KeyAlgo.DSA,
|
||||
mpint(dsa_pn.parameter_numbers.p),
|
||||
mpint(dsa_pn.parameter_numbers.q),
|
||||
mpint(dsa_pn.parameter_numbers.g),
|
||||
mpint(dsa_pn.y)
|
||||
)
|
||||
case EllipticCurvePublicKey():
|
||||
return EcdsaPublicKeyMsg(
|
||||
getattr(KeyAlgo, f'ECDSA{public_key.curve.key_size}'),
|
||||
unicode_string(f'nistp{public_key.curve.key_size}'),
|
||||
binary_string(public_key.public_bytes(
|
||||
encoding=serialization.Encoding.X962,
|
||||
format=serialization.PublicFormat.UncompressedPoint
|
||||
))
|
||||
)
|
||||
case Ed25519PublicKey():
|
||||
return Ed25519PublicKeyMsg(
|
||||
KeyAlgo.ED25519,
|
||||
binary_string(public_key.public_bytes(
|
||||
encoding=serialization.Encoding.Raw,
|
||||
format=serialization.PublicFormat.Raw,
|
||||
))
|
||||
)
|
||||
case RSAPublicKey():
|
||||
rsa_pn: RSAPublicNumbers = public_key.public_numbers()
|
||||
return RSAPublicKeyMsg(
|
||||
KeyAlgo.RSA,
|
||||
mpint(rsa_pn.e),
|
||||
mpint(rsa_pn.n)
|
||||
)
|
||||
case _:
|
||||
raise NotImplementedError(public_key)
|
||||
|
||||
@functools.cached_property
|
||||
def fingerprint(self) -> str:
|
||||
digest = hashlib.sha256()
|
||||
msg = copy.copy(self)
|
||||
msg.comments = unicode_string('')
|
||||
k = msg.to_blob()
|
||||
digest.update(k)
|
||||
return binascii.b2a_base64(
|
||||
digest.digest(),
|
||||
newline=False
|
||||
).rstrip(b'=').decode('utf-8')
|
||||
|
||||
|
||||
@dataclasses.dataclass(order=True, slots=True)
|
||||
class RSAPublicKeyMsg(PublicKeyMsg):
|
||||
type: KeyAlgo
|
||||
e: mpint
|
||||
n: mpint
|
||||
comments: unicode_string = dataclasses.field(default=unicode_string(''), compare=False)
|
||||
|
||||
|
||||
@dataclasses.dataclass(order=True, slots=True)
|
||||
class DSAPublicKeyMsg(PublicKeyMsg):
|
||||
type: KeyAlgo
|
||||
p: mpint
|
||||
q: mpint
|
||||
g: mpint
|
||||
y: mpint
|
||||
comments: unicode_string = dataclasses.field(default=unicode_string(''), compare=False)
|
||||
|
||||
|
||||
@dataclasses.dataclass(order=True, slots=True)
|
||||
class EcdsaPublicKeyMsg(PublicKeyMsg):
|
||||
type: KeyAlgo
|
||||
ecdsa_curve_name: unicode_string
|
||||
Q: binary_string
|
||||
comments: unicode_string = dataclasses.field(default=unicode_string(''), compare=False)
|
||||
|
||||
|
||||
@dataclasses.dataclass(order=True, slots=True)
|
||||
class Ed25519PublicKeyMsg(PublicKeyMsg):
|
||||
type: KeyAlgo
|
||||
enc_a: binary_string
|
||||
comments: unicode_string = dataclasses.field(default=unicode_string(''), compare=False)
|
||||
|
||||
|
||||
@dataclasses.dataclass(order=True, slots=True)
|
||||
class KeyList(Msg):
|
||||
nkeys: uint32
|
||||
keys: PublicKeyMsgList
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.nkeys != len(self.keys):
|
||||
raise SshAgentFailure(
|
||||
"agent: invalid number of keys received for identities list"
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(order=True, slots=True)
|
||||
class PublicKeyMsgList(Msg):
|
||||
keys: list[PublicKeyMsg]
|
||||
|
||||
def __iter__(self) -> t.Iterator[PublicKeyMsg]:
|
||||
yield from self.keys
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.keys)
|
||||
|
||||
@classmethod
|
||||
def from_blob(cls, blob: memoryview | bytes) -> t.Self:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def consume_from_blob(cls, blob: memoryview | bytes) -> tuple[t.Self, memoryview | bytes]:
|
||||
args: list[PublicKeyMsg] = []
|
||||
while blob:
|
||||
prev_blob = blob
|
||||
key_blob, key_blob_length, comment_blob = cls._consume_field(blob)
|
||||
|
||||
peek_key_algo, _length, _blob = cls._consume_field(key_blob)
|
||||
pub_key_msg_cls = PublicKeyMsg.get_dataclass(
|
||||
KeyAlgo(bytes(peek_key_algo).decode('utf-8'))
|
||||
)
|
||||
|
||||
_fv, comment_blob_length, blob = cls._consume_field(comment_blob)
|
||||
key_plus_comment = (
|
||||
prev_blob[4: (4 + key_blob_length) + (4 + comment_blob_length)]
|
||||
)
|
||||
|
||||
args.append(pub_key_msg_cls.from_blob(key_plus_comment))
|
||||
return cls(args), b""
|
||||
|
||||
@staticmethod
|
||||
def _consume_field(
|
||||
blob: memoryview | bytes
|
||||
) -> tuple[memoryview | bytes, uint32, memoryview | bytes]:
|
||||
length = uint32.from_blob(blob[:4])
|
||||
blob = blob[4:]
|
||||
data, rest = _split_blob(blob, length)
|
||||
return data, length, rest
|
||||
|
||||
|
||||
class SshAgentClient:
|
||||
def __init__(self, auth_sock: str) -> None:
|
||||
self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
self._sock.settimeout(_SSH_AGENT_CLIENT_SOCKET_TIMEOUT)
|
||||
self._sock.connect(auth_sock)
|
||||
|
||||
def close(self) -> None:
|
||||
self._sock.close()
|
||||
|
||||
def __enter__(self) -> t.Self:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: types.TracebackType | None
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
def send(self, msg: bytes) -> bytes:
|
||||
length = uint32(len(msg)).to_blob()
|
||||
self._sock.sendall(length + msg)
|
||||
bufsize = uint32.from_blob(self._sock.recv(4))
|
||||
resp = self._sock.recv(bufsize)
|
||||
if resp[0] == ProtocolMsgNumbers.SSH_AGENT_FAILURE:
|
||||
raise SshAgentFailure('agent: failure')
|
||||
return resp
|
||||
|
||||
def remove_all(self) -> None:
|
||||
self.send(
|
||||
ProtocolMsgNumbers.SSH_AGENTC_REMOVE_ALL_IDENTITIES.to_blob()
|
||||
)
|
||||
|
||||
def remove(self, public_key: CryptoPublicKey) -> None:
|
||||
key_blob = PublicKeyMsg.from_public_key(public_key).to_blob()
|
||||
self.send(
|
||||
ProtocolMsgNumbers.SSH_AGENTC_REMOVE_IDENTITY.to_blob() +
|
||||
uint32(len(key_blob)).to_blob() + key_blob
|
||||
)
|
||||
|
||||
def add(
|
||||
self,
|
||||
private_key: CryptoPrivateKey,
|
||||
comments: str | None = None,
|
||||
lifetime: int | None = None,
|
||||
confirm: bool | None = None,
|
||||
) -> None:
|
||||
key_msg = PrivateKeyMsg.from_private_key(private_key)
|
||||
key_msg.comments = unicode_string(comments or '')
|
||||
if lifetime:
|
||||
key_msg.constraints += constraints(
|
||||
[ProtocolMsgNumbers.SSH_AGENT_CONSTRAIN_LIFETIME]
|
||||
).to_blob() + uint32(lifetime).to_blob()
|
||||
if confirm:
|
||||
key_msg.constraints += constraints(
|
||||
[ProtocolMsgNumbers.SSH_AGENT_CONSTRAIN_CONFIRM]
|
||||
).to_blob()
|
||||
|
||||
if key_msg.constraints:
|
||||
msg = ProtocolMsgNumbers.SSH_AGENTC_ADD_ID_CONSTRAINED.to_blob()
|
||||
else:
|
||||
msg = ProtocolMsgNumbers.SSH_AGENTC_ADD_IDENTITY.to_blob()
|
||||
msg += key_msg.to_blob()
|
||||
self.send(msg)
|
||||
|
||||
def list(self) -> KeyList:
|
||||
req = ProtocolMsgNumbers.SSH_AGENTC_REQUEST_IDENTITIES.to_blob()
|
||||
r = memoryview(bytearray(self.send(req)))
|
||||
if r[0] != ProtocolMsgNumbers.SSH_AGENT_IDENTITIES_ANSWER:
|
||||
raise SshAgentFailure(
|
||||
'agent: non-identities answer received for identities list'
|
||||
)
|
||||
return KeyList.from_blob(r[1:])
|
||||
|
||||
def __contains__(self, public_key: CryptoPublicKey) -> bool:
|
||||
msg = PublicKeyMsg.from_public_key(public_key)
|
||||
return msg in self.list().keys
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _key_data_into_crypto_objects(key_data: bytes, passphrase: bytes | None) -> tuple[CryptoPrivateKey, CryptoPublicKey, str]:
|
||||
private_key = serialization.ssh.load_ssh_private_key(key_data, passphrase)
|
||||
public_key = private_key.public_key()
|
||||
fingerprint = PublicKeyMsg.from_public_key(public_key).fingerprint
|
||||
|
||||
return private_key, public_key, fingerprint
|
||||
28
licenses/BSD-3-Clause.txt
Normal file
28
licenses/BSD-3-Clause.txt
Normal file
@@ -0,0 +1,28 @@
|
||||
Copyright (c) Contributors to the Ansible project. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its contributors
|
||||
may be used to endorse or promote products derived from this software
|
||||
without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
|
||||
OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
|
||||
HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
|
||||
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
|
||||
OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
|
||||
SUCH DAMAGE.
|
||||
@@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from ansible.plugins.action import ActionBase
|
||||
from ansible.utils._ssh_agent import SshAgentClient
|
||||
|
||||
from cryptography.hazmat.primitives.serialization import ssh
|
||||
|
||||
|
||||
class ActionModule(ActionBase):
|
||||
|
||||
def run(self, tmp=None, task_vars=None):
|
||||
results = super(ActionModule, self).run(tmp, task_vars)
|
||||
del tmp # tmp no longer has any effect
|
||||
match self._task.args['action']:
|
||||
case 'list':
|
||||
return self.list()
|
||||
case 'remove':
|
||||
return self.remove(self._task.args['pubkey'])
|
||||
case 'remove_all':
|
||||
return self.remove_all()
|
||||
case _:
|
||||
return {'failed': True, 'msg': 'not implemented'}
|
||||
|
||||
def remove(self, pubkey_data):
|
||||
with SshAgentClient(os.environ['SSH_AUTH_SOCK']) as client:
|
||||
public_key = ssh.load_ssh_public_key(pubkey_data.encode())
|
||||
client.remove(public_key)
|
||||
return {'failed': public_key in client}
|
||||
|
||||
def remove_all(self):
|
||||
with SshAgentClient(os.environ['SSH_AUTH_SOCK']) as client:
|
||||
nkeys_before = client.list().nkeys
|
||||
client.remove_all()
|
||||
nkeys_after = client.list().nkeys
|
||||
return {
|
||||
'failed': nkeys_after != 0,
|
||||
'nkeys_removed': nkeys_before,
|
||||
}
|
||||
|
||||
def list(self):
|
||||
result = {'keys': [], 'nkeys': 0}
|
||||
with SshAgentClient(os.environ['SSH_AUTH_SOCK']) as client:
|
||||
key_list = client.list()
|
||||
result['nkeys'] = key_list.nkeys
|
||||
for key in key_list.keys:
|
||||
public_key = key.public_key
|
||||
key_size = getattr(public_key, 'key_size', 256)
|
||||
fingerprint = key.fingerprint
|
||||
key_type = key.type.main_type
|
||||
result['keys'].append({
|
||||
'type': key_type,
|
||||
'key_size': key_size,
|
||||
'fingerprint': f'SHA256:{fingerprint}',
|
||||
'comments': key.comments,
|
||||
})
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ansible.plugins.action import ActionBase
|
||||
from ansible.utils._ssh_agent import PublicKeyMsg
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_text
|
||||
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import generate_private_key as rsa_generate_private_key
|
||||
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
|
||||
from cryptography.hazmat.primitives.asymmetric.dsa import generate_private_key as dsa_generate_private_key
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import SECP384R1, generate_private_key as ecdsa_generate_private_key
|
||||
|
||||
|
||||
class ActionModule(ActionBase):
|
||||
|
||||
def run(self, tmp=None, task_vars=None):
|
||||
results = super(ActionModule, self).run(tmp, task_vars)
|
||||
del tmp # tmp no longer has any effect
|
||||
match self._task.args.get('type'):
|
||||
case 'ed25519':
|
||||
private_key = Ed25519PrivateKey.generate()
|
||||
case 'rsa':
|
||||
private_key = rsa_generate_private_key(65537, 4096)
|
||||
case 'dsa':
|
||||
private_key = dsa_generate_private_key(1024)
|
||||
case 'ecdsa':
|
||||
private_key = ecdsa_generate_private_key(SECP384R1())
|
||||
case _:
|
||||
return {'failed': True, 'msg': 'not implemented'}
|
||||
|
||||
public_key = private_key.public_key()
|
||||
public_key_msg = PublicKeyMsg.from_public_key(public_key)
|
||||
|
||||
if not (passphrase := self._task.args.get('passphrase')):
|
||||
encryption_algorithm = serialization.NoEncryption()
|
||||
else:
|
||||
encryption_algorithm = serialization.BestAvailableEncryption(
|
||||
to_bytes(passphrase)
|
||||
)
|
||||
|
||||
return {
|
||||
'changed': True,
|
||||
'private_key': to_text(private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.OpenSSH,
|
||||
encryption_algorithm=encryption_algorithm,
|
||||
)),
|
||||
'public_key': to_text(public_key.public_bytes(
|
||||
encoding=serialization.Encoding.OpenSSH,
|
||||
format=serialization.PublicFormat.OpenSSH,
|
||||
)),
|
||||
'fingerprint': f'SHA256:{public_key_msg.fingerprint}',
|
||||
}
|
||||
3
test/integration/targets/ssh_agent/aliases
Normal file
3
test/integration/targets/ssh_agent/aliases
Normal file
@@ -0,0 +1,3 @@
|
||||
needs/ssh
|
||||
shippable/posix/group2
|
||||
context/target
|
||||
46
test/integration/targets/ssh_agent/auto.yml
Normal file
46
test/integration/targets/ssh_agent/auto.yml
Normal file
@@ -0,0 +1,46 @@
|
||||
- hosts: testhost
|
||||
tasks:
|
||||
- set_fact:
|
||||
key_types:
|
||||
- ed25519
|
||||
- rsa
|
||||
- ecdsa
|
||||
|
||||
- set_fact:
|
||||
key_types: "{{ key_types + ['dsa'] }}"
|
||||
when: ansible_distribution == "RedHat"
|
||||
|
||||
- include_tasks: test_key.yml
|
||||
loop: "{{ key_types }}"
|
||||
loop_control:
|
||||
extended: true
|
||||
|
||||
- ssh_agent:
|
||||
action: remove
|
||||
pubkey: "{{ sshkey.public_key }}"
|
||||
|
||||
- ssh_agent:
|
||||
action: list
|
||||
register: keys
|
||||
|
||||
- assert:
|
||||
that:
|
||||
- keys.nkeys == key_types | length - 1
|
||||
|
||||
- name: remove all keys
|
||||
ssh_agent:
|
||||
action: remove_all
|
||||
register: r
|
||||
|
||||
- assert:
|
||||
that:
|
||||
- r is success
|
||||
- r.nkeys_removed == key_types | length - 1
|
||||
|
||||
- ssh_agent:
|
||||
action: list
|
||||
register: keys
|
||||
|
||||
- assert:
|
||||
that:
|
||||
- keys.nkeys == 0
|
||||
23
test/integration/targets/ssh_agent/tasks/main.yml
Normal file
23
test/integration/targets/ssh_agent/tasks/main.yml
Normal file
@@ -0,0 +1,23 @@
|
||||
- delegate_to: localhost
|
||||
block:
|
||||
- name: install bcrypt
|
||||
pip:
|
||||
name: bcrypt
|
||||
register: bcrypt
|
||||
|
||||
- tempfile:
|
||||
path: "{{ lookup('env', 'OUTPUT_DIR') }}"
|
||||
state: directory
|
||||
register: tmpdir
|
||||
|
||||
- import_tasks: tests.yml
|
||||
always:
|
||||
- name: uninstall bcrypt
|
||||
pip:
|
||||
name: bcrypt
|
||||
state: absent
|
||||
when: bcrypt is changed
|
||||
|
||||
- file:
|
||||
path: tmpdir.path
|
||||
state: absent
|
||||
49
test/integration/targets/ssh_agent/tasks/tests.yml
Normal file
49
test/integration/targets/ssh_agent/tasks/tests.yml
Normal file
@@ -0,0 +1,49 @@
|
||||
- slurp:
|
||||
path: ~/.ssh/authorized_keys
|
||||
register: akeys
|
||||
|
||||
- debug:
|
||||
msg: '{{ akeys.content|b64decode }}'
|
||||
|
||||
- command: ansible-playbook -i {{ ansible_inventory_sources|first|quote }} -vvv {{ role_path }}/auto.yml
|
||||
environment:
|
||||
ANSIBLE_CALLBACK_RESULT_FORMAT: yaml
|
||||
ANSIBLE_SSH_AGENT: auto
|
||||
register: auto
|
||||
|
||||
- command: ps {{ ps_flags }} -opid
|
||||
register: pids
|
||||
# Some distros will exit with rc=1 if no processes were returned
|
||||
vars:
|
||||
ps_flags: '{{ "" if ansible_distribution == "Alpine" else "-x" }}'
|
||||
|
||||
- assert:
|
||||
that:
|
||||
- >-
|
||||
'started and bound to' in auto.stdout
|
||||
- >-
|
||||
'SSH: SSH_AGENT adding' in auto.stdout
|
||||
- >-
|
||||
'exists in agent' in auto.stdout
|
||||
- pids|map('trim')|select('eq', pid) == []
|
||||
vars:
|
||||
pid: '{{ auto.stdout|regex_findall("ssh-agent\[(\d+)\]")|first }}'
|
||||
|
||||
- command: ssh-agent -D -s -a '{{ tmpdir.path }}/agent.sock'
|
||||
async: 30
|
||||
poll: 0
|
||||
|
||||
- command: ansible-playbook -i {{ ansible_inventory_sources|first|quote }} -vvv {{ role_path }}/auto.yml
|
||||
environment:
|
||||
ANSIBLE_CALLBACK_RESULT_FORMAT: yaml
|
||||
ANSIBLE_SSH_AGENT: '{{ tmpdir.path }}/agent.sock'
|
||||
register: existing
|
||||
|
||||
- assert:
|
||||
that:
|
||||
- >-
|
||||
'started and bound to' not in existing.stdout
|
||||
- >-
|
||||
'SSH: SSH_AGENT adding' in existing.stdout
|
||||
- >-
|
||||
'exists in agent' in existing.stdout
|
||||
38
test/integration/targets/ssh_agent/test_key.yml
Normal file
38
test/integration/targets/ssh_agent/test_key.yml
Normal file
@@ -0,0 +1,38 @@
|
||||
- ssh_keygen:
|
||||
type: "{{ item }}"
|
||||
passphrase: passphrase
|
||||
register: sshkey
|
||||
|
||||
- slurp:
|
||||
path: ~/.ssh/authorized_keys
|
||||
register: akeys
|
||||
|
||||
- copy:
|
||||
content: |
|
||||
{{ sshkey.public_key }}
|
||||
{{ akeys.content|b64decode }}
|
||||
dest: ~/.ssh/authorized_keys
|
||||
mode: '0400'
|
||||
|
||||
- block:
|
||||
- ping:
|
||||
|
||||
- name: list keys from agent
|
||||
ssh_agent:
|
||||
action: list
|
||||
register: keys
|
||||
|
||||
- assert:
|
||||
that:
|
||||
- keys.nkeys == ansible_loop.index
|
||||
- keys['keys'][ansible_loop.index0].fingerprint == fingerprint
|
||||
|
||||
- name: key already exists in the agent
|
||||
ping:
|
||||
vars:
|
||||
ansible_password: ~
|
||||
ansible_ssh_password: ~
|
||||
ansible_ssh_private_key_file: ~
|
||||
ansible_ssh_private_key: '{{ sshkey.private_key }}'
|
||||
ansible_ssh_private_key_passphrase: passphrase
|
||||
fingerprint: '{{ sshkey.fingerprint }}'
|
||||
Reference in New Issue
Block a user