mirror of https://github.com/chipsenkbeil/distant
pull/146/head
parent
4a4360abd1
commit
8c9cdbbb97
@ -1,39 +1,299 @@
|
||||
use crate::ConnectionId;
|
||||
use tokio::task::JoinHandle;
|
||||
use crate::{
|
||||
ConnectionCtx, FramedTransport, Interest, Response, ServerCtx, ServerHandler, ServerReply,
|
||||
ServerState, ShutdownTimer, Transport, UntypedRequest,
|
||||
};
|
||||
use log::*;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use std::{
|
||||
io,
|
||||
sync::{Arc, Weak},
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::{
|
||||
sync::{mpsc, RwLock},
|
||||
task::JoinHandle,
|
||||
};
|
||||
|
||||
/// Time to wait inbetween connection read/write when nothing was read or written on last pass
|
||||
const SLEEP_DURATION: Duration = Duration::from_millis(50);
|
||||
|
||||
/// Id associated with an active connection
|
||||
pub type ConnectionId = u64;
|
||||
|
||||
/// Represents an individual connection on the server
|
||||
pub struct ServerConnection {
|
||||
pub struct Connection {
|
||||
/// Unique identifier tied to the connection
|
||||
pub id: ConnectionId,
|
||||
id: ConnectionId,
|
||||
|
||||
/// Task that is processing requests and responses
|
||||
pub(crate) task: Option<JoinHandle<()>>,
|
||||
task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl Default for ServerConnection {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
impl Connection {
|
||||
/// Starts building a new connection
|
||||
pub fn build() -> ConnectionBuilder<(), ()> {
|
||||
let id: ConnectionId = rand::random();
|
||||
ConnectionBuilder {
|
||||
id,
|
||||
handler: Weak::new(),
|
||||
state: Weak::new(),
|
||||
transport: (),
|
||||
shutdown_timer: Weak::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerConnection {
|
||||
/// Creates a new connection, generating a unique id to represent the connection
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
id: rand::random(),
|
||||
task: None,
|
||||
}
|
||||
/// Returns the id associated with the connection
|
||||
pub fn id(&self) -> ConnectionId {
|
||||
self.id
|
||||
}
|
||||
|
||||
/// Returns true if connection is still processing incoming or outgoing messages
|
||||
pub fn is_active(&self) -> bool {
|
||||
self.task.is_some() && !self.task.as_ref().unwrap().is_finished()
|
||||
!self.task.is_finished()
|
||||
}
|
||||
|
||||
/// Aborts the connection
|
||||
pub fn abort(&self) {
|
||||
if let Some(task) = self.task.as_ref() {
|
||||
task.abort();
|
||||
self.task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ConnectionBuilder<H, T> {
|
||||
id: ConnectionId,
|
||||
handler: Weak<H>,
|
||||
state: Weak<ServerState>,
|
||||
transport: T,
|
||||
shutdown_timer: Weak<RwLock<ShutdownTimer>>,
|
||||
}
|
||||
|
||||
impl<H, T> ConnectionBuilder<H, T> {
|
||||
pub fn handler<U>(self, handler: Weak<U>) -> ConnectionBuilder<U, T> {
|
||||
ConnectionBuilder {
|
||||
id: self.id,
|
||||
handler,
|
||||
state: self.state,
|
||||
transport: self.transport,
|
||||
shutdown_timer: self.shutdown_timer,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn state(self, state: Weak<ServerState>) -> ConnectionBuilder<H, T> {
|
||||
ConnectionBuilder {
|
||||
id: self.id,
|
||||
handler: self.handler,
|
||||
state,
|
||||
transport: self.transport,
|
||||
shutdown_timer: self.shutdown_timer,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn transport<U>(self, transport: U) -> ConnectionBuilder<H, U> {
|
||||
ConnectionBuilder {
|
||||
id: self.id,
|
||||
handler: self.handler,
|
||||
state: self.state,
|
||||
transport,
|
||||
shutdown_timer: self.shutdown_timer,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shutdown_timer(
|
||||
self,
|
||||
shutdown_timer: Weak<RwLock<ShutdownTimer>>,
|
||||
) -> ConnectionBuilder<H, T> {
|
||||
ConnectionBuilder {
|
||||
id: self.id,
|
||||
handler: self.handler,
|
||||
state: self.state,
|
||||
transport: self.transport,
|
||||
shutdown_timer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<H, T> ConnectionBuilder<H, T>
|
||||
where
|
||||
H: ServerHandler + Sync + 'static,
|
||||
H::Request: DeserializeOwned + Send + Sync + 'static,
|
||||
H::Response: Serialize + Send + 'static,
|
||||
H::LocalData: Default + Send + Sync + 'static,
|
||||
T: Transport + Send + Sync + 'static,
|
||||
{
|
||||
pub fn spawn(self) -> Connection {
|
||||
let id = self.id;
|
||||
|
||||
Connection {
|
||||
id,
|
||||
task: tokio::spawn(self.run()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn run(self) {
|
||||
let ConnectionBuilder {
|
||||
id,
|
||||
handler,
|
||||
state,
|
||||
transport,
|
||||
shutdown_timer,
|
||||
} = self;
|
||||
|
||||
// Attempt to upgrade our handler for use with the connection going forward
|
||||
let handler = match Weak::upgrade(&handler) {
|
||||
Some(handler) => handler,
|
||||
None => {
|
||||
error!("[Conn {id}] Handler has been dropped");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Construct a queue of outgoing responses
|
||||
let (tx, mut rx) = mpsc::channel::<Response<H::Response>>(1);
|
||||
|
||||
// Perform a handshake to ensure that the connection is properly established
|
||||
let mut transport: FramedTransport<T> = FramedTransport::plain(transport);
|
||||
if let Err(x) = transport.server_handshake().await {
|
||||
error!("[Conn {id}] Handshake failed: {x}");
|
||||
return;
|
||||
}
|
||||
|
||||
// Create local data for the connection and then process it as well as perform
|
||||
// authentication and any other tasks on first connecting
|
||||
let mut local_data = H::LocalData::default();
|
||||
if let Err(x) = handler
|
||||
.on_accept(ConnectionCtx {
|
||||
connection_id: id,
|
||||
authenticator: &mut transport,
|
||||
local_data: &mut local_data,
|
||||
})
|
||||
.await
|
||||
{
|
||||
error!("[Conn {id}] Accepting connection failed: {x}");
|
||||
return;
|
||||
}
|
||||
|
||||
let local_data = Arc::new(local_data);
|
||||
|
||||
loop {
|
||||
let ready = transport
|
||||
.ready(Interest::READABLE | Interest::WRITABLE)
|
||||
.await
|
||||
.expect("[Conn {connection_id}] Failed to examine ready state");
|
||||
|
||||
// Keep track of whether we read or wrote anything
|
||||
let mut read_blocked = false;
|
||||
let mut write_blocked = false;
|
||||
|
||||
if ready.is_readable() {
|
||||
match transport.try_read_frame() {
|
||||
Ok(Some(frame)) => match UntypedRequest::from_slice(frame.as_item()) {
|
||||
Ok(request) => match request.to_typed_request() {
|
||||
Ok(request) => {
|
||||
let reply = ServerReply {
|
||||
origin_id: request.id.clone(),
|
||||
tx: tx.clone(),
|
||||
};
|
||||
|
||||
let ctx = ServerCtx {
|
||||
connection_id: id,
|
||||
request,
|
||||
reply: reply.clone(),
|
||||
local_data: Arc::clone(&local_data),
|
||||
};
|
||||
|
||||
handler.on_request(ctx).await;
|
||||
}
|
||||
Err(x) => {
|
||||
if log::log_enabled!(Level::Trace) {
|
||||
trace!(
|
||||
"[Conn {id}] Failed receiving {}",
|
||||
String::from_utf8_lossy(&request.payload),
|
||||
);
|
||||
}
|
||||
|
||||
error!("[Conn {id}] Invalid request: {x}");
|
||||
}
|
||||
},
|
||||
Err(x) => {
|
||||
error!("[Conn {id}] Invalid request: {x}");
|
||||
}
|
||||
},
|
||||
Ok(None) => {
|
||||
debug!("[Conn {id}] Connection closed");
|
||||
|
||||
// Remove the connection from our state if it has closed
|
||||
if let Some(state) = Weak::upgrade(&state) {
|
||||
state.connections.write().await.remove(&self.id);
|
||||
|
||||
// If we have no more connections, start the timer
|
||||
if let Some(timer) = Weak::upgrade(&shutdown_timer) {
|
||||
if state.connections.read().await.is_empty() {
|
||||
timer.write().await.start();
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
Err(x) if x.kind() == io::ErrorKind::WouldBlock => read_blocked = true,
|
||||
Err(x) => {
|
||||
// NOTE: We do NOT break out of the loop, as this could happen
|
||||
// if someone sends bad data at any point, but does not
|
||||
// mean that the reader itself has failed. This can
|
||||
// happen from getting non-compliant typed data
|
||||
error!("[Conn {id}] {x}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If our socket is ready to be written to, we try to get the next item from
|
||||
// the queue and process it
|
||||
if ready.is_writable() {
|
||||
// If we get more data to write, attempt to write it, which will result in writing
|
||||
// any queued bytes as well. Othewise, we attempt to flush any pending outgoing
|
||||
// bytes that weren't sent earlier.
|
||||
if let Ok(response) = rx.try_recv() {
|
||||
// Log our message as a string, which can be expensive
|
||||
if log_enabled!(Level::Trace) {
|
||||
trace!(
|
||||
"[Conn {id}] Sending {}",
|
||||
&response
|
||||
.to_vec()
|
||||
.map(|x| String::from_utf8_lossy(&x).to_string())
|
||||
.unwrap_or_else(|_| "<Cannot serialize>".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
match response.to_vec() {
|
||||
Ok(data) => match transport.try_write_frame(data) {
|
||||
Ok(()) => (),
|
||||
Err(x) if x.kind() == io::ErrorKind::WouldBlock => write_blocked = true,
|
||||
Err(x) => error!("[Conn {id}] Send failed: {x}"),
|
||||
},
|
||||
Err(x) => {
|
||||
error!("[Conn {id}] Unable to serialize outgoing response: {x}");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// In the case of flushing, there are two scenarios in which we want to
|
||||
// mark no write occurring:
|
||||
//
|
||||
// 1. When flush did not write any bytes, which can happen when the buffer
|
||||
// is empty
|
||||
// 2. When the call to write bytes blocks
|
||||
match transport.try_flush() {
|
||||
Ok(0) => write_blocked = true,
|
||||
Ok(_) => (),
|
||||
Err(x) if x.kind() == io::ErrorKind::WouldBlock => write_blocked = true,
|
||||
Err(x) => {
|
||||
error!("[Conn {id}] Failed to flush outgoing data: {x}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we did not read or write anything, sleep a bit to offload CPU usage
|
||||
if read_blocked && write_blocked {
|
||||
tokio::time::sleep(SLEEP_DURATION).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,100 @@
|
||||
use super::Shutdown;
|
||||
use crate::utils::Timer;
|
||||
use log::*;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::watch;
|
||||
|
||||
/// Cloneable notification for when a [`ShutdownTimer`] has completed.
|
||||
#[derive(Clone)]
|
||||
pub struct ShutdownNotification(watch::Receiver<()>);
|
||||
|
||||
impl ShutdownNotification {
|
||||
/// Waits to receive a notification that the shutdown timer has concluded
|
||||
pub async fn wait(&mut self) {
|
||||
let _ = self.0.changed().await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper around [`Timer`] to support shutdown-specific notifications.
|
||||
pub struct ShutdownTimer {
|
||||
timer: Timer<()>,
|
||||
watcher: ShutdownNotification,
|
||||
}
|
||||
|
||||
impl ShutdownTimer {
|
||||
pub fn new(shutdown: Shutdown) -> Self {
|
||||
// Create the timer that will be used shutdown the server after duration elapsed
|
||||
let (tx, mut rx) = watch::channel(());
|
||||
|
||||
// NOTE: We do a manual map such that the shutdown sender is not captured and dropped when
|
||||
// there is no shutdown after configured. This is because we need the future for the
|
||||
// shutdown receiver to last forever in the event that there is no shutdown configured,
|
||||
// not return immediately, which is what would happen if the sender was dropped.
|
||||
#[allow(clippy::manual_map)]
|
||||
let mut timer = match shutdown {
|
||||
// Create a timer that will complete after `duration`, dropping it to ensure that it
|
||||
// will always happen no matter if stop/abort is called
|
||||
Shutdown::After(duration) => {
|
||||
info!(
|
||||
"Server shutdown timer configured: terminate after {}s",
|
||||
duration.as_secs_f32()
|
||||
);
|
||||
Timer::new(duration, async move {
|
||||
let _ = tx.send(());
|
||||
})
|
||||
}
|
||||
|
||||
// Create a timer that will complete after `duration`
|
||||
Shutdown::Lonely(duration) => {
|
||||
info!(
|
||||
"Server shutdown timer configured: terminate after no activity in {}s",
|
||||
duration.as_secs_f32()
|
||||
);
|
||||
Timer::new(duration, async move {
|
||||
let _ = tx.send(());
|
||||
})
|
||||
}
|
||||
|
||||
// Create a timer that will never complete (max timeout possible) so we hold on to the
|
||||
// sender to avoid the receiver from completing
|
||||
Shutdown::Never => {
|
||||
info!("Server shutdown timer configured: never terminate");
|
||||
Timer::new(Duration::MAX, async move {
|
||||
let _ = tx.send(());
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
timer.start();
|
||||
|
||||
Self {
|
||||
timer,
|
||||
watcher: ShutdownNotification(rx),
|
||||
}
|
||||
}
|
||||
|
||||
/// Clones the notification
|
||||
pub fn clone_notification(&self) -> ShutdownNotification {
|
||||
self.watcher.clone()
|
||||
}
|
||||
|
||||
/// Wait for the timer to complete
|
||||
pub async fn wait(&mut self) {
|
||||
self.watcher.wait().await
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for ShutdownTimer {
|
||||
type Target = Timer<()>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.timer
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for ShutdownTimer {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.timer
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue