mirror of
https://github.com/esphome/esphome.git
synced 2025-09-01 10:52:19 +01:00
Add common base classes for entity protobuf messages to reduce duplicate code (#9090)
This commit is contained in:
@@ -848,7 +848,10 @@ def calculate_message_estimated_size(desc: descriptor.DescriptorProto) -> int:
|
||||
return total_size
|
||||
|
||||
|
||||
def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]:
|
||||
def build_message_type(
|
||||
desc: descriptor.DescriptorProto,
|
||||
base_class_fields: dict[str, list[descriptor.FieldDescriptorProto]] = None,
|
||||
) -> tuple[str, str]:
|
||||
public_content: list[str] = []
|
||||
protected_content: list[str] = []
|
||||
decode_varint: list[str] = []
|
||||
@@ -859,6 +862,12 @@ def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]:
|
||||
dump: list[str] = []
|
||||
size_calc: list[str] = []
|
||||
|
||||
# Check if this message has a base class
|
||||
base_class = get_base_class(desc)
|
||||
common_field_names = set()
|
||||
if base_class and base_class_fields and base_class in base_class_fields:
|
||||
common_field_names = {f.name for f in base_class_fields[base_class]}
|
||||
|
||||
# Get message ID if it's a service message
|
||||
message_id: int | None = get_opt(desc, pb.id)
|
||||
|
||||
@@ -886,8 +895,14 @@ def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]:
|
||||
ti = RepeatedTypeInfo(field)
|
||||
else:
|
||||
ti = TYPE_INFO[field.type](field)
|
||||
protected_content.extend(ti.protected_content)
|
||||
public_content.extend(ti.public_content)
|
||||
|
||||
# Skip field declarations for fields that are in the base class
|
||||
# but include their encode/decode logic
|
||||
if field.name not in common_field_names:
|
||||
protected_content.extend(ti.protected_content)
|
||||
public_content.extend(ti.public_content)
|
||||
|
||||
# Always include encode/decode logic for all fields
|
||||
encode.append(ti.encode_content)
|
||||
size_calc.append(ti.get_size_calculation(f"this->{ti.field_name}"))
|
||||
|
||||
@@ -1001,7 +1016,10 @@ def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]:
|
||||
prot += "#endif\n"
|
||||
public_content.append(prot)
|
||||
|
||||
out = f"class {desc.name} : public ProtoMessage {{\n"
|
||||
if base_class:
|
||||
out = f"class {desc.name} : public {base_class} {{\n"
|
||||
else:
|
||||
out = f"class {desc.name} : public ProtoMessage {{\n"
|
||||
out += " public:\n"
|
||||
out += indent("\n".join(public_content)) + "\n"
|
||||
out += "\n"
|
||||
@@ -1033,6 +1051,132 @@ def get_opt(
|
||||
return desc.options.Extensions[opt]
|
||||
|
||||
|
||||
def get_base_class(desc: descriptor.DescriptorProto) -> str | None:
|
||||
"""Get the base_class option from a message descriptor."""
|
||||
if not desc.options.HasExtension(pb.base_class):
|
||||
return None
|
||||
return desc.options.Extensions[pb.base_class]
|
||||
|
||||
|
||||
def collect_messages_by_base_class(
|
||||
messages: list[descriptor.DescriptorProto],
|
||||
) -> dict[str, list[descriptor.DescriptorProto]]:
|
||||
"""Group messages by their base_class option."""
|
||||
base_class_groups = {}
|
||||
|
||||
for msg in messages:
|
||||
base_class = get_base_class(msg)
|
||||
if base_class:
|
||||
if base_class not in base_class_groups:
|
||||
base_class_groups[base_class] = []
|
||||
base_class_groups[base_class].append(msg)
|
||||
|
||||
return base_class_groups
|
||||
|
||||
|
||||
def find_common_fields(
|
||||
messages: list[descriptor.DescriptorProto],
|
||||
) -> list[descriptor.FieldDescriptorProto]:
|
||||
"""Find fields that are common to all messages in the list."""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
# Start with fields from the first message
|
||||
first_msg_fields = {field.name: field for field in messages[0].field}
|
||||
common_fields = []
|
||||
|
||||
# Check each field to see if it exists in all messages with same type
|
||||
# Field numbers can vary between messages - derived classes handle the mapping
|
||||
for field_name, field in first_msg_fields.items():
|
||||
is_common = True
|
||||
|
||||
for msg in messages[1:]:
|
||||
found = False
|
||||
for other_field in msg.field:
|
||||
if (
|
||||
other_field.name == field_name
|
||||
and other_field.type == field.type
|
||||
and other_field.label == field.label
|
||||
):
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
is_common = False
|
||||
break
|
||||
|
||||
if is_common:
|
||||
common_fields.append(field)
|
||||
|
||||
# Sort by field number to maintain order
|
||||
common_fields.sort(key=lambda f: f.number)
|
||||
return common_fields
|
||||
|
||||
|
||||
def build_base_class(
|
||||
base_class_name: str,
|
||||
common_fields: list[descriptor.FieldDescriptorProto],
|
||||
) -> tuple[str, str]:
|
||||
"""Build the base class definition and implementation."""
|
||||
public_content = []
|
||||
protected_content = []
|
||||
|
||||
# For base classes, we only declare the fields but don't handle encode/decode
|
||||
# The derived classes will handle encoding/decoding with their specific field numbers
|
||||
for field in common_fields:
|
||||
if field.label == 3: # repeated
|
||||
ti = RepeatedTypeInfo(field)
|
||||
else:
|
||||
ti = TYPE_INFO[field.type](field)
|
||||
|
||||
# Only add field declarations, not encode/decode logic
|
||||
protected_content.extend(ti.protected_content)
|
||||
public_content.extend(ti.public_content)
|
||||
|
||||
# Build header
|
||||
out = f"class {base_class_name} : public ProtoMessage {{\n"
|
||||
out += " public:\n"
|
||||
|
||||
# Add destructor with override
|
||||
public_content.insert(0, f"~{base_class_name}() override = default;")
|
||||
|
||||
# Base classes don't implement encode/decode/calculate_size
|
||||
# Derived classes handle these with their specific field numbers
|
||||
cpp = ""
|
||||
|
||||
out += indent("\n".join(public_content)) + "\n"
|
||||
out += "\n"
|
||||
out += " protected:\n"
|
||||
out += indent("\n".join(protected_content))
|
||||
if protected_content:
|
||||
out += "\n"
|
||||
out += "};\n"
|
||||
|
||||
# No implementation needed for base classes
|
||||
|
||||
return out, cpp
|
||||
|
||||
|
||||
def generate_base_classes(
|
||||
base_class_groups: dict[str, list[descriptor.DescriptorProto]],
|
||||
) -> tuple[str, str]:
|
||||
"""Generate all base classes."""
|
||||
all_headers = []
|
||||
all_cpp = []
|
||||
|
||||
for base_class_name, messages in base_class_groups.items():
|
||||
# Find common fields
|
||||
common_fields = find_common_fields(messages)
|
||||
|
||||
if common_fields:
|
||||
# Generate base class
|
||||
header, cpp = build_base_class(base_class_name, common_fields)
|
||||
all_headers.append(header)
|
||||
all_cpp.append(cpp)
|
||||
|
||||
return "\n".join(all_headers), "\n".join(all_cpp)
|
||||
|
||||
|
||||
def build_service_message_type(
|
||||
mt: descriptor.DescriptorProto,
|
||||
) -> tuple[str, str] | None:
|
||||
@@ -1134,8 +1278,25 @@ def main() -> None:
|
||||
|
||||
mt = file.message_type
|
||||
|
||||
# Collect messages by base class
|
||||
base_class_groups = collect_messages_by_base_class(mt)
|
||||
|
||||
# Find common fields for each base class
|
||||
base_class_fields = {}
|
||||
for base_class_name, messages in base_class_groups.items():
|
||||
common_fields = find_common_fields(messages)
|
||||
if common_fields:
|
||||
base_class_fields[base_class_name] = common_fields
|
||||
|
||||
# Generate base classes
|
||||
if base_class_fields:
|
||||
base_headers, base_cpp = generate_base_classes(base_class_groups)
|
||||
content += base_headers
|
||||
cpp += base_cpp
|
||||
|
||||
# Generate message types with base class information
|
||||
for m in mt:
|
||||
s, c = build_message_type(m)
|
||||
s, c = build_message_type(m, base_class_fields)
|
||||
content += s
|
||||
cpp += c
|
||||
|
||||
|
Reference in New Issue
Block a user