Still unfinished changes

pull/146/head
Chip Senkbeil 2 years ago
parent ad9f1ac05a
commit 36c05c4283
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -1,4 +1,5 @@
use crate::{FramedTransport, Interest, Request, Transport, UntypedResponse};
use crate::{FramedTransport, Interest, Reconnectable, Request, Transport, UntypedResponse};
use async_trait::async_trait;
use log::*;
use serde::{de::DeserializeOwned, Serialize};
use std::{
@ -7,7 +8,7 @@ use std::{
};
use tokio::{
io,
sync::mpsc,
sync::{mpsc, oneshot},
task::{JoinError, JoinHandle},
};
@ -18,14 +19,16 @@ mod ext;
pub use ext::*;
/// Represents a client that can be used to send requests & receive responses from a server
pub struct Client<T, U>
where
T: Send + Sync + Serialize + 'static,
U: Send + Sync + DeserializeOwned + 'static,
{
pub struct Client<T, U> {
/// Used to send requests to a server
channel: Channel<T, U>,
/// Used to send reconnect request to inner transport
reconnect_tx: mpsc::Sender<oneshot::Sender<io::Result<()>>>,
/// Used to send shutdown request to inner transport
shutdown_tx: mpsc::Sender<()>,
/// Contains the task that is running to send requests and receive responses from a server
task: JoinHandle<()>,
}
@ -35,27 +38,37 @@ where
T: Send + Sync + Serialize,
U: Send + Sync + DeserializeOwned,
{
/// Initializes a client using the provided transport
pub fn new<V>(transport: V) -> io::Result<Self>
/// Initializes a client using the provided [`FramedTransport`]
pub fn new<V, const CAPACITY: usize>(transport: FramedTransport<V, CAPACITY>) -> Self
where
V: Transport,
V: Transport + Send + Sync,
{
let post_office = Arc::new(PostOffice::default());
let weak_post_office = Arc::downgrade(&post_office);
let (tx, mut rx) = mpsc::channel::<Request<T>>(1);
// Do handshake with the server
// TODO: Support user configuration
let mut transport: FramedTransport<_, _> = todo!();
let (reconnect_tx, reconnect_rx) = mpsc::channel::<oneshot::Sender<io::Result<()>>>(1);
let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
// Start a task that continually checks for responses and delivers them using the
// post office
let task = tokio::spawn(async move {
loop {
let ready = transport
.ready(Interest::READABLE | Interest::WRITABLE)
.await
.expect("Failed to examine ready state");
let ready = tokio::select! {
_ = shutdown_rx.recv() => {
break;
}
cb = reconnect_rx.recv() => {
if let Some(cb) = cb {
cb.send(Reconnectable::reconnect(&mut transport).await);
continue;
} else {
break;
}
}
result = transport.ready(Interest::READABLE | Interest::WRITABLE) => {
result.expect("Failed to examine ready state")
}
};
if ready.is_readable() {
match transport.try_read_frame() {
@ -64,6 +77,7 @@ where
match response.to_typed_response() {
Ok(response) => {
// Try to send response to appropriate mailbox
// TODO: This will block if full... is that a problem?
// TODO: How should we handle false response? Did logging in past
post_office.deliver_response(response).await;
}
@ -121,9 +135,16 @@ where
post_office: weak_post_office,
};
Ok(Self { channel, task })
Self {
channel,
reconnect_tx,
shutdown_tx,
task,
}
}
}
impl<T, U> Client<T, U> {
/// Convert into underlying channel
pub fn into_channel(self) -> Channel<T, U> {
self.channel
@ -151,11 +172,27 @@ where
}
}
impl<T, U> Deref for Client<T, U>
#[async_trait]
impl<T, U> Reconnectable for Client<T, U>
where
T: Send + Sync + Serialize + 'static,
U: Send + Sync + DeserializeOwned + 'static,
T: Send,
U: Send + Sync,
{
async fn reconnect(&mut self) -> io::Result<()> {
let (tx, rx) = oneshot::channel();
if self.reconnect_tx.send(tx).await.is_ok() {
rx.await
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Callback lost"))?
} else {
Err(io::Error::new(
io::ErrorKind::Other,
"Client internal task dead",
))
}
}
}
impl<T, U> Deref for Client<T, U> {
type Target = Channel<T, U>;
fn deref(&self) -> &Self::Target {
@ -163,21 +200,13 @@ where
}
}
impl<T, U> DerefMut for Client<T, U>
where
T: Send + Sync + Serialize + 'static,
U: Send + Sync + DeserializeOwned + 'static,
{
impl<T, U> DerefMut for Client<T, U> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.channel
}
}
impl<T, U> From<Client<T, U>> for Channel<T, U>
where
T: Send + Sync + Serialize + 'static,
U: Send + Sync + DeserializeOwned + 'static,
{
impl<T, U> From<Client<T, U>> for Channel<T, U> {
fn from(client: Client<T, U>) -> Self {
client.channel
}

@ -11,11 +11,7 @@ const CHANNEL_MAILBOX_CAPACITY: usize = 10000;
/// Represents a sender of requests tied to a session, holding onto a weak reference of
/// mailboxes to relay responses, meaning that once the [`Session`] is closed or dropped,
/// any sent request will no longer be able to receive responses
pub struct Channel<T, U>
where
T: Send + Sync,
U: Send + Sync,
{
pub struct Channel<T, U> {
/// Used to send requests to a server
pub(crate) tx: mpsc::Sender<Request<T>>,
@ -24,11 +20,7 @@ where
}
// NOTE: Implemented manually to avoid needing clone to be defined on generic types
impl<T, U> Clone for Channel<T, U>
where
T: Send + Sync,
U: Send + Sync,
{
impl<T, U> Clone for Channel<T, U> {
fn clone(&self) -> Self {
Self {
tx: self.tx.clone(),

@ -1,4 +1,4 @@
use crate::{Client, Codec, FramedTransport, TcpTransport};
use crate::{BoxedCodec, Client, FramedTransport, TcpTransport};
use async_trait::async_trait;
use serde::{de::DeserializeOwned, Serialize};
use std::{convert, net::SocketAddr};
@ -11,19 +11,15 @@ where
U: DeserializeOwned + Send + Sync,
{
/// Connect to a remote TCP server using the provided information
async fn connect<C>(addr: SocketAddr, codec: C) -> io::Result<Client<T, U>>
where
C: Codec + Send + 'static;
async fn connect(addr: SocketAddr, codec: impl Into<BoxedCodec>) -> io::Result<Client<T, U>>;
/// Connect to a remote TCP server, timing out after duration has passed
async fn connect_timeout<C>(
addr: SocketAddr,
codec: C,
codec: impl Into<BoxedCodec> + Send,
duration: Duration,
) -> io::Result<Client<T, U>>
where
C: Codec + Send + 'static,
{
) -> io::Result<Client<T, U>> {
let codec = codec.into();
tokio::time::timeout(duration, Self::connect(addr, codec))
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
@ -38,12 +34,9 @@ where
U: Send + Sync + DeserializeOwned + 'static,
{
/// Connect to a remote TCP server using the provided information
async fn connect<C>(addr: SocketAddr, codec: C) -> io::Result<Client<T, U>>
where
C: Codec + Send + 'static,
{
async fn connect(addr: SocketAddr, codec: impl Into<BoxedCodec>) -> io::Result<Client<T, U>> {
let transport = TcpTransport::connect(addr).await?;
let transport = FramedTransport::new(transport, codec);
Self::from_framed_transport(transport)
Self::new(transport)
}
}

@ -1,55 +0,0 @@
use crate::{BoxedCodec, FramedTransport, PlainCodec, Request, Response, Transport};
use serde::{Deserialize, Serialize};
use std::io;
/// Represents options that the server has available for a connection
#[derive(Serialize, Deserialize)]
struct ServerConnectionOptions {
/// Choices for encryption as string labels
pub encryption: Vec<String>,
/// Choices for compression as string labels
pub compression: Vec<String>,
}
/// Represents the choice that the client has made regarding server connection options
struct ClientConnectionChoice {
/// Selected encryption
pub encryption: String,
/// Selected compression
pub compression: String,
}
/// Performs the client-side of a handshake
pub async fn client_handshake<T>(transport: T) -> io::Result<FramedTransport<T, BoxedCodec>>
where
T: Transport,
{
let transport = FramedTransport::new(transport, PlainCodec::new());
// Wait for the server to send us choices for communication
let frame = transport.read_frame().await?.ok_or_else(|| {
io::Error::new(
io::ErrorKind::ConnectionAborted,
"Connection aborted before receiving server communication",
)
})?;
// Parse the frame as the request for the client
let request = Request::<ServerConnectionOptions>::from_slice(frame.as_item())?;
// Select an encryption and compression choice
let encryption = request.payload.encryption[0];
let compression = request.payload.compression[0];
// Respond back with choices
}
/// Performs the server-side of a handshake
pub async fn server_handshake<T>(transport: T) -> io::Result<FramedTransport<T, BoxedCodec>>
where
T: Transport,
{
let transport = FramedTransport::new(transport, PlainCodec::new());
}

@ -1,7 +1,6 @@
mod any;
// mod auth;
mod client;
mod handshake;
mod id;
mod key;
mod listener;

@ -1,4 +1,4 @@
use crate::{BoxedCodec, FramedTransport};
use crate::FramedTransport;
use async_trait::async_trait;
use serde::{de::DeserializeOwned, Serialize};
use std::io;
@ -43,10 +43,10 @@ pub trait Server: Send {
/// Invoked to facilitate a handshake between server and client upon establishing a connection,
/// returning an updated [`FramedTransport`] once the handshake is complete
async fn on_handshake<T: Send>(
async fn on_handshake<T: Send, const CAPACITY: usize>(
&self,
transport: FramedTransport<T, BoxedCodec>,
) -> io::Result<FramedTransport<T, BoxedCodec>> {
transport: FramedTransport<T, CAPACITY>,
) -> io::Result<FramedTransport<T, CAPACITY>> {
Ok(transport)
}

@ -1,7 +1,7 @@
use super::{Interest, Ready, Reconnectable, Transport};
use async_trait::async_trait;
use bytes::{Buf, BytesMut};
use std::io;
use std::{fmt, io};
mod codec;
pub use codec::*;
@ -9,30 +9,55 @@ pub use codec::*;
mod frame;
pub use frame::*;
mod handshake;
pub use handshake::*;
/// By default, framed transport's initial capacity (and max single-read) will be 8 KiB
const DEFAULT_CAPACITY: usize = 8 * 1024;
/// Represents a wrapper around a [`Transport`] that reads and writes using frames defined by a
/// [`Codec`]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct FramedTransport<T, U> {
/// [`Codec`]. `CAPACITY` represents both the initial capacity of incoming and outgoing buffers as
/// well as the maximum bytes read per call to [`try_read`].
///
/// [`try_read`]: Transport::try_read
#[derive(Clone)]
pub struct FramedTransport<T, const CAPACITY: usize = DEFAULT_CAPACITY> {
inner: T,
codec: U,
codec: BoxedCodec,
incoming: BytesMut,
outgoing: BytesMut,
}
impl<T, U> FramedTransport<T, U> {
pub fn new(inner: T, codec: U) -> Self {
impl<T, const CAPACITY: usize> FramedTransport<T, CAPACITY> {
fn new(inner: T, codec: BoxedCodec) -> Self {
Self {
inner,
codec,
incoming: BytesMut::with_capacity(DEFAULT_CAPACITY),
outgoing: BytesMut::with_capacity(DEFAULT_CAPACITY),
incoming: BytesMut::with_capacity(CAPACITY),
outgoing: BytesMut::with_capacity(CAPACITY),
}
}
/// Performs a handshake with the other side of the `transport` in order to determine which
/// [`Codec`] to use as well as perform any additional logic to prepare the framed transport.
///
/// Will use the handshake criteria provided in `handshake`
pub async fn from_handshake(
transport: T,
handshake: Handshake<T, CAPACITY>,
) -> io::Result<FramedTransport<T, CAPACITY>>
where
T: Transport,
{
handshake::do_handshake(transport, &handshake).await
}
/// Creates a new [`FramedTransport`] using the [`PlainCodec`]
pub fn plain(inner: T) -> Self {
Self::new(inner, Box::new(PlainCodec::new()))
}
/// Consumes the current transport, replacing it's codec with the provided codec,
/// and returning it. Note that any bytes in the incoming or outgoing buffers will
/// remain in the transport, meaning that this can cause corruption if the bytes
@ -41,10 +66,10 @@ impl<T, U> FramedTransport<T, U> {
/// For safety, use [`clear`] to wipe the buffers before further use.
///
/// [`clear`]: FramedTransport::clear
pub fn with_codec<C>(self, codec: C) -> FramedTransport<T, C> {
pub fn with_codec(self, codec: impl Into<BoxedCodec>) -> FramedTransport<T, CAPACITY> {
FramedTransport {
inner: self.inner,
codec,
codec: codec.into(),
incoming: self.incoming,
outgoing: self.outgoing,
}
@ -57,7 +82,17 @@ impl<T, U> FramedTransport<T, U> {
}
}
impl<T, U> FramedTransport<T, U>
impl<T, const CAPACITY: usize> fmt::Debug for FramedTransport<T, CAPACITY> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FramedTransport")
.field("capacity", &CAPACITY)
.field("incoming", &self.incoming)
.field("outgoing", &self.outgoing)
.finish()
}
}
impl<T, const CAPACITY: usize> FramedTransport<T, CAPACITY>
where
T: Transport,
{
@ -106,10 +141,9 @@ where
}
}
impl<T, U> FramedTransport<T, U>
impl<T, const CAPACITY: usize> FramedTransport<T, CAPACITY>
where
T: Transport,
U: Codec,
{
/// Reads a frame of bytes by using the [`Codec`] tied to this transport. Returns
/// `Ok(Some(frame))` upon reading a frame, or `Ok(None)` if the underlying transport has
@ -121,7 +155,7 @@ where
/// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock
pub fn try_read_frame(&mut self) -> io::Result<Option<OwnedFrame>> {
// Continually read bytes into the incoming queue and then attempt to tease out a frame
let mut buf = [0; DEFAULT_CAPACITY];
let mut buf = [0; CAPACITY];
loop {
match self.inner.try_read(&mut buf) {
@ -217,17 +251,28 @@ where
}
#[async_trait]
impl<T, U> Reconnectable for FramedTransport<T, U>
impl<T, const CAPACITY: usize> Reconnectable for FramedTransport<T, CAPACITY>
where
T: Transport + Send,
U: Codec + Send,
T: Transport + Send + Sync,
{
async fn reconnect(&mut self) -> io::Result<()> {
Reconnectable::reconnect(&mut self.inner).await
// Establish a new connection
Reconnectable::reconnect(&mut self.inner).await?;
// Perform handshake again, which can result in the underlying codec
// changing based on the exchange; so, we want to clear out any lingering
// bytes in the incoming and outgoing queues
self.clear();
let FramedTransport { inner, codec, .. } =
handshake::do_handshake(self.inner, &self.handshake).await?;
self.inner = inner;
self.codec = codec;
Ok(())
}
}
impl FramedTransport<super::InmemoryTransport, PlainCodec> {
impl<const CAPACITY: usize> FramedTransport<super::InmemoryTransport, CAPACITY> {
/// Produces a pair of inmemory transports that are connected to each other using
/// a standard codec
///
@ -235,8 +280,8 @@ impl FramedTransport<super::InmemoryTransport, PlainCodec> {
pub fn pair(
buffer: usize,
) -> (
FramedTransport<super::InmemoryTransport, PlainCodec>,
FramedTransport<super::InmemoryTransport, PlainCodec>,
FramedTransport<super::InmemoryTransport, CAPACITY>,
FramedTransport<super::InmemoryTransport, CAPACITY>,
) {
let (a, b) = super::InmemoryTransport::pair(buffer);
let a = FramedTransport::new(a, PlainCodec::new());

@ -4,18 +4,15 @@ use std::io;
mod chain;
mod compression;
mod encryption;
mod plain;
mod predicate;
mod xchacha20poly1305;
pub use chain::*;
pub use compression::*;
pub use encryption::*;
pub use plain::*;
pub use predicate::*;
pub use xchacha20poly1305::*;
/// Represents a [`Box`]ed version of [`Codec`]
pub type BoxedCodec = Box<dyn Codec + Send + Sync>;
/// Represents abstraction that implements specific encoder and decoder logic to transform an
/// arbitrary collection of bytes. This can be used to encrypt and authenticate bytes sent and
@ -28,6 +25,9 @@ pub trait Codec: DynClone {
fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>>;
}
/// Represents a [`Box`]ed version of [`Codec`]
pub type BoxedCodec = Box<dyn Codec + Send + Sync>;
macro_rules! impl_traits {
($($x:tt)+) => {
impl Clone for Box<dyn $($x)+> {

@ -3,10 +3,11 @@ use flate2::{
bufread::{DeflateDecoder, DeflateEncoder, GzDecoder, GzEncoder, ZlibDecoder, ZlibEncoder},
Compression,
};
use serde::{Deserialize, Serialize};
use std::io::{self, Read};
/// Represents the level of compression to apply to data
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum CompressionLevel {
/// Use no compression (can potentially inflate data)
Zero = 0,
@ -37,6 +38,36 @@ impl CompressionLevel {
pub const BEST: Self = Self::Nine;
}
impl Default for CompressionLevel {
/// Standard compression level used in zlib library is 6, which is also used here
fn default() -> Self {
Self::Six
}
}
/// Represents the type of compression for a [`CompressionCodec`]
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum CompressionType {
Deflate,
Gzip,
Zlib,
/// Indicates an unknown compression type for use in handshakes
#[serde(other)]
Unknown,
}
impl CompressionType {
/// Returns a list of all variants of the type *except* unknown.
pub const fn known_variants() -> &'static [CompressionType] {
&[
CompressionType::Deflate,
CompressionType::Gzip,
CompressionType::Zlib,
]
}
}
/// Represents a codec that applies compression during encoding and decompression during decoding
/// of a frame's item
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
@ -52,6 +83,23 @@ pub enum CompressionCodec {
}
impl CompressionCodec {
/// Makes a new [`CompressionCodec`] based on the [`CompressionType`] and [`CompressionLevel`],
/// returning error if the type is unknown
pub fn from_type_and_level(
ty: CompressionType,
level: CompressionLevel,
) -> io::Result<CompressionCodec> {
match ty {
CompressionType::Deflate => Ok(Self::Deflate { level }),
CompressionType::Gzip => Ok(Self::Gzip { level }),
CompressionType::Zlib => Ok(Self::Zlib { level }),
CompressionType::Unknown => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Unknown compression type",
)),
}
}
/// Create a new deflate compression codec with the specified `level`
pub fn deflate(level: impl Into<CompressionLevel>) -> Self {
Self::Deflate {
@ -81,6 +129,15 @@ impl CompressionCodec {
Self::Zlib { level } => *level,
}
}
/// Returns the compression type associated with the codec
pub fn ty(&self) -> CompressionType {
match self {
Self::Deflate { .. } => CompressionType::Deflate,
Self::Gzip { .. } => CompressionType::Gzip,
Self::Zlib { .. } => CompressionType::Zlib,
}
}
}
impl Codec for CompressionCodec {

@ -0,0 +1,242 @@
use super::{Codec, Frame};
use derive_more::Display;
use std::{fmt, io};
mod key;
pub use key::*;
/// Represents the type of encryption for a [`EncryptionCodec`]
#[derive(
Copy, Clone, Debug, Display, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize,
)]
pub enum EncryptionType {
/// ChaCha20Poly1305 variant with an extended 192-bit (24-byte) nonce
#[display(fmt = "xchacha20poly1305")]
XChaCha20Poly1305,
/// Indicates an unknown encryption type for use in handshakes
#[display(fmt = "unknown")]
#[serde(other)]
Unknown,
}
impl EncryptionType {
/// Generates bytes for a secret key based on the encryption type
pub fn generate_secret_key_bytes(&self) -> io::Result<Vec<u8>> {
match self {
Self::XChaCha20Poly1305 => Ok(SecretKey::<32>::generate()
.unwrap()
.unprotected_into_bytes()),
Self::Unknown => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Unknown encryption type",
)),
}
}
/// Returns a list of all variants of the type *except* unknown.
pub const fn known_variants() -> &'static [EncryptionType] {
&[EncryptionType::XChaCha20Poly1305]
}
}
/// Represents the codec that encodes & decodes frames by encrypting/decrypting them
#[derive(Clone)]
pub enum EncryptionCodec {
/// ChaCha20Poly1305 variant with an extended 192-bit (24-byte) nonce, using
/// [`XChaCha20Poly1305`] underneath
XChaCha20Poly1305 {
cipher: chacha20poly1305::XChaCha20Poly1305,
},
}
impl EncryptionCodec {
/// Makes a new [`EncryptionCodec`] based on the [`EncryptionType`] and `key`, returning an
/// error if the key is invalid for the encryption type or the type is unknown
pub fn from_type_and_key(ty: EncryptionType, key: &[u8]) -> io::Result<EncryptionCodec> {
match ty {
EncryptionType::XChaCha20Poly1305 => {
use chacha20poly1305::{KeyInit, XChaCha20Poly1305};
let cipher = XChaCha20Poly1305::new_from_slice(key)
.map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x))?;
Ok(Self::XChaCha20Poly1305 { cipher })
}
EncryptionType::Unknown => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Encryption type is unknown",
)),
}
}
pub fn new_xchacha20poly1305(secret_key: SecretKey32) -> EncryptionCodec {
// NOTE: This should never fail as we are enforcing the key size at compile time
Self::from_type_and_key(
EncryptionType::XChaCha20Poly1305,
secret_key.unprotected_as_bytes(),
)
.unwrap()
}
/// Returns the encryption type associa ted with the codec
pub fn ty(&self) -> EncryptionType {
match self {
Self::XChaCha20Poly1305 { .. } => EncryptionType::XChaCha20Poly1305,
}
}
/// Size of nonce (in bytes) associated with the encryption algorithm
pub const fn nonce_size(&self) -> usize {
match self {
// XChaCha20Poly1305 uses a 192-bit (24-byte) key
Self::XChaCha20Poly1305 { .. } => 24,
}
}
/// Generates a new nonce for the encryption algorithm
fn generate_nonce_bytes(&self) -> Vec<u8> {
// NOTE: As seen in orion, with a 24-bit nonce, it's safe to generate instead of
// maintaining a stateful counter due to its size (24-byte secret key generation
// will never panic)
match self {
Self::XChaCha20Poly1305 { .. } => SecretKey::<24>::generate()
.unwrap()
.unprotected_into_bytes(),
}
}
}
impl fmt::Debug for EncryptionCodec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("EncryptionCodec")
.field("cipher", &"**OMITTED**".to_string())
.field("nonce_size", &self.nonce_size())
.field("ty", &self.ty().to_string())
.finish()
}
}
impl Codec for EncryptionCodec {
fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
let frame = match self {
Self::XChaCha20Poly1305 { cipher } => {
use chacha20poly1305::{aead::Aead, XNonce};
let nonce_bytes = self.generate_nonce_bytes();
let nonce = XNonce::from_slice(&nonce_bytes);
// Encrypt the frame's item as our ciphertext
let ciphertext = cipher
.encrypt(nonce, frame.as_item())
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Encryption failed"))?;
// Frame is now comprised of the nonce and ciphertext in sequence
let mut frame = Frame::new(&nonce_bytes);
frame.extend(ciphertext);
frame
}
};
Ok(frame.into_owned())
}
fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
if frame.len() <= self.nonce_size() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Frame cannot have length less than {}",
self.nonce_size() + 1
),
));
}
// Grab the nonce from the front of the frame, and then use it with the remainder
// of the frame to tease out the decrypted frame item
let item = match self {
Self::XChaCha20Poly1305 { cipher } => {
use chacha20poly1305::{aead::Aead, XNonce};
let nonce = XNonce::from_slice(&frame.as_item()[..self.nonce_size()]);
cipher
.decrypt(nonce, &frame.as_item()[self.nonce_size()..])
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Decryption failed"))?
}
};
Ok(Frame::from(item))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_should_build_a_frame_containing_a_length_nonce_and_ciphertext() {
let ty = EncryptionType::XChaCha20Poly1305;
let key = ty.generate_secret_key_bytes().unwrap();
let mut codec = EncryptionCodec::from_type_and_key(ty, &key).unwrap();
let frame = codec
.encode(Frame::new(b"hello world"))
.expect("Failed to encode");
let nonce = &frame.as_item()[..codec.nonce_size()];
let ciphertext = &frame.as_item()[codec.nonce_size()..];
// Manually build our key & cipher so we can decrypt the frame manually to ensure it is
// correct
let item = {
use chacha20poly1305::{aead::Aead, KeyInit, XChaCha20Poly1305, XNonce};
let cipher = XChaCha20Poly1305::new_from_slice(&key).unwrap();
cipher
.decrypt(XNonce::from_slice(nonce), ciphertext)
.expect("Failed to decrypt")
};
assert_eq!(item, b"hello world");
}
#[test]
fn decode_should_fail_if_frame_length_is_smaller_than_nonce_plus_data() {
let ty = EncryptionType::XChaCha20Poly1305;
let key = ty.generate_secret_key_bytes().unwrap();
let mut codec = EncryptionCodec::from_type_and_key(ty, &key).unwrap();
// NONCE_SIZE + 1 is minimum for frame length
let frame = Frame::from(b"a".repeat(codec.nonce_size()));
let result = codec.decode(frame);
match result {
Err(x) if x.kind() == io::ErrorKind::InvalidData => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[test]
fn decode_should_fail_if_unable_to_decrypt_frame_item() {
let ty = EncryptionType::XChaCha20Poly1305;
let key = ty.generate_secret_key_bytes().unwrap();
let mut codec = EncryptionCodec::from_type_and_key(ty, &key).unwrap();
// NONCE_SIZE + 1 is minimum for frame length
let frame = Frame::from(b"a".repeat(codec.nonce_size() + 1));
let result = codec.decode(frame);
match result {
Err(x) if x.kind() == io::ErrorKind::InvalidData => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[test]
fn decode_should_return_decrypted_frame_when_successful() {
let ty = EncryptionType::XChaCha20Poly1305;
let key = ty.generate_secret_key_bytes().unwrap();
let mut codec = EncryptionCodec::from_type_and_key(ty, &key).unwrap();
let frame = codec
.encode(Frame::new(b"hello, world"))
.expect("Failed to encode");
let frame = codec.decode(frame).expect("Failed to decode");
assert_eq!(frame, b"hello, world");
}
}

@ -5,7 +5,13 @@ use std::{fmt, str::FromStr};
#[derive(Debug, Display, Error)]
pub struct SecretKeyError;
/// Represents a 32-byte secret key
/// Represents a 16-byte (128-bit) secret key
pub type SecretKey16 = SecretKey<16>;
/// Represents a 24-byte (192-bit) secret key
pub type SecretKey24 = SecretKey<24>;
/// Represents a 32-byte (256-bit) secret key
pub type SecretKey32 = SecretKey<32>;
/// Represents a secret key used with transport encryption and authentication
@ -42,6 +48,16 @@ impl<const N: usize> SecretKey<N> {
&self.0
}
/// Consumes the secret key and returns the array of key's bytes
pub fn unprotected_into_byte_array(self) -> [u8; N] {
self.0
}
/// Consumes the secret key and returns the key's bytes as a [`HeapSecretKey`]
pub fn into_heap_secret_key(self) -> HeapSecretKey {
HeapSecretKey(self.0.to_vec())
}
/// Returns the length of the key
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
@ -98,3 +114,56 @@ impl<const N: usize> fmt::Display for SecretKey<N> {
write!(f, "{}", hex::encode(self.unprotected_as_bytes()))
}
}
/// Represents a secret key used with transport encryption and authentication that is stored on the
/// heap
#[derive(Clone, PartialEq, Eq)]
pub struct HeapSecretKey(Vec<u8>);
impl fmt::Debug for HeapSecretKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("HeapSecretKey")
.field(&"**OMITTED**".to_string())
.finish()
}
}
impl HeapSecretKey {
/// Returns byte slice to the key's bytes
pub fn unprotected_as_bytes(&self) -> &[u8] {
&self.0
}
/// Consumes the secret key and returns the key's bytes
pub fn unprotected_into_bytes(self) -> Vec<u8> {
self.0.to_vec()
}
/// Returns the length of the key
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.0.len()
}
}
impl From<Vec<u8>> for HeapSecretKey {
fn from(bytes: Vec<u8>) -> Self {
Self(bytes)
}
}
impl FromStr for HeapSecretKey {
type Err = SecretKeyError;
/// Parse a str of hex as secret key on heap
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Self(hex::decode(s).map_err(|_| SecretKeyError)?))
}
}
impl fmt::Display for HeapSecretKey {
/// Display an N-byte secret key as a hex string
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", hex::encode(self.unprotected_as_bytes()))
}
}

@ -1,150 +0,0 @@
use super::{Codec, Frame};
use crate::{SecretKey, SecretKey32};
use chacha20poly1305::{aead::Aead, Key, KeyInit, XChaCha20Poly1305, XNonce};
use std::{fmt, io};
/// Total bytes to use for nonce
const NONCE_SIZE: usize = 24;
/// Represents the codec that encodes & decodes frames by encrypting/decrypting them using
/// [`XChaCha20Poly1305`].
///
/// NOTE: Uses a 32-byte key internally.
#[derive(Clone)]
pub struct XChaCha20Poly1305Codec {
cipher: XChaCha20Poly1305,
}
impl XChaCha20Poly1305Codec {
pub fn new(key: &[u8]) -> Self {
let key = Key::from_slice(key);
let cipher = XChaCha20Poly1305::new(key);
Self { cipher }
}
}
impl From<SecretKey32> for XChaCha20Poly1305Codec {
/// Create a new XChaCha20Poly1305 codec with a 32-byte key
fn from(secret_key: SecretKey32) -> Self {
Self::new(secret_key.unprotected_as_bytes())
}
}
impl fmt::Debug for XChaCha20Poly1305Codec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("XChaCha20Poly1305Codec")
.field("cipher", &"**OMITTED**".to_string())
.finish()
}
}
impl Codec for XChaCha20Poly1305Codec {
fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
// NOTE: As seen in orion, with a 24-bit nonce, it's safe to generate instead of
// maintaining a stateful counter due to its size (24-byte secret key generation
// will never panic)
let nonce_key = SecretKey::<NONCE_SIZE>::generate().unwrap();
let nonce = XNonce::from_slice(nonce_key.unprotected_as_bytes());
// Encrypt the frame's item as our ciphertext
let ciphertext = self
.cipher
.encrypt(nonce, frame.as_item())
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Encryption failed"))?;
// Frame is now comprised of the nonce and ciphertext in sequence
let mut frame = Frame::new(nonce.as_slice());
frame.extend(ciphertext);
Ok(frame.into_owned())
}
fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
if frame.len() <= NONCE_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Frame cannot have length less than {}", frame.len()),
));
}
// Grab the nonce from the front of the frame, and then use it with the remainder
// of the frame to tease out the decrypted frame item
let nonce = XNonce::from_slice(&frame.as_item()[..NONCE_SIZE]);
let item = self
.cipher
.decrypt(nonce, &frame.as_item()[NONCE_SIZE..])
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Decryption failed"))?;
Ok(Frame::from(item))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_should_build_a_frame_containing_a_length_nonce_and_ciphertext() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key.clone());
let frame = codec
.encode(Frame::new(b"hello world"))
.expect("Failed to encode");
let nonce = XNonce::from_slice(&frame.as_item()[..NONCE_SIZE]);
let ciphertext = &frame.as_item()[NONCE_SIZE..];
// Manually build our key & cipher so we can decrypt the frame manually to ensure it is
// correct
let key = Key::from_slice(key.unprotected_as_bytes());
let cipher = XChaCha20Poly1305::new(key);
let item = cipher
.decrypt(nonce, ciphertext)
.expect("Failed to decrypt");
assert_eq!(item, b"hello world");
}
#[test]
fn decode_should_fail_if_frame_length_is_smaller_than_nonce_plus_data() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
// NONCE_SIZE + 1 is minimum for frame length
let frame = Frame::from(b"a".repeat(NONCE_SIZE));
let result = codec.decode(frame);
match result {
Err(x) if x.kind() == io::ErrorKind::InvalidData => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[test]
fn decode_should_fail_if_unable_to_decrypt_frame_item() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
// NONCE_SIZE + 1 is minimum for frame length
let frame = Frame::from(b"a".repeat(NONCE_SIZE + 1));
let result = codec.decode(frame);
match result {
Err(x) if x.kind() == io::ErrorKind::InvalidData => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[test]
fn decode_should_return_decrypted_frame_when_successful() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
let frame = codec
.encode(Frame::new(b"hello, world"))
.expect("Failed to encode");
let frame = codec.decode(frame).expect("Failed to decode");
assert_eq!(frame, b"hello, world");
}
}

@ -0,0 +1,218 @@
use super::{
BoxedCodec, ChainCodec, CompressionCodec, CompressionLevel, CompressionType, EncryptionCodec,
EncryptionType, FramedTransport, HeapSecretKey, PlainCodec, Transport,
};
use crate::utils;
use log::*;
use serde::{Deserialize, Serialize};
use std::io;
mod on_choice;
mod on_handshake;
pub use on_choice::*;
pub use on_handshake::*;
/// Options from the server representing available methods to configure a framed transport
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct HandshakeServerOptions {
#[serde(rename = "c")]
compression: Vec<CompressionType>,
#[serde(rename = "e")]
encryption: Vec<EncryptionType>,
}
/// Client choice representing the selected configuration for a framed transport
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct HandshakeClientChoice {
#[serde(rename = "c")]
compression: Option<CompressionType>,
#[serde(rename = "cl")]
compression_level: Option<CompressionLevel>,
#[serde(rename = "e")]
encryption: Option<EncryptionType>,
}
/// Definition of the handshake to perform for a transport
#[derive(Debug)]
pub enum Handshake<T, const CAPACITY: usize> {
/// Indicates that the handshake is being performed from the client-side
Client {
/// Secret key to use with encryption
key: HeapSecretKey,
/// Callback to invoke when receiving server options
on_choice: OnHandshakeClientChoice,
/// Callback to invoke when the handshake has completed, providing a user-level handshake
/// operations
on_handshake: OnHandshake<T, CAPACITY>,
},
/// Indicates that the handshake is being performed from the server-side
Server {
/// List of available compression algorithms for use between client and server
compression: Vec<CompressionType>,
/// List of available encryption algorithms for use between client and server
encryption: Vec<EncryptionType>,
/// Secret key to use with encryption
key: HeapSecretKey,
/// Callback to invoke when the handshake has completed, providing a user-level handshake
/// operations
on_handshake: OnHandshake<T, CAPACITY>,
},
}
impl<T, const CAPACITY: usize> Handshake<T, CAPACITY> {
/// Creates a new client handshake definition with `on_handshake` as a callback when the
/// handshake has completed to enable user-level handshake operations
pub fn client(
key: HeapSecretKey,
on_choice: impl Into<OnHandshakeClientChoice>,
on_handshake: impl Into<OnHandshake<T, CAPACITY>>,
) -> Self {
Self::Client {
key,
on_choice: on_choice.into(),
on_handshake: on_handshake.into(),
}
}
/// Creates a new server handshake definition with `on_handshake` as a callback when the
/// handshake has completed to enable user-level handshake operations
pub fn server(key: HeapSecretKey, on_handshake: impl Into<OnHandshake<T, CAPACITY>>) -> Self {
Self::Server {
compression: CompressionType::known_variants().to_vec(),
encryption: EncryptionType::known_variants().to_vec(),
key,
on_handshake: on_handshake.into(),
}
}
}
/// Helper method to perform a handshake
///
/// ### Client
///
/// 1. Wait for options from server
/// 2. Send to server a compression and encryption choice
/// 3. Configure framed transport using selected choices
/// 4. Invoke on_handshake function
///
/// ### Server
///
/// 1. Send options to client
/// 2. Receive choices from client
/// 3. Configure framed transport using client's choices
/// 4. Invoke on_handshake function
///
pub(crate) async fn do_handshake<T, const CAPACITY: usize>(
transport: T,
handshake: &Handshake<T, CAPACITY>,
) -> io::Result<FramedTransport<T, CAPACITY>>
where
T: Transport,
{
let mut transport = FramedTransport::plain(transport);
macro_rules! write_frame {
($data:expr) => {{
transport
.write_frame(utils::serialize_to_vec(&$data)?)
.await?
}};
}
macro_rules! next_frame_as {
($type:ty) => {{
let frame = transport.read_frame().await?.ok_or_else(|| {
io::Error::new(io::ErrorKind::UnexpectedEof, "Transport closed early")
})?;
utils::deserialize_from_slice::<$type>(frame.as_item())?
}};
}
match handshake {
Handshake::Client {
key,
on_choice,
on_handshake,
} => {
// Receive options from the server and pick one
debug!("[Handshake] Client waiting on server options");
let options = next_frame_as!(HandshakeServerOptions);
// Choose a compression and encryption option from the options
debug!("[Handshake] Client selecting from server options: {options:#?}");
let choice = (on_choice.0)(options);
// Report back to the server the choice
debug!("[Handshake] Client reporting choice: {choice:#?}");
write_frame!(choice);
// Transform the transport's codec to abide by the choice
let transport = transform_transport(transport, choice, &key)?;
// Invoke callback to signal completion of handshake
debug!("[Handshake] Standard client handshake done, invoking callback");
(on_handshake.0)(transport).await
}
Handshake::Server {
compression,
encryption,
key,
on_handshake,
} => {
let options = HandshakeServerOptions {
compression: compression.to_vec(),
encryption: encryption.to_vec(),
};
// Send options to the client
debug!("[Handshake] Server sending options: {options:#?}");
write_frame!(options);
// Get client's response with selected compression and encryption
debug!("[Handshake] Server waiting on client choice");
let choice = next_frame_as!(HandshakeClientChoice);
// Transform the transport's codec to abide by the choice
let transport = transform_transport(transport, choice, &key)?;
// Invoke callback to signal completion of handshake
debug!("[Handshake] Standard server handshake done, invoking callback");
(on_handshake.0)(transport).await
}
}
}
fn transform_transport<T, const CAPACITY: usize>(
transport: FramedTransport<T, CAPACITY>,
choice: HandshakeClientChoice,
secret_key: &HeapSecretKey,
) -> io::Result<FramedTransport<T, CAPACITY>> {
let codec: BoxedCodec = match (choice.compression, choice.encryption) {
(Some(compression), Some(encryption)) => Box::new(ChainCodec::new(
EncryptionCodec::from_type_and_key(encryption, secret_key.unprotected_as_bytes())?,
CompressionCodec::from_type_and_level(
compression,
choice.compression_level.unwrap_or_default(),
)?,
)),
(None, Some(encryption)) => Box::new(EncryptionCodec::from_type_and_key(
encryption,
secret_key.unprotected_as_bytes(),
)?),
(Some(compression), None) => Box::new(CompressionCodec::from_type_and_level(
compression,
choice.compression_level.unwrap_or_default(),
)?),
(None, None) => Box::new(PlainCodec::new()),
};
Ok(transport.with_codec(codec))
}

@ -0,0 +1,44 @@
use super::{HandshakeClientChoice, HandshakeServerOptions};
use std::fmt;
/// Callback invoked when a client receives server options during a handshake
pub struct OnHandshakeClientChoice(
pub(super) Box<dyn Fn(HandshakeServerOptions) -> HandshakeClientChoice>,
);
impl OnHandshakeClientChoice {
/// Wraps a function `f` as a callback
pub fn new<F>(f: F) -> Self
where
F: Fn(HandshakeServerOptions) -> HandshakeClientChoice,
{
Self(Box::new(f))
}
}
impl<F> From<F> for OnHandshakeClientChoice
where
F: Fn(HandshakeServerOptions) -> HandshakeClientChoice,
{
fn from(f: F) -> Self {
Self::new(f)
}
}
impl fmt::Debug for OnHandshakeClientChoice {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OnHandshakeClientChoice").finish()
}
}
impl Default for OnHandshakeClientChoice {
/// Implements choice selection that picks first available of encryption and nothing of
/// compression
fn default() -> Self {
Self::new(|options| HandshakeClientChoice {
compression: None,
compression_level: None,
encryption: options.encryption.first().copied(),
})
}
}

@ -0,0 +1,48 @@
use super::FramedTransport;
use std::{fmt, future::Future, io, pin::Pin};
/// Boxed function representing `on_handshake` callback
pub type BoxedOnHandshakeFn<T, const CAPACITY: usize> = Box<
dyn Fn(
FramedTransport<T, CAPACITY>,
) -> Pin<Box<dyn Future<Output = io::Result<FramedTransport<T, CAPACITY>>>>>,
>;
/// Callback invoked when a handshake occurs
pub struct OnHandshake<T, const CAPACITY: usize>(pub(super) BoxedOnHandshakeFn<T, CAPACITY>);
impl<T, const CAPACITY: usize> OnHandshake<T, CAPACITY> {
/// Wraps a function `f` as a callback for a handshake
pub fn new<F>(f: F) -> Self
where
F: Fn(
FramedTransport<T, CAPACITY>,
) -> Pin<Box<dyn Future<Output = io::Result<FramedTransport<T, CAPACITY>>>>>,
{
Self(Box::new(f))
}
}
impl<T, F, const CAPACITY: usize> From<F> for OnHandshake<T, CAPACITY>
where
F: Fn(
FramedTransport<T, CAPACITY>,
) -> Pin<Box<dyn Future<Output = io::Result<FramedTransport<T, CAPACITY>>>>>,
{
fn from(f: F) -> Self {
Self::new(f)
}
}
impl<T, const CAPACITY: usize> fmt::Debug for OnHandshake<T, CAPACITY> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OnHandshake").finish()
}
}
impl<T, const CAPACITY: usize> Default for OnHandshake<T, CAPACITY> {
/// Implements handshake callback that does nothing
fn default() -> Self {
Self::new(|transport| Box::pin(async { Ok(transport) }))
}
}
Loading…
Cancel
Save