mirror of https://github.com/chipsenkbeil/distant
Implemented broken framed logic
parent
a52fb82fbf
commit
3c7561bef8
@ -0,0 +1,84 @@
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use derive_more::{Display, Error, From};
|
||||
use tokio::io;
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
/// Represents a marker to indicate the beginning of the next message
|
||||
static MSG_START: &'static [u8] = b";start;";
|
||||
|
||||
/// Represents a marker to indicate the end of the next message
|
||||
static MSG_END: &'static [u8] = b";end;";
|
||||
|
||||
#[inline]
|
||||
fn packet_size(msg_size: usize) -> usize {
|
||||
MSG_START.len() + msg_size + MSG_END.len()
|
||||
}
|
||||
|
||||
/// Possible errors that can occur during encoding and decoding
|
||||
#[derive(Debug, Display, Error, From)]
|
||||
pub enum DistantCodecError {
|
||||
#[display(fmt = "Corrupt Marker: {:?}", _0)]
|
||||
CorruptMarker(#[error(not(source))] Bytes),
|
||||
IoError(io::Error),
|
||||
}
|
||||
|
||||
/// Represents the codec to encode and decode data for transmission
|
||||
pub struct DistantCodec;
|
||||
|
||||
impl<'a> Encoder<&'a [u8]> for DistantCodec {
|
||||
type Error = DistantCodecError;
|
||||
|
||||
fn encode(&mut self, item: &'a [u8], dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
// Add our full packet to the bytes
|
||||
dst.reserve(packet_size(item.len()));
|
||||
dst.put(MSG_START);
|
||||
dst.put(item);
|
||||
dst.put(MSG_END);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for DistantCodec {
|
||||
type Item = Vec<u8>;
|
||||
type Error = DistantCodecError;
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
// First, check if we have more data than just our markers, if not we say that it's okay
|
||||
// but that we're waiting
|
||||
if src.len() <= (MSG_START.len() + MSG_END.len()) {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Second, verify that our first N bytes match our start marker
|
||||
let marker_start = &src[..MSG_START.len()];
|
||||
if marker_start != MSG_START {
|
||||
return Err(DistantCodecError::CorruptMarker(Bytes::copy_from_slice(
|
||||
marker_start,
|
||||
)));
|
||||
}
|
||||
|
||||
// Third, find end of message marker by scanning the available bytes, and
|
||||
// consume a full packet of bytes
|
||||
let mut maybe_frame = None;
|
||||
for i in (MSG_START.len() + 1)..(src.len() - MSG_END.len()) {
|
||||
let marker_end = &src[i..(i + MSG_END.len())];
|
||||
if marker_end == MSG_END {
|
||||
maybe_frame = Some(src.split_to(i + MSG_END.len()));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Fourth, return our msg if it's available, stripping it of the start and end markers
|
||||
if let Some(frame) = maybe_frame {
|
||||
let data = &frame[MSG_START.len()..(frame.len() - MSG_END.len())];
|
||||
|
||||
// Advance so frame is no longer kept around
|
||||
src.advance(frame.len());
|
||||
|
||||
Ok(Some(data.to_vec()))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,80 @@
|
||||
use crate::utils::Session;
|
||||
use codec::{DistantCodec, DistantCodecError};
|
||||
use derive_more::{Display, Error, From};
|
||||
use futures::SinkExt;
|
||||
use orion::{
|
||||
aead::{self, SecretKey},
|
||||
errors::UnknownCryptoError,
|
||||
};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use std::sync::Arc;
|
||||
use tokio::{io, net::TcpStream};
|
||||
use tokio_stream::StreamExt;
|
||||
use tokio_util::codec::Framed;
|
||||
|
||||
mod codec;
|
||||
|
||||
#[derive(Debug, Display, Error, From)]
|
||||
pub enum TransportError {
|
||||
CodecError(DistantCodecError),
|
||||
EncryptError(UnknownCryptoError),
|
||||
IoError(io::Error),
|
||||
SerializeError(serde_cbor::Error),
|
||||
}
|
||||
|
||||
/// Represents a transport of data across the network
|
||||
pub struct Transport {
|
||||
inner: Framed<TcpStream, DistantCodec>,
|
||||
key: Arc<SecretKey>,
|
||||
}
|
||||
|
||||
impl Transport {
|
||||
/// Wraps a `TcpStream` and associated credentials in a transport layer
|
||||
pub fn new(stream: TcpStream, key: Arc<SecretKey>) -> Self {
|
||||
Self {
|
||||
inner: Framed::new(stream, DistantCodec),
|
||||
key,
|
||||
}
|
||||
}
|
||||
|
||||
/// Establishes a connection using the provided session
|
||||
pub async fn connect(session: Session) -> io::Result<Self> {
|
||||
let stream = TcpStream::connect(session.to_socket_addr().await?).await?;
|
||||
Ok(Self::new(stream, Arc::new(session.key)))
|
||||
}
|
||||
|
||||
/// Sends some data across the wire
|
||||
pub async fn send<T: Serialize>(&mut self, data: T) -> Result<(), TransportError> {
|
||||
// Serialize, encrypt, and then (TODO) sign
|
||||
let data = serde_cbor::ser::to_vec_packed(&data)?;
|
||||
let data = aead::seal(&self.key, &data)?;
|
||||
|
||||
self.inner
|
||||
.send(&data)
|
||||
.await
|
||||
.map_err(TransportError::CodecError)
|
||||
}
|
||||
|
||||
/// Receives some data from out on the wire, waiting until it's available
|
||||
pub async fn receive<T: DeserializeOwned>(&mut self) -> Result<T, TransportError> {
|
||||
loop {
|
||||
if let Some(data) = self.try_receive().await? {
|
||||
break Ok(data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempts to receive some data from out on the wire, returning that data if available
|
||||
/// or none if unavailable
|
||||
pub async fn try_receive<T: DeserializeOwned>(&mut self) -> Result<Option<T>, TransportError> {
|
||||
if let Some(data) = self.inner.next().await {
|
||||
// Validate (TODO), decrypt, and then deserialize
|
||||
let data = data?;
|
||||
let data = aead::open(&self.key, &data)?;
|
||||
let data = serde_cbor::from_slice(&data)?;
|
||||
Ok(Some(data))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,7 @@
|
||||
use crate::utils::Session;
|
||||
use tokio::io;
|
||||
|
||||
pub fn run() -> Result<(), io::Error> {
|
||||
let rt = tokio::runtime::Runtime::new()?;
|
||||
rt.block_on(async { Session::clear().await })
|
||||
}
|
@ -1,3 +1,4 @@
|
||||
pub mod clear_session;
|
||||
pub mod execute;
|
||||
pub mod launch;
|
||||
pub mod listen;
|
||||
|
@ -0,0 +1,126 @@
|
||||
use crate::{PROJECT_DIRS, SESSION_PATH};
|
||||
use derive_more::{Display, Error, From};
|
||||
use orion::aead::SecretKey;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use tokio::{io, net::lookup_host};
|
||||
|
||||
#[derive(Debug, Display, Error, From)]
|
||||
pub enum SessionError {
|
||||
#[display(fmt = "Bad hex key for session")]
|
||||
BadSessionHexKey,
|
||||
|
||||
#[display(fmt = "Invalid address for session")]
|
||||
InvalidSessionAddr,
|
||||
|
||||
#[display(fmt = "Invalid key for session")]
|
||||
InvalidSessionKey,
|
||||
|
||||
#[display(fmt = "Invalid port for session")]
|
||||
InvalidSessionPort,
|
||||
|
||||
IoError(io::Error),
|
||||
|
||||
#[display(fmt = "Missing address for session")]
|
||||
MissingSessionAddr,
|
||||
|
||||
#[display(fmt = "Missing key for session")]
|
||||
MissingSessionKey,
|
||||
|
||||
#[display(fmt = "Missing port for session")]
|
||||
MissingSessionPort,
|
||||
|
||||
#[display(fmt = "No session file: {:?}", SESSION_PATH.as_path())]
|
||||
NoSessionFile,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct Session {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub key: SecretKey,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
/// Returns a string representing the secret key as hex
|
||||
pub fn to_hex_key(&self) -> String {
|
||||
hex::encode(self.key.unprotected_as_bytes())
|
||||
}
|
||||
|
||||
/// Returns the ip address associated with the session based on the host
|
||||
pub async fn to_ip_addr(&self) -> io::Result<IpAddr> {
|
||||
let addr = match self.host.parse::<IpAddr>() {
|
||||
Ok(addr) => addr,
|
||||
Err(_) => lookup_host((self.host.as_str(), self.port))
|
||||
.await?
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::NotFound, SessionError::InvalidSessionAddr)
|
||||
})?
|
||||
.ip(),
|
||||
};
|
||||
|
||||
Ok(addr)
|
||||
}
|
||||
|
||||
/// Returns socket address associated with the session
|
||||
pub async fn to_socket_addr(&self) -> io::Result<SocketAddr> {
|
||||
let addr = self.to_ip_addr().await?;
|
||||
Ok(SocketAddr::from((addr, self.port)))
|
||||
}
|
||||
|
||||
/// Clears the global session file
|
||||
pub async fn clear() -> io::Result<()> {
|
||||
tokio::fs::remove_file(SESSION_PATH.as_path()).await
|
||||
}
|
||||
|
||||
/// Saves a session to disk
|
||||
pub async fn save(&self) -> io::Result<()> {
|
||||
let key_hex_str = self.to_hex_key();
|
||||
|
||||
// Ensure our cache directory exists
|
||||
let cache_dir = PROJECT_DIRS.cache_dir();
|
||||
tokio::fs::create_dir_all(cache_dir).await?;
|
||||
|
||||
// Write our session file
|
||||
let addr = self.to_ip_addr().await?;
|
||||
tokio::fs::write(
|
||||
SESSION_PATH.as_path(),
|
||||
format!("{} {} {}", addr, self.port, key_hex_str),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Loads a session's information into memory
|
||||
pub async fn load() -> Result<Self, SessionError> {
|
||||
let text = tokio::fs::read_to_string(SESSION_PATH.as_path())
|
||||
.await
|
||||
.map_err(|_| SessionError::NoSessionFile)?;
|
||||
let mut tokens = text.split(' ').take(3);
|
||||
|
||||
// First, load up the address without parsing it
|
||||
let host = tokens
|
||||
.next()
|
||||
.ok_or(SessionError::MissingSessionAddr)?
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
// Second, load up the port and parse it into a number
|
||||
let port = tokens
|
||||
.next()
|
||||
.ok_or(SessionError::MissingSessionPort)?
|
||||
.trim()
|
||||
.parse::<u16>()
|
||||
.map_err(|_| SessionError::InvalidSessionPort)?;
|
||||
|
||||
// Third, load up the key and convert it back into a secret key from a hex slice
|
||||
let key = SecretKey::from_slice(
|
||||
&hex::decode(tokens.next().ok_or(SessionError::MissingSessionKey)?.trim())
|
||||
.map_err(|_| SessionError::BadSessionHexKey)?,
|
||||
)
|
||||
.map_err(|_| SessionError::InvalidSessionKey)?;
|
||||
|
||||
Ok(Session { host, port, key })
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue