mirror of
https://github.com/esphome/esphome.git
synced 2025-10-29 22:24:26 +00:00
Merge remote-tracking branch 'upstream/dev' into proto_field_ifdefs
This commit is contained in:
@@ -8,7 +8,6 @@ from pathlib import Path
|
||||
import re
|
||||
from subprocess import call
|
||||
import sys
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
|
||||
import aioesphomeapi.api_options_pb2 as pb
|
||||
@@ -181,13 +180,7 @@ class TypeInfo(ABC):
|
||||
content = self.decode_varint
|
||||
if content is None:
|
||||
return None
|
||||
return dedent(
|
||||
f"""\
|
||||
case {self.number}: {{
|
||||
this->{self.field_name} = {content};
|
||||
return true;
|
||||
}}"""
|
||||
)
|
||||
return f"case {self.number}: this->{self.field_name} = {content}; break;"
|
||||
|
||||
decode_varint = None
|
||||
|
||||
@@ -196,13 +189,7 @@ class TypeInfo(ABC):
|
||||
content = self.decode_length
|
||||
if content is None:
|
||||
return None
|
||||
return dedent(
|
||||
f"""\
|
||||
case {self.number}: {{
|
||||
this->{self.field_name} = {content};
|
||||
return true;
|
||||
}}"""
|
||||
)
|
||||
return f"case {self.number}: this->{self.field_name} = {content}; break;"
|
||||
|
||||
decode_length = None
|
||||
|
||||
@@ -211,13 +198,7 @@ class TypeInfo(ABC):
|
||||
content = self.decode_32bit
|
||||
if content is None:
|
||||
return None
|
||||
return dedent(
|
||||
f"""\
|
||||
case {self.number}: {{
|
||||
this->{self.field_name} = {content};
|
||||
return true;
|
||||
}}"""
|
||||
)
|
||||
return f"case {self.number}: this->{self.field_name} = {content}; break;"
|
||||
|
||||
decode_32bit = None
|
||||
|
||||
@@ -226,13 +207,7 @@ class TypeInfo(ABC):
|
||||
content = self.decode_64bit
|
||||
if content is None:
|
||||
return None
|
||||
return dedent(
|
||||
f"""\
|
||||
case {self.number}: {{
|
||||
this->{self.field_name} = {content};
|
||||
return true;
|
||||
}}"""
|
||||
)
|
||||
return f"case {self.number}: this->{self.field_name} = {content}; break;"
|
||||
|
||||
decode_64bit = None
|
||||
|
||||
@@ -337,6 +312,37 @@ class TypeInfo(ABC):
|
||||
|
||||
TYPE_INFO: dict[int, TypeInfo] = {}
|
||||
|
||||
# Unsupported 64-bit types that would add overhead for embedded systems
|
||||
# TYPE_DOUBLE = 1, TYPE_FIXED64 = 6, TYPE_SFIXED64 = 16, TYPE_SINT64 = 18
|
||||
UNSUPPORTED_TYPES = {1: "double", 6: "fixed64", 16: "sfixed64", 18: "sint64"}
|
||||
|
||||
|
||||
def validate_field_type(field_type: int, field_name: str = "") -> None:
|
||||
"""Validate that the field type is supported by ESPHome API.
|
||||
|
||||
Raises ValueError for unsupported 64-bit types.
|
||||
"""
|
||||
if field_type in UNSUPPORTED_TYPES:
|
||||
type_name = UNSUPPORTED_TYPES[field_type]
|
||||
field_info = f" (field: {field_name})" if field_name else ""
|
||||
raise ValueError(
|
||||
f"64-bit type '{type_name}'{field_info} is not supported by ESPHome API. "
|
||||
"These types add significant overhead for embedded systems. "
|
||||
"If you need 64-bit support, please add the necessary encoding/decoding "
|
||||
"functions to proto.h/proto.cpp first."
|
||||
)
|
||||
|
||||
|
||||
def get_type_info_for_field(field: descriptor.FieldDescriptorProto) -> TypeInfo:
|
||||
"""Get the appropriate TypeInfo for a field, handling repeated fields.
|
||||
|
||||
Also validates that the field type is supported.
|
||||
"""
|
||||
if field.label == 3: # repeated
|
||||
return RepeatedTypeInfo(field)
|
||||
validate_field_type(field.type, field.name)
|
||||
return TYPE_INFO[field.type](field)
|
||||
|
||||
|
||||
def register_type(name: int):
|
||||
"""Decorator to register a type with a name and number."""
|
||||
@@ -573,13 +579,7 @@ class MessageType(TypeInfo):
|
||||
@property
|
||||
def decode_length_content(self) -> str:
|
||||
# Custom decode that doesn't use templates
|
||||
return dedent(
|
||||
f"""\
|
||||
case {self.number}: {{
|
||||
value.decode_to_message(this->{self.field_name});
|
||||
return true;
|
||||
}}"""
|
||||
)
|
||||
return f"case {self.number}: value.decode_to_message(this->{self.field_name}); break;"
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f"{name}.dump_to(out);"
|
||||
@@ -762,6 +762,7 @@ class SInt64Type(TypeInfo):
|
||||
class RepeatedTypeInfo(TypeInfo):
|
||||
def __init__(self, field: descriptor.FieldDescriptorProto) -> None:
|
||||
super().__init__(field)
|
||||
validate_field_type(field.type, field.name)
|
||||
self._ti: TypeInfo = TYPE_INFO[field.type](field)
|
||||
|
||||
@property
|
||||
@@ -789,12 +790,8 @@ class RepeatedTypeInfo(TypeInfo):
|
||||
content = self._ti.decode_varint
|
||||
if content is None:
|
||||
return None
|
||||
return dedent(
|
||||
f"""\
|
||||
case {self.number}: {{
|
||||
this->{self.field_name}.push_back({content});
|
||||
return true;
|
||||
}}"""
|
||||
return (
|
||||
f"case {self.number}: this->{self.field_name}.push_back({content}); break;"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -802,22 +799,11 @@ class RepeatedTypeInfo(TypeInfo):
|
||||
content = self._ti.decode_length
|
||||
if content is None and isinstance(self._ti, MessageType):
|
||||
# Special handling for non-template message decoding
|
||||
return dedent(
|
||||
f"""\
|
||||
case {self.number}: {{
|
||||
this->{self.field_name}.emplace_back();
|
||||
value.decode_to_message(this->{self.field_name}.back());
|
||||
return true;
|
||||
}}"""
|
||||
)
|
||||
return f"case {self.number}: this->{self.field_name}.emplace_back(); value.decode_to_message(this->{self.field_name}.back()); break;"
|
||||
if content is None:
|
||||
return None
|
||||
return dedent(
|
||||
f"""\
|
||||
case {self.number}: {{
|
||||
this->{self.field_name}.push_back({content});
|
||||
return true;
|
||||
}}"""
|
||||
return (
|
||||
f"case {self.number}: this->{self.field_name}.push_back({content}); break;"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -825,12 +811,8 @@ class RepeatedTypeInfo(TypeInfo):
|
||||
content = self._ti.decode_32bit
|
||||
if content is None:
|
||||
return None
|
||||
return dedent(
|
||||
f"""\
|
||||
case {self.number}: {{
|
||||
this->{self.field_name}.push_back({content});
|
||||
return true;
|
||||
}}"""
|
||||
return (
|
||||
f"case {self.number}: this->{self.field_name}.push_back({content}); break;"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -838,12 +820,8 @@ class RepeatedTypeInfo(TypeInfo):
|
||||
content = self._ti.decode_64bit
|
||||
if content is None:
|
||||
return None
|
||||
return dedent(
|
||||
f"""\
|
||||
case {self.number}: {{
|
||||
this->{self.field_name}.push_back({content});
|
||||
return true;
|
||||
}}"""
|
||||
return (
|
||||
f"case {self.number}: this->{self.field_name}.push_back({content}); break;"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -1049,10 +1027,7 @@ def calculate_message_estimated_size(desc: descriptor.DescriptorProto) -> int:
|
||||
total_size = 0
|
||||
|
||||
for field in desc.field:
|
||||
if field.label == 3: # repeated
|
||||
ti = RepeatedTypeInfo(field)
|
||||
else:
|
||||
ti = TYPE_INFO[field.type](field)
|
||||
ti = get_type_info_for_field(field)
|
||||
|
||||
# Add estimated size for this field
|
||||
total_size += ti.get_estimated_size()
|
||||
@@ -1186,41 +1161,45 @@ def build_message_type(
|
||||
|
||||
cpp = ""
|
||||
if decode_varint:
|
||||
decode_varint.append("default:\n return false;")
|
||||
o = f"bool {desc.name}::decode_varint(uint32_t field_id, ProtoVarInt value) {{\n"
|
||||
o += " switch (field_id) {\n"
|
||||
o += indent("\n".join(decode_varint), " ") + "\n"
|
||||
o += " default: return false;\n"
|
||||
o += " }\n"
|
||||
o += " return true;\n"
|
||||
o += "}\n"
|
||||
cpp += o
|
||||
prot = "bool decode_varint(uint32_t field_id, ProtoVarInt value) override;"
|
||||
protected_content.insert(0, prot)
|
||||
if decode_length:
|
||||
decode_length.append("default:\n return false;")
|
||||
o = f"bool {desc.name}::decode_length(uint32_t field_id, ProtoLengthDelimited value) {{\n"
|
||||
o += " switch (field_id) {\n"
|
||||
o += indent("\n".join(decode_length), " ") + "\n"
|
||||
o += " default: return false;\n"
|
||||
o += " }\n"
|
||||
o += " return true;\n"
|
||||
o += "}\n"
|
||||
cpp += o
|
||||
prot = "bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;"
|
||||
protected_content.insert(0, prot)
|
||||
if decode_32bit:
|
||||
decode_32bit.append("default:\n return false;")
|
||||
o = f"bool {desc.name}::decode_32bit(uint32_t field_id, Proto32Bit value) {{\n"
|
||||
o += " switch (field_id) {\n"
|
||||
o += indent("\n".join(decode_32bit), " ") + "\n"
|
||||
o += " default: return false;\n"
|
||||
o += " }\n"
|
||||
o += " return true;\n"
|
||||
o += "}\n"
|
||||
cpp += o
|
||||
prot = "bool decode_32bit(uint32_t field_id, Proto32Bit value) override;"
|
||||
protected_content.insert(0, prot)
|
||||
if decode_64bit:
|
||||
decode_64bit.append("default:\n return false;")
|
||||
o = f"bool {desc.name}::decode_64bit(uint32_t field_id, Proto64Bit value) {{\n"
|
||||
o += " switch (field_id) {\n"
|
||||
o += indent("\n".join(decode_64bit), " ") + "\n"
|
||||
o += " default: return false;\n"
|
||||
o += " }\n"
|
||||
o += " return true;\n"
|
||||
o += "}\n"
|
||||
cpp += o
|
||||
prot = "bool decode_64bit(uint32_t field_id, Proto64Bit value) override;"
|
||||
@@ -1394,10 +1373,7 @@ def build_base_class(
|
||||
# For base classes, we only declare the fields but don't handle encode/decode
|
||||
# The derived classes will handle encoding/decoding with their specific field numbers
|
||||
for field in common_fields:
|
||||
if field.label == 3: # repeated
|
||||
ti = RepeatedTypeInfo(field)
|
||||
else:
|
||||
ti = TYPE_INFO[field.type](field)
|
||||
ti = get_type_info_for_field(field)
|
||||
|
||||
# Only add field declarations, not encode/decode logic
|
||||
protected_content.extend(ti.protected_content)
|
||||
|
||||
@@ -62,26 +62,6 @@ def get_clang_tidy_version_from_requirements() -> str:
|
||||
return "clang-tidy version not found"
|
||||
|
||||
|
||||
def extract_platformio_flags() -> str:
|
||||
"""Extract clang-tidy related flags from platformio.ini"""
|
||||
flags: list[str] = []
|
||||
in_clangtidy_section = False
|
||||
|
||||
platformio_path = Path(__file__).parent.parent / "platformio.ini"
|
||||
lines = read_file_lines(platformio_path)
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line.startswith("[flags:clangtidy]"):
|
||||
in_clangtidy_section = True
|
||||
continue
|
||||
elif line.startswith("[") and in_clangtidy_section:
|
||||
break
|
||||
elif in_clangtidy_section and line and not line.startswith("#"):
|
||||
flags.append(line)
|
||||
|
||||
return "\n".join(sorted(flags))
|
||||
|
||||
|
||||
def read_file_bytes(path: Path) -> bytes:
|
||||
"""Read bytes from a file."""
|
||||
with open(path, "rb") as f:
|
||||
@@ -101,9 +81,10 @@ def calculate_clang_tidy_hash() -> str:
|
||||
version = get_clang_tidy_version_from_requirements()
|
||||
hasher.update(version.encode())
|
||||
|
||||
# Hash relevant platformio.ini sections
|
||||
pio_flags = extract_platformio_flags()
|
||||
hasher.update(pio_flags.encode())
|
||||
# Hash the entire platformio.ini file
|
||||
platformio_path = Path(__file__).parent.parent / "platformio.ini"
|
||||
platformio_content = read_file_bytes(platformio_path)
|
||||
hasher.update(platformio_content)
|
||||
|
||||
return hasher.hexdigest()
|
||||
|
||||
@@ -126,7 +107,8 @@ def write_file_content(path: Path, content: str) -> None:
|
||||
def write_hash(hash_value: str) -> None:
|
||||
"""Write hash to file"""
|
||||
hash_file = Path(__file__).parent.parent / ".clang-tidy.hash"
|
||||
write_file_content(hash_file, hash_value)
|
||||
# Strip any trailing newlines to ensure consistent formatting
|
||||
write_file_content(hash_file, hash_value.strip() + "\n")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
||||
@@ -72,7 +72,7 @@ for f in ./tests/components/$target_component/*.*.yaml; do
|
||||
|
||||
if [ "$target_platform" = "all" ] || [ $file_name_parts = 2 ]; then
|
||||
# Test has *not* defined a specific target platform. Need to run tests for all possible target platforms.
|
||||
|
||||
|
||||
for target_platform_file in ./tests/test_build_components/build_components_base.*.yaml; do
|
||||
IFS='/' read -r -a folder_name <<< "$target_platform_file"
|
||||
IFS='.' read -r -a file_name <<< "${folder_name[3]}"
|
||||
@@ -83,7 +83,7 @@ for f in ./tests/components/$target_component/*.*.yaml; do
|
||||
|
||||
else
|
||||
# Test has defined a specific target platform.
|
||||
|
||||
|
||||
# Validate we have a base test yaml for selected platform.
|
||||
# The target_platform is sourced from the following location.
|
||||
# 1. `./tests/test_build_components/build_components_base.[target_platform].yaml`
|
||||
|
||||
Reference in New Issue
Block a user