mirror of https://github.com/chipsenkbeil/distant
Add native ssh (#57)
* Bump to 0.15.0 * Add new distant-ssh2 subcrate to provide an alternate session as an ssh client * Add rpassword & wezterm-ssh dependencies * Rename core -> distant-core in project directory structure and move ssh2 feature into distant-ssh2 crate * Upgrade tokio to 1.12, * Update github actions to detect changes and apply testing for only those changes * Add method parameter to support distant & ssh methods for action and lsp subcommands * Add ssh-host, ssh-port, and ssh-user parameters to specify information for ssh methodpull/59/head
parent
32639166bc
commit
0a11ec65a2
@ -1,3 +1,4 @@
|
||||
/target
|
||||
.DS_Store
|
||||
/core/Cargo.lock
|
||||
**/.DS_Store
|
||||
/distant-core/Cargo.lock
|
||||
/distant-ssh2/Cargo.lock
|
||||
|
@ -0,0 +1,32 @@
|
||||
[package]
|
||||
name = "distant-ssh2"
|
||||
description = "Library to enable native ssh-2 protocol for use with distant sessions"
|
||||
categories = ["network-programming"]
|
||||
version = "0.15.0"
|
||||
authors = ["Chip Senkbeil <chip@senkbeil.org>"]
|
||||
edition = "2018"
|
||||
homepage = "https://github.com/chipsenkbeil/distant"
|
||||
repository = "https://github.com/chipsenkbeil/distant"
|
||||
readme = "README.md"
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
async-compat = "0.2.1"
|
||||
distant-core = { version = "=0.15.0", path = "../distant-core" }
|
||||
futures = "0.3.16"
|
||||
log = "0.4.14"
|
||||
rand = { version = "0.8.4", features = ["getrandom"] }
|
||||
rpassword = "5.0.1"
|
||||
smol = "1.2"
|
||||
tokio = { version = "1.12.0", features = ["full"] }
|
||||
wezterm-ssh = { version = "0.2.0", features = ["vendored-openssl"], git = "https://github.com/chipsenkbeil/wezterm" }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = "2.0.0"
|
||||
assert_fs = "1.0.4"
|
||||
flexi_logger = "0.19.4"
|
||||
indoc = "1.0.3"
|
||||
once_cell = "1.8.0"
|
||||
predicates = "2.0.2"
|
||||
rstest = "0.11.0"
|
||||
whoami = "1.1.4"
|
@ -0,0 +1,881 @@
|
||||
use async_compat::CompatExt;
|
||||
use distant_core::{
|
||||
data::{DirEntry, Error as DistantError, FileType, RunningProcess},
|
||||
Request, RequestData, Response, ResponseData,
|
||||
};
|
||||
use futures::future;
|
||||
use log::*;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
future::Future,
|
||||
io::{self, Read, Write},
|
||||
path::{Component, Path, PathBuf},
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
};
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use wezterm_ssh::{Child, ExecResult, OpenFileType, OpenOptions, Session as WezSession, WriteMode};
|
||||
|
||||
const MAX_PIPE_CHUNK_SIZE: usize = 8192;
|
||||
const READ_PAUSE_MILLIS: u64 = 50;
|
||||
|
||||
fn to_other_error<E>(err: E) -> io::Error
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||
{
|
||||
io::Error::new(io::ErrorKind::Other, err)
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct State {
|
||||
processes: HashMap<usize, Process>,
|
||||
}
|
||||
|
||||
struct Process {
|
||||
id: usize,
|
||||
cmd: String,
|
||||
args: Vec<String>,
|
||||
stdin_tx: mpsc::Sender<String>,
|
||||
kill_tx: mpsc::Sender<()>,
|
||||
}
|
||||
|
||||
type ReplyRet = Pin<Box<dyn Future<Output = bool> + Send + 'static>>;
|
||||
|
||||
type PostHook = Box<dyn FnOnce() + Send>;
|
||||
struct Outgoing {
|
||||
data: ResponseData,
|
||||
post_hook: Option<PostHook>,
|
||||
}
|
||||
|
||||
impl From<ResponseData> for Outgoing {
|
||||
fn from(data: ResponseData) -> Self {
|
||||
Self {
|
||||
data,
|
||||
post_hook: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Processes the provided request, sending replies using the given sender
|
||||
pub(super) async fn process(
|
||||
session: WezSession,
|
||||
state: Arc<Mutex<State>>,
|
||||
req: Request,
|
||||
tx: mpsc::Sender<Response>,
|
||||
) -> Result<(), mpsc::error::SendError<Response>> {
|
||||
async fn inner<F>(
|
||||
session: WezSession,
|
||||
state: Arc<Mutex<State>>,
|
||||
data: RequestData,
|
||||
reply: F,
|
||||
) -> io::Result<Outgoing>
|
||||
where
|
||||
F: FnMut(Vec<ResponseData>) -> ReplyRet + Clone + Send + 'static,
|
||||
{
|
||||
match data {
|
||||
RequestData::FileRead { path } => file_read(session, path).await,
|
||||
RequestData::FileReadText { path } => file_read_text(session, path).await,
|
||||
RequestData::FileWrite { path, data } => file_write(session, path, data).await,
|
||||
RequestData::FileWriteText { path, text } => file_write(session, path, text).await,
|
||||
RequestData::FileAppend { path, data } => file_append(session, path, data).await,
|
||||
RequestData::FileAppendText { path, text } => file_append(session, path, text).await,
|
||||
RequestData::DirRead {
|
||||
path,
|
||||
depth,
|
||||
absolute,
|
||||
canonicalize,
|
||||
include_root,
|
||||
} => dir_read(session, path, depth, absolute, canonicalize, include_root).await,
|
||||
RequestData::DirCreate { path, all } => dir_create(session, path, all).await,
|
||||
RequestData::Remove { path, force } => remove(session, path, force).await,
|
||||
RequestData::Copy { src, dst } => copy(session, src, dst).await,
|
||||
RequestData::Rename { src, dst } => rename(session, src, dst).await,
|
||||
RequestData::Exists { path } => exists(session, path).await,
|
||||
RequestData::Metadata {
|
||||
path,
|
||||
canonicalize,
|
||||
resolve_file_type,
|
||||
} => metadata(session, path, canonicalize, resolve_file_type).await,
|
||||
RequestData::ProcRun { cmd, args } => proc_run(session, state, reply, cmd, args).await,
|
||||
RequestData::ProcKill { id } => proc_kill(session, state, id).await,
|
||||
RequestData::ProcStdin { id, data } => proc_stdin(session, state, id, data).await,
|
||||
RequestData::ProcList {} => proc_list(session, state).await,
|
||||
RequestData::SystemInfo {} => system_info(session).await,
|
||||
}
|
||||
}
|
||||
|
||||
let reply = {
|
||||
let origin_id = req.id;
|
||||
let tenant = req.tenant.clone();
|
||||
let tx_2 = tx.clone();
|
||||
move |payload: Vec<ResponseData>| -> ReplyRet {
|
||||
let tx = tx_2.clone();
|
||||
let res = Response::new(tenant.to_string(), origin_id, payload);
|
||||
Box::pin(async move { tx.send(res).await.is_ok() })
|
||||
}
|
||||
};
|
||||
|
||||
// Build up a collection of tasks to run independently
|
||||
let mut payload_tasks = Vec::new();
|
||||
for data in req.payload {
|
||||
let state_2 = Arc::clone(&state);
|
||||
let reply_2 = reply.clone();
|
||||
let session = session.clone();
|
||||
payload_tasks.push(tokio::spawn(async move {
|
||||
match inner(session, state_2, data, reply_2).await {
|
||||
Ok(outgoing) => outgoing,
|
||||
Err(x) => Outgoing::from(ResponseData::from(x)),
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
// Collect the results of our tasks into the payload entries
|
||||
let mut outgoing: Vec<Outgoing> = future::join_all(payload_tasks)
|
||||
.await
|
||||
.into_iter()
|
||||
.map(|x| match x {
|
||||
Ok(outgoing) => outgoing,
|
||||
Err(x) => Outgoing::from(ResponseData::from(x)),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let post_hooks: Vec<PostHook> = outgoing
|
||||
.iter_mut()
|
||||
.filter_map(|x| x.post_hook.take())
|
||||
.collect();
|
||||
|
||||
let payload = outgoing.into_iter().map(|x| x.data).collect();
|
||||
let res = Response::new(req.tenant, req.id, payload);
|
||||
// Send out our primary response from processing the request
|
||||
let result = tx.send(res).await;
|
||||
|
||||
// Invoke all post hooks
|
||||
for hook in post_hooks {
|
||||
hook();
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
async fn file_read(session: WezSession, path: PathBuf) -> io::Result<Outgoing> {
|
||||
use smol::io::AsyncReadExt;
|
||||
let mut file = session
|
||||
.sftp()
|
||||
.open(path)
|
||||
.compat()
|
||||
.await
|
||||
.map_err(to_other_error)?;
|
||||
|
||||
let mut contents = String::new();
|
||||
file.read_to_string(&mut contents).compat().await?;
|
||||
|
||||
Ok(Outgoing::from(ResponseData::Blob {
|
||||
data: contents.into_bytes(),
|
||||
}))
|
||||
}
|
||||
|
||||
async fn file_read_text(session: WezSession, path: PathBuf) -> io::Result<Outgoing> {
|
||||
use smol::io::AsyncReadExt;
|
||||
let mut file = session
|
||||
.sftp()
|
||||
.open(path)
|
||||
.compat()
|
||||
.await
|
||||
.map_err(to_other_error)?;
|
||||
|
||||
let mut contents = String::new();
|
||||
file.read_to_string(&mut contents).compat().await?;
|
||||
|
||||
Ok(Outgoing::from(ResponseData::Text { data: contents }))
|
||||
}
|
||||
|
||||
async fn file_write(
|
||||
session: WezSession,
|
||||
path: PathBuf,
|
||||
data: impl AsRef<[u8]>,
|
||||
) -> io::Result<Outgoing> {
|
||||
use smol::io::AsyncWriteExt;
|
||||
let mut file = session
|
||||
.sftp()
|
||||
.create(path)
|
||||
.compat()
|
||||
.await
|
||||
.map_err(to_other_error)?;
|
||||
|
||||
file.write_all(data.as_ref()).compat().await?;
|
||||
|
||||
Ok(Outgoing::from(ResponseData::Ok))
|
||||
}
|
||||
|
||||
async fn file_append(
|
||||
session: WezSession,
|
||||
path: PathBuf,
|
||||
data: impl AsRef<[u8]>,
|
||||
) -> io::Result<Outgoing> {
|
||||
use smol::io::AsyncWriteExt;
|
||||
let mut file = session
|
||||
.sftp()
|
||||
.open_mode(
|
||||
path,
|
||||
OpenOptions {
|
||||
read: false,
|
||||
write: Some(WriteMode::Append),
|
||||
// Using 644 as this mirrors "ssh <host> touch ..."
|
||||
// 644: rw-r--r--
|
||||
mode: 0o644,
|
||||
ty: OpenFileType::File,
|
||||
},
|
||||
)
|
||||
.compat()
|
||||
.await
|
||||
.map_err(to_other_error)?;
|
||||
|
||||
file.write_all(data.as_ref()).compat().await?;
|
||||
|
||||
Ok(Outgoing::from(ResponseData::Ok))
|
||||
}
|
||||
|
||||
async fn dir_read(
|
||||
session: WezSession,
|
||||
path: PathBuf,
|
||||
depth: usize,
|
||||
absolute: bool,
|
||||
canonicalize: bool,
|
||||
include_root: bool,
|
||||
) -> io::Result<Outgoing> {
|
||||
let sftp = session.sftp();
|
||||
|
||||
// Canonicalize our provided path to ensure that it is exists, not a loop, and absolute
|
||||
let root_path = sftp.realpath(path).compat().await.map_err(to_other_error)?;
|
||||
|
||||
// Build up our entry list
|
||||
let mut entries = Vec::new();
|
||||
let mut errors = Vec::new();
|
||||
|
||||
let mut to_traverse = vec![DirEntry {
|
||||
path: root_path.to_path_buf(),
|
||||
file_type: FileType::Dir,
|
||||
depth: 0,
|
||||
}];
|
||||
|
||||
while let Some(entry) = to_traverse.pop() {
|
||||
let is_root = entry.depth == 0;
|
||||
let next_depth = entry.depth + 1;
|
||||
let ft = entry.file_type;
|
||||
let path = if entry.path.is_relative() {
|
||||
root_path.join(&entry.path)
|
||||
} else {
|
||||
entry.path.to_path_buf()
|
||||
};
|
||||
|
||||
// Always include any non-root in our traverse list, but only include the
|
||||
// root directory if flagged to do so
|
||||
if !is_root || include_root {
|
||||
entries.push(entry);
|
||||
}
|
||||
|
||||
let is_dir = match ft {
|
||||
FileType::Dir => true,
|
||||
FileType::File => false,
|
||||
FileType::Symlink => match sftp.stat(&path).await {
|
||||
Ok(stat) => stat.is_dir(),
|
||||
Err(x) => {
|
||||
errors.push(DistantError::from(to_other_error(x)));
|
||||
continue;
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
// Determine if we continue traversing or stop
|
||||
if is_dir && (depth == 0 || next_depth <= depth) {
|
||||
match sftp.readdir(&path).compat().await.map_err(to_other_error) {
|
||||
Ok(entries) => {
|
||||
for (mut path, stat) in entries {
|
||||
// Canonicalize the path if specified, otherwise just return
|
||||
// the path as is
|
||||
path = if canonicalize {
|
||||
match sftp.realpath(path).compat().await {
|
||||
Ok(path) => path,
|
||||
Err(x) => {
|
||||
errors.push(DistantError::from(to_other_error(x)));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
path
|
||||
};
|
||||
|
||||
// Strip the path of its prefix based if not flagged as absolute
|
||||
if !absolute {
|
||||
// NOTE: In the situation where we canonicalized the path earlier,
|
||||
// there is no guarantee that our root path is still the parent of
|
||||
// the symlink's destination; so, in that case we MUST just return
|
||||
// the path if the strip_prefix fails
|
||||
path = path
|
||||
.strip_prefix(root_path.as_path())
|
||||
.map(Path::to_path_buf)
|
||||
.unwrap_or(path);
|
||||
};
|
||||
|
||||
let ft = stat.ty;
|
||||
to_traverse.push(DirEntry {
|
||||
path,
|
||||
file_type: if ft.is_dir() {
|
||||
FileType::Dir
|
||||
} else if ft.is_file() {
|
||||
FileType::File
|
||||
} else {
|
||||
FileType::Symlink
|
||||
},
|
||||
depth: next_depth,
|
||||
});
|
||||
}
|
||||
}
|
||||
Err(x) if is_root => return Err(io::Error::new(io::ErrorKind::Other, x)),
|
||||
Err(x) => errors.push(DistantError::from(x)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort entries by filename
|
||||
entries.sort_unstable_by_key(|e| e.path.to_path_buf());
|
||||
|
||||
Ok(Outgoing::from(ResponseData::DirEntries { entries, errors }))
|
||||
}
|
||||
|
||||
async fn dir_create(session: WezSession, path: PathBuf, all: bool) -> io::Result<Outgoing> {
|
||||
let sftp = session.sftp();
|
||||
|
||||
// Makes the immediate directory, failing if given a path with missing components
|
||||
async fn mkdir(sftp: &wezterm_ssh::Sftp, path: PathBuf) -> io::Result<()> {
|
||||
// Using 755 as this mirrors "ssh <host> mkdir ..."
|
||||
// 755: rwxr-xr-x
|
||||
sftp.mkdir(path, 0o755)
|
||||
.compat()
|
||||
.await
|
||||
.map_err(to_other_error)
|
||||
}
|
||||
|
||||
if all {
|
||||
// Keep trying to create a directory, moving up to parent each time a failure happens
|
||||
let mut failed_paths = Vec::new();
|
||||
let mut cur_path = path.as_path();
|
||||
loop {
|
||||
let failed = mkdir(&sftp, cur_path.to_path_buf()).await.is_err();
|
||||
if failed {
|
||||
failed_paths.push(cur_path);
|
||||
if let Some(path) = cur_path.parent() {
|
||||
cur_path = path;
|
||||
} else {
|
||||
return Err(io::Error::from(io::ErrorKind::PermissionDenied));
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Now that we've successfully created a parent component (or the directory), proceed
|
||||
// to attempt to create each failed directory
|
||||
while let Some(path) = failed_paths.pop() {
|
||||
mkdir(&sftp, path.to_path_buf()).await?;
|
||||
}
|
||||
} else {
|
||||
mkdir(&sftp, path).await?;
|
||||
}
|
||||
|
||||
Ok(Outgoing::from(ResponseData::Ok))
|
||||
}
|
||||
|
||||
async fn remove(session: WezSession, path: PathBuf, force: bool) -> io::Result<Outgoing> {
|
||||
let sftp = session.sftp();
|
||||
|
||||
// Determine if we are dealing with a file or directory
|
||||
let stat = sftp
|
||||
.stat(path.to_path_buf())
|
||||
.compat()
|
||||
.await
|
||||
.map_err(to_other_error)?;
|
||||
|
||||
// If a file or symlink, we just unlink (easy)
|
||||
if stat.is_file() || stat.is_symlink() {
|
||||
sftp.unlink(path)
|
||||
.compat()
|
||||
.await
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::PermissionDenied, x))?;
|
||||
// If directory and not forcing, we just rmdir (easy)
|
||||
} else if !force {
|
||||
sftp.rmdir(path)
|
||||
.compat()
|
||||
.await
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::PermissionDenied, x))?;
|
||||
// Otherwise, we need to find all files and directories, keep track of their depth, and
|
||||
// then attempt to remove them all
|
||||
} else {
|
||||
let mut entries = Vec::new();
|
||||
let mut to_traverse = vec![DirEntry {
|
||||
path,
|
||||
file_type: FileType::Dir,
|
||||
depth: 0,
|
||||
}];
|
||||
|
||||
// Collect all entries within directory
|
||||
while let Some(entry) = to_traverse.pop() {
|
||||
if entry.file_type == FileType::Dir {
|
||||
let path = entry.path.to_path_buf();
|
||||
let depth = entry.depth;
|
||||
|
||||
entries.push(entry);
|
||||
|
||||
for (path, stat) in sftp.readdir(path).await.map_err(to_other_error)? {
|
||||
to_traverse.push(DirEntry {
|
||||
path,
|
||||
file_type: if stat.is_dir() {
|
||||
FileType::Dir
|
||||
} else if stat.is_file() {
|
||||
FileType::File
|
||||
} else {
|
||||
FileType::Symlink
|
||||
},
|
||||
depth: depth + 1,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
entries.push(entry);
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by depth such that deepest are last as we will be popping
|
||||
// off entries from end to remove first
|
||||
entries.sort_unstable_by_key(|e| e.depth);
|
||||
|
||||
while let Some(entry) = entries.pop() {
|
||||
if entry.file_type == FileType::Dir {
|
||||
sftp.rmdir(entry.path)
|
||||
.compat()
|
||||
.await
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::PermissionDenied, x))?;
|
||||
} else {
|
||||
sftp.unlink(entry.path)
|
||||
.compat()
|
||||
.await
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::PermissionDenied, x))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Outgoing::from(ResponseData::Ok))
|
||||
}
|
||||
|
||||
async fn copy(session: WezSession, src: PathBuf, dst: PathBuf) -> io::Result<Outgoing> {
|
||||
// NOTE: SFTP does not provide a remote-to-remote copy method, so we instead execute
|
||||
// a program and hope that it applies, starting with the Unix/BSD/GNU cp method
|
||||
// and switch to Window's xcopy if the former fails
|
||||
|
||||
// Unix cp -R <src> <dst>
|
||||
let unix_result = session
|
||||
.exec(&format!("cp -R {:?} {:?}", src, dst), None)
|
||||
.compat()
|
||||
.await;
|
||||
|
||||
let failed = unix_result.is_err() || {
|
||||
let exit_status = unix_result.unwrap().child.async_wait().compat().await;
|
||||
exit_status.is_err() || !exit_status.unwrap().success()
|
||||
};
|
||||
|
||||
// Windows xcopy <src> <dst> /s /e
|
||||
if failed {
|
||||
let exit_status = session
|
||||
.exec(&format!("xcopy {:?} {:?} /s /e", src, dst), None)
|
||||
.compat()
|
||||
.await
|
||||
.map_err(to_other_error)?
|
||||
.child
|
||||
.async_wait()
|
||||
.compat()
|
||||
.await
|
||||
.map_err(to_other_error)?;
|
||||
|
||||
if !exit_status.success() {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Unix and windows copy commands failed",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Outgoing::from(ResponseData::Ok))
|
||||
}
|
||||
|
||||
async fn rename(session: WezSession, src: PathBuf, dst: PathBuf) -> io::Result<Outgoing> {
|
||||
session
|
||||
.sftp()
|
||||
.rename(src, dst, Default::default())
|
||||
.compat()
|
||||
.await
|
||||
.map_err(to_other_error)?;
|
||||
|
||||
Ok(Outgoing::from(ResponseData::Ok))
|
||||
}
|
||||
|
||||
async fn exists(session: WezSession, path: PathBuf) -> io::Result<Outgoing> {
|
||||
// NOTE: SFTP does not provide a means to check if a path exists that can be performed
|
||||
// separately from getting permission errors; so, we just assume any error means that the path
|
||||
// does not exist
|
||||
let exists = session.sftp().lstat(path).compat().await.is_ok();
|
||||
|
||||
Ok(Outgoing::from(ResponseData::Exists(exists)))
|
||||
}
|
||||
|
||||
async fn metadata(
|
||||
session: WezSession,
|
||||
path: PathBuf,
|
||||
canonicalize: bool,
|
||||
resolve_file_type: bool,
|
||||
) -> io::Result<Outgoing> {
|
||||
let sftp = session.sftp();
|
||||
let canonicalized_path = if canonicalize {
|
||||
Some(
|
||||
sftp.realpath(path.to_path_buf())
|
||||
.compat()
|
||||
.await
|
||||
.map_err(to_other_error)?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let stat = if resolve_file_type {
|
||||
sftp.stat(path).compat().await.map_err(to_other_error)?
|
||||
} else {
|
||||
sftp.lstat(path).compat().await.map_err(to_other_error)?
|
||||
};
|
||||
|
||||
let file_type = if stat.is_dir() {
|
||||
FileType::Dir
|
||||
} else if stat.is_file() {
|
||||
FileType::File
|
||||
} else {
|
||||
FileType::Symlink
|
||||
};
|
||||
|
||||
Ok(Outgoing::from(ResponseData::Metadata {
|
||||
canonicalized_path,
|
||||
file_type,
|
||||
len: stat.len(),
|
||||
// Check that owner, group, or other has write permission (if not, then readonly)
|
||||
readonly: stat.is_readonly(),
|
||||
accessed: stat.accessed.map(u128::from),
|
||||
modified: stat.modified.map(u128::from),
|
||||
created: None,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn proc_run<F>(
|
||||
session: WezSession,
|
||||
state: Arc<Mutex<State>>,
|
||||
reply: F,
|
||||
cmd: String,
|
||||
args: Vec<String>,
|
||||
) -> io::Result<Outgoing>
|
||||
where
|
||||
F: FnMut(Vec<ResponseData>) -> ReplyRet + Clone + Send + 'static,
|
||||
{
|
||||
let id = rand::random();
|
||||
let cmd_string = format!("{} {}", cmd, args.join(" "));
|
||||
|
||||
let ExecResult {
|
||||
mut stdin,
|
||||
mut stdout,
|
||||
mut stderr,
|
||||
mut child,
|
||||
} = session
|
||||
.exec(&cmd_string, None)
|
||||
.compat()
|
||||
.await
|
||||
.map_err(to_other_error)?;
|
||||
|
||||
// Force stdin, stdout, and stderr to be nonblocking
|
||||
stdin
|
||||
.set_non_blocking(true)
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))?;
|
||||
stdout
|
||||
.set_non_blocking(true)
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))?;
|
||||
stderr
|
||||
.set_non_blocking(true)
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))?;
|
||||
|
||||
// Check if the process died immediately and report
|
||||
// an error if that's the case
|
||||
if let Ok(Some(exit_status)) = child.try_wait() {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::BrokenPipe,
|
||||
format!("Process exited early: {:?}", exit_status),
|
||||
));
|
||||
}
|
||||
|
||||
let (stdin_tx, mut stdin_rx) = mpsc::channel(1);
|
||||
let (kill_tx, mut kill_rx) = mpsc::channel(1);
|
||||
state.lock().await.processes.insert(
|
||||
id,
|
||||
Process {
|
||||
id,
|
||||
cmd,
|
||||
args,
|
||||
stdin_tx,
|
||||
kill_tx,
|
||||
},
|
||||
);
|
||||
|
||||
let post_hook = Box::new(move || {
|
||||
// Spawn a task that sends stdout as a response
|
||||
let mut reply_2 = reply.clone();
|
||||
let stdout_task = tokio::spawn(async move {
|
||||
let mut buf: [u8; MAX_PIPE_CHUNK_SIZE] = [0; MAX_PIPE_CHUNK_SIZE];
|
||||
loop {
|
||||
match stdout.read(&mut buf) {
|
||||
Ok(n) if n > 0 => match String::from_utf8(buf[..n].to_vec()) {
|
||||
Ok(data) => {
|
||||
let payload = vec![ResponseData::ProcStdout { id, data }];
|
||||
if !reply_2(payload).await {
|
||||
error!("<Ssh: Proc {}> Stdout channel closed", id);
|
||||
break;
|
||||
}
|
||||
|
||||
// Pause to allow buffer to fill up a little bit, avoiding
|
||||
// spamming with a lot of smaller responses
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(
|
||||
READ_PAUSE_MILLIS,
|
||||
))
|
||||
.await;
|
||||
}
|
||||
Err(x) => {
|
||||
error!(
|
||||
"<Ssh: Proc {}> Invalid data read from stdout pipe: {}",
|
||||
id, x
|
||||
);
|
||||
break;
|
||||
}
|
||||
},
|
||||
Ok(_) => break,
|
||||
Err(x) if x.kind() == io::ErrorKind::WouldBlock => {
|
||||
// Pause to allow buffer to fill up a little bit, avoiding
|
||||
// spamming with a lot of smaller responses
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(READ_PAUSE_MILLIS))
|
||||
.await;
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Spawn a task that sends stderr as a response
|
||||
let mut reply_2 = reply.clone();
|
||||
let stderr_task = tokio::spawn(async move {
|
||||
let mut buf: [u8; MAX_PIPE_CHUNK_SIZE] = [0; MAX_PIPE_CHUNK_SIZE];
|
||||
loop {
|
||||
match stderr.read(&mut buf) {
|
||||
Ok(n) if n > 0 => match String::from_utf8(buf[..n].to_vec()) {
|
||||
Ok(data) => {
|
||||
let payload = vec![ResponseData::ProcStderr { id, data }];
|
||||
if !reply_2(payload).await {
|
||||
error!("<Ssh: Proc {}> Stderr channel closed", id);
|
||||
break;
|
||||
}
|
||||
|
||||
// Pause to allow buffer to fill up a little bit, avoiding
|
||||
// spamming with a lot of smaller responses
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(
|
||||
READ_PAUSE_MILLIS,
|
||||
))
|
||||
.await;
|
||||
}
|
||||
Err(x) => {
|
||||
error!(
|
||||
"<Ssh: Proc {}> Invalid data read from stderr pipe: {}",
|
||||
id, x
|
||||
);
|
||||
break;
|
||||
}
|
||||
},
|
||||
Ok(_) => break,
|
||||
Err(x) if x.kind() == io::ErrorKind::WouldBlock => {
|
||||
// Pause to allow buffer to fill up a little bit, avoiding
|
||||
// spamming with a lot of smaller responses
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(READ_PAUSE_MILLIS))
|
||||
.await;
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let stdin_task = tokio::spawn(async move {
|
||||
while let Some(line) = stdin_rx.recv().await {
|
||||
if let Err(x) = stdin.write_all(line.as_bytes()) {
|
||||
error!("<Ssh: Proc {}> Failed to send stdin: {}", id, x);
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Spawn a task that waits on the process to exit but can also
|
||||
// kill the process when triggered
|
||||
let state_2 = Arc::clone(&state);
|
||||
let mut reply_2 = reply.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut should_kill = false;
|
||||
let mut success = false;
|
||||
tokio::select! {
|
||||
_ = kill_rx.recv() => {
|
||||
should_kill = true;
|
||||
}
|
||||
result = child.async_wait().compat() => {
|
||||
match result {
|
||||
Ok(status) => {
|
||||
success = status.success();
|
||||
}
|
||||
Err(x) => {
|
||||
error!("<Ssh: Proc {}> Waiting on process failed: {}", id, x);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Force stdin task to abort if it hasn't exited as there is no
|
||||
// point to sending any more stdin
|
||||
stdin_task.abort();
|
||||
|
||||
if should_kill {
|
||||
debug!("<Ssh: Proc {}> Process killed", id);
|
||||
|
||||
if let Err(x) = child.kill() {
|
||||
error!("<Ssh: Proc {}> Unable to kill process: {}", id, x);
|
||||
}
|
||||
|
||||
// NOTE: At the moment, child.kill does nothing for wezterm_ssh::SshChildProcess;
|
||||
// so, we need to manually run kill/taskkill to make sure that the
|
||||
// process is sent a kill signal
|
||||
if let Some(pid) = child.process_id() {
|
||||
let _ = session
|
||||
.exec(&format!("kill -9 {}", pid), None)
|
||||
.compat()
|
||||
.await;
|
||||
let _ = session
|
||||
.exec(&format!("taskkill /F /PID {}", pid), None)
|
||||
.compat()
|
||||
.await;
|
||||
}
|
||||
} else {
|
||||
debug!("<Ssh: Proc {}> Process done", id);
|
||||
}
|
||||
|
||||
if let Err(x) = stderr_task.await {
|
||||
error!("<Ssh: Proc {}> Join on stderr task failed: {}", id, x);
|
||||
}
|
||||
|
||||
if let Err(x) = stdout_task.await {
|
||||
error!("<Ssh: Proc {}> Join on stdout task failed: {}", id, x);
|
||||
}
|
||||
|
||||
state_2.lock().await.processes.remove(&id);
|
||||
|
||||
let payload = vec![ResponseData::ProcDone {
|
||||
id,
|
||||
success: !should_kill && success,
|
||||
code: None,
|
||||
}];
|
||||
|
||||
if !reply_2(payload).await {
|
||||
error!("<Ssh: Proc {}> Failed to send done!", id,);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
Ok(Outgoing {
|
||||
data: ResponseData::ProcStart { id },
|
||||
post_hook: Some(post_hook),
|
||||
})
|
||||
}
|
||||
|
||||
async fn proc_kill(
|
||||
_session: WezSession,
|
||||
state: Arc<Mutex<State>>,
|
||||
id: usize,
|
||||
) -> io::Result<Outgoing> {
|
||||
if let Some(process) = state.lock().await.processes.remove(&id) {
|
||||
if process.kill_tx.send(()).await.is_ok() {
|
||||
return Ok(Outgoing::from(ResponseData::Ok));
|
||||
}
|
||||
}
|
||||
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::BrokenPipe,
|
||||
"Unable to send kill signal to process",
|
||||
))
|
||||
}
|
||||
|
||||
async fn proc_stdin(
|
||||
_session: WezSession,
|
||||
state: Arc<Mutex<State>>,
|
||||
id: usize,
|
||||
data: String,
|
||||
) -> io::Result<Outgoing> {
|
||||
if let Some(process) = state.lock().await.processes.get_mut(&id) {
|
||||
if process.stdin_tx.send(data).await.is_ok() {
|
||||
return Ok(Outgoing::from(ResponseData::Ok));
|
||||
}
|
||||
}
|
||||
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::BrokenPipe,
|
||||
"Unable to send stdin to process",
|
||||
))
|
||||
}
|
||||
|
||||
async fn proc_list(_session: WezSession, state: Arc<Mutex<State>>) -> io::Result<Outgoing> {
|
||||
Ok(Outgoing::from(ResponseData::ProcEntries {
|
||||
entries: state
|
||||
.lock()
|
||||
.await
|
||||
.processes
|
||||
.values()
|
||||
.map(|p| RunningProcess {
|
||||
cmd: p.cmd.to_string(),
|
||||
args: p.args.clone(),
|
||||
id: p.id,
|
||||
})
|
||||
.collect(),
|
||||
}))
|
||||
}
|
||||
|
||||
async fn system_info(session: WezSession) -> io::Result<Outgoing> {
|
||||
let current_dir = session
|
||||
.sftp()
|
||||
.realpath(".")
|
||||
.compat()
|
||||
.await
|
||||
.map_err(to_other_error)?;
|
||||
|
||||
let first_component = current_dir.components().next();
|
||||
let is_windows =
|
||||
first_component.is_some() && matches!(first_component.unwrap(), Component::Prefix(_));
|
||||
let is_unix = current_dir.as_os_str().to_string_lossy().starts_with('/');
|
||||
|
||||
let family = if is_windows {
|
||||
"windows"
|
||||
} else if is_unix {
|
||||
"unix"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
.to_string();
|
||||
|
||||
Ok(Outgoing::from(ResponseData::SystemInfo {
|
||||
family,
|
||||
os: "".to_string(),
|
||||
arch: "".to_string(),
|
||||
current_dir,
|
||||
main_separator: if is_windows { '\\' } else { '/' },
|
||||
}))
|
||||
}
|
@ -0,0 +1,255 @@
|
||||
use async_compat::CompatExt;
|
||||
use distant_core::{Request, Session, Transport};
|
||||
use log::*;
|
||||
use smol::channel::Receiver as SmolReceiver;
|
||||
use std::{
|
||||
io::{self, Write},
|
||||
path::PathBuf,
|
||||
sync::Arc,
|
||||
};
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use wezterm_ssh::{Config as WezConfig, Session as WezSession, SessionEvent as WezSessionEvent};
|
||||
|
||||
mod handler;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Ssh2AuthPrompt {
|
||||
/// The label to show when prompting the user
|
||||
pub prompt: String,
|
||||
|
||||
/// If true, the response that the user inputs should be displayed as they type. If false then
|
||||
/// treat it as a password entry and do not display what is typed in response to this prompt.
|
||||
pub echo: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Ssh2AuthEvent {
|
||||
/// Represents the name of the user to be authenticated. This may be empty!
|
||||
pub username: String,
|
||||
|
||||
/// Informational text to be displayed to the user prior to the prompt
|
||||
pub instructions: String,
|
||||
|
||||
/// Prompts to be conveyed to the user, each representing a single answer needed
|
||||
pub prompts: Vec<Ssh2AuthPrompt>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct Ssh2SessionOpts {
|
||||
pub identity_files: Vec<PathBuf>,
|
||||
pub identities_only: Option<bool>,
|
||||
pub port: Option<u16>,
|
||||
pub proxy_command: Option<String>,
|
||||
pub user: Option<String>,
|
||||
pub user_known_hosts_files: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
pub struct Ssh2AuthHandler {
|
||||
pub on_authenticate: Box<dyn FnMut(Ssh2AuthEvent) -> io::Result<Vec<String>>>,
|
||||
pub on_banner: Box<dyn FnMut(&str)>,
|
||||
pub on_host_verify: Box<dyn FnMut(&str) -> io::Result<bool>>,
|
||||
pub on_error: Box<dyn FnMut(&str)>,
|
||||
}
|
||||
|
||||
impl Default for Ssh2AuthHandler {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
on_authenticate: Box::new(|ev| {
|
||||
if !ev.username.is_empty() {
|
||||
eprintln!("Authentication for {}", ev.username);
|
||||
}
|
||||
|
||||
if !ev.instructions.is_empty() {
|
||||
eprintln!("{}", ev.instructions);
|
||||
}
|
||||
|
||||
let mut answers = Vec::new();
|
||||
for prompt in &ev.prompts {
|
||||
// Contains all prompt lines including same line
|
||||
let mut prompt_lines = prompt.prompt.split('\n').collect::<Vec<_>>();
|
||||
|
||||
// Line that is prompt on same line as answer
|
||||
let prompt_line = prompt_lines.pop().unwrap();
|
||||
|
||||
// Go ahead and display all other lines
|
||||
for line in prompt_lines.into_iter() {
|
||||
eprintln!("{}", line);
|
||||
}
|
||||
|
||||
let answer = if prompt.echo {
|
||||
eprint!("{}", prompt_line);
|
||||
std::io::stderr().lock().flush()?;
|
||||
|
||||
let mut answer = String::new();
|
||||
std::io::stdin().read_line(&mut answer)?;
|
||||
answer
|
||||
} else {
|
||||
rpassword::prompt_password_stderr(prompt_line)?
|
||||
};
|
||||
|
||||
answers.push(answer);
|
||||
}
|
||||
Ok(answers)
|
||||
}),
|
||||
on_banner: Box::new(|_| {}),
|
||||
on_host_verify: Box::new(|message| {
|
||||
eprintln!("{}", message);
|
||||
match rpassword::prompt_password_stderr("Enter [y/N]> ")?.as_str() {
|
||||
"y" | "Y" | "yes" | "YES" => Ok(true),
|
||||
_ => Ok(false),
|
||||
}
|
||||
}),
|
||||
on_error: Box::new(|_| {}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Ssh2Session {
|
||||
session: WezSession,
|
||||
events: SmolReceiver<WezSessionEvent>,
|
||||
}
|
||||
|
||||
impl Ssh2Session {
|
||||
/// Connect to a remote TCP server using SSH
|
||||
pub fn connect(host: impl AsRef<str>, opts: Ssh2SessionOpts) -> io::Result<Self> {
|
||||
let mut config = WezConfig::new();
|
||||
config.add_default_config_files();
|
||||
|
||||
// Grab the config for the specific host
|
||||
let mut config = config.for_host(host.as_ref());
|
||||
|
||||
// Override config with any settings provided by session opts
|
||||
if let Some(port) = opts.port.as_ref() {
|
||||
config.insert("port".to_string(), port.to_string());
|
||||
}
|
||||
if let Some(user) = opts.user.as_ref() {
|
||||
config.insert("user".to_string(), user.to_string());
|
||||
}
|
||||
if !opts.identity_files.is_empty() {
|
||||
config.insert(
|
||||
"identityfile".to_string(),
|
||||
opts.identity_files
|
||||
.iter()
|
||||
.filter_map(|p| p.to_str())
|
||||
.map(ToString::to_string)
|
||||
.collect::<Vec<String>>()
|
||||
.join(" "),
|
||||
);
|
||||
}
|
||||
if let Some(yes) = opts.identities_only.as_ref() {
|
||||
let value = if *yes {
|
||||
"yes".to_string()
|
||||
} else {
|
||||
"no".to_string()
|
||||
};
|
||||
config.insert("identitiesonly".to_string(), value);
|
||||
}
|
||||
if let Some(cmd) = opts.proxy_command.as_ref() {
|
||||
config.insert("proxycommand".to_string(), cmd.to_string());
|
||||
}
|
||||
if !opts.user_known_hosts_files.is_empty() {
|
||||
config.insert(
|
||||
"userknownhostsfile".to_string(),
|
||||
opts.user_known_hosts_files
|
||||
.iter()
|
||||
.filter_map(|p| p.to_str())
|
||||
.map(ToString::to_string)
|
||||
.collect::<Vec<String>>()
|
||||
.join(" "),
|
||||
);
|
||||
}
|
||||
|
||||
// Establish a connection
|
||||
let (session, events) =
|
||||
WezSession::connect(config).map_err(|x| io::Error::new(io::ErrorKind::Other, x))?;
|
||||
|
||||
Ok(Self { session, events })
|
||||
}
|
||||
|
||||
/// Authenticates the [`Ssh2Session`] and produces a [`Session`]
|
||||
pub async fn authenticate(self, mut handler: Ssh2AuthHandler) -> io::Result<Session> {
|
||||
// Perform the authentication by listening for events and continuing to handle them
|
||||
// until authenticated
|
||||
while let Ok(event) = self.events.recv().await {
|
||||
match event {
|
||||
WezSessionEvent::Banner(banner) => {
|
||||
if let Some(banner) = banner {
|
||||
(handler.on_banner)(banner.as_ref());
|
||||
}
|
||||
}
|
||||
WezSessionEvent::HostVerify(verify) => {
|
||||
let verified = (handler.on_host_verify)(verify.message.as_str())?;
|
||||
verify
|
||||
.answer(verified)
|
||||
.compat()
|
||||
.await
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))?;
|
||||
}
|
||||
WezSessionEvent::Authenticate(mut auth) => {
|
||||
let ev = Ssh2AuthEvent {
|
||||
username: auth.username.clone(),
|
||||
instructions: auth.instructions.clone(),
|
||||
prompts: auth
|
||||
.prompts
|
||||
.drain(..)
|
||||
.map(|p| Ssh2AuthPrompt {
|
||||
prompt: p.prompt,
|
||||
echo: p.echo,
|
||||
})
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let answers = (handler.on_authenticate)(ev)?;
|
||||
auth.answer(answers)
|
||||
.compat()
|
||||
.await
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))?;
|
||||
}
|
||||
WezSessionEvent::Error(err) => {
|
||||
(handler.on_error)(&err);
|
||||
return Err(io::Error::new(io::ErrorKind::PermissionDenied, err));
|
||||
}
|
||||
WezSessionEvent::Authenticated => break,
|
||||
}
|
||||
}
|
||||
|
||||
// We are now authenticated, so convert into a distant session that wraps our ssh2 session
|
||||
self.into_session()
|
||||
}
|
||||
|
||||
/// Consume [`Ssh2Session`] and produce a distant [`Session`]
|
||||
fn into_session(self) -> io::Result<Session> {
|
||||
let (t1, t2) = Transport::pair(1);
|
||||
let session = Session::initialize(t1)?;
|
||||
|
||||
// Spawn tasks that forward requests to the ssh session
|
||||
// and send back responses from the ssh session
|
||||
let (mut t_read, mut t_write) = t2.into_split();
|
||||
let Self {
|
||||
session: wez_session,
|
||||
..
|
||||
} = self;
|
||||
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
tokio::spawn(async move {
|
||||
let state = Arc::new(Mutex::new(handler::State::default()));
|
||||
while let Ok(Some(req)) = t_read.receive::<Request>().await {
|
||||
if let Err(x) =
|
||||
handler::process(wez_session.clone(), Arc::clone(&state), req, tx.clone()).await
|
||||
{
|
||||
error!("{}", x);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
while let Some(res) = rx.recv().await {
|
||||
if t_write.send(res).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
}
|
@ -0,0 +1,2 @@
|
||||
mod ssh2;
|
||||
mod sshd;
|
@ -0,0 +1 @@
|
||||
mod session;
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,432 @@
|
||||
use assert_fs::{prelude::*, TempDir};
|
||||
use distant_core::Session;
|
||||
use distant_ssh2::{Ssh2AuthHandler, Ssh2Session, Ssh2SessionOpts};
|
||||
use once_cell::sync::{Lazy, OnceCell};
|
||||
use rstest::*;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fmt, io,
|
||||
path::Path,
|
||||
process::{Child, Command},
|
||||
sync::atomic::{AtomicU16, Ordering},
|
||||
thread,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
/// NOTE: OpenSSH's sshd requires absolute path
|
||||
const BIN_PATH_STR: &str = "/usr/sbin/sshd";
|
||||
|
||||
/// Port range to use when finding a port to bind to (using IANA guidance)
|
||||
const PORT_RANGE: (u16, u16) = (49152, 65535);
|
||||
|
||||
static USERNAME: Lazy<String> = Lazy::new(whoami::username);
|
||||
|
||||
pub struct SshKeygen;
|
||||
|
||||
impl SshKeygen {
|
||||
// ssh-keygen -t rsa -f $ROOT/id_rsa -N "" -q
|
||||
pub fn generate_rsa(path: impl AsRef<Path>, passphrase: impl AsRef<str>) -> io::Result<bool> {
|
||||
let res = Command::new("ssh-keygen")
|
||||
.args(&["-m", "PEM"])
|
||||
.args(&["-t", "rsa"])
|
||||
.arg("-f")
|
||||
.arg(path.as_ref())
|
||||
.arg("-N")
|
||||
.arg(passphrase.as_ref())
|
||||
.arg("-q")
|
||||
.status()
|
||||
.map(|status| status.success())?;
|
||||
|
||||
#[cfg(unix)]
|
||||
if res {
|
||||
// chmod 600 id_rsa* -> ida_rsa + ida_rsa.pub
|
||||
std::fs::metadata(path.as_ref().with_extension("pub"))?
|
||||
.permissions()
|
||||
.set_mode(0o600);
|
||||
std::fs::metadata(path)?.permissions().set_mode(0o600);
|
||||
}
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SshAgent;
|
||||
|
||||
impl SshAgent {
|
||||
pub fn generate_shell_env() -> io::Result<HashMap<String, String>> {
|
||||
let output = Command::new("ssh-agent").arg("-s").output()?;
|
||||
let stdout = String::from_utf8(output.stdout)
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?;
|
||||
Ok(stdout
|
||||
.split(';')
|
||||
.map(str::trim)
|
||||
.filter(|s| s.contains('='))
|
||||
.map(|s| {
|
||||
let mut tokens = s.split('=');
|
||||
let key = tokens.next().unwrap().trim().to_string();
|
||||
let rest = tokens
|
||||
.map(str::trim)
|
||||
.map(ToString::to_string)
|
||||
.collect::<Vec<String>>()
|
||||
.join("=");
|
||||
(key, rest)
|
||||
})
|
||||
.collect::<HashMap<String, String>>())
|
||||
}
|
||||
|
||||
pub fn update_tests_with_shell_env() -> io::Result<()> {
|
||||
let env_map = Self::generate_shell_env()?;
|
||||
for (key, value) in env_map {
|
||||
std::env::set_var(key, value);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SshdConfig(HashMap<String, Vec<String>>);
|
||||
|
||||
impl Default for SshdConfig {
|
||||
fn default() -> Self {
|
||||
let mut config = Self::new();
|
||||
|
||||
config.set_authentication_methods(vec!["publickey".to_string()]);
|
||||
config.set_use_privilege_separation(false);
|
||||
config.set_subsystem(true, true);
|
||||
config.set_use_pam(false);
|
||||
config.set_x11_forwarding(true);
|
||||
config.set_print_motd(true);
|
||||
config.set_permit_tunnel(true);
|
||||
config.set_kbd_interactive_authentication(true);
|
||||
config.set_allow_tcp_forwarding(true);
|
||||
config.set_max_startups(500, None);
|
||||
config.set_strict_modes(false);
|
||||
|
||||
config
|
||||
}
|
||||
}
|
||||
|
||||
impl SshdConfig {
|
||||
pub fn new() -> Self {
|
||||
Self(HashMap::new())
|
||||
}
|
||||
|
||||
pub fn set_authentication_methods(&mut self, methods: Vec<String>) {
|
||||
self.0.insert("AuthenticationMethods".to_string(), methods);
|
||||
}
|
||||
|
||||
pub fn set_authorized_keys_file(&mut self, path: impl AsRef<Path>) {
|
||||
self.0.insert(
|
||||
"AuthorizedKeysFile".to_string(),
|
||||
vec![path.as_ref().to_string_lossy().to_string()],
|
||||
);
|
||||
}
|
||||
|
||||
pub fn set_host_key(&mut self, path: impl AsRef<Path>) {
|
||||
self.0.insert(
|
||||
"HostKey".to_string(),
|
||||
vec![path.as_ref().to_string_lossy().to_string()],
|
||||
);
|
||||
}
|
||||
|
||||
pub fn set_pid_file(&mut self, path: impl AsRef<Path>) {
|
||||
self.0.insert(
|
||||
"PidFile".to_string(),
|
||||
vec![path.as_ref().to_string_lossy().to_string()],
|
||||
);
|
||||
}
|
||||
|
||||
pub fn set_subsystem(&mut self, sftp: bool, internal_sftp: bool) {
|
||||
let mut values = Vec::new();
|
||||
if sftp {
|
||||
values.push("sftp".to_string());
|
||||
}
|
||||
if internal_sftp {
|
||||
values.push("internal-sftp".to_string());
|
||||
}
|
||||
|
||||
self.0.insert("Subsystem".to_string(), values);
|
||||
}
|
||||
|
||||
pub fn set_use_pam(&mut self, yes: bool) {
|
||||
self.0.insert("UsePAM".to_string(), Self::yes_value(yes));
|
||||
}
|
||||
|
||||
pub fn set_x11_forwarding(&mut self, yes: bool) {
|
||||
self.0
|
||||
.insert("X11Forwarding".to_string(), Self::yes_value(yes));
|
||||
}
|
||||
|
||||
pub fn set_use_privilege_separation(&mut self, yes: bool) {
|
||||
self.0
|
||||
.insert("UsePrivilegeSeparation".to_string(), Self::yes_value(yes));
|
||||
}
|
||||
|
||||
pub fn set_print_motd(&mut self, yes: bool) {
|
||||
self.0.insert("PrintMotd".to_string(), Self::yes_value(yes));
|
||||
}
|
||||
|
||||
pub fn set_permit_tunnel(&mut self, yes: bool) {
|
||||
self.0
|
||||
.insert("PermitTunnel".to_string(), Self::yes_value(yes));
|
||||
}
|
||||
|
||||
pub fn set_kbd_interactive_authentication(&mut self, yes: bool) {
|
||||
self.0.insert(
|
||||
"KbdInteractiveAuthentication".to_string(),
|
||||
Self::yes_value(yes),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn set_allow_tcp_forwarding(&mut self, yes: bool) {
|
||||
self.0
|
||||
.insert("AllowTcpForwarding".to_string(), Self::yes_value(yes));
|
||||
}
|
||||
|
||||
pub fn set_max_startups(&mut self, start: u16, rate_full: Option<(u16, u16)>) {
|
||||
let value = format!(
|
||||
"{}{}",
|
||||
start,
|
||||
rate_full
|
||||
.map(|(r, f)| format!(":{}:{}", r, f))
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
|
||||
self.0.insert("MaxStartups".to_string(), vec![value]);
|
||||
}
|
||||
|
||||
pub fn set_strict_modes(&mut self, yes: bool) {
|
||||
self.0
|
||||
.insert("StrictModes".to_string(), Self::yes_value(yes));
|
||||
}
|
||||
|
||||
fn yes_value(yes: bool) -> Vec<String> {
|
||||
vec![Self::yes_string(yes)]
|
||||
}
|
||||
|
||||
fn yes_string(yes: bool) -> String {
|
||||
Self::yes_str(yes).to_string()
|
||||
}
|
||||
|
||||
const fn yes_str(yes: bool) -> &'static str {
|
||||
if yes {
|
||||
"yes"
|
||||
} else {
|
||||
"no"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for SshdConfig {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
for (keyword, values) in self.0.iter() {
|
||||
writeln!(
|
||||
f,
|
||||
"{} {}",
|
||||
keyword,
|
||||
values
|
||||
.iter()
|
||||
.map(|v| {
|
||||
let v = v.trim();
|
||||
if v.contains(|c: char| c.is_whitespace()) {
|
||||
format!("\"{}\"", v)
|
||||
} else {
|
||||
v.to_string()
|
||||
}
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join(" ")
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Context for some sshd instance
|
||||
pub struct Sshd {
|
||||
child: Child,
|
||||
|
||||
/// Port that sshd is listening on
|
||||
pub port: u16,
|
||||
|
||||
/// Temporary directory used to hold resources for sshd such as its config, keys, and log
|
||||
pub tmp: TempDir,
|
||||
}
|
||||
|
||||
impl Sshd {
|
||||
pub fn spawn(mut config: SshdConfig) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let tmp = TempDir::new()?;
|
||||
|
||||
// Ensure that everything needed for interacting with ssh-agent is set
|
||||
SshAgent::update_tests_with_shell_env()?;
|
||||
|
||||
// ssh-keygen -t rsa -f $ROOT/id_rsa -N "" -q
|
||||
let id_rsa_file = tmp.child("id_rsa");
|
||||
assert!(
|
||||
SshKeygen::generate_rsa(id_rsa_file.path(), "")?,
|
||||
"Failed to ssh-keygen id_rsa"
|
||||
);
|
||||
|
||||
// cp $ROOT/id_rsa.pub $ROOT/authorized_keys
|
||||
let authorized_keys_file = tmp.child("authorized_keys");
|
||||
std::fs::copy(
|
||||
id_rsa_file.path().with_extension("pub"),
|
||||
authorized_keys_file.path(),
|
||||
)?;
|
||||
|
||||
// ssh-keygen -t rsa -f $ROOT/ssh_host_rsa_key -N "" -q
|
||||
let ssh_host_rsa_key_file = tmp.child("ssh_host_rsa_key");
|
||||
assert!(
|
||||
SshKeygen::generate_rsa(ssh_host_rsa_key_file.path(), "")?,
|
||||
"Failed to ssh-keygen ssh_host_rsa_key"
|
||||
);
|
||||
|
||||
config.set_authorized_keys_file(id_rsa_file.path().with_extension("pub"));
|
||||
config.set_host_key(ssh_host_rsa_key_file.path());
|
||||
|
||||
let sshd_pid_file = tmp.child("sshd.pid");
|
||||
config.set_pid_file(sshd_pid_file.path());
|
||||
|
||||
// Generate $ROOT/sshd_config based on config
|
||||
let sshd_config_file = tmp.child("sshd_config");
|
||||
sshd_config_file.write_str(&config.to_string())?;
|
||||
|
||||
let sshd_log_file = tmp.child("sshd.log");
|
||||
|
||||
let (child, port) = Self::try_spawn_next(sshd_config_file.path(), sshd_log_file.path())
|
||||
.expect("No open port available for sshd");
|
||||
|
||||
Ok(Self { child, port, tmp })
|
||||
}
|
||||
|
||||
fn try_spawn_next(
|
||||
config_path: impl AsRef<Path>,
|
||||
log_path: impl AsRef<Path>,
|
||||
) -> io::Result<(Child, u16)> {
|
||||
static PORT: AtomicU16 = AtomicU16::new(PORT_RANGE.0);
|
||||
|
||||
loop {
|
||||
let port = PORT.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
match Self::try_spawn(port, config_path.as_ref(), log_path.as_ref()) {
|
||||
// If successful, return our spawned server child process
|
||||
Ok(Ok(child)) => break Ok((child, port)),
|
||||
|
||||
// If the server died when spawned and we reached the final port, we want to exit
|
||||
Ok(Err((code, msg))) if port == PORT_RANGE.1 => {
|
||||
break Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!(
|
||||
"{} failed [{}]: {}",
|
||||
BIN_PATH_STR,
|
||||
code.map(|x| x.to_string())
|
||||
.unwrap_or_else(|| String::from("???")),
|
||||
msg
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
// If we've reached the final port in our range to try, we want to exit
|
||||
Err(x) if port == PORT_RANGE.1 => break Err(x),
|
||||
|
||||
// Otherwise, try next port
|
||||
Err(_) | Ok(Err(_)) => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn try_spawn(
|
||||
port: u16,
|
||||
config_path: impl AsRef<Path>,
|
||||
log_path: impl AsRef<Path>,
|
||||
) -> io::Result<Result<Child, (Option<i32>, String)>> {
|
||||
let mut child = Command::new(BIN_PATH_STR)
|
||||
.arg("-D")
|
||||
.arg("-p")
|
||||
.arg(port.to_string())
|
||||
.arg("-f")
|
||||
.arg(config_path.as_ref())
|
||||
.arg("-E")
|
||||
.arg(log_path.as_ref())
|
||||
.spawn()?;
|
||||
|
||||
// Pause for a little bit to make sure that the server didn't die due to an error
|
||||
thread::sleep(Duration::from_millis(100));
|
||||
|
||||
if let Some(exit_status) = child.try_wait()? {
|
||||
let output = child.wait_with_output()?;
|
||||
Ok(Err((
|
||||
exit_status.code(),
|
||||
format!(
|
||||
"{}\n{}",
|
||||
String::from_utf8(output.stdout).unwrap(),
|
||||
String::from_utf8(output.stderr).unwrap(),
|
||||
),
|
||||
)))
|
||||
} else {
|
||||
Ok(Ok(child))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Sshd {
|
||||
/// Kills server upon drop
|
||||
fn drop(&mut self) {
|
||||
let _ = self.child.kill();
|
||||
}
|
||||
}
|
||||
|
||||
#[fixture]
|
||||
pub fn logger() -> &'static flexi_logger::LoggerHandle {
|
||||
static LOGGER: OnceCell<flexi_logger::LoggerHandle> = OnceCell::new();
|
||||
|
||||
LOGGER.get_or_init(|| {
|
||||
// flexi_logger::Logger::try_with_str("off, distant_core=trace, distant_ssh2=trace")
|
||||
flexi_logger::Logger::try_with_str("off, distant_core=warn, distant_ssh2=warn")
|
||||
.expect("Failed to load env")
|
||||
.start()
|
||||
.expect("Failed to start logger")
|
||||
})
|
||||
}
|
||||
|
||||
#[fixture]
|
||||
pub fn sshd() -> &'static Sshd {
|
||||
static SSHD: OnceCell<Sshd> = OnceCell::new();
|
||||
|
||||
SSHD.get_or_init(|| Sshd::spawn(Default::default()).unwrap())
|
||||
}
|
||||
|
||||
#[fixture]
|
||||
pub async fn session(sshd: &'_ Sshd, _logger: &'_ flexi_logger::LoggerHandle) -> Session {
|
||||
let port = sshd.port;
|
||||
|
||||
Ssh2Session::connect(
|
||||
"127.0.0.1",
|
||||
Ssh2SessionOpts {
|
||||
port: Some(port),
|
||||
identity_files: vec![sshd.tmp.child("id_rsa").path().to_path_buf()],
|
||||
identities_only: Some(true),
|
||||
user: Some(USERNAME.to_string()),
|
||||
user_known_hosts_files: vec![sshd.tmp.child("known_hosts").path().to_path_buf()],
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.unwrap()
|
||||
.authenticate(Ssh2AuthHandler {
|
||||
on_authenticate: Box::new(|ev| {
|
||||
println!("on_authenticate: {:?}", ev);
|
||||
Ok(vec![String::new(); ev.prompts.len()])
|
||||
}),
|
||||
on_host_verify: Box::new(|host| {
|
||||
println!("on_host_verify: {}", host);
|
||||
Ok(true)
|
||||
}),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
Loading…
Reference in New Issue