From 286fbc9e7d91c4aea857726c775f486f7e7986fe Mon Sep 17 00:00:00 2001 From: Chip Senkbeil Date: Sun, 11 Jun 2023 18:09:46 -0500 Subject: [PATCH] Initial commit --- distant-net/src/common/packet.rs | 68 ---- distant-net/src/common/packet/request.rs | 339 +++++++++++--------- distant-net/src/common/packet/response.rs | 360 +++++++++++----------- 3 files changed, 381 insertions(+), 386 deletions(-) diff --git a/distant-net/src/common/packet.rs b/distant-net/src/common/packet.rs index c55fe2d..3ad1575 100644 --- a/distant-net/src/common/packet.rs +++ b/distant-net/src/common/packet.rs @@ -238,18 +238,6 @@ fn read_str_bytes(input: &[u8]) -> Result<(&str, &[u8]), &[u8]> { } } -/// Reads a str key from msgpack input and checks if it matches `key`. If so, the input is -/// advanced, otherwise the original input is returned. -/// -/// * If key read successfully and matches, returns (unit, remaining). -/// * Otherwise, returns existing bytes. -fn read_key_eq<'a>(input: &'a [u8], key: &str) -> Result<((), &'a [u8]), &'a [u8]> { - match read_str_bytes(input) { - Ok((s, input)) if s == key => Ok(((), input)), - _ => Err(input), - } -} - #[cfg(test)] mod tests { use super::*; @@ -279,62 +267,6 @@ mod tests { } } - mod read_key_eq { - use super::*; - use test_log::test; - - #[test] - fn should_fail_if_input_is_empty() { - let input = read_key_eq(&[], "key").unwrap_err(); - assert!(input.is_empty()); - } - - #[test] - fn should_fail_if_input_does_not_start_with_str() { - let input = &[ - 0xff, - rmp::Marker::FixStr(5).to_u8(), - b'h', - b'e', - b'l', - b'l', - b'o', - ]; - let remaining = read_key_eq(input, "key").unwrap_err(); - assert_eq!(remaining, input); - } - - #[test] - fn should_fail_if_read_key_does_not_match_specified_key() { - let input = &[ - rmp::Marker::FixStr(5).to_u8(), - b'h', - b'e', - b'l', - b'l', - b'o', - 0xff, - ]; - let remaining = read_key_eq(input, "key").unwrap_err(); - assert_eq!(remaining, input); - } - - #[test] - fn should_succeed_if_read_key_matches_specified_key() { - let input = &[ - rmp::Marker::FixStr(5).to_u8(), - b'h', - b'e', - b'l', - b'l', - b'o', - 0xff, - ]; - let (_, remaining) = read_key_eq(input, "hello").unwrap(); - assert_eq!(remaining, [0xff]); - } - } - mod read_header_bytes { use super::*; use test_log::test; diff --git a/distant-net/src/common/packet/request.rs b/distant-net/src/common/packet/request.rs index bd74e34..1e1bf23 100644 --- a/distant-net/src/common/packet/request.rs +++ b/distant-net/src/common/packet/request.rs @@ -5,26 +5,25 @@ use derive_more::{Display, Error}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; -use super::{read_header_bytes, read_key_eq, read_str_bytes, Header, Id}; +use super::{read_header_bytes, read_str_bytes, Header, Id}; use crate::common::utils; use crate::header; -/// Represents a request to send -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +/// Represents a request to send. +#[derive(Clone, Debug, PartialEq, Eq)] pub struct Request { - /// Optional header data to include with request - #[serde(default, skip_serializing_if = "Header::is_empty")] + /// Optional header data to include with request. pub header: Header, - /// Unique id associated with the request + /// Unique id associated with the request. pub id: Id, - /// Payload associated with the request + /// Payload associated with the request. pub payload: T, } impl Request { - /// Creates a new request with a random, unique id and no header data + /// Creates a new request with a random, unique id and no header data. pub fn new(payload: T) -> Self { Self { header: header!(), @@ -38,21 +37,21 @@ impl Request where T: Serialize, { - /// Serializes the request into bytes + /// Serializes the request into bytes using a compact approach. pub fn to_vec(&self) -> io::Result> { - utils::serialize_to_vec(self) + Ok(self.to_untyped_request()?.to_bytes()) } - /// Serializes the request's payload into bytes + /// Serializes the request's payload into bytes. pub fn to_payload_vec(&self) -> io::Result> { utils::serialize_to_vec(&self.payload) } - /// Attempts to convert a typed request to an untyped request + /// Attempts to convert a typed request to an untyped request. pub fn to_untyped_request(&self) -> io::Result { Ok(UntypedRequest { header: Cow::Owned(if !self.header.is_empty() { - utils::serialize_to_vec(&self.header)? + self.header.to_vec()? } else { Vec::new() }), @@ -66,9 +65,11 @@ impl Request where T: DeserializeOwned, { - /// Deserializes the request from bytes + /// Deserializes the request from bytes. pub fn from_slice(slice: &[u8]) -> io::Result { - utils::deserialize_from_slice(slice) + UntypedRequest::from_slice(slice) + .map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))? + .to_typed_request() } } @@ -78,26 +79,122 @@ impl From for Request { } } -/// Error encountered when attempting to parse bytes as an untyped request +mod serde_impl { + use super::*; + use serde::de::{self, Deserialize, Deserializer, SeqAccess, Visitor}; + use serde::ser::{Serialize, SerializeSeq, Serializer}; + use std::fmt; + use std::marker::PhantomData; + + impl Serialize for Request + where + T: Serialize, + { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let has_header = !self.header.is_empty(); + let mut cnt = 2; + if has_header { + cnt += 1; + } + + let mut seq = serializer.serialize_seq(Some(cnt))?; + + if has_header { + seq.serialize_element(&self.header)?; + } + seq.serialize_element(&self.id)?; + seq.serialize_element(&self.payload)?; + seq.end() + } + } + + impl<'de, T> Deserialize<'de> for Request + where + T: Deserialize<'de>, + { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_seq(RequestVisitor::new()) + } + } + + struct RequestVisitor { + marker: PhantomData Request>, + } + + impl RequestVisitor { + fn new() -> Self { + Self { + marker: PhantomData, + } + } + } + + impl<'de, T> Visitor<'de> for RequestVisitor + where + T: Deserialize<'de>, + { + type Value = Request; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a response") + } + + fn visit_seq(self, mut access: S) -> Result + where + S: SeqAccess<'de>, + { + // Attempt to determine if we have a header based on size, defaulting to attempting + // to parse the header first if we don't know the size. If we cannot parse the header, + // we use the default header and keep going. + let header = match access.size_hint() { + Some(2) => Header::default(), + Some(3) => access + .next_element()? + .ok_or_else(|| de::Error::custom("missing header"))?, + Some(_) => return Err(de::Error::custom("invalid response array len")), + None => access + .next_element::
() + .ok() + .flatten() + .unwrap_or_default(), + }; + + let id = access + .next_element()? + .ok_or_else(|| de::Error::custom("missing id"))?; + let payload = access + .next_element()? + .ok_or_else(|| de::Error::custom("missing payload"))?; + + Ok(Request { + header, + id, + payload, + }) + } + } +} + +/// Error encountered when attempting to parse bytes as an untyped request. #[derive(Copy, Clone, Debug, Display, Error, PartialEq, Eq, Hash)] pub enum UntypedRequestParseError { - /// When the bytes do not represent a request + /// When the bytes do not represent a request. WrongType, - /// When a header should be present, but the key is wrong - InvalidHeaderKey, - - /// When a header should be present, but the header bytes are wrong + /// When a header should be present, but the header bytes are wrong. InvalidHeader, - /// When the key for the id is wrong - InvalidIdKey, - - /// When the id is not a valid UTF-8 string + /// When the id is not a valid UTF-8 string. InvalidId, - /// When the key for the payload is wrong - InvalidPayloadKey, + /// When no payload found in the request. + MissingPayload, } #[inline] @@ -105,22 +202,22 @@ fn header_is_empty(header: &[u8]) -> bool { header.is_empty() } -/// Represents a request to send whose payload is bytes instead of a specific type +/// Represents a request to send whose payload is bytes instead of a specific type. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct UntypedRequest<'a> { - /// Header data associated with the request as bytes + /// Header data associated with the request as bytes. #[serde(default, skip_serializing_if = "header_is_empty")] pub header: Cow<'a, [u8]>, - /// Unique id associated with the request + /// Unique id associated with the request. pub id: Cow<'a, str>, - /// Payload associated with the request as bytes + /// Payload associated with the request as bytes. pub payload: Cow<'a, [u8]>, } impl<'a> UntypedRequest<'a> { - /// Attempts to convert an untyped request to a typed request + /// Attempts to convert an untyped request to a typed request. pub fn to_typed_request(&self) -> io::Result> { Ok(Request { header: if header_is_empty(&self.header) { @@ -133,7 +230,7 @@ impl<'a> UntypedRequest<'a> { }) } - /// Convert into a borrowed version + /// Convert into a borrowed version. pub fn as_borrowed(&self) -> UntypedRequest<'_> { UntypedRequest { header: match &self.header { @@ -151,7 +248,7 @@ impl<'a> UntypedRequest<'a> { } } - /// Convert into an owned version + /// Convert into an owned version. pub fn into_owned(self) -> UntypedRequest<'static> { UntypedRequest { header: match self.header { @@ -169,6 +266,12 @@ impl<'a> UntypedRequest<'a> { } } + /// Returns true if the request has an empty header. + #[inline] + pub fn is_header_empty(&self) -> bool { + header_is_empty(&self.header) + } + /// Updates the header of the request to the given `header`. pub fn set_header(&mut self, header: impl IntoIterator) { self.header = Cow::Owned(header.into_iter().collect()); @@ -185,20 +288,16 @@ impl<'a> UntypedRequest<'a> { let has_header = !header_is_empty(&self.header); if has_header { - rmp::encode::write_map_len(&mut bytes, 3).unwrap(); + rmp::encode::write_array_len(&mut bytes, 3).unwrap(); } else { - rmp::encode::write_map_len(&mut bytes, 2).unwrap(); + rmp::encode::write_array_len(&mut bytes, 2).unwrap(); } if has_header { - rmp::encode::write_str(&mut bytes, "header").unwrap(); bytes.extend_from_slice(&self.header); } - rmp::encode::write_str(&mut bytes, "id").unwrap(); rmp::encode::write_str(&mut bytes, &self.id).unwrap(); - - rmp::encode::write_str(&mut bytes, "payload").unwrap(); bytes.extend_from_slice(&self.payload); bytes @@ -215,8 +314,8 @@ impl<'a> UntypedRequest<'a> { } let has_header = match rmp::Marker::from_u8(input[0]) { - rmp::Marker::FixMap(2) => false, - rmp::Marker::FixMap(3) => true, + rmp::Marker::FixArray(2) => false, + rmp::Marker::FixArray(3) => true, _ => return Err(UntypedRequestParseError::WrongType), }; @@ -225,26 +324,17 @@ impl<'a> UntypedRequest<'a> { // Parse the header if we have one let (header, input) = if has_header { - let (_, input) = read_key_eq(input, "header") - .map_err(|_| UntypedRequestParseError::InvalidHeaderKey)?; - - let (header, input) = - read_header_bytes(input).map_err(|_| UntypedRequestParseError::InvalidHeader)?; - (header, input) + read_header_bytes(input).map_err(|_| UntypedRequestParseError::InvalidHeader)? } else { ([0u8; 0].as_slice(), input) }; - // Validate that next field is id - let (_, input) = - read_key_eq(input, "id").map_err(|_| UntypedRequestParseError::InvalidIdKey)?; - - // Get the id itself let (id, input) = read_str_bytes(input).map_err(|_| UntypedRequestParseError::InvalidId)?; - // Validate that final field is payload - let (_, input) = read_key_eq(input, "payload") - .map_err(|_| UntypedRequestParseError::InvalidPayloadKey)?; + // Check if we have input remaining, which should be our payload + if input.is_empty() { + return Err(UntypedRequestParseError::MissingPayload); + } let header = Cow::Borrowed(header); let id = Cow::Borrowed(id); @@ -267,9 +357,6 @@ mod tests { const TRUE_BYTE: u8 = 0xc3; const NEVER_USED_BYTE: u8 = 0xc1; - // fixstr of 6 bytes with str "header" - const HEADER_FIELD_BYTES: &[u8] = &[0xa6, b'h', b'e', b'a', b'd', b'e', b'r']; - // fixmap of 2 objects with // 1. key fixstr "key" and value fixstr "value" // 1. key fixstr "num" and value fixint 123 @@ -281,12 +368,6 @@ mod tests { 0x7b, // value: 123 ]; - // fixstr of 2 bytes with str "id" - const ID_FIELD_BYTES: &[u8] = &[0xa2, b'i', b'd']; - - // fixstr of 7 bytes with str "payload" - const PAYLOAD_FIELD_BYTES: &[u8] = &[0xa7, b'p', b'a', b'y', b'l', b'o', b'a', b'd']; - // fixstr of 4 bytes with str "test" const TEST_STR_BYTES: &[u8] = &[0xa4, b't', b'e', b's', b't']; @@ -383,15 +464,30 @@ mod tests { ); } + #[test] + fn untyped_request_should_support_parsing_without_header() { + let input = [ + &[rmp::Marker::FixArray(2).to_u8()], + TEST_STR_BYTES, + &[TRUE_BYTE], + ] + .concat(); + + // Convert into typed so we can test + let untyped_request = UntypedRequest::from_slice(&input).unwrap(); + let request: Request = untyped_request.to_typed_request().unwrap(); + + assert_eq!(request.header, header!()); + assert_eq!(request.id, "test"); + assert!(request.payload); + } + #[test] fn untyped_request_should_support_parsing_full_request() { let input = [ - &[0x83], - HEADER_FIELD_BYTES, + &[rmp::Marker::FixArray(3).to_u8()], HEADER_BYTES, - ID_FIELD_BYTES, TEST_STR_BYTES, - PAYLOAD_FIELD_BYTES, &[TRUE_BYTE], ] .concat(); @@ -419,116 +515,71 @@ mod tests { Err(UntypedRequestParseError::WrongType) ); - // Wrong starting byte (fixmap of 0 fields) + // Wrong starting byte (fixarray of 0 fields) assert_eq!( - UntypedRequest::from_slice(&[0x80]), + UntypedRequest::from_slice(&[rmp::Marker::FixArray(0).to_u8()]), Err(UntypedRequestParseError::WrongType) ); - // Invalid header key - assert_eq!( - UntypedRequest::from_slice( - [ - &[0x83], - &[0xa0], // header key would be defined here, set to empty str - HEADER_BYTES, - ID_FIELD_BYTES, - TEST_STR_BYTES, - PAYLOAD_FIELD_BYTES, - &[TRUE_BYTE], - ] - .concat() - .as_slice() - ), - Err(UntypedRequestParseError::InvalidHeaderKey) - ); - - // Invalid header bytes + // Missing fields (corrupt data) assert_eq!( - UntypedRequest::from_slice( - [ - &[0x83], - HEADER_FIELD_BYTES, - &[0xa0], // header would be defined here, set to empty str - ID_FIELD_BYTES, - TEST_STR_BYTES, - PAYLOAD_FIELD_BYTES, - &[TRUE_BYTE], - ] - .concat() - .as_slice() - ), - Err(UntypedRequestParseError::InvalidHeader) + UntypedRequest::from_slice(&[rmp::Marker::FixArray(2).to_u8()]), + Err(UntypedRequestParseError::InvalidId) ); // Missing fields (corrupt data) assert_eq!( - UntypedRequest::from_slice(&[0x82]), - Err(UntypedRequestParseError::InvalidIdKey) + UntypedRequest::from_slice(&[rmp::Marker::FixArray(3).to_u8()]), + Err(UntypedRequestParseError::InvalidHeader) ); - // Missing id field (has valid data itself) + // Invalid header bytes assert_eq!( UntypedRequest::from_slice( [ - &[0x82], - &[0xa0], // id would be defined here, set to empty str + &[rmp::Marker::FixArray(3).to_u8()], + &[0xa0], // header would be defined here, set to empty str TEST_STR_BYTES, - PAYLOAD_FIELD_BYTES, &[TRUE_BYTE], ] .concat() .as_slice() ), - Err(UntypedRequestParseError::InvalidIdKey) + Err(UntypedRequestParseError::InvalidHeader) ); // Non-str id field value assert_eq!( - UntypedRequest::from_slice( - [ - &[0x82], - ID_FIELD_BYTES, - &[TRUE_BYTE], // id value set to boolean - PAYLOAD_FIELD_BYTES, - &[TRUE_BYTE], - ] - .concat() - .as_slice() - ), + UntypedRequest::from_slice(&[ + rmp::Marker::FixArray(2).to_u8(), + TRUE_BYTE, // id value set to boolean + TRUE_BYTE, + ]), Err(UntypedRequestParseError::InvalidId) ); // Non-utf8 id field value assert_eq!( - UntypedRequest::from_slice( - [ - &[0x82], - ID_FIELD_BYTES, - &[0xa4, 0, 159, 146, 150], - PAYLOAD_FIELD_BYTES, - &[TRUE_BYTE], - ] - .concat() - .as_slice() - ), + UntypedRequest::from_slice(&[ + rmp::Marker::FixArray(2).to_u8(), + 0xa4, + 0, + 159, + 146, + 150, + TRUE_BYTE, + ]), Err(UntypedRequestParseError::InvalidId) ); - // Missing payload field (has valid data itself) + // Missing payload assert_eq!( UntypedRequest::from_slice( - [ - &[0x82], - ID_FIELD_BYTES, - TEST_STR_BYTES, - &[0xa0], // payload would be defined here, set to empty str - &[TRUE_BYTE], - ] - .concat() - .as_slice() + [&[rmp::Marker::FixArray(2).to_u8()], TEST_STR_BYTES] + .concat() + .as_slice() ), - Err(UntypedRequestParseError::InvalidPayloadKey) + Err(UntypedRequestParseError::MissingPayload) ); } } diff --git a/distant-net/src/common/packet/response.rs b/distant-net/src/common/packet/response.rs index 6056fe6..aa3df59 100644 --- a/distant-net/src/common/packet/response.rs +++ b/distant-net/src/common/packet/response.rs @@ -5,29 +5,28 @@ use derive_more::{Display, Error}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; -use super::{read_header_bytes, read_key_eq, read_str_bytes, Header, Id}; +use super::{read_header_bytes, read_str_bytes, Header, Id}; use crate::common::utils; use crate::header; -/// Represents a response received related to some response -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +/// Represents a response received related to some response. +#[derive(Clone, Debug, PartialEq, Eq)] pub struct Response { - /// Optional header data to include with response - #[serde(default, skip_serializing_if = "Header::is_empty")] + /// Optional header data to include with response. pub header: Header, - /// Unique id associated with the response + /// Unique id associated with the response. pub id: Id, - /// Unique id associated with the response that triggered the response + /// Unique id associated with the response that triggered the response. pub origin_id: Id, - /// Payload associated with the response + /// Payload associated with the response. pub payload: T, } impl Response { - /// Creates a new response with a random, unique id and no header data + /// Creates a new response with a random, unique id and no header data. pub fn new(origin_id: Id, payload: T) -> Self { Self { header: header!(), @@ -42,21 +41,21 @@ impl Response where T: Serialize, { - /// Serializes the response into bytes + /// Serializes the response into bytes using a compact approach. pub fn to_vec(&self) -> std::io::Result> { - utils::serialize_to_vec(self) + Ok(self.to_untyped_response()?.to_bytes()) } - /// Serializes the response's payload into bytes + /// Serializes the response's payload into bytes. pub fn to_payload_vec(&self) -> io::Result> { utils::serialize_to_vec(&self.payload) } - /// Attempts to convert a typed response to an untyped response + /// Attempts to convert a typed response to an untyped response. pub fn to_untyped_response(&self) -> io::Result { Ok(UntypedResponse { header: Cow::Owned(if !self.header.is_empty() { - utils::serialize_to_vec(&self.header)? + self.header.to_vec()? } else { Vec::new() }), @@ -71,38 +70,138 @@ impl Response where T: DeserializeOwned, { - /// Deserializes the response from bytes + /// Deserializes the response from bytes. pub fn from_slice(slice: &[u8]) -> std::io::Result { - utils::deserialize_from_slice(slice) + UntypedResponse::from_slice(slice) + .map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))? + .to_typed_response() } } -/// Error encountered when attempting to parse bytes as an untyped response +mod serde_impl { + use super::*; + use serde::de::{self, Deserialize, Deserializer, SeqAccess, Visitor}; + use serde::ser::{Serialize, SerializeSeq, Serializer}; + use std::fmt; + use std::marker::PhantomData; + + impl Serialize for Response + where + T: Serialize, + { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let has_header = !self.header.is_empty(); + let mut cnt = 3; + if has_header { + cnt += 1; + } + + let mut seq = serializer.serialize_seq(Some(cnt))?; + + if has_header { + seq.serialize_element(&self.header)?; + } + seq.serialize_element(&self.id)?; + seq.serialize_element(&self.origin_id)?; + seq.serialize_element(&self.payload)?; + seq.end() + } + } + + impl<'de, T> Deserialize<'de> for Response + where + T: Deserialize<'de>, + { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_seq(ResponseVisitor::new()) + } + } + + struct ResponseVisitor { + marker: PhantomData Response>, + } + + impl ResponseVisitor { + fn new() -> Self { + Self { + marker: PhantomData, + } + } + } + + impl<'de, T> Visitor<'de> for ResponseVisitor + where + T: Deserialize<'de>, + { + type Value = Response; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a response") + } + + fn visit_seq(self, mut access: S) -> Result + where + S: SeqAccess<'de>, + { + // Attempt to determine if we have a header based on size, defaulting to attempting + // to parse the header first if we don't know the size. If we cannot parse the header, + // we use the default header and keep going. + let header = match access.size_hint() { + Some(3) => Header::default(), + Some(4) => access + .next_element()? + .ok_or_else(|| de::Error::custom("missing header"))?, + Some(_) => return Err(de::Error::custom("invalid response array len")), + None => access + .next_element::
() + .ok() + .flatten() + .unwrap_or_default(), + }; + + let id = access + .next_element()? + .ok_or_else(|| de::Error::custom("missing id"))?; + let origin_id = access + .next_element()? + .ok_or_else(|| de::Error::custom("missing origin_id"))?; + let payload = access + .next_element()? + .ok_or_else(|| de::Error::custom("missing payload"))?; + + Ok(Response { + header, + id, + origin_id, + payload, + }) + } + } +} + +/// Error encountered when attempting to parse bytes as an untyped response. #[derive(Copy, Clone, Debug, Display, Error, PartialEq, Eq, Hash)] pub enum UntypedResponseParseError { - /// When the bytes do not represent a response + /// When the bytes do not represent a response. WrongType, - /// When a header should be present, but the key is wrong - InvalidHeaderKey, - - /// When a header should be present, but the header bytes are wrong + /// When a header should be present, but the header bytes are wrong. InvalidHeader, - /// When the key for the id is wrong - InvalidIdKey, - - /// When the id is not a valid UTF-8 string + /// When the id is not a valid UTF-8 string. InvalidId, - /// When the key for the origin id is wrong - InvalidOriginIdKey, - - /// When the origin id is not a valid UTF-8 string + /// When the origin id is not a valid UTF-8 string. InvalidOriginId, - /// When the key for the payload is wrong - InvalidPayloadKey, + /// When no payload found in the request. + MissingPayload, } #[inline] @@ -110,31 +209,31 @@ fn header_is_empty(header: &[u8]) -> bool { header.is_empty() } -/// Represents a response to send whose payload is bytes instead of a specific type -#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] +/// Represents a response to send whose payload is bytes instead of a specific type. +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct UntypedResponse<'a> { - /// Header data associated with the response as bytes + /// Header data associated with the response as bytes. #[serde(default, skip_serializing_if = "header_is_empty")] pub header: Cow<'a, [u8]>, - /// Unique id associated with the response + /// Unique id associated with the response. pub id: Cow<'a, str>, - /// Unique id associated with the response that triggered the response + /// Unique id associated with the response that triggered the response. pub origin_id: Cow<'a, str>, - /// Payload associated with the response as bytes + /// Payload associated with the response as bytes. pub payload: Cow<'a, [u8]>, } impl<'a> UntypedResponse<'a> { - /// Attempts to convert an untyped response to a typed response + /// Attempts to convert an untyped response to a typed response. pub fn to_typed_response(&self) -> io::Result> { Ok(Response { header: if header_is_empty(&self.header) { header!() } else { - utils::deserialize_from_slice(&self.header)? + Header::from_slice(&self.header)? }, id: self.id.to_string(), origin_id: self.origin_id.to_string(), @@ -142,7 +241,7 @@ impl<'a> UntypedResponse<'a> { }) } - /// Convert into a borrowed version + /// Convert into a borrowed version. pub fn as_borrowed(&self) -> UntypedResponse<'_> { UntypedResponse { header: match &self.header { @@ -164,7 +263,7 @@ impl<'a> UntypedResponse<'a> { } } - /// Convert into an owned version + /// Convert into an owned version. pub fn into_owned(self) -> UntypedResponse<'static> { UntypedResponse { header: match self.header { @@ -207,23 +306,17 @@ impl<'a> UntypedResponse<'a> { let has_header = !header_is_empty(&self.header); if has_header { - rmp::encode::write_map_len(&mut bytes, 4).unwrap(); + rmp::encode::write_array_len(&mut bytes, 4).unwrap(); } else { - rmp::encode::write_map_len(&mut bytes, 3).unwrap(); + rmp::encode::write_array_len(&mut bytes, 3).unwrap(); } if has_header { - rmp::encode::write_str(&mut bytes, "header").unwrap(); bytes.extend_from_slice(&self.header); } - rmp::encode::write_str(&mut bytes, "id").unwrap(); rmp::encode::write_str(&mut bytes, &self.id).unwrap(); - - rmp::encode::write_str(&mut bytes, "origin_id").unwrap(); rmp::encode::write_str(&mut bytes, &self.origin_id).unwrap(); - - rmp::encode::write_str(&mut bytes, "payload").unwrap(); bytes.extend_from_slice(&self.payload); bytes @@ -240,8 +333,8 @@ impl<'a> UntypedResponse<'a> { } let has_header = match rmp::Marker::from_u8(input[0]) { - rmp::Marker::FixMap(3) => false, - rmp::Marker::FixMap(4) => true, + rmp::Marker::FixArray(3) => false, + rmp::Marker::FixArray(4) => true, _ => return Err(UntypedResponseParseError::WrongType), }; @@ -250,35 +343,21 @@ impl<'a> UntypedResponse<'a> { // Parse the header if we have one let (header, input) = if has_header { - let (_, input) = read_key_eq(input, "header") - .map_err(|_| UntypedResponseParseError::InvalidHeaderKey)?; - - let (header, input) = - read_header_bytes(input).map_err(|_| UntypedResponseParseError::InvalidHeader)?; - (header, input) + read_header_bytes(input).map_err(|_| UntypedResponseParseError::InvalidHeader)? } else { ([0u8; 0].as_slice(), input) }; - // Validate that next field is id - let (_, input) = - read_key_eq(input, "id").map_err(|_| UntypedResponseParseError::InvalidIdKey)?; - - // Get the id itself let (id, input) = read_str_bytes(input).map_err(|_| UntypedResponseParseError::InvalidId)?; - // Validate that next field is origin_id - let (_, input) = read_key_eq(input, "origin_id") - .map_err(|_| UntypedResponseParseError::InvalidOriginIdKey)?; - - // Get the origin_id itself let (origin_id, input) = read_str_bytes(input).map_err(|_| UntypedResponseParseError::InvalidOriginId)?; - // Validate that final field is payload - let (_, input) = read_key_eq(input, "payload") - .map_err(|_| UntypedResponseParseError::InvalidPayloadKey)?; + // Check if we have input remaining, which should be our payload + if input.is_empty() { + return Err(UntypedResponseParseError::MissingPayload); + } let header = Cow::Borrowed(header); let id = Cow::Borrowed(id); @@ -303,9 +382,6 @@ mod tests { const TRUE_BYTE: u8 = 0xc3; const NEVER_USED_BYTE: u8 = 0xc1; - // fixstr of 6 bytes with str "header" - const HEADER_FIELD_BYTES: &[u8] = &[0xa6, b'h', b'e', b'a', b'd', b'e', b'r']; - // fixmap of 2 objects with // 1. key fixstr "key" and value fixstr "value" // 1. key fixstr "num" and value fixint 123 @@ -317,16 +393,6 @@ mod tests { 0x7b, // value: 123 ]; - // fixstr of 2 bytes with str "id" - const ID_FIELD_BYTES: &[u8] = &[0xa2, b'i', b'd']; - - // fixstr of 9 bytes with str "origin_id" - const ORIGIN_ID_FIELD_BYTES: &[u8] = - &[0xa9, 0x6f, 0x72, 0x69, 0x67, 0x69, 0x6e, 0x5f, 0x69, 0x64]; - - // fixstr of 7 bytes with str "payload" - const PAYLOAD_FIELD_BYTES: &[u8] = &[0xa7, b'p', b'a', b'y', b'l', b'o', b'a', b'd']; - /// fixstr of 4 bytes with str "test" const TEST_STR_BYTES: &[u8] = &[0xa4, b't', b'e', b's', b't']; @@ -431,17 +497,33 @@ mod tests { ); } + #[test] + fn untyped_response_should_support_parsing_without_header() { + let input = [ + &[rmp::Marker::FixArray(3).to_u8()], + TEST_STR_BYTES, + &[0xa2, b'o', b'g'], + &[TRUE_BYTE], + ] + .concat(); + + // Convert into typed so we can test + let untyped_response = UntypedResponse::from_slice(&input).unwrap(); + let response: Response = untyped_response.to_typed_response().unwrap(); + + assert_eq!(response.header, header!()); + assert_eq!(response.id, "test"); + assert_eq!(response.origin_id, "og"); + assert!(response.payload); + } + #[test] fn untyped_response_should_support_parsing_full_request() { let input = [ - &[0x84], - HEADER_FIELD_BYTES, + &[rmp::Marker::FixArray(4).to_u8()], HEADER_BYTES, - ID_FIELD_BYTES, TEST_STR_BYTES, - ORIGIN_ID_FIELD_BYTES, &[0xa2, b'o', b'g'], - PAYLOAD_FIELD_BYTES, &[TRUE_BYTE], ] .concat(); @@ -476,80 +558,41 @@ mod tests { Err(UntypedResponseParseError::WrongType) ); - // Invalid header key - assert_eq!( - UntypedResponse::from_slice( - [ - &[0x84], - &[0xa0], // header key would be defined here, set to empty str - HEADER_BYTES, - ID_FIELD_BYTES, - TEST_STR_BYTES, - ORIGIN_ID_FIELD_BYTES, - TEST_STR_BYTES, - PAYLOAD_FIELD_BYTES, - &[TRUE_BYTE], - ] - .concat() - .as_slice() - ), - Err(UntypedResponseParseError::InvalidHeaderKey) - ); - - // Invalid header bytes + // Missing fields (corrupt data) assert_eq!( - UntypedResponse::from_slice( - [ - &[0x84], - HEADER_FIELD_BYTES, - &[0xa0], // header would be defined here, set to empty str - ID_FIELD_BYTES, - TEST_STR_BYTES, - ORIGIN_ID_FIELD_BYTES, - TEST_STR_BYTES, - PAYLOAD_FIELD_BYTES, - &[TRUE_BYTE], - ] - .concat() - .as_slice() - ), - Err(UntypedResponseParseError::InvalidHeader) + UntypedResponse::from_slice(&[rmp::Marker::FixArray(3).to_u8()]), + Err(UntypedResponseParseError::InvalidId) ); // Missing fields (corrupt data) assert_eq!( - UntypedResponse::from_slice(&[0x83]), - Err(UntypedResponseParseError::InvalidIdKey) + UntypedResponse::from_slice(&[rmp::Marker::FixArray(4).to_u8()]), + Err(UntypedResponseParseError::InvalidHeader) ); - // Missing id field (has valid data itself) + // Invalid header bytes assert_eq!( UntypedResponse::from_slice( [ - &[0x83], - &[0xa0], // id would be defined here, set to empty str + &[rmp::Marker::FixArray(4).to_u8()], + &[0xa0], // header would be defined here, set to empty str TEST_STR_BYTES, - ORIGIN_ID_FIELD_BYTES, TEST_STR_BYTES, - PAYLOAD_FIELD_BYTES, &[TRUE_BYTE], ] .concat() .as_slice() ), - Err(UntypedResponseParseError::InvalidIdKey) + Err(UntypedResponseParseError::InvalidHeader) ); // Non-str id field value assert_eq!( UntypedResponse::from_slice( [ - &[0x83], - ID_FIELD_BYTES, + &[rmp::Marker::FixArray(3).to_u8()], &[TRUE_BYTE], // id value set to boolean - ORIGIN_ID_FIELD_BYTES, TEST_STR_BYTES, - PAYLOAD_FIELD_BYTES, &[TRUE_BYTE], ] .concat() @@ -562,12 +605,9 @@ mod tests { assert_eq!( UntypedResponse::from_slice( [ - &[0x83], - ID_FIELD_BYTES, + &[rmp::Marker::FixArray(3).to_u8()] as &[u8], &[0xa4, 0, 159, 146, 150], - ORIGIN_ID_FIELD_BYTES, TEST_STR_BYTES, - PAYLOAD_FIELD_BYTES, &[TRUE_BYTE], ] .concat() @@ -576,34 +616,13 @@ mod tests { Err(UntypedResponseParseError::InvalidId) ); - // Missing origin_id field (has valid data itself) - assert_eq!( - UntypedResponse::from_slice( - [ - &[0x83], - ID_FIELD_BYTES, - TEST_STR_BYTES, - &[0xa0], // id would be defined here, set to empty str - TEST_STR_BYTES, - PAYLOAD_FIELD_BYTES, - &[TRUE_BYTE], - ] - .concat() - .as_slice() - ), - Err(UntypedResponseParseError::InvalidOriginIdKey) - ); - // Non-str origin_id field value assert_eq!( UntypedResponse::from_slice( [ - &[0x83], - ID_FIELD_BYTES, + &[rmp::Marker::FixArray(3).to_u8()], TEST_STR_BYTES, - ORIGIN_ID_FIELD_BYTES, &[TRUE_BYTE], // id value set to boolean - PAYLOAD_FIELD_BYTES, &[TRUE_BYTE], ] .concat() @@ -616,12 +635,9 @@ mod tests { assert_eq!( UntypedResponse::from_slice( [ - &[0x83], - ID_FIELD_BYTES, + &[rmp::Marker::FixArray(3).to_u8()], TEST_STR_BYTES, - ORIGIN_ID_FIELD_BYTES, &[0xa4, 0, 159, 146, 150], - PAYLOAD_FIELD_BYTES, &[TRUE_BYTE], ] .concat() @@ -634,18 +650,14 @@ mod tests { assert_eq!( UntypedResponse::from_slice( [ - &[0x83], - ID_FIELD_BYTES, + &[rmp::Marker::FixArray(3).to_u8()], TEST_STR_BYTES, - ORIGIN_ID_FIELD_BYTES, TEST_STR_BYTES, - &[0xa0], // payload would be defined here, set to empty str - &[TRUE_BYTE], ] .concat() .as_slice() ), - Err(UntypedResponseParseError::InvalidPayloadKey) + Err(UntypedResponseParseError::MissingPayload) ); } }