diff --git a/README.md b/README.md index 625c220..a3e2a58 100644 --- a/README.md +++ b/README.md @@ -109,6 +109,10 @@ Putting it in a directory that is only readable by the super-user is not a bad i Domains can be filtered directly by the proxy, see the `[filtering]` section of the configuration file. +## Access control + +Access control can be enabled in the `[access_control]` section and configured with the `query_meta` configuration value of `dnscrypt-proxy`. + ## Prometheus metrics Prometheus metrics can optionally be enabled in order to monitor performance, cache efficiency, and more. diff --git a/example-encrypted-dns.toml b/example-encrypted-dns.toml index a8226e2..099ed54 100644 --- a/example-encrypted-dns.toml +++ b/example-encrypted-dns.toml @@ -221,3 +221,24 @@ allow_non_reserved_ports = false # Blacklisted upstream IP addresses blacklisted_ips = [ "93.184.216.34" ] + + + + +################################ +# Access control # +################################ + +[access_control] + +# Enabled access control + +enabled = false + +# Only allow access to client queries including one of these random tokens +# Tokens can be configured in the `query_meta` section of `dnscrypt-proxy` as +# `query_meta = ["token:..."]` -- Replace ... with the token to use by the client. +# Example: `query_meta = ["token:Y2oHkDJNHz"]` + +tokens = ["Y2oHkDJNHz", "G5zY3J5cHQtY", "C5zZWN1cmUuZG5z"] + diff --git a/src/config.rs b/src/config.rs index a59249a..8a1e3e6 100644 --- a/src/config.rs +++ b/src/config.rs @@ -9,6 +9,12 @@ use std::net::{IpAddr, SocketAddr}; use std::path::{Path, PathBuf}; use tokio::prelude::*; +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct AccessControlConfig { + pub enabled: bool, + pub tokens: Vec, +} + #[derive(Serialize, Deserialize, Debug, Clone)] pub struct AnonymizedDNSConfig { pub enabled: bool, @@ -79,6 +85,7 @@ pub struct Config { #[cfg(feature = "metrics")] pub metrics: Option, pub anonymized_dns: Option, + pub access_control: Option, } impl Config { diff --git a/src/dns.rs b/src/dns.rs index 2b92cd4..3c41721 100644 --- a/src/dns.rs +++ b/src/dns.rs @@ -104,6 +104,12 @@ fn arcount_inc(packet: &mut [u8]) -> Result<(), Error> { Ok(()) } +#[inline] +fn arcount_clear(packet: &mut [u8]) -> Result<(), Error> { + BigEndian::write_u16(&mut packet[10..], 0); + Ok(()) +} + #[inline] pub fn an_ns_ar_count_clear(packet: &mut [u8]) { packet[6..12].iter_mut().for_each(|x| *x = 0); @@ -312,13 +318,13 @@ fn traverse_rrs Result<(), Error>>( for _ in 0..rrcount { offset = skip_name(packet, offset)?; ensure!(packet_len - offset >= 10, "Short packet"); - cb(offset)?; let rdlen = BigEndian::read_u16(&packet[offset + 8..]) as usize; - offset += 10; ensure!( - packet_len - offset >= rdlen, + packet_len - offset >= 10 + rdlen, "Record length would exceed packet length" ); + cb(offset)?; + offset += 10; offset += rdlen; } Ok(offset) @@ -334,13 +340,13 @@ fn traverse_rrs_mut Result<(), Error>>( for _ in 0..rrcount { offset = skip_name(packet, offset)?; ensure!(packet_len - offset >= 10, "Short packet"); - cb(packet, offset)?; let rdlen = BigEndian::read_u16(&packet[offset + 8..]) as usize; - offset += 10; ensure!( - packet_len - offset >= rdlen, + packet_len - offset >= 10 + rdlen, "Record length would exceed packet length" ); + cb(packet, offset)?; + offset += 10; offset += rdlen; } Ok(offset) @@ -502,6 +508,73 @@ pub fn qtype_qclass(packet: &[u8]) -> Result<(u16, u16), Error> { Ok((qtype, qclass)) } +fn parse_txt_rrdata Result<(), Error>>( + rrdata: &[u8], + mut cb: F, +) -> Result<(), Error> { + let rrdata_len = rrdata.len(); + let mut offset = 0; + while offset < rrdata_len { + let part_len = rrdata[offset] as usize; + if part_len == 0 { + break; + } + ensure!(rrdata_len - offset > part_len, "Short TXT RR data"); + offset += 1; + let part_bin = &rrdata[offset..offset + part_len]; + let part = std::str::from_utf8(part_bin)?; + cb(part)?; + offset += part_len; + } + Ok(()) +} + +pub fn query_meta(packet: &mut Vec) -> Result, Error> { + let packet_len = packet.len(); + ensure!(packet_len > DNS_OFFSET_QUESTION, "Short packet"); + ensure!(packet_len <= DNS_MAX_PACKET_SIZE, "Large packet"); + ensure!(qdcount(packet) == 1, "No question"); + let mut offset = skip_name(packet, DNS_OFFSET_QUESTION)?; + assert!(offset > DNS_OFFSET_QUESTION); + ensure!(packet_len - offset >= 4, "Short packet"); + offset += 4; + let (ancount, nscount, arcount) = (ancount(packet), nscount(packet), arcount(packet)); + offset = traverse_rrs( + packet, + offset, + ancount as usize + nscount as usize, + |_offset| Ok(()), + )?; + let mut token = None; + traverse_rrs(packet, offset, arcount as _, |mut offset| { + let qtype = BigEndian::read_u16(&packet[offset..]); + let qclass = BigEndian::read_u16(&packet[offset + 2..]); + if qtype != DNS_TYPE_TXT || qclass != DNS_CLASS_INET { + return Ok(()); + } + let len = BigEndian::read_u16(&packet[offset + 8..]) as usize; + offset += 10; + ensure!(packet_len - offset >= len, "Short packet"); + let rrdata = &packet[offset..offset + len]; + parse_txt_rrdata(rrdata, |txt| { + if txt.len() < 7 || !txt.starts_with("token:") { + return Ok(()); + } + ensure!(token.is_none(), "Duplicate token"); + let found_token = &txt[6..]; + let found_token = found_token.to_owned(); + token = Some(found_token); + Ok(()) + })?; + Ok(()) + })?; + if token.is_some() { + arcount_clear(packet)?; + packet.truncate(offset); + } + Ok(token) +} + pub fn serve_nxdomain_response(client_packet: Vec) -> Result, Error> { ensure!(client_packet.len() >= DNS_HEADER_SIZE, "Short packet"); ensure!(qdcount(&client_packet) == 1, "No question"); diff --git a/src/globals.rs b/src/globals.rs index 21e36ce..beb25ce 100644 --- a/src/globals.rs +++ b/src/globals.rs @@ -48,6 +48,7 @@ pub struct Globals { pub anonymized_dns_allowed_ports: Vec, pub anonymized_dns_allow_non_reserved_ports: bool, pub anonymized_dns_blacklisted_ips: Vec, + pub access_control_tokens: Option>, #[cfg(feature = "metrics")] #[derivative(Debug = "ignore")] pub varz: Varz, diff --git a/src/main.rs b/src/main.rs index b9c2b8d..d341a7b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -218,6 +218,12 @@ async fn handle_client_query( !dns::is_response(&packet), "Question expected, but got a response instead" ); + if let Some(tokens) = &globals.access_control_tokens { + match query_meta(&mut packet)? { + None => bail!("No access token"), + Some(token) => ensure!(tokens.contains(&token), "Access token not found"), + } + } let response = resolver::get_cached_response_or_resolve(&globals, &mut packet).await?; encrypt_and_respond_to_query( globals, @@ -652,6 +658,13 @@ fn main() -> Result<(), Error> { anonymized_dns.blacklisted_ips, ), }; + let access_control_tokens = match config.access_control { + None => None, + Some(access_control) => match access_control.enabled { + false => None, + true => Some(access_control.tokens), + }, + }; let runtime_handle = runtime.handle(); let globals = Arc::new(Globals { runtime_handle: runtime_handle.clone(), @@ -689,6 +702,7 @@ fn main() -> Result<(), Error> { anonymized_dns_allowed_ports, anonymized_dns_allow_non_reserved_ports, anonymized_dns_blacklisted_ips, + access_control_tokens, #[cfg(feature = "metrics")] varz: Varz::default(), });