diff --git a/esphome/components/api/api.proto b/esphome/components/api/api.proto index c7d8fb28f0..f7c51b8e97 100644 --- a/esphome/components/api/api.proto +++ b/esphome/components/api/api.proto @@ -876,10 +876,10 @@ message ExecuteServiceArgument { string string_ = 4; // ESPHome 1.14 (api v1.3) make int a signed value sint32 int_ = 5; - repeated bool bool_array = 6 [packed=false]; - repeated sint32 int_array = 7 [packed=false]; - repeated float float_array = 8 [packed=false]; - repeated string string_array = 9; + repeated bool bool_array = 6 [packed=false, (fixed_vector) = true]; + repeated sint32 int_array = 7 [packed=false, (fixed_vector) = true]; + repeated float float_array = 8 [packed=false, (fixed_vector) = true]; + repeated string string_array = 9 [(fixed_vector) = true]; } message ExecuteServiceRequest { option (id) = 42; @@ -888,7 +888,7 @@ message ExecuteServiceRequest { option (ifdef) = "USE_API_SERVICES"; fixed32 key = 1; - repeated ExecuteServiceArgument args = 2; + repeated ExecuteServiceArgument args = 2 [(fixed_vector) = true]; } // ==================== CAMERA ==================== diff --git a/esphome/components/api/api_pb2.cpp b/esphome/components/api/api_pb2.cpp index 70bcf082a6..12b0bf6c98 100644 --- a/esphome/components/api/api_pb2.cpp +++ b/esphome/components/api/api_pb2.cpp @@ -1064,6 +1064,17 @@ bool ExecuteServiceArgument::decode_32bit(uint32_t field_id, Proto32Bit value) { } return true; } +void ExecuteServiceArgument::decode(const uint8_t *buffer, size_t length) { + uint32_t count_bool_array = ProtoDecodableMessage::count_repeated_field(buffer, length, 6); + this->bool_array.init(count_bool_array); + uint32_t count_int_array = ProtoDecodableMessage::count_repeated_field(buffer, length, 7); + this->int_array.init(count_int_array); + uint32_t count_float_array = ProtoDecodableMessage::count_repeated_field(buffer, length, 8); + this->float_array.init(count_float_array); + uint32_t count_string_array = ProtoDecodableMessage::count_repeated_field(buffer, length, 9); + this->string_array.init(count_string_array); + ProtoDecodableMessage::decode(buffer, length); +} bool ExecuteServiceRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value) { switch (field_id) { case 2: @@ -1085,6 +1096,11 @@ bool ExecuteServiceRequest::decode_32bit(uint32_t field_id, Proto32Bit value) { } return true; } +void ExecuteServiceRequest::decode(const uint8_t *buffer, size_t length) { + uint32_t count_args = ProtoDecodableMessage::count_repeated_field(buffer, length, 2); + this->args.init(count_args); + ProtoDecodableMessage::decode(buffer, length); +} #endif #ifdef USE_CAMERA void ListEntitiesCameraResponse::encode(ProtoWriteBuffer buffer) const { diff --git a/esphome/components/api/api_pb2.h b/esphome/components/api/api_pb2.h index 20866850a9..5433496d90 100644 --- a/esphome/components/api/api_pb2.h +++ b/esphome/components/api/api_pb2.h @@ -1279,10 +1279,11 @@ class ExecuteServiceArgument final : public ProtoDecodableMessage { float float_{0.0f}; std::string string_{}; int32_t int_{0}; - std::vector bool_array{}; - std::vector int_array{}; - std::vector float_array{}; - std::vector string_array{}; + FixedVector bool_array{}; + FixedVector int_array{}; + FixedVector float_array{}; + FixedVector string_array{}; + void decode(const uint8_t *buffer, size_t length) override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1300,7 +1301,8 @@ class ExecuteServiceRequest final : public ProtoDecodableMessage { const char *message_name() const override { return "execute_service_request"; } #endif uint32_t key{0}; - std::vector args{}; + FixedVector args{}; + void decode(const uint8_t *buffer, size_t length) override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif diff --git a/esphome/components/api/proto.cpp b/esphome/components/api/proto.cpp index afda5d32ba..f99e5b66e5 100644 --- a/esphome/components/api/proto.cpp +++ b/esphome/components/api/proto.cpp @@ -7,6 +7,69 @@ namespace esphome::api { static const char *const TAG = "api.proto"; +uint32_t ProtoDecodableMessage::count_repeated_field(const uint8_t *buffer, size_t length, uint32_t target_field_id) { + uint32_t count = 0; + const uint8_t *ptr = buffer; + const uint8_t *end = buffer + length; + + while (ptr < end) { + uint32_t consumed; + + // Parse field header (tag) + auto res = ProtoVarInt::parse(ptr, end - ptr, &consumed); + if (!res.has_value()) { + break; // Invalid data, stop counting + } + + uint32_t tag = res->as_uint32(); + uint32_t field_type = tag & 0b111; + uint32_t field_id = tag >> 3; + ptr += consumed; + + // Count if this is the target field + if (field_id == target_field_id) { + count++; + } + + // Skip field data based on wire type + switch (field_type) { + case 0: { // VarInt - parse and skip + res = ProtoVarInt::parse(ptr, end - ptr, &consumed); + if (!res.has_value()) { + return count; // Invalid data, return what we have + } + ptr += consumed; + break; + } + case 2: { // Length-delimited - parse length and skip data + res = ProtoVarInt::parse(ptr, end - ptr, &consumed); + if (!res.has_value()) { + return count; + } + uint32_t field_length = res->as_uint32(); + ptr += consumed; + if (ptr + field_length > end) { + return count; // Out of bounds + } + ptr += field_length; + break; + } + case 5: { // 32-bit - skip 4 bytes + if (ptr + 4 > end) { + return count; + } + ptr += 4; + break; + } + default: + // Unknown wire type, can't continue + return count; + } + } + + return count; +} + void ProtoDecodableMessage::decode(const uint8_t *buffer, size_t length) { const uint8_t *ptr = buffer; const uint8_t *end = buffer + length; diff --git a/esphome/components/api/proto.h b/esphome/components/api/proto.h index a6a09bf7c5..50f85fb247 100644 --- a/esphome/components/api/proto.h +++ b/esphome/components/api/proto.h @@ -354,7 +354,18 @@ class ProtoMessage { // Base class for messages that support decoding class ProtoDecodableMessage : public ProtoMessage { public: - void decode(const uint8_t *buffer, size_t length); + virtual void decode(const uint8_t *buffer, size_t length); + + /** + * Count occurrences of a repeated field in a protobuf buffer. + * This is a lightweight scan that only parses tags and skips field data. + * + * @param buffer Pointer to the protobuf buffer + * @param length Length of the buffer in bytes + * @param target_field_id The field ID to count + * @return Number of times the field appears in the buffer + */ + static uint32_t count_repeated_field(const uint8_t *buffer, size_t length, uint32_t target_field_id); protected: virtual bool decode_varint(uint32_t field_id, ProtoVarInt value) { return false; } diff --git a/esphome/components/api/user_services.cpp b/esphome/components/api/user_services.cpp index 27b30eb332..3cbf2ab5f9 100644 --- a/esphome/components/api/user_services.cpp +++ b/esphome/components/api/user_services.cpp @@ -12,16 +12,16 @@ template<> int32_t get_execute_arg_value(const ExecuteServiceArgument & template<> float get_execute_arg_value(const ExecuteServiceArgument &arg) { return arg.float_; } template<> std::string get_execute_arg_value(const ExecuteServiceArgument &arg) { return arg.string_; } template<> std::vector get_execute_arg_value>(const ExecuteServiceArgument &arg) { - return arg.bool_array; + return std::vector(arg.bool_array.begin(), arg.bool_array.end()); } template<> std::vector get_execute_arg_value>(const ExecuteServiceArgument &arg) { - return arg.int_array; + return std::vector(arg.int_array.begin(), arg.int_array.end()); } template<> std::vector get_execute_arg_value>(const ExecuteServiceArgument &arg) { - return arg.float_array; + return std::vector(arg.float_array.begin(), arg.float_array.end()); } template<> std::vector get_execute_arg_value>(const ExecuteServiceArgument &arg) { - return arg.string_array; + return std::vector(arg.string_array.begin(), arg.string_array.end()); } template<> enums::ServiceArgType to_service_arg_type() { return enums::SERVICE_ARG_TYPE_BOOL; } diff --git a/esphome/components/api/user_services.h b/esphome/components/api/user_services.h index 29843a2f78..9ca5e1093e 100644 --- a/esphome/components/api/user_services.h +++ b/esphome/components/api/user_services.h @@ -55,7 +55,7 @@ template class UserServiceBase : public UserServiceDescriptor { protected: virtual void execute(Ts... x) = 0; - template void execute_(const std::vector &args, seq type) { + template void execute_(const ArgsContainer &args, seq type) { this->execute((get_execute_arg_value(args[S]))...); } diff --git a/script/api_protobuf/api_protobuf.py b/script/api_protobuf/api_protobuf.py index 9a55f1d136..f58442ff01 100755 --- a/script/api_protobuf/api_protobuf.py +++ b/script/api_protobuf/api_protobuf.py @@ -1879,6 +1879,9 @@ def build_message_type( ) public_content.append("#endif") + # Collect fixed_vector fields for custom decode generation + fixed_vector_fields = [] + for field in desc.field: # Skip deprecated fields completely if field.options.deprecated: @@ -1910,6 +1913,14 @@ def build_message_type( f"since we cannot trust or control the number of items received from clients." ) + # Collect fixed_vector repeated fields for custom decode generation + if ( + needs_decode + and field.label == 3 + and get_field_opt(field, pb.fixed_vector, False) + ): + fixed_vector_fields.append((field.name, field.number)) + ti = create_field_type_info(field, needs_decode, needs_encode) # Skip field declarations for fields that are in the base class @@ -2018,6 +2029,22 @@ def build_message_type( prot = "bool decode_64bit(uint32_t field_id, Proto64Bit value) override;" protected_content.insert(0, prot) + # Generate custom decode() override for messages with FixedVector fields + if fixed_vector_fields: + # Generate the decode() implementation in cpp + o = f"void {desc.name}::decode(const uint8_t *buffer, size_t length) {{\n" + # Count and init each FixedVector field + for field_name, field_number in fixed_vector_fields: + o += f" uint32_t count_{field_name} = ProtoDecodableMessage::count_repeated_field(buffer, length, {field_number});\n" + o += f" this->{field_name}.init(count_{field_name});\n" + # Call parent decode to populate the fields + o += " ProtoDecodableMessage::decode(buffer, length);\n" + o += "}\n" + cpp += o + # Generate the decode() declaration in header (public method) + prot = "void decode(const uint8_t *buffer, size_t length) override;" + public_content.append(prot) + # Only generate encode method if this message needs encoding and has fields if needs_encode and encode: o = f"void {desc.name}::encode(ProtoWriteBuffer buffer) const {{"