1
0
mirror of https://github.com/esphome/esphome.git synced 2025-10-26 04:33:47 +00:00
This commit is contained in:
J. Nick Koston
2025-07-20 19:56:42 -10:00
parent 8b74333e8b
commit 04953db51e
2 changed files with 37 additions and 38 deletions

View File

@@ -991,13 +991,7 @@ class NoiseEncryptionSetKeyRequest : public ProtoDecodableMessage {
#ifdef HAS_PROTO_MESSAGE_DUMP
const char *message_name() const override { return "noise_encryption_set_key_request"; }
#endif
const uint8_t *key_ptr_{nullptr};
size_t key_len_{0};
std::string key{}; // Storage for decoded data
void set_key(const uint8_t *data, size_t len) {
this->key_ptr_ = data;
this->key_len_ = len;
}
#ifdef HAS_PROTO_MESSAGE_DUMP
void dump_to(std::string &out) const override;
#endif
@@ -1920,13 +1914,7 @@ class BluetoothGATTWriteRequest : public ProtoDecodableMessage {
uint64_t address{0};
uint32_t handle{0};
bool response{false};
const uint8_t *data_ptr_{nullptr};
size_t data_len_{0};
std::string data{}; // Storage for decoded data
void set_data(const uint8_t *data, size_t len) {
this->data_ptr_ = data;
this->data_len_ = len;
}
#ifdef HAS_PROTO_MESSAGE_DUMP
void dump_to(std::string &out) const override;
#endif
@@ -1960,13 +1948,7 @@ class BluetoothGATTWriteDescriptorRequest : public ProtoDecodableMessage {
#endif
uint64_t address{0};
uint32_t handle{0};
const uint8_t *data_ptr_{nullptr};
size_t data_len_{0};
std::string data{}; // Storage for decoded data
void set_data(const uint8_t *data, size_t len) {
this->data_ptr_ = data;
this->data_len_ = len;
}
#ifdef HAS_PROTO_MESSAGE_DUMP
void dump_to(std::string &out) const override;
#endif

View File

@@ -314,7 +314,9 @@ def validate_field_type(field_type: int, field_name: str = "") -> None:
def create_field_type_info(
field: descriptor.FieldDescriptorProto, needs_decode: bool = True
field: descriptor.FieldDescriptorProto,
needs_decode: bool = True,
needs_encode: bool = True,
) -> TypeInfo:
"""Create the appropriate TypeInfo instance for a field, handling repeated fields and custom options."""
if field.label == 3: # repeated
@@ -329,7 +331,7 @@ def create_field_type_info(
# Special handling for bytes fields
if field.type == 12:
return BytesType(field, needs_decode)
return BytesType(field, needs_decode, needs_encode)
validate_field_type(field.type, field.name)
return TYPE_INFO[field.type](field)
@@ -599,25 +601,36 @@ class BytesType(TypeInfo):
wire_type = WireType.LENGTH_DELIMITED # Uses wire type 2
def __init__(
self, field: descriptor.FieldDescriptorProto, needs_decode: bool = True
self,
field: descriptor.FieldDescriptorProto,
needs_decode: bool = True,
needs_encode: bool = True,
) -> None:
super().__init__(field)
self.needs_decode = needs_decode
self.needs_encode = needs_encode
@property
def public_content(self) -> list[str]:
# Store both pointer and length for zero-copy encoding, plus setter method
content = [
content = []
# Add pointer/length fields if message needs encoding
if self.needs_encode:
content.extend(
[
f"const uint8_t* {self.field_name}_ptr_{{nullptr}};",
f"size_t {self.field_name}_len_{{0}};",
]
)
# Only add storage if message needs decoding
# Add std::string storage if message needs decoding
if self.needs_decode:
content.append(
f"std::string {self.field_name}{{}}; // Storage for decoded data"
)
# Add setter method if message needs encoding
if self.needs_encode:
content.extend(
[
f"void set_{self.field_name}(const uint8_t* data, size_t len) {{",
@@ -1304,7 +1317,7 @@ def build_message_type(
if field.options.deprecated:
continue
ti = create_field_type_info(field, needs_decode)
ti = create_field_type_info(field, needs_decode, needs_encode)
# Skip field declarations for fields that are in the base class
# but include their encode/decode logic
@@ -1619,16 +1632,20 @@ def build_base_class(
public_content = []
protected_content = []
# Determine if any message using this base class needs decoding
# Determine if any message using this base class needs decoding/encoding
needs_decode = any(
message_source_map.get(msg.name, SOURCE_BOTH) in (SOURCE_BOTH, SOURCE_CLIENT)
for msg in messages
)
needs_encode = any(
message_source_map.get(msg.name, SOURCE_BOTH) in (SOURCE_BOTH, SOURCE_SERVER)
for msg in messages
)
# For base classes, we only declare the fields but don't handle encode/decode
# The derived classes will handle encoding/decoding with their specific field numbers
for field in common_fields:
ti = create_field_type_info(field, needs_decode)
ti = create_field_type_info(field, needs_decode, needs_encode)
# Get field_ifdef if it's consistent across all messages
field_ifdef = get_common_field_ifdef(field.name, messages)