1
0
mirror of https://github.com/esphome/esphome.git synced 2025-09-01 10:52:19 +01:00

Reserve buffer space to avoid frequent realloc when generating protobuf messages (#8707)

This commit is contained in:
J. Nick Koston
2025-05-07 21:56:54 -05:00
committed by GitHub
parent d60e1f02c0
commit 54ead9a6b4
7 changed files with 1705 additions and 7 deletions

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from enum import IntEnum
import os
from pathlib import Path
import re
@@ -10,11 +11,29 @@ import sys
from textwrap import dedent
from typing import Any
# Generate with
# protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto
import aioesphomeapi.api_options_pb2 as pb
import google.protobuf.descriptor_pb2 as descriptor
class WireType(IntEnum):
"""Protocol Buffer wire types as defined in the protobuf spec.
As specified in the Protocol Buffers encoding guide:
https://protobuf.dev/programming-guides/encoding/#structure
"""
VARINT = 0 # int32, int64, uint32, uint64, sint32, sint64, bool, enum
FIXED64 = 1 # fixed64, sfixed64, double
LENGTH_DELIMITED = 2 # string, bytes, embedded messages, packed repeated fields
START_GROUP = 3 # groups (deprecated)
END_GROUP = 4 # groups (deprecated)
FIXED32 = 5 # fixed32, sfixed32, float
# Generate with
# protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto
"""Python 3 script to automatically generate C++ classes for ESPHome's native API.
It's pretty crappy spaghetti code, but it works.
@@ -35,7 +54,7 @@ will be generated, they still need to be formatted
FILE_HEADER = """// This file was automatically generated with a tool.
// See scripts/api_protobuf/api_protobuf.py
// See script/api_protobuf/api_protobuf.py
"""
@@ -63,6 +82,11 @@ def camel_to_snake(name: str) -> str:
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
def force_str(force: bool) -> str:
"""Convert a boolean force value to string format for C++ code."""
return str(force).lower()
class TypeInfo(ABC):
"""Base class for all type information."""
@@ -99,6 +123,11 @@ class TypeInfo(ABC):
"""Check if the field is repeated."""
return self._field.label == 3
@property
def wire_type(self) -> WireType:
"""Get the wire type for the field."""
raise NotImplementedError
@property
def cpp_type(self) -> str:
raise NotImplementedError
@@ -200,6 +229,35 @@ class TypeInfo(ABC):
def dump(self, name: str) -> str:
"""Dump the value to the output."""
def calculate_field_id_size(self) -> int:
"""Calculates the size of a field ID in bytes.
Returns:
The number of bytes needed to encode the field ID
"""
# Calculate the tag by combining field_id and wire_type
tag = (self.number << 3) | (self.wire_type & 0b111)
# Calculate the varint size
if tag < 128:
return 1 # 7 bits
if tag < 16384:
return 2 # 14 bits
if tag < 2097152:
return 3 # 21 bits
if tag < 268435456:
return 4 # 28 bits
return 5 # 32 bits (maximum for uint32_t)
@abstractmethod
def get_size_calculation(self, name: str, force: bool = False) -> str:
"""Calculate the size needed for encoding this field.
Args:
name: The name of the field
force: Whether to force encoding the field even if it has a default value
"""
TYPE_INFO: dict[int, TypeInfo] = {}
@@ -221,12 +279,18 @@ class DoubleType(TypeInfo):
default_value = "0.0"
decode_64bit = "value.as_double()"
encode_func = "encode_double"
wire_type = WireType.FIXED64 # Uses wire type 1 according to protobuf spec
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%g", {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_fixed_field<8>(total_size, {field_id_size}, {name} != 0.0, {force_str(force)});"
return o
@register_type(2)
class FloatType(TypeInfo):
@@ -234,12 +298,18 @@ class FloatType(TypeInfo):
default_value = "0.0f"
decode_32bit = "value.as_float()"
encode_func = "encode_float"
wire_type = WireType.FIXED32 # Uses wire type 5
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%g", {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_fixed_field<4>(total_size, {field_id_size}, {name} != 0.0f, {force_str(force)});"
return o
@register_type(3)
class Int64Type(TypeInfo):
@@ -247,12 +317,18 @@ class Int64Type(TypeInfo):
default_value = "0"
decode_varint = "value.as_int64()"
encode_func = "encode_int64"
wire_type = WireType.VARINT # Uses wire type 0
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%lld", {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_int64_field(total_size, {field_id_size}, {name}, {force_str(force)});"
return o
@register_type(4)
class UInt64Type(TypeInfo):
@@ -260,12 +336,18 @@ class UInt64Type(TypeInfo):
default_value = "0"
decode_varint = "value.as_uint64()"
encode_func = "encode_uint64"
wire_type = WireType.VARINT # Uses wire type 0
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%llu", {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_uint64_field(total_size, {field_id_size}, {name}, {force_str(force)});"
return o
@register_type(5)
class Int32Type(TypeInfo):
@@ -273,12 +355,18 @@ class Int32Type(TypeInfo):
default_value = "0"
decode_varint = "value.as_int32()"
encode_func = "encode_int32"
wire_type = WireType.VARINT # Uses wire type 0
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%" PRId32, {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_int32_field(total_size, {field_id_size}, {name}, {force_str(force)});"
return o
@register_type(6)
class Fixed64Type(TypeInfo):
@@ -286,12 +374,18 @@ class Fixed64Type(TypeInfo):
default_value = "0"
decode_64bit = "value.as_fixed64()"
encode_func = "encode_fixed64"
wire_type = WireType.FIXED64 # Uses wire type 1
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%llu", {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_fixed_field<8>(total_size, {field_id_size}, {name} != 0, {force_str(force)});"
return o
@register_type(7)
class Fixed32Type(TypeInfo):
@@ -299,12 +393,18 @@ class Fixed32Type(TypeInfo):
default_value = "0"
decode_32bit = "value.as_fixed32()"
encode_func = "encode_fixed32"
wire_type = WireType.FIXED32 # Uses wire type 5
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%" PRIu32, {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_fixed_field<4>(total_size, {field_id_size}, {name} != 0, {force_str(force)});"
return o
@register_type(8)
class BoolType(TypeInfo):
@@ -312,11 +412,17 @@ class BoolType(TypeInfo):
default_value = "false"
decode_varint = "value.as_bool()"
encode_func = "encode_bool"
wire_type = WireType.VARINT # Uses wire type 0
def dump(self, name: str) -> str:
o = f"out.append(YESNO({name}));"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_bool_field(total_size, {field_id_size}, {name}, {force_str(force)});"
return o
@register_type(9)
class StringType(TypeInfo):
@@ -326,11 +432,17 @@ class StringType(TypeInfo):
const_reference_type = "const std::string &"
decode_length = "value.as_string()"
encode_func = "encode_string"
wire_type = WireType.LENGTH_DELIMITED # Uses wire type 2
def dump(self, name):
o = f'out.append("\'").append({name}).append("\'");'
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_string_field(total_size, {field_id_size}, {name}, {force_str(force)});"
return o
@register_type(11)
class MessageType(TypeInfo):
@@ -339,6 +451,7 @@ class MessageType(TypeInfo):
return self._field.type_name[1:]
default_value = ""
wire_type = WireType.LENGTH_DELIMITED # Uses wire type 2
@property
def reference_type(self) -> str:
@@ -360,6 +473,11 @@ class MessageType(TypeInfo):
o = f"{name}.dump_to(out);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_message_object(total_size, {field_id_size}, {name}, {force_str(force)});"
return o
@register_type(12)
class BytesType(TypeInfo):
@@ -369,11 +487,17 @@ class BytesType(TypeInfo):
const_reference_type = "const std::string &"
decode_length = "value.as_string()"
encode_func = "encode_string"
wire_type = WireType.LENGTH_DELIMITED # Uses wire type 2
def dump(self, name: str) -> str:
o = f'out.append("\'").append({name}).append("\'");'
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_string_field(total_size, {field_id_size}, {name}, {force_str(force)});"
return o
@register_type(13)
class UInt32Type(TypeInfo):
@@ -381,12 +505,18 @@ class UInt32Type(TypeInfo):
default_value = "0"
decode_varint = "value.as_uint32()"
encode_func = "encode_uint32"
wire_type = WireType.VARINT # Uses wire type 0
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%" PRIu32, {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_uint32_field(total_size, {field_id_size}, {name}, {force_str(force)});"
return o
@register_type(14)
class EnumType(TypeInfo):
@@ -399,6 +529,7 @@ class EnumType(TypeInfo):
return f"value.as_enum<{self.cpp_type}>()"
default_value = ""
wire_type = WireType.VARINT # Uses wire type 0
@property
def encode_func(self) -> str:
@@ -408,6 +539,11 @@ class EnumType(TypeInfo):
o = f"out.append(proto_enum_to_string<{self.cpp_type}>({name}));"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_enum_field(total_size, {field_id_size}, static_cast<uint32_t>({name}), {force_str(force)});"
return o
@register_type(15)
class SFixed32Type(TypeInfo):
@@ -415,12 +551,18 @@ class SFixed32Type(TypeInfo):
default_value = "0"
decode_32bit = "value.as_sfixed32()"
encode_func = "encode_sfixed32"
wire_type = WireType.FIXED32 # Uses wire type 5
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%" PRId32, {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_fixed_field<4>(total_size, {field_id_size}, {name} != 0, {force_str(force)});"
return o
@register_type(16)
class SFixed64Type(TypeInfo):
@@ -428,12 +570,18 @@ class SFixed64Type(TypeInfo):
default_value = "0"
decode_64bit = "value.as_sfixed64()"
encode_func = "encode_sfixed64"
wire_type = WireType.FIXED64 # Uses wire type 1
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%lld", {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_fixed_field<8>(total_size, {field_id_size}, {name} != 0, {force_str(force)});"
return o
@register_type(17)
class SInt32Type(TypeInfo):
@@ -441,12 +589,18 @@ class SInt32Type(TypeInfo):
default_value = "0"
decode_varint = "value.as_sint32()"
encode_func = "encode_sint32"
wire_type = WireType.VARINT # Uses wire type 0
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%" PRId32, {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_sint32_field(total_size, {field_id_size}, {name}, {force_str(force)});"
return o
@register_type(18)
class SInt64Type(TypeInfo):
@@ -454,12 +608,18 @@ class SInt64Type(TypeInfo):
default_value = "0"
decode_varint = "value.as_sint64()"
encode_func = "encode_sint64"
wire_type = WireType.VARINT # Uses wire type 0
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%lld", {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_sint64_field(total_size, {field_id_size}, {name}, {force_str(force)});"
return o
class RepeatedTypeInfo(TypeInfo):
def __init__(self, field: descriptor.FieldDescriptorProto) -> None:
@@ -478,6 +638,14 @@ class RepeatedTypeInfo(TypeInfo):
def const_reference_type(self) -> str:
return f"const {self.cpp_type} &"
@property
def wire_type(self) -> WireType:
"""Get the wire type for this repeated field.
For repeated fields, we use the same wire type as the underlying field.
"""
return self._ti.wire_type
@property
def decode_varint_content(self) -> str:
content = self._ti.decode_varint
@@ -554,6 +722,22 @@ class RepeatedTypeInfo(TypeInfo):
def dump(self, _: str):
pass
def get_size_calculation(self, name: str, force: bool = False) -> str:
# For repeated fields, we always need to pass force=True to the underlying type's calculation
# This is because the encode method always sets force=true for repeated fields
if isinstance(self._ti, MessageType):
# For repeated messages, use the dedicated helper that handles iteration internally
field_id_size = self._ti.calculate_field_id_size()
o = f"ProtoSize::add_repeated_message(total_size, {field_id_size}, {name});"
return o
# For other repeated types, use the underlying type's size calculation with force=True
o = f"if (!{name}.empty()) {{\n"
o += f" for (const auto {'' if self._ti_is_bool else '&'}it : {name}) {{\n"
o += f" {self._ti.get_size_calculation('it', True)}\n"
o += " }\n"
o += "}"
return o
def build_enum_type(desc) -> tuple[str, str]:
"""Builds the enum type."""
@@ -587,6 +771,7 @@ def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]:
decode_64bit: list[str] = []
encode: list[str] = []
dump: list[str] = []
size_calc: list[str] = []
for field in desc.field:
if field.label == 3:
@@ -596,6 +781,7 @@ def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]:
protected_content.extend(ti.protected_content)
public_content.extend(ti.public_content)
encode.append(ti.encode_content)
size_calc.append(ti.get_size_calculation(f"this->{ti.field_name}"))
if ti.decode_varint_content:
decode_varint.append(ti.decode_varint_content)
@@ -662,6 +848,25 @@ def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]:
prot = "void encode(ProtoWriteBuffer buffer) const override;"
public_content.append(prot)
# Add calculate_size method
o = f"void {desc.name}::calculate_size(uint32_t &total_size) const {{"
# Add a check for empty/default objects to short-circuit the calculation
# Only add this optimization if we have fields to check
if size_calc:
# For a single field, just inline it for simplicity
if len(size_calc) == 1 and len(size_calc[0]) + len(o) + 3 < 120:
o += f" {size_calc[0]} "
else:
# For multiple fields, add a short-circuit check
o += "\n"
# Performance optimization: add all the size calculations
o += indent("\n".join(size_calc)) + "\n"
o += "}\n"
cpp += o
prot = "void calculate_size(uint32_t &total_size) const override;"
public_content.append(prot)
o = f"void {desc.name}::dump_to(std::string &out) const {{"
if dump:
if len(dump) == 1 and len(dump[0]) + len(o) + 3 < 120:
@@ -796,6 +1001,7 @@ def main() -> None:
#pragma once
#include "proto.h"
#include "api_pb2_size.h"
namespace esphome {
namespace api {
@@ -805,6 +1011,7 @@ def main() -> None:
cpp = FILE_HEADER
cpp += """\
#include "api_pb2.h"
#include "api_pb2_size.h"
#include "esphome/core/log.h"
#include <cinttypes>