diff --git a/distant-net/src/packet.rs b/distant-net/src/packet.rs index 1cccf7c..0c26f7a 100644 --- a/distant-net/src/packet.rs +++ b/distant-net/src/packet.rs @@ -1,68 +1,254 @@ /// Represents a generic id type pub type Id = String; -/// Represents a request to send -#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] -#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] -pub struct Request { - /// Unique id associated with the request - pub id: Id, - - /// Payload associated with the request - pub payload: T, -} +mod request; +mod response; -impl Request { - /// Creates a new request with a random, unique id - pub fn new(payload: T) -> Self { - Self { - id: rand::random::().to_string(), - payload, - } - } -} +pub use request::*; +pub use response::*; -#[cfg(feature = "schemars")] -impl Request { - pub fn root_schema() -> schemars::schema::RootSchema { - schemars::schema_for!(Request) - } +#[derive(Clone, Debug, PartialEq, Eq)] +enum MsgPackStrParseError { + InvalidFormat, + Utf8Error(std::str::Utf8Error), } -impl From for Request { - fn from(payload: T) -> Self { - Self::new(payload) +/// Parse msgpack str, returning remaining bytes and str on success, or error on failure +fn parse_msg_pack_str(input: &[u8]) -> Result<(&[u8], &str), MsgPackStrParseError> { + let ilen = input.len(); + if ilen == 0 { + return Err(MsgPackStrParseError::InvalidFormat); } -} -/// Represents a response received related to some request -#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] -#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] -pub struct Response { - /// Unique id associated with the response - pub id: Id, + // * fixstr using 0xa0 - 0xbf to mark the start of the str where < 32 bytes + // * str 8 (0xd9) if up to (2^8)-1 bytes, using next byte for len + // * str 16 (0xda) if up to (2^16)-1 bytes, using next two bytes for len + // * str 32 (0xdb) if up to (2^32)-1 bytes, using next four bytes for len + let (input, len): (&[u8], usize) = if input[0] >= 0xa0 && input[0] <= 0xbf { + (&input[1..], (input[0] & 0b00011111).into()) + } else if input[0] == 0xd9 && ilen > 2 { + (&input[2..], input[1].into()) + } else if input[0] == 0xda && ilen > 3 { + (&input[3..], u16::from_be_bytes([input[1], input[2]]).into()) + } else if input[0] == 0xdb && ilen > 5 { + ( + &input[5..], + u32::from_be_bytes([input[1], input[2], input[3], input[4]]) + .try_into() + .unwrap(), + ) + } else { + return Err(MsgPackStrParseError::InvalidFormat); + }; - /// Unique id associated with the request that triggered the response - pub origin_id: Id, + let s = match std::str::from_utf8(&input[..len]) { + Ok(s) => s, + Err(x) => return Err(MsgPackStrParseError::Utf8Error(x)), + }; - /// Payload associated with the response - pub payload: T, + Ok((&input[len..], s)) } -impl Response { - /// Creates a new response with a random, unique id - pub fn new(origin_id: Id, payload: T) -> Self { - Self { - id: rand::random::().to_string(), - origin_id, - payload, +#[cfg(test)] +mod tests { + use super::*; + + mod parse_msg_pack_str { + use super::*; + + #[test] + fn should_be_able_to_parse_fixstr() { + // Empty str + let (input, s) = parse_msg_pack_str(&[0xa0]).unwrap(); + assert!(input.is_empty()); + assert_eq!(s, ""); + + // Single character + let (input, s) = parse_msg_pack_str(&[0xa1, b'a']).unwrap(); + assert!(input.is_empty()); + assert_eq!(s, "a"); + + // 31 byte str + let (input, s) = parse_msg_pack_str(&[ + 0xbf, b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', + b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', + b'a', b'a', b'a', b'a', + ]) + .unwrap(); + assert!(input.is_empty()); + assert_eq!(s, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + + // Verify that we only consume up to fixstr length + assert_eq!(parse_msg_pack_str(&[0xa0, b'a']).unwrap().0, b"a"); + assert_eq!( + parse_msg_pack_str(&[ + 0xbf, b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', + b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', + b'a', b'a', b'a', b'a', b'a', b'a', b'b' + ]) + .unwrap() + .0, + b"b" + ); } - } -} -#[cfg(feature = "schemars")] -impl Response { - pub fn root_schema() -> schemars::schema::RootSchema { - schemars::schema_for!(Response) + #[test] + fn should_be_able_to_parse_str_8() { + // 32 byte str + let (input, s) = parse_msg_pack_str(&[ + 0xd9, 32, b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', + b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', + b'a', b'a', b'a', b'a', b'a', b'a', + ]) + .unwrap(); + assert!(input.is_empty()); + assert_eq!(s, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + + // 2^8 - 1 (255) byte str + let test_str = "a".repeat(2usize.pow(8) - 1); + let mut input = vec![0xd9, 255]; + input.extend_from_slice(test_str.as_bytes()); + let (input, s) = parse_msg_pack_str(&input).unwrap(); + assert!(input.is_empty()); + assert_eq!(s, test_str); + + // Verify that we only consume up to 2^8 - 1 length + let mut input = vec![0xd9, 255]; + input.extend_from_slice(test_str.as_bytes()); + input.extend_from_slice(b"hello"); + let (input, s) = parse_msg_pack_str(&input).unwrap(); + assert_eq!(input, b"hello"); + assert_eq!(s, test_str); + } + + #[test] + fn should_be_able_to_parse_str_16() { + // 2^8 byte str (256) + let test_str = "a".repeat(2usize.pow(8)); + let mut input = vec![0xda, 1, 0]; + input.extend_from_slice(test_str.as_bytes()); + let (input, s) = parse_msg_pack_str(&input).unwrap(); + assert!(input.is_empty()); + assert_eq!(s, test_str); + + // 2^16 - 1 (65535) byte str + let test_str = "a".repeat(2usize.pow(16) - 1); + let mut input = vec![0xda, 255, 255]; + input.extend_from_slice(test_str.as_bytes()); + let (input, s) = parse_msg_pack_str(&input).unwrap(); + assert!(input.is_empty()); + assert_eq!(s, test_str); + + // Verify that we only consume up to 2^16 - 1 length + let mut input = vec![0xda, 255, 255]; + input.extend_from_slice(test_str.as_bytes()); + input.extend_from_slice(b"hello"); + let (input, s) = parse_msg_pack_str(&input).unwrap(); + assert_eq!(input, b"hello"); + assert_eq!(s, test_str); + } + + #[test] + fn should_be_able_to_parse_str_32() { + // 2^16 byte str + let test_str = "a".repeat(2usize.pow(16)); + let mut input = vec![0xdb, 0, 1, 0, 0]; + input.extend_from_slice(test_str.as_bytes()); + let (input, s) = parse_msg_pack_str(&input).unwrap(); + assert!(input.is_empty()); + assert_eq!(s, test_str); + + // NOTE: We are not going to run the below tests, not because they aren't valid but + // because this generates a 4GB str which takes 20+ seconds to run + + // 2^32 - 1 byte str (4294967295 bytes) + /* let test_str = "a".repeat(2usize.pow(32) - 1); + let mut input = vec![0xdb, 255, 255, 255, 255]; + input.extend_from_slice(test_str.as_bytes()); + let (input, s) = parse_msg_pack_str(&input).unwrap(); + assert!(input.is_empty()); + assert_eq!(s, test_str); */ + + // Verify that we only consume up to 2^32 - 1 length + /* let mut input = vec![0xdb, 255, 255, 255, 255]; + input.extend_from_slice(test_str.as_bytes()); + input.extend_from_slice(b"hello"); + let (input, s) = parse_msg_pack_str(&input).unwrap(); + assert_eq!(input, b"hello"); + assert_eq!(s, test_str); */ + } + + #[test] + fn should_fail_parsing_str_with_invalid_length() { + // Make sure that parse doesn't fail looking for bytes after str 8 len + assert_eq!( + parse_msg_pack_str(&[0xd9]), + Err(MsgPackStrParseError::InvalidFormat) + ); + assert_eq!( + parse_msg_pack_str(&[0xd9, 0]), + Err(MsgPackStrParseError::InvalidFormat) + ); + + // Make sure that parse doesn't fail looking for bytes after str 16 len + assert_eq!( + parse_msg_pack_str(&[0xda]), + Err(MsgPackStrParseError::InvalidFormat) + ); + assert_eq!( + parse_msg_pack_str(&[0xda, 0]), + Err(MsgPackStrParseError::InvalidFormat) + ); + assert_eq!( + parse_msg_pack_str(&[0xda, 0, 0]), + Err(MsgPackStrParseError::InvalidFormat) + ); + + // Make sure that parse doesn't fail looking for bytes after str 32 len + assert_eq!( + parse_msg_pack_str(&[0xdb]), + Err(MsgPackStrParseError::InvalidFormat) + ); + assert_eq!( + parse_msg_pack_str(&[0xdb, 0]), + Err(MsgPackStrParseError::InvalidFormat) + ); + assert_eq!( + parse_msg_pack_str(&[0xdb, 0, 0]), + Err(MsgPackStrParseError::InvalidFormat) + ); + assert_eq!( + parse_msg_pack_str(&[0xdb, 0, 0, 0]), + Err(MsgPackStrParseError::InvalidFormat) + ); + assert_eq!( + parse_msg_pack_str(&[0xdb, 0, 0, 0, 0]), + Err(MsgPackStrParseError::InvalidFormat) + ); + } + + #[test] + fn should_fail_parsing_other_types() { + assert_eq!( + parse_msg_pack_str(&[0xc3]), // Boolean (true) + Err(MsgPackStrParseError::InvalidFormat) + ); + } + + #[test] + fn should_fail_if_empty_input() { + assert_eq!( + parse_msg_pack_str(&[]), + Err(MsgPackStrParseError::InvalidFormat) + ); + } + + #[test] + fn should_fail_if_str_is_not_utf8() { + assert!(matches!( + parse_msg_pack_str(&[0xa4, 0, 159, 146, 150]), + Err(MsgPackStrParseError::Utf8Error(_)) + )); + } } } diff --git a/distant-net/src/packet/request.rs b/distant-net/src/packet/request.rs new file mode 100644 index 0000000..71511f9 --- /dev/null +++ b/distant-net/src/packet/request.rs @@ -0,0 +1,316 @@ +use super::{parse_msg_pack_str, Id}; +use crate::utils; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use std::{borrow::Cow, io, str}; + +/// Represents a request to send +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct Request { + /// Unique id associated with the request + pub id: Id, + + /// Payload associated with the request + pub payload: T, +} + +impl Request { + /// Creates a new request with a random, unique id + pub fn new(payload: T) -> Self { + Self { + id: rand::random::().to_string(), + payload, + } + } +} + +impl Request +where + T: Serialize, +{ + /// Serializes the request into bytes + pub fn to_vec(&self) -> io::Result> { + utils::serialize_to_vec(self) + } + + /// Serializes the request's payload into bytes + pub fn to_payload_vec(&self) -> io::Result> { + utils::serialize_to_vec(&self.payload) + } +} + +impl Request +where + T: DeserializeOwned, +{ + /// Deserializes the request from bytes + pub fn from_slice(slice: &[u8]) -> io::Result { + utils::deserialize_from_slice(slice) + } +} + +#[cfg(feature = "schemars")] +impl Request { + pub fn root_schema() -> schemars::schema::RootSchema { + schemars::schema_for!(Request) + } +} + +impl From for Request { + fn from(payload: T) -> Self { + Self::new(payload) + } +} + +/// Error encountered when attempting to parse bytes as an untyped request +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub enum UntypedRequestParseError { + /// When the bytes do not represent a request + WrongType, + + /// When the id is not a valid UTF-8 string + InvalidId, +} + +/// Represents a request to send whose payload is bytes instead of a specific type +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct UntypedRequest<'a> { + /// Unique id associated with the request + pub id: Cow<'a, str>, + + /// Payload associated with the request as bytes + pub payload: Cow<'a, [u8]>, +} + +impl<'a> UntypedRequest<'a> { + /// Attempts to convert an untyped request to a typed request + pub fn to_typed_request(&self) -> io::Result> { + Ok(Request { + id: self.id.to_string(), + payload: utils::deserialize_from_slice(&self.payload)?, + }) + } + + /// Convert into a borrowed version + pub fn as_borrowed(&self) -> UntypedRequest<'_> { + UntypedRequest { + id: match &self.id { + Cow::Borrowed(x) => Cow::Borrowed(x), + Cow::Owned(x) => Cow::Borrowed(x.as_str()), + }, + payload: match &self.payload { + Cow::Borrowed(x) => Cow::Borrowed(x), + Cow::Owned(x) => Cow::Borrowed(x.as_slice()), + }, + } + } + + /// Convert into an owned version + pub fn into_owned(self) -> UntypedRequest<'static> { + UntypedRequest { + id: match self.id { + Cow::Borrowed(x) => Cow::Owned(x.to_string()), + Cow::Owned(x) => Cow::Owned(x), + }, + payload: match self.payload { + Cow::Borrowed(x) => Cow::Owned(x.to_vec()), + Cow::Owned(x) => Cow::Owned(x), + }, + } + } + + /// Parses a collection of bytes, returning a partial request if it can be potentially + /// represented as a [`Request`] depending on the payload, or the original bytes if it does not + /// represent a [`Request`] + /// + /// NOTE: This supports parsing an invalid request where the payload would not properly + /// deserialize, but the bytes themselves represent a complete request of some kind. + pub fn from_slice(input: &'a [u8]) -> Result { + if input.len() < 2 { + return Err(UntypedRequestParseError::WrongType); + } + + // MsgPack marks a fixmap using 0x80 - 0x8f to indicate the size (up to 15 elements). + // + // In the case of the request, there are only two elements: id and payload. So the first + // byte should ALWAYS be 0x82 (130). + if input[0] != 0x82 { + return Err(UntypedRequestParseError::WrongType); + } + + // Skip the first byte representing the fixmap + let input = &input[1..]; + + // Validate that first field is id + let (input, id_key) = + parse_msg_pack_str(input).map_err(|_| UntypedRequestParseError::WrongType)?; + if id_key != "id" { + return Err(UntypedRequestParseError::WrongType); + } + + // Get the id itself + let (input, id) = + parse_msg_pack_str(input).map_err(|_| UntypedRequestParseError::InvalidId)?; + + // Validate that second field is payload + let (input, payload_key) = + parse_msg_pack_str(input).map_err(|_| UntypedRequestParseError::WrongType)?; + if payload_key != "payload" { + return Err(UntypedRequestParseError::WrongType); + } + + let id = Cow::Borrowed(id); + let payload = Cow::Borrowed(input); + + Ok(Self { id, payload }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + const TRUE_BYTE: u8 = 0xc3; + const NEVER_USED_BYTE: u8 = 0xc1; + + // fixstr of 2 bytes with str "id" + const ID_FIELD_BYTES: &[u8] = &[0xa2, 0x69, 0x64]; + + // fixstr of 7 bytes with str "payload" + const PAYLOAD_FIELD_BYTES: &[u8] = &[0xa7, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64]; + + /// fixstr of 4 bytes with str "test" + const TEST_STR_BYTES: &[u8] = &[0xa4, 0x74, 0x65, 0x73, 0x74]; + + #[test] + fn untyped_request_should_support_parsing_from_request_bytes_with_valid_payload() { + let bytes = Request { + id: "some id".to_string(), + payload: true, + } + .to_vec() + .unwrap(); + + assert_eq!( + UntypedRequest::from_slice(&bytes), + Ok(UntypedRequest { + id: Cow::Borrowed("some id"), + payload: Cow::Owned(vec![TRUE_BYTE]), + }) + ); + } + + #[test] + fn untyped_request_should_support_parsing_from_request_bytes_with_invalid_payload() { + // Request with id < 32 bytes + let mut bytes = Request { + id: "".to_string(), + payload: true, + } + .to_vec() + .unwrap(); + + // Push never used byte in msgpack + bytes.push(NEVER_USED_BYTE); + + // We don't actually check for a valid payload, so the extra byte shows up + assert_eq!( + UntypedRequest::from_slice(&bytes), + Ok(UntypedRequest { + id: Cow::Owned("".to_string()), + payload: Cow::Owned(vec![TRUE_BYTE, NEVER_USED_BYTE]), + }) + ); + } + + #[test] + fn untyped_request_should_fail_to_parse_if_given_bytes_not_representing_a_request() { + // Empty byte slice + assert_eq!( + UntypedRequest::from_slice(&[]), + Err(UntypedRequestParseError::WrongType) + ); + + // Wrong starting byte + assert_eq!( + UntypedRequest::from_slice(&[0x00]), + Err(UntypedRequestParseError::WrongType) + ); + + // Wrong starting byte (fixmap of 0 fields) + assert_eq!( + UntypedRequest::from_slice(&[0x80]), + Err(UntypedRequestParseError::WrongType) + ); + + // Missing fields (corrupt data) + assert_eq!( + UntypedRequest::from_slice(&[0x82]), + Err(UntypedRequestParseError::WrongType) + ); + + // Missing id field (has valid data itself) + assert_eq!( + UntypedRequest::from_slice( + [ + &[0x82], + &[0xa0], // id would be defined here, set to empty str + TEST_STR_BYTES, + PAYLOAD_FIELD_BYTES, + &[TRUE_BYTE], + ] + .concat() + .as_slice() + ), + Err(UntypedRequestParseError::WrongType) + ); + + // Non-str id field value + assert_eq!( + UntypedRequest::from_slice( + [ + &[0x82], + ID_FIELD_BYTES, + &[TRUE_BYTE], // id value set to boolean + PAYLOAD_FIELD_BYTES, + &[TRUE_BYTE], + ] + .concat() + .as_slice() + ), + Err(UntypedRequestParseError::InvalidId) + ); + + // Non-utf8 id field value + assert_eq!( + UntypedRequest::from_slice( + [ + &[0x82], + ID_FIELD_BYTES, + &[0xa4, 0, 159, 146, 150], + PAYLOAD_FIELD_BYTES, + &[TRUE_BYTE], + ] + .concat() + .as_slice() + ), + Err(UntypedRequestParseError::InvalidId) + ); + + // Missing payload field (has valid data itself) + assert_eq!( + UntypedRequest::from_slice( + [ + &[0x82], + ID_FIELD_BYTES, + TEST_STR_BYTES, + &[0xa0], // payload would be defined here, set to empty str + &[TRUE_BYTE], + ] + .concat() + .as_slice() + ), + Err(UntypedRequestParseError::WrongType) + ); + } +} diff --git a/distant-net/src/packet/response.rs b/distant-net/src/packet/response.rs new file mode 100644 index 0000000..eea2d66 --- /dev/null +++ b/distant-net/src/packet/response.rs @@ -0,0 +1,415 @@ +use super::{parse_msg_pack_str, Id}; +use crate::utils; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use std::{borrow::Cow, io}; + +/// Represents a response received related to some response +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct Response { + /// Unique id associated with the response + pub id: Id, + + /// Unique id associated with the response that triggered the response + pub origin_id: Id, + + /// Payload associated with the response + pub payload: T, +} + +impl Response { + /// Creates a new response with a random, unique id + pub fn new(origin_id: Id, payload: T) -> Self { + Self { + id: rand::random::().to_string(), + origin_id, + payload, + } + } +} + +impl Response +where + T: Serialize, +{ + /// Serializes the response into bytes + pub fn to_vec(&self) -> std::io::Result> { + utils::serialize_to_vec(self) + } + + /// Serializes the response's payload into bytes + pub fn to_payload_vec(&self) -> io::Result> { + utils::serialize_to_vec(&self.payload) + } +} + +impl Response +where + T: DeserializeOwned, +{ + /// Deserializes the response from bytes + pub fn from_slice(slice: &[u8]) -> std::io::Result { + utils::deserialize_from_slice(slice) + } +} + +#[cfg(feature = "schemars")] +impl Response { + pub fn root_schema() -> schemars::schema::RootSchema { + schemars::schema_for!(Response) + } +} + +/// Error encountered when attempting to parse bytes as an untyped response +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub enum UntypedResponseParseError { + /// When the bytes do not represent a response + WrongType, + + /// When the id is not a valid UTF-8 string + InvalidId, + + /// When the origin id is not a valid UTF-8 string + InvalidOriginId, +} + +/// Represents a response to send whose payload is bytes instead of a specific type +#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] +pub struct UntypedResponse<'a> { + /// Unique id associated with the response + pub id: Cow<'a, str>, + + /// Unique id associated with the response that triggered the response + pub origin_id: Cow<'a, str>, + + /// Payload associated with the response as bytes + pub payload: Cow<'a, [u8]>, +} + +impl<'a> UntypedResponse<'a> { + /// Attempts to convert an untyped request to a typed request + pub fn to_typed_request(&self) -> io::Result> { + Ok(Response { + id: self.id.to_string(), + origin_id: self.origin_id.to_string(), + payload: utils::deserialize_from_slice(&self.payload)?, + }) + } + + /// Convert into a borrowed version + pub fn as_borrowed(&self) -> UntypedResponse<'_> { + UntypedResponse { + id: match &self.id { + Cow::Borrowed(x) => Cow::Borrowed(x), + Cow::Owned(x) => Cow::Borrowed(x.as_str()), + }, + origin_id: match &self.origin_id { + Cow::Borrowed(x) => Cow::Borrowed(x), + Cow::Owned(x) => Cow::Borrowed(x.as_str()), + }, + payload: match &self.payload { + Cow::Borrowed(x) => Cow::Borrowed(x), + Cow::Owned(x) => Cow::Borrowed(x.as_slice()), + }, + } + } + + /// Convert into an owned version + pub fn into_owned(self) -> UntypedResponse<'static> { + UntypedResponse { + id: match self.id { + Cow::Borrowed(x) => Cow::Owned(x.to_string()), + Cow::Owned(x) => Cow::Owned(x), + }, + origin_id: match self.origin_id { + Cow::Borrowed(x) => Cow::Owned(x.to_string()), + Cow::Owned(x) => Cow::Owned(x), + }, + payload: match self.payload { + Cow::Borrowed(x) => Cow::Owned(x.to_vec()), + Cow::Owned(x) => Cow::Owned(x), + }, + } + } + + /// Parses a collection of bytes, returning an untyped response if it can be potentially + /// represented as a [`Response`] depending on the payload, or the original bytes if it does not + /// represent a [`Response`] + /// + /// NOTE: This supports parsing an invalid response where the payload would not properly + /// deserialize, but the bytes themselves represent a complete response of some kind. + pub fn from_slice(input: &'a [u8]) -> Result { + if input.len() < 2 { + return Err(UntypedResponseParseError::WrongType); + } + + // MsgPack marks a fixmap using 0x80 - 0x8f to indicate the size (up to 15 elements). + // + // In the case of the request, there are only three elements: id, origin_id, and payload. + // So the first byte should ALWAYS be 0x83 (131). + if input[0] != 0x83 { + return Err(UntypedResponseParseError::WrongType); + } + + // Skip the first byte representing the fixmap + let input = &input[1..]; + + // Validate that first field is id + let (input, id_key) = + parse_msg_pack_str(input).map_err(|_| UntypedResponseParseError::WrongType)?; + if id_key != "id" { + return Err(UntypedResponseParseError::WrongType); + } + + // Get the id itself + let (input, id) = + parse_msg_pack_str(input).map_err(|_| UntypedResponseParseError::InvalidId)?; + + // Validate that second field is origin_id + let (input, origin_id_key) = + parse_msg_pack_str(input).map_err(|_| UntypedResponseParseError::WrongType)?; + if origin_id_key != "origin_id" { + return Err(UntypedResponseParseError::WrongType); + } + + // Get the origin_id itself + let (input, origin_id) = + parse_msg_pack_str(input).map_err(|_| UntypedResponseParseError::InvalidOriginId)?; + + // Validate that second field is payload + let (input, payload_key) = + parse_msg_pack_str(input).map_err(|_| UntypedResponseParseError::WrongType)?; + if payload_key != "payload" { + return Err(UntypedResponseParseError::WrongType); + } + + let id = Cow::Borrowed(id); + let origin_id = Cow::Borrowed(origin_id); + let payload = Cow::Borrowed(input); + + Ok(Self { + id, + origin_id, + payload, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + const TRUE_BYTE: u8 = 0xc3; + const NEVER_USED_BYTE: u8 = 0xc1; + + // fixstr of 2 bytes with str "id" + const ID_FIELD_BYTES: &[u8] = &[0xa2, 0x69, 0x64]; + + // fixstr of 9 bytes with str "origin_id" + const ORIGIN_ID_FIELD_BYTES: &[u8] = + &[0xa9, 0x6f, 0x72, 0x69, 0x67, 0x69, 0x6e, 0x5f, 0x69, 0x64]; + + // fixstr of 7 bytes with str "payload" + const PAYLOAD_FIELD_BYTES: &[u8] = &[0xa7, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64]; + + /// fixstr of 4 bytes with str "test" + const TEST_STR_BYTES: &[u8] = &[0xa4, 0x74, 0x65, 0x73, 0x74]; + + #[test] + fn untyped_response_should_support_parsing_from_response_bytes_with_valid_payload() { + let bytes = Response { + id: "some id".to_string(), + origin_id: "some origin id".to_string(), + payload: true, + } + .to_vec() + .unwrap(); + + assert_eq!( + UntypedResponse::from_slice(&bytes), + Ok(UntypedResponse { + id: Cow::Borrowed("some id"), + origin_id: Cow::Borrowed("some origin id"), + payload: Cow::Owned(vec![TRUE_BYTE]), + }) + ); + } + + #[test] + fn untyped_response_should_support_parsing_from_response_bytes_with_invalid_payload() { + // Response with id < 32 bytes + let mut bytes = Response { + id: "".to_string(), + origin_id: "".to_string(), + payload: true, + } + .to_vec() + .unwrap(); + + // Push never used byte in msgpack + bytes.push(NEVER_USED_BYTE); + + // We don't actually check for a valid payload, so the extra byte shows up + assert_eq!( + UntypedResponse::from_slice(&bytes), + Ok(UntypedResponse { + id: Cow::Owned("".to_string()), + origin_id: Cow::Owned("".to_string()), + payload: Cow::Owned(vec![TRUE_BYTE, NEVER_USED_BYTE]), + }) + ); + } + + #[test] + fn untyped_response_should_fail_to_parse_if_given_bytes_not_representing_a_response() { + // Empty byte slice + assert_eq!( + UntypedResponse::from_slice(&[]), + Err(UntypedResponseParseError::WrongType) + ); + + // Wrong starting byte + assert_eq!( + UntypedResponse::from_slice(&[0x00]), + Err(UntypedResponseParseError::WrongType) + ); + + // Wrong starting byte (fixmap of 0 fields) + assert_eq!( + UntypedResponse::from_slice(&[0x80]), + Err(UntypedResponseParseError::WrongType) + ); + + // Missing fields (corrupt data) + assert_eq!( + UntypedResponse::from_slice(&[0x83]), + Err(UntypedResponseParseError::WrongType) + ); + + // Missing id field (has valid data itself) + assert_eq!( + UntypedResponse::from_slice( + [ + &[0x83], + &[0xa0], // id would be defined here, set to empty str + TEST_STR_BYTES, + ORIGIN_ID_FIELD_BYTES, + TEST_STR_BYTES, + PAYLOAD_FIELD_BYTES, + &[TRUE_BYTE], + ] + .concat() + .as_slice() + ), + Err(UntypedResponseParseError::WrongType) + ); + + // Non-str id field value + assert_eq!( + UntypedResponse::from_slice( + [ + &[0x83], + ID_FIELD_BYTES, + &[TRUE_BYTE], // id value set to boolean + ORIGIN_ID_FIELD_BYTES, + TEST_STR_BYTES, + PAYLOAD_FIELD_BYTES, + &[TRUE_BYTE], + ] + .concat() + .as_slice() + ), + Err(UntypedResponseParseError::InvalidId) + ); + + // Non-utf8 id field value + assert_eq!( + UntypedResponse::from_slice( + [ + &[0x83], + ID_FIELD_BYTES, + &[0xa4, 0, 159, 146, 150], + ORIGIN_ID_FIELD_BYTES, + TEST_STR_BYTES, + PAYLOAD_FIELD_BYTES, + &[TRUE_BYTE], + ] + .concat() + .as_slice() + ), + Err(UntypedResponseParseError::InvalidId) + ); + + // Missing origin_id field (has valid data itself) + assert_eq!( + UntypedResponse::from_slice( + [ + &[0x83], + ID_FIELD_BYTES, + TEST_STR_BYTES, + &[0xa0], // id would be defined here, set to empty str + TEST_STR_BYTES, + PAYLOAD_FIELD_BYTES, + &[TRUE_BYTE], + ] + .concat() + .as_slice() + ), + Err(UntypedResponseParseError::WrongType) + ); + + // Non-str origin_id field value + assert_eq!( + UntypedResponse::from_slice( + [ + &[0x83], + ID_FIELD_BYTES, + TEST_STR_BYTES, + ORIGIN_ID_FIELD_BYTES, + &[TRUE_BYTE], // id value set to boolean + PAYLOAD_FIELD_BYTES, + &[TRUE_BYTE], + ] + .concat() + .as_slice() + ), + Err(UntypedResponseParseError::InvalidOriginId) + ); + + // Non-utf8 origin_id field value + assert_eq!( + UntypedResponse::from_slice( + [ + &[0x83], + ID_FIELD_BYTES, + TEST_STR_BYTES, + ORIGIN_ID_FIELD_BYTES, + &[0xa4, 0, 159, 146, 150], + PAYLOAD_FIELD_BYTES, + &[TRUE_BYTE], + ] + .concat() + .as_slice() + ), + Err(UntypedResponseParseError::InvalidOriginId) + ); + + // Missing payload field (has valid data itself) + assert_eq!( + UntypedResponse::from_slice( + [ + &[0x83], + ID_FIELD_BYTES, + TEST_STR_BYTES, + ORIGIN_ID_FIELD_BYTES, + TEST_STR_BYTES, + &[0xa0], // payload would be defined here, set to empty str + &[TRUE_BYTE], + ] + .concat() + .as_slice() + ), + Err(UntypedResponseParseError::WrongType) + ); + } +} diff --git a/distant-net/src/server/ext.rs b/distant-net/src/server/ext.rs index d540f06..52d8f3c 100644 --- a/distant-net/src/server/ext.rs +++ b/distant-net/src/server/ext.rs @@ -160,6 +160,17 @@ where let (tx, mut rx) = mpsc::channel::>(1); connection.writer_task = Some(tokio::spawn(async move { while let Some(data) = rx.recv().await { + // Log our message as a string, which can be expensive + if log_enabled!(Level::Trace) { + trace!( + "[Conn {connection_id}] Sending {}", + &data + .to_vec() + .map(|x| String::from_utf8_lossy(&x).to_string()) + .unwrap_or_else(|_| "".to_string()) + ); + } + if let Err(x) = writer.write(data).await { error!("[Conn {connection_id}] Failed to send {x}"); break;