mirror of
https://github.com/esphome/esphome.git
synced 2025-09-01 10:52:19 +01:00
Reserve buffer space to avoid frequent realloc when generating protobuf messages (#8707)
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import IntEnum
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
@@ -10,11 +11,29 @@ import sys
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
|
||||
# Generate with
|
||||
# protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto
|
||||
import aioesphomeapi.api_options_pb2 as pb
|
||||
import google.protobuf.descriptor_pb2 as descriptor
|
||||
|
||||
|
||||
class WireType(IntEnum):
|
||||
"""Protocol Buffer wire types as defined in the protobuf spec.
|
||||
|
||||
As specified in the Protocol Buffers encoding guide:
|
||||
https://protobuf.dev/programming-guides/encoding/#structure
|
||||
"""
|
||||
|
||||
VARINT = 0 # int32, int64, uint32, uint64, sint32, sint64, bool, enum
|
||||
FIXED64 = 1 # fixed64, sfixed64, double
|
||||
LENGTH_DELIMITED = 2 # string, bytes, embedded messages, packed repeated fields
|
||||
START_GROUP = 3 # groups (deprecated)
|
||||
END_GROUP = 4 # groups (deprecated)
|
||||
FIXED32 = 5 # fixed32, sfixed32, float
|
||||
|
||||
|
||||
# Generate with
|
||||
# protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto
|
||||
|
||||
|
||||
"""Python 3 script to automatically generate C++ classes for ESPHome's native API.
|
||||
|
||||
It's pretty crappy spaghetti code, but it works.
|
||||
@@ -35,7 +54,7 @@ will be generated, they still need to be formatted
|
||||
|
||||
|
||||
FILE_HEADER = """// This file was automatically generated with a tool.
|
||||
// See scripts/api_protobuf/api_protobuf.py
|
||||
// See script/api_protobuf/api_protobuf.py
|
||||
"""
|
||||
|
||||
|
||||
@@ -63,6 +82,11 @@ def camel_to_snake(name: str) -> str:
|
||||
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
|
||||
|
||||
|
||||
def force_str(force: bool) -> str:
|
||||
"""Convert a boolean force value to string format for C++ code."""
|
||||
return str(force).lower()
|
||||
|
||||
|
||||
class TypeInfo(ABC):
|
||||
"""Base class for all type information."""
|
||||
|
||||
@@ -99,6 +123,11 @@ class TypeInfo(ABC):
|
||||
"""Check if the field is repeated."""
|
||||
return self._field.label == 3
|
||||
|
||||
@property
|
||||
def wire_type(self) -> WireType:
|
||||
"""Get the wire type for the field."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def cpp_type(self) -> str:
|
||||
raise NotImplementedError
|
||||
@@ -200,6 +229,35 @@ class TypeInfo(ABC):
|
||||
def dump(self, name: str) -> str:
|
||||
"""Dump the value to the output."""
|
||||
|
||||
def calculate_field_id_size(self) -> int:
|
||||
"""Calculates the size of a field ID in bytes.
|
||||
|
||||
Returns:
|
||||
The number of bytes needed to encode the field ID
|
||||
"""
|
||||
# Calculate the tag by combining field_id and wire_type
|
||||
tag = (self.number << 3) | (self.wire_type & 0b111)
|
||||
|
||||
# Calculate the varint size
|
||||
if tag < 128:
|
||||
return 1 # 7 bits
|
||||
if tag < 16384:
|
||||
return 2 # 14 bits
|
||||
if tag < 2097152:
|
||||
return 3 # 21 bits
|
||||
if tag < 268435456:
|
||||
return 4 # 28 bits
|
||||
return 5 # 32 bits (maximum for uint32_t)
|
||||
|
||||
@abstractmethod
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
"""Calculate the size needed for encoding this field.
|
||||
|
||||
Args:
|
||||
name: The name of the field
|
||||
force: Whether to force encoding the field even if it has a default value
|
||||
"""
|
||||
|
||||
|
||||
TYPE_INFO: dict[int, TypeInfo] = {}
|
||||
|
||||
@@ -221,12 +279,18 @@ class DoubleType(TypeInfo):
|
||||
default_value = "0.0"
|
||||
decode_64bit = "value.as_double()"
|
||||
encode_func = "encode_double"
|
||||
wire_type = WireType.FIXED64 # Uses wire type 1 according to protobuf spec
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%g", {name});\n'
|
||||
o += "out.append(buffer);"
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
o = f"ProtoSize::add_fixed_field<8>(total_size, {field_id_size}, {name} != 0.0, {force_str(force)});"
|
||||
return o
|
||||
|
||||
|
||||
@register_type(2)
|
||||
class FloatType(TypeInfo):
|
||||
@@ -234,12 +298,18 @@ class FloatType(TypeInfo):
|
||||
default_value = "0.0f"
|
||||
decode_32bit = "value.as_float()"
|
||||
encode_func = "encode_float"
|
||||
wire_type = WireType.FIXED32 # Uses wire type 5
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%g", {name});\n'
|
||||
o += "out.append(buffer);"
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
o = f"ProtoSize::add_fixed_field<4>(total_size, {field_id_size}, {name} != 0.0f, {force_str(force)});"
|
||||
return o
|
||||
|
||||
|
||||
@register_type(3)
|
||||
class Int64Type(TypeInfo):
|
||||
@@ -247,12 +317,18 @@ class Int64Type(TypeInfo):
|
||||
default_value = "0"
|
||||
decode_varint = "value.as_int64()"
|
||||
encode_func = "encode_int64"
|
||||
wire_type = WireType.VARINT # Uses wire type 0
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%lld", {name});\n'
|
||||
o += "out.append(buffer);"
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
o = f"ProtoSize::add_int64_field(total_size, {field_id_size}, {name}, {force_str(force)});"
|
||||
return o
|
||||
|
||||
|
||||
@register_type(4)
|
||||
class UInt64Type(TypeInfo):
|
||||
@@ -260,12 +336,18 @@ class UInt64Type(TypeInfo):
|
||||
default_value = "0"
|
||||
decode_varint = "value.as_uint64()"
|
||||
encode_func = "encode_uint64"
|
||||
wire_type = WireType.VARINT # Uses wire type 0
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%llu", {name});\n'
|
||||
o += "out.append(buffer);"
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
o = f"ProtoSize::add_uint64_field(total_size, {field_id_size}, {name}, {force_str(force)});"
|
||||
return o
|
||||
|
||||
|
||||
@register_type(5)
|
||||
class Int32Type(TypeInfo):
|
||||
@@ -273,12 +355,18 @@ class Int32Type(TypeInfo):
|
||||
default_value = "0"
|
||||
decode_varint = "value.as_int32()"
|
||||
encode_func = "encode_int32"
|
||||
wire_type = WireType.VARINT # Uses wire type 0
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%" PRId32, {name});\n'
|
||||
o += "out.append(buffer);"
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
o = f"ProtoSize::add_int32_field(total_size, {field_id_size}, {name}, {force_str(force)});"
|
||||
return o
|
||||
|
||||
|
||||
@register_type(6)
|
||||
class Fixed64Type(TypeInfo):
|
||||
@@ -286,12 +374,18 @@ class Fixed64Type(TypeInfo):
|
||||
default_value = "0"
|
||||
decode_64bit = "value.as_fixed64()"
|
||||
encode_func = "encode_fixed64"
|
||||
wire_type = WireType.FIXED64 # Uses wire type 1
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%llu", {name});\n'
|
||||
o += "out.append(buffer);"
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
o = f"ProtoSize::add_fixed_field<8>(total_size, {field_id_size}, {name} != 0, {force_str(force)});"
|
||||
return o
|
||||
|
||||
|
||||
@register_type(7)
|
||||
class Fixed32Type(TypeInfo):
|
||||
@@ -299,12 +393,18 @@ class Fixed32Type(TypeInfo):
|
||||
default_value = "0"
|
||||
decode_32bit = "value.as_fixed32()"
|
||||
encode_func = "encode_fixed32"
|
||||
wire_type = WireType.FIXED32 # Uses wire type 5
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%" PRIu32, {name});\n'
|
||||
o += "out.append(buffer);"
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
o = f"ProtoSize::add_fixed_field<4>(total_size, {field_id_size}, {name} != 0, {force_str(force)});"
|
||||
return o
|
||||
|
||||
|
||||
@register_type(8)
|
||||
class BoolType(TypeInfo):
|
||||
@@ -312,11 +412,17 @@ class BoolType(TypeInfo):
|
||||
default_value = "false"
|
||||
decode_varint = "value.as_bool()"
|
||||
encode_func = "encode_bool"
|
||||
wire_type = WireType.VARINT # Uses wire type 0
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f"out.append(YESNO({name}));"
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
o = f"ProtoSize::add_bool_field(total_size, {field_id_size}, {name}, {force_str(force)});"
|
||||
return o
|
||||
|
||||
|
||||
@register_type(9)
|
||||
class StringType(TypeInfo):
|
||||
@@ -326,11 +432,17 @@ class StringType(TypeInfo):
|
||||
const_reference_type = "const std::string &"
|
||||
decode_length = "value.as_string()"
|
||||
encode_func = "encode_string"
|
||||
wire_type = WireType.LENGTH_DELIMITED # Uses wire type 2
|
||||
|
||||
def dump(self, name):
|
||||
o = f'out.append("\'").append({name}).append("\'");'
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
o = f"ProtoSize::add_string_field(total_size, {field_id_size}, {name}, {force_str(force)});"
|
||||
return o
|
||||
|
||||
|
||||
@register_type(11)
|
||||
class MessageType(TypeInfo):
|
||||
@@ -339,6 +451,7 @@ class MessageType(TypeInfo):
|
||||
return self._field.type_name[1:]
|
||||
|
||||
default_value = ""
|
||||
wire_type = WireType.LENGTH_DELIMITED # Uses wire type 2
|
||||
|
||||
@property
|
||||
def reference_type(self) -> str:
|
||||
@@ -360,6 +473,11 @@ class MessageType(TypeInfo):
|
||||
o = f"{name}.dump_to(out);"
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
o = f"ProtoSize::add_message_object(total_size, {field_id_size}, {name}, {force_str(force)});"
|
||||
return o
|
||||
|
||||
|
||||
@register_type(12)
|
||||
class BytesType(TypeInfo):
|
||||
@@ -369,11 +487,17 @@ class BytesType(TypeInfo):
|
||||
const_reference_type = "const std::string &"
|
||||
decode_length = "value.as_string()"
|
||||
encode_func = "encode_string"
|
||||
wire_type = WireType.LENGTH_DELIMITED # Uses wire type 2
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'out.append("\'").append({name}).append("\'");'
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
o = f"ProtoSize::add_string_field(total_size, {field_id_size}, {name}, {force_str(force)});"
|
||||
return o
|
||||
|
||||
|
||||
@register_type(13)
|
||||
class UInt32Type(TypeInfo):
|
||||
@@ -381,12 +505,18 @@ class UInt32Type(TypeInfo):
|
||||
default_value = "0"
|
||||
decode_varint = "value.as_uint32()"
|
||||
encode_func = "encode_uint32"
|
||||
wire_type = WireType.VARINT # Uses wire type 0
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%" PRIu32, {name});\n'
|
||||
o += "out.append(buffer);"
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
o = f"ProtoSize::add_uint32_field(total_size, {field_id_size}, {name}, {force_str(force)});"
|
||||
return o
|
||||
|
||||
|
||||
@register_type(14)
|
||||
class EnumType(TypeInfo):
|
||||
@@ -399,6 +529,7 @@ class EnumType(TypeInfo):
|
||||
return f"value.as_enum<{self.cpp_type}>()"
|
||||
|
||||
default_value = ""
|
||||
wire_type = WireType.VARINT # Uses wire type 0
|
||||
|
||||
@property
|
||||
def encode_func(self) -> str:
|
||||
@@ -408,6 +539,11 @@ class EnumType(TypeInfo):
|
||||
o = f"out.append(proto_enum_to_string<{self.cpp_type}>({name}));"
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
o = f"ProtoSize::add_enum_field(total_size, {field_id_size}, static_cast<uint32_t>({name}), {force_str(force)});"
|
||||
return o
|
||||
|
||||
|
||||
@register_type(15)
|
||||
class SFixed32Type(TypeInfo):
|
||||
@@ -415,12 +551,18 @@ class SFixed32Type(TypeInfo):
|
||||
default_value = "0"
|
||||
decode_32bit = "value.as_sfixed32()"
|
||||
encode_func = "encode_sfixed32"
|
||||
wire_type = WireType.FIXED32 # Uses wire type 5
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%" PRId32, {name});\n'
|
||||
o += "out.append(buffer);"
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
o = f"ProtoSize::add_fixed_field<4>(total_size, {field_id_size}, {name} != 0, {force_str(force)});"
|
||||
return o
|
||||
|
||||
|
||||
@register_type(16)
|
||||
class SFixed64Type(TypeInfo):
|
||||
@@ -428,12 +570,18 @@ class SFixed64Type(TypeInfo):
|
||||
default_value = "0"
|
||||
decode_64bit = "value.as_sfixed64()"
|
||||
encode_func = "encode_sfixed64"
|
||||
wire_type = WireType.FIXED64 # Uses wire type 1
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%lld", {name});\n'
|
||||
o += "out.append(buffer);"
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
o = f"ProtoSize::add_fixed_field<8>(total_size, {field_id_size}, {name} != 0, {force_str(force)});"
|
||||
return o
|
||||
|
||||
|
||||
@register_type(17)
|
||||
class SInt32Type(TypeInfo):
|
||||
@@ -441,12 +589,18 @@ class SInt32Type(TypeInfo):
|
||||
default_value = "0"
|
||||
decode_varint = "value.as_sint32()"
|
||||
encode_func = "encode_sint32"
|
||||
wire_type = WireType.VARINT # Uses wire type 0
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%" PRId32, {name});\n'
|
||||
o += "out.append(buffer);"
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
o = f"ProtoSize::add_sint32_field(total_size, {field_id_size}, {name}, {force_str(force)});"
|
||||
return o
|
||||
|
||||
|
||||
@register_type(18)
|
||||
class SInt64Type(TypeInfo):
|
||||
@@ -454,12 +608,18 @@ class SInt64Type(TypeInfo):
|
||||
default_value = "0"
|
||||
decode_varint = "value.as_sint64()"
|
||||
encode_func = "encode_sint64"
|
||||
wire_type = WireType.VARINT # Uses wire type 0
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%lld", {name});\n'
|
||||
o += "out.append(buffer);"
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
o = f"ProtoSize::add_sint64_field(total_size, {field_id_size}, {name}, {force_str(force)});"
|
||||
return o
|
||||
|
||||
|
||||
class RepeatedTypeInfo(TypeInfo):
|
||||
def __init__(self, field: descriptor.FieldDescriptorProto) -> None:
|
||||
@@ -478,6 +638,14 @@ class RepeatedTypeInfo(TypeInfo):
|
||||
def const_reference_type(self) -> str:
|
||||
return f"const {self.cpp_type} &"
|
||||
|
||||
@property
|
||||
def wire_type(self) -> WireType:
|
||||
"""Get the wire type for this repeated field.
|
||||
|
||||
For repeated fields, we use the same wire type as the underlying field.
|
||||
"""
|
||||
return self._ti.wire_type
|
||||
|
||||
@property
|
||||
def decode_varint_content(self) -> str:
|
||||
content = self._ti.decode_varint
|
||||
@@ -554,6 +722,22 @@ class RepeatedTypeInfo(TypeInfo):
|
||||
def dump(self, _: str):
|
||||
pass
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
# For repeated fields, we always need to pass force=True to the underlying type's calculation
|
||||
# This is because the encode method always sets force=true for repeated fields
|
||||
if isinstance(self._ti, MessageType):
|
||||
# For repeated messages, use the dedicated helper that handles iteration internally
|
||||
field_id_size = self._ti.calculate_field_id_size()
|
||||
o = f"ProtoSize::add_repeated_message(total_size, {field_id_size}, {name});"
|
||||
return o
|
||||
# For other repeated types, use the underlying type's size calculation with force=True
|
||||
o = f"if (!{name}.empty()) {{\n"
|
||||
o += f" for (const auto {'' if self._ti_is_bool else '&'}it : {name}) {{\n"
|
||||
o += f" {self._ti.get_size_calculation('it', True)}\n"
|
||||
o += " }\n"
|
||||
o += "}"
|
||||
return o
|
||||
|
||||
|
||||
def build_enum_type(desc) -> tuple[str, str]:
|
||||
"""Builds the enum type."""
|
||||
@@ -587,6 +771,7 @@ def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]:
|
||||
decode_64bit: list[str] = []
|
||||
encode: list[str] = []
|
||||
dump: list[str] = []
|
||||
size_calc: list[str] = []
|
||||
|
||||
for field in desc.field:
|
||||
if field.label == 3:
|
||||
@@ -596,6 +781,7 @@ def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]:
|
||||
protected_content.extend(ti.protected_content)
|
||||
public_content.extend(ti.public_content)
|
||||
encode.append(ti.encode_content)
|
||||
size_calc.append(ti.get_size_calculation(f"this->{ti.field_name}"))
|
||||
|
||||
if ti.decode_varint_content:
|
||||
decode_varint.append(ti.decode_varint_content)
|
||||
@@ -662,6 +848,25 @@ def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]:
|
||||
prot = "void encode(ProtoWriteBuffer buffer) const override;"
|
||||
public_content.append(prot)
|
||||
|
||||
# Add calculate_size method
|
||||
o = f"void {desc.name}::calculate_size(uint32_t &total_size) const {{"
|
||||
|
||||
# Add a check for empty/default objects to short-circuit the calculation
|
||||
# Only add this optimization if we have fields to check
|
||||
if size_calc:
|
||||
# For a single field, just inline it for simplicity
|
||||
if len(size_calc) == 1 and len(size_calc[0]) + len(o) + 3 < 120:
|
||||
o += f" {size_calc[0]} "
|
||||
else:
|
||||
# For multiple fields, add a short-circuit check
|
||||
o += "\n"
|
||||
# Performance optimization: add all the size calculations
|
||||
o += indent("\n".join(size_calc)) + "\n"
|
||||
o += "}\n"
|
||||
cpp += o
|
||||
prot = "void calculate_size(uint32_t &total_size) const override;"
|
||||
public_content.append(prot)
|
||||
|
||||
o = f"void {desc.name}::dump_to(std::string &out) const {{"
|
||||
if dump:
|
||||
if len(dump) == 1 and len(dump[0]) + len(o) + 3 < 120:
|
||||
@@ -796,6 +1001,7 @@ def main() -> None:
|
||||
#pragma once
|
||||
|
||||
#include "proto.h"
|
||||
#include "api_pb2_size.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace api {
|
||||
@@ -805,6 +1011,7 @@ def main() -> None:
|
||||
cpp = FILE_HEADER
|
||||
cpp += """\
|
||||
#include "api_pb2.h"
|
||||
#include "api_pb2_size.h"
|
||||
#include "esphome/core/log.h"
|
||||
|
||||
#include <cinttypes>
|
||||
|
Reference in New Issue
Block a user