1
0
mirror of https://github.com/esphome/esphome.git synced 2025-10-20 18:53:47 +01:00

[api] Use FixedVector for ExecuteServiceRequest/Argument arrays to eliminate reallocations (#11270)

This commit is contained in:
J. Nick Koston
2025-10-15 17:22:08 -10:00
committed by GitHub
parent 18062d154f
commit 6943b1d985
8 changed files with 153 additions and 26 deletions

View File

@@ -876,10 +876,10 @@ message ExecuteServiceArgument {
string string_ = 4;
// ESPHome 1.14 (api v1.3) make int a signed value
sint32 int_ = 5;
repeated bool bool_array = 6 [packed=false];
repeated sint32 int_array = 7 [packed=false];
repeated float float_array = 8 [packed=false];
repeated string string_array = 9;
repeated bool bool_array = 6 [packed=false, (fixed_vector) = true];
repeated sint32 int_array = 7 [packed=false, (fixed_vector) = true];
repeated float float_array = 8 [packed=false, (fixed_vector) = true];
repeated string string_array = 9 [(fixed_vector) = true];
}
message ExecuteServiceRequest {
option (id) = 42;
@@ -888,7 +888,7 @@ message ExecuteServiceRequest {
option (ifdef) = "USE_API_SERVICES";
fixed32 key = 1;
repeated ExecuteServiceArgument args = 2;
repeated ExecuteServiceArgument args = 2 [(fixed_vector) = true];
}
// ==================== CAMERA ====================

View File

@@ -1064,6 +1064,17 @@ bool ExecuteServiceArgument::decode_32bit(uint32_t field_id, Proto32Bit value) {
}
return true;
}
void ExecuteServiceArgument::decode(const uint8_t *buffer, size_t length) {
uint32_t count_bool_array = ProtoDecodableMessage::count_repeated_field(buffer, length, 6);
this->bool_array.init(count_bool_array);
uint32_t count_int_array = ProtoDecodableMessage::count_repeated_field(buffer, length, 7);
this->int_array.init(count_int_array);
uint32_t count_float_array = ProtoDecodableMessage::count_repeated_field(buffer, length, 8);
this->float_array.init(count_float_array);
uint32_t count_string_array = ProtoDecodableMessage::count_repeated_field(buffer, length, 9);
this->string_array.init(count_string_array);
ProtoDecodableMessage::decode(buffer, length);
}
bool ExecuteServiceRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value) {
switch (field_id) {
case 2:
@@ -1085,6 +1096,11 @@ bool ExecuteServiceRequest::decode_32bit(uint32_t field_id, Proto32Bit value) {
}
return true;
}
void ExecuteServiceRequest::decode(const uint8_t *buffer, size_t length) {
uint32_t count_args = ProtoDecodableMessage::count_repeated_field(buffer, length, 2);
this->args.init(count_args);
ProtoDecodableMessage::decode(buffer, length);
}
#endif
#ifdef USE_CAMERA
void ListEntitiesCameraResponse::encode(ProtoWriteBuffer buffer) const {

View File

@@ -1279,10 +1279,11 @@ class ExecuteServiceArgument final : public ProtoDecodableMessage {
float float_{0.0f};
std::string string_{};
int32_t int_{0};
std::vector<bool> bool_array{};
std::vector<int32_t> int_array{};
std::vector<float> float_array{};
std::vector<std::string> string_array{};
FixedVector<bool> bool_array{};
FixedVector<int32_t> int_array{};
FixedVector<float> float_array{};
FixedVector<std::string> string_array{};
void decode(const uint8_t *buffer, size_t length) override;
#ifdef HAS_PROTO_MESSAGE_DUMP
void dump_to(std::string &out) const override;
#endif
@@ -1300,7 +1301,8 @@ class ExecuteServiceRequest final : public ProtoDecodableMessage {
const char *message_name() const override { return "execute_service_request"; }
#endif
uint32_t key{0};
std::vector<ExecuteServiceArgument> args{};
FixedVector<ExecuteServiceArgument> args{};
void decode(const uint8_t *buffer, size_t length) override;
#ifdef HAS_PROTO_MESSAGE_DUMP
void dump_to(std::string &out) const override;
#endif

View File

@@ -7,6 +7,69 @@ namespace esphome::api {
static const char *const TAG = "api.proto";
uint32_t ProtoDecodableMessage::count_repeated_field(const uint8_t *buffer, size_t length, uint32_t target_field_id) {
uint32_t count = 0;
const uint8_t *ptr = buffer;
const uint8_t *end = buffer + length;
while (ptr < end) {
uint32_t consumed;
// Parse field header (tag)
auto res = ProtoVarInt::parse(ptr, end - ptr, &consumed);
if (!res.has_value()) {
break; // Invalid data, stop counting
}
uint32_t tag = res->as_uint32();
uint32_t field_type = tag & WIRE_TYPE_MASK;
uint32_t field_id = tag >> 3;
ptr += consumed;
// Count if this is the target field
if (field_id == target_field_id) {
count++;
}
// Skip field data based on wire type
switch (field_type) {
case WIRE_TYPE_VARINT: { // VarInt - parse and skip
res = ProtoVarInt::parse(ptr, end - ptr, &consumed);
if (!res.has_value()) {
return count; // Invalid data, return what we have
}
ptr += consumed;
break;
}
case WIRE_TYPE_LENGTH_DELIMITED: { // Length-delimited - parse length and skip data
res = ProtoVarInt::parse(ptr, end - ptr, &consumed);
if (!res.has_value()) {
return count;
}
uint32_t field_length = res->as_uint32();
ptr += consumed;
if (ptr + field_length > end) {
return count; // Out of bounds
}
ptr += field_length;
break;
}
case WIRE_TYPE_FIXED32: { // 32-bit - skip 4 bytes
if (ptr + 4 > end) {
return count;
}
ptr += 4;
break;
}
default:
// Unknown wire type, can't continue
return count;
}
}
return count;
}
void ProtoDecodableMessage::decode(const uint8_t *buffer, size_t length) {
const uint8_t *ptr = buffer;
const uint8_t *end = buffer + length;
@@ -22,12 +85,12 @@ void ProtoDecodableMessage::decode(const uint8_t *buffer, size_t length) {
}
uint32_t tag = res->as_uint32();
uint32_t field_type = tag & 0b111;
uint32_t field_type = tag & WIRE_TYPE_MASK;
uint32_t field_id = tag >> 3;
ptr += consumed;
switch (field_type) {
case 0: { // VarInt
case WIRE_TYPE_VARINT: { // VarInt
res = ProtoVarInt::parse(ptr, end - ptr, &consumed);
if (!res.has_value()) {
ESP_LOGV(TAG, "Invalid VarInt at offset %ld", (long) (ptr - buffer));
@@ -39,7 +102,7 @@ void ProtoDecodableMessage::decode(const uint8_t *buffer, size_t length) {
ptr += consumed;
break;
}
case 2: { // Length-delimited
case WIRE_TYPE_LENGTH_DELIMITED: { // Length-delimited
res = ProtoVarInt::parse(ptr, end - ptr, &consumed);
if (!res.has_value()) {
ESP_LOGV(TAG, "Invalid Length Delimited at offset %ld", (long) (ptr - buffer));
@@ -57,7 +120,7 @@ void ProtoDecodableMessage::decode(const uint8_t *buffer, size_t length) {
ptr += field_length;
break;
}
case 5: { // 32-bit
case WIRE_TYPE_FIXED32: { // 32-bit
if (ptr + 4 > end) {
ESP_LOGV(TAG, "Out-of-bounds Fixed32-bit at offset %ld", (long) (ptr - buffer));
return;

View File

@@ -15,6 +15,13 @@
namespace esphome::api {
// Protocol Buffer wire type constants
// See https://protobuf.dev/programming-guides/encoding/#structure
constexpr uint8_t WIRE_TYPE_VARINT = 0; // int32, int64, uint32, uint64, sint32, sint64, bool, enum
constexpr uint8_t WIRE_TYPE_LENGTH_DELIMITED = 2; // string, bytes, embedded messages, packed repeated fields
constexpr uint8_t WIRE_TYPE_FIXED32 = 5; // fixed32, sfixed32, float
constexpr uint8_t WIRE_TYPE_MASK = 0b111; // Mask to extract wire type from tag
// Helper functions for ZigZag encoding/decoding
inline constexpr uint32_t encode_zigzag32(int32_t value) {
return (static_cast<uint32_t>(value) << 1) ^ (static_cast<uint32_t>(value >> 31));
@@ -241,7 +248,7 @@ class ProtoWriteBuffer {
* Following https://protobuf.dev/programming-guides/encoding/#structure
*/
void encode_field_raw(uint32_t field_id, uint32_t type) {
uint32_t val = (field_id << 3) | (type & 0b111);
uint32_t val = (field_id << 3) | (type & WIRE_TYPE_MASK);
this->encode_varint_raw(val);
}
void encode_string(uint32_t field_id, const char *string, size_t len, bool force = false) {
@@ -354,7 +361,18 @@ class ProtoMessage {
// Base class for messages that support decoding
class ProtoDecodableMessage : public ProtoMessage {
public:
void decode(const uint8_t *buffer, size_t length);
virtual void decode(const uint8_t *buffer, size_t length);
/**
* Count occurrences of a repeated field in a protobuf buffer.
* This is a lightweight scan that only parses tags and skips field data.
*
* @param buffer Pointer to the protobuf buffer
* @param length Length of the buffer in bytes
* @param target_field_id The field ID to count
* @return Number of times the field appears in the buffer
*/
static uint32_t count_repeated_field(const uint8_t *buffer, size_t length, uint32_t target_field_id);
protected:
virtual bool decode_varint(uint32_t field_id, ProtoVarInt value) { return false; }
@@ -482,7 +500,7 @@ class ProtoSize {
* @return The number of bytes needed to encode the field ID and wire type
*/
static constexpr uint32_t field(uint32_t field_id, uint32_t type) {
uint32_t tag = (field_id << 3) | (type & 0b111);
uint32_t tag = (field_id << 3) | (type & WIRE_TYPE_MASK);
return varint(tag);
}

View File

@@ -12,16 +12,16 @@ template<> int32_t get_execute_arg_value<int32_t>(const ExecuteServiceArgument &
template<> float get_execute_arg_value<float>(const ExecuteServiceArgument &arg) { return arg.float_; }
template<> std::string get_execute_arg_value<std::string>(const ExecuteServiceArgument &arg) { return arg.string_; }
template<> std::vector<bool> get_execute_arg_value<std::vector<bool>>(const ExecuteServiceArgument &arg) {
return arg.bool_array;
return std::vector<bool>(arg.bool_array.begin(), arg.bool_array.end());
}
template<> std::vector<int32_t> get_execute_arg_value<std::vector<int32_t>>(const ExecuteServiceArgument &arg) {
return arg.int_array;
return std::vector<int32_t>(arg.int_array.begin(), arg.int_array.end());
}
template<> std::vector<float> get_execute_arg_value<std::vector<float>>(const ExecuteServiceArgument &arg) {
return arg.float_array;
return std::vector<float>(arg.float_array.begin(), arg.float_array.end());
}
template<> std::vector<std::string> get_execute_arg_value<std::vector<std::string>>(const ExecuteServiceArgument &arg) {
return arg.string_array;
return std::vector<std::string>(arg.string_array.begin(), arg.string_array.end());
}
template<> enums::ServiceArgType to_service_arg_type<bool>() { return enums::SERVICE_ARG_TYPE_BOOL; }

View File

@@ -55,7 +55,7 @@ template<typename... Ts> class UserServiceBase : public UserServiceDescriptor {
protected:
virtual void execute(Ts... x) = 0;
template<int... S> void execute_(const std::vector<ExecuteServiceArgument> &args, seq<S...> type) {
template<typename ArgsContainer, int... S> void execute_(const ArgsContainer &args, seq<S...> type) {
this->execute((get_execute_arg_value<Ts>(args[S]))...);
}

View File

@@ -11,6 +11,7 @@ from typing import Any
import aioesphomeapi.api_options_pb2 as pb
import google.protobuf.descriptor_pb2 as descriptor
from google.protobuf.descriptor_pb2 import FieldDescriptorProto
class WireType(IntEnum):
@@ -148,7 +149,7 @@ class TypeInfo(ABC):
@property
def repeated(self) -> bool:
"""Check if the field is repeated."""
return self._field.label == 3
return self._field.label == FieldDescriptorProto.LABEL_REPEATED
@property
def wire_type(self) -> WireType:
@@ -337,7 +338,7 @@ def create_field_type_info(
needs_encode: bool = True,
) -> TypeInfo:
"""Create the appropriate TypeInfo instance for a field, handling repeated fields and custom options."""
if field.label == 3: # repeated
if field.label == FieldDescriptorProto.LABEL_REPEATED:
# Check if this repeated field has fixed_array_with_length_define option
if (
fixed_size := get_field_opt(field, pb.fixed_array_with_length_define)
@@ -1879,6 +1880,9 @@ def build_message_type(
)
public_content.append("#endif")
# Collect fixed_vector fields for custom decode generation
fixed_vector_fields = []
for field in desc.field:
# Skip deprecated fields completely
if field.options.deprecated:
@@ -1887,7 +1891,7 @@ def build_message_type(
# Validate that fixed_array_size is only used in encode-only messages
if (
needs_decode
and field.label == 3
and field.label == FieldDescriptorProto.LABEL_REPEATED
and get_field_opt(field, pb.fixed_array_size) is not None
):
raise ValueError(
@@ -1900,7 +1904,7 @@ def build_message_type(
# Validate that fixed_array_with_length_define is only used in encode-only messages
if (
needs_decode
and field.label == 3
and field.label == FieldDescriptorProto.LABEL_REPEATED
and get_field_opt(field, pb.fixed_array_with_length_define) is not None
):
raise ValueError(
@@ -1910,6 +1914,14 @@ def build_message_type(
f"since we cannot trust or control the number of items received from clients."
)
# Collect fixed_vector repeated fields for custom decode generation
if (
needs_decode
and field.label == FieldDescriptorProto.LABEL_REPEATED
and get_field_opt(field, pb.fixed_vector, False)
):
fixed_vector_fields.append((field.name, field.number))
ti = create_field_type_info(field, needs_decode, needs_encode)
# Skip field declarations for fields that are in the base class
@@ -2018,6 +2030,22 @@ def build_message_type(
prot = "bool decode_64bit(uint32_t field_id, Proto64Bit value) override;"
protected_content.insert(0, prot)
# Generate custom decode() override for messages with FixedVector fields
if fixed_vector_fields:
# Generate the decode() implementation in cpp
o = f"void {desc.name}::decode(const uint8_t *buffer, size_t length) {{\n"
# Count and init each FixedVector field
for field_name, field_number in fixed_vector_fields:
o += f" uint32_t count_{field_name} = ProtoDecodableMessage::count_repeated_field(buffer, length, {field_number});\n"
o += f" this->{field_name}.init(count_{field_name});\n"
# Call parent decode to populate the fields
o += " ProtoDecodableMessage::decode(buffer, length);\n"
o += "}\n"
cpp += o
# Generate the decode() declaration in header (public method)
prot = "void decode(const uint8_t *buffer, size_t length) override;"
public_content.append(prot)
# Only generate encode method if this message needs encoding and has fields
if needs_encode and encode:
o = f"void {desc.name}::encode(ProtoWriteBuffer buffer) const {{"