mirror of
https://github.com/esphome/esphome.git
synced 2025-09-05 21:02:20 +01:00
Merge branch 'multi_device_args' into integration
This commit is contained in:
@@ -9,6 +9,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
import argcomplete
|
import argcomplete
|
||||||
|
|
||||||
@@ -44,6 +45,7 @@ from esphome.const import (
|
|||||||
from esphome.core import CORE, EsphomeError, coroutine
|
from esphome.core import CORE, EsphomeError, coroutine
|
||||||
from esphome.helpers import get_bool_env, indent, is_ip_address
|
from esphome.helpers import get_bool_env, indent, is_ip_address
|
||||||
from esphome.log import AnsiFore, color, setup_log
|
from esphome.log import AnsiFore, color, setup_log
|
||||||
|
from esphome.types import ConfigType
|
||||||
from esphome.util import (
|
from esphome.util import (
|
||||||
get_serial_ports,
|
get_serial_ports,
|
||||||
list_yaml_files,
|
list_yaml_files,
|
||||||
@@ -55,6 +57,23 @@ from esphome.util import (
|
|||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ArgsProtocol(Protocol):
|
||||||
|
device: list[str] | None
|
||||||
|
reset: bool
|
||||||
|
username: str | None
|
||||||
|
password: str | None
|
||||||
|
client_id: str | None
|
||||||
|
topic: str | None
|
||||||
|
file: str | None
|
||||||
|
no_logs: bool
|
||||||
|
only_generate: bool
|
||||||
|
show_secrets: bool
|
||||||
|
dashboard: bool
|
||||||
|
configuration: str
|
||||||
|
name: str
|
||||||
|
upload_speed: str | None
|
||||||
|
|
||||||
|
|
||||||
def choose_prompt(options, purpose: str = None):
|
def choose_prompt(options, purpose: str = None):
|
||||||
if not options:
|
if not options:
|
||||||
raise EsphomeError(
|
raise EsphomeError(
|
||||||
@@ -88,30 +107,50 @@ def choose_prompt(options, purpose: str = None):
|
|||||||
|
|
||||||
|
|
||||||
def choose_upload_log_host(
|
def choose_upload_log_host(
|
||||||
default, check_default, show_ota, show_mqtt, show_api, purpose: str = None
|
default: list[str] | str | None,
|
||||||
):
|
check_default: str | None,
|
||||||
|
show_ota: bool,
|
||||||
|
show_mqtt: bool,
|
||||||
|
show_api: bool,
|
||||||
|
purpose: str | None = None,
|
||||||
|
) -> list[str]:
|
||||||
|
# Convert to list for uniform handling
|
||||||
|
defaults = [default] if isinstance(default, str) else default or []
|
||||||
|
|
||||||
|
# If devices specified, resolve them
|
||||||
|
if defaults:
|
||||||
|
resolved: list[str] = []
|
||||||
|
for device in defaults:
|
||||||
|
if device == "SERIAL":
|
||||||
|
options = [
|
||||||
|
(f"{port.path} ({port.description})", port.path)
|
||||||
|
for port in get_serial_ports()
|
||||||
|
]
|
||||||
|
resolved.append(choose_prompt(options, purpose=purpose))
|
||||||
|
elif device == "OTA":
|
||||||
|
if (show_ota and "ota" in CORE.config) or (
|
||||||
|
show_api and "api" in CORE.config
|
||||||
|
):
|
||||||
|
resolved.append(CORE.address)
|
||||||
|
elif show_mqtt and has_mqtt_logging():
|
||||||
|
resolved.append("MQTT")
|
||||||
|
else:
|
||||||
|
resolved.append(device)
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
# No devices specified, show interactive chooser
|
||||||
options = [
|
options = [
|
||||||
(f"{port.path} ({port.description})", port.path) for port in get_serial_ports()
|
(f"{port.path} ({port.description})", port.path) for port in get_serial_ports()
|
||||||
]
|
]
|
||||||
if default == "SERIAL":
|
|
||||||
return choose_prompt(options, purpose=purpose)
|
|
||||||
if (show_ota and "ota" in CORE.config) or (show_api and "api" in CORE.config):
|
if (show_ota and "ota" in CORE.config) or (show_api and "api" in CORE.config):
|
||||||
options.append((f"Over The Air ({CORE.address})", CORE.address))
|
options.append((f"Over The Air ({CORE.address})", CORE.address))
|
||||||
if default == "OTA":
|
if show_mqtt and has_mqtt_logging():
|
||||||
return CORE.address
|
mqtt_config = CORE.config[CONF_MQTT]
|
||||||
if (
|
|
||||||
show_mqtt
|
|
||||||
and (mqtt_config := CORE.config.get(CONF_MQTT))
|
|
||||||
and mqtt_logging_enabled(mqtt_config)
|
|
||||||
):
|
|
||||||
options.append((f"MQTT ({mqtt_config[CONF_BROKER]})", "MQTT"))
|
options.append((f"MQTT ({mqtt_config[CONF_BROKER]})", "MQTT"))
|
||||||
if default == "OTA":
|
|
||||||
return "MQTT"
|
|
||||||
if default is not None:
|
|
||||||
return default
|
|
||||||
if check_default is not None and check_default in [opt[1] for opt in options]:
|
if check_default is not None and check_default in [opt[1] for opt in options]:
|
||||||
return check_default
|
return [check_default]
|
||||||
return choose_prompt(options, purpose=purpose)
|
return [choose_prompt(options, purpose=purpose)]
|
||||||
|
|
||||||
|
|
||||||
def mqtt_logging_enabled(mqtt_config):
|
def mqtt_logging_enabled(mqtt_config):
|
||||||
@@ -123,7 +162,14 @@ def mqtt_logging_enabled(mqtt_config):
|
|||||||
return log_topic.get(CONF_LEVEL, None) != "NONE"
|
return log_topic.get(CONF_LEVEL, None) != "NONE"
|
||||||
|
|
||||||
|
|
||||||
def get_port_type(port):
|
def has_mqtt_logging() -> bool:
|
||||||
|
"""Check if MQTT logging is available."""
|
||||||
|
return (mqtt_config := CORE.config.get(CONF_MQTT)) and mqtt_logging_enabled(
|
||||||
|
mqtt_config
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_port_type(port: str) -> str:
|
||||||
if port.startswith("/") or port.startswith("COM"):
|
if port.startswith("/") or port.startswith("COM"):
|
||||||
return "SERIAL"
|
return "SERIAL"
|
||||||
if port == "MQTT":
|
if port == "MQTT":
|
||||||
@@ -131,7 +177,7 @@ def get_port_type(port):
|
|||||||
return "NETWORK"
|
return "NETWORK"
|
||||||
|
|
||||||
|
|
||||||
def run_miniterm(config, port, args):
|
def run_miniterm(config: ConfigType, port: str, args) -> int:
|
||||||
from aioesphomeapi import LogParser
|
from aioesphomeapi import LogParser
|
||||||
import serial
|
import serial
|
||||||
|
|
||||||
@@ -208,7 +254,7 @@ def wrap_to_code(name, comp):
|
|||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
def write_cpp(config):
|
def write_cpp(config: ConfigType) -> int:
|
||||||
if not get_bool_env(ENV_NOGITIGNORE):
|
if not get_bool_env(ENV_NOGITIGNORE):
|
||||||
writer.write_gitignore()
|
writer.write_gitignore()
|
||||||
|
|
||||||
@@ -216,7 +262,7 @@ def write_cpp(config):
|
|||||||
return write_cpp_file()
|
return write_cpp_file()
|
||||||
|
|
||||||
|
|
||||||
def generate_cpp_contents(config):
|
def generate_cpp_contents(config: ConfigType) -> None:
|
||||||
_LOGGER.info("Generating C++ source...")
|
_LOGGER.info("Generating C++ source...")
|
||||||
|
|
||||||
for name, component, conf in iter_component_configs(CORE.config):
|
for name, component, conf in iter_component_configs(CORE.config):
|
||||||
@@ -227,7 +273,7 @@ def generate_cpp_contents(config):
|
|||||||
CORE.flush_tasks()
|
CORE.flush_tasks()
|
||||||
|
|
||||||
|
|
||||||
def write_cpp_file():
|
def write_cpp_file() -> int:
|
||||||
code_s = indent(CORE.cpp_main_section)
|
code_s = indent(CORE.cpp_main_section)
|
||||||
writer.write_cpp(code_s)
|
writer.write_cpp(code_s)
|
||||||
|
|
||||||
@@ -238,7 +284,7 @@ def write_cpp_file():
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def compile_program(args, config):
|
def compile_program(args: ArgsProtocol, config: ConfigType) -> int:
|
||||||
from esphome import platformio_api
|
from esphome import platformio_api
|
||||||
|
|
||||||
_LOGGER.info("Compiling app...")
|
_LOGGER.info("Compiling app...")
|
||||||
@@ -249,7 +295,9 @@ def compile_program(args, config):
|
|||||||
return 0 if idedata is not None else 1
|
return 0 if idedata is not None else 1
|
||||||
|
|
||||||
|
|
||||||
def upload_using_esptool(config, port, file, speed):
|
def upload_using_esptool(
|
||||||
|
config: ConfigType, port: str, file: str, speed: int
|
||||||
|
) -> str | int:
|
||||||
from esphome import platformio_api
|
from esphome import platformio_api
|
||||||
|
|
||||||
first_baudrate = speed or config[CONF_ESPHOME][CONF_PLATFORMIO_OPTIONS].get(
|
first_baudrate = speed or config[CONF_ESPHOME][CONF_PLATFORMIO_OPTIONS].get(
|
||||||
@@ -314,7 +362,7 @@ def upload_using_esptool(config, port, file, speed):
|
|||||||
return run_esptool(115200)
|
return run_esptool(115200)
|
||||||
|
|
||||||
|
|
||||||
def upload_using_platformio(config, port):
|
def upload_using_platformio(config: ConfigType, port: str):
|
||||||
from esphome import platformio_api
|
from esphome import platformio_api
|
||||||
|
|
||||||
upload_args = ["-t", "upload", "-t", "nobuild"]
|
upload_args = ["-t", "upload", "-t", "nobuild"]
|
||||||
@@ -323,7 +371,7 @@ def upload_using_platformio(config, port):
|
|||||||
return platformio_api.run_platformio_cli_run(config, CORE.verbose, *upload_args)
|
return platformio_api.run_platformio_cli_run(config, CORE.verbose, *upload_args)
|
||||||
|
|
||||||
|
|
||||||
def check_permissions(port):
|
def check_permissions(port: str):
|
||||||
if os.name == "posix" and get_port_type(port) == "SERIAL":
|
if os.name == "posix" and get_port_type(port) == "SERIAL":
|
||||||
# Check if we can open selected serial port
|
# Check if we can open selected serial port
|
||||||
if not os.access(port, os.F_OK):
|
if not os.access(port, os.F_OK):
|
||||||
@@ -341,7 +389,7 @@ def check_permissions(port):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def upload_program(config, args, host):
|
def upload_program(config: ConfigType, args: ArgsProtocol, host: str) -> int | str:
|
||||||
try:
|
try:
|
||||||
module = importlib.import_module("esphome.components." + CORE.target_platform)
|
module = importlib.import_module("esphome.components." + CORE.target_platform)
|
||||||
if getattr(module, "upload_program")(config, args, host):
|
if getattr(module, "upload_program")(config, args, host):
|
||||||
@@ -356,7 +404,7 @@ def upload_program(config, args, host):
|
|||||||
return upload_using_esptool(config, host, file, args.upload_speed)
|
return upload_using_esptool(config, host, file, args.upload_speed)
|
||||||
|
|
||||||
if CORE.target_platform in (PLATFORM_RP2040):
|
if CORE.target_platform in (PLATFORM_RP2040):
|
||||||
return upload_using_platformio(config, args.device)
|
return upload_using_platformio(config, host)
|
||||||
|
|
||||||
if CORE.is_libretiny:
|
if CORE.is_libretiny:
|
||||||
return upload_using_platformio(config, host)
|
return upload_using_platformio(config, host)
|
||||||
@@ -379,9 +427,12 @@ def upload_program(config, args, host):
|
|||||||
remote_port = int(ota_conf[CONF_PORT])
|
remote_port = int(ota_conf[CONF_PORT])
|
||||||
password = ota_conf.get(CONF_PASSWORD, "")
|
password = ota_conf.get(CONF_PASSWORD, "")
|
||||||
|
|
||||||
|
# Check if we should use MQTT for address resolution
|
||||||
|
# This happens when no device was specified, or the current host is "MQTT"/"OTA"
|
||||||
|
devices: list[str] = args.device or []
|
||||||
if (
|
if (
|
||||||
CONF_MQTT in config # pylint: disable=too-many-boolean-expressions
|
CONF_MQTT in config # pylint: disable=too-many-boolean-expressions
|
||||||
and (not args.device or args.device in ("MQTT", "OTA"))
|
and (not devices or host in ("MQTT", "OTA"))
|
||||||
and (
|
and (
|
||||||
((config[CONF_MDNS][CONF_DISABLED]) and not is_ip_address(CORE.address))
|
((config[CONF_MDNS][CONF_DISABLED]) and not is_ip_address(CORE.address))
|
||||||
or get_port_type(host) == "MQTT"
|
or get_port_type(host) == "MQTT"
|
||||||
@@ -399,23 +450,28 @@ def upload_program(config, args, host):
|
|||||||
return espota2.run_ota(host, remote_port, password, CORE.firmware_bin)
|
return espota2.run_ota(host, remote_port, password, CORE.firmware_bin)
|
||||||
|
|
||||||
|
|
||||||
def show_logs(config, args, port):
|
def show_logs(config: ConfigType, args: ArgsProtocol, devices: list[str]) -> int | None:
|
||||||
if "logger" not in config:
|
if "logger" not in config:
|
||||||
raise EsphomeError("Logger is not configured!")
|
raise EsphomeError("Logger is not configured!")
|
||||||
|
|
||||||
|
port = devices[0]
|
||||||
|
|
||||||
if get_port_type(port) == "SERIAL":
|
if get_port_type(port) == "SERIAL":
|
||||||
check_permissions(port)
|
check_permissions(port)
|
||||||
return run_miniterm(config, port, args)
|
return run_miniterm(config, port, args)
|
||||||
if get_port_type(port) == "NETWORK" and "api" in config:
|
if get_port_type(port) == "NETWORK" and "api" in config:
|
||||||
|
addresses_to_use = devices
|
||||||
if config[CONF_MDNS][CONF_DISABLED] and CONF_MQTT in config:
|
if config[CONF_MDNS][CONF_DISABLED] and CONF_MQTT in config:
|
||||||
from esphome import mqtt
|
from esphome import mqtt
|
||||||
|
|
||||||
port = mqtt.get_esphome_device_ip(
|
mqtt_address = mqtt.get_esphome_device_ip(
|
||||||
config, args.username, args.password, args.client_id
|
config, args.username, args.password, args.client_id
|
||||||
)[0]
|
)[0]
|
||||||
|
addresses_to_use = [mqtt_address]
|
||||||
|
|
||||||
from esphome.components.api.client import run_logs
|
from esphome.components.api.client import run_logs
|
||||||
|
|
||||||
return run_logs(config, port)
|
return run_logs(config, addresses_to_use)
|
||||||
if get_port_type(port) == "MQTT" and "mqtt" in config:
|
if get_port_type(port) == "MQTT" and "mqtt" in config:
|
||||||
from esphome import mqtt
|
from esphome import mqtt
|
||||||
|
|
||||||
@@ -426,7 +482,7 @@ def show_logs(config, args, port):
|
|||||||
raise EsphomeError("No remote or local logging method configured (api/mqtt/logger)")
|
raise EsphomeError("No remote or local logging method configured (api/mqtt/logger)")
|
||||||
|
|
||||||
|
|
||||||
def clean_mqtt(config, args):
|
def clean_mqtt(config: ConfigType, args: ArgsProtocol) -> int | None:
|
||||||
from esphome import mqtt
|
from esphome import mqtt
|
||||||
|
|
||||||
return mqtt.clear_topic(
|
return mqtt.clear_topic(
|
||||||
@@ -434,13 +490,13 @@ def clean_mqtt(config, args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def command_wizard(args):
|
def command_wizard(args: ArgsProtocol) -> int | None:
|
||||||
from esphome import wizard
|
from esphome import wizard
|
||||||
|
|
||||||
return wizard.wizard(args.configuration)
|
return wizard.wizard(args.configuration)
|
||||||
|
|
||||||
|
|
||||||
def command_config(args, config):
|
def command_config(args: ArgsProtocol, config: ConfigType) -> int | None:
|
||||||
if not CORE.verbose:
|
if not CORE.verbose:
|
||||||
config = strip_default_ids(config)
|
config = strip_default_ids(config)
|
||||||
output = yaml_util.dump(config, args.show_secrets)
|
output = yaml_util.dump(config, args.show_secrets)
|
||||||
@@ -455,7 +511,7 @@ def command_config(args, config):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def command_vscode(args):
|
def command_vscode(args: ArgsProtocol) -> int | None:
|
||||||
from esphome import vscode
|
from esphome import vscode
|
||||||
|
|
||||||
logging.disable(logging.INFO)
|
logging.disable(logging.INFO)
|
||||||
@@ -463,7 +519,7 @@ def command_vscode(args):
|
|||||||
vscode.read_config(args)
|
vscode.read_config(args)
|
||||||
|
|
||||||
|
|
||||||
def command_compile(args, config):
|
def command_compile(args: ArgsProtocol, config: ConfigType) -> int | None:
|
||||||
# Set memory analysis options in config
|
# Set memory analysis options in config
|
||||||
if args.analyze_memory:
|
if args.analyze_memory:
|
||||||
config.setdefault(CONF_ESPHOME, {})["analyze_memory"] = True
|
config.setdefault(CONF_ESPHOME, {})["analyze_memory"] = True
|
||||||
@@ -484,8 +540,9 @@ def command_compile(args, config):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def command_upload(args, config):
|
def command_upload(args: ArgsProtocol, config: ConfigType) -> int | None:
|
||||||
port = choose_upload_log_host(
|
# Get devices, resolving special identifiers like OTA
|
||||||
|
devices = choose_upload_log_host(
|
||||||
default=args.device,
|
default=args.device,
|
||||||
check_default=None,
|
check_default=None,
|
||||||
show_ota=True,
|
show_ota=True,
|
||||||
@@ -493,14 +550,22 @@ def command_upload(args, config):
|
|||||||
show_api=False,
|
show_api=False,
|
||||||
purpose="uploading",
|
purpose="uploading",
|
||||||
)
|
)
|
||||||
exit_code = upload_program(config, args, port)
|
|
||||||
if exit_code != 0:
|
# Try each device until one succeeds
|
||||||
return exit_code
|
exit_code = 1
|
||||||
_LOGGER.info("Successfully uploaded program.")
|
for device in devices:
|
||||||
return 0
|
_LOGGER.info("Uploading to %s", device)
|
||||||
|
exit_code = upload_program(config, args, device)
|
||||||
|
if exit_code == 0:
|
||||||
|
_LOGGER.info("Successfully uploaded program.")
|
||||||
|
return 0
|
||||||
|
if len(devices) > 1:
|
||||||
|
_LOGGER.warning("Failed to upload to %s", device)
|
||||||
|
|
||||||
|
return exit_code
|
||||||
|
|
||||||
|
|
||||||
def command_discover(args, config):
|
def command_discover(args: ArgsProtocol, config: ConfigType) -> int | None:
|
||||||
if "mqtt" in config:
|
if "mqtt" in config:
|
||||||
from esphome import mqtt
|
from esphome import mqtt
|
||||||
|
|
||||||
@@ -509,8 +574,9 @@ def command_discover(args, config):
|
|||||||
raise EsphomeError("No discover method configured (mqtt)")
|
raise EsphomeError("No discover method configured (mqtt)")
|
||||||
|
|
||||||
|
|
||||||
def command_logs(args, config):
|
def command_logs(args: ArgsProtocol, config: ConfigType) -> int | None:
|
||||||
port = choose_upload_log_host(
|
# Get devices, resolving special identifiers like OTA
|
||||||
|
devices = choose_upload_log_host(
|
||||||
default=args.device,
|
default=args.device,
|
||||||
check_default=None,
|
check_default=None,
|
||||||
show_ota=False,
|
show_ota=False,
|
||||||
@@ -518,10 +584,10 @@ def command_logs(args, config):
|
|||||||
show_api=True,
|
show_api=True,
|
||||||
purpose="logging",
|
purpose="logging",
|
||||||
)
|
)
|
||||||
return show_logs(config, args, port)
|
return show_logs(config, args, devices)
|
||||||
|
|
||||||
|
|
||||||
def command_run(args, config):
|
def command_run(args: ArgsProtocol, config: ConfigType) -> int | None:
|
||||||
exit_code = write_cpp(config)
|
exit_code = write_cpp(config)
|
||||||
if exit_code != 0:
|
if exit_code != 0:
|
||||||
return exit_code
|
return exit_code
|
||||||
@@ -538,7 +604,8 @@ def command_run(args, config):
|
|||||||
program_path = idedata.raw["prog_path"]
|
program_path = idedata.raw["prog_path"]
|
||||||
return run_external_process(program_path)
|
return run_external_process(program_path)
|
||||||
|
|
||||||
port = choose_upload_log_host(
|
# Get devices, resolving special identifiers like OTA
|
||||||
|
devices = choose_upload_log_host(
|
||||||
default=args.device,
|
default=args.device,
|
||||||
check_default=None,
|
check_default=None,
|
||||||
show_ota=True,
|
show_ota=True,
|
||||||
@@ -546,39 +613,53 @@ def command_run(args, config):
|
|||||||
show_api=True,
|
show_api=True,
|
||||||
purpose="uploading",
|
purpose="uploading",
|
||||||
)
|
)
|
||||||
exit_code = upload_program(config, args, port)
|
|
||||||
if exit_code != 0:
|
# Try each device for upload until one succeeds
|
||||||
|
successful_device: str | None = None
|
||||||
|
for device in devices:
|
||||||
|
_LOGGER.info("Uploading to %s", device)
|
||||||
|
exit_code = upload_program(config, args, device)
|
||||||
|
if exit_code == 0:
|
||||||
|
_LOGGER.info("Successfully uploaded program.")
|
||||||
|
successful_device = device
|
||||||
|
break
|
||||||
|
if len(devices) > 1:
|
||||||
|
_LOGGER.warning("Failed to upload to %s", device)
|
||||||
|
|
||||||
|
if successful_device is None:
|
||||||
return exit_code
|
return exit_code
|
||||||
_LOGGER.info("Successfully uploaded program.")
|
|
||||||
if args.no_logs:
|
if args.no_logs:
|
||||||
return 0
|
return 0
|
||||||
port = choose_upload_log_host(
|
|
||||||
default=args.device,
|
# For logs, prefer the device we successfully uploaded to
|
||||||
check_default=port,
|
devices = choose_upload_log_host(
|
||||||
|
default=successful_device,
|
||||||
|
check_default=successful_device,
|
||||||
show_ota=False,
|
show_ota=False,
|
||||||
show_mqtt=True,
|
show_mqtt=True,
|
||||||
show_api=True,
|
show_api=True,
|
||||||
purpose="logging",
|
purpose="logging",
|
||||||
)
|
)
|
||||||
return show_logs(config, args, port)
|
return show_logs(config, args, devices)
|
||||||
|
|
||||||
|
|
||||||
def command_clean_mqtt(args, config):
|
def command_clean_mqtt(args: ArgsProtocol, config: ConfigType) -> int | None:
|
||||||
return clean_mqtt(config, args)
|
return clean_mqtt(config, args)
|
||||||
|
|
||||||
|
|
||||||
def command_mqtt_fingerprint(args, config):
|
def command_mqtt_fingerprint(args: ArgsProtocol, config: ConfigType) -> int | None:
|
||||||
from esphome import mqtt
|
from esphome import mqtt
|
||||||
|
|
||||||
return mqtt.get_fingerprint(config)
|
return mqtt.get_fingerprint(config)
|
||||||
|
|
||||||
|
|
||||||
def command_version(args):
|
def command_version(args: ArgsProtocol) -> int | None:
|
||||||
safe_print(f"Version: {const.__version__}")
|
safe_print(f"Version: {const.__version__}")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def command_clean(args, config):
|
def command_clean(args: ArgsProtocol, config: ConfigType) -> int | None:
|
||||||
try:
|
try:
|
||||||
writer.clean_build()
|
writer.clean_build()
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
@@ -588,13 +669,13 @@ def command_clean(args, config):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def command_dashboard(args):
|
def command_dashboard(args: ArgsProtocol) -> int | None:
|
||||||
from esphome.dashboard import dashboard
|
from esphome.dashboard import dashboard
|
||||||
|
|
||||||
return dashboard.start_dashboard(args)
|
return dashboard.start_dashboard(args)
|
||||||
|
|
||||||
|
|
||||||
def command_update_all(args):
|
def command_update_all(args: ArgsProtocol) -> int | None:
|
||||||
import click
|
import click
|
||||||
|
|
||||||
success = {}
|
success = {}
|
||||||
@@ -641,7 +722,7 @@ def command_update_all(args):
|
|||||||
return failed
|
return failed
|
||||||
|
|
||||||
|
|
||||||
def command_idedata(args, config):
|
def command_idedata(args: ArgsProtocol, config: ConfigType) -> int:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from esphome import platformio_api
|
from esphome import platformio_api
|
||||||
@@ -657,7 +738,7 @@ def command_idedata(args, config):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def command_rename(args, config):
|
def command_rename(args: ArgsProtocol, config: ConfigType) -> int | None:
|
||||||
for c in args.name:
|
for c in args.name:
|
||||||
if c not in ALLOWED_NAME_CHARS:
|
if c not in ALLOWED_NAME_CHARS:
|
||||||
print(
|
print(
|
||||||
@@ -774,6 +855,12 @@ POST_CONFIG_ACTIONS = {
|
|||||||
"discover": command_discover,
|
"discover": command_discover,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SIMPLE_CONFIG_ACTIONS = [
|
||||||
|
"clean",
|
||||||
|
"clean-mqtt",
|
||||||
|
"config",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def parse_args(argv):
|
def parse_args(argv):
|
||||||
options_parser = argparse.ArgumentParser(add_help=False)
|
options_parser = argparse.ArgumentParser(add_help=False)
|
||||||
@@ -872,7 +959,8 @@ def parse_args(argv):
|
|||||||
)
|
)
|
||||||
parser_upload.add_argument(
|
parser_upload.add_argument(
|
||||||
"--device",
|
"--device",
|
||||||
help="Manually specify the serial port/address to use, for example /dev/ttyUSB0.",
|
action="append",
|
||||||
|
help="Manually specify the serial port/address to use, for example /dev/ttyUSB0. Can be specified multiple times for fallback addresses.",
|
||||||
)
|
)
|
||||||
parser_upload.add_argument(
|
parser_upload.add_argument(
|
||||||
"--upload_speed",
|
"--upload_speed",
|
||||||
@@ -894,7 +982,8 @@ def parse_args(argv):
|
|||||||
)
|
)
|
||||||
parser_logs.add_argument(
|
parser_logs.add_argument(
|
||||||
"--device",
|
"--device",
|
||||||
help="Manually specify the serial port/address to use, for example /dev/ttyUSB0.",
|
action="append",
|
||||||
|
help="Manually specify the serial port/address to use, for example /dev/ttyUSB0. Can be specified multiple times for fallback addresses.",
|
||||||
)
|
)
|
||||||
parser_logs.add_argument(
|
parser_logs.add_argument(
|
||||||
"--reset",
|
"--reset",
|
||||||
@@ -923,7 +1012,8 @@ def parse_args(argv):
|
|||||||
)
|
)
|
||||||
parser_run.add_argument(
|
parser_run.add_argument(
|
||||||
"--device",
|
"--device",
|
||||||
help="Manually specify the serial port/address to use, for example /dev/ttyUSB0.",
|
action="append",
|
||||||
|
help="Manually specify the serial port/address to use, for example /dev/ttyUSB0. Can be specified multiple times for fallback addresses.",
|
||||||
)
|
)
|
||||||
parser_run.add_argument(
|
parser_run.add_argument(
|
||||||
"--upload_speed",
|
"--upload_speed",
|
||||||
@@ -1050,6 +1140,13 @@ def parse_args(argv):
|
|||||||
arguments = argv[1:]
|
arguments = argv[1:]
|
||||||
|
|
||||||
argcomplete.autocomplete(parser)
|
argcomplete.autocomplete(parser)
|
||||||
|
|
||||||
|
if len(arguments) > 0 and arguments[0] in SIMPLE_CONFIG_ACTIONS:
|
||||||
|
args, unknown_args = parser.parse_known_args(arguments)
|
||||||
|
if unknown_args:
|
||||||
|
_LOGGER.warning("Ignored unrecognized arguments: %s", unknown_args)
|
||||||
|
return args
|
||||||
|
|
||||||
return parser.parse_args(arguments)
|
return parser.parse_args(arguments)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
|||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def async_run_logs(config: dict[str, Any], address: str) -> None:
|
async def async_run_logs(config: dict[str, Any], addresses: list[str]) -> None:
|
||||||
"""Run the logs command in the event loop."""
|
"""Run the logs command in the event loop."""
|
||||||
conf = config["api"]
|
conf = config["api"]
|
||||||
name = config["esphome"]["name"]
|
name = config["esphome"]["name"]
|
||||||
@@ -39,13 +39,21 @@ async def async_run_logs(config: dict[str, Any], address: str) -> None:
|
|||||||
noise_psk: str | None = None
|
noise_psk: str | None = None
|
||||||
if (encryption := conf.get(CONF_ENCRYPTION)) and (key := encryption.get(CONF_KEY)):
|
if (encryption := conf.get(CONF_ENCRYPTION)) and (key := encryption.get(CONF_KEY)):
|
||||||
noise_psk = key
|
noise_psk = key
|
||||||
_LOGGER.info("Starting log output from %s using esphome API", address)
|
|
||||||
|
if len(addresses) == 1:
|
||||||
|
_LOGGER.info("Starting log output from %s using esphome API", addresses[0])
|
||||||
|
else:
|
||||||
|
_LOGGER.info(
|
||||||
|
"Starting log output from %s using esphome API", " or ".join(addresses)
|
||||||
|
)
|
||||||
|
|
||||||
cli = APIClient(
|
cli = APIClient(
|
||||||
address,
|
addresses[0], # Primary address for compatibility
|
||||||
port,
|
port,
|
||||||
password,
|
password,
|
||||||
client_info=f"ESPHome Logs {__version__}",
|
client_info=f"ESPHome Logs {__version__}",
|
||||||
noise_psk=noise_psk,
|
noise_psk=noise_psk,
|
||||||
|
addresses=addresses, # Pass all addresses for automatic retry
|
||||||
)
|
)
|
||||||
dashboard = CORE.dashboard
|
dashboard = CORE.dashboard
|
||||||
|
|
||||||
@@ -66,7 +74,7 @@ async def async_run_logs(config: dict[str, Any], address: str) -> None:
|
|||||||
await stop()
|
await stop()
|
||||||
|
|
||||||
|
|
||||||
def run_logs(config: dict[str, Any], address: str) -> None:
|
def run_logs(config: dict[str, Any], addresses: list[str]) -> None:
|
||||||
"""Run the logs command."""
|
"""Run the logs command."""
|
||||||
with contextlib.suppress(KeyboardInterrupt):
|
with contextlib.suppress(KeyboardInterrupt):
|
||||||
asyncio.run(async_run_logs(config, address))
|
asyncio.run(async_run_logs(config, addresses))
|
||||||
|
@@ -2,11 +2,7 @@
|
|||||||
#include "esphome/core/helpers.h"
|
#include "esphome/core/helpers.h"
|
||||||
#include "esphome/core/log.h"
|
#include "esphome/core/log.h"
|
||||||
|
|
||||||
#ifdef USE_ESP32
|
#if defined(USE_ESP32_VARIANT_ESP32) || defined(USE_ESP32_VARIANT_ESP32S2)
|
||||||
|
|
||||||
#ifdef USE_ARDUINO
|
|
||||||
#include <esp32-hal-dac.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace esphome {
|
namespace esphome {
|
||||||
namespace esp32_dac {
|
namespace esp32_dac {
|
||||||
@@ -23,18 +19,12 @@ void ESP32DAC::setup() {
|
|||||||
this->pin_->setup();
|
this->pin_->setup();
|
||||||
this->turn_off();
|
this->turn_off();
|
||||||
|
|
||||||
#ifdef USE_ESP_IDF
|
|
||||||
const dac_channel_t channel = this->pin_->get_pin() == DAC0_PIN ? DAC_CHAN_0 : DAC_CHAN_1;
|
const dac_channel_t channel = this->pin_->get_pin() == DAC0_PIN ? DAC_CHAN_0 : DAC_CHAN_1;
|
||||||
const dac_oneshot_config_t oneshot_cfg{channel};
|
const dac_oneshot_config_t oneshot_cfg{channel};
|
||||||
dac_oneshot_new_channel(&oneshot_cfg, &this->dac_handle_);
|
dac_oneshot_new_channel(&oneshot_cfg, &this->dac_handle_);
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ESP32DAC::on_safe_shutdown() {
|
void ESP32DAC::on_safe_shutdown() { dac_oneshot_del_channel(this->dac_handle_); }
|
||||||
#ifdef USE_ESP_IDF
|
|
||||||
dac_oneshot_del_channel(this->dac_handle_);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
void ESP32DAC::dump_config() {
|
void ESP32DAC::dump_config() {
|
||||||
ESP_LOGCONFIG(TAG, "ESP32 DAC:");
|
ESP_LOGCONFIG(TAG, "ESP32 DAC:");
|
||||||
@@ -48,15 +38,10 @@ void ESP32DAC::write_state(float state) {
|
|||||||
|
|
||||||
state = state * 255;
|
state = state * 255;
|
||||||
|
|
||||||
#ifdef USE_ESP_IDF
|
|
||||||
dac_oneshot_output_voltage(this->dac_handle_, state);
|
dac_oneshot_output_voltage(this->dac_handle_, state);
|
||||||
#endif
|
|
||||||
#ifdef USE_ARDUINO
|
|
||||||
dacWrite(this->pin_->get_pin(), state);
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace esp32_dac
|
} // namespace esp32_dac
|
||||||
} // namespace esphome
|
} // namespace esphome
|
||||||
|
|
||||||
#endif
|
#endif // USE_ESP32_VARIANT_ESP32 || USE_ESP32_VARIANT_ESP32S2
|
||||||
|
@@ -1,15 +1,13 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "esphome/components/output/float_output.h"
|
||||||
|
#include "esphome/core/automation.h"
|
||||||
#include "esphome/core/component.h"
|
#include "esphome/core/component.h"
|
||||||
#include "esphome/core/hal.h"
|
#include "esphome/core/hal.h"
|
||||||
#include "esphome/core/automation.h"
|
|
||||||
#include "esphome/components/output/float_output.h"
|
|
||||||
|
|
||||||
#ifdef USE_ESP32
|
#if defined(USE_ESP32_VARIANT_ESP32) || defined(USE_ESP32_VARIANT_ESP32S2)
|
||||||
|
|
||||||
#ifdef USE_ESP_IDF
|
|
||||||
#include <driver/dac_oneshot.h>
|
#include <driver/dac_oneshot.h>
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace esphome {
|
namespace esphome {
|
||||||
namespace esp32_dac {
|
namespace esp32_dac {
|
||||||
@@ -29,12 +27,10 @@ class ESP32DAC : public output::FloatOutput, public Component {
|
|||||||
void write_state(float state) override;
|
void write_state(float state) override;
|
||||||
|
|
||||||
InternalGPIOPin *pin_;
|
InternalGPIOPin *pin_;
|
||||||
#ifdef USE_ESP_IDF
|
|
||||||
dac_oneshot_handle_t dac_handle_;
|
dac_oneshot_handle_t dac_handle_;
|
||||||
#endif
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace esp32_dac
|
} // namespace esp32_dac
|
||||||
} // namespace esphome
|
} // namespace esphome
|
||||||
|
|
||||||
#endif
|
#endif // USE_ESP32_VARIANT_ESP32 || USE_ESP32_VARIANT_ESP32S2
|
||||||
|
@@ -2,10 +2,11 @@
|
|||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <cstring>
|
||||||
|
#include <limits>
|
||||||
#include "esphome/core/hal.h"
|
#include "esphome/core/hal.h"
|
||||||
|
|
||||||
namespace esphome {
|
namespace esphome::gpio_expander {
|
||||||
namespace gpio_expander {
|
|
||||||
|
|
||||||
/// @brief A class to cache the read state of a GPIO expander.
|
/// @brief A class to cache the read state of a GPIO expander.
|
||||||
/// This class caches reads between GPIO Pins which are on the same bank.
|
/// This class caches reads between GPIO Pins which are on the same bank.
|
||||||
@@ -17,12 +18,22 @@ namespace gpio_expander {
|
|||||||
/// N - Number of pins
|
/// N - Number of pins
|
||||||
template<typename T, T N> class CachedGpioExpander {
|
template<typename T, T N> class CachedGpioExpander {
|
||||||
public:
|
public:
|
||||||
|
/// @brief Read the state of the given pin. This will invalidate the cache for the given pin number.
|
||||||
|
/// @param pin Pin number to read
|
||||||
|
/// @return Pin state
|
||||||
bool digital_read(T pin) {
|
bool digital_read(T pin) {
|
||||||
uint8_t bank = pin / (sizeof(T) * BITS_PER_BYTE);
|
const uint8_t bank = pin / BANK_SIZE;
|
||||||
if (this->read_cache_invalidated_[bank]) {
|
const T pin_mask = (1 << (pin % BANK_SIZE));
|
||||||
this->read_cache_invalidated_[bank] = false;
|
// Check if specific pin cache is valid
|
||||||
|
if (this->read_cache_valid_[bank] & pin_mask) {
|
||||||
|
// Invalidate pin
|
||||||
|
this->read_cache_valid_[bank] &= ~pin_mask;
|
||||||
|
} else {
|
||||||
|
// Read whole bank from hardware
|
||||||
if (!this->digital_read_hw(pin))
|
if (!this->digital_read_hw(pin))
|
||||||
return false;
|
return false;
|
||||||
|
// Mark bank cache as valid except the pin that is being returned now
|
||||||
|
this->read_cache_valid_[bank] = std::numeric_limits<T>::max() & ~pin_mask;
|
||||||
}
|
}
|
||||||
return this->digital_read_cache(pin);
|
return this->digital_read_cache(pin);
|
||||||
}
|
}
|
||||||
@@ -36,18 +47,16 @@ template<typename T, T N> class CachedGpioExpander {
|
|||||||
virtual bool digital_read_cache(T pin) = 0;
|
virtual bool digital_read_cache(T pin) = 0;
|
||||||
/// @brief Call component low level function to write GPIO state to device
|
/// @brief Call component low level function to write GPIO state to device
|
||||||
virtual void digital_write_hw(T pin, bool value) = 0;
|
virtual void digital_write_hw(T pin, bool value) = 0;
|
||||||
const uint8_t cache_byte_size_ = N / (sizeof(T) * BITS_PER_BYTE);
|
|
||||||
|
|
||||||
/// @brief Invalidate cache. This function should be called in component loop().
|
/// @brief Invalidate cache. This function should be called in component loop().
|
||||||
void reset_pin_cache_() {
|
void reset_pin_cache_() { memset(this->read_cache_valid_, 0x00, CACHE_SIZE_BYTES); }
|
||||||
for (T i = 0; i < this->cache_byte_size_; i++) {
|
|
||||||
this->read_cache_invalidated_[i] = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static const uint8_t BITS_PER_BYTE = 8;
|
static constexpr uint8_t BITS_PER_BYTE = 8;
|
||||||
std::array<bool, N / (sizeof(T) * BITS_PER_BYTE)> read_cache_invalidated_{};
|
static constexpr uint8_t BANK_SIZE = sizeof(T) * BITS_PER_BYTE;
|
||||||
|
static constexpr size_t BANKS = N / BANK_SIZE;
|
||||||
|
static constexpr size_t CACHE_SIZE_BYTES = BANKS * sizeof(T);
|
||||||
|
|
||||||
|
T read_cache_valid_[BANKS]{0};
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gpio_expander
|
} // namespace esphome::gpio_expander
|
||||||
} // namespace esphome
|
|
||||||
|
@@ -324,39 +324,47 @@ class EsphomePortCommandWebSocket(EsphomeCommandWebSocket):
|
|||||||
configuration = json_message["configuration"]
|
configuration = json_message["configuration"]
|
||||||
config_file = settings.rel_path(configuration)
|
config_file = settings.rel_path(configuration)
|
||||||
port = json_message["port"]
|
port = json_message["port"]
|
||||||
|
addresses: list[str] = [port]
|
||||||
if (
|
if (
|
||||||
port == "OTA" # pylint: disable=too-many-boolean-expressions
|
port == "OTA" # pylint: disable=too-many-boolean-expressions
|
||||||
and (entry := entries.get(config_file))
|
and (entry := entries.get(config_file))
|
||||||
and entry.loaded_integrations
|
and entry.loaded_integrations
|
||||||
and "api" in entry.loaded_integrations
|
and "api" in entry.loaded_integrations
|
||||||
):
|
):
|
||||||
if (mdns := dashboard.mdns_status) and (
|
addresses = []
|
||||||
address_list := await mdns.async_resolve_host(entry.name)
|
# First priority: entry.address AKA use_address
|
||||||
):
|
if (
|
||||||
# Use the IP address if available but only
|
(use_address := entry.address)
|
||||||
# if the API is loaded and the device is online
|
|
||||||
# since MQTT logging will not work otherwise
|
|
||||||
port = sort_ip_addresses(address_list)[0]
|
|
||||||
elif (
|
|
||||||
entry.address
|
|
||||||
and (
|
and (
|
||||||
address_list := await dashboard.dns_cache.async_resolve(
|
address_list := await dashboard.dns_cache.async_resolve(
|
||||||
entry.address, time.monotonic()
|
use_address, time.monotonic()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
and not isinstance(address_list, Exception)
|
and not isinstance(address_list, Exception)
|
||||||
):
|
):
|
||||||
# If mdns is not available, try to use the DNS cache
|
addresses.extend(sort_ip_addresses(address_list))
|
||||||
port = sort_ip_addresses(address_list)[0]
|
|
||||||
|
|
||||||
return [
|
# Second priority: mDNS
|
||||||
*DASHBOARD_COMMAND,
|
if (
|
||||||
*args,
|
(mdns := dashboard.mdns_status)
|
||||||
config_file,
|
and (address_list := await mdns.async_resolve_host(entry.name))
|
||||||
"--device",
|
and (
|
||||||
port,
|
new_addresses := [
|
||||||
|
addr for addr in address_list if addr not in addresses
|
||||||
|
]
|
||||||
|
)
|
||||||
|
):
|
||||||
|
# Use the IP address if available but only
|
||||||
|
# if the API is loaded and the device is online
|
||||||
|
# since MQTT logging will not work otherwise
|
||||||
|
addresses.extend(sort_ip_addresses(new_addresses))
|
||||||
|
|
||||||
|
device_args: list[str] = [
|
||||||
|
arg for address in addresses for arg in ("--device", address)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
return [*DASHBOARD_COMMAND, *args, config_file, *device_args]
|
||||||
|
|
||||||
|
|
||||||
class EsphomeLogsHandler(EsphomePortCommandWebSocket):
|
class EsphomeLogsHandler(EsphomePortCommandWebSocket):
|
||||||
async def build_command(self, json_message: dict[str, Any]) -> list[str]:
|
async def build_command(self, json_message: dict[str, Any]) -> list[str]:
|
||||||
|
@@ -6,6 +6,7 @@ from pathlib import Path
|
|||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from esphome import const
|
from esphome import const
|
||||||
|
|
||||||
@@ -110,7 +111,7 @@ class RedirectText:
|
|||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
return getattr(self._out, item)
|
return getattr(self._out, item)
|
||||||
|
|
||||||
def _write_color_replace(self, s):
|
def _write_color_replace(self, s: str | bytes) -> None:
|
||||||
from esphome.core import CORE
|
from esphome.core import CORE
|
||||||
|
|
||||||
if CORE.dashboard:
|
if CORE.dashboard:
|
||||||
@@ -121,7 +122,7 @@ class RedirectText:
|
|||||||
s = s.replace("\033", "\\033")
|
s = s.replace("\033", "\\033")
|
||||||
self._out.write(s)
|
self._out.write(s)
|
||||||
|
|
||||||
def write(self, s):
|
def write(self, s: str | bytes) -> int:
|
||||||
# s is usually a str already (self._out is of type TextIOWrapper)
|
# s is usually a str already (self._out is of type TextIOWrapper)
|
||||||
# However, s is sometimes also a bytes object in python3. Let's make sure it's a
|
# However, s is sometimes also a bytes object in python3. Let's make sure it's a
|
||||||
# str
|
# str
|
||||||
@@ -223,7 +224,7 @@ def run_external_command(
|
|||||||
return retval
|
return retval
|
||||||
|
|
||||||
|
|
||||||
def run_external_process(*cmd, **kwargs):
|
def run_external_process(*cmd: str, **kwargs: Any) -> int | str:
|
||||||
full_cmd = " ".join(shlex_quote(x) for x in cmd)
|
full_cmd = " ".join(shlex_quote(x) for x in cmd)
|
||||||
_LOGGER.debug("Running: %s", full_cmd)
|
_LOGGER.debug("Running: %s", full_cmd)
|
||||||
filter_lines = kwargs.get("filter_lines")
|
filter_lines = kwargs.get("filter_lines")
|
||||||
@@ -266,7 +267,7 @@ class OrderedDict(collections.OrderedDict):
|
|||||||
return dict(self).__repr__()
|
return dict(self).__repr__()
|
||||||
|
|
||||||
|
|
||||||
def list_yaml_files(folders):
|
def list_yaml_files(folders: list[str]) -> list[str]:
|
||||||
files = filter_yaml_files(
|
files = filter_yaml_files(
|
||||||
[os.path.join(folder, p) for folder in folders for p in os.listdir(folder)]
|
[os.path.join(folder, p) for folder in folders for p in os.listdir(folder)]
|
||||||
)
|
)
|
||||||
@@ -274,7 +275,7 @@ def list_yaml_files(folders):
|
|||||||
return files
|
return files
|
||||||
|
|
||||||
|
|
||||||
def filter_yaml_files(files):
|
def filter_yaml_files(files: list[str]) -> list[str]:
|
||||||
return [
|
return [
|
||||||
f
|
f
|
||||||
for f in files
|
for f in files
|
||||||
|
@@ -12,7 +12,7 @@ platformio==6.1.18 # When updating platformio, also update /docker/Dockerfile
|
|||||||
esptool==5.0.2
|
esptool==5.0.2
|
||||||
click==8.1.7
|
click==8.1.7
|
||||||
esphome-dashboard==20250514.0
|
esphome-dashboard==20250514.0
|
||||||
aioesphomeapi==37.2.4
|
aioesphomeapi==37.2.5
|
||||||
zeroconf==0.147.0
|
zeroconf==0.147.0
|
||||||
puremagic==1.30
|
puremagic==1.30
|
||||||
ruamel.yaml==0.18.14 # dashboard_import
|
ruamel.yaml==0.18.14 # dashboard_import
|
||||||
|
@@ -0,0 +1,25 @@
|
|||||||
|
import esphome.codegen as cg
|
||||||
|
import esphome.config_validation as cv
|
||||||
|
from esphome.const import CONF_ID
|
||||||
|
|
||||||
|
AUTO_LOAD = ["gpio_expander"]
|
||||||
|
|
||||||
|
gpio_expander_test_component_ns = cg.esphome_ns.namespace(
|
||||||
|
"gpio_expander_test_component"
|
||||||
|
)
|
||||||
|
|
||||||
|
GPIOExpanderTestComponent = gpio_expander_test_component_ns.class_(
|
||||||
|
"GPIOExpanderTestComponent", cg.Component
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
CONFIG_SCHEMA = cv.Schema(
|
||||||
|
{
|
||||||
|
cv.GenerateID(): cv.declare_id(GPIOExpanderTestComponent),
|
||||||
|
}
|
||||||
|
).extend(cv.COMPONENT_SCHEMA)
|
||||||
|
|
||||||
|
|
||||||
|
async def to_code(config):
|
||||||
|
var = cg.new_Pvariable(config[CONF_ID])
|
||||||
|
await cg.register_component(var, config)
|
@@ -0,0 +1,38 @@
|
|||||||
|
#include "gpio_expander_test_component.h"
|
||||||
|
|
||||||
|
#include "esphome/core/application.h"
|
||||||
|
#include "esphome/core/log.h"
|
||||||
|
|
||||||
|
namespace esphome::gpio_expander_test_component {
|
||||||
|
|
||||||
|
static const char *const TAG = "gpio_expander_test";
|
||||||
|
|
||||||
|
void GPIOExpanderTestComponent::setup() {
|
||||||
|
for (uint8_t pin = 0; pin < 32; pin++) {
|
||||||
|
this->digital_read(pin);
|
||||||
|
}
|
||||||
|
|
||||||
|
this->digital_read(3);
|
||||||
|
this->digital_read(3);
|
||||||
|
this->digital_read(4);
|
||||||
|
this->digital_read(3);
|
||||||
|
this->digital_read(10);
|
||||||
|
this->reset_pin_cache_(); // Reset cache to ensure next read is from hardware
|
||||||
|
this->digital_read(15);
|
||||||
|
this->digital_read(14);
|
||||||
|
this->digital_read(14);
|
||||||
|
|
||||||
|
ESP_LOGD(TAG, "DONE");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GPIOExpanderTestComponent::digital_read_hw(uint8_t pin) {
|
||||||
|
ESP_LOGD(TAG, "digital_read_hw pin=%d", pin);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GPIOExpanderTestComponent::digital_read_cache(uint8_t pin) {
|
||||||
|
ESP_LOGD(TAG, "digital_read_cache pin=%d", pin);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace esphome::gpio_expander_test_component
|
@@ -0,0 +1,18 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "esphome/components/gpio_expander/cached_gpio.h"
|
||||||
|
#include "esphome/core/component.h"
|
||||||
|
|
||||||
|
namespace esphome::gpio_expander_test_component {
|
||||||
|
|
||||||
|
class GPIOExpanderTestComponent : public Component, public esphome::gpio_expander::CachedGpioExpander<uint8_t, 32> {
|
||||||
|
public:
|
||||||
|
void setup() override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
bool digital_read_hw(uint8_t pin) override;
|
||||||
|
bool digital_read_cache(uint8_t pin) override;
|
||||||
|
void digital_write_hw(uint8_t pin, bool value) override{};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace esphome::gpio_expander_test_component
|
17
tests/integration/fixtures/gpio_expander_cache.yaml
Normal file
17
tests/integration/fixtures/gpio_expander_cache.yaml
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
esphome:
|
||||||
|
name: gpio-expander-cache
|
||||||
|
host:
|
||||||
|
|
||||||
|
logger:
|
||||||
|
level: DEBUG
|
||||||
|
|
||||||
|
api:
|
||||||
|
|
||||||
|
# External component that uses gpio_expander::CachedGpioExpander
|
||||||
|
external_components:
|
||||||
|
- source:
|
||||||
|
type: local
|
||||||
|
path: EXTERNAL_COMPONENT_PATH
|
||||||
|
components: [gpio_expander_test_component]
|
||||||
|
|
||||||
|
gpio_expander_test_component:
|
123
tests/integration/test_gpio_expander_cache.py
Normal file
123
tests/integration/test_gpio_expander_cache.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
"""Integration test for CachedGPIOExpander to ensure correct behavior."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
import re
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from .types import APIClientConnectedFactory, RunCompiledFunction
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gpio_expander_cache(
|
||||||
|
yaml_config: str,
|
||||||
|
run_compiled: RunCompiledFunction,
|
||||||
|
api_client_connected: APIClientConnectedFactory,
|
||||||
|
) -> None:
|
||||||
|
"""Test gpio_expander::CachedGpioExpander correctly calls hardware functions."""
|
||||||
|
# Get the path to the external components directory
|
||||||
|
external_components_path = str(
|
||||||
|
Path(__file__).parent / "fixtures" / "external_components"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Replace the placeholder in the YAML config with the actual path
|
||||||
|
yaml_config = yaml_config.replace(
|
||||||
|
"EXTERNAL_COMPONENT_PATH", external_components_path
|
||||||
|
)
|
||||||
|
|
||||||
|
logs_done = asyncio.Event()
|
||||||
|
|
||||||
|
# Patterns to match in logs
|
||||||
|
digital_read_hw_pattern = re.compile(r"digital_read_hw pin=(\d+)")
|
||||||
|
digital_read_cache_pattern = re.compile(r"digital_read_cache pin=(\d+)")
|
||||||
|
|
||||||
|
# ensure logs are in the expected order
|
||||||
|
log_order = [
|
||||||
|
(digital_read_hw_pattern, 0),
|
||||||
|
[(digital_read_cache_pattern, i) for i in range(0, 8)],
|
||||||
|
(digital_read_hw_pattern, 8),
|
||||||
|
[(digital_read_cache_pattern, i) for i in range(8, 16)],
|
||||||
|
(digital_read_hw_pattern, 16),
|
||||||
|
[(digital_read_cache_pattern, i) for i in range(16, 24)],
|
||||||
|
(digital_read_hw_pattern, 24),
|
||||||
|
[(digital_read_cache_pattern, i) for i in range(24, 32)],
|
||||||
|
(digital_read_hw_pattern, 3),
|
||||||
|
(digital_read_cache_pattern, 3),
|
||||||
|
(digital_read_hw_pattern, 3),
|
||||||
|
(digital_read_cache_pattern, 3),
|
||||||
|
(digital_read_cache_pattern, 4),
|
||||||
|
(digital_read_hw_pattern, 3),
|
||||||
|
(digital_read_cache_pattern, 3),
|
||||||
|
(digital_read_hw_pattern, 10),
|
||||||
|
(digital_read_cache_pattern, 10),
|
||||||
|
# full cache reset here for testing
|
||||||
|
(digital_read_hw_pattern, 15),
|
||||||
|
(digital_read_cache_pattern, 15),
|
||||||
|
(digital_read_cache_pattern, 14),
|
||||||
|
(digital_read_hw_pattern, 14),
|
||||||
|
(digital_read_cache_pattern, 14),
|
||||||
|
]
|
||||||
|
# Flatten the log order for easier processing
|
||||||
|
log_order: list[tuple[re.Pattern, int]] = [
|
||||||
|
item
|
||||||
|
for sublist in log_order
|
||||||
|
for item in (sublist if isinstance(sublist, list) else [sublist])
|
||||||
|
]
|
||||||
|
|
||||||
|
index = 0
|
||||||
|
|
||||||
|
def check_output(line: str) -> None:
|
||||||
|
"""Check log output for expected messages."""
|
||||||
|
nonlocal index
|
||||||
|
if logs_done.is_set():
|
||||||
|
return
|
||||||
|
|
||||||
|
clean_line = re.sub(r"\x1b\[[0-9;]*m", "", line)
|
||||||
|
|
||||||
|
if "digital_read" in clean_line:
|
||||||
|
if index >= len(log_order):
|
||||||
|
print(f"Received unexpected log line: {clean_line}")
|
||||||
|
logs_done.set()
|
||||||
|
return
|
||||||
|
|
||||||
|
pattern, expected_pin = log_order[index]
|
||||||
|
match = pattern.search(clean_line)
|
||||||
|
|
||||||
|
if not match:
|
||||||
|
print(f"Log line did not match next expected pattern: {clean_line}")
|
||||||
|
logs_done.set()
|
||||||
|
return
|
||||||
|
|
||||||
|
pin = int(match.group(1))
|
||||||
|
if pin != expected_pin:
|
||||||
|
print(f"Unexpected pin number. Expected {expected_pin}, got {pin}")
|
||||||
|
logs_done.set()
|
||||||
|
return
|
||||||
|
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
elif "DONE" in clean_line:
|
||||||
|
# Check if we reached the end of the expected log entries
|
||||||
|
logs_done.set()
|
||||||
|
|
||||||
|
# Run with log monitoring
|
||||||
|
async with (
|
||||||
|
run_compiled(yaml_config, line_callback=check_output),
|
||||||
|
api_client_connected() as client,
|
||||||
|
):
|
||||||
|
# Verify device info
|
||||||
|
device_info = await client.device_info()
|
||||||
|
assert device_info is not None
|
||||||
|
assert device_info.name == "gpio-expander-cache"
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(logs_done.wait(), timeout=5.0)
|
||||||
|
except TimeoutError:
|
||||||
|
pytest.fail("Timeout waiting for logs to complete")
|
||||||
|
|
||||||
|
assert index == len(log_order), (
|
||||||
|
f"Expected {len(log_order)} log entries, but got {index}"
|
||||||
|
)
|
Reference in New Issue
Block a user