mirror of
				https://github.com/esphome/esphome.git
				synced 2025-10-31 07:03:55 +00:00 
			
		
		
		
	Merge branch 'multi_device_args' into integration
This commit is contained in:
		| @@ -9,6 +9,7 @@ import os | |||||||
| import re | import re | ||||||
| import sys | import sys | ||||||
| import time | import time | ||||||
|  | from typing import Protocol | ||||||
|  |  | ||||||
| import argcomplete | import argcomplete | ||||||
|  |  | ||||||
| @@ -44,6 +45,7 @@ from esphome.const import ( | |||||||
| from esphome.core import CORE, EsphomeError, coroutine | from esphome.core import CORE, EsphomeError, coroutine | ||||||
| from esphome.helpers import get_bool_env, indent, is_ip_address | from esphome.helpers import get_bool_env, indent, is_ip_address | ||||||
| from esphome.log import AnsiFore, color, setup_log | from esphome.log import AnsiFore, color, setup_log | ||||||
|  | from esphome.types import ConfigType | ||||||
| from esphome.util import ( | from esphome.util import ( | ||||||
|     get_serial_ports, |     get_serial_ports, | ||||||
|     list_yaml_files, |     list_yaml_files, | ||||||
| @@ -55,6 +57,23 @@ from esphome.util import ( | |||||||
| _LOGGER = logging.getLogger(__name__) | _LOGGER = logging.getLogger(__name__) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ArgsProtocol(Protocol): | ||||||
|  |     device: list[str] | None | ||||||
|  |     reset: bool | ||||||
|  |     username: str | None | ||||||
|  |     password: str | None | ||||||
|  |     client_id: str | None | ||||||
|  |     topic: str | None | ||||||
|  |     file: str | None | ||||||
|  |     no_logs: bool | ||||||
|  |     only_generate: bool | ||||||
|  |     show_secrets: bool | ||||||
|  |     dashboard: bool | ||||||
|  |     configuration: str | ||||||
|  |     name: str | ||||||
|  |     upload_speed: str | None | ||||||
|  |  | ||||||
|  |  | ||||||
| def choose_prompt(options, purpose: str = None): | def choose_prompt(options, purpose: str = None): | ||||||
|     if not options: |     if not options: | ||||||
|         raise EsphomeError( |         raise EsphomeError( | ||||||
| @@ -88,30 +107,50 @@ def choose_prompt(options, purpose: str = None): | |||||||
|  |  | ||||||
|  |  | ||||||
| def choose_upload_log_host( | def choose_upload_log_host( | ||||||
|     default, check_default, show_ota, show_mqtt, show_api, purpose: str = None |     default: list[str] | str | None, | ||||||
| ): |     check_default: str | None, | ||||||
|  |     show_ota: bool, | ||||||
|  |     show_mqtt: bool, | ||||||
|  |     show_api: bool, | ||||||
|  |     purpose: str | None = None, | ||||||
|  | ) -> list[str]: | ||||||
|  |     # Convert to list for uniform handling | ||||||
|  |     defaults = [default] if isinstance(default, str) else default or [] | ||||||
|  |  | ||||||
|  |     # If devices specified, resolve them | ||||||
|  |     if defaults: | ||||||
|  |         resolved: list[str] = [] | ||||||
|  |         for device in defaults: | ||||||
|  |             if device == "SERIAL": | ||||||
|  |                 options = [ | ||||||
|  |                     (f"{port.path} ({port.description})", port.path) | ||||||
|  |                     for port in get_serial_ports() | ||||||
|  |                 ] | ||||||
|  |                 resolved.append(choose_prompt(options, purpose=purpose)) | ||||||
|  |             elif device == "OTA": | ||||||
|  |                 if (show_ota and "ota" in CORE.config) or ( | ||||||
|  |                     show_api and "api" in CORE.config | ||||||
|  |                 ): | ||||||
|  |                     resolved.append(CORE.address) | ||||||
|  |                 elif show_mqtt and has_mqtt_logging(): | ||||||
|  |                     resolved.append("MQTT") | ||||||
|  |             else: | ||||||
|  |                 resolved.append(device) | ||||||
|  |         return resolved | ||||||
|  |  | ||||||
|  |     # No devices specified, show interactive chooser | ||||||
|     options = [ |     options = [ | ||||||
|         (f"{port.path} ({port.description})", port.path) for port in get_serial_ports() |         (f"{port.path} ({port.description})", port.path) for port in get_serial_ports() | ||||||
|     ] |     ] | ||||||
|     if default == "SERIAL": |  | ||||||
|         return choose_prompt(options, purpose=purpose) |  | ||||||
|     if (show_ota and "ota" in CORE.config) or (show_api and "api" in CORE.config): |     if (show_ota and "ota" in CORE.config) or (show_api and "api" in CORE.config): | ||||||
|         options.append((f"Over The Air ({CORE.address})", CORE.address)) |         options.append((f"Over The Air ({CORE.address})", CORE.address)) | ||||||
|         if default == "OTA": |     if show_mqtt and has_mqtt_logging(): | ||||||
|             return CORE.address |         mqtt_config = CORE.config[CONF_MQTT] | ||||||
|     if ( |  | ||||||
|         show_mqtt |  | ||||||
|         and (mqtt_config := CORE.config.get(CONF_MQTT)) |  | ||||||
|         and mqtt_logging_enabled(mqtt_config) |  | ||||||
|     ): |  | ||||||
|         options.append((f"MQTT ({mqtt_config[CONF_BROKER]})", "MQTT")) |         options.append((f"MQTT ({mqtt_config[CONF_BROKER]})", "MQTT")) | ||||||
|         if default == "OTA": |  | ||||||
|             return "MQTT" |  | ||||||
|     if default is not None: |  | ||||||
|         return default |  | ||||||
|     if check_default is not None and check_default in [opt[1] for opt in options]: |     if check_default is not None and check_default in [opt[1] for opt in options]: | ||||||
|         return check_default |         return [check_default] | ||||||
|     return choose_prompt(options, purpose=purpose) |     return [choose_prompt(options, purpose=purpose)] | ||||||
|  |  | ||||||
|  |  | ||||||
| def mqtt_logging_enabled(mqtt_config): | def mqtt_logging_enabled(mqtt_config): | ||||||
| @@ -123,7 +162,14 @@ def mqtt_logging_enabled(mqtt_config): | |||||||
|     return log_topic.get(CONF_LEVEL, None) != "NONE" |     return log_topic.get(CONF_LEVEL, None) != "NONE" | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_port_type(port): | def has_mqtt_logging() -> bool: | ||||||
|  |     """Check if MQTT logging is available.""" | ||||||
|  |     return (mqtt_config := CORE.config.get(CONF_MQTT)) and mqtt_logging_enabled( | ||||||
|  |         mqtt_config | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_port_type(port: str) -> str: | ||||||
|     if port.startswith("/") or port.startswith("COM"): |     if port.startswith("/") or port.startswith("COM"): | ||||||
|         return "SERIAL" |         return "SERIAL" | ||||||
|     if port == "MQTT": |     if port == "MQTT": | ||||||
| @@ -131,7 +177,7 @@ def get_port_type(port): | |||||||
|     return "NETWORK" |     return "NETWORK" | ||||||
|  |  | ||||||
|  |  | ||||||
| def run_miniterm(config, port, args): | def run_miniterm(config: ConfigType, port: str, args) -> int: | ||||||
|     from aioesphomeapi import LogParser |     from aioesphomeapi import LogParser | ||||||
|     import serial |     import serial | ||||||
|  |  | ||||||
| @@ -208,7 +254,7 @@ def wrap_to_code(name, comp): | |||||||
|     return wrapped |     return wrapped | ||||||
|  |  | ||||||
|  |  | ||||||
| def write_cpp(config): | def write_cpp(config: ConfigType) -> int: | ||||||
|     if not get_bool_env(ENV_NOGITIGNORE): |     if not get_bool_env(ENV_NOGITIGNORE): | ||||||
|         writer.write_gitignore() |         writer.write_gitignore() | ||||||
|  |  | ||||||
| @@ -216,7 +262,7 @@ def write_cpp(config): | |||||||
|     return write_cpp_file() |     return write_cpp_file() | ||||||
|  |  | ||||||
|  |  | ||||||
| def generate_cpp_contents(config): | def generate_cpp_contents(config: ConfigType) -> None: | ||||||
|     _LOGGER.info("Generating C++ source...") |     _LOGGER.info("Generating C++ source...") | ||||||
|  |  | ||||||
|     for name, component, conf in iter_component_configs(CORE.config): |     for name, component, conf in iter_component_configs(CORE.config): | ||||||
| @@ -227,7 +273,7 @@ def generate_cpp_contents(config): | |||||||
|     CORE.flush_tasks() |     CORE.flush_tasks() | ||||||
|  |  | ||||||
|  |  | ||||||
| def write_cpp_file(): | def write_cpp_file() -> int: | ||||||
|     code_s = indent(CORE.cpp_main_section) |     code_s = indent(CORE.cpp_main_section) | ||||||
|     writer.write_cpp(code_s) |     writer.write_cpp(code_s) | ||||||
|  |  | ||||||
| @@ -238,7 +284,7 @@ def write_cpp_file(): | |||||||
|     return 0 |     return 0 | ||||||
|  |  | ||||||
|  |  | ||||||
| def compile_program(args, config): | def compile_program(args: ArgsProtocol, config: ConfigType) -> int: | ||||||
|     from esphome import platformio_api |     from esphome import platformio_api | ||||||
|  |  | ||||||
|     _LOGGER.info("Compiling app...") |     _LOGGER.info("Compiling app...") | ||||||
| @@ -249,7 +295,9 @@ def compile_program(args, config): | |||||||
|     return 0 if idedata is not None else 1 |     return 0 if idedata is not None else 1 | ||||||
|  |  | ||||||
|  |  | ||||||
| def upload_using_esptool(config, port, file, speed): | def upload_using_esptool( | ||||||
|  |     config: ConfigType, port: str, file: str, speed: int | ||||||
|  | ) -> str | int: | ||||||
|     from esphome import platformio_api |     from esphome import platformio_api | ||||||
|  |  | ||||||
|     first_baudrate = speed or config[CONF_ESPHOME][CONF_PLATFORMIO_OPTIONS].get( |     first_baudrate = speed or config[CONF_ESPHOME][CONF_PLATFORMIO_OPTIONS].get( | ||||||
| @@ -314,7 +362,7 @@ def upload_using_esptool(config, port, file, speed): | |||||||
|     return run_esptool(115200) |     return run_esptool(115200) | ||||||
|  |  | ||||||
|  |  | ||||||
| def upload_using_platformio(config, port): | def upload_using_platformio(config: ConfigType, port: str): | ||||||
|     from esphome import platformio_api |     from esphome import platformio_api | ||||||
|  |  | ||||||
|     upload_args = ["-t", "upload", "-t", "nobuild"] |     upload_args = ["-t", "upload", "-t", "nobuild"] | ||||||
| @@ -323,7 +371,7 @@ def upload_using_platformio(config, port): | |||||||
|     return platformio_api.run_platformio_cli_run(config, CORE.verbose, *upload_args) |     return platformio_api.run_platformio_cli_run(config, CORE.verbose, *upload_args) | ||||||
|  |  | ||||||
|  |  | ||||||
| def check_permissions(port): | def check_permissions(port: str): | ||||||
|     if os.name == "posix" and get_port_type(port) == "SERIAL": |     if os.name == "posix" and get_port_type(port) == "SERIAL": | ||||||
|         # Check if we can open selected serial port |         # Check if we can open selected serial port | ||||||
|         if not os.access(port, os.F_OK): |         if not os.access(port, os.F_OK): | ||||||
| @@ -341,7 +389,7 @@ def check_permissions(port): | |||||||
|             ) |             ) | ||||||
|  |  | ||||||
|  |  | ||||||
| def upload_program(config, args, host): | def upload_program(config: ConfigType, args: ArgsProtocol, host: str) -> int | str: | ||||||
|     try: |     try: | ||||||
|         module = importlib.import_module("esphome.components." + CORE.target_platform) |         module = importlib.import_module("esphome.components." + CORE.target_platform) | ||||||
|         if getattr(module, "upload_program")(config, args, host): |         if getattr(module, "upload_program")(config, args, host): | ||||||
| @@ -356,7 +404,7 @@ def upload_program(config, args, host): | |||||||
|             return upload_using_esptool(config, host, file, args.upload_speed) |             return upload_using_esptool(config, host, file, args.upload_speed) | ||||||
|  |  | ||||||
|         if CORE.target_platform in (PLATFORM_RP2040): |         if CORE.target_platform in (PLATFORM_RP2040): | ||||||
|             return upload_using_platformio(config, args.device) |             return upload_using_platformio(config, host) | ||||||
|  |  | ||||||
|         if CORE.is_libretiny: |         if CORE.is_libretiny: | ||||||
|             return upload_using_platformio(config, host) |             return upload_using_platformio(config, host) | ||||||
| @@ -379,9 +427,12 @@ def upload_program(config, args, host): | |||||||
|     remote_port = int(ota_conf[CONF_PORT]) |     remote_port = int(ota_conf[CONF_PORT]) | ||||||
|     password = ota_conf.get(CONF_PASSWORD, "") |     password = ota_conf.get(CONF_PASSWORD, "") | ||||||
|  |  | ||||||
|  |     # Check if we should use MQTT for address resolution | ||||||
|  |     # This happens when no device was specified, or the current host is "MQTT"/"OTA" | ||||||
|  |     devices: list[str] = args.device or [] | ||||||
|     if ( |     if ( | ||||||
|         CONF_MQTT in config  # pylint: disable=too-many-boolean-expressions |         CONF_MQTT in config  # pylint: disable=too-many-boolean-expressions | ||||||
|         and (not args.device or args.device in ("MQTT", "OTA")) |         and (not devices or host in ("MQTT", "OTA")) | ||||||
|         and ( |         and ( | ||||||
|             ((config[CONF_MDNS][CONF_DISABLED]) and not is_ip_address(CORE.address)) |             ((config[CONF_MDNS][CONF_DISABLED]) and not is_ip_address(CORE.address)) | ||||||
|             or get_port_type(host) == "MQTT" |             or get_port_type(host) == "MQTT" | ||||||
| @@ -399,23 +450,28 @@ def upload_program(config, args, host): | |||||||
|     return espota2.run_ota(host, remote_port, password, CORE.firmware_bin) |     return espota2.run_ota(host, remote_port, password, CORE.firmware_bin) | ||||||
|  |  | ||||||
|  |  | ||||||
| def show_logs(config, args, port): | def show_logs(config: ConfigType, args: ArgsProtocol, devices: list[str]) -> int | None: | ||||||
|     if "logger" not in config: |     if "logger" not in config: | ||||||
|         raise EsphomeError("Logger is not configured!") |         raise EsphomeError("Logger is not configured!") | ||||||
|  |  | ||||||
|  |     port = devices[0] | ||||||
|  |  | ||||||
|     if get_port_type(port) == "SERIAL": |     if get_port_type(port) == "SERIAL": | ||||||
|         check_permissions(port) |         check_permissions(port) | ||||||
|         return run_miniterm(config, port, args) |         return run_miniterm(config, port, args) | ||||||
|     if get_port_type(port) == "NETWORK" and "api" in config: |     if get_port_type(port) == "NETWORK" and "api" in config: | ||||||
|  |         addresses_to_use = devices | ||||||
|         if config[CONF_MDNS][CONF_DISABLED] and CONF_MQTT in config: |         if config[CONF_MDNS][CONF_DISABLED] and CONF_MQTT in config: | ||||||
|             from esphome import mqtt |             from esphome import mqtt | ||||||
|  |  | ||||||
|             port = mqtt.get_esphome_device_ip( |             mqtt_address = mqtt.get_esphome_device_ip( | ||||||
|                 config, args.username, args.password, args.client_id |                 config, args.username, args.password, args.client_id | ||||||
|             )[0] |             )[0] | ||||||
|  |             addresses_to_use = [mqtt_address] | ||||||
|  |  | ||||||
|         from esphome.components.api.client import run_logs |         from esphome.components.api.client import run_logs | ||||||
|  |  | ||||||
|         return run_logs(config, port) |         return run_logs(config, addresses_to_use) | ||||||
|     if get_port_type(port) == "MQTT" and "mqtt" in config: |     if get_port_type(port) == "MQTT" and "mqtt" in config: | ||||||
|         from esphome import mqtt |         from esphome import mqtt | ||||||
|  |  | ||||||
| @@ -426,7 +482,7 @@ def show_logs(config, args, port): | |||||||
|     raise EsphomeError("No remote or local logging method configured (api/mqtt/logger)") |     raise EsphomeError("No remote or local logging method configured (api/mqtt/logger)") | ||||||
|  |  | ||||||
|  |  | ||||||
| def clean_mqtt(config, args): | def clean_mqtt(config: ConfigType, args: ArgsProtocol) -> int | None: | ||||||
|     from esphome import mqtt |     from esphome import mqtt | ||||||
|  |  | ||||||
|     return mqtt.clear_topic( |     return mqtt.clear_topic( | ||||||
| @@ -434,13 +490,13 @@ def clean_mqtt(config, args): | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| def command_wizard(args): | def command_wizard(args: ArgsProtocol) -> int | None: | ||||||
|     from esphome import wizard |     from esphome import wizard | ||||||
|  |  | ||||||
|     return wizard.wizard(args.configuration) |     return wizard.wizard(args.configuration) | ||||||
|  |  | ||||||
|  |  | ||||||
| def command_config(args, config): | def command_config(args: ArgsProtocol, config: ConfigType) -> int | None: | ||||||
|     if not CORE.verbose: |     if not CORE.verbose: | ||||||
|         config = strip_default_ids(config) |         config = strip_default_ids(config) | ||||||
|     output = yaml_util.dump(config, args.show_secrets) |     output = yaml_util.dump(config, args.show_secrets) | ||||||
| @@ -455,7 +511,7 @@ def command_config(args, config): | |||||||
|     return 0 |     return 0 | ||||||
|  |  | ||||||
|  |  | ||||||
| def command_vscode(args): | def command_vscode(args: ArgsProtocol) -> int | None: | ||||||
|     from esphome import vscode |     from esphome import vscode | ||||||
|  |  | ||||||
|     logging.disable(logging.INFO) |     logging.disable(logging.INFO) | ||||||
| @@ -463,7 +519,7 @@ def command_vscode(args): | |||||||
|     vscode.read_config(args) |     vscode.read_config(args) | ||||||
|  |  | ||||||
|  |  | ||||||
| def command_compile(args, config): | def command_compile(args: ArgsProtocol, config: ConfigType) -> int | None: | ||||||
|     # Set memory analysis options in config |     # Set memory analysis options in config | ||||||
|     if args.analyze_memory: |     if args.analyze_memory: | ||||||
|         config.setdefault(CONF_ESPHOME, {})["analyze_memory"] = True |         config.setdefault(CONF_ESPHOME, {})["analyze_memory"] = True | ||||||
| @@ -484,8 +540,9 @@ def command_compile(args, config): | |||||||
|     return 0 |     return 0 | ||||||
|  |  | ||||||
|  |  | ||||||
| def command_upload(args, config): | def command_upload(args: ArgsProtocol, config: ConfigType) -> int | None: | ||||||
|     port = choose_upload_log_host( |     # Get devices, resolving special identifiers like OTA | ||||||
|  |     devices = choose_upload_log_host( | ||||||
|         default=args.device, |         default=args.device, | ||||||
|         check_default=None, |         check_default=None, | ||||||
|         show_ota=True, |         show_ota=True, | ||||||
| @@ -493,14 +550,22 @@ def command_upload(args, config): | |||||||
|         show_api=False, |         show_api=False, | ||||||
|         purpose="uploading", |         purpose="uploading", | ||||||
|     ) |     ) | ||||||
|     exit_code = upload_program(config, args, port) |  | ||||||
|     if exit_code != 0: |     # Try each device until one succeeds | ||||||
|         return exit_code |     exit_code = 1 | ||||||
|  |     for device in devices: | ||||||
|  |         _LOGGER.info("Uploading to %s", device) | ||||||
|  |         exit_code = upload_program(config, args, device) | ||||||
|  |         if exit_code == 0: | ||||||
|             _LOGGER.info("Successfully uploaded program.") |             _LOGGER.info("Successfully uploaded program.") | ||||||
|             return 0 |             return 0 | ||||||
|  |         if len(devices) > 1: | ||||||
|  |             _LOGGER.warning("Failed to upload to %s", device) | ||||||
|  |  | ||||||
|  |     return exit_code | ||||||
|  |  | ||||||
|  |  | ||||||
| def command_discover(args, config): | def command_discover(args: ArgsProtocol, config: ConfigType) -> int | None: | ||||||
|     if "mqtt" in config: |     if "mqtt" in config: | ||||||
|         from esphome import mqtt |         from esphome import mqtt | ||||||
|  |  | ||||||
| @@ -509,8 +574,9 @@ def command_discover(args, config): | |||||||
|     raise EsphomeError("No discover method configured (mqtt)") |     raise EsphomeError("No discover method configured (mqtt)") | ||||||
|  |  | ||||||
|  |  | ||||||
| def command_logs(args, config): | def command_logs(args: ArgsProtocol, config: ConfigType) -> int | None: | ||||||
|     port = choose_upload_log_host( |     # Get devices, resolving special identifiers like OTA | ||||||
|  |     devices = choose_upload_log_host( | ||||||
|         default=args.device, |         default=args.device, | ||||||
|         check_default=None, |         check_default=None, | ||||||
|         show_ota=False, |         show_ota=False, | ||||||
| @@ -518,10 +584,10 @@ def command_logs(args, config): | |||||||
|         show_api=True, |         show_api=True, | ||||||
|         purpose="logging", |         purpose="logging", | ||||||
|     ) |     ) | ||||||
|     return show_logs(config, args, port) |     return show_logs(config, args, devices) | ||||||
|  |  | ||||||
|  |  | ||||||
| def command_run(args, config): | def command_run(args: ArgsProtocol, config: ConfigType) -> int | None: | ||||||
|     exit_code = write_cpp(config) |     exit_code = write_cpp(config) | ||||||
|     if exit_code != 0: |     if exit_code != 0: | ||||||
|         return exit_code |         return exit_code | ||||||
| @@ -538,7 +604,8 @@ def command_run(args, config): | |||||||
|         program_path = idedata.raw["prog_path"] |         program_path = idedata.raw["prog_path"] | ||||||
|         return run_external_process(program_path) |         return run_external_process(program_path) | ||||||
|  |  | ||||||
|     port = choose_upload_log_host( |     # Get devices, resolving special identifiers like OTA | ||||||
|  |     devices = choose_upload_log_host( | ||||||
|         default=args.device, |         default=args.device, | ||||||
|         check_default=None, |         check_default=None, | ||||||
|         show_ota=True, |         show_ota=True, | ||||||
| @@ -546,39 +613,53 @@ def command_run(args, config): | |||||||
|         show_api=True, |         show_api=True, | ||||||
|         purpose="uploading", |         purpose="uploading", | ||||||
|     ) |     ) | ||||||
|     exit_code = upload_program(config, args, port) |  | ||||||
|     if exit_code != 0: |     # Try each device for upload until one succeeds | ||||||
|         return exit_code |     successful_device: str | None = None | ||||||
|  |     for device in devices: | ||||||
|  |         _LOGGER.info("Uploading to %s", device) | ||||||
|  |         exit_code = upload_program(config, args, device) | ||||||
|  |         if exit_code == 0: | ||||||
|             _LOGGER.info("Successfully uploaded program.") |             _LOGGER.info("Successfully uploaded program.") | ||||||
|  |             successful_device = device | ||||||
|  |             break | ||||||
|  |         if len(devices) > 1: | ||||||
|  |             _LOGGER.warning("Failed to upload to %s", device) | ||||||
|  |  | ||||||
|  |     if successful_device is None: | ||||||
|  |         return exit_code | ||||||
|  |  | ||||||
|     if args.no_logs: |     if args.no_logs: | ||||||
|         return 0 |         return 0 | ||||||
|     port = choose_upload_log_host( |  | ||||||
|         default=args.device, |     # For logs, prefer the device we successfully uploaded to | ||||||
|         check_default=port, |     devices = choose_upload_log_host( | ||||||
|  |         default=successful_device, | ||||||
|  |         check_default=successful_device, | ||||||
|         show_ota=False, |         show_ota=False, | ||||||
|         show_mqtt=True, |         show_mqtt=True, | ||||||
|         show_api=True, |         show_api=True, | ||||||
|         purpose="logging", |         purpose="logging", | ||||||
|     ) |     ) | ||||||
|     return show_logs(config, args, port) |     return show_logs(config, args, devices) | ||||||
|  |  | ||||||
|  |  | ||||||
| def command_clean_mqtt(args, config): | def command_clean_mqtt(args: ArgsProtocol, config: ConfigType) -> int | None: | ||||||
|     return clean_mqtt(config, args) |     return clean_mqtt(config, args) | ||||||
|  |  | ||||||
|  |  | ||||||
| def command_mqtt_fingerprint(args, config): | def command_mqtt_fingerprint(args: ArgsProtocol, config: ConfigType) -> int | None: | ||||||
|     from esphome import mqtt |     from esphome import mqtt | ||||||
|  |  | ||||||
|     return mqtt.get_fingerprint(config) |     return mqtt.get_fingerprint(config) | ||||||
|  |  | ||||||
|  |  | ||||||
| def command_version(args): | def command_version(args: ArgsProtocol) -> int | None: | ||||||
|     safe_print(f"Version: {const.__version__}") |     safe_print(f"Version: {const.__version__}") | ||||||
|     return 0 |     return 0 | ||||||
|  |  | ||||||
|  |  | ||||||
| def command_clean(args, config): | def command_clean(args: ArgsProtocol, config: ConfigType) -> int | None: | ||||||
|     try: |     try: | ||||||
|         writer.clean_build() |         writer.clean_build() | ||||||
|     except OSError as err: |     except OSError as err: | ||||||
| @@ -588,13 +669,13 @@ def command_clean(args, config): | |||||||
|     return 0 |     return 0 | ||||||
|  |  | ||||||
|  |  | ||||||
| def command_dashboard(args): | def command_dashboard(args: ArgsProtocol) -> int | None: | ||||||
|     from esphome.dashboard import dashboard |     from esphome.dashboard import dashboard | ||||||
|  |  | ||||||
|     return dashboard.start_dashboard(args) |     return dashboard.start_dashboard(args) | ||||||
|  |  | ||||||
|  |  | ||||||
| def command_update_all(args): | def command_update_all(args: ArgsProtocol) -> int | None: | ||||||
|     import click |     import click | ||||||
|  |  | ||||||
|     success = {} |     success = {} | ||||||
| @@ -641,7 +722,7 @@ def command_update_all(args): | |||||||
|     return failed |     return failed | ||||||
|  |  | ||||||
|  |  | ||||||
| def command_idedata(args, config): | def command_idedata(args: ArgsProtocol, config: ConfigType) -> int: | ||||||
|     import json |     import json | ||||||
|  |  | ||||||
|     from esphome import platformio_api |     from esphome import platformio_api | ||||||
| @@ -657,7 +738,7 @@ def command_idedata(args, config): | |||||||
|     return 0 |     return 0 | ||||||
|  |  | ||||||
|  |  | ||||||
| def command_rename(args, config): | def command_rename(args: ArgsProtocol, config: ConfigType) -> int | None: | ||||||
|     for c in args.name: |     for c in args.name: | ||||||
|         if c not in ALLOWED_NAME_CHARS: |         if c not in ALLOWED_NAME_CHARS: | ||||||
|             print( |             print( | ||||||
| @@ -774,6 +855,12 @@ POST_CONFIG_ACTIONS = { | |||||||
|     "discover": command_discover, |     "discover": command_discover, | ||||||
| } | } | ||||||
|  |  | ||||||
|  | SIMPLE_CONFIG_ACTIONS = [ | ||||||
|  |     "clean", | ||||||
|  |     "clean-mqtt", | ||||||
|  |     "config", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  |  | ||||||
| def parse_args(argv): | def parse_args(argv): | ||||||
|     options_parser = argparse.ArgumentParser(add_help=False) |     options_parser = argparse.ArgumentParser(add_help=False) | ||||||
| @@ -872,7 +959,8 @@ def parse_args(argv): | |||||||
|     ) |     ) | ||||||
|     parser_upload.add_argument( |     parser_upload.add_argument( | ||||||
|         "--device", |         "--device", | ||||||
|         help="Manually specify the serial port/address to use, for example /dev/ttyUSB0.", |         action="append", | ||||||
|  |         help="Manually specify the serial port/address to use, for example /dev/ttyUSB0. Can be specified multiple times for fallback addresses.", | ||||||
|     ) |     ) | ||||||
|     parser_upload.add_argument( |     parser_upload.add_argument( | ||||||
|         "--upload_speed", |         "--upload_speed", | ||||||
| @@ -894,7 +982,8 @@ def parse_args(argv): | |||||||
|     ) |     ) | ||||||
|     parser_logs.add_argument( |     parser_logs.add_argument( | ||||||
|         "--device", |         "--device", | ||||||
|         help="Manually specify the serial port/address to use, for example /dev/ttyUSB0.", |         action="append", | ||||||
|  |         help="Manually specify the serial port/address to use, for example /dev/ttyUSB0. Can be specified multiple times for fallback addresses.", | ||||||
|     ) |     ) | ||||||
|     parser_logs.add_argument( |     parser_logs.add_argument( | ||||||
|         "--reset", |         "--reset", | ||||||
| @@ -923,7 +1012,8 @@ def parse_args(argv): | |||||||
|     ) |     ) | ||||||
|     parser_run.add_argument( |     parser_run.add_argument( | ||||||
|         "--device", |         "--device", | ||||||
|         help="Manually specify the serial port/address to use, for example /dev/ttyUSB0.", |         action="append", | ||||||
|  |         help="Manually specify the serial port/address to use, for example /dev/ttyUSB0. Can be specified multiple times for fallback addresses.", | ||||||
|     ) |     ) | ||||||
|     parser_run.add_argument( |     parser_run.add_argument( | ||||||
|         "--upload_speed", |         "--upload_speed", | ||||||
| @@ -1050,6 +1140,13 @@ def parse_args(argv): | |||||||
|     arguments = argv[1:] |     arguments = argv[1:] | ||||||
|  |  | ||||||
|     argcomplete.autocomplete(parser) |     argcomplete.autocomplete(parser) | ||||||
|  |  | ||||||
|  |     if len(arguments) > 0 and arguments[0] in SIMPLE_CONFIG_ACTIONS: | ||||||
|  |         args, unknown_args = parser.parse_known_args(arguments) | ||||||
|  |         if unknown_args: | ||||||
|  |             _LOGGER.warning("Ignored unrecognized arguments: %s", unknown_args) | ||||||
|  |         return args | ||||||
|  |  | ||||||
|     return parser.parse_args(arguments) |     return parser.parse_args(arguments) | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -30,7 +30,7 @@ if TYPE_CHECKING: | |||||||
| _LOGGER = logging.getLogger(__name__) | _LOGGER = logging.getLogger(__name__) | ||||||
|  |  | ||||||
|  |  | ||||||
| async def async_run_logs(config: dict[str, Any], address: str) -> None: | async def async_run_logs(config: dict[str, Any], addresses: list[str]) -> None: | ||||||
|     """Run the logs command in the event loop.""" |     """Run the logs command in the event loop.""" | ||||||
|     conf = config["api"] |     conf = config["api"] | ||||||
|     name = config["esphome"]["name"] |     name = config["esphome"]["name"] | ||||||
| @@ -39,13 +39,21 @@ async def async_run_logs(config: dict[str, Any], address: str) -> None: | |||||||
|     noise_psk: str | None = None |     noise_psk: str | None = None | ||||||
|     if (encryption := conf.get(CONF_ENCRYPTION)) and (key := encryption.get(CONF_KEY)): |     if (encryption := conf.get(CONF_ENCRYPTION)) and (key := encryption.get(CONF_KEY)): | ||||||
|         noise_psk = key |         noise_psk = key | ||||||
|     _LOGGER.info("Starting log output from %s using esphome API", address) |  | ||||||
|  |     if len(addresses) == 1: | ||||||
|  |         _LOGGER.info("Starting log output from %s using esphome API", addresses[0]) | ||||||
|  |     else: | ||||||
|  |         _LOGGER.info( | ||||||
|  |             "Starting log output from %s using esphome API", " or ".join(addresses) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     cli = APIClient( |     cli = APIClient( | ||||||
|         address, |         addresses[0],  # Primary address for compatibility | ||||||
|         port, |         port, | ||||||
|         password, |         password, | ||||||
|         client_info=f"ESPHome Logs {__version__}", |         client_info=f"ESPHome Logs {__version__}", | ||||||
|         noise_psk=noise_psk, |         noise_psk=noise_psk, | ||||||
|  |         addresses=addresses,  # Pass all addresses for automatic retry | ||||||
|     ) |     ) | ||||||
|     dashboard = CORE.dashboard |     dashboard = CORE.dashboard | ||||||
|  |  | ||||||
| @@ -66,7 +74,7 @@ async def async_run_logs(config: dict[str, Any], address: str) -> None: | |||||||
|         await stop() |         await stop() | ||||||
|  |  | ||||||
|  |  | ||||||
| def run_logs(config: dict[str, Any], address: str) -> None: | def run_logs(config: dict[str, Any], addresses: list[str]) -> None: | ||||||
|     """Run the logs command.""" |     """Run the logs command.""" | ||||||
|     with contextlib.suppress(KeyboardInterrupt): |     with contextlib.suppress(KeyboardInterrupt): | ||||||
|         asyncio.run(async_run_logs(config, address)) |         asyncio.run(async_run_logs(config, addresses)) | ||||||
|   | |||||||
| @@ -2,11 +2,7 @@ | |||||||
| #include "esphome/core/helpers.h" | #include "esphome/core/helpers.h" | ||||||
| #include "esphome/core/log.h" | #include "esphome/core/log.h" | ||||||
|  |  | ||||||
| #ifdef USE_ESP32 | #if defined(USE_ESP32_VARIANT_ESP32) || defined(USE_ESP32_VARIANT_ESP32S2) | ||||||
|  |  | ||||||
| #ifdef USE_ARDUINO |  | ||||||
| #include <esp32-hal-dac.h> |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| namespace esphome { | namespace esphome { | ||||||
| namespace esp32_dac { | namespace esp32_dac { | ||||||
| @@ -23,18 +19,12 @@ void ESP32DAC::setup() { | |||||||
|   this->pin_->setup(); |   this->pin_->setup(); | ||||||
|   this->turn_off(); |   this->turn_off(); | ||||||
|  |  | ||||||
| #ifdef USE_ESP_IDF |  | ||||||
|   const dac_channel_t channel = this->pin_->get_pin() == DAC0_PIN ? DAC_CHAN_0 : DAC_CHAN_1; |   const dac_channel_t channel = this->pin_->get_pin() == DAC0_PIN ? DAC_CHAN_0 : DAC_CHAN_1; | ||||||
|   const dac_oneshot_config_t oneshot_cfg{channel}; |   const dac_oneshot_config_t oneshot_cfg{channel}; | ||||||
|   dac_oneshot_new_channel(&oneshot_cfg, &this->dac_handle_); |   dac_oneshot_new_channel(&oneshot_cfg, &this->dac_handle_); | ||||||
| #endif |  | ||||||
| } | } | ||||||
|  |  | ||||||
| void ESP32DAC::on_safe_shutdown() { | void ESP32DAC::on_safe_shutdown() { dac_oneshot_del_channel(this->dac_handle_); } | ||||||
| #ifdef USE_ESP_IDF |  | ||||||
|   dac_oneshot_del_channel(this->dac_handle_); |  | ||||||
| #endif |  | ||||||
| } |  | ||||||
|  |  | ||||||
| void ESP32DAC::dump_config() { | void ESP32DAC::dump_config() { | ||||||
|   ESP_LOGCONFIG(TAG, "ESP32 DAC:"); |   ESP_LOGCONFIG(TAG, "ESP32 DAC:"); | ||||||
| @@ -48,15 +38,10 @@ void ESP32DAC::write_state(float state) { | |||||||
|  |  | ||||||
|   state = state * 255; |   state = state * 255; | ||||||
|  |  | ||||||
| #ifdef USE_ESP_IDF |  | ||||||
|   dac_oneshot_output_voltage(this->dac_handle_, state); |   dac_oneshot_output_voltage(this->dac_handle_, state); | ||||||
| #endif |  | ||||||
| #ifdef USE_ARDUINO |  | ||||||
|   dacWrite(this->pin_->get_pin(), state); |  | ||||||
| #endif |  | ||||||
| } | } | ||||||
|  |  | ||||||
| }  // namespace esp32_dac | }  // namespace esp32_dac | ||||||
| }  // namespace esphome | }  // namespace esphome | ||||||
|  |  | ||||||
| #endif | #endif  // USE_ESP32_VARIANT_ESP32 || USE_ESP32_VARIANT_ESP32S2 | ||||||
|   | |||||||
| @@ -1,15 +1,13 @@ | |||||||
| #pragma once | #pragma once | ||||||
|  |  | ||||||
|  | #include "esphome/components/output/float_output.h" | ||||||
|  | #include "esphome/core/automation.h" | ||||||
| #include "esphome/core/component.h" | #include "esphome/core/component.h" | ||||||
| #include "esphome/core/hal.h" | #include "esphome/core/hal.h" | ||||||
| #include "esphome/core/automation.h" |  | ||||||
| #include "esphome/components/output/float_output.h" |  | ||||||
|  |  | ||||||
| #ifdef USE_ESP32 | #if defined(USE_ESP32_VARIANT_ESP32) || defined(USE_ESP32_VARIANT_ESP32S2) | ||||||
|  |  | ||||||
| #ifdef USE_ESP_IDF |  | ||||||
| #include <driver/dac_oneshot.h> | #include <driver/dac_oneshot.h> | ||||||
| #endif |  | ||||||
|  |  | ||||||
| namespace esphome { | namespace esphome { | ||||||
| namespace esp32_dac { | namespace esp32_dac { | ||||||
| @@ -29,12 +27,10 @@ class ESP32DAC : public output::FloatOutput, public Component { | |||||||
|   void write_state(float state) override; |   void write_state(float state) override; | ||||||
|  |  | ||||||
|   InternalGPIOPin *pin_; |   InternalGPIOPin *pin_; | ||||||
| #ifdef USE_ESP_IDF |  | ||||||
|   dac_oneshot_handle_t dac_handle_; |   dac_oneshot_handle_t dac_handle_; | ||||||
| #endif |  | ||||||
| }; | }; | ||||||
|  |  | ||||||
| }  // namespace esp32_dac | }  // namespace esp32_dac | ||||||
| }  // namespace esphome | }  // namespace esphome | ||||||
|  |  | ||||||
| #endif | #endif  // USE_ESP32_VARIANT_ESP32 || USE_ESP32_VARIANT_ESP32S2 | ||||||
|   | |||||||
| @@ -2,10 +2,11 @@ | |||||||
|  |  | ||||||
| #include <array> | #include <array> | ||||||
| #include <cstdint> | #include <cstdint> | ||||||
|  | #include <cstring> | ||||||
|  | #include <limits> | ||||||
| #include "esphome/core/hal.h" | #include "esphome/core/hal.h" | ||||||
|  |  | ||||||
| namespace esphome { | namespace esphome::gpio_expander { | ||||||
| namespace gpio_expander { |  | ||||||
|  |  | ||||||
| /// @brief A class to cache the read state of a GPIO expander. | /// @brief A class to cache the read state of a GPIO expander. | ||||||
| ///        This class caches reads between GPIO Pins which are on the same bank. | ///        This class caches reads between GPIO Pins which are on the same bank. | ||||||
| @@ -17,12 +18,22 @@ namespace gpio_expander { | |||||||
| ///           N - Number of pins | ///           N - Number of pins | ||||||
| template<typename T, T N> class CachedGpioExpander { | template<typename T, T N> class CachedGpioExpander { | ||||||
|  public: |  public: | ||||||
|  |   /// @brief Read the state of the given pin. This will invalidate the cache for the given pin number. | ||||||
|  |   /// @param pin Pin number to read | ||||||
|  |   /// @return Pin state | ||||||
|   bool digital_read(T pin) { |   bool digital_read(T pin) { | ||||||
|     uint8_t bank = pin / (sizeof(T) * BITS_PER_BYTE); |     const uint8_t bank = pin / BANK_SIZE; | ||||||
|     if (this->read_cache_invalidated_[bank]) { |     const T pin_mask = (1 << (pin % BANK_SIZE)); | ||||||
|       this->read_cache_invalidated_[bank] = false; |     // Check if specific pin cache is valid | ||||||
|  |     if (this->read_cache_valid_[bank] & pin_mask) { | ||||||
|  |       // Invalidate pin | ||||||
|  |       this->read_cache_valid_[bank] &= ~pin_mask; | ||||||
|  |     } else { | ||||||
|  |       // Read whole bank from hardware | ||||||
|       if (!this->digital_read_hw(pin)) |       if (!this->digital_read_hw(pin)) | ||||||
|         return false; |         return false; | ||||||
|  |       // Mark bank cache as valid except the pin that is being returned now | ||||||
|  |       this->read_cache_valid_[bank] = std::numeric_limits<T>::max() & ~pin_mask; | ||||||
|     } |     } | ||||||
|     return this->digital_read_cache(pin); |     return this->digital_read_cache(pin); | ||||||
|   } |   } | ||||||
| @@ -36,18 +47,16 @@ template<typename T, T N> class CachedGpioExpander { | |||||||
|   virtual bool digital_read_cache(T pin) = 0; |   virtual bool digital_read_cache(T pin) = 0; | ||||||
|   /// @brief Call component low level function to write GPIO state to device |   /// @brief Call component low level function to write GPIO state to device | ||||||
|   virtual void digital_write_hw(T pin, bool value) = 0; |   virtual void digital_write_hw(T pin, bool value) = 0; | ||||||
|   const uint8_t cache_byte_size_ = N / (sizeof(T) * BITS_PER_BYTE); |  | ||||||
|  |  | ||||||
|   /// @brief Invalidate cache. This function should be called in component loop(). |   /// @brief Invalidate cache. This function should be called in component loop(). | ||||||
|   void reset_pin_cache_() { |   void reset_pin_cache_() { memset(this->read_cache_valid_, 0x00, CACHE_SIZE_BYTES); } | ||||||
|     for (T i = 0; i < this->cache_byte_size_; i++) { |  | ||||||
|       this->read_cache_invalidated_[i] = true; |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   static const uint8_t BITS_PER_BYTE = 8; |   static constexpr uint8_t BITS_PER_BYTE = 8; | ||||||
|   std::array<bool, N / (sizeof(T) * BITS_PER_BYTE)> read_cache_invalidated_{}; |   static constexpr uint8_t BANK_SIZE = sizeof(T) * BITS_PER_BYTE; | ||||||
|  |   static constexpr size_t BANKS = N / BANK_SIZE; | ||||||
|  |   static constexpr size_t CACHE_SIZE_BYTES = BANKS * sizeof(T); | ||||||
|  |  | ||||||
|  |   T read_cache_valid_[BANKS]{0}; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| }  // namespace gpio_expander | }  // namespace esphome::gpio_expander | ||||||
| }  // namespace esphome |  | ||||||
|   | |||||||
| @@ -324,38 +324,46 @@ class EsphomePortCommandWebSocket(EsphomeCommandWebSocket): | |||||||
|         configuration = json_message["configuration"] |         configuration = json_message["configuration"] | ||||||
|         config_file = settings.rel_path(configuration) |         config_file = settings.rel_path(configuration) | ||||||
|         port = json_message["port"] |         port = json_message["port"] | ||||||
|  |         addresses: list[str] = [port] | ||||||
|         if ( |         if ( | ||||||
|             port == "OTA"  # pylint: disable=too-many-boolean-expressions |             port == "OTA"  # pylint: disable=too-many-boolean-expressions | ||||||
|             and (entry := entries.get(config_file)) |             and (entry := entries.get(config_file)) | ||||||
|             and entry.loaded_integrations |             and entry.loaded_integrations | ||||||
|             and "api" in entry.loaded_integrations |             and "api" in entry.loaded_integrations | ||||||
|         ): |         ): | ||||||
|             if (mdns := dashboard.mdns_status) and ( |             addresses = [] | ||||||
|                 address_list := await mdns.async_resolve_host(entry.name) |             # First priority: entry.address AKA use_address | ||||||
|             ): |             if ( | ||||||
|                 # Use the IP address if available but only |                 (use_address := entry.address) | ||||||
|                 # if the API is loaded and the device is online |  | ||||||
|                 # since MQTT logging will not work otherwise |  | ||||||
|                 port = sort_ip_addresses(address_list)[0] |  | ||||||
|             elif ( |  | ||||||
|                 entry.address |  | ||||||
|                 and ( |                 and ( | ||||||
|                     address_list := await dashboard.dns_cache.async_resolve( |                     address_list := await dashboard.dns_cache.async_resolve( | ||||||
|                         entry.address, time.monotonic() |                         use_address, time.monotonic() | ||||||
|                     ) |                     ) | ||||||
|                 ) |                 ) | ||||||
|                 and not isinstance(address_list, Exception) |                 and not isinstance(address_list, Exception) | ||||||
|             ): |             ): | ||||||
|                 # If mdns is not available, try to use the DNS cache |                 addresses.extend(sort_ip_addresses(address_list)) | ||||||
|                 port = sort_ip_addresses(address_list)[0] |  | ||||||
|  |  | ||||||
|         return [ |             # Second priority: mDNS | ||||||
|             *DASHBOARD_COMMAND, |             if ( | ||||||
|             *args, |                 (mdns := dashboard.mdns_status) | ||||||
|             config_file, |                 and (address_list := await mdns.async_resolve_host(entry.name)) | ||||||
|             "--device", |                 and ( | ||||||
|             port, |                     new_addresses := [ | ||||||
|  |                         addr for addr in address_list if addr not in addresses | ||||||
|                     ] |                     ] | ||||||
|  |                 ) | ||||||
|  |             ): | ||||||
|  |                 # Use the IP address if available but only | ||||||
|  |                 # if the API is loaded and the device is online | ||||||
|  |                 # since MQTT logging will not work otherwise | ||||||
|  |                 addresses.extend(sort_ip_addresses(new_addresses)) | ||||||
|  |  | ||||||
|  |         device_args: list[str] = [ | ||||||
|  |             arg for address in addresses for arg in ("--device", address) | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |         return [*DASHBOARD_COMMAND, *args, config_file, *device_args] | ||||||
|  |  | ||||||
|  |  | ||||||
| class EsphomeLogsHandler(EsphomePortCommandWebSocket): | class EsphomeLogsHandler(EsphomePortCommandWebSocket): | ||||||
|   | |||||||
| @@ -6,6 +6,7 @@ from pathlib import Path | |||||||
| import re | import re | ||||||
| import subprocess | import subprocess | ||||||
| import sys | import sys | ||||||
|  | from typing import Any | ||||||
|  |  | ||||||
| from esphome import const | from esphome import const | ||||||
|  |  | ||||||
| @@ -110,7 +111,7 @@ class RedirectText: | |||||||
|     def __getattr__(self, item): |     def __getattr__(self, item): | ||||||
|         return getattr(self._out, item) |         return getattr(self._out, item) | ||||||
|  |  | ||||||
|     def _write_color_replace(self, s): |     def _write_color_replace(self, s: str | bytes) -> None: | ||||||
|         from esphome.core import CORE |         from esphome.core import CORE | ||||||
|  |  | ||||||
|         if CORE.dashboard: |         if CORE.dashboard: | ||||||
| @@ -121,7 +122,7 @@ class RedirectText: | |||||||
|             s = s.replace("\033", "\\033") |             s = s.replace("\033", "\\033") | ||||||
|         self._out.write(s) |         self._out.write(s) | ||||||
|  |  | ||||||
|     def write(self, s): |     def write(self, s: str | bytes) -> int: | ||||||
|         # s is usually a str already (self._out is of type TextIOWrapper) |         # s is usually a str already (self._out is of type TextIOWrapper) | ||||||
|         # However, s is sometimes also a bytes object in python3. Let's make sure it's a |         # However, s is sometimes also a bytes object in python3. Let's make sure it's a | ||||||
|         # str |         # str | ||||||
| @@ -223,7 +224,7 @@ def run_external_command( | |||||||
|     return retval |     return retval | ||||||
|  |  | ||||||
|  |  | ||||||
| def run_external_process(*cmd, **kwargs): | def run_external_process(*cmd: str, **kwargs: Any) -> int | str: | ||||||
|     full_cmd = " ".join(shlex_quote(x) for x in cmd) |     full_cmd = " ".join(shlex_quote(x) for x in cmd) | ||||||
|     _LOGGER.debug("Running:  %s", full_cmd) |     _LOGGER.debug("Running:  %s", full_cmd) | ||||||
|     filter_lines = kwargs.get("filter_lines") |     filter_lines = kwargs.get("filter_lines") | ||||||
| @@ -266,7 +267,7 @@ class OrderedDict(collections.OrderedDict): | |||||||
|         return dict(self).__repr__() |         return dict(self).__repr__() | ||||||
|  |  | ||||||
|  |  | ||||||
| def list_yaml_files(folders): | def list_yaml_files(folders: list[str]) -> list[str]: | ||||||
|     files = filter_yaml_files( |     files = filter_yaml_files( | ||||||
|         [os.path.join(folder, p) for folder in folders for p in os.listdir(folder)] |         [os.path.join(folder, p) for folder in folders for p in os.listdir(folder)] | ||||||
|     ) |     ) | ||||||
| @@ -274,7 +275,7 @@ def list_yaml_files(folders): | |||||||
|     return files |     return files | ||||||
|  |  | ||||||
|  |  | ||||||
| def filter_yaml_files(files): | def filter_yaml_files(files: list[str]) -> list[str]: | ||||||
|     return [ |     return [ | ||||||
|         f |         f | ||||||
|         for f in files |         for f in files | ||||||
|   | |||||||
| @@ -12,7 +12,7 @@ platformio==6.1.18  # When updating platformio, also update /docker/Dockerfile | |||||||
| esptool==5.0.2 | esptool==5.0.2 | ||||||
| click==8.1.7 | click==8.1.7 | ||||||
| esphome-dashboard==20250514.0 | esphome-dashboard==20250514.0 | ||||||
| aioesphomeapi==37.2.4 | aioesphomeapi==37.2.5 | ||||||
| zeroconf==0.147.0 | zeroconf==0.147.0 | ||||||
| puremagic==1.30 | puremagic==1.30 | ||||||
| ruamel.yaml==0.18.14 # dashboard_import | ruamel.yaml==0.18.14 # dashboard_import | ||||||
|   | |||||||
| @@ -0,0 +1,25 @@ | |||||||
|  | import esphome.codegen as cg | ||||||
|  | import esphome.config_validation as cv | ||||||
|  | from esphome.const import CONF_ID | ||||||
|  |  | ||||||
|  | AUTO_LOAD = ["gpio_expander"] | ||||||
|  |  | ||||||
|  | gpio_expander_test_component_ns = cg.esphome_ns.namespace( | ||||||
|  |     "gpio_expander_test_component" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | GPIOExpanderTestComponent = gpio_expander_test_component_ns.class_( | ||||||
|  |     "GPIOExpanderTestComponent", cg.Component | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | CONFIG_SCHEMA = cv.Schema( | ||||||
|  |     { | ||||||
|  |         cv.GenerateID(): cv.declare_id(GPIOExpanderTestComponent), | ||||||
|  |     } | ||||||
|  | ).extend(cv.COMPONENT_SCHEMA) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def to_code(config): | ||||||
|  |     var = cg.new_Pvariable(config[CONF_ID]) | ||||||
|  |     await cg.register_component(var, config) | ||||||
| @@ -0,0 +1,38 @@ | |||||||
|  | #include "gpio_expander_test_component.h" | ||||||
|  |  | ||||||
|  | #include "esphome/core/application.h" | ||||||
|  | #include "esphome/core/log.h" | ||||||
|  |  | ||||||
|  | namespace esphome::gpio_expander_test_component { | ||||||
|  |  | ||||||
|  | static const char *const TAG = "gpio_expander_test"; | ||||||
|  |  | ||||||
|  | void GPIOExpanderTestComponent::setup() { | ||||||
|  |   for (uint8_t pin = 0; pin < 32; pin++) { | ||||||
|  |     this->digital_read(pin); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   this->digital_read(3); | ||||||
|  |   this->digital_read(3); | ||||||
|  |   this->digital_read(4); | ||||||
|  |   this->digital_read(3); | ||||||
|  |   this->digital_read(10); | ||||||
|  |   this->reset_pin_cache_();  // Reset cache to ensure next read is from hardware | ||||||
|  |   this->digital_read(15); | ||||||
|  |   this->digital_read(14); | ||||||
|  |   this->digital_read(14); | ||||||
|  |  | ||||||
|  |   ESP_LOGD(TAG, "DONE"); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | bool GPIOExpanderTestComponent::digital_read_hw(uint8_t pin) { | ||||||
|  |   ESP_LOGD(TAG, "digital_read_hw pin=%d", pin); | ||||||
|  |   return true; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | bool GPIOExpanderTestComponent::digital_read_cache(uint8_t pin) { | ||||||
|  |   ESP_LOGD(TAG, "digital_read_cache pin=%d", pin); | ||||||
|  |   return true; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | }  // namespace esphome::gpio_expander_test_component | ||||||
| @@ -0,0 +1,18 @@ | |||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include "esphome/components/gpio_expander/cached_gpio.h" | ||||||
|  | #include "esphome/core/component.h" | ||||||
|  |  | ||||||
|  | namespace esphome::gpio_expander_test_component { | ||||||
|  |  | ||||||
|  | class GPIOExpanderTestComponent : public Component, public esphome::gpio_expander::CachedGpioExpander<uint8_t, 32> { | ||||||
|  |  public: | ||||||
|  |   void setup() override; | ||||||
|  |  | ||||||
|  |  protected: | ||||||
|  |   bool digital_read_hw(uint8_t pin) override; | ||||||
|  |   bool digital_read_cache(uint8_t pin) override; | ||||||
|  |   void digital_write_hw(uint8_t pin, bool value) override{}; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | }  // namespace esphome::gpio_expander_test_component | ||||||
							
								
								
									
										17
									
								
								tests/integration/fixtures/gpio_expander_cache.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								tests/integration/fixtures/gpio_expander_cache.yaml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,17 @@ | |||||||
|  | esphome: | ||||||
|  |   name: gpio-expander-cache | ||||||
|  | host: | ||||||
|  |  | ||||||
|  | logger: | ||||||
|  |   level: DEBUG | ||||||
|  |  | ||||||
|  | api: | ||||||
|  |  | ||||||
|  | # External component that uses gpio_expander::CachedGpioExpander | ||||||
|  | external_components: | ||||||
|  |   - source: | ||||||
|  |       type: local | ||||||
|  |       path: EXTERNAL_COMPONENT_PATH | ||||||
|  |     components: [gpio_expander_test_component] | ||||||
|  |  | ||||||
|  | gpio_expander_test_component: | ||||||
							
								
								
									
										123
									
								
								tests/integration/test_gpio_expander_cache.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										123
									
								
								tests/integration/test_gpio_expander_cache.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,123 @@ | |||||||
|  | """Integration test for CachedGPIOExpander to ensure correct behavior.""" | ||||||
|  |  | ||||||
|  | from __future__ import annotations | ||||||
|  |  | ||||||
|  | import asyncio | ||||||
|  | from pathlib import Path | ||||||
|  | import re | ||||||
|  |  | ||||||
|  | import pytest | ||||||
|  |  | ||||||
|  | from .types import APIClientConnectedFactory, RunCompiledFunction | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.asyncio | ||||||
|  | async def test_gpio_expander_cache( | ||||||
|  |     yaml_config: str, | ||||||
|  |     run_compiled: RunCompiledFunction, | ||||||
|  |     api_client_connected: APIClientConnectedFactory, | ||||||
|  | ) -> None: | ||||||
|  |     """Test gpio_expander::CachedGpioExpander correctly calls hardware functions.""" | ||||||
|  |     # Get the path to the external components directory | ||||||
|  |     external_components_path = str( | ||||||
|  |         Path(__file__).parent / "fixtures" / "external_components" | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # Replace the placeholder in the YAML config with the actual path | ||||||
|  |     yaml_config = yaml_config.replace( | ||||||
|  |         "EXTERNAL_COMPONENT_PATH", external_components_path | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     logs_done = asyncio.Event() | ||||||
|  |  | ||||||
|  |     # Patterns to match in logs | ||||||
|  |     digital_read_hw_pattern = re.compile(r"digital_read_hw pin=(\d+)") | ||||||
|  |     digital_read_cache_pattern = re.compile(r"digital_read_cache pin=(\d+)") | ||||||
|  |  | ||||||
|  |     # ensure logs are in the expected order | ||||||
|  |     log_order = [ | ||||||
|  |         (digital_read_hw_pattern, 0), | ||||||
|  |         [(digital_read_cache_pattern, i) for i in range(0, 8)], | ||||||
|  |         (digital_read_hw_pattern, 8), | ||||||
|  |         [(digital_read_cache_pattern, i) for i in range(8, 16)], | ||||||
|  |         (digital_read_hw_pattern, 16), | ||||||
|  |         [(digital_read_cache_pattern, i) for i in range(16, 24)], | ||||||
|  |         (digital_read_hw_pattern, 24), | ||||||
|  |         [(digital_read_cache_pattern, i) for i in range(24, 32)], | ||||||
|  |         (digital_read_hw_pattern, 3), | ||||||
|  |         (digital_read_cache_pattern, 3), | ||||||
|  |         (digital_read_hw_pattern, 3), | ||||||
|  |         (digital_read_cache_pattern, 3), | ||||||
|  |         (digital_read_cache_pattern, 4), | ||||||
|  |         (digital_read_hw_pattern, 3), | ||||||
|  |         (digital_read_cache_pattern, 3), | ||||||
|  |         (digital_read_hw_pattern, 10), | ||||||
|  |         (digital_read_cache_pattern, 10), | ||||||
|  |         # full cache reset here for testing | ||||||
|  |         (digital_read_hw_pattern, 15), | ||||||
|  |         (digital_read_cache_pattern, 15), | ||||||
|  |         (digital_read_cache_pattern, 14), | ||||||
|  |         (digital_read_hw_pattern, 14), | ||||||
|  |         (digital_read_cache_pattern, 14), | ||||||
|  |     ] | ||||||
|  |     # Flatten the log order for easier processing | ||||||
|  |     log_order: list[tuple[re.Pattern, int]] = [ | ||||||
|  |         item | ||||||
|  |         for sublist in log_order | ||||||
|  |         for item in (sublist if isinstance(sublist, list) else [sublist]) | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     index = 0 | ||||||
|  |  | ||||||
|  |     def check_output(line: str) -> None: | ||||||
|  |         """Check log output for expected messages.""" | ||||||
|  |         nonlocal index | ||||||
|  |         if logs_done.is_set(): | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |         clean_line = re.sub(r"\x1b\[[0-9;]*m", "", line) | ||||||
|  |  | ||||||
|  |         if "digital_read" in clean_line: | ||||||
|  |             if index >= len(log_order): | ||||||
|  |                 print(f"Received unexpected log line: {clean_line}") | ||||||
|  |                 logs_done.set() | ||||||
|  |                 return | ||||||
|  |  | ||||||
|  |             pattern, expected_pin = log_order[index] | ||||||
|  |             match = pattern.search(clean_line) | ||||||
|  |  | ||||||
|  |             if not match: | ||||||
|  |                 print(f"Log line did not match next expected pattern: {clean_line}") | ||||||
|  |                 logs_done.set() | ||||||
|  |                 return | ||||||
|  |  | ||||||
|  |             pin = int(match.group(1)) | ||||||
|  |             if pin != expected_pin: | ||||||
|  |                 print(f"Unexpected pin number. Expected {expected_pin}, got {pin}") | ||||||
|  |                 logs_done.set() | ||||||
|  |                 return | ||||||
|  |  | ||||||
|  |             index += 1 | ||||||
|  |  | ||||||
|  |         elif "DONE" in clean_line: | ||||||
|  |             # Check if we reached the end of the expected log entries | ||||||
|  |             logs_done.set() | ||||||
|  |  | ||||||
|  |     # Run with log monitoring | ||||||
|  |     async with ( | ||||||
|  |         run_compiled(yaml_config, line_callback=check_output), | ||||||
|  |         api_client_connected() as client, | ||||||
|  |     ): | ||||||
|  |         # Verify device info | ||||||
|  |         device_info = await client.device_info() | ||||||
|  |         assert device_info is not None | ||||||
|  |         assert device_info.name == "gpio-expander-cache" | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             await asyncio.wait_for(logs_done.wait(), timeout=5.0) | ||||||
|  |         except TimeoutError: | ||||||
|  |             pytest.fail("Timeout waiting for logs to complete") | ||||||
|  |  | ||||||
|  |         assert index == len(log_order), ( | ||||||
|  |             f"Expected {len(log_order)} log entries, but got {index}" | ||||||
|  |         ) | ||||||
		Reference in New Issue
	
	Block a user