|
|
|
@ -1,6 +1,6 @@
|
|
|
|
|
use crate::{
|
|
|
|
|
utils::Timer, GenericServerRef, Listener, Request, Response, Server, ServerConnection,
|
|
|
|
|
ServerCtx, ServerRef, ServerReply, ServerState, Shutdown, TypedAsyncRead, TypedAsyncWrite,
|
|
|
|
|
utils::Timer, ConnectionId, GenericServerRef, Interest, Listener, Request, Response, Server,
|
|
|
|
|
ServerConnection, ServerCtx, ServerRef, ServerReply, ServerState, Shutdown, TypedTransport,
|
|
|
|
|
};
|
|
|
|
|
use log::*;
|
|
|
|
|
use serde::{de::DeserializeOwned, Serialize};
|
|
|
|
@ -8,7 +8,10 @@ use std::{
|
|
|
|
|
io,
|
|
|
|
|
sync::{Arc, Weak},
|
|
|
|
|
};
|
|
|
|
|
use tokio::sync::{mpsc, Mutex};
|
|
|
|
|
use tokio::{
|
|
|
|
|
sync::{mpsc, Mutex},
|
|
|
|
|
task::JoinHandle,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
mod tcp;
|
|
|
|
|
pub use tcp::*;
|
|
|
|
@ -33,28 +36,26 @@ pub trait ServerExt {
|
|
|
|
|
type Response;
|
|
|
|
|
|
|
|
|
|
/// Start a new server using the provided listener
|
|
|
|
|
fn start<L, R, W>(self, listener: L) -> io::Result<Box<dyn ServerRef>>
|
|
|
|
|
fn start<L, T>(self, listener: L) -> io::Result<Box<dyn ServerRef>>
|
|
|
|
|
where
|
|
|
|
|
L: Listener<Output = (W, R)> + 'static,
|
|
|
|
|
R: TypedAsyncRead<Request<Self::Request>> + Send + 'static,
|
|
|
|
|
W: TypedAsyncWrite<Response<Self::Response>> + Send + 'static;
|
|
|
|
|
L: Listener<Output = T> + 'static,
|
|
|
|
|
T: TypedTransport<Input = Self::Request, Output = Self::Response> + Send + 'static;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl<S, Req, Res, Data> ServerExt for S
|
|
|
|
|
impl<S> ServerExt for S
|
|
|
|
|
where
|
|
|
|
|
S: Server<Request = Req, Response = Res, LocalData = Data> + Sync + 'static,
|
|
|
|
|
Req: DeserializeOwned + Send + Sync + 'static,
|
|
|
|
|
Res: Serialize + Send + 'static,
|
|
|
|
|
Data: Default + Send + Sync + 'static,
|
|
|
|
|
S: Server + Sync + 'static,
|
|
|
|
|
S::Request: DeserializeOwned + Send + Sync + 'static,
|
|
|
|
|
S::Response: Serialize + Send + 'static,
|
|
|
|
|
S::LocalData: Default + Send + Sync + 'static,
|
|
|
|
|
{
|
|
|
|
|
type Request = Req;
|
|
|
|
|
type Response = Res;
|
|
|
|
|
type Request = S::Request;
|
|
|
|
|
type Response = S::Response;
|
|
|
|
|
|
|
|
|
|
fn start<L, R, W>(self, listener: L) -> io::Result<Box<dyn ServerRef>>
|
|
|
|
|
fn start<L, T>(self, listener: L) -> io::Result<Box<dyn ServerRef>>
|
|
|
|
|
where
|
|
|
|
|
L: Listener<Output = (W, R)> + 'static,
|
|
|
|
|
R: TypedAsyncRead<Request<Self::Request>> + Send + 'static,
|
|
|
|
|
W: TypedAsyncWrite<Response<Self::Response>> + Send + 'static,
|
|
|
|
|
L: Listener<Output = T> + 'static,
|
|
|
|
|
T: TypedTransport<Input = Self::Request, Output = Self::Response> + Send + 'static,
|
|
|
|
|
{
|
|
|
|
|
let server = Arc::new(self);
|
|
|
|
|
let state = Arc::new(ServerState::new());
|
|
|
|
@ -65,15 +66,14 @@ where
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn task<S, Req, Res, Data, L, R, W>(server: Arc<S>, state: Arc<ServerState>, mut listener: L)
|
|
|
|
|
async fn task<S, L, T>(server: Arc<S>, state: Arc<ServerState>, mut listener: L)
|
|
|
|
|
where
|
|
|
|
|
S: Server<Request = Req, Response = Res, LocalData = Data> + Sync + 'static,
|
|
|
|
|
Req: DeserializeOwned + Send + Sync + 'static,
|
|
|
|
|
Res: Serialize + Send + 'static,
|
|
|
|
|
Data: Default + Send + Sync + 'static,
|
|
|
|
|
L: Listener<Output = (W, R)> + 'static,
|
|
|
|
|
R: TypedAsyncRead<Request<Req>> + Send + 'static,
|
|
|
|
|
W: TypedAsyncWrite<Response<Res>> + Send + 'static,
|
|
|
|
|
S: Server<Request = T::Input, Response = T::Output> + Sync + 'static,
|
|
|
|
|
S::LocalData: Default + Send + Sync + 'static,
|
|
|
|
|
L: Listener<Output = T> + 'static,
|
|
|
|
|
T: TypedTransport + Send + 'static,
|
|
|
|
|
T::Input: DeserializeOwned + Send + Sync + 'static,
|
|
|
|
|
T::Output: Serialize + Send + 'static,
|
|
|
|
|
{
|
|
|
|
|
// Grab a copy of our server's configuration so we can leverage it below
|
|
|
|
|
let config = server.config();
|
|
|
|
@ -116,7 +116,7 @@ where
|
|
|
|
|
|
|
|
|
|
// Receive a new connection, exiting if no longer accepting connections or if the shutdown
|
|
|
|
|
// signal has been received
|
|
|
|
|
let (mut writer, mut reader) = tokio::select! {
|
|
|
|
|
let mut transport = tokio::select! {
|
|
|
|
|
result = listener.accept() => {
|
|
|
|
|
match result {
|
|
|
|
|
Ok(x) => x,
|
|
|
|
@ -150,44 +150,70 @@ where
|
|
|
|
|
// Create some default data for the new connection and pass it
|
|
|
|
|
// to the callback prior to processing new requests
|
|
|
|
|
let local_data = {
|
|
|
|
|
let mut data = Data::default();
|
|
|
|
|
let mut data = S::LocalData::default();
|
|
|
|
|
server.on_accept(&mut data).await;
|
|
|
|
|
Arc::new(data)
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Start a writer task that reads from a channel and forwards all
|
|
|
|
|
// data through the writer
|
|
|
|
|
let (tx, mut rx) = mpsc::channel::<Response<Res>>(1);
|
|
|
|
|
connection.writer_task = Some(tokio::spawn(async move {
|
|
|
|
|
while let Some(data) = rx.recv().await {
|
|
|
|
|
// Log our message as a string, which can be expensive
|
|
|
|
|
if log_enabled!(Level::Trace) {
|
|
|
|
|
trace!(
|
|
|
|
|
"[Conn {connection_id}] Sending {}",
|
|
|
|
|
&data
|
|
|
|
|
.to_vec()
|
|
|
|
|
.map(|x| String::from_utf8_lossy(&x).to_string())
|
|
|
|
|
.unwrap_or_else(|_| "<Cannot serialize>".to_string())
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if let Err(x) = writer.write(data).await {
|
|
|
|
|
error!("[Conn {connection_id}] Failed to send {x}");
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
connection.task = Some(
|
|
|
|
|
ConnectionTask {
|
|
|
|
|
id: connection_id,
|
|
|
|
|
server,
|
|
|
|
|
state: Arc::downgrade(&state),
|
|
|
|
|
transport,
|
|
|
|
|
local_data,
|
|
|
|
|
shutdown_timer: shutdown_timer
|
|
|
|
|
.as_ref()
|
|
|
|
|
.map(Arc::downgrade)
|
|
|
|
|
.unwrap_or_default(),
|
|
|
|
|
}
|
|
|
|
|
}));
|
|
|
|
|
|
|
|
|
|
// Start a reader task that reads requests and processes them
|
|
|
|
|
// using the provided handler
|
|
|
|
|
let weak_state = Arc::downgrade(&state);
|
|
|
|
|
let weak_shutdown_timer = shutdown_timer
|
|
|
|
|
.as_ref()
|
|
|
|
|
.map(Arc::downgrade)
|
|
|
|
|
.unwrap_or_default();
|
|
|
|
|
connection.reader_task = Some(tokio::spawn(async move {
|
|
|
|
|
loop {
|
|
|
|
|
match reader.read().await {
|
|
|
|
|
.spawn(),
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
state
|
|
|
|
|
.connections
|
|
|
|
|
.write()
|
|
|
|
|
.await
|
|
|
|
|
.insert(connection_id, connection);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct ConnectionTask<S, T, D> {
|
|
|
|
|
id: ConnectionId,
|
|
|
|
|
server: Arc<S>,
|
|
|
|
|
state: Weak<ServerState>,
|
|
|
|
|
transport: T,
|
|
|
|
|
local_data: D,
|
|
|
|
|
shutdown_timer: Weak<Mutex<Timer<()>>>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl<S, T, D> ConnectionTask<S, T, D>
|
|
|
|
|
where
|
|
|
|
|
S: Server<Request = T::Input, Response = T::Output, LocalData = D> + Sync + 'static,
|
|
|
|
|
D: Default + Send + Sync + 'static,
|
|
|
|
|
T: TypedTransport + Send + 'static,
|
|
|
|
|
T::Input: DeserializeOwned + Send + Sync + 'static,
|
|
|
|
|
T::Output: Serialize + Send + 'static,
|
|
|
|
|
{
|
|
|
|
|
pub fn spawn(self) -> JoinHandle<()> {
|
|
|
|
|
tokio::spawn(self.run())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn run(self) {
|
|
|
|
|
let connection_id = self.id;
|
|
|
|
|
|
|
|
|
|
// Construct a queue of outgoing responses
|
|
|
|
|
let (tx, mut rx) = mpsc::channel::<Response<T::Output>>(1);
|
|
|
|
|
|
|
|
|
|
loop {
|
|
|
|
|
let ready = self
|
|
|
|
|
.transport
|
|
|
|
|
.ready(Interest::READABLE | Interest::WRITABLE)
|
|
|
|
|
.await
|
|
|
|
|
.expect("[Conn {connection_id}] Failed to examine ready state");
|
|
|
|
|
|
|
|
|
|
if ready.is_readable() {
|
|
|
|
|
match self.transport.try_read() {
|
|
|
|
|
Ok(Some(request)) => {
|
|
|
|
|
let reply = ServerReply {
|
|
|
|
|
origin_id: request.id.clone(),
|
|
|
|
@ -198,20 +224,20 @@ where
|
|
|
|
|
connection_id,
|
|
|
|
|
request,
|
|
|
|
|
reply: reply.clone(),
|
|
|
|
|
local_data: Arc::clone(&local_data),
|
|
|
|
|
local_data: Arc::clone(&self.local_data),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
server.on_request(ctx).await;
|
|
|
|
|
self.server.on_request(ctx).await;
|
|
|
|
|
}
|
|
|
|
|
Ok(None) => {
|
|
|
|
|
debug!("[Conn {connection_id}] Connection closed");
|
|
|
|
|
|
|
|
|
|
// Remove the connection from our state if it has closed
|
|
|
|
|
if let Some(state) = Weak::upgrade(&weak_state) {
|
|
|
|
|
state.connections.write().await.remove(&connection_id);
|
|
|
|
|
if let Some(state) = Weak::upgrade(&self.weak_state) {
|
|
|
|
|
state.connections.write().await.remove(&self.connection_id);
|
|
|
|
|
|
|
|
|
|
// If we have no more connections, start the timer
|
|
|
|
|
if let Some(timer) = Weak::upgrade(&weak_shutdown_timer) {
|
|
|
|
|
if let Some(timer) = Weak::upgrade(&self.weak_shutdown_timer) {
|
|
|
|
|
if state.connections.read().await.is_empty() {
|
|
|
|
|
timer.lock().await.start();
|
|
|
|
|
}
|
|
|
|
@ -219,6 +245,7 @@ where
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
Err(x) if x.kind() == io::ErrorKind::WouldBlock => continue,
|
|
|
|
|
Err(x) => {
|
|
|
|
|
// NOTE: We do NOT break out of the loop, as this could happen
|
|
|
|
|
// if someone sends bad data at any point, but does not
|
|
|
|
@ -228,13 +255,35 @@ where
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}));
|
|
|
|
|
|
|
|
|
|
state
|
|
|
|
|
.connections
|
|
|
|
|
.write()
|
|
|
|
|
.await
|
|
|
|
|
.insert(connection_id, connection);
|
|
|
|
|
// If our socket is ready to be written to, we try to get the next item from
|
|
|
|
|
// the queue and process it
|
|
|
|
|
if ready.is_writable() {
|
|
|
|
|
match rx.try_recv() {
|
|
|
|
|
Ok(data) => {
|
|
|
|
|
// Log our message as a string, which can be expensive
|
|
|
|
|
if log_enabled!(Level::Trace) {
|
|
|
|
|
trace!(
|
|
|
|
|
"[Conn {connection_id}] Sending {}",
|
|
|
|
|
&data
|
|
|
|
|
.to_vec()
|
|
|
|
|
.map(|x| String::from_utf8_lossy(&x).to_string())
|
|
|
|
|
.unwrap_or_else(|_| "<Cannot serialize>".to_string())
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
match self.transport.try_write(data) {
|
|
|
|
|
Ok(()) => continue,
|
|
|
|
|
Err(x) if x.kind() == io::ErrorKind::WouldBlock => continue,
|
|
|
|
|
Err(x) => error!("[Conn {connection_id}] Send failed: {x}"),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// If we don't have data, we skip
|
|
|
|
|
Err(_) => continue,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|