diff --git a/src/globals.rs b/src/globals.rs index cd26a49..f62e76e 100644 --- a/src/globals.rs +++ b/src/globals.rs @@ -2,6 +2,7 @@ use crate::crypto::*; use crate::dnscrypt_certs::*; use std::net::SocketAddr; +use std::sync::atomic::AtomicU32; use std::sync::Arc; use std::time::Duration; use tokio::runtime::Runtime; @@ -17,4 +18,6 @@ pub struct Globals { pub upstream_addr: SocketAddr, pub udp_timeout: Duration, pub tcp_timeout: Duration, + pub udp_concurrent_connections: Arc, + pub tcp_concurrent_connections: Arc, } diff --git a/src/main.rs b/src/main.rs index 2c6f103..b167d70 100644 --- a/src/main.rs +++ b/src/main.rs @@ -39,6 +39,7 @@ use std::convert::TryFrom; use std::mem; use std::net::SocketAddr; use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; +use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; use std::time::Duration; use tokio::net::{TcpListener, TcpStream, UdpSocket}; @@ -153,31 +154,33 @@ async fn tcp_acceptor(globals: Arc, tcp_listener: TcpListener) -> Resul let runtime = globals.runtime.clone(); let mut tcp_listener = tcp_listener.incoming(); let timeout = globals.tcp_timeout; + let concurrent_connections = globals.tcp_concurrent_connections.clone(); while let Some(client) = tcp_listener.next().await { let mut client_connection: TcpStream = match client { Ok(client_connection) => client_connection, Err(e) => bail!(e), }; + concurrent_connections.fetch_add(1, Ordering::Relaxed); client_connection.set_nodelay(true)?; let globals = globals.clone(); - runtime.spawn( - async { - let mut binlen = [0u8, 0]; - client_connection.read_exact(&mut binlen).await?; - let packet_len = BigEndian::read_u16(&binlen) as usize; - ensure!( - (DNSCRYPT_QUERY_MIN_SIZE..=DNSCRYPT_QUERY_MAX_SIZE).contains(&packet_len), - "Unexpected query size" - ); - let mut packet = vec![0u8; packet_len]; - client_connection.read_exact(&mut packet).await?; - let client_ctx = ClientCtx::Tcp(TcpClientCtx { client_connection }); - let _ = handle_client_query(globals, client_ctx, packet).await; - Ok(()) - } - .timeout(timeout) - .map(|_| ()), - ); + let concurrent_connections = concurrent_connections.clone(); + let fut = async { + let mut binlen = [0u8, 0]; + client_connection.read_exact(&mut binlen).await?; + let packet_len = BigEndian::read_u16(&binlen) as usize; + ensure!( + (DNSCRYPT_QUERY_MIN_SIZE..=DNSCRYPT_QUERY_MAX_SIZE).contains(&packet_len), + "Unexpected query size" + ); + let mut packet = vec![0u8; packet_len]; + client_connection.read_exact(&mut packet).await?; + let client_ctx = ClientCtx::Tcp(TcpClientCtx { client_connection }); + let _ = handle_client_query(globals, client_ctx, packet).await; + Ok(()) + }; + runtime.spawn(fut.timeout(timeout).map(move |_| { + concurrent_connections.fetch_sub(1, Ordering::Relaxed); + })); } Ok(()) } @@ -189,6 +192,7 @@ async fn udp_acceptor( let runtime = globals.runtime.clone(); let mut tokio_udp_socket = UdpSocket::try_from(net_udp_socket.try_clone()?)?; let timeout = globals.udp_timeout; + let concurrent_connections = globals.udp_concurrent_connections.clone(); loop { let mut packet = vec![0u8; DNSCRYPT_QUERY_MAX_SIZE]; let (packet_len, client_addr) = tokio_udp_socket.recv_from(&mut packet).await?; @@ -198,12 +202,13 @@ async fn udp_acceptor( net_udp_socket, client_addr, }); + concurrent_connections.fetch_add(1, Ordering::Relaxed); let globals = globals.clone(); - runtime.spawn( - async { handle_client_query(globals, client_ctx, packet).await } - .timeout(timeout) - .map(|_| ()), - ); + let concurrent_connections = concurrent_connections.clone(); + let fut = handle_client_query(globals, client_ctx, packet); + runtime.spawn(fut.timeout(timeout).map(move |_| { + concurrent_connections.fetch_sub(1, Ordering::Relaxed); + })); } } @@ -308,6 +313,8 @@ fn main() -> Result<(), Error> { external_addr, tcp_timeout, udp_timeout, + udp_concurrent_connections: Arc::new(AtomicU32::new(0)), + tcp_concurrent_connections: Arc::new(AtomicU32::new(0)), }); runtime.spawn(start(globals, runtime.clone()).map(|_| ())); runtime.block_on(future::pending::<()>());