diff --git a/trezorlib/client.py b/trezorlib/client.py index aa6bcf9..3964c78 100644 --- a/trezorlib/client.py +++ b/trezorlib/client.py @@ -88,6 +88,18 @@ class expect(object): return ret return wrapped_f +def session(f): + # Decorator wraps a BaseClient method + # with session activation / deactivation + def wrapped_f(*args, **kwargs): + client = args[0] + try: + client.transport.session_begin() + return f(*args, **kwargs) + finally: + client.transport.session_end() + return wrapped_f + def normalize_nfc(txt): if sys.version_info[0] < 3: if isinstance(txt, unicode): @@ -112,33 +124,22 @@ class BaseClient(object): def cancel(self): self.transport.write(proto.Cancel()) + @session def call_raw(self, msg): - try: - self.transport.session_begin() - self.transport.write(msg) - resp = self.transport.read_blocking() - finally: - self.transport.session_end() - - return resp + self.transport.write(msg) + return self.transport.read_blocking() + @session def call(self, msg): - try: - self.transport.session_begin() - - resp = self.call_raw(msg) - handler_name = "callback_%s" % resp.__class__.__name__ - handler = getattr(self, handler_name, None) + resp = self.call_raw(msg) + handler_name = "callback_%s" % resp.__class__.__name__ + handler = getattr(self, handler_name, None) - if handler != None: - msg = handler(resp) - if msg == None: - raise Exception("Callback %s must return protobuf message, not None" % handler) - - resp = self.call(msg) - - finally: - self.transport.session_end() + if handler != None: + msg = handler(resp) + if msg == None: + raise Exception("Callback %s must return protobuf message, not None" % handler) + resp = self.call(msg) return resp @@ -423,6 +424,7 @@ class ProtocolMixin(object): n = self._convert_prime(n) return self.call(proto.EthereumGetAddress(address_n=n, show_display=show_display)) + @session def ethereum_sign_tx(self, n, nonce, gas_price, gas_limit, to, value, data=None): def int_to_big_endian(value): import rlp.utils @@ -432,35 +434,30 @@ class ProtocolMixin(object): n = self._convert_prime(n) - try: - self.transport.session_begin() - - msg = proto.EthereumSignTx( - address_n=n, - nonce=int_to_big_endian(nonce), - gas_price=int_to_big_endian(gas_price), - gas_limit=int_to_big_endian(gas_limit), - value=int_to_big_endian(value)) + msg = proto.EthereumSignTx( + address_n=n, + nonce=int_to_big_endian(nonce), + gas_price=int_to_big_endian(gas_price), + gas_limit=int_to_big_endian(gas_limit), + value=int_to_big_endian(value)) - if to: - msg.to = to + if to: + msg.to = to - if data: - msg.data_length = len(data) - data, chunk = data[1024:], data[:1024] - msg.data_initial_chunk = chunk + if data: + msg.data_length = len(data) + data, chunk = data[1024:], data[:1024] + msg.data_initial_chunk = chunk - response = self.call(msg) + response = self.call(msg) - while response.HasField('data_length'): - data_length = response.data_length - data, chunk = data[data_length:], data[:data_length] - response = self.call(proto.EthereumTxAck(data_chunk=chunk)) + while response.HasField('data_length'): + data_length = response.data_length + data, chunk = data[data_length:], data[:data_length] + response = self.call(proto.EthereumTxAck(data_chunk=chunk)) - return response.signature_v, response.signature_r, response.signature_s + return response.signature_v, response.signature_r, response.signature_s - finally: - self.transport.session_end() @field('entropy') @expect(proto.Entropy) @@ -634,88 +631,83 @@ class ProtocolMixin(object): return txes + @session def sign_tx(self, coin_name, inputs, outputs, debug_processor=None): start = time.time() txes = self._prepare_sign_tx(coin_name, inputs, outputs) - try: - self.transport.session_begin() - - # Prepare and send initial message - tx = proto.SignTx() - tx.inputs_count = len(inputs) - tx.outputs_count = len(outputs) - tx.coin_name = coin_name - res = self.call(tx) - - # Prepare structure for signatures - signatures = [None] * len(inputs) - serialized_tx = b'' - - counter = 0 - while True: - counter += 1 - - if isinstance(res, proto.Failure): - raise CallException("Signing failed") - - if not isinstance(res, proto.TxRequest): - raise CallException("Unexpected message") - - # If there's some part of signed transaction, let's add it - if res.HasField('serialized') and res.serialized.HasField('serialized_tx'): - log("RECEIVED PART OF SERIALIZED TX (%d BYTES)" % len(res.serialized.serialized_tx)) - serialized_tx += res.serialized.serialized_tx - - if res.HasField('serialized') and res.serialized.HasField('signature_index'): - if signatures[res.serialized.signature_index] != None: - raise Exception("Signature for index %d already filled" % res.serialized.signature_index) - signatures[res.serialized.signature_index] = res.serialized.signature - - if res.request_type == types.TXFINISHED: - # Device didn't ask for more information, finish workflow - break - - # Device asked for one more information, let's process it. - current_tx = txes[res.details.tx_hash] - - if res.request_type == types.TXMETA: - msg = types.TransactionType() - msg.version = current_tx.version - msg.lock_time = current_tx.lock_time - msg.inputs_cnt = len(current_tx.inputs) - if res.details.tx_hash: - msg.outputs_cnt = len(current_tx.bin_outputs) - else: - msg.outputs_cnt = len(current_tx.outputs) - res = self.call(proto.TxAck(tx=msg)) - continue - - elif res.request_type == types.TXINPUT: - msg = types.TransactionType() - msg.inputs.extend([current_tx.inputs[res.details.request_index], ]) - res = self.call(proto.TxAck(tx=msg)) - continue - - elif res.request_type == types.TXOUTPUT: - msg = types.TransactionType() - if res.details.tx_hash: - msg.bin_outputs.extend([current_tx.bin_outputs[res.details.request_index], ]) - else: - msg.outputs.extend([current_tx.outputs[res.details.request_index], ]) - - if debug_processor != None: - # If debug_processor function is provided, - # pass thru it the request and prepared response. - # This is useful for unit tests, see test_msg_signtx - msg = debug_processor(res, msg) - - res = self.call(proto.TxAck(tx=msg)) - continue + # Prepare and send initial message + tx = proto.SignTx() + tx.inputs_count = len(inputs) + tx.outputs_count = len(outputs) + tx.coin_name = coin_name + res = self.call(tx) + + # Prepare structure for signatures + signatures = [None] * len(inputs) + serialized_tx = b'' + + counter = 0 + while True: + counter += 1 + + if isinstance(res, proto.Failure): + raise CallException("Signing failed") + + if not isinstance(res, proto.TxRequest): + raise CallException("Unexpected message") + + # If there's some part of signed transaction, let's add it + if res.HasField('serialized') and res.serialized.HasField('serialized_tx'): + log("RECEIVED PART OF SERIALIZED TX (%d BYTES)" % len(res.serialized.serialized_tx)) + serialized_tx += res.serialized.serialized_tx + + if res.HasField('serialized') and res.serialized.HasField('signature_index'): + if signatures[res.serialized.signature_index] != None: + raise Exception("Signature for index %d already filled" % res.serialized.signature_index) + signatures[res.serialized.signature_index] = res.serialized.signature + + if res.request_type == types.TXFINISHED: + # Device didn't ask for more information, finish workflow + break + + # Device asked for one more information, let's process it. + current_tx = txes[res.details.tx_hash] + + if res.request_type == types.TXMETA: + msg = types.TransactionType() + msg.version = current_tx.version + msg.lock_time = current_tx.lock_time + msg.inputs_cnt = len(current_tx.inputs) + if res.details.tx_hash: + msg.outputs_cnt = len(current_tx.bin_outputs) + else: + msg.outputs_cnt = len(current_tx.outputs) + res = self.call(proto.TxAck(tx=msg)) + continue - finally: - self.transport.session_end() + elif res.request_type == types.TXINPUT: + msg = types.TransactionType() + msg.inputs.extend([current_tx.inputs[res.details.request_index], ]) + res = self.call(proto.TxAck(tx=msg)) + continue + + elif res.request_type == types.TXOUTPUT: + msg = types.TransactionType() + if res.details.tx_hash: + msg.bin_outputs.extend([current_tx.bin_outputs[res.details.request_index], ]) + else: + msg.outputs.extend([current_tx.outputs[res.details.request_index], ]) + + if debug_processor != None: + # If debug_processor function is provided, + # pass thru it the request and prepared response. + # This is useful for unit tests, see test_msg_signtx + msg = debug_processor(res, msg) + + res = self.call(proto.TxAck(tx=msg)) + continue if None in signatures: raise Exception("Some signatures are missing!") @@ -753,6 +745,7 @@ class ProtocolMixin(object): @field('message') @expect(proto.Success) + @session def reset_device(self, display_random, strength, passphrase_protection, pin_protection, label, language): if self.features.initialized: raise Exception("Device is initialized already. Call wipe_device() and try again.") @@ -843,6 +836,7 @@ class ProtocolMixin(object): self.init_device() return resp + @session def firmware_update(self, fp): if self.features.bootloader_mode == False: raise Exception("Device must be in bootloader mode") diff --git a/trezorlib/transport.py b/trezorlib/transport.py index ecc19be..619049f 100644 --- a/trezorlib/transport.py +++ b/trezorlib/transport.py @@ -71,9 +71,10 @@ class Transport(object): def _parse_message(self, data): (session_id, msg_type, data) = data - # Raise exception if we get the response with - # unexpected session ID - self._check_session_id(session_id) + # Raise exception if we get the response with unexpected session ID + if session_id != self.session_id: + raise Exception("Session ID mismatch. Have %d, got %d" % + (self.session_id, session_id)) if msg_type == 'protobuf': return data @@ -82,14 +83,6 @@ class Transport(object): inst.ParseFromString(bytes(data)) return inst - def _check_session_id(self, session_id): - if self.session_id == 0: - # Let the device set the session ID - self.session_id = session_id - elif session_id != self.session_id: - # Session ID has been already set, but it differs from response - raise Exception("Session ID mismatch. Have %d, got %d" % (self.session_id, session_id)) - # Functions to be implemented in specific transports: def _open(self): raise NotImplementedException("Not implemented") @@ -237,6 +230,28 @@ class TransportV2(Transport): data = chunk[1 + headerlen:] return (session_id, data) + def parse_session(self, chunk): + if chunk[0:1] != b"!": + raise Exception("Unexpected magic character") + + try: + headerlen = struct.calcsize(">LL") + (null_session_id, new_session_id) = struct.unpack( + ">LL", bytes(chunk[1:1 + headerlen])) + except: + raise Exception("Cannot parse header") + + if null_session_id != 0: + raise Exception("Session response needs to use session ID 0") + return new_session_id + + def _session_begin(self): + self._write_chunk(b'!' + b'\0' * 63) + self.session_id = self.parse_session(self._read_chunk()) + + def _session_end(self): + pass + ''' def read_headers(self, read_f): c = read_f.read(2)