mirror of
				https://github.com/esphome/esphome.git
				synced 2025-10-30 22:53:59 +00:00 
			
		
		
		
	Add common base classes for entity protobuf messages to reduce duplicate code (#9090)
This commit is contained in:
		| @@ -848,7 +848,10 @@ def calculate_message_estimated_size(desc: descriptor.DescriptorProto) -> int: | ||||
|     return total_size | ||||
|  | ||||
|  | ||||
| def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]: | ||||
| def build_message_type( | ||||
|     desc: descriptor.DescriptorProto, | ||||
|     base_class_fields: dict[str, list[descriptor.FieldDescriptorProto]] = None, | ||||
| ) -> tuple[str, str]: | ||||
|     public_content: list[str] = [] | ||||
|     protected_content: list[str] = [] | ||||
|     decode_varint: list[str] = [] | ||||
| @@ -859,6 +862,12 @@ def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]: | ||||
|     dump: list[str] = [] | ||||
|     size_calc: list[str] = [] | ||||
|  | ||||
|     # Check if this message has a base class | ||||
|     base_class = get_base_class(desc) | ||||
|     common_field_names = set() | ||||
|     if base_class and base_class_fields and base_class in base_class_fields: | ||||
|         common_field_names = {f.name for f in base_class_fields[base_class]} | ||||
|  | ||||
|     # Get message ID if it's a service message | ||||
|     message_id: int | None = get_opt(desc, pb.id) | ||||
|  | ||||
| @@ -886,8 +895,14 @@ def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]: | ||||
|             ti = RepeatedTypeInfo(field) | ||||
|         else: | ||||
|             ti = TYPE_INFO[field.type](field) | ||||
|         protected_content.extend(ti.protected_content) | ||||
|         public_content.extend(ti.public_content) | ||||
|  | ||||
|         # Skip field declarations for fields that are in the base class | ||||
|         # but include their encode/decode logic | ||||
|         if field.name not in common_field_names: | ||||
|             protected_content.extend(ti.protected_content) | ||||
|             public_content.extend(ti.public_content) | ||||
|  | ||||
|         # Always include encode/decode logic for all fields | ||||
|         encode.append(ti.encode_content) | ||||
|         size_calc.append(ti.get_size_calculation(f"this->{ti.field_name}")) | ||||
|  | ||||
| @@ -1001,7 +1016,10 @@ def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]: | ||||
|     prot += "#endif\n" | ||||
|     public_content.append(prot) | ||||
|  | ||||
|     out = f"class {desc.name} : public ProtoMessage {{\n" | ||||
|     if base_class: | ||||
|         out = f"class {desc.name} : public {base_class} {{\n" | ||||
|     else: | ||||
|         out = f"class {desc.name} : public ProtoMessage {{\n" | ||||
|     out += " public:\n" | ||||
|     out += indent("\n".join(public_content)) + "\n" | ||||
|     out += "\n" | ||||
| @@ -1033,6 +1051,132 @@ def get_opt( | ||||
|     return desc.options.Extensions[opt] | ||||
|  | ||||
|  | ||||
| def get_base_class(desc: descriptor.DescriptorProto) -> str | None: | ||||
|     """Get the base_class option from a message descriptor.""" | ||||
|     if not desc.options.HasExtension(pb.base_class): | ||||
|         return None | ||||
|     return desc.options.Extensions[pb.base_class] | ||||
|  | ||||
|  | ||||
| def collect_messages_by_base_class( | ||||
|     messages: list[descriptor.DescriptorProto], | ||||
| ) -> dict[str, list[descriptor.DescriptorProto]]: | ||||
|     """Group messages by their base_class option.""" | ||||
|     base_class_groups = {} | ||||
|  | ||||
|     for msg in messages: | ||||
|         base_class = get_base_class(msg) | ||||
|         if base_class: | ||||
|             if base_class not in base_class_groups: | ||||
|                 base_class_groups[base_class] = [] | ||||
|             base_class_groups[base_class].append(msg) | ||||
|  | ||||
|     return base_class_groups | ||||
|  | ||||
|  | ||||
| def find_common_fields( | ||||
|     messages: list[descriptor.DescriptorProto], | ||||
| ) -> list[descriptor.FieldDescriptorProto]: | ||||
|     """Find fields that are common to all messages in the list.""" | ||||
|     if not messages: | ||||
|         return [] | ||||
|  | ||||
|     # Start with fields from the first message | ||||
|     first_msg_fields = {field.name: field for field in messages[0].field} | ||||
|     common_fields = [] | ||||
|  | ||||
|     # Check each field to see if it exists in all messages with same type | ||||
|     # Field numbers can vary between messages - derived classes handle the mapping | ||||
|     for field_name, field in first_msg_fields.items(): | ||||
|         is_common = True | ||||
|  | ||||
|         for msg in messages[1:]: | ||||
|             found = False | ||||
|             for other_field in msg.field: | ||||
|                 if ( | ||||
|                     other_field.name == field_name | ||||
|                     and other_field.type == field.type | ||||
|                     and other_field.label == field.label | ||||
|                 ): | ||||
|                     found = True | ||||
|                     break | ||||
|  | ||||
|             if not found: | ||||
|                 is_common = False | ||||
|                 break | ||||
|  | ||||
|         if is_common: | ||||
|             common_fields.append(field) | ||||
|  | ||||
|     # Sort by field number to maintain order | ||||
|     common_fields.sort(key=lambda f: f.number) | ||||
|     return common_fields | ||||
|  | ||||
|  | ||||
| def build_base_class( | ||||
|     base_class_name: str, | ||||
|     common_fields: list[descriptor.FieldDescriptorProto], | ||||
| ) -> tuple[str, str]: | ||||
|     """Build the base class definition and implementation.""" | ||||
|     public_content = [] | ||||
|     protected_content = [] | ||||
|  | ||||
|     # For base classes, we only declare the fields but don't handle encode/decode | ||||
|     # The derived classes will handle encoding/decoding with their specific field numbers | ||||
|     for field in common_fields: | ||||
|         if field.label == 3:  # repeated | ||||
|             ti = RepeatedTypeInfo(field) | ||||
|         else: | ||||
|             ti = TYPE_INFO[field.type](field) | ||||
|  | ||||
|         # Only add field declarations, not encode/decode logic | ||||
|         protected_content.extend(ti.protected_content) | ||||
|         public_content.extend(ti.public_content) | ||||
|  | ||||
|     # Build header | ||||
|     out = f"class {base_class_name} : public ProtoMessage {{\n" | ||||
|     out += " public:\n" | ||||
|  | ||||
|     # Add destructor with override | ||||
|     public_content.insert(0, f"~{base_class_name}() override = default;") | ||||
|  | ||||
|     # Base classes don't implement encode/decode/calculate_size | ||||
|     # Derived classes handle these with their specific field numbers | ||||
|     cpp = "" | ||||
|  | ||||
|     out += indent("\n".join(public_content)) + "\n" | ||||
|     out += "\n" | ||||
|     out += " protected:\n" | ||||
|     out += indent("\n".join(protected_content)) | ||||
|     if protected_content: | ||||
|         out += "\n" | ||||
|     out += "};\n" | ||||
|  | ||||
|     # No implementation needed for base classes | ||||
|  | ||||
|     return out, cpp | ||||
|  | ||||
|  | ||||
| def generate_base_classes( | ||||
|     base_class_groups: dict[str, list[descriptor.DescriptorProto]], | ||||
| ) -> tuple[str, str]: | ||||
|     """Generate all base classes.""" | ||||
|     all_headers = [] | ||||
|     all_cpp = [] | ||||
|  | ||||
|     for base_class_name, messages in base_class_groups.items(): | ||||
|         # Find common fields | ||||
|         common_fields = find_common_fields(messages) | ||||
|  | ||||
|         if common_fields: | ||||
|             # Generate base class | ||||
|             header, cpp = build_base_class(base_class_name, common_fields) | ||||
|             all_headers.append(header) | ||||
|             all_cpp.append(cpp) | ||||
|  | ||||
|     return "\n".join(all_headers), "\n".join(all_cpp) | ||||
|  | ||||
|  | ||||
| def build_service_message_type( | ||||
|     mt: descriptor.DescriptorProto, | ||||
| ) -> tuple[str, str] | None: | ||||
| @@ -1134,8 +1278,25 @@ def main() -> None: | ||||
|  | ||||
|     mt = file.message_type | ||||
|  | ||||
|     # Collect messages by base class | ||||
|     base_class_groups = collect_messages_by_base_class(mt) | ||||
|  | ||||
|     # Find common fields for each base class | ||||
|     base_class_fields = {} | ||||
|     for base_class_name, messages in base_class_groups.items(): | ||||
|         common_fields = find_common_fields(messages) | ||||
|         if common_fields: | ||||
|             base_class_fields[base_class_name] = common_fields | ||||
|  | ||||
|     # Generate base classes | ||||
|     if base_class_fields: | ||||
|         base_headers, base_cpp = generate_base_classes(base_class_groups) | ||||
|         content += base_headers | ||||
|         cpp += base_cpp | ||||
|  | ||||
|     # Generate message types with base class information | ||||
|     for m in mt: | ||||
|         s, c = build_message_type(m) | ||||
|         s, c = build_message_type(m, base_class_fields) | ||||
|         content += s | ||||
|         cpp += c | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user