Refactor session to use mpsc instead of broadcast channel, add LSP command

pull/38/head
Chip Senkbeil 3 years ago
parent ba6ebcfcb8
commit 260cb0e99d
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -88,7 +88,6 @@ impl ExitCodeError for RemoteProcessError {
match self {
Self::BadResponse => ExitCode::DataErr,
Self::ChannelDead => ExitCode::Unavailable,
Self::Overloaded => ExitCode::Software,
Self::TransportError(x) => x.to_exit_code(),
Self::UnexpectedEof => ExitCode::IoError,
Self::WaitFailed(_) => ExitCode::Software,

@ -95,6 +95,9 @@ pub enum Subcommand {
/// Begins listening for incoming requests
Listen(ListenSubcommand),
/// Specialized treatment of running a remote LSP process
Lsp(LspSubcommand),
}
impl Subcommand {
@ -104,6 +107,7 @@ impl Subcommand {
Self::Action(cmd) => subcommand::action::run(cmd, opt)?,
Self::Launch(cmd) => subcommand::launch::run(cmd, opt)?,
Self::Listen(cmd) => subcommand::listen::run(cmd, opt)?,
Self::Lsp(cmd) => subcommand::lsp::run(cmd, opt)?,
}
Ok(())
@ -499,3 +503,42 @@ impl ListenSubcommand {
.map(Duration::from_secs_f32)
}
}
/// Represents subcommand to execute some LSP server on a remote machine
#[derive(Debug, StructOpt)]
#[structopt(verbatim_doc_comment)]
pub struct LspSubcommand {
/// Represents the format that results should be returned
///
/// Currently, there are two possible formats:
///
/// 1. "json": printing out JSON for external program usage
///
/// 2. "shell": printing out human-readable results for interactive shell usage
#[structopt(
short,
long,
case_insensitive = true,
default_value = Format::Shell.into(),
possible_values = Format::VARIANTS
)]
pub format: Format,
/// Represents the medium for retrieving a session to use when running a remote LSP server
#[structopt(
long,
default_value = SessionInput::default().into(),
possible_values = SessionInput::VARIANTS
)]
pub session: SessionInput,
/// Contains additional information related to sessions
#[structopt(flatten)]
pub session_data: SessionOpt,
/// Command to run on the remote machine that represents an LSP server
pub cmd: String,
/// Additional arguments to supply to the remote machine
pub args: Vec<String>,
}

@ -11,7 +11,6 @@ use log::*;
use std::{io, thread};
use structopt::StructOpt;
use tokio::{sync::mpsc, task::JoinHandle};
use tokio_stream::{wrappers::BroadcastStream, StreamExt};
/// Represents a wrapper around a session that provides CLI functionality such as reading from
/// stdin and piping results back out to stdout
@ -29,9 +28,11 @@ impl CliSession {
let (stdin_thread, stdin_rx) = stdin::spawn_channel(MAX_PIPE_CHUNK_SIZE);
let (exit_tx, exit_rx) = mpsc::channel(1);
let stream = session.to_response_broadcast_stream();
let broadcast = session.broadcast.take().unwrap();
let res_task =
tokio::spawn(async move { process_incoming_responses(stream, format, exit_rx).await });
tokio::spawn(
async move { process_incoming_responses(broadcast, format, exit_rx).await },
);
let map_line = move |line: &str| match format {
Format::Json => serde_json::from_str(&line)
@ -76,16 +77,15 @@ impl CliSession {
/// Helper function that loops, processing incoming responses not tied to a request to be sent out
/// over stdout/stderr
async fn process_incoming_responses(
mut stream: BroadcastStream<Response>,
mut broadcast: mpsc::Receiver<Response>,
format: Format,
mut exit: mpsc::Receiver<()>,
) -> io::Result<()> {
loop {
tokio::select! {
res = stream.next() => {
res = broadcast.recv() => {
match res {
Some(Ok(res)) => ResponseOut::new(format, res)?.print(),
Some(Err(x)) => return Err(io::Error::new(io::ErrorKind::BrokenPipe, x)),
Some(res) => ResponseOut::new(format, res)?.print(),
None => return Ok(()),
}
}

@ -157,7 +157,7 @@ async fn socket_loop(
debug!("Binding to unix socket: {:?}", socket_path.as_ref());
let listener = tokio::net::UnixListener::bind(socket_path)?;
let server = RelayServer::initialize(session, listener, shutdown_after).await?;
let server = RelayServer::initialize(session, listener, shutdown_after)?;
server
.wait()
.await

@ -0,0 +1,141 @@
use crate::{
cli::{
link::RemoteProcessLink,
opt::{CommonOpt, LspSubcommand, SessionInput},
ExitCode, ExitCodeError,
},
core::{
client::{
self, LspData, RemoteLspProcess, RemoteProcessError, Session, SessionInfo,
SessionInfoFile,
},
net::DataStream,
},
};
use derive_more::{Display, Error, From};
use tokio::io;
#[derive(Debug, Display, Error, From)]
pub enum Error {
#[display(fmt = "Process failed with exit code: {}", _0)]
BadProcessExit(#[error(not(source))] i32),
IoError(io::Error),
RemoteProcessError(RemoteProcessError),
}
impl ExitCodeError for Error {
fn to_exit_code(&self) -> ExitCode {
match self {
Self::BadProcessExit(x) => ExitCode::Custom(*x),
Self::IoError(x) => x.to_exit_code(),
Self::RemoteProcessError(x) => x.to_exit_code(),
}
}
}
pub fn run(cmd: LspSubcommand, opt: CommonOpt) -> Result<(), Error> {
let rt = tokio::runtime::Runtime::new()?;
rt.block_on(async { run_async(cmd, opt).await })
}
async fn run_async(cmd: LspSubcommand, opt: CommonOpt) -> Result<(), Error> {
let timeout = opt.to_timeout_duration();
match cmd.session {
SessionInput::Environment => {
start(
cmd,
Session::tcp_connect_timeout(SessionInfo::from_environment()?, timeout).await?,
None,
)
.await
}
SessionInput::File => {
let path = cmd.session_data.session_file.clone();
start(
cmd,
Session::tcp_connect_timeout(
SessionInfoFile::load_from(path).await?.into(),
timeout,
)
.await?,
None,
)
.await
}
SessionInput::Pipe => {
start(
cmd,
Session::tcp_connect_timeout(SessionInfo::from_stdin()?, timeout).await?,
None,
)
.await
}
SessionInput::Lsp => {
let mut data =
LspData::from_buf_reader(&mut std::io::stdin().lock()).map_err(io::Error::from)?;
let info = data.take_session_info().map_err(io::Error::from)?;
start(
cmd,
Session::tcp_connect_timeout(info, timeout).await?,
Some(data),
)
.await
}
#[cfg(unix)]
SessionInput::Socket => {
let path = cmd.session_data.session_socket.clone();
start(
cmd,
Session::unix_connect_timeout(path, None, timeout).await?,
None,
)
.await
}
}
}
async fn start<T>(
cmd: LspSubcommand,
session: Session<T>,
lsp_data: Option<LspData>,
) -> Result<(), Error>
where
T: DataStream + 'static,
{
let mut proc =
RemoteLspProcess::spawn(client::new_tenant(), session, cmd.cmd, cmd.args).await?;
// If we also parsed an LSP's initialize request for its session, we want to forward
// it along in the case of a process call
if let Some(data) = lsp_data {
proc.stdin
.as_mut()
.unwrap()
.write(&data.to_string())
.await?;
}
// Now, map the remote LSP server's stdin/stdout/stderr to our own process
let link = RemoteProcessLink::from_remote_lsp_pipes(
proc.stdin.take().unwrap(),
proc.stdout.take().unwrap(),
proc.stderr.take().unwrap(),
);
let (success, exit_code) = proc.wait().await?;
// Shut down our link
link.shutdown().await;
if !success {
if let Some(code) = exit_code {
return Err(Error::BadProcessExit(code));
} else {
return Err(Error::BadProcessExit(1));
}
}
Ok(())
}

@ -1,3 +1,4 @@
pub mod action;
pub mod launch;
pub mod listen;
pub mod lsp;

@ -41,6 +41,11 @@ impl RemoteLspProcess {
stderr,
})
}
/// Waits for the process to terminate, returning the success status and an optional exit code
pub async fn wait(self) -> Result<(bool, Option<i32>), RemoteProcessError> {
self.inner.wait().await
}
}
impl Deref for RemoteLspProcess {

@ -10,7 +10,6 @@ use tokio::{
sync::mpsc,
task::{JoinError, JoinHandle},
};
use tokio_stream::{wrappers::BroadcastStream, StreamExt};
#[derive(Debug, Display, Error, From)]
pub enum RemoteProcessError {
@ -20,10 +19,6 @@ pub enum RemoteProcessError {
/// When attempting to relay stdout/stderr over channels, but the channels fail
ChannelDead,
/// When process is unable to read stdout/stderr from the server
/// fast enough, resulting in dropped data
Overloaded,
/// When the communication over the wire has issues
TransportError(TransportError),
@ -97,9 +92,9 @@ impl RemoteProcess {
// Now we spawn a task to handle future responses that are async
// such as ProcStdout, ProcStderr, and ProcDone
let stream = session.to_response_broadcast_stream();
let broadcast = session.broadcast.take().unwrap();
let res_task = tokio::spawn(async move {
process_incoming_responses(id, stream, stdout_tx, stderr_tx).await
process_incoming_responses(id, broadcast, stdout_tx, stderr_tx).await
});
// Spawn a task that takes stdin from our channel and forwards it to the remote process
@ -234,53 +229,45 @@ where
/// Helper function that loops, processing incoming stdout & stderr requests from a remote process
async fn process_incoming_responses(
proc_id: usize,
mut stream: BroadcastStream<Response>,
mut broadcast: mpsc::Receiver<Response>,
stdout_tx: mpsc::Sender<String>,
stderr_tx: mpsc::Sender<String>,
) -> Result<(bool, Option<i32>), RemoteProcessError> {
let mut result = Err(RemoteProcessError::UnexpectedEof);
while let Some(res) = stream.next().await {
match res {
Ok(res) => {
// Check if any of the payload data is the termination
let exit_status = res.payload.iter().find_map(|data| match data {
ResponseData::ProcDone { id, success, code } if *id == proc_id => {
Some((*success, *code))
}
_ => None,
});
// Next, check for stdout/stderr and send them along our channels
// TODO: What should we do about unexpected data? For now, just ignore
for data in res.payload {
match data {
ResponseData::ProcStdout { id, data } if id == proc_id => {
if let Err(_) = stdout_tx.send(data).await {
result = Err(RemoteProcessError::ChannelDead);
break;
}
}
ResponseData::ProcStderr { id, data } if id == proc_id => {
if let Err(_) = stderr_tx.send(data).await {
result = Err(RemoteProcessError::ChannelDead);
break;
}
}
_ => {}
while let Some(res) = broadcast.recv().await {
// Check if any of the payload data is the termination
let exit_status = res.payload.iter().find_map(|data| match data {
ResponseData::ProcDone { id, success, code } if *id == proc_id => {
Some((*success, *code))
}
_ => None,
});
// Next, check for stdout/stderr and send them along our channels
// TODO: What should we do about unexpected data? For now, just ignore
for data in res.payload {
match data {
ResponseData::ProcStdout { id, data } if id == proc_id => {
if let Err(_) = stdout_tx.send(data).await {
result = Err(RemoteProcessError::ChannelDead);
break;
}
}
// If we got a termination, then exit accordingly
if let Some((success, code)) = exit_status {
result = Ok((success, code));
break;
ResponseData::ProcStderr { id, data } if id == proc_id => {
if let Err(_) = stderr_tx.send(data).await {
result = Err(RemoteProcessError::ChannelDead);
break;
}
}
_ => {}
}
Err(_) => {
result = Err(RemoteProcessError::Overloaded);
break;
}
}
// If we got a termination, then exit accordingly
if let Some((success, code)) = exit_status {
result = Ok((success, code));
break;
}
}

@ -13,11 +13,10 @@ use std::{
use tokio::{
io,
net::TcpStream,
sync::{broadcast, oneshot},
sync::{mpsc, oneshot},
task::{JoinError, JoinHandle},
time::Duration,
};
use tokio_stream::wrappers::BroadcastStream;
mod info;
pub use info::{SessionInfo, SessionInfoFile, SessionInfoParseError};
@ -35,16 +34,11 @@ where
/// Collection of callbacks to be invoked upon receiving a response to a request
callbacks: Callbacks,
/// Callback to trigger when a response is received without an origin or with an origin
/// not found in the list of callbacks
broadcast: broadcast::Sender<Response>,
/// Represents an initial receiver for broadcasted responses that can capture responses
/// prior to a stream being established and consumed
init_broadcast_receiver: Option<broadcast::Receiver<Response>>,
/// Contains the task that is running to receive responses from a server
response_task: JoinHandle<()>,
/// Represents the receiver for broadcasted responses (ones with no callback)
pub broadcast: Option<mpsc::Receiver<Response>>,
}
impl Session<InmemoryStream> {
@ -116,12 +110,10 @@ where
pub async fn initialize(transport: Transport<T>) -> io::Result<Self> {
let (mut t_read, t_write) = transport.into_split();
let callbacks: Callbacks = Arc::new(Mutex::new(HashMap::new()));
let (broadcast, init_broadcast_receiver) =
broadcast::channel(CLIENT_BROADCAST_CHANNEL_CAPACITY);
let (broadcast_tx, broadcast_rx) = mpsc::channel(CLIENT_BROADCAST_CHANNEL_CAPACITY);
// Start a task that continually checks for responses and triggers callbacks
let callbacks_2 = Arc::clone(&callbacks);
let broadcast_2 = broadcast.clone();
let response_task = tokio::spawn(async move {
loop {
match t_read.receive::<Response>().await {
@ -142,7 +134,7 @@ where
// Otherwise, this goes into the junk draw of response handlers
} else {
trace!("Callback missing for response! Broadcasting!");
if let Err(x) = broadcast_2.send(res) {
if let Err(x) = broadcast_tx.send(res).await {
error!("Failed to trigger broadcast: {}", x);
}
}
@ -159,8 +151,7 @@ where
Ok(Self {
t_write,
callbacks,
broadcast,
init_broadcast_receiver: Some(init_broadcast_receiver),
broadcast: Some(broadcast_rx),
response_task,
})
}
@ -220,21 +211,6 @@ where
.map_err(TransportError::from)
.and_then(convert::identity)
}
/// Clones a new instance of the broadcaster used by the session
pub fn to_response_broadcaster(&self) -> broadcast::Sender<Response> {
self.broadcast.clone()
}
/// Creates and returns a new stream of responses that are received that do not match the
/// response to a `send` request
pub fn to_response_broadcast_stream(&mut self) -> BroadcastStream<Response> {
BroadcastStream::new(
self.init_broadcast_receiver
.take()
.unwrap_or_else(|| self.broadcast.subscribe()),
)
}
}
#[cfg(test)]

@ -46,3 +46,26 @@ impl Listener for tokio::net::UnixListener {
Box::pin(accept(self))
}
}
#[cfg(test)]
impl<T: DataStream + Send + Sync> Listener for tokio::sync::Mutex<tokio::sync::mpsc::Receiver<T>> {
type Conn = T;
fn accept<'a>(&'a self) -> Pin<Box<dyn Future<Output = io::Result<Self::Conn>> + Send + 'a>>
where
Self: Sync + 'a,
{
async fn accept<T>(
_self: &tokio::sync::Mutex<tokio::sync::mpsc::Receiver<T>>,
) -> io::Result<T> {
_self
.lock()
.await
.recv()
.await
.ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))
}
Box::pin(accept(self))
}
}

@ -457,7 +457,7 @@ async fn proc_run(
// Spawn a task that sends stdin to the process
let mut stdin = child.stdin.take().unwrap();
let (stdin_tx, mut stdin_rx) = mpsc::channel::<String>(1);
tokio::spawn(async move {
let stdin_task = tokio::spawn(async move {
while let Some(line) = stdin_rx.recv().await {
if let Err(x) = stdin.write_all(line.as_bytes()).await {
error!("Failed to send stdin to process {}: {}", id, x);
@ -469,9 +469,13 @@ async fn proc_run(
// Spawn a task that waits on the process to exit but can also
// kill the process when triggered
let (kill_tx, kill_rx) = oneshot::channel();
tokio::spawn(async move {
let wait_task = tokio::spawn(async move {
tokio::select! {
status = child.wait() => {
if let Err(x) = stdin_task.await {
error!("Join on stdin task failed: {}", x);
}
if let Err(x) = stderr_task.await {
error!("Join on stderr task failed: {}", x);
}
@ -554,6 +558,7 @@ async fn proc_run(
id,
stdin_tx,
kill_tx,
wait_task,
};
state.lock().await.push_process(conn_id, process);
@ -1280,4 +1285,155 @@ mod tests {
// Also verify that the directory was actually created
assert!(path.exists(), "Directory not created");
}
#[tokio::test]
async fn remove_should_send_error_on_failure() {
todo!();
}
#[tokio::test]
async fn remove_should_support_deleting_a_directory() {
todo!();
}
#[tokio::test]
async fn remove_should_delete_nonempty_directory_if_force_is_true() {
todo!();
}
#[tokio::test]
async fn remove_should_support_deleting_a_single_file() {
todo!();
}
#[tokio::test]
async fn copy_should_send_error_on_failure() {
todo!();
}
#[tokio::test]
async fn copy_should_support_copying_an_entire_directory() {
todo!();
}
#[tokio::test]
async fn copy_should_support_copying_a_single_file() {
todo!();
}
#[tokio::test]
async fn rename_should_send_error_on_failure() {
todo!();
}
#[tokio::test]
async fn rename_should_support_renaming_an_entire_directory() {
todo!();
}
#[tokio::test]
async fn rename_should_support_renaming_a_single_file() {
todo!();
}
#[tokio::test]
async fn exists_should_send_error_on_failure() {
todo!();
}
#[tokio::test]
async fn exists_should_send_true_if_path_exists() {
todo!();
}
#[tokio::test]
async fn exists_should_send_false_if_path_does_not_exist() {
todo!();
}
#[tokio::test]
async fn metadata_should_send_error_on_failure() {
todo!();
}
#[tokio::test]
async fn metadata_should_send_back_metadata_on_file_if_exists() {
todo!();
}
#[tokio::test]
async fn metadata_should_send_back_metadata_on_dir_if_exists() {
todo!();
}
#[tokio::test]
async fn metadata_should_include_canonicalized_path_if_flag_specified() {
todo!();
}
#[tokio::test]
async fn proc_run_should_send_error_on_failure() {
todo!();
}
#[tokio::test]
async fn proc_run_should_send_back_proc_start_on_success() {
todo!();
}
#[tokio::test]
async fn proc_run_should_send_back_stdout_periodically_when_available() {
todo!();
}
#[tokio::test]
async fn proc_run_should_send_back_stderr_periodically_when_available() {
todo!();
}
#[tokio::test]
async fn proc_run_should_send_back_done_when_proc_finishes() {
// Make sure to verify that process also removed from state
todo!();
}
#[tokio::test]
async fn proc_run_should_send_back_done_when_killed() {
// Make sure to verify that process also removed from state
todo!();
}
#[tokio::test]
async fn proc_kill_should_send_error_on_failure() {
// Can verify that if the process is not running, will fail
todo!();
}
#[tokio::test]
async fn proc_kill_should_send_ok_on_success() {
// Verify that we trigger sending done
todo!();
}
#[tokio::test]
async fn proc_stdin_should_send_error_on_failure() {
// Can verify that if the process is not running, will fail
todo!();
}
#[tokio::test]
async fn proc_stdin_should_send_ok_on_success() {
// Verify that we trigger sending stdin to process
todo!();
}
#[tokio::test]
async fn proc_list_should_send_proc_entry_list() {
todo!();
}
#[tokio::test]
async fn system_info_should_send_system_info_based_on_binary() {
todo!();
}
}

@ -1,6 +1,14 @@
use log::*;
use std::collections::HashMap;
use tokio::sync::{mpsc, oneshot};
use std::{
collections::HashMap,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tokio::{
sync::{mpsc, oneshot},
task::{JoinError, JoinHandle},
};
/// Holds state related to multiple clients managed by a server
#[derive(Default)]
@ -62,4 +70,22 @@ pub struct Process {
/// Transport channel to report that the process should be killed
pub kill_tx: oneshot::Sender<()>,
/// Task used to wait on the process to complete or be killed
pub wait_task: JoinHandle<()>,
}
impl Process {
pub async fn kill_and_wait(self) -> Result<(), JoinError> {
let _ = self.kill_tx.send(());
self.wait_task.await
}
}
impl Future for Process {
type Output = Result<(), JoinError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.wait_task).poll(cx)
}
}

@ -9,7 +9,7 @@ use log::*;
use std::{collections::HashMap, marker::Unpin, sync::Arc};
use tokio::{
io::{self, AsyncRead, AsyncWrite},
sync::{broadcast, mpsc, oneshot, Mutex},
sync::{mpsc, oneshot, Mutex},
task::{JoinError, JoinHandle},
time::Duration,
};
@ -17,13 +17,14 @@ use tokio::{
/// Represents a server that relays requests & responses between connections and the
/// actual server
pub struct RelayServer {
forward_task: JoinHandle<()>,
accept_task: JoinHandle<()>,
broadcast_task: JoinHandle<()>,
forward_task: JoinHandle<()>,
conns: Arc<Mutex<HashMap<usize, Conn>>>,
}
impl RelayServer {
pub async fn initialize<T1, T2, L>(
pub fn initialize<T1, T2, L>(
mut session: Session<T1>,
listener: L,
shutdown_after: Option<Duration>,
@ -33,10 +34,34 @@ impl RelayServer {
T2: DataStream + Send + 'static,
L: Listener<Conn = T2> + 'static,
{
// Get a copy of our session's broadcaster so we can have each connection
// subscribe to it for new messages filtered by tenant
debug!("Acquiring session broadcaster");
let broadcaster = session.to_response_broadcaster();
let conns: Arc<Mutex<HashMap<usize, Conn>>> = Arc::new(Mutex::new(HashMap::new()));
// Spawn task to send server responses to the appropriate connections
let conns_2 = Arc::clone(&conns);
debug!("Spawning response broadcast task");
let mut broadcast = session.broadcast.take().unwrap();
let broadcast_task = tokio::spawn(async move {
while let Some(res) = broadcast.recv().await {
// Search for all connections with a tenant that matches the response's tenant
for conn in conns_2.lock().await.values_mut() {
if conn.state.lock().await.tenant.as_deref() == Some(res.tenant.as_str()) {
debug!(
"Forwarding response of type{} {} to connection {}",
if res.payload.len() > 1 { "s" } else { "" },
res.to_payload_type_string(),
conn.id
);
if let Err(x) = conn.forward_response(res).await {
error!("Failed to pass forwarding message: {}", x);
}
// NOTE: We assume that tenant is unique, so we can break after
// forwarding the message to the first matching tenant
break;
}
}
}
});
// Spawn task to send to the server requests from connections
debug!("Spawning request forwarding task");
@ -56,7 +81,6 @@ impl RelayServer {
});
let (shutdown, tracker) = ShutdownTask::maybe_initialize(shutdown_after);
let conns = Arc::new(Mutex::new(HashMap::new()));
let conns_2 = Arc::clone(&conns);
let accept_task = tokio::spawn(async move {
let inner = async move {
@ -66,7 +90,6 @@ impl RelayServer {
let result = Conn::initialize(
stream,
req_tx.clone(),
broadcaster.clone(),
tracker.as_ref().map(Arc::clone),
)
.await;
@ -96,22 +119,24 @@ impl RelayServer {
});
Ok(Self {
forward_task,
accept_task,
broadcast_task,
forward_task,
conns,
})
}
pub async fn wait(self) -> Result<(), JoinError> {
match tokio::try_join!(self.forward_task, self.accept_task) {
match tokio::try_join!(self.accept_task, self.broadcast_task, self.forward_task) {
Ok(_) => Ok(()),
Err(x) => Err(x),
}
}
pub async fn abort(&self) {
self.forward_task.abort();
self.accept_task.abort();
self.broadcast_task.abort();
self.forward_task.abort();
self.conns
.lock()
.await
@ -124,11 +149,14 @@ struct Conn {
id: usize,
req_task: JoinHandle<()>,
res_task: JoinHandle<()>,
res_tx: mpsc::Sender<Response>,
state: Arc<Mutex<ConnState>>,
}
/// Represents state associated with a connection
#[derive(Default)]
struct ConnState {
tenant: Option<String>,
processes: Vec<usize>,
}
@ -136,7 +164,6 @@ impl Conn {
pub async fn initialize<T>(
stream: T,
req_tx: mpsc::Sender<Request>,
res_broadcaster: broadcast::Sender<Response>,
ct: Option<Arc<Mutex<ConnTracker>>>,
) -> io::Result<Self>
where
@ -164,7 +191,7 @@ impl Conn {
// Spawn task to continually receive responses from the session that
// may or may not be relevant to the connection, which will filter
// by tenant and then along any response that matches
let res_rx = res_broadcaster.subscribe();
let (res_tx, res_rx) = mpsc::channel::<Response>(CLIENT_BROADCAST_CHANNEL_CAPACITY);
let state_2 = Arc::clone(&state);
let res_task = tokio::spawn(async move {
handle_conn_outgoing(id, state_2, t_write, tenant_rx, res_rx).await;
@ -173,11 +200,12 @@ impl Conn {
// Spawn task to continually read requests from connection and forward
// them along to be sent via the session
let req_tx = req_tx.clone();
let state_2 = Arc::clone(&state);
let req_task = tokio::spawn(async move {
if let Some(ct) = ct.as_ref() {
ct.lock().await.increment();
}
handle_conn_incoming(id, state, t_read, tenant_tx, req_tx).await;
handle_conn_incoming(id, state_2, t_read, tenant_tx, req_tx).await;
if let Some(ct) = ct.as_ref() {
ct.lock().await.decrement();
}
@ -188,6 +216,8 @@ impl Conn {
id,
req_task,
res_task,
res_tx,
state,
})
}
@ -201,6 +231,14 @@ impl Conn {
self.req_task.abort();
self.res_task.abort();
}
/// Forwards a response back through this connection
pub async fn forward_response(
&mut self,
res: Response,
) -> Result<(), mpsc::error::SendError<Response>> {
self.res_tx.send(res).await
}
}
/// Conn::Request -> Session::Fire
@ -284,7 +322,7 @@ async fn handle_conn_outgoing<T>(
state: Arc<Mutex<ConnState>>,
mut writer: TransportWriteHalf<T>,
tenant_rx: oneshot::Receiver<String>,
mut res_rx: broadcast::Receiver<Response>,
mut res_rx: mpsc::Receiver<Response>,
) where
T: AsyncWrite + Unpin,
{
@ -294,43 +332,81 @@ async fn handle_conn_outgoing<T>(
// all responses before we know the tenant
if let Ok(tenant) = tenant_rx.await {
debug!("Associated tenant {} with conn {}", tenant, conn_id);
loop {
match res_rx.recv().await {
// Forward along responses that are for our connection
Ok(res) if res.tenant == tenant => {
debug!(
"Conn {} being sent response of type{} {}",
conn_id,
if res.payload.len() > 1 { "s" } else { "" },
res.to_payload_type_string(),
);
// If a new process was started, we want to capture the id and
// associate it with the connection
let ids = res.payload.iter().filter_map(|x| match x {
ResponseData::ProcStart { id } => Some(*id),
_ => None,
});
for id in ids {
debug!("Tracking proc {} for conn {}", id, conn_id);
state.lock().await.processes.push(id);
}
state.lock().await.tenant = Some(tenant.clone());
if let Err(x) = writer.send(res).await {
error!("Failed to send response through unix connection: {}", x);
break;
}
}
// Skip responses that are not for our connection
Ok(_) => {}
Err(x) => {
error!(
"Conn {} failed to receive broadcast response: {}",
conn_id, x
);
break;
}
while let Some(res) = res_rx.recv().await {
debug!(
"Conn {} being sent response of type{} {}",
conn_id,
if res.payload.len() > 1 { "s" } else { "" },
res.to_payload_type_string(),
);
// If a new process was started, we want to capture the id and
// associate it with the connection
let ids = res.payload.iter().filter_map(|x| match x {
ResponseData::ProcStart { id } => Some(*id),
_ => None,
});
for id in ids {
debug!("Tracking proc {} for conn {}", id, conn_id);
state.lock().await.processes.push(id);
}
if let Err(x) = writer.send(res).await {
error!("Failed to send response through unix connection: {}", x);
break;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn wait_should_return_ok_when_all_inner_tasks_complete() {
todo!();
}
#[test]
fn wait_should_return_error_when_server_aborted() {
todo!();
}
#[test]
fn abort_should_abort_inner_tasks_and_all_connections() {
todo!();
}
#[test]
fn server_should_shutdown_if_no_connections_after_shutdown_duration() {
todo!();
}
#[test]
fn server_shutdown_should_abort_all_connections() {
todo!();
}
#[test]
fn server_should_forward_connection_requests_to_session() {
todo!();
}
#[test]
fn server_should_forward_session_responses_to_connection_with_matching_tenant() {
todo!();
}
#[test]
fn connection_abort_should_abort_inner_tasks() {
todo!();
}
#[test]
fn connection_abort_should_send_process_kill_requests_through_session() {
todo!();
}
}

Loading…
Cancel
Save