diff --git a/CHANGELOG.md b/CHANGELOG.md index 05d8307..2b0a81e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 file patterns, and execute in parallel via the `ignore` crate and `num_cpus` crate to calculate thread count +### Fixed + +- Resolution of `BindAddress` now properly handles hostnames ranging from + `localhost` to `example.com` +- Parsing of `BindAddress` no longer causes a stack overflow + ## [0.19.0] - 2022-08-30 ### Added diff --git a/src/cli/commands/server.rs b/src/cli/commands/server.rs index d2ac050..420d5f9 100644 --- a/src/cli/commands/server.rs +++ b/src/cli/commands/server.rs @@ -141,8 +141,8 @@ impl ServerSubcommand { } let host = get!(host).unwrap_or(BindAddress::Any); - trace!("Starting server using unresolved host '{}'", host); - let addr = host.resolve(get!(@flag use_ipv6))?; + trace!("Starting server using unresolved host '{host}'"); + let addr = host.resolve(get!(@flag use_ipv6)).await?; // If specified, change the current working directory of this program if let Some(path) = get!(current_dir) { diff --git a/src/config/server/listen.rs b/src/config/server/listen.rs index 606883c..10bed5f 100644 --- a/src/config/server/listen.rs +++ b/src/config/server/listen.rs @@ -1,14 +1,13 @@ use anyhow::Context; use clap::Args; -use derive_more::Display; use distant_core::{ net::{PortRange, Shutdown}, - Map, + Host, HostParseError, Map, }; use serde::{Deserialize, Serialize}; use std::{ - env, - net::{AddrParseError, IpAddr, Ipv4Addr, Ipv6Addr}, + env, fmt, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, path::PathBuf, str::FromStr, }; @@ -107,17 +106,37 @@ impl From for Map { } /// Represents options for binding a server to an IP address -#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum BindAddress { - #[display = "ssh"] + /// Should read address from `SSH_CONNECTION` environment variable, which contains four + /// space-separated values: + /// + /// * client IP address + /// * client port number + /// * server IP address + /// * server port number Ssh, - #[display = "any"] + + /// Should bind to `0.0.0.0` or `::` depending on ipv6 flag Any, - Ip(IpAddr), + + /// Should bind to the specified host, which could be `example.com`, `localhost`, or an IP + /// address like `203.0.113.1` or `2001:DB8::1` + Host(Host), +} + +impl fmt::Display for BindAddress { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Ssh => write!(f, "ssh"), + Self::Any => write!(f, "any"), + Self::Host(host) => write!(f, "{host}"), + } + } } impl FromStr for BindAddress { - type Err = AddrParseError; + type Err = HostParseError; fn from_str(s: &str) -> Result { let s = s.trim(); @@ -126,7 +145,7 @@ impl FromStr for BindAddress { } else if s.eq_ignore_ascii_case("any") { Self::Any } else { - s.parse()? + Self::Host(s.parse::()?) }) } } @@ -134,7 +153,7 @@ impl FromStr for BindAddress { impl BindAddress { /// Resolves address into valid IP; in the case of "any", will leverage the /// `use_ipv6` flag to determine if binding should use ipv4 or ipv6 - pub fn resolve(self, use_ipv6: bool) -> anyhow::Result { + pub async fn resolve(self, use_ipv6: bool) -> anyhow::Result { match self { Self::Ssh => { let ssh_connection = @@ -149,7 +168,160 @@ impl BindAddress { } Self::Any if use_ipv6 => Ok(IpAddr::V6(Ipv6Addr::UNSPECIFIED)), Self::Any => Ok(IpAddr::V4(Ipv4Addr::UNSPECIFIED)), - Self::Ip(addr) => Ok(addr), + Self::Host(host) => match host { + Host::Ipv4(x) => Ok(IpAddr::V4(x)), + Host::Ipv6(x) => Ok(IpAddr::V6(x)), + + // Attempt to resolve the hostname, tacking on a :80 as a port if no colon is found + // in the hostname as we MUST have a socket address like example.com:80 in order to + // perform dns resolution + Host::Name(x) => Ok(tokio::net::lookup_host(if x.contains(':') { + x.to_string() + } else { + format!("{x}:80") + }) + .await? + .map(|addr| addr.ip()) + .find(|ip| (use_ipv6 && ip.is_ipv6()) || (!use_ipv6 && ip.is_ipv4())) + .ok_or_else(|| { + anyhow::anyhow!( + "Unable to resolve {x} to {} address", + if use_ipv6 { "ipv6" } else { "ipv4" } + ) + })?), + }, } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn to_string_should_properly_print_bind_address() { + assert_eq!(BindAddress::Any.to_string(), "any"); + assert_eq!(BindAddress::Ssh.to_string(), "ssh"); + assert_eq!( + BindAddress::Host(Host::Ipv4(Ipv4Addr::new(203, 0, 113, 1))).to_string(), + "203.0.113.1" + ); + assert_eq!( + BindAddress::Host(Host::Ipv6(Ipv6Addr::new( + 0x2001, 0x0DB8, 0, 0, 0, 0, 0, 0x0001 + ))) + .to_string(), + "2001:db8::1" + ); + assert_eq!( + BindAddress::Host(Host::Name(String::from("example.com"))).to_string(), + "example.com" + ); + } + + #[test] + fn parse_should_correctly_parse_host_or_special_cases() { + assert_eq!("any".parse::().unwrap(), BindAddress::Any); + assert_eq!("ssh".parse::().unwrap(), BindAddress::Ssh); + assert_eq!( + "203.0.113.1".parse::().unwrap(), + BindAddress::Host(Host::Ipv4(Ipv4Addr::new(203, 0, 113, 1))) + ); + assert_eq!( + "2001:DB8::1".parse::().unwrap(), + BindAddress::Host(Host::Ipv6(Ipv6Addr::new( + 0x2001, 0x0DB8, 0, 0, 0, 0, 0, 0x0001 + ))) + ); + assert_eq!( + "example.com".parse::().unwrap(), + BindAddress::Host(Host::Name(String::from("example.com"))) + ); + assert_eq!( + "localhost".parse::().unwrap(), + BindAddress::Host(Host::Name(String::from("localhost"))) + ); + } + + #[tokio::test] + async fn resolve_should_properly_resolve_bind_address() { + // For ssh, we check SSH_CONNECTION, and there are three situations where this can fail: + // + // 1. The environment variable does not exist + // 2. The environment variable does not have at least 3 (out of 4) space-separated args + // 3. The environment variable has an invalid IP address as the 3 arg + BindAddress::Ssh.resolve(false).await.unwrap_err(); + + env::set_var("SSH_CONNECTION", "127.0.0.1 1234"); + BindAddress::Ssh.resolve(false).await.unwrap_err(); + + env::set_var("SSH_CONNECTION", "127.0.0.1 1234 -notaddress 1234"); + BindAddress::Ssh.resolve(false).await.unwrap_err(); + + env::set_var("SSH_CONNECTION", "127.0.0.1 1234 127.0.0.1 1234"); + assert_eq!( + BindAddress::Ssh.resolve(false).await.unwrap(), + Ipv4Addr::new(127, 0, 0, 1) + ); + + // Any will resolve to unspecified ipv4 if the ipv6 flag is false + assert_eq!( + BindAddress::Any.resolve(false).await.unwrap(), + Ipv4Addr::UNSPECIFIED + ); + + // Any will resolve to unspecified ipv6 if the ipv6 flag is true + assert_eq!( + BindAddress::Any.resolve(true).await.unwrap(), + Ipv6Addr::UNSPECIFIED + ); + + // Host with ipv4 address should return that address + assert_eq!( + BindAddress::Host(Host::Ipv4(Ipv4Addr::UNSPECIFIED)) + .resolve(false) + .await + .unwrap(), + Ipv4Addr::UNSPECIFIED + ); + + // Host with ipv6 address should return that address + assert_eq!( + BindAddress::Host(Host::Ipv6(Ipv6Addr::UNSPECIFIED)) + .resolve(false) + .await + .unwrap(), + Ipv6Addr::UNSPECIFIED + ); + + // Host with name should attempt to resolve the name + assert_eq!( + BindAddress::Host(Host::Name(String::from("example.com"))) + .resolve(false) + .await + .unwrap(), + tokio::net::lookup_host("example.com:80") + .await + .unwrap() + .next() + .unwrap() + .ip(), + ); + + // Should support resolving localhost using ipv4/ipv6 + assert_eq!( + BindAddress::Host(Host::Name(String::from("localhost"))) + .resolve(false) + .await + .unwrap(), + Ipv4Addr::LOCALHOST + ); + assert_eq!( + BindAddress::Host(Host::Name(String::from("localhost"))) + .resolve(true) + .await + .unwrap(), + Ipv6Addr::LOCALHOST + ); + } +}