Refactored listener code into a handler module, wrote support to split transport into read and write halves, implemented most of backend although process run is not working yet

pull/38/head
Chip Senkbeil 3 years ago
parent a707523fb5
commit 24a8cf8401
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -279,7 +279,7 @@ pub enum ResponsePayload {
},
/// Response to retrieving a list of managed processes
ProcList {
ProcEntries {
/// List of managed processes
entries: Vec<RunningProcess>,
},

@ -1,5 +1,5 @@
mod transport;
pub use transport::{Transport, TransportError};
pub use transport::{Transport, TransportError, TransportReadHalf, TransportWriteHalf};
use crate::{
data::{Request, Response, ResponsePayload},

@ -25,6 +25,7 @@ pub enum DistantCodecError {
}
/// Represents the codec to encode and decode data for transmission
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct DistantCodec;
impl<'a> Encoder<&'a [u8]> for DistantCodec {

@ -8,9 +8,12 @@ use orion::{
};
use serde::{de::DeserializeOwned, Serialize};
use std::sync::Arc;
use tokio::{io, net::TcpStream};
use tokio::{
io,
net::{tcp, TcpStream},
};
use tokio_stream::StreamExt;
use tokio_util::codec::Framed;
use tokio_util::codec::{Framed, FramedRead, FramedWrite};
mod codec;
@ -72,4 +75,72 @@ impl Transport {
Ok(None)
}
}
/// Splits transport into read and write halves
pub fn split(self) -> (TransportReadHalf, TransportWriteHalf) {
let key = self.key;
let parts = self.inner.into_parts();
let (read_half, write_half) = parts.io.into_split();
// TODO: I can't figure out a way to re-inject the read/write buffers from parts
// into the new framed instances. This means we are dropping our old buffer
// data (I think). This shouldn't be a problem since we are splitting
// immediately, but it would be nice to cover this properly one day
let t_read = TransportReadHalf {
inner: FramedRead::new(read_half, parts.codec),
key: Arc::clone(&key),
};
let t_write = TransportWriteHalf {
inner: FramedWrite::new(write_half, parts.codec),
key,
};
(t_read, t_write)
}
}
/// Represents a transport of data out to the network
pub struct TransportWriteHalf {
inner: FramedWrite<tcp::OwnedWriteHalf, DistantCodec>,
key: Arc<SecretKey>,
}
impl TransportWriteHalf {
/// Sends some data across the wire
pub async fn send<T: Serialize>(&mut self, data: T) -> Result<(), TransportError> {
// Serialize, encrypt, and then (TODO) sign
// NOTE: Cannot used packed implementation for now due to issues with deserialization
let data = serde_cbor::to_vec(&data)?;
let data = aead::seal(&self.key, &data)?;
self.inner
.send(&data)
.await
.map_err(TransportError::CodecError)
}
}
/// Represents a transport of data in from the network
pub struct TransportReadHalf {
inner: FramedRead<tcp::OwnedReadHalf, DistantCodec>,
key: Arc<SecretKey>,
}
impl TransportReadHalf {
/// Receives some data from out on the wire, waiting until it's available,
/// returning none if the transport is now closed
pub async fn receive<T: DeserializeOwned>(&mut self) -> Result<Option<T>, TransportError> {
// If data is received, we process like usual
if let Some(data) = self.inner.next().await {
// Validate (TODO) signature, decrypt, and then deserialize
let data = data?;
let data = aead::open(&self.key, &data)?;
let data = serde_cbor::from_slice(&data)?;
Ok(Some(data))
// Otherwise, if no data is received, this means that our socket has closed
} else {
Ok(None)
}
}
}

@ -1,5 +1,5 @@
use crate::{subcommand, data::RequestPayload};
use derive_more::{Display, Error, From};
use derive_more::{Display, Error, From, IsVariant};
use lazy_static::lazy_static;
use std::{
env,
@ -72,12 +72,20 @@ impl Subcommand {
}
}
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq)]
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, IsVariant)]
pub enum ExecuteFormat {
#[display(fmt = "shell")]
Shell,
/// Output responses in JSON format
#[display(fmt = "json")]
Json,
/// Provides special formatting to stdout & stderr that only
/// outputs that of the remotely-executed program
#[display(fmt = "program")]
Program,
/// Output responses in human-readable format for shells
#[display(fmt = "shell")]
Shell,
}
#[derive(Clone, Debug, Display, From, Error, PartialEq, Eq)]
@ -90,8 +98,9 @@ impl FromStr for ExecuteFormat {
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.trim() {
"shell" => Ok(Self::Shell),
"json" => Ok(Self::Json),
"program" => Ok(Self::Program),
"shell" => Ok(Self::Shell),
x => Err(ExecuteFormatParseError::InvalidVariant(x.to_string())),
}
}
@ -104,12 +113,13 @@ pub struct ExecuteSubcommand {
/// Represents the format that results should be returned
///
/// Currently, there are two possible formats:
/// 1. "shell": printing out human-readable results for interactive shell usage
/// 2. "json": printing our JSON for external program usage
/// 1. "json": printing out JSON for external program usage
/// 2. "program": printing out verbatim all stdout and stderr of remotely-executed program
/// 3. "shell": printing out human-readable results for interactive shell usage
#[structopt(
short,
long,
value_name = "shell|json",
value_name = "json|program|shell",
default_value = "shell",
possible_values = &["shell", "json"]
)]
@ -318,6 +328,10 @@ pub struct ListenSubcommand {
#[structopt(short = "6", long)]
pub use_ipv6: bool,
/// Maximum capacity for concurrent message handling by the server
#[structopt(long, default_value = "1000")]
pub max_msg_capacity: u16,
/// Set the port(s) that the server will attempt to bind to
///
/// This can be in the form of PORT1 or PORT1:PORTN to provide a range of ports.

@ -1,11 +1,12 @@
use crate::{
data::{Request, Response, ResponsePayload},
data::{Request, RequestPayload, Response, ResponsePayload},
net::{Client, TransportError},
opt::{CommonOpt, ExecuteFormat, ExecuteSubcommand},
utils::{Session, SessionError},
};
use derive_more::{Display, Error, From};
use tokio::io;
use tokio_stream::StreamExt;
#[derive(Debug, Display, Error, From)]
pub enum Error {
@ -26,18 +27,68 @@ async fn run_async(cmd: ExecuteSubcommand, _opt: CommonOpt) -> Result<(), Error>
let req = Request::from(cmd.operation);
// Special conditions for continuing to process responses
let is_proc_req = req.payload.is_proc_run() || req.payload.is_proc_connect();
let not_detach = if let RequestPayload::ProcRun { detach, .. } = req.payload {
!detach
} else {
false
};
let res = client.send(req).await?;
let res_string = match cmd.format {
print_response(cmd.format, res)?;
// If we are executing a process and not detaching, we want to continue receiving
// responses sent to us
if is_proc_req && not_detach {
let mut stream = client.to_response_stream();
while let Some(res) = stream.next().await {
print_response(cmd.format, res)?;
}
}
Ok(())
}
fn print_response(fmt: ExecuteFormat, res: Response) -> io::Result<()> {
// If we are not program format or we are program format and got stdout/stderr, we want
// to print out the results
let is_fmt_program = fmt.is_program();
let is_type_stderr = res.payload.is_proc_stderr();
let do_print = !is_fmt_program || is_type_stderr || res.payload.is_proc_stdout();
let out = format_response(fmt, res)?;
// Print out our response if flagged to do so
if do_print {
// If we are program format and got stderr, write it to stderr
if is_fmt_program && is_type_stderr {
eprintln!("{}", out);
// Otherwise, always go to stdout
} else {
println!("{}", out);
}
}
Ok(())
}
fn format_response(fmt: ExecuteFormat, res: Response) -> io::Result<String> {
Ok(match fmt {
ExecuteFormat::Json => serde_json::to_string(&res)
.map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?,
ExecuteFormat::Program => format_program(res),
ExecuteFormat::Shell => format_human(res),
};
println!("{}", res_string);
// TODO: Process result to determine if we want to create a watch stream and continue
// to examine results
})
}
Ok(())
fn format_program(res: Response) -> String {
match res.payload {
ResponsePayload::ProcStdout { data, .. } => String::from_utf8_lossy(&data).to_string(),
ResponsePayload::ProcStderr { data, .. } => String::from_utf8_lossy(&data).to_string(),
_ => String::new(),
}
}
fn format_human(res: Response) -> String {
@ -61,7 +112,7 @@ fn format_human(res: Response) -> String {
})
.collect::<Vec<String>>()
.join("\n"),
ResponsePayload::ProcList { entries } => entries
ResponsePayload::ProcEntries { entries } => entries
.into_iter()
.map(|entry| format!("{}: {} {}", entry.id, entry.cmd, entry.args.join(" ")))
.collect::<Vec<String>>()

@ -1,290 +0,0 @@
use crate::{
data::{DirEntry, FileType, Request, RequestPayload, Response, ResponsePayload},
net::Transport,
opt::{CommonOpt, ConvertToIpAddrError, ListenSubcommand},
};
use derive_more::{Display, Error, From};
use fork::{daemon, Fork};
use log::*;
use orion::aead::SecretKey;
use std::{string::FromUtf8Error, sync::Arc};
use tokio::{
io::{self, AsyncWriteExt},
net::TcpListener,
};
use walkdir::WalkDir;
#[derive(Debug, Display, Error, From)]
pub enum Error {
ConvertToIpAddrError(ConvertToIpAddrError),
ForkError,
IoError(io::Error),
Utf8Error(FromUtf8Error),
}
pub fn run(cmd: ListenSubcommand, opt: CommonOpt) -> Result<(), Error> {
if cmd.daemon {
// NOTE: We keep the stdin, stdout, stderr open so we can print out the pid with the parent
match daemon(false, true) {
Ok(Fork::Child) => {
let rt = tokio::runtime::Runtime::new()?;
rt.block_on(async { run_async(cmd, opt, true).await })?;
}
Ok(Fork::Parent(pid)) => {
info!("[distant detached, pid = {}]", pid);
if let Err(_) = fork::close_fd() {
return Err(Error::ForkError);
}
}
Err(_) => return Err(Error::ForkError),
}
} else {
let rt = tokio::runtime::Runtime::new()?;
rt.block_on(async { run_async(cmd, opt, false).await })?;
}
Ok(())
}
async fn run_async(cmd: ListenSubcommand, _opt: CommonOpt, is_forked: bool) -> Result<(), Error> {
let addr = cmd.host.to_ip_addr(cmd.use_ipv6)?;
let socket_addrs = cmd.port.make_socket_addrs(addr);
debug!("Binding to {} in range {}", addr, cmd.port);
let listener = TcpListener::bind(socket_addrs.as_slice()).await?;
let port = listener.local_addr()?.port();
debug!("Bound to port: {}", port);
let key = Arc::new(SecretKey::default());
// Print information about port, key, etc. unless told not to
if !cmd.no_print_startup_data {
publish_data(port, &key);
}
// For the child, we want to fully disconnect it from pipes, which we do now
if is_forked {
if let Err(_) = fork::close_fd() {
return Err(Error::ForkError);
}
}
// Wait for a client connection, then spawn a new task to handle
// receiving data from the client
while let Ok((client, _)) = listener.accept().await {
// Grab the client's remote address for later logging purposes
let addr_string = match client.peer_addr() {
Ok(addr) => {
let addr_string = addr.to_string();
info!("<Client @ {}> Established connection", addr_string);
addr_string
}
Err(x) => {
error!("Unable to examine client's peer address: {}", x);
"???".to_string()
}
};
// Build a transport around the client
let mut transport = Transport::new(client, Arc::clone(&key));
// Spawn a new task that loops to handle requests from the client
tokio::spawn(async move {
loop {
match transport.receive::<Request>().await {
Ok(Some(request)) => {
trace!(
"<Client @ {}> Received request of type {}",
addr_string.as_str(),
request.payload.as_ref()
);
// Process the request, converting any error into an error response
let response = Response::from_payload_with_origin(
match process_request_payload(request.payload).await {
Ok(payload) => payload,
Err(x) => ResponsePayload::Error {
description: x.to_string(),
},
},
request.id,
);
if let Err(x) = transport.send(response).await {
error!("<Client @ {}> {}", addr_string.as_str(), x);
break;
}
}
Ok(None) => {
info!("<Client @ {}> Closed connection", addr_string.as_str());
break;
}
Err(x) => {
error!("<Client @ {}> {}", addr_string.as_str(), x);
break;
}
}
}
});
}
Ok(())
}
fn publish_data(port: u16, key: &SecretKey) {
// TODO: We have to share the key in some manner (maybe use k256 to arrive at the same key?)
// For now, we do what mosh does and print out the key knowing that this is shared over
// ssh, which should provide security
println!(
"DISTANT DATA {} {}",
port,
hex::encode(key.unprotected_as_bytes())
);
}
async fn process_request_payload(
payload: RequestPayload,
) -> Result<ResponsePayload, Box<dyn std::error::Error>> {
match payload {
RequestPayload::FileRead { path } => Ok(ResponsePayload::Blob {
data: tokio::fs::read(path).await?,
}),
RequestPayload::FileReadText { path } => Ok(ResponsePayload::Text {
data: tokio::fs::read_to_string(path).await?,
}),
RequestPayload::FileWrite {
path,
input: _,
data,
} => {
tokio::fs::write(path, data).await?;
Ok(ResponsePayload::Ok)
}
RequestPayload::FileAppend {
path,
input: _,
data,
} => {
let mut file = tokio::fs::OpenOptions::new()
.append(true)
.open(path)
.await?;
file.write_all(&data).await?;
Ok(ResponsePayload::Ok)
}
RequestPayload::DirRead { path, all } => {
// Traverse, but don't include root directory in entries (hence min depth 1)
let dir = WalkDir::new(path.as_path()).min_depth(1);
// If all, will recursively traverse, otherwise just return directly from dir
let dir = if all { dir } else { dir.max_depth(1) };
// TODO: Support both returning errors and successfully-traversed entries
// TODO: Support returning full paths instead of always relative?
Ok(ResponsePayload::DirEntries {
entries: dir
.into_iter()
.map(|e| {
e.map(|e| DirEntry {
path: e.path().strip_prefix(path.as_path()).unwrap().to_path_buf(),
file_type: if e.file_type().is_dir() {
FileType::Dir
} else if e.file_type().is_file() {
FileType::File
} else {
FileType::SymLink
},
depth: e.depth(),
})
})
.collect::<Result<Vec<DirEntry>, walkdir::Error>>()?,
})
}
RequestPayload::DirCreate { path, all } => {
if all {
tokio::fs::create_dir_all(path).await?;
} else {
tokio::fs::create_dir(path).await?;
}
Ok(ResponsePayload::Ok)
}
RequestPayload::Remove { path, force } => {
let path_metadata = tokio::fs::metadata(path.as_path()).await?;
if path_metadata.is_dir() {
if force {
tokio::fs::remove_dir_all(path).await?;
} else {
tokio::fs::remove_dir(path).await?;
}
} else {
tokio::fs::remove_file(path).await?;
}
Ok(ResponsePayload::Ok)
}
RequestPayload::Copy { src, dst } => {
let src_metadata = tokio::fs::metadata(src.as_path()).await?;
if src_metadata.is_dir() {
for entry in WalkDir::new(src.as_path())
.min_depth(1)
.follow_links(false)
.into_iter()
.filter_entry(|e| e.file_type().is_file() || e.path_is_symlink())
{
let entry = entry?;
// Get unique portion of path relative to src
// NOTE: Because we are traversing files that are all within src, this
// should always succeed
let local_src = entry.path().strip_prefix(src.as_path()).unwrap();
// Get the file without any directories
let local_src_file_name = local_src.file_name().unwrap();
// Get the directory housing the file
// NOTE: Because we enforce files/symlinks, there will always be a parent
let local_src_dir = local_src.parent().unwrap();
// Map out the path to the destination
let dst_parent_dir = dst.join(local_src_dir);
// Create the destination directory for the file when copying
tokio::fs::create_dir_all(dst_parent_dir.as_path()).await?;
// Perform copying from entry to destination
let dst_file = dst_parent_dir.join(local_src_file_name);
tokio::fs::copy(entry.path(), dst_file).await?;
}
} else {
tokio::fs::copy(src, dst).await?;
}
Ok(ResponsePayload::Ok)
}
RequestPayload::Rename { src, dst } => {
tokio::fs::rename(src, dst).await?;
Ok(ResponsePayload::Ok)
}
RequestPayload::ProcRun { cmd, args, detach } => todo!(),
RequestPayload::ProcConnect { id } => todo!(),
RequestPayload::ProcKill { id } => todo!(),
RequestPayload::ProcStdin { id, data } => todo!(),
RequestPayload::ProcList {} => todo!(),
}
}

@ -0,0 +1,337 @@
use super::{Process, State};
use crate::data::{
DirEntry, FileType, Request, RequestPayload, Response, ResponsePayload, RunningProcess,
};
use log::*;
use std::{error::Error, path::PathBuf, process::Stdio, sync::Arc};
use tokio::{
io::{self, AsyncReadExt, AsyncWriteExt},
process::Command,
sync::{mpsc, oneshot, Mutex},
};
use walkdir::WalkDir;
pub type Reply = mpsc::Sender<Response>;
type HState = Arc<Mutex<State>>;
/// Processes the provided request, sending replies using the given sender
pub(super) async fn process(
client_id: usize,
state: HState,
req: Request,
tx: Reply,
) -> Result<(), mpsc::error::SendError<Response>> {
async fn inner(
client_id: usize,
state: HState,
payload: RequestPayload,
tx: Reply,
) -> Result<ResponsePayload, Box<dyn std::error::Error>> {
match payload {
RequestPayload::FileRead { path } => file_read(path).await,
RequestPayload::FileReadText { path } => file_read_text(path).await,
RequestPayload::FileWrite { path, data, .. } => file_write(path, data).await,
RequestPayload::FileAppend { path, data, .. } => file_append(path, data).await,
RequestPayload::DirRead { path, all } => dir_read(path, all).await,
RequestPayload::DirCreate { path, all } => dir_create(path, all).await,
RequestPayload::Remove { path, force } => remove(path, force).await,
RequestPayload::Copy { src, dst } => copy(src, dst).await,
RequestPayload::Rename { src, dst } => rename(src, dst).await,
RequestPayload::ProcRun { cmd, args, detach } => {
proc_run(client_id, state, tx, cmd, args, detach).await
}
RequestPayload::ProcConnect { id } => proc_connect(id).await,
RequestPayload::ProcKill { id } => proc_kill(state, id).await,
RequestPayload::ProcStdin { id, data } => proc_stdin(state, id, data).await,
RequestPayload::ProcList {} => proc_list(state).await,
}
}
let res = Response::from_payload_with_origin(
match inner(client_id, state, req.payload, tx.clone()).await {
Ok(payload) => payload,
Err(x) => ResponsePayload::Error {
description: x.to_string(),
},
},
req.id,
);
// Send out our primary response from processing the request
tx.send(res).await
}
async fn file_read(path: PathBuf) -> Result<ResponsePayload, Box<dyn Error>> {
Ok(ResponsePayload::Blob {
data: tokio::fs::read(path).await?,
})
}
async fn file_read_text(path: PathBuf) -> Result<ResponsePayload, Box<dyn Error>> {
Ok(ResponsePayload::Text {
data: tokio::fs::read_to_string(path).await?,
})
}
async fn file_write(path: PathBuf, data: Vec<u8>) -> Result<ResponsePayload, Box<dyn Error>> {
tokio::fs::write(path, data).await?;
Ok(ResponsePayload::Ok)
}
async fn file_append(path: PathBuf, data: Vec<u8>) -> Result<ResponsePayload, Box<dyn Error>> {
let mut file = tokio::fs::OpenOptions::new()
.append(true)
.open(path)
.await?;
file.write_all(&data).await?;
Ok(ResponsePayload::Ok)
}
async fn dir_read(path: PathBuf, all: bool) -> Result<ResponsePayload, Box<dyn Error>> {
// Traverse, but don't include root directory in entries (hence min depth 1)
let dir = WalkDir::new(path.as_path()).min_depth(1);
// If all, will recursively traverse, otherwise just return directly from dir
let dir = if all { dir } else { dir.max_depth(1) };
// TODO: Support both returning errors and successfully-traversed entries
// TODO: Support returning full paths instead of always relative?
Ok(ResponsePayload::DirEntries {
entries: dir
.into_iter()
.map(|e| {
e.map(|e| DirEntry {
path: e.path().strip_prefix(path.as_path()).unwrap().to_path_buf(),
file_type: if e.file_type().is_dir() {
FileType::Dir
} else if e.file_type().is_file() {
FileType::File
} else {
FileType::SymLink
},
depth: e.depth(),
})
})
.collect::<Result<Vec<DirEntry>, walkdir::Error>>()?,
})
}
async fn dir_create(path: PathBuf, all: bool) -> Result<ResponsePayload, Box<dyn Error>> {
if all {
tokio::fs::create_dir_all(path).await?;
} else {
tokio::fs::create_dir(path).await?;
}
Ok(ResponsePayload::Ok)
}
async fn remove(path: PathBuf, force: bool) -> Result<ResponsePayload, Box<dyn Error>> {
let path_metadata = tokio::fs::metadata(path.as_path()).await?;
if path_metadata.is_dir() {
if force {
tokio::fs::remove_dir_all(path).await?;
} else {
tokio::fs::remove_dir(path).await?;
}
} else {
tokio::fs::remove_file(path).await?;
}
Ok(ResponsePayload::Ok)
}
async fn copy(src: PathBuf, dst: PathBuf) -> Result<ResponsePayload, Box<dyn Error>> {
let src_metadata = tokio::fs::metadata(src.as_path()).await?;
if src_metadata.is_dir() {
for entry in WalkDir::new(src.as_path())
.min_depth(1)
.follow_links(false)
.into_iter()
.filter_entry(|e| e.file_type().is_file() || e.path_is_symlink())
{
let entry = entry?;
// Get unique portion of path relative to src
// NOTE: Because we are traversing files that are all within src, this
// should always succeed
let local_src = entry.path().strip_prefix(src.as_path()).unwrap();
// Get the file without any directories
let local_src_file_name = local_src.file_name().unwrap();
// Get the directory housing the file
// NOTE: Because we enforce files/symlinks, there will always be a parent
let local_src_dir = local_src.parent().unwrap();
// Map out the path to the destination
let dst_parent_dir = dst.join(local_src_dir);
// Create the destination directory for the file when copying
tokio::fs::create_dir_all(dst_parent_dir.as_path()).await?;
// Perform copying from entry to destination
let dst_file = dst_parent_dir.join(local_src_file_name);
tokio::fs::copy(entry.path(), dst_file).await?;
}
} else {
tokio::fs::copy(src, dst).await?;
}
Ok(ResponsePayload::Ok)
}
async fn rename(src: PathBuf, dst: PathBuf) -> Result<ResponsePayload, Box<dyn Error>> {
tokio::fs::rename(src, dst).await?;
Ok(ResponsePayload::Ok)
}
async fn proc_run(
client_id: usize,
state: HState,
tx: Reply,
cmd: String,
args: Vec<String>,
detach: bool,
) -> Result<ResponsePayload, Box<dyn Error>> {
let id = rand::random();
let mut child = Command::new(cmd.to_string())
.args(args.clone())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()?;
// Spawn a task that sends stdout as a response
let tx_2 = tx.clone();
let mut stdout = child.stdout.take().unwrap();
tokio::spawn(async move {
loop {
let mut data = Vec::new();
match stdout.read_to_end(&mut data).await {
Ok(_) => {
if let Err(_) = tx_2
.send(Response::from(ResponsePayload::ProcStdout { id, data }))
.await
{
break;
}
}
Err(_) => break,
}
}
});
// Spawn a task that sends stderr as a response
let mut stderr = child.stderr.take().unwrap();
tokio::spawn(async move {
loop {
let mut data = Vec::new();
match stderr.read_to_end(&mut data).await {
Ok(_) => {
if let Err(_) = tx
.send(Response::from(ResponsePayload::ProcStderr { id, data }))
.await
{
break;
}
}
Err(_) => break,
}
}
});
// Spawn a task that sends stdin to the process
// TODO: Should this be configurable?
let mut stdin = child.stdin.take().unwrap();
let (stdin_tx, mut stdin_rx) = mpsc::channel::<Vec<u8>>(1);
tokio::spawn(async move {
while let Some(data) = stdin_rx.recv().await {
if let Err(x) = stdin.write_all(&data).await {
error!("Failed to send stdin to process {}: {}", id, x);
break;
}
}
});
// Spawn a task that kills the process when triggered
let (kill_tx, kill_rx) = oneshot::channel();
tokio::spawn(async move {
let _ = kill_rx.await;
if let Err(x) = child.kill().await {
error!("Unable to kill process {}: {}", id, x);
}
});
// Update our state with the new process
let process = Process {
cmd,
args,
id,
stdin_tx,
kill_tx,
};
state.lock().await.processes.insert(id, process);
// If we are not detaching from process, we want to associate it with our client
if !detach {
state
.lock()
.await
.client_processes
.entry(client_id)
.or_insert(Vec::new())
.push(id);
}
Ok(ResponsePayload::ProcStart { id })
}
async fn proc_connect(id: usize) -> Result<ResponsePayload, Box<dyn Error>> {
todo!();
}
async fn proc_kill(state: HState, id: usize) -> Result<ResponsePayload, Box<dyn Error>> {
if let Some(process) = state.lock().await.processes.remove(&id) {
process.kill_tx.send(()).map_err(|_| {
io::Error::new(
io::ErrorKind::BrokenPipe,
"Unable to send kill signal to process",
)
})?;
}
Ok(ResponsePayload::Ok)
}
async fn proc_stdin(
state: HState,
id: usize,
data: Vec<u8>,
) -> Result<ResponsePayload, Box<dyn Error>> {
if let Some(process) = state.lock().await.processes.get(&id) {
process.stdin_tx.send(data).await.map_err(|_| {
io::Error::new(io::ErrorKind::BrokenPipe, "Unable to send stdin to process")
})?;
}
Ok(ResponsePayload::Ok)
}
async fn proc_list(state: HState) -> Result<ResponsePayload, Box<dyn Error>> {
Ok(ResponsePayload::ProcEntries {
entries: state
.lock()
.await
.processes
.values()
.map(|p| RunningProcess {
cmd: p.cmd.to_string(),
args: p.args.clone(),
id: p.id,
})
.collect(),
})
}

@ -0,0 +1,213 @@
use crate::{
data::{Request, Response},
net::{Transport, TransportReadHalf, TransportWriteHalf},
opt::{CommonOpt, ConvertToIpAddrError, ListenSubcommand},
};
use derive_more::{Display, Error, From};
use fork::{daemon, Fork};
use log::*;
use orion::aead::SecretKey;
use std::{collections::HashMap, sync::Arc};
use tokio::{
io,
net::TcpListener,
sync::{mpsc, oneshot, Mutex},
};
mod handler;
#[derive(Debug, Display, Error, From)]
pub enum Error {
ConvertToIpAddrError(ConvertToIpAddrError),
ForkError,
IoError(io::Error),
}
/// Holds state relevant to the server
#[derive(Default)]
struct State {
/// Map of all processes running on the server
processes: HashMap<usize, Process>,
/// List of processes that will be killed when a client drops
client_processes: HashMap<usize, Vec<usize>>,
}
impl State {
/// Cleans up state associated with a particular client
pub async fn cleanup_client(&mut self, id: usize) {
if let Some(ids) = self.client_processes.remove(&id) {
for id in ids {
if let Some(process) = self.processes.remove(&id) {
if let Err(_) = process.kill_tx.send(()) {
error!(
"Client {} failed to send process {} kill signal",
id, process.id
);
}
}
}
}
}
}
/// Represents an actively-running process maintained by the server
struct Process {
pub id: usize,
pub cmd: String,
pub args: Vec<String>,
pub stdin_tx: mpsc::Sender<Vec<u8>>,
pub kill_tx: oneshot::Sender<()>,
}
pub fn run(cmd: ListenSubcommand, opt: CommonOpt) -> Result<(), Error> {
if cmd.daemon {
// NOTE: We keep the stdin, stdout, stderr open so we can print out the pid with the parent
match daemon(false, true) {
Ok(Fork::Child) => {
let rt = tokio::runtime::Runtime::new()?;
rt.block_on(async { run_async(cmd, opt, true).await })?;
}
Ok(Fork::Parent(pid)) => {
info!("[distant detached, pid = {}]", pid);
if let Err(_) = fork::close_fd() {
return Err(Error::ForkError);
}
}
Err(_) => return Err(Error::ForkError),
}
} else {
let rt = tokio::runtime::Runtime::new()?;
rt.block_on(async { run_async(cmd, opt, false).await })?;
}
Ok(())
}
async fn run_async(cmd: ListenSubcommand, _opt: CommonOpt, is_forked: bool) -> Result<(), Error> {
let addr = cmd.host.to_ip_addr(cmd.use_ipv6)?;
let socket_addrs = cmd.port.make_socket_addrs(addr);
debug!("Binding to {} in range {}", addr, cmd.port);
let listener = TcpListener::bind(socket_addrs.as_slice()).await?;
let port = listener.local_addr()?.port();
debug!("Bound to port: {}", port);
let key = Arc::new(SecretKey::default());
// Print information about port, key, etc. unless told not to
if !cmd.no_print_startup_data {
publish_data(port, &key);
}
// For the child, we want to fully disconnect it from pipes, which we do now
if is_forked {
if let Err(_) = fork::close_fd() {
return Err(Error::ForkError);
}
}
// Build our state for the server
let state = Arc::new(Mutex::new(State::default()));
// Wait for a client connection, then spawn a new task to handle
// receiving data from the client
while let Ok((client, _)) = listener.accept().await {
// Grab the client's remote address for later logging purposes
let addr_string = match client.peer_addr() {
Ok(addr) => {
let addr_string = addr.to_string();
info!("<Client @ {}> Established connection", addr_string);
addr_string
}
Err(x) => {
error!("Unable to examine client's peer address: {}", x);
"???".to_string()
}
};
// Create a unique id for the client
let id = rand::random();
// Build a transport around the client, splitting into read and write halves so we can
// handle input and output concurrently
let (t_read, t_write) = Transport::new(client, Arc::clone(&key)).split();
let (tx, rx) = mpsc::channel(cmd.max_msg_capacity as usize);
// Spawn a new task that loops to handle requests from the client
tokio::spawn({
let f = request_loop(id, addr_string.to_string(), Arc::clone(&state), t_read, tx);
let state = Arc::clone(&state);
async move {
f.await;
state.lock().await.cleanup_client(id).await;
}
});
// Spawn a new task that loops to handle responses to the client
tokio::spawn(async move { response_loop(addr_string, t_write, rx).await });
}
Ok(())
}
/// Repeatedly reads in new requests, processes them, and sends their responses to the
/// response loop
async fn request_loop(
id: usize,
addr: String,
state: Arc<Mutex<State>>,
mut transport: TransportReadHalf,
tx: mpsc::Sender<Response>,
) {
loop {
match transport.receive::<Request>().await {
Ok(Some(req)) => {
trace!(
"<Client @ {}> Received request of type {}",
addr.as_str(),
req.payload.as_ref()
);
if let Err(x) = handler::process(id, Arc::clone(&state), req, tx.clone()).await {
error!("<Client @ {}> {}", addr.as_str(), x);
break;
}
}
Ok(None) => {
info!("<Client @ {}> Closed connection", addr.as_str());
break;
}
Err(x) => {
error!("<Client @ {}> {}", addr.as_str(), x);
break;
}
}
}
}
async fn response_loop(
addr: String,
mut transport: TransportWriteHalf,
mut rx: mpsc::Receiver<Response>,
) {
while let Some(res) = rx.recv().await {
if let Err(x) = transport.send(res).await {
error!("<Client @ {}> {}", addr.as_str(), x);
break;
}
}
}
fn publish_data(port: u16, key: &SecretKey) {
// TODO: We have to share the key in some manner (maybe use k256 to arrive at the same key?)
// For now, we do what mosh does and print out the key knowing that this is shared over
// ssh, which should provide security
println!(
"DISTANT DATA {} {}",
port,
hex::encode(key.unprotected_as_bytes())
);
}
Loading…
Cancel
Save