mirror of https://github.com/chipsenkbeil/distant
Still unfinished changes
parent
ad9f1ac05a
commit
36c05c4283
@ -1,55 +0,0 @@
|
||||
use crate::{BoxedCodec, FramedTransport, PlainCodec, Request, Response, Transport};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::io;
|
||||
|
||||
/// Represents options that the server has available for a connection
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct ServerConnectionOptions {
|
||||
/// Choices for encryption as string labels
|
||||
pub encryption: Vec<String>,
|
||||
|
||||
/// Choices for compression as string labels
|
||||
pub compression: Vec<String>,
|
||||
}
|
||||
|
||||
/// Represents the choice that the client has made regarding server connection options
|
||||
struct ClientConnectionChoice {
|
||||
/// Selected encryption
|
||||
pub encryption: String,
|
||||
|
||||
/// Selected compression
|
||||
pub compression: String,
|
||||
}
|
||||
|
||||
/// Performs the client-side of a handshake
|
||||
pub async fn client_handshake<T>(transport: T) -> io::Result<FramedTransport<T, BoxedCodec>>
|
||||
where
|
||||
T: Transport,
|
||||
{
|
||||
let transport = FramedTransport::new(transport, PlainCodec::new());
|
||||
|
||||
// Wait for the server to send us choices for communication
|
||||
let frame = transport.read_frame().await?.ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::ConnectionAborted,
|
||||
"Connection aborted before receiving server communication",
|
||||
)
|
||||
})?;
|
||||
|
||||
// Parse the frame as the request for the client
|
||||
let request = Request::<ServerConnectionOptions>::from_slice(frame.as_item())?;
|
||||
|
||||
// Select an encryption and compression choice
|
||||
let encryption = request.payload.encryption[0];
|
||||
let compression = request.payload.compression[0];
|
||||
|
||||
// Respond back with choices
|
||||
}
|
||||
|
||||
/// Performs the server-side of a handshake
|
||||
pub async fn server_handshake<T>(transport: T) -> io::Result<FramedTransport<T, BoxedCodec>>
|
||||
where
|
||||
T: Transport,
|
||||
{
|
||||
let transport = FramedTransport::new(transport, PlainCodec::new());
|
||||
}
|
@ -0,0 +1,242 @@
|
||||
use super::{Codec, Frame};
|
||||
use derive_more::Display;
|
||||
use std::{fmt, io};
|
||||
|
||||
mod key;
|
||||
pub use key::*;
|
||||
|
||||
/// Represents the type of encryption for a [`EncryptionCodec`]
|
||||
#[derive(
|
||||
Copy, Clone, Debug, Display, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize,
|
||||
)]
|
||||
pub enum EncryptionType {
|
||||
/// ChaCha20Poly1305 variant with an extended 192-bit (24-byte) nonce
|
||||
#[display(fmt = "xchacha20poly1305")]
|
||||
XChaCha20Poly1305,
|
||||
|
||||
/// Indicates an unknown encryption type for use in handshakes
|
||||
#[display(fmt = "unknown")]
|
||||
#[serde(other)]
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl EncryptionType {
|
||||
/// Generates bytes for a secret key based on the encryption type
|
||||
pub fn generate_secret_key_bytes(&self) -> io::Result<Vec<u8>> {
|
||||
match self {
|
||||
Self::XChaCha20Poly1305 => Ok(SecretKey::<32>::generate()
|
||||
.unwrap()
|
||||
.unprotected_into_bytes()),
|
||||
Self::Unknown => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"Unknown encryption type",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a list of all variants of the type *except* unknown.
|
||||
pub const fn known_variants() -> &'static [EncryptionType] {
|
||||
&[EncryptionType::XChaCha20Poly1305]
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the codec that encodes & decodes frames by encrypting/decrypting them
|
||||
#[derive(Clone)]
|
||||
pub enum EncryptionCodec {
|
||||
/// ChaCha20Poly1305 variant with an extended 192-bit (24-byte) nonce, using
|
||||
/// [`XChaCha20Poly1305`] underneath
|
||||
XChaCha20Poly1305 {
|
||||
cipher: chacha20poly1305::XChaCha20Poly1305,
|
||||
},
|
||||
}
|
||||
|
||||
impl EncryptionCodec {
|
||||
/// Makes a new [`EncryptionCodec`] based on the [`EncryptionType`] and `key`, returning an
|
||||
/// error if the key is invalid for the encryption type or the type is unknown
|
||||
pub fn from_type_and_key(ty: EncryptionType, key: &[u8]) -> io::Result<EncryptionCodec> {
|
||||
match ty {
|
||||
EncryptionType::XChaCha20Poly1305 => {
|
||||
use chacha20poly1305::{KeyInit, XChaCha20Poly1305};
|
||||
let cipher = XChaCha20Poly1305::new_from_slice(key)
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x))?;
|
||||
Ok(Self::XChaCha20Poly1305 { cipher })
|
||||
}
|
||||
EncryptionType::Unknown => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"Encryption type is unknown",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_xchacha20poly1305(secret_key: SecretKey32) -> EncryptionCodec {
|
||||
// NOTE: This should never fail as we are enforcing the key size at compile time
|
||||
Self::from_type_and_key(
|
||||
EncryptionType::XChaCha20Poly1305,
|
||||
secret_key.unprotected_as_bytes(),
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Returns the encryption type associa ted with the codec
|
||||
pub fn ty(&self) -> EncryptionType {
|
||||
match self {
|
||||
Self::XChaCha20Poly1305 { .. } => EncryptionType::XChaCha20Poly1305,
|
||||
}
|
||||
}
|
||||
|
||||
/// Size of nonce (in bytes) associated with the encryption algorithm
|
||||
pub const fn nonce_size(&self) -> usize {
|
||||
match self {
|
||||
// XChaCha20Poly1305 uses a 192-bit (24-byte) key
|
||||
Self::XChaCha20Poly1305 { .. } => 24,
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates a new nonce for the encryption algorithm
|
||||
fn generate_nonce_bytes(&self) -> Vec<u8> {
|
||||
// NOTE: As seen in orion, with a 24-bit nonce, it's safe to generate instead of
|
||||
// maintaining a stateful counter due to its size (24-byte secret key generation
|
||||
// will never panic)
|
||||
match self {
|
||||
Self::XChaCha20Poly1305 { .. } => SecretKey::<24>::generate()
|
||||
.unwrap()
|
||||
.unprotected_into_bytes(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for EncryptionCodec {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("EncryptionCodec")
|
||||
.field("cipher", &"**OMITTED**".to_string())
|
||||
.field("nonce_size", &self.nonce_size())
|
||||
.field("ty", &self.ty().to_string())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Codec for EncryptionCodec {
|
||||
fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
|
||||
let frame = match self {
|
||||
Self::XChaCha20Poly1305 { cipher } => {
|
||||
use chacha20poly1305::{aead::Aead, XNonce};
|
||||
let nonce_bytes = self.generate_nonce_bytes();
|
||||
let nonce = XNonce::from_slice(&nonce_bytes);
|
||||
|
||||
// Encrypt the frame's item as our ciphertext
|
||||
let ciphertext = cipher
|
||||
.encrypt(nonce, frame.as_item())
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Encryption failed"))?;
|
||||
|
||||
// Frame is now comprised of the nonce and ciphertext in sequence
|
||||
let mut frame = Frame::new(&nonce_bytes);
|
||||
frame.extend(ciphertext);
|
||||
frame
|
||||
}
|
||||
};
|
||||
|
||||
Ok(frame.into_owned())
|
||||
}
|
||||
|
||||
fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
|
||||
if frame.len() <= self.nonce_size() {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!(
|
||||
"Frame cannot have length less than {}",
|
||||
self.nonce_size() + 1
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// Grab the nonce from the front of the frame, and then use it with the remainder
|
||||
// of the frame to tease out the decrypted frame item
|
||||
let item = match self {
|
||||
Self::XChaCha20Poly1305 { cipher } => {
|
||||
use chacha20poly1305::{aead::Aead, XNonce};
|
||||
let nonce = XNonce::from_slice(&frame.as_item()[..self.nonce_size()]);
|
||||
cipher
|
||||
.decrypt(nonce, &frame.as_item()[self.nonce_size()..])
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Decryption failed"))?
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Frame::from(item))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn encode_should_build_a_frame_containing_a_length_nonce_and_ciphertext() {
|
||||
let ty = EncryptionType::XChaCha20Poly1305;
|
||||
let key = ty.generate_secret_key_bytes().unwrap();
|
||||
let mut codec = EncryptionCodec::from_type_and_key(ty, &key).unwrap();
|
||||
|
||||
let frame = codec
|
||||
.encode(Frame::new(b"hello world"))
|
||||
.expect("Failed to encode");
|
||||
|
||||
let nonce = &frame.as_item()[..codec.nonce_size()];
|
||||
let ciphertext = &frame.as_item()[codec.nonce_size()..];
|
||||
|
||||
// Manually build our key & cipher so we can decrypt the frame manually to ensure it is
|
||||
// correct
|
||||
let item = {
|
||||
use chacha20poly1305::{aead::Aead, KeyInit, XChaCha20Poly1305, XNonce};
|
||||
let cipher = XChaCha20Poly1305::new_from_slice(&key).unwrap();
|
||||
cipher
|
||||
.decrypt(XNonce::from_slice(nonce), ciphertext)
|
||||
.expect("Failed to decrypt")
|
||||
};
|
||||
assert_eq!(item, b"hello world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_should_fail_if_frame_length_is_smaller_than_nonce_plus_data() {
|
||||
let ty = EncryptionType::XChaCha20Poly1305;
|
||||
let key = ty.generate_secret_key_bytes().unwrap();
|
||||
let mut codec = EncryptionCodec::from_type_and_key(ty, &key).unwrap();
|
||||
|
||||
// NONCE_SIZE + 1 is minimum for frame length
|
||||
let frame = Frame::from(b"a".repeat(codec.nonce_size()));
|
||||
|
||||
let result = codec.decode(frame);
|
||||
match result {
|
||||
Err(x) if x.kind() == io::ErrorKind::InvalidData => {}
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_should_fail_if_unable_to_decrypt_frame_item() {
|
||||
let ty = EncryptionType::XChaCha20Poly1305;
|
||||
let key = ty.generate_secret_key_bytes().unwrap();
|
||||
let mut codec = EncryptionCodec::from_type_and_key(ty, &key).unwrap();
|
||||
|
||||
// NONCE_SIZE + 1 is minimum for frame length
|
||||
let frame = Frame::from(b"a".repeat(codec.nonce_size() + 1));
|
||||
|
||||
let result = codec.decode(frame);
|
||||
match result {
|
||||
Err(x) if x.kind() == io::ErrorKind::InvalidData => {}
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_should_return_decrypted_frame_when_successful() {
|
||||
let ty = EncryptionType::XChaCha20Poly1305;
|
||||
let key = ty.generate_secret_key_bytes().unwrap();
|
||||
let mut codec = EncryptionCodec::from_type_and_key(ty, &key).unwrap();
|
||||
|
||||
let frame = codec
|
||||
.encode(Frame::new(b"hello, world"))
|
||||
.expect("Failed to encode");
|
||||
|
||||
let frame = codec.decode(frame).expect("Failed to decode");
|
||||
assert_eq!(frame, b"hello, world");
|
||||
}
|
||||
}
|
@ -1,150 +0,0 @@
|
||||
use super::{Codec, Frame};
|
||||
use crate::{SecretKey, SecretKey32};
|
||||
use chacha20poly1305::{aead::Aead, Key, KeyInit, XChaCha20Poly1305, XNonce};
|
||||
use std::{fmt, io};
|
||||
|
||||
/// Total bytes to use for nonce
|
||||
const NONCE_SIZE: usize = 24;
|
||||
|
||||
/// Represents the codec that encodes & decodes frames by encrypting/decrypting them using
|
||||
/// [`XChaCha20Poly1305`].
|
||||
///
|
||||
/// NOTE: Uses a 32-byte key internally.
|
||||
#[derive(Clone)]
|
||||
pub struct XChaCha20Poly1305Codec {
|
||||
cipher: XChaCha20Poly1305,
|
||||
}
|
||||
|
||||
impl XChaCha20Poly1305Codec {
|
||||
pub fn new(key: &[u8]) -> Self {
|
||||
let key = Key::from_slice(key);
|
||||
let cipher = XChaCha20Poly1305::new(key);
|
||||
Self { cipher }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SecretKey32> for XChaCha20Poly1305Codec {
|
||||
/// Create a new XChaCha20Poly1305 codec with a 32-byte key
|
||||
fn from(secret_key: SecretKey32) -> Self {
|
||||
Self::new(secret_key.unprotected_as_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for XChaCha20Poly1305Codec {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("XChaCha20Poly1305Codec")
|
||||
.field("cipher", &"**OMITTED**".to_string())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Codec for XChaCha20Poly1305Codec {
|
||||
fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
|
||||
// NOTE: As seen in orion, with a 24-bit nonce, it's safe to generate instead of
|
||||
// maintaining a stateful counter due to its size (24-byte secret key generation
|
||||
// will never panic)
|
||||
let nonce_key = SecretKey::<NONCE_SIZE>::generate().unwrap();
|
||||
let nonce = XNonce::from_slice(nonce_key.unprotected_as_bytes());
|
||||
|
||||
// Encrypt the frame's item as our ciphertext
|
||||
let ciphertext = self
|
||||
.cipher
|
||||
.encrypt(nonce, frame.as_item())
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Encryption failed"))?;
|
||||
|
||||
// Frame is now comprised of the nonce and ciphertext in sequence
|
||||
let mut frame = Frame::new(nonce.as_slice());
|
||||
frame.extend(ciphertext);
|
||||
|
||||
Ok(frame.into_owned())
|
||||
}
|
||||
|
||||
fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
|
||||
if frame.len() <= NONCE_SIZE {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Frame cannot have length less than {}", frame.len()),
|
||||
));
|
||||
}
|
||||
|
||||
// Grab the nonce from the front of the frame, and then use it with the remainder
|
||||
// of the frame to tease out the decrypted frame item
|
||||
let nonce = XNonce::from_slice(&frame.as_item()[..NONCE_SIZE]);
|
||||
let item = self
|
||||
.cipher
|
||||
.decrypt(nonce, &frame.as_item()[NONCE_SIZE..])
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Decryption failed"))?;
|
||||
|
||||
Ok(Frame::from(item))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn encode_should_build_a_frame_containing_a_length_nonce_and_ciphertext() {
|
||||
let key = SecretKey32::default();
|
||||
let mut codec = XChaCha20Poly1305Codec::from(key.clone());
|
||||
|
||||
let frame = codec
|
||||
.encode(Frame::new(b"hello world"))
|
||||
.expect("Failed to encode");
|
||||
|
||||
let nonce = XNonce::from_slice(&frame.as_item()[..NONCE_SIZE]);
|
||||
let ciphertext = &frame.as_item()[NONCE_SIZE..];
|
||||
|
||||
// Manually build our key & cipher so we can decrypt the frame manually to ensure it is
|
||||
// correct
|
||||
let key = Key::from_slice(key.unprotected_as_bytes());
|
||||
let cipher = XChaCha20Poly1305::new(key);
|
||||
let item = cipher
|
||||
.decrypt(nonce, ciphertext)
|
||||
.expect("Failed to decrypt");
|
||||
assert_eq!(item, b"hello world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_should_fail_if_frame_length_is_smaller_than_nonce_plus_data() {
|
||||
let key = SecretKey32::default();
|
||||
let mut codec = XChaCha20Poly1305Codec::from(key);
|
||||
|
||||
// NONCE_SIZE + 1 is minimum for frame length
|
||||
let frame = Frame::from(b"a".repeat(NONCE_SIZE));
|
||||
|
||||
let result = codec.decode(frame);
|
||||
match result {
|
||||
Err(x) if x.kind() == io::ErrorKind::InvalidData => {}
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_should_fail_if_unable_to_decrypt_frame_item() {
|
||||
let key = SecretKey32::default();
|
||||
let mut codec = XChaCha20Poly1305Codec::from(key);
|
||||
|
||||
// NONCE_SIZE + 1 is minimum for frame length
|
||||
let frame = Frame::from(b"a".repeat(NONCE_SIZE + 1));
|
||||
|
||||
let result = codec.decode(frame);
|
||||
match result {
|
||||
Err(x) if x.kind() == io::ErrorKind::InvalidData => {}
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_should_return_decrypted_frame_when_successful() {
|
||||
let key = SecretKey32::default();
|
||||
let mut codec = XChaCha20Poly1305Codec::from(key);
|
||||
|
||||
let frame = codec
|
||||
.encode(Frame::new(b"hello, world"))
|
||||
.expect("Failed to encode");
|
||||
|
||||
let frame = codec.decode(frame).expect("Failed to decode");
|
||||
assert_eq!(frame, b"hello, world");
|
||||
}
|
||||
}
|
@ -0,0 +1,218 @@
|
||||
use super::{
|
||||
BoxedCodec, ChainCodec, CompressionCodec, CompressionLevel, CompressionType, EncryptionCodec,
|
||||
EncryptionType, FramedTransport, HeapSecretKey, PlainCodec, Transport,
|
||||
};
|
||||
use crate::utils;
|
||||
use log::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::io;
|
||||
|
||||
mod on_choice;
|
||||
mod on_handshake;
|
||||
|
||||
pub use on_choice::*;
|
||||
pub use on_handshake::*;
|
||||
|
||||
/// Options from the server representing available methods to configure a framed transport
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct HandshakeServerOptions {
|
||||
#[serde(rename = "c")]
|
||||
compression: Vec<CompressionType>,
|
||||
#[serde(rename = "e")]
|
||||
encryption: Vec<EncryptionType>,
|
||||
}
|
||||
|
||||
/// Client choice representing the selected configuration for a framed transport
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct HandshakeClientChoice {
|
||||
#[serde(rename = "c")]
|
||||
compression: Option<CompressionType>,
|
||||
#[serde(rename = "cl")]
|
||||
compression_level: Option<CompressionLevel>,
|
||||
#[serde(rename = "e")]
|
||||
encryption: Option<EncryptionType>,
|
||||
}
|
||||
|
||||
/// Definition of the handshake to perform for a transport
|
||||
#[derive(Debug)]
|
||||
pub enum Handshake<T, const CAPACITY: usize> {
|
||||
/// Indicates that the handshake is being performed from the client-side
|
||||
Client {
|
||||
/// Secret key to use with encryption
|
||||
key: HeapSecretKey,
|
||||
|
||||
/// Callback to invoke when receiving server options
|
||||
on_choice: OnHandshakeClientChoice,
|
||||
|
||||
/// Callback to invoke when the handshake has completed, providing a user-level handshake
|
||||
/// operations
|
||||
on_handshake: OnHandshake<T, CAPACITY>,
|
||||
},
|
||||
|
||||
/// Indicates that the handshake is being performed from the server-side
|
||||
Server {
|
||||
/// List of available compression algorithms for use between client and server
|
||||
compression: Vec<CompressionType>,
|
||||
|
||||
/// List of available encryption algorithms for use between client and server
|
||||
encryption: Vec<EncryptionType>,
|
||||
|
||||
/// Secret key to use with encryption
|
||||
key: HeapSecretKey,
|
||||
|
||||
/// Callback to invoke when the handshake has completed, providing a user-level handshake
|
||||
/// operations
|
||||
on_handshake: OnHandshake<T, CAPACITY>,
|
||||
},
|
||||
}
|
||||
|
||||
impl<T, const CAPACITY: usize> Handshake<T, CAPACITY> {
|
||||
/// Creates a new client handshake definition with `on_handshake` as a callback when the
|
||||
/// handshake has completed to enable user-level handshake operations
|
||||
pub fn client(
|
||||
key: HeapSecretKey,
|
||||
on_choice: impl Into<OnHandshakeClientChoice>,
|
||||
on_handshake: impl Into<OnHandshake<T, CAPACITY>>,
|
||||
) -> Self {
|
||||
Self::Client {
|
||||
key,
|
||||
on_choice: on_choice.into(),
|
||||
on_handshake: on_handshake.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new server handshake definition with `on_handshake` as a callback when the
|
||||
/// handshake has completed to enable user-level handshake operations
|
||||
pub fn server(key: HeapSecretKey, on_handshake: impl Into<OnHandshake<T, CAPACITY>>) -> Self {
|
||||
Self::Server {
|
||||
compression: CompressionType::known_variants().to_vec(),
|
||||
encryption: EncryptionType::known_variants().to_vec(),
|
||||
key,
|
||||
on_handshake: on_handshake.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper method to perform a handshake
|
||||
///
|
||||
/// ### Client
|
||||
///
|
||||
/// 1. Wait for options from server
|
||||
/// 2. Send to server a compression and encryption choice
|
||||
/// 3. Configure framed transport using selected choices
|
||||
/// 4. Invoke on_handshake function
|
||||
///
|
||||
/// ### Server
|
||||
///
|
||||
/// 1. Send options to client
|
||||
/// 2. Receive choices from client
|
||||
/// 3. Configure framed transport using client's choices
|
||||
/// 4. Invoke on_handshake function
|
||||
///
|
||||
pub(crate) async fn do_handshake<T, const CAPACITY: usize>(
|
||||
transport: T,
|
||||
handshake: &Handshake<T, CAPACITY>,
|
||||
) -> io::Result<FramedTransport<T, CAPACITY>>
|
||||
where
|
||||
T: Transport,
|
||||
{
|
||||
let mut transport = FramedTransport::plain(transport);
|
||||
|
||||
macro_rules! write_frame {
|
||||
($data:expr) => {{
|
||||
transport
|
||||
.write_frame(utils::serialize_to_vec(&$data)?)
|
||||
.await?
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! next_frame_as {
|
||||
($type:ty) => {{
|
||||
let frame = transport.read_frame().await?.ok_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::UnexpectedEof, "Transport closed early")
|
||||
})?;
|
||||
|
||||
utils::deserialize_from_slice::<$type>(frame.as_item())?
|
||||
}};
|
||||
}
|
||||
|
||||
match handshake {
|
||||
Handshake::Client {
|
||||
key,
|
||||
on_choice,
|
||||
on_handshake,
|
||||
} => {
|
||||
// Receive options from the server and pick one
|
||||
debug!("[Handshake] Client waiting on server options");
|
||||
let options = next_frame_as!(HandshakeServerOptions);
|
||||
|
||||
// Choose a compression and encryption option from the options
|
||||
debug!("[Handshake] Client selecting from server options: {options:#?}");
|
||||
let choice = (on_choice.0)(options);
|
||||
|
||||
// Report back to the server the choice
|
||||
debug!("[Handshake] Client reporting choice: {choice:#?}");
|
||||
write_frame!(choice);
|
||||
|
||||
// Transform the transport's codec to abide by the choice
|
||||
let transport = transform_transport(transport, choice, &key)?;
|
||||
|
||||
// Invoke callback to signal completion of handshake
|
||||
debug!("[Handshake] Standard client handshake done, invoking callback");
|
||||
(on_handshake.0)(transport).await
|
||||
}
|
||||
Handshake::Server {
|
||||
compression,
|
||||
encryption,
|
||||
key,
|
||||
on_handshake,
|
||||
} => {
|
||||
let options = HandshakeServerOptions {
|
||||
compression: compression.to_vec(),
|
||||
encryption: encryption.to_vec(),
|
||||
};
|
||||
|
||||
// Send options to the client
|
||||
debug!("[Handshake] Server sending options: {options:#?}");
|
||||
write_frame!(options);
|
||||
|
||||
// Get client's response with selected compression and encryption
|
||||
debug!("[Handshake] Server waiting on client choice");
|
||||
let choice = next_frame_as!(HandshakeClientChoice);
|
||||
|
||||
// Transform the transport's codec to abide by the choice
|
||||
let transport = transform_transport(transport, choice, &key)?;
|
||||
|
||||
// Invoke callback to signal completion of handshake
|
||||
debug!("[Handshake] Standard server handshake done, invoking callback");
|
||||
(on_handshake.0)(transport).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn transform_transport<T, const CAPACITY: usize>(
|
||||
transport: FramedTransport<T, CAPACITY>,
|
||||
choice: HandshakeClientChoice,
|
||||
secret_key: &HeapSecretKey,
|
||||
) -> io::Result<FramedTransport<T, CAPACITY>> {
|
||||
let codec: BoxedCodec = match (choice.compression, choice.encryption) {
|
||||
(Some(compression), Some(encryption)) => Box::new(ChainCodec::new(
|
||||
EncryptionCodec::from_type_and_key(encryption, secret_key.unprotected_as_bytes())?,
|
||||
CompressionCodec::from_type_and_level(
|
||||
compression,
|
||||
choice.compression_level.unwrap_or_default(),
|
||||
)?,
|
||||
)),
|
||||
(None, Some(encryption)) => Box::new(EncryptionCodec::from_type_and_key(
|
||||
encryption,
|
||||
secret_key.unprotected_as_bytes(),
|
||||
)?),
|
||||
(Some(compression), None) => Box::new(CompressionCodec::from_type_and_level(
|
||||
compression,
|
||||
choice.compression_level.unwrap_or_default(),
|
||||
)?),
|
||||
(None, None) => Box::new(PlainCodec::new()),
|
||||
};
|
||||
|
||||
Ok(transport.with_codec(codec))
|
||||
}
|
@ -0,0 +1,44 @@
|
||||
use super::{HandshakeClientChoice, HandshakeServerOptions};
|
||||
use std::fmt;
|
||||
|
||||
/// Callback invoked when a client receives server options during a handshake
|
||||
pub struct OnHandshakeClientChoice(
|
||||
pub(super) Box<dyn Fn(HandshakeServerOptions) -> HandshakeClientChoice>,
|
||||
);
|
||||
|
||||
impl OnHandshakeClientChoice {
|
||||
/// Wraps a function `f` as a callback
|
||||
pub fn new<F>(f: F) -> Self
|
||||
where
|
||||
F: Fn(HandshakeServerOptions) -> HandshakeClientChoice,
|
||||
{
|
||||
Self(Box::new(f))
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> From<F> for OnHandshakeClientChoice
|
||||
where
|
||||
F: Fn(HandshakeServerOptions) -> HandshakeClientChoice,
|
||||
{
|
||||
fn from(f: F) -> Self {
|
||||
Self::new(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for OnHandshakeClientChoice {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("OnHandshakeClientChoice").finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OnHandshakeClientChoice {
|
||||
/// Implements choice selection that picks first available of encryption and nothing of
|
||||
/// compression
|
||||
fn default() -> Self {
|
||||
Self::new(|options| HandshakeClientChoice {
|
||||
compression: None,
|
||||
compression_level: None,
|
||||
encryption: options.encryption.first().copied(),
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,48 @@
|
||||
use super::FramedTransport;
|
||||
use std::{fmt, future::Future, io, pin::Pin};
|
||||
|
||||
/// Boxed function representing `on_handshake` callback
|
||||
pub type BoxedOnHandshakeFn<T, const CAPACITY: usize> = Box<
|
||||
dyn Fn(
|
||||
FramedTransport<T, CAPACITY>,
|
||||
) -> Pin<Box<dyn Future<Output = io::Result<FramedTransport<T, CAPACITY>>>>>,
|
||||
>;
|
||||
|
||||
/// Callback invoked when a handshake occurs
|
||||
pub struct OnHandshake<T, const CAPACITY: usize>(pub(super) BoxedOnHandshakeFn<T, CAPACITY>);
|
||||
|
||||
impl<T, const CAPACITY: usize> OnHandshake<T, CAPACITY> {
|
||||
/// Wraps a function `f` as a callback for a handshake
|
||||
pub fn new<F>(f: F) -> Self
|
||||
where
|
||||
F: Fn(
|
||||
FramedTransport<T, CAPACITY>,
|
||||
) -> Pin<Box<dyn Future<Output = io::Result<FramedTransport<T, CAPACITY>>>>>,
|
||||
{
|
||||
Self(Box::new(f))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, F, const CAPACITY: usize> From<F> for OnHandshake<T, CAPACITY>
|
||||
where
|
||||
F: Fn(
|
||||
FramedTransport<T, CAPACITY>,
|
||||
) -> Pin<Box<dyn Future<Output = io::Result<FramedTransport<T, CAPACITY>>>>>,
|
||||
{
|
||||
fn from(f: F) -> Self {
|
||||
Self::new(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, const CAPACITY: usize> fmt::Debug for OnHandshake<T, CAPACITY> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("OnHandshake").finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, const CAPACITY: usize> Default for OnHandshake<T, CAPACITY> {
|
||||
/// Implements handshake callback that does nothing
|
||||
fn default() -> Self {
|
||||
Self::new(|transport| Box::pin(async { Ok(transport) }))
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue