diff --git a/llarp/dns/message.cpp b/llarp/dns/message.cpp index f1f628fd7..577debb4c 100644 --- a/llarp/dns/message.cpp +++ b/llarp/dns/message.cpp @@ -42,7 +42,9 @@ namespace llarp return false; if(!buf->read_uint16(ns_count)) return false; - return buf->read_uint16(ar_count); + if(!buf->read_uint16(ar_count)) + return false; + return true; } Message::Message(Message&& other) @@ -66,13 +68,12 @@ namespace llarp } Message::Message(const MessageHeader& hdr) - : hdr_id(hdr.id) - , hdr_fields(hdr.fields) - , questions(size_t(hdr.qd_count)) - , answers(size_t(hdr.an_count)) - , authorities(size_t(hdr.ns_count)) - , additional(size_t(hdr.ar_count)) + : hdr_id(hdr.id), hdr_fields(hdr.fields) { + questions.resize(size_t(hdr.qd_count)); + answers.resize(size_t(hdr.an_count)); + authorities.resize(size_t(hdr.ns_count)); + additional.resize(size_t(hdr.ar_count)); } bool @@ -120,36 +121,32 @@ namespace llarp } llarp::LogDebug(qd); } - for(auto& an : answers) { - if(!an.Decode(buf)) + if(not an.Decode(buf)) { llarp::LogError("failed to decode answer"); return false; } - llarp::LogDebug(an); } - - for(auto& ns : authorities) + /* + for(auto& auth : authorities) { - if(!ns.Decode(buf)) + if(!auth.Decode(buf)) { - llarp::LogError("failed to decode authority"); + llarp::LogError("failed to decode auth"); return false; } - llarp::LogDebug(ns); } - - for(auto& ar : additional) + for(auto& rr : additional) { - if(!ar.Decode(buf)) + if(!rr.Decode(buf)) { - llarp::LogError("failed to decode additonal"); + llarp::LogError("failed to decode additional"); return false; } - llarp::LogDebug(ar); } + */ return true; } @@ -171,9 +168,8 @@ namespace llarp if(questions.size()) { hdr_fields |= flags_QR | flags_AA | flags_RA; - const auto& question = questions[0]; ResourceRecord rec; - rec.rr_name = question.qname; + rec.rr_name = questions[0].qname; rec.rr_class = qClassIN; rec.ttl = ttl; if(isV6) diff --git a/llarp/dns/name.cpp b/llarp/dns/name.cpp index ad51ad6c3..3ae7ed8e3 100644 --- a/llarp/dns/name.cpp +++ b/llarp/dns/name.cpp @@ -23,12 +23,6 @@ namespace llarp buf->cur++; if(l) { - if(l > 63) - { - llarp::LogError("decode name failed, field too big: ", l, " > 63"); - llarp::DumpBuffer(*buf); - return false; - } if(buf->size_left() < l) return false; diff --git a/llarp/dns/rr.cpp b/llarp/dns/rr.cpp index 94795dc6b..97511f4cc 100644 --- a/llarp/dns/rr.cpp +++ b/llarp/dns/rr.cpp @@ -1,5 +1,6 @@ #include - +#include +#include #include #include @@ -28,10 +29,8 @@ namespace llarp bool ResourceRecord::Encode(llarp_buffer_t* buf) const { - if(!EncodeName(buf, rr_name)) - { + if(not EncodeName(buf, rr_name)) return false; - } if(!buf->put_uint16(rr_type)) { return false; @@ -54,11 +53,9 @@ namespace llarp bool ResourceRecord::Decode(llarp_buffer_t* buf) { - if(!DecodeName(buf, rr_name)) - { - llarp::LogError("failed to decode rr name"); + uint16_t discard; + if(!buf->read_uint16(discard)) return false; - } if(!buf->read_uint16(rr_type)) { llarp::LogError("failed to decode rr type"); @@ -76,7 +73,7 @@ namespace llarp } if(!DecodeRData(buf, rData)) { - llarp::LogError("failed to decode rr rdata"); + llarp::LogError("failed to decode rr rdata ", *this); return false; } return true; @@ -86,7 +83,7 @@ namespace llarp ResourceRecord::print(std::ostream& stream, int level, int spaces) const { Printer printer(stream, level, spaces); - printer.printAttribute("RR name", rr_name); + printer.printAttribute("name", rr_name); printer.printAttribute("type", rr_type); printer.printAttribute("class", rr_class); printer.printAttribute("ttl", ttl); @@ -94,5 +91,19 @@ namespace llarp return stream; } + + bool + ResourceRecord::HasCNameForTLD(const std::string& tld) const + { + if(rr_type != qTypeCNAME) + return false; + Name_t name; + llarp_buffer_t buf(rData); + if(not DecodeName(&buf, name)) + return false; + return name.find(tld) != std::string::npos + && name.rfind(tld) == (name.size() - tld.size()) - 1; + } + } // namespace dns } // namespace llarp diff --git a/llarp/dns/rr.hpp b/llarp/dns/rr.hpp index 6bfda15d5..71a32ece4 100644 --- a/llarp/dns/rr.hpp +++ b/llarp/dns/rr.hpp @@ -32,6 +32,9 @@ namespace llarp std::ostream& print(std::ostream& stream, int level, int spaces) const; + bool + HasCNameForTLD(const std::string& tld) const; + Name_t rr_name; RRType_t rr_type; RRClass_t rr_class; diff --git a/llarp/dns/server.cpp b/llarp/dns/server.cpp index 75373282b..f37768554 100644 --- a/llarp/dns/server.cpp +++ b/llarp/dns/server.cpp @@ -1,5 +1,5 @@ #include - +#include #include #include #include @@ -137,22 +137,35 @@ namespace llarp void Proxy::HandlePktClient(llarp::Addr from, Buffer_t buf) { + llarp_buffer_t pkt(buf); MessageHeader hdr; + if(!hdr.Decode(&pkt)) { - llarp_buffer_t pkt(buf); - if(!hdr.Decode(&pkt)) - { - llarp::LogWarn("failed to parse dns header from ", from); - return; - } + llarp::LogWarn("failed to parse dns header from ", from); + return; } TX tx = {hdr.id, from}; auto itr = m_Forwarded.find(tx); if(itr == m_Forwarded.end()) return; - const Addr requester = itr->second; auto self = shared_from_this(); + Message msg(hdr); + if(msg.Decode(&pkt)) + { + if(m_QueryHandler && m_QueryHandler->ShouldHookDNSMessage(msg)) + { + msg.hdr_id = itr->first.txid; + if(!m_QueryHandler->HandleHookedDNSMessage( + std::move(msg), + std::bind(&Proxy::SendServerMessageTo, self, requester, + std::placeholders::_1))) + { + llarp::LogWarn("failed to handle hooked dns"); + } + return; + } + } LogicCall(m_ServerLogic, [=]() { // forward reply to requester via server const llarp_buffer_t tmpbuf(buf); diff --git a/llarp/handlers/tun.cpp b/llarp/handlers/tun.cpp index db9f3a430..ee5be943d 100644 --- a/llarp/handlers/tun.cpp +++ b/llarp/handlers/tun.cpp @@ -397,12 +397,54 @@ namespace llarp { // llarp::LogInfo("Tun.HandleHookedDNSMessage ", msg.questions[0].qname, " // of type", msg.questions[0].qtype); + std::string qname; + if(msg.answers.size() > 0) + { + const auto &answer = msg.answers[0]; + if(answer.HasCNameForTLD(".loki")) + { + dns::Name_t qname; + llarp_buffer_t buf(answer.rData); + if(not dns::DecodeName(&buf, qname, true)) + return false; + service::Address addr; + if(not addr.FromString(qname)) + { + LogError("bad name ", qname); + return false; + } + msg.authorities.resize(0); + msg.additional.resize(0); + msg.answers.resize(0); + msg.hdr_fields &= ~dns::flags_RCODENameError; + if(HasAddress(addr)) + { + huint128_t ip = ObtainIPForAddr(addr, false); + msg.AddINReply(ip, false); + reply(msg); + return true; + } + else + { + auto replyMsg = std::make_shared< dns::Message >(std::move(msg)); + using service::Address; + using service::OutboundContext; + return EnsurePathToService( + addr, + [=](const Address &, OutboundContext *ctx) { + SendDNSReply(addr, ctx, replyMsg, reply, false, false); + }, + 2000); + } + } + } if(msg.questions.size() != 1) { llarp::LogWarn("bad number of dns questions: ", msg.questions.size()); return false; } - const std::string qname = msg.questions[0].Name(); + qname = msg.questions[0].Name(); + if(msg.questions[0].qtype == dns::qTypeMX) { // mx record @@ -595,6 +637,13 @@ namespace llarp return m_OurRange.Contains(ip); } } + for(const auto &answer : msg.answers) + { + if(answer.HasCNameForTLD(".loki")) + return true; + if(answer.HasCNameForTLD(".snode")) + return true; + } return false; }