diff --git a/esphome/__main__.py b/esphome/__main__.py index 6e50af95cf..2ab98582f1 100644 --- a/esphome/__main__.py +++ b/esphome/__main__.py @@ -15,9 +15,11 @@ import argcomplete from esphome import const, writer, yaml_util import esphome.codegen as cg +from esphome.components.mqtt import CONF_DISCOVER_IP from esphome.config import iter_component_configs, read_config, strip_default_ids from esphome.const import ( ALLOWED_NAME_CHARS, + CONF_API, CONF_BAUD_RATE, CONF_BROKER, CONF_DEASSERT_RTS_DTR, @@ -43,6 +45,7 @@ from esphome.const import ( SECRETS_FILES, ) from esphome.core import CORE, EsphomeError, coroutine +from esphome.enum import StrEnum from esphome.helpers import get_bool_env, indent, is_ip_address from esphome.log import AnsiFore, color, setup_log from esphome.types import ConfigType @@ -106,13 +109,15 @@ def choose_prompt(options, purpose: str = None): return options[opt - 1][1] +class Purpose(StrEnum): + UPLOADING = "uploading" + LOGGING = "logging" + + def choose_upload_log_host( default: list[str] | str | None, check_default: str | None, - show_ota: bool, - show_mqtt: bool, - show_api: bool, - purpose: str | None = None, + purpose: Purpose, ) -> list[str]: # Convert to list for uniform handling defaults = [default] if isinstance(default, str) else default or [] @@ -132,13 +137,30 @@ def choose_upload_log_host( ] resolved.append(choose_prompt(options, purpose=purpose)) elif device == "OTA": - if CORE.address and ( - (show_ota and "ota" in CORE.config) - or (show_api and "api" in CORE.config) + # ensure IP adresses are used first + if is_ip_address(CORE.address) and ( + (purpose == Purpose.LOGGING and has_api()) + or (purpose == Purpose.UPLOADING and has_ota()) ): resolved.append(CORE.address) - elif show_mqtt and has_mqtt_logging(): - resolved.append("MQTT") + + if purpose == Purpose.LOGGING: + if has_api() and has_mqtt_ip_lookup(): + resolved.append("MQTTIP") + + if has_mqtt_logging(): + resolved.append("MQTT") + + if has_api() and has_non_ip_address(): + resolved.append(CORE.address) + + elif purpose == Purpose.UPLOADING: + if has_ota() and has_mqtt_ip_lookup(): + resolved.append("MQTTIP") + + if has_ota() and has_non_ip_address(): + resolved.append(CORE.address) + else: resolved.append(device) if not resolved: @@ -149,39 +171,111 @@ def choose_upload_log_host( options = [ (f"{port.path} ({port.description})", port.path) for port in get_serial_ports() ] - if (show_ota and "ota" in CORE.config) or (show_api and "api" in CORE.config): - options.append((f"Over The Air ({CORE.address})", CORE.address)) - if show_mqtt and has_mqtt_logging(): - mqtt_config = CORE.config[CONF_MQTT] - options.append((f"MQTT ({mqtt_config[CONF_BROKER]})", "MQTT")) + + if purpose == Purpose.LOGGING: + if has_mqtt_logging(): + mqtt_config = CORE.config[CONF_MQTT] + options.append((f"MQTT ({mqtt_config[CONF_BROKER]})", "MQTT")) + + if has_api(): + if has_resolvable_address(): + options.append((f"Over The Air ({CORE.address})", CORE.address)) + if has_mqtt_ip_lookup(): + options.append(("Over The Air (MQTT IP lookup)", "MQTTIP")) + + elif purpose == Purpose.UPLOADING and has_ota(): + if has_resolvable_address(): + options.append((f"Over The Air ({CORE.address})", CORE.address)) + if has_mqtt_ip_lookup(): + options.append(("Over The Air (MQTT IP lookup)", "MQTTIP")) if check_default is not None and check_default in [opt[1] for opt in options]: return [check_default] return [choose_prompt(options, purpose=purpose)] -def mqtt_logging_enabled(mqtt_config): +def has_mqtt_logging() -> bool: + """Check if MQTT logging is available.""" + if CONF_MQTT not in CORE.config: + return False + + mqtt_config = CORE.config[CONF_MQTT] + + # enabled by default + if CONF_LOG_TOPIC not in mqtt_config: + return True + log_topic = mqtt_config[CONF_LOG_TOPIC] if log_topic is None: return False + if CONF_TOPIC not in log_topic: return False - return log_topic.get(CONF_LEVEL, None) != "NONE" + + return log_topic[CONF_LEVEL] != "NONE" -def has_mqtt_logging() -> bool: - """Check if MQTT logging is available.""" - return (mqtt_config := CORE.config.get(CONF_MQTT)) and mqtt_logging_enabled( - mqtt_config - ) +def has_mqtt() -> bool: + """Check if MQTT is available.""" + return CONF_MQTT in CORE.config + + +def has_api() -> bool: + """Check if API is available.""" + return CONF_API in CORE.config + + +def has_ota() -> bool: + """Check if OTA is available.""" + return CONF_OTA in CORE.config + + +def has_mqtt_ip_lookup() -> bool: + """Check if MQTT is available and IP lookup is supported.""" + if CONF_MQTT not in CORE.config: + return False + # Default Enabled + if CONF_DISCOVER_IP not in CORE.config[CONF_MQTT]: + return True + return CORE.config[CONF_MQTT][CONF_DISCOVER_IP] + + +def has_mdns() -> bool: + """Check if MDNS is available.""" + return CONF_MDNS not in CORE.config or not CORE.config[CONF_MDNS][CONF_DISABLED] + + +def has_non_ip_address() -> bool: + """Check if CORE.address is set and is not an IP address.""" + return CORE.address is not None and not is_ip_address(CORE.address) + + +def has_ip_address() -> bool: + """Check if CORE.address is a valid IP address.""" + return CORE.address is not None and is_ip_address(CORE.address) + + +def has_resolvable_address() -> bool: + """Check if CORE.address is resolvable (via mDNS or is an IP address).""" + return has_mdns() or has_ip_address() + + +def mqtt_get_ip(config: ConfigType, username: str, password: str, client_id: str): + from esphome import mqtt + + return mqtt.get_esphome_device_ip(config, username, password, client_id) + + +_PORT_TO_PORT_TYPE = { + "MQTT": "MQTT", + "MQTTIP": "MQTTIP", +} def get_port_type(port: str) -> str: if port.startswith("/") or port.startswith("COM"): return "SERIAL" - if port == "MQTT": - return "MQTT" - return "NETWORK" + return _PORT_TO_PORT_TYPE.get(port, "NETWORK") def run_miniterm(config: ConfigType, port: str, args) -> int: @@ -226,7 +320,9 @@ def run_miniterm(config: ConfigType, port: str, args) -> int: .replace(b"\n", b"") .decode("utf8", "backslashreplace") ) - time_str = datetime.now().time().strftime("[%H:%M:%S]") + time_ = datetime.now() + nanoseconds = time_.microsecond // 1000 + time_str = f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}.{nanoseconds:03}]" safe_print(parser.parse_line(line, time_str)) backtrace_state = platformio_api.process_stacktrace( @@ -437,23 +533,9 @@ def upload_program( password = ota_conf.get(CONF_PASSWORD, "") binary = args.file if getattr(args, "file", None) is not None else CORE.firmware_bin - # Check if we should use MQTT for address resolution - # This happens when no device was specified, or the current host is "MQTT"/"OTA" - if ( - CONF_MQTT in config # pylint: disable=too-many-boolean-expressions - and (not devices or host in ("MQTT", "OTA")) - and ( - ((config[CONF_MDNS][CONF_DISABLED]) and not is_ip_address(CORE.address)) - or get_port_type(host) == "MQTT" - ) - ): - from esphome import mqtt - - devices = [ - mqtt.get_esphome_device_ip( - config, args.username, args.password, args.client_id - ) - ] + # MQTT address resolution + if get_port_type(host) in ("MQTT", "MQTTIP"): + devices = mqtt_get_ip(config, args.username, args.password, args.client_id) return espota2.run_ota(devices, remote_port, password, binary) @@ -474,20 +556,28 @@ def show_logs(config: ConfigType, args: ArgsProtocol, devices: list[str]) -> int if get_port_type(port) == "SERIAL": check_permissions(port) return run_miniterm(config, port, args) - if get_port_type(port) == "NETWORK" and "api" in config: - addresses_to_use = devices - if config[CONF_MDNS][CONF_DISABLED] and CONF_MQTT in config: - from esphome import mqtt - mqtt_address = mqtt.get_esphome_device_ip( + port_type = get_port_type(port) + + # Check if we should use API for logging + if has_api(): + addresses_to_use: list[str] | None = None + + if port_type == "NETWORK" and (has_mdns() or is_ip_address(port)): + addresses_to_use = devices + elif port_type in ("NETWORK", "MQTT", "MQTTIP") and has_mqtt_ip_lookup(): + # Only use MQTT IP lookup if the first condition didn't match + # (for MQTT/MQTTIP types, or for NETWORK when mdns/ip check fails) + addresses_to_use = mqtt_get_ip( config, args.username, args.password, args.client_id - )[0] - addresses_to_use = [mqtt_address] + ) - from esphome.components.api.client import run_logs + if addresses_to_use is not None: + from esphome.components.api.client import run_logs - return run_logs(config, addresses_to_use) - if get_port_type(port) in ("NETWORK", "MQTT") and "mqtt" in config: + return run_logs(config, addresses_to_use) + + if port_type in ("NETWORK", "MQTT") and has_mqtt_logging(): from esphome import mqtt return mqtt.show_logs( @@ -560,10 +650,7 @@ def command_upload(args: ArgsProtocol, config: ConfigType) -> int | None: devices = choose_upload_log_host( default=args.device, check_default=None, - show_ota=True, - show_mqtt=False, - show_api=False, - purpose="uploading", + purpose=Purpose.UPLOADING, ) exit_code, _ = upload_program(config, args, devices) @@ -588,10 +675,7 @@ def command_logs(args: ArgsProtocol, config: ConfigType) -> int | None: devices = choose_upload_log_host( default=args.device, check_default=None, - show_ota=False, - show_mqtt=True, - show_api=True, - purpose="logging", + purpose=Purpose.LOGGING, ) return show_logs(config, args, devices) @@ -617,10 +701,7 @@ def command_run(args: ArgsProtocol, config: ConfigType) -> int | None: devices = choose_upload_log_host( default=args.device, check_default=None, - show_ota=True, - show_mqtt=False, - show_api=True, - purpose="uploading", + purpose=Purpose.UPLOADING, ) exit_code, successful_device = upload_program(config, args, devices) @@ -637,10 +718,7 @@ def command_run(args: ArgsProtocol, config: ConfigType) -> int | None: devices = choose_upload_log_host( default=successful_device, check_default=successful_device, - show_ota=False, - show_mqtt=True, - show_api=True, - purpose="logging", + purpose=Purpose.LOGGING, ) return show_logs(config, args, devices) diff --git a/esphome/components/adc/__init__.py b/esphome/components/adc/__init__.py index f260e13242..15dc447b6c 100644 --- a/esphome/components/adc/__init__.py +++ b/esphome/components/adc/__init__.py @@ -11,15 +11,8 @@ from esphome.components.esp32.const import ( VARIANT_ESP32S2, VARIANT_ESP32S3, ) -from esphome.config_helpers import filter_source_files_from_platform import esphome.config_validation as cv -from esphome.const import ( - CONF_ANALOG, - CONF_INPUT, - CONF_NUMBER, - PLATFORM_ESP8266, - PlatformFramework, -) +from esphome.const import CONF_ANALOG, CONF_INPUT, CONF_NUMBER, PLATFORM_ESP8266 from esphome.core import CORE CODEOWNERS = ["@esphome/core"] @@ -273,21 +266,3 @@ def validate_adc_pin(value): )(value) raise NotImplementedError - - -FILTER_SOURCE_FILES = filter_source_files_from_platform( - { - "adc_sensor_esp32.cpp": { - PlatformFramework.ESP32_ARDUINO, - PlatformFramework.ESP32_IDF, - }, - "adc_sensor_esp8266.cpp": {PlatformFramework.ESP8266_ARDUINO}, - "adc_sensor_rp2040.cpp": {PlatformFramework.RP2040_ARDUINO}, - "adc_sensor_libretiny.cpp": { - PlatformFramework.BK72XX_ARDUINO, - PlatformFramework.RTL87XX_ARDUINO, - PlatformFramework.LN882X_ARDUINO, - }, - "adc_sensor_zephyr.cpp": {PlatformFramework.NRF52_ZEPHYR}, - } -) diff --git a/esphome/components/adc/sensor.py b/esphome/components/adc/sensor.py index 49970c5e3d..607609bbc7 100644 --- a/esphome/components/adc/sensor.py +++ b/esphome/components/adc/sensor.py @@ -9,6 +9,7 @@ from esphome.components.zephyr import ( zephyr_add_prj_conf, zephyr_add_user, ) +from esphome.config_helpers import filter_source_files_from_platform import esphome.config_validation as cv from esphome.const import ( CONF_ATTENUATION, @@ -20,6 +21,7 @@ from esphome.const import ( PLATFORM_NRF52, STATE_CLASS_MEASUREMENT, UNIT_VOLT, + PlatformFramework, ) from esphome.core import CORE @@ -174,3 +176,21 @@ async def to_code(config): }}; """ ) + + +FILTER_SOURCE_FILES = filter_source_files_from_platform( + { + "adc_sensor_esp32.cpp": { + PlatformFramework.ESP32_ARDUINO, + PlatformFramework.ESP32_IDF, + }, + "adc_sensor_esp8266.cpp": {PlatformFramework.ESP8266_ARDUINO}, + "adc_sensor_rp2040.cpp": {PlatformFramework.RP2040_ARDUINO}, + "adc_sensor_libretiny.cpp": { + PlatformFramework.BK72XX_ARDUINO, + PlatformFramework.RTL87XX_ARDUINO, + PlatformFramework.LN882X_ARDUINO, + }, + "adc_sensor_zephyr.cpp": {PlatformFramework.NRF52_ZEPHYR}, + } +) diff --git a/esphome/components/api/client.py b/esphome/components/api/client.py index ce018b3b98..ca1fc089fa 100644 --- a/esphome/components/api/client.py +++ b/esphome/components/api/client.py @@ -62,9 +62,11 @@ async def async_run_logs(config: dict[str, Any], addresses: list[str]) -> None: time_ = datetime.now() message: bytes = msg.message text = message.decode("utf8", "backslashreplace") - for parsed_msg in parse_log_message( - text, f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}]" - ): + nanoseconds = time_.microsecond // 1000 + timestamp = ( + f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}.{nanoseconds:03}]" + ) + for parsed_msg in parse_log_message(text, timestamp): print(parsed_msg.replace("\033", "\\033") if dashboard else parsed_msg) stop = await async_run(cli, on_log, name=name) diff --git a/esphome/components/esp32/__init__.py b/esphome/components/esp32/__init__.py index 12d84dd4b3..50a47765bf 100644 --- a/esphome/components/esp32/__init__.py +++ b/esphome/components/esp32/__init__.py @@ -353,6 +353,7 @@ SUPPORTED_PLATFORMIO_ESP_IDF_5X = [ # pioarduino versions that don't require a release number # List based on https://github.com/pioarduino/esp-idf/releases SUPPORTED_PIOARDUINO_ESP_IDF_5X = [ + cv.Version(5, 5, 1), cv.Version(5, 5, 0), cv.Version(5, 4, 2), cv.Version(5, 4, 1), diff --git a/esphome/components/factory_reset/button/factory_reset_button.cpp b/esphome/components/factory_reset/button/factory_reset_button.cpp index 585975c043..d582317767 100644 --- a/esphome/components/factory_reset/button/factory_reset_button.cpp +++ b/esphome/components/factory_reset/button/factory_reset_button.cpp @@ -1,7 +1,13 @@ #include "factory_reset_button.h" + +#include "esphome/core/defines.h" + +#ifdef USE_OPENTHREAD +#include "esphome/components/openthread/openthread.h" +#endif +#include "esphome/core/application.h" #include "esphome/core/hal.h" #include "esphome/core/log.h" -#include "esphome/core/application.h" namespace esphome { namespace factory_reset { @@ -13,9 +19,20 @@ void FactoryResetButton::press_action() { ESP_LOGI(TAG, "Resetting"); // Let MQTT settle a bit delay(100); // NOLINT +#ifdef USE_OPENTHREAD + openthread::global_openthread_component->on_factory_reset(FactoryResetButton::factory_reset_callback); +#else + global_preferences->reset(); + App.safe_reboot(); +#endif +} + +#ifdef USE_OPENTHREAD +void FactoryResetButton::factory_reset_callback() { global_preferences->reset(); App.safe_reboot(); } +#endif } // namespace factory_reset } // namespace esphome diff --git a/esphome/components/factory_reset/button/factory_reset_button.h b/esphome/components/factory_reset/button/factory_reset_button.h index 9996a860d9..c68da2ca74 100644 --- a/esphome/components/factory_reset/button/factory_reset_button.h +++ b/esphome/components/factory_reset/button/factory_reset_button.h @@ -1,7 +1,9 @@ #pragma once -#include "esphome/core/component.h" +#include "esphome/core/defines.h" + #include "esphome/components/button/button.h" +#include "esphome/core/component.h" namespace esphome { namespace factory_reset { @@ -9,6 +11,9 @@ namespace factory_reset { class FactoryResetButton : public button::Button, public Component { public: void dump_config() override; +#ifdef USE_OPENTHREAD + static void factory_reset_callback(); +#endif protected: void press_action() override; diff --git a/esphome/components/factory_reset/switch/factory_reset_switch.cpp b/esphome/components/factory_reset/switch/factory_reset_switch.cpp index 1282c73f4e..75449aa526 100644 --- a/esphome/components/factory_reset/switch/factory_reset_switch.cpp +++ b/esphome/components/factory_reset/switch/factory_reset_switch.cpp @@ -1,7 +1,13 @@ #include "factory_reset_switch.h" + +#include "esphome/core/defines.h" + +#ifdef USE_OPENTHREAD +#include "esphome/components/openthread/openthread.h" +#endif +#include "esphome/core/application.h" #include "esphome/core/hal.h" #include "esphome/core/log.h" -#include "esphome/core/application.h" namespace esphome { namespace factory_reset { @@ -17,10 +23,21 @@ void FactoryResetSwitch::write_state(bool state) { ESP_LOGI(TAG, "Resetting"); // Let MQTT settle a bit delay(100); // NOLINT +#ifdef USE_OPENTHREAD + openthread::global_openthread_component->on_factory_reset(FactoryResetSwitch::factory_reset_callback); +#else global_preferences->reset(); App.safe_reboot(); +#endif } } +#ifdef USE_OPENTHREAD +void FactoryResetSwitch::factory_reset_callback() { + global_preferences->reset(); + App.safe_reboot(); +} +#endif + } // namespace factory_reset } // namespace esphome diff --git a/esphome/components/factory_reset/switch/factory_reset_switch.h b/esphome/components/factory_reset/switch/factory_reset_switch.h index 2c914ea76d..8ea0c79108 100644 --- a/esphome/components/factory_reset/switch/factory_reset_switch.h +++ b/esphome/components/factory_reset/switch/factory_reset_switch.h @@ -1,7 +1,8 @@ #pragma once -#include "esphome/core/component.h" #include "esphome/components/switch/switch.h" +#include "esphome/core/component.h" +#include "esphome/core/defines.h" namespace esphome { namespace factory_reset { @@ -9,6 +10,9 @@ namespace factory_reset { class FactoryResetSwitch : public switch_::Switch, public Component { public: void dump_config() override; +#ifdef USE_OPENTHREAD + static void factory_reset_callback(); +#endif protected: void write_state(bool state) override; diff --git a/esphome/components/openthread/openthread.cpp b/esphome/components/openthread/openthread.cpp index 322ff43238..5b5c113f83 100644 --- a/esphome/components/openthread/openthread.cpp +++ b/esphome/components/openthread/openthread.cpp @@ -11,8 +11,6 @@ #include #include #include -#include -#include #include #include @@ -77,8 +75,14 @@ std::optional OpenThreadComponent::get_omr_address_(InstanceLock & return {}; } -void srp_callback(otError err, const otSrpClientHostInfo *host_info, const otSrpClientService *services, - const otSrpClientService *removed_services, void *context) { +void OpenThreadComponent::defer_factory_reset_external_callback() { + ESP_LOGD(TAG, "Defer factory_reset_external_callback_"); + this->defer([this]() { this->factory_reset_external_callback_(); }); +} + +void OpenThreadSrpComponent::srp_callback(otError err, const otSrpClientHostInfo *host_info, + const otSrpClientService *services, + const otSrpClientService *removed_services, void *context) { if (err != 0) { ESP_LOGW(TAG, "SRP client reported an error: %s", otThreadErrorToString(err)); for (const otSrpClientHostInfo *host = host_info; host; host = nullptr) { @@ -90,16 +94,30 @@ void srp_callback(otError err, const otSrpClientHostInfo *host_info, const otSrp } } -void srp_start_callback(const otSockAddr *server_socket_address, void *context) { +void OpenThreadSrpComponent::srp_start_callback(const otSockAddr *server_socket_address, void *context) { ESP_LOGI(TAG, "SRP client has started"); } +void OpenThreadSrpComponent::srp_factory_reset_callback(otError err, const otSrpClientHostInfo *host_info, + const otSrpClientService *services, + const otSrpClientService *removed_services, void *context) { + OpenThreadComponent *obj = (OpenThreadComponent *) context; + if (err == OT_ERROR_NONE && removed_services != NULL && host_info != NULL && + host_info->mState == OT_SRP_CLIENT_ITEM_STATE_REMOVED) { + ESP_LOGD(TAG, "Successful Removal SRP Host and Services"); + } else if (err != OT_ERROR_NONE) { + // Handle other SRP client events or errors + ESP_LOGW(TAG, "SRP client event/error: %s", otThreadErrorToString(err)); + } + obj->defer_factory_reset_external_callback(); +} + void OpenThreadSrpComponent::setup() { otError error; InstanceLock lock = InstanceLock::acquire(); otInstance *instance = lock.get_instance(); - otSrpClientSetCallback(instance, srp_callback, nullptr); + otSrpClientSetCallback(instance, OpenThreadSrpComponent::srp_callback, nullptr); // set the host name uint16_t size; @@ -179,7 +197,8 @@ void OpenThreadSrpComponent::setup() { ESP_LOGD(TAG, "Added service: %s", full_service.c_str()); } - otSrpClientEnableAutoStartMode(instance, srp_start_callback, nullptr); + otSrpClientEnableAutoStartMode(instance, OpenThreadSrpComponent::srp_start_callback, nullptr); + ESP_LOGD(TAG, "Finished SRP setup"); } void *OpenThreadSrpComponent::pool_alloc_(size_t size) { @@ -217,6 +236,21 @@ bool OpenThreadComponent::teardown() { return this->teardown_complete_; } +void OpenThreadComponent::on_factory_reset(std::function callback) { + factory_reset_external_callback_ = callback; + ESP_LOGD(TAG, "Start Removal SRP Host and Services"); + otError error; + InstanceLock lock = InstanceLock::acquire(); + otInstance *instance = lock.get_instance(); + otSrpClientSetCallback(instance, OpenThreadSrpComponent::srp_factory_reset_callback, this); + error = otSrpClientRemoveHostAndServices(instance, true, true); + if (error != OT_ERROR_NONE) { + ESP_LOGW(TAG, "Failed to Remove SRP Host and Services"); + return; + } + ESP_LOGD(TAG, "Waiting on Confirmation Removal SRP Host and Services"); +} + } // namespace openthread } // namespace esphome diff --git a/esphome/components/openthread/openthread.h b/esphome/components/openthread/openthread.h index a0ea1b3f3a..a9aff78e56 100644 --- a/esphome/components/openthread/openthread.h +++ b/esphome/components/openthread/openthread.h @@ -6,6 +6,8 @@ #include "esphome/components/network/ip_address.h" #include "esphome/core/component.h" +#include +#include #include #include @@ -28,11 +30,14 @@ class OpenThreadComponent : public Component { network::IPAddresses get_ip_addresses(); std::optional get_omr_address(); void ot_main(); + void on_factory_reset(std::function callback); + void defer_factory_reset_external_callback(); protected: std::optional get_omr_address_(InstanceLock &lock); bool teardown_started_{false}; bool teardown_complete_{false}; + std::function factory_reset_external_callback_; }; extern OpenThreadComponent *global_openthread_component; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) @@ -43,6 +48,12 @@ class OpenThreadSrpComponent : public Component { // This has to run after the mdns component or else no services are available to advertise float get_setup_priority() const override { return this->mdns_->get_setup_priority() - 1.0; } void setup() override; + static void srp_callback(otError err, const otSrpClientHostInfo *host_info, const otSrpClientService *services, + const otSrpClientService *removed_services, void *context); + static void srp_start_callback(const otSockAddr *server_socket_address, void *context); + static void srp_factory_reset_callback(otError err, const otSrpClientHostInfo *host_info, + const otSrpClientService *services, const otSrpClientService *removed_services, + void *context); protected: esphome::mdns::MDNSComponent *mdns_{nullptr}; diff --git a/esphome/components/packet_transport/packet_transport.cpp b/esphome/components/packet_transport/packet_transport.cpp index b6ce24bc1b..8bde4ee505 100644 --- a/esphome/components/packet_transport/packet_transport.cpp +++ b/esphome/components/packet_transport/packet_transport.cpp @@ -270,6 +270,7 @@ void PacketTransport::add_binary_data_(uint8_t key, const char *id, bool data) { auto len = 1 + 1 + 1 + strlen(id); if (len + this->header_.size() + this->data_.size() > this->get_max_packet_size()) { this->flush_(); + this->init_data_(); } add(this->data_, key); add(this->data_, (uint8_t) data); @@ -284,6 +285,7 @@ void PacketTransport::add_data_(uint8_t key, const char *id, uint32_t data) { auto len = 4 + 1 + 1 + strlen(id); if (len + this->header_.size() + this->data_.size() > this->get_max_packet_size()) { this->flush_(); + this->init_data_(); } add(this->data_, key); add(this->data_, data); diff --git a/esphome/const.py b/esphome/const.py index 308abe7706..677b9173ec 100644 --- a/esphome/const.py +++ b/esphome/const.py @@ -114,6 +114,7 @@ CONF_AND = "and" CONF_ANGLE = "angle" CONF_ANY = "any" CONF_AP = "ap" +CONF_API = "api" CONF_APPARENT_POWER = "apparent_power" CONF_ARDUINO_VERSION = "arduino_version" CONF_AREA = "area" diff --git a/requirements_test.txt b/requirements_test.txt index 01661f3b7c..5ec9c98408 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -8,7 +8,7 @@ pre-commit pytest==8.4.2 pytest-cov==7.0.0 pytest-mock==3.15.0 -pytest-asyncio==1.1.0 +pytest-asyncio==1.2.0 pytest-xdist==3.8.0 asyncmock==0.4.2 hypothesis==6.92.1 diff --git a/tests/dashboard/test_entries.py b/tests/dashboard/test_entries.py new file mode 100644 index 0000000000..a86c33a16f --- /dev/null +++ b/tests/dashboard/test_entries.py @@ -0,0 +1,203 @@ +"""Tests for dashboard entries Path-related functionality.""" + +from __future__ import annotations + +from pathlib import Path +import tempfile +from unittest.mock import MagicMock + +import pytest +import pytest_asyncio + +from esphome.core import CORE +from esphome.dashboard.entries import DashboardEntries, DashboardEntry + + +def create_cache_key() -> tuple[int, int, float, int]: + """Helper to create a valid DashboardCacheKeyType.""" + return (0, 0, 0.0, 0) + + +@pytest.fixture(autouse=True) +def setup_core(): + """Set up CORE for testing.""" + with tempfile.TemporaryDirectory() as tmpdir: + CORE.config_path = str(Path(tmpdir) / "test.yaml") + yield + CORE.reset() + + +@pytest.fixture +def mock_settings() -> MagicMock: + """Create mock dashboard settings.""" + settings = MagicMock() + settings.config_dir = "/test/config" + settings.absolute_config_dir = Path("/test/config") + return settings + + +@pytest_asyncio.fixture +async def dashboard_entries(mock_settings: MagicMock) -> DashboardEntries: + """Create a DashboardEntries instance for testing.""" + return DashboardEntries(mock_settings) + + +def test_dashboard_entry_path_initialization() -> None: + """Test DashboardEntry initializes with path correctly.""" + test_path = "/test/config/device.yaml" + cache_key = create_cache_key() + + entry = DashboardEntry(test_path, cache_key) + + assert entry.path == test_path + assert entry.cache_key == cache_key + + +def test_dashboard_entry_path_with_absolute_path() -> None: + """Test DashboardEntry handles absolute paths.""" + # Use a truly absolute path for the platform + test_path = Path.cwd() / "absolute" / "path" / "to" / "config.yaml" + cache_key = create_cache_key() + + entry = DashboardEntry(str(test_path), cache_key) + + assert entry.path == str(test_path) + assert Path(entry.path).is_absolute() + + +def test_dashboard_entry_path_with_relative_path() -> None: + """Test DashboardEntry handles relative paths.""" + test_path = "configs/device.yaml" + cache_key = create_cache_key() + + entry = DashboardEntry(test_path, cache_key) + + assert entry.path == test_path + assert not Path(entry.path).is_absolute() + + +@pytest.mark.asyncio +async def test_dashboard_entries_get_by_path( + dashboard_entries: DashboardEntries, +) -> None: + """Test getting entry by path.""" + test_path = "/test/config/device.yaml" + entry = DashboardEntry(test_path, create_cache_key()) + + dashboard_entries._entries[test_path] = entry + + result = dashboard_entries.get(test_path) + assert result == entry + + +@pytest.mark.asyncio +async def test_dashboard_entries_get_nonexistent_path( + dashboard_entries: DashboardEntries, +) -> None: + """Test getting non-existent entry returns None.""" + result = dashboard_entries.get("/nonexistent/path.yaml") + assert result is None + + +@pytest.mark.asyncio +async def test_dashboard_entries_path_normalization( + dashboard_entries: DashboardEntries, +) -> None: + """Test that paths are handled consistently.""" + path1 = "/test/config/device.yaml" + + entry = DashboardEntry(path1, create_cache_key()) + dashboard_entries._entries[path1] = entry + + result = dashboard_entries.get(path1) + assert result == entry + + +@pytest.mark.asyncio +async def test_dashboard_entries_path_with_spaces( + dashboard_entries: DashboardEntries, +) -> None: + """Test handling paths with spaces.""" + test_path = "/test/config/my device.yaml" + entry = DashboardEntry(test_path, create_cache_key()) + + dashboard_entries._entries[test_path] = entry + + result = dashboard_entries.get(test_path) + assert result == entry + assert result.path == test_path + + +@pytest.mark.asyncio +async def test_dashboard_entries_path_with_special_chars( + dashboard_entries: DashboardEntries, +) -> None: + """Test handling paths with special characters.""" + test_path = "/test/config/device-01_test.yaml" + entry = DashboardEntry(test_path, create_cache_key()) + + dashboard_entries._entries[test_path] = entry + + result = dashboard_entries.get(test_path) + assert result == entry + + +def test_dashboard_entries_windows_path() -> None: + """Test handling Windows-style paths.""" + test_path = r"C:\Users\test\esphome\device.yaml" + cache_key = create_cache_key() + + entry = DashboardEntry(test_path, cache_key) + + assert entry.path == test_path + + +@pytest.mark.asyncio +async def test_dashboard_entries_path_to_cache_key_mapping( + dashboard_entries: DashboardEntries, +) -> None: + """Test internal entries storage with paths and cache keys.""" + path1 = "/test/config/device1.yaml" + path2 = "/test/config/device2.yaml" + + entry1 = DashboardEntry(path1, create_cache_key()) + entry2 = DashboardEntry(path2, (1, 1, 1.0, 1)) + + dashboard_entries._entries[path1] = entry1 + dashboard_entries._entries[path2] = entry2 + + assert path1 in dashboard_entries._entries + assert path2 in dashboard_entries._entries + assert dashboard_entries._entries[path1].cache_key == create_cache_key() + assert dashboard_entries._entries[path2].cache_key == (1, 1, 1.0, 1) + + +def test_dashboard_entry_path_property() -> None: + """Test that path property returns expected value.""" + test_path = "/test/config/device.yaml" + entry = DashboardEntry(test_path, create_cache_key()) + + assert entry.path == test_path + assert isinstance(entry.path, str) + + +@pytest.mark.asyncio +async def test_dashboard_entries_all_returns_entries_with_paths( + dashboard_entries: DashboardEntries, +) -> None: + """Test that all() returns entries with their paths intact.""" + paths = [ + "/test/config/device1.yaml", + "/test/config/device2.yaml", + "/test/config/subfolder/device3.yaml", + ] + + for path in paths: + entry = DashboardEntry(path, create_cache_key()) + dashboard_entries._entries[path] = entry + + all_entries = dashboard_entries.async_all() + + assert len(all_entries) == len(paths) + retrieved_paths = [entry.path for entry in all_entries] + assert set(retrieved_paths) == set(paths) diff --git a/tests/dashboard/test_settings.py b/tests/dashboard/test_settings.py new file mode 100644 index 0000000000..90a79ac0f8 --- /dev/null +++ b/tests/dashboard/test_settings.py @@ -0,0 +1,168 @@ +"""Tests for dashboard settings Path-related functionality.""" + +from __future__ import annotations + +import os +from pathlib import Path +import tempfile + +import pytest + +from esphome.dashboard.settings import DashboardSettings + + +@pytest.fixture +def dashboard_settings(tmp_path: Path) -> DashboardSettings: + """Create DashboardSettings instance with temp directory.""" + settings = DashboardSettings() + # Resolve symlinks to ensure paths match + resolved_dir = tmp_path.resolve() + settings.config_dir = str(resolved_dir) + settings.absolute_config_dir = resolved_dir + return settings + + +def test_rel_path_simple(dashboard_settings: DashboardSettings) -> None: + """Test rel_path with simple relative path.""" + result = dashboard_settings.rel_path("config.yaml") + + expected = str(Path(dashboard_settings.config_dir) / "config.yaml") + assert result == expected + + +def test_rel_path_multiple_components(dashboard_settings: DashboardSettings) -> None: + """Test rel_path with multiple path components.""" + result = dashboard_settings.rel_path("subfolder", "device", "config.yaml") + + expected = str( + Path(dashboard_settings.config_dir) / "subfolder" / "device" / "config.yaml" + ) + assert result == expected + + +def test_rel_path_with_dots(dashboard_settings: DashboardSettings) -> None: + """Test rel_path prevents directory traversal.""" + # This should raise ValueError as it tries to go outside config_dir + with pytest.raises(ValueError): + dashboard_settings.rel_path("..", "outside.yaml") + + +def test_rel_path_absolute_path_within_config( + dashboard_settings: DashboardSettings, +) -> None: + """Test rel_path with absolute path that's within config dir.""" + internal_path = dashboard_settings.absolute_config_dir / "internal.yaml" + + internal_path.touch() + result = dashboard_settings.rel_path("internal.yaml") + expected = str(Path(dashboard_settings.config_dir) / "internal.yaml") + assert result == expected + + +def test_rel_path_absolute_path_outside_config( + dashboard_settings: DashboardSettings, +) -> None: + """Test rel_path with absolute path outside config dir raises error.""" + outside_path = "/tmp/outside/config.yaml" + + with pytest.raises(ValueError): + dashboard_settings.rel_path(outside_path) + + +def test_rel_path_empty_args(dashboard_settings: DashboardSettings) -> None: + """Test rel_path with no arguments returns config_dir.""" + result = dashboard_settings.rel_path() + assert result == dashboard_settings.config_dir + + +def test_rel_path_with_pathlib_path(dashboard_settings: DashboardSettings) -> None: + """Test rel_path works with Path objects as arguments.""" + path_obj = Path("subfolder") / "config.yaml" + result = dashboard_settings.rel_path(path_obj) + + expected = str(Path(dashboard_settings.config_dir) / "subfolder" / "config.yaml") + assert result == expected + + +def test_rel_path_normalizes_slashes(dashboard_settings: DashboardSettings) -> None: + """Test rel_path normalizes path separators.""" + # os.path.join normalizes slashes on Windows but preserves them on Unix + # Test that providing components separately gives same result + result1 = dashboard_settings.rel_path("folder", "subfolder", "file.yaml") + result2 = dashboard_settings.rel_path("folder", "subfolder", "file.yaml") + assert result1 == result2 + + # Also test that the result is as expected + expected = os.path.join( + dashboard_settings.config_dir, "folder", "subfolder", "file.yaml" + ) + assert result1 == expected + + +def test_rel_path_handles_spaces(dashboard_settings: DashboardSettings) -> None: + """Test rel_path handles paths with spaces.""" + result = dashboard_settings.rel_path("my folder", "my config.yaml") + + expected = str(Path(dashboard_settings.config_dir) / "my folder" / "my config.yaml") + assert result == expected + + +def test_rel_path_handles_special_chars(dashboard_settings: DashboardSettings) -> None: + """Test rel_path handles paths with special characters.""" + result = dashboard_settings.rel_path("device-01_test", "config.yaml") + + expected = str( + Path(dashboard_settings.config_dir) / "device-01_test" / "config.yaml" + ) + assert result == expected + + +def test_config_dir_as_path_property(dashboard_settings: DashboardSettings) -> None: + """Test that config_dir can be accessed and used with Path operations.""" + config_path = Path(dashboard_settings.config_dir) + + assert config_path.exists() + assert config_path.is_dir() + assert config_path.is_absolute() + + +def test_absolute_config_dir_property(dashboard_settings: DashboardSettings) -> None: + """Test absolute_config_dir is a Path object.""" + assert isinstance(dashboard_settings.absolute_config_dir, Path) + assert dashboard_settings.absolute_config_dir.exists() + assert dashboard_settings.absolute_config_dir.is_dir() + assert dashboard_settings.absolute_config_dir.is_absolute() + + +def test_rel_path_symlink_inside_config(dashboard_settings: DashboardSettings) -> None: + """Test rel_path with symlink that points inside config dir.""" + target = dashboard_settings.absolute_config_dir / "target.yaml" + target.touch() + symlink = dashboard_settings.absolute_config_dir / "link.yaml" + symlink.symlink_to(target) + result = dashboard_settings.rel_path("link.yaml") + expected = str(Path(dashboard_settings.config_dir) / "link.yaml") + assert result == expected + + +def test_rel_path_symlink_outside_config(dashboard_settings: DashboardSettings) -> None: + """Test rel_path with symlink that points outside config dir.""" + with tempfile.NamedTemporaryFile(suffix=".yaml") as tmp: + symlink = dashboard_settings.absolute_config_dir / "external_link.yaml" + symlink.symlink_to(tmp.name) + with pytest.raises(ValueError): + dashboard_settings.rel_path("external_link.yaml") + + +def test_rel_path_with_none_arg(dashboard_settings: DashboardSettings) -> None: + """Test rel_path handles None arguments gracefully.""" + result = dashboard_settings.rel_path("None") + expected = str(Path(dashboard_settings.config_dir) / "None") + assert result == expected + + +def test_rel_path_with_numeric_args(dashboard_settings: DashboardSettings) -> None: + """Test rel_path handles numeric arguments.""" + result = dashboard_settings.rel_path("123", "456.789") + expected = str(Path(dashboard_settings.config_dir) / "123" / "456.789") + assert result == expected diff --git a/tests/dashboard/test_web_server.py b/tests/dashboard/test_web_server.py index b77ab7a7a3..e206090ac0 100644 --- a/tests/dashboard/test_web_server.py +++ b/tests/dashboard/test_web_server.py @@ -1,13 +1,16 @@ from __future__ import annotations import asyncio +from collections.abc import Generator +import gzip import json import os -from unittest.mock import Mock +from pathlib import Path +from unittest.mock import MagicMock, Mock, patch import pytest import pytest_asyncio -from tornado.httpclient import AsyncHTTPClient, HTTPResponse +from tornado.httpclient import AsyncHTTPClient, HTTPClientError, HTTPResponse from tornado.httpserver import HTTPServer from tornado.ioloop import IOLoop from tornado.testing import bind_unused_port @@ -34,6 +37,66 @@ class DashboardTestHelper: return await future +@pytest.fixture +def mock_async_run_system_command() -> Generator[MagicMock]: + """Fixture to mock async_run_system_command.""" + with patch("esphome.dashboard.web_server.async_run_system_command") as mock: + yield mock + + +@pytest.fixture +def mock_trash_storage_path(tmp_path: Path) -> Generator[MagicMock]: + """Fixture to mock trash_storage_path.""" + trash_dir = tmp_path / "trash" + with patch( + "esphome.dashboard.web_server.trash_storage_path", return_value=str(trash_dir) + ) as mock: + yield mock + + +@pytest.fixture +def mock_archive_storage_path(tmp_path: Path) -> Generator[MagicMock]: + """Fixture to mock archive_storage_path.""" + archive_dir = tmp_path / "archive" + with patch( + "esphome.dashboard.web_server.archive_storage_path", + return_value=str(archive_dir), + ) as mock: + yield mock + + +@pytest.fixture +def mock_dashboard_settings() -> Generator[MagicMock]: + """Fixture to mock dashboard settings.""" + with patch("esphome.dashboard.web_server.settings") as mock_settings: + # Set default auth settings to avoid authentication issues + mock_settings.using_auth = False + mock_settings.on_ha_addon = False + yield mock_settings + + +@pytest.fixture +def mock_ext_storage_path(tmp_path: Path) -> Generator[MagicMock]: + """Fixture to mock ext_storage_path.""" + with patch("esphome.dashboard.web_server.ext_storage_path") as mock: + mock.return_value = str(tmp_path / "storage.json") + yield mock + + +@pytest.fixture +def mock_storage_json() -> Generator[MagicMock]: + """Fixture to mock StorageJSON.""" + with patch("esphome.dashboard.web_server.StorageJSON") as mock: + yield mock + + +@pytest.fixture +def mock_idedata() -> Generator[MagicMock]: + """Fixture to mock platformio_api.IDEData.""" + with patch("esphome.dashboard.web_server.platformio_api.IDEData") as mock: + yield mock + + @pytest_asyncio.fixture() async def dashboard() -> DashboardTestHelper: sock, port = bind_unused_port() @@ -80,3 +143,499 @@ async def test_devices_page(dashboard: DashboardTestHelper) -> None: first_device = configured_devices[0] assert first_device["name"] == "pico" assert first_device["configuration"] == "pico.yaml" + + +@pytest.mark.asyncio +async def test_wizard_handler_invalid_input(dashboard: DashboardTestHelper) -> None: + """Test the WizardRequestHandler.post method with invalid inputs.""" + # Test with missing name (should fail with 422) + body_no_name = json.dumps( + { + "name": "", # Empty name + "platform": "ESP32", + "board": "esp32dev", + } + ) + with pytest.raises(HTTPClientError) as exc_info: + await dashboard.fetch( + "/wizard", + method="POST", + body=body_no_name, + headers={"Content-Type": "application/json"}, + ) + assert exc_info.value.code == 422 + + # Test with invalid wizard type (should fail with 422) + body_invalid_type = json.dumps( + { + "name": "test_device", + "type": "invalid_type", + "platform": "ESP32", + "board": "esp32dev", + } + ) + with pytest.raises(HTTPClientError) as exc_info: + await dashboard.fetch( + "/wizard", + method="POST", + body=body_invalid_type, + headers={"Content-Type": "application/json"}, + ) + assert exc_info.value.code == 422 + + +@pytest.mark.asyncio +async def test_wizard_handler_conflict(dashboard: DashboardTestHelper) -> None: + """Test the WizardRequestHandler.post when config already exists.""" + # Try to create a wizard for existing pico.yaml (should conflict) + body = json.dumps( + { + "name": "pico", # This already exists in fixtures + "platform": "ESP32", + "board": "esp32dev", + } + ) + with pytest.raises(HTTPClientError) as exc_info: + await dashboard.fetch( + "/wizard", + method="POST", + body=body, + headers={"Content-Type": "application/json"}, + ) + assert exc_info.value.code == 409 + + +@pytest.mark.asyncio +async def test_download_binary_handler_not_found( + dashboard: DashboardTestHelper, +) -> None: + """Test the DownloadBinaryRequestHandler.get with non-existent config.""" + with pytest.raises(HTTPClientError) as exc_info: + await dashboard.fetch( + "/download.bin?configuration=nonexistent.yaml", + method="GET", + ) + assert exc_info.value.code == 404 + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_ext_storage_path") +async def test_download_binary_handler_no_file_param( + dashboard: DashboardTestHelper, + tmp_path: Path, + mock_storage_json: MagicMock, +) -> None: + """Test the DownloadBinaryRequestHandler.get without file parameter.""" + # Mock storage to exist, but still should fail without file param + mock_storage = Mock() + mock_storage.name = "test_device" + mock_storage.firmware_bin_path = str(tmp_path / "firmware.bin") + mock_storage_json.load.return_value = mock_storage + + with pytest.raises(HTTPClientError) as exc_info: + await dashboard.fetch( + "/download.bin?configuration=pico.yaml", + method="GET", + ) + assert exc_info.value.code == 400 + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_ext_storage_path") +async def test_download_binary_handler_with_file( + dashboard: DashboardTestHelper, + tmp_path: Path, + mock_storage_json: MagicMock, +) -> None: + """Test the DownloadBinaryRequestHandler.get with existing binary file.""" + # Create a fake binary file + build_dir = tmp_path / ".esphome" / "build" / "test" + build_dir.mkdir(parents=True) + firmware_file = build_dir / "firmware.bin" + firmware_file.write_bytes(b"fake firmware content") + + # Mock storage JSON + mock_storage = Mock() + mock_storage.name = "test_device" + mock_storage.firmware_bin_path = str(firmware_file) + mock_storage_json.load.return_value = mock_storage + + response = await dashboard.fetch( + "/download.bin?configuration=test.yaml&file=firmware.bin", + method="GET", + ) + assert response.code == 200 + assert response.body == b"fake firmware content" + assert response.headers["Content-Type"] == "application/octet-stream" + assert "attachment" in response.headers["Content-Disposition"] + assert "test_device-firmware.bin" in response.headers["Content-Disposition"] + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_ext_storage_path") +async def test_download_binary_handler_compressed( + dashboard: DashboardTestHelper, + tmp_path: Path, + mock_storage_json: MagicMock, +) -> None: + """Test the DownloadBinaryRequestHandler.get with compression.""" + # Create a fake binary file + build_dir = tmp_path / ".esphome" / "build" / "test" + build_dir.mkdir(parents=True) + firmware_file = build_dir / "firmware.bin" + original_content = b"fake firmware content for compression test" + firmware_file.write_bytes(original_content) + + # Mock storage JSON + mock_storage = Mock() + mock_storage.name = "test_device" + mock_storage.firmware_bin_path = str(firmware_file) + mock_storage_json.load.return_value = mock_storage + + response = await dashboard.fetch( + "/download.bin?configuration=test.yaml&file=firmware.bin&compressed=1", + method="GET", + ) + assert response.code == 200 + # Decompress and verify content + decompressed = gzip.decompress(response.body) + assert decompressed == original_content + assert response.headers["Content-Type"] == "application/octet-stream" + assert "firmware.bin.gz" in response.headers["Content-Disposition"] + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_ext_storage_path") +async def test_download_binary_handler_custom_download_name( + dashboard: DashboardTestHelper, + tmp_path: Path, + mock_storage_json: MagicMock, +) -> None: + """Test the DownloadBinaryRequestHandler.get with custom download name.""" + # Create a fake binary file + build_dir = tmp_path / ".esphome" / "build" / "test" + build_dir.mkdir(parents=True) + firmware_file = build_dir / "firmware.bin" + firmware_file.write_bytes(b"content") + + # Mock storage JSON + mock_storage = Mock() + mock_storage.name = "test_device" + mock_storage.firmware_bin_path = str(firmware_file) + mock_storage_json.load.return_value = mock_storage + + response = await dashboard.fetch( + "/download.bin?configuration=test.yaml&file=firmware.bin&download=custom_name.bin", + method="GET", + ) + assert response.code == 200 + assert "custom_name.bin" in response.headers["Content-Disposition"] + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_ext_storage_path") +async def test_download_binary_handler_idedata_fallback( + dashboard: DashboardTestHelper, + tmp_path: Path, + mock_async_run_system_command: MagicMock, + mock_storage_json: MagicMock, + mock_idedata: MagicMock, +) -> None: + """Test the DownloadBinaryRequestHandler.get falling back to idedata for extra images.""" + # Create build directory but no bootloader file initially + build_dir = tmp_path / ".esphome" / "build" / "test" + build_dir.mkdir(parents=True) + firmware_file = build_dir / "firmware.bin" + firmware_file.write_bytes(b"firmware") + + # Create bootloader file that idedata will find + bootloader_file = tmp_path / "bootloader.bin" + bootloader_file.write_bytes(b"bootloader content") + + # Mock storage JSON + mock_storage = Mock() + mock_storage.name = "test_device" + mock_storage.firmware_bin_path = str(firmware_file) + mock_storage_json.load.return_value = mock_storage + + # Mock idedata response + mock_image = Mock() + mock_image.path = str(bootloader_file) + mock_idedata_instance = Mock() + mock_idedata_instance.extra_flash_images = [mock_image] + mock_idedata.return_value = mock_idedata_instance + + # Mock async_run_system_command to return idedata JSON + mock_async_run_system_command.return_value = (0, '{"extra_flash_images": []}', "") + + response = await dashboard.fetch( + "/download.bin?configuration=test.yaml&file=bootloader.bin", + method="GET", + ) + assert response.code == 200 + assert response.body == b"bootloader content" + + +@pytest.mark.asyncio +async def test_edit_request_handler_post_invalid_file( + dashboard: DashboardTestHelper, +) -> None: + """Test the EditRequestHandler.post with non-yaml file.""" + with pytest.raises(HTTPClientError) as exc_info: + await dashboard.fetch( + "/edit?configuration=test.txt", + method="POST", + body=b"content", + ) + assert exc_info.value.code == 404 + + +@pytest.mark.asyncio +async def test_edit_request_handler_post_existing( + dashboard: DashboardTestHelper, + tmp_path: Path, + mock_dashboard_settings: MagicMock, +) -> None: + """Test the EditRequestHandler.post with existing yaml file.""" + # Create a temporary yaml file to edit (don't modify fixtures) + test_file = tmp_path / "test_edit.yaml" + test_file.write_text("esphome:\n name: original\n") + + # Configure the mock settings + mock_dashboard_settings.rel_path.return_value = str(test_file) + mock_dashboard_settings.absolute_config_dir = test_file.parent + + new_content = "esphome:\n name: modified\n" + response = await dashboard.fetch( + "/edit?configuration=test_edit.yaml", + method="POST", + body=new_content.encode(), + ) + assert response.code == 200 + + # Verify the file was actually modified + assert test_file.read_text() == new_content + + +@pytest.mark.asyncio +async def test_unarchive_request_handler( + dashboard: DashboardTestHelper, + mock_archive_storage_path: MagicMock, + mock_dashboard_settings: MagicMock, + tmp_path: Path, +) -> None: + """Test the UnArchiveRequestHandler.post method.""" + # Set up an archived file + archive_dir = Path(mock_archive_storage_path.return_value) + archive_dir.mkdir(parents=True, exist_ok=True) + archived_file = archive_dir / "archived.yaml" + archived_file.write_text("test content") + + # Set up the destination path where the file should be moved + config_dir = tmp_path / "config" + config_dir.mkdir(parents=True, exist_ok=True) + destination_file = config_dir / "archived.yaml" + mock_dashboard_settings.rel_path.return_value = str(destination_file) + + response = await dashboard.fetch( + "/unarchive?configuration=archived.yaml", + method="POST", + body=b"", + ) + assert response.code == 200 + + # Verify the file was actually moved from archive to config + assert not archived_file.exists() # File should be gone from archive + assert destination_file.exists() # File should now be in config + assert destination_file.read_text() == "test content" # Content preserved + + +@pytest.mark.asyncio +async def test_secret_keys_handler_no_file(dashboard: DashboardTestHelper) -> None: + """Test the SecretKeysRequestHandler.get when no secrets file exists.""" + # By default, there's no secrets file in the test fixtures + with pytest.raises(HTTPClientError) as exc_info: + await dashboard.fetch("/secret_keys", method="GET") + assert exc_info.value.code == 404 + + +@pytest.mark.asyncio +async def test_secret_keys_handler_with_file( + dashboard: DashboardTestHelper, + tmp_path: Path, + mock_dashboard_settings: MagicMock, +) -> None: + """Test the SecretKeysRequestHandler.get when secrets file exists.""" + # Create a secrets file in temp directory + secrets_file = tmp_path / "secrets.yaml" + secrets_file.write_text( + "wifi_ssid: TestNetwork\nwifi_password: TestPass123\napi_key: test_key\n" + ) + + # Configure mock to return our temp secrets file + # Since the file actually exists, os.path.isfile will return True naturally + mock_dashboard_settings.rel_path.return_value = str(secrets_file) + + response = await dashboard.fetch("/secret_keys", method="GET") + assert response.code == 200 + data = json.loads(response.body.decode()) + assert "wifi_ssid" in data + assert "wifi_password" in data + assert "api_key" in data + + +@pytest.mark.asyncio +async def test_json_config_handler( + dashboard: DashboardTestHelper, + mock_async_run_system_command: MagicMock, +) -> None: + """Test the JsonConfigRequestHandler.get method.""" + # This will actually run the esphome config command on pico.yaml + mock_output = json.dumps( + { + "esphome": {"name": "pico"}, + "esp32": {"board": "esp32dev"}, + } + ) + mock_async_run_system_command.return_value = (0, mock_output, "") + + response = await dashboard.fetch( + "/json-config?configuration=pico.yaml", method="GET" + ) + assert response.code == 200 + data = json.loads(response.body.decode()) + assert data["esphome"]["name"] == "pico" + + +@pytest.mark.asyncio +async def test_json_config_handler_invalid_config( + dashboard: DashboardTestHelper, + mock_async_run_system_command: MagicMock, +) -> None: + """Test the JsonConfigRequestHandler.get with invalid config.""" + # Simulate esphome config command failure + mock_async_run_system_command.return_value = (1, "", "Error: Invalid configuration") + + with pytest.raises(HTTPClientError) as exc_info: + await dashboard.fetch("/json-config?configuration=pico.yaml", method="GET") + assert exc_info.value.code == 422 + + +@pytest.mark.asyncio +async def test_json_config_handler_not_found(dashboard: DashboardTestHelper) -> None: + """Test the JsonConfigRequestHandler.get with non-existent file.""" + with pytest.raises(HTTPClientError) as exc_info: + await dashboard.fetch( + "/json-config?configuration=nonexistent.yaml", method="GET" + ) + assert exc_info.value.code == 404 + + +def test_start_web_server_with_address_port( + tmp_path: Path, + mock_trash_storage_path: MagicMock, + mock_archive_storage_path: MagicMock, +) -> None: + """Test the start_web_server function with address and port.""" + app = Mock() + trash_dir = Path(mock_trash_storage_path.return_value) + archive_dir = Path(mock_archive_storage_path.return_value) + + # Create trash dir to test migration + trash_dir.mkdir() + (trash_dir / "old.yaml").write_text("old") + + web_server.start_web_server(app, None, "127.0.0.1", 6052, str(tmp_path / "config")) + + # The function calls app.listen directly for non-socket mode + app.listen.assert_called_once_with(6052, "127.0.0.1") + + # Verify trash was moved to archive + assert not trash_dir.exists() + assert archive_dir.exists() + assert (archive_dir / "old.yaml").exists() + + +@pytest.mark.asyncio +async def test_edit_request_handler_get(dashboard: DashboardTestHelper) -> None: + """Test EditRequestHandler.get method.""" + # Test getting a valid yaml file + response = await dashboard.fetch("/edit?configuration=pico.yaml") + assert response.code == 200 + assert response.headers["content-type"] == "application/yaml" + content = response.body.decode() + assert "esphome:" in content # Verify it's a valid ESPHome config + + # Test getting a non-existent file + with pytest.raises(HTTPClientError) as exc_info: + await dashboard.fetch("/edit?configuration=nonexistent.yaml") + assert exc_info.value.code == 404 + + # Test getting a non-yaml file + with pytest.raises(HTTPClientError) as exc_info: + await dashboard.fetch("/edit?configuration=test.txt") + assert exc_info.value.code == 404 + + # Test path traversal attempt + with pytest.raises(HTTPClientError) as exc_info: + await dashboard.fetch("/edit?configuration=../../../etc/passwd") + assert exc_info.value.code == 404 + + +@pytest.mark.asyncio +async def test_archive_request_handler_post( + dashboard: DashboardTestHelper, + mock_archive_storage_path: MagicMock, + mock_ext_storage_path: MagicMock, + tmp_path: Path, +) -> None: + """Test ArchiveRequestHandler.post method.""" + + # Set up temp directories + config_dir = Path(get_fixture_path("conf")) + archive_dir = tmp_path / "archive" + + # Create a test configuration file + test_config = config_dir / "test_archive.yaml" + test_config.write_text("esphome:\n name: test_archive\n") + + # Archive the configuration + response = await dashboard.fetch( + "/archive", + method="POST", + body="configuration=test_archive.yaml", + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + assert response.code == 200 + + # Verify file was moved to archive + assert not test_config.exists() + assert (archive_dir / "test_archive.yaml").exists() + assert ( + archive_dir / "test_archive.yaml" + ).read_text() == "esphome:\n name: test_archive\n" + + +@pytest.mark.skipif(os.name == "nt", reason="Unix sockets are not supported on Windows") +@pytest.mark.usefixtures("mock_trash_storage_path", "mock_archive_storage_path") +def test_start_web_server_with_unix_socket(tmp_path: Path) -> None: + """Test the start_web_server function with unix socket.""" + app = Mock() + socket_path = tmp_path / "test.sock" + + # Don't create trash_dir - it doesn't exist, so no migration needed + with ( + patch("tornado.httpserver.HTTPServer") as mock_server_class, + patch("tornado.netutil.bind_unix_socket") as mock_bind, + ): + server = Mock() + mock_server_class.return_value = server + mock_bind.return_value = Mock() + + web_server.start_web_server( + app, str(socket_path), None, None, str(tmp_path / "config") + ) + + mock_server_class.assert_called_once_with(app) + mock_bind.assert_called_once_with(str(socket_path), mode=0o666) + server.add_socket.assert_called_once() diff --git a/tests/dashboard/test_web_server_paths.py b/tests/dashboard/test_web_server_paths.py new file mode 100644 index 0000000000..f66e6a7ec2 --- /dev/null +++ b/tests/dashboard/test_web_server_paths.py @@ -0,0 +1,230 @@ +"""Tests for dashboard web_server Path-related functionality.""" + +from __future__ import annotations + +import gzip +import os +from pathlib import Path +from unittest.mock import MagicMock, patch + +from esphome.dashboard import web_server + + +def test_get_base_frontend_path_production() -> None: + """Test get_base_frontend_path in production mode.""" + mock_module = MagicMock() + mock_module.where.return_value = "/usr/local/lib/esphome_dashboard" + + with ( + patch.dict(os.environ, {}, clear=True), + patch.dict("sys.modules", {"esphome_dashboard": mock_module}), + ): + result = web_server.get_base_frontend_path() + assert result == "/usr/local/lib/esphome_dashboard" + mock_module.where.assert_called_once() + + +def test_get_base_frontend_path_dev_mode() -> None: + """Test get_base_frontend_path in development mode.""" + test_path = "/home/user/esphome/dashboard" + + with patch.dict(os.environ, {"ESPHOME_DASHBOARD_DEV": test_path}): + result = web_server.get_base_frontend_path() + + # The function uses os.path.abspath which doesn't resolve symlinks + # We need to match that behavior + # The actual function adds "/" to the path, so we simulate that + test_path_with_slash = test_path if test_path.endswith("/") else test_path + "/" + expected = os.path.abspath( + os.path.join(os.getcwd(), test_path_with_slash, "esphome_dashboard") + ) + assert result == expected + + +def test_get_base_frontend_path_dev_mode_with_trailing_slash() -> None: + """Test get_base_frontend_path in dev mode with trailing slash.""" + test_path = "/home/user/esphome/dashboard/" + + with patch.dict(os.environ, {"ESPHOME_DASHBOARD_DEV": test_path}): + result = web_server.get_base_frontend_path() + + # The function uses os.path.abspath which doesn't resolve symlinks + expected = os.path.abspath(str(Path.cwd() / test_path / "esphome_dashboard")) + assert result == expected + + +def test_get_base_frontend_path_dev_mode_relative_path() -> None: + """Test get_base_frontend_path with relative dev path.""" + test_path = "./dashboard" + + with patch.dict(os.environ, {"ESPHOME_DASHBOARD_DEV": test_path}): + result = web_server.get_base_frontend_path() + + # The function uses os.path.abspath which doesn't resolve symlinks + # We need to match that behavior + # The actual function adds "/" to the path, so we simulate that + test_path_with_slash = test_path if test_path.endswith("/") else test_path + "/" + expected = os.path.abspath( + os.path.join(os.getcwd(), test_path_with_slash, "esphome_dashboard") + ) + assert result == expected + assert Path(result).is_absolute() + + +def test_get_static_path_single_component() -> None: + """Test get_static_path with single path component.""" + with patch("esphome.dashboard.web_server.get_base_frontend_path") as mock_base: + mock_base.return_value = "/base/frontend" + + result = web_server.get_static_path("file.js") + + assert result == os.path.join("/base/frontend", "static", "file.js") + + +def test_get_static_path_multiple_components() -> None: + """Test get_static_path with multiple path components.""" + with patch("esphome.dashboard.web_server.get_base_frontend_path") as mock_base: + mock_base.return_value = "/base/frontend" + + result = web_server.get_static_path("js", "esphome", "index.js") + + assert result == os.path.join( + "/base/frontend", "static", "js", "esphome", "index.js" + ) + + +def test_get_static_path_empty_args() -> None: + """Test get_static_path with no arguments.""" + with patch("esphome.dashboard.web_server.get_base_frontend_path") as mock_base: + mock_base.return_value = "/base/frontend" + + result = web_server.get_static_path() + + assert result == os.path.join("/base/frontend", "static") + + +def test_get_static_path_with_pathlib_path() -> None: + """Test get_static_path with Path objects.""" + with patch("esphome.dashboard.web_server.get_base_frontend_path") as mock_base: + mock_base.return_value = "/base/frontend" + + path_obj = Path("js") / "app.js" + result = web_server.get_static_path(str(path_obj)) + + assert result == os.path.join("/base/frontend", "static", "js", "app.js") + + +def test_get_static_file_url_production() -> None: + """Test get_static_file_url in production mode.""" + web_server.get_static_file_url.cache_clear() + mock_module = MagicMock() + mock_file = MagicMock() + mock_file.read.return_value = b"test content" + mock_file.__enter__ = MagicMock(return_value=mock_file) + mock_file.__exit__ = MagicMock(return_value=None) + + with ( + patch.dict(os.environ, {}, clear=True), + patch.dict("sys.modules", {"esphome_dashboard": mock_module}), + patch("esphome.dashboard.web_server.get_static_path") as mock_get_path, + patch("esphome.dashboard.web_server.open", create=True, return_value=mock_file), + ): + mock_get_path.return_value = "/fake/path/js/app.js" + result = web_server.get_static_file_url("js/app.js") + assert result.startswith("./static/js/app.js?hash=") + + +def test_get_static_file_url_dev_mode() -> None: + """Test get_static_file_url in development mode.""" + with patch.dict(os.environ, {"ESPHOME_DASHBOARD_DEV": "/dev/path"}): + web_server.get_static_file_url.cache_clear() + result = web_server.get_static_file_url("js/app.js") + + assert result == "./static/js/app.js" + + +def test_get_static_file_url_index_js_special_case() -> None: + """Test get_static_file_url replaces index.js with entrypoint.""" + web_server.get_static_file_url.cache_clear() + mock_module = MagicMock() + mock_module.entrypoint.return_value = "main.js" + + with ( + patch.dict(os.environ, {}, clear=True), + patch.dict("sys.modules", {"esphome_dashboard": mock_module}), + ): + result = web_server.get_static_file_url("js/esphome/index.js") + assert result == "./static/js/esphome/main.js" + + +def test_load_file_path(tmp_path: Path) -> None: + """Test loading a file.""" + test_file = tmp_path / "test.txt" + test_file.write_bytes(b"test content") + + with open(test_file, "rb") as f: + content = f.read() + assert content == b"test content" + + +def test_load_file_compressed_path(tmp_path: Path) -> None: + """Test loading a compressed file.""" + test_file = tmp_path / "test.txt.gz" + + with gzip.open(test_file, "wb") as gz: + gz.write(b"compressed content") + + with gzip.open(test_file, "rb") as gz: + content = gz.read() + assert content == b"compressed content" + + +def test_path_normalization_in_static_path() -> None: + """Test that paths are normalized correctly.""" + with patch("esphome.dashboard.web_server.get_base_frontend_path") as mock_base: + mock_base.return_value = "/base/frontend" + + # Test with separate components + result1 = web_server.get_static_path("js", "app.js") + result2 = web_server.get_static_path("js", "app.js") + + assert result1 == result2 + assert result1 == os.path.join("/base/frontend", "static", "js", "app.js") + + +def test_windows_path_handling() -> None: + """Test handling of Windows-style paths.""" + with patch("esphome.dashboard.web_server.get_base_frontend_path") as mock_base: + mock_base.return_value = r"C:\Program Files\esphome\frontend" + + result = web_server.get_static_path("js", "app.js") + + # os.path.join should handle this correctly on the platform + expected = os.path.join( + r"C:\Program Files\esphome\frontend", "static", "js", "app.js" + ) + assert result == expected + + +def test_path_with_special_characters() -> None: + """Test paths with special characters.""" + with patch("esphome.dashboard.web_server.get_base_frontend_path") as mock_base: + mock_base.return_value = "/base/frontend" + + result = web_server.get_static_path("js-modules", "app_v1.0.js") + + assert result == os.path.join( + "/base/frontend", "static", "js-modules", "app_v1.0.js" + ) + + +def test_path_with_spaces() -> None: + """Test paths with spaces.""" + with patch("esphome.dashboard.web_server.get_base_frontend_path") as mock_base: + mock_base.return_value = "/base/my frontend" + + result = web_server.get_static_path("my js", "my app.js") + + assert result == os.path.join( + "/base/my frontend", "static", "my js", "my app.js" + ) diff --git a/tests/unit_tests/fixtures/yaml_util/named_dir/.hidden.yaml b/tests/unit_tests/fixtures/yaml_util/named_dir/.hidden.yaml new file mode 100644 index 0000000000..75eb989ea5 --- /dev/null +++ b/tests/unit_tests/fixtures/yaml_util/named_dir/.hidden.yaml @@ -0,0 +1,3 @@ +# This file should be ignored +platform: template +name: "Hidden Sensor" diff --git a/tests/unit_tests/fixtures/yaml_util/named_dir/not_yaml.txt b/tests/unit_tests/fixtures/yaml_util/named_dir/not_yaml.txt new file mode 100644 index 0000000000..98efb74b0f --- /dev/null +++ b/tests/unit_tests/fixtures/yaml_util/named_dir/not_yaml.txt @@ -0,0 +1 @@ +This is not a YAML file and should be ignored diff --git a/tests/unit_tests/fixtures/yaml_util/named_dir/sensor1.yaml b/tests/unit_tests/fixtures/yaml_util/named_dir/sensor1.yaml new file mode 100644 index 0000000000..a4b0a11916 --- /dev/null +++ b/tests/unit_tests/fixtures/yaml_util/named_dir/sensor1.yaml @@ -0,0 +1,4 @@ +platform: template +name: "Sensor 1" +lambda: |- + return 42.0; diff --git a/tests/unit_tests/fixtures/yaml_util/named_dir/sensor2.yaml b/tests/unit_tests/fixtures/yaml_util/named_dir/sensor2.yaml new file mode 100644 index 0000000000..72d4b714b6 --- /dev/null +++ b/tests/unit_tests/fixtures/yaml_util/named_dir/sensor2.yaml @@ -0,0 +1,4 @@ +platform: template +name: "Sensor 2" +lambda: |- + return 100.0; diff --git a/tests/unit_tests/fixtures/yaml_util/named_dir/subdir/sensor3.yaml b/tests/unit_tests/fixtures/yaml_util/named_dir/subdir/sensor3.yaml new file mode 100644 index 0000000000..bcb8dd320d --- /dev/null +++ b/tests/unit_tests/fixtures/yaml_util/named_dir/subdir/sensor3.yaml @@ -0,0 +1,4 @@ +platform: template +name: "Sensor 3 in subdir" +lambda: |- + return 200.0; diff --git a/tests/unit_tests/fixtures/yaml_util/secrets.yaml b/tests/unit_tests/fixtures/yaml_util/secrets.yaml new file mode 100644 index 0000000000..4eef570926 --- /dev/null +++ b/tests/unit_tests/fixtures/yaml_util/secrets.yaml @@ -0,0 +1,4 @@ +test_secret: "my_secret_value" +another_secret: "another_value" +wifi_password: "super_secret_wifi" +api_key: "0123456789abcdef" diff --git a/tests/unit_tests/fixtures/yaml_util/test_secret.yaml b/tests/unit_tests/fixtures/yaml_util/test_secret.yaml new file mode 100644 index 0000000000..c23afaee94 --- /dev/null +++ b/tests/unit_tests/fixtures/yaml_util/test_secret.yaml @@ -0,0 +1,17 @@ +esphome: + name: test_device + platform: ESP32 + board: esp32dev + +wifi: + ssid: "TestNetwork" + password: !secret wifi_password + +api: + encryption: + key: !secret api_key + +sensor: + - platform: template + name: "Test Sensor" + id: !secret test_secret diff --git a/tests/unit_tests/test_main.py b/tests/unit_tests/test_main.py new file mode 100644 index 0000000000..bfebb44545 --- /dev/null +++ b/tests/unit_tests/test_main.py @@ -0,0 +1,1533 @@ +"""Unit tests for esphome.__main__ module.""" + +from __future__ import annotations + +from collections.abc import Generator +from dataclasses import dataclass +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock, Mock, patch + +import pytest +from pytest import CaptureFixture + +from esphome.__main__ import ( + Purpose, + choose_upload_log_host, + command_rename, + command_wizard, + get_port_type, + has_ip_address, + has_mqtt, + has_mqtt_ip_lookup, + has_mqtt_logging, + has_non_ip_address, + has_resolvable_address, + mqtt_get_ip, + show_logs, + upload_program, +) +from esphome.const import ( + CONF_API, + CONF_BROKER, + CONF_DISABLED, + CONF_ESPHOME, + CONF_LEVEL, + CONF_LOG_TOPIC, + CONF_MDNS, + CONF_MQTT, + CONF_NAME, + CONF_OTA, + CONF_PASSWORD, + CONF_PLATFORM, + CONF_PORT, + CONF_SUBSTITUTIONS, + CONF_TOPIC, + CONF_USE_ADDRESS, + CONF_WIFI, + KEY_CORE, + KEY_TARGET_PLATFORM, + PLATFORM_BK72XX, + PLATFORM_ESP32, + PLATFORM_ESP8266, + PLATFORM_RP2040, +) +from esphome.core import CORE, EsphomeError + + +@dataclass +class MockSerialPort: + """Mock serial port for testing. + + Attributes: + path (str): The device path of the mock serial port (e.g., '/dev/ttyUSB0'). + description (str): A human-readable description of the mock serial port. + """ + + path: str + description: str + + +def setup_core( + config: dict[str, Any] | None = None, + address: str | None = None, + platform: str | None = None, + tmp_path: Path | None = None, + name: str = "test", +) -> None: + """ + Helper to set up CORE configuration with optional address. + + Args: + config (dict[str, Any] | None): The configuration dictionary to set for CORE. If None, an empty dict is used. + address (str | None): Optional network address to set in the configuration. If provided, it is set under the wifi config. + platform (str | None): Optional target platform to set in CORE.data. + tmp_path (Path | None): Optional temp path for setting up build paths. + name (str): The name of the device (defaults to "test"). + """ + if config is None: + config = {} + + if address is not None: + # Set address via wifi config (could also use ethernet) + config[CONF_WIFI] = {CONF_USE_ADDRESS: address} + + CORE.config = config + + if platform is not None: + CORE.data[KEY_CORE] = {} + CORE.data[KEY_CORE][KEY_TARGET_PLATFORM] = platform + + if tmp_path is not None: + CORE.config_path = str(tmp_path / f"{name}.yaml") + CORE.name = name + CORE.build_path = str(tmp_path / ".esphome" / "build" / name) + + +@pytest.fixture +def mock_no_serial_ports() -> Generator[Mock]: + """Mock get_serial_ports to return no ports.""" + with patch("esphome.__main__.get_serial_ports", return_value=[]) as mock: + yield mock + + +@pytest.fixture +def mock_get_port_type() -> Generator[Mock]: + """Mock get_port_type for testing.""" + with patch("esphome.__main__.get_port_type") as mock: + yield mock + + +@pytest.fixture +def mock_check_permissions() -> Generator[Mock]: + """Mock check_permissions for testing.""" + with patch("esphome.__main__.check_permissions") as mock: + yield mock + + +@pytest.fixture +def mock_run_miniterm() -> Generator[Mock]: + """Mock run_miniterm for testing.""" + with patch("esphome.__main__.run_miniterm") as mock: + yield mock + + +@pytest.fixture +def mock_upload_using_esptool() -> Generator[Mock]: + """Mock upload_using_esptool for testing.""" + with patch("esphome.__main__.upload_using_esptool") as mock: + yield mock + + +@pytest.fixture +def mock_upload_using_platformio() -> Generator[Mock]: + """Mock upload_using_platformio for testing.""" + with patch("esphome.__main__.upload_using_platformio") as mock: + yield mock + + +@pytest.fixture +def mock_run_ota() -> Generator[Mock]: + """Mock espota2.run_ota for testing.""" + with patch("esphome.espota2.run_ota") as mock: + yield mock + + +@pytest.fixture +def mock_is_ip_address() -> Generator[Mock]: + """Mock is_ip_address for testing.""" + with patch("esphome.__main__.is_ip_address") as mock: + yield mock + + +@pytest.fixture +def mock_mqtt_get_ip() -> Generator[Mock]: + """Mock mqtt_get_ip for testing.""" + with patch("esphome.__main__.mqtt_get_ip") as mock: + yield mock + + +@pytest.fixture +def mock_serial_ports() -> Generator[Mock]: + """Mock get_serial_ports to return test ports.""" + mock_ports = [ + MockSerialPort("/dev/ttyUSB0", "USB Serial"), + MockSerialPort("/dev/ttyUSB1", "Another USB Serial"), + ] + with patch("esphome.__main__.get_serial_ports", return_value=mock_ports) as mock: + yield mock + + +@pytest.fixture +def mock_choose_prompt() -> Generator[Mock]: + """Mock choose_prompt to return default selection.""" + with patch("esphome.__main__.choose_prompt", return_value="/dev/ttyUSB0") as mock: + yield mock + + +@pytest.fixture +def mock_no_mqtt_logging() -> Generator[Mock]: + """Mock has_mqtt_logging to return False.""" + with patch("esphome.__main__.has_mqtt_logging", return_value=False) as mock: + yield mock + + +@pytest.fixture +def mock_has_mqtt_logging() -> Generator[Mock]: + """Mock has_mqtt_logging to return True.""" + with patch("esphome.__main__.has_mqtt_logging", return_value=True) as mock: + yield mock + + +@pytest.fixture +def mock_run_external_process() -> Generator[Mock]: + """Mock run_external_process for testing.""" + with patch("esphome.__main__.run_external_process") as mock: + mock.return_value = 0 # Default to success + yield mock + + +def test_choose_upload_log_host_with_string_default() -> None: + """Test with a single string default device.""" + setup_core() + result = choose_upload_log_host( + default="192.168.1.100", + check_default=None, + purpose=Purpose.UPLOADING, + ) + assert result == ["192.168.1.100"] + + +def test_choose_upload_log_host_with_list_default() -> None: + """Test with a list of default devices.""" + setup_core() + result = choose_upload_log_host( + default=["192.168.1.100", "192.168.1.101"], + check_default=None, + purpose=Purpose.UPLOADING, + ) + assert result == ["192.168.1.100", "192.168.1.101"] + + +def test_choose_upload_log_host_with_multiple_ip_addresses() -> None: + """Test with multiple IP addresses as defaults.""" + setup_core() + result = choose_upload_log_host( + default=["1.2.3.4", "4.5.5.6"], + check_default=None, + purpose=Purpose.LOGGING, + ) + assert result == ["1.2.3.4", "4.5.5.6"] + + +def test_choose_upload_log_host_with_mixed_hostnames_and_ips() -> None: + """Test with a mix of hostnames and IP addresses.""" + setup_core() + result = choose_upload_log_host( + default=["host.one", "host.one.local", "1.2.3.4"], + check_default=None, + purpose=Purpose.UPLOADING, + ) + assert result == ["host.one", "host.one.local", "1.2.3.4"] + + +def test_choose_upload_log_host_with_ota_list() -> None: + """Test with OTA as the only item in the list.""" + setup_core(config={CONF_OTA: {}}, address="192.168.1.100") + + result = choose_upload_log_host( + default=["OTA"], + check_default=None, + purpose=Purpose.UPLOADING, + ) + assert result == ["192.168.1.100"] + + +@pytest.mark.usefixtures("mock_has_mqtt_logging") +def test_choose_upload_log_host_with_ota_list_mqtt_fallback() -> None: + """Test with OTA list falling back to MQTT when no address.""" + setup_core(config={CONF_OTA: {}, "mqtt": {}}) + + result = choose_upload_log_host( + default=["OTA"], + check_default=None, + purpose=Purpose.UPLOADING, + ) + assert result == ["MQTTIP"] + + +@pytest.mark.usefixtures("mock_has_mqtt_logging") +def test_choose_upload_log_host_with_ota_list_mqtt_fallback_logging() -> None: + """Test with OTA list with API and MQTT when no address.""" + setup_core(config={CONF_API: {}, "mqtt": {}}) + + result = choose_upload_log_host( + default=["OTA"], + check_default=None, + purpose=Purpose.LOGGING, + ) + assert result == ["MQTTIP", "MQTT"] + + +@pytest.mark.usefixtures("mock_no_serial_ports") +def test_choose_upload_log_host_with_serial_device_no_ports( + caplog: pytest.LogCaptureFixture, +) -> None: + """Test SERIAL device when no serial ports are found.""" + setup_core() + result = choose_upload_log_host( + default="SERIAL", + check_default=None, + purpose=Purpose.UPLOADING, + ) + assert result == [] + assert "No serial ports found, skipping SERIAL device" in caplog.text + + +@pytest.mark.usefixtures("mock_serial_ports") +def test_choose_upload_log_host_with_serial_device_with_ports( + mock_choose_prompt: Mock, +) -> None: + """Test SERIAL device when serial ports are available.""" + setup_core() + result = choose_upload_log_host( + default="SERIAL", + check_default=None, + purpose=Purpose.UPLOADING, + ) + assert result == ["/dev/ttyUSB0"] + mock_choose_prompt.assert_called_once_with( + [ + ("/dev/ttyUSB0 (USB Serial)", "/dev/ttyUSB0"), + ("/dev/ttyUSB1 (Another USB Serial)", "/dev/ttyUSB1"), + ], + purpose=Purpose.UPLOADING, + ) + + +def test_choose_upload_log_host_with_ota_device_with_ota_config() -> None: + """Test OTA device when OTA is configured.""" + setup_core(config={CONF_OTA: {}}, address="192.168.1.100") + + result = choose_upload_log_host( + default="OTA", + check_default=None, + purpose=Purpose.UPLOADING, + ) + assert result == ["192.168.1.100"] + + +def test_choose_upload_log_host_with_ota_device_with_api_config() -> None: + """Test OTA device when API is configured (no upload without OTA in config).""" + setup_core(config={CONF_API: {}}, address="192.168.1.100") + + result = choose_upload_log_host( + default="OTA", + check_default=None, + purpose=Purpose.UPLOADING, + ) + assert result == [] + + +def test_choose_upload_log_host_with_ota_device_with_api_config_logging() -> None: + """Test OTA device when API is configured.""" + setup_core(config={CONF_API: {}}, address="192.168.1.100") + + result = choose_upload_log_host( + default="OTA", + check_default=None, + purpose=Purpose.LOGGING, + ) + assert result == ["192.168.1.100"] + + +@pytest.mark.usefixtures("mock_has_mqtt_logging") +def test_choose_upload_log_host_with_ota_device_fallback_to_mqtt() -> None: + """Test OTA device fallback to MQTT when no OTA/API config.""" + setup_core(config={"mqtt": {}}) + + result = choose_upload_log_host( + default="OTA", + check_default=None, + purpose=Purpose.LOGGING, + ) + assert result == ["MQTT"] + + +@pytest.mark.usefixtures("mock_no_mqtt_logging") +def test_choose_upload_log_host_with_ota_device_no_fallback() -> None: + """Test OTA device with no valid fallback options.""" + setup_core() + + result = choose_upload_log_host( + default="OTA", + check_default=None, + purpose=Purpose.UPLOADING, + ) + assert result == [] + + +@pytest.mark.usefixtures("mock_choose_prompt") +def test_choose_upload_log_host_multiple_devices() -> None: + """Test with multiple devices including special identifiers.""" + setup_core(config={CONF_OTA: {}}, address="192.168.1.100") + + mock_ports = [MockSerialPort("/dev/ttyUSB0", "USB Serial")] + + with patch("esphome.__main__.get_serial_ports", return_value=mock_ports): + result = choose_upload_log_host( + default=["192.168.1.50", "OTA", "SERIAL"], + check_default=None, + purpose=Purpose.UPLOADING, + ) + assert result == ["192.168.1.50", "192.168.1.100", "/dev/ttyUSB0"] + + +def test_choose_upload_log_host_no_defaults_with_serial_ports( + mock_choose_prompt: Mock, +) -> None: + """Test interactive mode with serial ports available.""" + mock_ports = [ + MockSerialPort("/dev/ttyUSB0", "USB Serial"), + ] + + setup_core() + + with patch("esphome.__main__.get_serial_ports", return_value=mock_ports): + result = choose_upload_log_host( + default=None, + check_default=None, + purpose=Purpose.UPLOADING, + ) + assert result == ["/dev/ttyUSB0"] + mock_choose_prompt.assert_called_once_with( + [("/dev/ttyUSB0 (USB Serial)", "/dev/ttyUSB0")], + purpose=Purpose.UPLOADING, + ) + + +@pytest.mark.usefixtures("mock_no_serial_ports") +def test_choose_upload_log_host_no_defaults_with_ota() -> None: + """Test interactive mode with OTA option.""" + setup_core(config={CONF_OTA: {}}, address="192.168.1.100") + + with patch( + "esphome.__main__.choose_prompt", return_value="192.168.1.100" + ) as mock_prompt: + result = choose_upload_log_host( + default=None, + check_default=None, + purpose=Purpose.UPLOADING, + ) + assert result == ["192.168.1.100"] + mock_prompt.assert_called_once_with( + [("Over The Air (192.168.1.100)", "192.168.1.100")], + purpose=Purpose.UPLOADING, + ) + + +@pytest.mark.usefixtures("mock_no_serial_ports") +def test_choose_upload_log_host_no_defaults_with_api() -> None: + """Test interactive mode with API option.""" + setup_core(config={CONF_API: {}}, address="192.168.1.100") + + with patch( + "esphome.__main__.choose_prompt", return_value="192.168.1.100" + ) as mock_prompt: + result = choose_upload_log_host( + default=None, + check_default=None, + purpose=Purpose.LOGGING, + ) + assert result == ["192.168.1.100"] + mock_prompt.assert_called_once_with( + [("Over The Air (192.168.1.100)", "192.168.1.100")], + purpose=Purpose.LOGGING, + ) + + +@pytest.mark.usefixtures("mock_no_serial_ports", "mock_has_mqtt_logging") +def test_choose_upload_log_host_no_defaults_with_mqtt() -> None: + """Test interactive mode with MQTT option.""" + setup_core(config={CONF_MQTT: {CONF_BROKER: "mqtt.local"}}) + + with patch("esphome.__main__.choose_prompt", return_value="MQTT") as mock_prompt: + result = choose_upload_log_host( + default=None, + check_default=None, + purpose=Purpose.LOGGING, + ) + assert result == ["MQTT"] + mock_prompt.assert_called_once_with( + [("MQTT (mqtt.local)", "MQTT")], + purpose=Purpose.LOGGING, + ) + + +@pytest.mark.usefixtures("mock_has_mqtt_logging") +def test_choose_upload_log_host_no_defaults_with_all_options( + mock_choose_prompt: Mock, +) -> None: + """Test interactive mode with all options available.""" + setup_core( + config={CONF_OTA: {}, CONF_API: {}, CONF_MQTT: {CONF_BROKER: "mqtt.local"}}, + address="192.168.1.100", + ) + + mock_ports = [MockSerialPort("/dev/ttyUSB0", "USB Serial")] + + with patch("esphome.__main__.get_serial_ports", return_value=mock_ports): + result = choose_upload_log_host( + default=None, + check_default=None, + purpose=Purpose.UPLOADING, + ) + assert result == ["/dev/ttyUSB0"] + + expected_options = [ + ("/dev/ttyUSB0 (USB Serial)", "/dev/ttyUSB0"), + ("Over The Air (192.168.1.100)", "192.168.1.100"), + ("Over The Air (MQTT IP lookup)", "MQTTIP"), + ] + mock_choose_prompt.assert_called_once_with( + expected_options, purpose=Purpose.UPLOADING + ) + + +def test_choose_upload_log_host_no_defaults_with_all_options_logging( + mock_choose_prompt: Mock, +) -> None: + """Test interactive mode with all options available.""" + setup_core( + config={CONF_OTA: {}, CONF_API: {}, CONF_MQTT: {CONF_BROKER: "mqtt.local"}}, + address="192.168.1.100", + ) + + mock_ports = [MockSerialPort("/dev/ttyUSB0", "USB Serial")] + + with patch("esphome.__main__.get_serial_ports", return_value=mock_ports): + result = choose_upload_log_host( + default=None, + check_default=None, + purpose=Purpose.LOGGING, + ) + assert result == ["/dev/ttyUSB0"] + + expected_options = [ + ("/dev/ttyUSB0 (USB Serial)", "/dev/ttyUSB0"), + ("MQTT (mqtt.local)", "MQTT"), + ("Over The Air (192.168.1.100)", "192.168.1.100"), + ("Over The Air (MQTT IP lookup)", "MQTTIP"), + ] + mock_choose_prompt.assert_called_once_with( + expected_options, purpose=Purpose.LOGGING + ) + + +@pytest.mark.usefixtures("mock_no_serial_ports") +def test_choose_upload_log_host_check_default_matches() -> None: + """Test when check_default matches an available option.""" + setup_core(config={CONF_OTA: {}}, address="192.168.1.100") + + result = choose_upload_log_host( + default=None, + check_default="192.168.1.100", + purpose=Purpose.UPLOADING, + ) + assert result == ["192.168.1.100"] + + +@pytest.mark.usefixtures("mock_no_serial_ports") +def test_choose_upload_log_host_check_default_no_match() -> None: + """Test when check_default doesn't match any available option.""" + setup_core() + + with patch( + "esphome.__main__.choose_prompt", return_value="fallback" + ) as mock_prompt: + result = choose_upload_log_host( + default=None, + check_default="192.168.1.100", + purpose=Purpose.UPLOADING, + ) + assert result == ["fallback"] + mock_prompt.assert_called_once() + + +@pytest.mark.usefixtures("mock_no_serial_ports") +def test_choose_upload_log_host_empty_defaults_list() -> None: + """Test with an empty list as default.""" + setup_core() + with patch("esphome.__main__.choose_prompt", return_value="chosen") as mock_prompt: + result = choose_upload_log_host( + default=[], + check_default=None, + purpose=Purpose.UPLOADING, + ) + assert result == ["chosen"] + mock_prompt.assert_called_once() + + +@pytest.mark.usefixtures("mock_no_serial_ports", "mock_no_mqtt_logging") +def test_choose_upload_log_host_all_devices_unresolved( + caplog: pytest.LogCaptureFixture, +) -> None: + """Test when all specified devices cannot be resolved.""" + setup_core() + + result = choose_upload_log_host( + default=["SERIAL", "OTA"], + check_default=None, + purpose=Purpose.UPLOADING, + ) + assert result == [] + assert ( + "All specified devices: ['SERIAL', 'OTA'] could not be resolved." in caplog.text + ) + + +@pytest.mark.usefixtures("mock_no_serial_ports", "mock_no_mqtt_logging") +def test_choose_upload_log_host_mixed_resolved_unresolved() -> None: + """Test with a mix of resolved and unresolved devices.""" + setup_core() + + result = choose_upload_log_host( + default=["192.168.1.50", "SERIAL", "OTA"], + check_default=None, + purpose=Purpose.UPLOADING, + ) + assert result == ["192.168.1.50"] + + +def test_choose_upload_log_host_ota_both_conditions() -> None: + """Test OTA device when both OTA and API are configured and enabled.""" + setup_core(config={CONF_OTA: {}, CONF_API: {}}, address="192.168.1.100") + + result = choose_upload_log_host( + default="OTA", + check_default=None, + purpose=Purpose.UPLOADING, + ) + assert result == ["192.168.1.100"] + + +@pytest.mark.usefixtures("mock_serial_ports") +def test_choose_upload_log_host_ota_ip_all_options() -> None: + """Test OTA device when both static IP, OTA, API and MQTT are configured and enabled but MDNS not.""" + setup_core( + config={ + CONF_OTA: {}, + CONF_API: {}, + CONF_MQTT: { + CONF_BROKER: "mqtt.local", + }, + CONF_MDNS: { + CONF_DISABLED: True, + }, + }, + address="192.168.1.100", + ) + + result = choose_upload_log_host( + default="OTA", + check_default=None, + purpose=Purpose.UPLOADING, + ) + assert result == ["192.168.1.100", "MQTTIP"] + + +@pytest.mark.usefixtures("mock_serial_ports") +def test_choose_upload_log_host_ota_local_all_options() -> None: + """Test OTA device when both static IP, OTA, API and MQTT are configured and enabled but MDNS not.""" + setup_core( + config={ + CONF_OTA: {}, + CONF_API: {}, + CONF_MQTT: { + CONF_BROKER: "mqtt.local", + }, + CONF_MDNS: { + CONF_DISABLED: True, + }, + }, + address="test.local", + ) + + result = choose_upload_log_host( + default="OTA", + check_default=None, + purpose=Purpose.UPLOADING, + ) + assert result == ["MQTTIP", "test.local"] + + +@pytest.mark.usefixtures("mock_serial_ports") +def test_choose_upload_log_host_ota_ip_all_options_logging() -> None: + """Test OTA device when both static IP, OTA, API and MQTT are configured and enabled but MDNS not.""" + setup_core( + config={ + CONF_OTA: {}, + CONF_API: {}, + CONF_MQTT: { + CONF_BROKER: "mqtt.local", + }, + CONF_MDNS: { + CONF_DISABLED: True, + }, + }, + address="192.168.1.100", + ) + + result = choose_upload_log_host( + default="OTA", + check_default=None, + purpose=Purpose.LOGGING, + ) + assert result == ["192.168.1.100", "MQTTIP", "MQTT"] + + +@pytest.mark.usefixtures("mock_serial_ports") +def test_choose_upload_log_host_ota_local_all_options_logging() -> None: + """Test OTA device when both static IP, OTA, API and MQTT are configured and enabled but MDNS not.""" + setup_core( + config={ + CONF_OTA: {}, + CONF_API: {}, + CONF_MQTT: { + CONF_BROKER: "mqtt.local", + }, + CONF_MDNS: { + CONF_DISABLED: True, + }, + }, + address="test.local", + ) + + result = choose_upload_log_host( + default="OTA", + check_default=None, + purpose=Purpose.LOGGING, + ) + assert result == ["MQTTIP", "MQTT", "test.local"] + + +@pytest.mark.usefixtures("mock_no_mqtt_logging") +def test_choose_upload_log_host_no_address_with_ota_config() -> None: + """Test OTA device when OTA is configured but no address is set.""" + setup_core(config={CONF_OTA: {}}) + + result = choose_upload_log_host( + default="OTA", + check_default=None, + purpose=Purpose.UPLOADING, + ) + assert result == [] + + +@dataclass +class MockArgs: + """Mock args for testing.""" + + file: str | None = None + upload_speed: int = 460800 + username: str | None = None + password: str | None = None + client_id: str | None = None + topic: str | None = None + configuration: str | None = None + name: str | None = None + dashboard: bool = False + + +def test_upload_program_serial_esp32( + mock_upload_using_esptool: Mock, + mock_get_port_type: Mock, + mock_check_permissions: Mock, +) -> None: + """Test upload_program with serial port for ESP32.""" + setup_core(platform=PLATFORM_ESP32) + mock_get_port_type.return_value = "SERIAL" + mock_upload_using_esptool.return_value = 0 + + config = {} + args = MockArgs() + devices = ["/dev/ttyUSB0"] + + exit_code, host = upload_program(config, args, devices) + + assert exit_code == 0 + assert host == "/dev/ttyUSB0" + mock_check_permissions.assert_called_once_with("/dev/ttyUSB0") + mock_upload_using_esptool.assert_called_once() + + +def test_upload_program_serial_esp8266_with_file( + mock_upload_using_esptool: Mock, + mock_get_port_type: Mock, + mock_check_permissions: Mock, +) -> None: + """Test upload_program with serial port for ESP8266 with custom file.""" + setup_core(platform=PLATFORM_ESP8266) + mock_get_port_type.return_value = "SERIAL" + mock_upload_using_esptool.return_value = 0 + + config = {} + args = MockArgs(file="firmware.bin") + devices = ["/dev/ttyUSB0"] + + exit_code, host = upload_program(config, args, devices) + + assert exit_code == 0 + assert host == "/dev/ttyUSB0" + mock_check_permissions.assert_called_once_with("/dev/ttyUSB0") + mock_upload_using_esptool.assert_called_once_with( + config, "/dev/ttyUSB0", "firmware.bin", 460800 + ) + + +@pytest.mark.parametrize( + "platform,device", + [ + (PLATFORM_RP2040, "/dev/ttyACM0"), + (PLATFORM_BK72XX, "/dev/ttyUSB0"), # LibreTiny platform + ], +) +def test_upload_program_serial_platformio_platforms( + mock_upload_using_platformio: Mock, + mock_get_port_type: Mock, + mock_check_permissions: Mock, + platform: str, + device: str, +) -> None: + """Test upload_program with serial port for platformio platforms (RP2040/LibreTiny).""" + setup_core(platform=platform) + mock_get_port_type.return_value = "SERIAL" + mock_upload_using_platformio.return_value = 0 + + config = {} + args = MockArgs() + devices = [device] + + exit_code, host = upload_program(config, args, devices) + + assert exit_code == 0 + assert host == device + mock_check_permissions.assert_called_once_with(device) + mock_upload_using_platformio.assert_called_once_with(config, device) + + +def test_upload_program_serial_upload_failed( + mock_upload_using_esptool: Mock, + mock_get_port_type: Mock, + mock_check_permissions: Mock, +) -> None: + """Test upload_program when serial upload fails.""" + setup_core(platform=PLATFORM_ESP32) + mock_get_port_type.return_value = "SERIAL" + mock_upload_using_esptool.return_value = 1 # Failed + + config = {} + args = MockArgs() + devices = ["/dev/ttyUSB0"] + + exit_code, host = upload_program(config, args, devices) + + assert exit_code == 1 + assert host is None + mock_check_permissions.assert_called_once_with("/dev/ttyUSB0") + mock_upload_using_esptool.assert_called_once() + + +def test_upload_program_ota_success( + mock_run_ota: Mock, + mock_get_port_type: Mock, + tmp_path: Path, +) -> None: + """Test upload_program with OTA.""" + setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path) + + mock_get_port_type.return_value = "NETWORK" + mock_run_ota.return_value = (0, "192.168.1.100") + + config = { + CONF_OTA: [ + { + CONF_PLATFORM: CONF_ESPHOME, + CONF_PORT: 3232, + CONF_PASSWORD: "secret", + } + ] + } + args = MockArgs() + devices = ["192.168.1.100"] + + exit_code, host = upload_program(config, args, devices) + + assert exit_code == 0 + assert host == "192.168.1.100" + expected_firmware = str( + tmp_path / ".esphome" / "build" / "test" / ".pioenvs" / "test" / "firmware.bin" + ) + mock_run_ota.assert_called_once_with( + ["192.168.1.100"], 3232, "secret", expected_firmware + ) + + +def test_upload_program_ota_with_file_arg( + mock_run_ota: Mock, + mock_get_port_type: Mock, + tmp_path: Path, +) -> None: + """Test upload_program with OTA and custom file.""" + setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path) + + mock_get_port_type.return_value = "NETWORK" + mock_run_ota.return_value = (0, "192.168.1.100") + + config = { + CONF_OTA: [ + { + CONF_PLATFORM: CONF_ESPHOME, + CONF_PORT: 3232, + } + ] + } + args = MockArgs(file="custom.bin") + devices = ["192.168.1.100"] + + exit_code, host = upload_program(config, args, devices) + + assert exit_code == 0 + assert host == "192.168.1.100" + mock_run_ota.assert_called_once_with(["192.168.1.100"], 3232, "", "custom.bin") + + +def test_upload_program_ota_no_config( + mock_get_port_type: Mock, +) -> None: + """Test upload_program with OTA but no OTA config.""" + setup_core(platform=PLATFORM_ESP32) + mock_get_port_type.return_value = "NETWORK" + + config = {} # No OTA config + args = MockArgs() + devices = ["192.168.1.100"] + + with pytest.raises(EsphomeError, match="Cannot upload Over the Air"): + upload_program(config, args, devices) + + +def test_upload_program_ota_with_mqtt_resolution( + mock_mqtt_get_ip: Mock, + mock_is_ip_address: Mock, + mock_run_ota: Mock, + tmp_path: Path, +) -> None: + """Test upload_program with OTA using MQTT for address resolution.""" + setup_core(address="device.local", platform=PLATFORM_ESP32, tmp_path=tmp_path) + + mock_is_ip_address.return_value = False + mock_mqtt_get_ip.return_value = ["192.168.1.100"] + mock_run_ota.return_value = (0, "192.168.1.100") + + config = { + CONF_OTA: [ + { + CONF_PLATFORM: CONF_ESPHOME, + CONF_PORT: 3232, + } + ], + CONF_MQTT: { + CONF_BROKER: "mqtt.local", + }, + CONF_MDNS: { + CONF_DISABLED: True, + }, + } + args = MockArgs(username="user", password="pass", client_id="client") + devices = ["MQTT"] + + exit_code, host = upload_program(config, args, devices) + + assert exit_code == 0 + assert host == "192.168.1.100" + mock_mqtt_get_ip.assert_called_once_with(config, "user", "pass", "client") + expected_firmware = str( + tmp_path / ".esphome" / "build" / "test" / ".pioenvs" / "test" / "firmware.bin" + ) + mock_run_ota.assert_called_once_with(["192.168.1.100"], 3232, "", expected_firmware) + + +@patch("esphome.__main__.importlib.import_module") +def test_upload_program_platform_specific_handler( + mock_import: Mock, + mock_get_port_type: Mock, +) -> None: + """Test upload_program with platform-specific upload handler.""" + setup_core(platform="custom_platform") + mock_get_port_type.return_value = "CUSTOM" + + mock_module = MagicMock() + mock_module.upload_program.return_value = True + mock_import.return_value = mock_module + + config = {} + args = MockArgs() + devices = ["custom_device"] + + exit_code, host = upload_program(config, args, devices) + + assert exit_code == 0 + assert host == "custom_device" + mock_import.assert_called_once_with("esphome.components.custom_platform") + mock_module.upload_program.assert_called_once_with(config, args, "custom_device") + + +def test_show_logs_serial( + mock_get_port_type: Mock, + mock_check_permissions: Mock, + mock_run_miniterm: Mock, +) -> None: + """Test show_logs with serial port.""" + setup_core(config={"logger": {}}, platform=PLATFORM_ESP32) + mock_get_port_type.return_value = "SERIAL" + mock_run_miniterm.return_value = 0 + + args = MockArgs() + devices = ["/dev/ttyUSB0"] + + result = show_logs(CORE.config, args, devices) + + assert result == 0 + mock_check_permissions.assert_called_once_with("/dev/ttyUSB0") + mock_run_miniterm.assert_called_once_with(CORE.config, "/dev/ttyUSB0", args) + + +def test_show_logs_no_logger() -> None: + """Test show_logs when logger is not configured.""" + setup_core(config={}, platform=PLATFORM_ESP32) # No logger config + args = MockArgs() + devices = ["/dev/ttyUSB0"] + + with pytest.raises(EsphomeError, match="Logger is not configured"): + show_logs(CORE.config, args, devices) + + +@patch("esphome.components.api.client.run_logs") +def test_show_logs_api( + mock_run_logs: Mock, +) -> None: + """Test show_logs with API.""" + setup_core( + config={ + "logger": {}, + CONF_API: {}, + CONF_MDNS: {CONF_DISABLED: False}, + }, + platform=PLATFORM_ESP32, + ) + mock_run_logs.return_value = 0 + + args = MockArgs() + devices = ["192.168.1.100", "192.168.1.101"] + + result = show_logs(CORE.config, args, devices) + + assert result == 0 + mock_run_logs.assert_called_once_with( + CORE.config, ["192.168.1.100", "192.168.1.101"] + ) + + +@patch("esphome.components.api.client.run_logs") +def test_show_logs_api_with_mqtt_fallback( + mock_run_logs: Mock, + mock_mqtt_get_ip: Mock, +) -> None: + """Test show_logs with API using MQTT for address resolution.""" + setup_core( + config={ + "logger": {}, + CONF_API: {}, + CONF_MDNS: {CONF_DISABLED: True}, + CONF_MQTT: {CONF_BROKER: "mqtt.local"}, + }, + platform=PLATFORM_ESP32, + ) + mock_run_logs.return_value = 0 + mock_mqtt_get_ip.return_value = ["192.168.1.200"] + + args = MockArgs(username="user", password="pass", client_id="client") + devices = ["device.local"] + + result = show_logs(CORE.config, args, devices) + + assert result == 0 + mock_mqtt_get_ip.assert_called_once_with(CORE.config, "user", "pass", "client") + mock_run_logs.assert_called_once_with(CORE.config, ["192.168.1.200"]) + + +@patch("esphome.mqtt.show_logs") +def test_show_logs_mqtt( + mock_mqtt_show_logs: Mock, +) -> None: + """Test show_logs with MQTT.""" + setup_core( + config={ + "logger": {}, + "mqtt": {CONF_BROKER: "mqtt.local"}, + }, + platform=PLATFORM_ESP32, + ) + mock_mqtt_show_logs.return_value = 0 + + args = MockArgs( + topic="esphome/logs", + username="user", + password="pass", + client_id="client", + ) + devices = ["MQTT"] + + result = show_logs(CORE.config, args, devices) + + assert result == 0 + mock_mqtt_show_logs.assert_called_once_with( + CORE.config, "esphome/logs", "user", "pass", "client" + ) + + +@patch("esphome.mqtt.show_logs") +def test_show_logs_network_with_mqtt_only( + mock_mqtt_show_logs: Mock, +) -> None: + """Test show_logs with network port but only MQTT configured.""" + setup_core( + config={ + "logger": {}, + "mqtt": {CONF_BROKER: "mqtt.local"}, + # No API configured + }, + platform=PLATFORM_ESP32, + ) + mock_mqtt_show_logs.return_value = 0 + + args = MockArgs( + topic="esphome/logs", + username="user", + password="pass", + client_id="client", + ) + devices = ["192.168.1.100"] + + result = show_logs(CORE.config, args, devices) + + assert result == 0 + mock_mqtt_show_logs.assert_called_once_with( + CORE.config, "esphome/logs", "user", "pass", "client" + ) + + +def test_show_logs_no_method_configured() -> None: + """Test show_logs when no remote logging method is configured.""" + setup_core( + config={ + "logger": {}, + # No API or MQTT configured + }, + platform=PLATFORM_ESP32, + ) + + args = MockArgs() + devices = ["192.168.1.100"] + + with pytest.raises( + EsphomeError, match="No remote or local logging method configured" + ): + show_logs(CORE.config, args, devices) + + +@patch("esphome.__main__.importlib.import_module") +def test_show_logs_platform_specific_handler( + mock_import: Mock, +) -> None: + """Test show_logs with platform-specific logs handler.""" + setup_core(platform="custom_platform", config={"logger": {}}) + + mock_module = MagicMock() + mock_module.show_logs.return_value = True + mock_import.return_value = mock_module + + config = {"logger": {}} + args = MockArgs() + devices = ["custom_device"] + + result = show_logs(config, args, devices) + + assert result == 0 + mock_import.assert_called_once_with("esphome.components.custom_platform") + mock_module.show_logs.assert_called_once_with(config, args, devices) + + +def test_has_mqtt_logging_no_log_topic() -> None: + """Test has_mqtt_logging returns True when CONF_LOG_TOPIC is not in mqtt_config.""" + + # Setup MQTT config without CONF_LOG_TOPIC (defaults to enabled - this is the missing test case) + setup_core(config={CONF_MQTT: {CONF_BROKER: "mqtt.local"}}) + assert has_mqtt_logging() is True + + # Setup MQTT config with CONF_LOG_TOPIC set to None (explicitly disabled) + setup_core(config={CONF_MQTT: {CONF_BROKER: "mqtt.local", CONF_LOG_TOPIC: None}}) + assert has_mqtt_logging() is False + + # Setup MQTT config with CONF_LOG_TOPIC set with topic and level (explicitly enabled) + setup_core( + config={ + CONF_MQTT: { + CONF_BROKER: "mqtt.local", + CONF_LOG_TOPIC: {CONF_TOPIC: "esphome/logs", CONF_LEVEL: "DEBUG"}, + } + } + ) + assert has_mqtt_logging() is True + + # Setup MQTT config with CONF_LOG_TOPIC set but level is NONE (disabled) + setup_core( + config={ + CONF_MQTT: { + CONF_BROKER: "mqtt.local", + CONF_LOG_TOPIC: {CONF_TOPIC: "esphome/logs", CONF_LEVEL: "NONE"}, + } + } + ) + assert has_mqtt_logging() is False + + # Setup without MQTT config at all + setup_core(config={}) + assert has_mqtt_logging() is False + + +def test_has_mqtt() -> None: + """Test has_mqtt function.""" + + # Test with MQTT configured + setup_core(config={CONF_MQTT: {CONF_BROKER: "mqtt.local"}}) + assert has_mqtt() is True + + # Test without MQTT configured + setup_core(config={}) + assert has_mqtt() is False + + # Test with other components but no MQTT + setup_core(config={CONF_API: {}, CONF_OTA: {}}) + assert has_mqtt() is False + + +def test_get_port_type() -> None: + """Test get_port_type function.""" + + assert get_port_type("/dev/ttyUSB0") == "SERIAL" + assert get_port_type("/dev/ttyACM0") == "SERIAL" + assert get_port_type("COM1") == "SERIAL" + assert get_port_type("COM10") == "SERIAL" + + assert get_port_type("MQTT") == "MQTT" + assert get_port_type("MQTTIP") == "MQTTIP" + + assert get_port_type("192.168.1.100") == "NETWORK" + assert get_port_type("esphome-device.local") == "NETWORK" + assert get_port_type("10.0.0.1") == "NETWORK" + + +def test_has_mqtt_ip_lookup() -> None: + """Test has_mqtt_ip_lookup function.""" + + CONF_DISCOVER_IP = "discover_ip" + + setup_core(config={}) + assert has_mqtt_ip_lookup() is False + + setup_core(config={CONF_MQTT: {CONF_BROKER: "mqtt.local"}}) + assert has_mqtt_ip_lookup() is True + + setup_core(config={CONF_MQTT: {CONF_BROKER: "mqtt.local", CONF_DISCOVER_IP: True}}) + assert has_mqtt_ip_lookup() is True + + setup_core(config={CONF_MQTT: {CONF_BROKER: "mqtt.local", CONF_DISCOVER_IP: False}}) + assert has_mqtt_ip_lookup() is False + + +def test_has_non_ip_address() -> None: + """Test has_non_ip_address function.""" + + setup_core(address=None) + assert has_non_ip_address() is False + + setup_core(address="192.168.1.100") + assert has_non_ip_address() is False + + setup_core(address="10.0.0.1") + assert has_non_ip_address() is False + + setup_core(address="esphome-device.local") + assert has_non_ip_address() is True + + setup_core(address="my-device") + assert has_non_ip_address() is True + + +def test_has_ip_address() -> None: + """Test has_ip_address function.""" + + setup_core(address=None) + assert has_ip_address() is False + + setup_core(address="192.168.1.100") + assert has_ip_address() is True + + setup_core(address="10.0.0.1") + assert has_ip_address() is True + + setup_core(address="esphome-device.local") + assert has_ip_address() is False + + setup_core(address="my-device") + assert has_ip_address() is False + + +def test_mqtt_get_ip() -> None: + """Test mqtt_get_ip function.""" + config = {CONF_MQTT: {CONF_BROKER: "mqtt.local"}} + + with patch("esphome.mqtt.get_esphome_device_ip") as mock_get_ip: + mock_get_ip.return_value = ["192.168.1.100", "192.168.1.101"] + + result = mqtt_get_ip(config, "user", "pass", "client-id") + + assert result == ["192.168.1.100", "192.168.1.101"] + mock_get_ip.assert_called_once_with(config, "user", "pass", "client-id") + + +def test_has_resolvable_address() -> None: + """Test has_resolvable_address function.""" + + # Test with mDNS enabled and hostname address + setup_core(config={}, address="esphome-device.local") + assert has_resolvable_address() is True + + # Test with mDNS disabled and hostname address + setup_core( + config={CONF_MDNS: {CONF_DISABLED: True}}, address="esphome-device.local" + ) + assert has_resolvable_address() is False + + # Test with IP address (mDNS doesn't matter) + setup_core(config={}, address="192.168.1.100") + assert has_resolvable_address() is True + + # Test with IP address and mDNS disabled + setup_core(config={CONF_MDNS: {CONF_DISABLED: True}}, address="192.168.1.100") + assert has_resolvable_address() is True + + # Test with no address but mDNS enabled (can still resolve mDNS names) + setup_core(config={}, address=None) + assert has_resolvable_address() is True + + # Test with no address and mDNS disabled + setup_core(config={CONF_MDNS: {CONF_DISABLED: True}}, address=None) + assert has_resolvable_address() is False + + +def test_command_wizard(tmp_path: Path) -> None: + """Test command_wizard function.""" + config_file = tmp_path / "test.yaml" + + # Mock wizard.wizard to avoid interactive prompts + with patch("esphome.wizard.wizard") as mock_wizard: + mock_wizard.return_value = 0 + + args = MockArgs(configuration=str(config_file)) + result = command_wizard(args) + + assert result == 0 + mock_wizard.assert_called_once_with(str(config_file)) + + +def test_command_rename_invalid_characters( + tmp_path: Path, capfd: CaptureFixture[str] +) -> None: + """Test command_rename with invalid characters in name.""" + setup_core(tmp_path=tmp_path) + + # Test with invalid character (space) + args = MockArgs(name="invalid name") + result = command_rename(args, {}) + + assert result == 1 + captured = capfd.readouterr() + assert "invalid character" in captured.out.lower() + + +def test_command_rename_complex_yaml( + tmp_path: Path, capfd: CaptureFixture[str] +) -> None: + """Test command_rename with complex YAML that cannot be renamed.""" + config_file = tmp_path / "test.yaml" + config_file.write_text("# Complex YAML without esphome section\nsome_key: value\n") + setup_core(tmp_path=tmp_path) + CORE.config_path = str(config_file) + + args = MockArgs(name="newname") + result = command_rename(args, {}) + + assert result == 1 + captured = capfd.readouterr() + assert "complex yaml" in captured.out.lower() + + +def test_command_rename_success( + tmp_path: Path, + capfd: CaptureFixture[str], + mock_run_external_process: Mock, +) -> None: + """Test successful rename of a simple configuration.""" + config_file = tmp_path / "oldname.yaml" + config_file.write_text(""" +esphome: + name: oldname + +esp32: + board: nodemcu-32s + +wifi: + ssid: "test" + password: "test1234" +""") + setup_core(tmp_path=tmp_path) + CORE.config_path = str(config_file) + + # Set up CORE.config to avoid ValueError when accessing CORE.address + CORE.config = {CONF_ESPHOME: {CONF_NAME: "oldname"}} + + args = MockArgs(name="newname", dashboard=False) + + # Simulate successful validation and upload + mock_run_external_process.return_value = 0 + + result = command_rename(args, {}) + + assert result == 0 + + # Verify new file was created + new_file = tmp_path / "newname.yaml" + assert new_file.exists() + + # Verify old file was removed + assert not config_file.exists() + + # Verify content was updated + content = new_file.read_text() + assert ( + 'name: "newname"' in content + or "name: 'newname'" in content + or "name: newname" in content + ) + + captured = capfd.readouterr() + assert "SUCCESS" in captured.out + + +def test_command_rename_with_substitutions( + tmp_path: Path, + mock_run_external_process: Mock, +) -> None: + """Test rename with substitutions in YAML.""" + config_file = tmp_path / "oldname.yaml" + config_file.write_text(""" +substitutions: + device_name: oldname + +esphome: + name: ${device_name} + +esp32: + board: nodemcu-32s +""") + setup_core(tmp_path=tmp_path) + CORE.config_path = str(config_file) + + # Set up CORE.config to avoid ValueError when accessing CORE.address + CORE.config = { + CONF_ESPHOME: {CONF_NAME: "oldname"}, + CONF_SUBSTITUTIONS: {"device_name": "oldname"}, + } + + args = MockArgs(name="newname", dashboard=False) + + mock_run_external_process.return_value = 0 + + result = command_rename(args, {}) + + assert result == 0 + + # Verify substitution was updated + new_file = tmp_path / "newname.yaml" + content = new_file.read_text() + assert 'device_name: "newname"' in content + + +def test_command_rename_validation_failure( + tmp_path: Path, + capfd: CaptureFixture[str], + mock_run_external_process: Mock, +) -> None: + """Test rename when validation fails.""" + config_file = tmp_path / "oldname.yaml" + config_file.write_text(""" +esphome: + name: oldname + +esp32: + board: nodemcu-32s +""") + setup_core(tmp_path=tmp_path) + CORE.config_path = str(config_file) + + args = MockArgs(name="newname", dashboard=False) + + # First call for validation fails + mock_run_external_process.return_value = 1 + + result = command_rename(args, {}) + + assert result == 1 + + # Verify new file was created but then removed due to failure + new_file = tmp_path / "newname.yaml" + assert not new_file.exists() + + # Verify old file still exists (not removed on failure) + assert config_file.exists() + + captured = capfd.readouterr() + assert "Rename failed" in captured.out diff --git a/tests/unit_tests/test_util.py b/tests/unit_tests/test_util.py index 74d6a74709..34f40a651f 100644 --- a/tests/unit_tests/test_util.py +++ b/tests/unit_tests/test_util.py @@ -141,3 +141,170 @@ def test_list_yaml_files_mixed_extensions(tmp_path: Path) -> None: str(yaml_file), str(yml_file), } + + +def test_list_yaml_files_does_not_recurse_into_subdirectories(tmp_path: Path) -> None: + """Test that list_yaml_files only finds files in specified directory, not subdirectories.""" + # Create directory structure with YAML files at different depths + root = tmp_path / "configs" + root.mkdir() + + # Create YAML files in the root directory + (root / "config1.yaml").write_text("test: 1") + (root / "config2.yml").write_text("test: 2") + (root / "device.yaml").write_text("test: device") + + # Create subdirectory with YAML files (should NOT be found) + subdir = root / "subdir" + subdir.mkdir() + (subdir / "nested1.yaml").write_text("test: nested1") + (subdir / "nested2.yml").write_text("test: nested2") + + # Create deeper subdirectory (should NOT be found) + deep_subdir = subdir / "deeper" + deep_subdir.mkdir() + (deep_subdir / "very_nested.yaml").write_text("test: very_nested") + + # Test listing files from the root directory + result = util.list_yaml_files([str(root)]) + + # Should only find the 3 files in root, not the 3 in subdirectories + assert len(result) == 3 + + # Check that only root-level files are found + assert str(root / "config1.yaml") in result + assert str(root / "config2.yml") in result + assert str(root / "device.yaml") in result + + # Ensure nested files are NOT found + for r in result: + assert "subdir" not in r + assert "deeper" not in r + assert "nested1.yaml" not in r + assert "nested2.yml" not in r + assert "very_nested.yaml" not in r + + +def test_list_yaml_files_excludes_secrets(tmp_path: Path) -> None: + """Test that secrets.yaml and secrets.yml are excluded.""" + root = tmp_path / "configs" + root.mkdir() + + # Create various YAML files including secrets + (root / "config.yaml").write_text("test: config") + (root / "secrets.yaml").write_text("wifi_password: secret123") + (root / "secrets.yml").write_text("api_key: secret456") + (root / "device.yaml").write_text("test: device") + + result = util.list_yaml_files([str(root)]) + + # Should find 2 files (config.yaml and device.yaml), not secrets + assert len(result) == 2 + assert str(root / "config.yaml") in result + assert str(root / "device.yaml") in result + assert str(root / "secrets.yaml") not in result + assert str(root / "secrets.yml") not in result + + +def test_list_yaml_files_excludes_hidden_files(tmp_path: Path) -> None: + """Test that hidden files (starting with .) are excluded.""" + root = tmp_path / "configs" + root.mkdir() + + # Create regular and hidden YAML files + (root / "config.yaml").write_text("test: config") + (root / ".hidden.yaml").write_text("test: hidden") + (root / ".backup.yml").write_text("test: backup") + (root / "device.yaml").write_text("test: device") + + result = util.list_yaml_files([str(root)]) + + # Should find only non-hidden files + assert len(result) == 2 + assert str(root / "config.yaml") in result + assert str(root / "device.yaml") in result + assert str(root / ".hidden.yaml") not in result + assert str(root / ".backup.yml") not in result + + +def test_filter_yaml_files_basic() -> None: + """Test filter_yaml_files function.""" + files = [ + "/path/to/config.yaml", + "/path/to/device.yml", + "/path/to/readme.txt", + "/path/to/script.py", + "/path/to/data.json", + "/path/to/another.yaml", + ] + + result = util.filter_yaml_files(files) + + assert len(result) == 3 + assert "/path/to/config.yaml" in result + assert "/path/to/device.yml" in result + assert "/path/to/another.yaml" in result + assert "/path/to/readme.txt" not in result + assert "/path/to/script.py" not in result + assert "/path/to/data.json" not in result + + +def test_filter_yaml_files_excludes_secrets() -> None: + """Test that filter_yaml_files excludes secrets files.""" + files = [ + "/path/to/config.yaml", + "/path/to/secrets.yaml", + "/path/to/secrets.yml", + "/path/to/device.yaml", + "/some/dir/secrets.yaml", + ] + + result = util.filter_yaml_files(files) + + assert len(result) == 2 + assert "/path/to/config.yaml" in result + assert "/path/to/device.yaml" in result + assert "/path/to/secrets.yaml" not in result + assert "/path/to/secrets.yml" not in result + assert "/some/dir/secrets.yaml" not in result + + +def test_filter_yaml_files_excludes_hidden() -> None: + """Test that filter_yaml_files excludes hidden files.""" + files = [ + "/path/to/config.yaml", + "/path/to/.hidden.yaml", + "/path/to/.backup.yml", + "/path/to/device.yaml", + "/some/dir/.config.yaml", + ] + + result = util.filter_yaml_files(files) + + assert len(result) == 2 + assert "/path/to/config.yaml" in result + assert "/path/to/device.yaml" in result + assert "/path/to/.hidden.yaml" not in result + assert "/path/to/.backup.yml" not in result + assert "/some/dir/.config.yaml" not in result + + +def test_filter_yaml_files_case_sensitive() -> None: + """Test that filter_yaml_files is case-sensitive for extensions.""" + files = [ + "/path/to/config.yaml", + "/path/to/config.YAML", + "/path/to/config.YML", + "/path/to/config.Yaml", + "/path/to/config.yml", + ] + + result = util.filter_yaml_files(files) + + # Should only match lowercase .yaml and .yml + assert len(result) == 2 + assert "/path/to/config.yaml" in result + assert "/path/to/config.yml" in result + assert "/path/to/config.YAML" not in result + assert "/path/to/config.YML" not in result + assert "/path/to/config.Yaml" not in result diff --git a/tests/unit_tests/test_writer.py b/tests/unit_tests/test_writer.py index f47947ff37..f1f86a322e 100644 --- a/tests/unit_tests/test_writer.py +++ b/tests/unit_tests/test_writer.py @@ -1,13 +1,34 @@ """Test writer module functionality.""" from collections.abc import Callable +from pathlib import Path from typing import Any from unittest.mock import MagicMock, patch import pytest +from esphome.core import EsphomeError from esphome.storage_json import StorageJSON -from esphome.writer import storage_should_clean, update_storage_json +from esphome.writer import ( + CPP_AUTO_GENERATE_BEGIN, + CPP_AUTO_GENERATE_END, + CPP_INCLUDE_BEGIN, + CPP_INCLUDE_END, + GITIGNORE_CONTENT, + clean_build, + clean_cmake_cache, + storage_should_clean, + update_storage_json, + write_cpp, + write_gitignore, +) + + +@pytest.fixture +def mock_copy_src_tree(): + """Mock copy_src_tree to avoid side effects during tests.""" + with patch("esphome.writer.copy_src_tree"): + yield @pytest.fixture @@ -218,3 +239,396 @@ def test_update_storage_json_logging_components_removed( # Verify save was called new_storage.save.assert_called_once_with("/test/path") + + +@patch("esphome.writer.CORE") +def test_clean_cmake_cache( + mock_core: MagicMock, + tmp_path: Path, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test clean_cmake_cache removes CMakeCache.txt file.""" + # Create directory structure + pioenvs_dir = tmp_path / ".pioenvs" + pioenvs_dir.mkdir() + device_dir = pioenvs_dir / "test_device" + device_dir.mkdir() + cmake_cache_file = device_dir / "CMakeCache.txt" + cmake_cache_file.write_text("# CMake cache file") + + # Setup mocks + mock_core.relative_pioenvs_path.side_effect = [ + str(pioenvs_dir), # First call for directory check + str(cmake_cache_file), # Second call for file path + ] + mock_core.name = "test_device" + + # Verify file exists before + assert cmake_cache_file.exists() + + # Call the function + with caplog.at_level("INFO"): + clean_cmake_cache() + + # Verify file was removed + assert not cmake_cache_file.exists() + + # Verify logging + assert "Deleting" in caplog.text + assert "CMakeCache.txt" in caplog.text + + +@patch("esphome.writer.CORE") +def test_clean_cmake_cache_no_pioenvs_dir( + mock_core: MagicMock, + tmp_path: Path, +) -> None: + """Test clean_cmake_cache when pioenvs directory doesn't exist.""" + # Setup non-existent directory path + pioenvs_dir = tmp_path / ".pioenvs" + + # Setup mocks + mock_core.relative_pioenvs_path.return_value = str(pioenvs_dir) + + # Verify directory doesn't exist + assert not pioenvs_dir.exists() + + # Call the function - should not crash + clean_cmake_cache() + + # Verify directory still doesn't exist + assert not pioenvs_dir.exists() + + +@patch("esphome.writer.CORE") +def test_clean_cmake_cache_no_cmake_file( + mock_core: MagicMock, + tmp_path: Path, +) -> None: + """Test clean_cmake_cache when CMakeCache.txt doesn't exist.""" + # Create directory structure without CMakeCache.txt + pioenvs_dir = tmp_path / ".pioenvs" + pioenvs_dir.mkdir() + device_dir = pioenvs_dir / "test_device" + device_dir.mkdir() + cmake_cache_file = device_dir / "CMakeCache.txt" + + # Setup mocks + mock_core.relative_pioenvs_path.side_effect = [ + str(pioenvs_dir), # First call for directory check + str(cmake_cache_file), # Second call for file path + ] + mock_core.name = "test_device" + + # Verify file doesn't exist + assert not cmake_cache_file.exists() + + # Call the function - should not crash + clean_cmake_cache() + + # Verify file still doesn't exist + assert not cmake_cache_file.exists() + + +@patch("esphome.writer.CORE") +def test_clean_build( + mock_core: MagicMock, + tmp_path: Path, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test clean_build removes all build artifacts.""" + # Create directory structure and files + pioenvs_dir = tmp_path / ".pioenvs" + pioenvs_dir.mkdir() + (pioenvs_dir / "test_file.o").write_text("object file") + + piolibdeps_dir = tmp_path / ".piolibdeps" + piolibdeps_dir.mkdir() + (piolibdeps_dir / "library").mkdir() + + dependencies_lock = tmp_path / "dependencies.lock" + dependencies_lock.write_text("lock file") + + # Setup mocks + mock_core.relative_pioenvs_path.return_value = str(pioenvs_dir) + mock_core.relative_piolibdeps_path.return_value = str(piolibdeps_dir) + mock_core.relative_build_path.return_value = str(dependencies_lock) + + # Verify all exist before + assert pioenvs_dir.exists() + assert piolibdeps_dir.exists() + assert dependencies_lock.exists() + + # Call the function + with caplog.at_level("INFO"): + clean_build() + + # Verify all were removed + assert not pioenvs_dir.exists() + assert not piolibdeps_dir.exists() + assert not dependencies_lock.exists() + + # Verify logging + assert "Deleting" in caplog.text + assert ".pioenvs" in caplog.text + assert ".piolibdeps" in caplog.text + assert "dependencies.lock" in caplog.text + + +@patch("esphome.writer.CORE") +def test_clean_build_partial_exists( + mock_core: MagicMock, + tmp_path: Path, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test clean_build when only some paths exist.""" + # Create only pioenvs directory + pioenvs_dir = tmp_path / ".pioenvs" + pioenvs_dir.mkdir() + (pioenvs_dir / "test_file.o").write_text("object file") + + piolibdeps_dir = tmp_path / ".piolibdeps" + dependencies_lock = tmp_path / "dependencies.lock" + + # Setup mocks + mock_core.relative_pioenvs_path.return_value = str(pioenvs_dir) + mock_core.relative_piolibdeps_path.return_value = str(piolibdeps_dir) + mock_core.relative_build_path.return_value = str(dependencies_lock) + + # Verify only pioenvs exists + assert pioenvs_dir.exists() + assert not piolibdeps_dir.exists() + assert not dependencies_lock.exists() + + # Call the function + with caplog.at_level("INFO"): + clean_build() + + # Verify only existing path was removed + assert not pioenvs_dir.exists() + assert not piolibdeps_dir.exists() + assert not dependencies_lock.exists() + + # Verify logging - only pioenvs should be logged + assert "Deleting" in caplog.text + assert ".pioenvs" in caplog.text + assert ".piolibdeps" not in caplog.text + assert "dependencies.lock" not in caplog.text + + +@patch("esphome.writer.CORE") +def test_clean_build_nothing_exists( + mock_core: MagicMock, + tmp_path: Path, +) -> None: + """Test clean_build when no build artifacts exist.""" + # Setup paths that don't exist + pioenvs_dir = tmp_path / ".pioenvs" + piolibdeps_dir = tmp_path / ".piolibdeps" + dependencies_lock = tmp_path / "dependencies.lock" + + # Setup mocks + mock_core.relative_pioenvs_path.return_value = str(pioenvs_dir) + mock_core.relative_piolibdeps_path.return_value = str(piolibdeps_dir) + mock_core.relative_build_path.return_value = str(dependencies_lock) + + # Verify nothing exists + assert not pioenvs_dir.exists() + assert not piolibdeps_dir.exists() + assert not dependencies_lock.exists() + + # Call the function - should not crash + clean_build() + + # Verify nothing was created + assert not pioenvs_dir.exists() + assert not piolibdeps_dir.exists() + assert not dependencies_lock.exists() + + +@patch("esphome.writer.CORE") +def test_write_gitignore_creates_new_file( + mock_core: MagicMock, + tmp_path: Path, +) -> None: + """Test write_gitignore creates a new .gitignore file when it doesn't exist.""" + gitignore_path = tmp_path / ".gitignore" + + # Setup mocks + mock_core.relative_config_path.return_value = str(gitignore_path) + + # Verify file doesn't exist + assert not gitignore_path.exists() + + # Call the function + write_gitignore() + + # Verify file was created with correct content + assert gitignore_path.exists() + assert gitignore_path.read_text() == GITIGNORE_CONTENT + + +@patch("esphome.writer.CORE") +def test_write_gitignore_skips_existing_file( + mock_core: MagicMock, + tmp_path: Path, +) -> None: + """Test write_gitignore doesn't overwrite existing .gitignore file.""" + gitignore_path = tmp_path / ".gitignore" + existing_content = "# Custom gitignore\n/custom_dir/\n" + gitignore_path.write_text(existing_content) + + # Setup mocks + mock_core.relative_config_path.return_value = str(gitignore_path) + + # Verify file exists with custom content + assert gitignore_path.exists() + assert gitignore_path.read_text() == existing_content + + # Call the function + write_gitignore() + + # Verify file was not modified + assert gitignore_path.exists() + assert gitignore_path.read_text() == existing_content + + +@patch("esphome.writer.write_file_if_changed") # Mock to capture output +@patch("esphome.writer.copy_src_tree") # Keep this mock as it's complex +@patch("esphome.writer.CORE") +def test_write_cpp_with_existing_file( + mock_core: MagicMock, + mock_copy_src_tree: MagicMock, + mock_write_file: MagicMock, + tmp_path: Path, +) -> None: + """Test write_cpp when main.cpp already exists.""" + # Create a real file with markers + main_cpp = tmp_path / "main.cpp" + existing_content = f"""#include "esphome.h" +{CPP_INCLUDE_BEGIN} +// Old includes +{CPP_INCLUDE_END} +void setup() {{ +{CPP_AUTO_GENERATE_BEGIN} +// Old code +{CPP_AUTO_GENERATE_END} +}} +void loop() {{}}""" + main_cpp.write_text(existing_content) + + # Setup mocks + mock_core.relative_src_path.return_value = str(main_cpp) + mock_core.cpp_global_section = "// Global section" + + # Call the function + test_code = " // New generated code" + write_cpp(test_code) + + # Verify copy_src_tree was called + mock_copy_src_tree.assert_called_once() + + # Get the content that would be written + mock_write_file.assert_called_once() + written_path, written_content = mock_write_file.call_args[0] + + # Check that markers are preserved and content is updated + assert CPP_INCLUDE_BEGIN in written_content + assert CPP_INCLUDE_END in written_content + assert CPP_AUTO_GENERATE_BEGIN in written_content + assert CPP_AUTO_GENERATE_END in written_content + assert test_code in written_content + assert "// Global section" in written_content + + +@patch("esphome.writer.write_file_if_changed") # Mock to capture output +@patch("esphome.writer.copy_src_tree") # Keep this mock as it's complex +@patch("esphome.writer.CORE") +def test_write_cpp_creates_new_file( + mock_core: MagicMock, + mock_copy_src_tree: MagicMock, + mock_write_file: MagicMock, + tmp_path: Path, +) -> None: + """Test write_cpp when main.cpp doesn't exist.""" + # Setup path for new file + main_cpp = tmp_path / "main.cpp" + + # Setup mocks + mock_core.relative_src_path.return_value = str(main_cpp) + mock_core.cpp_global_section = "// Global section" + + # Verify file doesn't exist + assert not main_cpp.exists() + + # Call the function + test_code = " // Generated code" + write_cpp(test_code) + + # Verify copy_src_tree was called + mock_copy_src_tree.assert_called_once() + + # Get the content that would be written + mock_write_file.assert_called_once() + written_path, written_content = mock_write_file.call_args[0] + assert written_path == str(main_cpp) + + # Check that all necessary parts are in the new file + assert '#include "esphome.h"' in written_content + assert CPP_INCLUDE_BEGIN in written_content + assert CPP_INCLUDE_END in written_content + assert CPP_AUTO_GENERATE_BEGIN in written_content + assert CPP_AUTO_GENERATE_END in written_content + assert test_code in written_content + assert "void setup()" in written_content + assert "void loop()" in written_content + assert "App.setup();" in written_content + assert "App.loop();" in written_content + + +@pytest.mark.usefixtures("mock_copy_src_tree") +@patch("esphome.writer.CORE") +def test_write_cpp_with_missing_end_marker( + mock_core: MagicMock, + tmp_path: Path, +) -> None: + """Test write_cpp raises error when end marker is missing.""" + # Create a file with begin marker but no end marker + main_cpp = tmp_path / "main.cpp" + existing_content = f"""#include "esphome.h" +{CPP_AUTO_GENERATE_BEGIN} +// Code without end marker""" + main_cpp.write_text(existing_content) + + # Setup mocks + mock_core.relative_src_path.return_value = str(main_cpp) + + # Call should raise an error + with pytest.raises(EsphomeError, match="Could not find auto generated code end"): + write_cpp("// New code") + + +@pytest.mark.usefixtures("mock_copy_src_tree") +@patch("esphome.writer.CORE") +def test_write_cpp_with_duplicate_markers( + mock_core: MagicMock, + tmp_path: Path, +) -> None: + """Test write_cpp raises error when duplicate markers exist.""" + # Create a file with duplicate begin markers + main_cpp = tmp_path / "main.cpp" + existing_content = f"""#include "esphome.h" +{CPP_AUTO_GENERATE_BEGIN} +// First section +{CPP_AUTO_GENERATE_END} +{CPP_AUTO_GENERATE_BEGIN} +// Duplicate section +{CPP_AUTO_GENERATE_END}""" + main_cpp.write_text(existing_content) + + # Setup mocks + mock_core.relative_src_path.return_value = str(main_cpp) + + # Call should raise an error + with pytest.raises(EsphomeError, match="Found multiple auto generate code begins"): + write_cpp("// New code") diff --git a/tests/unit_tests/test_yaml_util.py b/tests/unit_tests/test_yaml_util.py index f31e9554dc..bc3c89a64d 100644 --- a/tests/unit_tests/test_yaml_util.py +++ b/tests/unit_tests/test_yaml_util.py @@ -1,9 +1,26 @@ -from esphome import yaml_util +from pathlib import Path +import shutil +from unittest.mock import patch + +import pytest + +from esphome import core, yaml_util from esphome.components import substitutions from esphome.core import EsphomeError +from esphome.util import OrderedDict -def test_include_with_vars(fixture_path): +@pytest.fixture(autouse=True) +def clear_secrets_cache() -> None: + """Clear the secrets cache before each test.""" + yaml_util._SECRET_VALUES.clear() + yaml_util._SECRET_CACHE.clear() + yield + yaml_util._SECRET_VALUES.clear() + yaml_util._SECRET_CACHE.clear() + + +def test_include_with_vars(fixture_path: Path) -> None: yaml_file = fixture_path / "yaml_util" / "includetest.yaml" actual = yaml_util.load_yaml(yaml_file) @@ -62,3 +79,202 @@ def test_parsing_with_custom_loader(fixture_path): assert loader_calls[0].endswith("includes/included.yaml") assert loader_calls[1].endswith("includes/list.yaml") assert loader_calls[2].endswith("includes/scalar.yaml") + + +def test_construct_secret_simple(fixture_path: Path) -> None: + """Test loading a YAML file with !secret tags.""" + yaml_file = fixture_path / "yaml_util" / "test_secret.yaml" + + actual = yaml_util.load_yaml(yaml_file) + + # Check that secrets were properly loaded + assert actual["wifi"]["password"] == "super_secret_wifi" + assert actual["api"]["encryption"]["key"] == "0123456789abcdef" + assert actual["sensor"][0]["id"] == "my_secret_value" + + +def test_construct_secret_missing(fixture_path: Path, tmp_path: Path) -> None: + """Test that missing secrets raise proper errors.""" + # Create a YAML file with a secret that doesn't exist + test_yaml = tmp_path / "test.yaml" + test_yaml.write_text(""" +esphome: + name: test + +wifi: + password: !secret nonexistent_secret +""") + + # Create an empty secrets file + secrets_yaml = tmp_path / "secrets.yaml" + secrets_yaml.write_text("some_other_secret: value") + + with pytest.raises(EsphomeError, match="Secret 'nonexistent_secret' not defined"): + yaml_util.load_yaml(str(test_yaml)) + + +def test_construct_secret_no_secrets_file(tmp_path: Path) -> None: + """Test that missing secrets.yaml file raises proper error.""" + # Create a YAML file with a secret but no secrets.yaml + test_yaml = tmp_path / "test.yaml" + test_yaml.write_text(""" +wifi: + password: !secret some_secret +""") + + # Mock CORE.config_path to avoid NoneType error + with ( + patch.object(core.CORE, "config_path", str(tmp_path / "main.yaml")), + pytest.raises(EsphomeError, match="secrets.yaml"), + ): + yaml_util.load_yaml(str(test_yaml)) + + +def test_construct_secret_fallback_to_main_config_dir( + fixture_path: Path, tmp_path: Path +) -> None: + """Test fallback to main config directory for secrets.""" + # Create a subdirectory with a YAML file that uses secrets + subdir = tmp_path / "subdir" + subdir.mkdir() + + test_yaml = subdir / "test.yaml" + test_yaml.write_text(""" +wifi: + password: !secret test_secret +""") + + # Create secrets.yaml in the main directory + main_secrets = tmp_path / "secrets.yaml" + main_secrets.write_text("test_secret: main_secret_value") + + # Mock CORE.config_path to point to main directory + with patch.object(core.CORE, "config_path", str(tmp_path / "main.yaml")): + actual = yaml_util.load_yaml(str(test_yaml)) + assert actual["wifi"]["password"] == "main_secret_value" + + +def test_construct_include_dir_named(fixture_path: Path, tmp_path: Path) -> None: + """Test !include_dir_named directive.""" + # Copy fixture directory to temporary location + src_dir = fixture_path / "yaml_util" + dst_dir = tmp_path / "yaml_util" + shutil.copytree(src_dir, dst_dir) + + # Create test YAML that uses include_dir_named + test_yaml = dst_dir / "test_include_named.yaml" + test_yaml.write_text(""" +sensor: !include_dir_named named_dir +""") + + actual = yaml_util.load_yaml(str(test_yaml)) + actual_sensor = actual["sensor"] + + # Check that files were loaded with their names as keys + assert isinstance(actual_sensor, OrderedDict) + assert "sensor1" in actual_sensor + assert "sensor2" in actual_sensor + assert "sensor3" in actual_sensor # Files from subdirs are included with basename + + # Check content of loaded files + assert actual_sensor["sensor1"]["platform"] == "template" + assert actual_sensor["sensor1"]["name"] == "Sensor 1" + assert actual_sensor["sensor2"]["platform"] == "template" + assert actual_sensor["sensor2"]["name"] == "Sensor 2" + + # Check that subdirectory files are included with their basename + assert actual_sensor["sensor3"]["platform"] == "template" + assert actual_sensor["sensor3"]["name"] == "Sensor 3 in subdir" + + # Check that hidden files and non-YAML files are not included + assert ".hidden" not in actual_sensor + assert "not_yaml" not in actual_sensor + + +def test_construct_include_dir_named_empty_dir(tmp_path: Path) -> None: + """Test !include_dir_named with empty directory.""" + # Create empty directory + empty_dir = tmp_path / "empty_dir" + empty_dir.mkdir() + + test_yaml = tmp_path / "test.yaml" + test_yaml.write_text(""" +sensor: !include_dir_named empty_dir +""") + + actual = yaml_util.load_yaml(str(test_yaml)) + + # Should return empty OrderedDict + assert isinstance(actual["sensor"], OrderedDict) + assert len(actual["sensor"]) == 0 + + +def test_construct_include_dir_named_with_dots(tmp_path: Path) -> None: + """Test that include_dir_named ignores files starting with dots.""" + # Create directory with various files + test_dir = tmp_path / "test_dir" + test_dir.mkdir() + + # Create visible file + visible_file = test_dir / "visible.yaml" + visible_file.write_text("key: visible_value") + + # Create hidden file + hidden_file = test_dir / ".hidden.yaml" + hidden_file.write_text("key: hidden_value") + + # Create hidden directory with files + hidden_dir = test_dir / ".hidden_dir" + hidden_dir.mkdir() + hidden_subfile = hidden_dir / "subfile.yaml" + hidden_subfile.write_text("key: hidden_subfile_value") + + test_yaml = tmp_path / "test.yaml" + test_yaml.write_text(""" +test: !include_dir_named test_dir +""") + + actual = yaml_util.load_yaml(str(test_yaml)) + + # Should only include visible file + assert "visible" in actual["test"] + assert actual["test"]["visible"]["key"] == "visible_value" + + # Should not include hidden files or directories + assert ".hidden" not in actual["test"] + assert ".hidden_dir" not in actual["test"] + + +def test_find_files_recursive(fixture_path: Path, tmp_path: Path) -> None: + """Test that _find_files works recursively through include_dir_named.""" + # Copy fixture directory to temporary location + src_dir = fixture_path / "yaml_util" + dst_dir = tmp_path / "yaml_util" + shutil.copytree(src_dir, dst_dir) + + # This indirectly tests _find_files by using include_dir_named + test_yaml = dst_dir / "test_include_recursive.yaml" + test_yaml.write_text(""" +all_sensors: !include_dir_named named_dir +""") + + actual = yaml_util.load_yaml(str(test_yaml)) + + # Should find sensor1.yaml, sensor2.yaml, and subdir/sensor3.yaml (all flattened) + assert len(actual["all_sensors"]) == 3 + assert "sensor1" in actual["all_sensors"] + assert "sensor2" in actual["all_sensors"] + assert "sensor3" in actual["all_sensors"] + + +def test_secret_values_tracking(fixture_path: Path) -> None: + """Test that secret values are properly tracked for dumping.""" + yaml_file = fixture_path / "yaml_util" / "test_secret.yaml" + + yaml_util.load_yaml(yaml_file) + + # Check that secret values are tracked + assert "super_secret_wifi" in yaml_util._SECRET_VALUES + assert yaml_util._SECRET_VALUES["super_secret_wifi"] == "wifi_password" + assert "0123456789abcdef" in yaml_util._SECRET_VALUES + assert yaml_util._SECRET_VALUES["0123456789abcdef"] == "api_key"