diff --git a/include/llarp.hpp b/include/llarp.hpp index 33fe6bfc1..0d6e5449c 100644 --- a/include/llarp.hpp +++ b/include/llarp.hpp @@ -47,9 +47,6 @@ namespace llarp virtual ~Context() = default; - void - Close(); - void Setup(const RuntimeOptions& opts); @@ -73,6 +70,9 @@ namespace llarp bool LooksAlive() const; + bool + IsStopping() const; + /// close async void CloseAsync(); @@ -111,6 +111,9 @@ namespace llarp void SigINT(); + void + Close(); + std::unique_ptr> closeWaiter; }; diff --git a/include/lokinet.h b/include/lokinet.h index 92a754ff0..9135c3f97 100644 --- a/include/lokinet.h +++ b/include/lokinet.h @@ -28,7 +28,7 @@ extern "C" lokinet_context_stop(struct lokinet_context*); /// get default lokinet context - /// does not need to be freed by lokinet_context_free + /// not to be freed by lokinet_context_free struct lokinet_context* lokinet_default(); @@ -40,35 +40,44 @@ extern "C" int errno; /// the local ip address we mapped the remote endpoint to - char* local_address; + /// null terminated + char local_address[256]; /// the local port we mapped the remote endpoint to int local_port; + /// the id of the stream we created + int stream_id; }; /// connect out to a remote endpoint - /// remote is in the form of "name:port" - /// returns NULL if context was NULL or not started - /// returns a free()-able lokinet_stream_result * that contains the result + /// remoteAddr is in the form of "name:port" + /// localAddr is either NULL for any or in the form of "ip:port" to bind to an explicit address + /// returns a lokinet_stream_result * that contains the result that can be free()'d struct lokinet_stream_result* - lokinet_outbound_stream(const char* remote, struct lokinet_context* context); + lokinet_outbound_stream( + const char* remoteAddr, const char* localAddr, struct lokinet_context* context); /// stream accept filter determines if we should accept a stream or not /// return 0 to accept /// return -1 to explicitly reject /// return -2 to silently drop - typedef int (*lokinet_stream_filter)(const char*, uint16_t, struct sockaddr* const, void*); + typedef int (*lokinet_stream_filter)(const char* remote, uint16_t port, void*); /// set stream accepter filter /// passes user parameter into stream filter as void * - void + /// returns stream id + int lokinet_inbound_stream_filter( lokinet_stream_filter acceptFilter, void* user, struct lokinet_context* context); /// simple stream acceptor /// simple variant of lokinet_inbound_stream_filter that maps port to localhost:port - void + int lokinet_inbound_stream(uint16_t port, struct lokinet_context* context); + /// close a stream by id + void + lokinet_close_stream(int stream_id, struct lokinet_context* context); + #ifdef __cplusplus } #endif diff --git a/llarp/config/config.cpp b/llarp/config/config.cpp index e9ad25752..437c25365 100644 --- a/llarp/config/config.cpp +++ b/llarp/config/config.cpp @@ -1353,4 +1353,12 @@ namespace llarp return def.generateINIConfig(true); } + std::shared_ptr + Config::EmbeddedConfig() + { + auto config = std::make_shared(fs::path{""}); + config->network.m_endpointType = "null"; + return config; + } + } // namespace llarp diff --git a/llarp/config/config.hpp b/llarp/config/config.hpp index 9d4072c7c..9c389001b 100644 --- a/llarp/config/config.hpp +++ b/llarp/config/config.hpp @@ -251,6 +251,10 @@ namespace llarp void AddDefault(std::string section, std::string key, std::string value); + /// create a config with the default parameters for an embedded lokinet + static std::shared_ptr + EmbeddedConfig(); + private: /// Load (initialize) a default config. /// diff --git a/llarp/context.cpp b/llarp/context.cpp index 5ec161b4e..5cb4aa901 100644 --- a/llarp/context.cpp +++ b/llarp/context.cpp @@ -129,13 +129,19 @@ namespace llarp Context::CloseAsync() { /// already closing - if (closeWaiter) + if (IsStopping()) return; if (CallSafe(std::bind(&Context::HandleSignal, this, SIGTERM))) closeWaiter = std::make_unique>(); } + bool + Context::IsStopping() const + { + return closeWaiter.operator bool(); + } + void Context::Wait() { diff --git a/llarp/handlers/null.hpp b/llarp/handlers/null.hpp index b038cbdfc..4f52ab1cc 100644 --- a/llarp/handlers/null.hpp +++ b/llarp/handlers/null.hpp @@ -16,36 +16,15 @@ namespace llarp NullEndpoint(AbstractRouter* r, llarp::service::Context* parent) : llarp::service::Endpoint(r, parent) { - r->loop()->add_ticker([this] { - while (not m_InboundQuic.empty()) - { - m_InboundQuic.top().process(); - m_InboundQuic.pop(); - } - Pump(Now()); - }); + r->loop()->add_ticker([this] { Pump(Now()); }); } - struct QUICEvent - { - uint64_t seqno; - std::function process; - - bool - operator<(const QUICEvent& other) const - { - return other.seqno < seqno; - } - }; - - std::priority_queue m_InboundQuic; - virtual bool HandleInboundPacket( const service::ConvoTag tag, const llarp_buffer_t& buf, service::ProtocolType t, - uint64_t seqno) override + uint64_t) override { LogTrace("Inbound ", t, " packet (", buf.sz, "B) on convo ", tag); if (t == service::ProtocolType::Control) @@ -68,11 +47,7 @@ namespace llarp return false; } MarkConvoTagActive(tag); - std::vector copy; - copy.resize(buf.sz); - std::copy_n(buf.base, buf.sz, copy.data()); - m_InboundQuic.push({seqno, [quic, buf = copy, tag]() { quic->receive_packet(tag, buf); }}); - m_router->loop()->wakeup(); + quic->receive_packet(tag, buf); return true; } diff --git a/llarp/lokinet_shared.cpp b/llarp/lokinet_shared.cpp index 0bbdb5287..c21b0b2a2 100644 --- a/llarp/lokinet_shared.cpp +++ b/llarp/lokinet_shared.cpp @@ -4,8 +4,16 @@ #include "llarp.hpp" #include "config/config.hpp" +#include +#include +#include + +#include + struct lokinet_context { + std::mutex m_access; + std::shared_ptr impl; std::unique_ptr runner; @@ -18,10 +26,85 @@ struct lokinet_context if (runner) runner->join(); } + + /// acquire mutex for accessing this context + [[nodiscard]] auto + acquire() + { + return std::unique_lock{m_access}; + } + + std::unordered_map streams; + + void + inbound_stream(int id) + { + streams[id] = true; + } + + void + outbound_stream(int id) + { + streams[id] = false; + } }; -struct lokinet_context g_context -{}; +namespace +{ + struct lokinet_context g_context + {}; + + lokinet_stream_result* + stream_error(int err) + { + return new lokinet_stream_result{err, {0}, 0, 0}; + } + + lokinet_stream_result* + stream_okay(std::string host, int port, int stream_id) + { + auto* result = new lokinet_stream_result{}; + std::copy_n( + host.c_str(), + std::min(host.size(), sizeof(result->local_address) - 1), + result->local_address); + result->local_port = port; + result->stream_id = stream_id; + return result; + } + + std::pair + split_host_port(std::string data, std::string proto = "tcp") + { + std::string host, portStr; + if (auto pos = data.find(":"); pos != std::string::npos) + { + host = data.substr(0, pos); + portStr = data.substr(pos + 1); + } + else + throw EINVAL; + + if (auto* serv = getservbyname(portStr.c_str(), proto.c_str())) + { + return {host, serv->s_port}; + } + else + throw(errno ? errno : EINVAL); + } + + int + accept_port(const char* remote, uint16_t port, void* ptr) + { + (void)remote; + if (port == *static_cast(ptr)) + { + return 0; + } + return -1; + } + +} // namespace extern "C" { @@ -40,24 +123,236 @@ extern "C" void lokinet_context_free(struct lokinet_context* ctx) { + lokinet_context_stop(ctx); delete ctx; } void lokinet_context_start(struct lokinet_context* ctx) { + if (not ctx) + return; + auto lock = ctx->acquire(); ctx->runner = std::make_unique([ctx]() { - auto config = std::make_shared(fs::path{""}); - ctx->impl->Configure(config); + ctx->impl->Configure(llarp::Config::EmbeddedConfig()); const llarp::RuntimeOptions opts{}; ctx->impl->Setup(opts); + ctx->impl->Run(opts); }); } void lokinet_context_stop(struct lokinet_context* ctx) { - ctx->impl->CloseAsync(); - ctx->impl->Wait(); + if (not ctx) + return; + auto lock = ctx->acquire(); + + if (not ctx->impl->IsStopping()) + { + ctx->impl->CloseAsync(); + ctx->impl->Wait(); + } + + if (ctx->runner) + ctx->runner->join(); + + ctx->runner.reset(); + } + + struct lokinet_stream_result* + lokinet_outbound_stream(const char* remote, const char* local, struct lokinet_context* ctx) + { + if (ctx == nullptr) + return stream_error(EHOSTDOWN); + + std::promise promise; + + { + auto lock = ctx->acquire(); + + if (not ctx->impl->IsUp()) + return stream_error(EHOSTDOWN); + + std::string remotehost; + int remoteport; + try + { + auto [h, p] = split_host_port(remote); + remotehost = h; + remoteport = p; + } + catch (int err) + { + return stream_error(err); + } + // TODO: make configurable (?) + std::string endpoint{"default"}; + + llarp::SockAddr localAddr; + try + { + if (local) + localAddr = llarp::SockAddr{std::string{local}}; + else + localAddr = llarp::SockAddr{"127.0.0.1:0"}; + } + catch (std::exception& ex) + { + return stream_error(EINVAL); + } + auto call = [&promise, + ctx, + router = ctx->impl->router, + remotehost, + remoteport, + endpoint, + localAddr]() { + auto ep = router->hiddenServiceContext().GetEndpointByName(endpoint); + if (ep == nullptr) + { + promise.set_value(stream_error(EHOSTUNREACH)); + return; + } + auto* quic = ep->GetQUICTunnel(); + if (quic == nullptr) + { + promise.set_value(stream_error(ENOTSUP)); + return; + } + try + { + auto [addr, id] = quic->open( + remotehost, remoteport, [](auto&&) {}, localAddr); + auto [host, port] = split_host_port(addr.toString()); + ctx->outbound_stream(id); + promise.set_value(stream_okay(host, port, id)); + } + catch (std::exception& ex) + { + promise.set_value(stream_error(ECANCELED)); + } + catch (int err) + { + promise.set_value(stream_error(err)); + } + }; + + ctx->impl->CallSafe([call]() { + // we dont want the mainloop to die in case setting the value on the promise fails + try + { + call(); + } + catch (...) + {} + }); + } + + auto future = promise.get_future(); + try + { + if (auto status = future.wait_for(std::chrono::seconds{10}); + status == std::future_status::ready) + { + return future.get(); + } + else + { + promise.set_value(stream_error(ETIMEDOUT)); + return future.get(); + } + } + catch (std::exception& ex) + { + return stream_error(EBADF); + } + } + + int + lokinet_inbound_stream(uint16_t port, struct lokinet_context* ctx) + { + /// FIXME: delete pointer later + return lokinet_inbound_stream_filter(&accept_port, (void*)new std::uintptr_t{port}, ctx); + } + + int + lokinet_inbound_stream_filter( + lokinet_stream_filter acceptFilter, void* user, struct lokinet_context* ctx) + { + if (acceptFilter == nullptr) + { + acceptFilter = [](auto, auto, auto) { return 0; }; + } + if (not ctx) + return -1; + std::promise promise; + { + auto lock = ctx->acquire(); + if (not ctx->impl->IsUp()) + { + return -1; + } + + ctx->impl->CallSafe([router = ctx->impl->router, acceptFilter, user, &promise]() { + auto ep = router->hiddenServiceContext().GetEndpointByName("default"); + auto* quic = ep->GetQUICTunnel(); + auto id = quic->listen( + [acceptFilter, user](auto remoteAddr, auto port) -> std::optional { + std::string remote{remoteAddr}; + if (auto result = acceptFilter(remote.c_str(), port, user)) + { + if (result == -1) + { + throw std::invalid_argument{"rejected"}; + } + } + else + return llarp::SockAddr{"127.0.0.1:" + std::to_string(port)}; + return std::nullopt; + }); + promise.set_value(id); + }); + } + auto ftr = promise.get_future(); + auto id = ftr.get(); + { + auto lock = ctx->acquire(); + ctx->inbound_stream(id); + } + return id; + } + + void + lokinet_close_stream(int stream_id, struct lokinet_context* ctx) + { + if (not ctx) + return; + auto lock = ctx->acquire(); + if (not ctx->impl->IsUp()) + return; + + try + { + std::promise promise; + bool inbound = ctx->streams.at(stream_id); + ctx->impl->CallSafe([stream_id, inbound, router = ctx->impl->router, &promise]() { + auto ep = router->hiddenServiceContext().GetEndpointByName("default"); + auto* quic = ep->GetQUICTunnel(); + try + { + if (inbound) + quic->forget(stream_id); + else + quic->close(stream_id); + } + catch (...) + {} + promise.set_value(); + }); + promise.get_future().get(); + } + catch (...) + {} } } diff --git a/llarp/service/endpoint.cpp b/llarp/service/endpoint.cpp index ee5db4432..5eb5ea3b2 100644 --- a/llarp/service/endpoint.cpp +++ b/llarp/service/endpoint.cpp @@ -1378,6 +1378,7 @@ namespace llarp } if (not SendToOrQueue(*maybe, pkt, t)) return false; + MarkConvoTagActive(tag); Loop()->wakeup(); return true; } @@ -1425,7 +1426,11 @@ namespace llarp msg.payload.size(), " bytes seqno=", msg.seqno); - if (not HandleInboundPacket(msg.tag, msg.payload, msg.proto, msg.seqno)) + if (HandleInboundPacket(msg.tag, msg.payload, msg.proto, msg.seqno)) + { + MarkConvoTagActive(msg.tag); + } + else { LogWarn("Failed to handle inbound message"); }