mirror of https://github.com/chipsenkbeil/distant
Refactored listener code into a handler module, wrote support to split transport into read and write halves, implemented most of backend although process run is not working yet
parent
a707523fb5
commit
24a8cf8401
@ -1,290 +0,0 @@
|
||||
use crate::{
|
||||
data::{DirEntry, FileType, Request, RequestPayload, Response, ResponsePayload},
|
||||
net::Transport,
|
||||
opt::{CommonOpt, ConvertToIpAddrError, ListenSubcommand},
|
||||
};
|
||||
use derive_more::{Display, Error, From};
|
||||
use fork::{daemon, Fork};
|
||||
use log::*;
|
||||
use orion::aead::SecretKey;
|
||||
use std::{string::FromUtf8Error, sync::Arc};
|
||||
use tokio::{
|
||||
io::{self, AsyncWriteExt},
|
||||
net::TcpListener,
|
||||
};
|
||||
use walkdir::WalkDir;
|
||||
|
||||
#[derive(Debug, Display, Error, From)]
|
||||
pub enum Error {
|
||||
ConvertToIpAddrError(ConvertToIpAddrError),
|
||||
ForkError,
|
||||
IoError(io::Error),
|
||||
Utf8Error(FromUtf8Error),
|
||||
}
|
||||
|
||||
pub fn run(cmd: ListenSubcommand, opt: CommonOpt) -> Result<(), Error> {
|
||||
if cmd.daemon {
|
||||
// NOTE: We keep the stdin, stdout, stderr open so we can print out the pid with the parent
|
||||
match daemon(false, true) {
|
||||
Ok(Fork::Child) => {
|
||||
let rt = tokio::runtime::Runtime::new()?;
|
||||
rt.block_on(async { run_async(cmd, opt, true).await })?;
|
||||
}
|
||||
Ok(Fork::Parent(pid)) => {
|
||||
info!("[distant detached, pid = {}]", pid);
|
||||
if let Err(_) = fork::close_fd() {
|
||||
return Err(Error::ForkError);
|
||||
}
|
||||
}
|
||||
Err(_) => return Err(Error::ForkError),
|
||||
}
|
||||
} else {
|
||||
let rt = tokio::runtime::Runtime::new()?;
|
||||
rt.block_on(async { run_async(cmd, opt, false).await })?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_async(cmd: ListenSubcommand, _opt: CommonOpt, is_forked: bool) -> Result<(), Error> {
|
||||
let addr = cmd.host.to_ip_addr(cmd.use_ipv6)?;
|
||||
let socket_addrs = cmd.port.make_socket_addrs(addr);
|
||||
|
||||
debug!("Binding to {} in range {}", addr, cmd.port);
|
||||
let listener = TcpListener::bind(socket_addrs.as_slice()).await?;
|
||||
|
||||
let port = listener.local_addr()?.port();
|
||||
debug!("Bound to port: {}", port);
|
||||
|
||||
let key = Arc::new(SecretKey::default());
|
||||
|
||||
// Print information about port, key, etc. unless told not to
|
||||
if !cmd.no_print_startup_data {
|
||||
publish_data(port, &key);
|
||||
}
|
||||
|
||||
// For the child, we want to fully disconnect it from pipes, which we do now
|
||||
if is_forked {
|
||||
if let Err(_) = fork::close_fd() {
|
||||
return Err(Error::ForkError);
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for a client connection, then spawn a new task to handle
|
||||
// receiving data from the client
|
||||
while let Ok((client, _)) = listener.accept().await {
|
||||
// Grab the client's remote address for later logging purposes
|
||||
let addr_string = match client.peer_addr() {
|
||||
Ok(addr) => {
|
||||
let addr_string = addr.to_string();
|
||||
info!("<Client @ {}> Established connection", addr_string);
|
||||
addr_string
|
||||
}
|
||||
Err(x) => {
|
||||
error!("Unable to examine client's peer address: {}", x);
|
||||
"???".to_string()
|
||||
}
|
||||
};
|
||||
|
||||
// Build a transport around the client
|
||||
let mut transport = Transport::new(client, Arc::clone(&key));
|
||||
|
||||
// Spawn a new task that loops to handle requests from the client
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
match transport.receive::<Request>().await {
|
||||
Ok(Some(request)) => {
|
||||
trace!(
|
||||
"<Client @ {}> Received request of type {}",
|
||||
addr_string.as_str(),
|
||||
request.payload.as_ref()
|
||||
);
|
||||
|
||||
// Process the request, converting any error into an error response
|
||||
let response = Response::from_payload_with_origin(
|
||||
match process_request_payload(request.payload).await {
|
||||
Ok(payload) => payload,
|
||||
Err(x) => ResponsePayload::Error {
|
||||
description: x.to_string(),
|
||||
},
|
||||
},
|
||||
request.id,
|
||||
);
|
||||
|
||||
if let Err(x) = transport.send(response).await {
|
||||
error!("<Client @ {}> {}", addr_string.as_str(), x);
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
info!("<Client @ {}> Closed connection", addr_string.as_str());
|
||||
break;
|
||||
}
|
||||
Err(x) => {
|
||||
error!("<Client @ {}> {}", addr_string.as_str(), x);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn publish_data(port: u16, key: &SecretKey) {
|
||||
// TODO: We have to share the key in some manner (maybe use k256 to arrive at the same key?)
|
||||
// For now, we do what mosh does and print out the key knowing that this is shared over
|
||||
// ssh, which should provide security
|
||||
println!(
|
||||
"DISTANT DATA {} {}",
|
||||
port,
|
||||
hex::encode(key.unprotected_as_bytes())
|
||||
);
|
||||
}
|
||||
|
||||
async fn process_request_payload(
|
||||
payload: RequestPayload,
|
||||
) -> Result<ResponsePayload, Box<dyn std::error::Error>> {
|
||||
match payload {
|
||||
RequestPayload::FileRead { path } => Ok(ResponsePayload::Blob {
|
||||
data: tokio::fs::read(path).await?,
|
||||
}),
|
||||
|
||||
RequestPayload::FileReadText { path } => Ok(ResponsePayload::Text {
|
||||
data: tokio::fs::read_to_string(path).await?,
|
||||
}),
|
||||
|
||||
RequestPayload::FileWrite {
|
||||
path,
|
||||
input: _,
|
||||
data,
|
||||
} => {
|
||||
tokio::fs::write(path, data).await?;
|
||||
Ok(ResponsePayload::Ok)
|
||||
}
|
||||
|
||||
RequestPayload::FileAppend {
|
||||
path,
|
||||
input: _,
|
||||
data,
|
||||
} => {
|
||||
let mut file = tokio::fs::OpenOptions::new()
|
||||
.append(true)
|
||||
.open(path)
|
||||
.await?;
|
||||
file.write_all(&data).await?;
|
||||
Ok(ResponsePayload::Ok)
|
||||
}
|
||||
|
||||
RequestPayload::DirRead { path, all } => {
|
||||
// Traverse, but don't include root directory in entries (hence min depth 1)
|
||||
let dir = WalkDir::new(path.as_path()).min_depth(1);
|
||||
|
||||
// If all, will recursively traverse, otherwise just return directly from dir
|
||||
let dir = if all { dir } else { dir.max_depth(1) };
|
||||
|
||||
// TODO: Support both returning errors and successfully-traversed entries
|
||||
// TODO: Support returning full paths instead of always relative?
|
||||
Ok(ResponsePayload::DirEntries {
|
||||
entries: dir
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
e.map(|e| DirEntry {
|
||||
path: e.path().strip_prefix(path.as_path()).unwrap().to_path_buf(),
|
||||
file_type: if e.file_type().is_dir() {
|
||||
FileType::Dir
|
||||
} else if e.file_type().is_file() {
|
||||
FileType::File
|
||||
} else {
|
||||
FileType::SymLink
|
||||
},
|
||||
depth: e.depth(),
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<DirEntry>, walkdir::Error>>()?,
|
||||
})
|
||||
}
|
||||
|
||||
RequestPayload::DirCreate { path, all } => {
|
||||
if all {
|
||||
tokio::fs::create_dir_all(path).await?;
|
||||
} else {
|
||||
tokio::fs::create_dir(path).await?;
|
||||
}
|
||||
|
||||
Ok(ResponsePayload::Ok)
|
||||
}
|
||||
|
||||
RequestPayload::Remove { path, force } => {
|
||||
let path_metadata = tokio::fs::metadata(path.as_path()).await?;
|
||||
if path_metadata.is_dir() {
|
||||
if force {
|
||||
tokio::fs::remove_dir_all(path).await?;
|
||||
} else {
|
||||
tokio::fs::remove_dir(path).await?;
|
||||
}
|
||||
} else {
|
||||
tokio::fs::remove_file(path).await?;
|
||||
}
|
||||
|
||||
Ok(ResponsePayload::Ok)
|
||||
}
|
||||
|
||||
RequestPayload::Copy { src, dst } => {
|
||||
let src_metadata = tokio::fs::metadata(src.as_path()).await?;
|
||||
if src_metadata.is_dir() {
|
||||
for entry in WalkDir::new(src.as_path())
|
||||
.min_depth(1)
|
||||
.follow_links(false)
|
||||
.into_iter()
|
||||
.filter_entry(|e| e.file_type().is_file() || e.path_is_symlink())
|
||||
{
|
||||
let entry = entry?;
|
||||
|
||||
// Get unique portion of path relative to src
|
||||
// NOTE: Because we are traversing files that are all within src, this
|
||||
// should always succeed
|
||||
let local_src = entry.path().strip_prefix(src.as_path()).unwrap();
|
||||
|
||||
// Get the file without any directories
|
||||
let local_src_file_name = local_src.file_name().unwrap();
|
||||
|
||||
// Get the directory housing the file
|
||||
// NOTE: Because we enforce files/symlinks, there will always be a parent
|
||||
let local_src_dir = local_src.parent().unwrap();
|
||||
|
||||
// Map out the path to the destination
|
||||
let dst_parent_dir = dst.join(local_src_dir);
|
||||
|
||||
// Create the destination directory for the file when copying
|
||||
tokio::fs::create_dir_all(dst_parent_dir.as_path()).await?;
|
||||
|
||||
// Perform copying from entry to destination
|
||||
let dst_file = dst_parent_dir.join(local_src_file_name);
|
||||
tokio::fs::copy(entry.path(), dst_file).await?;
|
||||
}
|
||||
} else {
|
||||
tokio::fs::copy(src, dst).await?;
|
||||
}
|
||||
|
||||
Ok(ResponsePayload::Ok)
|
||||
}
|
||||
|
||||
RequestPayload::Rename { src, dst } => {
|
||||
tokio::fs::rename(src, dst).await?;
|
||||
|
||||
Ok(ResponsePayload::Ok)
|
||||
}
|
||||
|
||||
RequestPayload::ProcRun { cmd, args, detach } => todo!(),
|
||||
|
||||
RequestPayload::ProcConnect { id } => todo!(),
|
||||
|
||||
RequestPayload::ProcKill { id } => todo!(),
|
||||
|
||||
RequestPayload::ProcStdin { id, data } => todo!(),
|
||||
|
||||
RequestPayload::ProcList {} => todo!(),
|
||||
}
|
||||
}
|
@ -0,0 +1,337 @@
|
||||
use super::{Process, State};
|
||||
use crate::data::{
|
||||
DirEntry, FileType, Request, RequestPayload, Response, ResponsePayload, RunningProcess,
|
||||
};
|
||||
use log::*;
|
||||
use std::{error::Error, path::PathBuf, process::Stdio, sync::Arc};
|
||||
use tokio::{
|
||||
io::{self, AsyncReadExt, AsyncWriteExt},
|
||||
process::Command,
|
||||
sync::{mpsc, oneshot, Mutex},
|
||||
};
|
||||
use walkdir::WalkDir;
|
||||
|
||||
pub type Reply = mpsc::Sender<Response>;
|
||||
type HState = Arc<Mutex<State>>;
|
||||
|
||||
/// Processes the provided request, sending replies using the given sender
|
||||
pub(super) async fn process(
|
||||
client_id: usize,
|
||||
state: HState,
|
||||
req: Request,
|
||||
tx: Reply,
|
||||
) -> Result<(), mpsc::error::SendError<Response>> {
|
||||
async fn inner(
|
||||
client_id: usize,
|
||||
state: HState,
|
||||
payload: RequestPayload,
|
||||
tx: Reply,
|
||||
) -> Result<ResponsePayload, Box<dyn std::error::Error>> {
|
||||
match payload {
|
||||
RequestPayload::FileRead { path } => file_read(path).await,
|
||||
RequestPayload::FileReadText { path } => file_read_text(path).await,
|
||||
RequestPayload::FileWrite { path, data, .. } => file_write(path, data).await,
|
||||
RequestPayload::FileAppend { path, data, .. } => file_append(path, data).await,
|
||||
RequestPayload::DirRead { path, all } => dir_read(path, all).await,
|
||||
RequestPayload::DirCreate { path, all } => dir_create(path, all).await,
|
||||
RequestPayload::Remove { path, force } => remove(path, force).await,
|
||||
RequestPayload::Copy { src, dst } => copy(src, dst).await,
|
||||
RequestPayload::Rename { src, dst } => rename(src, dst).await,
|
||||
RequestPayload::ProcRun { cmd, args, detach } => {
|
||||
proc_run(client_id, state, tx, cmd, args, detach).await
|
||||
}
|
||||
RequestPayload::ProcConnect { id } => proc_connect(id).await,
|
||||
RequestPayload::ProcKill { id } => proc_kill(state, id).await,
|
||||
RequestPayload::ProcStdin { id, data } => proc_stdin(state, id, data).await,
|
||||
RequestPayload::ProcList {} => proc_list(state).await,
|
||||
}
|
||||
}
|
||||
|
||||
let res = Response::from_payload_with_origin(
|
||||
match inner(client_id, state, req.payload, tx.clone()).await {
|
||||
Ok(payload) => payload,
|
||||
Err(x) => ResponsePayload::Error {
|
||||
description: x.to_string(),
|
||||
},
|
||||
},
|
||||
req.id,
|
||||
);
|
||||
|
||||
// Send out our primary response from processing the request
|
||||
tx.send(res).await
|
||||
}
|
||||
|
||||
async fn file_read(path: PathBuf) -> Result<ResponsePayload, Box<dyn Error>> {
|
||||
Ok(ResponsePayload::Blob {
|
||||
data: tokio::fs::read(path).await?,
|
||||
})
|
||||
}
|
||||
|
||||
async fn file_read_text(path: PathBuf) -> Result<ResponsePayload, Box<dyn Error>> {
|
||||
Ok(ResponsePayload::Text {
|
||||
data: tokio::fs::read_to_string(path).await?,
|
||||
})
|
||||
}
|
||||
|
||||
async fn file_write(path: PathBuf, data: Vec<u8>) -> Result<ResponsePayload, Box<dyn Error>> {
|
||||
tokio::fs::write(path, data).await?;
|
||||
Ok(ResponsePayload::Ok)
|
||||
}
|
||||
|
||||
async fn file_append(path: PathBuf, data: Vec<u8>) -> Result<ResponsePayload, Box<dyn Error>> {
|
||||
let mut file = tokio::fs::OpenOptions::new()
|
||||
.append(true)
|
||||
.open(path)
|
||||
.await?;
|
||||
file.write_all(&data).await?;
|
||||
Ok(ResponsePayload::Ok)
|
||||
}
|
||||
|
||||
async fn dir_read(path: PathBuf, all: bool) -> Result<ResponsePayload, Box<dyn Error>> {
|
||||
// Traverse, but don't include root directory in entries (hence min depth 1)
|
||||
let dir = WalkDir::new(path.as_path()).min_depth(1);
|
||||
|
||||
// If all, will recursively traverse, otherwise just return directly from dir
|
||||
let dir = if all { dir } else { dir.max_depth(1) };
|
||||
|
||||
// TODO: Support both returning errors and successfully-traversed entries
|
||||
// TODO: Support returning full paths instead of always relative?
|
||||
Ok(ResponsePayload::DirEntries {
|
||||
entries: dir
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
e.map(|e| DirEntry {
|
||||
path: e.path().strip_prefix(path.as_path()).unwrap().to_path_buf(),
|
||||
file_type: if e.file_type().is_dir() {
|
||||
FileType::Dir
|
||||
} else if e.file_type().is_file() {
|
||||
FileType::File
|
||||
} else {
|
||||
FileType::SymLink
|
||||
},
|
||||
depth: e.depth(),
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<DirEntry>, walkdir::Error>>()?,
|
||||
})
|
||||
}
|
||||
|
||||
async fn dir_create(path: PathBuf, all: bool) -> Result<ResponsePayload, Box<dyn Error>> {
|
||||
if all {
|
||||
tokio::fs::create_dir_all(path).await?;
|
||||
} else {
|
||||
tokio::fs::create_dir(path).await?;
|
||||
}
|
||||
|
||||
Ok(ResponsePayload::Ok)
|
||||
}
|
||||
|
||||
async fn remove(path: PathBuf, force: bool) -> Result<ResponsePayload, Box<dyn Error>> {
|
||||
let path_metadata = tokio::fs::metadata(path.as_path()).await?;
|
||||
if path_metadata.is_dir() {
|
||||
if force {
|
||||
tokio::fs::remove_dir_all(path).await?;
|
||||
} else {
|
||||
tokio::fs::remove_dir(path).await?;
|
||||
}
|
||||
} else {
|
||||
tokio::fs::remove_file(path).await?;
|
||||
}
|
||||
|
||||
Ok(ResponsePayload::Ok)
|
||||
}
|
||||
|
||||
async fn copy(src: PathBuf, dst: PathBuf) -> Result<ResponsePayload, Box<dyn Error>> {
|
||||
let src_metadata = tokio::fs::metadata(src.as_path()).await?;
|
||||
if src_metadata.is_dir() {
|
||||
for entry in WalkDir::new(src.as_path())
|
||||
.min_depth(1)
|
||||
.follow_links(false)
|
||||
.into_iter()
|
||||
.filter_entry(|e| e.file_type().is_file() || e.path_is_symlink())
|
||||
{
|
||||
let entry = entry?;
|
||||
|
||||
// Get unique portion of path relative to src
|
||||
// NOTE: Because we are traversing files that are all within src, this
|
||||
// should always succeed
|
||||
let local_src = entry.path().strip_prefix(src.as_path()).unwrap();
|
||||
|
||||
// Get the file without any directories
|
||||
let local_src_file_name = local_src.file_name().unwrap();
|
||||
|
||||
// Get the directory housing the file
|
||||
// NOTE: Because we enforce files/symlinks, there will always be a parent
|
||||
let local_src_dir = local_src.parent().unwrap();
|
||||
|
||||
// Map out the path to the destination
|
||||
let dst_parent_dir = dst.join(local_src_dir);
|
||||
|
||||
// Create the destination directory for the file when copying
|
||||
tokio::fs::create_dir_all(dst_parent_dir.as_path()).await?;
|
||||
|
||||
// Perform copying from entry to destination
|
||||
let dst_file = dst_parent_dir.join(local_src_file_name);
|
||||
tokio::fs::copy(entry.path(), dst_file).await?;
|
||||
}
|
||||
} else {
|
||||
tokio::fs::copy(src, dst).await?;
|
||||
}
|
||||
|
||||
Ok(ResponsePayload::Ok)
|
||||
}
|
||||
|
||||
async fn rename(src: PathBuf, dst: PathBuf) -> Result<ResponsePayload, Box<dyn Error>> {
|
||||
tokio::fs::rename(src, dst).await?;
|
||||
|
||||
Ok(ResponsePayload::Ok)
|
||||
}
|
||||
|
||||
async fn proc_run(
|
||||
client_id: usize,
|
||||
state: HState,
|
||||
tx: Reply,
|
||||
cmd: String,
|
||||
args: Vec<String>,
|
||||
detach: bool,
|
||||
) -> Result<ResponsePayload, Box<dyn Error>> {
|
||||
let id = rand::random();
|
||||
|
||||
let mut child = Command::new(cmd.to_string())
|
||||
.args(args.clone())
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()?;
|
||||
|
||||
// Spawn a task that sends stdout as a response
|
||||
let tx_2 = tx.clone();
|
||||
let mut stdout = child.stdout.take().unwrap();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let mut data = Vec::new();
|
||||
match stdout.read_to_end(&mut data).await {
|
||||
Ok(_) => {
|
||||
if let Err(_) = tx_2
|
||||
.send(Response::from(ResponsePayload::ProcStdout { id, data }))
|
||||
.await
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Spawn a task that sends stderr as a response
|
||||
let mut stderr = child.stderr.take().unwrap();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let mut data = Vec::new();
|
||||
match stderr.read_to_end(&mut data).await {
|
||||
Ok(_) => {
|
||||
if let Err(_) = tx
|
||||
.send(Response::from(ResponsePayload::ProcStderr { id, data }))
|
||||
.await
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Spawn a task that sends stdin to the process
|
||||
// TODO: Should this be configurable?
|
||||
let mut stdin = child.stdin.take().unwrap();
|
||||
let (stdin_tx, mut stdin_rx) = mpsc::channel::<Vec<u8>>(1);
|
||||
tokio::spawn(async move {
|
||||
while let Some(data) = stdin_rx.recv().await {
|
||||
if let Err(x) = stdin.write_all(&data).await {
|
||||
error!("Failed to send stdin to process {}: {}", id, x);
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Spawn a task that kills the process when triggered
|
||||
let (kill_tx, kill_rx) = oneshot::channel();
|
||||
tokio::spawn(async move {
|
||||
let _ = kill_rx.await;
|
||||
if let Err(x) = child.kill().await {
|
||||
error!("Unable to kill process {}: {}", id, x);
|
||||
}
|
||||
});
|
||||
|
||||
// Update our state with the new process
|
||||
let process = Process {
|
||||
cmd,
|
||||
args,
|
||||
id,
|
||||
stdin_tx,
|
||||
kill_tx,
|
||||
};
|
||||
state.lock().await.processes.insert(id, process);
|
||||
|
||||
// If we are not detaching from process, we want to associate it with our client
|
||||
if !detach {
|
||||
state
|
||||
.lock()
|
||||
.await
|
||||
.client_processes
|
||||
.entry(client_id)
|
||||
.or_insert(Vec::new())
|
||||
.push(id);
|
||||
}
|
||||
|
||||
Ok(ResponsePayload::ProcStart { id })
|
||||
}
|
||||
|
||||
async fn proc_connect(id: usize) -> Result<ResponsePayload, Box<dyn Error>> {
|
||||
todo!();
|
||||
}
|
||||
|
||||
async fn proc_kill(state: HState, id: usize) -> Result<ResponsePayload, Box<dyn Error>> {
|
||||
if let Some(process) = state.lock().await.processes.remove(&id) {
|
||||
process.kill_tx.send(()).map_err(|_| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::BrokenPipe,
|
||||
"Unable to send kill signal to process",
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(ResponsePayload::Ok)
|
||||
}
|
||||
|
||||
async fn proc_stdin(
|
||||
state: HState,
|
||||
id: usize,
|
||||
data: Vec<u8>,
|
||||
) -> Result<ResponsePayload, Box<dyn Error>> {
|
||||
if let Some(process) = state.lock().await.processes.get(&id) {
|
||||
process.stdin_tx.send(data).await.map_err(|_| {
|
||||
io::Error::new(io::ErrorKind::BrokenPipe, "Unable to send stdin to process")
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(ResponsePayload::Ok)
|
||||
}
|
||||
|
||||
async fn proc_list(state: HState) -> Result<ResponsePayload, Box<dyn Error>> {
|
||||
Ok(ResponsePayload::ProcEntries {
|
||||
entries: state
|
||||
.lock()
|
||||
.await
|
||||
.processes
|
||||
.values()
|
||||
.map(|p| RunningProcess {
|
||||
cmd: p.cmd.to_string(),
|
||||
args: p.args.clone(),
|
||||
id: p.id,
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
}
|
@ -0,0 +1,213 @@
|
||||
use crate::{
|
||||
data::{Request, Response},
|
||||
net::{Transport, TransportReadHalf, TransportWriteHalf},
|
||||
opt::{CommonOpt, ConvertToIpAddrError, ListenSubcommand},
|
||||
};
|
||||
use derive_more::{Display, Error, From};
|
||||
use fork::{daemon, Fork};
|
||||
use log::*;
|
||||
use orion::aead::SecretKey;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tokio::{
|
||||
io,
|
||||
net::TcpListener,
|
||||
sync::{mpsc, oneshot, Mutex},
|
||||
};
|
||||
|
||||
mod handler;
|
||||
|
||||
#[derive(Debug, Display, Error, From)]
|
||||
pub enum Error {
|
||||
ConvertToIpAddrError(ConvertToIpAddrError),
|
||||
ForkError,
|
||||
IoError(io::Error),
|
||||
}
|
||||
|
||||
/// Holds state relevant to the server
|
||||
#[derive(Default)]
|
||||
struct State {
|
||||
/// Map of all processes running on the server
|
||||
processes: HashMap<usize, Process>,
|
||||
|
||||
/// List of processes that will be killed when a client drops
|
||||
client_processes: HashMap<usize, Vec<usize>>,
|
||||
}
|
||||
|
||||
impl State {
|
||||
/// Cleans up state associated with a particular client
|
||||
pub async fn cleanup_client(&mut self, id: usize) {
|
||||
if let Some(ids) = self.client_processes.remove(&id) {
|
||||
for id in ids {
|
||||
if let Some(process) = self.processes.remove(&id) {
|
||||
if let Err(_) = process.kill_tx.send(()) {
|
||||
error!(
|
||||
"Client {} failed to send process {} kill signal",
|
||||
id, process.id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents an actively-running process maintained by the server
|
||||
struct Process {
|
||||
pub id: usize,
|
||||
pub cmd: String,
|
||||
pub args: Vec<String>,
|
||||
pub stdin_tx: mpsc::Sender<Vec<u8>>,
|
||||
pub kill_tx: oneshot::Sender<()>,
|
||||
}
|
||||
|
||||
pub fn run(cmd: ListenSubcommand, opt: CommonOpt) -> Result<(), Error> {
|
||||
if cmd.daemon {
|
||||
// NOTE: We keep the stdin, stdout, stderr open so we can print out the pid with the parent
|
||||
match daemon(false, true) {
|
||||
Ok(Fork::Child) => {
|
||||
let rt = tokio::runtime::Runtime::new()?;
|
||||
rt.block_on(async { run_async(cmd, opt, true).await })?;
|
||||
}
|
||||
Ok(Fork::Parent(pid)) => {
|
||||
info!("[distant detached, pid = {}]", pid);
|
||||
if let Err(_) = fork::close_fd() {
|
||||
return Err(Error::ForkError);
|
||||
}
|
||||
}
|
||||
Err(_) => return Err(Error::ForkError),
|
||||
}
|
||||
} else {
|
||||
let rt = tokio::runtime::Runtime::new()?;
|
||||
rt.block_on(async { run_async(cmd, opt, false).await })?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_async(cmd: ListenSubcommand, _opt: CommonOpt, is_forked: bool) -> Result<(), Error> {
|
||||
let addr = cmd.host.to_ip_addr(cmd.use_ipv6)?;
|
||||
let socket_addrs = cmd.port.make_socket_addrs(addr);
|
||||
|
||||
debug!("Binding to {} in range {}", addr, cmd.port);
|
||||
let listener = TcpListener::bind(socket_addrs.as_slice()).await?;
|
||||
|
||||
let port = listener.local_addr()?.port();
|
||||
debug!("Bound to port: {}", port);
|
||||
|
||||
let key = Arc::new(SecretKey::default());
|
||||
|
||||
// Print information about port, key, etc. unless told not to
|
||||
if !cmd.no_print_startup_data {
|
||||
publish_data(port, &key);
|
||||
}
|
||||
|
||||
// For the child, we want to fully disconnect it from pipes, which we do now
|
||||
if is_forked {
|
||||
if let Err(_) = fork::close_fd() {
|
||||
return Err(Error::ForkError);
|
||||
}
|
||||
}
|
||||
|
||||
// Build our state for the server
|
||||
let state = Arc::new(Mutex::new(State::default()));
|
||||
|
||||
// Wait for a client connection, then spawn a new task to handle
|
||||
// receiving data from the client
|
||||
while let Ok((client, _)) = listener.accept().await {
|
||||
// Grab the client's remote address for later logging purposes
|
||||
let addr_string = match client.peer_addr() {
|
||||
Ok(addr) => {
|
||||
let addr_string = addr.to_string();
|
||||
info!("<Client @ {}> Established connection", addr_string);
|
||||
addr_string
|
||||
}
|
||||
Err(x) => {
|
||||
error!("Unable to examine client's peer address: {}", x);
|
||||
"???".to_string()
|
||||
}
|
||||
};
|
||||
|
||||
// Create a unique id for the client
|
||||
let id = rand::random();
|
||||
|
||||
// Build a transport around the client, splitting into read and write halves so we can
|
||||
// handle input and output concurrently
|
||||
let (t_read, t_write) = Transport::new(client, Arc::clone(&key)).split();
|
||||
let (tx, rx) = mpsc::channel(cmd.max_msg_capacity as usize);
|
||||
|
||||
// Spawn a new task that loops to handle requests from the client
|
||||
tokio::spawn({
|
||||
let f = request_loop(id, addr_string.to_string(), Arc::clone(&state), t_read, tx);
|
||||
|
||||
let state = Arc::clone(&state);
|
||||
async move {
|
||||
f.await;
|
||||
state.lock().await.cleanup_client(id).await;
|
||||
}
|
||||
});
|
||||
|
||||
// Spawn a new task that loops to handle responses to the client
|
||||
tokio::spawn(async move { response_loop(addr_string, t_write, rx).await });
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Repeatedly reads in new requests, processes them, and sends their responses to the
|
||||
/// response loop
|
||||
async fn request_loop(
|
||||
id: usize,
|
||||
addr: String,
|
||||
state: Arc<Mutex<State>>,
|
||||
mut transport: TransportReadHalf,
|
||||
tx: mpsc::Sender<Response>,
|
||||
) {
|
||||
loop {
|
||||
match transport.receive::<Request>().await {
|
||||
Ok(Some(req)) => {
|
||||
trace!(
|
||||
"<Client @ {}> Received request of type {}",
|
||||
addr.as_str(),
|
||||
req.payload.as_ref()
|
||||
);
|
||||
|
||||
if let Err(x) = handler::process(id, Arc::clone(&state), req, tx.clone()).await {
|
||||
error!("<Client @ {}> {}", addr.as_str(), x);
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
info!("<Client @ {}> Closed connection", addr.as_str());
|
||||
break;
|
||||
}
|
||||
Err(x) => {
|
||||
error!("<Client @ {}> {}", addr.as_str(), x);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn response_loop(
|
||||
addr: String,
|
||||
mut transport: TransportWriteHalf,
|
||||
mut rx: mpsc::Receiver<Response>,
|
||||
) {
|
||||
while let Some(res) = rx.recv().await {
|
||||
if let Err(x) = transport.send(res).await {
|
||||
error!("<Client @ {}> {}", addr.as_str(), x);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn publish_data(port: u16, key: &SecretKey) {
|
||||
// TODO: We have to share the key in some manner (maybe use k256 to arrive at the same key?)
|
||||
// For now, we do what mosh does and print out the key knowing that this is shared over
|
||||
// ssh, which should provide security
|
||||
println!(
|
||||
"DISTANT DATA {} {}",
|
||||
port,
|
||||
hex::encode(key.unprotected_as_bytes())
|
||||
);
|
||||
}
|
Loading…
Reference in New Issue