From 8596537a54b88d91b3ceb980591c7b94de7625f4 Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Tue, 16 Jun 2015 10:34:02 +0300 Subject: [PATCH] protocol: use Handler class and fix pylint warnings --- sshagent/protocol.py | 144 +++++++++++++++++++++++-------------------- sshagent/server.py | 7 ++- 2 files changed, 81 insertions(+), 70 deletions(-) diff --git a/sshagent/protocol.py b/sshagent/protocol.py index 8a40881..785b4eb 100644 --- a/sshagent/protocol.py +++ b/sshagent/protocol.py @@ -20,70 +20,80 @@ SSH2_AGENTC_REMOVE_IDENTITY = 18 SSH2_AGENTC_REMOVE_ALL_IDENTITIES = 19 -def legacy_pubs(buf, keys, signer): - code = util.pack('B', SSH_AGENT_RSA_IDENTITIES_ANSWER) - num = util.pack('L', 0) # no SSH v1 keys - return util.frame(code, num) - - -def list_pubs(buf, keys, signer): - code = util.pack('B', SSH2_AGENT_IDENTITIES_ANSWER) - num = util.pack('L', len(keys)) - log.debug('available keys: %s', [k['name'] for k in keys]) - for i, k in enumerate(keys): - log.debug('%2d) %s', i+1, k['fingerprint']) - pubs = [util.frame(k['blob']) + util.frame(k['name']) for k in keys] - return util.frame(code, num, *pubs) - - -def sign_message(buf, keys, signer): - key = formats.parse_pubkey(util.read_frame(buf)) - log.debug('looking for %s', key['fingerprint']) - blob = util.read_frame(buf) - - for k in keys: - if (k['fingerprint']) == (key['fingerprint']): - log.debug('using key %r (%s)', k['name'], k['fingerprint']) - key = k - break - else: - raise ValueError('key not found') - - log.debug('signing %d-byte blob', len(blob)) - r, s = signer(label=key['name'], blob=blob) - signature = (r, s) - log.debug('signature: %s', signature) - - success = key['verifying_key'].verify(signature=signature, data=blob, - sigdecode=lambda sig, _: sig) - log.info('signature status: %s', 'OK' if success else 'ERROR') - if not success: - raise ValueError('invalid signature') - - sig_bytes = io.BytesIO() - for x in signature: - sig_bytes.write(util.frame(b'\x00' + util.num2bytes(x, key['size']))) - sig_bytes = sig_bytes.getvalue() - log.debug('signature size: %d bytes', len(sig_bytes)) - - data = util.frame(util.frame(key['type']), util.frame(sig_bytes)) - code = util.pack('B', SSH2_AGENT_SIGN_RESPONSE) - return util.frame(code, data) - - -handlers = { - SSH_AGENTC_REQUEST_RSA_IDENTITIES: legacy_pubs, - SSH2_AGENTC_REQUEST_IDENTITIES: list_pubs, - SSH2_AGENTC_SIGN_REQUEST: sign_message, -} - - -def handle_message(msg, keys, signer): - log.debug('request: %d bytes', len(msg)) - buf = io.BytesIO(msg) - code, = util.recv(buf, '>B') - handler = handlers[code] - log.debug('calling %s()', handler.__name__) - reply = handler(buf=buf, keys=keys, signer=signer) - log.debug('reply: %d bytes', len(reply)) - return reply +class Handler(object): + + def __init__(self, keys, signer): + self.public_keys = keys + self.signer = signer + + self.methods = { + SSH_AGENTC_REQUEST_RSA_IDENTITIES: Handler.legacy_pubs, + SSH2_AGENTC_REQUEST_IDENTITIES: self.list_pubs, + SSH2_AGENTC_SIGN_REQUEST: self.sign_message, + } + + def handle(self, msg): + log.debug('request: %d bytes', len(msg)) + buf = io.BytesIO(msg) + code, = util.recv(buf, '>B') + method = self.methods[code] + log.debug('calling %s()', method.__name__) + reply = method(buf=buf) + log.debug('reply: %d bytes', len(reply)) + return reply + + @staticmethod + def legacy_pubs(buf): + ''' SSH v1 public keys are not supported ''' + assert not buf.read() + code = util.pack('B', SSH_AGENT_RSA_IDENTITIES_ANSWER) + num = util.pack('L', 0) # no SSH v1 keys + return util.frame(code, num) + + def list_pubs(self, buf): + ''' SSH v2 public keys are serialized and returned. ''' + assert not buf.read() + keys = self.public_keys + code = util.pack('B', SSH2_AGENT_IDENTITIES_ANSWER) + num = util.pack('L', len(keys)) + log.debug('available keys: %s', [k['name'] for k in keys]) + for i, k in enumerate(keys): + log.debug('%2d) %s', i+1, k['fingerprint']) + pubs = [util.frame(k['blob']) + util.frame(k['name']) for k in keys] + return util.frame(code, num, *pubs) + + def sign_message(self, buf): + ''' SSH v2 public key authentication is performed. ''' + key = formats.parse_pubkey(util.read_frame(buf)) + log.debug('looking for %s', key['fingerprint']) + blob = util.read_frame(buf) + + for k in self.public_keys: + if (k['fingerprint']) == (key['fingerprint']): + log.debug('using key %r (%s)', k['name'], k['fingerprint']) + key = k + break + else: + raise ValueError('key not found') + + log.debug('signing %d-byte blob', len(blob)) + r, s = self.signer(label=key['name'], blob=blob) + signature = (r, s) + log.debug('signature: %s', signature) + + success = key['verifying_key'].verify(signature=signature, data=blob, + sigdecode=lambda sig, _: sig) + log.info('signature status: %s', 'OK' if success else 'ERROR') + if not success: + raise ValueError('invalid signature') + + sig_bytes = io.BytesIO() + for x in signature: + x_frame = util.frame(b'\x00' + util.num2bytes(x, key['size'])) + sig_bytes.write(x_frame) + sig_bytes = sig_bytes.getvalue() + log.debug('signature size: %d bytes', len(sig_bytes)) + + data = util.frame(util.frame(key['type']), util.frame(sig_bytes)) + code = util.pack('B', SSH2_AGENT_SIGN_RESPONSE) + return util.frame(code, data) diff --git a/sshagent/server.py b/sshagent/server.py index 05e6a04..488173d 100644 --- a/sshagent/server.py +++ b/sshagent/server.py @@ -31,12 +31,12 @@ def unix_domain_socket_server(sock_path): os.remove(sock_path) -def handle_connection(conn, keys, signer): +def handle_connection(conn, handler): try: log.debug('welcome agent') while True: msg = util.read_frame(conn) - reply = protocol.handle_message(msg=msg, keys=keys, signer=signer) + reply = handler.handle(msg=msg) util.send(conn, reply) except EOFError: log.debug('goodbye agent') @@ -47,6 +47,7 @@ def handle_connection(conn, keys, signer): def server_thread(server, keys, signer): log.debug('server thread started') + handler = protocol.Handler(keys=keys, signer=signer) while True: log.debug('waiting for connection on %s', server.getsockname()) try: @@ -55,7 +56,7 @@ def server_thread(server, keys, signer): log.debug('server error: %s', e, exc_info=True) break with contextlib.closing(conn): - handle_connection(conn, keys, signer) + handle_connection(conn, handler) log.debug('server thread stopped')