mirror of https://github.com/chipsenkbeil/distant
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
286 lines
8.7 KiB
Rust
286 lines
8.7 KiB
Rust
use std::collections::HashMap;
|
|
use std::{fmt, io};
|
|
|
|
use log::*;
|
|
use tokio::sync::{mpsc, oneshot};
|
|
use tokio::task::JoinHandle;
|
|
|
|
use crate::client::{Mailbox, UntypedClient};
|
|
use crate::common::{ConnectionId, Destination, Map, UntypedRequest, UntypedResponse};
|
|
use crate::manager::data::{ManagerChannelId, ManagerResponse};
|
|
use crate::server::ServerReply;
|
|
|
|
/// Represents a connection a distant manager has with some distant-compatible server
|
|
pub struct ManagerConnection {
|
|
pub id: ConnectionId,
|
|
pub destination: Destination,
|
|
pub options: Map,
|
|
tx: mpsc::UnboundedSender<Action>,
|
|
|
|
action_task: JoinHandle<()>,
|
|
request_task: JoinHandle<()>,
|
|
response_task: JoinHandle<()>,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct ManagerChannel {
|
|
channel_id: ManagerChannelId,
|
|
tx: mpsc::UnboundedSender<Action>,
|
|
}
|
|
|
|
impl ManagerChannel {
|
|
/// Returns the id associated with the channel.
|
|
pub fn id(&self) -> ManagerChannelId {
|
|
self.channel_id
|
|
}
|
|
|
|
/// Sends the untyped request to the server on the other side of the channel.
|
|
pub fn send(&self, req: UntypedRequest<'static>) -> io::Result<()> {
|
|
let id = self.channel_id;
|
|
|
|
self.tx.send(Action::Write { id, req }).map_err(|x| {
|
|
io::Error::new(
|
|
io::ErrorKind::BrokenPipe,
|
|
format!("channel {id} send failed: {x}"),
|
|
)
|
|
})
|
|
}
|
|
|
|
/// Closes the channel, unregistering it with the connection.
|
|
pub fn close(&self) -> io::Result<()> {
|
|
let id = self.channel_id;
|
|
self.tx.send(Action::Unregister { id }).map_err(|x| {
|
|
io::Error::new(
|
|
io::ErrorKind::BrokenPipe,
|
|
format!("channel {id} close failed: {x}"),
|
|
)
|
|
})
|
|
}
|
|
}
|
|
|
|
impl ManagerConnection {
|
|
pub async fn spawn(
|
|
spawn: Destination,
|
|
options: Map,
|
|
mut client: UntypedClient,
|
|
) -> io::Result<Self> {
|
|
let connection_id = rand::random();
|
|
let (tx, rx) = mpsc::unbounded_channel();
|
|
|
|
// NOTE: Ensure that the connection is severed when the client is dropped; otherwise, when
|
|
// the connection is terminated via aborting it or the connection being dropped, the
|
|
// connection will persist which can cause problems such as lonely shutdown of the server
|
|
// never triggering!
|
|
client.shutdown_on_drop(true);
|
|
|
|
let (request_tx, request_rx) = mpsc::unbounded_channel();
|
|
let action_task = tokio::spawn(action_task(connection_id, rx, request_tx));
|
|
let response_task = tokio::spawn(response_task(
|
|
connection_id,
|
|
client.assign_default_mailbox(100).await?,
|
|
tx.clone(),
|
|
));
|
|
let request_task = tokio::spawn(request_task(connection_id, client, request_rx));
|
|
|
|
Ok(Self {
|
|
id: connection_id,
|
|
destination: spawn,
|
|
options,
|
|
tx,
|
|
action_task,
|
|
request_task,
|
|
response_task,
|
|
})
|
|
}
|
|
|
|
pub fn open_channel(&self, reply: ServerReply<ManagerResponse>) -> io::Result<ManagerChannel> {
|
|
let channel_id = rand::random();
|
|
self.tx
|
|
.send(Action::Register {
|
|
id: channel_id,
|
|
reply,
|
|
})
|
|
.map_err(|x| {
|
|
io::Error::new(
|
|
io::ErrorKind::BrokenPipe,
|
|
format!("open_channel failed: {x}"),
|
|
)
|
|
})?;
|
|
Ok(ManagerChannel {
|
|
channel_id,
|
|
tx: self.tx.clone(),
|
|
})
|
|
}
|
|
|
|
pub async fn channel_ids(&self) -> io::Result<Vec<ManagerChannelId>> {
|
|
let (tx, rx) = oneshot::channel();
|
|
self.tx
|
|
.send(Action::GetRegistered { cb: tx })
|
|
.map_err(|x| {
|
|
io::Error::new(
|
|
io::ErrorKind::BrokenPipe,
|
|
format!("channel_ids failed: {x}"),
|
|
)
|
|
})?;
|
|
|
|
let channel_ids = rx.await.map_err(|x| {
|
|
io::Error::new(
|
|
io::ErrorKind::BrokenPipe,
|
|
format!("channel_ids callback dropped: {x}"),
|
|
)
|
|
})?;
|
|
Ok(channel_ids)
|
|
}
|
|
|
|
/// Aborts the tasks used to engage with the connection.
|
|
pub fn abort(&self) {
|
|
self.action_task.abort();
|
|
self.request_task.abort();
|
|
self.response_task.abort();
|
|
}
|
|
}
|
|
|
|
impl Drop for ManagerConnection {
|
|
fn drop(&mut self) {
|
|
self.abort();
|
|
}
|
|
}
|
|
|
|
enum Action {
|
|
Register {
|
|
id: ManagerChannelId,
|
|
reply: ServerReply<ManagerResponse>,
|
|
},
|
|
|
|
Unregister {
|
|
id: ManagerChannelId,
|
|
},
|
|
|
|
GetRegistered {
|
|
cb: oneshot::Sender<Vec<ManagerChannelId>>,
|
|
},
|
|
|
|
Read {
|
|
res: UntypedResponse<'static>,
|
|
},
|
|
|
|
Write {
|
|
id: ManagerChannelId,
|
|
req: UntypedRequest<'static>,
|
|
},
|
|
}
|
|
|
|
impl fmt::Debug for Action {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
match self {
|
|
Self::Register { id, .. } => write!(f, "Action::Register {{ id: {id}, .. }}"),
|
|
Self::Unregister { id } => write!(f, "Action::Unregister {{ id: {id} }}"),
|
|
Self::GetRegistered { .. } => write!(f, "Action::GetRegistered {{ .. }}"),
|
|
Self::Read { .. } => write!(f, "Action::Read {{ .. }}"),
|
|
Self::Write { id, .. } => write!(f, "Action::Write {{ id: {id}, .. }}"),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Internal task to process outgoing [`UntypedRequest`]s.
|
|
async fn request_task(
|
|
id: ConnectionId,
|
|
mut client: UntypedClient,
|
|
mut rx: mpsc::UnboundedReceiver<UntypedRequest<'static>>,
|
|
) {
|
|
while let Some(req) = rx.recv().await {
|
|
trace!("[Conn {id}] Firing off request {}", req.id);
|
|
if let Err(x) = client.fire(req).await {
|
|
error!("[Conn {id}] Failed to send request: {x}");
|
|
}
|
|
}
|
|
|
|
trace!("[Conn {id}] Manager request task closed");
|
|
}
|
|
|
|
/// Internal task to process incoming [`UntypedResponse`]s.
|
|
async fn response_task(
|
|
id: ConnectionId,
|
|
mut mailbox: Mailbox<UntypedResponse<'static>>,
|
|
tx: mpsc::UnboundedSender<Action>,
|
|
) {
|
|
while let Some(res) = mailbox.next().await {
|
|
trace!(
|
|
"[Conn {id}] Receiving response {} to request {}",
|
|
res.id,
|
|
res.origin_id
|
|
);
|
|
if let Err(x) = tx.send(Action::Read { res }) {
|
|
error!("[Conn {id}] Failed to forward received response: {x}");
|
|
}
|
|
}
|
|
|
|
trace!("[Conn {id}] Manager response task closed");
|
|
}
|
|
|
|
/// Internal task to process [`Action`] items.
|
|
///
|
|
/// * `id` - the id of the connection.
|
|
/// * `rx` - used to receive new [`Action`]s to process.
|
|
/// * `tx` - used to send outgoing requests through the connection.
|
|
async fn action_task(
|
|
id: ConnectionId,
|
|
mut rx: mpsc::UnboundedReceiver<Action>,
|
|
tx: mpsc::UnboundedSender<UntypedRequest<'static>>,
|
|
) {
|
|
let mut registered = HashMap::new();
|
|
|
|
while let Some(action) = rx.recv().await {
|
|
trace!("[Conn {id}] {action:?}");
|
|
|
|
match action {
|
|
Action::Register { id, reply } => {
|
|
registered.insert(id, reply);
|
|
}
|
|
Action::Unregister { id } => {
|
|
registered.remove(&id);
|
|
}
|
|
Action::GetRegistered { cb } => {
|
|
let _ = cb.send(registered.keys().copied().collect());
|
|
}
|
|
Action::Read { mut res } => {
|
|
// Split {channel id}_{request id} back into pieces and
|
|
// update the origin id to match the request id only
|
|
let channel_id = match res.origin_id.split_once('_') {
|
|
Some((cid_str, oid_str)) => {
|
|
if let Ok(cid) = cid_str.parse::<ManagerChannelId>() {
|
|
res.set_origin_id(oid_str.to_string());
|
|
cid
|
|
} else {
|
|
continue;
|
|
}
|
|
}
|
|
None => continue,
|
|
};
|
|
|
|
if let Some(reply) = registered.get(&channel_id) {
|
|
let response = ManagerResponse::Channel {
|
|
id: channel_id,
|
|
response: res,
|
|
};
|
|
|
|
if let Err(x) = reply.send(response) {
|
|
error!("[Conn {id}] {x}");
|
|
}
|
|
}
|
|
}
|
|
Action::Write { id, mut req } => {
|
|
// Combine channel id with request id so we can properly forward
|
|
// the response containing this in the origin id
|
|
req.set_id(format!("{id}_{}", req.id));
|
|
|
|
if let Err(x) = tx.send(req) {
|
|
error!("[Conn {id}] {x}");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
trace!("[Conn {id}] Manager action task closed");
|
|
}
|