use crate::{ client::utils, constants::CLIENT_MAILBOX_CAPACITY, data::{Request, Response}, net::{Codec, DataStream, Transport, TransportError}, }; use log::*; use std::{ convert, net::SocketAddr, ops::{Deref, DerefMut}, sync::{Arc, Weak}, }; use tokio::{ io, net::TcpStream, sync::{mpsc, Mutex}, task::{JoinError, JoinHandle}, time::Duration, }; mod ext; pub use ext::{Metadata, SessionExt, SessionExtError}; mod info; pub use info::{SessionInfo, SessionInfoFile, SessionInfoParseError}; mod mailbox; pub use mailbox::Mailbox; use mailbox::PostOffice; /// Represents a session with a remote server that can be used to send requests & receive responses pub struct Session { /// Used to send requests to a server channel: SessionChannel, /// Contains the task that is running to send requests to a server request_task: JoinHandle<()>, /// Contains the task that is running to receive responses from a server response_task: JoinHandle<()>, /// Contains the task that runs on a timer to prune closed mailboxes prune_task: JoinHandle<()>, } impl Session { /// Connect to a remote TCP server using the provided information pub async fn tcp_connect(addr: SocketAddr, codec: U) -> io::Result where U: Codec + Send + 'static, { let transport = Transport::::connect(addr, codec).await?; debug!( "Session has been established with {}", transport .peer_addr() .map(|x| x.to_string()) .unwrap_or_else(|_| String::from("???")) ); Self::initialize(transport) } /// Connect to a remote TCP server, timing out after duration has passed pub async fn tcp_connect_timeout( addr: SocketAddr, codec: U, duration: Duration, ) -> io::Result where U: Codec + Send + 'static, { utils::timeout(duration, Self::tcp_connect(addr, codec)) .await .and_then(convert::identity) } } #[cfg(unix)] impl Session { /// Connect to a proxy unix socket pub async fn unix_connect(path: impl AsRef, codec: U) -> io::Result where U: Codec + Send + 'static, { let transport = Transport::::connect(path, codec).await?; debug!( "Session has been established with {}", transport .peer_addr() .map(|x| format!("{:?}", x)) .unwrap_or_else(|_| String::from("???")) ); Self::initialize(transport) } /// Connect to a proxy unix socket, timing out after duration has passed pub async fn unix_connect_timeout( path: impl AsRef, codec: U, duration: Duration, ) -> io::Result where U: Codec + Send + 'static, { utils::timeout(duration, Self::unix_connect(path, codec)) .await .and_then(convert::identity) } } impl Session { /// Initializes a session using the provided transport pub fn initialize(transport: Transport) -> io::Result where T: DataStream, U: Codec + Send + 'static, { let (mut t_read, mut t_write) = transport.into_split(); let post_office = Arc::new(Mutex::new(PostOffice::new())); let weak_post_office = Arc::downgrade(&post_office); // Start a task that continually checks for responses and delivers them using the // post office let response_task = tokio::spawn(async move { loop { match t_read.receive::().await { Ok(Some(res)) => { trace!("Incoming response: {:?}", res); let res_id = res.id; let res_origin_id = res.origin_id; // Try to send response to appropriate mailbox // NOTE: We don't log failures as errors as using fire(...) for a // session is valid and would not have a mailbox if !post_office.lock().await.deliver(res).await { trace!( "Response {} has no mailbox for origin {}", res_id, res_origin_id ); } } Ok(None) => { debug!("Session closing response task as transport read-half closed!"); break; } Err(x) => { error!("{}", x); break; } } } // Clean up remaining mailbox senders post_office.lock().await.clear_mailboxes(); }); let (tx, mut rx) = mpsc::channel::(1); let request_task = tokio::spawn(async move { while let Some(req) = rx.recv().await { if let Err(x) = t_write.send(req).await { error!("Failed to send request to server: {}", x); break; } } }); // Create a task that runs once a minute and prunes mailboxes let post_office = Weak::clone(&weak_post_office); let prune_task = tokio::spawn(async move { loop { tokio::time::sleep(Duration::from_secs(60)).await; if let Some(post_office) = Weak::upgrade(&post_office) { post_office.lock().await.prune_mailboxes(); } else { break; } } }); let channel = SessionChannel { tx, post_office: weak_post_office, }; Ok(Self { channel, request_task, response_task, prune_task, }) } } impl Session { /// Waits for the session to terminate, which results when the receiving end of the network /// connection is closed (or the session is shutdown) pub async fn wait(self) -> Result<(), JoinError> { self.prune_task.abort(); tokio::try_join!(self.request_task, self.response_task).map(|_| ()) } /// Abort the session's current connection by forcing its response task to shutdown pub fn abort(&self) { self.request_task.abort(); self.response_task.abort(); self.prune_task.abort(); } /// Clones the underlying channel for requests and returns the cloned instance pub fn clone_channel(&self) -> SessionChannel { self.channel.clone() } } impl Deref for Session { type Target = SessionChannel; fn deref(&self) -> &Self::Target { &self.channel } } impl DerefMut for Session { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.channel } } impl From for SessionChannel { fn from(session: Session) -> Self { session.channel } } /// Represents a sender of requests tied to a session, holding onto a weak reference of /// mailboxes to relay responses, meaning that once the [`Session`] is closed or dropped, /// any sent request will no longer be able to receive responses #[derive(Clone)] pub struct SessionChannel { /// Used to send requests to a server tx: mpsc::Sender, /// Collection of mailboxes for receiving responses to requests post_office: Weak>, } impl SessionChannel { /// Returns true if no more requests can be transferred pub fn is_closed(&self) -> bool { self.tx.is_closed() } /// Sends a request and returns a mailbox that can receive one or more responses, failing if /// unable to send a request or if the session's receiving line to the remote server has /// already been severed pub async fn mail(&mut self, req: Request) -> Result { trace!("Mailing request: {:?}", req); // First, create a mailbox using the request's id let mailbox = Weak::upgrade(&self.post_office) .ok_or_else(|| { TransportError::IoError(io::Error::new( io::ErrorKind::NotConnected, "Session's post office is no longer available", )) })? .lock() .await .make_mailbox(req.id, CLIENT_MAILBOX_CAPACITY); // Second, send the request self.fire(req).await?; // Third, return mailbox Ok(mailbox) } /// Sends a request and waits for a response, failing if unable to send a request or if /// the session's receiving line to the remote server has already been severed pub async fn send(&mut self, req: Request) -> Result { trace!("Sending request: {:?}", req); // Send mail and get back a mailbox let mut mailbox = self.mail(req).await?; // Wait for first response, and then drop the mailbox mailbox.next().await.ok_or_else(|| { TransportError::IoError(io::Error::from(io::ErrorKind::ConnectionAborted)) }) } /// Sends a request and waits for a response, timing out after duration has passed pub async fn send_timeout( &mut self, req: Request, duration: Duration, ) -> Result { utils::timeout(duration, self.send(req)) .await .map_err(TransportError::from) .and_then(convert::identity) } /// Sends a request without waiting for a response; this method is able to be used even /// if the session's receiving line to the remote server has been severed pub async fn fire(&mut self, req: Request) -> Result<(), TransportError> { trace!("Firing off request: {:?}", req); self.tx .send(req) .await .map_err(|x| TransportError::IoError(io::Error::new(io::ErrorKind::BrokenPipe, x))) } /// Sends a request without waiting for a response, timing out after duration has passed pub async fn fire_timeout( &mut self, req: Request, duration: Duration, ) -> Result<(), TransportError> { utils::timeout(duration, self.fire(req)) .await .map_err(TransportError::from) .and_then(convert::identity) } } #[cfg(test)] mod tests { use super::*; use crate::{ constants::test::TENANT, data::{RequestData, ResponseData}, }; use std::time::Duration; #[tokio::test] async fn mail_should_return_mailbox_that_receives_responses_until_transport_closes() { let (t1, mut t2) = Transport::make_pair(); let mut session = Session::initialize(t1).unwrap(); let req = Request::new(TENANT, vec![RequestData::ProcList {}]); let res = Response::new(TENANT, req.id, vec![ResponseData::Ok]); let mut mailbox = session.mail(req).await.unwrap(); // Get first response match tokio::join!(mailbox.next(), t2.send(res.clone())) { (Some(actual), _) => assert_eq!(actual, res), x => panic!("Unexpected response: {:?}", x), } // Get second response match tokio::join!(mailbox.next(), t2.send(res.clone())) { (Some(actual), _) => assert_eq!(actual, res), x => panic!("Unexpected response: {:?}", x), } // Trigger the mailbox to wait BEFORE closing our transport to ensure that // we don't get stuck if the mailbox was already waiting let next_task = tokio::spawn(async move { mailbox.next().await }); tokio::task::yield_now().await; drop(t2); match next_task.await { Ok(None) => {} x => panic!("Unexpected response: {:?}", x), } } #[tokio::test] async fn send_should_wait_until_response_received() { let (t1, mut t2) = Transport::make_pair(); let mut session = Session::initialize(t1).unwrap(); let req = Request::new(TENANT, vec![RequestData::ProcList {}]); let res = Response::new( TENANT, req.id, vec![ResponseData::ProcEntries { entries: Vec::new(), }], ); let (actual, _) = tokio::join!(session.send(req), t2.send(res.clone())); match actual { Ok(actual) => assert_eq!(actual, res), x => panic!("Unexpected response: {:?}", x), } } #[tokio::test] async fn send_timeout_should_fail_if_response_not_received_in_time() { let (t1, mut t2) = Transport::make_pair(); let mut session = Session::initialize(t1).unwrap(); let req = Request::new(TENANT, vec![RequestData::ProcList {}]); match session.send_timeout(req, Duration::from_millis(30)).await { Err(TransportError::IoError(x)) => assert_eq!(x.kind(), io::ErrorKind::TimedOut), x => panic!("Unexpected response: {:?}", x), } let req = t2.receive::().await.unwrap().unwrap(); assert_eq!(req.tenant, TENANT); } #[tokio::test] async fn fire_should_send_request_and_not_wait_for_response() { let (t1, mut t2) = Transport::make_pair(); let mut session = Session::initialize(t1).unwrap(); let req = Request::new(TENANT, vec![RequestData::ProcList {}]); match session.fire(req).await { Ok(_) => {} x => panic!("Unexpected response: {:?}", x), } let req = t2.receive::().await.unwrap().unwrap(); assert_eq!(req.tenant, TENANT); } }