mirror of
https://github.com/esphome/esphome.git
synced 2025-10-20 18:53:47 +01:00
[api] Use FixedVector for ExecuteServiceRequest/Argument arrays to eliminate reallocations (#11270)
This commit is contained in:
@@ -876,10 +876,10 @@ message ExecuteServiceArgument {
|
||||
string string_ = 4;
|
||||
// ESPHome 1.14 (api v1.3) make int a signed value
|
||||
sint32 int_ = 5;
|
||||
repeated bool bool_array = 6 [packed=false];
|
||||
repeated sint32 int_array = 7 [packed=false];
|
||||
repeated float float_array = 8 [packed=false];
|
||||
repeated string string_array = 9;
|
||||
repeated bool bool_array = 6 [packed=false, (fixed_vector) = true];
|
||||
repeated sint32 int_array = 7 [packed=false, (fixed_vector) = true];
|
||||
repeated float float_array = 8 [packed=false, (fixed_vector) = true];
|
||||
repeated string string_array = 9 [(fixed_vector) = true];
|
||||
}
|
||||
message ExecuteServiceRequest {
|
||||
option (id) = 42;
|
||||
@@ -888,7 +888,7 @@ message ExecuteServiceRequest {
|
||||
option (ifdef) = "USE_API_SERVICES";
|
||||
|
||||
fixed32 key = 1;
|
||||
repeated ExecuteServiceArgument args = 2;
|
||||
repeated ExecuteServiceArgument args = 2 [(fixed_vector) = true];
|
||||
}
|
||||
|
||||
// ==================== CAMERA ====================
|
||||
|
@@ -1064,6 +1064,17 @@ bool ExecuteServiceArgument::decode_32bit(uint32_t field_id, Proto32Bit value) {
|
||||
}
|
||||
return true;
|
||||
}
|
||||
void ExecuteServiceArgument::decode(const uint8_t *buffer, size_t length) {
|
||||
uint32_t count_bool_array = ProtoDecodableMessage::count_repeated_field(buffer, length, 6);
|
||||
this->bool_array.init(count_bool_array);
|
||||
uint32_t count_int_array = ProtoDecodableMessage::count_repeated_field(buffer, length, 7);
|
||||
this->int_array.init(count_int_array);
|
||||
uint32_t count_float_array = ProtoDecodableMessage::count_repeated_field(buffer, length, 8);
|
||||
this->float_array.init(count_float_array);
|
||||
uint32_t count_string_array = ProtoDecodableMessage::count_repeated_field(buffer, length, 9);
|
||||
this->string_array.init(count_string_array);
|
||||
ProtoDecodableMessage::decode(buffer, length);
|
||||
}
|
||||
bool ExecuteServiceRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value) {
|
||||
switch (field_id) {
|
||||
case 2:
|
||||
@@ -1085,6 +1096,11 @@ bool ExecuteServiceRequest::decode_32bit(uint32_t field_id, Proto32Bit value) {
|
||||
}
|
||||
return true;
|
||||
}
|
||||
void ExecuteServiceRequest::decode(const uint8_t *buffer, size_t length) {
|
||||
uint32_t count_args = ProtoDecodableMessage::count_repeated_field(buffer, length, 2);
|
||||
this->args.init(count_args);
|
||||
ProtoDecodableMessage::decode(buffer, length);
|
||||
}
|
||||
#endif
|
||||
#ifdef USE_CAMERA
|
||||
void ListEntitiesCameraResponse::encode(ProtoWriteBuffer buffer) const {
|
||||
|
@@ -1279,10 +1279,11 @@ class ExecuteServiceArgument final : public ProtoDecodableMessage {
|
||||
float float_{0.0f};
|
||||
std::string string_{};
|
||||
int32_t int_{0};
|
||||
std::vector<bool> bool_array{};
|
||||
std::vector<int32_t> int_array{};
|
||||
std::vector<float> float_array{};
|
||||
std::vector<std::string> string_array{};
|
||||
FixedVector<bool> bool_array{};
|
||||
FixedVector<int32_t> int_array{};
|
||||
FixedVector<float> float_array{};
|
||||
FixedVector<std::string> string_array{};
|
||||
void decode(const uint8_t *buffer, size_t length) override;
|
||||
#ifdef HAS_PROTO_MESSAGE_DUMP
|
||||
void dump_to(std::string &out) const override;
|
||||
#endif
|
||||
@@ -1300,7 +1301,8 @@ class ExecuteServiceRequest final : public ProtoDecodableMessage {
|
||||
const char *message_name() const override { return "execute_service_request"; }
|
||||
#endif
|
||||
uint32_t key{0};
|
||||
std::vector<ExecuteServiceArgument> args{};
|
||||
FixedVector<ExecuteServiceArgument> args{};
|
||||
void decode(const uint8_t *buffer, size_t length) override;
|
||||
#ifdef HAS_PROTO_MESSAGE_DUMP
|
||||
void dump_to(std::string &out) const override;
|
||||
#endif
|
||||
|
@@ -7,6 +7,69 @@ namespace esphome::api {
|
||||
|
||||
static const char *const TAG = "api.proto";
|
||||
|
||||
uint32_t ProtoDecodableMessage::count_repeated_field(const uint8_t *buffer, size_t length, uint32_t target_field_id) {
|
||||
uint32_t count = 0;
|
||||
const uint8_t *ptr = buffer;
|
||||
const uint8_t *end = buffer + length;
|
||||
|
||||
while (ptr < end) {
|
||||
uint32_t consumed;
|
||||
|
||||
// Parse field header (tag)
|
||||
auto res = ProtoVarInt::parse(ptr, end - ptr, &consumed);
|
||||
if (!res.has_value()) {
|
||||
break; // Invalid data, stop counting
|
||||
}
|
||||
|
||||
uint32_t tag = res->as_uint32();
|
||||
uint32_t field_type = tag & WIRE_TYPE_MASK;
|
||||
uint32_t field_id = tag >> 3;
|
||||
ptr += consumed;
|
||||
|
||||
// Count if this is the target field
|
||||
if (field_id == target_field_id) {
|
||||
count++;
|
||||
}
|
||||
|
||||
// Skip field data based on wire type
|
||||
switch (field_type) {
|
||||
case WIRE_TYPE_VARINT: { // VarInt - parse and skip
|
||||
res = ProtoVarInt::parse(ptr, end - ptr, &consumed);
|
||||
if (!res.has_value()) {
|
||||
return count; // Invalid data, return what we have
|
||||
}
|
||||
ptr += consumed;
|
||||
break;
|
||||
}
|
||||
case WIRE_TYPE_LENGTH_DELIMITED: { // Length-delimited - parse length and skip data
|
||||
res = ProtoVarInt::parse(ptr, end - ptr, &consumed);
|
||||
if (!res.has_value()) {
|
||||
return count;
|
||||
}
|
||||
uint32_t field_length = res->as_uint32();
|
||||
ptr += consumed;
|
||||
if (ptr + field_length > end) {
|
||||
return count; // Out of bounds
|
||||
}
|
||||
ptr += field_length;
|
||||
break;
|
||||
}
|
||||
case WIRE_TYPE_FIXED32: { // 32-bit - skip 4 bytes
|
||||
if (ptr + 4 > end) {
|
||||
return count;
|
||||
}
|
||||
ptr += 4;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
// Unknown wire type, can't continue
|
||||
return count;
|
||||
}
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
|
||||
void ProtoDecodableMessage::decode(const uint8_t *buffer, size_t length) {
|
||||
const uint8_t *ptr = buffer;
|
||||
const uint8_t *end = buffer + length;
|
||||
@@ -22,12 +85,12 @@ void ProtoDecodableMessage::decode(const uint8_t *buffer, size_t length) {
|
||||
}
|
||||
|
||||
uint32_t tag = res->as_uint32();
|
||||
uint32_t field_type = tag & 0b111;
|
||||
uint32_t field_type = tag & WIRE_TYPE_MASK;
|
||||
uint32_t field_id = tag >> 3;
|
||||
ptr += consumed;
|
||||
|
||||
switch (field_type) {
|
||||
case 0: { // VarInt
|
||||
case WIRE_TYPE_VARINT: { // VarInt
|
||||
res = ProtoVarInt::parse(ptr, end - ptr, &consumed);
|
||||
if (!res.has_value()) {
|
||||
ESP_LOGV(TAG, "Invalid VarInt at offset %ld", (long) (ptr - buffer));
|
||||
@@ -39,7 +102,7 @@ void ProtoDecodableMessage::decode(const uint8_t *buffer, size_t length) {
|
||||
ptr += consumed;
|
||||
break;
|
||||
}
|
||||
case 2: { // Length-delimited
|
||||
case WIRE_TYPE_LENGTH_DELIMITED: { // Length-delimited
|
||||
res = ProtoVarInt::parse(ptr, end - ptr, &consumed);
|
||||
if (!res.has_value()) {
|
||||
ESP_LOGV(TAG, "Invalid Length Delimited at offset %ld", (long) (ptr - buffer));
|
||||
@@ -57,7 +120,7 @@ void ProtoDecodableMessage::decode(const uint8_t *buffer, size_t length) {
|
||||
ptr += field_length;
|
||||
break;
|
||||
}
|
||||
case 5: { // 32-bit
|
||||
case WIRE_TYPE_FIXED32: { // 32-bit
|
||||
if (ptr + 4 > end) {
|
||||
ESP_LOGV(TAG, "Out-of-bounds Fixed32-bit at offset %ld", (long) (ptr - buffer));
|
||||
return;
|
||||
|
@@ -15,6 +15,13 @@
|
||||
|
||||
namespace esphome::api {
|
||||
|
||||
// Protocol Buffer wire type constants
|
||||
// See https://protobuf.dev/programming-guides/encoding/#structure
|
||||
constexpr uint8_t WIRE_TYPE_VARINT = 0; // int32, int64, uint32, uint64, sint32, sint64, bool, enum
|
||||
constexpr uint8_t WIRE_TYPE_LENGTH_DELIMITED = 2; // string, bytes, embedded messages, packed repeated fields
|
||||
constexpr uint8_t WIRE_TYPE_FIXED32 = 5; // fixed32, sfixed32, float
|
||||
constexpr uint8_t WIRE_TYPE_MASK = 0b111; // Mask to extract wire type from tag
|
||||
|
||||
// Helper functions for ZigZag encoding/decoding
|
||||
inline constexpr uint32_t encode_zigzag32(int32_t value) {
|
||||
return (static_cast<uint32_t>(value) << 1) ^ (static_cast<uint32_t>(value >> 31));
|
||||
@@ -241,7 +248,7 @@ class ProtoWriteBuffer {
|
||||
* Following https://protobuf.dev/programming-guides/encoding/#structure
|
||||
*/
|
||||
void encode_field_raw(uint32_t field_id, uint32_t type) {
|
||||
uint32_t val = (field_id << 3) | (type & 0b111);
|
||||
uint32_t val = (field_id << 3) | (type & WIRE_TYPE_MASK);
|
||||
this->encode_varint_raw(val);
|
||||
}
|
||||
void encode_string(uint32_t field_id, const char *string, size_t len, bool force = false) {
|
||||
@@ -354,7 +361,18 @@ class ProtoMessage {
|
||||
// Base class for messages that support decoding
|
||||
class ProtoDecodableMessage : public ProtoMessage {
|
||||
public:
|
||||
void decode(const uint8_t *buffer, size_t length);
|
||||
virtual void decode(const uint8_t *buffer, size_t length);
|
||||
|
||||
/**
|
||||
* Count occurrences of a repeated field in a protobuf buffer.
|
||||
* This is a lightweight scan that only parses tags and skips field data.
|
||||
*
|
||||
* @param buffer Pointer to the protobuf buffer
|
||||
* @param length Length of the buffer in bytes
|
||||
* @param target_field_id The field ID to count
|
||||
* @return Number of times the field appears in the buffer
|
||||
*/
|
||||
static uint32_t count_repeated_field(const uint8_t *buffer, size_t length, uint32_t target_field_id);
|
||||
|
||||
protected:
|
||||
virtual bool decode_varint(uint32_t field_id, ProtoVarInt value) { return false; }
|
||||
@@ -482,7 +500,7 @@ class ProtoSize {
|
||||
* @return The number of bytes needed to encode the field ID and wire type
|
||||
*/
|
||||
static constexpr uint32_t field(uint32_t field_id, uint32_t type) {
|
||||
uint32_t tag = (field_id << 3) | (type & 0b111);
|
||||
uint32_t tag = (field_id << 3) | (type & WIRE_TYPE_MASK);
|
||||
return varint(tag);
|
||||
}
|
||||
|
||||
|
@@ -12,16 +12,16 @@ template<> int32_t get_execute_arg_value<int32_t>(const ExecuteServiceArgument &
|
||||
template<> float get_execute_arg_value<float>(const ExecuteServiceArgument &arg) { return arg.float_; }
|
||||
template<> std::string get_execute_arg_value<std::string>(const ExecuteServiceArgument &arg) { return arg.string_; }
|
||||
template<> std::vector<bool> get_execute_arg_value<std::vector<bool>>(const ExecuteServiceArgument &arg) {
|
||||
return arg.bool_array;
|
||||
return std::vector<bool>(arg.bool_array.begin(), arg.bool_array.end());
|
||||
}
|
||||
template<> std::vector<int32_t> get_execute_arg_value<std::vector<int32_t>>(const ExecuteServiceArgument &arg) {
|
||||
return arg.int_array;
|
||||
return std::vector<int32_t>(arg.int_array.begin(), arg.int_array.end());
|
||||
}
|
||||
template<> std::vector<float> get_execute_arg_value<std::vector<float>>(const ExecuteServiceArgument &arg) {
|
||||
return arg.float_array;
|
||||
return std::vector<float>(arg.float_array.begin(), arg.float_array.end());
|
||||
}
|
||||
template<> std::vector<std::string> get_execute_arg_value<std::vector<std::string>>(const ExecuteServiceArgument &arg) {
|
||||
return arg.string_array;
|
||||
return std::vector<std::string>(arg.string_array.begin(), arg.string_array.end());
|
||||
}
|
||||
|
||||
template<> enums::ServiceArgType to_service_arg_type<bool>() { return enums::SERVICE_ARG_TYPE_BOOL; }
|
||||
|
@@ -55,7 +55,7 @@ template<typename... Ts> class UserServiceBase : public UserServiceDescriptor {
|
||||
|
||||
protected:
|
||||
virtual void execute(Ts... x) = 0;
|
||||
template<int... S> void execute_(const std::vector<ExecuteServiceArgument> &args, seq<S...> type) {
|
||||
template<typename ArgsContainer, int... S> void execute_(const ArgsContainer &args, seq<S...> type) {
|
||||
this->execute((get_execute_arg_value<Ts>(args[S]))...);
|
||||
}
|
||||
|
||||
|
@@ -11,6 +11,7 @@ from typing import Any
|
||||
|
||||
import aioesphomeapi.api_options_pb2 as pb
|
||||
import google.protobuf.descriptor_pb2 as descriptor
|
||||
from google.protobuf.descriptor_pb2 import FieldDescriptorProto
|
||||
|
||||
|
||||
class WireType(IntEnum):
|
||||
@@ -148,7 +149,7 @@ class TypeInfo(ABC):
|
||||
@property
|
||||
def repeated(self) -> bool:
|
||||
"""Check if the field is repeated."""
|
||||
return self._field.label == 3
|
||||
return self._field.label == FieldDescriptorProto.LABEL_REPEATED
|
||||
|
||||
@property
|
||||
def wire_type(self) -> WireType:
|
||||
@@ -337,7 +338,7 @@ def create_field_type_info(
|
||||
needs_encode: bool = True,
|
||||
) -> TypeInfo:
|
||||
"""Create the appropriate TypeInfo instance for a field, handling repeated fields and custom options."""
|
||||
if field.label == 3: # repeated
|
||||
if field.label == FieldDescriptorProto.LABEL_REPEATED:
|
||||
# Check if this repeated field has fixed_array_with_length_define option
|
||||
if (
|
||||
fixed_size := get_field_opt(field, pb.fixed_array_with_length_define)
|
||||
@@ -1879,6 +1880,9 @@ def build_message_type(
|
||||
)
|
||||
public_content.append("#endif")
|
||||
|
||||
# Collect fixed_vector fields for custom decode generation
|
||||
fixed_vector_fields = []
|
||||
|
||||
for field in desc.field:
|
||||
# Skip deprecated fields completely
|
||||
if field.options.deprecated:
|
||||
@@ -1887,7 +1891,7 @@ def build_message_type(
|
||||
# Validate that fixed_array_size is only used in encode-only messages
|
||||
if (
|
||||
needs_decode
|
||||
and field.label == 3
|
||||
and field.label == FieldDescriptorProto.LABEL_REPEATED
|
||||
and get_field_opt(field, pb.fixed_array_size) is not None
|
||||
):
|
||||
raise ValueError(
|
||||
@@ -1900,7 +1904,7 @@ def build_message_type(
|
||||
# Validate that fixed_array_with_length_define is only used in encode-only messages
|
||||
if (
|
||||
needs_decode
|
||||
and field.label == 3
|
||||
and field.label == FieldDescriptorProto.LABEL_REPEATED
|
||||
and get_field_opt(field, pb.fixed_array_with_length_define) is not None
|
||||
):
|
||||
raise ValueError(
|
||||
@@ -1910,6 +1914,14 @@ def build_message_type(
|
||||
f"since we cannot trust or control the number of items received from clients."
|
||||
)
|
||||
|
||||
# Collect fixed_vector repeated fields for custom decode generation
|
||||
if (
|
||||
needs_decode
|
||||
and field.label == FieldDescriptorProto.LABEL_REPEATED
|
||||
and get_field_opt(field, pb.fixed_vector, False)
|
||||
):
|
||||
fixed_vector_fields.append((field.name, field.number))
|
||||
|
||||
ti = create_field_type_info(field, needs_decode, needs_encode)
|
||||
|
||||
# Skip field declarations for fields that are in the base class
|
||||
@@ -2018,6 +2030,22 @@ def build_message_type(
|
||||
prot = "bool decode_64bit(uint32_t field_id, Proto64Bit value) override;"
|
||||
protected_content.insert(0, prot)
|
||||
|
||||
# Generate custom decode() override for messages with FixedVector fields
|
||||
if fixed_vector_fields:
|
||||
# Generate the decode() implementation in cpp
|
||||
o = f"void {desc.name}::decode(const uint8_t *buffer, size_t length) {{\n"
|
||||
# Count and init each FixedVector field
|
||||
for field_name, field_number in fixed_vector_fields:
|
||||
o += f" uint32_t count_{field_name} = ProtoDecodableMessage::count_repeated_field(buffer, length, {field_number});\n"
|
||||
o += f" this->{field_name}.init(count_{field_name});\n"
|
||||
# Call parent decode to populate the fields
|
||||
o += " ProtoDecodableMessage::decode(buffer, length);\n"
|
||||
o += "}\n"
|
||||
cpp += o
|
||||
# Generate the decode() declaration in header (public method)
|
||||
prot = "void decode(const uint8_t *buffer, size_t length) override;"
|
||||
public_content.append(prot)
|
||||
|
||||
# Only generate encode method if this message needs encoding and has fields
|
||||
if needs_encode and encode:
|
||||
o = f"void {desc.name}::encode(ProtoWriteBuffer buffer) const {{"
|
||||
|
Reference in New Issue
Block a user