diff --git a/esphome/components/api/proto.cpp b/esphome/components/api/proto.cpp index f99e5b66e5..4f0d0846d7 100644 --- a/esphome/components/api/proto.cpp +++ b/esphome/components/api/proto.cpp @@ -22,7 +22,7 @@ uint32_t ProtoDecodableMessage::count_repeated_field(const uint8_t *buffer, size } uint32_t tag = res->as_uint32(); - uint32_t field_type = tag & 0b111; + uint32_t field_type = tag & WIRE_TYPE_MASK; uint32_t field_id = tag >> 3; ptr += consumed; @@ -33,7 +33,7 @@ uint32_t ProtoDecodableMessage::count_repeated_field(const uint8_t *buffer, size // Skip field data based on wire type switch (field_type) { - case 0: { // VarInt - parse and skip + case WIRE_TYPE_VARINT: { // VarInt - parse and skip res = ProtoVarInt::parse(ptr, end - ptr, &consumed); if (!res.has_value()) { return count; // Invalid data, return what we have @@ -41,7 +41,7 @@ uint32_t ProtoDecodableMessage::count_repeated_field(const uint8_t *buffer, size ptr += consumed; break; } - case 2: { // Length-delimited - parse length and skip data + case WIRE_TYPE_LENGTH_DELIMITED: { // Length-delimited - parse length and skip data res = ProtoVarInt::parse(ptr, end - ptr, &consumed); if (!res.has_value()) { return count; @@ -54,7 +54,7 @@ uint32_t ProtoDecodableMessage::count_repeated_field(const uint8_t *buffer, size ptr += field_length; break; } - case 5: { // 32-bit - skip 4 bytes + case WIRE_TYPE_FIXED32: { // 32-bit - skip 4 bytes if (ptr + 4 > end) { return count; } @@ -85,12 +85,12 @@ void ProtoDecodableMessage::decode(const uint8_t *buffer, size_t length) { } uint32_t tag = res->as_uint32(); - uint32_t field_type = tag & 0b111; + uint32_t field_type = tag & WIRE_TYPE_MASK; uint32_t field_id = tag >> 3; ptr += consumed; switch (field_type) { - case 0: { // VarInt + case WIRE_TYPE_VARINT: { // VarInt res = ProtoVarInt::parse(ptr, end - ptr, &consumed); if (!res.has_value()) { ESP_LOGV(TAG, "Invalid VarInt at offset %ld", (long) (ptr - buffer)); @@ -102,7 +102,7 @@ void ProtoDecodableMessage::decode(const uint8_t *buffer, size_t length) { ptr += consumed; break; } - case 2: { // Length-delimited + case WIRE_TYPE_LENGTH_DELIMITED: { // Length-delimited res = ProtoVarInt::parse(ptr, end - ptr, &consumed); if (!res.has_value()) { ESP_LOGV(TAG, "Invalid Length Delimited at offset %ld", (long) (ptr - buffer)); @@ -120,7 +120,7 @@ void ProtoDecodableMessage::decode(const uint8_t *buffer, size_t length) { ptr += field_length; break; } - case 5: { // 32-bit + case WIRE_TYPE_FIXED32: { // 32-bit if (ptr + 4 > end) { ESP_LOGV(TAG, "Out-of-bounds Fixed32-bit at offset %ld", (long) (ptr - buffer)); return; diff --git a/esphome/components/api/proto.h b/esphome/components/api/proto.h index 50f85fb247..e7585924a5 100644 --- a/esphome/components/api/proto.h +++ b/esphome/components/api/proto.h @@ -15,6 +15,13 @@ namespace esphome::api { +// Protocol Buffer wire type constants +// See https://protobuf.dev/programming-guides/encoding/#structure +constexpr uint8_t WIRE_TYPE_VARINT = 0; // int32, int64, uint32, uint64, sint32, sint64, bool, enum +constexpr uint8_t WIRE_TYPE_LENGTH_DELIMITED = 2; // string, bytes, embedded messages, packed repeated fields +constexpr uint8_t WIRE_TYPE_FIXED32 = 5; // fixed32, sfixed32, float +constexpr uint8_t WIRE_TYPE_MASK = 0b111; // Mask to extract wire type from tag + // Helper functions for ZigZag encoding/decoding inline constexpr uint32_t encode_zigzag32(int32_t value) { return (static_cast(value) << 1) ^ (static_cast(value >> 31)); @@ -241,7 +248,7 @@ class ProtoWriteBuffer { * Following https://protobuf.dev/programming-guides/encoding/#structure */ void encode_field_raw(uint32_t field_id, uint32_t type) { - uint32_t val = (field_id << 3) | (type & 0b111); + uint32_t val = (field_id << 3) | (type & WIRE_TYPE_MASK); this->encode_varint_raw(val); } void encode_string(uint32_t field_id, const char *string, size_t len, bool force = false) { @@ -493,7 +500,7 @@ class ProtoSize { * @return The number of bytes needed to encode the field ID and wire type */ static constexpr uint32_t field(uint32_t field_id, uint32_t type) { - uint32_t tag = (field_id << 3) | (type & 0b111); + uint32_t tag = (field_id << 3) | (type & WIRE_TYPE_MASK); return varint(tag); } diff --git a/script/api_protobuf/api_protobuf.py b/script/api_protobuf/api_protobuf.py index f58442ff01..4936434fc2 100755 --- a/script/api_protobuf/api_protobuf.py +++ b/script/api_protobuf/api_protobuf.py @@ -11,6 +11,7 @@ from typing import Any import aioesphomeapi.api_options_pb2 as pb import google.protobuf.descriptor_pb2 as descriptor +from google.protobuf.descriptor_pb2 import FieldDescriptorProto class WireType(IntEnum): @@ -148,7 +149,7 @@ class TypeInfo(ABC): @property def repeated(self) -> bool: """Check if the field is repeated.""" - return self._field.label == 3 + return self._field.label == FieldDescriptorProto.LABEL_REPEATED @property def wire_type(self) -> WireType: @@ -337,7 +338,7 @@ def create_field_type_info( needs_encode: bool = True, ) -> TypeInfo: """Create the appropriate TypeInfo instance for a field, handling repeated fields and custom options.""" - if field.label == 3: # repeated + if field.label == FieldDescriptorProto.LABEL_REPEATED: # Check if this repeated field has fixed_array_with_length_define option if ( fixed_size := get_field_opt(field, pb.fixed_array_with_length_define) @@ -1890,7 +1891,7 @@ def build_message_type( # Validate that fixed_array_size is only used in encode-only messages if ( needs_decode - and field.label == 3 + and field.label == FieldDescriptorProto.LABEL_REPEATED and get_field_opt(field, pb.fixed_array_size) is not None ): raise ValueError( @@ -1903,7 +1904,7 @@ def build_message_type( # Validate that fixed_array_with_length_define is only used in encode-only messages if ( needs_decode - and field.label == 3 + and field.label == FieldDescriptorProto.LABEL_REPEATED and get_field_opt(field, pb.fixed_array_with_length_define) is not None ): raise ValueError( @@ -1916,7 +1917,7 @@ def build_message_type( # Collect fixed_vector repeated fields for custom decode generation if ( needs_decode - and field.label == 3 + and field.label == FieldDescriptorProto.LABEL_REPEATED and get_field_opt(field, pb.fixed_vector, False) ): fixed_vector_fields.append((field.name, field.number))