Unfinished authentication support -- and we're pegging cpu again probably still from our loops on read and write frame

pull/146/head
Chip Senkbeil 2 years ago
parent 6dcb943422
commit fd684185cb
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -1,76 +1,7 @@
use super::{FramedTransport, HeapSecretKey, Reconnectable, Transport};
use async_trait::async_trait;
use std::{
collections::HashMap,
io,
ops::{Deref, DerefMut},
};
mod authenticator;
mod data;
pub use data::*;
/// Internal state for a singular transport
#[derive(Clone, Debug)]
enum State {
/// Transport is not authenticated and has not begun the process of authenticating
NotAuthenticated,
/// Transport is in the state of currently authenticating, either by issuing challenges or
/// responding with answers to challenges
Authenticating,
/// Transport has finished authenticating successfully
Authenticated {
/// Unique key that marks the transport as authenticated for use in shortcutting
/// authentication when a transport needs to reconnect. This is NOT the key used for
/// encryption and is instead meant to be shared (secretly) between client-server that are
/// aware of a previously-successful authentication.
key: HeapSecretKey,
},
}
/// Represents a stateful authenticator that is capable of performing authentication with
/// another [`Authenticator`] by communicating through a [`FramedTransport`].
///
/// ### Details
///
/// The authenticator manages a mapping of `ClientId` -> `Key` upon successful authentication which
/// can be used to verify re-authentication without needing to perform full authentication again.
/// This is particularly useful in re-connecting a `FramedTransport` post-handshake after a network
/// disruption.
#[derive(Clone, Debug)]
pub struct Authenticator {
authenticated: HashMap<String, State>,
}
mod handler;
impl Authenticator {
pub fn new() -> Self {
Self {
authenticated: HashMap::new(),
}
}
/// Performs authentication with the other side, moving the state to be authenticated.
pub async fn authenticate<T, const CAPACITY: usize>(
&mut self,
transport: &mut FramedTransport<T, CAPACITY>,
authentication: Authentication,
) -> io::Result<()> {
if self.is_authenticated() {
return Ok(());
}
todo!();
}
/// Clears out any tracked clients
pub fn clear(&mut self) {
self.authenticated.clear();
}
}
impl Default for Authenticator {
fn default() -> Self {
Self::new()
}
}
pub use authenticator::*;
pub use data::*;
pub use handler::*;

@ -0,0 +1,186 @@
use super::{data::*, AuthHandler};
use crate::{utils, FramedTransport, Transport};
use async_trait::async_trait;
use log::*;
use std::{collections::HashMap, io};
/// Represents an interface for authenticating or submitting challenges for authentication.
#[async_trait]
pub trait Authenticator: Send {
/// Performs authentication by leveraging the `handler` for any received challenge.
async fn authenticate(&mut self, mut handler: impl AuthHandler + Send) -> io::Result<()>;
/// Issues a challenge and returns the answers to the `questions` asked.
async fn challenge(
&mut self,
questions: Vec<AuthQuestion>,
options: HashMap<String, String>,
) -> io::Result<Vec<String>>;
/// Requests verification of some `kind` and `text`, returning true if passed verification.
async fn verify(&mut self, kind: AuthVerifyKind, text: String) -> io::Result<bool>;
/// Reports information with no response expected.
async fn info(&mut self, text: String) -> io::Result<()>;
/// Reports an error occurred during authentication, consuming the authenticator since no more
/// challenges should be issued.
async fn error(self, kind: AuthErrorKind, text: String) -> io::Result<()>;
/// Reports that the authentication has finished successfully, consuming the authenticator
/// since no more challenges should be issued.
async fn finished(self) -> io::Result<()>;
}
/// Wraps a [`FramedTransport`] in order to perform challenge-based communication through the
/// transport to authenticate it. The authenticator is capable of conducting challenges or
/// leveraging an [`AuthHandler`] to process challenges.
pub struct FramedAuthenticator<'a, T: Send, const CAPACITY: usize> {
transport: &'a mut FramedTransport<T, CAPACITY>,
}
impl<'a, T: Send, const CAPACITY: usize> FramedAuthenticator<'a, T, CAPACITY> {
pub fn new(transport: &'a mut FramedTransport<T, CAPACITY>) -> Self {
Self { transport }
}
}
macro_rules! write_frame {
($transport:expr, $data:expr) => {{
$transport
.write_frame(utils::serialize_to_vec(&$data)?)
.await?
}};
}
macro_rules! next_frame_as {
($transport:expr, $type:ident, $variant:ident) => {{
match { next_frame_as!($transport, $type) } {
$type::$variant(x) => x,
x => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Unexpected frame: {x:?}"),
))
}
}
}};
($transport:expr, $type:ident) => {{
let frame = $transport.read_frame().await?.ok_or_else(|| {
io::Error::new(io::ErrorKind::UnexpectedEof, "Transport closed early")
})?;
utils::deserialize_from_slice::<$type>(frame.as_item())?
}};
}
#[async_trait]
impl<'a, T, const CAPACITY: usize> Authenticator for FramedAuthenticator<'a, T, CAPACITY>
where
T: Transport + Send + Sync,
{
/// Performs authentication by leveraging the `handler` for any received challenge.
async fn authenticate(&mut self, mut handler: impl AuthHandler + Send) -> io::Result<()> {
loop {
match next_frame_as!(self.transport, AuthRequest) {
AuthRequest::Challenge(x) => {
let answers = handler.on_challenge(x.questions, x.options).await?;
write_frame!(
self.transport,
AuthResponse::Challenge(AuthChallengeResponse { answers })
);
}
AuthRequest::Verify(x) => {
let valid = handler.on_verify(x.kind, x.text).await?;
write_frame!(
self.transport,
AuthResponse::Verify(AuthVerifyResponse { valid })
);
}
AuthRequest::Info(x) => {
handler.on_info(x.text).await?;
}
AuthRequest::Error(x) => {
let kind = x.kind;
let text = x.text;
handler.on_error(kind, &text).await?;
return Err(match kind {
AuthErrorKind::FailedChallenge => io::Error::new(
io::ErrorKind::PermissionDenied,
format!("Failed challenge: {text}"),
),
AuthErrorKind::FailedVerification => io::Error::new(
io::ErrorKind::PermissionDenied,
format!("Failed verification: {text}"),
),
AuthErrorKind::Unknown => {
io::Error::new(io::ErrorKind::Other, format!("Unknown error: {text}"))
}
});
}
AuthRequest::Finished => return Ok(()),
}
}
}
/// Issues a challenge and returns the answers to the `questions` asked.
async fn challenge(
&mut self,
questions: Vec<AuthQuestion>,
options: HashMap<String, String>,
) -> io::Result<Vec<String>> {
trace!(
"Authenticator::challenge(questions = {:?}, options = {:?})",
questions,
options
);
write_frame!(
self.transport,
AuthRequest::from(AuthChallengeRequest { questions, options })
);
let response = next_frame_as!(self.transport, AuthResponse, Challenge);
Ok(response.answers)
}
/// Requests verification of some `kind` and `text`, returning true if passed verification.
async fn verify(&mut self, kind: AuthVerifyKind, text: String) -> io::Result<bool> {
trace!(
"Authenticator::verify(kind = {:?}, text = {:?})",
kind,
text
);
write_frame!(
self.transport,
AuthRequest::from(AuthVerifyRequest { kind, text })
);
let response = next_frame_as!(self.transport, AuthResponse, Verify);
Ok(response.valid)
}
/// Reports information with no response expected.
async fn info(&mut self, text: String) -> io::Result<()> {
trace!("Authenticator::info(text = {:?})", text);
write_frame!(self.transport, AuthRequest::from(AuthInfo { text }));
Ok(())
}
/// Reports an error occurred during authentication, consuming the authenticator since no more
/// challenges should be issued.
async fn error(self, kind: AuthErrorKind, text: String) -> io::Result<()> {
trace!("Authenticator::error(kind = {:?}, text = {:?})", kind, text);
write_frame!(self.transport, AuthRequest::from(AuthError { kind, text }));
Ok(())
}
/// Reports that the authentication has finished successfully, consuming the authenticator
/// since no more challenges should be issued.
async fn finished(self) -> io::Result<()> {
trace!("Authenticator::finished()");
write_frame!(self.transport, AuthRequest::Finished);
Ok(())
}
}

@ -1,60 +1,26 @@
use async_trait::async_trait;
use derive_more::Display;
use derive_more::{Display, From};
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, io};
/// Interface for a handler of authentication requests
#[async_trait]
pub trait AuthHandler {
/// Callback when a challenge is received, returning answers to the given questions.
async fn on_challenge(
&mut self,
questions: Vec<AuthQuestion>,
options: HashMap<String, String>,
) -> io::Result<Vec<String>>;
/// Callback when a verification request is received, returning true if approvided or false if
/// unapproved.
async fn on_verify(&mut self, kind: AuthVerifyKind, text: String) -> io::Result<bool>;
/// Callback when authentication is finished and no more requests will be received
async fn on_done(&mut self) -> io::Result<()> {
Ok(())
}
/// Callback when information is received. To fail, return an error from this function.
#[allow(unused_variables)]
async fn on_info(&mut self, text: String) -> io::Result<()> {
Ok(())
}
/// Callback when an error is received. To fail, return an error from this function.
async fn on_error(&mut self, kind: AuthErrorKind, text: String) -> io::Result<()> {
Err(match kind {
AuthErrorKind::FailedChallenge => io::Error::new(
io::ErrorKind::PermissionDenied,
format!("Failed challenge: {text}"),
),
AuthErrorKind::FailedVerification => io::Error::new(
io::ErrorKind::PermissionDenied,
format!("Failed verification: {text}"),
),
AuthErrorKind::Unknown => {
io::Error::new(io::ErrorKind::Other, format!("Unknown error: {text}"))
}
})
}
}
use std::collections::HashMap;
/// Represents authentication messages that act as initiators such as providing
/// a challenge, verifying information, presenting information, or highlighting an error
#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, From, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum AuthRequest {
/// Issues a challenge to be answered
Challenge(AuthChallengeRequest),
/// Requests verification of some text
Verify(AuthVerifyRequest),
/// Reports some information
Info(AuthInfo),
/// Reports an error occurrred
Error(AuthError),
/// Indicates that the authentication is finished
Finished,
}
/// Represents a challenge comprising a series of questions to be presented
@ -86,10 +52,13 @@ pub struct AuthError {
/// Represents authentication messages that are responses to auth requests such
/// as answers to challenges or verifying information
#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, From, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum AuthResponse {
/// Contains answers to challenge request
Challenge(AuthChallengeResponse),
/// Contains response to a verification request
Verify(AuthVerifyResponse),
}
@ -113,6 +82,11 @@ pub enum AuthVerifyKind {
/// An ask to verify the host such as with SSH
#[display(fmt = "host")]
Host,
/// When the verification is unknown (happens when other side is unaware of the kind)
#[display(fmt = "unknown")]
#[serde(other)]
Unknown,
}
/// Represents a single question in a challenge

@ -0,0 +1,64 @@
use super::data::*;
use async_trait::async_trait;
use std::{collections::HashMap, io};
/// Interface for a handler of authentication requests
#[async_trait]
pub trait AuthHandler {
/// Callback when a challenge is received, returning answers to the given questions.
async fn on_challenge(
&mut self,
questions: Vec<AuthQuestion>,
options: HashMap<String, String>,
) -> io::Result<Vec<String>>;
/// Callback when a verification request is received, returning true if approvided or false if
/// unapproved.
async fn on_verify(&mut self, kind: AuthVerifyKind, text: String) -> io::Result<bool>;
/// Callback when authentication is finished and no more requests will be received
async fn on_finished(&mut self) -> io::Result<()> {
Ok(())
}
/// Callback when information is received. To fail, return an error from this function.
#[allow(unused_variables)]
async fn on_info(&mut self, text: String) -> io::Result<()> {
Ok(())
}
/// Callback when an error is received. Regardless of the result returned, this will terminate
/// the authenticator. In the situation where a custom error would be preferred, have this
/// callback return an error.
#[allow(unused_variables)]
async fn on_error(&mut self, kind: AuthErrorKind, text: &str) -> io::Result<()> {
Ok(())
}
}
#[async_trait]
impl<H: AuthHandler + Send> AuthHandler for &mut H {
async fn on_challenge(
&mut self,
questions: Vec<AuthQuestion>,
options: HashMap<String, String>,
) -> io::Result<Vec<String>> {
AuthHandler::on_challenge(self, questions, options).await
}
async fn on_verify(&mut self, kind: AuthVerifyKind, text: String) -> io::Result<bool> {
AuthHandler::on_verify(self, kind, text).await
}
async fn on_finished(&mut self) -> io::Result<()> {
AuthHandler::on_finished(self).await
}
async fn on_info(&mut self, text: String) -> io::Result<()> {
AuthHandler::on_info(self, text).await
}
async fn on_error(&mut self, kind: AuthErrorKind, text: &str) -> io::Result<()> {
AuthHandler::on_error(self, kind, text).await
}
}

@ -1,7 +1,4 @@
use crate::{
FramedTransport, Interest, Reconnectable, Request, StatefulFramedTransport, Transport,
UntypedResponse,
};
use crate::{FramedTransport, Interest, Reconnectable, Request, Transport, UntypedResponse};
use async_trait::async_trait;
use log::*;
use serde::{de::DeserializeOwned, Serialize};
@ -41,8 +38,13 @@ where
T: Send + Sync + Serialize + 'static,
U: Send + Sync + DeserializeOwned + 'static,
{
/// Initializes a client using the provided [`FramedTransport`]
pub fn new<V, const CAPACITY: usize>(transport: FramedTransport<V, CAPACITY>) -> Self
/// Creates a client using the provided [`FramedTransport`].
///
/// ### Note
///
/// It is assumed that the provided transport has performed any necessary handshake and is
/// fully authenticated.
pub fn new<V, const CAPACITY: usize>(mut transport: FramedTransport<V, CAPACITY>) -> Self
where
V: Transport + Send + Sync + 'static,
{
@ -52,16 +54,9 @@ where
let (reconnect_tx, mut reconnect_rx) = mpsc::channel::<oneshot::Sender<io::Result<()>>>(1);
let (shutdown_tx, mut shutdown_rx) = mpsc::channel(1);
let mut transport = StatefulFramedTransport::new(transport);
// Start a task that continually checks for responses and delivers them using the
// post office
let task = tokio::spawn(async move {
transport
.authenticate()
.await
.expect("Failed to authenticate with the remote server");
loop {
let ready = tokio::select! {
_ = shutdown_rx.recv() => {

@ -1,5 +1,5 @@
mod any;
mod auth;
pub mod auth;
mod client;
mod id;
mod listener;
@ -10,7 +10,6 @@ mod transport;
mod utils;
pub use any::*;
pub use auth::*;
pub use client::*;
pub use id::*;
pub use listener::*;

@ -1,3 +1,4 @@
use crate::auth::Authenticator;
use async_trait::async_trait;
use serde::{de::DeserializeOwned, Serialize};
use std::io;
@ -40,11 +41,21 @@ pub trait Server: Send {
ServerConfig::default()
}
/// Invoked upon a new connection becoming established, which provides a mutable reference to
/// the data created for the connection. This can be useful in performing some additional
/// initialization on the data prior to it being used anywhere else.
#[allow(unused_variables)]
async fn on_accept(&self, local_data: &mut Self::LocalData) {}
/// Invoked upon a new connection becoming established.
///
/// ### Note
///
/// This can be useful in performing some additional initialization on the connection's local
/// data prior to it being used anywhere else.
///
/// Additionally, the context contains an authenticator which can be used to issue challenges
/// to the connection to validate its access.
async fn on_accept<A: Authenticator>(
&self,
ctx: ConnectionCtx<'_, A, Self::LocalData>,
) -> io::Result<()> {
ctx.authenticator.finished().await
}
/// Invoked upon receiving a request from a client. The server should process this
/// request, which can be found in `ctx`, and send one or more replies in response.

@ -1,4 +1,4 @@
use crate::{ConnectionId, Request, ServerReply};
use crate::{auth::Authenticator, ConnectionId, Request, ServerReply};
use std::sync::Arc;
/// Represents contextual information for working with an inbound request
@ -15,3 +15,18 @@ pub struct ServerCtx<RequestData, ResponseData, LocalData> {
/// Reference to the connection's local data
pub local_data: Arc<LocalData>,
}
/// Represents contextual information for working with an inbound connection
pub struct ConnectionCtx<'a, A, D>
where
A: Authenticator,
{
/// Unique identifer associated with the connection
pub connection_id: ConnectionId,
/// Authenticator to use to issue challenges to the connection to ensure it is valid
pub authenticator: A,
/// Reference to the connection's local data
pub local_data: &'a mut D,
}

@ -1,7 +1,7 @@
use crate::{
utils::Timer, ConnectionId, FramedTransport, GenericServerRef, Interest, Listener, Response,
Server, ServerConnection, ServerCtx, ServerRef, ServerReply, ServerState, Shutdown, Transport,
UntypedRequest,
auth::FramedAuthenticator, utils::Timer, ConnectionCtx, ConnectionId, FramedTransport,
GenericServerRef, Interest, Listener, Response, Server, ServerConnection, ServerCtx, ServerRef,
ServerReply, ServerState, Shutdown, Transport, UntypedRequest,
};
use log::*;
use serde::{de::DeserializeOwned, Serialize};
@ -148,21 +148,12 @@ where
timer.lock().await.stop();
}
// Create some default data for the new connection and pass it
// to the callback prior to processing new requests
let local_data = {
let mut data = S::LocalData::default();
server.on_accept(&mut data).await;
Arc::new(data)
};
connection.task = Some(
ConnectionTask {
id: connection_id,
server,
state: Arc::downgrade(&state),
transport,
local_data,
shutdown_timer: shutdown_timer
.as_ref()
.map(Arc::downgrade)
@ -179,21 +170,20 @@ where
}
}
struct ConnectionTask<S, T, D> {
struct ConnectionTask<S, T> {
id: ConnectionId,
server: Arc<S>,
state: Weak<ServerState>,
transport: T,
local_data: Arc<D>,
shutdown_timer: Weak<Mutex<Timer<()>>>,
}
impl<S, T, D> ConnectionTask<S, T, D>
impl<S, T> ConnectionTask<S, T>
where
S: Server<LocalData = D> + Sync + 'static,
S: Server + Sync + 'static,
S::Request: DeserializeOwned + Send + Sync + 'static,
S::Response: Serialize + Send + 'static,
D: Default + Send + Sync + 'static,
S::LocalData: Default + Send + Sync + 'static,
T: Transport + Send + Sync + 'static,
{
pub fn spawn(self) -> JoinHandle<()> {
@ -213,6 +203,24 @@ where
return;
}
// Create local data for the connection and then process it as well as perform
// authentication and any other tasks on first connecting
let mut local_data = S::LocalData::default();
if let Err(x) = self
.server
.on_accept(ConnectionCtx {
connection_id,
authenticator: FramedAuthenticator::new(&mut transport),
local_data: &mut local_data,
})
.await
{
error!("[Conn {connection_id}] Accepting connection failed: {x}");
return;
}
let local_data = Arc::new(local_data);
loop {
let ready = transport
.ready(Interest::READABLE | Interest::WRITABLE)
@ -233,7 +241,7 @@ where
connection_id,
request,
reply: reply.clone(),
local_data: Arc::clone(&self.local_data),
local_data: Arc::clone(&local_data),
};
self.server.on_request(ctx).await;

@ -1,4 +1,4 @@
use crate::{ConnectionId, ServerConnection};
use crate::{ConnectionId, HeapSecretKey, ServerConnection};
use std::collections::HashMap;
use tokio::sync::RwLock;
@ -6,12 +6,16 @@ use tokio::sync::RwLock;
pub struct ServerState {
/// Mapping of connection ids to their transports
pub connections: RwLock<HashMap<ConnectionId, ServerConnection>>,
/// Mapping of connection ids to their authenticated keys
pub authenticated: RwLock<HashMap<ConnectionId, HeapSecretKey>>,
}
impl ServerState {
pub fn new() -> Self {
Self {
connections: RwLock::new(HashMap::new()),
authenticated: RwLock::new(HashMap::new()),
}
}
}

@ -218,9 +218,17 @@ where
///
/// [`ErrorKind::WriteZero`]: io::ErrorKind::WriteZero
/// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock
pub fn try_write_frame<'a>(&mut self, frame: impl Into<Frame<'a>>) -> io::Result<()> {
pub fn try_write_frame<'a, F>(&mut self, frame: F) -> io::Result<()>
where
F: TryInto<Frame<'a>>,
F::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
// Encode the frame and store it in our outgoing queue
let frame = self.codec.encode(frame.into())?;
let frame = self.codec.encode(
frame
.try_into()
.map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x))?,
)?;
frame.write(&mut self.outgoing)?;
// Attempt to write everything in our queue
@ -234,7 +242,11 @@ where
/// [`try_write_frame`]: FramedTransport::try_write_frame
/// [`try_flush`]: FramedTransport::try_flush
/// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock
pub async fn write_frame<'a>(&mut self, frame: impl Into<Frame<'a>>) -> io::Result<()> {
pub async fn write_frame<'a, F>(&mut self, frame: F) -> io::Result<()>
where
F: TryInto<Frame<'a>>,
F::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
self.writeable().await?;
match self.try_write_frame(frame) {

Loading…
Cancel
Save