mirror of
https://github.com/esphome/esphome.git
synced 2025-10-24 04:33:49 +01:00
[api] Use FixedVector for ExecuteServiceRequest/Argument arrays to eliminate reallocations
This commit is contained in:
@@ -876,10 +876,10 @@ message ExecuteServiceArgument {
|
|||||||
string string_ = 4;
|
string string_ = 4;
|
||||||
// ESPHome 1.14 (api v1.3) make int a signed value
|
// ESPHome 1.14 (api v1.3) make int a signed value
|
||||||
sint32 int_ = 5;
|
sint32 int_ = 5;
|
||||||
repeated bool bool_array = 6 [packed=false];
|
repeated bool bool_array = 6 [packed=false, (fixed_vector) = true];
|
||||||
repeated sint32 int_array = 7 [packed=false];
|
repeated sint32 int_array = 7 [packed=false, (fixed_vector) = true];
|
||||||
repeated float float_array = 8 [packed=false];
|
repeated float float_array = 8 [packed=false, (fixed_vector) = true];
|
||||||
repeated string string_array = 9;
|
repeated string string_array = 9 [(fixed_vector) = true];
|
||||||
}
|
}
|
||||||
message ExecuteServiceRequest {
|
message ExecuteServiceRequest {
|
||||||
option (id) = 42;
|
option (id) = 42;
|
||||||
@@ -888,7 +888,7 @@ message ExecuteServiceRequest {
|
|||||||
option (ifdef) = "USE_API_SERVICES";
|
option (ifdef) = "USE_API_SERVICES";
|
||||||
|
|
||||||
fixed32 key = 1;
|
fixed32 key = 1;
|
||||||
repeated ExecuteServiceArgument args = 2;
|
repeated ExecuteServiceArgument args = 2 [(fixed_vector) = true];
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================== CAMERA ====================
|
// ==================== CAMERA ====================
|
||||||
|
@@ -1064,6 +1064,17 @@ bool ExecuteServiceArgument::decode_32bit(uint32_t field_id, Proto32Bit value) {
|
|||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
void ExecuteServiceArgument::decode(const uint8_t *buffer, size_t length) {
|
||||||
|
uint32_t count_bool_array = ProtoDecodableMessage::count_repeated_field(buffer, length, 6);
|
||||||
|
this->bool_array.init(count_bool_array);
|
||||||
|
uint32_t count_int_array = ProtoDecodableMessage::count_repeated_field(buffer, length, 7);
|
||||||
|
this->int_array.init(count_int_array);
|
||||||
|
uint32_t count_float_array = ProtoDecodableMessage::count_repeated_field(buffer, length, 8);
|
||||||
|
this->float_array.init(count_float_array);
|
||||||
|
uint32_t count_string_array = ProtoDecodableMessage::count_repeated_field(buffer, length, 9);
|
||||||
|
this->string_array.init(count_string_array);
|
||||||
|
ProtoDecodableMessage::decode(buffer, length);
|
||||||
|
}
|
||||||
bool ExecuteServiceRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value) {
|
bool ExecuteServiceRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value) {
|
||||||
switch (field_id) {
|
switch (field_id) {
|
||||||
case 2:
|
case 2:
|
||||||
@@ -1085,6 +1096,11 @@ bool ExecuteServiceRequest::decode_32bit(uint32_t field_id, Proto32Bit value) {
|
|||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
void ExecuteServiceRequest::decode(const uint8_t *buffer, size_t length) {
|
||||||
|
uint32_t count_args = ProtoDecodableMessage::count_repeated_field(buffer, length, 2);
|
||||||
|
this->args.init(count_args);
|
||||||
|
ProtoDecodableMessage::decode(buffer, length);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
#ifdef USE_CAMERA
|
#ifdef USE_CAMERA
|
||||||
void ListEntitiesCameraResponse::encode(ProtoWriteBuffer buffer) const {
|
void ListEntitiesCameraResponse::encode(ProtoWriteBuffer buffer) const {
|
||||||
|
@@ -1279,10 +1279,11 @@ class ExecuteServiceArgument final : public ProtoDecodableMessage {
|
|||||||
float float_{0.0f};
|
float float_{0.0f};
|
||||||
std::string string_{};
|
std::string string_{};
|
||||||
int32_t int_{0};
|
int32_t int_{0};
|
||||||
std::vector<bool> bool_array{};
|
FixedVector<bool> bool_array{};
|
||||||
std::vector<int32_t> int_array{};
|
FixedVector<int32_t> int_array{};
|
||||||
std::vector<float> float_array{};
|
FixedVector<float> float_array{};
|
||||||
std::vector<std::string> string_array{};
|
FixedVector<std::string> string_array{};
|
||||||
|
void decode(const uint8_t *buffer, size_t length) override;
|
||||||
#ifdef HAS_PROTO_MESSAGE_DUMP
|
#ifdef HAS_PROTO_MESSAGE_DUMP
|
||||||
void dump_to(std::string &out) const override;
|
void dump_to(std::string &out) const override;
|
||||||
#endif
|
#endif
|
||||||
@@ -1300,7 +1301,8 @@ class ExecuteServiceRequest final : public ProtoDecodableMessage {
|
|||||||
const char *message_name() const override { return "execute_service_request"; }
|
const char *message_name() const override { return "execute_service_request"; }
|
||||||
#endif
|
#endif
|
||||||
uint32_t key{0};
|
uint32_t key{0};
|
||||||
std::vector<ExecuteServiceArgument> args{};
|
FixedVector<ExecuteServiceArgument> args{};
|
||||||
|
void decode(const uint8_t *buffer, size_t length) override;
|
||||||
#ifdef HAS_PROTO_MESSAGE_DUMP
|
#ifdef HAS_PROTO_MESSAGE_DUMP
|
||||||
void dump_to(std::string &out) const override;
|
void dump_to(std::string &out) const override;
|
||||||
#endif
|
#endif
|
||||||
|
@@ -7,6 +7,69 @@ namespace esphome::api {
|
|||||||
|
|
||||||
static const char *const TAG = "api.proto";
|
static const char *const TAG = "api.proto";
|
||||||
|
|
||||||
|
uint32_t ProtoDecodableMessage::count_repeated_field(const uint8_t *buffer, size_t length, uint32_t target_field_id) {
|
||||||
|
uint32_t count = 0;
|
||||||
|
const uint8_t *ptr = buffer;
|
||||||
|
const uint8_t *end = buffer + length;
|
||||||
|
|
||||||
|
while (ptr < end) {
|
||||||
|
uint32_t consumed;
|
||||||
|
|
||||||
|
// Parse field header (tag)
|
||||||
|
auto res = ProtoVarInt::parse(ptr, end - ptr, &consumed);
|
||||||
|
if (!res.has_value()) {
|
||||||
|
break; // Invalid data, stop counting
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t tag = res->as_uint32();
|
||||||
|
uint32_t field_type = tag & 0b111;
|
||||||
|
uint32_t field_id = tag >> 3;
|
||||||
|
ptr += consumed;
|
||||||
|
|
||||||
|
// Count if this is the target field
|
||||||
|
if (field_id == target_field_id) {
|
||||||
|
count++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip field data based on wire type
|
||||||
|
switch (field_type) {
|
||||||
|
case 0: { // VarInt - parse and skip
|
||||||
|
res = ProtoVarInt::parse(ptr, end - ptr, &consumed);
|
||||||
|
if (!res.has_value()) {
|
||||||
|
return count; // Invalid data, return what we have
|
||||||
|
}
|
||||||
|
ptr += consumed;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2: { // Length-delimited - parse length and skip data
|
||||||
|
res = ProtoVarInt::parse(ptr, end - ptr, &consumed);
|
||||||
|
if (!res.has_value()) {
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
uint32_t field_length = res->as_uint32();
|
||||||
|
ptr += consumed;
|
||||||
|
if (ptr + field_length > end) {
|
||||||
|
return count; // Out of bounds
|
||||||
|
}
|
||||||
|
ptr += field_length;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 5: { // 32-bit - skip 4 bytes
|
||||||
|
if (ptr + 4 > end) {
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
ptr += 4;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Unknown wire type, can't continue
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
void ProtoDecodableMessage::decode(const uint8_t *buffer, size_t length) {
|
void ProtoDecodableMessage::decode(const uint8_t *buffer, size_t length) {
|
||||||
const uint8_t *ptr = buffer;
|
const uint8_t *ptr = buffer;
|
||||||
const uint8_t *end = buffer + length;
|
const uint8_t *end = buffer + length;
|
||||||
|
@@ -354,7 +354,18 @@ class ProtoMessage {
|
|||||||
// Base class for messages that support decoding
|
// Base class for messages that support decoding
|
||||||
class ProtoDecodableMessage : public ProtoMessage {
|
class ProtoDecodableMessage : public ProtoMessage {
|
||||||
public:
|
public:
|
||||||
void decode(const uint8_t *buffer, size_t length);
|
virtual void decode(const uint8_t *buffer, size_t length);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Count occurrences of a repeated field in a protobuf buffer.
|
||||||
|
* This is a lightweight scan that only parses tags and skips field data.
|
||||||
|
*
|
||||||
|
* @param buffer Pointer to the protobuf buffer
|
||||||
|
* @param length Length of the buffer in bytes
|
||||||
|
* @param target_field_id The field ID to count
|
||||||
|
* @return Number of times the field appears in the buffer
|
||||||
|
*/
|
||||||
|
static uint32_t count_repeated_field(const uint8_t *buffer, size_t length, uint32_t target_field_id);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual bool decode_varint(uint32_t field_id, ProtoVarInt value) { return false; }
|
virtual bool decode_varint(uint32_t field_id, ProtoVarInt value) { return false; }
|
||||||
|
@@ -12,16 +12,16 @@ template<> int32_t get_execute_arg_value<int32_t>(const ExecuteServiceArgument &
|
|||||||
template<> float get_execute_arg_value<float>(const ExecuteServiceArgument &arg) { return arg.float_; }
|
template<> float get_execute_arg_value<float>(const ExecuteServiceArgument &arg) { return arg.float_; }
|
||||||
template<> std::string get_execute_arg_value<std::string>(const ExecuteServiceArgument &arg) { return arg.string_; }
|
template<> std::string get_execute_arg_value<std::string>(const ExecuteServiceArgument &arg) { return arg.string_; }
|
||||||
template<> std::vector<bool> get_execute_arg_value<std::vector<bool>>(const ExecuteServiceArgument &arg) {
|
template<> std::vector<bool> get_execute_arg_value<std::vector<bool>>(const ExecuteServiceArgument &arg) {
|
||||||
return arg.bool_array;
|
return std::vector<bool>(arg.bool_array.begin(), arg.bool_array.end());
|
||||||
}
|
}
|
||||||
template<> std::vector<int32_t> get_execute_arg_value<std::vector<int32_t>>(const ExecuteServiceArgument &arg) {
|
template<> std::vector<int32_t> get_execute_arg_value<std::vector<int32_t>>(const ExecuteServiceArgument &arg) {
|
||||||
return arg.int_array;
|
return std::vector<int32_t>(arg.int_array.begin(), arg.int_array.end());
|
||||||
}
|
}
|
||||||
template<> std::vector<float> get_execute_arg_value<std::vector<float>>(const ExecuteServiceArgument &arg) {
|
template<> std::vector<float> get_execute_arg_value<std::vector<float>>(const ExecuteServiceArgument &arg) {
|
||||||
return arg.float_array;
|
return std::vector<float>(arg.float_array.begin(), arg.float_array.end());
|
||||||
}
|
}
|
||||||
template<> std::vector<std::string> get_execute_arg_value<std::vector<std::string>>(const ExecuteServiceArgument &arg) {
|
template<> std::vector<std::string> get_execute_arg_value<std::vector<std::string>>(const ExecuteServiceArgument &arg) {
|
||||||
return arg.string_array;
|
return std::vector<std::string>(arg.string_array.begin(), arg.string_array.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
template<> enums::ServiceArgType to_service_arg_type<bool>() { return enums::SERVICE_ARG_TYPE_BOOL; }
|
template<> enums::ServiceArgType to_service_arg_type<bool>() { return enums::SERVICE_ARG_TYPE_BOOL; }
|
||||||
|
@@ -55,7 +55,7 @@ template<typename... Ts> class UserServiceBase : public UserServiceDescriptor {
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual void execute(Ts... x) = 0;
|
virtual void execute(Ts... x) = 0;
|
||||||
template<int... S> void execute_(const std::vector<ExecuteServiceArgument> &args, seq<S...> type) {
|
template<typename ArgsContainer, int... S> void execute_(const ArgsContainer &args, seq<S...> type) {
|
||||||
this->execute((get_execute_arg_value<Ts>(args[S]))...);
|
this->execute((get_execute_arg_value<Ts>(args[S]))...);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1879,6 +1879,9 @@ def build_message_type(
|
|||||||
)
|
)
|
||||||
public_content.append("#endif")
|
public_content.append("#endif")
|
||||||
|
|
||||||
|
# Collect fixed_vector fields for custom decode generation
|
||||||
|
fixed_vector_fields = []
|
||||||
|
|
||||||
for field in desc.field:
|
for field in desc.field:
|
||||||
# Skip deprecated fields completely
|
# Skip deprecated fields completely
|
||||||
if field.options.deprecated:
|
if field.options.deprecated:
|
||||||
@@ -1910,6 +1913,14 @@ def build_message_type(
|
|||||||
f"since we cannot trust or control the number of items received from clients."
|
f"since we cannot trust or control the number of items received from clients."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Collect fixed_vector repeated fields for custom decode generation
|
||||||
|
if (
|
||||||
|
needs_decode
|
||||||
|
and field.label == 3
|
||||||
|
and get_field_opt(field, pb.fixed_vector, False)
|
||||||
|
):
|
||||||
|
fixed_vector_fields.append((field.name, field.number))
|
||||||
|
|
||||||
ti = create_field_type_info(field, needs_decode, needs_encode)
|
ti = create_field_type_info(field, needs_decode, needs_encode)
|
||||||
|
|
||||||
# Skip field declarations for fields that are in the base class
|
# Skip field declarations for fields that are in the base class
|
||||||
@@ -2018,6 +2029,22 @@ def build_message_type(
|
|||||||
prot = "bool decode_64bit(uint32_t field_id, Proto64Bit value) override;"
|
prot = "bool decode_64bit(uint32_t field_id, Proto64Bit value) override;"
|
||||||
protected_content.insert(0, prot)
|
protected_content.insert(0, prot)
|
||||||
|
|
||||||
|
# Generate custom decode() override for messages with FixedVector fields
|
||||||
|
if fixed_vector_fields:
|
||||||
|
# Generate the decode() implementation in cpp
|
||||||
|
o = f"void {desc.name}::decode(const uint8_t *buffer, size_t length) {{\n"
|
||||||
|
# Count and init each FixedVector field
|
||||||
|
for field_name, field_number in fixed_vector_fields:
|
||||||
|
o += f" uint32_t count_{field_name} = ProtoDecodableMessage::count_repeated_field(buffer, length, {field_number});\n"
|
||||||
|
o += f" this->{field_name}.init(count_{field_name});\n"
|
||||||
|
# Call parent decode to populate the fields
|
||||||
|
o += " ProtoDecodableMessage::decode(buffer, length);\n"
|
||||||
|
o += "}\n"
|
||||||
|
cpp += o
|
||||||
|
# Generate the decode() declaration in header (public method)
|
||||||
|
prot = "void decode(const uint8_t *buffer, size_t length) override;"
|
||||||
|
public_content.append(prot)
|
||||||
|
|
||||||
# Only generate encode method if this message needs encoding and has fields
|
# Only generate encode method if this message needs encoding and has fields
|
||||||
if needs_encode and encode:
|
if needs_encode and encode:
|
||||||
o = f"void {desc.name}::encode(ProtoWriteBuffer buffer) const {{"
|
o = f"void {desc.name}::encode(ProtoWriteBuffer buffer) const {{"
|
||||||
|
Reference in New Issue
Block a user