Fix InmemoryStreamWriteHalf AsyncWrite to properly yield pending upon full channel

pull/96/head
Chip Senkbeil 3 years ago
parent f1e0f82df5
commit 7c27c24636
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -5,7 +5,7 @@ use std::{
};
use tokio::{
io::{self, AsyncRead, AsyncWrite, ReadBuf},
sync::mpsc,
sync::mpsc::{self, error::TrySendError},
};
/// Represents a data stream comprised of two inmemory channels
@ -147,7 +147,8 @@ impl AsyncWrite for InmemoryStreamWriteHalf {
) -> Poll<io::Result<usize>> {
match self.0.try_send(buf.to_vec()) {
Ok(_) => Poll::Ready(Ok(buf.len())),
Err(_) => Poll::Ready(Ok(0)),
Err(TrySendError::Full(_)) => Poll::Pending,
Err(TrySendError::Closed(_)) => Poll::Ready(Ok(0)),
}
}
@ -309,4 +310,233 @@ mod tests {
assert_eq!(rx.recv().await, None, "Unexpectedly got more data");
}
#[tokio::test]
async fn read_half_should_fail_if_buf_has_no_space_remaining() {
let (_tx, _rx, stream) = InmemoryStream::make(1);
let (mut t_read, _t_write) = stream.into_split();
let mut buf = [0u8; 0];
match t_read.read(&mut buf).await {
Err(x) if x.kind() == io::ErrorKind::Other => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[tokio::test]
async fn read_half_should_update_buf_with_all_overflow_from_last_read_if_it_all_fits() {
let (tx, _rx, stream) = InmemoryStream::make(1);
let (mut t_read, _t_write) = stream.into_split();
tx.send(vec![1, 2, 3]).await.expect("Failed to send");
let mut buf = [0u8; 2];
// First, read part of the data (first two bytes)
match t_read.read(&mut buf).await {
Ok(n) if n == 2 => assert_eq!(&buf[..n], &[1, 2]),
x => panic!("Unexpected result: {:?}", x),
}
// Second, we send more data because the last message was placed in overflow
tx.send(vec![4, 5, 6]).await.expect("Failed to send");
// Third, read remainder of the overflow from first message (third byte)
match t_read.read(&mut buf).await {
Ok(n) if n == 1 => assert_eq!(&buf[..n], &[3]),
x => panic!("Unexpected result: {:?}", x),
}
// Fourth, verify that we start to receive the next overflow
match t_read.read(&mut buf).await {
Ok(n) if n == 2 => assert_eq!(&buf[..n], &[4, 5]),
x => panic!("Unexpected result: {:?}", x),
}
// Fifth, verify that we get the last bit of overflow
match t_read.read(&mut buf).await {
Ok(n) if n == 1 => assert_eq!(&buf[..n], &[6]),
x => panic!("Unexpected result: {:?}", x),
}
}
#[tokio::test]
async fn read_half_should_update_buf_with_some_of_overflow_that_can_fit() {
let (tx, _rx, stream) = InmemoryStream::make(1);
let (mut t_read, _t_write) = stream.into_split();
tx.send(vec![1, 2, 3, 4, 5]).await.expect("Failed to send");
let mut buf = [0u8; 2];
// First, read part of the data (first two bytes)
match t_read.read(&mut buf).await {
Ok(n) if n == 2 => assert_eq!(&buf[..n], &[1, 2]),
x => panic!("Unexpected result: {:?}", x),
}
// Second, we send more data because the last message was placed in overflow
tx.send(vec![6]).await.expect("Failed to send");
// Third, read next chunk of the overflow from first message (next two byte)
match t_read.read(&mut buf).await {
Ok(n) if n == 2 => assert_eq!(&buf[..n], &[3, 4]),
x => panic!("Unexpected result: {:?}", x),
}
// Fourth, read last chunk of the overflow from first message (fifth byte)
match t_read.read(&mut buf).await {
Ok(n) if n == 1 => assert_eq!(&buf[..n], &[5]),
x => panic!("Unexpected result: {:?}", x),
}
}
#[tokio::test]
async fn read_half_should_update_buf_with_all_of_inner_channel_when_it_fits() {
let (tx, _rx, stream) = InmemoryStream::make(1);
let (mut t_read, _t_write) = stream.into_split();
let mut buf = [0u8; 5];
tx.send(vec![1, 2, 3, 4, 5]).await.expect("Failed to send");
// First, read all of data that fits exactly
match t_read.read(&mut buf).await {
Ok(n) if n == 5 => assert_eq!(&buf[..n], &[1, 2, 3, 4, 5]),
x => panic!("Unexpected result: {:?}", x),
}
tx.send(vec![6, 7, 8]).await.expect("Failed to send");
// Second, read data that fits within buf
match t_read.read(&mut buf).await {
Ok(n) if n == 3 => assert_eq!(&buf[..n], &[6, 7, 8]),
x => panic!("Unexpected result: {:?}", x),
}
}
#[tokio::test]
async fn read_half_should_update_buf_with_some_of_inner_channel_that_can_fit_and_add_rest_to_overflow(
) {
let (tx, _rx, stream) = InmemoryStream::make(1);
let (mut t_read, _t_write) = stream.into_split();
let mut buf = [0u8; 1];
tx.send(vec![1, 2, 3, 4, 5]).await.expect("Failed to send");
// Attempt a read that places more in overflow
match t_read.read(&mut buf).await {
Ok(n) if n == 1 => assert_eq!(&buf[..n], &[1]),
x => panic!("Unexpected result: {:?}", x),
}
// Verify overflow contains the rest
assert_eq!(&t_read.overflow, &[2, 3, 4, 5]);
// Queue up extra data that will not be read until overflow is finished
tx.send(vec![6, 7, 8]).await.expect("Failed to send");
// Read next data point
match t_read.read(&mut buf).await {
Ok(n) if n == 1 => assert_eq!(&buf[..n], &[2]),
x => panic!("Unexpected result: {:?}", x),
}
// Verify overflow contains the rest without having added extra data
assert_eq!(&t_read.overflow, &[3, 4, 5]);
}
#[tokio::test]
async fn read_half_should_yield_pending_if_no_data_available_on_inner_channel() {
let (_tx, _rx, stream) = InmemoryStream::make(1);
let (mut t_read, _t_write) = stream.into_split();
let mut buf = [0u8; 1];
// Attempt a read that should yield ok with no change, which is what should
// happen when nothing is read into buf
let f = t_read.read(&mut buf);
tokio::pin!(f);
match futures::poll!(f) {
Poll::Pending => {}
x => panic!("Unexpected poll result: {:?}", x),
}
}
#[tokio::test]
async fn read_half_should_not_update_buf_if_inner_channel_closed() {
let (tx, _rx, stream) = InmemoryStream::make(1);
let (mut t_read, _t_write) = stream.into_split();
let mut buf = [0u8; 1];
// Drop the channel that would be sending data to the transport
drop(tx);
// Attempt a read that should yield ok with no change, which is what should
// happen when nothing is read into buf
match t_read.read(&mut buf).await {
Ok(n) if n == 0 => assert_eq!(&buf, &[0]),
x => panic!("Unexpected result: {:?}", x),
}
}
#[tokio::test]
async fn write_half_should_return_buf_len_if_can_send_immediately() {
let (_tx, mut rx, stream) = InmemoryStream::make(1);
let (_t_read, mut t_write) = stream.into_split();
// Write that is not waiting should always succeed with full contents
let n = t_write.write(&[1, 2, 3]).await.expect("Failed to write");
assert_eq!(n, 3, "Unexpected byte count returned");
// Verify we actually had the data sent
let data = rx.try_recv().expect("Failed to recv data");
assert_eq!(data, &[1, 2, 3]);
}
#[tokio::test]
async fn write_half_should_return_support_eventually_sending_by_retrying_when_not_ready() {
let (_tx, mut rx, stream) = InmemoryStream::make(1);
let (_t_read, mut t_write) = stream.into_split();
// Queue a write already so that we block on the next one
t_write.write(&[1, 2, 3]).await.expect("Failed to write");
// Verify that the next write is pending
let f = t_write.write(&[4, 5]);
tokio::pin!(f);
match futures::poll!(&mut f) {
Poll::Pending => {}
x => panic!("Unexpected poll result: {:?}", x),
}
// Consume first batch of data so future of second can continue
let data = rx.try_recv().expect("Failed to recv data");
assert_eq!(data, &[1, 2, 3]);
// Verify that poll now returns success
match futures::poll!(f) {
Poll::Ready(Ok(n)) if n == 2 => {}
x => panic!("Unexpected poll result: {:?}", x),
}
// Consume second batch of data
let data = rx.try_recv().expect("Failed to recv data");
assert_eq!(data, &[4, 5]);
}
#[tokio::test]
async fn write_half_should_zero_if_inner_channel_closed() {
let (_tx, rx, stream) = InmemoryStream::make(1);
let (_t_read, mut t_write) = stream.into_split();
// Drop receiving end that transport would talk to
drop(rx);
// Channel is dropped, so return 0 to indicate no bytes sent
let n = t_write.write(&[1, 2, 3]).await.expect("Failed to write");
assert_eq!(n, 0, "Unexpected byte count returned");
}
}

@ -103,8 +103,8 @@ pub(super) async fn process(
args,
detached,
} => proc_run(conn_id, state, reply, cmd, args, detached).await,
RequestData::ProcKill { id } => proc_kill(state, id).await,
RequestData::ProcStdin { id, data } => proc_stdin(state, id, data).await,
RequestData::ProcKill { id } => proc_kill(conn_id, state, id).await,
RequestData::ProcStdin { id, data } => proc_stdin(conn_id, state, id, data).await,
RequestData::ProcList {} => proc_list(state).await,
RequestData::SystemInfo {} => system_info().await,
}
@ -458,7 +458,7 @@ where
Ok(data) => {
let payload = vec![ResponseData::ProcStdout { id, data }];
if !reply_2(payload).await {
error!("<Conn @ {}> Stdout channel closed", conn_id);
error!("<Conn @ {} | Proc {}> Stdout channel closed", conn_id, id);
break;
}
@ -470,12 +470,21 @@ where
.await;
}
Err(x) => {
error!("Invalid data read from stdout pipe: {}", x);
error!(
"<Conn @ {} | Proc {}> Invalid data read from stdout pipe: {}",
conn_id, id, x
);
break;
}
},
Ok(_) => break,
Err(_) => break,
Err(x) => {
error!(
"<Conn @ {} | Proc {}> Reading stdout failed: {}",
conn_id, id, x
);
break;
}
}
}
});
@ -491,7 +500,7 @@ where
Ok(data) => {
let payload = vec![ResponseData::ProcStderr { id, data }];
if !reply_2(payload).await {
error!("<Conn @ {}> Stderr channel closed", conn_id);
error!("<Conn @ {} | Proc {}> Stderr channel closed", conn_id, id);
break;
}
@ -503,12 +512,21 @@ where
.await;
}
Err(x) => {
error!("Invalid data read from stdout pipe: {}", x);
error!(
"<Conn @ {} | Proc {}> Invalid data read from stdout pipe: {}",
conn_id, id, x
);
break;
}
},
Ok(_) => break,
Err(_) => break,
Err(x) => {
error!(
"<Conn @ {} | Proc {}> Reading stderr failed: {}",
conn_id, id, x
);
break;
}
}
}
});
@ -520,7 +538,7 @@ where
while let Some(line) = stdin_rx.recv().await {
if let Err(x) = stdin.write_all(line.as_bytes()).await {
error!(
"<Conn @ {}> Failed to send stdin to process {}: {}",
"<Conn @ {} | Proc {}> Failed to send stdin: {}",
conn_id, id, x
);
break;
@ -536,18 +554,22 @@ where
let wait_task = tokio::spawn(async move {
tokio::select! {
status = child.wait() => {
debug!("<Conn @ {}> Process {} done", conn_id, id);
debug!(
"<Conn @ {} | Proc {}> Completed and waiting on stdout & stderr tasks",
conn_id,
id,
);
// Force stdin task to abort if it hasn't exited as there is no
// point to sending any more stdin
stdin_task.abort();
if let Err(x) = stderr_task.await {
error!("<Conn @ {}> Join on stderr task failed: {}", conn_id, x);
error!("<Conn @ {} | Proc {}> Join on stderr task failed: {}", conn_id, id, x);
}
if let Err(x) = stdout_task.await {
error!("<Conn @ {}> Join on stdout task failed: {}", conn_id, x);
error!("<Conn @ {} | Proc {}> Join on stdout task failed: {}", conn_id, id, x);
}
state_2.lock().await.remove_process(conn_id, id);
@ -559,7 +581,7 @@ where
let payload = vec![ResponseData::ProcDone { id, success, code }];
if !reply_2(payload).await {
error!(
"<Conn @ {}> Failed to send done for process {}!",
"<Conn @ {} | Proc {}> Failed to send done",
conn_id,
id,
);
@ -569,7 +591,7 @@ where
let payload = vec![ResponseData::from(x)];
if !reply_2(payload).await {
error!(
"<Conn @ {}> Failed to send error for waiting on process {}!",
"<Conn @ {} | Proc {}> Failed to send error for waiting",
conn_id,
id,
);
@ -579,10 +601,10 @@ where
},
_ = kill_rx => {
debug!("<Conn @ {}> Process {} killed", conn_id, id);
debug!("<Conn @ {} | Proc {}> Killing", conn_id, id);
if let Err(x) = child.kill().await {
error!("<Conn @ {}> Unable to kill process {}: {}", conn_id, id, x);
error!("<Conn @ {} | Proc {}> Unable to kill: {}", conn_id, id, x);
}
// Force stdin task to abort if it hasn't exited as there is no
@ -590,24 +612,24 @@ where
stdin_task.abort();
if let Err(x) = stderr_task.await {
error!("<Conn @ {}> Join on stderr task failed: {}", conn_id, x);
error!("<Conn @ {} | Proc {}> Join on stderr task failed: {}", conn_id, id, x);
}
if let Err(x) = stdout_task.await {
error!("<Conn @ {}> Join on stdout task failed: {}", conn_id, x);
error!("<Conn @ {} | Proc {}> Join on stdout task failed: {}", conn_id, id, x);
}
// Wait for the child after being killed to ensure that it has been cleaned
// up at the operating system level
if let Err(x) = child.wait().await {
error!("<Conn @ {}> Failed to wait on killed process {}: {}", conn_id, id, x);
error!("<Conn @ {} | Proc {}> Failed to wait after killed: {}", conn_id, id, x);
}
state_2.lock().await.remove_process(conn_id, id);
let payload = vec![ResponseData::ProcDone { id, success: false, code: None }];
if !reply_2(payload).await {
error!("<Conn @ {}> Failed to send done for process {}!", conn_id, id);
error!("<Conn @ {} | Proc {}> Failed to send done", conn_id, id);
}
}
}
@ -625,7 +647,7 @@ where
})
}
async fn proc_kill(state: HState, id: usize) -> Result<Outgoing, ServerError> {
async fn proc_kill(conn_id: usize, state: HState, id: usize) -> Result<Outgoing, ServerError> {
if let Some(process) = state.lock().await.processes.remove(&id) {
if process.kill() {
return Ok(Outgoing::from(ResponseData::Ok));
@ -634,11 +656,19 @@ async fn proc_kill(state: HState, id: usize) -> Result<Outgoing, ServerError> {
Err(ServerError::IoError(io::Error::new(
io::ErrorKind::BrokenPipe,
"Unable to send kill signal to process",
format!(
"<Conn @ {} | Proc {}> Unable to send kill signal to process",
conn_id, id
),
)))
}
async fn proc_stdin(state: HState, id: usize, data: String) -> Result<Outgoing, ServerError> {
async fn proc_stdin(
conn_id: usize,
state: HState,
id: usize,
data: String,
) -> Result<Outgoing, ServerError> {
if let Some(process) = state.lock().await.processes.get(&id) {
if process.send_stdin(data).await {
return Ok(Outgoing::from(ResponseData::Ok));
@ -647,7 +677,10 @@ async fn proc_stdin(state: HState, id: usize, data: String) -> Result<Outgoing,
Err(ServerError::IoError(io::Error::new(
io::ErrorKind::BrokenPipe,
"Unable to send stdin to process",
format!(
"<Conn @ {} | Proc {}> Unable to send stdin to process",
conn_id, id,
),
)))
}

@ -672,7 +672,7 @@ where
Ok(data) => {
let payload = vec![ResponseData::ProcStdout { id, data }];
if !reply_2(payload).await {
error!("<Ssh: Proc {}> Stdout channel closed", id);
error!("<Ssh | Proc {}> Stdout channel closed", id);
break;
}
@ -685,7 +685,7 @@ where
}
Err(x) => {
error!(
"<Ssh: Proc {}> Invalid data read from stdout pipe: {}",
"<Ssh | Proc {}> Invalid data read from stdout pipe: {}",
id, x
);
break;
@ -698,7 +698,10 @@ where
tokio::time::sleep(tokio::time::Duration::from_millis(READ_PAUSE_MILLIS))
.await;
}
Err(_) => break,
Err(x) => {
error!("<Ssh | Proc {}> Stdout unexpectedly closed: {}", id, x);
break;
}
}
}
});
@ -713,7 +716,7 @@ where
Ok(data) => {
let payload = vec![ResponseData::ProcStderr { id, data }];
if !reply_2(payload).await {
error!("<Ssh: Proc {}> Stderr channel closed", id);
error!("<Ssh | Proc {}> Stderr channel closed", id);
break;
}
@ -726,7 +729,7 @@ where
}
Err(x) => {
error!(
"<Ssh: Proc {}> Invalid data read from stderr pipe: {}",
"<Ssh | Proc {}> Invalid data read from stderr pipe: {}",
id, x
);
break;
@ -739,7 +742,10 @@ where
tokio::time::sleep(tokio::time::Duration::from_millis(READ_PAUSE_MILLIS))
.await;
}
Err(_) => break,
Err(x) => {
error!("<Ssh | Proc {}> Stderr unexpectedly closed: {}", id, x);
break;
}
}
}
});
@ -747,7 +753,7 @@ where
let stdin_task = tokio::spawn(async move {
while let Some(line) = stdin_rx.recv().await {
if let Err(x) = stdin.write_all(line.as_bytes()) {
error!("<Ssh: Proc {}> Failed to send stdin: {}", id, x);
error!("<Ssh | Proc {}> Failed to send stdin: {}", id, x);
break;
}
}
@ -770,7 +776,7 @@ where
success = status.success();
}
Err(x) => {
error!("<Ssh: Proc {}> Waiting on process failed: {}", id, x);
error!("<Ssh | Proc {}> Waiting on process failed: {}", id, x);
}
}
}
@ -781,10 +787,10 @@ where
stdin_task.abort();
if should_kill {
debug!("<Ssh: Proc {}> Process killed", id);
debug!("<Ssh | Proc {}> Killing", id);
if let Err(x) = child.kill() {
error!("<Ssh: Proc {}> Unable to kill process: {}", id, x);
error!("<Ssh | Proc {}> Unable to kill process: {}", id, x);
}
// NOTE: At the moment, child.kill does nothing for wezterm_ssh::SshChildProcess;
@ -801,15 +807,18 @@ where
.await;
}
} else {
debug!("<Ssh: Proc {}> Process done", id);
debug!(
"<Ssh | Proc {}> Completed and waiting on stdout & stderr tasks",
id
);
}
if let Err(x) = stderr_task.await {
error!("<Ssh: Proc {}> Join on stderr task failed: {}", id, x);
error!("<Ssh | Proc {}> Join on stderr task failed: {}", id, x);
}
if let Err(x) = stdout_task.await {
error!("<Ssh: Proc {}> Join on stdout task failed: {}", id, x);
error!("<Ssh | Proc {}> Join on stdout task failed: {}", id, x);
}
state_2.lock().await.processes.remove(&id);
@ -821,7 +830,7 @@ where
}];
if !reply_2(payload).await {
error!("<Ssh: Proc {}> Failed to send done!", id,);
error!("<Ssh | Proc {}> Failed to send done", id,);
}
});
});
@ -845,7 +854,7 @@ async fn proc_kill(
Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"Unable to send kill signal to process",
format!("<Ssh | Proc {}> Unable to send kill signal to process", id),
))
}
@ -863,7 +872,7 @@ async fn proc_stdin(
Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"Unable to send stdin to process",
format!("<Ssh | Proc {}> Unable to send stdin to process", id),
))
}

@ -507,17 +507,20 @@ impl Ssh2Session {
if let Err(x) =
handler::process(wez_session.clone(), Arc::clone(&state), req, tx.clone()).await
{
error!("{}", x);
error!("Ssh session receiver handler failed: {}", x);
}
}
debug!("Ssh receiver task is now closed");
});
tokio::spawn(async move {
while let Some(res) = rx.recv().await {
if t_write.send(res).await.is_err() {
if let Err(x) = t_write.send(res).await {
error!("Ssh session sender failed: {}", x);
break;
}
}
debug!("Ssh sender task is now closed");
});
Ok(session)

Loading…
Cancel
Save