|
|
|
@ -1,3 +1,4 @@
|
|
|
|
|
"""UNIX-domain socket server for ssh-agent implementation."""
|
|
|
|
|
import contextlib
|
|
|
|
|
import logging
|
|
|
|
|
import os
|
|
|
|
@ -14,6 +15,7 @@ UNIX_SOCKET_TIMEOUT = 0.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def remove_file(path, remove=os.remove, exists=os.path.exists):
|
|
|
|
|
"""Remove file, and raise OSError if still exists."""
|
|
|
|
|
try:
|
|
|
|
|
remove(path)
|
|
|
|
|
except OSError:
|
|
|
|
@ -23,6 +25,11 @@ def remove_file(path, remove=os.remove, exists=os.path.exists):
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
|
def unix_domain_socket_server(sock_path):
|
|
|
|
|
"""
|
|
|
|
|
Create UNIX-domain socket on specified path.
|
|
|
|
|
|
|
|
|
|
Listen on it, and delete it after the generated context is over.
|
|
|
|
|
"""
|
|
|
|
|
log.debug('serving on SSH_AUTH_SOCK=%s', sock_path)
|
|
|
|
|
remove_file(sock_path)
|
|
|
|
|
|
|
|
|
@ -36,6 +43,12 @@ def unix_domain_socket_server(sock_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def handle_connection(conn, handler):
|
|
|
|
|
"""
|
|
|
|
|
Handle a single connection using the specified protocol handler in a loop.
|
|
|
|
|
|
|
|
|
|
Exit when EOFError is raised.
|
|
|
|
|
All other exceptions are logged as warnings.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
log.debug('welcome agent')
|
|
|
|
|
while True:
|
|
|
|
@ -49,6 +62,12 @@ def handle_connection(conn, handler):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def retry(func, exception_type, quit_event):
|
|
|
|
|
"""
|
|
|
|
|
Run the function, retrying when the specified exception_type occurs.
|
|
|
|
|
|
|
|
|
|
Poll quit_event on each iteration, to be responsive to an external
|
|
|
|
|
exit request.
|
|
|
|
|
"""
|
|
|
|
|
while True:
|
|
|
|
|
if quit_event.is_set():
|
|
|
|
|
raise StopIteration
|
|
|
|
@ -58,16 +77,17 @@ def retry(func, exception_type, quit_event):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def server_thread(server, handler, quit_event):
|
|
|
|
|
def server_thread(sock, handler, quit_event):
|
|
|
|
|
"""Run a server on the specified socket."""
|
|
|
|
|
log.debug('server thread started')
|
|
|
|
|
|
|
|
|
|
def accept_connection():
|
|
|
|
|
conn, _ = server.accept()
|
|
|
|
|
conn, _ = sock.accept()
|
|
|
|
|
conn.settimeout(None)
|
|
|
|
|
return conn
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
log.debug('waiting for connection on %s', server.getsockname())
|
|
|
|
|
log.debug('waiting for connection on %s', sock.getsockname())
|
|
|
|
|
try:
|
|
|
|
|
conn = retry(accept_connection, socket.timeout, quit_event)
|
|
|
|
|
except StopIteration:
|
|
|
|
@ -80,6 +100,7 @@ def server_thread(server, handler, quit_event):
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
|
def spawn(func, kwargs):
|
|
|
|
|
"""Spawn a thread, and join it after the context is over."""
|
|
|
|
|
t = threading.Thread(target=func, kwargs=kwargs)
|
|
|
|
|
t.start()
|
|
|
|
|
yield
|
|
|
|
@ -88,14 +109,20 @@ def spawn(func, kwargs):
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
|
def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT):
|
|
|
|
|
"""
|
|
|
|
|
Start the ssh-agent server on a UNIX-domain socket.
|
|
|
|
|
|
|
|
|
|
If no connection is made during the specified timeout,
|
|
|
|
|
retry until the context is over.
|
|
|
|
|
"""
|
|
|
|
|
if sock_path is None:
|
|
|
|
|
sock_path = tempfile.mktemp(prefix='ssh-agent-')
|
|
|
|
|
|
|
|
|
|
environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
|
|
|
|
|
with unix_domain_socket_server(sock_path) as server:
|
|
|
|
|
server.settimeout(timeout)
|
|
|
|
|
with unix_domain_socket_server(sock_path) as sock:
|
|
|
|
|
sock.settimeout(timeout)
|
|
|
|
|
quit_event = threading.Event()
|
|
|
|
|
kwargs = dict(server=server, handler=handler, quit_event=quit_event)
|
|
|
|
|
kwargs = dict(sock=sock, handler=handler, quit_event=quit_event)
|
|
|
|
|
with spawn(server_thread, kwargs):
|
|
|
|
|
try:
|
|
|
|
|
yield environ
|
|
|
|
@ -105,6 +132,11 @@ def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_process(command, environ, use_shell=False):
|
|
|
|
|
"""
|
|
|
|
|
Run the specified process and wait until it finishes.
|
|
|
|
|
|
|
|
|
|
Use environ dict for environment variables.
|
|
|
|
|
"""
|
|
|
|
|
log.info('running %r with %r', command, environ)
|
|
|
|
|
env = dict(os.environ)
|
|
|
|
|
env.update(environ)
|
|
|
|
|