diff --git a/libagent/ssh/__init__.py b/libagent/ssh/__init__.py index cf4f41f..da055f4 100644 --- a/libagent/ssh/__init__.py +++ b/libagent/ssh/__init__.py @@ -23,9 +23,11 @@ log = logging.getLogger(__name__) UNIX_SOCKET_TIMEOUT = 0.1 -def ssh_args(label): +def ssh_args(conn): """Create SSH command for connecting specified server.""" - identity = device.interface.string_to_identity(label) + I, = conn.identities + identity = I.identity_dict + pubkey_tempfile, = conn.public_keys_as_files() args = [] if 'port' in identity: @@ -33,12 +35,15 @@ def ssh_args(label): if 'user' in identity: args += ['-l', identity['user']] + args += ['-o', 'IdentityFile={}'.format(pubkey_tempfile.name)] + args += ['-o', 'IdentitiesOnly=true'] return args + [identity['host']] -def mosh_args(label): +def mosh_args(conn): """Create SSH command for connecting specified server.""" - identity = device.interface.string_to_identity(label) + I, = conn.identities + identity = I.identity_dict args = [] if 'port' in identity: @@ -193,6 +198,7 @@ class JustInTimeConnection(object): self.conn_factory = conn_factory self.identities = identities self.public_keys_cache = public_keys + self.public_keys_tempfiles = [] def public_keys(self): """Return a list of SSH public keys (in textual format).""" @@ -209,6 +215,17 @@ class JustInTimeConnection(object): pk['identity'] = identity return public_keys + def public_keys_as_files(self): + """Store public keys as temporary SSH identity files.""" + if not self.public_keys_tempfiles: + for pk in self.public_keys(): + f = tempfile.NamedTemporaryFile(prefix='trezor-ssh-pubkey-', mode='w') + f.write(pk) + f.flush() + self.public_keys_tempfiles.append(f) + + return self.public_keys_tempfiles + def sign(self, blob, identity): """Sign a given blob using the specified identity on the device.""" conn = self.conn_factory() @@ -238,6 +255,7 @@ def main(device_type): util.setup_logging(verbosity=args.verbose, filename=args.log_file) public_keys = None + filename = None if args.identity.startswith('/'): filename = args.identity contents = open(filename, 'rb').read().decode('utf-8') @@ -252,14 +270,22 @@ def main(device_type): identity.identity_dict['proto'] = u'ssh' log.info('identity #%d: %s', index, identity.to_string()) - sock_path = _get_sock_path(args) + # override default PIN/passphrase entry tools (relevant for TREZOR/Keepkey): + device_type.ui = device.ui.UI(device_type=device_type, config=vars(args)) + device_type.cached_passphrase_ack = util.ExpiringCache( + args.cache_expiry_seconds) + + conn = JustInTimeConnection( + conn_factory=lambda: client.Client(device_type()), + identities=identities, public_keys=public_keys) + sock_path = _get_sock_path(args) command = args.command context = _dummy_context() if args.connect: - command = ['ssh'] + ssh_args(args.identity) + args.command + command = ['ssh'] + ssh_args(conn) + args.command elif args.mosh: - command = ['mosh'] + mosh_args(args.identity) + args.command + command = ['mosh'] + mosh_args(conn) + args.command elif args.daemonize: out = 'SSH_AUTH_SOCK={0}; export SSH_AUTH_SOCK;\n'.format(sock_path) sys.stdout.write(out) @@ -274,15 +300,6 @@ def main(device_type): command = os.environ['SHELL'] sys.stdin.close() - # override default PIN/passphrase entry tools (relevant for TREZOR/Keepkey): - device_type.ui = device.ui.UI(device_type=device_type, config=vars(args)) - device_type.cached_passphrase_ack = util.ExpiringCache( - args.cache_expiry_seconds) - - conn = JustInTimeConnection( - conn_factory=lambda: client.Client(device_type()), - identities=identities, public_keys=public_keys) - if command or args.daemonize or args.foreground: with context: return run_server(conn=conn, command=command, sock_path=sock_path,