From 2abbed2014a8cf45de8b48490efeca1a286e3359 Mon Sep 17 00:00:00 2001 From: Ronald Holshausen Date: Mon, 11 Nov 2024 14:15:12 +1100 Subject: [PATCH] refactor: Write out the additional field values when serialising a field #73 --- src/dynamic_message.rs | 517 +++++++++++++++++++++++++++++++++---- src/message_decoder/mod.rs | 2 +- 2 files changed, 461 insertions(+), 58 deletions(-) diff --git a/src/dynamic_message.rs b/src/dynamic_message.rs index d97cf8b..6ab0cfd 100644 --- a/src/dynamic_message.rs +++ b/src/dynamic_message.rs @@ -100,68 +100,83 @@ impl DynamicMessage { for (field_num, values) in self.fields.iter() .sorted_by(|(a, _), (b, _)| Ord::cmp(a, b)) { for field in values { - trace!(%field_num, field = field.to_string().as_str(), "Writing"); - encode_key(field.field_num, field.wire_type, buffer); - match field.wire_type { - WireType::Varint => match &field.data { - ProtobufFieldData::Boolean(b) => encode_varint(*b as u64, buffer), - ProtobufFieldData::UInteger32(n) => encode_varint(*n as u64, buffer), - ProtobufFieldData::Integer32(n) => encode_varint(*n as u64, buffer), - ProtobufFieldData::UInteger64(n) => encode_varint(*n, buffer), - ProtobufFieldData::Integer64(n) => encode_varint(*n as u64, buffer), - ProtobufFieldData::Enum(n, _) => encode_varint(*n as u64, buffer), - ProtobufFieldData::Unknown(b) => { - debug!("Writing unknown field {}", field.data); - buffer.put_slice(b.as_slice()); - }, - _ => return Err(anyhow!("Expected a varint, but field is {}", field.data)) - }, - WireType::SixtyFourBit => match &field.data { - ProtobufFieldData::UInteger64(n) => buffer.put_u64_le(*n), - ProtobufFieldData::Integer64(n) => buffer.put_i64_le(*n), - ProtobufFieldData::Double(n) => buffer.put_f64_le(*n), - ProtobufFieldData::Unknown(b) => { - debug!("Writing unknown field {}", field.data); - buffer.put_slice(b.as_slice()); - } - _ => return Err(anyhow!("Expected a 64 bit value, but field is {}", field.data)) - } - WireType::LengthDelimited => match &field.data { - ProtobufFieldData::String(s) => { - encode_varint(s.len() as u64, buffer); - buffer.put_slice(s.as_bytes()); - } - ProtobufFieldData::Bytes(b) => { - encode_varint(b.len() as u64, buffer); - buffer.put_slice(b.as_slice()); - } - ProtobufFieldData::Message(m, _) => { - encode_varint(m.len() as u64, buffer); - buffer.put_slice(m.as_slice()); - } - ProtobufFieldData::Unknown(b) => { - debug!("Writing unknown field {}", field.data); - buffer.put_slice(b.as_slice()); - }, - _ => return Err(anyhow!("Expected a length delimited value, but field is {}", field.data)) + Self::write_field(buffer, *field_num, field, &field.data)?; + if field.repeated_field() && !field.additional_data.is_empty() { + for data in &field.additional_data { + Self::write_field(buffer, *field_num, field, data)?; } - WireType::ThirtyTwoBit => match &field.data { - ProtobufFieldData::UInteger32(n) => buffer.put_u32_le(*n), - ProtobufFieldData::Integer32(n) => buffer.put_i32_le(*n), - ProtobufFieldData::Float(n) => buffer.put_f32_le(*n), - ProtobufFieldData::Unknown(b) => { - debug!("Writing unknown field {}", field.data); - buffer.put_slice(b.as_slice()); - }, - _ => return Err(anyhow!("Expected a 32 bit value, but field is {}", field.data)) - } - _ => return Err(anyhow!("Groups are not supported")) } } } Ok(()) } + fn write_field( + buffer: &mut B, + field_num: u32, + field: &ProtobufField, + data: &ProtobufFieldData + ) -> anyhow::Result<()> where B: BufMut { + trace!(%field_num, %field, %data, "Writing field data"); + encode_key(field.field_num, field.wire_type, buffer); + match field.wire_type { + WireType::Varint => match data { + ProtobufFieldData::Boolean(b) => encode_varint(*b as u64, buffer), + ProtobufFieldData::UInteger32(n) => encode_varint(*n as u64, buffer), + ProtobufFieldData::Integer32(n) => encode_varint(*n as u64, buffer), + ProtobufFieldData::UInteger64(n) => encode_varint(*n, buffer), + ProtobufFieldData::Integer64(n) => encode_varint(*n as u64, buffer), + ProtobufFieldData::Enum(n, _) => encode_varint(*n as u64, buffer), + ProtobufFieldData::Unknown(b) => { + debug!("Writing unknown field {}", field.data); + buffer.put_slice(b.as_slice()); + }, + _ => return Err(anyhow!("Expected a varint, but field is {}", field.data)) + }, + WireType::SixtyFourBit => match data { + ProtobufFieldData::UInteger64(n) => buffer.put_u64_le(*n), + ProtobufFieldData::Integer64(n) => buffer.put_i64_le(*n), + ProtobufFieldData::Double(n) => buffer.put_f64_le(*n), + ProtobufFieldData::Unknown(b) => { + debug!("Writing unknown field {}", field.data); + buffer.put_slice(b.as_slice()); + } + _ => return Err(anyhow!("Expected a 64 bit value, but field is {}", field.data)) + } + WireType::LengthDelimited => match data { + ProtobufFieldData::String(s) => { + encode_varint(s.len() as u64, buffer); + buffer.put_slice(s.as_bytes()); + } + ProtobufFieldData::Bytes(b) => { + encode_varint(b.len() as u64, buffer); + buffer.put_slice(b.as_slice()); + } + ProtobufFieldData::Message(m, _) => { + encode_varint(m.len() as u64, buffer); + buffer.put_slice(m.as_slice()); + } + ProtobufFieldData::Unknown(b) => { + debug!("Writing unknown field {}", field.data); + buffer.put_slice(b.as_slice()); + }, + _ => return Err(anyhow!("Expected a length delimited value, but field is {}", field.data)) + } + WireType::ThirtyTwoBit => match data { + ProtobufFieldData::UInteger32(n) => buffer.put_u32_le(*n), + ProtobufFieldData::Integer32(n) => buffer.put_i32_le(*n), + ProtobufFieldData::Float(n) => buffer.put_f32_le(*n), + ProtobufFieldData::Unknown(b) => { + debug!("Writing unknown field {}", field.data); + buffer.put_slice(b.as_slice()); + }, + _ => return Err(anyhow!("Expected a 32 bit value, but field is {}", field.data)) + } + _ => return Err(anyhow!("Groups are not supported")) + } + Ok(()) + } + /// Retrieve the value for a message field using the given path #[instrument(ret, skip(self), fields(path = %path))] pub fn fetch_field_value(&mut self, path: &DocPath) -> Option> { @@ -430,11 +445,13 @@ mod tests { use pact_models::generators::GeneratorTestMode; use pact_models::path_exp::DocPath; use pact_models::prelude::Generator::RandomInt; + use pretty_assertions::assert_eq; use prost::encoding::WireType; - use prost_types::{DescriptorProto, field_descriptor_proto, FieldDescriptorProto, FileDescriptorSet}; + use prost_types::{DescriptorProto, field_descriptor_proto, FieldDescriptorProto, FileDescriptorProto, FileDescriptorSet}; use serde_json::json; + use crate::dynamic_message::DynamicMessage; - use crate::message_decoder::{ProtobufField, ProtobufFieldData}; + use crate::message_decoder::{decode_message, ProtobufField, ProtobufFieldData}; #[test] fn dynamic_message_fetch_value_with_no_fields() { @@ -793,4 +810,390 @@ mod tests { expect!(result.unwrap_err().to_string()).to(be_equal_to( "i64 can not be generated from 'sss' - invalid digit found in string")); } + + #[test] + fn dynamic_message_write_to_test() { + let field_descriptor = FieldDescriptorProto { + name: Some("one".to_string()), + number: Some(1), + r#type: Some(field_descriptor_proto::Type::Int64 as i32), + label: None, + .. FieldDescriptorProto::default() + }; + let field = ProtobufField { + field_num: 1, + field_name: "one".to_string(), + wire_type: WireType::Varint, + data: ProtobufFieldData::Integer64(100), + additional_data: vec![], + descriptor: field_descriptor.clone() + }; + let descriptors = FileDescriptorSet { + file: vec![] + }; + let fields = vec![ field.clone() ]; + let descriptor = DescriptorProto { + field: vec![ + field_descriptor.clone() + ], + .. DescriptorProto::default() + }; + let message = DynamicMessage::new(&descriptor, fields.as_slice(), &descriptors); + + let mut buffer = BytesMut::new(); + message.write_to(&mut buffer).unwrap(); + + let result = decode_message(&mut buffer.freeze(), &descriptor, &descriptors).unwrap(); + expect!(result).to(be_equal_to(vec![ field ])); + } + + #[test] + fn dynamic_message_write_to_test_with_multiple_fields() { + let field_descriptor_1 = FieldDescriptorProto { + name: Some("one".to_string()), + number: Some(1), + r#type: Some(field_descriptor_proto::Type::Int64 as i32), + label: None, + .. FieldDescriptorProto::default() + }; + let field_1 = ProtobufField { + field_num: 1, + field_name: "one".to_string(), + wire_type: WireType::Varint, + data: ProtobufFieldData::Integer64(100), + additional_data: vec![], + descriptor: field_descriptor_1.clone() + }; + + let field_descriptor_2 = FieldDescriptorProto { + name: Some("two".to_string()), + number: Some(2), + r#type: Some(field_descriptor_proto::Type::String as i32), + label: None, + .. FieldDescriptorProto::default() + }; + let field_2 = ProtobufField { + field_num: 2, + field_name: "two".to_string(), + wire_type: WireType::LengthDelimited, + data: ProtobufFieldData::String("test".to_string()), + additional_data: vec![], + descriptor: field_descriptor_2.clone() + }; + + let field_descriptor_3 = FieldDescriptorProto { + name: Some("three".to_string()), + number: Some(3), + r#type: Some(field_descriptor_proto::Type::Bool as i32), + label: None, + .. FieldDescriptorProto::default() + }; + let field_3 = ProtobufField { + field_num: 3, + field_name: "three".to_string(), + wire_type: WireType::Varint, + data: ProtobufFieldData::Boolean(true), + additional_data: vec![], + descriptor: field_descriptor_3.clone() + }; + + let descriptors = FileDescriptorSet { + file: vec![] + }; + let fields = vec![ field_1.clone(), field_3.clone(), field_2.clone() ]; + let descriptor = DescriptorProto { + field: vec![ + field_descriptor_1.clone(), + field_descriptor_2.clone(), + field_descriptor_3.clone() + ], + .. DescriptorProto::default() + }; + let message = DynamicMessage::new(&descriptor, fields.as_slice(), &descriptors); + + let mut buffer = BytesMut::new(); + message.write_to(&mut buffer).unwrap(); + + let result = decode_message(&mut buffer.freeze(), &descriptor, &descriptors).unwrap(); + expect!(result).to(be_equal_to(vec![ field_1, field_2, field_3 ])); + } + + #[test] + fn dynamic_message_write_to_test_with_child_field() { + let child_proto_1 = FieldDescriptorProto { + name: Some("two".to_string()), + number: Some(1), + r#type: Some(3), + ..FieldDescriptorProto::default() + }; + let child_proto_2 = FieldDescriptorProto { + name: Some("three".to_string()), + number: Some(2), + r#type: Some(3), + ..FieldDescriptorProto::default() + }; + let child_descriptor = DescriptorProto { + name: Some("child".to_string()), + field: vec![ + child_proto_1.clone(), + child_proto_2.clone() + ], + .. DescriptorProto::default() + }; + let child_field = ProtobufField { + field_num: 1, + field_name: "two".to_string(), + wire_type: WireType::Varint, + data: ProtobufFieldData::Integer64(100), + additional_data: vec![], + descriptor: child_proto_1.clone() + }; + let child_field2 = ProtobufField { + field_num: 2, + field_name: "three".to_string(), + wire_type: WireType::Varint, + data: ProtobufFieldData::Integer64(200), + additional_data: vec![], + descriptor: child_proto_2.clone() + }; + + let field_descriptor = FieldDescriptorProto { + name: Some("one".to_string()), + number: Some(1), + r#type: Some(field_descriptor_proto::Type::Message as i32), + type_name: Some("child".to_string()), + label: None, + .. FieldDescriptorProto::default() + }; + let descriptor = DescriptorProto { + name: Some("parent".to_string()), + field: vec![ + field_descriptor.clone() + ], + .. DescriptorProto::default() + }; + let descriptors = FileDescriptorSet { + file: vec![ + FileDescriptorProto { + message_type: vec![ + descriptor.clone(), child_descriptor.clone() + ], + .. FileDescriptorProto::default() + } + ] + }; + + let child_message = DynamicMessage::new(&child_descriptor, &[child_field.clone(), child_field2], &descriptors); + let mut child_buffer = BytesMut::new(); + child_message.write_to(&mut child_buffer).unwrap(); + + let field = ProtobufField { + field_num: 1, + field_name: "one".to_string(), + wire_type: WireType::LengthDelimited, + data: ProtobufFieldData::Message(child_buffer.to_vec(), child_descriptor), + additional_data: vec![], + descriptor: field_descriptor.clone() + }; + let fields = vec![ field.clone() ]; + let message = DynamicMessage::new(&descriptor, fields.as_slice(), &descriptors); + + let mut buffer = BytesMut::new(); + message.write_to(&mut buffer).unwrap(); + + let result = decode_message(&mut buffer.freeze(), &descriptor, &descriptors).unwrap(); + assert_eq!(result, vec![ field ]); + } + + #[test] + fn dynamic_message_write_to_test_with_repeated_fields() { + let field_descriptor_1 = FieldDescriptorProto { + name: Some("one".to_string()), + number: Some(1), + r#type: Some(field_descriptor_proto::Type::Int64 as i32), + label: Some(field_descriptor_proto::Label::Repeated as i32), + .. FieldDescriptorProto::default() + }; + let field_1_1 = ProtobufField { + field_num: 1, + field_name: "one".to_string(), + wire_type: WireType::Varint, + data: ProtobufFieldData::Integer64(100), + additional_data: vec![], + descriptor: field_descriptor_1.clone() + }; + let field_1_2 = ProtobufField { + field_num: 1, + field_name: "one".to_string(), + wire_type: WireType::Varint, + data: ProtobufFieldData::Integer64(101), + additional_data: vec![], + descriptor: field_descriptor_1.clone() + }; + let field_1_3 = ProtobufField { + field_num: 1, + field_name: "one".to_string(), + wire_type: WireType::Varint, + data: ProtobufFieldData::Integer64(102), + additional_data: vec![], + descriptor: field_descriptor_1.clone() + }; + + let field_descriptor_2 = FieldDescriptorProto { + name: Some("two".to_string()), + number: Some(2), + r#type: Some(field_descriptor_proto::Type::String as i32), + label: None, + .. FieldDescriptorProto::default() + }; + let field_2 = ProtobufField { + field_num: 2, + field_name: "two".to_string(), + wire_type: WireType::LengthDelimited, + data: ProtobufFieldData::String("test".to_string()), + additional_data: vec![], + descriptor: field_descriptor_2.clone() + }; + + let field_descriptor_3 = FieldDescriptorProto { + name: Some("three".to_string()), + number: Some(3), + r#type: Some(field_descriptor_proto::Type::Bool as i32), + label: None, + .. FieldDescriptorProto::default() + }; + let field_3 = ProtobufField { + field_num: 3, + field_name: "three".to_string(), + wire_type: WireType::Varint, + data: ProtobufFieldData::Boolean(true), + additional_data: vec![], + descriptor: field_descriptor_3.clone() + }; + + let descriptors = FileDescriptorSet { + file: vec![] + }; + let fields = vec![ + field_1_1.clone(), + field_3.clone(), + field_1_2.clone(), + field_2.clone(), + field_1_3.clone() + ]; + let descriptor = DescriptorProto { + field: vec![ + field_descriptor_1.clone(), + field_descriptor_2.clone(), + field_descriptor_3.clone() + ], + .. DescriptorProto::default() + }; + let message = DynamicMessage::new(&descriptor, fields.as_slice(), &descriptors); + + let mut buffer = BytesMut::new(); + message.write_to(&mut buffer).unwrap(); + + let result = decode_message(&mut buffer.freeze(), &descriptor, &descriptors).unwrap(); + expect!(result).to(be_equal_to(vec![ field_1_1, field_1_2, field_1_3, field_2, field_3 ])); + } + + #[test] + fn dynamic_message_write_to_test_with_repeated_field_with_additional_values() { + let field_descriptor_1 = FieldDescriptorProto { + name: Some("one".to_string()), + number: Some(1), + r#type: Some(field_descriptor_proto::Type::Int64 as i32), + label: Some(field_descriptor_proto::Label::Repeated as i32), + .. FieldDescriptorProto::default() + }; + let field_1 = ProtobufField { + field_num: 1, + field_name: "one".to_string(), + wire_type: WireType::Varint, + data: ProtobufFieldData::Integer64(100), + additional_data: vec![ ProtobufFieldData::Integer64(101), ProtobufFieldData::Integer64(102) ], + descriptor: field_descriptor_1.clone() + }; + + let field_descriptor_2 = FieldDescriptorProto { + name: Some("two".to_string()), + number: Some(2), + r#type: Some(field_descriptor_proto::Type::String as i32), + label: None, + .. FieldDescriptorProto::default() + }; + let field_2 = ProtobufField { + field_num: 2, + field_name: "two".to_string(), + wire_type: WireType::LengthDelimited, + data: ProtobufFieldData::String("test".to_string()), + additional_data: vec![], + descriptor: field_descriptor_2.clone() + }; + + let field_descriptor_3 = FieldDescriptorProto { + name: Some("three".to_string()), + number: Some(3), + r#type: Some(field_descriptor_proto::Type::Bool as i32), + label: None, + .. FieldDescriptorProto::default() + }; + let field_3 = ProtobufField { + field_num: 3, + field_name: "three".to_string(), + wire_type: WireType::Varint, + data: ProtobufFieldData::Boolean(true), + additional_data: vec![], + descriptor: field_descriptor_3.clone() + }; + + let descriptors = FileDescriptorSet { + file: vec![] + }; + let fields = vec![ + field_1.clone(), + field_3.clone(), + field_2.clone() + ]; + let descriptor = DescriptorProto { + field: vec![ + field_descriptor_1.clone(), + field_descriptor_2.clone(), + field_descriptor_3.clone() + ], + .. DescriptorProto::default() + }; + let message = DynamicMessage::new(&descriptor, fields.as_slice(), &descriptors); + + let mut buffer = BytesMut::new(); + message.write_to(&mut buffer).unwrap(); + + let result = decode_message(&mut buffer.freeze(), &descriptor, &descriptors).unwrap(); + let field_1_1 = ProtobufField { + field_num: 1, + field_name: "one".to_string(), + wire_type: WireType::Varint, + data: ProtobufFieldData::Integer64(100), + additional_data: vec![], + descriptor: field_descriptor_1.clone() + }; + let field_1_2 = ProtobufField { + field_num: 1, + field_name: "one".to_string(), + wire_type: WireType::Varint, + data: ProtobufFieldData::Integer64(101), + additional_data: vec![], + descriptor: field_descriptor_1.clone() + }; + let field_1_3 = ProtobufField { + field_num: 1, + field_name: "one".to_string(), + wire_type: WireType::Varint, + data: ProtobufFieldData::Integer64(102), + additional_data: vec![], + descriptor: field_descriptor_1.clone() + }; + assert_eq!(result, vec![ field_1_1, field_1_2, field_1_3, field_2, field_3 ]); + } } diff --git a/src/message_decoder/mod.rs b/src/message_decoder/mod.rs index 6023a2d..25a3201 100644 --- a/src/message_decoder/mod.rs +++ b/src/message_decoder/mod.rs @@ -160,7 +160,7 @@ fn wire_type_for_field(descriptor: &FieldDescriptorProto) -> WireType { impl Display for ProtobufField { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}:({}, {:?}, {}) = {}", self.field_num, self.field_name, self.wire_type, self.data.type_name(), self.data) + write!(f, "{}:({}, {:?}, {})", self.field_num, self.field_name, self.wire_type, self.data.type_name()) } }