mirror of
https://github.com/esphome/esphome.git
synced 2025-09-03 20:02:22 +01:00
Merge branch 'multi_device_args' into integration
This commit is contained in:
@@ -9,6 +9,7 @@ import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from typing import Protocol
|
||||
|
||||
import argcomplete
|
||||
|
||||
@@ -44,6 +45,7 @@ from esphome.const import (
|
||||
from esphome.core import CORE, EsphomeError, coroutine
|
||||
from esphome.helpers import get_bool_env, indent, is_ip_address
|
||||
from esphome.log import AnsiFore, color, setup_log
|
||||
from esphome.types import ConfigType
|
||||
from esphome.util import (
|
||||
get_serial_ports,
|
||||
list_yaml_files,
|
||||
@@ -55,6 +57,23 @@ from esphome.util import (
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ArgsProtocol(Protocol):
|
||||
device: list[str] | None
|
||||
reset: bool
|
||||
username: str | None
|
||||
password: str | None
|
||||
client_id: str | None
|
||||
topic: str | None
|
||||
file: str | None
|
||||
no_logs: bool
|
||||
only_generate: bool
|
||||
show_secrets: bool
|
||||
dashboard: bool
|
||||
configuration: str
|
||||
name: str
|
||||
upload_speed: str | None
|
||||
|
||||
|
||||
def choose_prompt(options, purpose: str = None):
|
||||
if not options:
|
||||
raise EsphomeError(
|
||||
@@ -88,30 +107,50 @@ def choose_prompt(options, purpose: str = None):
|
||||
|
||||
|
||||
def choose_upload_log_host(
|
||||
default, check_default, show_ota, show_mqtt, show_api, purpose: str = None
|
||||
):
|
||||
default: list[str] | str | None,
|
||||
check_default: str | None,
|
||||
show_ota: bool,
|
||||
show_mqtt: bool,
|
||||
show_api: bool,
|
||||
purpose: str | None = None,
|
||||
) -> list[str]:
|
||||
# Convert to list for uniform handling
|
||||
defaults = [default] if isinstance(default, str) else default or []
|
||||
|
||||
# If devices specified, resolve them
|
||||
if defaults:
|
||||
resolved: list[str] = []
|
||||
for device in defaults:
|
||||
if device == "SERIAL":
|
||||
options = [
|
||||
(f"{port.path} ({port.description})", port.path)
|
||||
for port in get_serial_ports()
|
||||
]
|
||||
resolved.append(choose_prompt(options, purpose=purpose))
|
||||
elif device == "OTA":
|
||||
if (show_ota and "ota" in CORE.config) or (
|
||||
show_api and "api" in CORE.config
|
||||
):
|
||||
resolved.append(CORE.address)
|
||||
elif show_mqtt and has_mqtt_logging():
|
||||
resolved.append("MQTT")
|
||||
else:
|
||||
resolved.append(device)
|
||||
return resolved
|
||||
|
||||
# No devices specified, show interactive chooser
|
||||
options = [
|
||||
(f"{port.path} ({port.description})", port.path) for port in get_serial_ports()
|
||||
]
|
||||
if default == "SERIAL":
|
||||
return choose_prompt(options, purpose=purpose)
|
||||
if (show_ota and "ota" in CORE.config) or (show_api and "api" in CORE.config):
|
||||
options.append((f"Over The Air ({CORE.address})", CORE.address))
|
||||
if default == "OTA":
|
||||
return CORE.address
|
||||
if (
|
||||
show_mqtt
|
||||
and (mqtt_config := CORE.config.get(CONF_MQTT))
|
||||
and mqtt_logging_enabled(mqtt_config)
|
||||
):
|
||||
if show_mqtt and has_mqtt_logging():
|
||||
mqtt_config = CORE.config[CONF_MQTT]
|
||||
options.append((f"MQTT ({mqtt_config[CONF_BROKER]})", "MQTT"))
|
||||
if default == "OTA":
|
||||
return "MQTT"
|
||||
if default is not None:
|
||||
return default
|
||||
|
||||
if check_default is not None and check_default in [opt[1] for opt in options]:
|
||||
return check_default
|
||||
return choose_prompt(options, purpose=purpose)
|
||||
return [check_default]
|
||||
return [choose_prompt(options, purpose=purpose)]
|
||||
|
||||
|
||||
def mqtt_logging_enabled(mqtt_config):
|
||||
@@ -123,7 +162,14 @@ def mqtt_logging_enabled(mqtt_config):
|
||||
return log_topic.get(CONF_LEVEL, None) != "NONE"
|
||||
|
||||
|
||||
def get_port_type(port):
|
||||
def has_mqtt_logging() -> bool:
|
||||
"""Check if MQTT logging is available."""
|
||||
return (mqtt_config := CORE.config.get(CONF_MQTT)) and mqtt_logging_enabled(
|
||||
mqtt_config
|
||||
)
|
||||
|
||||
|
||||
def get_port_type(port: str) -> str:
|
||||
if port.startswith("/") or port.startswith("COM"):
|
||||
return "SERIAL"
|
||||
if port == "MQTT":
|
||||
@@ -131,7 +177,7 @@ def get_port_type(port):
|
||||
return "NETWORK"
|
||||
|
||||
|
||||
def run_miniterm(config, port, args):
|
||||
def run_miniterm(config: ConfigType, port: str, args) -> int:
|
||||
from aioesphomeapi import LogParser
|
||||
import serial
|
||||
|
||||
@@ -208,7 +254,7 @@ def wrap_to_code(name, comp):
|
||||
return wrapped
|
||||
|
||||
|
||||
def write_cpp(config):
|
||||
def write_cpp(config: ConfigType) -> int:
|
||||
if not get_bool_env(ENV_NOGITIGNORE):
|
||||
writer.write_gitignore()
|
||||
|
||||
@@ -216,7 +262,7 @@ def write_cpp(config):
|
||||
return write_cpp_file()
|
||||
|
||||
|
||||
def generate_cpp_contents(config):
|
||||
def generate_cpp_contents(config: ConfigType) -> None:
|
||||
_LOGGER.info("Generating C++ source...")
|
||||
|
||||
for name, component, conf in iter_component_configs(CORE.config):
|
||||
@@ -227,7 +273,7 @@ def generate_cpp_contents(config):
|
||||
CORE.flush_tasks()
|
||||
|
||||
|
||||
def write_cpp_file():
|
||||
def write_cpp_file() -> int:
|
||||
code_s = indent(CORE.cpp_main_section)
|
||||
writer.write_cpp(code_s)
|
||||
|
||||
@@ -238,7 +284,7 @@ def write_cpp_file():
|
||||
return 0
|
||||
|
||||
|
||||
def compile_program(args, config):
|
||||
def compile_program(args: ArgsProtocol, config: ConfigType) -> int:
|
||||
from esphome import platformio_api
|
||||
|
||||
_LOGGER.info("Compiling app...")
|
||||
@@ -249,7 +295,9 @@ def compile_program(args, config):
|
||||
return 0 if idedata is not None else 1
|
||||
|
||||
|
||||
def upload_using_esptool(config, port, file, speed):
|
||||
def upload_using_esptool(
|
||||
config: ConfigType, port: str, file: str, speed: int
|
||||
) -> str | int:
|
||||
from esphome import platformio_api
|
||||
|
||||
first_baudrate = speed or config[CONF_ESPHOME][CONF_PLATFORMIO_OPTIONS].get(
|
||||
@@ -314,7 +362,7 @@ def upload_using_esptool(config, port, file, speed):
|
||||
return run_esptool(115200)
|
||||
|
||||
|
||||
def upload_using_platformio(config, port):
|
||||
def upload_using_platformio(config: ConfigType, port: str):
|
||||
from esphome import platformio_api
|
||||
|
||||
upload_args = ["-t", "upload", "-t", "nobuild"]
|
||||
@@ -323,7 +371,7 @@ def upload_using_platformio(config, port):
|
||||
return platformio_api.run_platformio_cli_run(config, CORE.verbose, *upload_args)
|
||||
|
||||
|
||||
def check_permissions(port):
|
||||
def check_permissions(port: str):
|
||||
if os.name == "posix" and get_port_type(port) == "SERIAL":
|
||||
# Check if we can open selected serial port
|
||||
if not os.access(port, os.F_OK):
|
||||
@@ -341,7 +389,7 @@ def check_permissions(port):
|
||||
)
|
||||
|
||||
|
||||
def upload_program(config, args, host):
|
||||
def upload_program(config: ConfigType, args: ArgsProtocol, host: str) -> int | str:
|
||||
try:
|
||||
module = importlib.import_module("esphome.components." + CORE.target_platform)
|
||||
if getattr(module, "upload_program")(config, args, host):
|
||||
@@ -356,7 +404,7 @@ def upload_program(config, args, host):
|
||||
return upload_using_esptool(config, host, file, args.upload_speed)
|
||||
|
||||
if CORE.target_platform in (PLATFORM_RP2040):
|
||||
return upload_using_platformio(config, args.device)
|
||||
return upload_using_platformio(config, host)
|
||||
|
||||
if CORE.is_libretiny:
|
||||
return upload_using_platformio(config, host)
|
||||
@@ -379,9 +427,12 @@ def upload_program(config, args, host):
|
||||
remote_port = int(ota_conf[CONF_PORT])
|
||||
password = ota_conf.get(CONF_PASSWORD, "")
|
||||
|
||||
# Check if we should use MQTT for address resolution
|
||||
# This happens when no device was specified, or the current host is "MQTT"/"OTA"
|
||||
devices: list[str] = args.device or []
|
||||
if (
|
||||
CONF_MQTT in config # pylint: disable=too-many-boolean-expressions
|
||||
and (not args.device or args.device in ("MQTT", "OTA"))
|
||||
and (not devices or host in ("MQTT", "OTA"))
|
||||
and (
|
||||
((config[CONF_MDNS][CONF_DISABLED]) and not is_ip_address(CORE.address))
|
||||
or get_port_type(host) == "MQTT"
|
||||
@@ -399,23 +450,28 @@ def upload_program(config, args, host):
|
||||
return espota2.run_ota(host, remote_port, password, CORE.firmware_bin)
|
||||
|
||||
|
||||
def show_logs(config, args, port):
|
||||
def show_logs(config: ConfigType, args: ArgsProtocol, devices: list[str]) -> int | None:
|
||||
if "logger" not in config:
|
||||
raise EsphomeError("Logger is not configured!")
|
||||
|
||||
port = devices[0]
|
||||
|
||||
if get_port_type(port) == "SERIAL":
|
||||
check_permissions(port)
|
||||
return run_miniterm(config, port, args)
|
||||
if get_port_type(port) == "NETWORK" and "api" in config:
|
||||
addresses_to_use = devices
|
||||
if config[CONF_MDNS][CONF_DISABLED] and CONF_MQTT in config:
|
||||
from esphome import mqtt
|
||||
|
||||
port = mqtt.get_esphome_device_ip(
|
||||
mqtt_address = mqtt.get_esphome_device_ip(
|
||||
config, args.username, args.password, args.client_id
|
||||
)[0]
|
||||
addresses_to_use = [mqtt_address]
|
||||
|
||||
from esphome.components.api.client import run_logs
|
||||
|
||||
return run_logs(config, port)
|
||||
return run_logs(config, addresses_to_use)
|
||||
if get_port_type(port) == "MQTT" and "mqtt" in config:
|
||||
from esphome import mqtt
|
||||
|
||||
@@ -426,7 +482,7 @@ def show_logs(config, args, port):
|
||||
raise EsphomeError("No remote or local logging method configured (api/mqtt/logger)")
|
||||
|
||||
|
||||
def clean_mqtt(config, args):
|
||||
def clean_mqtt(config: ConfigType, args: ArgsProtocol) -> int | None:
|
||||
from esphome import mqtt
|
||||
|
||||
return mqtt.clear_topic(
|
||||
@@ -434,13 +490,13 @@ def clean_mqtt(config, args):
|
||||
)
|
||||
|
||||
|
||||
def command_wizard(args):
|
||||
def command_wizard(args: ArgsProtocol) -> int | None:
|
||||
from esphome import wizard
|
||||
|
||||
return wizard.wizard(args.configuration)
|
||||
|
||||
|
||||
def command_config(args, config):
|
||||
def command_config(args: ArgsProtocol, config: ConfigType) -> int | None:
|
||||
if not CORE.verbose:
|
||||
config = strip_default_ids(config)
|
||||
output = yaml_util.dump(config, args.show_secrets)
|
||||
@@ -455,7 +511,7 @@ def command_config(args, config):
|
||||
return 0
|
||||
|
||||
|
||||
def command_vscode(args):
|
||||
def command_vscode(args: ArgsProtocol) -> int | None:
|
||||
from esphome import vscode
|
||||
|
||||
logging.disable(logging.INFO)
|
||||
@@ -463,7 +519,7 @@ def command_vscode(args):
|
||||
vscode.read_config(args)
|
||||
|
||||
|
||||
def command_compile(args, config):
|
||||
def command_compile(args: ArgsProtocol, config: ConfigType) -> int | None:
|
||||
# Set memory analysis options in config
|
||||
if args.analyze_memory:
|
||||
config.setdefault(CONF_ESPHOME, {})["analyze_memory"] = True
|
||||
@@ -484,8 +540,9 @@ def command_compile(args, config):
|
||||
return 0
|
||||
|
||||
|
||||
def command_upload(args, config):
|
||||
port = choose_upload_log_host(
|
||||
def command_upload(args: ArgsProtocol, config: ConfigType) -> int | None:
|
||||
# Get devices, resolving special identifiers like OTA
|
||||
devices = choose_upload_log_host(
|
||||
default=args.device,
|
||||
check_default=None,
|
||||
show_ota=True,
|
||||
@@ -493,14 +550,22 @@ def command_upload(args, config):
|
||||
show_api=False,
|
||||
purpose="uploading",
|
||||
)
|
||||
exit_code = upload_program(config, args, port)
|
||||
if exit_code != 0:
|
||||
return exit_code
|
||||
_LOGGER.info("Successfully uploaded program.")
|
||||
return 0
|
||||
|
||||
# Try each device until one succeeds
|
||||
exit_code = 1
|
||||
for device in devices:
|
||||
_LOGGER.info("Uploading to %s", device)
|
||||
exit_code = upload_program(config, args, device)
|
||||
if exit_code == 0:
|
||||
_LOGGER.info("Successfully uploaded program.")
|
||||
return 0
|
||||
if len(devices) > 1:
|
||||
_LOGGER.warning("Failed to upload to %s", device)
|
||||
|
||||
return exit_code
|
||||
|
||||
|
||||
def command_discover(args, config):
|
||||
def command_discover(args: ArgsProtocol, config: ConfigType) -> int | None:
|
||||
if "mqtt" in config:
|
||||
from esphome import mqtt
|
||||
|
||||
@@ -509,8 +574,9 @@ def command_discover(args, config):
|
||||
raise EsphomeError("No discover method configured (mqtt)")
|
||||
|
||||
|
||||
def command_logs(args, config):
|
||||
port = choose_upload_log_host(
|
||||
def command_logs(args: ArgsProtocol, config: ConfigType) -> int | None:
|
||||
# Get devices, resolving special identifiers like OTA
|
||||
devices = choose_upload_log_host(
|
||||
default=args.device,
|
||||
check_default=None,
|
||||
show_ota=False,
|
||||
@@ -518,10 +584,10 @@ def command_logs(args, config):
|
||||
show_api=True,
|
||||
purpose="logging",
|
||||
)
|
||||
return show_logs(config, args, port)
|
||||
return show_logs(config, args, devices)
|
||||
|
||||
|
||||
def command_run(args, config):
|
||||
def command_run(args: ArgsProtocol, config: ConfigType) -> int | None:
|
||||
exit_code = write_cpp(config)
|
||||
if exit_code != 0:
|
||||
return exit_code
|
||||
@@ -538,7 +604,8 @@ def command_run(args, config):
|
||||
program_path = idedata.raw["prog_path"]
|
||||
return run_external_process(program_path)
|
||||
|
||||
port = choose_upload_log_host(
|
||||
# Get devices, resolving special identifiers like OTA
|
||||
devices = choose_upload_log_host(
|
||||
default=args.device,
|
||||
check_default=None,
|
||||
show_ota=True,
|
||||
@@ -546,39 +613,53 @@ def command_run(args, config):
|
||||
show_api=True,
|
||||
purpose="uploading",
|
||||
)
|
||||
exit_code = upload_program(config, args, port)
|
||||
if exit_code != 0:
|
||||
|
||||
# Try each device for upload until one succeeds
|
||||
successful_device: str | None = None
|
||||
for device in devices:
|
||||
_LOGGER.info("Uploading to %s", device)
|
||||
exit_code = upload_program(config, args, device)
|
||||
if exit_code == 0:
|
||||
_LOGGER.info("Successfully uploaded program.")
|
||||
successful_device = device
|
||||
break
|
||||
if len(devices) > 1:
|
||||
_LOGGER.warning("Failed to upload to %s", device)
|
||||
|
||||
if successful_device is None:
|
||||
return exit_code
|
||||
_LOGGER.info("Successfully uploaded program.")
|
||||
|
||||
if args.no_logs:
|
||||
return 0
|
||||
port = choose_upload_log_host(
|
||||
default=args.device,
|
||||
check_default=port,
|
||||
|
||||
# For logs, prefer the device we successfully uploaded to
|
||||
devices = choose_upload_log_host(
|
||||
default=successful_device,
|
||||
check_default=successful_device,
|
||||
show_ota=False,
|
||||
show_mqtt=True,
|
||||
show_api=True,
|
||||
purpose="logging",
|
||||
)
|
||||
return show_logs(config, args, port)
|
||||
return show_logs(config, args, devices)
|
||||
|
||||
|
||||
def command_clean_mqtt(args, config):
|
||||
def command_clean_mqtt(args: ArgsProtocol, config: ConfigType) -> int | None:
|
||||
return clean_mqtt(config, args)
|
||||
|
||||
|
||||
def command_mqtt_fingerprint(args, config):
|
||||
def command_mqtt_fingerprint(args: ArgsProtocol, config: ConfigType) -> int | None:
|
||||
from esphome import mqtt
|
||||
|
||||
return mqtt.get_fingerprint(config)
|
||||
|
||||
|
||||
def command_version(args):
|
||||
def command_version(args: ArgsProtocol) -> int | None:
|
||||
safe_print(f"Version: {const.__version__}")
|
||||
return 0
|
||||
|
||||
|
||||
def command_clean(args, config):
|
||||
def command_clean(args: ArgsProtocol, config: ConfigType) -> int | None:
|
||||
try:
|
||||
writer.clean_build()
|
||||
except OSError as err:
|
||||
@@ -588,13 +669,13 @@ def command_clean(args, config):
|
||||
return 0
|
||||
|
||||
|
||||
def command_dashboard(args):
|
||||
def command_dashboard(args: ArgsProtocol) -> int | None:
|
||||
from esphome.dashboard import dashboard
|
||||
|
||||
return dashboard.start_dashboard(args)
|
||||
|
||||
|
||||
def command_update_all(args):
|
||||
def command_update_all(args: ArgsProtocol) -> int | None:
|
||||
import click
|
||||
|
||||
success = {}
|
||||
@@ -641,7 +722,7 @@ def command_update_all(args):
|
||||
return failed
|
||||
|
||||
|
||||
def command_idedata(args, config):
|
||||
def command_idedata(args: ArgsProtocol, config: ConfigType) -> int:
|
||||
import json
|
||||
|
||||
from esphome import platformio_api
|
||||
@@ -657,7 +738,7 @@ def command_idedata(args, config):
|
||||
return 0
|
||||
|
||||
|
||||
def command_rename(args, config):
|
||||
def command_rename(args: ArgsProtocol, config: ConfigType) -> int | None:
|
||||
for c in args.name:
|
||||
if c not in ALLOWED_NAME_CHARS:
|
||||
print(
|
||||
@@ -774,6 +855,12 @@ POST_CONFIG_ACTIONS = {
|
||||
"discover": command_discover,
|
||||
}
|
||||
|
||||
SIMPLE_CONFIG_ACTIONS = [
|
||||
"clean",
|
||||
"clean-mqtt",
|
||||
"config",
|
||||
]
|
||||
|
||||
|
||||
def parse_args(argv):
|
||||
options_parser = argparse.ArgumentParser(add_help=False)
|
||||
@@ -872,7 +959,8 @@ def parse_args(argv):
|
||||
)
|
||||
parser_upload.add_argument(
|
||||
"--device",
|
||||
help="Manually specify the serial port/address to use, for example /dev/ttyUSB0.",
|
||||
action="append",
|
||||
help="Manually specify the serial port/address to use, for example /dev/ttyUSB0. Can be specified multiple times for fallback addresses.",
|
||||
)
|
||||
parser_upload.add_argument(
|
||||
"--upload_speed",
|
||||
@@ -894,7 +982,8 @@ def parse_args(argv):
|
||||
)
|
||||
parser_logs.add_argument(
|
||||
"--device",
|
||||
help="Manually specify the serial port/address to use, for example /dev/ttyUSB0.",
|
||||
action="append",
|
||||
help="Manually specify the serial port/address to use, for example /dev/ttyUSB0. Can be specified multiple times for fallback addresses.",
|
||||
)
|
||||
parser_logs.add_argument(
|
||||
"--reset",
|
||||
@@ -923,7 +1012,8 @@ def parse_args(argv):
|
||||
)
|
||||
parser_run.add_argument(
|
||||
"--device",
|
||||
help="Manually specify the serial port/address to use, for example /dev/ttyUSB0.",
|
||||
action="append",
|
||||
help="Manually specify the serial port/address to use, for example /dev/ttyUSB0. Can be specified multiple times for fallback addresses.",
|
||||
)
|
||||
parser_run.add_argument(
|
||||
"--upload_speed",
|
||||
@@ -1050,6 +1140,13 @@ def parse_args(argv):
|
||||
arguments = argv[1:]
|
||||
|
||||
argcomplete.autocomplete(parser)
|
||||
|
||||
if len(arguments) > 0 and arguments[0] in SIMPLE_CONFIG_ACTIONS:
|
||||
args, unknown_args = parser.parse_known_args(arguments)
|
||||
if unknown_args:
|
||||
_LOGGER.warning("Ignored unrecognized arguments: %s", unknown_args)
|
||||
return args
|
||||
|
||||
return parser.parse_args(arguments)
|
||||
|
||||
|
||||
|
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_run_logs(config: dict[str, Any], address: str) -> None:
|
||||
async def async_run_logs(config: dict[str, Any], addresses: list[str]) -> None:
|
||||
"""Run the logs command in the event loop."""
|
||||
conf = config["api"]
|
||||
name = config["esphome"]["name"]
|
||||
@@ -39,13 +39,21 @@ async def async_run_logs(config: dict[str, Any], address: str) -> None:
|
||||
noise_psk: str | None = None
|
||||
if (encryption := conf.get(CONF_ENCRYPTION)) and (key := encryption.get(CONF_KEY)):
|
||||
noise_psk = key
|
||||
_LOGGER.info("Starting log output from %s using esphome API", address)
|
||||
|
||||
if len(addresses) == 1:
|
||||
_LOGGER.info("Starting log output from %s using esphome API", addresses[0])
|
||||
else:
|
||||
_LOGGER.info(
|
||||
"Starting log output from %s using esphome API", " or ".join(addresses)
|
||||
)
|
||||
|
||||
cli = APIClient(
|
||||
address,
|
||||
addresses[0], # Primary address for compatibility
|
||||
port,
|
||||
password,
|
||||
client_info=f"ESPHome Logs {__version__}",
|
||||
noise_psk=noise_psk,
|
||||
addresses=addresses, # Pass all addresses for automatic retry
|
||||
)
|
||||
dashboard = CORE.dashboard
|
||||
|
||||
@@ -66,7 +74,7 @@ async def async_run_logs(config: dict[str, Any], address: str) -> None:
|
||||
await stop()
|
||||
|
||||
|
||||
def run_logs(config: dict[str, Any], address: str) -> None:
|
||||
def run_logs(config: dict[str, Any], addresses: list[str]) -> None:
|
||||
"""Run the logs command."""
|
||||
with contextlib.suppress(KeyboardInterrupt):
|
||||
asyncio.run(async_run_logs(config, address))
|
||||
asyncio.run(async_run_logs(config, addresses))
|
||||
|
@@ -2,11 +2,7 @@
|
||||
#include "esphome/core/helpers.h"
|
||||
#include "esphome/core/log.h"
|
||||
|
||||
#ifdef USE_ESP32
|
||||
|
||||
#ifdef USE_ARDUINO
|
||||
#include <esp32-hal-dac.h>
|
||||
#endif
|
||||
#if defined(USE_ESP32_VARIANT_ESP32) || defined(USE_ESP32_VARIANT_ESP32S2)
|
||||
|
||||
namespace esphome {
|
||||
namespace esp32_dac {
|
||||
@@ -23,18 +19,12 @@ void ESP32DAC::setup() {
|
||||
this->pin_->setup();
|
||||
this->turn_off();
|
||||
|
||||
#ifdef USE_ESP_IDF
|
||||
const dac_channel_t channel = this->pin_->get_pin() == DAC0_PIN ? DAC_CHAN_0 : DAC_CHAN_1;
|
||||
const dac_oneshot_config_t oneshot_cfg{channel};
|
||||
dac_oneshot_new_channel(&oneshot_cfg, &this->dac_handle_);
|
||||
#endif
|
||||
}
|
||||
|
||||
void ESP32DAC::on_safe_shutdown() {
|
||||
#ifdef USE_ESP_IDF
|
||||
dac_oneshot_del_channel(this->dac_handle_);
|
||||
#endif
|
||||
}
|
||||
void ESP32DAC::on_safe_shutdown() { dac_oneshot_del_channel(this->dac_handle_); }
|
||||
|
||||
void ESP32DAC::dump_config() {
|
||||
ESP_LOGCONFIG(TAG, "ESP32 DAC:");
|
||||
@@ -48,15 +38,10 @@ void ESP32DAC::write_state(float state) {
|
||||
|
||||
state = state * 255;
|
||||
|
||||
#ifdef USE_ESP_IDF
|
||||
dac_oneshot_output_voltage(this->dac_handle_, state);
|
||||
#endif
|
||||
#ifdef USE_ARDUINO
|
||||
dacWrite(this->pin_->get_pin(), state);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace esp32_dac
|
||||
} // namespace esphome
|
||||
|
||||
#endif
|
||||
#endif // USE_ESP32_VARIANT_ESP32 || USE_ESP32_VARIANT_ESP32S2
|
||||
|
@@ -1,15 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "esphome/components/output/float_output.h"
|
||||
#include "esphome/core/automation.h"
|
||||
#include "esphome/core/component.h"
|
||||
#include "esphome/core/hal.h"
|
||||
#include "esphome/core/automation.h"
|
||||
#include "esphome/components/output/float_output.h"
|
||||
|
||||
#ifdef USE_ESP32
|
||||
#if defined(USE_ESP32_VARIANT_ESP32) || defined(USE_ESP32_VARIANT_ESP32S2)
|
||||
|
||||
#ifdef USE_ESP_IDF
|
||||
#include <driver/dac_oneshot.h>
|
||||
#endif
|
||||
|
||||
namespace esphome {
|
||||
namespace esp32_dac {
|
||||
@@ -29,12 +27,10 @@ class ESP32DAC : public output::FloatOutput, public Component {
|
||||
void write_state(float state) override;
|
||||
|
||||
InternalGPIOPin *pin_;
|
||||
#ifdef USE_ESP_IDF
|
||||
dac_oneshot_handle_t dac_handle_;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace esp32_dac
|
||||
} // namespace esphome
|
||||
|
||||
#endif
|
||||
#endif // USE_ESP32_VARIANT_ESP32 || USE_ESP32_VARIANT_ESP32S2
|
||||
|
@@ -2,10 +2,11 @@
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include "esphome/core/hal.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace gpio_expander {
|
||||
namespace esphome::gpio_expander {
|
||||
|
||||
/// @brief A class to cache the read state of a GPIO expander.
|
||||
/// This class caches reads between GPIO Pins which are on the same bank.
|
||||
@@ -17,12 +18,22 @@ namespace gpio_expander {
|
||||
/// N - Number of pins
|
||||
template<typename T, T N> class CachedGpioExpander {
|
||||
public:
|
||||
/// @brief Read the state of the given pin. This will invalidate the cache for the given pin number.
|
||||
/// @param pin Pin number to read
|
||||
/// @return Pin state
|
||||
bool digital_read(T pin) {
|
||||
uint8_t bank = pin / (sizeof(T) * BITS_PER_BYTE);
|
||||
if (this->read_cache_invalidated_[bank]) {
|
||||
this->read_cache_invalidated_[bank] = false;
|
||||
const uint8_t bank = pin / BANK_SIZE;
|
||||
const T pin_mask = (1 << (pin % BANK_SIZE));
|
||||
// Check if specific pin cache is valid
|
||||
if (this->read_cache_valid_[bank] & pin_mask) {
|
||||
// Invalidate pin
|
||||
this->read_cache_valid_[bank] &= ~pin_mask;
|
||||
} else {
|
||||
// Read whole bank from hardware
|
||||
if (!this->digital_read_hw(pin))
|
||||
return false;
|
||||
// Mark bank cache as valid except the pin that is being returned now
|
||||
this->read_cache_valid_[bank] = std::numeric_limits<T>::max() & ~pin_mask;
|
||||
}
|
||||
return this->digital_read_cache(pin);
|
||||
}
|
||||
@@ -36,18 +47,16 @@ template<typename T, T N> class CachedGpioExpander {
|
||||
virtual bool digital_read_cache(T pin) = 0;
|
||||
/// @brief Call component low level function to write GPIO state to device
|
||||
virtual void digital_write_hw(T pin, bool value) = 0;
|
||||
const uint8_t cache_byte_size_ = N / (sizeof(T) * BITS_PER_BYTE);
|
||||
|
||||
/// @brief Invalidate cache. This function should be called in component loop().
|
||||
void reset_pin_cache_() {
|
||||
for (T i = 0; i < this->cache_byte_size_; i++) {
|
||||
this->read_cache_invalidated_[i] = true;
|
||||
}
|
||||
}
|
||||
void reset_pin_cache_() { memset(this->read_cache_valid_, 0x00, CACHE_SIZE_BYTES); }
|
||||
|
||||
static const uint8_t BITS_PER_BYTE = 8;
|
||||
std::array<bool, N / (sizeof(T) * BITS_PER_BYTE)> read_cache_invalidated_{};
|
||||
static constexpr uint8_t BITS_PER_BYTE = 8;
|
||||
static constexpr uint8_t BANK_SIZE = sizeof(T) * BITS_PER_BYTE;
|
||||
static constexpr size_t BANKS = N / BANK_SIZE;
|
||||
static constexpr size_t CACHE_SIZE_BYTES = BANKS * sizeof(T);
|
||||
|
||||
T read_cache_valid_[BANKS]{0};
|
||||
};
|
||||
|
||||
} // namespace gpio_expander
|
||||
} // namespace esphome
|
||||
} // namespace esphome::gpio_expander
|
||||
|
@@ -324,39 +324,47 @@ class EsphomePortCommandWebSocket(EsphomeCommandWebSocket):
|
||||
configuration = json_message["configuration"]
|
||||
config_file = settings.rel_path(configuration)
|
||||
port = json_message["port"]
|
||||
addresses: list[str] = [port]
|
||||
if (
|
||||
port == "OTA" # pylint: disable=too-many-boolean-expressions
|
||||
and (entry := entries.get(config_file))
|
||||
and entry.loaded_integrations
|
||||
and "api" in entry.loaded_integrations
|
||||
):
|
||||
if (mdns := dashboard.mdns_status) and (
|
||||
address_list := await mdns.async_resolve_host(entry.name)
|
||||
):
|
||||
# Use the IP address if available but only
|
||||
# if the API is loaded and the device is online
|
||||
# since MQTT logging will not work otherwise
|
||||
port = sort_ip_addresses(address_list)[0]
|
||||
elif (
|
||||
entry.address
|
||||
addresses = []
|
||||
# First priority: entry.address AKA use_address
|
||||
if (
|
||||
(use_address := entry.address)
|
||||
and (
|
||||
address_list := await dashboard.dns_cache.async_resolve(
|
||||
entry.address, time.monotonic()
|
||||
use_address, time.monotonic()
|
||||
)
|
||||
)
|
||||
and not isinstance(address_list, Exception)
|
||||
):
|
||||
# If mdns is not available, try to use the DNS cache
|
||||
port = sort_ip_addresses(address_list)[0]
|
||||
addresses.extend(sort_ip_addresses(address_list))
|
||||
|
||||
return [
|
||||
*DASHBOARD_COMMAND,
|
||||
*args,
|
||||
config_file,
|
||||
"--device",
|
||||
port,
|
||||
# Second priority: mDNS
|
||||
if (
|
||||
(mdns := dashboard.mdns_status)
|
||||
and (address_list := await mdns.async_resolve_host(entry.name))
|
||||
and (
|
||||
new_addresses := [
|
||||
addr for addr in address_list if addr not in addresses
|
||||
]
|
||||
)
|
||||
):
|
||||
# Use the IP address if available but only
|
||||
# if the API is loaded and the device is online
|
||||
# since MQTT logging will not work otherwise
|
||||
addresses.extend(sort_ip_addresses(new_addresses))
|
||||
|
||||
device_args: list[str] = [
|
||||
arg for address in addresses for arg in ("--device", address)
|
||||
]
|
||||
|
||||
return [*DASHBOARD_COMMAND, *args, config_file, *device_args]
|
||||
|
||||
|
||||
class EsphomeLogsHandler(EsphomePortCommandWebSocket):
|
||||
async def build_command(self, json_message: dict[str, Any]) -> list[str]:
|
||||
|
@@ -6,6 +6,7 @@ from pathlib import Path
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from esphome import const
|
||||
|
||||
@@ -110,7 +111,7 @@ class RedirectText:
|
||||
def __getattr__(self, item):
|
||||
return getattr(self._out, item)
|
||||
|
||||
def _write_color_replace(self, s):
|
||||
def _write_color_replace(self, s: str | bytes) -> None:
|
||||
from esphome.core import CORE
|
||||
|
||||
if CORE.dashboard:
|
||||
@@ -121,7 +122,7 @@ class RedirectText:
|
||||
s = s.replace("\033", "\\033")
|
||||
self._out.write(s)
|
||||
|
||||
def write(self, s):
|
||||
def write(self, s: str | bytes) -> int:
|
||||
# s is usually a str already (self._out is of type TextIOWrapper)
|
||||
# However, s is sometimes also a bytes object in python3. Let's make sure it's a
|
||||
# str
|
||||
@@ -223,7 +224,7 @@ def run_external_command(
|
||||
return retval
|
||||
|
||||
|
||||
def run_external_process(*cmd, **kwargs):
|
||||
def run_external_process(*cmd: str, **kwargs: Any) -> int | str:
|
||||
full_cmd = " ".join(shlex_quote(x) for x in cmd)
|
||||
_LOGGER.debug("Running: %s", full_cmd)
|
||||
filter_lines = kwargs.get("filter_lines")
|
||||
@@ -266,7 +267,7 @@ class OrderedDict(collections.OrderedDict):
|
||||
return dict(self).__repr__()
|
||||
|
||||
|
||||
def list_yaml_files(folders):
|
||||
def list_yaml_files(folders: list[str]) -> list[str]:
|
||||
files = filter_yaml_files(
|
||||
[os.path.join(folder, p) for folder in folders for p in os.listdir(folder)]
|
||||
)
|
||||
@@ -274,7 +275,7 @@ def list_yaml_files(folders):
|
||||
return files
|
||||
|
||||
|
||||
def filter_yaml_files(files):
|
||||
def filter_yaml_files(files: list[str]) -> list[str]:
|
||||
return [
|
||||
f
|
||||
for f in files
|
||||
|
@@ -12,7 +12,7 @@ platformio==6.1.18 # When updating platformio, also update /docker/Dockerfile
|
||||
esptool==5.0.2
|
||||
click==8.1.7
|
||||
esphome-dashboard==20250514.0
|
||||
aioesphomeapi==37.2.4
|
||||
aioesphomeapi==37.2.5
|
||||
zeroconf==0.147.0
|
||||
puremagic==1.30
|
||||
ruamel.yaml==0.18.14 # dashboard_import
|
||||
|
@@ -0,0 +1,25 @@
|
||||
import esphome.codegen as cg
|
||||
import esphome.config_validation as cv
|
||||
from esphome.const import CONF_ID
|
||||
|
||||
AUTO_LOAD = ["gpio_expander"]
|
||||
|
||||
gpio_expander_test_component_ns = cg.esphome_ns.namespace(
|
||||
"gpio_expander_test_component"
|
||||
)
|
||||
|
||||
GPIOExpanderTestComponent = gpio_expander_test_component_ns.class_(
|
||||
"GPIOExpanderTestComponent", cg.Component
|
||||
)
|
||||
|
||||
|
||||
CONFIG_SCHEMA = cv.Schema(
|
||||
{
|
||||
cv.GenerateID(): cv.declare_id(GPIOExpanderTestComponent),
|
||||
}
|
||||
).extend(cv.COMPONENT_SCHEMA)
|
||||
|
||||
|
||||
async def to_code(config):
|
||||
var = cg.new_Pvariable(config[CONF_ID])
|
||||
await cg.register_component(var, config)
|
@@ -0,0 +1,38 @@
|
||||
#include "gpio_expander_test_component.h"
|
||||
|
||||
#include "esphome/core/application.h"
|
||||
#include "esphome/core/log.h"
|
||||
|
||||
namespace esphome::gpio_expander_test_component {
|
||||
|
||||
static const char *const TAG = "gpio_expander_test";
|
||||
|
||||
void GPIOExpanderTestComponent::setup() {
|
||||
for (uint8_t pin = 0; pin < 32; pin++) {
|
||||
this->digital_read(pin);
|
||||
}
|
||||
|
||||
this->digital_read(3);
|
||||
this->digital_read(3);
|
||||
this->digital_read(4);
|
||||
this->digital_read(3);
|
||||
this->digital_read(10);
|
||||
this->reset_pin_cache_(); // Reset cache to ensure next read is from hardware
|
||||
this->digital_read(15);
|
||||
this->digital_read(14);
|
||||
this->digital_read(14);
|
||||
|
||||
ESP_LOGD(TAG, "DONE");
|
||||
}
|
||||
|
||||
bool GPIOExpanderTestComponent::digital_read_hw(uint8_t pin) {
|
||||
ESP_LOGD(TAG, "digital_read_hw pin=%d", pin);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GPIOExpanderTestComponent::digital_read_cache(uint8_t pin) {
|
||||
ESP_LOGD(TAG, "digital_read_cache pin=%d", pin);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace esphome::gpio_expander_test_component
|
@@ -0,0 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#include "esphome/components/gpio_expander/cached_gpio.h"
|
||||
#include "esphome/core/component.h"
|
||||
|
||||
namespace esphome::gpio_expander_test_component {
|
||||
|
||||
class GPIOExpanderTestComponent : public Component, public esphome::gpio_expander::CachedGpioExpander<uint8_t, 32> {
|
||||
public:
|
||||
void setup() override;
|
||||
|
||||
protected:
|
||||
bool digital_read_hw(uint8_t pin) override;
|
||||
bool digital_read_cache(uint8_t pin) override;
|
||||
void digital_write_hw(uint8_t pin, bool value) override{};
|
||||
};
|
||||
|
||||
} // namespace esphome::gpio_expander_test_component
|
17
tests/integration/fixtures/gpio_expander_cache.yaml
Normal file
17
tests/integration/fixtures/gpio_expander_cache.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
esphome:
|
||||
name: gpio-expander-cache
|
||||
host:
|
||||
|
||||
logger:
|
||||
level: DEBUG
|
||||
|
||||
api:
|
||||
|
||||
# External component that uses gpio_expander::CachedGpioExpander
|
||||
external_components:
|
||||
- source:
|
||||
type: local
|
||||
path: EXTERNAL_COMPONENT_PATH
|
||||
components: [gpio_expander_test_component]
|
||||
|
||||
gpio_expander_test_component:
|
123
tests/integration/test_gpio_expander_cache.py
Normal file
123
tests/integration/test_gpio_expander_cache.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Integration test for CachedGPIOExpander to ensure correct behavior."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
import pytest
|
||||
|
||||
from .types import APIClientConnectedFactory, RunCompiledFunction
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpio_expander_cache(
|
||||
yaml_config: str,
|
||||
run_compiled: RunCompiledFunction,
|
||||
api_client_connected: APIClientConnectedFactory,
|
||||
) -> None:
|
||||
"""Test gpio_expander::CachedGpioExpander correctly calls hardware functions."""
|
||||
# Get the path to the external components directory
|
||||
external_components_path = str(
|
||||
Path(__file__).parent / "fixtures" / "external_components"
|
||||
)
|
||||
|
||||
# Replace the placeholder in the YAML config with the actual path
|
||||
yaml_config = yaml_config.replace(
|
||||
"EXTERNAL_COMPONENT_PATH", external_components_path
|
||||
)
|
||||
|
||||
logs_done = asyncio.Event()
|
||||
|
||||
# Patterns to match in logs
|
||||
digital_read_hw_pattern = re.compile(r"digital_read_hw pin=(\d+)")
|
||||
digital_read_cache_pattern = re.compile(r"digital_read_cache pin=(\d+)")
|
||||
|
||||
# ensure logs are in the expected order
|
||||
log_order = [
|
||||
(digital_read_hw_pattern, 0),
|
||||
[(digital_read_cache_pattern, i) for i in range(0, 8)],
|
||||
(digital_read_hw_pattern, 8),
|
||||
[(digital_read_cache_pattern, i) for i in range(8, 16)],
|
||||
(digital_read_hw_pattern, 16),
|
||||
[(digital_read_cache_pattern, i) for i in range(16, 24)],
|
||||
(digital_read_hw_pattern, 24),
|
||||
[(digital_read_cache_pattern, i) for i in range(24, 32)],
|
||||
(digital_read_hw_pattern, 3),
|
||||
(digital_read_cache_pattern, 3),
|
||||
(digital_read_hw_pattern, 3),
|
||||
(digital_read_cache_pattern, 3),
|
||||
(digital_read_cache_pattern, 4),
|
||||
(digital_read_hw_pattern, 3),
|
||||
(digital_read_cache_pattern, 3),
|
||||
(digital_read_hw_pattern, 10),
|
||||
(digital_read_cache_pattern, 10),
|
||||
# full cache reset here for testing
|
||||
(digital_read_hw_pattern, 15),
|
||||
(digital_read_cache_pattern, 15),
|
||||
(digital_read_cache_pattern, 14),
|
||||
(digital_read_hw_pattern, 14),
|
||||
(digital_read_cache_pattern, 14),
|
||||
]
|
||||
# Flatten the log order for easier processing
|
||||
log_order: list[tuple[re.Pattern, int]] = [
|
||||
item
|
||||
for sublist in log_order
|
||||
for item in (sublist if isinstance(sublist, list) else [sublist])
|
||||
]
|
||||
|
||||
index = 0
|
||||
|
||||
def check_output(line: str) -> None:
|
||||
"""Check log output for expected messages."""
|
||||
nonlocal index
|
||||
if logs_done.is_set():
|
||||
return
|
||||
|
||||
clean_line = re.sub(r"\x1b\[[0-9;]*m", "", line)
|
||||
|
||||
if "digital_read" in clean_line:
|
||||
if index >= len(log_order):
|
||||
print(f"Received unexpected log line: {clean_line}")
|
||||
logs_done.set()
|
||||
return
|
||||
|
||||
pattern, expected_pin = log_order[index]
|
||||
match = pattern.search(clean_line)
|
||||
|
||||
if not match:
|
||||
print(f"Log line did not match next expected pattern: {clean_line}")
|
||||
logs_done.set()
|
||||
return
|
||||
|
||||
pin = int(match.group(1))
|
||||
if pin != expected_pin:
|
||||
print(f"Unexpected pin number. Expected {expected_pin}, got {pin}")
|
||||
logs_done.set()
|
||||
return
|
||||
|
||||
index += 1
|
||||
|
||||
elif "DONE" in clean_line:
|
||||
# Check if we reached the end of the expected log entries
|
||||
logs_done.set()
|
||||
|
||||
# Run with log monitoring
|
||||
async with (
|
||||
run_compiled(yaml_config, line_callback=check_output),
|
||||
api_client_connected() as client,
|
||||
):
|
||||
# Verify device info
|
||||
device_info = await client.device_info()
|
||||
assert device_info is not None
|
||||
assert device_info.name == "gpio-expander-cache"
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(logs_done.wait(), timeout=5.0)
|
||||
except TimeoutError:
|
||||
pytest.fail("Timeout waiting for logs to complete")
|
||||
|
||||
assert index == len(log_order), (
|
||||
f"Expected {len(log_order)} log entries, but got {index}"
|
||||
)
|
Reference in New Issue
Block a user