1
0
mirror of https://github.com/esphome/esphome.git synced 2025-11-15 14:25:45 +00:00

Compare commits

...

1 Commits

Author SHA1 Message Date
J. Nick Koston
cc1b547ad2 der dupe lam 2025-11-14 22:27:23 -06:00
2 changed files with 137 additions and 6 deletions

View File

@@ -497,6 +497,11 @@ def generate_cpp_contents(config: ConfigType) -> None:
CORE.flush_tasks()
# Flush deferred lambda deduplication declarations after all variables are declared
from esphome import cpp_generator as cg
cg.flush_lambda_dedup_declarations()
def write_cpp_file() -> int:
code_s = indent(CORE.cpp_main_section)

View File

@@ -24,6 +24,10 @@ from esphome.types import Expression, SafeExpType, TemplateArgsType
from esphome.util import OrderedDict
from esphome.yaml_util import ESPHomeDataBase
# Keys for lambda deduplication storage in CORE.data
_KEY_LAMBDA_DEDUP = "lambda_dedup"
_KEY_LAMBDA_DEDUP_DECLARATIONS = "lambda_dedup_declarations"
class RawExpression(Expression):
__slots__ = ("text",)
@@ -188,7 +192,7 @@ class LambdaExpression(Expression):
def __init__(
self, parts, parameters, capture: str = "=", return_type=None, source=None
):
) -> None:
self.parts = parts
if not isinstance(parameters, ParameterListExpression):
parameters = ParameterListExpression(*parameters)
@@ -197,16 +201,21 @@ class LambdaExpression(Expression):
self.capture = capture
self.return_type = safe_exp(return_type) if return_type is not None else None
def __str__(self):
def _format_body(self) -> str:
"""Format the lambda body with source directive and content."""
body = ""
if self.source is not None:
body += f"{self.source.as_line_directive}\n"
body += self.content
return body
def __str__(self) -> str:
# Stateless lambdas (empty capture) implicitly convert to function pointers
# when assigned to function pointer types - no unary + needed
cpp = f"[{self.capture}]({self.parameters})"
if self.return_type is not None:
cpp += f" -> {self.return_type}"
cpp += " {\n"
if self.source is not None:
cpp += f"{self.source.as_line_directive}\n"
cpp += f"{self.content}\n}}"
cpp += f" {{\n{self._format_body()}\n}}"
return indent_all_but_first_and_last(cpp)
@property
@@ -214,6 +223,37 @@ class LambdaExpression(Expression):
return "".join(str(part) for part in self.parts)
class SharedFunctionLambdaExpression(LambdaExpression):
"""A lambda expression that references a shared deduplicated function.
This class wraps a function pointer but maintains the LambdaExpression
interface so calling code works unchanged.
"""
__slots__ = ("_func_name",)
def __init__(
self,
func_name: str,
parameters: TemplateArgsType,
return_type: SafeExpType | None = None,
) -> None:
# Initialize parent with empty parts since we're just a function reference
super().__init__(
[], parameters, capture="", return_type=return_type, source=None
)
self._func_name = func_name
def __str__(self) -> str:
# Just return the function name - it's already a function pointer
return self._func_name
@property
def content(self) -> str:
# No content, just a function reference
return ""
# pylint: disable=abstract-method
class Literal(Expression, metaclass=abc.ABCMeta):
__slots__ = ()
@@ -583,6 +623,24 @@ def add_global(expression: SafeExpType | Statement, prepend: bool = False):
CORE.add_global(expression, prepend)
def flush_lambda_dedup_declarations():
"""Flush all deferred lambda deduplication declarations to global scope.
This must be called after all component code generation is complete
to ensure all referenced variables are declared before the shared
lambda functions that use them.
"""
if _KEY_LAMBDA_DEDUP_DECLARATIONS not in CORE.data:
return
declarations = CORE.data[_KEY_LAMBDA_DEDUP_DECLARATIONS]
for func_declaration in declarations:
add_global(RawStatement(func_declaration))
# Clear the list so we don't add them again
CORE.data[_KEY_LAMBDA_DEDUP_DECLARATIONS] = []
def add_library(name: str, version: str | None, repository: str | None = None):
"""Add a library to the codegen library storage.
@@ -656,6 +714,62 @@ async def get_variable_with_full_id(id_: ID) -> tuple[ID, "MockObj"]:
return await CORE.get_variable_with_full_id(id_)
def _try_deduplicate_lambda(lambda_expr: LambdaExpression) -> str | None:
"""Try to deduplicate a lambda expression.
If an identical lambda was already generated, returns the name of the
shared function. Otherwise, creates a new shared function and stores it.
Args:
lambda_expr: The lambda expression to potentially deduplicate
Returns:
The name of the shared function if this lambda should be deduplicated,
None if this is the first occurrence (caller should use original lambda)
"""
# Create a unique key from the lambda content, parameters, and return type
content = lambda_expr.content
param_str = str(lambda_expr.parameters)
return_str = (
str(lambda_expr.return_type) if lambda_expr.return_type is not None else "void"
)
# Use tuple of (content, params, return_type) as key
lambda_key = (content, param_str, return_str)
# Initialize deduplication storage in CORE.data if not exists
if _KEY_LAMBDA_DEDUP not in CORE.data:
CORE.data[_KEY_LAMBDA_DEDUP] = {}
lambda_cache = CORE.data[_KEY_LAMBDA_DEDUP]
# Check if we've seen this lambda before
if lambda_key in lambda_cache:
# Return name of existing shared function
return lambda_cache[lambda_key]
# First occurrence - create a shared function
# Use the cache size as the function number
func_name = f"shared_lambda_{len(lambda_cache)}"
# Build the function declaration using lambda's body formatting
func_declaration = (
f"{return_str} {func_name}({param_str}) {{\n{lambda_expr._format_body()}\n}}"
)
# Store the declaration to be added later (after all variable declarations)
# We can't add it immediately because it might reference variables not yet declared
if _KEY_LAMBDA_DEDUP_DECLARATIONS not in CORE.data:
CORE.data[_KEY_LAMBDA_DEDUP_DECLARATIONS] = []
CORE.data[_KEY_LAMBDA_DEDUP_DECLARATIONS].append(func_declaration)
# Store in cache
lambda_cache[lambda_key] = func_name
# Return the function name (this is the first occurrence, but we still generate shared function)
return func_name
async def process_lambda(
value: Lambda,
parameters: TemplateArgsType,
@@ -713,6 +827,18 @@ async def process_lambda(
location.line += value.content_offset
else:
location = None
# Lambda deduplication: Only deduplicate stateless lambdas (empty capture).
# Stateful lambdas cannot be shared as they capture different contexts.
if capture == "":
lambda_expr = LambdaExpression(
parts, parameters, capture, return_type, location
)
func_name = _try_deduplicate_lambda(lambda_expr)
if func_name is not None:
# Return a shared function reference instead of inline lambda
return SharedFunctionLambdaExpression(func_name, parameters, return_type)
return LambdaExpression(parts, parameters, capture, return_type, location)