mirror of https://github.com/chipsenkbeil/distant
Progress towards stateful framed transport
parent
a2ec96e556
commit
e95370589f
@ -1,44 +0,0 @@
|
||||
use super::{HandshakeClientChoice, HandshakeServerOptions};
|
||||
use std::fmt;
|
||||
|
||||
/// Callback invoked when a client receives server options during a handshake
|
||||
pub struct OnHandshakeClientChoice(
|
||||
pub(super) Box<dyn Fn(HandshakeServerOptions) -> HandshakeClientChoice>,
|
||||
);
|
||||
|
||||
impl OnHandshakeClientChoice {
|
||||
/// Wraps a function `f` as a callback
|
||||
pub fn new<F>(f: F) -> Self
|
||||
where
|
||||
F: Fn(HandshakeServerOptions) -> HandshakeClientChoice,
|
||||
{
|
||||
Self(Box::new(f))
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> From<F> for OnHandshakeClientChoice
|
||||
where
|
||||
F: Fn(HandshakeServerOptions) -> HandshakeClientChoice,
|
||||
{
|
||||
fn from(f: F) -> Self {
|
||||
Self::new(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for OnHandshakeClientChoice {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("OnHandshakeClientChoice").finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OnHandshakeClientChoice {
|
||||
/// Implements choice selection that picks first available of encryption and nothing of
|
||||
/// compression
|
||||
fn default() -> Self {
|
||||
Self::new(|options| HandshakeClientChoice {
|
||||
compression: None,
|
||||
compression_level: None,
|
||||
encryption: options.encryption.first().copied(),
|
||||
})
|
||||
}
|
||||
}
|
@ -1,44 +0,0 @@
|
||||
use super::FramedTransport;
|
||||
use std::{fmt, future::Future, io, pin::Pin};
|
||||
|
||||
/// Boxed function representing `on_handshake` callback
|
||||
pub type BoxedOnHandshakeFn<T, const CAPACITY: usize> = Box<
|
||||
dyn FnMut(&mut FramedTransport<T, CAPACITY>) -> Pin<Box<dyn Future<Output = io::Result<()>>>>,
|
||||
>;
|
||||
|
||||
/// Callback invoked when a handshake occurs
|
||||
pub struct OnHandshake<T, const CAPACITY: usize>(pub(super) BoxedOnHandshakeFn<T, CAPACITY>);
|
||||
|
||||
impl<T, const CAPACITY: usize> OnHandshake<T, CAPACITY> {
|
||||
/// Wraps a function `f` as a callback for a handshake
|
||||
pub fn new<F>(f: F) -> Self
|
||||
where
|
||||
F: FnMut(
|
||||
&mut FramedTransport<T, CAPACITY>,
|
||||
) -> Pin<Box<dyn Future<Output = io::Result<()>>>>,
|
||||
{
|
||||
Self(Box::new(f))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, F, const CAPACITY: usize> From<F> for OnHandshake<T, CAPACITY>
|
||||
where
|
||||
F: FnMut(&mut FramedTransport<T, CAPACITY>) -> Pin<Box<dyn Future<Output = io::Result<()>>>>,
|
||||
{
|
||||
fn from(f: F) -> Self {
|
||||
Self::new(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, const CAPACITY: usize> fmt::Debug for OnHandshake<T, CAPACITY> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("OnHandshake").finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, const CAPACITY: usize> Default for OnHandshake<T, CAPACITY> {
|
||||
/// Implements handshake callback that does nothing
|
||||
fn default() -> Self {
|
||||
Self::new(|_| Box::pin(async { Ok(()) }))
|
||||
}
|
||||
}
|
@ -0,0 +1,81 @@
|
||||
use super::{FramedTransport, HeapSecretKey, Reconnectable, Transport};
|
||||
use async_trait::async_trait;
|
||||
use std::io;
|
||||
|
||||
mod handshake;
|
||||
pub use handshake::*;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
enum State {
|
||||
NotAuthenticated,
|
||||
Authenticated {
|
||||
key: HeapSecretKey,
|
||||
handshake_options: HandshakeOptions,
|
||||
},
|
||||
}
|
||||
|
||||
/// Represents an stateful framed transport that is capable of peforming handshakes and
|
||||
/// reconnecting using an authenticated state
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct StatefulFramedTransport<T, const CAPACITY: usize> {
|
||||
inner: FramedTransport<T, CAPACITY>,
|
||||
state: State,
|
||||
}
|
||||
|
||||
impl<T, const CAPACITY: usize> StatefulFramedTransport<T, CAPACITY> {
|
||||
/// Creates a new stateful framed transport that is not yet authenticated
|
||||
pub fn new(inner: FramedTransport<T, CAPACITY>) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
state: State::NotAuthenticated,
|
||||
}
|
||||
}
|
||||
|
||||
/// Performs an authentication handshake, moving the state to be authenticated.
|
||||
///
|
||||
/// Does nothing if already authenticated
|
||||
pub async fn authenticate(&mut self, handshake_options: HandshakeOptions) -> io::Result<()> {
|
||||
if self.is_authenticated() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
todo!();
|
||||
}
|
||||
|
||||
/// Returns true if in an authenticated state
|
||||
pub fn is_authenticated(&self) -> bool {
|
||||
matches!(self.state, State::Authenticated { .. })
|
||||
}
|
||||
|
||||
/// Returns a reference to the [`HandshakeOptions`] used during authentication. Returns `None`
|
||||
/// if not authenticated.
|
||||
pub fn handshake_options(&self) -> Option<&HandshakeOptions> {
|
||||
match &self.state {
|
||||
State::NotAuthenticated => None,
|
||||
State::Authenticated {
|
||||
handshake_options, ..
|
||||
} => Some(handshake_options),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T, const CAPACITY: usize> Reconnectable for StatefulFramedTransport<T, CAPACITY>
|
||||
where
|
||||
T: Transport + Send + Sync,
|
||||
{
|
||||
async fn reconnect(&mut self) -> io::Result<()> {
|
||||
match self.state {
|
||||
// If not authenticated, we simply perform a raw reconnect
|
||||
State::NotAuthenticated => Reconnectable::reconnect(&mut self.inner).await,
|
||||
|
||||
// If authenticated, we perform a reconnect followed by re-authentication using our
|
||||
// previously-derived key to skip the need to do another authentication
|
||||
State::Authenticated { key, .. } => {
|
||||
Reconnectable::reconnect(&mut self.inner).await?;
|
||||
|
||||
todo!("do handshake with key");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,106 +0,0 @@
|
||||
use super::{Interest, Ready, Reconnectable, TypedTransport};
|
||||
use async_trait::async_trait;
|
||||
use std::{io, sync::Mutex};
|
||||
use tokio::sync::mpsc::{
|
||||
self,
|
||||
error::{TryRecvError, TrySendError},
|
||||
};
|
||||
|
||||
/// Represents a [`TypedTransport`] of data across the network that uses tokio's mpsc [`Sender`]
|
||||
/// and [`Receiver`] underneath.
|
||||
///
|
||||
/// [`Sender`]: mpsc::Sender
|
||||
/// [`Receiver`]: mpsc::Receiver
|
||||
#[derive(Debug)]
|
||||
pub struct InmemoryTypedTransport<T, U> {
|
||||
tx: mpsc::Sender<T>,
|
||||
rx: Mutex<mpsc::Receiver<U>>,
|
||||
}
|
||||
|
||||
impl<T, U> InmemoryTypedTransport<T, U> {
|
||||
pub fn new(tx: mpsc::Sender<T>, rx: mpsc::Receiver<U>) -> Self {
|
||||
Self {
|
||||
tx,
|
||||
rx: Mutex::new(rx),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a pair of connected transports using `buffer` as maximum
|
||||
/// channel capacity for each
|
||||
pub fn pair(buffer: usize) -> (InmemoryTypedTransport<T, U>, InmemoryTypedTransport<U, T>) {
|
||||
let (t_tx, t_rx) = mpsc::channel(buffer);
|
||||
let (u_tx, u_rx) = mpsc::channel(buffer);
|
||||
(
|
||||
InmemoryTypedTransport::new(t_tx, u_rx),
|
||||
InmemoryTypedTransport::new(u_tx, t_rx),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T, U> Reconnectable for InmemoryTypedTransport<T, U>
|
||||
where
|
||||
T: Send,
|
||||
U: Send,
|
||||
{
|
||||
/// Once the underlying channels have closed, there is no way for this transport to
|
||||
/// re-establish those channels; therefore, reconnecting will always fail with
|
||||
/// [`ErrorKind::Unsupported`]
|
||||
///
|
||||
/// [`ErrorKind::Unsupported`]: io::ErrorKind::Unsupported
|
||||
async fn reconnect(&mut self) -> io::Result<()> {
|
||||
Err(io::Error::from(io::ErrorKind::Unsupported))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T, U> TypedTransport for InmemoryTypedTransport<T, U>
|
||||
where
|
||||
T: Send,
|
||||
U: Send,
|
||||
{
|
||||
type Input = U;
|
||||
type Output = T;
|
||||
|
||||
fn try_read(&self) -> io::Result<Option<Self::Input>> {
|
||||
match self.rx.lock().unwrap().try_recv() {
|
||||
Ok(x) => Ok(Some(x)),
|
||||
Err(TryRecvError::Empty) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
|
||||
Err(TryRecvError::Disconnected) => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn try_write(&self, value: Self::Output) -> io::Result<()> {
|
||||
match self.tx.try_send(value) {
|
||||
Ok(()) => Ok(()),
|
||||
Err(TrySendError::Full(_)) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
|
||||
Err(TrySendError::Closed(_)) => Err(io::Error::from(io::ErrorKind::BrokenPipe)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn ready(&self, interest: Interest) -> io::Result<Ready> {
|
||||
let mut status = Ready::EMPTY;
|
||||
|
||||
if interest.is_readable() {
|
||||
// TODO: Replace `self.is_rx_closed()` with `self.rx.is_closed()` once the tokio issue
|
||||
// is resolved that adds `is_closed` to the `mpsc::Receiver`
|
||||
//
|
||||
// See https://github.com/tokio-rs/tokio/issues/4638
|
||||
status |= if self.is_rx_closed() && self.buf.lock().unwrap().is_none() {
|
||||
Ready::READ_CLOSED
|
||||
} else {
|
||||
Ready::READABLE
|
||||
};
|
||||
}
|
||||
|
||||
if interest.is_writable() {
|
||||
status |= if self.tx.is_closed() {
|
||||
Ready::WRITE_CLOSED
|
||||
} else {
|
||||
Ready::WRITABLE
|
||||
};
|
||||
}
|
||||
|
||||
Ok(status)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue