diff --git a/esphome/components/substitutions/__init__.py b/esphome/components/substitutions/__init__.py index 098d56bfad..7e15f714f7 100644 --- a/esphome/components/substitutions/__init__.py +++ b/esphome/components/substitutions/__init__.py @@ -1,4 +1,6 @@ import logging +from re import Match +from typing import Any from esphome import core from esphome.config_helpers import Extend, Remove, merge_config, merge_dicts_ordered @@ -39,7 +41,34 @@ async def to_code(config): pass -def _expand_jinja(value, orig_value, path, jinja, ignore_missing): +def _restore_data_base(value: Any, orig_value: ESPHomeDataBase) -> ESPHomeDataBase: + """This function restores ESPHomeDataBase metadata held by the original string. + This is needed because during jinja evaluation, strings can be replaced by other types, + but we want to keep the original metadata for error reporting and source mapping. + For example, if a substitution replaces a string with a dictionary, we want that items + in the dictionary to still point to the original document location + """ + if isinstance(value, ESPHomeDataBase): + return value + if isinstance(value, dict): + return { + _restore_data_base(k, orig_value): _restore_data_base(v, orig_value) + for k, v in value.items() + } + if isinstance(value, list): + return [_restore_data_base(v, orig_value) for v in value] + if isinstance(value, str): + return make_data_base(value, orig_value) + return value + + +def _expand_jinja( + value: str | JinjaStr, + orig_value: str | JinjaStr, + path, + jinja: Jinja, + ignore_missing: bool, +) -> Any: if has_jinja(value): # If the original value passed in to this function is a JinjaStr, it means it contains an unresolved # Jinja expression from a previous pass. @@ -65,10 +94,17 @@ def _expand_jinja(value, orig_value, path, jinja, ignore_missing): f"\nSee {'->'.join(str(x) for x in path)}", path, ) + # If the original, unexpanded string, contained document metadata (ESPHomeDatabase), + # assign this same document metadata to the resulting value. + if isinstance(orig_value, ESPHomeDataBase): + value = _restore_data_base(value, orig_value) + return value -def _expand_substitutions(substitutions, value, path, jinja, ignore_missing): +def _expand_substitutions( + substitutions: dict, value: str, path, jinja: Jinja, ignore_missing: bool +) -> Any: if "$" not in value: return value @@ -76,14 +112,14 @@ def _expand_substitutions(substitutions, value, path, jinja, ignore_missing): i = 0 while True: - m = cv.VARIABLE_PROG.search(value, i) + m: Match[str] = cv.VARIABLE_PROG.search(value, i) if not m: # No more variable substitutions found. See if the remainder looks like a jinja template value = _expand_jinja(value, orig_value, path, jinja, ignore_missing) break i, j = m.span(0) - name = m.group(1) + name: str = m.group(1) if name.startswith("{") and name.endswith("}"): name = name[1:-1] if name not in substitutions: @@ -98,7 +134,7 @@ def _expand_substitutions(substitutions, value, path, jinja, ignore_missing): i = j continue - sub = substitutions[name] + sub: Any = substitutions[name] if i == 0 and j == len(value): # The variable spans the whole expression, e.g., "${varName}". Return its resolved value directly @@ -121,7 +157,13 @@ def _expand_substitutions(substitutions, value, path, jinja, ignore_missing): return value -def _substitute_item(substitutions, item, path, jinja, ignore_missing): +def _substitute_item( + substitutions: dict, + item: Any, + path: list[int | str], + jinja: Jinja, + ignore_missing: bool, +) -> Any | None: if isinstance(item, ESPLiteralValue): return None # do not substitute inside literal blocks if isinstance(item, list): @@ -160,7 +202,9 @@ def _substitute_item(substitutions, item, path, jinja, ignore_missing): return None -def do_substitution_pass(config, command_line_substitutions, ignore_missing=False): +def do_substitution_pass( + config: dict, command_line_substitutions: dict, ignore_missing: bool = False +) -> None: if CONF_SUBSTITUTIONS not in config and not command_line_substitutions: return diff --git a/esphome/components/substitutions/jinja.py b/esphome/components/substitutions/jinja.py index dde0162993..cb3c6dfac5 100644 --- a/esphome/components/substitutions/jinja.py +++ b/esphome/components/substitutions/jinja.py @@ -1,10 +1,14 @@ from ast import literal_eval +from collections.abc import Iterator +from itertools import chain, islice import logging import math import re +from types import GeneratorType +from typing import Any import jinja2 as jinja -from jinja2.sandbox import SandboxedEnvironment +from jinja2.nativetypes import NativeCodeGenerator, NativeTemplate from esphome.yaml_util import ESPLiteralValue @@ -24,7 +28,7 @@ detect_jinja_re = re.compile( ) -def has_jinja(st): +def has_jinja(st: str) -> bool: return detect_jinja_re.search(st) is not None @@ -109,12 +113,56 @@ class TrackerContext(jinja.runtime.Context): return val -class Jinja(SandboxedEnvironment): +def _concat_nodes_override(values: Iterator[Any]) -> Any: + """ + This function customizes how Jinja preserves native types when concatenating + multiple result nodes together. If the result is a single node, its value + is returned. Otherwise, the nodes are concatenated as strings. If + the result can be parsed with `ast.literal_eval`, the parsed + value is returned. Otherwise, the string is returned. + This helps preserve metadata such as ESPHomeDataBase from original values + and mimicks how HomeAssistant deals with template evaluation and preserving + the original datatype. + """ + head: list[Any] = list(islice(values, 2)) + + if not head: + return None + + if len(head) == 1: + raw = head[0] + if not isinstance(raw, str): + return raw + else: + if isinstance(values, GeneratorType): + values = chain(head, values) + raw = "".join([str(v) for v in values]) + + try: + # Attempt to parse the concatenated string into a Python literal. + # This allows expressions like "1 + 2" to be evaluated to the integer 3. + # If the result is also a string or there is a parsing error, + # fall back to returning the raw string. This is consistent with + # Home Assistant's behavior when evaluating templates + result = literal_eval(raw) + if not isinstance(result, str): + return result + + except (ValueError, SyntaxError, MemoryError, TypeError): + pass + return raw + + +class Jinja(jinja.Environment): """ Wraps a Jinja environment """ - def __init__(self, context_vars): + # jinja environment customization overrides + code_generator_class = NativeCodeGenerator + concat = staticmethod(_concat_nodes_override) + + def __init__(self, context_vars: dict): super().__init__( trim_blocks=True, lstrip_blocks=True, @@ -142,19 +190,10 @@ class Jinja(SandboxedEnvironment): **SAFE_GLOBALS, } - def safe_eval(self, expr): - try: - result = literal_eval(expr) - if not isinstance(result, str): - return result - except (ValueError, SyntaxError, MemoryError, TypeError): - pass - return expr - - def expand(self, content_str): + def expand(self, content_str: str | JinjaStr) -> Any: """ Renders a string that may contain Jinja expressions or statements - Returns the resulting processed string if all values could be resolved. + Returns the resulting value if all variables and expressions could be resolved. Otherwise, it returns a tagged (JinjaStr) string that captures variables in scope (upvalues), like a closure for later evaluation. """ @@ -172,7 +211,7 @@ class Jinja(SandboxedEnvironment): self.context_trace = {} try: template = self.from_string(content_str) - result = self.safe_eval(template.render(override_vars)) + result = template.render(override_vars) if isinstance(result, Undefined): print("" + result) # force a UndefinedError exception except (TemplateSyntaxError, UndefinedError) as err: @@ -201,3 +240,10 @@ class Jinja(SandboxedEnvironment): content_str.result = result return result, None + + +class JinjaTemplate(NativeTemplate): + environment_class = Jinja + + +Jinja.template_class = JinjaTemplate diff --git a/tests/unit_tests/test_substitutions.py b/tests/unit_tests/test_substitutions.py index beb1ebc73e..7d50b44506 100644 --- a/tests/unit_tests/test_substitutions.py +++ b/tests/unit_tests/test_substitutions.py @@ -1,6 +1,7 @@ import glob import logging from pathlib import Path +from typing import Any from esphome import config as config_module, yaml_util from esphome.components import substitutions @@ -60,6 +61,29 @@ def write_yaml(path: Path, data: dict) -> None: path.write_text(yaml_util.dump(data), encoding="utf-8") +def verify_database(value: Any, path: str = "") -> str | None: + if isinstance(value, list): + for i, v in enumerate(value): + result = verify_database(v, f"{path}[{i}]") + if result is not None: + return result + return None + if isinstance(value, dict): + for k, v in value.items(): + key_result = verify_database(k, f"{path}/{k}") + if key_result is not None: + return key_result + value_result = verify_database(v, f"{path}/{k}") + if value_result is not None: + return value_result + return None + if isinstance(value, str): + if not isinstance(value, yaml_util.ESPHomeDataBase): + return f"{path}: {value!r} is not ESPHomeDataBase" + return None + return None + + def test_substitutions_fixtures(fixture_path): base_dir = fixture_path / "substitutions" sources = sorted(glob.glob(str(base_dir / "*.input.yaml"))) @@ -83,6 +107,9 @@ def test_substitutions_fixtures(fixture_path): substitutions.do_substitution_pass(config, None) resolve_extend_remove(config) + verify_database_result = verify_database(config) + if verify_database_result is not None: + raise AssertionError(verify_database_result) # Also load expected using ESPHome's loader, or use {} if missing and DEV_MODE if expected_path.is_file():