Add ssh-agent launching, and ssh-agent python client (#84754)

* Add ssh-agent launching, and ssh-agent python client

* Move things around, is this better??

* docs

* postpone creating dir after bin lookup

* fix method name

* changelog ssh agent

* address reviews

* fix typing

* do not redefine public_key

* typing

* more typing

* Catch OSError when starting ssh agent

* likely copy pasted old code

* var type fix

* why is this needed?

ci_complete

* ignoring the change for now

* write out pub key file atomically

* defensive timeout for the socket

* _populate_agent docstring

* do not allow setting these in config

* check expected length before slicing blobs

* test all key types

* remove lock/unlock functionality

* docstring

* private _ssh_agent

* .

* launch agent in cli and ansible_ssh_*

* additional info for ssh-agent comment

* Add tests for remove and remove_all

* comment on os.rename

* hopefully mitigate agent startup/delays problems

* exceptions

* unused import

* fix sanity

* perf

---------

Co-authored-by: Matt Martz <matt@sivel.net>
This commit is contained in:
Martin Krizek
2025-04-11 00:30:34 +02:00
committed by GitHub
parent fcdf0b80b3
commit 244c2f06ed
16 changed files with 1161 additions and 6 deletions

View File

@@ -0,0 +1,6 @@
minor_changes:
- ssh-agent - ``ansible``, ``ansible-playbook`` and ``ansible-console`` are capable of spawning or reusing an ssh-agent,
allowing plugins to interact with the ssh-agent.
Additionally a pure python ssh-agent client has been added, enabling easy interaction with the agent. The ssh connection plugin contains
new functionality via ``ansible_ssh_private_key`` and ``ansible_ssh_private_key_passphrase``, for loading an SSH private key into
the agent from a variable.

View File

@@ -7,6 +7,7 @@ from __future__ import annotations
import locale
import os
import signal
import sys
@@ -88,6 +89,7 @@ if jinja2_version < LooseVersion('3.1'):
'Current version: %s' % jinja2_version
)
import atexit
import errno
import getpass
import subprocess
@@ -111,10 +113,12 @@ from ansible.module_utils.six import string_types
from ansible.module_utils.common.text.converters import to_bytes, to_text
from ansible.module_utils.common.collections import is_sequence
from ansible.module_utils.common.file import is_executable
from ansible.module_utils.common.process import get_bin_path
from ansible.parsing.dataloader import DataLoader
from ansible.parsing.vault import PromptVaultSecret, get_file_vault_secret
from ansible.plugins.loader import add_all_plugin_dirs, init_plugin_loader
from ansible.release import __version__
from ansible.utils._ssh_agent import SshAgentClient
from ansible.utils.collection_loader import AnsibleCollectionConfig
from ansible.utils.collection_loader._collection_finder import _get_collection_name_from_path
from ansible.utils.path import unfrackpath
@@ -128,6 +132,77 @@ except ImportError:
HAS_ARGCOMPLETE = False
_SSH_AGENT_STDOUT_READ_TIMEOUT = 5 # seconds
def _ssh_agent_timeout_handler(signum, frame):
raise TimeoutError
def _launch_ssh_agent() -> None:
ssh_agent_cfg = C.config.get_config_value('SSH_AGENT')
match ssh_agent_cfg:
case 'none':
display.debug('SSH_AGENT set to none')
return
case 'auto':
try:
ssh_agent_bin = get_bin_path('ssh-agent', required=True)
except ValueError as e:
raise AnsibleError('SSH_AGENT set to auto, but cannot find ssh-agent binary') from e
ssh_agent_dir = os.path.join(C.DEFAULT_LOCAL_TMP, 'ssh_agent')
os.mkdir(ssh_agent_dir, 0o700)
sock = os.path.join(ssh_agent_dir, 'agent.sock')
display.vvv('SSH_AGENT: starting...')
try:
p = subprocess.Popen(
[ssh_agent_bin, '-D', '-s', '-a', sock],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
except OSError as e:
raise AnsibleError(
f'Could not start ssh-agent: {e}'
) from e
if p.poll() is not None:
raise AnsibleError(
f'Could not start ssh-agent: rc={p.returncode} stderr="{p.stderr.read().decode()}"'
)
old_sigalrm_handler = signal.signal(signal.SIGALRM, _ssh_agent_timeout_handler)
signal.alarm(_SSH_AGENT_STDOUT_READ_TIMEOUT)
try:
stdout = p.stdout.read(13)
except TimeoutError:
stdout = b''
finally:
signal.alarm(0)
signal.signal(signal.SIGALRM, old_sigalrm_handler)
if stdout != b'SSH_AUTH_SOCK':
display.warning(
f'The first 13 characters of stdout did not match the '
f'expected SSH_AUTH_SOCK. This may not be the right binary, '
f'or an incompatible agent: {stdout.decode()}'
)
display.vvv(f'SSH_AGENT: ssh-agent[{p.pid}] started and bound to {sock}')
atexit.register(p.terminate)
case _:
sock = ssh_agent_cfg
try:
with SshAgentClient(sock) as client:
client.list()
except Exception as e:
raise AnsibleError(
f'Could not communicate with ssh-agent using auth sock {sock}: {e}'
) from e
os.environ['SSH_AUTH_SOCK'] = os.environ['ANSIBLE_SSH_AGENT'] = sock
class CLI(ABC):
""" code behind bin/ansible* programs """
@@ -137,6 +212,7 @@ class CLI(ABC):
# -S (chop long lines) -X (disable termcap init and de-init)
LESS_OPTS = 'FRSX'
SKIP_INVENTORY_DEFAULTS = False
USES_CONNECTION = False
def __init__(self, args, callback=None):
"""
@@ -528,8 +604,7 @@ class CLI(ABC):
except KeyboardInterrupt:
pass
@staticmethod
def _play_prereqs():
def _play_prereqs(self):
# TODO: evaluate moving all of the code that touches ``AnsibleCollectionConfig``
# into ``init_plugin_loader`` so that we can specifically remove
# ``AnsibleCollectionConfig.playbook_paths`` to make it immutable after instantiation
@@ -560,6 +635,12 @@ class CLI(ABC):
auto_prompt=False)
loader.set_vault_secrets(vault_secrets)
if self.USES_CONNECTION:
try:
_launch_ssh_agent()
except Exception as e:
raise AnsibleError('Failed to launch ssh agent', orig_exc=e)
# create the inventory, and filter it based on the subset specified (if any)
inventory = InventoryManager(loader=loader, sources=options['inventory'], cache=(not options.get('flush_cache')))

View File

@@ -30,6 +30,8 @@ class AdHocCLI(CLI):
name = 'ansible'
USES_CONNECTION = True
def init_parser(self):
""" create an options parser for bin/ansible """
super(AdHocCLI, self).init_parser(usage='%prog <host-pattern> [options]',

View File

@@ -72,6 +72,8 @@ class ConsoleCLI(CLI, cmd.Cmd):
# use specific to console, but fallback to highlight for backwards compatibility
NORMAL_PROMPT = C.COLOR_CONSOLE_PROMPT or C.COLOR_HIGHLIGHT
USES_CONNECTION = True
def __init__(self, args):
super(ConsoleCLI, self).__init__(args)

View File

@@ -34,6 +34,8 @@ class PlaybookCLI(CLI):
name = 'ansible-playbook'
USES_CONNECTION = True
def init_parser(self):
# create parser for CLI options

View File

@@ -1888,6 +1888,24 @@ SHOW_CUSTOM_STATS:
ini:
- {key: show_custom_stats, section: defaults}
type: bool
SSH_AGENT:
name: Manage an SSH Agent
description: Manage an SSH Agent via Ansible. A configuration of ``none`` will not interact with an agent,
``auto`` will start and destroy an agent via ``ssh-agent`` binary during the run, and a path
to an SSH_AUTH_SOCK will allow interaction with a pre-existing agent.
default: none
type: string
env: [{name: ANSIBLE_SSH_AGENT}]
ini: [{key: ssh_agent, section: connection}]
version_added: '2.19'
SSH_AGENT_KEY_LIFETIME:
name: Set a maximum lifetime when adding identities to an agent
description: For keys inserted into an agent defined by ``SSH_AGENT``, define a lifetime, in seconds, that the key may remain
in the agent.
type: int
env: [{name: ANSIBLE_SSH_AGENT_KEY_LIFETIME}]
ini: [{key: ssh_agent_key_lifetime, section: connection}]
version_added: '2.19'
STRING_TYPE_FILTERS:
name: Filters to preserve strings
default: [string, to_json, to_nice_json, to_yaml, to_nice_yaml, ppretty, json]

View File

@@ -265,7 +265,6 @@ DOCUMENTATION = """
vars:
- name: ansible_pipelining
- name: ansible_ssh_pipelining
private_key_file:
description:
- Path to private key file to use for authentication.
@@ -281,7 +280,27 @@ DOCUMENTATION = """
cli:
- name: private_key_file
option: '--private-key'
private_key:
description:
- Private key contents in PEM format. Requires the C(SSH_AGENT) configuration to be enabled.
type: string
env:
- name: ANSIBLE_PRIVATE_KEY
vars:
- name: ansible_private_key
- name: ansible_ssh_private_key
version_added: '2.19'
private_key_passphrase:
description:
- Private key passphrase, dependent on O(private_key).
- This does NOT have any effect when used with O(private_key_file).
type: string
env:
- name: ANSIBLE_PRIVATE_KEY_PASSPHRASE
vars:
- name: ansible_private_key_passphrase
- name: ansible_ssh_private_key_passphrase
version_added: '2.19'
control_path:
description:
- This is the location to save SSH's ControlPath sockets, it uses SSH's variable substitution.
@@ -398,11 +417,13 @@ import shlex
import shutil
import subprocess
import sys
import tempfile
import time
import typing as t
from functools import wraps
from multiprocessing.shared_memory import SharedMemory
from ansible import constants as C
from ansible.errors import (
AnsibleAuthenticationFailure,
AnsibleConnectionFailure,
@@ -415,6 +436,15 @@ from ansible.plugins.connection import ConnectionBase, BUFSIZE
from ansible.plugins.shell.powershell import _replace_stderr_clixml
from ansible.utils.display import Display
from ansible.utils.path import unfrackpath, makedirs_safe
from ansible.utils._ssh_agent import SshAgentClient, _key_data_into_crypto_objects
try:
from cryptography.hazmat.primitives import serialization
except ImportError:
HAS_CRYPTOGRAPHY = False
else:
HAS_CRYPTOGRAPHY = True
display = Display()
@@ -638,6 +668,8 @@ class Connection(ConnectionBase):
self._tty_parser.add_argument('-t', action='count')
self._tty_parser.add_argument('-o', action='append')
self._populated_agent: pathlib.Path | None = None
# The connection is created by running ssh/scp/sftp from the exec_command,
# put_file, and fetch_file methods, so we don't need to do any connection
# management here.
@@ -712,6 +744,52 @@ class Connection(ConnectionBase):
display.vvvvv(u'SSH: %s: (%s)' % (explanation, ')('.join(to_text(a) for a in b_args)), host=self.host)
b_command += b_args
def _populate_agent(self) -> pathlib.Path:
"""Adds configured private key identity to the SSH agent. Returns a path to a file containing the public key."""
if self._populated_agent:
return self._populated_agent
if (auth_sock := C.config.get_config_value('SSH_AGENT')) == 'none':
raise AnsibleError('Cannot utilize private_key with SSH_AGENT disabled')
key_data = self.get_option('private_key')
passphrase = self.get_option('private_key_passphrase')
private_key, public_key, fingerprint = _key_data_into_crypto_objects(
to_bytes(key_data),
to_bytes(passphrase) if passphrase else None,
)
with SshAgentClient(auth_sock) as client:
if public_key not in client:
display.vvv(f'SSH: SSH_AGENT adding {fingerprint} to agent', host=self.host)
client.add(
private_key,
f'[added by ansible: PID={os.getpid()}, UID={os.getuid()}, EUID={os.geteuid()}, TIME={time.time()}]',
C.config.get_config_value('SSH_AGENT_KEY_LIFETIME'),
)
else:
display.vvv(f'SSH: SSH_AGENT {fingerprint} exists in agent', host=self.host)
# Write the public key to disk, to be provided as IdentityFile.
# This allows ssh to pick an explicit key in the agent to use,
# preventing ssh from attempting all keys in the agent.
pubkey_path = self._populated_agent = pathlib.Path(C.DEFAULT_LOCAL_TMP).joinpath(
fingerprint.replace('/', '-') + '.pub'
)
if os.path.exists(pubkey_path):
return pubkey_path
with tempfile.NamedTemporaryFile(dir=C.DEFAULT_LOCAL_TMP, delete=False) as f:
f.write(public_key.public_bytes(
encoding=serialization.Encoding.OpenSSH,
format=serialization.PublicFormat.OpenSSH
))
# move atomically to prevent race conditions, silently succeeds if the target exists
os.rename(f.name, pubkey_path)
os.chmod(pubkey_path, mode=0o400)
return self._populated_agent
def _build_command(self, binary: str, subsystem: str, *other_args: bytes | str) -> list[bytes]:
"""
Takes a executable (ssh, scp, sftp or wrapper) and optional extra arguments and returns the remote command
@@ -797,8 +875,17 @@ class Connection(ConnectionBase):
b_args = (b"-o", b"Port=" + to_bytes(self.port, nonstring='simplerepr', errors='surrogate_or_strict'))
self._add_args(b_command, b_args, u"ANSIBLE_REMOTE_PORT/remote_port/ansible_port set")
key = self.get_option('private_key_file')
if key:
if self.get_option('private_key'):
try:
key = self._populate_agent()
except Exception as e:
raise AnsibleAuthenticationFailure(
'Failed to add configured private key into ssh-agent',
orig_exc=e,
)
b_args = (b'-o', b'IdentitiesOnly=yes', b'-o', to_bytes(f'IdentityFile="{key}"', errors='surrogate_or_strict'))
self._add_args(b_command, b_args, "ANSIBLE_PRIVATE_KEY/private_key set")
elif key := self.get_option('private_key_file'):
b_args = (b"-o", b'IdentityFile="' + to_bytes(os.path.expanduser(key), errors='surrogate_or_strict') + b'"')
self._add_args(b_command, b_args, u"ANSIBLE_PRIVATE_KEY_FILE/private_key_file/ansible_ssh_private_key_file set")

View File

@@ -0,0 +1,657 @@
# Copyright: Contributors to the Ansible project
# BSD 3 Clause License (see licenses/BSD-3-Clause.txt or https://opensource.org/license/bsd-3-clause/)
from __future__ import annotations
import binascii
import copy
import dataclasses
import enum
import functools
import hashlib
import socket
import types
import typing as t
try:
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.dsa import (
DSAParameterNumbers,
DSAPrivateKey,
DSAPublicKey,
DSAPublicNumbers,
)
from cryptography.hazmat.primitives.asymmetric.ec import (
EllipticCurve,
EllipticCurvePrivateKey,
EllipticCurvePublicKey,
SECP256R1,
SECP384R1,
SECP521R1,
)
from cryptography.hazmat.primitives.asymmetric.ed25519 import (
Ed25519PrivateKey,
Ed25519PublicKey,
)
from cryptography.hazmat.primitives.asymmetric.rsa import (
RSAPrivateKey,
RSAPublicKey,
RSAPublicNumbers,
)
except ImportError:
HAS_CRYPTOGRAPHY = False
else:
HAS_CRYPTOGRAPHY = True
CryptoPublicKey = t.Union[
DSAPublicKey,
EllipticCurvePublicKey,
Ed25519PublicKey,
RSAPublicKey,
]
CryptoPrivateKey = t.Union[
DSAPrivateKey,
EllipticCurvePrivateKey,
Ed25519PrivateKey,
RSAPrivateKey,
]
if t.TYPE_CHECKING:
from cryptography.hazmat.primitives.asymmetric.dsa import DSAPrivateNumbers
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateNumbers
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateNumbers
_SSH_AGENT_CLIENT_SOCKET_TIMEOUT = 10
class ProtocolMsgNumbers(enum.IntEnum):
# Responses
SSH_AGENT_FAILURE = 5
SSH_AGENT_SUCCESS = 6
SSH_AGENT_IDENTITIES_ANSWER = 12
SSH_AGENT_SIGN_RESPONSE = 14
SSH_AGENT_EXTENSION_FAILURE = 28
SSH_AGENT_EXTENSION_RESPONSE = 29
# Constraints
SSH_AGENT_CONSTRAIN_LIFETIME = 1
SSH_AGENT_CONSTRAIN_CONFIRM = 2
SSH_AGENT_CONSTRAIN_EXTENSION = 255
# Requests
SSH_AGENTC_REQUEST_IDENTITIES = 11
SSH_AGENTC_SIGN_REQUEST = 13
SSH_AGENTC_ADD_IDENTITY = 17
SSH_AGENTC_REMOVE_IDENTITY = 18
SSH_AGENTC_REMOVE_ALL_IDENTITIES = 19
SSH_AGENTC_ADD_SMARTCARD_KEY = 20
SSH_AGENTC_REMOVE_SMARTCARD_KEY = 21
SSH_AGENTC_LOCK = 22
SSH_AGENTC_UNLOCK = 23
SSH_AGENTC_ADD_ID_CONSTRAINED = 25
SSH_AGENTC_ADD_SMARTCARD_KEY_CONSTRAINED = 26
SSH_AGENTC_EXTENSION = 27
def to_blob(self) -> bytes:
return bytes([self])
class SshAgentFailure(RuntimeError):
"""Server failure or unexpected response."""
# NOTE: Classes below somewhat represent "Data Type Representations Used in the SSH Protocols"
# as specified by RFC4251
@t.runtime_checkable
class SupportsToBlob(t.Protocol):
def to_blob(self) -> bytes:
...
@t.runtime_checkable
class SupportsFromBlob(t.Protocol):
@classmethod
def from_blob(cls, blob: memoryview | bytes) -> t.Self:
...
@classmethod
def consume_from_blob(cls, blob: memoryview | bytes) -> tuple[t.Self, memoryview | bytes]:
...
def _split_blob(blob: memoryview | bytes, length: int) -> tuple[memoryview | bytes, memoryview | bytes]:
if len(blob) < length:
raise ValueError("_split_blob: unexpected data length")
return blob[:length], blob[length:]
class VariableSized:
@classmethod
def from_blob(cls, blob: memoryview | bytes) -> t.Self:
raise NotImplementedError
@classmethod
def consume_from_blob(cls, blob: memoryview | bytes) -> tuple[t.Self, memoryview | bytes]:
length = uint32.from_blob(blob[:4])
blob = blob[4:]
data, rest = _split_blob(blob, length)
return cls.from_blob(data), rest
class uint32(int):
def to_blob(self) -> bytes:
return self.to_bytes(length=4, byteorder='big')
@classmethod
def from_blob(cls, blob: memoryview | bytes) -> t.Self:
return cls.from_bytes(blob, byteorder='big')
@classmethod
def consume_from_blob(cls, blob: memoryview | bytes) -> tuple[t.Self, memoryview | bytes]:
length = uint32(4)
data, rest = _split_blob(blob, length)
return cls.from_blob(data), rest
class mpint(int, VariableSized):
def to_blob(self) -> bytes:
if self < 0:
raise ValueError("negative mpint not allowed")
if not self:
return b""
nbytes = (self.bit_length() + 8) // 8
ret = bytearray(self.to_bytes(length=nbytes, byteorder='big'))
ret[:0] = uint32(len(ret)).to_blob()
return ret
@classmethod
def from_blob(cls, blob: memoryview | bytes) -> t.Self:
if blob and blob[0] > 127:
raise ValueError("Invalid data")
return cls.from_bytes(blob, byteorder='big')
class constraints(bytes):
def to_blob(self) -> bytes:
return self
class binary_string(bytes, VariableSized):
def to_blob(self) -> bytes:
if length := len(self):
return uint32(length).to_blob() + self
else:
return b""
@classmethod
def from_blob(cls, blob: memoryview | bytes) -> t.Self:
return cls(blob)
class unicode_string(str, VariableSized):
def to_blob(self) -> bytes:
val = self.encode('utf-8')
if length := len(val):
return uint32(length).to_blob() + val
else:
return b""
@classmethod
def from_blob(cls, blob: memoryview | bytes) -> t.Self:
return cls(bytes(blob).decode('utf-8'))
class KeyAlgo(str, VariableSized, enum.Enum):
RSA = "ssh-rsa"
DSA = "ssh-dss"
ECDSA256 = "ecdsa-sha2-nistp256"
SKECDSA256 = "sk-ecdsa-sha2-nistp256@openssh.com"
ECDSA384 = "ecdsa-sha2-nistp384"
ECDSA521 = "ecdsa-sha2-nistp521"
ED25519 = "ssh-ed25519"
SKED25519 = "sk-ssh-ed25519@openssh.com"
RSASHA256 = "rsa-sha2-256"
RSASHA512 = "rsa-sha2-512"
@property
def main_type(self) -> str:
match self:
case self.RSA:
return 'RSA'
case self.DSA:
return 'DSA'
case self.ECDSA256 | self.ECDSA384 | self.ECDSA521:
return 'ECDSA'
case self.ED25519:
return 'ED25519'
case _:
raise NotImplementedError(self.name)
def to_blob(self) -> bytes:
b_self = self.encode('utf-8')
return uint32(len(b_self)).to_blob() + b_self
@classmethod
def from_blob(cls, blob: memoryview | bytes) -> t.Self:
return cls(bytes(blob).decode('utf-8'))
if HAS_CRYPTOGRAPHY:
_ECDSA_KEY_TYPE: dict[KeyAlgo, type[EllipticCurve]] = {
KeyAlgo.ECDSA256: SECP256R1,
KeyAlgo.ECDSA384: SECP384R1,
KeyAlgo.ECDSA521: SECP521R1,
}
@dataclasses.dataclass
class Msg:
def to_blob(self) -> bytes:
rv = bytearray()
for field in dataclasses.fields(self):
fv = getattr(self, field.name)
if isinstance(fv, SupportsToBlob):
rv.extend(fv.to_blob())
else:
raise NotImplementedError(field.type)
return rv
@classmethod
def from_blob(cls, blob: memoryview | bytes) -> t.Self:
args: list[t.Any] = []
for _field_name, field_type in t.get_type_hints(cls).items():
if isinstance(field_type, SupportsFromBlob):
fv, blob = field_type.consume_from_blob(blob)
args.append(fv)
else:
raise NotImplementedError(str(field_type))
return cls(*args)
@dataclasses.dataclass
class PrivateKeyMsg(Msg):
@staticmethod
def from_private_key(private_key: CryptoPrivateKey) -> PrivateKeyMsg:
match private_key:
case RSAPrivateKey():
rsa_pn: RSAPrivateNumbers = private_key.private_numbers()
return RSAPrivateKeyMsg(
KeyAlgo.RSA,
mpint(rsa_pn.public_numbers.n),
mpint(rsa_pn.public_numbers.e),
mpint(rsa_pn.d),
mpint(rsa_pn.iqmp),
mpint(rsa_pn.p),
mpint(rsa_pn.q),
)
case DSAPrivateKey():
dsa_pn: DSAPrivateNumbers = private_key.private_numbers()
return DSAPrivateKeyMsg(
KeyAlgo.DSA,
mpint(dsa_pn.public_numbers.parameter_numbers.p),
mpint(dsa_pn.public_numbers.parameter_numbers.q),
mpint(dsa_pn.public_numbers.parameter_numbers.g),
mpint(dsa_pn.public_numbers.y),
mpint(dsa_pn.x),
)
case EllipticCurvePrivateKey():
ecdsa_pn: EllipticCurvePrivateNumbers = private_key.private_numbers()
key_size = private_key.key_size
return EcdsaPrivateKeyMsg(
getattr(KeyAlgo, f'ECDSA{key_size}'),
unicode_string(f'nistp{key_size}'),
binary_string(private_key.public_key().public_bytes(
encoding=serialization.Encoding.X962,
format=serialization.PublicFormat.UncompressedPoint
)),
mpint(ecdsa_pn.private_value),
)
case Ed25519PrivateKey():
public_bytes = private_key.public_key().public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
)
private_bytes = private_key.private_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PrivateFormat.Raw,
encryption_algorithm=serialization.NoEncryption()
)
return Ed25519PrivateKeyMsg(
KeyAlgo.ED25519,
binary_string(public_bytes),
binary_string(private_bytes + public_bytes),
)
case _:
raise NotImplementedError(private_key)
@dataclasses.dataclass(order=True, slots=True)
class RSAPrivateKeyMsg(PrivateKeyMsg):
type: KeyAlgo
n: mpint
e: mpint
d: mpint
iqmp: mpint
p: mpint
q: mpint
comments: unicode_string = dataclasses.field(default=unicode_string(''), compare=False)
constraints: constraints = dataclasses.field(default=constraints(b''))
@dataclasses.dataclass(order=True, slots=True)
class DSAPrivateKeyMsg(PrivateKeyMsg):
type: KeyAlgo
p: mpint
q: mpint
g: mpint
y: mpint
x: mpint
comments: unicode_string = dataclasses.field(default=unicode_string(''), compare=False)
constraints: constraints = dataclasses.field(default=constraints(b''))
@dataclasses.dataclass(order=True, slots=True)
class EcdsaPrivateKeyMsg(PrivateKeyMsg):
type: KeyAlgo
ecdsa_curve_name: unicode_string
Q: binary_string
d: mpint
comments: unicode_string = dataclasses.field(default=unicode_string(''), compare=False)
constraints: constraints = dataclasses.field(default=constraints(b''))
@dataclasses.dataclass(order=True, slots=True)
class Ed25519PrivateKeyMsg(PrivateKeyMsg):
type: KeyAlgo
enc_a: binary_string
k_env_a: binary_string
comments: unicode_string = dataclasses.field(default=unicode_string(''), compare=False)
constraints: constraints = dataclasses.field(default=constraints(b''))
@dataclasses.dataclass
class PublicKeyMsg(Msg):
@staticmethod
def get_dataclass(
type: KeyAlgo
) -> type[t.Union[
RSAPublicKeyMsg,
EcdsaPublicKeyMsg,
Ed25519PublicKeyMsg,
DSAPublicKeyMsg
]]:
match type:
case KeyAlgo.RSA:
return RSAPublicKeyMsg
case KeyAlgo.ECDSA256 | KeyAlgo.ECDSA384 | KeyAlgo.ECDSA521:
return EcdsaPublicKeyMsg
case KeyAlgo.ED25519:
return Ed25519PublicKeyMsg
case KeyAlgo.DSA:
return DSAPublicKeyMsg
case _:
raise NotImplementedError(type)
@functools.cached_property
def public_key(self) -> CryptoPublicKey:
type: KeyAlgo = self.type
match type:
case KeyAlgo.RSA:
return RSAPublicNumbers(
self.e,
self.n
).public_key()
case KeyAlgo.ECDSA256 | KeyAlgo.ECDSA384 | KeyAlgo.ECDSA521:
curve = _ECDSA_KEY_TYPE[KeyAlgo(type)]
return EllipticCurvePublicKey.from_encoded_point(
curve(),
self.Q
)
case KeyAlgo.ED25519:
return Ed25519PublicKey.from_public_bytes(
self.enc_a
)
case KeyAlgo.DSA:
return DSAPublicNumbers(
self.y,
DSAParameterNumbers(
self.p,
self.q,
self.g
)
).public_key()
case _:
raise NotImplementedError(type)
@staticmethod
def from_public_key(public_key: CryptoPublicKey) -> PublicKeyMsg:
match public_key:
case DSAPublicKey():
dsa_pn: DSAPublicNumbers = public_key.public_numbers()
return DSAPublicKeyMsg(
KeyAlgo.DSA,
mpint(dsa_pn.parameter_numbers.p),
mpint(dsa_pn.parameter_numbers.q),
mpint(dsa_pn.parameter_numbers.g),
mpint(dsa_pn.y)
)
case EllipticCurvePublicKey():
return EcdsaPublicKeyMsg(
getattr(KeyAlgo, f'ECDSA{public_key.curve.key_size}'),
unicode_string(f'nistp{public_key.curve.key_size}'),
binary_string(public_key.public_bytes(
encoding=serialization.Encoding.X962,
format=serialization.PublicFormat.UncompressedPoint
))
)
case Ed25519PublicKey():
return Ed25519PublicKeyMsg(
KeyAlgo.ED25519,
binary_string(public_key.public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
))
)
case RSAPublicKey():
rsa_pn: RSAPublicNumbers = public_key.public_numbers()
return RSAPublicKeyMsg(
KeyAlgo.RSA,
mpint(rsa_pn.e),
mpint(rsa_pn.n)
)
case _:
raise NotImplementedError(public_key)
@functools.cached_property
def fingerprint(self) -> str:
digest = hashlib.sha256()
msg = copy.copy(self)
msg.comments = unicode_string('')
k = msg.to_blob()
digest.update(k)
return binascii.b2a_base64(
digest.digest(),
newline=False
).rstrip(b'=').decode('utf-8')
@dataclasses.dataclass(order=True, slots=True)
class RSAPublicKeyMsg(PublicKeyMsg):
type: KeyAlgo
e: mpint
n: mpint
comments: unicode_string = dataclasses.field(default=unicode_string(''), compare=False)
@dataclasses.dataclass(order=True, slots=True)
class DSAPublicKeyMsg(PublicKeyMsg):
type: KeyAlgo
p: mpint
q: mpint
g: mpint
y: mpint
comments: unicode_string = dataclasses.field(default=unicode_string(''), compare=False)
@dataclasses.dataclass(order=True, slots=True)
class EcdsaPublicKeyMsg(PublicKeyMsg):
type: KeyAlgo
ecdsa_curve_name: unicode_string
Q: binary_string
comments: unicode_string = dataclasses.field(default=unicode_string(''), compare=False)
@dataclasses.dataclass(order=True, slots=True)
class Ed25519PublicKeyMsg(PublicKeyMsg):
type: KeyAlgo
enc_a: binary_string
comments: unicode_string = dataclasses.field(default=unicode_string(''), compare=False)
@dataclasses.dataclass(order=True, slots=True)
class KeyList(Msg):
nkeys: uint32
keys: PublicKeyMsgList
def __post_init__(self) -> None:
if self.nkeys != len(self.keys):
raise SshAgentFailure(
"agent: invalid number of keys received for identities list"
)
@dataclasses.dataclass(order=True, slots=True)
class PublicKeyMsgList(Msg):
keys: list[PublicKeyMsg]
def __iter__(self) -> t.Iterator[PublicKeyMsg]:
yield from self.keys
def __len__(self) -> int:
return len(self.keys)
@classmethod
def from_blob(cls, blob: memoryview | bytes) -> t.Self:
...
@classmethod
def consume_from_blob(cls, blob: memoryview | bytes) -> tuple[t.Self, memoryview | bytes]:
args: list[PublicKeyMsg] = []
while blob:
prev_blob = blob
key_blob, key_blob_length, comment_blob = cls._consume_field(blob)
peek_key_algo, _length, _blob = cls._consume_field(key_blob)
pub_key_msg_cls = PublicKeyMsg.get_dataclass(
KeyAlgo(bytes(peek_key_algo).decode('utf-8'))
)
_fv, comment_blob_length, blob = cls._consume_field(comment_blob)
key_plus_comment = (
prev_blob[4: (4 + key_blob_length) + (4 + comment_blob_length)]
)
args.append(pub_key_msg_cls.from_blob(key_plus_comment))
return cls(args), b""
@staticmethod
def _consume_field(
blob: memoryview | bytes
) -> tuple[memoryview | bytes, uint32, memoryview | bytes]:
length = uint32.from_blob(blob[:4])
blob = blob[4:]
data, rest = _split_blob(blob, length)
return data, length, rest
class SshAgentClient:
def __init__(self, auth_sock: str) -> None:
self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self._sock.settimeout(_SSH_AGENT_CLIENT_SOCKET_TIMEOUT)
self._sock.connect(auth_sock)
def close(self) -> None:
self._sock.close()
def __enter__(self) -> t.Self:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None
) -> None:
self.close()
def send(self, msg: bytes) -> bytes:
length = uint32(len(msg)).to_blob()
self._sock.sendall(length + msg)
bufsize = uint32.from_blob(self._sock.recv(4))
resp = self._sock.recv(bufsize)
if resp[0] == ProtocolMsgNumbers.SSH_AGENT_FAILURE:
raise SshAgentFailure('agent: failure')
return resp
def remove_all(self) -> None:
self.send(
ProtocolMsgNumbers.SSH_AGENTC_REMOVE_ALL_IDENTITIES.to_blob()
)
def remove(self, public_key: CryptoPublicKey) -> None:
key_blob = PublicKeyMsg.from_public_key(public_key).to_blob()
self.send(
ProtocolMsgNumbers.SSH_AGENTC_REMOVE_IDENTITY.to_blob() +
uint32(len(key_blob)).to_blob() + key_blob
)
def add(
self,
private_key: CryptoPrivateKey,
comments: str | None = None,
lifetime: int | None = None,
confirm: bool | None = None,
) -> None:
key_msg = PrivateKeyMsg.from_private_key(private_key)
key_msg.comments = unicode_string(comments or '')
if lifetime:
key_msg.constraints += constraints(
[ProtocolMsgNumbers.SSH_AGENT_CONSTRAIN_LIFETIME]
).to_blob() + uint32(lifetime).to_blob()
if confirm:
key_msg.constraints += constraints(
[ProtocolMsgNumbers.SSH_AGENT_CONSTRAIN_CONFIRM]
).to_blob()
if key_msg.constraints:
msg = ProtocolMsgNumbers.SSH_AGENTC_ADD_ID_CONSTRAINED.to_blob()
else:
msg = ProtocolMsgNumbers.SSH_AGENTC_ADD_IDENTITY.to_blob()
msg += key_msg.to_blob()
self.send(msg)
def list(self) -> KeyList:
req = ProtocolMsgNumbers.SSH_AGENTC_REQUEST_IDENTITIES.to_blob()
r = memoryview(bytearray(self.send(req)))
if r[0] != ProtocolMsgNumbers.SSH_AGENT_IDENTITIES_ANSWER:
raise SshAgentFailure(
'agent: non-identities answer received for identities list'
)
return KeyList.from_blob(r[1:])
def __contains__(self, public_key: CryptoPublicKey) -> bool:
msg = PublicKeyMsg.from_public_key(public_key)
return msg in self.list().keys
@functools.cache
def _key_data_into_crypto_objects(key_data: bytes, passphrase: bytes | None) -> tuple[CryptoPrivateKey, CryptoPublicKey, str]:
private_key = serialization.ssh.load_ssh_private_key(key_data, passphrase)
public_key = private_key.public_key()
fingerprint = PublicKeyMsg.from_public_key(public_key).fingerprint
return private_key, public_key, fingerprint

28
licenses/BSD-3-Clause.txt Normal file
View File

@@ -0,0 +1,28 @@
Copyright (c) Contributors to the Ansible project. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its contributors
may be used to endorse or promote products derived from this software
without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
SUCH DAMAGE.

View File

@@ -0,0 +1,59 @@
from __future__ import annotations
import os
from ansible.plugins.action import ActionBase
from ansible.utils._ssh_agent import SshAgentClient
from cryptography.hazmat.primitives.serialization import ssh
class ActionModule(ActionBase):
def run(self, tmp=None, task_vars=None):
results = super(ActionModule, self).run(tmp, task_vars)
del tmp # tmp no longer has any effect
match self._task.args['action']:
case 'list':
return self.list()
case 'remove':
return self.remove(self._task.args['pubkey'])
case 'remove_all':
return self.remove_all()
case _:
return {'failed': True, 'msg': 'not implemented'}
def remove(self, pubkey_data):
with SshAgentClient(os.environ['SSH_AUTH_SOCK']) as client:
public_key = ssh.load_ssh_public_key(pubkey_data.encode())
client.remove(public_key)
return {'failed': public_key in client}
def remove_all(self):
with SshAgentClient(os.environ['SSH_AUTH_SOCK']) as client:
nkeys_before = client.list().nkeys
client.remove_all()
nkeys_after = client.list().nkeys
return {
'failed': nkeys_after != 0,
'nkeys_removed': nkeys_before,
}
def list(self):
result = {'keys': [], 'nkeys': 0}
with SshAgentClient(os.environ['SSH_AUTH_SOCK']) as client:
key_list = client.list()
result['nkeys'] = key_list.nkeys
for key in key_list.keys:
public_key = key.public_key
key_size = getattr(public_key, 'key_size', 256)
fingerprint = key.fingerprint
key_type = key.type.main_type
result['keys'].append({
'type': key_type,
'key_size': key_size,
'fingerprint': f'SHA256:{fingerprint}',
'comments': key.comments,
})
return result

View File

@@ -0,0 +1,54 @@
from __future__ import annotations
from ansible.plugins.action import ActionBase
from ansible.utils._ssh_agent import PublicKeyMsg
from ansible.module_utils.common.text.converters import to_bytes, to_text
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.rsa import generate_private_key as rsa_generate_private_key
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
from cryptography.hazmat.primitives.asymmetric.dsa import generate_private_key as dsa_generate_private_key
from cryptography.hazmat.primitives.asymmetric.ec import SECP384R1, generate_private_key as ecdsa_generate_private_key
class ActionModule(ActionBase):
def run(self, tmp=None, task_vars=None):
results = super(ActionModule, self).run(tmp, task_vars)
del tmp # tmp no longer has any effect
match self._task.args.get('type'):
case 'ed25519':
private_key = Ed25519PrivateKey.generate()
case 'rsa':
private_key = rsa_generate_private_key(65537, 4096)
case 'dsa':
private_key = dsa_generate_private_key(1024)
case 'ecdsa':
private_key = ecdsa_generate_private_key(SECP384R1())
case _:
return {'failed': True, 'msg': 'not implemented'}
public_key = private_key.public_key()
public_key_msg = PublicKeyMsg.from_public_key(public_key)
if not (passphrase := self._task.args.get('passphrase')):
encryption_algorithm = serialization.NoEncryption()
else:
encryption_algorithm = serialization.BestAvailableEncryption(
to_bytes(passphrase)
)
return {
'changed': True,
'private_key': to_text(private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.OpenSSH,
encryption_algorithm=encryption_algorithm,
)),
'public_key': to_text(public_key.public_bytes(
encoding=serialization.Encoding.OpenSSH,
format=serialization.PublicFormat.OpenSSH,
)),
'fingerprint': f'SHA256:{public_key_msg.fingerprint}',
}

View File

@@ -0,0 +1,3 @@
needs/ssh
shippable/posix/group2
context/target

View File

@@ -0,0 +1,46 @@
- hosts: testhost
tasks:
- set_fact:
key_types:
- ed25519
- rsa
- ecdsa
- set_fact:
key_types: "{{ key_types + ['dsa'] }}"
when: ansible_distribution == "RedHat"
- include_tasks: test_key.yml
loop: "{{ key_types }}"
loop_control:
extended: true
- ssh_agent:
action: remove
pubkey: "{{ sshkey.public_key }}"
- ssh_agent:
action: list
register: keys
- assert:
that:
- keys.nkeys == key_types | length - 1
- name: remove all keys
ssh_agent:
action: remove_all
register: r
- assert:
that:
- r is success
- r.nkeys_removed == key_types | length - 1
- ssh_agent:
action: list
register: keys
- assert:
that:
- keys.nkeys == 0

View File

@@ -0,0 +1,23 @@
- delegate_to: localhost
block:
- name: install bcrypt
pip:
name: bcrypt
register: bcrypt
- tempfile:
path: "{{ lookup('env', 'OUTPUT_DIR') }}"
state: directory
register: tmpdir
- import_tasks: tests.yml
always:
- name: uninstall bcrypt
pip:
name: bcrypt
state: absent
when: bcrypt is changed
- file:
path: tmpdir.path
state: absent

View File

@@ -0,0 +1,49 @@
- slurp:
path: ~/.ssh/authorized_keys
register: akeys
- debug:
msg: '{{ akeys.content|b64decode }}'
- command: ansible-playbook -i {{ ansible_inventory_sources|first|quote }} -vvv {{ role_path }}/auto.yml
environment:
ANSIBLE_CALLBACK_RESULT_FORMAT: yaml
ANSIBLE_SSH_AGENT: auto
register: auto
- command: ps {{ ps_flags }} -opid
register: pids
# Some distros will exit with rc=1 if no processes were returned
vars:
ps_flags: '{{ "" if ansible_distribution == "Alpine" else "-x" }}'
- assert:
that:
- >-
'started and bound to' in auto.stdout
- >-
'SSH: SSH_AGENT adding' in auto.stdout
- >-
'exists in agent' in auto.stdout
- pids|map('trim')|select('eq', pid) == []
vars:
pid: '{{ auto.stdout|regex_findall("ssh-agent\[(\d+)\]")|first }}'
- command: ssh-agent -D -s -a '{{ tmpdir.path }}/agent.sock'
async: 30
poll: 0
- command: ansible-playbook -i {{ ansible_inventory_sources|first|quote }} -vvv {{ role_path }}/auto.yml
environment:
ANSIBLE_CALLBACK_RESULT_FORMAT: yaml
ANSIBLE_SSH_AGENT: '{{ tmpdir.path }}/agent.sock'
register: existing
- assert:
that:
- >-
'started and bound to' not in existing.stdout
- >-
'SSH: SSH_AGENT adding' in existing.stdout
- >-
'exists in agent' in existing.stdout

View File

@@ -0,0 +1,38 @@
- ssh_keygen:
type: "{{ item }}"
passphrase: passphrase
register: sshkey
- slurp:
path: ~/.ssh/authorized_keys
register: akeys
- copy:
content: |
{{ sshkey.public_key }}
{{ akeys.content|b64decode }}
dest: ~/.ssh/authorized_keys
mode: '0400'
- block:
- ping:
- name: list keys from agent
ssh_agent:
action: list
register: keys
- assert:
that:
- keys.nkeys == ansible_loop.index
- keys['keys'][ansible_loop.index0].fingerprint == fingerprint
- name: key already exists in the agent
ping:
vars:
ansible_password: ~
ansible_ssh_password: ~
ansible_ssh_private_key_file: ~
ansible_ssh_private_key: '{{ sshkey.private_key }}'
ansible_ssh_private_key_passphrase: passphrase
fingerprint: '{{ sshkey.fingerprint }}'