More progress by moving handshake logic to establish a codec to the FramedTransport and separating out authentication to the StatefulFramedTransport

pull/146/head
Chip Senkbeil 2 years ago
parent e95370589f
commit af0fc81187
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -1,13 +1,20 @@
use super::{Interest, Ready, Reconnectable, Transport};
use crate::utils;
use async_trait::async_trait;
use bytes::{Buf, BytesMut};
use log::*;
use serde::{Deserialize, Serialize};
use std::{fmt, io};
mod codec;
pub use codec::*;
mod exchange;
mod frame;
mod handshake;
pub use codec::*;
pub use exchange::*;
pub use frame::*;
pub use handshake::*;
/// By default, framed transport's initial capacity (and max single-read) will be 8 KiB
const DEFAULT_CAPACITY: usize = 8 * 1024;
@ -115,6 +122,192 @@ where
Ok(())
}
/// Performs a handshake in order to establish a new codec to use between this transport and
/// the other side. The parameter `handshake` defines how the transport will handle the
/// handshake with `Client` being used to pick the compression and encryption used while
/// `Server` defines what the choices are for compression and encryption.
///
/// This will reset the framed transport's codec to [`PlainCodec`] in order to communicate
/// which compression and encryption to use. Upon selecting an encryption type, a shared secret
/// key will be derived on both sides and used to establish the [`EncryptionCodec`], which in
/// combination with the [`CompressionCodec`] (if any) will replace this transport's codec.
///
/// ### Client
///
/// 1. Wait for options from server
/// 2. Send to server a compression and encryption choice
/// 3. Configure framed transport using selected choices
/// 4. Invoke on_handshake function
///
/// ### Server
///
/// 1. Send options to client
/// 2. Receive choices from client
/// 3. Configure framed transport using client's choices
/// 4. Invoke on_handshake function
///
/// ### Failure
///
/// The handshake will fail in several cases:
///
/// * If any frame during the handshake fails to be serialized
/// * If any unexpected frame is received during the handshake
/// * If using encryption and unable to derive a shared secret key
///
/// If a failure happens, the codec will be reset to what it was prior to the handshake
/// request, and all internal buffers will be cleared to avoid corruption.
///
pub async fn handshake(&mut self, handshake: Handshake) -> io::Result<()>
where
T: Transport,
{
// Place transport in plain text communication mode for start of handshake, and clear any
// data that is lingering within internal buffers
//
// NOTE: We grab the old codec in case we encounter an error and need to reset it
let old_codec = std::mem::replace(&mut self.codec, Box::new(PlainCodec::new()));
self.clear();
// Transform the transport's codec to abide by the choice. In the case of an error, we
// reset the codec back to what it was prior to attempting the handshake and clear the
// internal buffers as they may be corrupt.
match self.handshake_impl(handshake).await {
Ok(codec) => {
self.set_codec(codec);
Ok(())
}
Err(x) => {
self.set_codec(old_codec);
self.clear();
Err(x)
}
}
}
async fn handshake_impl(&mut self, handshake: Handshake) -> io::Result<BoxedCodec> {
#[derive(Debug, Serialize, Deserialize)]
struct Choice {
compression_level: Option<CompressionLevel>,
compression_type: Option<CompressionType>,
encryption_type: Option<EncryptionType>,
}
#[derive(Debug, Serialize, Deserialize)]
struct Options {
compression_types: Vec<CompressionType>,
encryption_types: Vec<EncryptionType>,
}
macro_rules! write_frame {
($data:expr) => {{
self.write_frame(utils::serialize_to_vec(&$data)?).await?
}};
}
macro_rules! next_frame_as {
($type:ty) => {{
let frame = self.read_frame().await?.ok_or_else(|| {
io::Error::new(io::ErrorKind::UnexpectedEof, "Transport closed early")
})?;
utils::deserialize_from_slice::<$type>(frame.as_item())?
}};
}
// Determine compression and encryption to apply to framed transport
let choice = match handshake {
Handshake::Client {
preferred_compression_type,
preferred_compression_level,
preferred_encryption_type,
} => {
// Receive options from the server and pick one
debug!("[Handshake] Client waiting on server options");
let options = next_frame_as!(Options);
// Choose a compression and encryption option from the options
debug!("[Handshake] Client selecting from server options: {options:#?}");
let choice = Choice {
compression_type: preferred_compression_type
.filter(|ty| options.compression_types.contains(ty)),
compression_level: preferred_compression_level,
encryption_type: preferred_encryption_type
.filter(|ty| options.encryption_types.contains(ty)),
};
// Report back to the server the choice
debug!("[Handshake] Client reporting choice: {choice:#?}");
write_frame!(choice);
choice
}
Handshake::Server {
compression_types,
encryption_types,
} => {
let options = Options {
compression_types: compression_types.to_vec(),
encryption_types: encryption_types.to_vec(),
};
// Send options to the client
debug!("[Handshake] Server sending options: {options:#?}");
write_frame!(options);
// Get client's response with selected compression and encryption
debug!("[Handshake] Server waiting on client choice");
next_frame_as!(Choice)
}
};
debug!("[Handshake] Building compression & encryption codecs based on {choice:#?}");
let compression_level = choice.compression_level.unwrap_or_default();
// Acquire a codec for the compression type
let compression_codec = choice
.compression_type
.map(|ty| ty.new_codec(compression_level))
.transpose()?;
// In the case that we are using encryption, we derive a shared secret key to use with the
// encryption type
let encryption_codec = match choice.encryption_type {
Some(ty) => {
#[derive(Serialize, Deserialize)]
struct KeyExchangeData {
/// Bytes of the public key
#[serde(with = "serde_bytes")]
public_key: PublicKeyBytes,
/// Randomly generated salt
#[serde(with = "serde_bytes")]
salt: Salt,
}
let exchange = KeyExchange::default();
write_frame!(KeyExchangeData {
public_key: exchange.pk_bytes(),
salt: *exchange.salt(),
});
let data = next_frame_as!(KeyExchangeData);
let key = exchange.derive_shared_secret(data.public_key, data.salt)?;
Some(ty.new_codec(key.unprotected_as_bytes())?)
}
None => None,
};
// Bundle our compression and encryption codecs into a single, chained codec
let codec: BoxedCodec = match (compression_codec, encryption_codec) {
(Some(c), Some(e)) => Box::new(ChainCodec::new(c, e)),
(Some(c), None) => Box::new(c),
(None, Some(e)) => Box::new(e),
(None, None) => Box::new(PlainCodec::new()),
};
Ok(codec)
}
}
impl<T, const CAPACITY: usize> FramedTransport<T, CAPACITY>

@ -66,6 +66,11 @@ impl CompressionType {
CompressionType::Zlib,
]
}
/// Creates a new [`CompressionCodec`] for this type, failing if this type is unknown
pub fn new_codec(&self, level: CompressionLevel) -> io::Result<CompressionCodec> {
CompressionCodec::from_type_and_level(*self, level)
}
}
/// Represents a codec that applies compression during encoding and decompression during decoding

@ -26,6 +26,7 @@ impl EncryptionType {
match self {
Self::XChaCha20Poly1305 => Ok(SecretKey::<32>::generate()
.unwrap()
.into_heap_secret_key()
.unprotected_into_bytes()),
Self::Unknown => Err(io::Error::new(
io::ErrorKind::InvalidInput,
@ -38,6 +39,12 @@ impl EncryptionType {
pub const fn known_variants() -> &'static [EncryptionType] {
&[EncryptionType::XChaCha20Poly1305]
}
/// Creates a new [`EncryptionCodec`] for this type, failing if this type is unknown or the key
/// is an invalid length
pub fn new_codec(&self, key: &[u8]) -> io::Result<EncryptionCodec> {
EncryptionCodec::from_type_and_key(*self, key)
}
}
/// Represents the codec that encodes & decodes frames by encrypting/decrypting them
@ -100,6 +107,7 @@ impl EncryptionCodec {
match self {
Self::XChaCha20Poly1305 { .. } => SecretKey::<24>::generate()
.unwrap()
.into_heap_secret_key()
.unprotected_into_bytes(),
}
}

@ -152,6 +152,12 @@ impl From<Vec<u8>> for HeapSecretKey {
}
}
impl<const N: usize> From<[u8; N]> for HeapSecretKey {
fn from(arr: [u8; N]) -> Self {
Self::from(arr.to_vec())
}
}
impl FromStr for HeapSecretKey {
type Err = SecretKeyError;

@ -1,3 +1,4 @@
use super::HeapSecretKey;
use p256::{ecdh::EphemeralSecret, PublicKey};
use rand::rngs::OsRng;
use sha2::Sha256;
@ -9,16 +10,14 @@ pub use pkb::PublicKeyBytes;
mod salt;
pub use salt::Salt;
/// 32-byte key shared by handshake
pub type SharedKey = [u8; 32];
/// Utility to perform a handshake
pub struct Handshake {
/// Utility to support performing an exchange of public keys and salts in order to derive a shared
/// key between two separate entities
pub struct KeyExchange {
secret: EphemeralSecret,
salt: Salt,
}
impl Default for Handshake {
impl Default for KeyExchange {
// Create a new handshake instance with a secret and salt
fn default() -> Self {
let secret = EphemeralSecret::random(&mut OsRng);
@ -28,7 +27,7 @@ impl Default for Handshake {
}
}
impl Handshake {
impl KeyExchange {
// Return encoded bytes of public key
pub fn pk_bytes(&self) -> PublicKeyBytes {
PublicKeyBytes::from(self.secret.public_key())
@ -39,8 +38,13 @@ impl Handshake {
&self.salt
}
pub fn handshake(&self, public_key: PublicKeyBytes, salt: Salt) -> io::Result<SharedKey> {
// Decode the public key of the client
/// Derives a shared secret using another key exchange's public key and salt
pub fn derive_shared_secret(
&self,
public_key: PublicKeyBytes,
salt: Salt,
) -> io::Result<HeapSecretKey> {
// Decode the public key of the other side
let decoded_public_key = PublicKey::try_from(public_key)?;
// Produce a salt that is consistent with what the other side will do
@ -55,7 +59,7 @@ impl Handshake {
// Derive a shared key (32 bytes)
let mut shared_key = [0u8; 32];
match hkdf.expand(&[], &mut shared_key) {
Ok(_) => Ok(shared_key),
Ok(_) => Ok(HeapSecretKey::from(shared_key)),
Err(x) => Err(io::Error::new(io::ErrorKind::InvalidData, x.to_string())),
}
}

@ -0,0 +1,47 @@
use super::{CompressionLevel, CompressionType, EncryptionType};
/// Definition of the handshake to perform for a transport
#[derive(Clone, Debug)]
pub enum Handshake {
/// Indicates that the handshake is being performed from the client-side
Client {
/// Preferred compression algorithm when presented options by server
preferred_compression_type: Option<CompressionType>,
/// Preferred compression level when presented options by server
preferred_compression_level: Option<CompressionLevel>,
/// Preferred encryption algorithm when presented options by server
preferred_encryption_type: Option<EncryptionType>,
},
/// Indicates that the handshake is being performed from the server-side
Server {
/// List of available compression algorithms for use between client and server
compression_types: Vec<CompressionType>,
/// List of available encryption algorithms for use between client and server
encryption_types: Vec<EncryptionType>,
},
}
impl Handshake {
/// Creates a new client handshake definition, providing defaults for the preferred compression
/// type, compression level, and encryption type
pub fn client() -> Self {
Self::Client {
preferred_compression_type: None,
preferred_compression_level: None,
preferred_encryption_type: Some(EncryptionType::XChaCha20Poly1305),
}
}
/// Creates a new server handshake definition, providing defaults for the compression types and
/// encryption types by including all known variants
pub fn server() -> Self {
Self::Server {
compression_types: CompressionType::known_variants().to_vec(),
encryption_types: EncryptionType::known_variants().to_vec(),
}
}
}

@ -1,21 +1,34 @@
use super::{FramedTransport, HeapSecretKey, Reconnectable, Transport};
use async_trait::async_trait;
use std::io;
mod handshake;
pub use handshake::*;
use std::{
io,
ops::{Deref, DerefMut},
};
/// Internal state for our transport
#[derive(Clone, Debug)]
enum State {
/// Transport is not authenticated and has not begun the process of authenticating
NotAuthenticated,
/// Transport is in the state of currently authenticating, either by issuing challenges or
/// responding with answers to challenges
Authenticating,
/// Transport has finished authenticating successfully
Authenticated {
/// Unique key that marks the transport as authenticated for use in shortcutting
/// authentication when the transport reconnects. This is NOT the key used for encryption
/// and is instead meant to be shared (secretly) between transports that are aware of a
/// previously-successful authentication.
key: HeapSecretKey,
handshake_options: HandshakeOptions,
},
}
/// Represents an stateful framed transport that is capable of peforming handshakes and
/// reconnecting using an authenticated state
/// Represents an stateful [`FramedTransport`] that is capable of performing authentication with
/// another [`FramedTransport`] in order to properly encrypt messages by deriving an appropriate
/// encryption codec. When authenticated, reconnecting will skip authentication unless the
/// transport on the other side declines the authenticated state.
#[derive(Clone, Debug)]
pub struct StatefulFramedTransport<T, const CAPACITY: usize> {
inner: FramedTransport<T, CAPACITY>,
@ -31,10 +44,10 @@ impl<T, const CAPACITY: usize> StatefulFramedTransport<T, CAPACITY> {
}
}
/// Performs an authentication handshake, moving the state to be authenticated.
/// Performs authentication with the other side, moving the state to be authenticated.
///
/// Does nothing if already authenticated
pub async fn authenticate(&mut self, handshake_options: HandshakeOptions) -> io::Result<()> {
/// NOTE: Does nothing if already authenticated!
pub async fn authenticate(&mut self) -> io::Result<()> {
if self.is_authenticated() {
return Ok(());
}
@ -42,20 +55,38 @@ impl<T, const CAPACITY: usize> StatefulFramedTransport<T, CAPACITY> {
todo!();
}
/// Returns true if in an authenticated state
/// Returns true if has not started the authentication process
///
/// NOTE: This will return false if in the process of authenticating, but not finished! To
/// check if not authenticated or actively authenticating, use ![`is_authenticated`].
///
/// [`is_authenticated`]: StatefulFramedTransport::is_authenticated
pub fn is_not_authenticated(&self) -> bool {
matches!(self.state, State::NotAuthenticated)
}
/// Returns true if actively authenticating
pub fn is_authenticating(&self) -> bool {
matches!(self.state, State::Authenticating)
}
/// Returns true if has authenticated successfully
pub fn is_authenticated(&self) -> bool {
matches!(self.state, State::Authenticated { .. })
}
}
/// Returns a reference to the [`HandshakeOptions`] used during authentication. Returns `None`
/// if not authenticated.
pub fn handshake_options(&self) -> Option<&HandshakeOptions> {
match &self.state {
State::NotAuthenticated => None,
State::Authenticated {
handshake_options, ..
} => Some(handshake_options),
}
impl<T, const CAPACITY: usize> Deref for StatefulFramedTransport<T, CAPACITY> {
type Target = FramedTransport<T, CAPACITY>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<T, const CAPACITY: usize> DerefMut for StatefulFramedTransport<T, CAPACITY> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
@ -66,11 +97,17 @@ where
{
async fn reconnect(&mut self) -> io::Result<()> {
match self.state {
// If not authenticated, we simply perform a raw reconnect
State::NotAuthenticated => Reconnectable::reconnect(&mut self.inner).await,
// If not authenticated or in the process of authenticating, we simply perform a raw
// reconnect and reset to not being authenticated
State::NotAuthenticated | State::Authenticating => {
self.state = State::NotAuthenticated;
Reconnectable::reconnect(&mut self.inner).await
}
// If authenticated, we perform a reconnect followed by re-authentication using our
// previously-derived key to skip the need to do another authentication
// previously-acquired key to skip the need to do another authentication. Note that
// this can still change the underlying codec used by the transport if an alternative
// compression or encryption codec is picked.
State::Authenticated { key, .. } => {
Reconnectable::reconnect(&mut self.inner).await?;

@ -1,200 +0,0 @@
use super::{
BoxedCodec, ChainCodec, CompressionCodec, CompressionLevel, CompressionType, EncryptionCodec,
EncryptionType, FramedTransport, HeapSecretKey, PlainCodec, Transport,
};
use crate::utils;
use log::*;
use serde::{Deserialize, Serialize};
use std::io;
/// Options from the server representing available methods to configure a framed transport
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct HandshakeServerOptions {
#[serde(rename = "c")]
compression_types: Vec<CompressionType>,
#[serde(rename = "e")]
encryption_types: Vec<EncryptionType>,
}
/// Client choice representing the selected configuration for a framed transport
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct HandshakeClientChoice {
#[serde(rename = "c")]
compression_type: Option<CompressionType>,
#[serde(rename = "cl")]
compression_level: Option<CompressionLevel>,
#[serde(rename = "e")]
encryption_type: Option<EncryptionType>,
}
/// Definition of the handshake to perform for a transport
#[derive(Clone, Debug)]
pub enum HandshakeOptions {
/// Indicates that the handshake is being performed from the client-side
Client {
/// Preferred compression algorithm when presented options by server
preferred_compression_type: Option<CompressionType>,
/// Preferred compression level when presented options by server
preferred_compression_level: Option<CompressionLevel>,
/// Preferred encryption algorithm when presented options by server
preferred_encryption_type: Option<EncryptionType>,
},
/// Indicates that the handshake is being performed from the server-side
Server {
/// List of available compression algorithms for use between client and server
compression_types: Vec<CompressionType>,
/// List of available encryption algorithms for use between client and server
encryption_types: Vec<EncryptionType>,
},
}
impl HandshakeOptions {
/// Creates a new client handshake definition, providing defaults for the preferred compression
/// type, compression level, and encryption type
pub fn client() -> Self {
Self::Client {
preferred_compression_type: None,
preferred_compression_level: None,
preferred_encryption_type: Some(EncryptionType::XChaCha20Poly1305),
}
}
/// Creates a new server handshake definition, providing defaults for the compression types and
/// encryption types by including all known variants
pub fn server() -> Self {
Self::Server {
compression_types: CompressionType::known_variants().to_vec(),
encryption_types: EncryptionType::known_variants().to_vec(),
}
}
}
/// Helper method to perform a handshake
///
/// ### Client
///
/// 1. Wait for options from server
/// 2. Send to server a compression and encryption choice
/// 3. Configure framed transport using selected choices
/// 4. Invoke on_handshake function
///
/// ### Server
///
/// 1. Send options to client
/// 2. Receive choices from client
/// 3. Configure framed transport using client's choices
/// 4. Invoke on_handshake function
///
pub(crate) async fn do_handshake<T, const CAPACITY: usize>(
transport: &mut FramedTransport<T, CAPACITY>,
) -> io::Result<()>
where
T: Transport,
{
// Place transport in plain text communication mode for start of handshake, and clear any data
// that is lingering within internal buffers
transport.set_codec(Box::new(PlainCodec::new()));
transport.clear();
macro_rules! write_frame {
($data:expr) => {{
transport
.write_frame(utils::serialize_to_vec(&$data)?)
.await?
}};
}
macro_rules! next_frame_as {
($type:ty) => {{
let frame = transport.read_frame().await?.ok_or_else(|| {
io::Error::new(io::ErrorKind::UnexpectedEof, "Transport closed early")
})?;
utils::deserialize_from_slice::<$type>(frame.as_item())?
}};
}
match transport.handshake.clone() {
HandshakeOptions::Client {
access_token,
preferred_compression_type,
preferred_compression_level,
preferred_encryption_type,
} => {
// Receive options from the server and pick one
debug!("[Handshake] Client waiting on server options");
let options = next_frame_as!(HandshakeServerOptions);
// Choose a compression and encryption option from the options
debug!("[Handshake] Client selecting from server options: {options:#?}");
let choice = HandshakeClientChoice {
compression_type: preferred_compression_type
.filter(|ty| options.compression_types.contains(ty)),
compression_level: preferred_compression_level,
encryption_type: preferred_encryption_type
.filter(|ty| options.encryption_types.contains(ty)),
};
// Report back to the server the choice
debug!("[Handshake] Client reporting choice: {choice:#?}");
write_frame!(choice);
// Transform the transport's codec to abide by the choice
debug!("[Handshake] Client updating codec based on {choice:#?}");
transform_transport(transport, choice, &access_token)
}
HandshakeOptions::Server {
key,
compression_types,
encryption_types,
} => {
let options = HandshakeServerOptions {
compression_types: compression_types.to_vec(),
encryption_types: encryption_types.to_vec(),
};
// Send options to the client
debug!("[Handshake] Server sending options: {options:#?}");
write_frame!(options);
// Get client's response with selected compression and encryption
debug!("[Handshake] Server waiting on client choice");
let choice = next_frame_as!(HandshakeClientChoice);
// Transform the transport's codec to abide by the choice
debug!("[Handshake] Server updating codec based on {choice:#?}");
transform_transport(transport, choice, &key)
}
}
}
fn transform_transport<T, const CAPACITY: usize>(
transport: &mut FramedTransport<T, CAPACITY>,
choice: HandshakeClientChoice,
secret_key: &HeapSecretKey,
) -> io::Result<()> {
let codec: BoxedCodec = match (choice.compression_type, choice.encryption_type) {
(Some(compression), Some(encryption)) => Box::new(ChainCodec::new(
EncryptionCodec::from_type_and_key(encryption, secret_key.unprotected_as_bytes())?,
CompressionCodec::from_type_and_level(
compression,
choice.compression_level.unwrap_or_default(),
)?,
)),
(None, Some(encryption)) => Box::new(EncryptionCodec::from_type_and_key(
encryption,
secret_key.unprotected_as_bytes(),
)?),
(Some(compression), None) => Box::new(CompressionCodec::from_type_and_level(
compression,
choice.compression_level.unwrap_or_default(),
)?),
(None, None) => Box::new(PlainCodec::new()),
};
Ok(transport.set_codec(codec))
}
Loading…
Cancel
Save