diff --git a/include/lokinet/lokinet_udp.h b/include/lokinet/lokinet_udp.h index 99d7cc83a..7e5b6ca2b 100644 --- a/include/lokinet/lokinet_udp.h +++ b/include/lokinet/lokinet_udp.h @@ -13,7 +13,7 @@ extern "C" /// remote endpoint's .loki or .snode address char remote_host[256]; /// remote endpont's port - int remote_port; + uint16_t remote_port; /// the socket id for this flow used for i/o purposes and closing this socket int socket_id; }; @@ -32,10 +32,14 @@ extern "C" void** /* flow-userdata */, int* /* timeout seconds */); + /// callback to make a new outbound flow + typedef void(lokinet_udp_create_flow_func)( + void* /*userdata*/, void** /*flow userdata*/, int* /* flowtimeout */); + /// hook function for handling packets typedef void (*lokinet_udp_flow_recv_func)( const struct lokinet_udp_flowinfo* /* remote address */, - char* /* data pointer */, + const char* /* data pointer */, size_t /* data length */, void* /* flow-userdata */); @@ -59,7 +63,7 @@ extern "C" /// @returns nonzero on error in which it is an errno value int EXPORT lokinet_udp_bind( - int exposedPort, + uint16_t exposedPort, lokinet_udp_flow_filter filter, lokinet_udp_flow_recv_func recv, lokinet_udp_flow_timeout_func timeout, @@ -69,13 +73,21 @@ extern "C" /// @brief establish a udp flow to remote endpoint /// + /// @param create_flow the callback to create the new flow if we establish one + /// + /// @param user passed to new_flow as user data + /// /// @param remote the remote address to establish to /// /// @param ctx the lokinet context to use /// /// @return 0 on success, non zero errno on fail int EXPORT - lokinet_udp_establish(const struct lokinet_udp_flowinfo* remote, struct lokinet_context* ctx); + lokinet_udp_establish( + lokinet_udp_create_flow_func create_flow, + void* user, + const struct lokinet_udp_flowinfo* remote, + struct lokinet_context* ctx); /// @brief send on an established flow to remote endpoint /// diff --git a/llarp/CMakeLists.txt b/llarp/CMakeLists.txt index 05a60bbc3..2ccc1712c 100644 --- a/llarp/CMakeLists.txt +++ b/llarp/CMakeLists.txt @@ -54,6 +54,7 @@ add_library(lokinet-platform net/net_int.cpp net/sock_addr.cpp vpn/packet_router.cpp + vpn/egres_packet_router.cpp vpn/platform.cpp ) diff --git a/llarp/handlers/null.hpp b/llarp/handlers/null.hpp index e8f209dc6..174afb60d 100644 --- a/llarp/handlers/null.hpp +++ b/llarp/handlers/null.hpp @@ -5,87 +5,114 @@ #include #include #include +#include -namespace llarp +namespace llarp::handlers { - namespace handlers + struct NullEndpoint final : public llarp::service::Endpoint, + public std::enable_shared_from_this { - struct NullEndpoint final : public llarp::service::Endpoint, - public std::enable_shared_from_this + NullEndpoint(AbstractRouter* r, llarp::service::Context* parent) + : llarp::service::Endpoint{r, parent} + , m_PacketRouter{new vpn::EgresPacketRouter{[](auto, auto) {}}} { - NullEndpoint(AbstractRouter* r, llarp::service::Context* parent) - : llarp::service::Endpoint(r, parent) + r->loop()->add_ticker([this] { Pump(Now()); }); + } + + virtual bool + HandleInboundPacket( + const service::ConvoTag tag, + const llarp_buffer_t& buf, + service::ProtocolType t, + uint64_t) override + { + LogTrace("Inbound ", t, " packet (", buf.sz, "B) on convo ", tag); + if (t == service::ProtocolType::Control) { - r->loop()->add_ticker([this] { Pump(Now()); }); + return true; } - - virtual bool - HandleInboundPacket( - const service::ConvoTag tag, - const llarp_buffer_t& buf, - service::ProtocolType t, - uint64_t) override + if (t == service::ProtocolType::TrafficV4 or t == service::ProtocolType::TrafficV6) { - LogTrace("Inbound ", t, " packet (", buf.sz, "B) on convo ", tag); - if (t == service::ProtocolType::Control) + if (auto from = GetEndpointWithConvoTag(tag)) { + net::IPPacket pkt{}; + if (not pkt.Load(buf)) + { + LogWarn("invalid ip packet from remote T=", tag); + return false; + } + m_PacketRouter->HandleIPPacketFrom(std::move(*from), std::move(pkt)); return true; } - if (t != service::ProtocolType::QUIC) - return false; - - auto* quic = GetQUICTunnel(); - if (!quic) - { - LogWarn("incoming quic packet but this endpoint is not quic capable; dropping"); - return false; - } - if (buf.sz < 4) + else { - LogWarn("invalid incoming quic packet, dropping"); + LogWarn("did not handle packet, no endpoint with convotag T=", tag); return false; } - quic->receive_packet(tag, buf); - return true; } + if (t != service::ProtocolType::QUIC) + return false; - std::string - GetIfName() const override + auto* quic = GetQUICTunnel(); + if (!quic) { - return ""; + LogWarn("incoming quic packet but this endpoint is not quic capable; dropping"); + return false; } - - path::PathSet_ptr - GetSelf() override + if (buf.sz < 4) { - return shared_from_this(); + LogWarn("invalid incoming quic packet, dropping"); + return false; } + quic->receive_packet(tag, buf); + return true; + } - std::weak_ptr - GetWeak() override - { - return weak_from_this(); - } + std::string + GetIfName() const override + { + return ""; + } - bool - SupportsV6() const override - { - return false; - } + path::PathSet_ptr + GetSelf() override + { + return shared_from_this(); + } - void - SendPacketToRemote(const llarp_buffer_t&, service::ProtocolType) override{}; + std::weak_ptr + GetWeak() override + { + return weak_from_this(); + } - huint128_t ObtainIPForAddr(std::variant) override - { - return {0}; - } + bool + SupportsV6() const override + { + return false; + } - std::optional> ObtainAddrForIP( - huint128_t) const override - { - return std::nullopt; - } - }; - } // namespace handlers -} // namespace llarp + void + SendPacketToRemote(const llarp_buffer_t&, service::ProtocolType) override{}; + + huint128_t ObtainIPForAddr(std::variant) override + { + return {0}; + } + + std::optional> ObtainAddrForIP( + huint128_t) const override + { + return std::nullopt; + } + + vpn::EgresPacketRouter* + EgresPacketRouter() override + { + return m_PacketRouter.get(); + } + + private: + std::unique_ptr m_PacketRouter; + }; +} // namespace llarp::handlers diff --git a/llarp/handlers/tun.hpp b/llarp/handlers/tun.hpp index ee7a64b5f..37fe9f37d 100644 --- a/llarp/handlers/tun.hpp +++ b/llarp/handlers/tun.hpp @@ -15,7 +15,7 @@ #include #include #include -#include "service/protocol_type.hpp" +#include namespace llarp { diff --git a/llarp/lokinet_shared.cpp b/llarp/lokinet_shared.cpp index 3020889e2..982b1eef9 100644 --- a/llarp/lokinet_shared.cpp +++ b/llarp/lokinet_shared.cpp @@ -1,7 +1,5 @@ - - -#include "lokinet.h" -#include "llarp.hpp" +#include +#include #include #include @@ -15,6 +13,8 @@ #include #include +#include +#include #ifdef _WIN32 #define EHOSTDOWN ENETDOWN @@ -32,6 +32,165 @@ namespace return std::make_shared(); } }; + + struct UDPFlow + { + using Clock_t = std::chrono::steady_clock; + void* m_FlowUserData; + std::chrono::seconds m_FlowTimeout; + std::chrono::time_point m_ExpiresAt; + lokinet_udp_flowinfo m_FlowInfo; + lokinet_udp_flow_recv_func m_Recv; + + /// call timeout hook for this flow + void + TimedOut(lokinet_udp_flow_timeout_func timeout) + { + timeout(&m_FlowInfo, m_FlowUserData); + } + + /// mark this flow as active + /// updates the expires at timestamp + void + MarkActive() + { + m_ExpiresAt = Clock_t::now() + m_FlowTimeout; + } + + /// returns true if we think this flow is expired + bool + IsExpired() const + { + return Clock_t::now() >= m_ExpiresAt; + } + + void + HandlePacket(const llarp::net::IPPacket& pkt) + { + if (auto maybe = pkt.L4Data()) + { + MarkActive(); + m_Recv(&m_FlowInfo, maybe->first, maybe->second, m_FlowUserData); + } + } + }; + + struct UDPHandler + { + using AddressVariant_t = llarp::vpn::AddressVariant_t; + int m_SocketID; + llarp::nuint16_t m_LocalPort; + lokinet_udp_flow_filter m_Filter; + lokinet_udp_flow_recv_func m_Recv; + lokinet_udp_flow_timeout_func m_Timeout; + void* m_User; + std::weak_ptr m_Endpoint; + + std::unordered_map m_Flows; + + std::mutex m_Access; + + explicit UDPHandler( + int socketid, + llarp::nuint16_t localport, + lokinet_udp_flow_filter filter, + lokinet_udp_flow_recv_func recv, + lokinet_udp_flow_timeout_func timeout, + void* user, + std::weak_ptr ep) + : m_SocketID{socketid} + , m_LocalPort{localport} + , m_Filter{filter} + , m_Recv{recv} + , m_Timeout{timeout} + , m_User{user} + , m_Endpoint{ep} + {} + + void + KillAllFlows() + { + std::unique_lock lock{m_Access}; + for (auto& item : m_Flows) + { + item.second.TimedOut(m_Timeout); + } + m_Flows.clear(); + } + + void + AddFlow( + const AddressVariant_t& from, + const lokinet_udp_flowinfo& flow_addr, + void* flow_userdata, + int flow_timeoutseconds) + { + std::unique_lock lock{m_Access}; + auto& flow = m_Flows[from]; + flow.m_FlowInfo = flow_addr; + flow.m_FlowTimeout = std::chrono::seconds{flow_timeoutseconds}; + flow.m_FlowUserData = flow_userdata; + } + + void + ExpireOldFlows() + { + std::unique_lock lock{m_Access}; + for (auto itr = m_Flows.begin(); itr != m_Flows.end();) + { + if (itr->second.IsExpired()) + { + itr->second.TimedOut(m_Timeout); + itr = m_Flows.erase(itr); + } + else + ++itr; + } + } + + void + HandlePacketFrom(AddressVariant_t from, llarp::net::IPPacket pkt) + { + bool isNewFlow{false}; + { + std::unique_lock lock{m_Access}; + isNewFlow = m_Flows.count(from) == 0; + } + if (isNewFlow) + { + lokinet_udp_flowinfo flow_addr{}; + // set flow remote address + var::visit( + [&flow_addr](auto&& from) { + const auto addr = from.ToString(); + std::copy_n( + addr.data(), + std::min(addr.size(), sizeof(flow_addr.remote_host)), + flow_addr.remote_host); + }, + from); + // set socket id + flow_addr.socket_id = m_SocketID; + // get source port + if (auto srcport = pkt.SrcPort()) + { + flow_addr.remote_port = ToHost(*srcport).h; + } + else + return; // invalid data so we bail + void* flow_userdata = nullptr; + int flow_timeoutseconds{}; + // got a new flow, let's check if we want it + if (m_Filter(m_User, &flow_addr, &flow_userdata, &flow_timeoutseconds)) + return; + AddFlow(from, flow_addr, flow_userdata, flow_timeoutseconds); + } + { + std::unique_lock lock{m_Access}; + m_Flows[from].HandlePacket(pkt); + } + } + }; } // namespace struct lokinet_context @@ -43,7 +202,10 @@ struct lokinet_context std::unique_ptr runner; - lokinet_context() : impl{std::make_shared()}, config{llarp::Config::EmbeddedConfig()} + int _socket_id; + + lokinet_context() + : impl{std::make_shared()}, config{llarp::Config::EmbeddedConfig()}, _socket_id{0} {} ~lokinet_context() @@ -52,6 +214,69 @@ struct lokinet_context runner->join(); } + int + next_socket_id() + { + int id = ++_socket_id; + // handle overflow + if (id < 0) + { + _socket_id = 0; + id = ++_socket_id; + } + return id; + } + + /// make a udp handler and hold onto it + /// return its id + [[nodiscard]] int + make_udp_handler( + const std::shared_ptr& ep, + llarp::huint16_t exposePort, + lokinet_udp_flow_filter filter, + lokinet_udp_flow_recv_func recv, + lokinet_udp_flow_timeout_func timeout, + void* user) + { + if (udp_sockets.empty()) + { + // start udp flow expiration timer + impl->router->loop()->call_every(1s, std::make_shared(0), [this]() { + std::unique_lock lock{m_access}; + for (auto& item : udp_sockets) + { + item.second->ExpireOldFlows(); + } + }); + } + + auto udp = std::make_unique( + next_socket_id(), llarp::ToNet(exposePort), filter, recv, timeout, user, std::weak_ptr{ep}); + auto id = udp->m_SocketID; + auto pkt = ep->EgresPacketRouter(); + pkt->AddUDPHandler(exposePort, [udp = udp.get(), this](auto from, auto pkt) { + udp->HandlePacketFrom(std::move(from), std::move(pkt)); + }); + udp_sockets[udp->m_SocketID] = std::move(udp); + return id; + } + + void + remove_udp_handler(int socket_id) + { + std::unique_ptr udp; + { + std::unique_lock lock{m_access}; + if (auto itr = udp_sockets.find(socket_id); itr != udp_sockets.end()) + { + udp = std::move(itr->second); + udp_sockets.erase(itr); + } + } + if (udp) + udp->KillAllFlows(); + } + /// acquire mutex for accessing this context [[nodiscard]] auto acquire() @@ -66,6 +291,7 @@ struct lokinet_context } std::unordered_map streams; + std::unordered_map> udp_sockets; void inbound_stream(int id) @@ -82,8 +308,6 @@ struct lokinet_context namespace { - std::unique_ptr g_context; - void stream_error(lokinet_stream_result* result, int err) { @@ -359,11 +583,11 @@ extern "C" return; auto lock = ctx->acquire(); - if (not ctx->impl->IsStopping()) - { - ctx->impl->CloseAsync(); - ctx->impl->Wait(); - } + if (ctx->impl->IsStopping()) + return; + + ctx->impl->CloseAsync(); + ctx->impl->Wait(); if (ctx->runner) ctx->runner->join(); @@ -626,4 +850,148 @@ extern "C" delete result->internal; result->internal = nullptr; } + + int EXPORT + lokinet_udp_bind( + uint16_t exposedPort, + lokinet_udp_flow_filter filter, + lokinet_udp_flow_recv_func recv, + lokinet_udp_flow_timeout_func timeout, + void* user, + struct lokinet_udp_bind_result* result, + struct lokinet_context* ctx) + { + if (filter == nullptr or recv == nullptr or timeout == nullptr or result == nullptr + or ctx == nullptr) + return EINVAL; + + auto lock = ctx->acquire(); + if (auto ep = ctx->endpoint()) + { + result->socket_id = + ctx->make_udp_handler(ep, llarp::huint16_t{exposedPort}, filter, recv, timeout, user); + return 0; + } + else + return EINVAL; + } + + void EXPORT + lokinet_udp_close(int socket_id, struct lokinet_context* ctx) + { + if (ctx) + ctx->remove_udp_handler(socket_id); + } + + int EXPORT + lokinet_udp_flow_send( + const struct lokinet_udp_flowinfo* remote, + const void* ptr, + size_t len, + struct lokinet_context* ctx) + { + if (remote == nullptr or remote->remote_port == 0 or ptr == nullptr or len == 0 + or ctx == nullptr) + return EINVAL; + std::shared_ptr ep; + llarp::nuint16_t srcport{0}; + llarp::nuint16_t dstport{llarp::ToNet(llarp::huint16_t{remote->remote_port})}; + { + auto lock = ctx->acquire(); + if (auto itr = ctx->udp_sockets.find(remote->socket_id); itr != ctx->udp_sockets.end()) + { + ep = itr->second->m_Endpoint.lock(); + srcport = itr->second->m_LocalPort; + } + else + return EHOSTUNREACH; + } + if (auto maybe = llarp::service::ParseAddress(std::string{remote->remote_host})) + { + llarp::net::IPPacket pkt = llarp::net::IPPacket::UDP( + llarp::nuint32_t{0}, + srcport, + llarp::nuint32_t{0}, + dstport, + llarp_buffer_t{reinterpret_cast(ptr), len}); + + if (pkt.sz == 0) + return EINVAL; + std::promise ret; + ctx->impl->router->loop()->call_soon([addr = *maybe, pkt = std::move(pkt), ep, &ret]() { + if (auto tag = ep->GetBestConvoTagFor(addr)) + { + if (ep->SendToOrQueue(*tag, pkt.ConstBuffer(), llarp::service::ProtocolType::TrafficV4)) + { + ret.set_value(0); + return; + } + } + ret.set_value(ENETUNREACH); + }); + return ret.get_future().get(); + } + return EINVAL; + } + + int EXPORT + lokinet_udp_establish( + lokinet_udp_create_flow_func create_flow, + void* user, + const struct lokinet_udp_flowinfo* remote, + struct lokinet_context* ctx) + { + if (create_flow == nullptr or remote == nullptr or ctx == nullptr) + return EINVAL; + std::shared_ptr ep; + { + auto lock = ctx->acquire(); + if (auto itr = ctx->udp_sockets.find(remote->socket_id); itr != ctx->udp_sockets.end()) + { + ep = itr->second->m_Endpoint.lock(); + } + else + return EHOSTUNREACH; + } + if (auto maybe = llarp::service::ParseAddress(std::string{remote->remote_host})) + { + { + // check for pre existing flow + auto lock = ctx->acquire(); + if (auto itr = ctx->udp_sockets.find(remote->socket_id); itr != ctx->udp_sockets.end()) + { + auto& udp = itr->second; + if (udp->m_Flows.count(*maybe)) + { + // we already have a flow. + return EADDRINUSE; + } + } + } + std::promise gotten; + ctx->impl->router->loop()->call_soon([addr = *maybe, ep, &gotten]() { + ep->EnsurePathTo( + addr, [&gotten](auto result) { gotten.set_value(result.has_value()); }, 5s); + }); + if (gotten.get_future().get()) + { + void* flow_data{nullptr}; + int flow_timeoutseconds{}; + create_flow(user, &flow_data, &flow_timeoutseconds); + { + auto lock = ctx->acquire(); + if (auto itr = ctx->udp_sockets.find(remote->socket_id); itr != ctx->udp_sockets.end()) + { + itr->second->AddFlow(*maybe, *remote, flow_data, flow_timeoutseconds); + return 0; + } + else + return EADDRINUSE; + } + } + else + return ETIMEDOUT; + } + return EINVAL; + } } diff --git a/llarp/net/ip_packet.cpp b/llarp/net/ip_packet.cpp index c31c2332d..b1deaaeea 100644 --- a/llarp/net/ip_packet.cpp +++ b/llarp/net/ip_packet.cpp @@ -128,6 +128,19 @@ namespace llarp } } + std::optional + IPPacket::SrctPort() const + { + switch (IPProtocol{Header()->protocol}) + { + case IPProtocol::TCP: + case IPProtocol::UDP: + return nuint16_t{*reinterpret_cast(buf + (Header()->ihl * 4))}; + default: + return std::nullopt; + } + } + huint32_t IPPacket::srcv4() const { @@ -571,6 +584,26 @@ namespace llarp return std::nullopt; } + std::optional> + IPPacket::L4Data() const + { + const auto* hdr = Header(); + size_t l4_HeaderSize = 0; + if (hdr->protocol == 0x11) + { + l4_HeaderSize = 8; + } + else + return std::nullopt; + + // check for invalid size + if (sz < (hdr->ihl * 4) + l4_HeaderSize) + return std::nullopt; + + const uint8_t* ptr = buf + ((hdr->ihl * 4) + l4_HeaderSize); + return std::make_pair(reinterpret_cast(ptr), std::distance(ptr, buf + sz)); + } + IPPacket IPPacket::UDP( nuint32_t srcaddr, diff --git a/llarp/net/ip_packet.hpp b/llarp/net/ip_packet.hpp index 9f4b91fb6..0987633aa 100644 --- a/llarp/net/ip_packet.hpp +++ b/llarp/net/ip_packet.hpp @@ -293,6 +293,14 @@ namespace llarp std::optional DstPort() const; + /// get source port if applicable + std::optional + SrcPort() const; + + /// get pointer and size of layer 4 data + std::optional> + L4Data() const; + void UpdateIPv4Address(nuint32_t src, nuint32_t dst); diff --git a/llarp/service/endpoint.hpp b/llarp/service/endpoint.hpp index b15add487..48308a525 100644 --- a/llarp/service/endpoint.hpp +++ b/llarp/service/endpoint.hpp @@ -28,6 +28,8 @@ #include +#include + // minimum time between introset shifts #ifndef MIN_SHIFT_INTERVAL #define MIN_SHIFT_INTERVAL 5s @@ -168,6 +170,12 @@ namespace llarp void HandlePathDied(path::Path_ptr p) override; + virtual vpn::EgresPacketRouter* + EgresPacketRouter() + { + return nullptr; + }; + bool PublishIntroSet(const EncryptedIntroSet& i, AbstractRouter* r) override; diff --git a/llarp/vpn/egres_packet_router.cpp b/llarp/vpn/egres_packet_router.cpp new file mode 100644 index 000000000..f892f0832 --- /dev/null +++ b/llarp/vpn/egres_packet_router.cpp @@ -0,0 +1,101 @@ +#include "egres_packet_router.hpp" + +namespace llarp::vpn +{ + struct EgresUDPPacketHandler : public EgresLayer4Handler + { + EgresPacketHandlerFunc m_BaseHandler; + std::unordered_map m_LocalPorts; + + explicit EgresUDPPacketHandler(EgresPacketHandlerFunc baseHandler) + : m_BaseHandler{std::move(baseHandler)} + {} + + void + AddSubHandler(nuint16_t localport, EgresPacketHandlerFunc handler) override + { + m_LocalPorts.emplace(localport, std::move(handler)); + } + + void + RemoveSubHandler(nuint16_t localport) override + { + m_LocalPorts.erase(localport); + } + + void + HandleIPPacketFrom(AddressVariant_t from, net::IPPacket pkt) override + { + const uint8_t* ptr = pkt.buf + (pkt.Header()->ihl * 4) + 2; + const nuint16_t dstPort{*reinterpret_cast(ptr)}; + if (auto itr = m_LocalPorts.find(dstPort); itr != m_LocalPorts.end()) + { + itr->second(std::move(from), std::move(pkt)); + } + else + m_BaseHandler(std::move(from), std::move(pkt)); + } + }; + + struct EgresGenericLayer4Handler : public EgresLayer4Handler + { + EgresPacketHandlerFunc m_BaseHandler; + + explicit EgresGenericLayer4Handler(EgresPacketHandlerFunc baseHandler) + : m_BaseHandler{std::move(baseHandler)} + {} + + void + HandleIPPacketFrom(AddressVariant_t from, net::IPPacket pkt) override + { + m_BaseHandler(std::move(from), std::move(pkt)); + } + }; + + EgresPacketRouter::EgresPacketRouter(EgresPacketHandlerFunc baseHandler) + : m_BaseHandler{std::move(baseHandler)} + {} + + void + EgresPacketRouter::HandleIPPacketFrom(AddressVariant_t from, net::IPPacket pkt) + { + const auto proto = pkt.Header()->protocol; + if (const auto itr = m_IPProtoHandler.find(proto); itr != m_IPProtoHandler.end()) + { + itr->second->HandleIPPacketFrom(std::move(from), std::move(pkt)); + } + else + m_BaseHandler(std::move(from), std::move(pkt)); + } + + namespace + { + constexpr byte_t udp_proto = 0x11; + } + + void + EgresPacketRouter::AddUDPHandler(huint16_t localport, EgresPacketHandlerFunc func) + { + if (m_IPProtoHandler.find(udp_proto) == m_IPProtoHandler.end()) + { + m_IPProtoHandler.emplace(udp_proto, std::make_unique(m_BaseHandler)); + } + m_IPProtoHandler[udp_proto]->AddSubHandler(ToNet(localport), func); + } + + void + EgresPacketRouter::AddIProtoHandler(uint8_t proto, EgresPacketHandlerFunc func) + { + m_IPProtoHandler[proto] = std::make_unique(std::move(func)); + } + + void + EgresPacketRouter::RemoveUDPHandler(huint16_t localport) + { + if (auto itr = m_IPProtoHandler.find(udp_proto); itr != m_IPProtoHandler.end()) + { + itr->second->RemoveSubHandler(ToNet(localport)); + } + } + +} // namespace llarp::vpn diff --git a/llarp/vpn/egres_packet_router.hpp b/llarp/vpn/egres_packet_router.hpp new file mode 100644 index 000000000..8b074267d --- /dev/null +++ b/llarp/vpn/egres_packet_router.hpp @@ -0,0 +1,49 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace llarp::vpn +{ + using AddressVariant_t = llarp::EndpointBase::AddressVariant_t; + using EgresPacketHandlerFunc = std::function; + + struct EgresLayer4Handler + { + virtual ~EgresLayer4Handler() = default; + + virtual void + HandleIPPacketFrom(AddressVariant_t from, net::IPPacket pkt) = 0; + + virtual void AddSubHandler(nuint16_t, EgresPacketHandlerFunc){}; + virtual void RemoveSubHandler(nuint16_t){}; + }; + + class EgresPacketRouter + { + EgresPacketHandlerFunc m_BaseHandler; + std::unordered_map> m_IPProtoHandler; + + public: + /// baseHandler will be called if no other handlers matches a packet + explicit EgresPacketRouter(EgresPacketHandlerFunc baseHandler); + + /// feed in an ip packet for handling + void + HandleIPPacketFrom(AddressVariant_t, net::IPPacket pkt); + + /// add a non udp packet handler using ip protocol proto + void + AddIProtoHandler(uint8_t proto, EgresPacketHandlerFunc func); + + /// helper that adds a udp packet handler for UDP destinted for localport + void + AddUDPHandler(huint16_t localport, EgresPacketHandlerFunc func); + + /// remove a udp handler that is already set up by bound port + void + RemoveUDPHandler(huint16_t localport); + }; +} // namespace llarp::vpn diff --git a/llarp/vpn/packet_router.hpp b/llarp/vpn/packet_router.hpp index e84454eae..ee0721a05 100644 --- a/llarp/vpn/packet_router.hpp +++ b/llarp/vpn/packet_router.hpp @@ -17,7 +17,6 @@ namespace llarp::vpn virtual void AddSubHandler(nuint16_t, PacketHandlerFunc){}; }; - class PacketRouter { PacketHandlerFunc m_BaseHandler; @@ -38,5 +37,9 @@ namespace llarp::vpn /// helper that adds a udp packet handler for UDP destinted for localport void AddUDPHandler(huint16_t localport, PacketHandlerFunc func); + + /// remove a udp handler that is already set up by bound port + void + RemoveUDPHandler(huint16_t localport); }; } // namespace llarp::vpn