Still working on it

pull/146/head
Chip Senkbeil 2 years ago
parent 7cf4c39ac8
commit cf48d48b03
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -184,20 +184,13 @@ mod tests {
use super::*;
use crate::data::ChangeKind;
use crate::DistantClient;
use distant_net::{
Client, FramedTransport, InmemoryTransport, IntoSplit, PlainCodec, Response,
TypedAsyncRead, TypedAsyncWrite,
};
use distant_net::{Client, FramedTransport, InmemoryTransport, Response};
use std::sync::Arc;
use tokio::sync::Mutex;
fn make_session() -> (
FramedTransport<InmemoryTransport, PlainCodec>,
DistantClient,
) {
let (t1, t2) = FramedTransport::pair(100);
let (writer, reader) = t2.into_split();
(t1, Client::new(writer, reader).unwrap())
fn make_session() -> (FramedTransport<InmemoryTransport>, DistantClient) {
let (t1, t2) = FramedTransport::test_pair(100);
(t1, Client::new(t2))
}
#[tokio::test]

@ -1,122 +0,0 @@
use derive_more::Display;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
mod client;
pub use client::*;
mod handshake;
pub use handshake::*;
mod server;
pub use server::*;
/// Represents authentication messages that can be sent over the wire
///
/// NOTE: Must use serde's content attribute with the tag attribute. Just the tag attribute will
/// cause deserialization to fail
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type", content = "data")]
pub enum Auth {
/// Represents a request to perform an authentication handshake,
/// providing the public key and salt from one side in order to
/// derive the shared key
#[serde(rename = "auth_handshake")]
Handshake {
/// Bytes of the public key
#[serde(with = "serde_bytes")]
public_key: PublicKeyBytes,
/// Randomly generated salt
#[serde(with = "serde_bytes")]
salt: Salt,
},
/// Represents the bytes of an encrypted message
///
/// Underneath, will be one of either [`AuthRequest`] or [`AuthResponse`]
#[serde(rename = "auth_msg")]
Msg {
#[serde(with = "serde_bytes")]
encrypted_payload: Vec<u8>,
},
}
/// Represents authentication messages that act as initiators such as providing
/// a challenge, verifying information, presenting information, or highlighting an error
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum AuthRequest {
/// Represents a challenge comprising a series of questions to be presented
Challenge {
questions: Vec<AuthQuestion>,
options: HashMap<String, String>,
},
/// Represents an ask to verify some information
Verify { kind: AuthVerifyKind, text: String },
/// Represents some information to be presented
Info { text: String },
/// Represents some error that occurred
Error { kind: AuthErrorKind, text: String },
}
/// Represents authentication messages that are responses to auth requests such
/// as answers to challenges or verifying information
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum AuthResponse {
/// Represents the answers to a previously-asked challenge
Challenge { answers: Vec<String> },
/// Represents the answer to a previously-asked verify
Verify { valid: bool },
}
/// Represents the type of verification being requested
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum AuthVerifyKind {
/// An ask to verify the host such as with SSH
#[display(fmt = "host")]
Host,
}
/// Represents a single question in a challenge
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct AuthQuestion {
/// The text of the question
pub text: String,
/// Any options information specific to a particular auth domain
/// such as including a username and instructions for SSH authentication
pub options: HashMap<String, String>,
}
impl AuthQuestion {
/// Creates a new question without any options data
pub fn new(text: impl Into<String>) -> Self {
Self {
text: text.into(),
options: HashMap::new(),
}
}
}
/// Represents the type of error encountered during authentication
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AuthErrorKind {
/// When the answer(s) to a challenge do not pass authentication
FailedChallenge,
/// When verification during authentication fails
/// (e.g. a host is not allowed or blocked)
FailedVerification,
/// When the error is unknown
Unknown,
}

@ -1,817 +0,0 @@
use crate::{
utils, Auth, AuthErrorKind, AuthQuestion, AuthRequest, AuthResponse, AuthVerifyKind, Client,
Codec, Handshake, XChaCha20Poly1305Codec,
};
use bytes::BytesMut;
use log::*;
use std::{collections::HashMap, io};
pub struct AuthClient {
inner: Client<Auth, Auth>,
codec: Option<XChaCha20Poly1305Codec>,
jit_handshake: bool,
}
impl From<Client<Auth, Auth>> for AuthClient {
fn from(client: Client<Auth, Auth>) -> Self {
Self {
inner: client,
codec: None,
jit_handshake: false,
}
}
}
impl AuthClient {
/// Sends a request to the server to establish an encrypted connection
pub async fn handshake(&mut self) -> io::Result<()> {
let handshake = Handshake::default();
let response = self
.inner
.send(Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
})
.await?;
match response.payload {
Auth::Handshake { public_key, salt } => {
let key = handshake.handshake(public_key, salt)?;
self.codec.replace(XChaCha20Poly1305Codec::new(&key));
Ok(())
}
Auth::Msg { .. } => Err(io::Error::new(
io::ErrorKind::Other,
"Got unexpected encrypted message during handshake",
)),
}
}
/// Perform a handshake only if jit is enabled and no handshake has succeeded yet
async fn jit_handshake(&mut self) -> io::Result<()> {
if self.will_jit_handshake() && !self.is_ready() {
self.handshake().await
} else {
Ok(())
}
}
/// Returns true if client has successfully performed a handshake
/// and is ready to communicate with the server
pub fn is_ready(&self) -> bool {
self.codec.is_some()
}
/// Returns true if this client will perform a handshake just-in-time (JIT) prior to making a
/// request in the scenario where the client has not already performed a handshake
#[inline]
pub fn will_jit_handshake(&self) -> bool {
self.jit_handshake
}
/// Sets the jit flag on this client with `true` indicating that this client will perform a
/// handshake just-in-time (JIT) prior to making a request in the scenario where the client has
/// not already performed a handshake
#[inline]
pub fn set_jit_handshake(&mut self, flag: bool) {
self.jit_handshake = flag;
}
/// Provides a challenge to the server and returns the answers to the questions
/// asked by the client
pub async fn challenge(
&mut self,
questions: Vec<AuthQuestion>,
options: HashMap<String, String>,
) -> io::Result<Vec<String>> {
trace!(
"AuthClient::challenge(questions = {:?}, options = {:?})",
questions,
options
);
// Perform JIT handshake if enabled
self.jit_handshake().await?;
let payload = AuthRequest::Challenge { questions, options };
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
let response = self.inner.send(Auth::Msg { encrypted_payload }).await?;
match response.payload {
Auth::Msg { encrypted_payload } => {
match self.decrypt_and_deserialize(&encrypted_payload)? {
AuthResponse::Challenge { answers } => Ok(answers),
AuthResponse::Verify { .. } => Err(io::Error::new(
io::ErrorKind::Other,
"Got unexpected verify response during challenge",
)),
}
}
Auth::Handshake { .. } => Err(io::Error::new(
io::ErrorKind::Other,
"Got unexpected handshake during challenge",
)),
}
}
/// Provides a verification request to the server and returns whether or not
/// the server approved
pub async fn verify(&mut self, kind: AuthVerifyKind, text: String) -> io::Result<bool> {
trace!("AuthClient::verify(kind = {:?}, text = {:?})", kind, text);
// Perform JIT handshake if enabled
self.jit_handshake().await?;
let payload = AuthRequest::Verify { kind, text };
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
let response = self.inner.send(Auth::Msg { encrypted_payload }).await?;
match response.payload {
Auth::Msg { encrypted_payload } => {
match self.decrypt_and_deserialize(&encrypted_payload)? {
AuthResponse::Verify { valid } => Ok(valid),
AuthResponse::Challenge { .. } => Err(io::Error::new(
io::ErrorKind::Other,
"Got unexpected challenge response during verify",
)),
}
}
Auth::Handshake { .. } => Err(io::Error::new(
io::ErrorKind::Other,
"Got unexpected handshake during verify",
)),
}
}
/// Provides information to the server to use as it pleases with no response expected
pub async fn info(&mut self, text: String) -> io::Result<()> {
trace!("AuthClient::info(text = {:?})", text);
// Perform JIT handshake if enabled
self.jit_handshake().await?;
let payload = AuthRequest::Info { text };
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
self.inner.fire(Auth::Msg { encrypted_payload }).await
}
/// Provides an error to the server to use as it pleases with no response expected
pub async fn error(&mut self, kind: AuthErrorKind, text: String) -> io::Result<()> {
trace!("AuthClient::error(kind = {:?}, text = {:?})", kind, text);
// Perform JIT handshake if enabled
self.jit_handshake().await?;
let payload = AuthRequest::Error { kind, text };
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
self.inner.fire(Auth::Msg { encrypted_payload }).await
}
fn serialize_and_encrypt(&mut self, payload: &AuthRequest) -> io::Result<Vec<u8>> {
let codec = self.codec.as_mut().ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
"Handshake must be performed first (client encrypt message)",
)
})?;
let mut encryped_payload = BytesMut::new();
let payload = utils::serialize_to_vec(payload)?;
codec.encode(&payload, &mut encryped_payload)?;
Ok(encryped_payload.freeze().to_vec())
}
fn decrypt_and_deserialize(&mut self, payload: &[u8]) -> io::Result<AuthResponse> {
let codec = self.codec.as_mut().ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
"Handshake must be performed first (client decrypt message)",
)
})?;
let mut payload = BytesMut::from(payload);
match codec.decode(&mut payload)? {
Some(payload) => utils::deserialize_from_slice::<AuthResponse>(&payload),
None => Err(io::Error::new(
io::ErrorKind::InvalidData,
"Incomplete message received",
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Client, FramedTransport, Request, Response, TypedAsyncRead, TypedAsyncWrite};
use serde::{de::DeserializeOwned, Serialize};
const TIMEOUT_MILLIS: u64 = 100;
#[tokio::test]
async fn handshake_should_fail_if_get_unexpected_response_from_server() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move { client.handshake().await });
// Get the request, but send a bad response
let request: Request<Auth> = server.read().await.unwrap().unwrap();
match request.payload {
Auth::Handshake { .. } => server
.write(Response::new(
request.id,
Auth::Msg {
encrypted_payload: Vec::new(),
},
))
.await
.unwrap(),
_ => panic!("Server received unexpected payload"),
}
let result = task.await.unwrap();
assert!(result.is_err(), "Handshake succeeded unexpectedly")
}
#[tokio::test]
async fn challenge_should_fail_if_handshake_not_finished() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move { client.challenge(Vec::new(), HashMap::new()).await });
// Wait for a request, failing if we get one as the failure
// should have prevented sending anything, but we should
tokio::select! {
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
match x {
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
Ok(None) => {},
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
}
},
_ = wait_ms(TIMEOUT_MILLIS) => {
panic!("Should have gotten server closure as part of client exit");
}
}
// Verify that we got an error with the method
let result = task.await.unwrap();
assert!(result.is_err(), "Challenge succeeded unexpectedly")
}
#[tokio::test]
async fn challenge_should_fail_if_receive_wrong_response() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move {
client.handshake().await.unwrap();
client
.challenge(
vec![
AuthQuestion::new("question1".to_string()),
AuthQuestion {
text: "question2".to_string(),
options: vec![("key2".to_string(), "value2".to_string())]
.into_iter()
.collect(),
},
],
vec![("key".to_string(), "value".to_string())]
.into_iter()
.collect(),
)
.await
});
// Wait for a handshake request and set up our encryption codec
let request: Request<Auth> = server.read().await.unwrap().unwrap();
let mut codec = match request.payload {
Auth::Handshake { public_key, salt } => {
let handshake = Handshake::default();
let key = handshake.handshake(public_key, salt).unwrap();
server
.write(Response::new(
request.id,
Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
},
))
.await
.unwrap();
XChaCha20Poly1305Codec::new(&key)
}
_ => panic!("Server received unexpected payload"),
};
// Wait for a challenge request and send back wrong response
let request: Request<Auth> = server.read().await.unwrap().unwrap();
match request.payload {
Auth::Msg { encrypted_payload } => {
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
AuthRequest::Challenge { .. } => {
server
.write(Response::new(
request.id,
Auth::Msg {
encrypted_payload: serialize_and_encrypt(
&mut codec,
&AuthResponse::Verify { valid: true },
)
.unwrap(),
},
))
.await
.unwrap();
}
_ => panic!("Server received wrong request type"),
}
}
_ => panic!("Server received unexpected payload"),
};
// Verify that we got an error with the method
let result = task.await.unwrap();
assert!(result.is_err(), "Challenge succeeded unexpectedly")
}
#[tokio::test]
async fn challenge_should_return_answers_received_from_server() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move {
client.handshake().await.unwrap();
client
.challenge(
vec![
AuthQuestion::new("question1".to_string()),
AuthQuestion {
text: "question2".to_string(),
options: vec![("key2".to_string(), "value2".to_string())]
.into_iter()
.collect(),
},
],
vec![("key".to_string(), "value".to_string())]
.into_iter()
.collect(),
)
.await
});
// Wait for a handshake request and set up our encryption codec
let request: Request<Auth> = server.read().await.unwrap().unwrap();
let mut codec = match request.payload {
Auth::Handshake { public_key, salt } => {
let handshake = Handshake::default();
let key = handshake.handshake(public_key, salt).unwrap();
server
.write(Response::new(
request.id,
Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
},
))
.await
.unwrap();
XChaCha20Poly1305Codec::new(&key)
}
_ => panic!("Server received unexpected payload"),
};
// Wait for a challenge request and send back wrong response
let request: Request<Auth> = server.read().await.unwrap().unwrap();
match request.payload {
Auth::Msg { encrypted_payload } => {
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
AuthRequest::Challenge { questions, options } => {
assert_eq!(
questions,
vec![
AuthQuestion::new("question1".to_string()),
AuthQuestion {
text: "question2".to_string(),
options: vec![("key2".to_string(), "value2".to_string())]
.into_iter()
.collect(),
},
],
);
assert_eq!(
options,
vec![("key".to_string(), "value".to_string())]
.into_iter()
.collect(),
);
server
.write(Response::new(
request.id,
Auth::Msg {
encrypted_payload: serialize_and_encrypt(
&mut codec,
&AuthResponse::Challenge {
answers: vec![
"answer1".to_string(),
"answer2".to_string(),
],
},
)
.unwrap(),
},
))
.await
.unwrap();
}
_ => panic!("Server received wrong request type"),
}
}
_ => panic!("Server received unexpected payload"),
};
// Verify that we got the right results
let answers = task.await.unwrap().unwrap();
assert_eq!(answers, vec!["answer1".to_string(), "answer2".to_string()]);
}
#[tokio::test]
async fn verify_should_fail_if_handshake_not_finished() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move {
client
.verify(AuthVerifyKind::Host, "some text".to_string())
.await
});
// Wait for a request, failing if we get one as the failure
// should have prevented sending anything, but we should
tokio::select! {
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
match x {
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
Ok(None) => {},
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
}
},
_ = wait_ms(TIMEOUT_MILLIS) => {
panic!("Should have gotten server closure as part of client exit");
}
}
// Verify that we got an error with the method
let result = task.await.unwrap();
assert!(result.is_err(), "Verify succeeded unexpectedly")
}
#[tokio::test]
async fn verify_should_fail_if_receive_wrong_response() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move {
client.handshake().await.unwrap();
client
.verify(AuthVerifyKind::Host, "some text".to_string())
.await
});
// Wait for a handshake request and set up our encryption codec
let request: Request<Auth> = server.read().await.unwrap().unwrap();
let mut codec = match request.payload {
Auth::Handshake { public_key, salt } => {
let handshake = Handshake::default();
let key = handshake.handshake(public_key, salt).unwrap();
server
.write(Response::new(
request.id,
Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
},
))
.await
.unwrap();
XChaCha20Poly1305Codec::new(&key)
}
_ => panic!("Server received unexpected payload"),
};
// Wait for a verify request and send back wrong response
let request: Request<Auth> = server.read().await.unwrap().unwrap();
match request.payload {
Auth::Msg { encrypted_payload } => {
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
AuthRequest::Verify { .. } => {
server
.write(Response::new(
request.id,
Auth::Msg {
encrypted_payload: serialize_and_encrypt(
&mut codec,
&AuthResponse::Challenge {
answers: Vec::new(),
},
)
.unwrap(),
},
))
.await
.unwrap();
}
_ => panic!("Server received wrong request type"),
}
}
_ => panic!("Server received unexpected payload"),
};
// Verify that we got an error with the method
let result = task.await.unwrap();
assert!(result.is_err(), "Verify succeeded unexpectedly")
}
#[tokio::test]
async fn verify_should_return_valid_bool_received_from_server() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move {
client.handshake().await.unwrap();
client
.verify(AuthVerifyKind::Host, "some text".to_string())
.await
});
// Wait for a handshake request and set up our encryption codec
let request: Request<Auth> = server.read().await.unwrap().unwrap();
let mut codec = match request.payload {
Auth::Handshake { public_key, salt } => {
let handshake = Handshake::default();
let key = handshake.handshake(public_key, salt).unwrap();
server
.write(Response::new(
request.id,
Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
},
))
.await
.unwrap();
XChaCha20Poly1305Codec::new(&key)
}
_ => panic!("Server received unexpected payload"),
};
// Wait for a challenge request and send back wrong response
let request: Request<Auth> = server.read().await.unwrap().unwrap();
match request.payload {
Auth::Msg { encrypted_payload } => {
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
AuthRequest::Verify { kind, text } => {
assert_eq!(kind, AuthVerifyKind::Host);
assert_eq!(text, "some text");
server
.write(Response::new(
request.id,
Auth::Msg {
encrypted_payload: serialize_and_encrypt(
&mut codec,
&AuthResponse::Verify { valid: true },
)
.unwrap(),
},
))
.await
.unwrap();
}
_ => panic!("Server received wrong request type"),
}
}
_ => panic!("Server received unexpected payload"),
};
// Verify that we got the right results
let valid = task.await.unwrap().unwrap();
assert!(valid, "Got verify response, but valid was set incorrectly");
}
#[tokio::test]
async fn info_should_fail_if_handshake_not_finished() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move { client.info("some text".to_string()).await });
// Wait for a request, failing if we get one as the failure
// should have prevented sending anything, but we should
tokio::select! {
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
match x {
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
Ok(None) => {},
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
}
},
_ = wait_ms(TIMEOUT_MILLIS) => {
panic!("Should have gotten server closure as part of client exit");
}
}
// Verify that we got an error with the method
let result = task.await.unwrap();
assert!(result.is_err(), "Info succeeded unexpectedly")
}
#[tokio::test]
async fn info_should_send_the_server_a_request_but_not_wait_for_a_response() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move {
client.handshake().await.unwrap();
client.info("some text".to_string()).await
});
// Wait for a handshake request and set up our encryption codec
let request: Request<Auth> = server.read().await.unwrap().unwrap();
let mut codec = match request.payload {
Auth::Handshake { public_key, salt } => {
let handshake = Handshake::default();
let key = handshake.handshake(public_key, salt).unwrap();
server
.write(Response::new(
request.id,
Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
},
))
.await
.unwrap();
XChaCha20Poly1305Codec::new(&key)
}
_ => panic!("Server received unexpected payload"),
};
// Wait for a request
let request: Request<Auth> = server.read().await.unwrap().unwrap();
match request.payload {
Auth::Msg { encrypted_payload } => {
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
AuthRequest::Info { text } => {
assert_eq!(text, "some text");
}
_ => panic!("Server received wrong request type"),
}
}
_ => panic!("Server received unexpected payload"),
};
// Verify that we got the right results
task.await.unwrap().unwrap();
}
#[tokio::test]
async fn error_should_fail_if_handshake_not_finished() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move {
client
.error(AuthErrorKind::FailedChallenge, "some text".to_string())
.await
});
// Wait for a request, failing if we get one as the failure
// should have prevented sending anything, but we should
tokio::select! {
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
match x {
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
Ok(None) => {},
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
}
},
_ = wait_ms(TIMEOUT_MILLIS) => {
panic!("Should have gotten server closure as part of client exit");
}
}
// Verify that we got an error with the method
let result = task.await.unwrap();
assert!(result.is_err(), "Error succeeded unexpectedly")
}
#[tokio::test]
async fn error_should_send_the_server_a_request_but_not_wait_for_a_response() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move {
client.handshake().await.unwrap();
client
.error(AuthErrorKind::FailedChallenge, "some text".to_string())
.await
});
// Wait for a handshake request and set up our encryption codec
let request: Request<Auth> = server.read().await.unwrap().unwrap();
let mut codec = match request.payload {
Auth::Handshake { public_key, salt } => {
let handshake = Handshake::default();
let key = handshake.handshake(public_key, salt).unwrap();
server
.write(Response::new(
request.id,
Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
},
))
.await
.unwrap();
XChaCha20Poly1305Codec::new(&key)
}
_ => panic!("Server received unexpected payload"),
};
// Wait for a request
let request: Request<Auth> = server.read().await.unwrap().unwrap();
match request.payload {
Auth::Msg { encrypted_payload } => {
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
AuthRequest::Error { kind, text } => {
assert_eq!(kind, AuthErrorKind::FailedChallenge);
assert_eq!(text, "some text");
}
_ => panic!("Server received wrong request type"),
}
}
_ => panic!("Server received unexpected payload"),
};
// Verify that we got the right results
task.await.unwrap().unwrap();
}
async fn wait_ms(ms: u64) {
use std::time::Duration;
tokio::time::sleep(Duration::from_millis(ms)).await;
}
fn serialize_and_encrypt<T: Serialize>(
codec: &mut XChaCha20Poly1305Codec,
payload: &T,
) -> io::Result<Vec<u8>> {
let mut encryped_payload = BytesMut::new();
let payload = utils::serialize_to_vec(payload)?;
codec.encode(&payload, &mut encryped_payload)?;
Ok(encryped_payload.freeze().to_vec())
}
fn decrypt_and_deserialize<T: DeserializeOwned>(
codec: &mut XChaCha20Poly1305Codec,
payload: &[u8],
) -> io::Result<T> {
let mut payload = BytesMut::from(payload);
match codec.decode(&mut payload)? {
Some(payload) => utils::deserialize_from_slice::<T>(&payload),
None => Err(io::Error::new(
io::ErrorKind::InvalidData,
"Incomplete message received",
)),
}
}
}

@ -1,654 +0,0 @@
use crate::{
utils, Auth, AuthErrorKind, AuthQuestion, AuthRequest, AuthResponse, AuthVerifyKind, Codec,
Handshake, Server, ServerCtx, XChaCha20Poly1305Codec,
};
use async_trait::async_trait;
use bytes::BytesMut;
use log::*;
use std::{collections::HashMap, io};
use tokio::sync::RwLock;
/// Type signature for a dynamic on_challenge function
pub type AuthChallengeFn =
dyn Fn(Vec<AuthQuestion>, HashMap<String, String>) -> Vec<String> + Send + Sync;
/// Type signature for a dynamic on_verify function
pub type AuthVerifyFn = dyn Fn(AuthVerifyKind, String) -> bool + Send + Sync;
/// Type signature for a dynamic on_info function
pub type AuthInfoFn = dyn Fn(String) + Send + Sync;
/// Type signature for a dynamic on_error function
pub type AuthErrorFn = dyn Fn(AuthErrorKind, String) + Send + Sync;
/// Represents an [`AuthServer`] where all handlers are stored on the heap
pub type HeapAuthServer =
AuthServer<Box<AuthChallengeFn>, Box<AuthVerifyFn>, Box<AuthInfoFn>, Box<AuthErrorFn>>;
/// Server that handles authentication
pub struct AuthServer<ChallengeFn, VerifyFn, InfoFn, ErrorFn>
where
ChallengeFn: Fn(Vec<AuthQuestion>, HashMap<String, String>) -> Vec<String> + Send + Sync,
VerifyFn: Fn(AuthVerifyKind, String) -> bool + Send + Sync,
InfoFn: Fn(String) + Send + Sync,
ErrorFn: Fn(AuthErrorKind, String) + Send + Sync,
{
pub on_challenge: ChallengeFn,
pub on_verify: VerifyFn,
pub on_info: InfoFn,
pub on_error: ErrorFn,
}
#[async_trait]
impl<ChallengeFn, VerifyFn, InfoFn, ErrorFn> Server
for AuthServer<ChallengeFn, VerifyFn, InfoFn, ErrorFn>
where
ChallengeFn: Fn(Vec<AuthQuestion>, HashMap<String, String>) -> Vec<String> + Send + Sync,
VerifyFn: Fn(AuthVerifyKind, String) -> bool + Send + Sync,
InfoFn: Fn(String) + Send + Sync,
ErrorFn: Fn(AuthErrorKind, String) + Send + Sync,
{
type Request = Auth;
type Response = Auth;
type LocalData = RwLock<Option<XChaCha20Poly1305Codec>>;
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
let reply = ctx.reply.clone();
match ctx.request.payload {
Auth::Handshake { public_key, salt } => {
trace!(
"Received handshake request from client, request id = {}",
ctx.request.id
);
let handshake = Handshake::default();
match handshake.handshake(public_key, salt) {
Ok(key) => {
ctx.local_data
.write()
.await
.replace(XChaCha20Poly1305Codec::new(&key));
trace!(
"Sending reciprocal handshake to client, response origin id = {}",
ctx.request.id
);
if let Err(x) = reply
.send(Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
})
.await
{
error!("[Conn {}] {}", ctx.connection_id, x);
}
}
Err(x) => {
error!("[Conn {}] {}", ctx.connection_id, x);
return;
}
}
}
Auth::Msg {
ref encrypted_payload,
} => {
trace!(
"Received auth msg, encrypted payload size = {}",
encrypted_payload.len()
);
// Attempt to decrypt the message so we can understand what to do
let request = match ctx.local_data.write().await.as_mut() {
Some(codec) => {
let mut payload = BytesMut::from(encrypted_payload.as_slice());
match codec.decode(&mut payload) {
Ok(Some(payload)) => {
utils::deserialize_from_slice::<AuthRequest>(&payload)
}
Ok(None) => Err(io::Error::new(
io::ErrorKind::InvalidData,
"Incomplete message received",
)),
Err(x) => Err(x),
}
}
None => Err(io::Error::new(
io::ErrorKind::Other,
"Handshake must be performed first (server decrypt message)",
)),
};
let response = match request {
Ok(request) => match request {
AuthRequest::Challenge { questions, options } => {
trace!("Received challenge request");
trace!("questions = {:?}", questions);
trace!("options = {:?}", options);
let answers = (self.on_challenge)(questions, options);
AuthResponse::Challenge { answers }
}
AuthRequest::Verify { kind, text } => {
trace!("Received verify request");
trace!("kind = {:?}", kind);
trace!("text = {:?}", text);
let valid = (self.on_verify)(kind, text);
AuthResponse::Verify { valid }
}
AuthRequest::Info { text } => {
trace!("Received info request");
trace!("text = {:?}", text);
(self.on_info)(text);
return;
}
AuthRequest::Error { kind, text } => {
trace!("Received error request");
trace!("kind = {:?}", kind);
trace!("text = {:?}", text);
(self.on_error)(kind, text);
return;
}
},
Err(x) => {
error!("[Conn {}] {}", ctx.connection_id, x);
return;
}
};
// Serialize and encrypt the message before sending it back
let encrypted_payload = match ctx.local_data.write().await.as_mut() {
Some(codec) => {
let mut encrypted_payload = BytesMut::new();
// Convert the response into bytes for us to send back
match utils::serialize_to_vec(&response) {
Ok(bytes) => match codec.encode(&bytes, &mut encrypted_payload) {
Ok(_) => Ok(encrypted_payload.freeze().to_vec()),
Err(x) => Err(x),
},
Err(x) => Err(x),
}
}
None => Err(io::Error::new(
io::ErrorKind::Other,
"Handshake must be performed first (server encrypt messaage)",
)),
};
match encrypted_payload {
Ok(encrypted_payload) => {
if let Err(x) = reply.send(Auth::Msg { encrypted_payload }).await {
error!("[Conn {}] {}", ctx.connection_id, x);
return;
}
}
Err(x) => {
error!("[Conn {}] {}", ctx.connection_id, x);
return;
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
InmemoryTypedTransport, IntoSplit, MpscListener, Request, Response, ServerExt, ServerRef,
TypedAsyncRead, TypedAsyncWrite,
};
use tokio::sync::mpsc;
const TIMEOUT_MILLIS: u64 = 100;
#[tokio::test]
async fn should_not_reply_if_receive_encrypted_msg_without_handshake_first() {
let (mut t, _) = spawn_auth_server(
/* on_challenge */ |_, _| Vec::new(),
/* on_verify */ |_, _| false,
/* on_info */ |_| {},
/* on_error */ |_, _| {},
)
.await
.expect("Failed to spawn server");
// Send an encrypted message before establishing a handshake
t.write(Request::new(Auth::Msg {
encrypted_payload: Vec::new(),
}))
.await
.expect("Failed to send request to server");
// Wait for a response, failing if we get one
tokio::select! {
x = t.read() => panic!("Unexpectedly resolved: {:?}", x),
_ = wait_ms(TIMEOUT_MILLIS) => {}
}
}
#[tokio::test]
async fn should_reply_to_handshake_request_with_new_public_key_and_salt() {
let (mut t, _) = spawn_auth_server(
/* on_challenge */ |_, _| Vec::new(),
/* on_verify */ |_, _| false,
/* on_info */ |_| {},
/* on_error */ |_, _| {},
)
.await
.expect("Failed to spawn server");
// Send a handshake
let handshake = Handshake::default();
t.write(Request::new(Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
}))
.await
.expect("Failed to send request to server");
// Wait for a handshake response
tokio::select! {
x = t.read() => {
let response = x.expect("Request failed").expect("Response missing");
match response.payload {
Auth::Handshake { .. } => {},
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
}
}
_ = wait_ms(TIMEOUT_MILLIS) => panic!("Ran out of time waiting on response"),
}
}
#[tokio::test]
async fn should_not_reply_if_receive_invalid_encrypted_msg() {
let (mut t, _) = spawn_auth_server(
/* on_challenge */ |_, _| Vec::new(),
/* on_verify */ |_, _| false,
/* on_info */ |_| {},
/* on_error */ |_, _| {},
)
.await
.expect("Failed to spawn server");
// Send a handshake
let handshake = Handshake::default();
t.write(Request::new(Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
}))
.await
.expect("Failed to send request to server");
// Complete handshake
let key = match t.read().await.unwrap().unwrap().payload {
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
};
// Send a bad chunk of data
let _codec = XChaCha20Poly1305Codec::new(&key);
t.write(Request::new(Auth::Msg {
encrypted_payload: vec![1, 2, 3, 4],
}))
.await
.unwrap();
// Wait for a response, failing if we get one
tokio::select! {
x = t.read() => panic!("Unexpectedly resolved: {:?}", x),
_ = wait_ms(TIMEOUT_MILLIS) => {}
}
}
#[tokio::test]
async fn should_invoke_appropriate_function_when_receive_challenge_request_and_reply() {
let (tx, mut rx) = mpsc::channel(1);
let (mut t, _) = spawn_auth_server(
/* on_challenge */
move |questions, options| {
tx.try_send((questions, options)).unwrap();
vec!["answer1".to_string(), "answer2".to_string()]
},
/* on_verify */ |_, _| false,
/* on_info */ |_| {},
/* on_error */ |_, _| {},
)
.await
.expect("Failed to spawn server");
// Send a handshake
let handshake = Handshake::default();
t.write(Request::new(Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
}))
.await
.expect("Failed to send request to server");
// Complete handshake
let key = match t.read().await.unwrap().unwrap().payload {
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
};
// Send an error request
let mut codec = XChaCha20Poly1305Codec::new(&key);
t.write(Request::new(Auth::Msg {
encrypted_payload: serialize_and_encrypt(
&mut codec,
&AuthRequest::Challenge {
questions: vec![
AuthQuestion::new("question1".to_string()),
AuthQuestion {
text: "question2".to_string(),
options: vec![("key".to_string(), "value".to_string())]
.into_iter()
.collect(),
},
],
options: vec![("hello".to_string(), "world".to_string())]
.into_iter()
.collect(),
},
)
.unwrap(),
}))
.await
.unwrap();
// Verify that the handler was triggered
let (questions, options) = rx.recv().await.expect("Channel closed unexpectedly");
assert_eq!(
questions,
vec![
AuthQuestion::new("question1".to_string()),
AuthQuestion {
text: "question2".to_string(),
options: vec![("key".to_string(), "value".to_string())]
.into_iter()
.collect(),
}
]
);
assert_eq!(
options,
vec![("hello".to_string(), "world".to_string())]
.into_iter()
.collect()
);
// Wait for a response and verify that it matches what we expect
tokio::select! {
x = t.read() => {
let response = x.expect("Request failed").expect("Response missing");
match response.payload {
Auth::Handshake { .. } => panic!("Received unexpected handshake"),
Auth::Msg { encrypted_payload } => {
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
AuthResponse::Challenge { answers } =>
assert_eq!(
answers,
vec!["answer1".to_string(), "answer2".to_string()]
),
_ => panic!("Got wrong response for verify"),
}
},
}
}
_ = wait_ms(TIMEOUT_MILLIS) => {}
}
}
#[tokio::test]
async fn should_invoke_appropriate_function_when_receive_verify_request_and_reply() {
let (tx, mut rx) = mpsc::channel(1);
let (mut t, _) = spawn_auth_server(
/* on_challenge */ |_, _| Vec::new(),
/* on_verify */
move |kind, text| {
tx.try_send((kind, text)).unwrap();
true
},
/* on_info */ |_| {},
/* on_error */ |_, _| {},
)
.await
.expect("Failed to spawn server");
// Send a handshake
let handshake = Handshake::default();
t.write(Request::new(Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
}))
.await
.expect("Failed to send request to server");
// Complete handshake
let key = match t.read().await.unwrap().unwrap().payload {
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
};
// Send an error request
let mut codec = XChaCha20Poly1305Codec::new(&key);
t.write(Request::new(Auth::Msg {
encrypted_payload: serialize_and_encrypt(
&mut codec,
&AuthRequest::Verify {
kind: AuthVerifyKind::Host,
text: "some text".to_string(),
},
)
.unwrap(),
}))
.await
.unwrap();
// Verify that the handler was triggered
let (kind, text) = rx.recv().await.expect("Channel closed unexpectedly");
assert_eq!(kind, AuthVerifyKind::Host);
assert_eq!(text, "some text");
// Wait for a response and verify that it matches what we expect
tokio::select! {
x = t.read() => {
let response = x.expect("Request failed").expect("Response missing");
match response.payload {
Auth::Handshake { .. } => panic!("Received unexpected handshake"),
Auth::Msg { encrypted_payload } => {
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
AuthResponse::Verify { valid } =>
assert!(valid, "Got verify, but valid was wrong"),
_ => panic!("Got wrong response for verify"),
}
},
}
}
_ = wait_ms(TIMEOUT_MILLIS) => {}
}
}
#[tokio::test]
async fn should_invoke_appropriate_function_when_receive_info_request() {
let (tx, mut rx) = mpsc::channel(1);
let (mut t, _) = spawn_auth_server(
/* on_challenge */ |_, _| Vec::new(),
/* on_verify */ |_, _| false,
/* on_info */
move |text| {
tx.try_send(text).unwrap();
},
/* on_error */ |_, _| {},
)
.await
.expect("Failed to spawn server");
// Send a handshake
let handshake = Handshake::default();
t.write(Request::new(Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
}))
.await
.expect("Failed to send request to server");
// Complete handshake
let key = match t.read().await.unwrap().unwrap().payload {
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
};
// Send an error request
let mut codec = XChaCha20Poly1305Codec::new(&key);
t.write(Request::new(Auth::Msg {
encrypted_payload: serialize_and_encrypt(
&mut codec,
&AuthRequest::Info {
text: "some text".to_string(),
},
)
.unwrap(),
}))
.await
.unwrap();
// Verify that the handler was triggered
let text = rx.recv().await.expect("Channel closed unexpectedly");
assert_eq!(text, "some text");
// Wait for a response, failing if we get one
tokio::select! {
x = t.read() => panic!("Unexpectedly resolved: {:?}", x),
_ = wait_ms(TIMEOUT_MILLIS) => {}
}
}
#[tokio::test]
async fn should_invoke_appropriate_function_when_receive_error_request() {
let (tx, mut rx) = mpsc::channel(1);
let (mut t, _) = spawn_auth_server(
/* on_challenge */ |_, _| Vec::new(),
/* on_verify */ |_, _| false,
/* on_info */ |_| {},
/* on_error */
move |kind, text| {
tx.try_send((kind, text)).unwrap();
},
)
.await
.expect("Failed to spawn server");
// Send a handshake
let handshake = Handshake::default();
t.write(Request::new(Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
}))
.await
.expect("Failed to send request to server");
// Complete handshake
let key = match t.read().await.unwrap().unwrap().payload {
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
};
// Send an error request
let mut codec = XChaCha20Poly1305Codec::new(&key);
t.write(Request::new(Auth::Msg {
encrypted_payload: serialize_and_encrypt(
&mut codec,
&AuthRequest::Error {
kind: AuthErrorKind::FailedChallenge,
text: "some text".to_string(),
},
)
.unwrap(),
}))
.await
.unwrap();
// Verify that the handler was triggered
let (kind, text) = rx.recv().await.expect("Channel closed unexpectedly");
assert_eq!(kind, AuthErrorKind::FailedChallenge);
assert_eq!(text, "some text");
// Wait for a response, failing if we get one
tokio::select! {
x = t.read() => panic!("Unexpectedly resolved: {:?}", x),
_ = wait_ms(TIMEOUT_MILLIS) => {}
}
}
async fn wait_ms(ms: u64) {
use std::time::Duration;
tokio::time::sleep(Duration::from_millis(ms)).await;
}
fn serialize_and_encrypt(
codec: &mut XChaCha20Poly1305Codec,
payload: &AuthRequest,
) -> io::Result<Vec<u8>> {
let mut encryped_payload = BytesMut::new();
let payload = utils::serialize_to_vec(payload)?;
codec.encode(&payload, &mut encryped_payload)?;
Ok(encryped_payload.freeze().to_vec())
}
fn decrypt_and_deserialize(
codec: &mut XChaCha20Poly1305Codec,
payload: &[u8],
) -> io::Result<AuthResponse> {
let mut payload = BytesMut::from(payload);
match codec.decode(&mut payload)? {
Some(payload) => utils::deserialize_from_slice::<AuthResponse>(&payload),
None => Err(io::Error::new(
io::ErrorKind::InvalidData,
"Incomplete message received",
)),
}
}
async fn spawn_auth_server<ChallengeFn, VerifyFn, InfoFn, ErrorFn>(
on_challenge: ChallengeFn,
on_verify: VerifyFn,
on_info: InfoFn,
on_error: ErrorFn,
) -> io::Result<(
InmemoryTypedTransport<Request<Auth>, Response<Auth>>,
Box<dyn ServerRef>,
)>
where
ChallengeFn:
Fn(Vec<AuthQuestion>, HashMap<String, String>) -> Vec<String> + Send + Sync + 'static,
VerifyFn: Fn(AuthVerifyKind, String) -> bool + Send + Sync + 'static,
InfoFn: Fn(String) + Send + Sync + 'static,
ErrorFn: Fn(AuthErrorKind, String) + Send + Sync + 'static,
{
let server = AuthServer {
on_challenge,
on_verify,
on_info,
on_error,
};
// Create a test listener where we will forward a connection
let (tx, listener) = MpscListener::channel(100);
// Make bounded transport pair and send off one of them to act as our connection
let (transport, connection) =
InmemoryTypedTransport::<Request<Auth>, Response<Auth>>::pair(100);
tx.send(connection.into_split())
.await
.expect("Failed to feed listener a connection");
let server = server.start(listener)?;
Ok((transport, server))
}
}

@ -1,4 +1,7 @@
use crate::{FramedTransport, Interest, Reconnectable, Request, Transport, UntypedResponse};
use crate::{
FramedTransport, Interest, Reconnectable, Request, StatefulFramedTransport, Transport,
UntypedResponse,
};
use async_trait::async_trait;
use log::*;
use serde::{de::DeserializeOwned, Serialize};
@ -35,23 +38,30 @@ pub struct Client<T, U> {
impl<T, U> Client<T, U>
where
T: Send + Sync + Serialize,
U: Send + Sync + DeserializeOwned,
T: Send + Sync + Serialize + 'static,
U: Send + Sync + DeserializeOwned + 'static,
{
/// Initializes a client using the provided [`FramedTransport`]
pub fn new<V, const CAPACITY: usize>(transport: FramedTransport<V, CAPACITY>) -> Self
where
V: Transport + Send + Sync,
V: Transport + Send + Sync + 'static,
{
let post_office = Arc::new(PostOffice::default());
let weak_post_office = Arc::downgrade(&post_office);
let (tx, mut rx) = mpsc::channel::<Request<T>>(1);
let (reconnect_tx, reconnect_rx) = mpsc::channel::<oneshot::Sender<io::Result<()>>>(1);
let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
let (reconnect_tx, mut reconnect_rx) = mpsc::channel::<oneshot::Sender<io::Result<()>>>(1);
let (shutdown_tx, mut shutdown_rx) = mpsc::channel(1);
let mut transport = StatefulFramedTransport::new(transport);
// Start a task that continually checks for responses and delivers them using the
// post office
let task = tokio::spawn(async move {
transport
.authenticate()
.await
.expect("Failed to authenticate with the remote server");
loop {
let ready = tokio::select! {
_ = shutdown_rx.recv() => {
@ -59,7 +69,7 @@ where
}
cb = reconnect_rx.recv() => {
if let Some(cb) = cb {
cb.send(Reconnectable::reconnect(&mut transport).await);
let _ = cb.send(Reconnectable::reconnect(&mut transport).await);
continue;
} else {
break;
@ -166,6 +176,11 @@ impl<T, U> Client<T, U> {
self.task.abort();
}
/// Signal for the client to shutdown its connection cleanly
pub async fn shutdown(&self) -> bool {
self.shutdown_tx.send(()).await.is_ok()
}
/// Returns true if client's underlying event processing has finished/terminated
pub fn is_finished(&self) -> bool {
self.task.is_finished()

@ -203,7 +203,7 @@ mod tests {
}
let frame = t2.try_read_frame().unwrap().unwrap();
let _req: Request<u8> = Request::from_slice(&frame.as_item()).unwrap();
let _req: Request<u8> = Request::from_slice(frame.as_item()).unwrap();
}
#[tokio::test]
@ -219,6 +219,6 @@ mod tests {
}
let frame = t2.try_read_frame().unwrap().unwrap();
let _req: Request<u8> = Request::from_slice(&frame.as_item()).unwrap();
let _req: Request<u8> = Request::from_slice(frame.as_item()).unwrap();
}
}

@ -1,4 +1,3 @@
use crate::FramedTransport;
use async_trait::async_trait;
use serde::{de::DeserializeOwned, Serialize};
use std::io;
@ -41,15 +40,6 @@ pub trait Server: Send {
ServerConfig::default()
}
/// Invoked to facilitate a handshake between server and client upon establishing a connection,
/// returning an updated [`FramedTransport`] once the handshake is complete
async fn on_handshake<T: Send, const CAPACITY: usize>(
&self,
transport: FramedTransport<T, CAPACITY>,
) -> io::Result<FramedTransport<T, CAPACITY>> {
Ok(transport)
}
/// Invoked upon a new connection becoming established, which provides a mutable reference to
/// the data created for the connection. This can be useful in performing some additional
/// initialization on the data prior to it being used anywhere else.

@ -360,7 +360,7 @@ mod tests {
let (tx, listener) = make_listener(100);
// Make bounded transport pair and send off one of them to act as our connection
let (mut transport, connection) = InmemoryTransport::pair(100);
let (transport, connection) = InmemoryTransport::pair(100);
tx.send(connection)
.await
.expect("Failed to feed listener a connection");

@ -43,7 +43,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::{Client, PlainCodec, Request, ServerCtx, UnixSocketClientExt};
use crate::{Client, Request, ServerCtx, UnixSocketClientExt};
use tempfile::NamedTempFile;
pub struct TestServer;
@ -75,7 +75,7 @@ mod tests {
.await
.expect("Failed to start Unix socket server");
let mut client: Client<String, String> = Client::connect(server.path(), PlainCodec)
let mut client: Client<String, String> = Client::connect(server.path())
.await
.expect("Client failed to connect");

@ -558,7 +558,7 @@ mod tests {
#[test]
fn try_read_frame_should_return_would_block_if_fails_to_read_frame_before_blocking() {
// Should fail if immediately blocks
let mut transport = FramedTransport::new(
let mut transport = FramedTransport::<_>::new(
TestTransport {
f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::WouldBlock))),
f_ready: Box::new(|_| Ok(Ready::READABLE)),
@ -572,7 +572,7 @@ mod tests {
);
// Should fail if not read enough bytes before blocking
let mut transport = FramedTransport::new(
let mut transport = FramedTransport::<_>::new(
TestTransport {
f_try_read: simulate_try_read(vec![Frame::new(b"some data")], 1, |cnt| cnt == 1),
f_ready: Box::new(|_| Ok(Ready::READABLE)),
@ -588,7 +588,7 @@ mod tests {
#[test]
fn try_read_frame_should_return_error_if_encountered_error_with_reading_bytes() {
let mut transport = FramedTransport::new(
let mut transport = FramedTransport::<_>::new(
TestTransport {
f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
f_ready: Box::new(|_| Ok(Ready::READABLE)),
@ -604,7 +604,7 @@ mod tests {
#[test]
fn try_read_frame_should_return_error_if_encountered_error_during_decode() {
let mut transport = FramedTransport::new(
let mut transport = FramedTransport::<_>::new(
TestTransport {
f_try_read: simulate_try_read(vec![Frame::new(b"some data")], 1, |_| false),
f_ready: Box::new(|_| Ok(Ready::READABLE)),
@ -626,7 +626,7 @@ mod tests {
data.freeze()
};
let mut transport = FramedTransport::new(
let mut transport = FramedTransport::<_>::new(
TestTransport {
f_try_read: Box::new(move |buf| {
buf[..data.len()].copy_from_slice(data.as_ref());
@ -644,7 +644,7 @@ mod tests {
fn try_read_frame_should_keep_reading_until_a_frame_is_found() {
const STEP_SIZE: usize = Frame::HEADER_SIZE + 7;
let mut transport = FramedTransport::new(
let mut transport = FramedTransport::<_>::new(
TestTransport {
f_try_read: simulate_try_read(
vec![Frame::new(b"hello world"), Frame::new(b"test hello")],
@ -668,7 +668,7 @@ mod tests {
#[test]
fn try_write_frame_should_return_would_block_if_fails_to_write_frame_before_blocking() {
let mut transport = FramedTransport::new(
let mut transport = FramedTransport::<_>::new(
TestTransport {
f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::WouldBlock))),
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
@ -689,7 +689,7 @@ mod tests {
#[test]
fn try_write_frame_should_return_error_if_encountered_error_with_writing_bytes() {
let mut transport = FramedTransport::new(
let mut transport = FramedTransport::<_>::new(
TestTransport {
f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
@ -708,7 +708,7 @@ mod tests {
#[test]
fn try_write_frame_should_return_error_if_encountered_error_during_encode() {
let mut transport = FramedTransport::new(
let mut transport = FramedTransport::<_>::new(
TestTransport {
f_try_write: Box::new(|buf| Ok(buf.len())),
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
@ -728,7 +728,7 @@ mod tests {
#[test]
fn try_write_frame_should_write_entire_frame_if_possible() {
let (tx, rx) = std::sync::mpsc::sync_channel(1);
let mut transport = FramedTransport::new(
let mut transport = FramedTransport::<_>::new(
TestTransport {
f_try_write: Box::new(move |buf| {
let len = buf.len();
@ -754,7 +754,7 @@ mod tests {
fn try_write_frame_should_write_any_prior_queued_bytes_before_writing_next_frame() {
const STEP_SIZE: usize = Frame::HEADER_SIZE + 5;
let (tx, rx) = std::sync::mpsc::sync_channel(10);
let mut transport = FramedTransport::new(
let mut transport = FramedTransport::<_>::new(
TestTransport {
f_try_write: Box::new(move |buf| {
static mut CNT: usize = 0;
@ -809,7 +809,7 @@ mod tests {
#[test]
fn try_flush_should_return_error_if_try_write_fails() {
let mut transport = FramedTransport::new(
let mut transport = FramedTransport::<_>::new(
TestTransport {
f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
@ -830,7 +830,7 @@ mod tests {
#[test]
fn try_flush_should_return_error_if_try_write_returns_0_bytes_written() {
let mut transport = FramedTransport::new(
let mut transport = FramedTransport::<_>::new(
TestTransport {
f_try_write: Box::new(|_| Ok(0)),
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
@ -851,7 +851,7 @@ mod tests {
#[test]
fn try_flush_should_be_noop_if_nothing_to_flush() {
let mut transport = FramedTransport::new(
let mut transport = FramedTransport::<_>::new(
TestTransport {
f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
@ -868,7 +868,7 @@ mod tests {
fn try_flush_should_continually_call_try_write_until_outgoing_buffer_is_empty() {
const STEP_SIZE: usize = 5;
let (tx, rx) = std::sync::mpsc::sync_channel(10);
let mut transport = FramedTransport::new(
let mut transport = FramedTransport::<_>::new(
TestTransport {
f_try_write: Box::new(move |buf| {
let len = std::cmp::min(STEP_SIZE, buf.len());

@ -125,35 +125,34 @@ impl fmt::Debug for EncryptionCodec {
impl Codec for EncryptionCodec {
fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
let frame = match self {
let nonce_bytes = self.generate_nonce_bytes();
Ok(match self {
Self::XChaCha20Poly1305 { cipher } => {
use chacha20poly1305::{aead::Aead, XNonce};
let nonce_bytes = self.generate_nonce_bytes();
let item = frame.into_item();
let nonce = XNonce::from_slice(&nonce_bytes);
// Encrypt the frame's item as our ciphertext
let ciphertext = cipher
.encrypt(nonce, frame.as_item())
.encrypt(nonce, item.as_ref())
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Encryption failed"))?;
// Frame is now comprised of the nonce and ciphertext in sequence
let mut frame = Frame::new(&nonce_bytes);
// Start our frame with the nonce at the beginning
let mut frame = Frame::from(nonce_bytes);
frame.extend(ciphertext);
frame
}
};
Ok(frame.into_owned())
})
}
fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
if frame.len() <= self.nonce_size() {
let nonce_size = self.nonce_size();
if frame.len() <= nonce_size {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Frame cannot have length less than {}",
self.nonce_size() + 1
),
format!("Frame cannot have length less than {}", nonce_size + 1),
));
}
@ -162,9 +161,9 @@ impl Codec for EncryptionCodec {
let item = match self {
Self::XChaCha20Poly1305 { cipher } => {
use chacha20poly1305::{aead::Aead, XNonce};
let nonce = XNonce::from_slice(&frame.as_item()[..self.nonce_size()]);
let nonce = XNonce::from_slice(&frame.as_item()[..nonce_size]);
cipher
.decrypt(nonce, &frame.as_item()[self.nonce_size()..])
.decrypt(nonce, &frame.as_item()[nonce_size..])
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Decryption failed"))?
}
};

@ -5,6 +5,9 @@ use std::{
ops::{Deref, DerefMut},
};
mod auth;
pub use auth::*;
/// Internal state for our transport
#[derive(Clone, Debug)]
enum State {
@ -45,8 +48,6 @@ impl<T, const CAPACITY: usize> StatefulFramedTransport<T, CAPACITY> {
}
/// Performs authentication with the other side, moving the state to be authenticated.
///
/// NOTE: Does nothing if already authenticated!
pub async fn authenticate(&mut self) -> io::Result<()> {
if self.is_authenticated() {
return Ok(());

@ -0,0 +1,152 @@
use async_trait::async_trait;
use derive_more::Display;
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, io};
/// Interface for a handler of authentication requests
#[async_trait]
pub trait AuthHandler {
/// Callback when a challenge is received, returning answers to the given questions.
async fn on_challenge(
&mut self,
questions: Vec<AuthQuestion>,
options: HashMap<String, String>,
) -> io::Result<Vec<String>>;
/// Callback when a verification request is received, returning true if approvided or false if
/// unapproved.
async fn on_verify(&mut self, kind: AuthVerifyKind, text: String) -> io::Result<bool>;
/// Callback when authentication is finished and no more requests will be received
async fn on_done(&mut self) -> io::Result<()> {
Ok(())
}
/// Callback when information is received. To fail, return an error from this function.
#[allow(unused_variables)]
async fn on_info(&mut self, text: String) -> io::Result<()> {
Ok(())
}
/// Callback when an error is received. To fail, return an error from this function.
async fn on_error(&mut self, kind: AuthErrorKind, text: String) -> io::Result<()> {
Err(match kind {
AuthErrorKind::FailedChallenge => io::Error::new(
io::ErrorKind::PermissionDenied,
format!("Failed challenge: {text}"),
),
AuthErrorKind::FailedVerification => io::Error::new(
io::ErrorKind::PermissionDenied,
format!("Failed verification: {text}"),
),
AuthErrorKind::Unknown => {
io::Error::new(io::ErrorKind::Other, format!("Unknown error: {text}"))
}
})
}
}
/// Represents authentication messages that act as initiators such as providing
/// a challenge, verifying information, presenting information, or highlighting an error
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum AuthRequest {
Challenge(AuthChallengeRequest),
Verify(AuthVerifyRequest),
Info(AuthInfo),
Error(AuthError),
}
/// Represents a challenge comprising a series of questions to be presented
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AuthChallengeRequest {
pub questions: Vec<AuthQuestion>,
pub options: HashMap<String, String>,
}
/// Represents an ask to verify some information
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AuthVerifyRequest {
pub kind: AuthVerifyKind,
pub text: String,
}
/// Represents some information to be presented
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AuthInfo {
pub text: String,
}
/// Represents some error that occurred
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AuthError {
pub kind: AuthErrorKind,
pub text: String,
}
/// Represents authentication messages that are responses to auth requests such
/// as answers to challenges or verifying information
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum AuthResponse {
Challenge(AuthChallengeResponse),
Verify(AuthVerifyResponse),
}
/// Represents the answers to a previously-asked challenge
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AuthChallengeResponse {
pub answers: Vec<String>,
}
/// Represents the answer to a previously-asked verify
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AuthVerifyResponse {
pub valid: bool,
}
/// Represents the type of verification being requested
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum AuthVerifyKind {
/// An ask to verify the host such as with SSH
#[display(fmt = "host")]
Host,
}
/// Represents a single question in a challenge
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct AuthQuestion {
/// The text of the question
pub text: String,
/// Any options information specific to a particular auth domain
/// such as including a username and instructions for SSH authentication
pub options: HashMap<String, String>,
}
impl AuthQuestion {
/// Creates a new question without any options data
pub fn new(text: impl Into<String>) -> Self {
Self {
text: text.into(),
options: HashMap::new(),
}
}
}
/// Represents the type of error encountered during authentication
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AuthErrorKind {
/// When the answer(s) to a challenge do not pass authentication
FailedChallenge,
/// When verification during authentication fails
/// (e.g. a host is not allowed or blocked)
FailedVerification,
/// When the error is unknown
Unknown,
}
Loading…
Cancel
Save