From efad345a0dea72bcd350f10dcef31e92773e9291 Mon Sep 17 00:00:00 2001 From: Chip Senkbeil Date: Sun, 11 Jun 2023 16:07:36 -0700 Subject: [PATCH] Add header support to request & response (#200) --- CHANGELOG.md | 5 + Cargo.lock | 1 + distant-net/Cargo.toml | 2 + distant-net/src/common.rs | 1 + distant-net/src/common/packet.rs | 1263 ++++++++++++--------- distant-net/src/common/packet/header.rs | 80 ++ distant-net/src/common/packet/request.rs | 262 ++++- distant-net/src/common/packet/response.rs | 288 ++++- distant-protocol/src/response.rs | 12 +- 9 files changed, 1274 insertions(+), 640 deletions(-) create mode 100644 distant-net/src/common/packet/header.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index e3114b0..77bf3ea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- `Request` and `Response` types from `distant-net` now support an optional + `Header` to send miscellaneous information + ### Changed - `Change` structure now provides a single `path` instead of `paths` with the diff --git a/Cargo.lock b/Cargo.lock index 5eb89a4..383ef57 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -935,6 +935,7 @@ dependencies = [ "p256", "paste", "rand", + "rmp", "rmp-serde", "serde", "serde_bytes", diff --git a/distant-net/Cargo.toml b/distant-net/Cargo.toml index 1a52b05..c6193df 100644 --- a/distant-net/Cargo.toml +++ b/distant-net/Cargo.toml @@ -25,10 +25,12 @@ log = "0.4.18" paste = "1.0.12" p256 = { version = "0.13.2", features = ["ecdh", "pem"] } rand = { version = "0.8.5", features = ["getrandom"] } +rmp = "0.8.11" rmp-serde = "1.1.1" sha2 = "0.10.6" serde = { version = "1.0.163", features = ["derive"] } serde_bytes = "0.11.9" +serde_json = "1.0.96" strum = { version = "0.24.1", features = ["derive"] } tokio = { version = "1.28.2", features = ["full"] } diff --git a/distant-net/src/common.rs b/distant-net/src/common.rs index 5f793c8..a0e79dc 100644 --- a/distant-net/src/common.rs +++ b/distant-net/src/common.rs @@ -20,4 +20,5 @@ pub use listener::*; pub use map::*; pub use packet::*; pub use port::*; +pub use serde_json::Value; pub use transport::*; diff --git a/distant-net/src/common/packet.rs b/distant-net/src/common/packet.rs index f78bbfe..c55fe2d 100644 --- a/distant-net/src/common/packet.rs +++ b/distant-net/src/common/packet.rs @@ -1,628 +1,805 @@ -/// Represents a generic id type -pub type Id = String; - +mod header; mod request; mod response; +pub use header::*; pub use request::*; pub use response::*; -#[derive(Clone, Debug, PartialEq, Eq)] -enum MsgPackStrParseError { - InvalidFormat, - Utf8Error(std::str::Utf8Error), -} +use std::io::Cursor; -/// Writes the given str to the end of `buf` as the str's msgpack representation. -/// -/// # Panics +/// Represents a generic id type +pub type Id = String; + +/// Reads the header bytes from msgpack input, including the marker and len bytes. /// -/// Panics if `s.len() >= 2 ^ 32` as the maximum str length for a msgpack str is `(2 ^ 32) - 1`. -fn write_str_msg_pack(s: &str, buf: &mut Vec) { - assert!( - s.len() < 2usize.pow(32), - "str cannot be longer than (2^32)-1 bytes" - ); - - if s.len() < 32 { - buf.push(s.len() as u8 | 0b10100000); - } else if s.len() < 2usize.pow(8) { - buf.push(0xd9); - buf.push(s.len() as u8); - } else if s.len() < 2usize.pow(16) { - buf.push(0xda); - for b in (s.len() as u16).to_be_bytes() { - buf.push(b); - } - } else { - buf.push(0xdb); - for b in (s.len() as u32).to_be_bytes() { - buf.push(b); +/// * If succeeds, returns (header, remaining). +/// * If fails, returns existing bytes. +fn read_header_bytes(input: &[u8]) -> Result<(&[u8], &[u8]), &[u8]> { + let mut cursor = Cursor::new(input); + let input_len = input.len(); + + // Determine size of header map in terms of total objects + let len = match rmp::decode::read_map_len(&mut cursor) { + Ok(x) => x, + Err(_) => return Err(input), + }; + + // For each object, we have a corresponding key in front of it has a string, + // so we need to iterate, advancing by a string key and then the object + for _i in 0..len { + // Read just the length of the key to avoid copying the key itself + let key_len = match rmp::decode::read_str_len(&mut cursor) { + Ok(x) => x as u64, + Err(_) => return Err(input), + }; + + // Advance forward past the key + cursor.set_position(cursor.position() + key_len); + + // If we would have advanced past our input, fail + if cursor.position() as usize > input_len { + return Err(input); + } + + // Point locally to just past the str key so we can determine next byte len to skip + let input = &input[cursor.position() as usize..]; + + // Read the type of object and advance accordingly + match find_msgpack_byte_len(input) { + Some(len) => cursor.set_position(cursor.position() + len), + None => return Err(input), } + + // If we would have advanced past our input, fail + if cursor.position() as usize > input_len { + return Err(input); + } + } + + let pos = cursor.position() as usize; + + // Check if we've read beyond the input (being equal to len is okay + // because we could consume all of the remaining input this way) + if pos > input_len { + return Err(input); } - buf.extend_from_slice(s.as_bytes()); + Ok((&input[..pos], &input[pos..])) } -/// Parse msgpack str, returning remaining bytes and str on success, or error on failure. -fn parse_msg_pack_str(input: &[u8]) -> Result<(&[u8], &str), MsgPackStrParseError> { - let ilen = input.len(); - if ilen == 0 { - return Err(MsgPackStrParseError::InvalidFormat); +/// Determines the length of the next object based on its marker. From the marker, some objects +/// need to be traversed (e.g. map) in order to fully understand the total byte length. +/// +/// This will include the marker bytes in the total byte len such that collecting all of the +/// bytes up to len will yield a valid msgpack object in byte form. +/// +/// If the first byte does not signify a valid marker, this method returns None. +fn find_msgpack_byte_len(input: &[u8]) -> Option { + if input.is_empty() { + return None; } - // * fixstr using 0xa0 - 0xbf to mark the start of the str where < 32 bytes - // * str 8 (0xd9) if up to (2^8)-1 bytes, using next byte for len - // * str 16 (0xda) if up to (2^16)-1 bytes, using next two bytes for len - // * str 32 (0xdb) if up to (2^32)-1 bytes, using next four bytes for len - let (input, len): (&[u8], usize) = if input[0] >= 0xa0 && input[0] <= 0xbf { - (&input[1..], (input[0] & 0b00011111).into()) - } else if input[0] == 0xd9 && ilen > 2 { - (&input[2..], input[1].into()) - } else if input[0] == 0xda && ilen > 3 { - (&input[3..], u16::from_be_bytes([input[1], input[2]]).into()) - } else if input[0] == 0xdb && ilen > 5 { - ( - &input[5..], - u32::from_be_bytes([input[1], input[2], input[3], input[4]]) - .try_into() - .unwrap(), - ) - } else { - return Err(MsgPackStrParseError::InvalidFormat); - }; + macro_rules! read_len { + (u8: $input:expr $(, start = $start:expr)?) => {{ + let input = $input; + + $( + if input.len() < $start { + return None; + } + let input = &input[$start..]; + )? + + if input.is_empty() { + return None; + } else { + input[0] as u64 + } + }}; + (u16: $input:expr $(, start = $start:expr)?) => {{ + let input = $input; + + $( + if input.len() < $start { + return None; + } + let input = &input[$start..]; + )? + + if input.len() < 2 { + return None; + } else { + u16::from_be_bytes([input[0], input[1]]) as u64 + } + }}; + (u32: $input:expr $(, start = $start:expr)?) => {{ + let input = $input; + + $( + if input.len() < $start { + return None; + } + let input = &input[$start..]; + )? + + if input.len() < 4 { + return None; + } else { + u32::from_be_bytes([input[0], input[1], input[2], input[3]]) as u64 + } + }}; + ($cnt:expr => $input:expr $(, start = $start:expr)?) => {{ + let input = $input; + + $( + if input.len() < $start { + return None; + } + let input = &input[$start..]; + )? + + let cnt = $cnt; + let mut len = 0; + for _i in 0..cnt { + if input.len() < len { + return None; + } + + let input = &input[len..]; + match find_msgpack_byte_len(input) { + Some(x) => len += x as usize, + None => return None, + } + } + len as u64 + }}; + } - let s = match std::str::from_utf8(&input[..len]) { - Ok(s) => s, - Err(x) => return Err(MsgPackStrParseError::Utf8Error(x)), - }; + Some(match rmp::Marker::from_u8(input[0]) { + // Booleans and nil (aka null) are a combination of marker and value (single byte) + rmp::Marker::Null => 1, + rmp::Marker::True => 1, + rmp::Marker::False => 1, + + // Integers are stored in 1, 2, 3, 5, or 9 bytes + rmp::Marker::FixPos(_) => 1, + rmp::Marker::FixNeg(_) => 1, + rmp::Marker::U8 => 2, + rmp::Marker::U16 => 3, + rmp::Marker::U32 => 5, + rmp::Marker::U64 => 9, + rmp::Marker::I8 => 2, + rmp::Marker::I16 => 3, + rmp::Marker::I32 => 5, + rmp::Marker::I64 => 9, + + // Floats are stored in 5 or 9 bytes + rmp::Marker::F32 => 5, + rmp::Marker::F64 => 9, + + // Str are stored in 1, 2, 3, or 5 bytes + the data buffer + rmp::Marker::FixStr(len) => 1 + len as u64, + rmp::Marker::Str8 => 2 + read_len!(u8: input, start = 1), + rmp::Marker::Str16 => 3 + read_len!(u16: input, start = 1), + rmp::Marker::Str32 => 5 + read_len!(u32: input, start = 1), + + // Bin are stored in 2, 3, or 5 bytes + the data buffer + rmp::Marker::Bin8 => 2 + read_len!(u8: input, start = 1), + rmp::Marker::Bin16 => 3 + read_len!(u16: input, start = 1), + rmp::Marker::Bin32 => 5 + read_len!(u32: input, start = 1), + + // Arrays are stored in 1, 3, or 5 bytes + N objects (where each object has its own len) + rmp::Marker::FixArray(cnt) => 1 + read_len!(cnt => input, start = 1), + rmp::Marker::Array16 => { + let cnt = read_len!(u16: input, start = 1); + 3 + read_len!(cnt => input, start = 3) + } + rmp::Marker::Array32 => { + let cnt = read_len!(u32: input, start = 1); + 5 + read_len!(cnt => input, start = 5) + } - Ok((&input[len..], s)) + // Maps are stored in 1, 3, or 5 bytes + 2*N objects (where each object has its own len) + rmp::Marker::FixMap(cnt) => 1 + read_len!(2 * cnt => input, start = 1), + rmp::Marker::Map16 => { + let cnt = read_len!(u16: input, start = 1); + 3 + read_len!(2 * cnt => input, start = 3) + } + rmp::Marker::Map32 => { + let cnt = read_len!(u32: input, start = 1); + 5 + read_len!(2 * cnt => input, start = 5) + } + + // Ext are stored in an integer (8-bit, 16-bit, 32-bit), type (8-bit), and byte array + rmp::Marker::FixExt1 => 3, + rmp::Marker::FixExt2 => 4, + rmp::Marker::FixExt4 => 6, + rmp::Marker::FixExt8 => 10, + rmp::Marker::FixExt16 => 18, + rmp::Marker::Ext8 => 3 + read_len!(u8: input, start = 1), + rmp::Marker::Ext16 => 4 + read_len!(u16: input, start = 1), + rmp::Marker::Ext32 => 6 + read_len!(u32: input, start = 1), + + // NOTE: This is marked in the msgpack spec as never being used, so we return none + // as this is signfies something has gone wrong! + rmp::Marker::Reserved => return None, + }) +} + +/// Reads the str bytes from msgpack input, including the marker and len bytes. +/// +/// * If succeeds, returns (str, remaining). +/// * If fails, returns existing bytes. +fn read_str_bytes(input: &[u8]) -> Result<(&str, &[u8]), &[u8]> { + match rmp::decode::read_str_from_slice(input) { + Ok(x) => Ok(x), + Err(_) => Err(input), + } +} + +/// Reads a str key from msgpack input and checks if it matches `key`. If so, the input is +/// advanced, otherwise the original input is returned. +/// +/// * If key read successfully and matches, returns (unit, remaining). +/// * Otherwise, returns existing bytes. +fn read_key_eq<'a>(input: &'a [u8], key: &str) -> Result<((), &'a [u8]), &'a [u8]> { + match read_str_bytes(input) { + Ok((s, input)) if s == key => Ok(((), input)), + _ => Err(input), + } } #[cfg(test)] mod tests { use super::*; - mod write_str_msg_pack { + mod read_str_bytes { use super::*; + use test_log::test; #[test] - fn should_support_fixstr() { - // 0-byte str - let mut buf = Vec::new(); - write_str_msg_pack("", &mut buf); - assert_eq!(buf, &[0xa0]); - - // 1-byte str - let mut buf = Vec::new(); - write_str_msg_pack("a", &mut buf); - assert_eq!(buf, &[0xa1, b'a']); - - // 2-byte str - let mut buf = Vec::new(); - write_str_msg_pack("ab", &mut buf); - assert_eq!(buf, &[0xa2, b'a', b'b']); - - // 3-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abc", &mut buf); - assert_eq!(buf, &[0xa3, b'a', b'b', b'c']); - - // 4-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcd", &mut buf); - assert_eq!(buf, &[0xa4, b'a', b'b', b'c', b'd']); - - // 5-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcde", &mut buf); - assert_eq!(buf, &[0xa5, b'a', b'b', b'c', b'd', b'e']); - - // 6-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcdef", &mut buf); - assert_eq!(buf, &[0xa6, b'a', b'b', b'c', b'd', b'e', b'f']); - - // 7-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcdefg", &mut buf); - assert_eq!(buf, &[0xa7, b'a', b'b', b'c', b'd', b'e', b'f', b'g']); - - // 8-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcdefgh", &mut buf); - assert_eq!(buf, &[0xa8, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h']); - - // 9-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcdefghi", &mut buf); - assert_eq!( - buf, - &[0xa9, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i'] - ); - - // 10-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcdefghij", &mut buf); - assert_eq!( - buf, - &[0xaa, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j'] - ); - - // 11-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcdefghijk", &mut buf); - assert_eq!( - buf, - &[0xab, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k'] - ); + fn should_fail_if_input_is_empty() { + let input = read_str_bytes(&[]).unwrap_err(); + assert!(input.is_empty()); + } - // 12-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcdefghijkl", &mut buf); - assert_eq!( - buf, - &[0xac, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l'] - ); + #[test] + fn should_fail_if_input_does_not_start_with_str() { + let input = read_str_bytes(&[0xff, 0xa5, b'h', b'e', b'l', b'l', b'o']).unwrap_err(); + assert_eq!(input, [0xff, 0xa5, b'h', b'e', b'l', b'l', b'o']); + } - // 13-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcdefghijklm", &mut buf); - assert_eq!( - buf, - &[ - 0xad, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', - b'm' - ] - ); + #[test] + fn should_succeed_if_input_starts_with_str() { + let (s, remaining) = + read_str_bytes(&[0xa5, b'h', b'e', b'l', b'l', b'o', 0xff]).unwrap(); + assert_eq!(s, "hello"); + assert_eq!(remaining, [0xff]); + } + } - // 14-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcdefghijklmn", &mut buf); - assert_eq!( - buf, - &[ - 0xae, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', - b'm', b'n' - ] - ); + mod read_key_eq { + use super::*; + use test_log::test; - // 15-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcdefghijklmno", &mut buf); - assert_eq!( - buf, - &[ - 0xaf, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', - b'm', b'n', b'o' - ] - ); + #[test] + fn should_fail_if_input_is_empty() { + let input = read_key_eq(&[], "key").unwrap_err(); + assert!(input.is_empty()); + } - // 16-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcdefghijklmnop", &mut buf); - assert_eq!( - buf, - &[ - 0xb0, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', - b'm', b'n', b'o', b'p' - ] - ); + #[test] + fn should_fail_if_input_does_not_start_with_str() { + let input = &[ + 0xff, + rmp::Marker::FixStr(5).to_u8(), + b'h', + b'e', + b'l', + b'l', + b'o', + ]; + let remaining = read_key_eq(input, "key").unwrap_err(); + assert_eq!(remaining, input); + } - // 17-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcdefghijklmnopq", &mut buf); - assert_eq!( - buf, - &[ - 0xb1, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', - b'm', b'n', b'o', b'p', b'q' - ] - ); + #[test] + fn should_fail_if_read_key_does_not_match_specified_key() { + let input = &[ + rmp::Marker::FixStr(5).to_u8(), + b'h', + b'e', + b'l', + b'l', + b'o', + 0xff, + ]; + let remaining = read_key_eq(input, "key").unwrap_err(); + assert_eq!(remaining, input); + } - // 18-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcdefghijklmnopqr", &mut buf); - assert_eq!( - buf, - &[ - 0xb2, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', - b'm', b'n', b'o', b'p', b'q', b'r' - ] - ); + #[test] + fn should_succeed_if_read_key_matches_specified_key() { + let input = &[ + rmp::Marker::FixStr(5).to_u8(), + b'h', + b'e', + b'l', + b'l', + b'o', + 0xff, + ]; + let (_, remaining) = read_key_eq(input, "hello").unwrap(); + assert_eq!(remaining, [0xff]); + } + } - // 19-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcdefghijklmnopqrs", &mut buf); - assert_eq!( - buf, - &[ - 0xb3, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', - b'm', b'n', b'o', b'p', b'q', b'r', b's' - ] - ); + mod read_header_bytes { + use super::*; + use test_log::test; - // 20-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcdefghijklmnopqrst", &mut buf); - assert_eq!( - buf, - &[ - 0xb4, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', - b'm', b'n', b'o', b'p', b'q', b'r', b's', b't' - ] - ); + #[test] + fn should_fail_if_input_is_empty() { + let input = vec![]; + assert!(read_header_bytes(&input).is_err()); + } - // 21-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcdefghijklmnopqrstu", &mut buf); - assert_eq!( - buf, - &[ - 0xb5, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', - b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u' - ] - ); + #[test] + fn should_fail_if_not_a_map() { + // Provide an array instead of a map + let input = vec![0x93, 0xa3, b'a', b'b', b'c', 0xcc, 0xff, 0xc2]; + assert!(read_header_bytes(&input).is_err()); + } - // 22-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcdefghijklmnopqrstuv", &mut buf); - assert_eq!( - buf, - &[ - 0xb6, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', - b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v' - ] - ); + #[test] + fn should_fail_if_cannot_read_str_key_length() { + let input = vec![ + 0x81, // valid map with 1 pair, but key is not a str + 0x03, 0xa3, b'a', b'b', b'c', // 3 -> "abc" + ]; + assert!(read_header_bytes(&input).is_err()); + } + #[test] + fn should_fail_if_key_length_exceeds_remaining_bytes() { + let input = vec![ + 0x81, // valid map with 1 pair, but key length is too long + 0xa8, b'a', b'b', b'c', // key: "abc" (but len is much greater) + 0xa3, b'a', b'b', b'c', // value: "abc" + ]; + assert!(read_header_bytes(&input).is_err()); + } - // 23-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcdefghijklmnopqrstuvw", &mut buf); - assert_eq!( - buf, - &[ - 0xb7, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', - b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w' - ] - ); + #[test] + fn should_fail_if_missing_value_for_key() { + let input = vec![ + 0x81, // valid map with 1 pair, but value is missing + 0xa3, b'a', b'b', b'c', // key: "abc" + ]; + assert!(read_header_bytes(&input).is_err()); + } - // 24-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcdefghijklmnopqrstuvwx", &mut buf); - assert_eq!( - buf, - &[ - 0xb8, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', - b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x' - ] - ); + #[test] + fn should_fail_if_unable_to_read_value_length() { + let input = vec![ + 0x81, // valid map with 1 pair, but value is missing + 0xa3, b'a', b'b', b'c', // key: "abc" + 0xd9, // value: str 8 with missing length + ]; + assert!(read_header_bytes(&input).is_err()); + } - // 25-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcdefghijklmnopqrstuvwxy", &mut buf); - assert_eq!( - buf, - &[ - 0xb9, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', - b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y' - ] - ); + #[test] + fn should_fail_if_value_length_exceeds_remaining_bytes() { + let input = vec![ + 0x81, // valid map with 1 pair, but value is too long + 0xa3, b'a', b'b', b'c', // key: "abc" + 0xa2, b'd', // value: fixstr w/ len 1 too long + ]; + assert!(read_header_bytes(&input).is_err()); + } - // 26-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcdefghijklmnopqrstuvwxyz", &mut buf); - assert_eq!( - buf, - &[ - 0xba, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', - b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y', - b'z' - ] - ); + #[test] + fn should_succeed_with_empty_map() { + // fixmap with 0 pairs + let input = vec![0x80]; + let (header, _) = read_header_bytes(&input).unwrap(); + assert_eq!(header, input); + + // map 16 with 0 pairs + let input = vec![0xde, 0x00, 0x00]; + let (header, _) = read_header_bytes(&input).unwrap(); + assert_eq!(header, input); + + // map 32 with 0 pairs + let input = vec![0xdf, 0x00, 0x00, 0x00, 0x00]; + let (header, _) = read_header_bytes(&input).unwrap(); + assert_eq!(header, input); + } - // 27-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcdefghijklmnopqrstuvwxyz0", &mut buf); - assert_eq!( - buf, - &[ - 0xbb, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', - b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y', - b'z', b'0' - ] - ); + #[test] + fn should_succeed_with_single_key_value_map() { + // fixmap with single pair + let input = vec![ + 0x81, // valid map with 1 pair + 0xa3, b'k', b'e', b'y', // key: "key" + 0xa5, b'v', b'a', b'l', b'u', b'e', // value: "value" + ]; + let (header, _) = read_header_bytes(&input).unwrap(); + assert_eq!(header, input); + + // map 16 with single pair + let input = vec![ + 0xde, 0x00, 0x01, // valid map with 1 pair + 0xa3, b'k', b'e', b'y', // key: "key" + 0xa5, b'v', b'a', b'l', b'u', b'e', // value: "value" + ]; + let (header, _) = read_header_bytes(&input).unwrap(); + assert_eq!(header, input); + + // map 32 with single pair + let input = vec![ + 0xdf, 0x00, 0x00, 0x00, 0x01, // valid map with 1 pair + 0xa3, b'k', b'e', b'y', // key: "key" + 0xa5, b'v', b'a', b'l', b'u', b'e', // value: "value" + ]; + let (header, _) = read_header_bytes(&input).unwrap(); + assert_eq!(header, input); + } - // 28-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcdefghijklmnopqrstuvwxyz01", &mut buf); - assert_eq!( - buf, - &[ - 0xbc, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', - b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y', - b'z', b'0', b'1' - ] - ); + #[test] + fn should_succeed_with_multiple_key_value_map() { + // fixmap with single pair + let input = vec![ + 0x82, // valid map with 2 pairs + 0xa3, b'k', b'e', b'y', // key: "key" + 0xa5, b'v', b'a', b'l', b'u', b'e', // value: "value" + 0xa3, b'y', b'e', b'k', // key: "yek" + 0x7b, // value: 123 (fixint) + ]; + let (header, _) = read_header_bytes(&input).unwrap(); + assert_eq!(header, input); + + // map 16 with single pair + let input = vec![ + 0xde, 0x00, 0x02, // valid map with 2 pairs + 0xa3, b'k', b'e', b'y', // key: "key" + 0xa5, b'v', b'a', b'l', b'u', b'e', // value: "value" + 0xa3, b'y', b'e', b'k', // key: "yek" + 0x7b, // value: 123 (fixint) + ]; + let (header, _) = read_header_bytes(&input).unwrap(); + assert_eq!(header, input); + + // map 32 with single pair + let input = vec![ + 0xdf, 0x00, 0x00, 0x00, 0x02, // valid map with 2 pairs + 0xa3, b'k', b'e', b'y', // key: "key" + 0xa5, b'v', b'a', b'l', b'u', b'e', // value: "value" + 0xa3, b'y', b'e', b'k', // key: "yek" + 0x7b, // value: 123 (fixint) + ]; + let (header, _) = read_header_bytes(&input).unwrap(); + assert_eq!(header, input); + } - // 29-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcdefghijklmnopqrstuvwxyz012", &mut buf); - assert_eq!( - buf, - &[ - 0xbd, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', - b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y', - b'z', b'0', b'1', b'2' - ] - ); + #[test] + fn should_succeed_with_nested_map() { + // fixmap with single pair + let input = vec![ + 0x81, // valid map with 1 pair + 0xa3, b'm', b'a', b'p', // key: "map" + 0x81, // value: valid map with 1 pair + 0xa3, b'k', b'e', b'y', // key: "key" + 0xa5, b'v', b'a', b'l', b'u', b'e', // value: "value" + ]; + let (header, _) = read_header_bytes(&input).unwrap(); + assert_eq!(header, input); + } - // 30-byte str - let mut buf = Vec::new(); - write_str_msg_pack("abcdefghijklmnopqrstuvwxyz0123", &mut buf); - assert_eq!( - buf, - &[ - 0xbe, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', - b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y', - b'z', b'0', b'1', b'2', b'3' + #[test] + fn should_only_consume_map_from_input() { + // fixmap with single pair + let input = vec![ + 0x81, // valid map with 1 pair + 0xa3, b'k', b'e', b'y', // key: "key" + 0xa5, b'v', b'a', b'l', b'u', b'e', // value: "value" + 0xa4, b'm', b'o', b'r', b'e', // "more" (fixstr) + ]; + let (header, remaining) = read_header_bytes(&input).unwrap(); + assert_eq!( + header, + vec![ + 0x81, // valid map with 1 pair + 0xa3, b'k', b'e', b'y', // key: "key" + 0xa5, b'v', b'a', b'l', b'u', b'e', // value: "value" ] ); - - // 31-byte str is maximum len of fixstr - let mut buf = Vec::new(); - write_str_msg_pack("abcdefghijklmnopqrstuvwxyz01234", &mut buf); assert_eq!( - buf, - &[ - 0xbf, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', - b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y', - b'z', b'0', b'1', b'2', b'3', b'4' - ] + remaining, + vec![ + 0xa4, b'm', b'o', b'r', b'e', // "more" (fixstr) + ] ); } + } + + mod find_msgpack_byte_len { + use super::*; + use test_log::test; #[test] - fn should_support_str_8() { - let input = "a".repeat(32); - let mut buf = Vec::new(); - write_str_msg_pack(&input, &mut buf); - assert_eq!(buf[0], 0xd9); - assert_eq!(buf[1], input.len() as u8); - assert_eq!(&buf[2..], input.as_bytes()); - - let input = "a".repeat(2usize.pow(8) - 1); - let mut buf = Vec::new(); - write_str_msg_pack(&input, &mut buf); - assert_eq!(buf[0], 0xd9); - assert_eq!(buf[1], input.len() as u8); - assert_eq!(&buf[2..], input.as_bytes()); + fn should_return_none_if_input_is_empty() { + let input = vec![]; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, None, "Wrong len for {input:X?}"); } #[test] - fn should_support_str_16() { - let input = "a".repeat(2usize.pow(8)); - let mut buf = Vec::new(); - write_str_msg_pack(&input, &mut buf); - assert_eq!(buf[0], 0xda); - assert_eq!(&buf[1..3], &(input.len() as u16).to_be_bytes()); - assert_eq!(&buf[3..], input.as_bytes()); - - let input = "a".repeat(2usize.pow(16) - 1); - let mut buf = Vec::new(); - write_str_msg_pack(&input, &mut buf); - assert_eq!(buf[0], 0xda); - assert_eq!(&buf[1..3], &(input.len() as u16).to_be_bytes()); - assert_eq!(&buf[3..], input.as_bytes()); + fn should_return_none_if_input_has_reserved_marker() { + let input = vec![rmp::Marker::Reserved.to_u8()]; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, None, "Wrong len for {input:X?}"); } #[test] - fn should_support_str_32() { - let input = "a".repeat(2usize.pow(16)); - let mut buf = Vec::new(); - write_str_msg_pack(&input, &mut buf); - assert_eq!(buf[0], 0xdb); - assert_eq!(&buf[1..5], &(input.len() as u32).to_be_bytes()); - assert_eq!(&buf[5..], input.as_bytes()); + fn should_return_1_if_input_is_nil() { + let input = vec![0xc0]; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(1), "Wrong len for {input:X?}"); } - } - - mod parse_msg_pack_str { - use super::*; #[test] - fn should_be_able_to_parse_fixstr() { - // Empty str - let (input, s) = parse_msg_pack_str(&[0xa0]).unwrap(); - assert!(input.is_empty()); - assert_eq!(s, ""); - - // Single character - let (input, s) = parse_msg_pack_str(&[0xa1, b'a']).unwrap(); - assert!(input.is_empty()); - assert_eq!(s, "a"); - - // 31 byte str - let (input, s) = parse_msg_pack_str(&[ - 0xbf, b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', - b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', - b'a', b'a', b'a', b'a', - ]) - .unwrap(); - assert!(input.is_empty()); - assert_eq!(s, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); - - // Verify that we only consume up to fixstr length - assert_eq!(parse_msg_pack_str(&[0xa0, b'a']).unwrap().0, b"a"); - assert_eq!( - parse_msg_pack_str(&[ - 0xbf, b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', - b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', - b'a', b'a', b'a', b'a', b'a', b'a', b'b' - ]) - .unwrap() - .0, - b"b" - ); + fn should_return_1_if_input_is_a_boolean() { + let input = vec![0xc2]; // false + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(1), "Wrong len for {input:X?}"); + + let input = vec![0xc3]; // true + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(1), "Wrong len for {input:X?}"); } #[test] - fn should_be_able_to_parse_str_8() { - // 32 byte str - let (input, s) = parse_msg_pack_str(&[ - 0xd9, 32, b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', - b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', - b'a', b'a', b'a', b'a', b'a', b'a', - ]) - .unwrap(); - assert!(input.is_empty()); - assert_eq!(s, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); - - // 2^8 - 1 (255) byte str - let test_str = "a".repeat(2usize.pow(8) - 1); - let mut input = vec![0xd9, 255]; - input.extend_from_slice(test_str.as_bytes()); - let (input, s) = parse_msg_pack_str(&input).unwrap(); - assert!(input.is_empty()); - assert_eq!(s, test_str); - - // Verify that we only consume up to 2^8 - 1 length - let mut input = vec![0xd9, 255]; - input.extend_from_slice(test_str.as_bytes()); - input.extend_from_slice(b"hello"); - let (input, s) = parse_msg_pack_str(&input).unwrap(); - assert_eq!(input, b"hello"); - assert_eq!(s, test_str); + fn should_return_appropriate_len_if_input_is_some_integer() { + let input = vec![0x00]; // positive fixint (0) + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(1), "Wrong len for {input:X?}"); + + let input = vec![0xff]; // negative fixint (-1) + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(1), "Wrong len for {input:X?}"); + + let input = vec![0xcc, 0xff]; // unsigned 8-bit (255) + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(2), "Wrong len for {input:X?}"); + + let input = vec![0xcd, 0xff, 0xff]; // unsigned 16-bit (65535) + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(3), "Wrong len for {input:X?}"); + + let input = vec![0xce, 0xff, 0xff, 0xff, 0xff]; // unsigned 32-bit (4294967295) + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(5), "Wrong len for {input:X?}"); + + let input = vec![0xcf, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00]; // unsigned 64-bit (4294967296) + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(9), "Wrong len for {input:X?}"); + + let input = vec![0xd0, 0x81]; // signed 8-bit (-127) + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(2), "Wrong len for {input:X?}"); + + let input = vec![0xd1, 0x80, 0x01]; // signed 16-bit (-32767) + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(3), "Wrong len for {input:X?}"); + + let input = vec![0xd2, 0x80, 0x00, 0x00, 0x01]; // signed 32-bit (-2147483647) + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(5), "Wrong len for {input:X?}"); + + let input = vec![0xd3, 0xff, 0xff, 0xff, 0xff, 0x80, 0x00, 0x00, 0x00]; // signed 64-bit (-2147483648) + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(9), "Wrong len for {input:X?}"); } #[test] - fn should_be_able_to_parse_str_16() { - // 2^8 byte str (256) - let test_str = "a".repeat(2usize.pow(8)); - let mut input = vec![0xda, 1, 0]; - input.extend_from_slice(test_str.as_bytes()); - let (input, s) = parse_msg_pack_str(&input).unwrap(); - assert!(input.is_empty()); - assert_eq!(s, test_str); - - // 2^16 - 1 (65535) byte str - let test_str = "a".repeat(2usize.pow(16) - 1); - let mut input = vec![0xda, 255, 255]; - input.extend_from_slice(test_str.as_bytes()); - let (input, s) = parse_msg_pack_str(&input).unwrap(); - assert!(input.is_empty()); - assert_eq!(s, test_str); - - // Verify that we only consume up to 2^16 - 1 length - let mut input = vec![0xda, 255, 255]; - input.extend_from_slice(test_str.as_bytes()); - input.extend_from_slice(b"hello"); - let (input, s) = parse_msg_pack_str(&input).unwrap(); - assert_eq!(input, b"hello"); - assert_eq!(s, test_str); + fn should_return_appropriate_len_if_input_is_some_float() { + let input = vec![0xca, 0x3d, 0xcc, 0xcc, 0xcd]; // f32 (0.1) + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(5), "Wrong len for {input:X?}"); + + let input = vec![0xcb, 0x3f, 0xb9, 0x99, 0x99, 0x99, 0x99, 0x99, 0x9a]; // f64 (0.1) + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(9), "Wrong len for {input:X?}"); } #[test] - fn should_be_able_to_parse_str_32() { - // 2^16 byte str - let test_str = "a".repeat(2usize.pow(16)); - let mut input = vec![0xdb, 0, 1, 0, 0]; - input.extend_from_slice(test_str.as_bytes()); - let (input, s) = parse_msg_pack_str(&input).unwrap(); - assert!(input.is_empty()); - assert_eq!(s, test_str); - - // NOTE: We are not going to run the below tests, not because they aren't valid but - // because this generates a 4GB str which takes 20+ seconds to run - - // 2^32 - 1 byte str (4294967295 bytes) - /* let test_str = "a".repeat(2usize.pow(32) - 1); - let mut input = vec![0xdb, 255, 255, 255, 255]; - input.extend_from_slice(test_str.as_bytes()); - let (input, s) = parse_msg_pack_str(&input).unwrap(); - assert!(input.is_empty()); - assert_eq!(s, test_str); */ - - // Verify that we only consume up to 2^32 - 1 length - /* let mut input = vec![0xdb, 255, 255, 255, 255]; - input.extend_from_slice(test_str.as_bytes()); - input.extend_from_slice(b"hello"); - let (input, s) = parse_msg_pack_str(&input).unwrap(); - assert_eq!(input, b"hello"); - assert_eq!(s, test_str); */ + fn should_return_appropriate_len_if_input_is_some_str() { + // fixstr (31 bytes max) + let input = vec![0xa5, b'h', b'e', b'l', b'l', b'o']; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(5 + 1), "Wrong len for {input:X?}"); + + // str 8 will read second byte (u8) for size + let input = vec![0xd9, 0xff, b'd', b'a', b't', b'a']; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(u8::MAX as u64 + 2), "Wrong len for {input:X?}"); + + // str 16 will read second & third bytes (u16) for size + let input = vec![0xda, 0xff, 0xff, b'd', b'a', b't', b'a']; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(u16::MAX as u64 + 3), "Wrong len for {input:X?}"); + + // str 32 will read second, third, fourth, & fifth bytes (u32) for size + let input = vec![0xdb, 0xff, 0xff, 0xff, 0xff, b'd', b'a', b't', b'a']; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(u32::MAX as u64 + 5), "Wrong len for {input:X?}"); } #[test] - fn should_fail_parsing_str_with_invalid_length() { - // Make sure that parse doesn't fail looking for bytes after str 8 len - assert_eq!( - parse_msg_pack_str(&[0xd9]), - Err(MsgPackStrParseError::InvalidFormat) - ); - assert_eq!( - parse_msg_pack_str(&[0xd9, 0]), - Err(MsgPackStrParseError::InvalidFormat) - ); - - // Make sure that parse doesn't fail looking for bytes after str 16 len - assert_eq!( - parse_msg_pack_str(&[0xda]), - Err(MsgPackStrParseError::InvalidFormat) - ); - assert_eq!( - parse_msg_pack_str(&[0xda, 0]), - Err(MsgPackStrParseError::InvalidFormat) - ); - assert_eq!( - parse_msg_pack_str(&[0xda, 0, 0]), - Err(MsgPackStrParseError::InvalidFormat) - ); - - // Make sure that parse doesn't fail looking for bytes after str 32 len - assert_eq!( - parse_msg_pack_str(&[0xdb]), - Err(MsgPackStrParseError::InvalidFormat) - ); - assert_eq!( - parse_msg_pack_str(&[0xdb, 0]), - Err(MsgPackStrParseError::InvalidFormat) - ); - assert_eq!( - parse_msg_pack_str(&[0xdb, 0, 0]), - Err(MsgPackStrParseError::InvalidFormat) - ); - assert_eq!( - parse_msg_pack_str(&[0xdb, 0, 0, 0]), - Err(MsgPackStrParseError::InvalidFormat) - ); - assert_eq!( - parse_msg_pack_str(&[0xdb, 0, 0, 0, 0]), - Err(MsgPackStrParseError::InvalidFormat) - ); + fn should_return_appropriate_len_if_input_is_some_bin() { + // bin 8 will read second byte (u8) for size + let input = vec![0xc4, 0xff, b'd', b'a', b't', b'a']; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(u8::MAX as u64 + 2), "Wrong len for {input:X?}"); + + // bin 16 will read second & third bytes (u16) for size + let input = vec![0xc5, 0xff, 0xff, b'd', b'a', b't', b'a']; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(u16::MAX as u64 + 3), "Wrong len for {input:X?}"); + + // bin 32 will read second, third, fourth, & fifth bytes (u32) for size + let input = vec![0xc6, 0xff, 0xff, 0xff, 0xff, b'd', b'a', b't', b'a']; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(u32::MAX as u64 + 5), "Wrong len for {input:X?}"); } #[test] - fn should_fail_parsing_other_types() { - assert_eq!( - parse_msg_pack_str(&[0xc3]), // Boolean (true) - Err(MsgPackStrParseError::InvalidFormat) - ); + fn should_return_appropriate_len_if_input_is_some_array() { + // fixarray has a length up to 15 objects + // + // In this example, we have an array of 3 objects that are a str, integer, and bool + let input = vec![0x93, 0xa3, b'a', b'b', b'c', 0xcc, 0xff, 0xc2]; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(1 + 4 + 2 + 1), "Wrong len for {input:X?}"); + + // Invalid fixarray count should return none + let input = vec![0x93, 0xa3, b'a', b'b', b'c', 0xcc, 0xff]; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, None, "Wrong len for {input:X?}"); + + // array 16 will read second & third bytes (u16) for object length + // + // In this example, we have an array of 3 objects that are a str, integer, and bool + let input = vec![0xdc, 0x00, 0x03, 0xa3, b'a', b'b', b'c', 0xcc, 0xff, 0xc2]; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(3 + 4 + 2 + 1), "Wrong len for {input:X?}"); + + // Invalid array 16 count should return none + let input = vec![0xdc, 0x00, 0x03, 0xa3, b'a', b'b', b'c', 0xcc, 0xff]; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, None, "Wrong len for {input:X?}"); + + // array 32 will read second, third, fourth, & fifth bytes (u32) for object length + let input = vec![ + 0xdd, 0x00, 0x00, 0x00, 0x03, 0xa3, b'a', b'b', b'c', 0xcc, 0xff, 0xc2, + ]; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(5 + 4 + 2 + 1), "Wrong len for {input:X?}"); + + // Invalid array 32 count should return none + let input = vec![ + 0xdd, 0x00, 0x00, 0x00, 0x03, 0xa3, b'a', b'b', b'c', 0xcc, 0xff, + ]; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, None, "Wrong len for {input:X?}"); } #[test] - fn should_fail_if_empty_input() { - assert_eq!( - parse_msg_pack_str(&[]), - Err(MsgPackStrParseError::InvalidFormat) - ); + fn should_return_appropriate_len_if_input_is_some_map() { + // fixmap has a length up to 2*15 objects + let input = vec![ + 0x83, // 3 objects /w keys + 0x03, 0xa3, b'a', b'b', b'c', // 3 -> "abc" + 0xa3, b'a', b'b', b'c', 0xcc, 0xff, // "abc" -> 255 + 0xc3, 0xc2, // true -> false + ]; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(1 + 5 + 6 + 2), "Wrong len for {input:X?}"); + + // Invalid fixmap count should return none + let input = vec![ + 0x83, // 3 objects /w keys + 0x03, 0xa3, b'a', b'b', b'c', // 3 -> "abc" + 0xa3, b'a', b'b', b'c', 0xcc, 0xff, // "abc" -> 255 + 0xc3, // true -> ??? + ]; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, None, "Wrong len for {input:X?}"); + + // map 16 will read second & third bytes (u16) for object length + let input = vec![ + 0xde, 0x00, 0x03, // 3 objects w/ keys + 0x03, 0xa3, b'a', b'b', b'c', // 3 -> "abc" + 0xa3, b'a', b'b', b'c', 0xcc, 0xff, // "abc" -> 255 + 0xc3, 0xc2, // true -> false + ]; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(3 + 5 + 6 + 2), "Wrong len for {input:X?}"); + + // Invalid map 16 count should return none + let input = vec![ + 0xde, 0x00, 0x03, // 3 objects w/ keys + 0x03, 0xa3, b'a', b'b', b'c', // 3 -> "abc" + 0xa3, b'a', b'b', b'c', 0xcc, 0xff, // "abc" -> 255 + 0xc3, // true -> ??? + ]; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, None, "Wrong len for {input:X?}"); + + // map 32 will read second, third, fourth, & fifth bytes (u32) for object length + let input = vec![ + 0xdf, 0x00, 0x00, 0x00, 0x03, // 3 objects w/ keys + 0x03, 0xa3, b'a', b'b', b'c', // 3 -> "abc" + 0xa3, b'a', b'b', b'c', 0xcc, 0xff, // "abc" -> 255 + 0xc3, 0xc2, // true -> false + ]; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(5 + 5 + 6 + 2), "Wrong len for {input:X?}"); + + // Invalid map 32 count should return none + let input = vec![ + 0xdf, 0x00, 0x00, 0x00, 0x03, // 3 objects w/ keys + 0x03, 0xa3, b'a', b'b', b'c', // 3 -> "abc" + 0xa3, b'a', b'b', b'c', 0xcc, 0xff, // "abc" -> 255 + 0xc3, // true -> ??? + ]; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, None, "Wrong len for {input:X?}"); } #[test] - fn should_fail_if_str_is_not_utf8() { - assert!(matches!( - parse_msg_pack_str(&[0xa4, 0, 159, 146, 150]), - Err(MsgPackStrParseError::Utf8Error(_)) - )); + fn should_return_appropriate_len_if_input_is_some_ext() { + // fixext 1 claims single data byte (excluding type) + let input = vec![0xd4, 0x00, 0x12]; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(1 + 1 + 1), "Wrong len for {input:X?}"); + + // fixext 2 claims two data bytes (excluding type) + let input = vec![0xd5, 0x00, 0x12, 0x34]; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(1 + 1 + 2), "Wrong len for {input:X?}"); + + // fixext 4 claims four data bytes (excluding type) + let input = vec![0xd6, 0x00, 0x12, 0x34, 0x56, 0x78]; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(1 + 1 + 4), "Wrong len for {input:X?}"); + + // fixext 8 claims eight data bytes (excluding type) + let input = vec![0xd7, 0x00, 0x12, 0x34, 0x56, 0x78]; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(1 + 1 + 8), "Wrong len for {input:X?}"); + + // fixext 16 claims sixteen data bytes (excluding type) + let input = vec![0xd8, 0x00, 0x12, 0x34, 0x56, 0x78]; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(1 + 1 + 16), "Wrong len for {input:X?}"); + + // ext 8 will read second byte (u8) for size (excluding type) + let input = vec![0xc7, 0xff, 0x00, b'd', b'a', b't', b'a']; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(u8::MAX as u64 + 3), "Wrong len for {input:X?}"); + + // ext 16 will read second & third bytes (u16) for size (excluding type) + let input = vec![0xc8, 0xff, 0xff, 0x00, b'd', b'a', b't', b'a']; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(u16::MAX as u64 + 4), "Wrong len for {input:X?}"); + + // ext 32 will read second, third, fourth, & fifth bytes (u32) for size (excluding type) + let input = vec![0xc9, 0xff, 0xff, 0xff, 0xff, 0x00, b'd', b'a', b't', b'a']; + let len = find_msgpack_byte_len(&input); + assert_eq!(len, Some(u32::MAX as u64 + 6), "Wrong len for {input:X?}"); } } } diff --git a/distant-net/src/common/packet/header.rs b/distant-net/src/common/packet/header.rs new file mode 100644 index 0000000..93425f4 --- /dev/null +++ b/distant-net/src/common/packet/header.rs @@ -0,0 +1,80 @@ +use crate::common::{utils, Value}; +use derive_more::IntoIterator; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::io; +use std::ops::{Deref, DerefMut}; + +/// Generates a new [`Header`] of key/value pairs based on literals. +/// +/// ``` +/// use distant_net::header; +/// +/// let _header = header!("key" -> "value", "key2" -> 123); +/// ``` +#[macro_export] +macro_rules! header { + ($($key:literal -> $value:expr),* $(,)?) => {{ + let mut _header = $crate::common::Header::default(); + + $( + _header.insert($key, $value); + )* + + _header + }}; +} + +/// Represents a packet header comprised of arbitrary data tied to string keys. +#[derive(Clone, Debug, Default, PartialEq, Eq, IntoIterator, Serialize, Deserialize)] +#[serde(transparent)] +pub struct Header(HashMap); + +impl Header { + /// Creates an empty [`Header`] newtype wrapper. + pub fn new() -> Self { + Self::default() + } + + /// Exists purely to support serde serialization checks. + #[inline] + pub(crate) fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// Inserts a key-value pair into the map. + /// + /// If the map did not have this key present, [`None`] is returned. + /// + /// If the map did have this key present, the value is updated, and the old value is returned. + /// The key is not updated, though; this matters for types that can be `==` without being + /// identical. See the [module-level documentation](std::collections#insert-and-complex-keys) + /// for more. + pub fn insert(&mut self, key: impl Into, value: impl Into) -> Option { + self.0.insert(key.into(), value.into()) + } + + /// Serializes the header into bytes. + pub fn to_vec(&self) -> io::Result> { + utils::serialize_to_vec(self) + } + + /// Deserializes the header from bytes. + pub fn from_slice(slice: &[u8]) -> io::Result { + utils::deserialize_from_slice(slice) + } +} + +impl Deref for Header { + type Target = HashMap; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for Header { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} diff --git a/distant-net/src/common/packet/request.rs b/distant-net/src/common/packet/request.rs index 84c8b81..bd74e34 100644 --- a/distant-net/src/common/packet/request.rs +++ b/distant-net/src/common/packet/request.rs @@ -5,12 +5,17 @@ use derive_more::{Display, Error}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; -use super::{parse_msg_pack_str, write_str_msg_pack, Id}; +use super::{read_header_bytes, read_key_eq, read_str_bytes, Header, Id}; use crate::common::utils; +use crate::header; /// Represents a request to send -#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct Request { + /// Optional header data to include with request + #[serde(default, skip_serializing_if = "Header::is_empty")] + pub header: Header, + /// Unique id associated with the request pub id: Id, @@ -19,9 +24,10 @@ pub struct Request { } impl Request { - /// Creates a new request with a random, unique id + /// Creates a new request with a random, unique id and no header data pub fn new(payload: T) -> Self { Self { + header: header!(), id: rand::random::().to_string(), payload, } @@ -45,6 +51,11 @@ where /// Attempts to convert a typed request to an untyped request pub fn to_untyped_request(&self) -> io::Result { Ok(UntypedRequest { + header: Cow::Owned(if !self.header.is_empty() { + utils::serialize_to_vec(&self.header)? + } else { + Vec::new() + }), id: Cow::Borrowed(&self.id), payload: Cow::Owned(self.to_payload_vec()?), }) @@ -73,13 +84,34 @@ pub enum UntypedRequestParseError { /// When the bytes do not represent a request WrongType, + /// When a header should be present, but the key is wrong + InvalidHeaderKey, + + /// When a header should be present, but the header bytes are wrong + InvalidHeader, + + /// When the key for the id is wrong + InvalidIdKey, + /// When the id is not a valid UTF-8 string InvalidId, + + /// When the key for the payload is wrong + InvalidPayloadKey, +} + +#[inline] +fn header_is_empty(header: &[u8]) -> bool { + header.is_empty() } /// Represents a request to send whose payload is bytes instead of a specific type -#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct UntypedRequest<'a> { + /// Header data associated with the request as bytes + #[serde(default, skip_serializing_if = "header_is_empty")] + pub header: Cow<'a, [u8]>, + /// Unique id associated with the request pub id: Cow<'a, str>, @@ -91,6 +123,11 @@ impl<'a> UntypedRequest<'a> { /// Attempts to convert an untyped request to a typed request pub fn to_typed_request(&self) -> io::Result> { Ok(Request { + header: if header_is_empty(&self.header) { + header!() + } else { + utils::deserialize_from_slice(&self.header)? + }, id: self.id.to_string(), payload: utils::deserialize_from_slice(&self.payload)?, }) @@ -99,6 +136,10 @@ impl<'a> UntypedRequest<'a> { /// Convert into a borrowed version pub fn as_borrowed(&self) -> UntypedRequest<'_> { UntypedRequest { + header: match &self.header { + Cow::Borrowed(x) => Cow::Borrowed(x), + Cow::Owned(x) => Cow::Borrowed(x.as_slice()), + }, id: match &self.id { Cow::Borrowed(x) => Cow::Borrowed(x), Cow::Owned(x) => Cow::Borrowed(x.as_str()), @@ -113,6 +154,10 @@ impl<'a> UntypedRequest<'a> { /// Convert into an owned version pub fn into_owned(self) -> UntypedRequest<'static> { UntypedRequest { + header: match self.header { + Cow::Borrowed(x) => Cow::Owned(x.to_vec()), + Cow::Owned(x) => Cow::Owned(x), + }, id: match self.id { Cow::Borrowed(x) => Cow::Owned(x.to_string()), Cow::Owned(x) => Cow::Owned(x), @@ -124,6 +169,11 @@ impl<'a> UntypedRequest<'a> { } } + /// Updates the header of the request to the given `header`. + pub fn set_header(&mut self, header: impl IntoIterator) { + self.header = Cow::Owned(header.into_iter().collect()); + } + /// Updates the id of the request to the given `id`. pub fn set_id(&mut self, id: impl Into) { self.id = Cow::Owned(id.into()); @@ -131,61 +181,80 @@ impl<'a> UntypedRequest<'a> { /// Allocates a new collection of bytes representing the request. pub fn to_bytes(&self) -> Vec { - let mut bytes = vec![0x82]; + let mut bytes = vec![]; + + let has_header = !header_is_empty(&self.header); + if has_header { + rmp::encode::write_map_len(&mut bytes, 3).unwrap(); + } else { + rmp::encode::write_map_len(&mut bytes, 2).unwrap(); + } + + if has_header { + rmp::encode::write_str(&mut bytes, "header").unwrap(); + bytes.extend_from_slice(&self.header); + } - write_str_msg_pack("id", &mut bytes); - write_str_msg_pack(&self.id, &mut bytes); + rmp::encode::write_str(&mut bytes, "id").unwrap(); + rmp::encode::write_str(&mut bytes, &self.id).unwrap(); - write_str_msg_pack("payload", &mut bytes); + rmp::encode::write_str(&mut bytes, "payload").unwrap(); bytes.extend_from_slice(&self.payload); bytes } /// Parses a collection of bytes, returning a partial request if it can be potentially - /// represented as a [`Request`] depending on the payload, or the original bytes if it does not - /// represent a [`Request`] + /// represented as a [`Request`] depending on the payload. /// /// NOTE: This supports parsing an invalid request where the payload would not properly /// deserialize, but the bytes themselves represent a complete request of some kind. pub fn from_slice(input: &'a [u8]) -> Result { - if input.len() < 2 { + if input.is_empty() { return Err(UntypedRequestParseError::WrongType); } - // MsgPack marks a fixmap using 0x80 - 0x8f to indicate the size (up to 15 elements). - // - // In the case of the request, there are only two elements: id and payload. So the first - // byte should ALWAYS be 0x82 (130). - if input[0] != 0x82 { - return Err(UntypedRequestParseError::WrongType); - } + let has_header = match rmp::Marker::from_u8(input[0]) { + rmp::Marker::FixMap(2) => false, + rmp::Marker::FixMap(3) => true, + _ => return Err(UntypedRequestParseError::WrongType), + }; - // Skip the first byte representing the fixmap + // Advance position by marker let input = &input[1..]; - // Validate that first field is id - let (input, id_key) = - parse_msg_pack_str(input).map_err(|_| UntypedRequestParseError::WrongType)?; - if id_key != "id" { - return Err(UntypedRequestParseError::WrongType); - } + // Parse the header if we have one + let (header, input) = if has_header { + let (_, input) = read_key_eq(input, "header") + .map_err(|_| UntypedRequestParseError::InvalidHeaderKey)?; + + let (header, input) = + read_header_bytes(input).map_err(|_| UntypedRequestParseError::InvalidHeader)?; + (header, input) + } else { + ([0u8; 0].as_slice(), input) + }; + + // Validate that next field is id + let (_, input) = + read_key_eq(input, "id").map_err(|_| UntypedRequestParseError::InvalidIdKey)?; // Get the id itself - let (input, id) = - parse_msg_pack_str(input).map_err(|_| UntypedRequestParseError::InvalidId)?; + let (id, input) = read_str_bytes(input).map_err(|_| UntypedRequestParseError::InvalidId)?; - // Validate that second field is payload - let (input, payload_key) = - parse_msg_pack_str(input).map_err(|_| UntypedRequestParseError::WrongType)?; - if payload_key != "payload" { - return Err(UntypedRequestParseError::WrongType); - } + // Validate that final field is payload + let (_, input) = read_key_eq(input, "payload") + .map_err(|_| UntypedRequestParseError::InvalidPayloadKey)?; + let header = Cow::Borrowed(header); let id = Cow::Borrowed(id); let payload = Cow::Borrowed(input); - Ok(Self { id, payload }) + Ok(Self { + header, + id, + payload, + }) } } @@ -198,18 +267,33 @@ mod tests { const TRUE_BYTE: u8 = 0xc3; const NEVER_USED_BYTE: u8 = 0xc1; + // fixstr of 6 bytes with str "header" + const HEADER_FIELD_BYTES: &[u8] = &[0xa6, b'h', b'e', b'a', b'd', b'e', b'r']; + + // fixmap of 2 objects with + // 1. key fixstr "key" and value fixstr "value" + // 1. key fixstr "num" and value fixint 123 + const HEADER_BYTES: &[u8] = &[ + 0x82, // valid map with 2 pair + 0xa3, b'k', b'e', b'y', // key: "key" + 0xa5, b'v', b'a', b'l', b'u', b'e', // value: "value" + 0xa3, b'n', b'u', b'm', // key: "num" + 0x7b, // value: 123 + ]; + // fixstr of 2 bytes with str "id" - const ID_FIELD_BYTES: &[u8] = &[0xa2, 0x69, 0x64]; + const ID_FIELD_BYTES: &[u8] = &[0xa2, b'i', b'd']; // fixstr of 7 bytes with str "payload" - const PAYLOAD_FIELD_BYTES: &[u8] = &[0xa7, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64]; + const PAYLOAD_FIELD_BYTES: &[u8] = &[0xa7, b'p', b'a', b'y', b'l', b'o', b'a', b'd']; - /// fixstr of 4 bytes with str "test" - const TEST_STR_BYTES: &[u8] = &[0xa4, 0x74, 0x65, 0x73, 0x74]; + // fixstr of 4 bytes with str "test" + const TEST_STR_BYTES: &[u8] = &[0xa4, b't', b'e', b's', b't']; #[test] fn untyped_request_should_support_converting_to_bytes() { let bytes = Request { + header: header!(), id: "some id".to_string(), payload: true, } @@ -220,9 +304,44 @@ mod tests { assert_eq!(untyped_request.to_bytes(), bytes); } + #[test] + fn untyped_request_should_support_converting_to_bytes_with_header() { + let bytes = Request { + header: header!("key" -> 123), + id: "some id".to_string(), + payload: true, + } + .to_vec() + .unwrap(); + + let untyped_request = UntypedRequest::from_slice(&bytes).unwrap(); + assert_eq!(untyped_request.to_bytes(), bytes); + } + + #[test] + fn untyped_request_should_support_parsing_from_request_bytes_with_header() { + let bytes = Request { + header: header!("key" -> 123), + id: "some id".to_string(), + payload: true, + } + .to_vec() + .unwrap(); + + assert_eq!( + UntypedRequest::from_slice(&bytes), + Ok(UntypedRequest { + header: Cow::Owned(utils::serialize_to_vec(&header!("key" -> 123)).unwrap()), + id: Cow::Borrowed("some id"), + payload: Cow::Owned(vec![TRUE_BYTE]), + }) + ); + } + #[test] fn untyped_request_should_support_parsing_from_request_bytes_with_valid_payload() { let bytes = Request { + header: header!(), id: "some id".to_string(), payload: true, } @@ -232,6 +351,7 @@ mod tests { assert_eq!( UntypedRequest::from_slice(&bytes), Ok(UntypedRequest { + header: Cow::Owned(vec![]), id: Cow::Borrowed("some id"), payload: Cow::Owned(vec![TRUE_BYTE]), }) @@ -242,6 +362,7 @@ mod tests { fn untyped_request_should_support_parsing_from_request_bytes_with_invalid_payload() { // Request with id < 32 bytes let mut bytes = Request { + header: header!(), id: "".to_string(), payload: true, } @@ -255,12 +376,35 @@ mod tests { assert_eq!( UntypedRequest::from_slice(&bytes), Ok(UntypedRequest { + header: Cow::Owned(vec![]), id: Cow::Owned("".to_string()), payload: Cow::Owned(vec![TRUE_BYTE, NEVER_USED_BYTE]), }) ); } + #[test] + fn untyped_request_should_support_parsing_full_request() { + let input = [ + &[0x83], + HEADER_FIELD_BYTES, + HEADER_BYTES, + ID_FIELD_BYTES, + TEST_STR_BYTES, + PAYLOAD_FIELD_BYTES, + &[TRUE_BYTE], + ] + .concat(); + + // Convert into typed so we can test + let untyped_request = UntypedRequest::from_slice(&input).unwrap(); + let request: Request = untyped_request.to_typed_request().unwrap(); + + assert_eq!(request.header, header!("key" -> "value", "num" -> 123)); + assert_eq!(request.id, "test"); + assert!(request.payload); + } + #[test] fn untyped_request_should_fail_to_parse_if_given_bytes_not_representing_a_request() { // Empty byte slice @@ -281,10 +425,46 @@ mod tests { Err(UntypedRequestParseError::WrongType) ); + // Invalid header key + assert_eq!( + UntypedRequest::from_slice( + [ + &[0x83], + &[0xa0], // header key would be defined here, set to empty str + HEADER_BYTES, + ID_FIELD_BYTES, + TEST_STR_BYTES, + PAYLOAD_FIELD_BYTES, + &[TRUE_BYTE], + ] + .concat() + .as_slice() + ), + Err(UntypedRequestParseError::InvalidHeaderKey) + ); + + // Invalid header bytes + assert_eq!( + UntypedRequest::from_slice( + [ + &[0x83], + HEADER_FIELD_BYTES, + &[0xa0], // header would be defined here, set to empty str + ID_FIELD_BYTES, + TEST_STR_BYTES, + PAYLOAD_FIELD_BYTES, + &[TRUE_BYTE], + ] + .concat() + .as_slice() + ), + Err(UntypedRequestParseError::InvalidHeader) + ); + // Missing fields (corrupt data) assert_eq!( UntypedRequest::from_slice(&[0x82]), - Err(UntypedRequestParseError::WrongType) + Err(UntypedRequestParseError::InvalidIdKey) ); // Missing id field (has valid data itself) @@ -300,7 +480,7 @@ mod tests { .concat() .as_slice() ), - Err(UntypedRequestParseError::WrongType) + Err(UntypedRequestParseError::InvalidIdKey) ); // Non-str id field value @@ -348,7 +528,7 @@ mod tests { .concat() .as_slice() ), - Err(UntypedRequestParseError::WrongType) + Err(UntypedRequestParseError::InvalidPayloadKey) ); } } diff --git a/distant-net/src/common/packet/response.rs b/distant-net/src/common/packet/response.rs index cc96e96..6056fe6 100644 --- a/distant-net/src/common/packet/response.rs +++ b/distant-net/src/common/packet/response.rs @@ -5,12 +5,17 @@ use derive_more::{Display, Error}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; -use super::{parse_msg_pack_str, write_str_msg_pack, Id}; +use super::{read_header_bytes, read_key_eq, read_str_bytes, Header, Id}; use crate::common::utils; +use crate::header; /// Represents a response received related to some response -#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct Response { + /// Optional header data to include with response + #[serde(default, skip_serializing_if = "Header::is_empty")] + pub header: Header, + /// Unique id associated with the response pub id: Id, @@ -22,9 +27,10 @@ pub struct Response { } impl Response { - /// Creates a new response with a random, unique id + /// Creates a new response with a random, unique id and no header data pub fn new(origin_id: Id, payload: T) -> Self { Self { + header: header!(), id: rand::random::().to_string(), origin_id, payload, @@ -49,6 +55,11 @@ where /// Attempts to convert a typed response to an untyped response pub fn to_untyped_response(&self) -> io::Result { Ok(UntypedResponse { + header: Cow::Owned(if !self.header.is_empty() { + utils::serialize_to_vec(&self.header)? + } else { + Vec::new() + }), id: Cow::Borrowed(&self.id), origin_id: Cow::Borrowed(&self.origin_id), payload: Cow::Owned(self.to_payload_vec()?), @@ -72,16 +83,40 @@ pub enum UntypedResponseParseError { /// When the bytes do not represent a response WrongType, + /// When a header should be present, but the key is wrong + InvalidHeaderKey, + + /// When a header should be present, but the header bytes are wrong + InvalidHeader, + + /// When the key for the id is wrong + InvalidIdKey, + /// When the id is not a valid UTF-8 string InvalidId, + /// When the key for the origin id is wrong + InvalidOriginIdKey, + /// When the origin id is not a valid UTF-8 string InvalidOriginId, + + /// When the key for the payload is wrong + InvalidPayloadKey, +} + +#[inline] +fn header_is_empty(header: &[u8]) -> bool { + header.is_empty() } /// Represents a response to send whose payload is bytes instead of a specific type #[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] pub struct UntypedResponse<'a> { + /// Header data associated with the response as bytes + #[serde(default, skip_serializing_if = "header_is_empty")] + pub header: Cow<'a, [u8]>, + /// Unique id associated with the response pub id: Cow<'a, str>, @@ -93,9 +128,14 @@ pub struct UntypedResponse<'a> { } impl<'a> UntypedResponse<'a> { - /// Attempts to convert an untyped request to a typed request + /// Attempts to convert an untyped response to a typed response pub fn to_typed_response(&self) -> io::Result> { Ok(Response { + header: if header_is_empty(&self.header) { + header!() + } else { + utils::deserialize_from_slice(&self.header)? + }, id: self.id.to_string(), origin_id: self.origin_id.to_string(), payload: utils::deserialize_from_slice(&self.payload)?, @@ -105,6 +145,10 @@ impl<'a> UntypedResponse<'a> { /// Convert into a borrowed version pub fn as_borrowed(&self) -> UntypedResponse<'_> { UntypedResponse { + header: match &self.header { + Cow::Borrowed(x) => Cow::Borrowed(x), + Cow::Owned(x) => Cow::Borrowed(x.as_slice()), + }, id: match &self.id { Cow::Borrowed(x) => Cow::Borrowed(x), Cow::Owned(x) => Cow::Borrowed(x.as_str()), @@ -123,6 +167,10 @@ impl<'a> UntypedResponse<'a> { /// Convert into an owned version pub fn into_owned(self) -> UntypedResponse<'static> { UntypedResponse { + header: match self.header { + Cow::Borrowed(x) => Cow::Owned(x.to_vec()), + Cow::Owned(x) => Cow::Owned(x), + }, id: match self.id { Cow::Borrowed(x) => Cow::Owned(x.to_string()), Cow::Owned(x) => Cow::Owned(x), @@ -138,6 +186,11 @@ impl<'a> UntypedResponse<'a> { } } + /// Updates the header of the response to the given `header`. + pub fn set_header(&mut self, header: impl IntoIterator) { + self.header = Cow::Owned(header.into_iter().collect()); + } + /// Updates the id of the response to the given `id`. pub fn set_id(&mut self, id: impl Into) { self.id = Cow::Owned(id.into()); @@ -150,76 +203,90 @@ impl<'a> UntypedResponse<'a> { /// Allocates a new collection of bytes representing the response. pub fn to_bytes(&self) -> Vec { - let mut bytes = vec![0x83]; + let mut bytes = vec![]; - write_str_msg_pack("id", &mut bytes); - write_str_msg_pack(&self.id, &mut bytes); + let has_header = !header_is_empty(&self.header); + if has_header { + rmp::encode::write_map_len(&mut bytes, 4).unwrap(); + } else { + rmp::encode::write_map_len(&mut bytes, 3).unwrap(); + } - write_str_msg_pack("origin_id", &mut bytes); - write_str_msg_pack(&self.origin_id, &mut bytes); + if has_header { + rmp::encode::write_str(&mut bytes, "header").unwrap(); + bytes.extend_from_slice(&self.header); + } - write_str_msg_pack("payload", &mut bytes); + rmp::encode::write_str(&mut bytes, "id").unwrap(); + rmp::encode::write_str(&mut bytes, &self.id).unwrap(); + + rmp::encode::write_str(&mut bytes, "origin_id").unwrap(); + rmp::encode::write_str(&mut bytes, &self.origin_id).unwrap(); + + rmp::encode::write_str(&mut bytes, "payload").unwrap(); bytes.extend_from_slice(&self.payload); bytes } /// Parses a collection of bytes, returning an untyped response if it can be potentially - /// represented as a [`Response`] depending on the payload, or the original bytes if it does not - /// represent a [`Response`]. + /// represented as a [`Response`] depending on the payload. /// /// NOTE: This supports parsing an invalid response where the payload would not properly /// deserialize, but the bytes themselves represent a complete response of some kind. pub fn from_slice(input: &'a [u8]) -> Result { - if input.len() < 2 { + if input.is_empty() { return Err(UntypedResponseParseError::WrongType); } - // MsgPack marks a fixmap using 0x80 - 0x8f to indicate the size (up to 15 elements). - // - // In the case of the request, there are only three elements: id, origin_id, and payload. - // So the first byte should ALWAYS be 0x83 (131). - if input[0] != 0x83 { - return Err(UntypedResponseParseError::WrongType); - } + let has_header = match rmp::Marker::from_u8(input[0]) { + rmp::Marker::FixMap(3) => false, + rmp::Marker::FixMap(4) => true, + _ => return Err(UntypedResponseParseError::WrongType), + }; - // Skip the first byte representing the fixmap + // Advance position by marker let input = &input[1..]; - // Validate that first field is id - let (input, id_key) = - parse_msg_pack_str(input).map_err(|_| UntypedResponseParseError::WrongType)?; - if id_key != "id" { - return Err(UntypedResponseParseError::WrongType); - } + // Parse the header if we have one + let (header, input) = if has_header { + let (_, input) = read_key_eq(input, "header") + .map_err(|_| UntypedResponseParseError::InvalidHeaderKey)?; + + let (header, input) = + read_header_bytes(input).map_err(|_| UntypedResponseParseError::InvalidHeader)?; + (header, input) + } else { + ([0u8; 0].as_slice(), input) + }; + + // Validate that next field is id + let (_, input) = + read_key_eq(input, "id").map_err(|_| UntypedResponseParseError::InvalidIdKey)?; // Get the id itself - let (input, id) = - parse_msg_pack_str(input).map_err(|_| UntypedResponseParseError::InvalidId)?; + let (id, input) = + read_str_bytes(input).map_err(|_| UntypedResponseParseError::InvalidId)?; - // Validate that second field is origin_id - let (input, origin_id_key) = - parse_msg_pack_str(input).map_err(|_| UntypedResponseParseError::WrongType)?; - if origin_id_key != "origin_id" { - return Err(UntypedResponseParseError::WrongType); - } + // Validate that next field is origin_id + let (_, input) = read_key_eq(input, "origin_id") + .map_err(|_| UntypedResponseParseError::InvalidOriginIdKey)?; // Get the origin_id itself - let (input, origin_id) = - parse_msg_pack_str(input).map_err(|_| UntypedResponseParseError::InvalidOriginId)?; + let (origin_id, input) = + read_str_bytes(input).map_err(|_| UntypedResponseParseError::InvalidOriginId)?; - // Validate that second field is payload - let (input, payload_key) = - parse_msg_pack_str(input).map_err(|_| UntypedResponseParseError::WrongType)?; - if payload_key != "payload" { - return Err(UntypedResponseParseError::WrongType); - } + // Validate that final field is payload + let (_, input) = read_key_eq(input, "payload") + .map_err(|_| UntypedResponseParseError::InvalidPayloadKey)?; + let header = Cow::Borrowed(header); let id = Cow::Borrowed(id); let origin_id = Cow::Borrowed(origin_id); let payload = Cow::Borrowed(input); Ok(Self { + header, id, origin_id, payload, @@ -236,22 +303,52 @@ mod tests { const TRUE_BYTE: u8 = 0xc3; const NEVER_USED_BYTE: u8 = 0xc1; + // fixstr of 6 bytes with str "header" + const HEADER_FIELD_BYTES: &[u8] = &[0xa6, b'h', b'e', b'a', b'd', b'e', b'r']; + + // fixmap of 2 objects with + // 1. key fixstr "key" and value fixstr "value" + // 1. key fixstr "num" and value fixint 123 + const HEADER_BYTES: &[u8] = &[ + 0x82, // valid map with 2 pair + 0xa3, b'k', b'e', b'y', // key: "key" + 0xa5, b'v', b'a', b'l', b'u', b'e', // value: "value" + 0xa3, b'n', b'u', b'm', // key: "num" + 0x7b, // value: 123 + ]; + // fixstr of 2 bytes with str "id" - const ID_FIELD_BYTES: &[u8] = &[0xa2, 0x69, 0x64]; + const ID_FIELD_BYTES: &[u8] = &[0xa2, b'i', b'd']; // fixstr of 9 bytes with str "origin_id" const ORIGIN_ID_FIELD_BYTES: &[u8] = &[0xa9, 0x6f, 0x72, 0x69, 0x67, 0x69, 0x6e, 0x5f, 0x69, 0x64]; // fixstr of 7 bytes with str "payload" - const PAYLOAD_FIELD_BYTES: &[u8] = &[0xa7, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64]; + const PAYLOAD_FIELD_BYTES: &[u8] = &[0xa7, b'p', b'a', b'y', b'l', b'o', b'a', b'd']; /// fixstr of 4 bytes with str "test" - const TEST_STR_BYTES: &[u8] = &[0xa4, 0x74, 0x65, 0x73, 0x74]; + const TEST_STR_BYTES: &[u8] = &[0xa4, b't', b'e', b's', b't']; #[test] fn untyped_response_should_support_converting_to_bytes() { let bytes = Response { + header: header!(), + id: "some id".to_string(), + origin_id: "some origin id".to_string(), + payload: true, + } + .to_vec() + .unwrap(); + + let untyped_response = UntypedResponse::from_slice(&bytes).unwrap(); + assert_eq!(untyped_response.to_bytes(), bytes); + } + + #[test] + fn untyped_response_should_support_converting_to_bytes_with_header() { + let bytes = Response { + header: header!("key" -> 123), id: "some id".to_string(), origin_id: "some origin id".to_string(), payload: true, @@ -263,9 +360,32 @@ mod tests { assert_eq!(untyped_response.to_bytes(), bytes); } + #[test] + fn untyped_response_should_support_parsing_from_response_bytes_with_header() { + let bytes = Response { + header: header!("key" -> 123), + id: "some id".to_string(), + origin_id: "some origin id".to_string(), + payload: true, + } + .to_vec() + .unwrap(); + + assert_eq!( + UntypedResponse::from_slice(&bytes), + Ok(UntypedResponse { + header: Cow::Owned(utils::serialize_to_vec(&header!("key" -> 123)).unwrap()), + id: Cow::Borrowed("some id"), + origin_id: Cow::Borrowed("some origin id"), + payload: Cow::Owned(vec![TRUE_BYTE]), + }) + ); + } + #[test] fn untyped_response_should_support_parsing_from_response_bytes_with_valid_payload() { let bytes = Response { + header: header!(), id: "some id".to_string(), origin_id: "some origin id".to_string(), payload: true, @@ -276,6 +396,7 @@ mod tests { assert_eq!( UntypedResponse::from_slice(&bytes), Ok(UntypedResponse { + header: Cow::Owned(vec![]), id: Cow::Borrowed("some id"), origin_id: Cow::Borrowed("some origin id"), payload: Cow::Owned(vec![TRUE_BYTE]), @@ -287,6 +408,7 @@ mod tests { fn untyped_response_should_support_parsing_from_response_bytes_with_invalid_payload() { // Response with id < 32 bytes let mut bytes = Response { + header: header!(), id: "".to_string(), origin_id: "".to_string(), payload: true, @@ -301,6 +423,7 @@ mod tests { assert_eq!( UntypedResponse::from_slice(&bytes), Ok(UntypedResponse { + header: Cow::Owned(vec![]), id: Cow::Owned("".to_string()), origin_id: Cow::Owned("".to_string()), payload: Cow::Owned(vec![TRUE_BYTE, NEVER_USED_BYTE]), @@ -308,6 +431,31 @@ mod tests { ); } + #[test] + fn untyped_response_should_support_parsing_full_request() { + let input = [ + &[0x84], + HEADER_FIELD_BYTES, + HEADER_BYTES, + ID_FIELD_BYTES, + TEST_STR_BYTES, + ORIGIN_ID_FIELD_BYTES, + &[0xa2, b'o', b'g'], + PAYLOAD_FIELD_BYTES, + &[TRUE_BYTE], + ] + .concat(); + + // Convert into typed so we can test + let untyped_response = UntypedResponse::from_slice(&input).unwrap(); + let response: Response = untyped_response.to_typed_response().unwrap(); + + assert_eq!(response.header, header!("key" -> "value", "num" -> 123)); + assert_eq!(response.id, "test"); + assert_eq!(response.origin_id, "og"); + assert!(response.payload); + } + #[test] fn untyped_response_should_fail_to_parse_if_given_bytes_not_representing_a_response() { // Empty byte slice @@ -328,10 +476,50 @@ mod tests { Err(UntypedResponseParseError::WrongType) ); + // Invalid header key + assert_eq!( + UntypedResponse::from_slice( + [ + &[0x84], + &[0xa0], // header key would be defined here, set to empty str + HEADER_BYTES, + ID_FIELD_BYTES, + TEST_STR_BYTES, + ORIGIN_ID_FIELD_BYTES, + TEST_STR_BYTES, + PAYLOAD_FIELD_BYTES, + &[TRUE_BYTE], + ] + .concat() + .as_slice() + ), + Err(UntypedResponseParseError::InvalidHeaderKey) + ); + + // Invalid header bytes + assert_eq!( + UntypedResponse::from_slice( + [ + &[0x84], + HEADER_FIELD_BYTES, + &[0xa0], // header would be defined here, set to empty str + ID_FIELD_BYTES, + TEST_STR_BYTES, + ORIGIN_ID_FIELD_BYTES, + TEST_STR_BYTES, + PAYLOAD_FIELD_BYTES, + &[TRUE_BYTE], + ] + .concat() + .as_slice() + ), + Err(UntypedResponseParseError::InvalidHeader) + ); + // Missing fields (corrupt data) assert_eq!( UntypedResponse::from_slice(&[0x83]), - Err(UntypedResponseParseError::WrongType) + Err(UntypedResponseParseError::InvalidIdKey) ); // Missing id field (has valid data itself) @@ -349,7 +537,7 @@ mod tests { .concat() .as_slice() ), - Err(UntypedResponseParseError::WrongType) + Err(UntypedResponseParseError::InvalidIdKey) ); // Non-str id field value @@ -403,7 +591,7 @@ mod tests { .concat() .as_slice() ), - Err(UntypedResponseParseError::WrongType) + Err(UntypedResponseParseError::InvalidOriginIdKey) ); // Non-str origin_id field value @@ -457,7 +645,7 @@ mod tests { .concat() .as_slice() ), - Err(UntypedResponseParseError::WrongType) + Err(UntypedResponseParseError::InvalidPayloadKey) ); } } diff --git a/distant-protocol/src/response.rs b/distant-protocol/src/response.rs index 8470ce1..c73812e 100644 --- a/distant-protocol/src/response.rs +++ b/distant-protocol/src/response.rs @@ -631,7 +631,7 @@ mod tests { value, serde_json::json!({ "type": "changed", - "ts": u64::MAX, + "timestamp": u64::MAX, "kind": "access", "path": "path", }) @@ -657,13 +657,13 @@ mod tests { value, serde_json::json!({ "type": "changed", - "ts": u64::MAX, + "timestamp": u64::MAX, "kind": "access", "path": "path", "details": { "attribute": "permissions", "renamed": "renamed", - "ts": u64::MAX, + "timestamp": u64::MAX, "extra": "info", }, }) @@ -674,7 +674,7 @@ mod tests { fn should_be_able_to_deserialize_minimal_payload_from_json() { let value = serde_json::json!({ "type": "changed", - "ts": u64::MAX, + "timestamp": u64::MAX, "kind": "access", "path": "path", }); @@ -695,13 +695,13 @@ mod tests { fn should_be_able_to_deserialize_full_payload_from_json() { let value = serde_json::json!({ "type": "changed", - "ts": u64::MAX, + "timestamp": u64::MAX, "kind": "access", "path": "path", "details": { "attribute": "permissions", "renamed": "renamed", - "ts": u64::MAX, + "timestamp": u64::MAX, "extra": "info", }, });