mirror of https://github.com/chipsenkbeil/distant
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
163 lines
4.5 KiB
Rust
163 lines
4.5 KiB
Rust
use super::{Codec, DataStream, Transport};
|
|
use futures::stream::Stream;
|
|
use log::*;
|
|
use std::{future::Future, pin::Pin};
|
|
use tokio::{
|
|
io,
|
|
net::{TcpListener, TcpStream},
|
|
sync::mpsc,
|
|
task::JoinHandle,
|
|
};
|
|
|
|
/// Represents a [`Stream`] consisting of newly-connected [`DataStream`] instances that
|
|
/// have been wrapped in [`Transport`]
|
|
pub struct TransportListener<T, U>
|
|
where
|
|
T: DataStream,
|
|
U: Codec,
|
|
{
|
|
listen_task: JoinHandle<()>,
|
|
accept_task: JoinHandle<()>,
|
|
rx: mpsc::Receiver<Transport<T, U>>,
|
|
}
|
|
|
|
impl<T, U> TransportListener<T, U>
|
|
where
|
|
T: DataStream + Send + 'static,
|
|
U: Codec + Send + 'static,
|
|
{
|
|
pub fn initialize<L, F>(listener: L, mut make_transport: F) -> Self
|
|
where
|
|
L: Listener<Output = T> + 'static,
|
|
F: FnMut(T) -> Transport<T, U> + Send + 'static,
|
|
{
|
|
let (stream_tx, mut stream_rx) = mpsc::channel::<T>(1);
|
|
let listen_task = tokio::spawn(async move {
|
|
loop {
|
|
match listener.accept().await {
|
|
Ok(stream) => {
|
|
if stream_tx.send(stream).await.is_err() {
|
|
error!("Listener failed to pass along stream");
|
|
break;
|
|
}
|
|
}
|
|
Err(x) => {
|
|
error!("Listener failed to accept stream: {}", x);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
let (tx, rx) = mpsc::channel::<Transport<T, U>>(1);
|
|
let accept_task = tokio::spawn(async move {
|
|
// Check if we have a new connection. If so, wrap it in a transport and forward
|
|
// it along to
|
|
while let Some(stream) = stream_rx.recv().await {
|
|
let transport = make_transport(stream);
|
|
if let Err(x) = tx.send(transport).await {
|
|
error!("Failed to forward transport: {}", x);
|
|
}
|
|
}
|
|
});
|
|
|
|
Self {
|
|
listen_task,
|
|
accept_task,
|
|
rx,
|
|
}
|
|
}
|
|
|
|
pub fn abort(&self) {
|
|
self.listen_task.abort();
|
|
self.accept_task.abort();
|
|
}
|
|
|
|
/// Waits for the next fully-initialized transport for an incoming stream to be available,
|
|
/// returning none if no longer accepting new connections
|
|
pub async fn accept(&mut self) -> Option<Transport<T, U>> {
|
|
self.rx.recv().await
|
|
}
|
|
|
|
/// Converts into a stream of transport-wrapped connections
|
|
pub fn into_stream(self) -> impl Stream<Item = Transport<T, U>> {
|
|
futures::stream::unfold(self, |mut _self| async move {
|
|
_self
|
|
.accept()
|
|
.await
|
|
.map(move |transport| (transport, _self))
|
|
})
|
|
}
|
|
}
|
|
|
|
pub type AcceptFuture<'a, T> = Pin<Box<dyn Future<Output = io::Result<T>> + Send + 'a>>;
|
|
|
|
/// Represents a type that has a listen interface for receiving raw streams
|
|
pub trait Listener: Send + Sync {
|
|
type Output;
|
|
|
|
fn accept<'a>(&'a self) -> AcceptFuture<'a, Self::Output>
|
|
where
|
|
Self: Sync + 'a;
|
|
}
|
|
|
|
impl Listener for TcpListener {
|
|
type Output = TcpStream;
|
|
|
|
fn accept<'a>(&'a self) -> AcceptFuture<'a, Self::Output>
|
|
where
|
|
Self: Sync + 'a,
|
|
{
|
|
async fn accept(_self: &TcpListener) -> io::Result<TcpStream> {
|
|
_self.accept().await.map(|(stream, _)| stream)
|
|
}
|
|
|
|
Box::pin(accept(self))
|
|
}
|
|
}
|
|
|
|
#[cfg(unix)]
|
|
impl Listener for tokio::net::UnixListener {
|
|
type Output = tokio::net::UnixStream;
|
|
|
|
fn accept<'a>(&'a self) -> AcceptFuture<'a, Self::Output>
|
|
where
|
|
Self: Sync + 'a,
|
|
{
|
|
async fn accept(_self: &tokio::net::UnixListener) -> io::Result<tokio::net::UnixStream> {
|
|
_self.accept().await.map(|(stream, _)| stream)
|
|
}
|
|
|
|
Box::pin(accept(self))
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
impl<T> Listener for tokio::sync::Mutex<tokio::sync::mpsc::Receiver<T>>
|
|
where
|
|
T: DataStream + Send + Sync + 'static,
|
|
{
|
|
type Output = T;
|
|
|
|
fn accept<'a>(&'a self) -> AcceptFuture<'a, Self::Output>
|
|
where
|
|
Self: Sync + 'a,
|
|
{
|
|
async fn accept<T>(
|
|
_self: &tokio::sync::Mutex<tokio::sync::mpsc::Receiver<T>>,
|
|
) -> io::Result<T>
|
|
where
|
|
T: DataStream + Send + Sync + 'static,
|
|
{
|
|
_self
|
|
.lock()
|
|
.await
|
|
.recv()
|
|
.await
|
|
.ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))
|
|
}
|
|
|
|
Box::pin(accept(self))
|
|
}
|
|
}
|