mirror of
				https://github.com/esphome/esphome.git
				synced 2025-10-25 21:23:53 +01:00 
			
		
		
		
	Merge branch 'execute_fixed_vector' into integration
This commit is contained in:
		| @@ -876,10 +876,10 @@ message ExecuteServiceArgument { | ||||
|   string string_ = 4; | ||||
|   // ESPHome 1.14 (api v1.3) make int a signed value | ||||
|   sint32 int_ = 5; | ||||
|   repeated bool bool_array = 6 [packed=false]; | ||||
|   repeated sint32 int_array = 7 [packed=false]; | ||||
|   repeated float float_array = 8 [packed=false]; | ||||
|   repeated string string_array = 9; | ||||
|   repeated bool bool_array = 6 [packed=false, (fixed_vector) = true]; | ||||
|   repeated sint32 int_array = 7 [packed=false, (fixed_vector) = true]; | ||||
|   repeated float float_array = 8 [packed=false, (fixed_vector) = true]; | ||||
|   repeated string string_array = 9 [(fixed_vector) = true]; | ||||
| } | ||||
| message ExecuteServiceRequest { | ||||
|   option (id) = 42; | ||||
| @@ -888,7 +888,7 @@ message ExecuteServiceRequest { | ||||
|   option (ifdef) = "USE_API_SERVICES"; | ||||
|  | ||||
|   fixed32 key = 1; | ||||
|   repeated ExecuteServiceArgument args = 2; | ||||
|   repeated ExecuteServiceArgument args = 2 [(fixed_vector) = true]; | ||||
| } | ||||
|  | ||||
| // ==================== CAMERA ==================== | ||||
|   | ||||
| @@ -1064,6 +1064,17 @@ bool ExecuteServiceArgument::decode_32bit(uint32_t field_id, Proto32Bit value) { | ||||
|   } | ||||
|   return true; | ||||
| } | ||||
| void ExecuteServiceArgument::decode(const uint8_t *buffer, size_t length) { | ||||
|   uint32_t count_bool_array = ProtoDecodableMessage::count_repeated_field(buffer, length, 6); | ||||
|   this->bool_array.init(count_bool_array); | ||||
|   uint32_t count_int_array = ProtoDecodableMessage::count_repeated_field(buffer, length, 7); | ||||
|   this->int_array.init(count_int_array); | ||||
|   uint32_t count_float_array = ProtoDecodableMessage::count_repeated_field(buffer, length, 8); | ||||
|   this->float_array.init(count_float_array); | ||||
|   uint32_t count_string_array = ProtoDecodableMessage::count_repeated_field(buffer, length, 9); | ||||
|   this->string_array.init(count_string_array); | ||||
|   ProtoDecodableMessage::decode(buffer, length); | ||||
| } | ||||
| bool ExecuteServiceRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value) { | ||||
|   switch (field_id) { | ||||
|     case 2: | ||||
| @@ -1085,6 +1096,11 @@ bool ExecuteServiceRequest::decode_32bit(uint32_t field_id, Proto32Bit value) { | ||||
|   } | ||||
|   return true; | ||||
| } | ||||
| void ExecuteServiceRequest::decode(const uint8_t *buffer, size_t length) { | ||||
|   uint32_t count_args = ProtoDecodableMessage::count_repeated_field(buffer, length, 2); | ||||
|   this->args.init(count_args); | ||||
|   ProtoDecodableMessage::decode(buffer, length); | ||||
| } | ||||
| #endif | ||||
| #ifdef USE_CAMERA | ||||
| void ListEntitiesCameraResponse::encode(ProtoWriteBuffer buffer) const { | ||||
|   | ||||
| @@ -1279,10 +1279,11 @@ class ExecuteServiceArgument final : public ProtoDecodableMessage { | ||||
|   float float_{0.0f}; | ||||
|   std::string string_{}; | ||||
|   int32_t int_{0}; | ||||
|   std::vector<bool> bool_array{}; | ||||
|   std::vector<int32_t> int_array{}; | ||||
|   std::vector<float> float_array{}; | ||||
|   std::vector<std::string> string_array{}; | ||||
|   FixedVector<bool> bool_array{}; | ||||
|   FixedVector<int32_t> int_array{}; | ||||
|   FixedVector<float> float_array{}; | ||||
|   FixedVector<std::string> string_array{}; | ||||
|   void decode(const uint8_t *buffer, size_t length) override; | ||||
| #ifdef HAS_PROTO_MESSAGE_DUMP | ||||
|   void dump_to(std::string &out) const override; | ||||
| #endif | ||||
| @@ -1300,7 +1301,8 @@ class ExecuteServiceRequest final : public ProtoDecodableMessage { | ||||
|   const char *message_name() const override { return "execute_service_request"; } | ||||
| #endif | ||||
|   uint32_t key{0}; | ||||
|   std::vector<ExecuteServiceArgument> args{}; | ||||
|   FixedVector<ExecuteServiceArgument> args{}; | ||||
|   void decode(const uint8_t *buffer, size_t length) override; | ||||
| #ifdef HAS_PROTO_MESSAGE_DUMP | ||||
|   void dump_to(std::string &out) const override; | ||||
| #endif | ||||
|   | ||||
| @@ -7,6 +7,69 @@ namespace esphome::api { | ||||
|  | ||||
| static const char *const TAG = "api.proto"; | ||||
|  | ||||
| uint32_t ProtoDecodableMessage::count_repeated_field(const uint8_t *buffer, size_t length, uint32_t target_field_id) { | ||||
|   uint32_t count = 0; | ||||
|   const uint8_t *ptr = buffer; | ||||
|   const uint8_t *end = buffer + length; | ||||
|  | ||||
|   while (ptr < end) { | ||||
|     uint32_t consumed; | ||||
|  | ||||
|     // Parse field header (tag) | ||||
|     auto res = ProtoVarInt::parse(ptr, end - ptr, &consumed); | ||||
|     if (!res.has_value()) { | ||||
|       break;  // Invalid data, stop counting | ||||
|     } | ||||
|  | ||||
|     uint32_t tag = res->as_uint32(); | ||||
|     uint32_t field_type = tag & 0b111; | ||||
|     uint32_t field_id = tag >> 3; | ||||
|     ptr += consumed; | ||||
|  | ||||
|     // Count if this is the target field | ||||
|     if (field_id == target_field_id) { | ||||
|       count++; | ||||
|     } | ||||
|  | ||||
|     // Skip field data based on wire type | ||||
|     switch (field_type) { | ||||
|       case 0: {  // VarInt - parse and skip | ||||
|         res = ProtoVarInt::parse(ptr, end - ptr, &consumed); | ||||
|         if (!res.has_value()) { | ||||
|           return count;  // Invalid data, return what we have | ||||
|         } | ||||
|         ptr += consumed; | ||||
|         break; | ||||
|       } | ||||
|       case 2: {  // Length-delimited - parse length and skip data | ||||
|         res = ProtoVarInt::parse(ptr, end - ptr, &consumed); | ||||
|         if (!res.has_value()) { | ||||
|           return count; | ||||
|         } | ||||
|         uint32_t field_length = res->as_uint32(); | ||||
|         ptr += consumed; | ||||
|         if (ptr + field_length > end) { | ||||
|           return count;  // Out of bounds | ||||
|         } | ||||
|         ptr += field_length; | ||||
|         break; | ||||
|       } | ||||
|       case 5: {  // 32-bit - skip 4 bytes | ||||
|         if (ptr + 4 > end) { | ||||
|           return count; | ||||
|         } | ||||
|         ptr += 4; | ||||
|         break; | ||||
|       } | ||||
|       default: | ||||
|         // Unknown wire type, can't continue | ||||
|         return count; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   return count; | ||||
| } | ||||
|  | ||||
| void ProtoDecodableMessage::decode(const uint8_t *buffer, size_t length) { | ||||
|   const uint8_t *ptr = buffer; | ||||
|   const uint8_t *end = buffer + length; | ||||
|   | ||||
| @@ -354,7 +354,18 @@ class ProtoMessage { | ||||
| // Base class for messages that support decoding | ||||
| class ProtoDecodableMessage : public ProtoMessage { | ||||
|  public: | ||||
|   void decode(const uint8_t *buffer, size_t length); | ||||
|   virtual void decode(const uint8_t *buffer, size_t length); | ||||
|  | ||||
|   /** | ||||
|    * Count occurrences of a repeated field in a protobuf buffer. | ||||
|    * This is a lightweight scan that only parses tags and skips field data. | ||||
|    * | ||||
|    * @param buffer Pointer to the protobuf buffer | ||||
|    * @param length Length of the buffer in bytes | ||||
|    * @param target_field_id The field ID to count | ||||
|    * @return Number of times the field appears in the buffer | ||||
|    */ | ||||
|   static uint32_t count_repeated_field(const uint8_t *buffer, size_t length, uint32_t target_field_id); | ||||
|  | ||||
|  protected: | ||||
|   virtual bool decode_varint(uint32_t field_id, ProtoVarInt value) { return false; } | ||||
|   | ||||
| @@ -12,16 +12,16 @@ template<> int32_t get_execute_arg_value<int32_t>(const ExecuteServiceArgument & | ||||
| template<> float get_execute_arg_value<float>(const ExecuteServiceArgument &arg) { return arg.float_; } | ||||
| template<> std::string get_execute_arg_value<std::string>(const ExecuteServiceArgument &arg) { return arg.string_; } | ||||
| template<> std::vector<bool> get_execute_arg_value<std::vector<bool>>(const ExecuteServiceArgument &arg) { | ||||
|   return arg.bool_array; | ||||
|   return std::vector<bool>(arg.bool_array.begin(), arg.bool_array.end()); | ||||
| } | ||||
| template<> std::vector<int32_t> get_execute_arg_value<std::vector<int32_t>>(const ExecuteServiceArgument &arg) { | ||||
|   return arg.int_array; | ||||
|   return std::vector<int32_t>(arg.int_array.begin(), arg.int_array.end()); | ||||
| } | ||||
| template<> std::vector<float> get_execute_arg_value<std::vector<float>>(const ExecuteServiceArgument &arg) { | ||||
|   return arg.float_array; | ||||
|   return std::vector<float>(arg.float_array.begin(), arg.float_array.end()); | ||||
| } | ||||
| template<> std::vector<std::string> get_execute_arg_value<std::vector<std::string>>(const ExecuteServiceArgument &arg) { | ||||
|   return arg.string_array; | ||||
|   return std::vector<std::string>(arg.string_array.begin(), arg.string_array.end()); | ||||
| } | ||||
|  | ||||
| template<> enums::ServiceArgType to_service_arg_type<bool>() { return enums::SERVICE_ARG_TYPE_BOOL; } | ||||
|   | ||||
| @@ -55,7 +55,7 @@ template<typename... Ts> class UserServiceBase : public UserServiceDescriptor { | ||||
|  | ||||
|  protected: | ||||
|   virtual void execute(Ts... x) = 0; | ||||
|   template<int... S> void execute_(const std::vector<ExecuteServiceArgument> &args, seq<S...> type) { | ||||
|   template<typename ArgsContainer, int... S> void execute_(const ArgsContainer &args, seq<S...> type) { | ||||
|     this->execute((get_execute_arg_value<Ts>(args[S]))...); | ||||
|   } | ||||
|  | ||||
|   | ||||
| @@ -1879,6 +1879,9 @@ def build_message_type( | ||||
|         ) | ||||
|         public_content.append("#endif") | ||||
|  | ||||
|     # Collect fixed_vector fields for custom decode generation | ||||
|     fixed_vector_fields = [] | ||||
|  | ||||
|     for field in desc.field: | ||||
|         # Skip deprecated fields completely | ||||
|         if field.options.deprecated: | ||||
| @@ -1910,6 +1913,14 @@ def build_message_type( | ||||
|                 f"since we cannot trust or control the number of items received from clients." | ||||
|             ) | ||||
|  | ||||
|         # Collect fixed_vector repeated fields for custom decode generation | ||||
|         if ( | ||||
|             needs_decode | ||||
|             and field.label == 3 | ||||
|             and get_field_opt(field, pb.fixed_vector, False) | ||||
|         ): | ||||
|             fixed_vector_fields.append((field.name, field.number)) | ||||
|  | ||||
|         ti = create_field_type_info(field, needs_decode, needs_encode) | ||||
|  | ||||
|         # Skip field declarations for fields that are in the base class | ||||
| @@ -2018,6 +2029,22 @@ def build_message_type( | ||||
|         prot = "bool decode_64bit(uint32_t field_id, Proto64Bit value) override;" | ||||
|         protected_content.insert(0, prot) | ||||
|  | ||||
|     # Generate custom decode() override for messages with FixedVector fields | ||||
|     if fixed_vector_fields: | ||||
|         # Generate the decode() implementation in cpp | ||||
|         o = f"void {desc.name}::decode(const uint8_t *buffer, size_t length) {{\n" | ||||
|         # Count and init each FixedVector field | ||||
|         for field_name, field_number in fixed_vector_fields: | ||||
|             o += f"  uint32_t count_{field_name} = ProtoDecodableMessage::count_repeated_field(buffer, length, {field_number});\n" | ||||
|             o += f"  this->{field_name}.init(count_{field_name});\n" | ||||
|         # Call parent decode to populate the fields | ||||
|         o += "  ProtoDecodableMessage::decode(buffer, length);\n" | ||||
|         o += "}\n" | ||||
|         cpp += o | ||||
|         # Generate the decode() declaration in header (public method) | ||||
|         prot = "void decode(const uint8_t *buffer, size_t length) override;" | ||||
|         public_content.append(prot) | ||||
|  | ||||
|     # Only generate encode method if this message needs encoding and has fields | ||||
|     if needs_encode and encode: | ||||
|         o = f"void {desc.name}::encode(ProtoWriteBuffer buffer) const {{" | ||||
|   | ||||
		Reference in New Issue
	
	Block a user