1
0
mirror of https://github.com/esphome/esphome.git synced 2025-09-01 10:52:19 +01:00

Fix missing ifdef guards in API protobuf generator (#9296)

This commit is contained in:
J. Nick Koston
2025-07-02 16:39:20 -05:00
committed by GitHub
parent b9391f2cd4
commit 4ef5c941c9
5 changed files with 4616 additions and 4539 deletions

View File

@@ -813,27 +813,137 @@ class RepeatedTypeInfo(TypeInfo):
return underlying_size * 2
def build_enum_type(desc) -> tuple[str, str]:
"""Builds the enum type."""
def build_type_usage_map(
file_desc: descriptor.FileDescriptorProto,
) -> tuple[dict[str, str | None], dict[str, str | None]]:
"""Build mappings for both enums and messages to their ifdefs based on usage.
Returns:
tuple: (enum_ifdef_map, message_ifdef_map)
"""
enum_ifdef_map: dict[str, str | None] = {}
message_ifdef_map: dict[str, str | None] = {}
# Build maps of which types are used by which messages
enum_usage: dict[
str, set[str]
] = {} # enum_name -> set of message names that use it
message_usage: dict[
str, set[str]
] = {} # message_name -> set of message names that use it
# Build message name to ifdef mapping for quick lookup
message_to_ifdef: dict[str, str | None] = {
msg.name: get_opt(msg, pb.ifdef) for msg in file_desc.message_type
}
# Analyze field usage
for message in file_desc.message_type:
for field in message.field:
type_name = field.type_name.split(".")[-1] if field.type_name else None
if not type_name:
continue
# Track enum usage
if field.type == 14: # TYPE_ENUM
enum_usage.setdefault(type_name, set()).add(message.name)
# Track message usage
elif field.type == 11: # TYPE_MESSAGE
message_usage.setdefault(type_name, set()).add(message.name)
# Helper to get unique ifdef from a set of messages
def get_unique_ifdef(message_names: set[str]) -> str | None:
ifdefs: set[str] = {
message_to_ifdef[name]
for name in message_names
if message_to_ifdef.get(name)
}
return ifdefs.pop() if len(ifdefs) == 1 else None
# Build enum ifdef map
for enum in file_desc.enum_type:
if enum.name in enum_usage:
enum_ifdef_map[enum.name] = get_unique_ifdef(enum_usage[enum.name])
else:
enum_ifdef_map[enum.name] = None
# Build message ifdef map
for message in file_desc.message_type:
# Explicit ifdef takes precedence
explicit_ifdef = message_to_ifdef.get(message.name)
if explicit_ifdef:
message_ifdef_map[message.name] = explicit_ifdef
elif message.name in message_usage:
# Inherit ifdef if all parent messages have the same one
message_ifdef_map[message.name] = get_unique_ifdef(
message_usage[message.name]
)
else:
message_ifdef_map[message.name] = None
# Second pass: propagate ifdefs recursively
# Keep iterating until no more changes are made
changed = True
iterations = 0
while changed and iterations < 10: # Add safety limit
changed = False
iterations += 1
for message in file_desc.message_type:
# Skip if already has an ifdef
if message_ifdef_map.get(message.name):
continue
# Check if this message is used by other messages
if message.name not in message_usage:
continue
# Get ifdefs from all messages that use this one
parent_ifdefs: set[str] = {
message_ifdef_map.get(parent)
for parent in message_usage[message.name]
if message_ifdef_map.get(parent)
}
# If all parents have the same ifdef, inherit it
if len(parent_ifdefs) == 1 and None not in parent_ifdefs:
message_ifdef_map[message.name] = parent_ifdefs.pop()
changed = True
return enum_ifdef_map, message_ifdef_map
def build_enum_type(desc, enum_ifdef_map) -> tuple[str, str, str]:
"""Builds the enum type.
Args:
desc: The enum descriptor
enum_ifdef_map: Mapping of enum names to their ifdefs
Returns:
tuple: (header_content, cpp_content, dump_cpp_content)
"""
name = desc.name
out = f"enum {name} : uint32_t {{\n"
for v in desc.value:
out += f" {v.name} = {v.number},\n"
out += "};\n"
cpp = "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
cpp += f"template<> const char *proto_enum_to_string<enums::{name}>(enums::{name} value) {{\n"
cpp += " switch (value) {\n"
for v in desc.value:
cpp += f" case enums::{v.name}:\n"
cpp += f' return "{v.name}";\n'
cpp += " default:\n"
cpp += ' return "UNKNOWN";\n'
cpp += " }\n"
cpp += "}\n"
cpp += "#endif\n"
# Regular cpp file has no enum content anymore
cpp = ""
return out, cpp
# Dump cpp content for enum string conversion
dump_cpp = f"template<> const char *proto_enum_to_string<enums::{name}>(enums::{name} value) {{\n"
dump_cpp += " switch (value) {\n"
for v in desc.value:
dump_cpp += f" case enums::{v.name}:\n"
dump_cpp += f' return "{v.name}";\n'
dump_cpp += " default:\n"
dump_cpp += ' return "UNKNOWN";\n'
dump_cpp += " }\n"
dump_cpp += "}\n"
return out, cpp, dump_cpp
def calculate_message_estimated_size(desc: descriptor.DescriptorProto) -> int:
@@ -855,7 +965,7 @@ def calculate_message_estimated_size(desc: descriptor.DescriptorProto) -> int:
def build_message_type(
desc: descriptor.DescriptorProto,
base_class_fields: dict[str, list[descriptor.FieldDescriptorProto]] = None,
) -> tuple[str, str]:
) -> tuple[str, str, str]:
public_content: list[str] = []
protected_content: list[str] = []
decode_varint: list[str] = []
@@ -886,7 +996,7 @@ def build_message_type(
f"static constexpr uint16_t ESTIMATED_SIZE = {estimated_size};"
)
# Add message_name method for debugging
# Add message_name method inline in header
public_content.append("#ifdef HAS_PROTO_MESSAGE_DUMP")
snake_name = camel_to_snake(desc.name)
public_content.append(
@@ -993,32 +1103,32 @@ def build_message_type(
public_content.append(prot)
# If no fields to calculate size for, the default implementation in ProtoMessage will be used
o = f"void {desc.name}::dump_to(std::string &out) const {{"
if dump:
if len(dump) == 1 and len(dump[0]) + len(o) + 3 < 120:
o += f" {dump[0]} "
else:
o += "\n"
o += " __attribute__((unused)) char buffer[64];\n"
o += f' out.append("{desc.name} {{\\n");\n'
o += indent("\n".join(dump)) + "\n"
o += ' out.append("}");\n'
else:
o2 = f'out.append("{desc.name} {{}}");'
if len(o) + len(o2) + 3 < 120:
o += f" {o2} "
else:
o += "\n"
o += f" {o2}\n"
o += "}\n"
cpp += "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
cpp += o
cpp += "#endif\n"
# dump_to method declaration in header
prot = "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
prot += "void dump_to(std::string &out) const override;\n"
prot += "#endif\n"
public_content.append(prot)
# dump_to implementation will go in dump_cpp
dump_impl = f"void {desc.name}::dump_to(std::string &out) const {{"
if dump:
if len(dump) == 1 and len(dump[0]) + len(dump_impl) + 3 < 120:
dump_impl += f" {dump[0]} "
else:
dump_impl += "\n"
dump_impl += " __attribute__((unused)) char buffer[64];\n"
dump_impl += f' out.append("{desc.name} {{\\n");\n'
dump_impl += indent("\n".join(dump)) + "\n"
dump_impl += ' out.append("}");\n'
else:
o2 = f'out.append("{desc.name} {{}}");'
if len(dump_impl) + len(o2) + 3 < 120:
dump_impl += f" {o2} "
else:
dump_impl += "\n"
dump_impl += f" {o2}\n"
dump_impl += "}\n"
if base_class:
out = f"class {desc.name} : public {base_class} {{\n"
else:
@@ -1031,7 +1141,11 @@ def build_message_type(
if len(protected_content) > 0:
out += "\n"
out += "};\n"
return out, cpp
# Build dump_cpp content with dump_to implementation
dump_cpp = dump_impl
return out, cpp, dump_cpp
SOURCE_BOTH = 0
@@ -1119,7 +1233,7 @@ def find_common_fields(
def build_base_class(
base_class_name: str,
common_fields: list[descriptor.FieldDescriptorProto],
) -> tuple[str, str]:
) -> tuple[str, str, str]:
"""Build the base class definition and implementation."""
public_content = []
protected_content = []
@@ -1156,16 +1270,18 @@ def build_base_class(
out += "};\n"
# No implementation needed for base classes
dump_cpp = ""
return out, cpp
return out, cpp, dump_cpp
def generate_base_classes(
base_class_groups: dict[str, list[descriptor.DescriptorProto]],
) -> tuple[str, str]:
) -> tuple[str, str, str]:
"""Generate all base classes."""
all_headers = []
all_cpp = []
all_dump_cpp = []
for base_class_name, messages in base_class_groups.items():
# Find common fields
@@ -1173,11 +1289,12 @@ def generate_base_classes(
if common_fields:
# Generate base class
header, cpp = build_base_class(base_class_name, common_fields)
header, cpp, dump_cpp = build_base_class(base_class_name, common_fields)
all_headers.append(header)
all_cpp.append(cpp)
all_dump_cpp.append(dump_cpp)
return "\n".join(all_headers), "\n".join(all_cpp)
return "\n".join(all_headers), "\n".join(all_cpp), "\n".join(all_dump_cpp)
def build_service_message_type(
@@ -1244,15 +1361,17 @@ def main() -> None:
file = d.file[0]
content = FILE_HEADER
content += """\
#pragma once
#pragma once
#include "proto.h"
#include "api_pb2_size.h"
#include "esphome/core/defines.h"
namespace esphome {
namespace api {
#include "proto.h"
#include "api_pb2_size.h"
"""
namespace esphome {
namespace api {
"""
cpp = FILE_HEADER
cpp += """\
@@ -1261,19 +1380,56 @@ def main() -> None:
#include "esphome/core/log.h"
#include "esphome/core/helpers.h"
#include <cinttypes>
namespace esphome {
namespace api {
namespace esphome {
namespace api {
"""
"""
# Initialize dump cpp content
dump_cpp = FILE_HEADER
dump_cpp += """\
#include "api_pb2.h"
#include "esphome/core/helpers.h"
#include <cinttypes>
#ifdef HAS_PROTO_MESSAGE_DUMP
namespace esphome {
namespace api {
"""
content += "namespace enums {\n\n"
# Build dynamic ifdef mappings for both enums and messages
enum_ifdef_map, message_ifdef_map = build_type_usage_map(file)
# Simple grouping of enums by ifdef
current_ifdef = None
for enum in file.enum_type:
s, c = build_enum_type(enum)
s, c, dc = build_enum_type(enum, enum_ifdef_map)
enum_ifdef = enum_ifdef_map.get(enum.name)
# Handle ifdef changes
if enum_ifdef != current_ifdef:
if current_ifdef is not None:
content += "#endif\n"
dump_cpp += "#endif\n"
if enum_ifdef is not None:
content += f"#ifdef {enum_ifdef}\n"
dump_cpp += f"#ifdef {enum_ifdef}\n"
current_ifdef = enum_ifdef
content += s
cpp += c
dump_cpp += dc
# Close last ifdef
if current_ifdef is not None:
content += "#endif\n"
dump_cpp += "#endif\n"
content += "\n} // namespace enums\n\n"
@@ -1291,26 +1447,61 @@ def main() -> None:
# Generate base classes
if base_class_fields:
base_headers, base_cpp = generate_base_classes(base_class_groups)
base_headers, base_cpp, base_dump_cpp = generate_base_classes(base_class_groups)
content += base_headers
cpp += base_cpp
dump_cpp += base_dump_cpp
# Generate message types with base class information
# Simple grouping by ifdef
current_ifdef = None
for m in mt:
s, c = build_message_type(m, base_class_fields)
s, c, dc = build_message_type(m, base_class_fields)
msg_ifdef = message_ifdef_map.get(m.name)
# Handle ifdef changes
if msg_ifdef != current_ifdef:
if current_ifdef is not None:
content += "#endif\n"
if cpp:
cpp += "#endif\n"
if dump_cpp:
dump_cpp += "#endif\n"
if msg_ifdef is not None:
content += f"#ifdef {msg_ifdef}\n"
cpp += f"#ifdef {msg_ifdef}\n"
dump_cpp += f"#ifdef {msg_ifdef}\n"
current_ifdef = msg_ifdef
content += s
cpp += c
dump_cpp += dc
# Close last ifdef
if current_ifdef is not None:
content += "#endif\n"
cpp += "#endif\n"
dump_cpp += "#endif\n"
content += """\
} // namespace api
} // namespace esphome
"""
} // namespace api
} // namespace esphome
"""
cpp += """\
} // namespace api
} // namespace esphome
"""
} // namespace api
} // namespace esphome
"""
dump_cpp += """\
} // namespace api
} // namespace esphome
#endif // HAS_PROTO_MESSAGE_DUMP
"""
with open(root / "api_pb2.h", "w", encoding="utf-8") as f:
f.write(content)
@@ -1318,29 +1509,33 @@ def main() -> None:
with open(root / "api_pb2.cpp", "w", encoding="utf-8") as f:
f.write(cpp)
with open(root / "api_pb2_dump.cpp", "w", encoding="utf-8") as f:
f.write(dump_cpp)
hpp = FILE_HEADER
hpp += """\
#pragma once
#pragma once
#include "api_pb2.h"
#include "esphome/core/defines.h"
#include "esphome/core/defines.h"
namespace esphome {
namespace api {
#include "api_pb2.h"
"""
namespace esphome {
namespace api {
"""
cpp = FILE_HEADER
cpp += """\
#include "api_pb2_service.h"
#include "esphome/core/log.h"
#include "api_pb2_service.h"
#include "esphome/core/log.h"
namespace esphome {
namespace api {
namespace esphome {
namespace api {
static const char *const TAG = "api.service";
static const char *const TAG = "api.service";
"""
"""
class_name = "APIServerConnectionBase"
@@ -1419,7 +1614,7 @@ def main() -> None:
needs_conn = get_opt(m, pb.needs_setup_connection, True)
needs_auth = get_opt(m, pb.needs_authentication, True)
ifdef = ifdefs.get(inp, None)
ifdef = message_ifdef_map.get(inp, ifdefs.get(inp, None))
if ifdef is not None:
hpp += f"#ifdef {ifdef}\n"
@@ -1476,14 +1671,14 @@ def main() -> None:
hpp += """\
} // namespace api
} // namespace esphome
"""
} // namespace api
} // namespace esphome
"""
cpp += """\
} // namespace api
} // namespace esphome
"""
} // namespace api
} // namespace esphome
"""
with open(root / "api_pb2_service.h", "w", encoding="utf-8") as f:
f.write(hpp)
@@ -1506,6 +1701,8 @@ def main() -> None:
exec_clang_format(root / "api_pb2_service.cpp")
exec_clang_format(root / "api_pb2.h")
exec_clang_format(root / "api_pb2.cpp")
exec_clang_format(root / "api_pb2_dump.h")
exec_clang_format(root / "api_pb2_dump.cpp")
except ImportError:
pass