Fix process cleanup happening when only half of a tranport has closed

pull/38/head
Chip Senkbeil 3 years ago
parent fc1c262f55
commit b362ff5ab8
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -171,10 +171,15 @@ async fn on_new_conn<T>(
auth_key: Arc<SecretKey>,
tracker: Option<Arc<Mutex<ConnTracker>>>,
max_msg_capacity: usize,
) -> io::Result<(JoinHandle<()>, JoinHandle<()>)>
) -> io::Result<JoinHandle<()>>
where
T: DataStream,
{
// Update our tracker to reflect the new connection
if let Some(ct) = tracker.as_ref() {
ct.lock().await.increment();
}
// Establish a proper connection via a handshake,
// discarding the connection otherwise
let transport = Transport::from_handshake(conn, Some(auth_key)).await?;
@ -185,26 +190,27 @@ where
let (tx, rx) = mpsc::channel(max_msg_capacity);
// Spawn a new task that loops to handle requests from the client
let req_task = tokio::spawn({
let f = request_loop(conn_id, Arc::clone(&state), t_read, tx);
let state = Arc::clone(&state);
async move {
if let Some(ct) = tracker.as_ref() {
ct.lock().await.increment();
}
f.await;
state.lock().await.cleanup_connection(conn_id).await;
if let Some(ct) = tracker.as_ref() {
ct.lock().await.decrement();
}
}
let state_2 = Arc::clone(&state);
let req_task = tokio::spawn(async move {
request_loop(conn_id, state_2, t_read, tx).await;
});
// Spawn a new task that loops to handle responses to the client
let res_task = tokio::spawn(async move { response_loop(conn_id, t_write, rx).await });
Ok((req_task, res_task))
// Spawn cleanup task that waits on our req & res tasks to complete
let cleanup_task = tokio::spawn(async move {
// Wait for both receiving and sending tasks to complete before marking
// the connection as complete
let _ = tokio::join!(req_task, res_task);
state.lock().await.cleanup_connection(conn_id).await;
if let Some(ct) = tracker.as_ref() {
ct.lock().await.decrement();
}
});
Ok(cleanup_task)
}
/// Repeatedly reads in new requests, processes them, and sends their responses to the
@ -234,7 +240,7 @@ async fn request_loop<T>(
}
}
Ok(None) => {
info!("<Conn @ {}> Closed connection", conn_id);
info!("<Conn @ {}> Input from connection closed", conn_id);
break;
}
Err(x) => {
@ -243,6 +249,10 @@ async fn request_loop<T>(
}
}
}
// Properly close off any associated process' stdin given that we can't get new
// requests to send more stdin to them
state.lock().await.close_stdin_for_connection(conn_id);
}
/// Repeatedly sends responses out over the wire

@ -30,6 +30,24 @@ impl State {
self.processes.insert(process.id, process);
}
/// Closes stdin for all processes associated with the connection
pub fn close_stdin_for_connection(&mut self, conn_id: usize) {
debug!("<Conn @ {:?}> Closing stdin to all processes", conn_id);
if let Some(ids) = self.client_processes.get(&conn_id) {
for id in ids {
if let Some(process) = self.processes.get_mut(&id) {
trace!(
"<Conn @ {:?}> Closing stdin for proc {}",
conn_id,
process.id
);
process.close_stdin();
}
}
}
}
/// Cleans up state associated with a particular connection
pub async fn cleanup_connection(&mut self, conn_id: usize) {
debug!("<Conn @ {:?}> Cleaning up state", conn_id);

@ -154,6 +154,7 @@ struct Conn {
id: usize,
req_task: JoinHandle<()>,
res_task: JoinHandle<()>,
cleanup_task: JoinHandle<()>,
res_tx: mpsc::Sender<Response>,
state: Arc<Mutex<ConnState>>,
}
@ -193,24 +194,35 @@ impl Conn {
debug!("<Conn @ {}> Initializing internal state", id);
let state = Arc::new(Mutex::new(ConnState::default()));
// Mark that we have a new connection
if let Some(ct) = ct.as_ref() {
ct.lock().await.increment();
}
// Spawn task to continually receive responses from the session that
// may or may not be relevant to the connection, which will filter
// by tenant and then along any response that matches
let (res_tx, res_rx) = mpsc::channel::<Response>(CLIENT_BROADCAST_CHANNEL_CAPACITY);
let (res_task_tx, res_task_rx) = oneshot::channel();
let state_2 = Arc::clone(&state);
let res_task = tokio::spawn(async move {
handle_conn_outgoing(id, state_2, t_write, tenant_rx, res_rx).await;
let _ = res_task_tx.send(());
});
// Spawn task to continually read requests from connection and forward
// them along to be sent via the session
let req_tx = req_tx.clone();
let (req_task_tx, req_task_rx) = oneshot::channel();
let state_2 = Arc::clone(&state);
let req_task = tokio::spawn(async move {
if let Some(ct) = ct.as_ref() {
ct.lock().await.increment();
}
handle_conn_incoming(id, state_2, t_read, tenant_tx, req_tx).await;
let _ = req_task_tx.send(());
});
let cleanup_task = tokio::spawn(async move {
let _ = tokio::join!(req_task_rx, res_task_rx);
if let Some(ct) = ct.as_ref() {
ct.lock().await.decrement();
}
@ -221,6 +233,7 @@ impl Conn {
id,
req_task,
res_task,
cleanup_task,
res_tx,
state,
})
@ -233,6 +246,8 @@ impl Conn {
/// Aborts the connection from the server side
pub fn abort(&self) {
// NOTE: We don't abort the cleanup task as that needs to actually happen
// and will even if these tasks are aborted
self.req_task.abort();
self.res_task.abort();
}

@ -122,8 +122,6 @@ async fn start<T>(
where
T: DataStream + 'static,
{
// TODO: Because lsp is being handled in a separate action, we should fail if we get
// a session type of lsp for a regular action
match (cmd.interactive, cmd.operation) {
// ProcRun request is specially handled and we ignore interactive as
// the stdin will be used for sending ProcStdin to remote process

@ -53,7 +53,6 @@ pub fn run(cmd: ListenSubcommand, opt: CommonOpt) -> Result<(), Error> {
async fn run_async(cmd: ListenSubcommand, _opt: CommonOpt, is_forked: bool) -> Result<(), Error> {
let addr = cmd.host.to_ip_addr(cmd.use_ipv6)?;
let socket_addrs = cmd.port.make_socket_addrs(addr);
let shutdown_after = cmd.to_shutdown_after_duration();
// If specified, change the current working directory of this program

Loading…
Cancel
Save