mirror of
				https://github.com/esphome/esphome.git
				synced 2025-10-30 22:53:59 +00:00 
			
		
		
		
	Use message_source_map consistently in proto generation (#9542)
This commit is contained in:
		| @@ -1495,6 +1495,7 @@ def build_base_class( | ||||
|     base_class_name: str, | ||||
|     common_fields: list[descriptor.FieldDescriptorProto], | ||||
|     messages: list[descriptor.DescriptorProto], | ||||
|     message_source_map: dict[str, int], | ||||
| ) -> tuple[str, str, str]: | ||||
|     """Build the base class definition and implementation.""" | ||||
|     public_content = [] | ||||
| @@ -1511,7 +1512,7 @@ def build_base_class( | ||||
|  | ||||
|     # Determine if any message using this base class needs decoding | ||||
|     needs_decode = any( | ||||
|         get_opt(msg, pb.source, SOURCE_BOTH) in (SOURCE_BOTH, SOURCE_CLIENT) | ||||
|         message_source_map.get(msg.name, SOURCE_BOTH) in (SOURCE_BOTH, SOURCE_CLIENT) | ||||
|         for msg in messages | ||||
|     ) | ||||
|  | ||||
| @@ -1543,6 +1544,7 @@ def build_base_class( | ||||
|  | ||||
| def generate_base_classes( | ||||
|     base_class_groups: dict[str, list[descriptor.DescriptorProto]], | ||||
|     message_source_map: dict[str, int], | ||||
| ) -> tuple[str, str, str]: | ||||
|     """Generate all base classes.""" | ||||
|     all_headers = [] | ||||
| @@ -1556,7 +1558,7 @@ def generate_base_classes( | ||||
|         if common_fields: | ||||
|             # Generate base class | ||||
|             header, cpp, dump_cpp = build_base_class( | ||||
|                 base_class_name, common_fields, messages | ||||
|                 base_class_name, common_fields, messages, message_source_map | ||||
|             ) | ||||
|             all_headers.append(header) | ||||
|             all_cpp.append(cpp) | ||||
| @@ -1567,6 +1569,7 @@ def generate_base_classes( | ||||
|  | ||||
| def build_service_message_type( | ||||
|     mt: descriptor.DescriptorProto, | ||||
|     message_source_map: dict[str, int], | ||||
| ) -> tuple[str, str] | None: | ||||
|     """Builds the service message type.""" | ||||
|     snake = camel_to_snake(mt.name) | ||||
| @@ -1574,7 +1577,7 @@ def build_service_message_type( | ||||
|     if id_ is None: | ||||
|         return None | ||||
|  | ||||
|     source: int = get_opt(mt, pb.source, 0) | ||||
|     source: int = message_source_map.get(mt.name, SOURCE_BOTH) | ||||
|  | ||||
|     ifdef: str | None = get_opt(mt, pb.ifdef) | ||||
|     log: bool = get_opt(mt, pb.log, True) | ||||
| @@ -1714,7 +1717,9 @@ namespace api { | ||||
|  | ||||
|     # Generate base classes | ||||
|     if base_class_fields: | ||||
|         base_headers, base_cpp, base_dump_cpp = generate_base_classes(base_class_groups) | ||||
|         base_headers, base_cpp, base_dump_cpp = generate_base_classes( | ||||
|             base_class_groups, message_source_map | ||||
|         ) | ||||
|         content += base_headers | ||||
|         cpp += base_cpp | ||||
|         dump_cpp += base_dump_cpp | ||||
| @@ -1832,7 +1837,7 @@ static const char *const TAG = "api.service"; | ||||
|     cpp += "#endif\n\n" | ||||
|  | ||||
|     for mt in file.message_type: | ||||
|         obj = build_service_message_type(mt) | ||||
|         obj = build_service_message_type(mt, message_source_map) | ||||
|         if obj is None: | ||||
|             continue | ||||
|         hout, cout = obj | ||||
|   | ||||
		Reference in New Issue
	
	Block a user