From 4a12bfa0b70ec3c6b8840ec56a8f8a7fa61f9f93 Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Thu, 15 Feb 2018 15:10:34 +0200 Subject: [PATCH] Allow SSH agent to daemonize when invoked with `-d` flag This change adds the support for "eval `trezor-agent -d`" invocation. --- libagent/ssh/__init__.py | 44 +++++++++++++++++++++++++++++++--------- setup.py | 1 + 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/libagent/ssh/__init__.py b/libagent/ssh/__init__.py index 2411307..462af4d 100644 --- a/libagent/ssh/__init__.py +++ b/libagent/ssh/__init__.py @@ -5,6 +5,7 @@ import io import logging import os import re +import signal import subprocess import sys import tempfile @@ -12,6 +13,7 @@ import threading import pkg_resources import configargparse +import daemon from .. import device, formats, server, util from . import client, protocol @@ -80,6 +82,8 @@ def create_agent_parser(device_type): help='log SSH protocol messages for debugging.') g = p.add_mutually_exclusive_group() + g.add_argument('-d', '--daemonize', default=False, action='store_true', + help='Daemonize the agent and print its UNIX socket path') g.add_argument('-s', '--shell', default=False, action='store_true', help=('run ${SHELL} as subprocess under SSH agent, allowing ' 'regular SSH-based tools to be used in the shell')) @@ -96,7 +100,7 @@ def create_agent_parser(device_type): @contextlib.contextmanager -def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT): +def serve(handler, sock_path, timeout=UNIX_SOCKET_TIMEOUT): """ Start the ssh-agent server on a UNIX-domain socket. @@ -106,9 +110,6 @@ def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT): ssh_version = subprocess.check_output(['ssh', '-V'], stderr=subprocess.STDOUT) log.debug('local SSH version: %r', ssh_version) - if sock_path is None: - sock_path = tempfile.mktemp(prefix='trezor-ssh-agent-') - environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())} device_mutex = threading.Lock() with server.unix_domain_socket_server(sock_path) as sock: @@ -128,12 +129,17 @@ def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT): quit_event.set() -def run_server(conn, command, debug, timeout): +def run_server(conn, command, sock_path, debug, timeout): """Common code for run_agent and run_git below.""" try: handler = protocol.Handler(conn=conn, debug=debug) - with serve(handler=handler, timeout=timeout) as env: - return server.run_process(command=command, environ=env) + with serve(handler=handler, sock_path=sock_path, + timeout=timeout) as env: + if command is None: + signal.pause() # wait for signal + return 0 + else: + return server.run_process(command=command, environ=env) except KeyboardInterrupt: log.info('server stopped') @@ -195,6 +201,11 @@ class JustInTimeConnection(object): return conn.sign_ssh_challenge(blob=blob, identity=identity) +@contextlib.contextmanager +def _dummy_context(): + yield + + @handle_connection_error def main(device_type): """Run ssh-agent using given hardware client factory.""" @@ -216,10 +227,21 @@ def main(device_type): identity.identity_dict['proto'] = u'ssh' log.info('identity #%d: %s', index, identity.to_string()) + sock_path = tempfile.mktemp(prefix='trezor-ssh-agent-') + + command = None + context = _dummy_context() if args.connect: command = ['ssh'] + ssh_args(args.identity) + args.command elif args.mosh: command = ['mosh'] + mosh_args(args.identity) + args.command + elif args.daemonize: + msg = ('SSH_AUTH_SOCK={0}; export SSH_AUTH_SOCK;\n' + 'SSH_AGENT_PID={1}; export SSH_AGENT_PID;\n' + 'echo Agent pid {1};\n'.format(sock_path, os.getpid())) + sys.stdout.write(msg) + sys.stdout.flush() + context = daemon.DaemonContext() else: command = args.command @@ -231,9 +253,11 @@ def main(device_type): conn = JustInTimeConnection( conn_factory=lambda: client.Client(device_type()), identities=identities, public_keys=public_keys) - if command: - return run_server(conn=conn, command=command, debug=args.debug, - timeout=args.timeout) + + if command or args.daemonize: + with context: + return run_server(conn=conn, command=command, sock_path=sock_path, + debug=args.debug, timeout=args.timeout) else: for pk in conn.public_keys(): sys.stdout.write(pk) diff --git a/setup.py b/setup.py index b53b45e..e2161a9 100755 --- a/setup.py +++ b/setup.py @@ -17,6 +17,7 @@ setup( install_requires=[ 'backports.shutil_which>=3.5.1', 'ConfigArgParse>=0.12.0', + 'python-daemon>=2.1.2', 'ecdsa>=0.13', 'ed25519>=1.4', 'pymsgbox>=1.0.6',