From 24a8cf84019a76d1289e0b3caa7c466f9afeac2e Mon Sep 17 00:00:00 2001 From: Chip Senkbeil Date: Wed, 28 Jul 2021 01:32:20 -0500 Subject: [PATCH] Refactored listener code into a handler module, wrote support to split transport into read and write halves, implemented most of backend although process run is not working yet --- src/data.rs | 2 +- src/net/mod.rs | 2 +- src/net/transport/codec.rs | 1 + src/net/transport/mod.rs | 75 ++++++- src/opt.rs | 30 ++- src/subcommand/execute.rs | 69 ++++++- src/subcommand/listen.rs | 290 -------------------------- src/subcommand/listen/handler.rs | 337 +++++++++++++++++++++++++++++++ src/subcommand/listen/mod.rs | 213 +++++++++++++++++++ 9 files changed, 708 insertions(+), 311 deletions(-) delete mode 100644 src/subcommand/listen.rs create mode 100644 src/subcommand/listen/handler.rs create mode 100644 src/subcommand/listen/mod.rs diff --git a/src/data.rs b/src/data.rs index 2de6f69..bd80890 100644 --- a/src/data.rs +++ b/src/data.rs @@ -279,7 +279,7 @@ pub enum ResponsePayload { }, /// Response to retrieving a list of managed processes - ProcList { + ProcEntries { /// List of managed processes entries: Vec, }, diff --git a/src/net/mod.rs b/src/net/mod.rs index bc1c447..e4a8995 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -1,5 +1,5 @@ mod transport; -pub use transport::{Transport, TransportError}; +pub use transport::{Transport, TransportError, TransportReadHalf, TransportWriteHalf}; use crate::{ data::{Request, Response, ResponsePayload}, diff --git a/src/net/transport/codec.rs b/src/net/transport/codec.rs index 316c824..8050609 100644 --- a/src/net/transport/codec.rs +++ b/src/net/transport/codec.rs @@ -25,6 +25,7 @@ pub enum DistantCodecError { } /// Represents the codec to encode and decode data for transmission +#[derive(Copy, Clone, Debug, PartialEq, Eq)] pub struct DistantCodec; impl<'a> Encoder<&'a [u8]> for DistantCodec { diff --git a/src/net/transport/mod.rs b/src/net/transport/mod.rs index 92634aa..2b34c84 100644 --- a/src/net/transport/mod.rs +++ b/src/net/transport/mod.rs @@ -8,9 +8,12 @@ use orion::{ }; use serde::{de::DeserializeOwned, Serialize}; use std::sync::Arc; -use tokio::{io, net::TcpStream}; +use tokio::{ + io, + net::{tcp, TcpStream}, +}; use tokio_stream::StreamExt; -use tokio_util::codec::Framed; +use tokio_util::codec::{Framed, FramedRead, FramedWrite}; mod codec; @@ -72,4 +75,72 @@ impl Transport { Ok(None) } } + + /// Splits transport into read and write halves + pub fn split(self) -> (TransportReadHalf, TransportWriteHalf) { + let key = self.key; + let parts = self.inner.into_parts(); + let (read_half, write_half) = parts.io.into_split(); + + // TODO: I can't figure out a way to re-inject the read/write buffers from parts + // into the new framed instances. This means we are dropping our old buffer + // data (I think). This shouldn't be a problem since we are splitting + // immediately, but it would be nice to cover this properly one day + let t_read = TransportReadHalf { + inner: FramedRead::new(read_half, parts.codec), + key: Arc::clone(&key), + }; + let t_write = TransportWriteHalf { + inner: FramedWrite::new(write_half, parts.codec), + key, + }; + + (t_read, t_write) + } +} + +/// Represents a transport of data out to the network +pub struct TransportWriteHalf { + inner: FramedWrite, + key: Arc, +} + +impl TransportWriteHalf { + /// Sends some data across the wire + pub async fn send(&mut self, data: T) -> Result<(), TransportError> { + // Serialize, encrypt, and then (TODO) sign + // NOTE: Cannot used packed implementation for now due to issues with deserialization + let data = serde_cbor::to_vec(&data)?; + let data = aead::seal(&self.key, &data)?; + + self.inner + .send(&data) + .await + .map_err(TransportError::CodecError) + } +} + +/// Represents a transport of data in from the network +pub struct TransportReadHalf { + inner: FramedRead, + key: Arc, +} + +impl TransportReadHalf { + /// Receives some data from out on the wire, waiting until it's available, + /// returning none if the transport is now closed + pub async fn receive(&mut self) -> Result, TransportError> { + // If data is received, we process like usual + if let Some(data) = self.inner.next().await { + // Validate (TODO) signature, decrypt, and then deserialize + let data = data?; + let data = aead::open(&self.key, &data)?; + let data = serde_cbor::from_slice(&data)?; + Ok(Some(data)) + + // Otherwise, if no data is received, this means that our socket has closed + } else { + Ok(None) + } + } } diff --git a/src/opt.rs b/src/opt.rs index 1fedc32..882a4e3 100644 --- a/src/opt.rs +++ b/src/opt.rs @@ -1,5 +1,5 @@ use crate::{subcommand, data::RequestPayload}; -use derive_more::{Display, Error, From}; +use derive_more::{Display, Error, From, IsVariant}; use lazy_static::lazy_static; use std::{ env, @@ -72,12 +72,20 @@ impl Subcommand { } } -#[derive(Copy, Clone, Debug, Display, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, IsVariant)] pub enum ExecuteFormat { - #[display(fmt = "shell")] - Shell, + /// Output responses in JSON format #[display(fmt = "json")] Json, + + /// Provides special formatting to stdout & stderr that only + /// outputs that of the remotely-executed program + #[display(fmt = "program")] + Program, + + /// Output responses in human-readable format for shells + #[display(fmt = "shell")] + Shell, } #[derive(Clone, Debug, Display, From, Error, PartialEq, Eq)] @@ -90,8 +98,9 @@ impl FromStr for ExecuteFormat { fn from_str(s: &str) -> Result { match s.trim() { - "shell" => Ok(Self::Shell), "json" => Ok(Self::Json), + "program" => Ok(Self::Program), + "shell" => Ok(Self::Shell), x => Err(ExecuteFormatParseError::InvalidVariant(x.to_string())), } } @@ -104,12 +113,13 @@ pub struct ExecuteSubcommand { /// Represents the format that results should be returned /// /// Currently, there are two possible formats: - /// 1. "shell": printing out human-readable results for interactive shell usage - /// 2. "json": printing our JSON for external program usage + /// 1. "json": printing out JSON for external program usage + /// 2. "program": printing out verbatim all stdout and stderr of remotely-executed program + /// 3. "shell": printing out human-readable results for interactive shell usage #[structopt( short, long, - value_name = "shell|json", + value_name = "json|program|shell", default_value = "shell", possible_values = &["shell", "json"] )] @@ -318,6 +328,10 @@ pub struct ListenSubcommand { #[structopt(short = "6", long)] pub use_ipv6: bool, + /// Maximum capacity for concurrent message handling by the server + #[structopt(long, default_value = "1000")] + pub max_msg_capacity: u16, + /// Set the port(s) that the server will attempt to bind to /// /// This can be in the form of PORT1 or PORT1:PORTN to provide a range of ports. diff --git a/src/subcommand/execute.rs b/src/subcommand/execute.rs index c2a69ed..8090367 100644 --- a/src/subcommand/execute.rs +++ b/src/subcommand/execute.rs @@ -1,11 +1,12 @@ use crate::{ - data::{Request, Response, ResponsePayload}, + data::{Request, RequestPayload, Response, ResponsePayload}, net::{Client, TransportError}, opt::{CommonOpt, ExecuteFormat, ExecuteSubcommand}, utils::{Session, SessionError}, }; use derive_more::{Display, Error, From}; use tokio::io; +use tokio_stream::StreamExt; #[derive(Debug, Display, Error, From)] pub enum Error { @@ -26,18 +27,68 @@ async fn run_async(cmd: ExecuteSubcommand, _opt: CommonOpt) -> Result<(), Error> let req = Request::from(cmd.operation); + // Special conditions for continuing to process responses + let is_proc_req = req.payload.is_proc_run() || req.payload.is_proc_connect(); + let not_detach = if let RequestPayload::ProcRun { detach, .. } = req.payload { + !detach + } else { + false + }; + let res = client.send(req).await?; - let res_string = match cmd.format { + print_response(cmd.format, res)?; + + // If we are executing a process and not detaching, we want to continue receiving + // responses sent to us + if is_proc_req && not_detach { + let mut stream = client.to_response_stream(); + while let Some(res) = stream.next().await { + print_response(cmd.format, res)?; + } + } + + Ok(()) +} + +fn print_response(fmt: ExecuteFormat, res: Response) -> io::Result<()> { + // If we are not program format or we are program format and got stdout/stderr, we want + // to print out the results + let is_fmt_program = fmt.is_program(); + let is_type_stderr = res.payload.is_proc_stderr(); + let do_print = !is_fmt_program || is_type_stderr || res.payload.is_proc_stdout(); + + let out = format_response(fmt, res)?; + + // Print out our response if flagged to do so + if do_print { + // If we are program format and got stderr, write it to stderr + if is_fmt_program && is_type_stderr { + eprintln!("{}", out); + + // Otherwise, always go to stdout + } else { + println!("{}", out); + } + } + + Ok(()) +} + +fn format_response(fmt: ExecuteFormat, res: Response) -> io::Result { + Ok(match fmt { ExecuteFormat::Json => serde_json::to_string(&res) .map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?, + ExecuteFormat::Program => format_program(res), ExecuteFormat::Shell => format_human(res), - }; - println!("{}", res_string); - - // TODO: Process result to determine if we want to create a watch stream and continue - // to examine results + }) +} - Ok(()) +fn format_program(res: Response) -> String { + match res.payload { + ResponsePayload::ProcStdout { data, .. } => String::from_utf8_lossy(&data).to_string(), + ResponsePayload::ProcStderr { data, .. } => String::from_utf8_lossy(&data).to_string(), + _ => String::new(), + } } fn format_human(res: Response) -> String { @@ -61,7 +112,7 @@ fn format_human(res: Response) -> String { }) .collect::>() .join("\n"), - ResponsePayload::ProcList { entries } => entries + ResponsePayload::ProcEntries { entries } => entries .into_iter() .map(|entry| format!("{}: {} {}", entry.id, entry.cmd, entry.args.join(" "))) .collect::>() diff --git a/src/subcommand/listen.rs b/src/subcommand/listen.rs deleted file mode 100644 index 6f574c4..0000000 --- a/src/subcommand/listen.rs +++ /dev/null @@ -1,290 +0,0 @@ -use crate::{ - data::{DirEntry, FileType, Request, RequestPayload, Response, ResponsePayload}, - net::Transport, - opt::{CommonOpt, ConvertToIpAddrError, ListenSubcommand}, -}; -use derive_more::{Display, Error, From}; -use fork::{daemon, Fork}; -use log::*; -use orion::aead::SecretKey; -use std::{string::FromUtf8Error, sync::Arc}; -use tokio::{ - io::{self, AsyncWriteExt}, - net::TcpListener, -}; -use walkdir::WalkDir; - -#[derive(Debug, Display, Error, From)] -pub enum Error { - ConvertToIpAddrError(ConvertToIpAddrError), - ForkError, - IoError(io::Error), - Utf8Error(FromUtf8Error), -} - -pub fn run(cmd: ListenSubcommand, opt: CommonOpt) -> Result<(), Error> { - if cmd.daemon { - // NOTE: We keep the stdin, stdout, stderr open so we can print out the pid with the parent - match daemon(false, true) { - Ok(Fork::Child) => { - let rt = tokio::runtime::Runtime::new()?; - rt.block_on(async { run_async(cmd, opt, true).await })?; - } - Ok(Fork::Parent(pid)) => { - info!("[distant detached, pid = {}]", pid); - if let Err(_) = fork::close_fd() { - return Err(Error::ForkError); - } - } - Err(_) => return Err(Error::ForkError), - } - } else { - let rt = tokio::runtime::Runtime::new()?; - rt.block_on(async { run_async(cmd, opt, false).await })?; - } - - Ok(()) -} - -async fn run_async(cmd: ListenSubcommand, _opt: CommonOpt, is_forked: bool) -> Result<(), Error> { - let addr = cmd.host.to_ip_addr(cmd.use_ipv6)?; - let socket_addrs = cmd.port.make_socket_addrs(addr); - - debug!("Binding to {} in range {}", addr, cmd.port); - let listener = TcpListener::bind(socket_addrs.as_slice()).await?; - - let port = listener.local_addr()?.port(); - debug!("Bound to port: {}", port); - - let key = Arc::new(SecretKey::default()); - - // Print information about port, key, etc. unless told not to - if !cmd.no_print_startup_data { - publish_data(port, &key); - } - - // For the child, we want to fully disconnect it from pipes, which we do now - if is_forked { - if let Err(_) = fork::close_fd() { - return Err(Error::ForkError); - } - } - - // Wait for a client connection, then spawn a new task to handle - // receiving data from the client - while let Ok((client, _)) = listener.accept().await { - // Grab the client's remote address for later logging purposes - let addr_string = match client.peer_addr() { - Ok(addr) => { - let addr_string = addr.to_string(); - info!(" Established connection", addr_string); - addr_string - } - Err(x) => { - error!("Unable to examine client's peer address: {}", x); - "???".to_string() - } - }; - - // Build a transport around the client - let mut transport = Transport::new(client, Arc::clone(&key)); - - // Spawn a new task that loops to handle requests from the client - tokio::spawn(async move { - loop { - match transport.receive::().await { - Ok(Some(request)) => { - trace!( - " Received request of type {}", - addr_string.as_str(), - request.payload.as_ref() - ); - - // Process the request, converting any error into an error response - let response = Response::from_payload_with_origin( - match process_request_payload(request.payload).await { - Ok(payload) => payload, - Err(x) => ResponsePayload::Error { - description: x.to_string(), - }, - }, - request.id, - ); - - if let Err(x) = transport.send(response).await { - error!(" {}", addr_string.as_str(), x); - break; - } - } - Ok(None) => { - info!(" Closed connection", addr_string.as_str()); - break; - } - Err(x) => { - error!(" {}", addr_string.as_str(), x); - break; - } - } - } - }); - } - - Ok(()) -} - -fn publish_data(port: u16, key: &SecretKey) { - // TODO: We have to share the key in some manner (maybe use k256 to arrive at the same key?) - // For now, we do what mosh does and print out the key knowing that this is shared over - // ssh, which should provide security - println!( - "DISTANT DATA {} {}", - port, - hex::encode(key.unprotected_as_bytes()) - ); -} - -async fn process_request_payload( - payload: RequestPayload, -) -> Result> { - match payload { - RequestPayload::FileRead { path } => Ok(ResponsePayload::Blob { - data: tokio::fs::read(path).await?, - }), - - RequestPayload::FileReadText { path } => Ok(ResponsePayload::Text { - data: tokio::fs::read_to_string(path).await?, - }), - - RequestPayload::FileWrite { - path, - input: _, - data, - } => { - tokio::fs::write(path, data).await?; - Ok(ResponsePayload::Ok) - } - - RequestPayload::FileAppend { - path, - input: _, - data, - } => { - let mut file = tokio::fs::OpenOptions::new() - .append(true) - .open(path) - .await?; - file.write_all(&data).await?; - Ok(ResponsePayload::Ok) - } - - RequestPayload::DirRead { path, all } => { - // Traverse, but don't include root directory in entries (hence min depth 1) - let dir = WalkDir::new(path.as_path()).min_depth(1); - - // If all, will recursively traverse, otherwise just return directly from dir - let dir = if all { dir } else { dir.max_depth(1) }; - - // TODO: Support both returning errors and successfully-traversed entries - // TODO: Support returning full paths instead of always relative? - Ok(ResponsePayload::DirEntries { - entries: dir - .into_iter() - .map(|e| { - e.map(|e| DirEntry { - path: e.path().strip_prefix(path.as_path()).unwrap().to_path_buf(), - file_type: if e.file_type().is_dir() { - FileType::Dir - } else if e.file_type().is_file() { - FileType::File - } else { - FileType::SymLink - }, - depth: e.depth(), - }) - }) - .collect::, walkdir::Error>>()?, - }) - } - - RequestPayload::DirCreate { path, all } => { - if all { - tokio::fs::create_dir_all(path).await?; - } else { - tokio::fs::create_dir(path).await?; - } - - Ok(ResponsePayload::Ok) - } - - RequestPayload::Remove { path, force } => { - let path_metadata = tokio::fs::metadata(path.as_path()).await?; - if path_metadata.is_dir() { - if force { - tokio::fs::remove_dir_all(path).await?; - } else { - tokio::fs::remove_dir(path).await?; - } - } else { - tokio::fs::remove_file(path).await?; - } - - Ok(ResponsePayload::Ok) - } - - RequestPayload::Copy { src, dst } => { - let src_metadata = tokio::fs::metadata(src.as_path()).await?; - if src_metadata.is_dir() { - for entry in WalkDir::new(src.as_path()) - .min_depth(1) - .follow_links(false) - .into_iter() - .filter_entry(|e| e.file_type().is_file() || e.path_is_symlink()) - { - let entry = entry?; - - // Get unique portion of path relative to src - // NOTE: Because we are traversing files that are all within src, this - // should always succeed - let local_src = entry.path().strip_prefix(src.as_path()).unwrap(); - - // Get the file without any directories - let local_src_file_name = local_src.file_name().unwrap(); - - // Get the directory housing the file - // NOTE: Because we enforce files/symlinks, there will always be a parent - let local_src_dir = local_src.parent().unwrap(); - - // Map out the path to the destination - let dst_parent_dir = dst.join(local_src_dir); - - // Create the destination directory for the file when copying - tokio::fs::create_dir_all(dst_parent_dir.as_path()).await?; - - // Perform copying from entry to destination - let dst_file = dst_parent_dir.join(local_src_file_name); - tokio::fs::copy(entry.path(), dst_file).await?; - } - } else { - tokio::fs::copy(src, dst).await?; - } - - Ok(ResponsePayload::Ok) - } - - RequestPayload::Rename { src, dst } => { - tokio::fs::rename(src, dst).await?; - - Ok(ResponsePayload::Ok) - } - - RequestPayload::ProcRun { cmd, args, detach } => todo!(), - - RequestPayload::ProcConnect { id } => todo!(), - - RequestPayload::ProcKill { id } => todo!(), - - RequestPayload::ProcStdin { id, data } => todo!(), - - RequestPayload::ProcList {} => todo!(), - } -} diff --git a/src/subcommand/listen/handler.rs b/src/subcommand/listen/handler.rs new file mode 100644 index 0000000..b19c0cd --- /dev/null +++ b/src/subcommand/listen/handler.rs @@ -0,0 +1,337 @@ +use super::{Process, State}; +use crate::data::{ + DirEntry, FileType, Request, RequestPayload, Response, ResponsePayload, RunningProcess, +}; +use log::*; +use std::{error::Error, path::PathBuf, process::Stdio, sync::Arc}; +use tokio::{ + io::{self, AsyncReadExt, AsyncWriteExt}, + process::Command, + sync::{mpsc, oneshot, Mutex}, +}; +use walkdir::WalkDir; + +pub type Reply = mpsc::Sender; +type HState = Arc>; + +/// Processes the provided request, sending replies using the given sender +pub(super) async fn process( + client_id: usize, + state: HState, + req: Request, + tx: Reply, +) -> Result<(), mpsc::error::SendError> { + async fn inner( + client_id: usize, + state: HState, + payload: RequestPayload, + tx: Reply, + ) -> Result> { + match payload { + RequestPayload::FileRead { path } => file_read(path).await, + RequestPayload::FileReadText { path } => file_read_text(path).await, + RequestPayload::FileWrite { path, data, .. } => file_write(path, data).await, + RequestPayload::FileAppend { path, data, .. } => file_append(path, data).await, + RequestPayload::DirRead { path, all } => dir_read(path, all).await, + RequestPayload::DirCreate { path, all } => dir_create(path, all).await, + RequestPayload::Remove { path, force } => remove(path, force).await, + RequestPayload::Copy { src, dst } => copy(src, dst).await, + RequestPayload::Rename { src, dst } => rename(src, dst).await, + RequestPayload::ProcRun { cmd, args, detach } => { + proc_run(client_id, state, tx, cmd, args, detach).await + } + RequestPayload::ProcConnect { id } => proc_connect(id).await, + RequestPayload::ProcKill { id } => proc_kill(state, id).await, + RequestPayload::ProcStdin { id, data } => proc_stdin(state, id, data).await, + RequestPayload::ProcList {} => proc_list(state).await, + } + } + + let res = Response::from_payload_with_origin( + match inner(client_id, state, req.payload, tx.clone()).await { + Ok(payload) => payload, + Err(x) => ResponsePayload::Error { + description: x.to_string(), + }, + }, + req.id, + ); + + // Send out our primary response from processing the request + tx.send(res).await +} + +async fn file_read(path: PathBuf) -> Result> { + Ok(ResponsePayload::Blob { + data: tokio::fs::read(path).await?, + }) +} + +async fn file_read_text(path: PathBuf) -> Result> { + Ok(ResponsePayload::Text { + data: tokio::fs::read_to_string(path).await?, + }) +} + +async fn file_write(path: PathBuf, data: Vec) -> Result> { + tokio::fs::write(path, data).await?; + Ok(ResponsePayload::Ok) +} + +async fn file_append(path: PathBuf, data: Vec) -> Result> { + let mut file = tokio::fs::OpenOptions::new() + .append(true) + .open(path) + .await?; + file.write_all(&data).await?; + Ok(ResponsePayload::Ok) +} + +async fn dir_read(path: PathBuf, all: bool) -> Result> { + // Traverse, but don't include root directory in entries (hence min depth 1) + let dir = WalkDir::new(path.as_path()).min_depth(1); + + // If all, will recursively traverse, otherwise just return directly from dir + let dir = if all { dir } else { dir.max_depth(1) }; + + // TODO: Support both returning errors and successfully-traversed entries + // TODO: Support returning full paths instead of always relative? + Ok(ResponsePayload::DirEntries { + entries: dir + .into_iter() + .map(|e| { + e.map(|e| DirEntry { + path: e.path().strip_prefix(path.as_path()).unwrap().to_path_buf(), + file_type: if e.file_type().is_dir() { + FileType::Dir + } else if e.file_type().is_file() { + FileType::File + } else { + FileType::SymLink + }, + depth: e.depth(), + }) + }) + .collect::, walkdir::Error>>()?, + }) +} + +async fn dir_create(path: PathBuf, all: bool) -> Result> { + if all { + tokio::fs::create_dir_all(path).await?; + } else { + tokio::fs::create_dir(path).await?; + } + + Ok(ResponsePayload::Ok) +} + +async fn remove(path: PathBuf, force: bool) -> Result> { + let path_metadata = tokio::fs::metadata(path.as_path()).await?; + if path_metadata.is_dir() { + if force { + tokio::fs::remove_dir_all(path).await?; + } else { + tokio::fs::remove_dir(path).await?; + } + } else { + tokio::fs::remove_file(path).await?; + } + + Ok(ResponsePayload::Ok) +} + +async fn copy(src: PathBuf, dst: PathBuf) -> Result> { + let src_metadata = tokio::fs::metadata(src.as_path()).await?; + if src_metadata.is_dir() { + for entry in WalkDir::new(src.as_path()) + .min_depth(1) + .follow_links(false) + .into_iter() + .filter_entry(|e| e.file_type().is_file() || e.path_is_symlink()) + { + let entry = entry?; + + // Get unique portion of path relative to src + // NOTE: Because we are traversing files that are all within src, this + // should always succeed + let local_src = entry.path().strip_prefix(src.as_path()).unwrap(); + + // Get the file without any directories + let local_src_file_name = local_src.file_name().unwrap(); + + // Get the directory housing the file + // NOTE: Because we enforce files/symlinks, there will always be a parent + let local_src_dir = local_src.parent().unwrap(); + + // Map out the path to the destination + let dst_parent_dir = dst.join(local_src_dir); + + // Create the destination directory for the file when copying + tokio::fs::create_dir_all(dst_parent_dir.as_path()).await?; + + // Perform copying from entry to destination + let dst_file = dst_parent_dir.join(local_src_file_name); + tokio::fs::copy(entry.path(), dst_file).await?; + } + } else { + tokio::fs::copy(src, dst).await?; + } + + Ok(ResponsePayload::Ok) +} + +async fn rename(src: PathBuf, dst: PathBuf) -> Result> { + tokio::fs::rename(src, dst).await?; + + Ok(ResponsePayload::Ok) +} + +async fn proc_run( + client_id: usize, + state: HState, + tx: Reply, + cmd: String, + args: Vec, + detach: bool, +) -> Result> { + let id = rand::random(); + + let mut child = Command::new(cmd.to_string()) + .args(args.clone()) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn()?; + + // Spawn a task that sends stdout as a response + let tx_2 = tx.clone(); + let mut stdout = child.stdout.take().unwrap(); + tokio::spawn(async move { + loop { + let mut data = Vec::new(); + match stdout.read_to_end(&mut data).await { + Ok(_) => { + if let Err(_) = tx_2 + .send(Response::from(ResponsePayload::ProcStdout { id, data })) + .await + { + break; + } + } + Err(_) => break, + } + } + }); + + // Spawn a task that sends stderr as a response + let mut stderr = child.stderr.take().unwrap(); + tokio::spawn(async move { + loop { + let mut data = Vec::new(); + match stderr.read_to_end(&mut data).await { + Ok(_) => { + if let Err(_) = tx + .send(Response::from(ResponsePayload::ProcStderr { id, data })) + .await + { + break; + } + } + Err(_) => break, + } + } + }); + + // Spawn a task that sends stdin to the process + // TODO: Should this be configurable? + let mut stdin = child.stdin.take().unwrap(); + let (stdin_tx, mut stdin_rx) = mpsc::channel::>(1); + tokio::spawn(async move { + while let Some(data) = stdin_rx.recv().await { + if let Err(x) = stdin.write_all(&data).await { + error!("Failed to send stdin to process {}: {}", id, x); + break; + } + } + }); + + // Spawn a task that kills the process when triggered + let (kill_tx, kill_rx) = oneshot::channel(); + tokio::spawn(async move { + let _ = kill_rx.await; + if let Err(x) = child.kill().await { + error!("Unable to kill process {}: {}", id, x); + } + }); + + // Update our state with the new process + let process = Process { + cmd, + args, + id, + stdin_tx, + kill_tx, + }; + state.lock().await.processes.insert(id, process); + + // If we are not detaching from process, we want to associate it with our client + if !detach { + state + .lock() + .await + .client_processes + .entry(client_id) + .or_insert(Vec::new()) + .push(id); + } + + Ok(ResponsePayload::ProcStart { id }) +} + +async fn proc_connect(id: usize) -> Result> { + todo!(); +} + +async fn proc_kill(state: HState, id: usize) -> Result> { + if let Some(process) = state.lock().await.processes.remove(&id) { + process.kill_tx.send(()).map_err(|_| { + io::Error::new( + io::ErrorKind::BrokenPipe, + "Unable to send kill signal to process", + ) + })?; + } + + Ok(ResponsePayload::Ok) +} + +async fn proc_stdin( + state: HState, + id: usize, + data: Vec, +) -> Result> { + if let Some(process) = state.lock().await.processes.get(&id) { + process.stdin_tx.send(data).await.map_err(|_| { + io::Error::new(io::ErrorKind::BrokenPipe, "Unable to send stdin to process") + })?; + } + + Ok(ResponsePayload::Ok) +} + +async fn proc_list(state: HState) -> Result> { + Ok(ResponsePayload::ProcEntries { + entries: state + .lock() + .await + .processes + .values() + .map(|p| RunningProcess { + cmd: p.cmd.to_string(), + args: p.args.clone(), + id: p.id, + }) + .collect(), + }) +} diff --git a/src/subcommand/listen/mod.rs b/src/subcommand/listen/mod.rs new file mode 100644 index 0000000..a0b2c78 --- /dev/null +++ b/src/subcommand/listen/mod.rs @@ -0,0 +1,213 @@ +use crate::{ + data::{Request, Response}, + net::{Transport, TransportReadHalf, TransportWriteHalf}, + opt::{CommonOpt, ConvertToIpAddrError, ListenSubcommand}, +}; +use derive_more::{Display, Error, From}; +use fork::{daemon, Fork}; +use log::*; +use orion::aead::SecretKey; +use std::{collections::HashMap, sync::Arc}; +use tokio::{ + io, + net::TcpListener, + sync::{mpsc, oneshot, Mutex}, +}; + +mod handler; + +#[derive(Debug, Display, Error, From)] +pub enum Error { + ConvertToIpAddrError(ConvertToIpAddrError), + ForkError, + IoError(io::Error), +} + +/// Holds state relevant to the server +#[derive(Default)] +struct State { + /// Map of all processes running on the server + processes: HashMap, + + /// List of processes that will be killed when a client drops + client_processes: HashMap>, +} + +impl State { + /// Cleans up state associated with a particular client + pub async fn cleanup_client(&mut self, id: usize) { + if let Some(ids) = self.client_processes.remove(&id) { + for id in ids { + if let Some(process) = self.processes.remove(&id) { + if let Err(_) = process.kill_tx.send(()) { + error!( + "Client {} failed to send process {} kill signal", + id, process.id + ); + } + } + } + } + } +} + +/// Represents an actively-running process maintained by the server +struct Process { + pub id: usize, + pub cmd: String, + pub args: Vec, + pub stdin_tx: mpsc::Sender>, + pub kill_tx: oneshot::Sender<()>, +} + +pub fn run(cmd: ListenSubcommand, opt: CommonOpt) -> Result<(), Error> { + if cmd.daemon { + // NOTE: We keep the stdin, stdout, stderr open so we can print out the pid with the parent + match daemon(false, true) { + Ok(Fork::Child) => { + let rt = tokio::runtime::Runtime::new()?; + rt.block_on(async { run_async(cmd, opt, true).await })?; + } + Ok(Fork::Parent(pid)) => { + info!("[distant detached, pid = {}]", pid); + if let Err(_) = fork::close_fd() { + return Err(Error::ForkError); + } + } + Err(_) => return Err(Error::ForkError), + } + } else { + let rt = tokio::runtime::Runtime::new()?; + rt.block_on(async { run_async(cmd, opt, false).await })?; + } + + Ok(()) +} + +async fn run_async(cmd: ListenSubcommand, _opt: CommonOpt, is_forked: bool) -> Result<(), Error> { + let addr = cmd.host.to_ip_addr(cmd.use_ipv6)?; + let socket_addrs = cmd.port.make_socket_addrs(addr); + + debug!("Binding to {} in range {}", addr, cmd.port); + let listener = TcpListener::bind(socket_addrs.as_slice()).await?; + + let port = listener.local_addr()?.port(); + debug!("Bound to port: {}", port); + + let key = Arc::new(SecretKey::default()); + + // Print information about port, key, etc. unless told not to + if !cmd.no_print_startup_data { + publish_data(port, &key); + } + + // For the child, we want to fully disconnect it from pipes, which we do now + if is_forked { + if let Err(_) = fork::close_fd() { + return Err(Error::ForkError); + } + } + + // Build our state for the server + let state = Arc::new(Mutex::new(State::default())); + + // Wait for a client connection, then spawn a new task to handle + // receiving data from the client + while let Ok((client, _)) = listener.accept().await { + // Grab the client's remote address for later logging purposes + let addr_string = match client.peer_addr() { + Ok(addr) => { + let addr_string = addr.to_string(); + info!(" Established connection", addr_string); + addr_string + } + Err(x) => { + error!("Unable to examine client's peer address: {}", x); + "???".to_string() + } + }; + + // Create a unique id for the client + let id = rand::random(); + + // Build a transport around the client, splitting into read and write halves so we can + // handle input and output concurrently + let (t_read, t_write) = Transport::new(client, Arc::clone(&key)).split(); + let (tx, rx) = mpsc::channel(cmd.max_msg_capacity as usize); + + // Spawn a new task that loops to handle requests from the client + tokio::spawn({ + let f = request_loop(id, addr_string.to_string(), Arc::clone(&state), t_read, tx); + + let state = Arc::clone(&state); + async move { + f.await; + state.lock().await.cleanup_client(id).await; + } + }); + + // Spawn a new task that loops to handle responses to the client + tokio::spawn(async move { response_loop(addr_string, t_write, rx).await }); + } + + Ok(()) +} + +/// Repeatedly reads in new requests, processes them, and sends their responses to the +/// response loop +async fn request_loop( + id: usize, + addr: String, + state: Arc>, + mut transport: TransportReadHalf, + tx: mpsc::Sender, +) { + loop { + match transport.receive::().await { + Ok(Some(req)) => { + trace!( + " Received request of type {}", + addr.as_str(), + req.payload.as_ref() + ); + + if let Err(x) = handler::process(id, Arc::clone(&state), req, tx.clone()).await { + error!(" {}", addr.as_str(), x); + break; + } + } + Ok(None) => { + info!(" Closed connection", addr.as_str()); + break; + } + Err(x) => { + error!(" {}", addr.as_str(), x); + break; + } + } + } +} + +async fn response_loop( + addr: String, + mut transport: TransportWriteHalf, + mut rx: mpsc::Receiver, +) { + while let Some(res) = rx.recv().await { + if let Err(x) = transport.send(res).await { + error!(" {}", addr.as_str(), x); + break; + } + } +} + +fn publish_data(port: u16, key: &SecretKey) { + // TODO: We have to share the key in some manner (maybe use k256 to arrive at the same key?) + // For now, we do what mosh does and print out the key knowing that this is shared over + // ssh, which should provide security + println!( + "DISTANT DATA {} {}", + port, + hex::encode(key.unprotected_as_bytes()) + ); +}