mirror of https://github.com/chipsenkbeil/distant
Still working on it
parent
7cf4c39ac8
commit
cf48d48b03
@ -1,122 +0,0 @@
|
||||
use derive_more::Display;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
mod client;
|
||||
pub use client::*;
|
||||
|
||||
mod handshake;
|
||||
pub use handshake::*;
|
||||
|
||||
mod server;
|
||||
pub use server::*;
|
||||
|
||||
/// Represents authentication messages that can be sent over the wire
|
||||
///
|
||||
/// NOTE: Must use serde's content attribute with the tag attribute. Just the tag attribute will
|
||||
/// cause deserialization to fail
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "type", content = "data")]
|
||||
pub enum Auth {
|
||||
/// Represents a request to perform an authentication handshake,
|
||||
/// providing the public key and salt from one side in order to
|
||||
/// derive the shared key
|
||||
#[serde(rename = "auth_handshake")]
|
||||
Handshake {
|
||||
/// Bytes of the public key
|
||||
#[serde(with = "serde_bytes")]
|
||||
public_key: PublicKeyBytes,
|
||||
|
||||
/// Randomly generated salt
|
||||
#[serde(with = "serde_bytes")]
|
||||
salt: Salt,
|
||||
},
|
||||
|
||||
/// Represents the bytes of an encrypted message
|
||||
///
|
||||
/// Underneath, will be one of either [`AuthRequest`] or [`AuthResponse`]
|
||||
#[serde(rename = "auth_msg")]
|
||||
Msg {
|
||||
#[serde(with = "serde_bytes")]
|
||||
encrypted_payload: Vec<u8>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Represents authentication messages that act as initiators such as providing
|
||||
/// a challenge, verifying information, presenting information, or highlighting an error
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "type")]
|
||||
pub enum AuthRequest {
|
||||
/// Represents a challenge comprising a series of questions to be presented
|
||||
Challenge {
|
||||
questions: Vec<AuthQuestion>,
|
||||
options: HashMap<String, String>,
|
||||
},
|
||||
|
||||
/// Represents an ask to verify some information
|
||||
Verify { kind: AuthVerifyKind, text: String },
|
||||
|
||||
/// Represents some information to be presented
|
||||
Info { text: String },
|
||||
|
||||
/// Represents some error that occurred
|
||||
Error { kind: AuthErrorKind, text: String },
|
||||
}
|
||||
|
||||
/// Represents authentication messages that are responses to auth requests such
|
||||
/// as answers to challenges or verifying information
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "type")]
|
||||
pub enum AuthResponse {
|
||||
/// Represents the answers to a previously-asked challenge
|
||||
Challenge { answers: Vec<String> },
|
||||
|
||||
/// Represents the answer to a previously-asked verify
|
||||
Verify { valid: bool },
|
||||
}
|
||||
|
||||
/// Represents the type of verification being requested
|
||||
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[non_exhaustive]
|
||||
pub enum AuthVerifyKind {
|
||||
/// An ask to verify the host such as with SSH
|
||||
#[display(fmt = "host")]
|
||||
Host,
|
||||
}
|
||||
|
||||
/// Represents a single question in a challenge
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct AuthQuestion {
|
||||
/// The text of the question
|
||||
pub text: String,
|
||||
|
||||
/// Any options information specific to a particular auth domain
|
||||
/// such as including a username and instructions for SSH authentication
|
||||
pub options: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl AuthQuestion {
|
||||
/// Creates a new question without any options data
|
||||
pub fn new(text: impl Into<String>) -> Self {
|
||||
Self {
|
||||
text: text.into(),
|
||||
options: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the type of error encountered during authentication
|
||||
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum AuthErrorKind {
|
||||
/// When the answer(s) to a challenge do not pass authentication
|
||||
FailedChallenge,
|
||||
|
||||
/// When verification during authentication fails
|
||||
/// (e.g. a host is not allowed or blocked)
|
||||
FailedVerification,
|
||||
|
||||
/// When the error is unknown
|
||||
Unknown,
|
||||
}
|
@ -1,817 +0,0 @@
|
||||
use crate::{
|
||||
utils, Auth, AuthErrorKind, AuthQuestion, AuthRequest, AuthResponse, AuthVerifyKind, Client,
|
||||
Codec, Handshake, XChaCha20Poly1305Codec,
|
||||
};
|
||||
use bytes::BytesMut;
|
||||
use log::*;
|
||||
use std::{collections::HashMap, io};
|
||||
|
||||
pub struct AuthClient {
|
||||
inner: Client<Auth, Auth>,
|
||||
codec: Option<XChaCha20Poly1305Codec>,
|
||||
jit_handshake: bool,
|
||||
}
|
||||
|
||||
impl From<Client<Auth, Auth>> for AuthClient {
|
||||
fn from(client: Client<Auth, Auth>) -> Self {
|
||||
Self {
|
||||
inner: client,
|
||||
codec: None,
|
||||
jit_handshake: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthClient {
|
||||
/// Sends a request to the server to establish an encrypted connection
|
||||
pub async fn handshake(&mut self) -> io::Result<()> {
|
||||
let handshake = Handshake::default();
|
||||
|
||||
let response = self
|
||||
.inner
|
||||
.send(Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
match response.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let key = handshake.handshake(public_key, salt)?;
|
||||
self.codec.replace(XChaCha20Poly1305Codec::new(&key));
|
||||
Ok(())
|
||||
}
|
||||
Auth::Msg { .. } => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Got unexpected encrypted message during handshake",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a handshake only if jit is enabled and no handshake has succeeded yet
|
||||
async fn jit_handshake(&mut self) -> io::Result<()> {
|
||||
if self.will_jit_handshake() && !self.is_ready() {
|
||||
self.handshake().await
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if client has successfully performed a handshake
|
||||
/// and is ready to communicate with the server
|
||||
pub fn is_ready(&self) -> bool {
|
||||
self.codec.is_some()
|
||||
}
|
||||
|
||||
/// Returns true if this client will perform a handshake just-in-time (JIT) prior to making a
|
||||
/// request in the scenario where the client has not already performed a handshake
|
||||
#[inline]
|
||||
pub fn will_jit_handshake(&self) -> bool {
|
||||
self.jit_handshake
|
||||
}
|
||||
|
||||
/// Sets the jit flag on this client with `true` indicating that this client will perform a
|
||||
/// handshake just-in-time (JIT) prior to making a request in the scenario where the client has
|
||||
/// not already performed a handshake
|
||||
#[inline]
|
||||
pub fn set_jit_handshake(&mut self, flag: bool) {
|
||||
self.jit_handshake = flag;
|
||||
}
|
||||
|
||||
/// Provides a challenge to the server and returns the answers to the questions
|
||||
/// asked by the client
|
||||
pub async fn challenge(
|
||||
&mut self,
|
||||
questions: Vec<AuthQuestion>,
|
||||
options: HashMap<String, String>,
|
||||
) -> io::Result<Vec<String>> {
|
||||
trace!(
|
||||
"AuthClient::challenge(questions = {:?}, options = {:?})",
|
||||
questions,
|
||||
options
|
||||
);
|
||||
|
||||
// Perform JIT handshake if enabled
|
||||
self.jit_handshake().await?;
|
||||
|
||||
let payload = AuthRequest::Challenge { questions, options };
|
||||
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
|
||||
let response = self.inner.send(Auth::Msg { encrypted_payload }).await?;
|
||||
|
||||
match response.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match self.decrypt_and_deserialize(&encrypted_payload)? {
|
||||
AuthResponse::Challenge { answers } => Ok(answers),
|
||||
AuthResponse::Verify { .. } => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Got unexpected verify response during challenge",
|
||||
)),
|
||||
}
|
||||
}
|
||||
Auth::Handshake { .. } => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Got unexpected handshake during challenge",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides a verification request to the server and returns whether or not
|
||||
/// the server approved
|
||||
pub async fn verify(&mut self, kind: AuthVerifyKind, text: String) -> io::Result<bool> {
|
||||
trace!("AuthClient::verify(kind = {:?}, text = {:?})", kind, text);
|
||||
|
||||
// Perform JIT handshake if enabled
|
||||
self.jit_handshake().await?;
|
||||
|
||||
let payload = AuthRequest::Verify { kind, text };
|
||||
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
|
||||
let response = self.inner.send(Auth::Msg { encrypted_payload }).await?;
|
||||
|
||||
match response.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match self.decrypt_and_deserialize(&encrypted_payload)? {
|
||||
AuthResponse::Verify { valid } => Ok(valid),
|
||||
AuthResponse::Challenge { .. } => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Got unexpected challenge response during verify",
|
||||
)),
|
||||
}
|
||||
}
|
||||
Auth::Handshake { .. } => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Got unexpected handshake during verify",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides information to the server to use as it pleases with no response expected
|
||||
pub async fn info(&mut self, text: String) -> io::Result<()> {
|
||||
trace!("AuthClient::info(text = {:?})", text);
|
||||
|
||||
// Perform JIT handshake if enabled
|
||||
self.jit_handshake().await?;
|
||||
|
||||
let payload = AuthRequest::Info { text };
|
||||
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
|
||||
self.inner.fire(Auth::Msg { encrypted_payload }).await
|
||||
}
|
||||
|
||||
/// Provides an error to the server to use as it pleases with no response expected
|
||||
pub async fn error(&mut self, kind: AuthErrorKind, text: String) -> io::Result<()> {
|
||||
trace!("AuthClient::error(kind = {:?}, text = {:?})", kind, text);
|
||||
|
||||
// Perform JIT handshake if enabled
|
||||
self.jit_handshake().await?;
|
||||
|
||||
let payload = AuthRequest::Error { kind, text };
|
||||
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
|
||||
self.inner.fire(Auth::Msg { encrypted_payload }).await
|
||||
}
|
||||
|
||||
fn serialize_and_encrypt(&mut self, payload: &AuthRequest) -> io::Result<Vec<u8>> {
|
||||
let codec = self.codec.as_mut().ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Handshake must be performed first (client encrypt message)",
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut encryped_payload = BytesMut::new();
|
||||
let payload = utils::serialize_to_vec(payload)?;
|
||||
codec.encode(&payload, &mut encryped_payload)?;
|
||||
Ok(encryped_payload.freeze().to_vec())
|
||||
}
|
||||
|
||||
fn decrypt_and_deserialize(&mut self, payload: &[u8]) -> io::Result<AuthResponse> {
|
||||
let codec = self.codec.as_mut().ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Handshake must be performed first (client decrypt message)",
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut payload = BytesMut::from(payload);
|
||||
match codec.decode(&mut payload)? {
|
||||
Some(payload) => utils::deserialize_from_slice::<AuthResponse>(&payload),
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Incomplete message received",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{Client, FramedTransport, Request, Response, TypedAsyncRead, TypedAsyncWrite};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
|
||||
const TIMEOUT_MILLIS: u64 = 100;
|
||||
|
||||
#[tokio::test]
|
||||
async fn handshake_should_fail_if_get_unexpected_response_from_server() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move { client.handshake().await });
|
||||
|
||||
// Get the request, but send a bad response
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Handshake { .. } => server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Msg {
|
||||
encrypted_payload: Vec::new(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap(),
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
}
|
||||
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Handshake succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn challenge_should_fail_if_handshake_not_finished() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move { client.challenge(Vec::new(), HashMap::new()).await });
|
||||
|
||||
// Wait for a request, failing if we get one as the failure
|
||||
// should have prevented sending anything, but we should
|
||||
tokio::select! {
|
||||
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
|
||||
match x {
|
||||
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
|
||||
Ok(None) => {},
|
||||
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
|
||||
}
|
||||
},
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {
|
||||
panic!("Should have gotten server closure as part of client exit");
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that we got an error with the method
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Challenge succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn challenge_should_fail_if_receive_wrong_response() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client.handshake().await.unwrap();
|
||||
client
|
||||
.challenge(
|
||||
vec![
|
||||
AuthQuestion::new("question1".to_string()),
|
||||
AuthQuestion {
|
||||
text: "question2".to_string(),
|
||||
options: vec![("key2".to_string(), "value2".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
],
|
||||
vec![("key".to_string(), "value".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a handshake request and set up our encryption codec
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
let mut codec = match request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let handshake = Handshake::default();
|
||||
let key = handshake.handshake(public_key, salt).unwrap();
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
XChaCha20Poly1305Codec::new(&key)
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Wait for a challenge request and send back wrong response
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthRequest::Challenge { .. } => {
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthResponse::Verify { valid: true },
|
||||
)
|
||||
.unwrap(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
_ => panic!("Server received wrong request type"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Verify that we got an error with the method
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Challenge succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn challenge_should_return_answers_received_from_server() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client.handshake().await.unwrap();
|
||||
client
|
||||
.challenge(
|
||||
vec![
|
||||
AuthQuestion::new("question1".to_string()),
|
||||
AuthQuestion {
|
||||
text: "question2".to_string(),
|
||||
options: vec![("key2".to_string(), "value2".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
],
|
||||
vec![("key".to_string(), "value".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a handshake request and set up our encryption codec
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
let mut codec = match request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let handshake = Handshake::default();
|
||||
let key = handshake.handshake(public_key, salt).unwrap();
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
XChaCha20Poly1305Codec::new(&key)
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Wait for a challenge request and send back wrong response
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthRequest::Challenge { questions, options } => {
|
||||
assert_eq!(
|
||||
questions,
|
||||
vec![
|
||||
AuthQuestion::new("question1".to_string()),
|
||||
AuthQuestion {
|
||||
text: "question2".to_string(),
|
||||
options: vec![("key2".to_string(), "value2".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
],
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
options,
|
||||
vec![("key".to_string(), "value".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
);
|
||||
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthResponse::Challenge {
|
||||
answers: vec![
|
||||
"answer1".to_string(),
|
||||
"answer2".to_string(),
|
||||
],
|
||||
},
|
||||
)
|
||||
.unwrap(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
_ => panic!("Server received wrong request type"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Verify that we got the right results
|
||||
let answers = task.await.unwrap().unwrap();
|
||||
assert_eq!(answers, vec!["answer1".to_string(), "answer2".to_string()]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_should_fail_if_handshake_not_finished() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client
|
||||
.verify(AuthVerifyKind::Host, "some text".to_string())
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a request, failing if we get one as the failure
|
||||
// should have prevented sending anything, but we should
|
||||
tokio::select! {
|
||||
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
|
||||
match x {
|
||||
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
|
||||
Ok(None) => {},
|
||||
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
|
||||
}
|
||||
},
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {
|
||||
panic!("Should have gotten server closure as part of client exit");
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that we got an error with the method
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Verify succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_should_fail_if_receive_wrong_response() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client.handshake().await.unwrap();
|
||||
client
|
||||
.verify(AuthVerifyKind::Host, "some text".to_string())
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a handshake request and set up our encryption codec
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
let mut codec = match request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let handshake = Handshake::default();
|
||||
let key = handshake.handshake(public_key, salt).unwrap();
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
XChaCha20Poly1305Codec::new(&key)
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Wait for a verify request and send back wrong response
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthRequest::Verify { .. } => {
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthResponse::Challenge {
|
||||
answers: Vec::new(),
|
||||
},
|
||||
)
|
||||
.unwrap(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
_ => panic!("Server received wrong request type"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Verify that we got an error with the method
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Verify succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_should_return_valid_bool_received_from_server() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client.handshake().await.unwrap();
|
||||
client
|
||||
.verify(AuthVerifyKind::Host, "some text".to_string())
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a handshake request and set up our encryption codec
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
let mut codec = match request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let handshake = Handshake::default();
|
||||
let key = handshake.handshake(public_key, salt).unwrap();
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
XChaCha20Poly1305Codec::new(&key)
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Wait for a challenge request and send back wrong response
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthRequest::Verify { kind, text } => {
|
||||
assert_eq!(kind, AuthVerifyKind::Host);
|
||||
assert_eq!(text, "some text");
|
||||
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthResponse::Verify { valid: true },
|
||||
)
|
||||
.unwrap(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
_ => panic!("Server received wrong request type"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Verify that we got the right results
|
||||
let valid = task.await.unwrap().unwrap();
|
||||
assert!(valid, "Got verify response, but valid was set incorrectly");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn info_should_fail_if_handshake_not_finished() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move { client.info("some text".to_string()).await });
|
||||
|
||||
// Wait for a request, failing if we get one as the failure
|
||||
// should have prevented sending anything, but we should
|
||||
tokio::select! {
|
||||
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
|
||||
match x {
|
||||
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
|
||||
Ok(None) => {},
|
||||
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
|
||||
}
|
||||
},
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {
|
||||
panic!("Should have gotten server closure as part of client exit");
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that we got an error with the method
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Info succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn info_should_send_the_server_a_request_but_not_wait_for_a_response() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client.handshake().await.unwrap();
|
||||
client.info("some text".to_string()).await
|
||||
});
|
||||
|
||||
// Wait for a handshake request and set up our encryption codec
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
let mut codec = match request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let handshake = Handshake::default();
|
||||
let key = handshake.handshake(public_key, salt).unwrap();
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
XChaCha20Poly1305Codec::new(&key)
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Wait for a request
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthRequest::Info { text } => {
|
||||
assert_eq!(text, "some text");
|
||||
}
|
||||
_ => panic!("Server received wrong request type"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Verify that we got the right results
|
||||
task.await.unwrap().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn error_should_fail_if_handshake_not_finished() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client
|
||||
.error(AuthErrorKind::FailedChallenge, "some text".to_string())
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a request, failing if we get one as the failure
|
||||
// should have prevented sending anything, but we should
|
||||
tokio::select! {
|
||||
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
|
||||
match x {
|
||||
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
|
||||
Ok(None) => {},
|
||||
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
|
||||
}
|
||||
},
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {
|
||||
panic!("Should have gotten server closure as part of client exit");
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that we got an error with the method
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Error succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn error_should_send_the_server_a_request_but_not_wait_for_a_response() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client.handshake().await.unwrap();
|
||||
client
|
||||
.error(AuthErrorKind::FailedChallenge, "some text".to_string())
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a handshake request and set up our encryption codec
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
let mut codec = match request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let handshake = Handshake::default();
|
||||
let key = handshake.handshake(public_key, salt).unwrap();
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
XChaCha20Poly1305Codec::new(&key)
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Wait for a request
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthRequest::Error { kind, text } => {
|
||||
assert_eq!(kind, AuthErrorKind::FailedChallenge);
|
||||
assert_eq!(text, "some text");
|
||||
}
|
||||
_ => panic!("Server received wrong request type"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Verify that we got the right results
|
||||
task.await.unwrap().unwrap();
|
||||
}
|
||||
|
||||
async fn wait_ms(ms: u64) {
|
||||
use std::time::Duration;
|
||||
tokio::time::sleep(Duration::from_millis(ms)).await;
|
||||
}
|
||||
|
||||
fn serialize_and_encrypt<T: Serialize>(
|
||||
codec: &mut XChaCha20Poly1305Codec,
|
||||
payload: &T,
|
||||
) -> io::Result<Vec<u8>> {
|
||||
let mut encryped_payload = BytesMut::new();
|
||||
let payload = utils::serialize_to_vec(payload)?;
|
||||
codec.encode(&payload, &mut encryped_payload)?;
|
||||
Ok(encryped_payload.freeze().to_vec())
|
||||
}
|
||||
|
||||
fn decrypt_and_deserialize<T: DeserializeOwned>(
|
||||
codec: &mut XChaCha20Poly1305Codec,
|
||||
payload: &[u8],
|
||||
) -> io::Result<T> {
|
||||
let mut payload = BytesMut::from(payload);
|
||||
match codec.decode(&mut payload)? {
|
||||
Some(payload) => utils::deserialize_from_slice::<T>(&payload),
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Incomplete message received",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
@ -1,654 +0,0 @@
|
||||
use crate::{
|
||||
utils, Auth, AuthErrorKind, AuthQuestion, AuthRequest, AuthResponse, AuthVerifyKind, Codec,
|
||||
Handshake, Server, ServerCtx, XChaCha20Poly1305Codec,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use bytes::BytesMut;
|
||||
use log::*;
|
||||
use std::{collections::HashMap, io};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Type signature for a dynamic on_challenge function
|
||||
pub type AuthChallengeFn =
|
||||
dyn Fn(Vec<AuthQuestion>, HashMap<String, String>) -> Vec<String> + Send + Sync;
|
||||
|
||||
/// Type signature for a dynamic on_verify function
|
||||
pub type AuthVerifyFn = dyn Fn(AuthVerifyKind, String) -> bool + Send + Sync;
|
||||
|
||||
/// Type signature for a dynamic on_info function
|
||||
pub type AuthInfoFn = dyn Fn(String) + Send + Sync;
|
||||
|
||||
/// Type signature for a dynamic on_error function
|
||||
pub type AuthErrorFn = dyn Fn(AuthErrorKind, String) + Send + Sync;
|
||||
|
||||
/// Represents an [`AuthServer`] where all handlers are stored on the heap
|
||||
pub type HeapAuthServer =
|
||||
AuthServer<Box<AuthChallengeFn>, Box<AuthVerifyFn>, Box<AuthInfoFn>, Box<AuthErrorFn>>;
|
||||
|
||||
/// Server that handles authentication
|
||||
pub struct AuthServer<ChallengeFn, VerifyFn, InfoFn, ErrorFn>
|
||||
where
|
||||
ChallengeFn: Fn(Vec<AuthQuestion>, HashMap<String, String>) -> Vec<String> + Send + Sync,
|
||||
VerifyFn: Fn(AuthVerifyKind, String) -> bool + Send + Sync,
|
||||
InfoFn: Fn(String) + Send + Sync,
|
||||
ErrorFn: Fn(AuthErrorKind, String) + Send + Sync,
|
||||
{
|
||||
pub on_challenge: ChallengeFn,
|
||||
pub on_verify: VerifyFn,
|
||||
pub on_info: InfoFn,
|
||||
pub on_error: ErrorFn,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<ChallengeFn, VerifyFn, InfoFn, ErrorFn> Server
|
||||
for AuthServer<ChallengeFn, VerifyFn, InfoFn, ErrorFn>
|
||||
where
|
||||
ChallengeFn: Fn(Vec<AuthQuestion>, HashMap<String, String>) -> Vec<String> + Send + Sync,
|
||||
VerifyFn: Fn(AuthVerifyKind, String) -> bool + Send + Sync,
|
||||
InfoFn: Fn(String) + Send + Sync,
|
||||
ErrorFn: Fn(AuthErrorKind, String) + Send + Sync,
|
||||
{
|
||||
type Request = Auth;
|
||||
type Response = Auth;
|
||||
type LocalData = RwLock<Option<XChaCha20Poly1305Codec>>;
|
||||
|
||||
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
|
||||
let reply = ctx.reply.clone();
|
||||
|
||||
match ctx.request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
trace!(
|
||||
"Received handshake request from client, request id = {}",
|
||||
ctx.request.id
|
||||
);
|
||||
let handshake = Handshake::default();
|
||||
match handshake.handshake(public_key, salt) {
|
||||
Ok(key) => {
|
||||
ctx.local_data
|
||||
.write()
|
||||
.await
|
||||
.replace(XChaCha20Poly1305Codec::new(&key));
|
||||
|
||||
trace!(
|
||||
"Sending reciprocal handshake to client, response origin id = {}",
|
||||
ctx.request.id
|
||||
);
|
||||
if let Err(x) = reply
|
||||
.send(Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
})
|
||||
.await
|
||||
{
|
||||
error!("[Conn {}] {}", ctx.connection_id, x);
|
||||
}
|
||||
}
|
||||
Err(x) => {
|
||||
error!("[Conn {}] {}", ctx.connection_id, x);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
Auth::Msg {
|
||||
ref encrypted_payload,
|
||||
} => {
|
||||
trace!(
|
||||
"Received auth msg, encrypted payload size = {}",
|
||||
encrypted_payload.len()
|
||||
);
|
||||
|
||||
// Attempt to decrypt the message so we can understand what to do
|
||||
let request = match ctx.local_data.write().await.as_mut() {
|
||||
Some(codec) => {
|
||||
let mut payload = BytesMut::from(encrypted_payload.as_slice());
|
||||
match codec.decode(&mut payload) {
|
||||
Ok(Some(payload)) => {
|
||||
utils::deserialize_from_slice::<AuthRequest>(&payload)
|
||||
}
|
||||
Ok(None) => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Incomplete message received",
|
||||
)),
|
||||
Err(x) => Err(x),
|
||||
}
|
||||
}
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Handshake must be performed first (server decrypt message)",
|
||||
)),
|
||||
};
|
||||
|
||||
let response = match request {
|
||||
Ok(request) => match request {
|
||||
AuthRequest::Challenge { questions, options } => {
|
||||
trace!("Received challenge request");
|
||||
trace!("questions = {:?}", questions);
|
||||
trace!("options = {:?}", options);
|
||||
|
||||
let answers = (self.on_challenge)(questions, options);
|
||||
AuthResponse::Challenge { answers }
|
||||
}
|
||||
AuthRequest::Verify { kind, text } => {
|
||||
trace!("Received verify request");
|
||||
trace!("kind = {:?}", kind);
|
||||
trace!("text = {:?}", text);
|
||||
|
||||
let valid = (self.on_verify)(kind, text);
|
||||
AuthResponse::Verify { valid }
|
||||
}
|
||||
AuthRequest::Info { text } => {
|
||||
trace!("Received info request");
|
||||
trace!("text = {:?}", text);
|
||||
|
||||
(self.on_info)(text);
|
||||
return;
|
||||
}
|
||||
AuthRequest::Error { kind, text } => {
|
||||
trace!("Received error request");
|
||||
trace!("kind = {:?}", kind);
|
||||
trace!("text = {:?}", text);
|
||||
|
||||
(self.on_error)(kind, text);
|
||||
return;
|
||||
}
|
||||
},
|
||||
Err(x) => {
|
||||
error!("[Conn {}] {}", ctx.connection_id, x);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Serialize and encrypt the message before sending it back
|
||||
let encrypted_payload = match ctx.local_data.write().await.as_mut() {
|
||||
Some(codec) => {
|
||||
let mut encrypted_payload = BytesMut::new();
|
||||
|
||||
// Convert the response into bytes for us to send back
|
||||
match utils::serialize_to_vec(&response) {
|
||||
Ok(bytes) => match codec.encode(&bytes, &mut encrypted_payload) {
|
||||
Ok(_) => Ok(encrypted_payload.freeze().to_vec()),
|
||||
Err(x) => Err(x),
|
||||
},
|
||||
Err(x) => Err(x),
|
||||
}
|
||||
}
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Handshake must be performed first (server encrypt messaage)",
|
||||
)),
|
||||
};
|
||||
|
||||
match encrypted_payload {
|
||||
Ok(encrypted_payload) => {
|
||||
if let Err(x) = reply.send(Auth::Msg { encrypted_payload }).await {
|
||||
error!("[Conn {}] {}", ctx.connection_id, x);
|
||||
return;
|
||||
}
|
||||
}
|
||||
Err(x) => {
|
||||
error!("[Conn {}] {}", ctx.connection_id, x);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
InmemoryTypedTransport, IntoSplit, MpscListener, Request, Response, ServerExt, ServerRef,
|
||||
TypedAsyncRead, TypedAsyncWrite,
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
const TIMEOUT_MILLIS: u64 = 100;
|
||||
|
||||
#[tokio::test]
|
||||
async fn should_not_reply_if_receive_encrypted_msg_without_handshake_first() {
|
||||
let (mut t, _) = spawn_auth_server(
|
||||
/* on_challenge */ |_, _| Vec::new(),
|
||||
/* on_verify */ |_, _| false,
|
||||
/* on_info */ |_| {},
|
||||
/* on_error */ |_, _| {},
|
||||
)
|
||||
.await
|
||||
.expect("Failed to spawn server");
|
||||
|
||||
// Send an encrypted message before establishing a handshake
|
||||
t.write(Request::new(Auth::Msg {
|
||||
encrypted_payload: Vec::new(),
|
||||
}))
|
||||
.await
|
||||
.expect("Failed to send request to server");
|
||||
|
||||
// Wait for a response, failing if we get one
|
||||
tokio::select! {
|
||||
x = t.read() => panic!("Unexpectedly resolved: {:?}", x),
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn should_reply_to_handshake_request_with_new_public_key_and_salt() {
|
||||
let (mut t, _) = spawn_auth_server(
|
||||
/* on_challenge */ |_, _| Vec::new(),
|
||||
/* on_verify */ |_, _| false,
|
||||
/* on_info */ |_| {},
|
||||
/* on_error */ |_, _| {},
|
||||
)
|
||||
.await
|
||||
.expect("Failed to spawn server");
|
||||
|
||||
// Send a handshake
|
||||
let handshake = Handshake::default();
|
||||
t.write(Request::new(Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
}))
|
||||
.await
|
||||
.expect("Failed to send request to server");
|
||||
|
||||
// Wait for a handshake response
|
||||
tokio::select! {
|
||||
x = t.read() => {
|
||||
let response = x.expect("Request failed").expect("Response missing");
|
||||
match response.payload {
|
||||
Auth::Handshake { .. } => {},
|
||||
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
|
||||
}
|
||||
}
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => panic!("Ran out of time waiting on response"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn should_not_reply_if_receive_invalid_encrypted_msg() {
|
||||
let (mut t, _) = spawn_auth_server(
|
||||
/* on_challenge */ |_, _| Vec::new(),
|
||||
/* on_verify */ |_, _| false,
|
||||
/* on_info */ |_| {},
|
||||
/* on_error */ |_, _| {},
|
||||
)
|
||||
.await
|
||||
.expect("Failed to spawn server");
|
||||
|
||||
// Send a handshake
|
||||
let handshake = Handshake::default();
|
||||
t.write(Request::new(Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
}))
|
||||
.await
|
||||
.expect("Failed to send request to server");
|
||||
|
||||
// Complete handshake
|
||||
let key = match t.read().await.unwrap().unwrap().payload {
|
||||
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
|
||||
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
|
||||
};
|
||||
|
||||
// Send a bad chunk of data
|
||||
let _codec = XChaCha20Poly1305Codec::new(&key);
|
||||
t.write(Request::new(Auth::Msg {
|
||||
encrypted_payload: vec![1, 2, 3, 4],
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Wait for a response, failing if we get one
|
||||
tokio::select! {
|
||||
x = t.read() => panic!("Unexpectedly resolved: {:?}", x),
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn should_invoke_appropriate_function_when_receive_challenge_request_and_reply() {
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
let (mut t, _) = spawn_auth_server(
|
||||
/* on_challenge */
|
||||
move |questions, options| {
|
||||
tx.try_send((questions, options)).unwrap();
|
||||
vec!["answer1".to_string(), "answer2".to_string()]
|
||||
},
|
||||
/* on_verify */ |_, _| false,
|
||||
/* on_info */ |_| {},
|
||||
/* on_error */ |_, _| {},
|
||||
)
|
||||
.await
|
||||
.expect("Failed to spawn server");
|
||||
|
||||
// Send a handshake
|
||||
let handshake = Handshake::default();
|
||||
t.write(Request::new(Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
}))
|
||||
.await
|
||||
.expect("Failed to send request to server");
|
||||
|
||||
// Complete handshake
|
||||
let key = match t.read().await.unwrap().unwrap().payload {
|
||||
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
|
||||
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
|
||||
};
|
||||
|
||||
// Send an error request
|
||||
let mut codec = XChaCha20Poly1305Codec::new(&key);
|
||||
t.write(Request::new(Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthRequest::Challenge {
|
||||
questions: vec![
|
||||
AuthQuestion::new("question1".to_string()),
|
||||
AuthQuestion {
|
||||
text: "question2".to_string(),
|
||||
options: vec![("key".to_string(), "value".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
],
|
||||
options: vec![("hello".to_string(), "world".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
)
|
||||
.unwrap(),
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify that the handler was triggered
|
||||
let (questions, options) = rx.recv().await.expect("Channel closed unexpectedly");
|
||||
assert_eq!(
|
||||
questions,
|
||||
vec![
|
||||
AuthQuestion::new("question1".to_string()),
|
||||
AuthQuestion {
|
||||
text: "question2".to_string(),
|
||||
options: vec![("key".to_string(), "value".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
}
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
options,
|
||||
vec![("hello".to_string(), "world".to_string())]
|
||||
.into_iter()
|
||||
.collect()
|
||||
);
|
||||
|
||||
// Wait for a response and verify that it matches what we expect
|
||||
tokio::select! {
|
||||
x = t.read() => {
|
||||
let response = x.expect("Request failed").expect("Response missing");
|
||||
match response.payload {
|
||||
Auth::Handshake { .. } => panic!("Received unexpected handshake"),
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthResponse::Challenge { answers } =>
|
||||
assert_eq!(
|
||||
answers,
|
||||
vec!["answer1".to_string(), "answer2".to_string()]
|
||||
),
|
||||
_ => panic!("Got wrong response for verify"),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn should_invoke_appropriate_function_when_receive_verify_request_and_reply() {
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
let (mut t, _) = spawn_auth_server(
|
||||
/* on_challenge */ |_, _| Vec::new(),
|
||||
/* on_verify */
|
||||
move |kind, text| {
|
||||
tx.try_send((kind, text)).unwrap();
|
||||
true
|
||||
},
|
||||
/* on_info */ |_| {},
|
||||
/* on_error */ |_, _| {},
|
||||
)
|
||||
.await
|
||||
.expect("Failed to spawn server");
|
||||
|
||||
// Send a handshake
|
||||
let handshake = Handshake::default();
|
||||
t.write(Request::new(Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
}))
|
||||
.await
|
||||
.expect("Failed to send request to server");
|
||||
|
||||
// Complete handshake
|
||||
let key = match t.read().await.unwrap().unwrap().payload {
|
||||
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
|
||||
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
|
||||
};
|
||||
|
||||
// Send an error request
|
||||
let mut codec = XChaCha20Poly1305Codec::new(&key);
|
||||
t.write(Request::new(Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthRequest::Verify {
|
||||
kind: AuthVerifyKind::Host,
|
||||
text: "some text".to_string(),
|
||||
},
|
||||
)
|
||||
.unwrap(),
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify that the handler was triggered
|
||||
let (kind, text) = rx.recv().await.expect("Channel closed unexpectedly");
|
||||
assert_eq!(kind, AuthVerifyKind::Host);
|
||||
assert_eq!(text, "some text");
|
||||
|
||||
// Wait for a response and verify that it matches what we expect
|
||||
tokio::select! {
|
||||
x = t.read() => {
|
||||
let response = x.expect("Request failed").expect("Response missing");
|
||||
match response.payload {
|
||||
Auth::Handshake { .. } => panic!("Received unexpected handshake"),
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthResponse::Verify { valid } =>
|
||||
assert!(valid, "Got verify, but valid was wrong"),
|
||||
_ => panic!("Got wrong response for verify"),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn should_invoke_appropriate_function_when_receive_info_request() {
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
let (mut t, _) = spawn_auth_server(
|
||||
/* on_challenge */ |_, _| Vec::new(),
|
||||
/* on_verify */ |_, _| false,
|
||||
/* on_info */
|
||||
move |text| {
|
||||
tx.try_send(text).unwrap();
|
||||
},
|
||||
/* on_error */ |_, _| {},
|
||||
)
|
||||
.await
|
||||
.expect("Failed to spawn server");
|
||||
|
||||
// Send a handshake
|
||||
let handshake = Handshake::default();
|
||||
t.write(Request::new(Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
}))
|
||||
.await
|
||||
.expect("Failed to send request to server");
|
||||
|
||||
// Complete handshake
|
||||
let key = match t.read().await.unwrap().unwrap().payload {
|
||||
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
|
||||
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
|
||||
};
|
||||
|
||||
// Send an error request
|
||||
let mut codec = XChaCha20Poly1305Codec::new(&key);
|
||||
t.write(Request::new(Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthRequest::Info {
|
||||
text: "some text".to_string(),
|
||||
},
|
||||
)
|
||||
.unwrap(),
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify that the handler was triggered
|
||||
let text = rx.recv().await.expect("Channel closed unexpectedly");
|
||||
assert_eq!(text, "some text");
|
||||
|
||||
// Wait for a response, failing if we get one
|
||||
tokio::select! {
|
||||
x = t.read() => panic!("Unexpectedly resolved: {:?}", x),
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn should_invoke_appropriate_function_when_receive_error_request() {
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
let (mut t, _) = spawn_auth_server(
|
||||
/* on_challenge */ |_, _| Vec::new(),
|
||||
/* on_verify */ |_, _| false,
|
||||
/* on_info */ |_| {},
|
||||
/* on_error */
|
||||
move |kind, text| {
|
||||
tx.try_send((kind, text)).unwrap();
|
||||
},
|
||||
)
|
||||
.await
|
||||
.expect("Failed to spawn server");
|
||||
|
||||
// Send a handshake
|
||||
let handshake = Handshake::default();
|
||||
t.write(Request::new(Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
}))
|
||||
.await
|
||||
.expect("Failed to send request to server");
|
||||
|
||||
// Complete handshake
|
||||
let key = match t.read().await.unwrap().unwrap().payload {
|
||||
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
|
||||
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
|
||||
};
|
||||
|
||||
// Send an error request
|
||||
let mut codec = XChaCha20Poly1305Codec::new(&key);
|
||||
t.write(Request::new(Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthRequest::Error {
|
||||
kind: AuthErrorKind::FailedChallenge,
|
||||
text: "some text".to_string(),
|
||||
},
|
||||
)
|
||||
.unwrap(),
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify that the handler was triggered
|
||||
let (kind, text) = rx.recv().await.expect("Channel closed unexpectedly");
|
||||
assert_eq!(kind, AuthErrorKind::FailedChallenge);
|
||||
assert_eq!(text, "some text");
|
||||
|
||||
// Wait for a response, failing if we get one
|
||||
tokio::select! {
|
||||
x = t.read() => panic!("Unexpectedly resolved: {:?}", x),
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {}
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_ms(ms: u64) {
|
||||
use std::time::Duration;
|
||||
tokio::time::sleep(Duration::from_millis(ms)).await;
|
||||
}
|
||||
|
||||
fn serialize_and_encrypt(
|
||||
codec: &mut XChaCha20Poly1305Codec,
|
||||
payload: &AuthRequest,
|
||||
) -> io::Result<Vec<u8>> {
|
||||
let mut encryped_payload = BytesMut::new();
|
||||
let payload = utils::serialize_to_vec(payload)?;
|
||||
codec.encode(&payload, &mut encryped_payload)?;
|
||||
Ok(encryped_payload.freeze().to_vec())
|
||||
}
|
||||
|
||||
fn decrypt_and_deserialize(
|
||||
codec: &mut XChaCha20Poly1305Codec,
|
||||
payload: &[u8],
|
||||
) -> io::Result<AuthResponse> {
|
||||
let mut payload = BytesMut::from(payload);
|
||||
match codec.decode(&mut payload)? {
|
||||
Some(payload) => utils::deserialize_from_slice::<AuthResponse>(&payload),
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Incomplete message received",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn spawn_auth_server<ChallengeFn, VerifyFn, InfoFn, ErrorFn>(
|
||||
on_challenge: ChallengeFn,
|
||||
on_verify: VerifyFn,
|
||||
on_info: InfoFn,
|
||||
on_error: ErrorFn,
|
||||
) -> io::Result<(
|
||||
InmemoryTypedTransport<Request<Auth>, Response<Auth>>,
|
||||
Box<dyn ServerRef>,
|
||||
)>
|
||||
where
|
||||
ChallengeFn:
|
||||
Fn(Vec<AuthQuestion>, HashMap<String, String>) -> Vec<String> + Send + Sync + 'static,
|
||||
VerifyFn: Fn(AuthVerifyKind, String) -> bool + Send + Sync + 'static,
|
||||
InfoFn: Fn(String) + Send + Sync + 'static,
|
||||
ErrorFn: Fn(AuthErrorKind, String) + Send + Sync + 'static,
|
||||
{
|
||||
let server = AuthServer {
|
||||
on_challenge,
|
||||
on_verify,
|
||||
on_info,
|
||||
on_error,
|
||||
};
|
||||
|
||||
// Create a test listener where we will forward a connection
|
||||
let (tx, listener) = MpscListener::channel(100);
|
||||
|
||||
// Make bounded transport pair and send off one of them to act as our connection
|
||||
let (transport, connection) =
|
||||
InmemoryTypedTransport::<Request<Auth>, Response<Auth>>::pair(100);
|
||||
tx.send(connection.into_split())
|
||||
.await
|
||||
.expect("Failed to feed listener a connection");
|
||||
|
||||
let server = server.start(listener)?;
|
||||
Ok((transport, server))
|
||||
}
|
||||
}
|
@ -0,0 +1,152 @@
|
||||
use async_trait::async_trait;
|
||||
use derive_more::Display;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{collections::HashMap, io};
|
||||
|
||||
/// Interface for a handler of authentication requests
|
||||
#[async_trait]
|
||||
pub trait AuthHandler {
|
||||
/// Callback when a challenge is received, returning answers to the given questions.
|
||||
async fn on_challenge(
|
||||
&mut self,
|
||||
questions: Vec<AuthQuestion>,
|
||||
options: HashMap<String, String>,
|
||||
) -> io::Result<Vec<String>>;
|
||||
|
||||
/// Callback when a verification request is received, returning true if approvided or false if
|
||||
/// unapproved.
|
||||
async fn on_verify(&mut self, kind: AuthVerifyKind, text: String) -> io::Result<bool>;
|
||||
|
||||
/// Callback when authentication is finished and no more requests will be received
|
||||
async fn on_done(&mut self) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Callback when information is received. To fail, return an error from this function.
|
||||
#[allow(unused_variables)]
|
||||
async fn on_info(&mut self, text: String) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Callback when an error is received. To fail, return an error from this function.
|
||||
async fn on_error(&mut self, kind: AuthErrorKind, text: String) -> io::Result<()> {
|
||||
Err(match kind {
|
||||
AuthErrorKind::FailedChallenge => io::Error::new(
|
||||
io::ErrorKind::PermissionDenied,
|
||||
format!("Failed challenge: {text}"),
|
||||
),
|
||||
AuthErrorKind::FailedVerification => io::Error::new(
|
||||
io::ErrorKind::PermissionDenied,
|
||||
format!("Failed verification: {text}"),
|
||||
),
|
||||
AuthErrorKind::Unknown => {
|
||||
io::Error::new(io::ErrorKind::Other, format!("Unknown error: {text}"))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents authentication messages that act as initiators such as providing
|
||||
/// a challenge, verifying information, presenting information, or highlighting an error
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "type")]
|
||||
pub enum AuthRequest {
|
||||
Challenge(AuthChallengeRequest),
|
||||
Verify(AuthVerifyRequest),
|
||||
Info(AuthInfo),
|
||||
Error(AuthError),
|
||||
}
|
||||
|
||||
/// Represents a challenge comprising a series of questions to be presented
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AuthChallengeRequest {
|
||||
pub questions: Vec<AuthQuestion>,
|
||||
pub options: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// Represents an ask to verify some information
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AuthVerifyRequest {
|
||||
pub kind: AuthVerifyKind,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
/// Represents some information to be presented
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AuthInfo {
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
/// Represents some error that occurred
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AuthError {
|
||||
pub kind: AuthErrorKind,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
/// Represents authentication messages that are responses to auth requests such
|
||||
/// as answers to challenges or verifying information
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "type")]
|
||||
pub enum AuthResponse {
|
||||
Challenge(AuthChallengeResponse),
|
||||
Verify(AuthVerifyResponse),
|
||||
}
|
||||
|
||||
/// Represents the answers to a previously-asked challenge
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AuthChallengeResponse {
|
||||
pub answers: Vec<String>,
|
||||
}
|
||||
|
||||
/// Represents the answer to a previously-asked verify
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AuthVerifyResponse {
|
||||
pub valid: bool,
|
||||
}
|
||||
|
||||
/// Represents the type of verification being requested
|
||||
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[non_exhaustive]
|
||||
pub enum AuthVerifyKind {
|
||||
/// An ask to verify the host such as with SSH
|
||||
#[display(fmt = "host")]
|
||||
Host,
|
||||
}
|
||||
|
||||
/// Represents a single question in a challenge
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct AuthQuestion {
|
||||
/// The text of the question
|
||||
pub text: String,
|
||||
|
||||
/// Any options information specific to a particular auth domain
|
||||
/// such as including a username and instructions for SSH authentication
|
||||
pub options: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl AuthQuestion {
|
||||
/// Creates a new question without any options data
|
||||
pub fn new(text: impl Into<String>) -> Self {
|
||||
Self {
|
||||
text: text.into(),
|
||||
options: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the type of error encountered during authentication
|
||||
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum AuthErrorKind {
|
||||
/// When the answer(s) to a challenge do not pass authentication
|
||||
FailedChallenge,
|
||||
|
||||
/// When verification during authentication fails
|
||||
/// (e.g. a host is not allowed or blocked)
|
||||
FailedVerification,
|
||||
|
||||
/// When the error is unknown
|
||||
Unknown,
|
||||
}
|
Loading…
Reference in New Issue