mirror of
https://github.com/esphome/esphome.git
synced 2025-09-01 10:52:19 +01:00
Reduce API proto vtable overhead by splitting decode functionality (#9541)
This commit is contained in:
@@ -877,14 +877,15 @@ class RepeatedTypeInfo(TypeInfo):
|
||||
|
||||
def build_type_usage_map(
|
||||
file_desc: descriptor.FileDescriptorProto,
|
||||
) -> tuple[dict[str, str | None], dict[str, str | None]]:
|
||||
) -> tuple[dict[str, str | None], dict[str, str | None], dict[str, int]]:
|
||||
"""Build mappings for both enums and messages to their ifdefs based on usage.
|
||||
|
||||
Returns:
|
||||
tuple: (enum_ifdef_map, message_ifdef_map)
|
||||
tuple: (enum_ifdef_map, message_ifdef_map, message_source_map)
|
||||
"""
|
||||
enum_ifdef_map: dict[str, str | None] = {}
|
||||
message_ifdef_map: dict[str, str | None] = {}
|
||||
message_source_map: dict[str, int] = {}
|
||||
|
||||
# Build maps of which types are used by which messages
|
||||
enum_usage: dict[
|
||||
@@ -971,7 +972,44 @@ def build_type_usage_map(
|
||||
message_ifdef_map[message.name] = parent_ifdefs.pop()
|
||||
changed = True
|
||||
|
||||
return enum_ifdef_map, message_ifdef_map
|
||||
# Build message source map
|
||||
# First pass: Get explicit sources for messages with source option or id
|
||||
for msg in file_desc.message_type:
|
||||
if msg.options.HasExtension(pb.source):
|
||||
# Explicit source option takes precedence
|
||||
message_source_map[msg.name] = get_opt(msg, pb.source, SOURCE_BOTH)
|
||||
elif msg.options.HasExtension(pb.id):
|
||||
# Service messages (with id) default to SOURCE_BOTH
|
||||
message_source_map[msg.name] = SOURCE_BOTH
|
||||
|
||||
# Second pass: Determine sources for embedded messages based on their usage
|
||||
for msg in file_desc.message_type:
|
||||
if msg.name in message_source_map:
|
||||
continue # Already has explicit source
|
||||
|
||||
if msg.name in message_usage:
|
||||
# Get sources from all parent messages that use this one
|
||||
parent_sources = {
|
||||
message_source_map[parent]
|
||||
for parent in message_usage[msg.name]
|
||||
if parent in message_source_map
|
||||
}
|
||||
|
||||
# Combine parent sources
|
||||
if not parent_sources:
|
||||
# No parent has explicit source, default to encode-only
|
||||
message_source_map[msg.name] = SOURCE_SERVER
|
||||
elif len(parent_sources) > 1:
|
||||
# Multiple different sources or SOURCE_BOTH present
|
||||
message_source_map[msg.name] = SOURCE_BOTH
|
||||
else:
|
||||
# Inherit single parent source
|
||||
message_source_map[msg.name] = parent_sources.pop()
|
||||
else:
|
||||
# Not used by any message and no explicit source - default to encode-only
|
||||
message_source_map[msg.name] = SOURCE_SERVER
|
||||
|
||||
return enum_ifdef_map, message_ifdef_map, message_source_map
|
||||
|
||||
|
||||
def build_enum_type(desc, enum_ifdef_map) -> tuple[str, str, str]:
|
||||
@@ -1023,7 +1061,8 @@ def calculate_message_estimated_size(desc: descriptor.DescriptorProto) -> int:
|
||||
|
||||
def build_message_type(
|
||||
desc: descriptor.DescriptorProto,
|
||||
base_class_fields: dict[str, list[descriptor.FieldDescriptorProto]] = None,
|
||||
base_class_fields: dict[str, list[descriptor.FieldDescriptorProto]],
|
||||
message_source_map: dict[str, int],
|
||||
) -> tuple[str, str, str]:
|
||||
public_content: list[str] = []
|
||||
protected_content: list[str] = []
|
||||
@@ -1045,7 +1084,7 @@ def build_message_type(
|
||||
message_id: int | None = get_opt(desc, pb.id)
|
||||
|
||||
# Get source direction to determine if we need decode/encode methods
|
||||
source: int = get_opt(desc, pb.source, SOURCE_BOTH)
|
||||
source = message_source_map[desc.name]
|
||||
needs_decode = source in (SOURCE_BOTH, SOURCE_CLIENT)
|
||||
needs_encode = source in (SOURCE_BOTH, SOURCE_SERVER)
|
||||
|
||||
@@ -1250,7 +1289,9 @@ def build_message_type(
|
||||
if base_class:
|
||||
out = f"class {desc.name} : public {base_class} {{\n"
|
||||
else:
|
||||
out = f"class {desc.name} : public ProtoMessage {{\n"
|
||||
# Determine inheritance based on whether the message needs decoding
|
||||
base_class = "ProtoDecodableMessage" if needs_decode else "ProtoMessage"
|
||||
out = f"class {desc.name} : public {base_class} {{\n"
|
||||
out += " public:\n"
|
||||
out += indent("\n".join(public_content)) + "\n"
|
||||
out += "\n"
|
||||
@@ -1351,6 +1392,7 @@ def find_common_fields(
|
||||
def build_base_class(
|
||||
base_class_name: str,
|
||||
common_fields: list[descriptor.FieldDescriptorProto],
|
||||
messages: list[descriptor.DescriptorProto],
|
||||
) -> tuple[str, str, str]:
|
||||
"""Build the base class definition and implementation."""
|
||||
public_content = []
|
||||
@@ -1365,8 +1407,15 @@ def build_base_class(
|
||||
protected_content.extend(ti.protected_content)
|
||||
public_content.extend(ti.public_content)
|
||||
|
||||
# Determine if any message using this base class needs decoding
|
||||
needs_decode = any(
|
||||
get_opt(msg, pb.source, SOURCE_BOTH) in (SOURCE_BOTH, SOURCE_CLIENT)
|
||||
for msg in messages
|
||||
)
|
||||
|
||||
# Build header
|
||||
out = f"class {base_class_name} : public ProtoMessage {{\n"
|
||||
parent_class = "ProtoDecodableMessage" if needs_decode else "ProtoMessage"
|
||||
out = f"class {base_class_name} : public {parent_class} {{\n"
|
||||
out += " public:\n"
|
||||
|
||||
# Add destructor with override
|
||||
@@ -1404,7 +1453,9 @@ def generate_base_classes(
|
||||
|
||||
if common_fields:
|
||||
# Generate base class
|
||||
header, cpp, dump_cpp = build_base_class(base_class_name, common_fields)
|
||||
header, cpp, dump_cpp = build_base_class(
|
||||
base_class_name, common_fields, messages
|
||||
)
|
||||
all_headers.append(header)
|
||||
all_cpp.append(cpp)
|
||||
all_dump_cpp.append(dump_cpp)
|
||||
@@ -1516,7 +1567,7 @@ namespace api {
|
||||
content += "namespace enums {\n\n"
|
||||
|
||||
# Build dynamic ifdef mappings for both enums and messages
|
||||
enum_ifdef_map, message_ifdef_map = build_type_usage_map(file)
|
||||
enum_ifdef_map, message_ifdef_map, message_source_map = build_type_usage_map(file)
|
||||
|
||||
# Simple grouping of enums by ifdef
|
||||
current_ifdef = None
|
||||
@@ -1570,7 +1621,7 @@ namespace api {
|
||||
current_ifdef = None
|
||||
|
||||
for m in mt:
|
||||
s, c, dc = build_message_type(m, base_class_fields)
|
||||
s, c, dc = build_message_type(m, base_class_fields, message_source_map)
|
||||
msg_ifdef = message_ifdef_map.get(m.name)
|
||||
|
||||
# Handle ifdef changes
|
||||
|
Reference in New Issue
Block a user