Move around some net structs and impls to their own modules, add some client tests

pull/38/head
Chip Senkbeil 3 years ago
parent aded5fd16f
commit f12c3428eb
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -17,6 +17,10 @@ pub const MAX_PIPE_CHUNK_SIZE: usize = 1024;
/// Represents the length of the salt to use for encryption
pub const SALT_LEN: usize = 16;
/// Test-only constant for channel buffer size
#[cfg(test)]
pub const TEST_BUFFER_SIZE: usize = 100;
lazy_static::lazy_static! {
pub static ref TIMEOUT_STR: String = TIMEOUT.to_string();

@ -225,3 +225,79 @@ where
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{
constants::TEST_BUFFER_SIZE,
data::{RequestData, ResponseData},
};
use crate::net::InmemoryStream;
use orion::aead::SecretKey;
use std::time::Duration;
const TEST_TENANT: &str = "test-tenant";
/// Makes a connected pair of transports with matching auth and crypt keys
pub fn make_transport_pair() -> (Transport<InmemoryStream>, Transport<InmemoryStream>) {
let auth_key = Arc::new(SecretKey::default());
let crypt_key = Arc::new(SecretKey::default());
let (a, b) = InmemoryStream::pair(TEST_BUFFER_SIZE);
let a = Transport::new(a, Some(Arc::clone(&auth_key)), Arc::clone(&crypt_key));
let b = Transport::new(b, Some(auth_key), crypt_key);
(a, b)
}
#[tokio::test]
async fn send_should_wait_until_response_received() {
let (t1, mut t2) = make_transport_pair();
let mut client = Client::inner_connect(t1).await.unwrap();
let req = Request::new(TEST_TENANT, vec![RequestData::ProcList {}]);
let res = Response::new(
TEST_TENANT,
Some(req.id),
vec![ResponseData::ProcEntries {
entries: Vec::new(),
}],
);
let (actual, _) = tokio::join!(client.send(req), t2.send(res.clone()));
match actual {
Ok(actual) => assert_eq!(actual, res),
x => panic!("Unexpected response: {:?}", x),
}
}
#[tokio::test]
async fn send_timeout_should_fail_if_response_not_received_in_time() {
let (t1, mut t2) = make_transport_pair();
let mut client = Client::inner_connect(t1).await.unwrap();
let req = Request::new(TEST_TENANT, vec![RequestData::ProcList {}]);
match client.send_timeout(req, Duration::from_millis(30)).await {
Err(TransportError::IoError(x)) => assert_eq!(x.kind(), io::ErrorKind::TimedOut),
x => panic!("Unexpected response: {:?}", x),
}
let req = t2.receive::<Request>().await.unwrap().unwrap();
assert_eq!(req.tenant, TEST_TENANT);
}
#[tokio::test]
async fn fire_should_send_request_and_not_wait_for_response() {
let (t1, mut t2) = make_transport_pair();
let mut client = Client::inner_connect(t1).await.unwrap();
let req = Request::new(TEST_TENANT, vec![RequestData::ProcList {}]);
match client.fire(req).await {
Ok(_) => {}
x => panic!("Unexpected response: {:?}", x),
}
let req = t2.receive::<Request>().await.unwrap().unwrap();
assert_eq!(req.tenant, TEST_TENANT);
}
}

@ -1,5 +1,5 @@
mod transport;
pub use transport::{DataStream, Transport, TransportError, TransportReadHalf, TransportWriteHalf};
pub use transport::*;
mod client;
pub use client::Client;

@ -0,0 +1,123 @@
use super::DataStream;
use std::{
pin::Pin,
task::{Context, Poll},
};
use tokio::{
io::{self, AsyncRead, AsyncWrite, ReadBuf},
sync::mpsc,
};
/// Represents a data stream comprised of two inmemory channels
pub struct InmemoryStream {
incoming: InmemoryStreamReadHalf,
outgoing: InmemoryStreamWriteHalf,
}
impl InmemoryStream {
pub fn new(incoming: mpsc::Receiver<Vec<u8>>, outgoing: mpsc::Sender<Vec<u8>>) -> Self {
Self {
incoming: InmemoryStreamReadHalf(incoming),
outgoing: InmemoryStreamWriteHalf(outgoing),
}
}
/// Returns (incoming_tx, outgoing_rx, stream)
pub fn make(buffer: usize) -> (mpsc::Sender<Vec<u8>>, mpsc::Receiver<Vec<u8>>, Self) {
let (incoming_tx, incoming_rx) = mpsc::channel(buffer);
let (outgoing_tx, outgoing_rx) = mpsc::channel(buffer);
(
incoming_tx,
outgoing_rx,
Self::new(incoming_rx, outgoing_tx),
)
}
/// Returns pair of streams that are connected such that one sends to the other and
/// vice versa
pub fn pair(buffer: usize) -> (Self, Self) {
let (tx, rx, stream) = Self::make(buffer);
(stream, Self::new(rx, tx))
}
}
impl AsyncRead for InmemoryStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.incoming).poll_read(cx, buf)
}
}
impl AsyncWrite for InmemoryStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.outgoing).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.outgoing).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.outgoing).poll_shutdown(cx)
}
}
pub struct InmemoryStreamReadHalf(mpsc::Receiver<Vec<u8>>);
impl AsyncRead for InmemoryStreamReadHalf {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.0.poll_recv(cx).map(|x| match x {
Some(x) => {
buf.put_slice(&x);
Ok(())
}
None => Ok(()),
})
}
}
pub struct InmemoryStreamWriteHalf(mpsc::Sender<Vec<u8>>);
impl AsyncWrite for InmemoryStreamWriteHalf {
fn poll_write(
self: Pin<&mut Self>,
_: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.0.try_send(buf.to_vec()) {
Ok(_) => Poll::Ready(Ok(buf.len())),
Err(_) => Poll::Ready(Ok(0)),
}
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.poll_flush(cx)
}
}
impl DataStream for InmemoryStream {
type Read = InmemoryStreamReadHalf;
type Write = InmemoryStreamWriteHalf;
fn to_connection_tag(&self) -> String {
String::from("test-stream")
}
fn into_split(self) -> (Self::Read, Self::Write) {
(self.incoming, self.outgoing)
}
}

@ -1,4 +1,4 @@
use crate::core::{constants::SALT_LEN, session::Session};
use crate::core::constants::SALT_LEN;
use codec::DistantCodec;
use derive_more::{Display, Error, From};
use futures::SinkExt;
@ -12,16 +12,25 @@ use orion::{
pwhash::Password,
};
use serde::{de::DeserializeOwned, Serialize};
use std::{marker::Unpin, net::SocketAddr, sync::Arc};
use tokio::{
io::{self, AsyncRead, AsyncWrite},
net::{self, tcp, TcpStream},
};
use std::{marker::Unpin, sync::Arc};
use tokio::io::{self, AsyncRead, AsyncWrite};
use tokio_stream::StreamExt;
use tokio_util::codec::{Framed, FramedRead, FramedWrite};
mod codec;
mod inmemory;
pub use inmemory::*;
mod tcp;
pub use tcp::*;
#[cfg(unix)]
mod unix;
#[cfg(unix)]
pub use unix::*;
#[derive(Debug, Display, Error, From)]
pub enum TransportError {
#[from(ignore)]
@ -48,37 +57,6 @@ pub trait DataStream: AsyncRead + AsyncWrite + Unpin {
fn into_split(self) -> (Self::Read, Self::Write);
}
impl DataStream for TcpStream {
type Read = tcp::OwnedReadHalf;
type Write = tcp::OwnedWriteHalf;
fn to_connection_tag(&self) -> String {
self.peer_addr()
.map(|addr| format!("{}", addr))
.unwrap_or_else(|_| String::from("--"))
}
fn into_split(self) -> (Self::Read, Self::Write) {
TcpStream::into_split(self)
}
}
#[cfg(unix)]
impl DataStream for net::UnixStream {
type Read = net::unix::OwnedReadHalf;
type Write = net::unix::OwnedWriteHalf;
fn to_connection_tag(&self) -> String {
self.peer_addr()
.map(|addr| format!("{:?}", addr))
.unwrap_or_else(|_| String::from("--"))
}
fn into_split(self) -> (Self::Read, Self::Write) {
net::UnixStream::into_split(self)
}
}
/// Sends some data across the wire, waiting for it to completely send
macro_rules! send {
($conn:expr, $crypt_key:expr, $auth_key:expr, $data:expr) => {
@ -180,6 +158,15 @@ impl<T> Transport<T>
where
T: DataStream,
{
/// Creates a new instance of the transport, wrapping the stream in a `Framed<T, DistantCodec>`
pub fn new(stream: T, auth_key: Option<Arc<SecretKey>>, crypt_key: Arc<SecretKey>) -> Self {
Self {
conn: Framed::new(stream, DistantCodec),
auth_key,
crypt_key,
}
}
/// Takes a pre-existing connection and performs a handshake to build out the encryption key
/// with the remote system, returning a transport ready to communicate with the other side
///
@ -267,14 +254,12 @@ where
}
/// Sends some data across the wire, waiting for it to completely send
#[allow(dead_code)]
pub async fn send<D: Serialize>(&mut self, data: D) -> Result<(), TransportError> {
send!(self.conn, self.crypt_key, self.auth_key.as_ref(), data).await
}
/// Receives some data from out on the wire, waiting until it's available,
/// returning none if the transport is now closed
#[allow(dead_code)]
pub async fn receive<R: DeserializeOwned>(&mut self) -> Result<Option<R>, TransportError> {
recv!(self.conn, self.crypt_key, self.auth_key).await
}
@ -308,42 +293,6 @@ where
}
}
impl Transport<TcpStream> {
/// Establishes a connection using the provided session and performs a handshake to establish
/// means of encryption, returning a transport ready to communicate with the other side
///
/// TCP Streams will always use a session's authentication key
pub async fn connect(session: Session) -> io::Result<Self> {
let stream = TcpStream::connect(session.to_socket_addr().await?).await?;
Self::from_handshake(stream, Some(Arc::new(session.auth_key))).await
}
/// Returns the address of the peer the transport is connected to
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.conn.get_ref().peer_addr()
}
}
#[cfg(unix)]
impl Transport<net::UnixStream> {
/// Establishes a connection using the provided session and performs a handshake to establish
/// means of encryption, returning a transport ready to communicate with the other side
///
/// Takes an optional authentication key
pub async fn connect(
path: impl AsRef<std::path::Path>,
auth_key: Option<Arc<SecretKey>>,
) -> io::Result<Self> {
let stream = net::UnixStream::connect(path.as_ref()).await?;
Self::from_handshake(stream, auth_key).await
}
/// Returns the address of the peer the transport is connected to
pub fn peer_addr(&self) -> io::Result<net::unix::SocketAddr> {
self.conn.get_ref().peer_addr()
}
}
/// Represents a transport of data out to the network
pub struct TransportWriteHalf<T>
where
@ -398,133 +347,13 @@ where
#[cfg(test)]
mod tests {
use super::*;
use std::{
io,
pin::Pin,
task::{Context, Poll},
};
use tokio::{io::ReadBuf, sync::mpsc};
pub const TEST_DATA_STREAM_CHANNEL_BUFFER_SIZE: usize = 100;
/// Represents a data stream comprised of two inmemory buffers of data
pub struct TestDataStream {
incoming: TestDataStreamReadHalf,
outgoing: TestDataStreamWriteHalf,
}
impl TestDataStream {
pub fn new(incoming: mpsc::Receiver<Vec<u8>>, outgoing: mpsc::Sender<Vec<u8>>) -> Self {
Self {
incoming: TestDataStreamReadHalf(incoming),
outgoing: TestDataStreamWriteHalf(outgoing),
}
}
/// Returns (incoming_tx, outgoing_rx, stream)
pub fn make() -> (mpsc::Sender<Vec<u8>>, mpsc::Receiver<Vec<u8>>, Self) {
let (incoming_tx, incoming_rx) = mpsc::channel(TEST_DATA_STREAM_CHANNEL_BUFFER_SIZE);
let (outgoing_tx, outgoing_rx) = mpsc::channel(TEST_DATA_STREAM_CHANNEL_BUFFER_SIZE);
(
incoming_tx,
outgoing_rx,
Self::new(incoming_rx, outgoing_tx),
)
}
/// Returns pair of streams that are connected such that one sends to the other and
/// vice versa
pub fn pair() -> (Self, Self) {
let (tx, rx, stream) = Self::make();
(stream, Self::new(rx, tx))
}
}
impl AsyncRead for TestDataStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.incoming).poll_read(cx, buf)
}
}
impl AsyncWrite for TestDataStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.outgoing).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.outgoing).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.outgoing).poll_shutdown(cx)
}
}
pub struct TestDataStreamReadHalf(mpsc::Receiver<Vec<u8>>);
impl AsyncRead for TestDataStreamReadHalf {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.0.poll_recv(cx).map(|x| match x {
Some(x) => {
buf.put_slice(&x);
Ok(())
}
None => Ok(()),
})
}
}
pub struct TestDataStreamWriteHalf(mpsc::Sender<Vec<u8>>);
impl AsyncWrite for TestDataStreamWriteHalf {
fn poll_write(
self: Pin<&mut Self>,
_: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.0.try_send(buf.to_vec()) {
Ok(_) => Poll::Ready(Ok(buf.len())),
Err(_) => Poll::Ready(Ok(0)),
}
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.poll_flush(cx)
}
}
impl DataStream for TestDataStream {
type Read = TestDataStreamReadHalf;
type Write = TestDataStreamWriteHalf;
fn to_connection_tag(&self) -> String {
String::from("test-stream")
}
fn into_split(self) -> (Self::Read, Self::Write) {
(self.incoming, self.outgoing)
}
}
use crate::core::constants::TEST_BUFFER_SIZE;
use std::io;
#[tokio::test]
async fn transport_from_handshake_should_fail_if_connection_reached_eof() {
// Cause nothing left incoming to stream by _
let (_, mut rx, stream) = TestDataStream::make();
let (_, mut rx, stream) = InmemoryStream::make(TEST_BUFFER_SIZE);
let result = Transport::from_handshake(stream, None).await;
// Verify that a salt and public key were sent out first
@ -548,7 +377,7 @@ mod tests {
#[tokio::test]
async fn transport_from_handshake_should_fail_if_response_data_is_too_small() {
let (tx, _rx, stream) = TestDataStream::make();
let (tx, _rx, stream) = InmemoryStream::make(TEST_BUFFER_SIZE);
// Need SALT + PUB KEY where salt has a defined size; so, at least 1 larger than salt
// would succeed, whereas we are providing exactly salt, which will fail
@ -569,7 +398,7 @@ mod tests {
#[tokio::test]
async fn transport_from_handshake_should_fail_if_bad_foreign_public_key_received() {
let (tx, _rx, stream) = TestDataStream::make();
let (tx, _rx, stream) = InmemoryStream::make(TEST_BUFFER_SIZE);
// Send {SALT LEN}{SALT}{PUB KEY} where public key is bad;
// normally public key bytes would be {LEN}{KEY} where len is first byte;
@ -600,7 +429,7 @@ mod tests {
#[tokio::test]
async fn transport_should_be_able_to_send_encrypted_data_to_other_side_to_decrypt() {
let (src, dst) = TestDataStream::pair();
let (src, dst) = InmemoryStream::pair(TEST_BUFFER_SIZE);
// NOTE: This is slow during tests as it is an expensive process and we're doing it twice!
let (src, dst) = tokio::join!(
@ -623,7 +452,7 @@ mod tests {
#[tokio::test]
async fn transport_should_be_able_to_sign_and_validate_signature_if_auth_key_included() {
let (src, dst) = TestDataStream::pair();
let (src, dst) = InmemoryStream::pair(TEST_BUFFER_SIZE);
let auth_key = Arc::new(SecretKey::default());
@ -648,7 +477,7 @@ mod tests {
#[tokio::test]
async fn transport_receive_should_fail_if_auth_key_differs_from_other_end() {
let (src, dst) = TestDataStream::pair();
let (src, dst) = InmemoryStream::pair(TEST_BUFFER_SIZE);
// Make two transports with different auth keys
// NOTE: This is slow during tests as it is an expensive process and we're doing it twice!
@ -669,7 +498,7 @@ mod tests {
#[tokio::test]
async fn transport_receive_should_fail_if_has_auth_key_while_sender_did_not_use_one() {
let (src, dst) = TestDataStream::pair();
let (src, dst) = InmemoryStream::pair(TEST_BUFFER_SIZE);
// Make two transports with different auth keys
// NOTE: This is slow during tests as it is an expensive process and we're doing it twice!
@ -695,7 +524,7 @@ mod tests {
#[tokio::test]
async fn transport_receive_should_fail_if_has_no_auth_key_while_sender_used_one() {
let (src, dst) = TestDataStream::pair();
let (src, dst) = InmemoryStream::pair(TEST_BUFFER_SIZE);
// Make two transports with different auth keys
// NOTE: This is slow during tests as it is an expensive process and we're doing it twice!

@ -0,0 +1,41 @@
use super::{DataStream, Transport};
use crate::core::session::Session;
use std::{net::SocketAddr, sync::Arc};
use tokio::{
io,
net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
},
};
impl DataStream for TcpStream {
type Read = OwnedReadHalf;
type Write = OwnedWriteHalf;
fn to_connection_tag(&self) -> String {
self.peer_addr()
.map(|addr| format!("{}", addr))
.unwrap_or_else(|_| String::from("--"))
}
fn into_split(self) -> (Self::Read, Self::Write) {
TcpStream::into_split(self)
}
}
impl Transport<TcpStream> {
/// Establishes a connection using the provided session and performs a handshake to establish
/// means of encryption, returning a transport ready to communicate with the other side
///
/// TCP Streams will always use a session's authentication key
pub async fn connect(session: Session) -> io::Result<Self> {
let stream = TcpStream::connect(session.to_socket_addr().await?).await?;
Self::from_handshake(stream, Some(Arc::new(session.auth_key))).await
}
/// Returns the address of the peer the transport is connected to
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.conn.get_ref().peer_addr()
}
}

@ -0,0 +1,44 @@
use super::{DataStream, Transport};
use orion::aead::SecretKey;
use std::sync::Arc;
use tokio::{
io,
net::{
unix::{OwnedReadHalf, OwnedWriteHalf, SocketAddr},
UnixStream,
},
};
impl DataStream for UnixStream {
type Read = OwnedReadHalf;
type Write = OwnedWriteHalf;
fn to_connection_tag(&self) -> String {
self.peer_addr()
.map(|addr| format!("{:?}", addr))
.unwrap_or_else(|_| String::from("--"))
}
fn into_split(self) -> (Self::Read, Self::Write) {
UnixStream::into_split(self)
}
}
impl Transport<UnixStream> {
/// Establishes a connection using the provided session and performs a handshake to establish
/// means of encryption, returning a transport ready to communicate with the other side
///
/// Takes an optional authentication key
pub async fn connect(
path: impl AsRef<std::path::Path>,
auth_key: Option<Arc<SecretKey>>,
) -> io::Result<Self> {
let stream = UnixStream::connect(path.as_ref()).await?;
Self::from_handshake(stream, auth_key).await
}
/// Returns the address of the peer the transport is connected to
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.conn.get_ref().peer_addr()
}
}

@ -1,6 +1,7 @@
mod cli;
mod core;
pub use self::core::{data, net};
use log::error;
/// Represents an error that can be converted into an exit code

Loading…
Cancel
Save