Implement access control

pull/34/head
Frank Denis 4 years ago
parent 691129eec2
commit d5b06a6653

@ -109,6 +109,10 @@ Putting it in a directory that is only readable by the super-user is not a bad i
Domains can be filtered directly by the proxy, see the `[filtering]` section of the configuration file. Domains can be filtered directly by the proxy, see the `[filtering]` section of the configuration file.
## Access control
Access control can be enabled in the `[access_control]` section and configured with the `query_meta` configuration value of `dnscrypt-proxy`.
## Prometheus metrics ## Prometheus metrics
Prometheus metrics can optionally be enabled in order to monitor performance, cache efficiency, and more. Prometheus metrics can optionally be enabled in order to monitor performance, cache efficiency, and more.

@ -221,3 +221,24 @@ allow_non_reserved_ports = false
# Blacklisted upstream IP addresses # Blacklisted upstream IP addresses
blacklisted_ips = [ "93.184.216.34" ] blacklisted_ips = [ "93.184.216.34" ]
################################
# Access control #
################################
[access_control]
# Enabled access control
enabled = false
# Only allow access to client queries including one of these random tokens
# Tokens can be configured in the `query_meta` section of `dnscrypt-proxy` as
# `query_meta = ["token:..."]` -- Replace ... with the token to use by the client.
# Example: `query_meta = ["token:Y2oHkDJNHz"]`
tokens = ["Y2oHkDJNHz", "G5zY3J5cHQtY", "C5zZWN1cmUuZG5z"]

@ -9,6 +9,12 @@ use std::net::{IpAddr, SocketAddr};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use tokio::prelude::*; use tokio::prelude::*;
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct AccessControlConfig {
pub enabled: bool,
pub tokens: Vec<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub struct AnonymizedDNSConfig { pub struct AnonymizedDNSConfig {
pub enabled: bool, pub enabled: bool,
@ -79,6 +85,7 @@ pub struct Config {
#[cfg(feature = "metrics")] #[cfg(feature = "metrics")]
pub metrics: Option<MetricsConfig>, pub metrics: Option<MetricsConfig>,
pub anonymized_dns: Option<AnonymizedDNSConfig>, pub anonymized_dns: Option<AnonymizedDNSConfig>,
pub access_control: Option<AccessControlConfig>,
} }
impl Config { impl Config {

@ -104,6 +104,12 @@ fn arcount_inc(packet: &mut [u8]) -> Result<(), Error> {
Ok(()) Ok(())
} }
#[inline]
fn arcount_clear(packet: &mut [u8]) -> Result<(), Error> {
BigEndian::write_u16(&mut packet[10..], 0);
Ok(())
}
#[inline] #[inline]
pub fn an_ns_ar_count_clear(packet: &mut [u8]) { pub fn an_ns_ar_count_clear(packet: &mut [u8]) {
packet[6..12].iter_mut().for_each(|x| *x = 0); packet[6..12].iter_mut().for_each(|x| *x = 0);
@ -312,13 +318,13 @@ fn traverse_rrs<F: FnMut(usize) -> Result<(), Error>>(
for _ in 0..rrcount { for _ in 0..rrcount {
offset = skip_name(packet, offset)?; offset = skip_name(packet, offset)?;
ensure!(packet_len - offset >= 10, "Short packet"); ensure!(packet_len - offset >= 10, "Short packet");
cb(offset)?;
let rdlen = BigEndian::read_u16(&packet[offset + 8..]) as usize; let rdlen = BigEndian::read_u16(&packet[offset + 8..]) as usize;
offset += 10;
ensure!( ensure!(
packet_len - offset >= rdlen, packet_len - offset >= 10 + rdlen,
"Record length would exceed packet length" "Record length would exceed packet length"
); );
cb(offset)?;
offset += 10;
offset += rdlen; offset += rdlen;
} }
Ok(offset) Ok(offset)
@ -334,13 +340,13 @@ fn traverse_rrs_mut<F: FnMut(&mut [u8], usize) -> Result<(), Error>>(
for _ in 0..rrcount { for _ in 0..rrcount {
offset = skip_name(packet, offset)?; offset = skip_name(packet, offset)?;
ensure!(packet_len - offset >= 10, "Short packet"); ensure!(packet_len - offset >= 10, "Short packet");
cb(packet, offset)?;
let rdlen = BigEndian::read_u16(&packet[offset + 8..]) as usize; let rdlen = BigEndian::read_u16(&packet[offset + 8..]) as usize;
offset += 10;
ensure!( ensure!(
packet_len - offset >= rdlen, packet_len - offset >= 10 + rdlen,
"Record length would exceed packet length" "Record length would exceed packet length"
); );
cb(packet, offset)?;
offset += 10;
offset += rdlen; offset += rdlen;
} }
Ok(offset) Ok(offset)
@ -502,6 +508,73 @@ pub fn qtype_qclass(packet: &[u8]) -> Result<(u16, u16), Error> {
Ok((qtype, qclass)) Ok((qtype, qclass))
} }
fn parse_txt_rrdata<F: FnMut(&str) -> Result<(), Error>>(
rrdata: &[u8],
mut cb: F,
) -> Result<(), Error> {
let rrdata_len = rrdata.len();
let mut offset = 0;
while offset < rrdata_len {
let part_len = rrdata[offset] as usize;
if part_len == 0 {
break;
}
ensure!(rrdata_len - offset > part_len, "Short TXT RR data");
offset += 1;
let part_bin = &rrdata[offset..offset + part_len];
let part = std::str::from_utf8(part_bin)?;
cb(part)?;
offset += part_len;
}
Ok(())
}
pub fn query_meta(packet: &mut Vec<u8>) -> Result<Option<String>, Error> {
let packet_len = packet.len();
ensure!(packet_len > DNS_OFFSET_QUESTION, "Short packet");
ensure!(packet_len <= DNS_MAX_PACKET_SIZE, "Large packet");
ensure!(qdcount(packet) == 1, "No question");
let mut offset = skip_name(packet, DNS_OFFSET_QUESTION)?;
assert!(offset > DNS_OFFSET_QUESTION);
ensure!(packet_len - offset >= 4, "Short packet");
offset += 4;
let (ancount, nscount, arcount) = (ancount(packet), nscount(packet), arcount(packet));
offset = traverse_rrs(
packet,
offset,
ancount as usize + nscount as usize,
|_offset| Ok(()),
)?;
let mut token = None;
traverse_rrs(packet, offset, arcount as _, |mut offset| {
let qtype = BigEndian::read_u16(&packet[offset..]);
let qclass = BigEndian::read_u16(&packet[offset + 2..]);
if qtype != DNS_TYPE_TXT || qclass != DNS_CLASS_INET {
return Ok(());
}
let len = BigEndian::read_u16(&packet[offset + 8..]) as usize;
offset += 10;
ensure!(packet_len - offset >= len, "Short packet");
let rrdata = &packet[offset..offset + len];
parse_txt_rrdata(rrdata, |txt| {
if txt.len() < 7 || !txt.starts_with("token:") {
return Ok(());
}
ensure!(token.is_none(), "Duplicate token");
let found_token = &txt[6..];
let found_token = found_token.to_owned();
token = Some(found_token);
Ok(())
})?;
Ok(())
})?;
if token.is_some() {
arcount_clear(packet)?;
packet.truncate(offset);
}
Ok(token)
}
pub fn serve_nxdomain_response(client_packet: Vec<u8>) -> Result<Vec<u8>, Error> { pub fn serve_nxdomain_response(client_packet: Vec<u8>) -> Result<Vec<u8>, Error> {
ensure!(client_packet.len() >= DNS_HEADER_SIZE, "Short packet"); ensure!(client_packet.len() >= DNS_HEADER_SIZE, "Short packet");
ensure!(qdcount(&client_packet) == 1, "No question"); ensure!(qdcount(&client_packet) == 1, "No question");

@ -48,6 +48,7 @@ pub struct Globals {
pub anonymized_dns_allowed_ports: Vec<u16>, pub anonymized_dns_allowed_ports: Vec<u16>,
pub anonymized_dns_allow_non_reserved_ports: bool, pub anonymized_dns_allow_non_reserved_ports: bool,
pub anonymized_dns_blacklisted_ips: Vec<IpAddr>, pub anonymized_dns_blacklisted_ips: Vec<IpAddr>,
pub access_control_tokens: Option<Vec<String>>,
#[cfg(feature = "metrics")] #[cfg(feature = "metrics")]
#[derivative(Debug = "ignore")] #[derivative(Debug = "ignore")]
pub varz: Varz, pub varz: Varz,

@ -218,6 +218,12 @@ async fn handle_client_query(
!dns::is_response(&packet), !dns::is_response(&packet),
"Question expected, but got a response instead" "Question expected, but got a response instead"
); );
if let Some(tokens) = &globals.access_control_tokens {
match query_meta(&mut packet)? {
None => bail!("No access token"),
Some(token) => ensure!(tokens.contains(&token), "Access token not found"),
}
}
let response = resolver::get_cached_response_or_resolve(&globals, &mut packet).await?; let response = resolver::get_cached_response_or_resolve(&globals, &mut packet).await?;
encrypt_and_respond_to_query( encrypt_and_respond_to_query(
globals, globals,
@ -652,6 +658,13 @@ fn main() -> Result<(), Error> {
anonymized_dns.blacklisted_ips, anonymized_dns.blacklisted_ips,
), ),
}; };
let access_control_tokens = match config.access_control {
None => None,
Some(access_control) => match access_control.enabled {
false => None,
true => Some(access_control.tokens),
},
};
let runtime_handle = runtime.handle(); let runtime_handle = runtime.handle();
let globals = Arc::new(Globals { let globals = Arc::new(Globals {
runtime_handle: runtime_handle.clone(), runtime_handle: runtime_handle.clone(),
@ -689,6 +702,7 @@ fn main() -> Result<(), Error> {
anonymized_dns_allowed_ports, anonymized_dns_allowed_ports,
anonymized_dns_allow_non_reserved_ports, anonymized_dns_allow_non_reserved_ports,
anonymized_dns_blacklisted_ips, anonymized_dns_blacklisted_ips,
access_control_tokens,
#[cfg(feature = "metrics")] #[cfg(feature = "metrics")]
varz: Varz::default(), varz: Varz::default(),
}); });

Loading…
Cancel
Save