From 27e1095cd7071c6195236ebeaedc1a476d1a88f8 Mon Sep 17 00:00:00 2001 From: Jesse Hills <3060199+jesserockz@users.noreply.github.com> Date: Tue, 7 Oct 2025 09:36:27 +1300 Subject: [PATCH] [core] Allow `AUTO_LOAD` to receive the component config to determine if it should load other components (#10961) Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: J. Nick Koston --- esphome/config.py | 61 +++++++- esphome/loader.py | 9 +- script/list-components.py | 36 ++++- tests/unit_tests/conftest.py | 7 + .../fixtures/auto_load_dynamic.yaml | 10 ++ .../unit_tests/fixtures/auto_load_static.yaml | 8 ++ tests/unit_tests/test_config_auto_load.py | 131 ++++++++++++++++++ tests/unit_tests/test_config_normalization.py | 7 - 8 files changed, 245 insertions(+), 24 deletions(-) create mode 100644 tests/unit_tests/fixtures/auto_load_dynamic.yaml create mode 100644 tests/unit_tests/fixtures/auto_load_static.yaml create mode 100644 tests/unit_tests/test_config_auto_load.py diff --git a/esphome/config.py b/esphome/config.py index a5297a53cb..7a083fee33 100644 --- a/esphome/config.py +++ b/esphome/config.py @@ -67,6 +67,31 @@ ConfigPath = list[str | int] path_context = contextvars.ContextVar("Config path") +def _add_auto_load_steps(result: Config, loads: list[str]) -> None: + """Add AutoLoadValidationStep for each component in loads that isn't already loaded.""" + for load in loads: + if load not in result: + result.add_validation_step(AutoLoadValidationStep(load)) + + +def _process_auto_load( + result: Config, platform: ComponentManifest, path: ConfigPath +) -> None: + # Process platform's AUTO_LOAD + auto_load = platform.auto_load + if isinstance(auto_load, list): + _add_auto_load_steps(result, auto_load) + elif callable(auto_load): + import inspect + + if inspect.signature(auto_load).parameters: + result.add_validation_step( + AddDynamicAutoLoadsValidationStep(path, platform) + ) + else: + _add_auto_load_steps(result, auto_load()) + + def _process_platform_config( result: Config, component_name: str, @@ -91,9 +116,7 @@ def _process_platform_config( CORE.loaded_platforms.add(f"{component_name}/{platform_name}") # Process platform's AUTO_LOAD - for load in platform.auto_load: - if load not in result: - result.add_validation_step(AutoLoadValidationStep(load)) + _process_auto_load(result, platform, path) # Add validation steps for the platform p_domain = f"{component_name}.{platform_name}" @@ -390,9 +413,7 @@ class LoadValidationStep(ConfigValidationStep): result[self.domain] = self.conf = [self.conf] # Process AUTO_LOAD - for load in component.auto_load: - if load not in result: - result.add_validation_step(AutoLoadValidationStep(load)) + _process_auto_load(result, component, path) result.add_validation_step( MetadataValidationStep([self.domain], self.domain, self.conf, component) @@ -618,6 +639,34 @@ class MetadataValidationStep(ConfigValidationStep): result.add_validation_step(FinalValidateValidationStep(self.path, self.comp)) +class AddDynamicAutoLoadsValidationStep(ConfigValidationStep): + """Add dynamic auto loads step. + + This step is used to auto-load components where one component can alter its + AUTO_LOAD based on its configuration. + """ + + # Has to happen after normal schema is validated and before final schema validation + priority = -10.0 + + def __init__(self, path: ConfigPath, comp: ComponentManifest) -> None: + self.path = path + self.comp = comp + + def run(self, result: Config) -> None: + if result.errors: + # If result already has errors, skip this step + return + + conf = result.get_nested_item(self.path) + with result.catch_error(self.path): + auto_load = self.comp.auto_load + if not callable(auto_load): + return + loads = auto_load(conf) + _add_auto_load_steps(result, loads) + + class SchemaValidationStep(ConfigValidationStep): """Schema validation step. diff --git a/esphome/loader.py b/esphome/loader.py index ec2f5101da..387443c032 100644 --- a/esphome/loader.py +++ b/esphome/loader.py @@ -82,11 +82,10 @@ class ComponentManifest: return getattr(self.module, "CONFLICTS_WITH", []) @property - def auto_load(self) -> list[str]: - al = getattr(self.module, "AUTO_LOAD", []) - if callable(al): - return al() - return al + def auto_load( + self, + ) -> list[str] | Callable[[], list[str]] | Callable[[ConfigType], list[str]]: + return getattr(self.module, "AUTO_LOAD", []) @property def codeowners(self) -> list[str]: diff --git a/script/list-components.py b/script/list-components.py index ef02aecdf6..9ab1cdd852 100755 --- a/script/list-components.py +++ b/script/list-components.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import argparse +from collections.abc import Callable from pathlib import Path import sys @@ -13,7 +14,7 @@ from esphome.const import ( PLATFORM_ESP8266, ) from esphome.core import CORE -from esphome.loader import get_component, get_platform +from esphome.loader import ComponentManifest, get_component, get_platform def filter_component_files(str): @@ -45,6 +46,29 @@ def add_item_to_components_graph(components_graph, parent, child): components_graph[parent].append(child) +def resolve_auto_load( + auto_load: list[str] | Callable[[], list[str]] | Callable[[dict | None], list[str]], + config: dict | None = None, +) -> list[str]: + """Resolve AUTO_LOAD to a list, handling callables with or without config parameter. + + Args: + auto_load: The AUTO_LOAD value (list or callable) + config: Optional config to pass to callable AUTO_LOAD functions + + Returns: + List of component names to auto-load + """ + if not callable(auto_load): + return auto_load + + import inspect + + if inspect.signature(auto_load).parameters: + return auto_load(config) + return auto_load() + + def create_components_graph(): # The root directory of the repo root = Path(__file__).parent.parent @@ -63,7 +87,7 @@ def create_components_graph(): components_graph = {} platforms = [] - components = [] + components: list[tuple[ComponentManifest, str, Path]] = [] for path in components_dir.iterdir(): if not path.is_dir(): @@ -92,8 +116,8 @@ def create_components_graph(): for target_config in TARGET_CONFIGURATIONS: CORE.data[KEY_CORE] = target_config - for auto_load in comp.auto_load: - add_item_to_components_graph(components_graph, auto_load, name) + for item in resolve_auto_load(comp.auto_load, config=None): + add_item_to_components_graph(components_graph, item, name) # restore config CORE.data[KEY_CORE] = TARGET_CONFIGURATIONS[0] @@ -114,8 +138,8 @@ def create_components_graph(): for target_config in TARGET_CONFIGURATIONS: CORE.data[KEY_CORE] = target_config - for auto_load in platform.auto_load: - add_item_to_components_graph(components_graph, auto_load, name) + for item in resolve_auto_load(platform.auto_load, config={}): + add_item_to_components_graph(components_graph, item, name) # restore config CORE.data[KEY_CORE] = TARGET_CONFIGURATIONS[0] diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index e8d9c02524..932221997c 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -101,3 +101,10 @@ def mock_get_idedata() -> Generator[Mock, None, None]: """Mock get_idedata for platformio_api.""" with patch("esphome.platformio_api.get_idedata") as mock: yield mock + + +@pytest.fixture +def mock_get_component() -> Generator[Mock, None, None]: + """Mock get_component for config module.""" + with patch("esphome.config.get_component") as mock: + yield mock diff --git a/tests/unit_tests/fixtures/auto_load_dynamic.yaml b/tests/unit_tests/fixtures/auto_load_dynamic.yaml new file mode 100644 index 0000000000..b604a2a42b --- /dev/null +++ b/tests/unit_tests/fixtures/auto_load_dynamic.yaml @@ -0,0 +1,10 @@ +esphome: + name: test-device + +esp32: + board: esp32dev + +# Test component with dynamic AUTO_LOAD +test_component: + enable_logger: true + enable_api: false diff --git a/tests/unit_tests/fixtures/auto_load_static.yaml b/tests/unit_tests/fixtures/auto_load_static.yaml new file mode 100644 index 0000000000..c8f9e6222a --- /dev/null +++ b/tests/unit_tests/fixtures/auto_load_static.yaml @@ -0,0 +1,8 @@ +esphome: + name: test-device + +esp32: + board: esp32dev + +# Test component with static AUTO_LOAD +test_component: diff --git a/tests/unit_tests/test_config_auto_load.py b/tests/unit_tests/test_config_auto_load.py new file mode 100644 index 0000000000..d31b17eeec --- /dev/null +++ b/tests/unit_tests/test_config_auto_load.py @@ -0,0 +1,131 @@ +"""Tests for AUTO_LOAD functionality including dynamic AUTO_LOAD.""" + +from pathlib import Path +from typing import Any +from unittest.mock import Mock + +import pytest + +from esphome import config, config_validation as cv, yaml_util +from esphome.core import CORE + + +@pytest.fixture +def fixtures_dir() -> Path: + """Get the fixtures directory.""" + return Path(__file__).parent / "fixtures" + + +@pytest.fixture +def default_component() -> Mock: + """Create a default mock component for unmocked components.""" + return Mock( + auto_load=[], + is_platform_component=False, + is_platform=False, + multi_conf=False, + multi_conf_no_default=False, + dependencies=[], + conflicts_with=[], + config_schema=cv.Schema({}, extra=cv.ALLOW_EXTRA), + ) + + +@pytest.fixture +def static_auto_load_component() -> Mock: + """Create a mock component with static AUTO_LOAD.""" + return Mock( + auto_load=["logger"], + is_platform_component=False, + is_platform=False, + multi_conf=False, + multi_conf_no_default=False, + dependencies=[], + conflicts_with=[], + config_schema=cv.Schema({}, extra=cv.ALLOW_EXTRA), + ) + + +def test_static_auto_load_adds_components( + mock_get_component: Mock, + fixtures_dir: Path, + static_auto_load_component: Mock, + default_component: Mock, +) -> None: + """Test that static AUTO_LOAD triggers loading of specified components.""" + CORE.config_path = fixtures_dir / "auto_load_static.yaml" + + config_file = fixtures_dir / "auto_load_static.yaml" + raw_config = yaml_util.load_yaml(config_file) + + component_mocks = {"test_component": static_auto_load_component} + mock_get_component.side_effect = lambda name: component_mocks.get( + name, default_component + ) + + result = config.validate_config(raw_config, {}) + + # Check for validation errors + assert not result.errors, f"Validation errors: {result.errors}" + + # Logger should have been auto-loaded by test_component + assert "logger" in result + assert "test_component" in result + + +def test_dynamic_auto_load_with_config_param( + mock_get_component: Mock, + fixtures_dir: Path, + default_component: Mock, +) -> None: + """Test that dynamic AUTO_LOAD evaluates based on configuration.""" + CORE.config_path = fixtures_dir / "auto_load_dynamic.yaml" + + config_file = fixtures_dir / "auto_load_dynamic.yaml" + raw_config = yaml_util.load_yaml(config_file) + + # Track if auto_load was called with config + auto_load_calls = [] + + def dynamic_auto_load(conf: dict[str, Any]) -> list[str]: + """Dynamically load components based on config.""" + auto_load_calls.append(conf) + component_map = { + "enable_logger": "logger", + "enable_api": "api", + } + return [comp for key, comp in component_map.items() if conf.get(key)] + + dynamic_component = Mock( + auto_load=dynamic_auto_load, + is_platform_component=False, + is_platform=False, + multi_conf=False, + multi_conf_no_default=False, + dependencies=[], + conflicts_with=[], + config_schema=cv.Schema({}, extra=cv.ALLOW_EXTRA), + ) + + component_mocks = {"test_component": dynamic_component} + mock_get_component.side_effect = lambda name: component_mocks.get( + name, default_component + ) + + result = config.validate_config(raw_config, {}) + + # Check for validation errors + assert not result.errors, f"Validation errors: {result.errors}" + + # Verify auto_load was called with the validated config + assert len(auto_load_calls) == 1, "auto_load should be called exactly once" + assert auto_load_calls[0].get("enable_logger") is True + assert auto_load_calls[0].get("enable_api") is False + + # Only logger should be auto-loaded (enable_logger=true in YAML) + assert "logger" in result, ( + f"Logger not found in result. Result keys: {list(result.keys())}" + ) + # API should NOT be auto-loaded (enable_api=false in YAML) + assert "api" not in result + assert "test_component" in result diff --git a/tests/unit_tests/test_config_normalization.py b/tests/unit_tests/test_config_normalization.py index 4b79ddd426..d70f3c24e0 100644 --- a/tests/unit_tests/test_config_normalization.py +++ b/tests/unit_tests/test_config_normalization.py @@ -10,13 +10,6 @@ from esphome import config, yaml_util from esphome.core import CORE -@pytest.fixture -def mock_get_component() -> Generator[Mock, None, None]: - """Fixture for mocking get_component.""" - with patch("esphome.config.get_component") as mock_get_component: - yield mock_get_component - - @pytest.fixture def mock_get_platform() -> Generator[Mock, None, None]: """Fixture for mocking get_platform."""