Refactor codec, add Frame struct

pull/146/head
Chip Senkbeil 2 years ago
parent 9e8306c634
commit dfc5c08cfd
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

35
Cargo.lock generated

@ -2,6 +2,12 @@
# It is not intended for manual editing.
version = 3
[[package]]
name = "adler"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "aead"
version = "0.5.0"
@ -505,6 +511,15 @@ dependencies = [
"libc",
]
[[package]]
name = "crc32fast"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d"
dependencies = [
"cfg-if 1.0.0",
]
[[package]]
name = "crossbeam-channel"
version = "0.5.4"
@ -795,6 +810,7 @@ dependencies = [
"bytes",
"chacha20poly1305",
"derive_more",
"flate2",
"hex",
"hkdf",
"log",
@ -1026,6 +1042,16 @@ version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
[[package]]
name = "flate2"
version = "1.0.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f82b0f4c27ad9f8bfd1f3208d882da2b09c301bc1c828fd3a00d0216d2fbbff6"
dependencies = [
"crc32fast",
"miniz_oxide",
]
[[package]]
name = "flexi_logger"
version = "0.23.0"
@ -1647,6 +1673,15 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]]
name = "miniz_oxide"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96590ba8f175222643a85693f33d26e9c8a015f599c216509b1a6894af675d34"
dependencies = [
"adler",
]
[[package]]
name = "mio"
version = "0.8.3"

@ -16,6 +16,7 @@ async-trait = "0.1.57"
bytes = "1.2.1"
chacha20poly1305 = "0.10.0"
derive_more = { version = "0.99.17", default-features = false, features = ["as_mut", "as_ref", "deref", "deref_mut", "display", "from", "error", "into", "into_iterator", "is_variant", "try_into"] }
flate2 = "1.0.24"
hex = "0.4.3"
hkdf = "0.12.3"
log = "0.4.17"

@ -6,11 +6,15 @@ use std::io;
mod codec;
pub use codec::*;
mod frame;
pub use frame::*;
/// By default, framed transport's initial capacity (and max single-read) will be 8 KiB
const DEFAULT_CAPACITY: usize = 8 * 1024;
/// Represents a wrapper around a [`Transport`] that reads and writes using frames defined by a
/// [`Codec`]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct FramedTransport<T, C> {
inner: T,
codec: C,
@ -41,7 +45,7 @@ where
/// is not ready to read data or has not received a full frame before waiting.
///
/// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock
pub fn try_read_frame(&mut self) -> io::Result<Option<Vec<u8>>> {
pub fn try_read_frame(&mut self) -> io::Result<Option<OwnedFrame>> {
// Continually read bytes into the incoming queue and then attempt to tease out a frame
let mut buf = [0; DEFAULT_CAPACITY];
@ -59,19 +63,16 @@ where
Ok(n) => {
self.incoming.extend_from_slice(&buf[..n]);
// Attempt to decode a frame, returning the frame if we get one, continuing to
// try to read more bytes if we don't find a frame, and returing any error that
// is encountered from the decode call
match self.codec.decode(&mut self.incoming) {
Ok(Some(frame)) => return Ok(Some(frame)),
// Attempt to read a frame, returning the decoded frame if we get one,
// continuing to try to read more bytes if we don't find a frame, and returning
// any error that is encountered from reading frames or failing to decode
let frame = match Frame::read(&mut self.incoming) {
Ok(Some(frame)) => frame,
Ok(None) => continue,
// TODO: tokio-util's decoder would cause Framed to return Ok(None)
// if the decoder failed as that indicated a corrupt stream.
//
// Should we continue mirroring this behavior?
Err(x) => return Err(x),
}
};
return Ok(Some(self.codec.decode(frame)?.into_owned()));
}
// Any error (including WouldBlock) will get bubbled up
@ -90,9 +91,10 @@ where
///
/// [`ErrorKind::WriteZero`]: io::ErrorKind::WriteZero
/// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock
pub fn try_write_frame(&mut self, item: &[u8]) -> io::Result<()> {
// Queue up the item as a new frame of bytes
self.codec.encode(item, &mut self.outgoing)?;
pub fn try_write_frame<'a>(&mut self, frame: impl Into<Frame<'a>>) -> io::Result<()> {
// Encode the frame and store it in our outgoing queue
let frame = self.codec.encode(frame.into())?;
frame.write(&mut self.outgoing)?;
// Attempt to write everything in our queue
self.try_flush()
@ -178,47 +180,76 @@ mod tests {
use crate::TestTransport;
use bytes::BufMut;
/// Test codec makes a frame be {len}{bytes}, where len has a max size of 255
/// Codec that always succeeds without altering the frame
#[derive(Clone)]
struct TestCodec;
struct OkCodec;
impl Codec for TestCodec {
fn encode(&mut self, item: &[u8], dst: &mut BytesMut) -> io::Result<()> {
dst.put_u8(item.len() as u8);
dst.extend_from_slice(item);
Ok(())
impl Codec for OkCodec {
fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
Ok(frame)
}
fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<Vec<u8>>> {
if src.is_empty() {
return Ok(None);
}
let len = src[0] as usize;
if src.len() - 1 < len {
return Ok(None);
}
let frame = src.split_to(len + 1);
let frame = frame[1..].to_vec();
Ok(Some(frame))
fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
Ok(frame)
}
}
/// Codec that always fails
#[derive(Clone)]
struct ErrorCodec;
struct ErrCodec;
impl Codec for ErrorCodec {
fn encode(&mut self, _item: &[u8], _dst: &mut BytesMut) -> io::Result<()> {
impl Codec for ErrCodec {
fn encode<'a>(&mut self, _frame: Frame<'a>) -> io::Result<Frame<'a>> {
Err(io::Error::from(io::ErrorKind::Other))
}
fn decode(&mut self, _src: &mut BytesMut) -> io::Result<Option<Vec<u8>>> {
fn decode<'a>(&mut self, _frame: Frame<'a>) -> io::Result<Frame<'a>> {
Err(io::Error::from(io::ErrorKind::Other))
}
}
/// Simulate calls to try_read by feeding back `data` in `step` increments, triggering a block
/// if `block_on` returns true where `block_on` is provided a counter value that is incremented
/// every time the simulated `try_read` function is called
///
/// NOTE: This will inject the frame len in front of the provided data to properly simulate
/// receiving a frame of data
fn simulate_try_read(
frames: Vec<Frame>,
step: usize,
block_on: impl Fn(usize) -> bool + Send + Sync + 'static,
) -> Box<dyn Fn(&mut [u8]) -> io::Result<usize> + Send + Sync> {
use std::sync::atomic::{AtomicUsize, Ordering};
// Stuff all of our frames into a single byte collection
let data = {
let mut buf = BytesMut::new();
for frame in frames {
frame.write(&mut buf).unwrap();
}
buf.to_vec()
};
let idx = AtomicUsize::new(0);
let cnt = AtomicUsize::new(0);
Box::new(move |buf| {
if block_on(cnt.fetch_add(1, Ordering::Relaxed)) {
return Err(io::Error::from(io::ErrorKind::WouldBlock));
}
let start = idx.fetch_add(step, Ordering::Relaxed);
let end = start + step;
let end = if end > data.len() { data.len() } else { end };
let len = if start > end { 0 } else { end - start };
buf[..len].copy_from_slice(&data[start..end]);
Ok(len)
})
}
#[test]
fn try_read_frame_should_return_would_block_if_fails_to_read_frame_before_blocking() {
// Should fail if immediately blocks
@ -228,7 +259,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
},
TestCodec,
OkCodec,
);
assert_eq!(
transport.try_read_frame().unwrap_err().kind(),
@ -238,23 +269,11 @@ mod tests {
// Should fail if not read enough bytes before blocking
let mut transport = FramedTransport::new(
TestTransport {
f_try_read: Box::new(|buf| {
static mut CNT: u8 = 0;
unsafe {
CNT += 1;
if CNT == 2 {
Err(io::Error::from(io::ErrorKind::WouldBlock))
} else {
buf[0] = CNT;
Ok(1)
}
}
}),
f_try_read: simulate_try_read(vec![Frame::new(b"some data")], 1, |cnt| cnt == 1),
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
},
TestCodec,
OkCodec,
);
assert_eq!(
transport.try_read_frame().unwrap_err().kind(),
@ -270,7 +289,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
},
TestCodec,
OkCodec,
);
assert_eq!(
transport.try_read_frame().unwrap_err().kind(),
@ -282,14 +301,11 @@ mod tests {
fn try_read_frame_should_return_error_if_encountered_error_during_decode() {
let mut transport = FramedTransport::new(
TestTransport {
f_try_read: Box::new(|buf| {
buf[0] = b'a';
Ok(1)
}),
f_try_read: simulate_try_read(vec![Frame::new(b"some data")], 1, |_| false),
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
},
ErrorCodec,
ErrCodec,
);
assert_eq!(
transport.try_read_frame().unwrap_err().kind(),
@ -301,7 +317,7 @@ mod tests {
fn try_read_frame_should_return_next_available_frame() {
let data = {
let mut data = BytesMut::new();
TestCodec.encode(b"hello world", &mut data).unwrap();
Frame::new(b"hello world").write(&mut data).unwrap();
data.freeze()
};
@ -314,51 +330,35 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
},
TestCodec,
OkCodec,
);
assert_eq!(transport.try_read_frame().unwrap().unwrap(), b"hello world");
}
#[test]
fn try_read_frame_should_keep_reading_until_a_frame_is_found() {
const STEP_SIZE: usize = 7;
let data_1 = {
let mut data = BytesMut::new();
TestCodec.encode(b"hello world", &mut data).unwrap();
data.freeze()
};
let data_2 = {
let mut data = BytesMut::new();
TestCodec.encode(b"test hello", &mut data).unwrap();
data.freeze()
};
let data = [data_1, data_2].concat();
const STEP_SIZE: usize = Frame::HEADER_SIZE + 7;
let mut transport = FramedTransport::new(
TestTransport {
f_try_read: Box::new(move |buf| {
static mut IDX: usize = 0;
unsafe {
let len: usize = IDX + STEP_SIZE;
let len = if len > data.len() { data.len() } else { len };
buf[..STEP_SIZE].copy_from_slice(&data[IDX..len]);
IDX += STEP_SIZE;
Ok(STEP_SIZE)
}
}),
f_try_read: simulate_try_read(
vec![Frame::new(b"hello world"), Frame::new(b"test hello")],
STEP_SIZE,
|_| false,
),
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
},
TestCodec,
OkCodec,
);
assert_eq!(transport.try_read_frame().unwrap().unwrap(), b"hello world");
// Should have leftover bytes from next frame; for our test encoder
// we have a single byte length (10 for "test hello") and the first character
assert_eq!(transport.incoming.to_vec(), [10, b't']);
// Should have leftover bytes from next frame
// where len = 10, "tes"
assert_eq!(
transport.incoming.to_vec(),
[0, 0, 0, 0, 0, 0, 0, 10, b't', b'e', b's']
);
}
#[test]
@ -369,7 +369,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
},
TestCodec,
OkCodec,
);
// First call will only write part of the frame and then return WouldBlock
@ -390,7 +390,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
},
TestCodec,
OkCodec,
);
assert_eq!(
transport
@ -409,7 +409,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
},
ErrorCodec,
ErrCodec,
);
assert_eq!(
transport
@ -433,7 +433,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
},
TestCodec,
OkCodec,
);
transport.try_write_frame(b"hello world").unwrap();
@ -441,13 +441,13 @@ mod tests {
// Transmitted data should be encoded using the framed transport's codec
assert_eq!(
rx.try_recv().unwrap(),
[&[11], b"hello world".as_slice()].concat()
[11u64.to_be_bytes().as_slice(), b"hello world".as_slice()].concat()
);
}
#[test]
fn try_write_frame_should_write_any_prior_queued_bytes_before_writing_next_frame() {
const STEP_SIZE: usize = 5;
const STEP_SIZE: usize = Frame::HEADER_SIZE + 5;
let (tx, rx) = std::sync::mpsc::sync_channel(10);
let mut transport = FramedTransport::new(
TestTransport {
@ -467,7 +467,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
},
TestCodec,
OkCodec,
);
// First call will only write part of the frame and then return WouldBlock
@ -480,7 +480,10 @@ mod tests {
);
// Transmitted data should be encoded using the framed transport's codec
assert_eq!(rx.try_recv().unwrap(), [&[11], b"hell".as_slice()].concat());
assert_eq!(
rx.try_recv().unwrap(),
[11u64.to_be_bytes().as_slice(), b"hello".as_slice()].concat()
);
assert_eq!(
rx.try_recv().unwrap_err(),
std::sync::mpsc::TryRecvError::Empty
@ -488,12 +491,11 @@ mod tests {
// Next call will keep writing successfully until done
transport.try_write_frame(b"test").unwrap();
assert_eq!(rx.try_recv().unwrap(), b"o wor");
assert_eq!(
rx.try_recv().unwrap(),
[b"ld".as_slice(), &[4], b"te".as_slice()].concat()
[b' ', b'w', b'o', b'r', b'l', b'd', 0, 0, 0, 0, 0, 0, 0]
);
assert_eq!(rx.try_recv().unwrap(), b"st");
assert_eq!(rx.try_recv().unwrap(), [4, b't', b'e', b's', b't']);
assert_eq!(
rx.try_recv().unwrap_err(),
std::sync::mpsc::TryRecvError::Empty
@ -508,7 +510,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
},
TestCodec,
OkCodec,
);
// Set our outgoing buffer to flush
@ -529,7 +531,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
},
TestCodec,
OkCodec,
);
// Set our outgoing buffer to flush
@ -550,7 +552,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
},
TestCodec,
OkCodec,
);
// Perform flush and verify nothing happens
@ -571,7 +573,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
},
TestCodec,
OkCodec,
);
// Set our outgoing buffer to flush

@ -1,20 +1,19 @@
use bytes::BytesMut;
use super::Frame;
use std::io;
mod plain;
pub use plain::PlainCodec;
mod xchacha20poly1305;
pub use xchacha20poly1305::XChaCha20Poly1305Codec;
pub use plain::*;
pub use xchacha20poly1305::*;
/// Represents abstraction that implements specific encoder and decoder logic to transform an
/// arbitrary collection of bytes. This can be used to encrypt and authenticate bytes sent and
/// received by transports.
pub trait Codec: Clone {
/// Encodes some `item` as a frame, placing the result at the end of `dst`
fn encode(&mut self, item: &[u8], dst: &mut BytesMut) -> io::Result<()>;
/// Encodes a frame's item
fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>>;
/// Attempts to decode a frame from `src`, returning `Some(Frame)` if a frame was found
/// or `None` if the current `src` does not contain a frame
fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<Vec<u8>>>;
/// Decodes a frame's item
fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>>;
}

@ -1,11 +1,7 @@
use crate::Codec;
use bytes::{Buf, BufMut, BytesMut};
use std::{convert::TryInto, io};
use super::{Codec, Frame};
use std::io;
/// Total bytes to use as the len field denoting a frame's size
const LEN_SIZE: usize = 8;
/// Represents a codec that just ships messages back and forth with no encryption or authentication
/// Represents a codec that does not alter the frame (synonymous with "plain text")
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
pub struct PlainCodec;
@ -16,176 +12,11 @@ impl PlainCodec {
}
impl Codec for PlainCodec {
fn encode(&mut self, item: &[u8], dst: &mut BytesMut) -> io::Result<()> {
// Validate that we can fit the message plus nonce +
if item.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Empty item provided",
));
}
dst.reserve(8 + item.len());
// Add data in form of {LEN}{ITEM}
dst.put_u64((item.len()) as u64);
dst.put_slice(item);
Ok(())
}
fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<Vec<u8>>> {
// First, check if we have more data than just our frame's message length
if src.len() <= LEN_SIZE {
return Ok(None);
}
// Second, retrieve total size of our frame's message
let msg_len = u64::from_be_bytes(src[..LEN_SIZE].try_into().unwrap()) as usize;
if msg_len == 0 {
// Ensure we advance to remove the frame
src.advance(LEN_SIZE);
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Frame's msg cannot have length of 0",
));
}
// Third, check if we have all data for our frame; if not, exit early
if src.len() < msg_len + LEN_SIZE {
return Ok(None);
}
// Fourth, get and return our item
let item = src[LEN_SIZE..(LEN_SIZE + msg_len)].to_vec();
// Fifth, advance so frame is no longer kept around
src.advance(LEN_SIZE + msg_len);
Ok(Some(item))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_should_fail_when_item_is_zero_bytes() {
let mut codec = PlainCodec::new();
let mut buf = BytesMut::new();
let result = codec.encode(&[], &mut buf);
match result {
Err(x) if x.kind() == io::ErrorKind::InvalidInput => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[test]
fn encode_should_build_a_frame_containing_a_length_and_item() {
let mut codec = PlainCodec::new();
let mut buf = BytesMut::new();
codec
.encode(b"hello, world", &mut buf)
.expect("Failed to encode");
let len = buf.get_u64() as usize;
assert_eq!(len, 12, "Wrong length encoded");
assert_eq!(buf.as_ref(), b"hello, world");
fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
Ok(frame)
}
#[test]
fn decode_should_return_none_if_data_smaller_than_or_equal_to_item_length_field() {
let mut codec = PlainCodec::new();
let mut buf = BytesMut::new();
buf.put_bytes(0, LEN_SIZE);
let result = codec.decode(&mut buf);
assert!(
matches!(result, Ok(None)),
"Unexpected result: {:?}",
result
);
}
#[test]
fn decode_should_return_none_if_not_enough_data_for_frame() {
let mut codec = PlainCodec::new();
let mut buf = BytesMut::new();
buf.put_u64(0);
let result = codec.decode(&mut buf);
assert!(
matches!(result, Ok(None)),
"Unexpected result: {:?}",
result
);
}
#[test]
fn decode_should_fail_if_encoded_item_length_is_zero() {
let mut codec = PlainCodec::new();
let mut buf = BytesMut::new();
buf.put_u64(0);
buf.put_u8(255);
let result = codec.decode(&mut buf);
match result {
Err(x) if x.kind() == io::ErrorKind::InvalidData => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[test]
fn decode_should_advance_src_by_frame_size_even_if_item_length_is_zero() {
let mut codec = PlainCodec::new();
let mut buf = BytesMut::new();
buf.put_u64(0);
buf.put_bytes(0, 3);
assert!(
codec.decode(&mut buf).is_err(),
"Decode unexpectedly succeeded"
);
assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf");
}
#[test]
fn decode_should_advance_src_by_frame_size_when_successful() {
let mut codec = PlainCodec::new();
// Add 3 extra bytes after a full frame
let mut buf = BytesMut::new();
codec
.encode(b"hello, world", &mut buf)
.expect("Failed to encode");
buf.put_bytes(0, 3);
assert!(codec.decode(&mut buf).is_ok(), "Decode unexpectedly failed");
assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf");
}
#[test]
fn decode_should_return_some_byte_vec_when_successful() {
let mut codec = PlainCodec::new();
let mut buf = BytesMut::new();
codec
.encode(b"hello, world", &mut buf)
.expect("Failed to encode");
let item = codec
.decode(&mut buf)
.expect("Failed to decode")
.expect("Item not properly captured");
assert_eq!(item, b"hello, world");
fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
Ok(frame)
}
}

@ -1,17 +1,15 @@
use crate::{Codec, SecretKey, SecretKey32};
use bytes::{Buf, BufMut, BytesMut};
use super::{Codec, Frame};
use crate::{SecretKey, SecretKey32};
use chacha20poly1305::{aead::Aead, Key, KeyInit, XChaCha20Poly1305, XNonce};
use std::{convert::TryInto, fmt, io};
/// Total bytes to use as the len field denoting a frame's size
const LEN_SIZE: usize = 8;
use std::{fmt, io};
/// Total bytes to use for nonce
const NONCE_SIZE: usize = 24;
/// Represents the codec to encode & decode data while also encrypting/decrypting it
/// Represents the codec that encodes & decodes frames by encrypting/decrypting them using
/// [`XChaCha20Poly1305`].
///
/// Uses a 32-byte key internally
/// NOTE: Uses a 32-byte key internally.
#[derive(Clone)]
pub struct XChaCha20Poly1305Codec {
cipher: XChaCha20Poly1305,
@ -41,75 +39,43 @@ impl fmt::Debug for XChaCha20Poly1305Codec {
}
impl Codec for XChaCha20Poly1305Codec {
fn encode(&mut self, item: &[u8], dst: &mut BytesMut) -> io::Result<()> {
// Validate that we can fit the message plus nonce +
if item.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Empty item provided",
));
}
fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
// NOTE: As seen in orion, with a 24-bit nonce, it's safe to generate instead of
// maintaining a stateful counter due to its size (24-byte secret key generation
// will never panic)
let nonce_key = SecretKey::<NONCE_SIZE>::generate().unwrap();
let nonce = XNonce::from_slice(nonce_key.unprotected_as_bytes());
// Encrypt the frame's item as our ciphertext
let ciphertext = self
.cipher
.encrypt(nonce, item)
.encrypt(nonce, frame.as_item())
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Encryption failed"))?;
dst.reserve(8 + nonce.len() + ciphertext.len());
// Frame is now comprised of the nonce and ciphertext in sequence
let mut frame = Frame::new(nonce.as_slice());
frame.extend(ciphertext);
// Add data in form of {LEN}{NONCE}{CIPHER TEXT}
dst.put_u64((nonce_key.len() + ciphertext.len()) as u64);
dst.put_slice(nonce.as_slice());
dst.extend(ciphertext);
Ok(())
Ok(frame.into_owned())
}
fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<Vec<u8>>> {
// First, check if we have more data than just our frame's message length
if src.len() <= LEN_SIZE {
return Ok(None);
}
// Second, retrieve total size of our frame's message
let msg_len = u64::from_be_bytes(src[..LEN_SIZE].try_into().unwrap()) as usize;
if msg_len <= NONCE_SIZE {
// Ensure we advance to remove the frame
src.advance(LEN_SIZE + msg_len);
fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
if frame.len() <= NONCE_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Frame's msg cannot have length less than 25",
format!("Frame cannot have length less than {}", frame.len()),
));
}
// Third, check if we have all data for our frame; if not, exit early
if src.len() < msg_len + LEN_SIZE {
return Ok(None);
}
// Fourth, retrieve the nonce used with the ciphertext
let nonce = XNonce::from_slice(&src[LEN_SIZE..(NONCE_SIZE + LEN_SIZE)]);
// Fifth, acquire the encrypted & signed ciphertext
let ciphertext = &src[(NONCE_SIZE + LEN_SIZE)..(msg_len + LEN_SIZE)];
// Sixth, convert ciphertext back into our item
let item = self.cipher.decrypt(nonce, ciphertext);
// Seventh, advance so frame is no longer kept around
src.advance(LEN_SIZE + msg_len);
// Eighth, report an error if there is one
let item =
item.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Decryption failed"))?;
// Grab the nonce from the front of the frame, and then use it with the remainder
// of the frame to tease out the decrypted frame item
let nonce = XNonce::from_slice(&frame.as_item()[..NONCE_SIZE]);
let item = self
.cipher
.decrypt(nonce, &frame.as_item()[NONCE_SIZE..])
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Decryption failed"))?;
Ok(Some(item))
Ok(Frame::from(item))
}
}
@ -117,78 +83,37 @@ impl Codec for XChaCha20Poly1305Codec {
mod tests {
use super::*;
#[test]
fn encode_should_fail_when_item_is_zero_bytes() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
let mut buf = BytesMut::new();
let result = codec.encode(&[], &mut buf);
match result {
Err(x) if x.kind() == io::ErrorKind::InvalidInput => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[test]
fn encode_should_build_a_frame_containing_a_length_nonce_and_ciphertext() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
let mut codec = XChaCha20Poly1305Codec::from(key.clone());
let mut buf = BytesMut::new();
codec
.encode(b"hello, world", &mut buf)
let frame = codec
.encode(Frame::new(b"hello world"))
.expect("Failed to encode");
let len = buf.get_u64() as usize;
assert!(buf.len() > NONCE_SIZE, "Msg size not big enough");
assert_eq!(len, buf.len(), "Msg size does not match attached size");
}
#[test]
fn decode_should_return_none_if_data_smaller_than_or_equal_to_frame_length_field() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
let mut buf = BytesMut::new();
buf.put_bytes(0, LEN_SIZE);
let result = codec.decode(&mut buf);
assert!(
matches!(result, Ok(None)),
"Unexpected result: {:?}",
result
);
}
#[test]
fn decode_should_return_none_if_not_enough_data_for_frame() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
let nonce = XNonce::from_slice(&frame.as_item()[..NONCE_SIZE]);
let ciphertext = &frame.as_item()[NONCE_SIZE..];
let mut buf = BytesMut::new();
buf.put_u64(0);
let result = codec.decode(&mut buf);
assert!(
matches!(result, Ok(None)),
"Unexpected result: {:?}",
result
);
// Manually build our key & cipher so we can decrypt the frame manually to ensure it is
// correct
let key = Key::from_slice(key.unprotected_as_bytes());
let cipher = XChaCha20Poly1305::new(key);
let item = cipher
.decrypt(nonce, ciphertext)
.expect("Failed to decrypt");
assert_eq!(item, b"hello world");
}
#[test]
fn decode_should_fail_if_encoded_frame_length_is_smaller_than_nonce_plus_data() {
fn decode_should_fail_if_frame_length_is_smaller_than_nonce_plus_data() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
// NONCE_SIZE + 1 is minimum for frame length
let mut buf = BytesMut::new();
buf.put_u64(NONCE_SIZE as u64);
buf.put_bytes(0, NONCE_SIZE);
let frame = Frame::from(b"a".repeat(NONCE_SIZE));
let result = codec.decode(&mut buf);
let result = codec.decode(frame);
match result {
Err(x) if x.kind() == io::ErrorKind::InvalidData => {}
x => panic!("Unexpected result: {:?}", x),
@ -196,72 +121,30 @@ mod tests {
}
#[test]
fn decode_should_advance_src_by_frame_size_even_if_frame_length_is_too_small() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
// LEN_SIZE + NONCE_SIZE + msg not matching encryption + 3 more bytes
let mut buf = BytesMut::new();
buf.put_u64(NONCE_SIZE as u64);
buf.put_bytes(0, NONCE_SIZE);
buf.put_bytes(0, 3);
assert!(
codec.decode(&mut buf).is_err(),
"Decode unexpectedly succeeded"
);
assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf");
}
#[test]
fn decode_should_advance_src_by_frame_size_even_if_decryption_fails() {
fn decode_should_fail_if_unable_to_decrypt_frame_item() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
// LEN_SIZE + NONCE_SIZE + msg not matching encryption + 3 more bytes
let mut buf = BytesMut::new();
buf.put_u64((NONCE_SIZE + 12) as u64);
buf.put_bytes(0, NONCE_SIZE);
buf.put_slice(b"hello, world");
buf.put_bytes(0, 3);
assert!(
codec.decode(&mut buf).is_err(),
"Decode unexpectedly succeeded"
);
assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf");
}
#[test]
fn decode_should_advance_src_by_frame_size_when_successful() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
// Add 3 extra bytes after a full frame
let mut buf = BytesMut::new();
codec
.encode(b"hello, world", &mut buf)
.expect("Failed to encode");
buf.put_bytes(0, 3);
// NONCE_SIZE + 1 is minimum for frame length
let frame = Frame::from(b"a".repeat(NONCE_SIZE + 1));
assert!(codec.decode(&mut buf).is_ok(), "Decode unexpectedly failed");
assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf");
let result = codec.decode(frame);
match result {
Err(x) if x.kind() == io::ErrorKind::InvalidData => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[test]
fn decode_should_return_some_byte_vec_when_successful() {
fn decode_should_return_decrypted_frame_when_successful() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
let mut buf = BytesMut::new();
codec
.encode(b"hello, world", &mut buf)
let frame = codec
.encode(Frame::new(b"hello, world"))
.expect("Failed to encode");
let item = codec
.decode(&mut buf)
.expect("Failed to decode")
.expect("Item not properly captured");
assert_eq!(item, b"hello, world");
let frame = codec.decode(frame).expect("Failed to decode");
assert_eq!(frame, b"hello, world");
}
}

@ -0,0 +1,317 @@
use bytes::{Buf, BufMut, BytesMut};
use std::{borrow::Cow, io};
/// Represents a frame whose lifetime is static
pub type OwnedFrame = Frame<'static>;
/// Represents some data wrapped in a frame in order to ship it over the network. The format is
/// simple and follows `{len}{item}` where `len` is the length of the item as a `u64`.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Frame<'a> {
/// Represents the item that will be shipped across the network
item: Cow<'a, [u8]>,
}
impl<'a> Frame<'a> {
/// Creates a new frame wrapping the `item` that will be shipped across the network
pub fn new(item: &'a [u8]) -> Self {
Self {
item: Cow::Borrowed(item),
}
}
/// Consumes the frame and returns its underlying item.
pub fn into_item(self) -> Cow<'a, [u8]> {
self.item
}
}
impl Frame<'_> {
/// Total bytes to use as the header field denoting a frame's size
pub const HEADER_SIZE: usize = 8;
/// Returns the len (in bytes) of the item wrapped by the frame
pub fn len(&self) -> usize {
self.item.len()
}
/// Returns true if the frame is comprised of zero bytes
pub fn is_empty(&self) -> bool {
self.item.is_empty()
}
/// Returns a reference to the bytes of the frame's item
pub fn as_item(&self) -> &[u8] {
&self.item
}
/// Writes the frame to the end of `dst`, including the header representing the length of the
/// item as part of the written bytes
pub fn write(&self, dst: &mut BytesMut) -> io::Result<()> {
if self.item.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Empty item provided",
));
}
dst.reserve(Self::HEADER_SIZE + self.item.len());
// Add data in form of {LEN}{ITEM}
dst.put_u64((self.item.len()) as u64);
dst.put_slice(&self.item);
Ok(())
}
/// Attempts to read a frame from `src`, returning `Some(Frame)` if a frame was found
/// (including the header) or `None` if the current `src` does not contain a frame
pub fn read(src: &mut BytesMut) -> io::Result<Option<OwnedFrame>> {
// First, check if we have more data than just our frame's message length
if src.len() <= Self::HEADER_SIZE {
return Ok(None);
}
// Second, retrieve total size of our frame's message
let item_len = u64::from_be_bytes(src[..Self::HEADER_SIZE].try_into().unwrap()) as usize;
// In the case that our item len is 0, we skip over the invalid frame
if item_len == 0 {
// Ensure we advance to remove the frame
src.advance(Self::HEADER_SIZE);
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Frame's msg cannot have length of 0",
));
}
// Third, check if we have all data for our frame; if not, exit early
if src.len() < item_len + Self::HEADER_SIZE {
return Ok(None);
}
// Fourth, get and return our item
let item = src[Self::HEADER_SIZE..(Self::HEADER_SIZE + item_len)].to_vec();
// Fifth, advance so frame is no longer kept around
src.advance(Self::HEADER_SIZE + item_len);
Ok(Some(Frame::from(item)))
}
/// Returns a new frame which is identical but has a lifetime tied to this frame.
pub fn as_borrowed(&self) -> Frame<'_> {
let item = match &self.item {
Cow::Borrowed(x) => x,
Cow::Owned(x) => x.as_slice(),
};
Frame {
item: Cow::Borrowed(item),
}
}
/// Converts the [`Frame`] into an owned copy.
///
/// If you construct the frame from an item with a non-static lifetime, you may run into
/// lifetime problems due to the way the struct is designed. Calling this function will ensure
/// that the returned value has a static lifetime.
///
/// This is different from just cloning. Cloning the frame will just copy the references, and
/// thus the lifetime will remain the same.
pub fn into_owned(self) -> OwnedFrame {
Frame {
item: Cow::from(self.item.into_owned()),
}
}
}
impl<'a> From<&'a [u8]> for Frame<'a> {
fn from(item: &'a [u8]) -> Self {
Self {
item: Cow::Borrowed(item),
}
}
}
impl<'a, const N: usize> From<&'a [u8; N]> for Frame<'a> {
fn from(item: &'a [u8; N]) -> Self {
Self {
item: Cow::Borrowed(item),
}
}
}
impl<const N: usize> From<[u8; N]> for OwnedFrame {
fn from(item: [u8; N]) -> Self {
Self {
item: Cow::Owned(item.to_vec()),
}
}
}
impl From<Vec<u8>> for OwnedFrame {
fn from(item: Vec<u8>) -> Self {
Self {
item: Cow::Owned(item),
}
}
}
impl AsRef<[u8]> for Frame<'_> {
fn as_ref(&self) -> &[u8] {
AsRef::as_ref(&self.item)
}
}
impl Extend<u8> for Frame<'_> {
fn extend<T: IntoIterator<Item = u8>>(&mut self, iter: T) {
match &mut self.item {
// If we only have a borrowed item, we need to allocate it into a new vec so we can
// extend it with additional bytes
Cow::Borrowed(item) => {
let mut item = item.to_vec();
item.extend(iter);
self.item = Cow::Owned(item);
}
// Othewise, if we already have an owned allocation of bytes, we just extend it
Cow::Owned(item) => {
item.extend(iter);
}
}
}
}
impl PartialEq<[u8]> for Frame<'_> {
fn eq(&self, item: &[u8]) -> bool {
self.item.as_ref().eq(item)
}
}
impl<'a> PartialEq<&'a [u8]> for Frame<'_> {
fn eq(&self, item: &&'a [u8]) -> bool {
self.item.as_ref().eq(*item)
}
}
impl<const N: usize> PartialEq<[u8; N]> for Frame<'_> {
fn eq(&self, item: &[u8; N]) -> bool {
self.item.as_ref().eq(item)
}
}
impl<'a, const N: usize> PartialEq<&'a [u8; N]> for Frame<'_> {
fn eq(&self, item: &&'a [u8; N]) -> bool {
self.item.as_ref().eq(*item)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn write_should_fail_when_item_is_zero_bytes() {
let frame = Frame::new(&[]);
let mut buf = BytesMut::new();
let result = frame.write(&mut buf);
match result {
Err(x) if x.kind() == io::ErrorKind::InvalidInput => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[test]
fn write_should_build_a_frame_containing_a_length_and_item() {
let frame = Frame::new(b"hello, world");
let mut buf = BytesMut::new();
frame.write(&mut buf).expect("Failed to write");
let len = buf.get_u64() as usize;
assert_eq!(len, 12, "Wrong length writed");
assert_eq!(buf.as_ref(), b"hello, world");
}
#[test]
fn read_should_return_none_if_data_smaller_than_or_equal_to_item_length_field() {
let mut buf = BytesMut::new();
buf.put_bytes(0, Frame::HEADER_SIZE);
let result = Frame::read(&mut buf);
assert!(
matches!(result, Ok(None)),
"Unexpected result: {:?}",
result
);
}
#[test]
fn read_should_return_none_if_not_enough_data_for_frame() {
let mut buf = BytesMut::new();
buf.put_u64(0);
let result = Frame::read(&mut buf);
assert!(
matches!(result, Ok(None)),
"Unexpected result: {:?}",
result
);
}
#[test]
fn read_should_fail_if_writed_item_length_is_zero() {
let mut buf = BytesMut::new();
buf.put_u64(0);
buf.put_u8(255);
let result = Frame::read(&mut buf);
match result {
Err(x) if x.kind() == io::ErrorKind::InvalidData => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[test]
fn read_should_advance_src_by_frame_size_even_if_item_length_is_zero() {
let mut buf = BytesMut::new();
buf.put_u64(0);
buf.put_bytes(0, 3);
assert!(
Frame::read(&mut buf).is_err(),
"read unexpectedly succeeded"
);
assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf");
}
#[test]
fn read_should_advance_src_by_frame_size_when_successful() {
// Add 3 extra bytes after a full frame
let mut buf = BytesMut::new();
Frame::new(b"hello, world")
.write(&mut buf)
.expect("Failed to write");
buf.put_bytes(0, 3);
assert!(Frame::read(&mut buf).is_ok(), "read unexpectedly failed");
assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf");
}
#[test]
fn read_should_return_some_byte_vec_when_successful() {
let mut buf = BytesMut::new();
Frame::new(b"hello, world")
.write(&mut buf)
.expect("Failed to write");
let item = Frame::read(&mut buf)
.expect("Failed to read")
.expect("Item not properly captured");
assert_eq!(item, b"hello, world");
}
}
Loading…
Cancel
Save