Support sequential batch processing (#201)

pull/203/head
Chip Senkbeil 11 months ago committed by GitHub
parent efad345a0d
commit 4fb9045152
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,6 +18,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
`distant-local` implementation sending a separate `Changed` event per path
- `ChangeDetails` now includes a `renamed` field to capture the new path name
when known
- `DistantApi` now handles batch requests in parallel, returning the results in
order. To achieve the previous sequential processing of batch requests, the
header value `sequence` needs to be set to true
## [0.20.0-alpha.8]

@ -27,7 +27,7 @@ pub struct DistantApiServerHandler<T, D>
where
T: DistantApi<LocalData = D>,
{
api: T,
api: Arc<T>,
}
impl<T, D> DistantApiServerHandler<T, D>
@ -35,7 +35,7 @@ where
T: DistantApi<LocalData = D>,
{
pub fn new(api: T) -> Self {
Self { api }
Self { api: Arc::new(api) }
}
}
@ -424,8 +424,8 @@ pub trait DistantApi {
#[async_trait]
impl<T, D> ServerHandler for DistantApiServerHandler<T, D>
where
T: DistantApi<LocalData = D> + Send + Sync,
D: Send + Sync,
T: DistantApi<LocalData = D> + Send + Sync + 'static,
D: Send + Sync + 'static,
{
type LocalData = D;
type Request = protocol::Msg<protocol::Request>;
@ -457,7 +457,7 @@ where
local_data,
};
let data = handle_request(self, ctx, data).await;
let data = handle_request(Arc::clone(&self.api), ctx, data).await;
// Report outgoing errors in our debug logs
if let protocol::Response::Error(x) = &data {
@ -466,27 +466,35 @@ where
protocol::Msg::Single(data)
}
protocol::Msg::Batch(list) => {
protocol::Msg::Batch(list)
if matches!(request.header.get_as("sequence"), Some(Ok(true))) =>
{
let mut out = Vec::new();
let mut has_failed = false;
for data in list {
// Once we hit a failure, all remaining requests return interrupted
if has_failed {
out.push(protocol::Response::Error(protocol::Error {
kind: protocol::ErrorKind::Interrupted,
description: String::from("Canceled due to earlier error"),
}));
continue;
}
let ctx = DistantCtx {
connection_id,
reply: Box::new(DistantSingleReply::from(reply.clone_reply())),
local_data: Arc::clone(&local_data),
};
// TODO: This does not run in parallel, meaning that the next item in the
// batch will not be queued until the previous item completes! This
// would be useful if we wanted to chain requests where the previous
// request feeds into the current request, but not if we just want
// to run everything together. So we should instead rewrite this
// to spawn a task per request and then await completion of all tasks
let data = handle_request(self, ctx, data).await;
let data = handle_request(Arc::clone(&self.api), ctx, data).await;
// Report outgoing errors in our debug logs
// Report outgoing errors in our debug logs and mark as failed
// to cancel any future tasks being run
if let protocol::Response::Error(x) = &data {
debug!("[Conn {}] {}", connection_id, x);
has_failed = true;
}
out.push(data);
@ -494,6 +502,44 @@ where
protocol::Msg::Batch(out)
}
protocol::Msg::Batch(list) => {
let mut tasks = Vec::new();
// If sequence specified as true, we want to process in order, otherwise we can
// process in any order
for data in list {
let api = Arc::clone(&self.api);
let ctx = DistantCtx {
connection_id,
reply: Box::new(DistantSingleReply::from(reply.clone_reply())),
local_data: Arc::clone(&local_data),
};
let task = tokio::spawn(async move {
let data = handle_request(api, ctx, data).await;
// Report outgoing errors in our debug logs
if let protocol::Response::Error(x) = &data {
debug!("[Conn {}] {}", connection_id, x);
}
data
});
tasks.push(task);
}
let out = futures::future::join_all(tasks)
.await
.into_iter()
.map(|x| match x {
Ok(x) => x,
Err(x) => protocol::Response::Error(x.to_string().into()),
})
.collect();
protocol::Msg::Batch(out)
}
};
// Queue up our result to go before ANY of the other messages that might be sent.
@ -515,7 +561,7 @@ where
/// Processes an incoming request
async fn handle_request<T, D>(
server: &DistantApiServerHandler<T, D>,
api: Arc<T>,
ctx: DistantCtx<D>,
request: protocol::Request,
) -> protocol::Response
@ -524,44 +570,37 @@ where
D: Send + Sync,
{
match request {
protocol::Request::Version {} => server
.api
protocol::Request::Version {} => api
.version(ctx)
.await
.map(protocol::Response::Version)
.unwrap_or_else(protocol::Response::from),
protocol::Request::FileRead { path } => server
.api
protocol::Request::FileRead { path } => api
.read_file(ctx, path)
.await
.map(|data| protocol::Response::Blob { data })
.unwrap_or_else(protocol::Response::from),
protocol::Request::FileReadText { path } => server
.api
protocol::Request::FileReadText { path } => api
.read_file_text(ctx, path)
.await
.map(|data| protocol::Response::Text { data })
.unwrap_or_else(protocol::Response::from),
protocol::Request::FileWrite { path, data } => server
.api
protocol::Request::FileWrite { path, data } => api
.write_file(ctx, path, data)
.await
.map(|_| protocol::Response::Ok)
.unwrap_or_else(protocol::Response::from),
protocol::Request::FileWriteText { path, text } => server
.api
protocol::Request::FileWriteText { path, text } => api
.write_file_text(ctx, path, text)
.await
.map(|_| protocol::Response::Ok)
.unwrap_or_else(protocol::Response::from),
protocol::Request::FileAppend { path, data } => server
.api
protocol::Request::FileAppend { path, data } => api
.append_file(ctx, path, data)
.await
.map(|_| protocol::Response::Ok)
.unwrap_or_else(protocol::Response::from),
protocol::Request::FileAppendText { path, text } => server
.api
protocol::Request::FileAppendText { path, text } => api
.append_file_text(ctx, path, text)
.await
.map(|_| protocol::Response::Ok)
@ -572,8 +611,7 @@ where
absolute,
canonicalize,
include_root,
} => server
.api
} => api
.read_dir(ctx, path, depth, absolute, canonicalize, include_root)
.await
.map(|(entries, errors)| protocol::Response::DirEntries {
@ -581,26 +619,22 @@ where
errors: errors.into_iter().map(Error::from).collect(),
})
.unwrap_or_else(protocol::Response::from),
protocol::Request::DirCreate { path, all } => server
.api
protocol::Request::DirCreate { path, all } => api
.create_dir(ctx, path, all)
.await
.map(|_| protocol::Response::Ok)
.unwrap_or_else(protocol::Response::from),
protocol::Request::Remove { path, force } => server
.api
protocol::Request::Remove { path, force } => api
.remove(ctx, path, force)
.await
.map(|_| protocol::Response::Ok)
.unwrap_or_else(protocol::Response::from),
protocol::Request::Copy { src, dst } => server
.api
protocol::Request::Copy { src, dst } => api
.copy(ctx, src, dst)
.await
.map(|_| protocol::Response::Ok)
.unwrap_or_else(protocol::Response::from),
protocol::Request::Rename { src, dst } => server
.api
protocol::Request::Rename { src, dst } => api
.rename(ctx, src, dst)
.await
.map(|_| protocol::Response::Ok)
@ -610,20 +644,17 @@ where
recursive,
only,
except,
} => server
.api
} => api
.watch(ctx, path, recursive, only, except)
.await
.map(|_| protocol::Response::Ok)
.unwrap_or_else(protocol::Response::from),
protocol::Request::Unwatch { path } => server
.api
protocol::Request::Unwatch { path } => api
.unwatch(ctx, path)
.await
.map(|_| protocol::Response::Ok)
.unwrap_or_else(protocol::Response::from),
protocol::Request::Exists { path } => server
.api
protocol::Request::Exists { path } => api
.exists(ctx, path)
.await
.map(|value| protocol::Response::Exists { value })
@ -632,8 +663,7 @@ where
path,
canonicalize,
resolve_file_type,
} => server
.api
} => api
.metadata(ctx, path, canonicalize, resolve_file_type)
.await
.map(protocol::Response::Metadata)
@ -642,20 +672,17 @@ where
path,
permissions,
options,
} => server
.api
} => api
.set_permissions(ctx, path, permissions, options)
.await
.map(|_| protocol::Response::Ok)
.unwrap_or_else(protocol::Response::from),
protocol::Request::Search { query } => server
.api
protocol::Request::Search { query } => api
.search(ctx, query)
.await
.map(|id| protocol::Response::SearchStarted { id })
.unwrap_or_else(protocol::Response::from),
protocol::Request::CancelSearch { id } => server
.api
protocol::Request::CancelSearch { id } => api
.cancel_search(ctx, id)
.await
.map(|_| protocol::Response::Ok)
@ -665,32 +692,27 @@ where
environment,
current_dir,
pty,
} => server
.api
} => api
.proc_spawn(ctx, cmd.into(), environment, current_dir, pty)
.await
.map(|id| protocol::Response::ProcSpawned { id })
.unwrap_or_else(protocol::Response::from),
protocol::Request::ProcKill { id } => server
.api
protocol::Request::ProcKill { id } => api
.proc_kill(ctx, id)
.await
.map(|_| protocol::Response::Ok)
.unwrap_or_else(protocol::Response::from),
protocol::Request::ProcStdin { id, data } => server
.api
protocol::Request::ProcStdin { id, data } => api
.proc_stdin(ctx, id, data)
.await
.map(|_| protocol::Response::Ok)
.unwrap_or_else(protocol::Response::from),
protocol::Request::ProcResizePty { id, size } => server
.api
protocol::Request::ProcResizePty { id, size } => api
.proc_resize_pty(ctx, id, size)
.await
.map(|_| protocol::Response::Ok)
.unwrap_or_else(protocol::Response::from),
protocol::Request::SystemInfo {} => server
.api
protocol::Request::SystemInfo {} => api
.system_info(ctx)
.await
.map(protocol::Response::SystemInfo)

@ -0,0 +1,347 @@
use std::io;
use std::path::PathBuf;
use async_trait::async_trait;
use distant_core::{
DistantApi, DistantApiServerHandler, DistantChannelExt, DistantClient, DistantCtx,
};
use distant_net::auth::{DummyAuthHandler, Verifier};
use distant_net::client::Client;
use distant_net::common::{InmemoryTransport, OneshotListener};
use distant_net::server::{Server, ServerRef};
/// Stands up an inmemory client and server using the given api.
async fn setup(
api: impl DistantApi<LocalData = ()> + Send + Sync + 'static,
) -> (DistantClient, Box<dyn ServerRef>) {
let (t1, t2) = InmemoryTransport::pair(100);
let server = Server::new()
.handler(DistantApiServerHandler::new(api))
.verifier(Verifier::none())
.start(OneshotListener::from_value(t2))
.expect("Failed to start server");
let client: DistantClient = Client::build()
.auth_handler(DummyAuthHandler)
.connector(t1)
.connect()
.await
.expect("Failed to connect to server");
(client, server)
}
mod single {
use super::*;
use test_log::test;
#[test(tokio::test)]
async fn should_support_single_request_returning_error() {
struct TestDistantApi;
#[async_trait]
impl DistantApi for TestDistantApi {
type LocalData = ();
async fn read_file(
&self,
_ctx: DistantCtx<Self::LocalData>,
_path: PathBuf,
) -> io::Result<Vec<u8>> {
Err(io::Error::new(io::ErrorKind::NotFound, "test error"))
}
}
let (mut client, _server) = setup(TestDistantApi).await;
let error = client.read_file(PathBuf::from("file")).await.unwrap_err();
assert_eq!(error.kind(), io::ErrorKind::NotFound);
assert_eq!(error.to_string(), "test error");
}
#[test(tokio::test)]
async fn should_support_single_request_returning_success() {
struct TestDistantApi;
#[async_trait]
impl DistantApi for TestDistantApi {
type LocalData = ();
async fn read_file(
&self,
_ctx: DistantCtx<Self::LocalData>,
_path: PathBuf,
) -> io::Result<Vec<u8>> {
Ok(b"hello world".to_vec())
}
}
let (mut client, _server) = setup(TestDistantApi).await;
let contents = client.read_file(PathBuf::from("file")).await.unwrap();
assert_eq!(contents, b"hello world");
}
}
mod batch_parallel {
use super::*;
use distant_net::common::Request;
use distant_protocol::{Msg, Request as RequestPayload};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use test_log::test;
#[test(tokio::test)]
async fn should_support_multiple_requests_running_in_parallel() {
struct TestDistantApi;
#[async_trait]
impl DistantApi for TestDistantApi {
type LocalData = ();
async fn read_file(
&self,
_ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<Vec<u8>> {
if path.to_str().unwrap() == "slow" {
tokio::time::sleep(Duration::from_millis(500)).await;
}
let time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
Ok((time.as_millis() as u64).to_be_bytes().to_vec())
}
}
let (mut client, _server) = setup(TestDistantApi).await;
let request = Request::new(Msg::batch([
RequestPayload::FileRead {
path: PathBuf::from("file1"),
},
RequestPayload::FileRead {
path: PathBuf::from("slow"),
},
RequestPayload::FileRead {
path: PathBuf::from("file2"),
},
]));
let response = client.send(request).await.unwrap();
let payloads = response.payload.into_batch().unwrap();
// Collect our times from the reading
let mut times = Vec::new();
for payload in payloads {
match payload {
distant_protocol::Response::Blob { data } => {
let mut buf = [0u8; 8];
buf.copy_from_slice(&data[..8]);
times.push(u64::from_be_bytes(buf));
}
x => panic!("Unexpected payload: {x:?}"),
}
}
// Verify that these ran in parallel as the first and third requests should not be
// over 500 milliseconds apart due to the sleep in the middle!
let diff = times[0].abs_diff(times[2]);
assert!(diff <= 500, "Sequential ordering detected");
}
#[test(tokio::test)]
async fn should_run_all_requests_even_if_some_fail() {
struct TestDistantApi;
#[async_trait]
impl DistantApi for TestDistantApi {
type LocalData = ();
async fn read_file(
&self,
_ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<Vec<u8>> {
if path.to_str().unwrap() == "fail" {
return Err(io::Error::new(io::ErrorKind::Other, "test error"));
}
Ok(Vec::new())
}
}
let (mut client, _server) = setup(TestDistantApi).await;
let request = Request::new(Msg::batch([
RequestPayload::FileRead {
path: PathBuf::from("file1"),
},
RequestPayload::FileRead {
path: PathBuf::from("fail"),
},
RequestPayload::FileRead {
path: PathBuf::from("file2"),
},
]));
let response = client.send(request).await.unwrap();
let payloads = response.payload.into_batch().unwrap();
// Should be a success, error, and success
assert!(
matches!(payloads[0], distant_protocol::Response::Blob { .. }),
"Unexpected payloads[0]: {:?}",
payloads[0]
);
assert!(
matches!(
&payloads[1],
distant_protocol::Response::Error(distant_protocol::Error { kind, description })
if matches!(kind, distant_protocol::ErrorKind::Other) && description == "test error"
),
"Unexpected payloads[1]: {:?}",
payloads[1]
);
assert!(
matches!(payloads[2], distant_protocol::Response::Blob { .. }),
"Unexpected payloads[2]: {:?}",
payloads[2]
);
}
}
mod batch_sequence {
use super::*;
use distant_net::common::Request;
use distant_protocol::{Msg, Request as RequestPayload};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use test_log::test;
#[test(tokio::test)]
async fn should_support_multiple_requests_running_in_sequence() {
struct TestDistantApi;
#[async_trait]
impl DistantApi for TestDistantApi {
type LocalData = ();
async fn read_file(
&self,
_ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<Vec<u8>> {
if path.to_str().unwrap() == "slow" {
tokio::time::sleep(Duration::from_millis(500)).await;
}
let time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
Ok((time.as_millis() as u64).to_be_bytes().to_vec())
}
}
let (mut client, _server) = setup(TestDistantApi).await;
let mut request = Request::new(Msg::batch([
RequestPayload::FileRead {
path: PathBuf::from("file1"),
},
RequestPayload::FileRead {
path: PathBuf::from("slow"),
},
RequestPayload::FileRead {
path: PathBuf::from("file2"),
},
]));
// Mark as running in sequence
request.header.insert("sequence", true);
let response = client.send(request).await.unwrap();
let payloads = response.payload.into_batch().unwrap();
// Collect our times from the reading
let mut times = Vec::new();
for payload in payloads {
match payload {
distant_protocol::Response::Blob { data } => {
let mut buf = [0u8; 8];
buf.copy_from_slice(&data[..8]);
times.push(u64::from_be_bytes(buf));
}
x => panic!("Unexpected payload: {x:?}"),
}
}
// Verify that these ran in sequence as the first and third requests should be
// over 500 milliseconds apart due to the sleep in the middle!
let diff = times[0].abs_diff(times[2]);
assert!(diff > 500, "Parallel ordering detected");
}
#[test(tokio::test)]
async fn should_interrupt_any_requests_following_a_failure() {
struct TestDistantApi;
#[async_trait]
impl DistantApi for TestDistantApi {
type LocalData = ();
async fn read_file(
&self,
_ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<Vec<u8>> {
if path.to_str().unwrap() == "fail" {
return Err(io::Error::new(io::ErrorKind::Other, "test error"));
}
Ok(Vec::new())
}
}
let (mut client, _server) = setup(TestDistantApi).await;
let mut request = Request::new(Msg::batch([
RequestPayload::FileRead {
path: PathBuf::from("file1"),
},
RequestPayload::FileRead {
path: PathBuf::from("fail"),
},
RequestPayload::FileRead {
path: PathBuf::from("file2"),
},
]));
// Mark as running in sequence
request.header.insert("sequence", true);
let response = client.send(request).await.unwrap();
let payloads = response.payload.into_batch().unwrap();
// Should be a success, error, and interrupt
assert!(
matches!(payloads[0], distant_protocol::Response::Blob { .. }),
"Unexpected payloads[0]: {:?}",
payloads[0]
);
assert!(
matches!(
&payloads[1],
distant_protocol::Response::Error(distant_protocol::Error { kind, description })
if matches!(kind, distant_protocol::ErrorKind::Other) && description == "test error"
),
"Unexpected payloads[1]: {:?}",
payloads[1]
);
assert!(
matches!(
&payloads[2],
distant_protocol::Response::Error(distant_protocol::Error { kind, .. })
if matches!(kind, distant_protocol::ErrorKind::Interrupted)
),
"Unexpected payloads[2]: {:?}",
payloads[2]
);
}
}

@ -20,5 +20,4 @@ pub use listener::*;
pub use map::*;
pub use packet::*;
pub use port::*;
pub use serde_json::Value;
pub use transport::*;

@ -1,10 +1,12 @@
mod header;
mod request;
mod response;
mod value;
pub use header::*;
pub use request::*;
pub use response::*;
pub use value::*;
use std::io::Cursor;

@ -1,5 +1,6 @@
use crate::common::{utils, Value};
use derive_more::IntoIterator;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::io;
@ -54,6 +55,17 @@ impl Header {
self.0.insert(key.into(), value.into())
}
/// Retrieves a value from the header, attempting to convert it to the specified type `T`
/// by cloning the value and then converting it.
pub fn get_as<T>(&self, key: impl AsRef<str>) -> Option<io::Result<T>>
where
T: DeserializeOwned,
{
self.0
.get(key.as_ref())
.map(|value| value.clone().cast_as())
}
/// Serializes the header into bytes.
pub fn to_vec(&self) -> io::Result<Vec<u8>> {
utils::serialize_to_vec(self)

@ -0,0 +1,110 @@
use crate::common::utils;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::io;
use std::ops::{Deref, DerefMut};
/// Generic value type for data passed through header.
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct Value(serde_json::Value);
impl Value {
/// Creates a new [`Value`] by converting `value` to the underlying type.
pub fn new(value: impl Into<serde_json::Value>) -> Self {
Self(value.into())
}
/// Serializes the value into bytes.
pub fn to_vec(&self) -> io::Result<Vec<u8>> {
utils::serialize_to_vec(self)
}
/// Deserializes the value from bytes.
pub fn from_slice(slice: &[u8]) -> io::Result<Self> {
utils::deserialize_from_slice(slice)
}
/// Attempts to convert this generic value to a specific type.
pub fn cast_as<T>(self) -> io::Result<T>
where
T: DeserializeOwned,
{
serde_json::from_value(self.0).map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))
}
}
impl Deref for Value {
type Target = serde_json::Value;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for Value {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
macro_rules! impl_from {
($($type:ty),+) => {
$(
impl From<$type> for Value {
fn from(x: $type) -> Self {
Self(From::from(x))
}
}
)+
};
}
impl_from!(
(),
i8, i16, i32, i64, isize,
u8, u16, u32, u64, usize,
f32, f64,
bool, String, serde_json::Number,
serde_json::Map<String, serde_json::Value>
);
impl<'a, T> From<&'a [T]> for Value
where
T: Clone + Into<serde_json::Value>,
{
fn from(x: &'a [T]) -> Self {
Self(From::from(x))
}
}
impl<'a> From<&'a str> for Value {
fn from(x: &'a str) -> Self {
Self(From::from(x))
}
}
impl<'a> From<Cow<'a, str>> for Value {
fn from(x: Cow<'a, str>) -> Self {
Self(From::from(x))
}
}
impl<T> From<Option<T>> for Value
where
T: Into<serde_json::Value>,
{
fn from(x: Option<T>) -> Self {
Self(From::from(x))
}
}
impl<T> From<Vec<T>> for Value
where
T: Into<serde_json::Value>,
{
fn from(x: Vec<T>) -> Self {
Self(From::from(x))
}
}
Loading…
Cancel
Save