diff --git a/esphome/components/udp/__init__.py b/esphome/components/udp/__init__.py index 9be196d420..8252e35023 100644 --- a/esphome/components/udp/__init__.py +++ b/esphome/components/udp/__init__.py @@ -12,8 +12,8 @@ from esphome.components.packet_transport import ( ) import esphome.config_validation as cv from esphome.const import CONF_DATA, CONF_ID, CONF_PORT, CONF_TRIGGER_ID -from esphome.core import ID, Lambda -from esphome.cpp_generator import ExpressionStatement, MockObj +from esphome.core import ID +from esphome.cpp_generator import literal CODEOWNERS = ["@clydebarrow"] DEPENDENCIES = ["network"] @@ -24,6 +24,8 @@ udp_ns = cg.esphome_ns.namespace("udp") UDPComponent = udp_ns.class_("UDPComponent", cg.Component) UDPWriteAction = udp_ns.class_("UDPWriteAction", automation.Action) trigger_args = cg.std_vector.template(cg.uint8) +trigger_argname = "data" +trigger_argtype = [(trigger_args, trigger_argname)] CONF_ADDRESSES = "addresses" CONF_LISTEN_ADDRESS = "listen_address" @@ -111,13 +113,14 @@ async def to_code(config): cg.add(var.set_addresses([str(addr) for addr in config[CONF_ADDRESSES]])) if on_receive := config.get(CONF_ON_RECEIVE): on_receive = on_receive[0] - trigger = cg.new_Pvariable(on_receive[CONF_TRIGGER_ID]) + trigger_id = cg.new_Pvariable(on_receive[CONF_TRIGGER_ID]) trigger = await automation.build_automation( - trigger, [(trigger_args, "data")], on_receive + trigger_id, trigger_argtype, on_receive ) - trigger = Lambda(str(ExpressionStatement(trigger.trigger(MockObj("data"))))) - trigger = await cg.process_lambda(trigger, [(trigger_args, "data")]) - cg.add(var.add_listener(trigger)) + trigger_lambda = await cg.process_lambda( + trigger.trigger(literal(trigger_argname)), trigger_argtype + ) + cg.add(var.add_listener(trigger_lambda)) cg.add(var.set_should_listen()) diff --git a/esphome/core/__init__.py b/esphome/core/__init__.py index 9a7dd49609..5308ad241e 100644 --- a/esphome/core/__init__.py +++ b/esphome/core/__init__.py @@ -278,9 +278,13 @@ LAMBDA_PROG = re.compile(r"\bid\(\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\)(\.?)") class Lambda: def __init__(self, value): + from esphome.cpp_generator import Expression, statement + # pylint: disable=protected-access if isinstance(value, Lambda): self._value = value._value + elif isinstance(value, Expression): + self._value = str(statement(value)) else: self._value = value self._parts = None diff --git a/esphome/cpp_generator.py b/esphome/cpp_generator.py index cff0748c95..83f2d6cf81 100644 --- a/esphome/cpp_generator.py +++ b/esphome/cpp_generator.py @@ -462,6 +462,16 @@ def statement(expression: Expression | Statement) -> Statement: return ExpressionStatement(expression) +def literal(name: str) -> "MockObj": + """Create a literal name that will appear in the generated code + not surrounded by quotes. + + :param name: The name of the literal. + :return: The literal as a MockObj. + """ + return MockObj(name, "") + + def variable( id_: ID, rhs: SafeExpType, type_: "MockObj" = None, register=True ) -> "MockObj": @@ -665,7 +675,7 @@ async def get_variable_with_full_id(id_: ID) -> tuple[ID, "MockObj"]: async def process_lambda( - value: Lambda, + value: Lambda | Expression, parameters: TemplateArgsType, capture: str = "", return_type: SafeExpType = None, @@ -689,6 +699,14 @@ async def process_lambda( if value is None: return None + # Inadvertently passing a malformed parameters value will lead to the build process mysteriously hanging at the + # "Generating C++ source..." stage, so check here to save the developer's hair. + assert isinstance(parameters, list) and all( + isinstance(p, tuple) and len(p) == 2 for p in parameters + ) + if isinstance(value, Expression): + value = Lambda(value) + parts = value.parts[:] for i, id in enumerate(value.requires_ids): full_id, var = await get_variable_with_full_id(id) diff --git a/tests/unit_tests/test_cpp_generator.py b/tests/unit_tests/test_cpp_generator.py index 2c9f760c8e..8755e6e2a1 100644 --- a/tests/unit_tests/test_cpp_generator.py +++ b/tests/unit_tests/test_cpp_generator.py @@ -347,3 +347,280 @@ class TestMockObj: assert isinstance(actual, cg.MockObj) assert actual.base == "foo.eek" assert actual.op == "." + + +class TestStatementFunction: + """Tests for the statement() function.""" + + def test_statement__expression_converted_to_statement(self): + """Test that expressions are converted to ExpressionStatement.""" + expr = cg.RawExpression("foo()") + result = cg.statement(expr) + + assert isinstance(result, cg.ExpressionStatement) + assert str(result) == "foo();" + + def test_statement__statement_unchanged(self): + """Test that statements are returned unchanged.""" + stmt = cg.RawStatement("foo()") + result = cg.statement(stmt) + + assert result is stmt + assert str(result) == "foo()" + + def test_statement__expression_statement_unchanged(self): + """Test that ExpressionStatement is returned unchanged.""" + stmt = cg.ExpressionStatement(42) + result = cg.statement(stmt) + + assert result is stmt + assert str(result) == "42;" + + def test_statement__line_comment_unchanged(self): + """Test that LineComment is returned unchanged.""" + stmt = cg.LineComment("This is a comment") + result = cg.statement(stmt) + + assert result is stmt + assert str(result) == "// This is a comment" + + +class TestLiteralFunction: + """Tests for the literal() function.""" + + def test_literal__creates_mockobj(self): + """Test that literal() creates a MockObj.""" + result = cg.literal("MY_CONSTANT") + + assert isinstance(result, cg.MockObj) + assert result.base == "MY_CONSTANT" + assert result.op == "" + + def test_literal__string_representation(self): + """Test that literal names appear unquoted in generated code.""" + result = cg.literal("nullptr") + + assert str(result) == "nullptr" + + def test_literal__can_be_used_in_expressions(self): + """Test that literals can be used as part of larger expressions.""" + null_lit = cg.literal("nullptr") + expr = cg.CallExpression(cg.RawExpression("my_func"), null_lit) + + assert str(expr) == "my_func(nullptr)" + + def test_literal__common_cpp_literals(self): + """Test common C++ literal values.""" + test_cases = [ + ("nullptr", "nullptr"), + ("true", "true"), + ("false", "false"), + ("NULL", "NULL"), + ("NAN", "NAN"), + ] + + for name, expected in test_cases: + result = cg.literal(name) + assert str(result) == expected + + +class TestLambdaConstructor: + """Tests for the Lambda class constructor in core/__init__.py.""" + + def test_lambda__from_string(self): + """Test Lambda constructor with string argument.""" + from esphome.core import Lambda + + lambda_obj = Lambda("return x + 1;") + + assert lambda_obj.value == "return x + 1;" + assert str(lambda_obj) == "return x + 1;" + + def test_lambda__from_expression(self): + """Test Lambda constructor with Expression argument.""" + from esphome.core import Lambda + + expr = cg.RawExpression("x + 1") + lambda_obj = Lambda(expr) + + # Expression should be converted to statement (with semicolon) + assert lambda_obj.value == "x + 1;" + + def test_lambda__from_lambda(self): + """Test Lambda constructor with another Lambda argument.""" + from esphome.core import Lambda + + original = Lambda("return x + 1;") + copy = Lambda(original) + + assert copy.value == original.value + assert copy.value == "return x + 1;" + + def test_lambda__parts_parsing(self): + """Test that Lambda correctly parses parts with id() references.""" + from esphome.core import Lambda + + lambda_obj = Lambda("return id(my_sensor).state;") + parts = lambda_obj.parts + + # Parts should be split by LAMBDA_PROG regex: text, id, op, text + assert len(parts) == 4 + assert parts[0] == "return " + assert parts[1] == "my_sensor" + assert parts[2] == "." + assert parts[3] == "state;" + + def test_lambda__requires_ids(self): + """Test that Lambda correctly extracts required IDs.""" + from esphome.core import ID, Lambda + + lambda_obj = Lambda("return id(sensor1).state + id(sensor2).value;") + ids = lambda_obj.requires_ids + + assert len(ids) == 2 + assert all(isinstance(id_obj, ID) for id_obj in ids) + assert ids[0].id == "sensor1" + assert ids[1].id == "sensor2" + + def test_lambda__no_ids(self): + """Test Lambda with no id() references.""" + from esphome.core import Lambda + + lambda_obj = Lambda("return 42;") + ids = lambda_obj.requires_ids + + assert len(ids) == 0 + + def test_lambda__comment_removal(self): + """Test that comments are removed when parsing parts.""" + from esphome.core import Lambda + + lambda_obj = Lambda("return id(sensor).state; // Get sensor state") + parts = lambda_obj.parts + + # Comment should be replaced with space, not affect parsing + assert "my_sensor" not in str(parts) + + def test_lambda__multiline_string(self): + """Test Lambda with multiline string.""" + from esphome.core import Lambda + + code = """if (id(sensor).state > 0) { + return true; +} +return false;""" + lambda_obj = Lambda(code) + + assert lambda_obj.value == code + assert "sensor" in [id_obj.id for id_obj in lambda_obj.requires_ids] + + +@pytest.mark.asyncio +class TestProcessLambda: + """Tests for the process_lambda() async function.""" + + async def test_process_lambda__none_value(self): + """Test that None returns None.""" + result = await cg.process_lambda(None, []) + + assert result is None + + async def test_process_lambda__with_expression(self): + """Test process_lambda with Expression argument.""" + + expr = cg.RawExpression("return x + 1") + result = await cg.process_lambda(expr, [(int, "x")]) + + assert isinstance(result, cg.LambdaExpression) + assert "x + 1" in str(result) + + async def test_process_lambda__simple_lambda_no_ids(self): + """Test process_lambda with simple Lambda without id() references.""" + from esphome.core import Lambda + + lambda_obj = Lambda("return x + 1;") + result = await cg.process_lambda(lambda_obj, [(int, "x")]) + + assert isinstance(result, cg.LambdaExpression) + # Should have parameter + lambda_str = str(result) + assert "int32_t x" in lambda_str + assert "return x + 1;" in lambda_str + + async def test_process_lambda__with_return_type(self): + """Test process_lambda with return type specified.""" + from esphome.core import Lambda + + lambda_obj = Lambda("return x > 0;") + result = await cg.process_lambda(lambda_obj, [(int, "x")], return_type=bool) + + assert isinstance(result, cg.LambdaExpression) + lambda_str = str(result) + assert "-> bool" in lambda_str + + async def test_process_lambda__with_capture(self): + """Test process_lambda with capture specified.""" + from esphome.core import Lambda + + lambda_obj = Lambda("return captured + x;") + result = await cg.process_lambda(lambda_obj, [(int, "x")], capture="captured") + + assert isinstance(result, cg.LambdaExpression) + lambda_str = str(result) + assert "[captured]" in lambda_str + + async def test_process_lambda__empty_capture(self): + """Test process_lambda with empty capture (stateless lambda).""" + from esphome.core import Lambda + + lambda_obj = Lambda("return x + 1;") + result = await cg.process_lambda(lambda_obj, [(int, "x")], capture="") + + assert isinstance(result, cg.LambdaExpression) + lambda_str = str(result) + assert "[]" in lambda_str + + async def test_process_lambda__no_parameters(self): + """Test process_lambda with no parameters.""" + from esphome.core import Lambda + + lambda_obj = Lambda("return 42;") + result = await cg.process_lambda(lambda_obj, []) + + assert isinstance(result, cg.LambdaExpression) + lambda_str = str(result) + # Should have empty parameter list + assert "()" in lambda_str + + async def test_process_lambda__multiple_parameters(self): + """Test process_lambda with multiple parameters.""" + from esphome.core import Lambda + + lambda_obj = Lambda("return x + y + z;") + result = await cg.process_lambda( + lambda_obj, [(int, "x"), (float, "y"), (bool, "z")] + ) + + assert isinstance(result, cg.LambdaExpression) + lambda_str = str(result) + assert "int32_t x" in lambda_str + assert "float y" in lambda_str + assert "bool z" in lambda_str + + async def test_process_lambda__parameter_validation(self): + """Test that malformed parameters raise assertion error.""" + from esphome.core import Lambda + + lambda_obj = Lambda("return x;") + + # Test invalid parameter format (not list of tuples) + with pytest.raises(AssertionError): + await cg.process_lambda(lambda_obj, "invalid") + + # Test invalid tuple format (not 2-element tuples) + with pytest.raises(AssertionError): + await cg.process_lambda(lambda_obj, [(int, "x", "extra")]) + + # Test invalid tuple format (single element) + with pytest.raises(AssertionError): + await cg.process_lambda(lambda_obj, [(int,)])