mirror of
				https://github.com/esphome/esphome.git
				synced 2025-10-26 04:33:47 +00:00 
			
		
		
		
	Merge branch 'execute_fixed_vector' into integration
This commit is contained in:
		| @@ -876,10 +876,10 @@ message ExecuteServiceArgument { | |||||||
|   string string_ = 4; |   string string_ = 4; | ||||||
|   // ESPHome 1.14 (api v1.3) make int a signed value |   // ESPHome 1.14 (api v1.3) make int a signed value | ||||||
|   sint32 int_ = 5; |   sint32 int_ = 5; | ||||||
|   repeated bool bool_array = 6 [packed=false]; |   repeated bool bool_array = 6 [packed=false, (fixed_vector) = true]; | ||||||
|   repeated sint32 int_array = 7 [packed=false]; |   repeated sint32 int_array = 7 [packed=false, (fixed_vector) = true]; | ||||||
|   repeated float float_array = 8 [packed=false]; |   repeated float float_array = 8 [packed=false, (fixed_vector) = true]; | ||||||
|   repeated string string_array = 9; |   repeated string string_array = 9 [(fixed_vector) = true]; | ||||||
| } | } | ||||||
| message ExecuteServiceRequest { | message ExecuteServiceRequest { | ||||||
|   option (id) = 42; |   option (id) = 42; | ||||||
| @@ -888,7 +888,7 @@ message ExecuteServiceRequest { | |||||||
|   option (ifdef) = "USE_API_SERVICES"; |   option (ifdef) = "USE_API_SERVICES"; | ||||||
|  |  | ||||||
|   fixed32 key = 1; |   fixed32 key = 1; | ||||||
|   repeated ExecuteServiceArgument args = 2; |   repeated ExecuteServiceArgument args = 2 [(fixed_vector) = true]; | ||||||
| } | } | ||||||
|  |  | ||||||
| // ==================== CAMERA ==================== | // ==================== CAMERA ==================== | ||||||
|   | |||||||
| @@ -1064,6 +1064,17 @@ bool ExecuteServiceArgument::decode_32bit(uint32_t field_id, Proto32Bit value) { | |||||||
|   } |   } | ||||||
|   return true; |   return true; | ||||||
| } | } | ||||||
|  | void ExecuteServiceArgument::decode(const uint8_t *buffer, size_t length) { | ||||||
|  |   uint32_t count_bool_array = ProtoDecodableMessage::count_repeated_field(buffer, length, 6); | ||||||
|  |   this->bool_array.init(count_bool_array); | ||||||
|  |   uint32_t count_int_array = ProtoDecodableMessage::count_repeated_field(buffer, length, 7); | ||||||
|  |   this->int_array.init(count_int_array); | ||||||
|  |   uint32_t count_float_array = ProtoDecodableMessage::count_repeated_field(buffer, length, 8); | ||||||
|  |   this->float_array.init(count_float_array); | ||||||
|  |   uint32_t count_string_array = ProtoDecodableMessage::count_repeated_field(buffer, length, 9); | ||||||
|  |   this->string_array.init(count_string_array); | ||||||
|  |   ProtoDecodableMessage::decode(buffer, length); | ||||||
|  | } | ||||||
| bool ExecuteServiceRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value) { | bool ExecuteServiceRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value) { | ||||||
|   switch (field_id) { |   switch (field_id) { | ||||||
|     case 2: |     case 2: | ||||||
| @@ -1085,6 +1096,11 @@ bool ExecuteServiceRequest::decode_32bit(uint32_t field_id, Proto32Bit value) { | |||||||
|   } |   } | ||||||
|   return true; |   return true; | ||||||
| } | } | ||||||
|  | void ExecuteServiceRequest::decode(const uint8_t *buffer, size_t length) { | ||||||
|  |   uint32_t count_args = ProtoDecodableMessage::count_repeated_field(buffer, length, 2); | ||||||
|  |   this->args.init(count_args); | ||||||
|  |   ProtoDecodableMessage::decode(buffer, length); | ||||||
|  | } | ||||||
| #endif | #endif | ||||||
| #ifdef USE_CAMERA | #ifdef USE_CAMERA | ||||||
| void ListEntitiesCameraResponse::encode(ProtoWriteBuffer buffer) const { | void ListEntitiesCameraResponse::encode(ProtoWriteBuffer buffer) const { | ||||||
|   | |||||||
| @@ -1279,10 +1279,11 @@ class ExecuteServiceArgument final : public ProtoDecodableMessage { | |||||||
|   float float_{0.0f}; |   float float_{0.0f}; | ||||||
|   std::string string_{}; |   std::string string_{}; | ||||||
|   int32_t int_{0}; |   int32_t int_{0}; | ||||||
|   std::vector<bool> bool_array{}; |   FixedVector<bool> bool_array{}; | ||||||
|   std::vector<int32_t> int_array{}; |   FixedVector<int32_t> int_array{}; | ||||||
|   std::vector<float> float_array{}; |   FixedVector<float> float_array{}; | ||||||
|   std::vector<std::string> string_array{}; |   FixedVector<std::string> string_array{}; | ||||||
|  |   void decode(const uint8_t *buffer, size_t length) override; | ||||||
| #ifdef HAS_PROTO_MESSAGE_DUMP | #ifdef HAS_PROTO_MESSAGE_DUMP | ||||||
|   void dump_to(std::string &out) const override; |   void dump_to(std::string &out) const override; | ||||||
| #endif | #endif | ||||||
| @@ -1300,7 +1301,8 @@ class ExecuteServiceRequest final : public ProtoDecodableMessage { | |||||||
|   const char *message_name() const override { return "execute_service_request"; } |   const char *message_name() const override { return "execute_service_request"; } | ||||||
| #endif | #endif | ||||||
|   uint32_t key{0}; |   uint32_t key{0}; | ||||||
|   std::vector<ExecuteServiceArgument> args{}; |   FixedVector<ExecuteServiceArgument> args{}; | ||||||
|  |   void decode(const uint8_t *buffer, size_t length) override; | ||||||
| #ifdef HAS_PROTO_MESSAGE_DUMP | #ifdef HAS_PROTO_MESSAGE_DUMP | ||||||
|   void dump_to(std::string &out) const override; |   void dump_to(std::string &out) const override; | ||||||
| #endif | #endif | ||||||
|   | |||||||
| @@ -7,6 +7,69 @@ namespace esphome::api { | |||||||
|  |  | ||||||
| static const char *const TAG = "api.proto"; | static const char *const TAG = "api.proto"; | ||||||
|  |  | ||||||
|  | uint32_t ProtoDecodableMessage::count_repeated_field(const uint8_t *buffer, size_t length, uint32_t target_field_id) { | ||||||
|  |   uint32_t count = 0; | ||||||
|  |   const uint8_t *ptr = buffer; | ||||||
|  |   const uint8_t *end = buffer + length; | ||||||
|  |  | ||||||
|  |   while (ptr < end) { | ||||||
|  |     uint32_t consumed; | ||||||
|  |  | ||||||
|  |     // Parse field header (tag) | ||||||
|  |     auto res = ProtoVarInt::parse(ptr, end - ptr, &consumed); | ||||||
|  |     if (!res.has_value()) { | ||||||
|  |       break;  // Invalid data, stop counting | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     uint32_t tag = res->as_uint32(); | ||||||
|  |     uint32_t field_type = tag & 0b111; | ||||||
|  |     uint32_t field_id = tag >> 3; | ||||||
|  |     ptr += consumed; | ||||||
|  |  | ||||||
|  |     // Count if this is the target field | ||||||
|  |     if (field_id == target_field_id) { | ||||||
|  |       count++; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // Skip field data based on wire type | ||||||
|  |     switch (field_type) { | ||||||
|  |       case 0: {  // VarInt - parse and skip | ||||||
|  |         res = ProtoVarInt::parse(ptr, end - ptr, &consumed); | ||||||
|  |         if (!res.has_value()) { | ||||||
|  |           return count;  // Invalid data, return what we have | ||||||
|  |         } | ||||||
|  |         ptr += consumed; | ||||||
|  |         break; | ||||||
|  |       } | ||||||
|  |       case 2: {  // Length-delimited - parse length and skip data | ||||||
|  |         res = ProtoVarInt::parse(ptr, end - ptr, &consumed); | ||||||
|  |         if (!res.has_value()) { | ||||||
|  |           return count; | ||||||
|  |         } | ||||||
|  |         uint32_t field_length = res->as_uint32(); | ||||||
|  |         ptr += consumed; | ||||||
|  |         if (ptr + field_length > end) { | ||||||
|  |           return count;  // Out of bounds | ||||||
|  |         } | ||||||
|  |         ptr += field_length; | ||||||
|  |         break; | ||||||
|  |       } | ||||||
|  |       case 5: {  // 32-bit - skip 4 bytes | ||||||
|  |         if (ptr + 4 > end) { | ||||||
|  |           return count; | ||||||
|  |         } | ||||||
|  |         ptr += 4; | ||||||
|  |         break; | ||||||
|  |       } | ||||||
|  |       default: | ||||||
|  |         // Unknown wire type, can't continue | ||||||
|  |         return count; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   return count; | ||||||
|  | } | ||||||
|  |  | ||||||
| void ProtoDecodableMessage::decode(const uint8_t *buffer, size_t length) { | void ProtoDecodableMessage::decode(const uint8_t *buffer, size_t length) { | ||||||
|   const uint8_t *ptr = buffer; |   const uint8_t *ptr = buffer; | ||||||
|   const uint8_t *end = buffer + length; |   const uint8_t *end = buffer + length; | ||||||
|   | |||||||
| @@ -354,7 +354,18 @@ class ProtoMessage { | |||||||
| // Base class for messages that support decoding | // Base class for messages that support decoding | ||||||
| class ProtoDecodableMessage : public ProtoMessage { | class ProtoDecodableMessage : public ProtoMessage { | ||||||
|  public: |  public: | ||||||
|   void decode(const uint8_t *buffer, size_t length); |   virtual void decode(const uint8_t *buffer, size_t length); | ||||||
|  |  | ||||||
|  |   /** | ||||||
|  |    * Count occurrences of a repeated field in a protobuf buffer. | ||||||
|  |    * This is a lightweight scan that only parses tags and skips field data. | ||||||
|  |    * | ||||||
|  |    * @param buffer Pointer to the protobuf buffer | ||||||
|  |    * @param length Length of the buffer in bytes | ||||||
|  |    * @param target_field_id The field ID to count | ||||||
|  |    * @return Number of times the field appears in the buffer | ||||||
|  |    */ | ||||||
|  |   static uint32_t count_repeated_field(const uint8_t *buffer, size_t length, uint32_t target_field_id); | ||||||
|  |  | ||||||
|  protected: |  protected: | ||||||
|   virtual bool decode_varint(uint32_t field_id, ProtoVarInt value) { return false; } |   virtual bool decode_varint(uint32_t field_id, ProtoVarInt value) { return false; } | ||||||
|   | |||||||
| @@ -12,16 +12,16 @@ template<> int32_t get_execute_arg_value<int32_t>(const ExecuteServiceArgument & | |||||||
| template<> float get_execute_arg_value<float>(const ExecuteServiceArgument &arg) { return arg.float_; } | template<> float get_execute_arg_value<float>(const ExecuteServiceArgument &arg) { return arg.float_; } | ||||||
| template<> std::string get_execute_arg_value<std::string>(const ExecuteServiceArgument &arg) { return arg.string_; } | template<> std::string get_execute_arg_value<std::string>(const ExecuteServiceArgument &arg) { return arg.string_; } | ||||||
| template<> std::vector<bool> get_execute_arg_value<std::vector<bool>>(const ExecuteServiceArgument &arg) { | template<> std::vector<bool> get_execute_arg_value<std::vector<bool>>(const ExecuteServiceArgument &arg) { | ||||||
|   return arg.bool_array; |   return std::vector<bool>(arg.bool_array.begin(), arg.bool_array.end()); | ||||||
| } | } | ||||||
| template<> std::vector<int32_t> get_execute_arg_value<std::vector<int32_t>>(const ExecuteServiceArgument &arg) { | template<> std::vector<int32_t> get_execute_arg_value<std::vector<int32_t>>(const ExecuteServiceArgument &arg) { | ||||||
|   return arg.int_array; |   return std::vector<int32_t>(arg.int_array.begin(), arg.int_array.end()); | ||||||
| } | } | ||||||
| template<> std::vector<float> get_execute_arg_value<std::vector<float>>(const ExecuteServiceArgument &arg) { | template<> std::vector<float> get_execute_arg_value<std::vector<float>>(const ExecuteServiceArgument &arg) { | ||||||
|   return arg.float_array; |   return std::vector<float>(arg.float_array.begin(), arg.float_array.end()); | ||||||
| } | } | ||||||
| template<> std::vector<std::string> get_execute_arg_value<std::vector<std::string>>(const ExecuteServiceArgument &arg) { | template<> std::vector<std::string> get_execute_arg_value<std::vector<std::string>>(const ExecuteServiceArgument &arg) { | ||||||
|   return arg.string_array; |   return std::vector<std::string>(arg.string_array.begin(), arg.string_array.end()); | ||||||
| } | } | ||||||
|  |  | ||||||
| template<> enums::ServiceArgType to_service_arg_type<bool>() { return enums::SERVICE_ARG_TYPE_BOOL; } | template<> enums::ServiceArgType to_service_arg_type<bool>() { return enums::SERVICE_ARG_TYPE_BOOL; } | ||||||
|   | |||||||
| @@ -55,7 +55,7 @@ template<typename... Ts> class UserServiceBase : public UserServiceDescriptor { | |||||||
|  |  | ||||||
|  protected: |  protected: | ||||||
|   virtual void execute(Ts... x) = 0; |   virtual void execute(Ts... x) = 0; | ||||||
|   template<int... S> void execute_(const std::vector<ExecuteServiceArgument> &args, seq<S...> type) { |   template<typename ArgsContainer, int... S> void execute_(const ArgsContainer &args, seq<S...> type) { | ||||||
|     this->execute((get_execute_arg_value<Ts>(args[S]))...); |     this->execute((get_execute_arg_value<Ts>(args[S]))...); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1879,6 +1879,9 @@ def build_message_type( | |||||||
|         ) |         ) | ||||||
|         public_content.append("#endif") |         public_content.append("#endif") | ||||||
|  |  | ||||||
|  |     # Collect fixed_vector fields for custom decode generation | ||||||
|  |     fixed_vector_fields = [] | ||||||
|  |  | ||||||
|     for field in desc.field: |     for field in desc.field: | ||||||
|         # Skip deprecated fields completely |         # Skip deprecated fields completely | ||||||
|         if field.options.deprecated: |         if field.options.deprecated: | ||||||
| @@ -1910,6 +1913,14 @@ def build_message_type( | |||||||
|                 f"since we cannot trust or control the number of items received from clients." |                 f"since we cannot trust or control the number of items received from clients." | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|  |         # Collect fixed_vector repeated fields for custom decode generation | ||||||
|  |         if ( | ||||||
|  |             needs_decode | ||||||
|  |             and field.label == 3 | ||||||
|  |             and get_field_opt(field, pb.fixed_vector, False) | ||||||
|  |         ): | ||||||
|  |             fixed_vector_fields.append((field.name, field.number)) | ||||||
|  |  | ||||||
|         ti = create_field_type_info(field, needs_decode, needs_encode) |         ti = create_field_type_info(field, needs_decode, needs_encode) | ||||||
|  |  | ||||||
|         # Skip field declarations for fields that are in the base class |         # Skip field declarations for fields that are in the base class | ||||||
| @@ -2018,6 +2029,22 @@ def build_message_type( | |||||||
|         prot = "bool decode_64bit(uint32_t field_id, Proto64Bit value) override;" |         prot = "bool decode_64bit(uint32_t field_id, Proto64Bit value) override;" | ||||||
|         protected_content.insert(0, prot) |         protected_content.insert(0, prot) | ||||||
|  |  | ||||||
|  |     # Generate custom decode() override for messages with FixedVector fields | ||||||
|  |     if fixed_vector_fields: | ||||||
|  |         # Generate the decode() implementation in cpp | ||||||
|  |         o = f"void {desc.name}::decode(const uint8_t *buffer, size_t length) {{\n" | ||||||
|  |         # Count and init each FixedVector field | ||||||
|  |         for field_name, field_number in fixed_vector_fields: | ||||||
|  |             o += f"  uint32_t count_{field_name} = ProtoDecodableMessage::count_repeated_field(buffer, length, {field_number});\n" | ||||||
|  |             o += f"  this->{field_name}.init(count_{field_name});\n" | ||||||
|  |         # Call parent decode to populate the fields | ||||||
|  |         o += "  ProtoDecodableMessage::decode(buffer, length);\n" | ||||||
|  |         o += "}\n" | ||||||
|  |         cpp += o | ||||||
|  |         # Generate the decode() declaration in header (public method) | ||||||
|  |         prot = "void decode(const uint8_t *buffer, size_t length) override;" | ||||||
|  |         public_content.append(prot) | ||||||
|  |  | ||||||
|     # Only generate encode method if this message needs encoding and has fields |     # Only generate encode method if this message needs encoding and has fields | ||||||
|     if needs_encode and encode: |     if needs_encode and encode: | ||||||
|         o = f"void {desc.name}::encode(ProtoWriteBuffer buffer) const {{" |         o = f"void {desc.name}::encode(ProtoWriteBuffer buffer) const {{" | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user