You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Comrad/p2p/kademlia/protocol.py

180 lines
6.3 KiB
Python

import random
import asyncio
import logging
from rpcudp.protocol import RPCProtocol
from kademlia.node import Node
from kademlia.routing import RoutingTable
from kademlia.utils import digest
log = logging.getLogger(__name__) # pylint: disable=invalid-name
#### PROXY PROTOCOL
class ProxyDatagramProtocol(asyncio.DatagramProtocol):
def __init__(self, remote_address):
self.remote_address = remote_address
self.remotes = {}
super().__init__()
def connection_made(self, transport):
self.transport = transport
def datagram_received(self, data, addr):
if addr in self.remotes:
self.remotes[addr].transport.sendto(data)
return
loop = asyncio.get_event_loop()
self.remotes[addr] = RemoteDatagramProtocol(self, addr, data)
coro = loop.create_datagram_endpoint(
lambda: self.remotes[addr], remote_addr=self.remote_address)
asyncio.ensure_future(coro)
class RemoteDatagramProtocol(asyncio.DatagramProtocol):
def __init__(self, proxy, addr, data):
self.proxy = proxy
self.addr = addr
self.data = data
super().__init__()
def connection_made(self, transport):
self.transport = transport
self.transport.sendto(self.data)
def datagram_received(self, data, _):
self.proxy.transport.sendto(data, self.addr)
def connection_lost(self, exc):
self.proxy.remotes.pop(self.attr)
#####
import logging
log = logging.getLogger(__name__) # pylint: disable=invalid-name
class KademliaProtocol(RPCProtocol):
def __init__(self, source_node, storage, ksize, log=None):
RPCProtocol.__init__(self)
self.router = RoutingTable(self, ksize, source_node)
self.storage = storage
self.source_node = source_node
self.log=log.debug if log is None else log
def get_refresh_ids(self):
"""
Get ids to search for to keep old buckets up to date.
"""
ids = []
for bucket in self.router.lonely_buckets():
rid = random.randint(*bucket.range).to_bytes(20, byteorder='big')
ids.append(rid)
return ids
def rpc_stun(self, sender): # pylint: disable=no-self-use
return sender
def rpc_ping(self, sender, nodeid):
source = Node(nodeid, sender[0], sender[1])
self.welcome_if_new(source)
return self.source_node.id
def rpc_store(self, sender, nodeid, key, value):
source = Node(nodeid, sender[0], sender[1])
self.welcome_if_new(source)
self.log("got a store request from %s, storing '%s' -> %s'" %
(sender, key.hex(), value[:10]))
self.storage[key] = value
return True
def rpc_find_node(self, sender, nodeid, key):
self.log("finding neighbors of %i in local table" %
int(nodeid.hex(), 16))
source = Node(nodeid, sender[0], sender[1])
self.welcome_if_new(source)
node = Node(key)
neighbors = self.router.find_neighbors(node, exclude=source)
return list(map(tuple, neighbors))
def rpc_find_value(self, sender, nodeid, key):
source = Node(nodeid, sender[0], sender[1])
self.welcome_if_new(source)
value = self.storage.get(key, None)
if value is None:
return self.rpc_find_node(sender, nodeid, key)
return {'value': value}
async def call_find_node(self, node_to_ask, node_to_find):
address = (node_to_ask.ip, node_to_ask.port)
result = await self.find_node(address, self.source_node.id,
node_to_find.id)
return self.handle_call_response(result, node_to_ask)
async def call_find_value(self, node_to_ask, node_to_find):
address = (node_to_ask.ip, node_to_ask.port)
result = await self.find_value(address, self.source_node.id,
node_to_find.id)
return self.handle_call_response(result, node_to_ask)
async def call_ping(self, node_to_ask):
address = (node_to_ask.ip, node_to_ask.port)
result = await self.ping(address, self.source_node.id)
return self.handle_call_response(result, node_to_ask)
async def call_store(self, node_to_ask, key, value):
address = (node_to_ask.ip, node_to_ask.port)
result = await self.store(address, self.source_node.id, key, value)
return self.handle_call_response(result, node_to_ask)
def welcome_if_new(self, node):
"""
Given a new node, send it all the keys/values it should be storing,
then add it to the routing table.
@param node: A new node that just joined (or that we just found out
about).
Process:
For each key in storage, get k closest nodes. If newnode is closer
than the furtherst in that list, and the node for this server
is closer than the closest in that list, then store the key/value
on the new node (per section 2.5 of the paper)
"""
if not self.router.is_new_node(node):
return
self.log("never seen %s before, adding to router" % node)
#for key, value in self.storage:
for key in self.storage.keys():
value = self.storage[key]
keynode = Node(digest(key))
neighbors = self.router.find_neighbors(keynode)
if neighbors:
last = neighbors[-1].distance_to(keynode)
new_node_close = node.distance_to(keynode) < last
first = neighbors[0].distance_to(keynode)
this_closest = self.source_node.distance_to(keynode) < first
if not neighbors or (new_node_close and this_closest):
asyncio.ensure_future(self.call_store(node, key, value))
self.router.add_contact(node)
def handle_call_response(self, result, node):
"""
If we get a response, add the node to the routing table. If
we get no response, make sure it's removed from the routing table.
"""
if not result[0]:
self.log("!! no response from %s, removing from router", node)
self.router.remove_contact(node)
return result
self.log("got successful response from %s" % node)
self.welcome_if_new(node)
return result