pull/146/head
Chip Senkbeil 2 years ago
parent 4a4360abd1
commit 8c9cdbbb97
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -6,7 +6,7 @@ use crate::{
ConnectionId, DistantMsg, DistantRequestData, DistantResponseData,
};
use async_trait::async_trait;
use distant_net::{Reply, Server, ServerConfig, ServerCtx};
use distant_net::{Reply, ServerConfig, ServerCtx, ServerHandler};
use log::*;
use std::{io, path::PathBuf, sync::Arc};
@ -420,7 +420,7 @@ pub trait DistantApi {
}
#[async_trait]
impl<T, D> Server for DistantApiServer<T, D>
impl<T, D> ServerHandler for DistantApiServer<T, D>
where
T: DistantApi<LocalData = D> + Send + Sync,
D: Send + Sync,

@ -4,7 +4,7 @@ use crate::{
};
use async_trait::async_trait;
use distant_net::{
Client, Listener, MpscListener, Request, Response, Server, ServerCtx, ServerExt,
Client, Listener, MpscListener, Request, Response, ServerCtx, ServerExt, ServerHandler,
};
use log::*;
use std::{collections::HashMap, io, sync::Arc};
@ -267,7 +267,7 @@ pub struct DistantManagerServerConnection {
}
#[async_trait]
impl Server for DistantManager {
impl ServerHandler for DistantManager {
type Request = ManagerRequest;
type Response = ManagerResponse;
type LocalData = DistantManagerServerConnection;

@ -22,7 +22,7 @@ pub use channel::*;
/// Time to wait inbetween connection read/write when nothing was read or written on last pass
const SLEEP_DURATION: Duration = Duration::from_millis(50);
/// Represents a client that can be used to send requests & receive responses from a server
/// Represents a client that can be used to send requests & receive responses from a server.
pub struct Client<T, U> {
/// Used to send requests to a server
channel: Channel<T, U>,

@ -1,16 +1,9 @@
use crate::{auth::Authenticator, utils::Timer, ConnectionId, Listener, Transport};
use crate::{auth::Authenticator, Listener, Transport};
use async_trait::async_trait;
use log::*;
use serde::{de::DeserializeOwned, Serialize};
use std::{
io,
sync::{Arc, Weak},
time::Duration,
};
use tokio::{
sync::{mpsc, Mutex},
task::JoinHandle,
};
use std::{io, sync::Arc};
use tokio::sync::RwLock;
mod config;
pub use config::*;
@ -33,25 +26,29 @@ pub use reply::*;
mod state;
pub use state::*;
/// Time to wait inbetween connection read/write when nothing was read or written on last pass
const SLEEP_DURATION: Duration = Duration::from_millis(50);
mod shutdown_timer;
pub use shutdown_timer::*;
/// Interface for a general-purpose server that receives requests to handle
/// Represents a server that can be used to receive requests & send responses to clients.
pub struct Server<T> {
/// Custom configuration details associated with the server
config: ServerConfig,
/// Handler used to process various server events
handler: T,
}
/// Interface for a handler that receives connections and requests
#[async_trait]
pub trait Server: Send {
pub trait ServerHandler: Send {
/// Type of data received by the server
type Request: DeserializeOwned + Send + Sync;
type Request;
/// Type of data sent back by the server
type Response: Serialize + Send;
type Response;
/// Type of data to store locally tied to the specific connection
type LocalData: Send + Sync;
/// Returns configuration tied to server instance
fn config(&self) -> ServerConfig {
ServerConfig::default()
}
type LocalData;
/// Invoked upon a new connection becoming established.
///
@ -70,298 +67,79 @@ pub trait Server: Send {
/// Invoked upon receiving a request from a client. The server should process this
/// request, which can be found in `ctx`, and send one or more replies in response.
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>);
}
fn start<L>(self, listener: L) -> io::Result<Box<dyn ServerRef>>
impl<T> Server<T>
where
T: ServerHandler + Sync + 'static,
T::Request: DeserializeOwned + Send + Sync + 'static,
T::Response: Serialize + Send + 'static,
T::LocalData: Default + Send + Sync + 'static,
{
/// Consumes the server, starting a task to process connections from the `listener` and
/// returning a [`ServerRef`] that can be used to control the active server instance.
pub fn start<L>(self, listener: L) -> io::Result<Box<dyn ServerRef>>
where
L: Listener + 'static,
L::Output: Transport + Send + Sync + 'static,
{
let server = Arc::new(self);
let state = Arc::new(ServerState::new());
let task = tokio::spawn(task(server, Arc::clone(&state), listener));
let task = tokio::spawn(self.task(Arc::clone(&state), listener));
Ok(Box::new(GenericServerRef { state, task }))
}
}
async fn task<S, L>(server: Arc<S>, state: Arc<ServerState>, mut listener: L)
where
S: Server + Sync + 'static,
S::Request: DeserializeOwned + Send + Sync + 'static,
S::Response: Serialize + Send + 'static,
S::LocalData: Default + Send + Sync + 'static,
L: Listener + 'static,
L::Output: Transport + Send + Sync + 'static,
{
// Grab a copy of our server's configuration so we can leverage it below
let config = server.config();
// Create the timer that will be used shutdown the server after duration elapsed
let (shutdown_tx, mut shutdown_rx) = mpsc::channel(1);
// NOTE: We do a manual map such that the shutdown sender is not captured and dropped when
// there is no shutdown after configured. This is because we need the future for the
// shutdown receiver to last forever in the event that there is no shutdown configured,
// not return immediately, which is what would happen if the sender was dropped.
#[allow(clippy::manual_map)]
let mut shutdown_timer = match config.shutdown {
// Create a timer, start it, and drop it so it will always happen
Shutdown::After(duration) => {
Timer::new(duration, async move {
let _ = shutdown_tx.send(()).await;
})
.start();
None
}
Shutdown::Lonely(duration) => Some(Timer::new(duration, async move {
let _ = shutdown_tx.send(()).await;
})),
Shutdown::Never => None,
};
if let Some(timer) = shutdown_timer.as_mut() {
info!(
"Server shutdown timer configured: {}s",
timer.duration().as_secs_f32()
);
timer.start();
}
let mut shutdown_timer = shutdown_timer.map(|timer| Arc::new(Mutex::new(timer)));
loop {
let server = Arc::clone(&server);
// Receive a new connection, exiting if no longer accepting connections or if the shutdown
// signal has been received
let transport = tokio::select! {
result = listener.accept() => {
match result {
Ok(x) => x,
Err(x) => {
error!("Server no longer accepting connections: {x}");
if let Some(timer) = shutdown_timer.take() {
timer.lock().await.abort();
}
break;
}
}
}
_ = shutdown_rx.recv() => {
info!(
"Server shutdown triggered after {}s",
config.shutdown.duration().unwrap_or_default().as_secs_f32(),
);
break;
}
};
let mut connection = ServerConnection::new();
let connection_id = connection.id;
let state = Arc::clone(&state);
// Ensure that the shutdown timer is cancelled now that we have a connection
if let Some(timer) = shutdown_timer.as_ref() {
timer.lock().await.stop();
}
connection.task = Some(
ConnectionTask {
id: connection_id,
server,
state: Arc::downgrade(&state),
transport,
shutdown_timer: shutdown_timer
.as_ref()
.map(Arc::downgrade)
.unwrap_or_default(),
}
.spawn(),
);
state
.connections
.write()
.await
.insert(connection_id, connection);
}
}
struct ConnectionTask<S, T> {
id: ConnectionId,
server: Arc<S>,
state: Weak<ServerState>,
transport: T,
shutdown_timer: Weak<Mutex<Timer<()>>>,
}
impl<S, T> ConnectionTask<S, T>
where
S: Server + Sync + 'static,
S::Request: DeserializeOwned + Send + Sync + 'static,
S::Response: Serialize + Send + 'static,
S::LocalData: Default + Send + Sync + 'static,
T: Transport + Send + Sync + 'static,
{
pub fn spawn(self) -> JoinHandle<()> {
tokio::spawn(self.run())
}
async fn run(self) {
let connection_id = self.id;
// Construct a queue of outgoing responses
let (tx, mut rx) = mpsc::channel::<Response<S::Response>>(1);
// Perform a handshake to ensure that the connection is properly established
let mut transport: FramedTransport<T> = FramedTransport::plain(self.transport);
if let Err(x) = transport.server_handshake().await {
error!("[Conn {connection_id}] Handshake failed: {x}");
return;
}
// Create local data for the connection and then process it as well as perform
// authentication and any other tasks on first connecting
let mut local_data = S::LocalData::default();
if let Err(x) = self
.server
.on_accept(ConnectionCtx {
connection_id,
authenticator: &mut transport,
local_data: &mut local_data,
})
.await
{
error!("[Conn {connection_id}] Accepting connection failed: {x}");
return;
}
/// Internal task that is run to receive connections and spawn connection tasks
async fn task<L>(self, state: Arc<ServerState>, mut listener: L)
where
L: Listener + 'static,
L::Output: Transport + Send + Sync + 'static,
{
let Server { config, handler } = self;
let local_data = Arc::new(local_data);
let handler = Arc::new(handler);
let timer = ShutdownTimer::new(config.shutdown);
let mut notification = timer.clone_notification();
let timer = Arc::new(RwLock::new(timer));
loop {
let ready = transport
.ready(Interest::READABLE | Interest::WRITABLE)
.await
.expect("[Conn {connection_id}] Failed to examine ready state");
// Keep track of whether we read or wrote anything
let mut read_blocked = false;
let mut write_blocked = false;
if ready.is_readable() {
match transport.try_read_frame() {
Ok(Some(frame)) => match UntypedRequest::from_slice(frame.as_item()) {
Ok(request) => match request.to_typed_request() {
Ok(request) => {
let reply = ServerReply {
origin_id: request.id.clone(),
tx: tx.clone(),
};
let ctx = ServerCtx {
connection_id,
request,
reply: reply.clone(),
local_data: Arc::clone(&local_data),
};
self.server.on_request(ctx).await;
}
Err(x) => {
if log::log_enabled!(Level::Trace) {
trace!(
"[Conn {connection_id}] Failed receiving {}",
String::from_utf8_lossy(&request.payload),
);
}
error!("[Conn {connection_id}] Invalid request: {x}");
}
},
// Receive a new connection, exiting if no longer accepting connections or if the shutdown
// signal has been received
let transport = tokio::select! {
result = listener.accept() => {
match result {
Ok(x) => x,
Err(x) => {
error!("[Conn {connection_id}] Invalid request: {x}");
}
},
Ok(None) => {
debug!("[Conn {connection_id}] Connection closed");
// Remove the connection from our state if it has closed
if let Some(state) = Weak::upgrade(&self.state) {
state.connections.write().await.remove(&self.id);
// If we have no more connections, start the timer
if let Some(timer) = Weak::upgrade(&self.shutdown_timer) {
if state.connections.read().await.is_empty() {
timer.lock().await.start();
}
}
error!("Server no longer accepting connections: {x}");
timer.read().await.abort();
break;
}
break;
}
Err(x) if x.kind() == io::ErrorKind::WouldBlock => read_blocked = true,
Err(x) => {
// NOTE: We do NOT break out of the loop, as this could happen
// if someone sends bad data at any point, but does not
// mean that the reader itself has failed. This can
// happen from getting non-compliant typed data
error!("[Conn {connection_id}] {x}");
}
}
}
// If our socket is ready to be written to, we try to get the next item from
// the queue and process it
if ready.is_writable() {
// If we get more data to write, attempt to write it, which will result in writing
// any queued bytes as well. Othewise, we attempt to flush any pending outgoing
// bytes that weren't sent earlier.
if let Ok(response) = rx.try_recv() {
// Log our message as a string, which can be expensive
if log_enabled!(Level::Trace) {
trace!(
"[Conn {connection_id}] Sending {}",
&response
.to_vec()
.map(|x| String::from_utf8_lossy(&x).to_string())
.unwrap_or_else(|_| "<Cannot serialize>".to_string())
);
}
match response.to_vec() {
Ok(data) => match transport.try_write_frame(data) {
Ok(()) => (),
Err(x) if x.kind() == io::ErrorKind::WouldBlock => write_blocked = true,
Err(x) => error!("[Conn {connection_id}] Send failed: {x}"),
},
Err(x) => {
error!(
"[Conn {connection_id}] Unable to serialize outgoing response: {x}"
);
}
}
} else {
// In the case of flushing, there are two scenarios in which we want to
// mark no write occurring:
//
// 1. When flush did not write any bytes, which can happen when the buffer
// is empty
// 2. When the call to write bytes blocks
match transport.try_flush() {
Ok(0) => write_blocked = true,
Ok(_) => (),
Err(x) if x.kind() == io::ErrorKind::WouldBlock => write_blocked = true,
Err(x) => {
error!("[Conn {connection_id}] Failed to flush outgoing data: {x}");
}
}
_ = notification.wait() => {
info!(
"Server shutdown triggered after {}s",
config.shutdown.duration().unwrap_or_default().as_secs_f32(),
);
break;
}
}
};
// Ensure that the shutdown timer is cancelled now that we have a connection
timer.read().await.stop();
// If we did not read or write anything, sleep a bit to offload CPU usage
if read_blocked && write_blocked {
tokio::time::sleep(SLEEP_DURATION).await;
}
let connection = Connection::build()
.handler(Arc::downgrade(&handler))
.state(Arc::downgrade(&state))
.transport(transport)
.shutdown_timer(Arc::downgrade(&timer))
.spawn();
state
.connections
.write()
.await
.insert(connection.id(), connection);
}
}
}
@ -369,22 +147,21 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::{auth::Authenticator, InmemoryTransport, MpscListener, Request, ServerConfig};
use crate::{
auth::Authenticator, InmemoryTransport, MpscListener, Request, Response, ServerConfig,
};
use async_trait::async_trait;
use std::time::Duration;
use tokio::sync::mpsc;
pub struct TestServer(ServerConfig);
pub struct TestServerHandler(ServerConfig);
#[async_trait]
impl Server for TestServer {
impl ServerHandler for TestServerHandler {
type Request = u16;
type Response = String;
type LocalData = ();
fn config(&self) -> ServerConfig {
self.0.clone()
}
async fn on_accept<A: Authenticator>(
&self,
ctx: ConnectionCtx<'_, A, Self::LocalData>,
@ -419,7 +196,7 @@ mod tests {
.await
.expect("Failed to feed listener a connection");
let _server = ServerExt::start(TestServer(ServerConfig::default()), listener)
let _server = ServerExt::start(TestServerHandler(ServerConfig::default()), listener)
.expect("Failed to start server");
transport
@ -437,7 +214,7 @@ mod tests {
let (_tx, listener) = make_listener(100);
let server = ServerExt::start(
TestServer(ServerConfig {
TestServerHandler(ServerConfig {
shutdown: Shutdown::Lonely(Duration::from_millis(100)),
..Default::default()
}),
@ -464,7 +241,7 @@ mod tests {
.expect("Failed to feed listener a connection");
let server = ServerExt::start(
TestServer(ServerConfig {
TestServerHandler(ServerConfig {
shutdown: Shutdown::Lonely(Duration::from_millis(100)),
..Default::default()
}),
@ -493,7 +270,7 @@ mod tests {
.expect("Failed to feed listener a connection");
let server = ServerExt::start(
TestServer(ServerConfig {
TestServerHandler(ServerConfig {
shutdown: Shutdown::Lonely(Duration::from_millis(100)),
..Default::default()
}),
@ -518,7 +295,7 @@ mod tests {
.expect("Failed to feed listener a connection");
let server = ServerExt::start(
TestServer(ServerConfig {
TestServerHandler(ServerConfig {
shutdown: Shutdown::After(Duration::from_millis(100)),
..Default::default()
}),
@ -537,7 +314,7 @@ mod tests {
let (_tx, listener) = make_listener(100);
let server = ServerExt::start(
TestServer(ServerConfig {
TestServerHandler(ServerConfig {
shutdown: Shutdown::After(Duration::from_millis(100)),
..Default::default()
}),
@ -556,7 +333,7 @@ mod tests {
let (_tx, listener) = make_listener(100);
let server = ServerExt::start(
TestServer(ServerConfig {
TestServerHandler(ServerConfig {
shutdown: Shutdown::Never,
..Default::default()
}),

@ -1,39 +1,299 @@
use crate::ConnectionId;
use tokio::task::JoinHandle;
use crate::{
ConnectionCtx, FramedTransport, Interest, Response, ServerCtx, ServerHandler, ServerReply,
ServerState, ShutdownTimer, Transport, UntypedRequest,
};
use log::*;
use serde::{de::DeserializeOwned, Serialize};
use std::{
io,
sync::{Arc, Weak},
time::Duration,
};
use tokio::{
sync::{mpsc, RwLock},
task::JoinHandle,
};
/// Time to wait inbetween connection read/write when nothing was read or written on last pass
const SLEEP_DURATION: Duration = Duration::from_millis(50);
/// Id associated with an active connection
pub type ConnectionId = u64;
/// Represents an individual connection on the server
pub struct ServerConnection {
pub struct Connection {
/// Unique identifier tied to the connection
pub id: ConnectionId,
id: ConnectionId,
/// Task that is processing requests and responses
pub(crate) task: Option<JoinHandle<()>>,
task: JoinHandle<()>,
}
impl Default for ServerConnection {
fn default() -> Self {
Self::new()
impl Connection {
/// Starts building a new connection
pub fn build() -> ConnectionBuilder<(), ()> {
let id: ConnectionId = rand::random();
ConnectionBuilder {
id,
handler: Weak::new(),
state: Weak::new(),
transport: (),
shutdown_timer: Weak::new(),
}
}
}
impl ServerConnection {
/// Creates a new connection, generating a unique id to represent the connection
pub fn new() -> Self {
Self {
id: rand::random(),
task: None,
}
/// Returns the id associated with the connection
pub fn id(&self) -> ConnectionId {
self.id
}
/// Returns true if connection is still processing incoming or outgoing messages
pub fn is_active(&self) -> bool {
self.task.is_some() && !self.task.as_ref().unwrap().is_finished()
!self.task.is_finished()
}
/// Aborts the connection
pub fn abort(&self) {
if let Some(task) = self.task.as_ref() {
task.abort();
self.task.abort();
}
}
pub struct ConnectionBuilder<H, T> {
id: ConnectionId,
handler: Weak<H>,
state: Weak<ServerState>,
transport: T,
shutdown_timer: Weak<RwLock<ShutdownTimer>>,
}
impl<H, T> ConnectionBuilder<H, T> {
pub fn handler<U>(self, handler: Weak<U>) -> ConnectionBuilder<U, T> {
ConnectionBuilder {
id: self.id,
handler,
state: self.state,
transport: self.transport,
shutdown_timer: self.shutdown_timer,
}
}
pub fn state(self, state: Weak<ServerState>) -> ConnectionBuilder<H, T> {
ConnectionBuilder {
id: self.id,
handler: self.handler,
state,
transport: self.transport,
shutdown_timer: self.shutdown_timer,
}
}
pub fn transport<U>(self, transport: U) -> ConnectionBuilder<H, U> {
ConnectionBuilder {
id: self.id,
handler: self.handler,
state: self.state,
transport,
shutdown_timer: self.shutdown_timer,
}
}
pub fn shutdown_timer(
self,
shutdown_timer: Weak<RwLock<ShutdownTimer>>,
) -> ConnectionBuilder<H, T> {
ConnectionBuilder {
id: self.id,
handler: self.handler,
state: self.state,
transport: self.transport,
shutdown_timer,
}
}
}
impl<H, T> ConnectionBuilder<H, T>
where
H: ServerHandler + Sync + 'static,
H::Request: DeserializeOwned + Send + Sync + 'static,
H::Response: Serialize + Send + 'static,
H::LocalData: Default + Send + Sync + 'static,
T: Transport + Send + Sync + 'static,
{
pub fn spawn(self) -> Connection {
let id = self.id;
Connection {
id,
task: tokio::spawn(self.run()),
}
}
async fn run(self) {
let ConnectionBuilder {
id,
handler,
state,
transport,
shutdown_timer,
} = self;
// Attempt to upgrade our handler for use with the connection going forward
let handler = match Weak::upgrade(&handler) {
Some(handler) => handler,
None => {
error!("[Conn {id}] Handler has been dropped");
return;
}
};
// Construct a queue of outgoing responses
let (tx, mut rx) = mpsc::channel::<Response<H::Response>>(1);
// Perform a handshake to ensure that the connection is properly established
let mut transport: FramedTransport<T> = FramedTransport::plain(transport);
if let Err(x) = transport.server_handshake().await {
error!("[Conn {id}] Handshake failed: {x}");
return;
}
// Create local data for the connection and then process it as well as perform
// authentication and any other tasks on first connecting
let mut local_data = H::LocalData::default();
if let Err(x) = handler
.on_accept(ConnectionCtx {
connection_id: id,
authenticator: &mut transport,
local_data: &mut local_data,
})
.await
{
error!("[Conn {id}] Accepting connection failed: {x}");
return;
}
let local_data = Arc::new(local_data);
loop {
let ready = transport
.ready(Interest::READABLE | Interest::WRITABLE)
.await
.expect("[Conn {connection_id}] Failed to examine ready state");
// Keep track of whether we read or wrote anything
let mut read_blocked = false;
let mut write_blocked = false;
if ready.is_readable() {
match transport.try_read_frame() {
Ok(Some(frame)) => match UntypedRequest::from_slice(frame.as_item()) {
Ok(request) => match request.to_typed_request() {
Ok(request) => {
let reply = ServerReply {
origin_id: request.id.clone(),
tx: tx.clone(),
};
let ctx = ServerCtx {
connection_id: id,
request,
reply: reply.clone(),
local_data: Arc::clone(&local_data),
};
handler.on_request(ctx).await;
}
Err(x) => {
if log::log_enabled!(Level::Trace) {
trace!(
"[Conn {id}] Failed receiving {}",
String::from_utf8_lossy(&request.payload),
);
}
error!("[Conn {id}] Invalid request: {x}");
}
},
Err(x) => {
error!("[Conn {id}] Invalid request: {x}");
}
},
Ok(None) => {
debug!("[Conn {id}] Connection closed");
// Remove the connection from our state if it has closed
if let Some(state) = Weak::upgrade(&state) {
state.connections.write().await.remove(&self.id);
// If we have no more connections, start the timer
if let Some(timer) = Weak::upgrade(&shutdown_timer) {
if state.connections.read().await.is_empty() {
timer.write().await.start();
}
}
}
break;
}
Err(x) if x.kind() == io::ErrorKind::WouldBlock => read_blocked = true,
Err(x) => {
// NOTE: We do NOT break out of the loop, as this could happen
// if someone sends bad data at any point, but does not
// mean that the reader itself has failed. This can
// happen from getting non-compliant typed data
error!("[Conn {id}] {x}");
}
}
}
// If our socket is ready to be written to, we try to get the next item from
// the queue and process it
if ready.is_writable() {
// If we get more data to write, attempt to write it, which will result in writing
// any queued bytes as well. Othewise, we attempt to flush any pending outgoing
// bytes that weren't sent earlier.
if let Ok(response) = rx.try_recv() {
// Log our message as a string, which can be expensive
if log_enabled!(Level::Trace) {
trace!(
"[Conn {id}] Sending {}",
&response
.to_vec()
.map(|x| String::from_utf8_lossy(&x).to_string())
.unwrap_or_else(|_| "<Cannot serialize>".to_string())
);
}
match response.to_vec() {
Ok(data) => match transport.try_write_frame(data) {
Ok(()) => (),
Err(x) if x.kind() == io::ErrorKind::WouldBlock => write_blocked = true,
Err(x) => error!("[Conn {id}] Send failed: {x}"),
},
Err(x) => {
error!("[Conn {id}] Unable to serialize outgoing response: {x}");
}
}
} else {
// In the case of flushing, there are two scenarios in which we want to
// mark no write occurring:
//
// 1. When flush did not write any bytes, which can happen when the buffer
// is empty
// 2. When the call to write bytes blocks
match transport.try_flush() {
Ok(0) => write_blocked = true,
Ok(_) => (),
Err(x) if x.kind() == io::ErrorKind::WouldBlock => write_blocked = true,
Err(x) => {
error!("[Conn {id}] Failed to flush outgoing data: {x}");
}
}
}
}
// If we did not read or write anything, sleep a bit to offload CPU usage
if read_blocked && write_blocked {
tokio::time::sleep(SLEEP_DURATION).await;
}
}
}
}

@ -1,6 +1,6 @@
use crate::{
utils::Timer, ConnectionCtx, ConnectionId, FramedTransport, GenericServerRef, Interest,
Listener, Response, Server, ServerConnection, ServerCtx, ServerRef, ServerReply, ServerState,
utils::Timer, Connection, ConnectionCtx, ConnectionId, FramedTransport, GenericServerRef,
Interest, Listener, Response, ServerCtx, ServerHandler, ServerRef, ServerReply, ServerState,
Shutdown, Transport, UntypedRequest,
};
use log::*;
@ -49,7 +49,7 @@ pub trait ServerExt {
impl<S> ServerExt for S
where
S: Server + Sync + 'static,
S: ServerHandler + Sync + 'static,
S::Request: DeserializeOwned + Send + Sync + 'static,
S::Response: Serialize + Send + 'static,
S::LocalData: Default + Send + Sync + 'static,
@ -73,7 +73,7 @@ where
async fn task<S, L>(server: Arc<S>, state: Arc<ServerState>, mut listener: L)
where
S: Server + Sync + 'static,
S: ServerHandler + Sync + 'static,
S::Request: DeserializeOwned + Send + Sync + 'static,
S::Response: Serialize + Send + 'static,
S::LocalData: Default + Send + Sync + 'static,
@ -143,7 +143,7 @@ where
}
};
let mut connection = ServerConnection::new();
let mut connection = Connection::new();
let connection_id = connection.id;
let state = Arc::clone(&state);
@ -184,7 +184,7 @@ struct ConnectionTask<S, T> {
impl<S, T> ConnectionTask<S, T>
where
S: Server + Sync + 'static,
S: ServerHandler + Sync + 'static,
S::Request: DeserializeOwned + Send + Sync + 'static,
S::Response: Serialize + Send + 'static,
S::LocalData: Default + Send + Sync + 'static,
@ -362,7 +362,7 @@ mod tests {
pub struct TestServer(ServerConfig);
#[async_trait]
impl Server for TestServer {
impl ServerHandler for TestServer {
type Request = u16;
type Response = String;
type LocalData = ();

@ -1,4 +1,4 @@
use crate::{PortRange, Server, ServerExt, TcpListener, TcpServerRef};
use crate::{PortRange, ServerExt, ServerHandler, TcpListener, TcpServerRef};
use async_trait::async_trait;
use serde::{de::DeserializeOwned, Serialize};
use std::{io, net::IpAddr};
@ -19,7 +19,7 @@ pub trait TcpServerExt {
#[async_trait]
impl<S> TcpServerExt for S
where
S: Server + Sync + 'static,
S: ServerHandler + Sync + 'static,
S::Request: DeserializeOwned + Send + Sync + 'static,
S::Response: Serialize + Send + 'static,
S::LocalData: Default + Send + Sync + 'static,
@ -56,7 +56,7 @@ mod tests {
pub struct TestServer;
#[async_trait]
impl Server for TestServer {
impl ServerHandler for TestServer {
type Request = String;
type Response = String;
type LocalData = ();

@ -1,4 +1,4 @@
use crate::{Server, ServerExt, UnixSocketListener, UnixSocketServerRef};
use crate::{ServerExt, ServerHandler, UnixSocketListener, UnixSocketServerRef};
use async_trait::async_trait;
use serde::{de::DeserializeOwned, Serialize};
use std::{io, path::Path};
@ -19,7 +19,7 @@ pub trait UnixSocketServerExt {
#[async_trait]
impl<S> UnixSocketServerExt for S
where
S: Server + Sync + 'static,
S: ServerHandler + Sync + 'static,
S::Request: DeserializeOwned + Send + Sync + 'static,
S::Response: Serialize + Send + 'static,
S::LocalData: Default + Send + Sync + 'static,
@ -58,7 +58,7 @@ mod tests {
pub struct TestServer;
#[async_trait]
impl Server for TestServer {
impl ServerHandler for TestServer {
type Request = String;
type Response = String;
type LocalData = ();

@ -0,0 +1,100 @@
use super::Shutdown;
use crate::utils::Timer;
use log::*;
use std::ops::{Deref, DerefMut};
use std::time::Duration;
use tokio::sync::watch;
/// Cloneable notification for when a [`ShutdownTimer`] has completed.
#[derive(Clone)]
pub struct ShutdownNotification(watch::Receiver<()>);
impl ShutdownNotification {
/// Waits to receive a notification that the shutdown timer has concluded
pub async fn wait(&mut self) {
let _ = self.0.changed().await;
}
}
/// Wrapper around [`Timer`] to support shutdown-specific notifications.
pub struct ShutdownTimer {
timer: Timer<()>,
watcher: ShutdownNotification,
}
impl ShutdownTimer {
pub fn new(shutdown: Shutdown) -> Self {
// Create the timer that will be used shutdown the server after duration elapsed
let (tx, mut rx) = watch::channel(());
// NOTE: We do a manual map such that the shutdown sender is not captured and dropped when
// there is no shutdown after configured. This is because we need the future for the
// shutdown receiver to last forever in the event that there is no shutdown configured,
// not return immediately, which is what would happen if the sender was dropped.
#[allow(clippy::manual_map)]
let mut timer = match shutdown {
// Create a timer that will complete after `duration`, dropping it to ensure that it
// will always happen no matter if stop/abort is called
Shutdown::After(duration) => {
info!(
"Server shutdown timer configured: terminate after {}s",
duration.as_secs_f32()
);
Timer::new(duration, async move {
let _ = tx.send(());
})
}
// Create a timer that will complete after `duration`
Shutdown::Lonely(duration) => {
info!(
"Server shutdown timer configured: terminate after no activity in {}s",
duration.as_secs_f32()
);
Timer::new(duration, async move {
let _ = tx.send(());
})
}
// Create a timer that will never complete (max timeout possible) so we hold on to the
// sender to avoid the receiver from completing
Shutdown::Never => {
info!("Server shutdown timer configured: never terminate");
Timer::new(Duration::MAX, async move {
let _ = tx.send(());
})
}
};
timer.start();
Self {
timer,
watcher: ShutdownNotification(rx),
}
}
/// Clones the notification
pub fn clone_notification(&self) -> ShutdownNotification {
self.watcher.clone()
}
/// Wait for the timer to complete
pub async fn wait(&mut self) {
self.watcher.wait().await
}
}
impl Deref for ShutdownTimer {
type Target = Timer<()>;
fn deref(&self) -> &Self::Target {
&self.timer
}
}
impl DerefMut for ShutdownTimer {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.timer
}
}

@ -1,11 +1,11 @@
use crate::{ConnectionId, HeapSecretKey, ServerConnection};
use crate::{Connection, ConnectionId, HeapSecretKey};
use std::collections::HashMap;
use tokio::sync::RwLock;
/// Contains all top-level state for the server
pub struct ServerState {
/// Mapping of connection ids to their transports
pub connections: RwLock<HashMap<ConnectionId, ServerConnection>>,
pub connections: RwLock<HashMap<ConnectionId, Connection>>,
/// Mapping of connection ids to their authenticated keys
pub authenticated: RwLock<HashMap<ConnectionId, HeapSecretKey>>,

@ -65,6 +65,7 @@ where
pub fn start(&mut self) {
// Cancel the active timer task
self.stop();
self.active_timer = None;
// Exit early if callback completed as starting will do nothing
if self.callback.is_finished() {
@ -82,9 +83,8 @@ where
/// Stops the timer, cancelling the internal task, but leaving the callback in place in case
/// the timer is re-started later
pub fn stop(&mut self) {
// Delete the active timer task
if let Some(task) = self.active_timer.take() {
pub fn stop(&self) {
if let Some(task) = self.active_timer.as_ref() {
task.abort();
}
}
@ -92,10 +92,7 @@ where
/// Aborts the timer's callback task and internal task to trigger the callback, which means
/// that the timer will never complete the callback and starting will have no effect
pub fn abort(&self) {
if let Some(task) = self.active_timer.as_ref() {
task.abort();
}
self.stop();
self.callback.abort();
}
}

Loading…
Cancel
Save