1
0
mirror of https://github.com/esphome/esphome.git synced 2026-02-08 00:31:58 +00:00

[core] Simplify generation of Lambda during to_code() (#13533)

This commit is contained in:
Clyde Stubbs
2026-01-31 12:18:30 +11:00
committed by GitHub
parent 5e3561d60b
commit 9dcb469460
4 changed files with 310 additions and 8 deletions

View File

@@ -12,8 +12,8 @@ from esphome.components.packet_transport import (
) )
import esphome.config_validation as cv import esphome.config_validation as cv
from esphome.const import CONF_DATA, CONF_ID, CONF_PORT, CONF_TRIGGER_ID from esphome.const import CONF_DATA, CONF_ID, CONF_PORT, CONF_TRIGGER_ID
from esphome.core import ID, Lambda from esphome.core import ID
from esphome.cpp_generator import ExpressionStatement, MockObj from esphome.cpp_generator import literal
CODEOWNERS = ["@clydebarrow"] CODEOWNERS = ["@clydebarrow"]
DEPENDENCIES = ["network"] DEPENDENCIES = ["network"]
@@ -24,6 +24,8 @@ udp_ns = cg.esphome_ns.namespace("udp")
UDPComponent = udp_ns.class_("UDPComponent", cg.Component) UDPComponent = udp_ns.class_("UDPComponent", cg.Component)
UDPWriteAction = udp_ns.class_("UDPWriteAction", automation.Action) UDPWriteAction = udp_ns.class_("UDPWriteAction", automation.Action)
trigger_args = cg.std_vector.template(cg.uint8) trigger_args = cg.std_vector.template(cg.uint8)
trigger_argname = "data"
trigger_argtype = [(trigger_args, trigger_argname)]
CONF_ADDRESSES = "addresses" CONF_ADDRESSES = "addresses"
CONF_LISTEN_ADDRESS = "listen_address" CONF_LISTEN_ADDRESS = "listen_address"
@@ -111,13 +113,14 @@ async def to_code(config):
cg.add(var.set_addresses([str(addr) for addr in config[CONF_ADDRESSES]])) cg.add(var.set_addresses([str(addr) for addr in config[CONF_ADDRESSES]]))
if on_receive := config.get(CONF_ON_RECEIVE): if on_receive := config.get(CONF_ON_RECEIVE):
on_receive = on_receive[0] on_receive = on_receive[0]
trigger = cg.new_Pvariable(on_receive[CONF_TRIGGER_ID]) trigger_id = cg.new_Pvariable(on_receive[CONF_TRIGGER_ID])
trigger = await automation.build_automation( trigger = await automation.build_automation(
trigger, [(trigger_args, "data")], on_receive trigger_id, trigger_argtype, on_receive
) )
trigger = Lambda(str(ExpressionStatement(trigger.trigger(MockObj("data"))))) trigger_lambda = await cg.process_lambda(
trigger = await cg.process_lambda(trigger, [(trigger_args, "data")]) trigger.trigger(literal(trigger_argname)), trigger_argtype
cg.add(var.add_listener(trigger)) )
cg.add(var.add_listener(trigger_lambda))
cg.add(var.set_should_listen()) cg.add(var.set_should_listen())

View File

@@ -278,9 +278,13 @@ LAMBDA_PROG = re.compile(r"\bid\(\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\)(\.?)")
class Lambda: class Lambda:
def __init__(self, value): def __init__(self, value):
from esphome.cpp_generator import Expression, statement
# pylint: disable=protected-access # pylint: disable=protected-access
if isinstance(value, Lambda): if isinstance(value, Lambda):
self._value = value._value self._value = value._value
elif isinstance(value, Expression):
self._value = str(statement(value))
else: else:
self._value = value self._value = value
self._parts = None self._parts = None

View File

@@ -462,6 +462,16 @@ def statement(expression: Expression | Statement) -> Statement:
return ExpressionStatement(expression) return ExpressionStatement(expression)
def literal(name: str) -> "MockObj":
"""Create a literal name that will appear in the generated code
not surrounded by quotes.
:param name: The name of the literal.
:return: The literal as a MockObj.
"""
return MockObj(name, "")
def variable( def variable(
id_: ID, rhs: SafeExpType, type_: "MockObj" = None, register=True id_: ID, rhs: SafeExpType, type_: "MockObj" = None, register=True
) -> "MockObj": ) -> "MockObj":
@@ -665,7 +675,7 @@ async def get_variable_with_full_id(id_: ID) -> tuple[ID, "MockObj"]:
async def process_lambda( async def process_lambda(
value: Lambda, value: Lambda | Expression,
parameters: TemplateArgsType, parameters: TemplateArgsType,
capture: str = "", capture: str = "",
return_type: SafeExpType = None, return_type: SafeExpType = None,
@@ -689,6 +699,14 @@ async def process_lambda(
if value is None: if value is None:
return None return None
# Inadvertently passing a malformed parameters value will lead to the build process mysteriously hanging at the
# "Generating C++ source..." stage, so check here to save the developer's hair.
assert isinstance(parameters, list) and all(
isinstance(p, tuple) and len(p) == 2 for p in parameters
)
if isinstance(value, Expression):
value = Lambda(value)
parts = value.parts[:] parts = value.parts[:]
for i, id in enumerate(value.requires_ids): for i, id in enumerate(value.requires_ids):
full_id, var = await get_variable_with_full_id(id) full_id, var = await get_variable_with_full_id(id)

View File

@@ -347,3 +347,280 @@ class TestMockObj:
assert isinstance(actual, cg.MockObj) assert isinstance(actual, cg.MockObj)
assert actual.base == "foo.eek" assert actual.base == "foo.eek"
assert actual.op == "." assert actual.op == "."
class TestStatementFunction:
"""Tests for the statement() function."""
def test_statement__expression_converted_to_statement(self):
"""Test that expressions are converted to ExpressionStatement."""
expr = cg.RawExpression("foo()")
result = cg.statement(expr)
assert isinstance(result, cg.ExpressionStatement)
assert str(result) == "foo();"
def test_statement__statement_unchanged(self):
"""Test that statements are returned unchanged."""
stmt = cg.RawStatement("foo()")
result = cg.statement(stmt)
assert result is stmt
assert str(result) == "foo()"
def test_statement__expression_statement_unchanged(self):
"""Test that ExpressionStatement is returned unchanged."""
stmt = cg.ExpressionStatement(42)
result = cg.statement(stmt)
assert result is stmt
assert str(result) == "42;"
def test_statement__line_comment_unchanged(self):
"""Test that LineComment is returned unchanged."""
stmt = cg.LineComment("This is a comment")
result = cg.statement(stmt)
assert result is stmt
assert str(result) == "// This is a comment"
class TestLiteralFunction:
"""Tests for the literal() function."""
def test_literal__creates_mockobj(self):
"""Test that literal() creates a MockObj."""
result = cg.literal("MY_CONSTANT")
assert isinstance(result, cg.MockObj)
assert result.base == "MY_CONSTANT"
assert result.op == ""
def test_literal__string_representation(self):
"""Test that literal names appear unquoted in generated code."""
result = cg.literal("nullptr")
assert str(result) == "nullptr"
def test_literal__can_be_used_in_expressions(self):
"""Test that literals can be used as part of larger expressions."""
null_lit = cg.literal("nullptr")
expr = cg.CallExpression(cg.RawExpression("my_func"), null_lit)
assert str(expr) == "my_func(nullptr)"
def test_literal__common_cpp_literals(self):
"""Test common C++ literal values."""
test_cases = [
("nullptr", "nullptr"),
("true", "true"),
("false", "false"),
("NULL", "NULL"),
("NAN", "NAN"),
]
for name, expected in test_cases:
result = cg.literal(name)
assert str(result) == expected
class TestLambdaConstructor:
"""Tests for the Lambda class constructor in core/__init__.py."""
def test_lambda__from_string(self):
"""Test Lambda constructor with string argument."""
from esphome.core import Lambda
lambda_obj = Lambda("return x + 1;")
assert lambda_obj.value == "return x + 1;"
assert str(lambda_obj) == "return x + 1;"
def test_lambda__from_expression(self):
"""Test Lambda constructor with Expression argument."""
from esphome.core import Lambda
expr = cg.RawExpression("x + 1")
lambda_obj = Lambda(expr)
# Expression should be converted to statement (with semicolon)
assert lambda_obj.value == "x + 1;"
def test_lambda__from_lambda(self):
"""Test Lambda constructor with another Lambda argument."""
from esphome.core import Lambda
original = Lambda("return x + 1;")
copy = Lambda(original)
assert copy.value == original.value
assert copy.value == "return x + 1;"
def test_lambda__parts_parsing(self):
"""Test that Lambda correctly parses parts with id() references."""
from esphome.core import Lambda
lambda_obj = Lambda("return id(my_sensor).state;")
parts = lambda_obj.parts
# Parts should be split by LAMBDA_PROG regex: text, id, op, text
assert len(parts) == 4
assert parts[0] == "return "
assert parts[1] == "my_sensor"
assert parts[2] == "."
assert parts[3] == "state;"
def test_lambda__requires_ids(self):
"""Test that Lambda correctly extracts required IDs."""
from esphome.core import ID, Lambda
lambda_obj = Lambda("return id(sensor1).state + id(sensor2).value;")
ids = lambda_obj.requires_ids
assert len(ids) == 2
assert all(isinstance(id_obj, ID) for id_obj in ids)
assert ids[0].id == "sensor1"
assert ids[1].id == "sensor2"
def test_lambda__no_ids(self):
"""Test Lambda with no id() references."""
from esphome.core import Lambda
lambda_obj = Lambda("return 42;")
ids = lambda_obj.requires_ids
assert len(ids) == 0
def test_lambda__comment_removal(self):
"""Test that comments are removed when parsing parts."""
from esphome.core import Lambda
lambda_obj = Lambda("return id(sensor).state; // Get sensor state")
parts = lambda_obj.parts
# Comment should be replaced with space, not affect parsing
assert "my_sensor" not in str(parts)
def test_lambda__multiline_string(self):
"""Test Lambda with multiline string."""
from esphome.core import Lambda
code = """if (id(sensor).state > 0) {
return true;
}
return false;"""
lambda_obj = Lambda(code)
assert lambda_obj.value == code
assert "sensor" in [id_obj.id for id_obj in lambda_obj.requires_ids]
@pytest.mark.asyncio
class TestProcessLambda:
"""Tests for the process_lambda() async function."""
async def test_process_lambda__none_value(self):
"""Test that None returns None."""
result = await cg.process_lambda(None, [])
assert result is None
async def test_process_lambda__with_expression(self):
"""Test process_lambda with Expression argument."""
expr = cg.RawExpression("return x + 1")
result = await cg.process_lambda(expr, [(int, "x")])
assert isinstance(result, cg.LambdaExpression)
assert "x + 1" in str(result)
async def test_process_lambda__simple_lambda_no_ids(self):
"""Test process_lambda with simple Lambda without id() references."""
from esphome.core import Lambda
lambda_obj = Lambda("return x + 1;")
result = await cg.process_lambda(lambda_obj, [(int, "x")])
assert isinstance(result, cg.LambdaExpression)
# Should have parameter
lambda_str = str(result)
assert "int32_t x" in lambda_str
assert "return x + 1;" in lambda_str
async def test_process_lambda__with_return_type(self):
"""Test process_lambda with return type specified."""
from esphome.core import Lambda
lambda_obj = Lambda("return x > 0;")
result = await cg.process_lambda(lambda_obj, [(int, "x")], return_type=bool)
assert isinstance(result, cg.LambdaExpression)
lambda_str = str(result)
assert "-> bool" in lambda_str
async def test_process_lambda__with_capture(self):
"""Test process_lambda with capture specified."""
from esphome.core import Lambda
lambda_obj = Lambda("return captured + x;")
result = await cg.process_lambda(lambda_obj, [(int, "x")], capture="captured")
assert isinstance(result, cg.LambdaExpression)
lambda_str = str(result)
assert "[captured]" in lambda_str
async def test_process_lambda__empty_capture(self):
"""Test process_lambda with empty capture (stateless lambda)."""
from esphome.core import Lambda
lambda_obj = Lambda("return x + 1;")
result = await cg.process_lambda(lambda_obj, [(int, "x")], capture="")
assert isinstance(result, cg.LambdaExpression)
lambda_str = str(result)
assert "[]" in lambda_str
async def test_process_lambda__no_parameters(self):
"""Test process_lambda with no parameters."""
from esphome.core import Lambda
lambda_obj = Lambda("return 42;")
result = await cg.process_lambda(lambda_obj, [])
assert isinstance(result, cg.LambdaExpression)
lambda_str = str(result)
# Should have empty parameter list
assert "()" in lambda_str
async def test_process_lambda__multiple_parameters(self):
"""Test process_lambda with multiple parameters."""
from esphome.core import Lambda
lambda_obj = Lambda("return x + y + z;")
result = await cg.process_lambda(
lambda_obj, [(int, "x"), (float, "y"), (bool, "z")]
)
assert isinstance(result, cg.LambdaExpression)
lambda_str = str(result)
assert "int32_t x" in lambda_str
assert "float y" in lambda_str
assert "bool z" in lambda_str
async def test_process_lambda__parameter_validation(self):
"""Test that malformed parameters raise assertion error."""
from esphome.core import Lambda
lambda_obj = Lambda("return x;")
# Test invalid parameter format (not list of tuples)
with pytest.raises(AssertionError):
await cg.process_lambda(lambda_obj, "invalid")
# Test invalid tuple format (not 2-element tuples)
with pytest.raises(AssertionError):
await cg.process_lambda(lambda_obj, [(int, "x", "extra")])
# Test invalid tuple format (single element)
with pytest.raises(AssertionError):
await cg.process_lambda(lambda_obj, [(int,)])