diff --git a/.travis.yml b/.travis.yml index fa3739c..f58c4a3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,12 +2,22 @@ sudo: false language: python python: - "2.7" - - "3.4" - - "3.5" + +cache: + directories: + - $HOME/.cache/pip + +addons: + apt: + packages: + - libudev-dev + - libusb-1.0-0-dev + +before_install: + - pip install -U setuptools pylint coverage pep8 pydocstyle "pip>=7.0" wheel install: - - pip install ecdsa ed25519 semver # test without trezorlib for now - - pip install -U pylint coverage pep8 pydocstyle # use latest tools + - pip install -e . script: - pep8 trezor_agent diff --git a/setup.py b/setup.py index be356c8..87fc7d1 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,8 @@ setup( author_email='roman.zeyde@gmail.com', url='http://github.com/romanz/trezor-agent', packages=['trezor_agent', 'trezor_agent.gpg'], - install_requires=['ecdsa>=0.13', 'ed25519>=1.4', 'Cython>=0.23.4', 'protobuf>=3.0.0', 'trezor>=0.7.4', 'semver>=2.2'], + install_requires=['ecdsa>=0.13', 'ed25519>=1.4', 'Cython>=0.23.4', 'protobuf>=3.0.0', 'semver>=2.2', + 'trezor>=0.7.6', 'keepkey>=0.7.3', 'ledgerblue>=0.1.8'], platforms=['POSIX'], classifiers=[ 'Environment :: Console', @@ -27,10 +28,6 @@ setup( 'Topic :: Security', 'Topic :: Utilities', ], - extras_require={ - 'trezorlib': ['python-trezor>=0.7.6'], - 'keepkeylib': ['keepkey>=0.7.3'], - }, entry_points={'console_scripts': [ 'trezor-agent = trezor_agent.__main__:run_agent', 'trezor-git = trezor_agent.__main__:run_git', diff --git a/trezor_agent/__main__.py b/trezor_agent/__main__.py index fd86baf..0a291c3 100644 --- a/trezor_agent/__main__.py +++ b/trezor_agent/__main__.py @@ -7,14 +7,14 @@ import re import subprocess import sys -from . import client, formats, protocol, server, util +from . import client, device, formats, protocol, server, util log = logging.getLogger(__name__) def ssh_args(label): """Create SSH command for connecting specified server.""" - identity = util.string_to_identity(label, identity_type=dict) + identity = device.interface.string_to_identity(label) args = [] if 'port' in identity: @@ -125,27 +125,28 @@ def run_agent(client_factory=client.Client): args = create_agent_parser().parse_args() util.setup_logging(verbosity=args.verbose) - with client_factory(curve=args.ecdsa_curve_name) as conn: - label = args.identity - command = args.command + d = device.detect(identity_str=args.identity, + curve_name=args.ecdsa_curve_name) + conn = client_factory(device=d) - public_key = conn.get_public_key(label=label) + command = args.command + public_key = conn.get_public_key() - if args.connect: - command = ssh_args(label) + args.command - log.debug('SSH connect: %r', command) + if args.connect: + command = ssh_args(args.identity) + args.command + log.debug('SSH connect: %r', command) - use_shell = bool(args.shell) - if use_shell: - command = os.environ['SHELL'] - log.debug('using shell: %r', command) + use_shell = bool(args.shell) + if use_shell: + command = os.environ['SHELL'] + log.debug('using shell: %r', command) - if not command: - sys.stdout.write(public_key) - return + if not command: + sys.stdout.write(public_key) + return - return run_server(conn=conn, public_key=public_key, command=command, - debug=args.debug, timeout=args.timeout) + return run_server(conn=conn, public_key=public_key, command=command, + debug=args.debug, timeout=args.timeout) @handle_connection_error diff --git a/trezor_agent/_ledger.py b/trezor_agent/_ledger.py deleted file mode 100644 index 1d159de..0000000 --- a/trezor_agent/_ledger.py +++ /dev/null @@ -1,148 +0,0 @@ -"""TREZOR-like interface for Ledger hardware wallet.""" -import binascii -import struct - -from trezorlib.types_pb2 import IdentityType # pylint: disable=import-error,unused-import -from . import util - - -class LedgerClientConnection(object): - """Mock for TREZOR-like connection object.""" - - def __init__(self, dongle): - """Create connection.""" - self.dongle = dongle - - @staticmethod - def expand_path(path): - """Convert BIP32 path into bytes.""" - return b''.join((struct.pack('>I', e) for e in path)) - - @staticmethod - def convert_public_key(ecdsa_curve_name, result): - """Convert Ledger reply into PublicKey object.""" - from trezorlib.messages_pb2 import PublicKey # pylint: disable=import-error - if ecdsa_curve_name == 'nist256p1': - if (result[64] & 1) != 0: - result = bytearray([0x03]) + result[1:33] - else: - result = bytearray([0x02]) + result[1:33] - else: - result = result[1:] - keyX = bytearray(result[0:32]) - keyY = bytearray(result[32:][::-1]) - if (keyX[31] & 1) != 0: - keyY[31] |= 0x80 - result = b'\x00' + bytes(keyY) - publicKey = PublicKey() - publicKey.node.public_key = bytes(result) - return publicKey - - # pylint: disable=unused-argument - def get_public_node(self, n, ecdsa_curve_name='secp256k1', show_display=False): - """Get PublicKey object for specified BIP32 address and elliptic curve.""" - donglePath = LedgerClientConnection.expand_path(n) - if ecdsa_curve_name == 'nist256p1': - p2 = '01' - else: - p2 = '02' - apdu = '800200' + p2 - apdu = binascii.unhexlify(apdu) - apdu += bytearray([len(donglePath) + 1, len(donglePath) // 4]) - apdu += donglePath - result = bytearray(self.dongle.exchange(bytes(apdu)))[1:] - return LedgerClientConnection.convert_public_key(ecdsa_curve_name, result) - - # pylint: disable=too-many-locals - def sign_identity(self, identity, challenge_hidden, challenge_visual, - ecdsa_curve_name='secp256k1'): - """Sign specified challenges using secret key derived from given identity.""" - from trezorlib.messages_pb2 import SignedIdentity # pylint: disable=import-error - n = util.get_bip32_address(identity) - donglePath = LedgerClientConnection.expand_path(n) - if identity.proto == 'ssh': - ins = '04' - p1 = '00' - else: - ins = '08' - p1 = '00' - if ecdsa_curve_name == 'nist256p1': - p2 = '81' if identity.proto == 'ssh' else '01' - else: - p2 = '82' if identity.proto == 'ssh' else '02' - apdu = '80' + ins + p1 + p2 - apdu = binascii.unhexlify(apdu) - apdu += bytearray([len(challenge_hidden) + len(donglePath) + 1]) - apdu += bytearray([len(donglePath) // 4]) + donglePath - apdu += challenge_hidden - result = bytearray(self.dongle.exchange(bytes(apdu))) - if ecdsa_curve_name == 'nist256p1': - offset = 3 - length = result[offset] - r = result[offset+1:offset+1+length] - if r[0] == 0: - r = r[1:] - offset = offset + 1 + length + 1 - length = result[offset] - s = result[offset+1:offset+1+length] - if s[0] == 0: - s = s[1:] - offset = offset + 1 + length - signature = SignedIdentity() - signature.signature = b'\x00' + bytes(r) + bytes(s) - if identity.proto == 'ssh': - keyData = result[offset:] - pk = LedgerClientConnection.convert_public_key(ecdsa_curve_name, keyData) - signature.public_key = pk.node.public_key - return signature - else: - signature = SignedIdentity() - signature.signature = b'\x00' + bytes(result[0:64]) - if identity.proto == 'ssh': - keyData = result[64:] - pk = LedgerClientConnection.convert_public_key(ecdsa_curve_name, keyData) - signature.public_key = pk.node.public_key - return signature - - def get_ecdh_session_key(self, identity, peer_public_key, ecdsa_curve_name='secp256k1'): - """Create shared secret key for GPG decryption.""" - from trezorlib.messages_pb2 import ECDHSessionKey # pylint: disable=import-error - n = util.get_bip32_address(identity, True) - donglePath = LedgerClientConnection.expand_path(n) - if ecdsa_curve_name == 'nist256p1': - p2 = '01' - else: - p2 = '02' - apdu = '800a00' + p2 - apdu = binascii.unhexlify(apdu) - apdu += bytearray([len(peer_public_key) + len(donglePath) + 1]) - apdu += bytearray([len(donglePath) // 4]) + donglePath - apdu += peer_public_key - result = bytearray(self.dongle.exchange(bytes(apdu))) - sessionKey = ECDHSessionKey() - sessionKey.session_key = bytes(result) - return sessionKey - - def clear_session(self): - """Mock for TREZOR interface compatibility.""" - pass - - def close(self): - """Close connection.""" - self.dongle.close() - - # pylint: disable=unused-argument - # pylint: disable=no-self-use - def ping(self, msg, button_protection=False, pin_protection=False, - passphrase_protection=False): - """Mock for TREZOR interface compatibility.""" - return msg - - -class CallException(Exception): - """Ledger-related error (mainly for TREZOR compatibility).""" - - def __init__(self, code, message): - """Create an error.""" - super(CallException, self).__init__() - self.args = [code, message] diff --git a/trezor_agent/client.py b/trezor_agent/client.py index b4624c1..2e5b074 100644 --- a/trezor_agent/client.py +++ b/trezor_agent/client.py @@ -3,11 +3,10 @@ Connection to hardware authentication device. It is used for getting SSH public keys and ECDSA signing of server requests. """ -import binascii import io import logging -from . import factory, formats, util +from . import formats, util log = logging.getLogger(__name__) @@ -15,79 +14,36 @@ log = logging.getLogger(__name__) class Client(object): """Client wrapper for SSH authentication device.""" - def __init__(self, loader=factory.load, curve=formats.CURVE_NIST256): + def __init__(self, device): """Connect to hardware device.""" - client_wrapper = loader() - self.client = client_wrapper.connection - self.identity_type = client_wrapper.identity_type - self.device_name = client_wrapper.device_name - self.call_exception = client_wrapper.call_exception - self.curve = curve + device.identity_dict['proto'] = 'ssh' + self.device = device - def __enter__(self): - """Start a session, and test connection.""" - msg = 'Hello World!' - assert self.client.ping(msg) == msg - return self + def get_public_key(self): + """Get SSH public key from the device.""" + with self.device: + pubkey = self.device.pubkey() - def __exit__(self, *args): - """Keep the session open (doesn't forget PIN).""" - log.info('disconnected from %s', self.device_name) - self.client.close() + vk = formats.decompress_pubkey(pubkey=pubkey, + curve_name=self.device.curve_name) + return formats.export_public_key(vk=vk, + label=self.device.identity_str()) - def get_identity(self, label, index=0): - """Parse label string into Identity protobuf.""" - identity = util.string_to_identity(label, self.identity_type) - identity.proto = 'ssh' - identity.index = index - return identity - - def get_public_key(self, label): - """Get SSH public key corresponding to specified by label.""" - identity = self.get_identity(label=label) - label = util.identity_to_string(identity) # canonize key label - log.info('getting "%s" public key (%s) from %s...', - label, self.curve, self.device_name) - addr = util.get_bip32_address(identity) - node = self.client.get_public_node(n=addr, - ecdsa_curve_name=self.curve) - - pubkey = node.node.public_key - vk = formats.decompress_pubkey(pubkey=pubkey, curve_name=self.curve) - return formats.export_public_key(vk=vk, label=label) - - def sign_ssh_challenge(self, label, blob): - """Sign given blob using a private key, specified by the label.""" - identity = self.get_identity(label=label) + def sign_ssh_challenge(self, blob): + """Sign given blob using a private key on the device.""" msg = _parse_ssh_blob(blob) log.debug('%s: user %r via %r (%r)', msg['conn'], msg['user'], msg['auth'], msg['key_type']) - log.debug('nonce: %s', binascii.hexlify(msg['nonce'])) + log.debug('nonce: %r', msg['nonce']) log.debug('fingerprint: %s', msg['public_key']['fingerprint']) log.debug('hidden challenge size: %d bytes', len(blob)) log.info('please confirm user "%s" login to "%s" using %s...', - msg['user'].decode('ascii'), label, self.device_name) - - try: - result = self.client.sign_identity(identity=identity, - challenge_hidden=blob, - challenge_visual='', - ecdsa_curve_name=self.curve) - except self.call_exception as e: - code, msg = e.args - log.warning('%s error #%s: %s', self.device_name, code, msg) - raise IOError(msg) # close current connection, keep server open - - verifying_key = formats.decompress_pubkey(pubkey=result.public_key, - curve_name=self.curve) - key_type, blob = formats.serialize_verifying_key(verifying_key) - assert blob == msg['public_key']['blob'] - assert key_type == msg['key_type'] - assert len(result.signature) == 65 - assert result.signature[:1] == bytearray([0]) + msg['user'].decode('ascii'), self.device.identity_str(), + self.device) - return result.signature[1:] + with self.device: + return self.device.sign(blob=blob) def _parse_ssh_blob(data): diff --git a/trezor_agent/device/__init__.py b/trezor_agent/device/__init__.py new file mode 100644 index 0000000..65e915c --- /dev/null +++ b/trezor_agent/device/__init__.py @@ -0,0 +1,28 @@ +"""Cryptographic hardware device management.""" + +import logging + +from . import trezor +from . import keepkey +from . import ledger +from . import interface + +log = logging.getLogger(__name__) + +DEVICE_TYPES = [ + trezor.Trezor, + keepkey.KeepKey, + ledger.LedgerNanoS, +] + + +def detect(identity_str, curve_name): + """Detect the first available device and return it to the user.""" + for device_type in DEVICE_TYPES: + try: + with device_type(identity_str, curve_name) as d: + return d + except interface.NotFoundError as e: + log.debug('device not found: %s', e) + raise IOError('No device found: "{}" ({})'.format(identity_str, + curve_name)) diff --git a/trezor_agent/device/interface.py b/trezor_agent/device/interface.py new file mode 100644 index 0000000..5baa21e --- /dev/null +++ b/trezor_agent/device/interface.py @@ -0,0 +1,125 @@ +"""Device abstraction layer.""" + +import hashlib +import io +import logging +import re +import struct + +from .. import formats, util + +log = logging.getLogger(__name__) + +_identity_regexp = re.compile(''.join([ + '^' + r'(?:(?P.*)://)?', + r'(?:(?P.*)@)?', + r'(?P.*?)', + r'(?::(?P\w*))?', + r'(?P/.*)?', + '$' +])) + + +def string_to_identity(identity_str): + """Parse string into Identity dictionary.""" + m = _identity_regexp.match(identity_str) + result = m.groupdict() + log.debug('parsed identity: %s', result) + return {k: v for k, v in result.items() if v} + + +def identity_to_string(identity_dict): + """Dump Identity dictionary into its string representation.""" + result = [] + if identity_dict.get('proto'): + result.append(identity_dict['proto'] + '://') + if identity_dict.get('user'): + result.append(identity_dict['user'] + '@') + result.append(identity_dict['host']) + if identity_dict.get('port'): + result.append(':' + identity_dict['port']) + if identity_dict.get('path'): + result.append(identity_dict['path']) + log.debug('identity parts: %s', result) + return ''.join(result) + + +def get_bip32_address(identity_dict, ecdh=False): + """Compute BIP32 derivation address according to SLIP-0013/0017.""" + index = struct.pack('I', e) for e in path)) + + +def _convert_public_key(ecdsa_curve_name, result): + """Convert Ledger reply into PublicKey object.""" + if ecdsa_curve_name == 'nist256p1': + if (result[64] & 1) != 0: + result = bytearray([0x03]) + result[1:33] + else: + result = bytearray([0x02]) + result[1:33] + else: + result = result[1:] + keyX = bytearray(result[0:32]) + keyY = bytearray(result[32:][::-1]) + if (keyX[31] & 1) != 0: + keyY[31] |= 0x80 + result = b'\x00' + bytes(keyY) + return bytes(result) + + +class LedgerNanoS(interface.Device): + """Connection to Ledger Nano S device.""" + + def connect(self): + """Enumerate and connect to the first USB HID interface.""" + try: + return comm.getDongle() + except comm.CommException as e: + raise interface.NotFoundError( + '{} not connected: "{}"'.format(self, e)) + + def pubkey(self, ecdh=False): + """Get PublicKey object for specified BIP32 address and elliptic curve.""" + curve_name = self.get_curve_name(ecdh) + path = _expand_path(interface.get_bip32_address(self.identity_dict, + ecdh=ecdh)) + if curve_name == 'nist256p1': + p2 = '01' + else: + p2 = '02' + apdu = '800200' + p2 + apdu = binascii.unhexlify(apdu) + apdu += bytearray([len(path) + 1, len(path) // 4]) + apdu += path + result = bytearray(self.conn.exchange(bytes(apdu)))[1:] + return _convert_public_key(curve_name, result) + + def sign(self, blob): + """Sign given blob and return the signature (as bytes).""" + path = _expand_path(interface.get_bip32_address(self.identity_dict, + ecdh=False)) + if self.identity_dict['proto'] == 'ssh': + ins = '04' + p1 = '00' + else: + ins = '08' + p1 = '00' + if self.curve_name == 'nist256p1': + p2 = '81' if self.identity_dict['proto'] == 'ssh' else '01' + else: + p2 = '82' if self.identity_dict['proto'] == 'ssh' else '02' + apdu = '80' + ins + p1 + p2 + apdu = binascii.unhexlify(apdu) + apdu += bytearray([len(blob) + len(path) + 1]) + apdu += bytearray([len(path) // 4]) + path + apdu += blob + result = bytearray(self.conn.exchange(bytes(apdu))) + if self.curve_name == 'nist256p1': + offset = 3 + length = result[offset] + r = result[offset+1:offset+1+length] + if r[0] == 0: + r = r[1:] + offset = offset + 1 + length + 1 + length = result[offset] + s = result[offset+1:offset+1+length] + if s[0] == 0: + s = s[1:] + offset = offset + 1 + length + return bytes(r) + bytes(s) + else: + return bytes(result[:64]) + + def ecdh(self, pubkey): + """Get shared session key using Elliptic Curve Diffie-Hellman.""" + path = _expand_path(interface.get_bip32_address(self.identity_dict, + ecdh=True)) + if self.curve_name == 'nist256p1': + p2 = '01' + else: + p2 = '02' + apdu = '800a00' + p2 + apdu = binascii.unhexlify(apdu) + apdu += bytearray([len(pubkey) + len(path) + 1]) + apdu += bytearray([len(path) // 4]) + path + apdu += pubkey + result = bytearray(self.conn.exchange(bytes(apdu))) + assert result[0] == 0x04 + return bytes(result) diff --git a/trezor_agent/device/trezor.py b/trezor_agent/device/trezor.py new file mode 100644 index 0000000..1f538a3 --- /dev/null +++ b/trezor_agent/device/trezor.py @@ -0,0 +1,108 @@ +"""TREZOR-related code (see http://bitcointrezor.com/).""" + +import binascii +import logging +import semver + +from . import interface + +log = logging.getLogger(__name__) + + +class Trezor(interface.Device): + """Connection to TREZOR device.""" + + from . import trezor_defs as defs + + required_version = '>=1.4.0' + + def connect(self): + """Enumerate and connect to the first USB HID interface.""" + def empty_passphrase_handler(_): + return self.defs.PassphraseAck(passphrase='') + + for d in self.defs.HidTransport.enumerate(): + log.debug('endpoint: %s', d) + transport = self.defs.HidTransport(d) + connection = self.defs.Client(transport) + connection.callback_PassphraseRequest = empty_passphrase_handler + f = connection.features + log.debug('connected to %s %s', self, f.device_id) + log.debug('label : %s', f.label) + log.debug('vendor : %s', f.vendor) + current_version = '{}.{}.{}'.format(f.major_version, + f.minor_version, + f.patch_version) + log.debug('version : %s', current_version) + log.debug('revision : %s', binascii.hexlify(f.revision)) + if not semver.match(current_version, self.required_version): + fmt = ('Please upgrade your {} firmware to {} version' + ' (current: {})') + raise ValueError(fmt.format(self, self.required_version, + current_version)) + connection.ping(msg='', pin_protection=True) # unlock PIN + return connection + raise interface.NotFoundError('{} not connected'.format(self)) + + def close(self): + """Close connection.""" + self.conn.close() + + def pubkey(self, ecdh=False): + """Return public key.""" + curve_name = self.get_curve_name(ecdh=ecdh) + log.debug('"%s" getting public key (%s) from %s', + interface.identity_to_string(self.identity_dict), + curve_name, self) + addr = interface.get_bip32_address(self.identity_dict, ecdh=ecdh) + result = self.conn.get_public_node(n=addr, + ecdsa_curve_name=curve_name) + log.debug('result: %s', result) + return result.node.public_key + + def _identity_proto(self): + result = self.defs.IdentityType() + for name, value in self.identity_dict.items(): + setattr(result, name, value) + return result + + def sign(self, blob): + """Sign given blob and return the signature (as bytes).""" + curve_name = self.get_curve_name(ecdh=False) + log.debug('"%s" signing %r (%s) on %s', + interface.identity_to_string(self.identity_dict), blob, + curve_name, self) + try: + result = self.conn.sign_identity( + identity=self._identity_proto(), + challenge_hidden=blob, + challenge_visual='', + ecdsa_curve_name=curve_name) + log.debug('result: %s', result) + assert len(result.signature) == 65 + assert result.signature[:1] == b'\x00' + return result.signature[1:] + except self.defs.CallException as e: + msg = '{} error: {}'.format(self, e) + log.debug(msg, exc_info=True) + raise interface.DeviceError(msg) + + def ecdh(self, pubkey): + """Get shared session key using Elliptic Curve Diffie-Hellman.""" + curve_name = self.get_curve_name(ecdh=True) + log.debug('"%s" shared session key (%s) for %r from %s', + interface.identity_to_string(self.identity_dict), + curve_name, pubkey, self) + try: + result = self.conn.get_ecdh_session_key( + identity=self._identity_proto(), + peer_public_key=pubkey, + ecdsa_curve_name=curve_name) + log.debug('result: %s', result) + assert len(result.session_key) in {65, 33} # NIST256 or Curve25519 + assert result.session_key[:1] == b'\x04' + return result.session_key + except self.defs.CallException as e: + msg = '{} error: {}'.format(self, e) + log.debug(msg, exc_info=True) + raise interface.DeviceError(msg) diff --git a/trezor_agent/device/trezor_defs.py b/trezor_agent/device/trezor_defs.py new file mode 100644 index 0000000..2dff8ee --- /dev/null +++ b/trezor_agent/device/trezor_defs.py @@ -0,0 +1,8 @@ +"""TREZOR-related definitions.""" + +# pylint: disable=unused-import +from trezorlib.client import TrezorClient as Client +from trezorlib.client import CallException +from trezorlib.transport_hid import HidTransport +from trezorlib.messages_pb2 import PassphraseAck +from trezorlib.types_pb2 import IdentityType diff --git a/trezor_agent/factory.py b/trezor_agent/factory.py deleted file mode 100644 index 1af7eb1..0000000 --- a/trezor_agent/factory.py +++ /dev/null @@ -1,124 +0,0 @@ -"""Thin wrapper around trezor/keepkey libraries.""" -from __future__ import absolute_import - -import binascii -import collections -import logging - -import semver - -log = logging.getLogger(__name__) - -ClientWrapper = collections.namedtuple( - 'ClientWrapper', - ['connection', 'identity_type', 'device_name', 'call_exception']) - - -# pylint: disable=too-many-arguments -def _load_client(name, client_type, hid_transport, - passphrase_ack, identity_type, - required_version, call_exception): - - def empty_passphrase_handler(_): - return passphrase_ack(passphrase='') - - for d in hid_transport.enumerate(): - connection = client_type(hid_transport(d)) - connection.callback_PassphraseRequest = empty_passphrase_handler - f = connection.features - log.debug('connected to %s %s', name, f.device_id) - log.debug('label : %s', f.label) - log.debug('vendor : %s', f.vendor) - current_version = '{}.{}.{}'.format(f.major_version, - f.minor_version, - f.patch_version) - log.debug('version : %s', current_version) - log.debug('revision : %s', binascii.hexlify(f.revision)) - if not semver.match(current_version, required_version): - fmt = 'Please upgrade your {} firmware to {} version (current: {})' - raise ValueError(fmt.format(name, - required_version, - current_version)) - yield ClientWrapper(connection=connection, - identity_type=identity_type, - device_name=name, - call_exception=call_exception) - return - - -def _load_trezor(): - try: - from trezorlib.client import TrezorClient, CallException - from trezorlib.transport_hid import HidTransport - from trezorlib.messages_pb2 import PassphraseAck - from trezorlib.types_pb2 import IdentityType - return _load_client(name='Trezor', - client_type=TrezorClient, - hid_transport=HidTransport, - passphrase_ack=PassphraseAck, - identity_type=IdentityType, - required_version='>=1.4.0', - call_exception=CallException) - except ImportError as e: - log.warning('%s: install via "pip install trezor" ' - 'if you need to support this device', e) - - -def _load_keepkey(): - try: - from keepkeylib.client import KeepKeyClient, CallException - from keepkeylib.transport_hid import HidTransport - from keepkeylib.messages_pb2 import PassphraseAck - from keepkeylib.types_pb2 import IdentityType - return _load_client(name='KeepKey', - client_type=KeepKeyClient, - hid_transport=HidTransport, - passphrase_ack=PassphraseAck, - identity_type=IdentityType, - required_version='>=1.0.4', - call_exception=CallException) - except ImportError as e: - log.warning('%s: install via "pip install keepkey" ' - 'if you need to support this device', e) - - -def _load_ledger(): - from ._ledger import LedgerClientConnection, CallException, IdentityType - try: - from ledgerblue.comm import getDongle, CommException - except ImportError as e: - log.warning('%s: install via "pip install ledgerblue" ' - 'if you need to support this device', e) - return - try: - dongle = getDongle() - except CommException: - return - - yield ClientWrapper(connection=LedgerClientConnection(dongle), - identity_type=IdentityType, - device_name="ledger", - call_exception=CallException) - - -LOADERS = [ - _load_trezor, - _load_keepkey, - _load_ledger -] - - -def load(loaders=None): - """Load a single device, via specified loaders' list.""" - loaders = loaders if loaders is not None else LOADERS - device_list = [] - for loader in loaders: - device = loader() - if device: - device_list.extend(device) - - if len(device_list) == 1: - return device_list[0] - - msg = '{:d} devices found'.format(len(device_list)) - raise IOError(msg) diff --git a/trezor_agent/gpg/__main__.py b/trezor_agent/gpg/__main__.py index d069eb5..e9b8596 100755 --- a/trezor_agent/gpg/__main__.py +++ b/trezor_agent/gpg/__main__.py @@ -29,10 +29,10 @@ def run_create(args): log.warning('NOTE: in order to re-generate the exact same GPG key later, ' 'run this command with "--time=%d" commandline flag (to set ' 'the timestamp of the GPG key manually).', args.time) - conn = device.HardwareSigner(user_id=args.user_id, - curve_name=args.ecdsa_curve) - verifying_key = conn.pubkey(ecdh=False) - decryption_key = conn.pubkey(ecdh=True) + d = device.HardwareSigner(user_id=args.user_id, + curve_name=args.ecdsa_curve) + verifying_key = d.pubkey(ecdh=False) + decryption_key = d.pubkey(ecdh=True) if key_exists(args.user_id): # add as subkey log.info('adding %s GPG subkey for "%s" to existing key', @@ -48,10 +48,10 @@ def run_create(args): primary_bytes = keyring.export_public_key(args.user_id) result = encode.create_subkey(primary_bytes=primary_bytes, subkey=signing_key, - signer_func=conn.sign) + signer_func=d.sign) result = encode.create_subkey(primary_bytes=result, subkey=encryption_key, - signer_func=conn.sign) + signer_func=d.sign) else: # add as primary log.info('creating new %s GPG primary key for "%s"', args.ecdsa_curve, args.user_id) @@ -66,10 +66,10 @@ def run_create(args): result = encode.create_primary(user_id=args.user_id, pubkey=primary, - signer_func=conn.sign) + signer_func=d.sign) result = encode.create_subkey(primary_bytes=result, subkey=subkey, - signer_func=conn.sign) + signer_func=d.sign) sys.stdout.write(protocol.armor(result, 'PUBLIC KEY BLOCK')) diff --git a/trezor_agent/gpg/agent.py b/trezor_agent/gpg/agent.py index e1d3f9c..575f8d2 100644 --- a/trezor_agent/gpg/agent.py +++ b/trezor_agent/gpg/agent.py @@ -1,6 +1,5 @@ """GPG-agent utilities.""" import binascii -import contextlib import logging from . import decode, device, keyring, protocol @@ -37,7 +36,6 @@ def sig_encode(r, s): return b'(7:sig-val(5:ecdsa(1:r32:' + r + b')(1:s32:' + s + b')))' -@contextlib.contextmanager def open_connection(keygrip_bytes): """ Connect to the device for the specified keygrip. @@ -49,29 +47,28 @@ def open_connection(keygrip_bytes): pubkey_bytes=keyring.export_public_keys(), keygrip=keygrip_bytes) # We assume the first user ID is used to generate TREZOR-based GPG keys. - user_id = user_ids[0]['value'] + user_id = user_ids[0]['value'].decode('ascii') curve_name = protocol.get_curve_name_by_oid(pubkey_dict['curve_oid']) ecdh = (pubkey_dict['algo'] == protocol.ECDH_ALGO_ID) conn = device.HardwareSigner(user_id, curve_name=curve_name) - with contextlib.closing(conn): - pubkey = protocol.PublicKey( - curve_name=curve_name, created=pubkey_dict['created'], - verifying_key=conn.pubkey(ecdh=ecdh), ecdh=ecdh) - assert pubkey.key_id() == pubkey_dict['key_id'] - assert pubkey.keygrip == keygrip_bytes - yield conn + pubkey = protocol.PublicKey( + curve_name=curve_name, created=pubkey_dict['created'], + verifying_key=conn.pubkey(ecdh=ecdh), ecdh=ecdh) + assert pubkey.key_id() == pubkey_dict['key_id'] + assert pubkey.keygrip == keygrip_bytes + return conn def pksign(keygrip, digest, algo): """Sign a message digest using a private EC key.""" log.debug('signing %r digest (algo #%s)', digest, algo) keygrip_bytes = binascii.unhexlify(keygrip) - with open_connection(keygrip_bytes) as conn: - r, s = conn.sign(binascii.unhexlify(digest)) - result = sig_encode(r, s) - log.debug('result: %r', result) - return result + conn = open_connection(keygrip_bytes) + r, s = conn.sign(binascii.unhexlify(digest)) + result = sig_encode(r, s) + log.debug('result: %r', result) + return result def _serialize_point(data): @@ -105,8 +102,8 @@ def pkdecrypt(keygrip, conn): remote_pubkey = parse_ecdh(line) keygrip_bytes = binascii.unhexlify(keygrip) - with open_connection(keygrip_bytes) as conn: - return _serialize_point(conn.ecdh(remote_pubkey)) + conn = open_connection(keygrip_bytes) + return _serialize_point(conn.ecdh(remote_pubkey)) def handle_connection(conn): diff --git a/trezor_agent/gpg/device.py b/trezor_agent/gpg/device.py index 123f7f8..04e6dba 100644 --- a/trezor_agent/gpg/device.py +++ b/trezor_agent/gpg/device.py @@ -2,7 +2,7 @@ import logging -from .. import factory, formats, util +from .. import device, formats, util log = logging.getLogger(__name__) @@ -12,55 +12,33 @@ class HardwareSigner(object): def __init__(self, user_id, curve_name): """Connect to the device and retrieve required public key.""" - self.client_wrapper = factory.load() - self.identity = self.client_wrapper.identity_type() - self.identity.proto = 'gpg' - self.identity.host = user_id - self.curve_name = curve_name + self.device = device.detect(identity_str='', + curve_name=curve_name) + self.device.identity_dict['proto'] = 'gpg' + self.device.identity_dict['host'] = user_id self.user_id = user_id def pubkey(self, ecdh=False): """Return public key as VerifyingKey object.""" - addr = util.get_bip32_address(identity=self.identity, ecdh=ecdh) - if ecdh: - curve_name = formats.get_ecdh_curve_name(self.curve_name) - else: - curve_name = self.curve_name - public_node = self.client_wrapper.connection.get_public_node( - n=addr, ecdsa_curve_name=curve_name) - + with self.device: + pubkey = self.device.pubkey(ecdh=ecdh) return formats.decompress_pubkey( - pubkey=public_node.node.public_key, - curve_name=curve_name) + pubkey=pubkey, curve_name=self.device.curve_name) def sign(self, digest): """Sign the digest and return a serialized signature.""" log.info('please confirm GPG signature on %s for "%s"...', - self.client_wrapper.device_name, self.user_id) - if self.curve_name == formats.CURVE_NIST256: + self.device, self.user_id) + if self.device.curve_name == formats.CURVE_NIST256: digest = digest[:32] # sign the first 256 bits log.debug('signing digest: %s', util.hexlify(digest)) - result = self.client_wrapper.connection.sign_identity( - identity=self.identity, - challenge_hidden=digest, - challenge_visual='', - ecdsa_curve_name=self.curve_name) - assert result.signature[:1] == b'\x00' - sig = result.signature[1:] + with self.device: + sig = self.device.sign(blob=digest) return (util.bytes2num(sig[:32]), util.bytes2num(sig[32:])) def ecdh(self, pubkey): """Derive shared secret using ECDH from remote public key.""" log.info('please confirm GPG decryption on %s for "%s"...', - self.client_wrapper.device_name, self.user_id) - result = self.client_wrapper.connection.get_ecdh_session_key( - identity=self.identity, - peer_public_key=pubkey, - ecdsa_curve_name=formats.get_ecdh_curve_name(self.curve_name)) - assert len(result.session_key) in {65, 33} # NIST256 or Curve25519 - assert result.session_key[:1] == b'\x04' - return result.session_key - - def close(self): - """Close the connection to the device.""" - self.client_wrapper.connection.close() + self.device, self.user_id) + with self.device: + return self.device.ecdh(pubkey=pubkey) diff --git a/trezor_agent/protocol.py b/trezor_agent/protocol.py index efa2fea..e1a763f 100644 --- a/trezor_agent/protocol.py +++ b/trezor_agent/protocol.py @@ -7,7 +7,6 @@ for more details. The server's source code can be found here: https://github.com/openssh/openssh-portable/blob/master/authfd.c """ -import binascii import io import logging @@ -138,13 +137,13 @@ class Handler(object): else: raise KeyError('key not found') - log.debug('signing %d-byte blob', len(blob)) label = key['name'].decode('ascii') # label should be a string + log.debug('signing %d-byte blob with "%s" key', len(blob), label) try: - signature = self.signer(label=label, blob=blob) + signature = self.signer(blob=blob) except IOError: return failure() - log.debug('signature: %s', binascii.hexlify(signature)) + log.debug('signature: %r', signature) try: sig_bytes = key['verifier'](sig=signature, msg=blob) diff --git a/trezor_agent/tests/test_client.py b/trezor_agent/tests/test_client.py index 6a8273f..b3f4bad 100644 --- a/trezor_agent/tests/test_client.py +++ b/trezor_agent/tests/test_client.py @@ -3,7 +3,7 @@ import io import mock import pytest -from .. import client, factory, formats, util +from .. import client, device, formats, util ADDR = [2147483661, 2810943954, 3938368396, 3454558782, 3848009040] CURVE = 'nist256p1' @@ -15,29 +15,23 @@ PUBKEY_TEXT = ('ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzd' 'VUfhvrGljR2Z/CMRONY6ejB+9PnpUOPuzYqi8= ssh://localhost:22\n') -class FakeConnection(object): +class MockDevice(device.interface.Device): # pylint: disable=abstract-method - def __init__(self): - self.closed = False + def connect(self): # pylint: disable=no-self-use + return mock.Mock() def close(self): - self.closed = True + self.conn = None - def clear_session(self): - self.closed = True + def pubkey(self, ecdh=False): # pylint: disable=unused-argument + assert self.conn + return PUBKEY - def get_public_node(self, n, ecdsa_curve_name=b'secp256k1'): - assert not self.closed - assert n == ADDR - assert ecdsa_curve_name in {'secp256k1', 'nist256p1'} - result = mock.Mock(spec=[]) - result.node = mock.Mock(spec=[]) - result.node.public_key = PUBKEY - return result - - def ping(self, msg): - assert not self.closed - return msg + def sign(self, blob): + """Sign given blob and return the signature (as bytes).""" + assert self.conn + assert blob == BLOB + return SIG def identity_type(**kwargs): @@ -50,13 +44,6 @@ def identity_type(**kwargs): return result -def load_client(): - return factory.ClientWrapper(connection=FakeConnection(), - identity_type=identity_type, - device_name='DEVICE_NAME', - call_exception=Exception) - - BLOB = (b'\x00\x00\x00 \xce\xe0\xc9\xd5\xceu/\xe8\xc5\xf2\xbfR+x\xa1\xcf\xb0' b'\x8e;R\xd3)m\x96\x1b\xb4\xd8s\xf1\x99\x16\xaa2\x00\x00\x00\x05roman' b'\x00\x00\x00\x0essh-connection\x00\x00\x00\tpublickey' @@ -66,71 +53,33 @@ BLOB = (b'\x00\x00\x00 \xce\xe0\xc9\xd5\xceu/\xe8\xc5\xf2\xbfR+x\xa1\xcf\xb0' b'\xdd\xbc+\xfar~\x9dAis4\xc1\x10yeT~\x1b\xeb\x1aX\xd1\xd9\x9f\xc21' b'\x13\x8dc\xa7\xa3\x07\xefO\x9e\x95\x0e>\xec\xd8\xaa/') -SIG = (b'\x00R\x19T\xf2\x84$\xef#\x0e\xee\x04X\xc6\xc3\x99T`\xd1\xd8\xf7!' +SIG = (b'R\x19T\xf2\x84$\xef#\x0e\xee\x04X\xc6\xc3\x99T`\xd1\xd8\xf7!' b'\x862@cx\xb8\xb9i@1\x1b3#\x938\x86]\x97*Y\xb2\x02Xa\xdf@\xecK' b'\xdc\xf0H\xab\xa8\xac\xa7? \x8f=C\x88N\xe2') def test_ssh_agent(): - label = 'localhost:22' - c = client.Client(loader=load_client) - ident = c.get_identity(label=label) - assert ident.host == 'localhost' - assert ident.proto == 'ssh' - assert ident.port == '22' - assert ident.user is None - assert ident.path is None - assert ident.index == 0 - - with c: - assert c.get_public_key(label) == PUBKEY_TEXT - - def ssh_sign_identity(identity, challenge_hidden, - challenge_visual, ecdsa_curve_name): - assert (util.identity_to_string(identity) == - util.identity_to_string(ident)) - assert challenge_hidden == BLOB - assert challenge_visual == '' - assert ecdsa_curve_name == 'nist256p1' - - result = mock.Mock(spec=[]) - result.public_key = PUBKEY - result.signature = SIG - return result - - c.client.sign_identity = ssh_sign_identity - signature = c.sign_ssh_challenge(label=label, blob=BLOB) - - key = formats.import_public_key(PUBKEY_TEXT) - serialized_sig = key['verifier'](sig=signature, msg=BLOB) - - stream = io.BytesIO(serialized_sig) - r = util.read_frame(stream) - s = util.read_frame(stream) - assert not stream.read() - assert r[:1] == b'\x00' - assert s[:1] == b'\x00' - assert r[1:] + s[1:] == SIG[1:] - - c.client.call_exception = ValueError - - # pylint: disable=unused-argument - def cancel_sign_identity(identity, challenge_hidden, - challenge_visual, ecdsa_curve_name): - raise c.client.call_exception(42, 'ERROR') - - c.client.sign_identity = cancel_sign_identity - with pytest.raises(IOError): - c.sign_ssh_challenge(label=label, blob=BLOB) - - -def test_utils(): - identity = mock.Mock(spec=[]) - identity.proto = 'https' - identity.user = 'user' - identity.host = 'host' - identity.port = '443' - identity.path = '/path' - - url = 'https://user@host:443/path' - assert util.identity_to_string(identity) == url + identity_str = 'localhost:22' + c = client.Client(device=MockDevice(identity_str=identity_str, + curve_name=CURVE)) + assert c.get_public_key() == PUBKEY_TEXT + signature = c.sign_ssh_challenge(blob=BLOB) + + key = formats.import_public_key(PUBKEY_TEXT) + serialized_sig = key['verifier'](sig=signature, msg=BLOB) + + stream = io.BytesIO(serialized_sig) + r = util.read_frame(stream) + s = util.read_frame(stream) + assert not stream.read() + assert r[:1] == b'\x00' + assert s[:1] == b'\x00' + assert r[1:] + s[1:] == SIG + + # pylint: disable=unused-argument + def cancel_sign(blob): + raise IOError(42, 'ERROR') + + c.device.sign = cancel_sign + with pytest.raises(IOError): + c.sign_ssh_challenge(blob=BLOB) diff --git a/trezor_agent/tests/test_factory.py b/trezor_agent/tests/test_factory.py deleted file mode 100644 index b904666..0000000 --- a/trezor_agent/tests/test_factory.py +++ /dev/null @@ -1,97 +0,0 @@ -import mock -import pytest - -from .. import factory - - -def test_load(): - - def single(): - return [0] - - def nothing(): - return [] - - def double(): - return [1, 2] - - assert factory.load(loaders=[single]) == 0 - assert factory.load(loaders=[single, nothing]) == 0 - assert factory.load(loaders=[nothing, single]) == 0 - - with pytest.raises(IOError): - factory.load(loaders=[]) - - with pytest.raises(IOError): - factory.load(loaders=[single, single]) - - with pytest.raises(IOError): - factory.load(loaders=[double]) - - -def factory_load_client(**kwargs): - # pylint: disable=protected-access - return list(factory._load_client(**kwargs)) - - -def test_load_nothing(): - hid_transport = mock.Mock(spec_set=['enumerate']) - hid_transport.enumerate.return_value = [] - result = factory_load_client( - name=None, - client_type=None, - hid_transport=hid_transport, - passphrase_ack=None, - identity_type=None, - required_version=None, - call_exception=None) - assert result == [] - - -def create_client_type(version): - conn = mock.Mock(spec=[]) - conn.features = mock.Mock(spec=[]) - major, minor, patch = version.split('.') - conn.features.device_id = 'DEVICE_ID' - conn.features.label = 'LABEL' - conn.features.vendor = 'VENDOR' - conn.features.major_version = major - conn.features.minor_version = minor - conn.features.patch_version = patch - conn.features.revision = b'\x12\x34\x56\x78' - return mock.Mock(spec_set=[], return_value=conn) - - -def test_load_single(): - hid_transport = mock.Mock(spec_set=['enumerate']) - hid_transport.enumerate.return_value = [0] - for version in ('1.3.4', '1.3.5', '1.4.0', '2.0.0'): - passphrase_ack = mock.Mock(spec_set=[]) - client_type = create_client_type(version) - client_wrapper, = factory_load_client( - name='DEVICE_NAME', - client_type=client_type, - hid_transport=hid_transport, - passphrase_ack=passphrase_ack, - identity_type=None, - required_version='>=1.3.4', - call_exception=None) - assert client_wrapper.connection is client_type.return_value - assert client_wrapper.device_name == 'DEVICE_NAME' - client_wrapper.connection.callback_PassphraseRequest('MESSAGE') - assert passphrase_ack.mock_calls == [mock.call(passphrase='')] - - -def test_load_old(): - hid_transport = mock.Mock(spec_set=['enumerate']) - hid_transport.enumerate.return_value = [0] - for version in ('1.3.3', '1.2.5', '1.1.0', '0.9.9'): - with pytest.raises(ValueError): - factory_load_client( - name='DEVICE_NAME', - client_type=create_client_type(version), - hid_transport=hid_transport, - passphrase_ack=None, - identity_type=None, - required_version='>=1.3.4', - call_exception=None) diff --git a/trezor_agent/tests/test_protocol.py b/trezor_agent/tests/test_protocol.py index 541fecb..17a1001 100644 --- a/trezor_agent/tests/test_protocol.py +++ b/trezor_agent/tests/test_protocol.py @@ -28,8 +28,7 @@ def test_unsupported(): assert reply == b'\x00\x00\x00\x01\x05' -def ecdsa_signer(label, blob): - assert label == 'ssh://localhost' +def ecdsa_signer(blob): assert blob == NIST256_BLOB return NIST256_SIG @@ -49,8 +48,7 @@ def test_sign_missing(): def test_sign_wrong(): - def wrong_signature(label, blob): - assert label == 'ssh://localhost' + def wrong_signature(blob): assert blob == NIST256_BLOB return b'\x00' * 64 @@ -62,7 +60,7 @@ def test_sign_wrong(): def test_sign_cancel(): - def cancel_signature(label, blob): # pylint: disable=unused-argument + def cancel_signature(blob): # pylint: disable=unused-argument raise IOError() key = formats.import_public_key(NIST256_KEY) @@ -79,8 +77,7 @@ ED25519_BLOB = b'''\x00\x00\x00 i3\xae}yk\\\xa1L\xb9\xe1\xbf\xbc\x8e\x87\r\x0e\x ED25519_SIG = b'''\x8eb)\xa6\xe9P\x83VE\xfbq\xc6\xbf\x1dV3\xe3.*)://)?', - r'(?:(?P.*)@)?', - r'(?P.*?)', - r'(?::(?P\w*))?', - r'(?P/.*)?', - '$' -])) - - -def string_to_identity(s, identity_type): - """Parse string into Identity protobuf.""" - m = _identity_regexp.match(s) - result = m.groupdict() - log.debug('parsed identity: %s', result) - kwargs = {k: v for k, v in result.items() if v} - return identity_type(**kwargs) - - -def identity_to_string(identity): - """Dump Identity protobuf into its string representation.""" - result = [] - if identity.proto: - result.append(identity.proto + '://') - if identity.user: - result.append(identity.user + '@') - result.append(identity.host) - if identity.port: - result.append(':' + identity.port) - if identity.path: - result.append(identity.path) - return ''.join(result) - - -def get_bip32_address(identity, ecdh=False): - """Compute BIP32 derivation address according to SLIP-0013/0017.""" - index = struct.pack('