Add tests for Transport::write_all and Transport::read_exact

pull/146/head
Chip Senkbeil 2 years ago
parent df9405dcf4
commit c7b920e251
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -3,8 +3,8 @@ use std::io;
// mod router;
/* mod framed;
pub use framed::*; */
mod framed;
pub use framed::*;
mod inmemory;
pub use inmemory::*;
@ -58,13 +58,13 @@ pub trait Transport: Reconnectable {
/// Waits for the transport to be readable to follow up with `try_read`
async fn readable(&self) -> io::Result<()> {
let _ = self.ready(Interest::READABLE).await?;
self.ready(Interest::READABLE).await?;
Ok(())
}
/// Waits for the transport to be writeable to follow up with `try_write`
async fn writeable(&self) -> io::Result<()> {
let _ = self.ready(Interest::WRITABLE).await?;
self.ready(Interest::WRITABLE).await?;
Ok(())
}
@ -115,12 +115,12 @@ pub trait Transport: Reconnectable {
match self.try_write(&buf[i..]) {
// If we get 0 bytes written, this usually means that the underlying writer
// has closed, so we will return a broken pipe error to reflect that
// has closed, so we will return a write zero error to reflect that
//
// NOTE: `try_write` can also return 0 if the buf len is zero, but because we check
// that our index is < len, the situation where we call try_write with a buf
// of len 0 will never happen
Ok(0) => return Err(io::Error::from(io::ErrorKind::BrokenPipe)),
Ok(0) => return Err(io::Error::from(io::ErrorKind::WriteZero)),
Ok(n) => i += n,
@ -135,3 +135,217 @@ pub trait Transport: Reconnectable {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestTransport {
f_try_read: Box<dyn Fn(&mut [u8]) -> io::Result<usize> + Send + Sync>,
f_try_write: Box<dyn Fn(&[u8]) -> io::Result<usize> + Send + Sync>,
f_ready: Box<dyn Fn(Interest) -> io::Result<Ready> + Send + Sync>,
}
impl Default for TestTransport {
fn default() -> Self {
Self {
f_try_read: Box::new(|_| unimplemented!()),
f_try_write: Box::new(|_| unimplemented!()),
f_ready: Box::new(|_| unimplemented!()),
}
}
}
#[async_trait]
impl Reconnectable for TestTransport {
async fn reconnect(&mut self) -> io::Result<()> {
unimplemented!();
}
}
#[async_trait]
impl Transport for TestTransport {
fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
(self.f_try_read)(buf)
}
fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
(self.f_try_write)(buf)
}
async fn ready(&self, interest: Interest) -> io::Result<Ready> {
(self.f_ready)(interest)
}
}
#[tokio::test]
async fn read_exact_should_fail_if_try_read_encounters_error_other_than_would_block() {
let transport = TestTransport {
f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
};
let mut buf = [0; 1];
assert_eq!(
transport.read_exact(&mut buf).await.unwrap_err().kind(),
io::ErrorKind::NotConnected
);
}
#[tokio::test]
async fn read_exact_should_fail_if_try_read_returns_0_before_necessary_bytes_read() {
let transport = TestTransport {
f_try_read: Box::new(|_| Ok(0)),
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
};
let mut buf = [0; 1];
assert_eq!(
transport.read_exact(&mut buf).await.unwrap_err().kind(),
io::ErrorKind::UnexpectedEof
);
}
#[tokio::test]
async fn read_exact_should_continue_to_call_try_read_until_buffer_is_filled() {
let transport = TestTransport {
f_try_read: Box::new(|buf| {
static mut CNT: u8 = 0;
unsafe {
buf[0] = b'a' + CNT;
CNT += 1;
}
Ok(1)
}),
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
};
let mut buf = [0; 3];
assert_eq!(transport.read_exact(&mut buf).await.unwrap(), 3);
assert_eq!(&buf, b"abc");
}
#[tokio::test]
async fn read_exact_should_continue_to_call_try_read_while_it_returns_would_block() {
// Configure `try_read` to alternate between reading a byte and WouldBlock
let transport = TestTransport {
f_try_read: Box::new(|buf| {
static mut CNT: u8 = 0;
unsafe {
buf[0] = b'a' + CNT;
CNT += 1;
if CNT % 2 == 1 {
Ok(1)
} else {
Err(io::Error::from(io::ErrorKind::WouldBlock))
}
}
}),
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
};
let mut buf = [0; 3];
assert_eq!(transport.read_exact(&mut buf).await.unwrap(), 3);
assert_eq!(&buf, b"ace");
}
#[tokio::test]
async fn read_exact_should_return_0_if_given_a_buffer_of_0_len() {
let transport = TestTransport {
f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
};
let mut buf = [0; 0];
assert_eq!(transport.read_exact(&mut buf).await.unwrap(), 0);
}
#[tokio::test]
async fn write_all_should_fail_if_try_write_encounters_error_other_than_would_block() {
let transport = TestTransport {
f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
};
assert_eq!(
transport.write_all(b"abc").await.unwrap_err().kind(),
io::ErrorKind::NotConnected
);
}
#[tokio::test]
async fn write_all_should_fail_if_try_write_returns_0_before_all_bytes_written() {
let transport = TestTransport {
f_try_write: Box::new(|_| Ok(0)),
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
};
assert_eq!(
transport.write_all(b"abc").await.unwrap_err().kind(),
io::ErrorKind::WriteZero
);
}
#[tokio::test]
async fn write_all_should_continue_to_call_try_write_until_all_bytes_written() {
// Configure `try_write` to alternate between writing a byte and WouldBlock
let transport = TestTransport {
f_try_write: Box::new(|buf| {
static mut CNT: u8 = 0;
unsafe {
assert_eq!(buf[0], b'a' + CNT);
CNT += 1;
Ok(1)
}
}),
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
};
transport.write_all(b"abc").await.unwrap();
}
#[tokio::test]
async fn write_all_should_continue_to_call_try_write_while_it_returns_would_block() {
// Configure `try_write` to alternate between writing a byte and WouldBlock
let transport = TestTransport {
f_try_write: Box::new(|buf| {
static mut CNT: u8 = 0;
unsafe {
if CNT % 2 == 0 {
assert_eq!(buf[0], b'a' + CNT);
CNT += 1;
Ok(1)
} else {
CNT += 1;
Err(io::Error::from(io::ErrorKind::WouldBlock))
}
}
}),
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
};
transport.write_all(b"ace").await.unwrap();
}
#[tokio::test]
async fn write_all_should_return_immediately_if_given_buffer_of_0_len() {
let transport = TestTransport {
f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
};
// No error takes place as we never call try_write
let buf = [0; 0];
transport.write_all(&buf).await.unwrap();
}
}

@ -1,51 +1,156 @@
use super::{Interest, Transport, Ready, Reconnectable};
use super::{Interest, Ready, Reconnectable, Transport};
use async_trait::async_trait;
use bytes::{Buf, BytesMut};
use std::io;
mod codec;
pub use codec::*;
/// Represents a [`Transport`] that reads and writes using frames defined by a [`Codec`],
/// which provides the ability to guarantee that data is read and written completely and also
/// follows the format of the given codec such as encryption and authentication of bytes
pub struct FramedTransport<T, C>
where
T: Transport,
C: Codec,
{
/// By default, framed transport's initial capacity will be 64 KiB
const DEFAULT_CAPACITY: usize = 64 * 1024;
/// Represents a wrapper around a [`Transport`] that reads and writes using frames defined by a
/// [`Codec`]
pub struct FramedTransport<T, C> {
inner: T,
codec: C,
incoming: BytesMut,
outgoing: BytesMut,
}
#[async_trait]
impl<T, C> Reconnectable for FramedTransport<T, C>
impl<T, C> FramedTransport<T, C>
where
T: Transport,
C: Codec,
{
async fn reconnect(&mut self) -> io::Result<()> {
Reconnectable::reconnect(&mut self.inner).await
pub fn new(inner: T, codec: C) -> Self {
Self {
inner,
codec,
incoming: BytesMut::with_capacity(DEFAULT_CAPACITY),
outgoing: BytesMut::with_capacity(DEFAULT_CAPACITY),
}
}
}
#[async_trait]
impl<T, C> Transport for FramedTransport<T, C>
where
T: Transport,
C: Codec,
{
/// Tries to read a frame of data into `buf`
fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
todo!();
/// Reads a frame of bytes by using the [`Codec`] tied to this transport. Returns
/// `Ok(Some(frame))` upon reading a frame, or `Ok(None)` if the underlying transport has
/// closed.
///
/// This call may return an error with [`ErrorKind::WouldBlock`] in the case that the transport
/// is not ready to read data or has not received a full frame before waiting.
///
/// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock
pub fn try_read_frame(&mut self) -> io::Result<Option<Vec<u8>>> {
// Continually read bytes into the incoming queue and then attempt to tease out a frame
let mut buf = [0; DEFAULT_CAPACITY];
loop {
match self.inner.try_read(&mut buf) {
// Getting 0 bytes on read indicates the channel has closed. If we were still
// expecting more bytes for our frame, then this is an error, otherwise if we
// have nothing remaining if our queue then this is an expected end and we
// return None
Ok(0) if self.incoming.is_empty() => return Ok(None),
Ok(0) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)),
// Got some additional bytes, which we will add to our queue and then attempt to
// decode into a frame
Ok(n) => {
self.incoming.extend_from_slice(&buf[..n]);
// Attempt to decode a frame, returning the frame if we get one, continuing to
// try to read more bytes if we don't find a frame, and returing any error that
// is encountered from the decode call
match self.codec.decode(&mut self.incoming) {
Ok(Some(frame)) => return Ok(Some(frame)),
Ok(None) => continue,
// TODO: tokio-util's decoder would cause Framed to return Ok(None)
// if the decoder failed as that indicated a corrupt stream.
//
// Should we continue mirroring this behavior?
Err(x) => return Err(x),
}
}
// Any error (including WouldBlock) will get bubbled up
Err(x) => return Err(x),
}
}
}
/// Tries to write `buf` as a frame of data
fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
todo!();
/// Writes an `item` of bytes as a frame by using the [`Codec`] tied to this transport.
///
/// This is accomplished by continually calling the inner transport's `try_write`. If 0 is
/// returned from a call to `try_write`, this will fail with [`ErrorKind::WriteZero`].
///
/// This call may return an error with [`ErrorKind::WouldBlock`] in the case that the transport
/// is not ready to write data or has not written the entire frame before waiting.
///
/// [`ErrorKind::WriteZero`]: io::ErrorKind::WriteZero
/// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock
pub fn try_write_frame(&mut self, item: &[u8]) -> io::Result<()> {
// Queue up the item as a new frame of bytes
self.codec.encode(item, &mut self.outgoing)?;
// Attempt to write everything in our queue
self.try_flush()
}
async fn ready(&self, interest: Interest) -> io::Result<Ready> {
todo!();
/// Attempts to flush any remaining bytes in the outgoing queue.
///
/// This is accomplished by continually calling the inner transport's `try_write`. If 0 is
/// returned from a call to `try_write`, this will fail with [`ErrorKind::WriteZero`].
///
/// This call may return an error with [`ErrorKind::WouldBlock`] in the case that the transport
/// is not ready to write data.
///
/// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock
pub fn try_flush(&mut self) -> io::Result<()> {
// Continue to send from the outgoing buffer until we either finish or fail
while !self.outgoing.is_empty() {
match self.inner.try_write(self.outgoing.as_ref()) {
// Getting 0 bytes on write indicates the channel has closed
Ok(0) => return Err(io::Error::from(io::ErrorKind::WriteZero)),
// Successful write will advance the outgoing buffer
Ok(n) => self.outgoing.advance(n),
// Any error (including WouldBlock) will get bubbled up
Err(x) => return Err(x),
}
}
Ok(())
}
/// Waits for the transport to be ready based on the given interest, returning the ready status
pub async fn ready(&self, interest: Interest) -> io::Result<Ready> {
Transport::ready(&self.inner, interest).await
}
/// Waits for the transport to be readable to follow up with `try_read`
pub async fn readable(&self) -> io::Result<()> {
let _ = self.ready(Interest::READABLE).await?;
Ok(())
}
/// Waits for the transport to be writeable to follow up with `try_write`
pub async fn writeable(&self) -> io::Result<()> {
let _ = self.ready(Interest::WRITABLE).await?;
Ok(())
}
}
#[async_trait]
impl<T, C> Reconnectable for FramedTransport<T, C>
where
T: Transport + Send,
C: Codec + Send,
{
async fn reconnect(&mut self) -> io::Result<()> {
Reconnectable::reconnect(&mut self.inner).await
}
}
@ -66,3 +171,53 @@ impl FramedTransport<super::InmemoryTransport, PlainCodec> {
(a, b)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn try_read_frame_should_return_would_block_if_fails_to_read_frame_before_blocking() {
todo!();
}
#[test]
fn try_read_frame_should_return_error_if_encountered_error_with_reading_bytes() {
todo!();
}
#[test]
fn try_read_frame_should_return_none_if_encountered_error_during_decode() {
todo!();
}
#[test]
fn try_read_frame_should_return_next_available_frame() {
todo!();
}
#[test]
fn try_write_frame_should_return_would_block_if_fails_to_write_frame_before_blocking() {
todo!();
}
#[test]
fn try_write_frame_should_return_error_if_encountered_error_with_writing_bytes() {
todo!();
}
#[test]
fn try_write_frame_should_return_error_if_encountered_error_during_encode() {
todo!();
}
#[test]
fn try_write_frame_should_write_entire_frame_if_possible() {
todo!();
}
#[test]
fn try_write_frame_should_write_any_prior_queued_bytes_before_writing_next_frame() {
todo!();
}
}

@ -7,8 +7,14 @@ pub use plain::PlainCodec;
mod xchacha20poly1305;
pub use xchacha20poly1305::XChaCha20Poly1305Codec;
/// Represents abstraction of a codec that implements specific encoder and decoder for distant
/// Represents abstraction that implements specific encoder and decoder logic to transform an
/// arbitrary collection of bytes. This can be used to encrypt and authenticate bytes sent and
/// received by transports.
pub trait Codec: Clone {
fn encode(&mut self, item: &[u8], dst: &mut BytesMut) -> io::Result<()>;
fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<Vec<u8>>>;
/// Encodes some `item` as a frame, placing the result at the end of `dst`
fn encode(&self, item: &[u8], dst: &mut BytesMut) -> io::Result<()>;
/// Attempts to decode a frame from `src`, returning `Some(Frame)` if a frame was found
/// or `None` if the current `src` does not contain a frame
fn decode(&self, src: &mut BytesMut) -> io::Result<Option<Vec<u8>>>;
}

@ -16,7 +16,7 @@ impl PlainCodec {
}
impl Codec for PlainCodec {
fn encode(&mut self, item: &[u8], dst: &mut BytesMut) -> io::Result<()> {
fn encode(&self, item: &[u8], dst: &mut BytesMut) -> io::Result<()> {
// Validate that we can fit the message plus nonce +
if item.is_empty() {
return Err(io::Error::new(
@ -34,7 +34,7 @@ impl Codec for PlainCodec {
Ok(())
}
fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<Vec<u8>>> {
fn decode(&self, src: &mut BytesMut) -> io::Result<Option<Vec<u8>>> {
// First, check if we have more data than just our frame's message length
if src.len() <= LEN_SIZE {
return Ok(None);
@ -73,7 +73,7 @@ mod tests {
#[test]
fn encode_should_fail_when_item_is_zero_bytes() {
let mut codec = PlainCodec::new();
let codec = PlainCodec::new();
let mut buf = BytesMut::new();
let result = codec.encode(&[], &mut buf);
@ -86,7 +86,7 @@ mod tests {
#[test]
fn encode_should_build_a_frame_containing_a_length_and_item() {
let mut codec = PlainCodec::new();
let codec = PlainCodec::new();
let mut buf = BytesMut::new();
codec
@ -100,7 +100,7 @@ mod tests {
#[test]
fn decode_should_return_none_if_data_smaller_than_or_equal_to_item_length_field() {
let mut codec = PlainCodec::new();
let codec = PlainCodec::new();
let mut buf = BytesMut::new();
buf.put_bytes(0, LEN_SIZE);
@ -115,7 +115,7 @@ mod tests {
#[test]
fn decode_should_return_none_if_not_enough_data_for_frame() {
let mut codec = PlainCodec::new();
let codec = PlainCodec::new();
let mut buf = BytesMut::new();
buf.put_u64(0);
@ -130,7 +130,7 @@ mod tests {
#[test]
fn decode_should_fail_if_encoded_item_length_is_zero() {
let mut codec = PlainCodec::new();
let codec = PlainCodec::new();
let mut buf = BytesMut::new();
buf.put_u64(0);
@ -145,7 +145,7 @@ mod tests {
#[test]
fn decode_should_advance_src_by_frame_size_even_if_item_length_is_zero() {
let mut codec = PlainCodec::new();
let codec = PlainCodec::new();
let mut buf = BytesMut::new();
buf.put_u64(0);
@ -160,7 +160,7 @@ mod tests {
#[test]
fn decode_should_advance_src_by_frame_size_when_successful() {
let mut codec = PlainCodec::new();
let codec = PlainCodec::new();
// Add 3 extra bytes after a full frame
let mut buf = BytesMut::new();
@ -175,7 +175,7 @@ mod tests {
#[test]
fn decode_should_return_some_byte_vec_when_successful() {
let mut codec = PlainCodec::new();
let codec = PlainCodec::new();
let mut buf = BytesMut::new();
codec

@ -41,7 +41,7 @@ impl fmt::Debug for XChaCha20Poly1305Codec {
}
impl Codec for XChaCha20Poly1305Codec {
fn encode(&mut self, item: &[u8], dst: &mut BytesMut) -> io::Result<()> {
fn encode(&self, item: &[u8], dst: &mut BytesMut) -> io::Result<()> {
// Validate that we can fit the message plus nonce +
if item.is_empty() {
return Err(io::Error::new(
@ -70,7 +70,7 @@ impl Codec for XChaCha20Poly1305Codec {
Ok(())
}
fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<Vec<u8>>> {
fn decode(&self, src: &mut BytesMut) -> io::Result<Option<Vec<u8>>> {
// First, check if we have more data than just our frame's message length
if src.len() <= LEN_SIZE {
return Ok(None);
@ -120,7 +120,7 @@ mod tests {
#[test]
fn encode_should_fail_when_item_is_zero_bytes() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
let codec = XChaCha20Poly1305Codec::from(key);
let mut buf = BytesMut::new();
let result = codec.encode(&[], &mut buf);
@ -134,7 +134,7 @@ mod tests {
#[test]
fn encode_should_build_a_frame_containing_a_length_nonce_and_ciphertext() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
let codec = XChaCha20Poly1305Codec::from(key);
let mut buf = BytesMut::new();
codec
@ -149,7 +149,7 @@ mod tests {
#[test]
fn decode_should_return_none_if_data_smaller_than_or_equal_to_frame_length_field() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
let codec = XChaCha20Poly1305Codec::from(key);
let mut buf = BytesMut::new();
buf.put_bytes(0, LEN_SIZE);
@ -165,7 +165,7 @@ mod tests {
#[test]
fn decode_should_return_none_if_not_enough_data_for_frame() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
let codec = XChaCha20Poly1305Codec::from(key);
let mut buf = BytesMut::new();
buf.put_u64(0);
@ -181,7 +181,7 @@ mod tests {
#[test]
fn decode_should_fail_if_encoded_frame_length_is_smaller_than_nonce_plus_data() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
let codec = XChaCha20Poly1305Codec::from(key);
// NONCE_SIZE + 1 is minimum for frame length
let mut buf = BytesMut::new();
@ -198,7 +198,7 @@ mod tests {
#[test]
fn decode_should_advance_src_by_frame_size_even_if_frame_length_is_too_small() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
let codec = XChaCha20Poly1305Codec::from(key);
// LEN_SIZE + NONCE_SIZE + msg not matching encryption + 3 more bytes
let mut buf = BytesMut::new();
@ -216,7 +216,7 @@ mod tests {
#[test]
fn decode_should_advance_src_by_frame_size_even_if_decryption_fails() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
let codec = XChaCha20Poly1305Codec::from(key);
// LEN_SIZE + NONCE_SIZE + msg not matching encryption + 3 more bytes
let mut buf = BytesMut::new();
@ -235,7 +235,7 @@ mod tests {
#[test]
fn decode_should_advance_src_by_frame_size_when_successful() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
let codec = XChaCha20Poly1305Codec::from(key);
// Add 3 extra bytes after a full frame
let mut buf = BytesMut::new();
@ -251,7 +251,7 @@ mod tests {
#[test]
fn decode_should_return_some_byte_vec_when_successful() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
let codec = XChaCha20Poly1305Codec::from(key);
let mut buf = BytesMut::new();
codec

Loading…
Cancel
Save