diff --git a/esphome/__main__.py b/esphome/__main__.py index bba254436e..0147a82530 100644 --- a/esphome/__main__.py +++ b/esphome/__main__.py @@ -15,9 +15,11 @@ import argcomplete from esphome import const, writer, yaml_util import esphome.codegen as cg +from esphome.components.mqtt import CONF_DISCOVER_IP from esphome.config import iter_component_configs, read_config, strip_default_ids from esphome.const import ( ALLOWED_NAME_CHARS, + CONF_API, CONF_BAUD_RATE, CONF_BROKER, CONF_DEASSERT_RTS_DTR, @@ -43,6 +45,7 @@ from esphome.const import ( SECRETS_FILES, ) from esphome.core import CORE, EsphomeError, coroutine +from esphome.enum import StrEnum from esphome.helpers import get_bool_env, indent, is_ip_address from esphome.log import AnsiFore, color, setup_log from esphome.types import ConfigType @@ -106,13 +109,15 @@ def choose_prompt(options, purpose: str = None): return options[opt - 1][1] +class Purpose(StrEnum): + UPLOADING = "uploading" + LOGGING = "logging" + + def choose_upload_log_host( default: list[str] | str | None, check_default: str | None, - show_ota: bool, - show_mqtt: bool, - show_api: bool, - purpose: str | None = None, + purpose: Purpose, ) -> list[str]: # Convert to list for uniform handling defaults = [default] if isinstance(default, str) else default or [] @@ -132,13 +137,30 @@ def choose_upload_log_host( ] resolved.append(choose_prompt(options, purpose=purpose)) elif device == "OTA": - if CORE.address and ( - (show_ota and "ota" in CORE.config) - or (show_api and "api" in CORE.config) + # ensure IP adresses are used first + if is_ip_address(CORE.address) and ( + (purpose == Purpose.LOGGING and has_api()) + or (purpose == Purpose.UPLOADING and has_ota()) ): resolved.append(CORE.address) - elif show_mqtt and has_mqtt_logging(): - resolved.append("MQTT") + + if purpose == Purpose.LOGGING: + if has_api() and has_mqtt_ip_lookup(): + resolved.append("MQTTIP") + + if has_mqtt_logging(): + resolved.append("MQTT") + + if has_api() and has_non_ip_address(): + resolved.append(CORE.address) + + elif purpose == Purpose.UPLOADING: + if has_ota() and has_mqtt_ip_lookup(): + resolved.append("MQTTIP") + + if has_ota() and has_non_ip_address(): + resolved.append(CORE.address) + else: resolved.append(device) if not resolved: @@ -149,39 +171,111 @@ def choose_upload_log_host( options = [ (f"{port.path} ({port.description})", port.path) for port in get_serial_ports() ] - if (show_ota and "ota" in CORE.config) or (show_api and "api" in CORE.config): - options.append((f"Over The Air ({CORE.address})", CORE.address)) - if show_mqtt and has_mqtt_logging(): - mqtt_config = CORE.config[CONF_MQTT] - options.append((f"MQTT ({mqtt_config[CONF_BROKER]})", "MQTT")) + + if purpose == Purpose.LOGGING: + if has_mqtt_logging(): + mqtt_config = CORE.config[CONF_MQTT] + options.append((f"MQTT ({mqtt_config[CONF_BROKER]})", "MQTT")) + + if has_api(): + if has_resolvable_address(): + options.append((f"Over The Air ({CORE.address})", CORE.address)) + if has_mqtt_ip_lookup(): + options.append(("Over The Air (MQTT IP lookup)", "MQTTIP")) + + elif purpose == Purpose.UPLOADING and has_ota(): + if has_resolvable_address(): + options.append((f"Over The Air ({CORE.address})", CORE.address)) + if has_mqtt_ip_lookup(): + options.append(("Over The Air (MQTT IP lookup)", "MQTTIP")) if check_default is not None and check_default in [opt[1] for opt in options]: return [check_default] return [choose_prompt(options, purpose=purpose)] -def mqtt_logging_enabled(mqtt_config): +def has_mqtt_logging() -> bool: + """Check if MQTT logging is available.""" + if CONF_MQTT not in CORE.config: + return False + + mqtt_config = CORE.config[CONF_MQTT] + + # enabled by default + if CONF_LOG_TOPIC not in mqtt_config: + return True + log_topic = mqtt_config[CONF_LOG_TOPIC] if log_topic is None: return False + if CONF_TOPIC not in log_topic: return False - return log_topic.get(CONF_LEVEL, None) != "NONE" + + return log_topic[CONF_LEVEL] != "NONE" -def has_mqtt_logging() -> bool: - """Check if MQTT logging is available.""" - return (mqtt_config := CORE.config.get(CONF_MQTT)) and mqtt_logging_enabled( - mqtt_config - ) +def has_mqtt() -> bool: + """Check if MQTT is available.""" + return CONF_MQTT in CORE.config + + +def has_api() -> bool: + """Check if API is available.""" + return CONF_API in CORE.config + + +def has_ota() -> bool: + """Check if OTA is available.""" + return CONF_OTA in CORE.config + + +def has_mqtt_ip_lookup() -> bool: + """Check if MQTT is available and IP lookup is supported.""" + if CONF_MQTT not in CORE.config: + return False + # Default Enabled + if CONF_DISCOVER_IP not in CORE.config[CONF_MQTT]: + return True + return CORE.config[CONF_MQTT][CONF_DISCOVER_IP] + + +def has_mdns() -> bool: + """Check if MDNS is available.""" + return CONF_MDNS not in CORE.config or not CORE.config[CONF_MDNS][CONF_DISABLED] + + +def has_non_ip_address() -> bool: + """Check if CORE.address is set and is not an IP address.""" + return CORE.address is not None and not is_ip_address(CORE.address) + + +def has_ip_address() -> bool: + """Check if CORE.address is a valid IP address.""" + return CORE.address is not None and is_ip_address(CORE.address) + + +def has_resolvable_address() -> bool: + """Check if CORE.address is resolvable (via mDNS or is an IP address).""" + return has_mdns() or has_ip_address() + + +def mqtt_get_ip(config: ConfigType, username: str, password: str, client_id: str): + from esphome import mqtt + + return mqtt.get_esphome_device_ip(config, username, password, client_id) + + +_PORT_TO_PORT_TYPE = { + "MQTT": "MQTT", + "MQTTIP": "MQTTIP", +} def get_port_type(port: str) -> str: if port.startswith("/") or port.startswith("COM"): return "SERIAL" - if port == "MQTT": - return "MQTT" - return "NETWORK" + return _PORT_TO_PORT_TYPE.get(port, "NETWORK") def run_miniterm(config: ConfigType, port: str, args) -> int: @@ -439,23 +533,9 @@ def upload_program( password = ota_conf.get(CONF_PASSWORD, "") binary = args.file if getattr(args, "file", None) is not None else CORE.firmware_bin - # Check if we should use MQTT for address resolution - # This happens when no device was specified, or the current host is "MQTT"/"OTA" - if ( - CONF_MQTT in config # pylint: disable=too-many-boolean-expressions - and (not devices or host in ("MQTT", "OTA")) - and ( - ((config[CONF_MDNS][CONF_DISABLED]) and not is_ip_address(CORE.address)) - or get_port_type(host) == "MQTT" - ) - ): - from esphome import mqtt - - devices = [ - mqtt.get_esphome_device_ip( - config, args.username, args.password, args.client_id - ) - ] + # MQTT address resolution + if get_port_type(host) in ("MQTT", "MQTTIP"): + devices = mqtt_get_ip(config, args.username, args.password, args.client_id) return espota2.run_ota(devices, remote_port, password, binary) @@ -476,20 +556,28 @@ def show_logs(config: ConfigType, args: ArgsProtocol, devices: list[str]) -> int if get_port_type(port) == "SERIAL": check_permissions(port) return run_miniterm(config, port, args) - if get_port_type(port) == "NETWORK" and "api" in config: - addresses_to_use = devices - if config[CONF_MDNS][CONF_DISABLED] and CONF_MQTT in config: - from esphome import mqtt - mqtt_address = mqtt.get_esphome_device_ip( + port_type = get_port_type(port) + + # Check if we should use API for logging + if has_api(): + addresses_to_use: list[str] | None = None + + if port_type == "NETWORK" and (has_mdns() or is_ip_address(port)): + addresses_to_use = devices + elif port_type in ("NETWORK", "MQTT", "MQTTIP") and has_mqtt_ip_lookup(): + # Only use MQTT IP lookup if the first condition didn't match + # (for MQTT/MQTTIP types, or for NETWORK when mdns/ip check fails) + addresses_to_use = mqtt_get_ip( config, args.username, args.password, args.client_id - )[0] - addresses_to_use = [mqtt_address] + ) - from esphome.components.api.client import run_logs + if addresses_to_use is not None: + from esphome.components.api.client import run_logs - return run_logs(config, addresses_to_use) - if get_port_type(port) in ("NETWORK", "MQTT") and "mqtt" in config: + return run_logs(config, addresses_to_use) + + if port_type in ("NETWORK", "MQTT") and has_mqtt_logging(): from esphome import mqtt return mqtt.show_logs( @@ -555,10 +643,7 @@ def command_upload(args: ArgsProtocol, config: ConfigType) -> int | None: devices = choose_upload_log_host( default=args.device, check_default=None, - show_ota=True, - show_mqtt=False, - show_api=False, - purpose="uploading", + purpose=Purpose.UPLOADING, ) exit_code, _ = upload_program(config, args, devices) @@ -583,10 +668,7 @@ def command_logs(args: ArgsProtocol, config: ConfigType) -> int | None: devices = choose_upload_log_host( default=args.device, check_default=None, - show_ota=False, - show_mqtt=True, - show_api=True, - purpose="logging", + purpose=Purpose.LOGGING, ) return show_logs(config, args, devices) @@ -612,10 +694,7 @@ def command_run(args: ArgsProtocol, config: ConfigType) -> int | None: devices = choose_upload_log_host( default=args.device, check_default=None, - show_ota=True, - show_mqtt=False, - show_api=True, - purpose="uploading", + purpose=Purpose.UPLOADING, ) exit_code, successful_device = upload_program(config, args, devices) @@ -632,10 +711,7 @@ def command_run(args: ArgsProtocol, config: ConfigType) -> int | None: devices = choose_upload_log_host( default=successful_device, check_default=successful_device, - show_ota=False, - show_mqtt=True, - show_api=True, - purpose="logging", + purpose=Purpose.LOGGING, ) return show_logs(config, args, devices) diff --git a/esphome/components/packet_transport/packet_transport.cpp b/esphome/components/packet_transport/packet_transport.cpp index b6ce24bc1b..8bde4ee505 100644 --- a/esphome/components/packet_transport/packet_transport.cpp +++ b/esphome/components/packet_transport/packet_transport.cpp @@ -270,6 +270,7 @@ void PacketTransport::add_binary_data_(uint8_t key, const char *id, bool data) { auto len = 1 + 1 + 1 + strlen(id); if (len + this->header_.size() + this->data_.size() > this->get_max_packet_size()) { this->flush_(); + this->init_data_(); } add(this->data_, key); add(this->data_, (uint8_t) data); @@ -284,6 +285,7 @@ void PacketTransport::add_data_(uint8_t key, const char *id, uint32_t data) { auto len = 4 + 1 + 1 + strlen(id); if (len + this->header_.size() + this->data_.size() > this->get_max_packet_size()) { this->flush_(); + this->init_data_(); } add(this->data_, key); add(this->data_, data); diff --git a/esphome/const.py b/esphome/const.py index 308abe7706..677b9173ec 100644 --- a/esphome/const.py +++ b/esphome/const.py @@ -114,6 +114,7 @@ CONF_AND = "and" CONF_ANGLE = "angle" CONF_ANY = "any" CONF_AP = "ap" +CONF_API = "api" CONF_APPARENT_POWER = "apparent_power" CONF_ARDUINO_VERSION = "arduino_version" CONF_AREA = "area" diff --git a/requirements_test.txt b/requirements_test.txt index 01661f3b7c..5ec9c98408 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -8,7 +8,7 @@ pre-commit pytest==8.4.2 pytest-cov==7.0.0 pytest-mock==3.15.0 -pytest-asyncio==1.1.0 +pytest-asyncio==1.2.0 pytest-xdist==3.8.0 asyncmock==0.4.2 hypothesis==6.92.1 diff --git a/tests/dashboard/test_web_server.py b/tests/dashboard/test_web_server.py index a22f4a8b2a..e206090ac0 100644 --- a/tests/dashboard/test_web_server.py +++ b/tests/dashboard/test_web_server.py @@ -556,6 +556,66 @@ def test_start_web_server_with_address_port( assert (archive_dir / "old.yaml").exists() +@pytest.mark.asyncio +async def test_edit_request_handler_get(dashboard: DashboardTestHelper) -> None: + """Test EditRequestHandler.get method.""" + # Test getting a valid yaml file + response = await dashboard.fetch("/edit?configuration=pico.yaml") + assert response.code == 200 + assert response.headers["content-type"] == "application/yaml" + content = response.body.decode() + assert "esphome:" in content # Verify it's a valid ESPHome config + + # Test getting a non-existent file + with pytest.raises(HTTPClientError) as exc_info: + await dashboard.fetch("/edit?configuration=nonexistent.yaml") + assert exc_info.value.code == 404 + + # Test getting a non-yaml file + with pytest.raises(HTTPClientError) as exc_info: + await dashboard.fetch("/edit?configuration=test.txt") + assert exc_info.value.code == 404 + + # Test path traversal attempt + with pytest.raises(HTTPClientError) as exc_info: + await dashboard.fetch("/edit?configuration=../../../etc/passwd") + assert exc_info.value.code == 404 + + +@pytest.mark.asyncio +async def test_archive_request_handler_post( + dashboard: DashboardTestHelper, + mock_archive_storage_path: MagicMock, + mock_ext_storage_path: MagicMock, + tmp_path: Path, +) -> None: + """Test ArchiveRequestHandler.post method.""" + + # Set up temp directories + config_dir = Path(get_fixture_path("conf")) + archive_dir = tmp_path / "archive" + + # Create a test configuration file + test_config = config_dir / "test_archive.yaml" + test_config.write_text("esphome:\n name: test_archive\n") + + # Archive the configuration + response = await dashboard.fetch( + "/archive", + method="POST", + body="configuration=test_archive.yaml", + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + assert response.code == 200 + + # Verify file was moved to archive + assert not test_config.exists() + assert (archive_dir / "test_archive.yaml").exists() + assert ( + archive_dir / "test_archive.yaml" + ).read_text() == "esphome:\n name: test_archive\n" + + @pytest.mark.skipif(os.name == "nt", reason="Unix sockets are not supported on Windows") @pytest.mark.usefixtures("mock_trash_storage_path", "mock_archive_storage_path") def test_start_web_server_with_unix_socket(tmp_path: Path) -> None: diff --git a/tests/unit_tests/test_main.py b/tests/unit_tests/test_main.py index 2c7236c7f8..bfebb44545 100644 --- a/tests/unit_tests/test_main.py +++ b/tests/unit_tests/test_main.py @@ -4,14 +4,55 @@ from __future__ import annotations from collections.abc import Generator from dataclasses import dataclass +from pathlib import Path from typing import Any -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pytest +from pytest import CaptureFixture -from esphome.__main__ import choose_upload_log_host -from esphome.const import CONF_BROKER, CONF_MQTT, CONF_USE_ADDRESS, CONF_WIFI -from esphome.core import CORE +from esphome.__main__ import ( + Purpose, + choose_upload_log_host, + command_rename, + command_wizard, + get_port_type, + has_ip_address, + has_mqtt, + has_mqtt_ip_lookup, + has_mqtt_logging, + has_non_ip_address, + has_resolvable_address, + mqtt_get_ip, + show_logs, + upload_program, +) +from esphome.const import ( + CONF_API, + CONF_BROKER, + CONF_DISABLED, + CONF_ESPHOME, + CONF_LEVEL, + CONF_LOG_TOPIC, + CONF_MDNS, + CONF_MQTT, + CONF_NAME, + CONF_OTA, + CONF_PASSWORD, + CONF_PLATFORM, + CONF_PORT, + CONF_SUBSTITUTIONS, + CONF_TOPIC, + CONF_USE_ADDRESS, + CONF_WIFI, + KEY_CORE, + KEY_TARGET_PLATFORM, + PLATFORM_BK72XX, + PLATFORM_ESP32, + PLATFORM_ESP8266, + PLATFORM_RP2040, +) +from esphome.core import CORE, EsphomeError @dataclass @@ -28,7 +69,11 @@ class MockSerialPort: def setup_core( - config: dict[str, Any] | None = None, address: str | None = None + config: dict[str, Any] | None = None, + address: str | None = None, + platform: str | None = None, + tmp_path: Path | None = None, + name: str = "test", ) -> None: """ Helper to set up CORE configuration with optional address. @@ -36,6 +81,9 @@ def setup_core( Args: config (dict[str, Any] | None): The configuration dictionary to set for CORE. If None, an empty dict is used. address (str | None): Optional network address to set in the configuration. If provided, it is set under the wifi config. + platform (str | None): Optional target platform to set in CORE.data. + tmp_path (Path | None): Optional temp path for setting up build paths. + name (str): The name of the device (defaults to "test"). """ if config is None: config = {} @@ -46,6 +94,15 @@ def setup_core( CORE.config = config + if platform is not None: + CORE.data[KEY_CORE] = {} + CORE.data[KEY_CORE][KEY_TARGET_PLATFORM] = platform + + if tmp_path is not None: + CORE.config_path = str(tmp_path / f"{name}.yaml") + CORE.name = name + CORE.build_path = str(tmp_path / ".esphome" / "build" / name) + @pytest.fixture def mock_no_serial_ports() -> Generator[Mock]: @@ -54,6 +111,62 @@ def mock_no_serial_ports() -> Generator[Mock]: yield mock +@pytest.fixture +def mock_get_port_type() -> Generator[Mock]: + """Mock get_port_type for testing.""" + with patch("esphome.__main__.get_port_type") as mock: + yield mock + + +@pytest.fixture +def mock_check_permissions() -> Generator[Mock]: + """Mock check_permissions for testing.""" + with patch("esphome.__main__.check_permissions") as mock: + yield mock + + +@pytest.fixture +def mock_run_miniterm() -> Generator[Mock]: + """Mock run_miniterm for testing.""" + with patch("esphome.__main__.run_miniterm") as mock: + yield mock + + +@pytest.fixture +def mock_upload_using_esptool() -> Generator[Mock]: + """Mock upload_using_esptool for testing.""" + with patch("esphome.__main__.upload_using_esptool") as mock: + yield mock + + +@pytest.fixture +def mock_upload_using_platformio() -> Generator[Mock]: + """Mock upload_using_platformio for testing.""" + with patch("esphome.__main__.upload_using_platformio") as mock: + yield mock + + +@pytest.fixture +def mock_run_ota() -> Generator[Mock]: + """Mock espota2.run_ota for testing.""" + with patch("esphome.espota2.run_ota") as mock: + yield mock + + +@pytest.fixture +def mock_is_ip_address() -> Generator[Mock]: + """Mock is_ip_address for testing.""" + with patch("esphome.__main__.is_ip_address") as mock: + yield mock + + +@pytest.fixture +def mock_mqtt_get_ip() -> Generator[Mock]: + """Mock mqtt_get_ip for testing.""" + with patch("esphome.__main__.mqtt_get_ip") as mock: + yield mock + + @pytest.fixture def mock_serial_ports() -> Generator[Mock]: """Mock get_serial_ports to return test ports.""" @@ -86,64 +199,66 @@ def mock_has_mqtt_logging() -> Generator[Mock]: yield mock +@pytest.fixture +def mock_run_external_process() -> Generator[Mock]: + """Mock run_external_process for testing.""" + with patch("esphome.__main__.run_external_process") as mock: + mock.return_value = 0 # Default to success + yield mock + + def test_choose_upload_log_host_with_string_default() -> None: """Test with a single string default device.""" + setup_core() result = choose_upload_log_host( default="192.168.1.100", check_default=None, - show_ota=False, - show_mqtt=False, - show_api=False, + purpose=Purpose.UPLOADING, ) assert result == ["192.168.1.100"] def test_choose_upload_log_host_with_list_default() -> None: """Test with a list of default devices.""" + setup_core() result = choose_upload_log_host( default=["192.168.1.100", "192.168.1.101"], check_default=None, - show_ota=False, - show_mqtt=False, - show_api=False, + purpose=Purpose.UPLOADING, ) assert result == ["192.168.1.100", "192.168.1.101"] def test_choose_upload_log_host_with_multiple_ip_addresses() -> None: """Test with multiple IP addresses as defaults.""" + setup_core() result = choose_upload_log_host( default=["1.2.3.4", "4.5.5.6"], check_default=None, - show_ota=False, - show_mqtt=False, - show_api=False, + purpose=Purpose.LOGGING, ) assert result == ["1.2.3.4", "4.5.5.6"] def test_choose_upload_log_host_with_mixed_hostnames_and_ips() -> None: """Test with a mix of hostnames and IP addresses.""" + setup_core() result = choose_upload_log_host( default=["host.one", "host.one.local", "1.2.3.4"], check_default=None, - show_ota=False, - show_mqtt=False, - show_api=False, + purpose=Purpose.UPLOADING, ) assert result == ["host.one", "host.one.local", "1.2.3.4"] def test_choose_upload_log_host_with_ota_list() -> None: """Test with OTA as the only item in the list.""" - setup_core(config={"ota": {}}, address="192.168.1.100") + setup_core(config={CONF_OTA: {}}, address="192.168.1.100") result = choose_upload_log_host( default=["OTA"], check_default=None, - show_ota=True, - show_mqtt=False, - show_api=False, + purpose=Purpose.UPLOADING, ) assert result == ["192.168.1.100"] @@ -151,16 +266,27 @@ def test_choose_upload_log_host_with_ota_list() -> None: @pytest.mark.usefixtures("mock_has_mqtt_logging") def test_choose_upload_log_host_with_ota_list_mqtt_fallback() -> None: """Test with OTA list falling back to MQTT when no address.""" - setup_core() + setup_core(config={CONF_OTA: {}, "mqtt": {}}) result = choose_upload_log_host( default=["OTA"], check_default=None, - show_ota=False, - show_mqtt=True, - show_api=False, + purpose=Purpose.UPLOADING, ) - assert result == ["MQTT"] + assert result == ["MQTTIP"] + + +@pytest.mark.usefixtures("mock_has_mqtt_logging") +def test_choose_upload_log_host_with_ota_list_mqtt_fallback_logging() -> None: + """Test with OTA list with API and MQTT when no address.""" + setup_core(config={CONF_API: {}, "mqtt": {}}) + + result = choose_upload_log_host( + default=["OTA"], + check_default=None, + purpose=Purpose.LOGGING, + ) + assert result == ["MQTTIP", "MQTT"] @pytest.mark.usefixtures("mock_no_serial_ports") @@ -168,12 +294,11 @@ def test_choose_upload_log_host_with_serial_device_no_ports( caplog: pytest.LogCaptureFixture, ) -> None: """Test SERIAL device when no serial ports are found.""" + setup_core() result = choose_upload_log_host( default="SERIAL", check_default=None, - show_ota=False, - show_mqtt=False, - show_api=False, + purpose=Purpose.UPLOADING, ) assert result == [] assert "No serial ports found, skipping SERIAL device" in caplog.text @@ -184,13 +309,11 @@ def test_choose_upload_log_host_with_serial_device_with_ports( mock_choose_prompt: Mock, ) -> None: """Test SERIAL device when serial ports are available.""" + setup_core() result = choose_upload_log_host( default="SERIAL", check_default=None, - show_ota=False, - show_mqtt=False, - show_api=False, - purpose="testing", + purpose=Purpose.UPLOADING, ) assert result == ["/dev/ttyUSB0"] mock_choose_prompt.assert_called_once_with( @@ -198,34 +321,42 @@ def test_choose_upload_log_host_with_serial_device_with_ports( ("/dev/ttyUSB0 (USB Serial)", "/dev/ttyUSB0"), ("/dev/ttyUSB1 (Another USB Serial)", "/dev/ttyUSB1"), ], - purpose="testing", + purpose=Purpose.UPLOADING, ) def test_choose_upload_log_host_with_ota_device_with_ota_config() -> None: """Test OTA device when OTA is configured.""" - setup_core(config={"ota": {}}, address="192.168.1.100") + setup_core(config={CONF_OTA: {}}, address="192.168.1.100") result = choose_upload_log_host( default="OTA", check_default=None, - show_ota=True, - show_mqtt=False, - show_api=False, + purpose=Purpose.UPLOADING, ) assert result == ["192.168.1.100"] def test_choose_upload_log_host_with_ota_device_with_api_config() -> None: - """Test OTA device when API is configured.""" - setup_core(config={"api": {}}, address="192.168.1.100") + """Test OTA device when API is configured (no upload without OTA in config).""" + setup_core(config={CONF_API: {}}, address="192.168.1.100") result = choose_upload_log_host( default="OTA", check_default=None, - show_ota=False, - show_mqtt=False, - show_api=True, + purpose=Purpose.UPLOADING, + ) + assert result == [] + + +def test_choose_upload_log_host_with_ota_device_with_api_config_logging() -> None: + """Test OTA device when API is configured.""" + setup_core(config={CONF_API: {}}, address="192.168.1.100") + + result = choose_upload_log_host( + default="OTA", + check_default=None, + purpose=Purpose.LOGGING, ) assert result == ["192.168.1.100"] @@ -233,14 +364,12 @@ def test_choose_upload_log_host_with_ota_device_with_api_config() -> None: @pytest.mark.usefixtures("mock_has_mqtt_logging") def test_choose_upload_log_host_with_ota_device_fallback_to_mqtt() -> None: """Test OTA device fallback to MQTT when no OTA/API config.""" - setup_core() + setup_core(config={"mqtt": {}}) result = choose_upload_log_host( default="OTA", check_default=None, - show_ota=False, - show_mqtt=True, - show_api=False, + purpose=Purpose.LOGGING, ) assert result == ["MQTT"] @@ -253,9 +382,7 @@ def test_choose_upload_log_host_with_ota_device_no_fallback() -> None: result = choose_upload_log_host( default="OTA", check_default=None, - show_ota=True, - show_mqtt=True, - show_api=False, + purpose=Purpose.UPLOADING, ) assert result == [] @@ -263,7 +390,7 @@ def test_choose_upload_log_host_with_ota_device_no_fallback() -> None: @pytest.mark.usefixtures("mock_choose_prompt") def test_choose_upload_log_host_multiple_devices() -> None: """Test with multiple devices including special identifiers.""" - setup_core(config={"ota": {}}, address="192.168.1.100") + setup_core(config={CONF_OTA: {}}, address="192.168.1.100") mock_ports = [MockSerialPort("/dev/ttyUSB0", "USB Serial")] @@ -271,9 +398,7 @@ def test_choose_upload_log_host_multiple_devices() -> None: result = choose_upload_log_host( default=["192.168.1.50", "OTA", "SERIAL"], check_default=None, - show_ota=True, - show_mqtt=False, - show_api=False, + purpose=Purpose.UPLOADING, ) assert result == ["192.168.1.50", "192.168.1.100", "/dev/ttyUSB0"] @@ -292,22 +417,19 @@ def test_choose_upload_log_host_no_defaults_with_serial_ports( result = choose_upload_log_host( default=None, check_default=None, - show_ota=False, - show_mqtt=False, - show_api=False, - purpose="uploading", + purpose=Purpose.UPLOADING, ) assert result == ["/dev/ttyUSB0"] mock_choose_prompt.assert_called_once_with( [("/dev/ttyUSB0 (USB Serial)", "/dev/ttyUSB0")], - purpose="uploading", + purpose=Purpose.UPLOADING, ) @pytest.mark.usefixtures("mock_no_serial_ports") def test_choose_upload_log_host_no_defaults_with_ota() -> None: """Test interactive mode with OTA option.""" - setup_core(config={"ota": {}}, address="192.168.1.100") + setup_core(config={CONF_OTA: {}}, address="192.168.1.100") with patch( "esphome.__main__.choose_prompt", return_value="192.168.1.100" @@ -315,21 +437,19 @@ def test_choose_upload_log_host_no_defaults_with_ota() -> None: result = choose_upload_log_host( default=None, check_default=None, - show_ota=True, - show_mqtt=False, - show_api=False, + purpose=Purpose.UPLOADING, ) assert result == ["192.168.1.100"] mock_prompt.assert_called_once_with( [("Over The Air (192.168.1.100)", "192.168.1.100")], - purpose=None, + purpose=Purpose.UPLOADING, ) @pytest.mark.usefixtures("mock_no_serial_ports") def test_choose_upload_log_host_no_defaults_with_api() -> None: """Test interactive mode with API option.""" - setup_core(config={"api": {}}, address="192.168.1.100") + setup_core(config={CONF_API: {}}, address="192.168.1.100") with patch( "esphome.__main__.choose_prompt", return_value="192.168.1.100" @@ -337,14 +457,12 @@ def test_choose_upload_log_host_no_defaults_with_api() -> None: result = choose_upload_log_host( default=None, check_default=None, - show_ota=False, - show_mqtt=False, - show_api=True, + purpose=Purpose.LOGGING, ) assert result == ["192.168.1.100"] mock_prompt.assert_called_once_with( [("Over The Air (192.168.1.100)", "192.168.1.100")], - purpose=None, + purpose=Purpose.LOGGING, ) @@ -357,14 +475,12 @@ def test_choose_upload_log_host_no_defaults_with_mqtt() -> None: result = choose_upload_log_host( default=None, check_default=None, - show_ota=False, - show_mqtt=True, - show_api=False, + purpose=Purpose.LOGGING, ) assert result == ["MQTT"] mock_prompt.assert_called_once_with( [("MQTT (mqtt.local)", "MQTT")], - purpose=None, + purpose=Purpose.LOGGING, ) @@ -374,7 +490,7 @@ def test_choose_upload_log_host_no_defaults_with_all_options( ) -> None: """Test interactive mode with all options available.""" setup_core( - config={"ota": {}, "api": {}, CONF_MQTT: {CONF_BROKER: "mqtt.local"}}, + config={CONF_OTA: {}, CONF_API: {}, CONF_MQTT: {CONF_BROKER: "mqtt.local"}}, address="192.168.1.100", ) @@ -384,32 +500,59 @@ def test_choose_upload_log_host_no_defaults_with_all_options( result = choose_upload_log_host( default=None, check_default=None, - show_ota=True, - show_mqtt=True, - show_api=True, - purpose="testing", + purpose=Purpose.UPLOADING, ) assert result == ["/dev/ttyUSB0"] expected_options = [ ("/dev/ttyUSB0 (USB Serial)", "/dev/ttyUSB0"), ("Over The Air (192.168.1.100)", "192.168.1.100"), - ("MQTT (mqtt.local)", "MQTT"), + ("Over The Air (MQTT IP lookup)", "MQTTIP"), ] - mock_choose_prompt.assert_called_once_with(expected_options, purpose="testing") + mock_choose_prompt.assert_called_once_with( + expected_options, purpose=Purpose.UPLOADING + ) + + +def test_choose_upload_log_host_no_defaults_with_all_options_logging( + mock_choose_prompt: Mock, +) -> None: + """Test interactive mode with all options available.""" + setup_core( + config={CONF_OTA: {}, CONF_API: {}, CONF_MQTT: {CONF_BROKER: "mqtt.local"}}, + address="192.168.1.100", + ) + + mock_ports = [MockSerialPort("/dev/ttyUSB0", "USB Serial")] + + with patch("esphome.__main__.get_serial_ports", return_value=mock_ports): + result = choose_upload_log_host( + default=None, + check_default=None, + purpose=Purpose.LOGGING, + ) + assert result == ["/dev/ttyUSB0"] + + expected_options = [ + ("/dev/ttyUSB0 (USB Serial)", "/dev/ttyUSB0"), + ("MQTT (mqtt.local)", "MQTT"), + ("Over The Air (192.168.1.100)", "192.168.1.100"), + ("Over The Air (MQTT IP lookup)", "MQTTIP"), + ] + mock_choose_prompt.assert_called_once_with( + expected_options, purpose=Purpose.LOGGING + ) @pytest.mark.usefixtures("mock_no_serial_ports") def test_choose_upload_log_host_check_default_matches() -> None: """Test when check_default matches an available option.""" - setup_core(config={"ota": {}}, address="192.168.1.100") + setup_core(config={CONF_OTA: {}}, address="192.168.1.100") result = choose_upload_log_host( default=None, check_default="192.168.1.100", - show_ota=True, - show_mqtt=False, - show_api=False, + purpose=Purpose.UPLOADING, ) assert result == ["192.168.1.100"] @@ -425,9 +568,7 @@ def test_choose_upload_log_host_check_default_no_match() -> None: result = choose_upload_log_host( default=None, check_default="192.168.1.100", - show_ota=False, - show_mqtt=False, - show_api=False, + purpose=Purpose.UPLOADING, ) assert result == ["fallback"] mock_prompt.assert_called_once() @@ -436,13 +577,12 @@ def test_choose_upload_log_host_check_default_no_match() -> None: @pytest.mark.usefixtures("mock_no_serial_ports") def test_choose_upload_log_host_empty_defaults_list() -> None: """Test with an empty list as default.""" + setup_core() with patch("esphome.__main__.choose_prompt", return_value="chosen") as mock_prompt: result = choose_upload_log_host( default=[], check_default=None, - show_ota=False, - show_mqtt=False, - show_api=False, + purpose=Purpose.UPLOADING, ) assert result == ["chosen"] mock_prompt.assert_called_once() @@ -458,9 +598,7 @@ def test_choose_upload_log_host_all_devices_unresolved( result = choose_upload_log_host( default=["SERIAL", "OTA"], check_default=None, - show_ota=False, - show_mqtt=False, - show_api=False, + purpose=Purpose.UPLOADING, ) assert result == [] assert ( @@ -476,37 +614,920 @@ def test_choose_upload_log_host_mixed_resolved_unresolved() -> None: result = choose_upload_log_host( default=["192.168.1.50", "SERIAL", "OTA"], check_default=None, - show_ota=False, - show_mqtt=False, - show_api=False, + purpose=Purpose.UPLOADING, ) assert result == ["192.168.1.50"] def test_choose_upload_log_host_ota_both_conditions() -> None: """Test OTA device when both OTA and API are configured and enabled.""" - setup_core(config={"ota": {}, "api": {}}, address="192.168.1.100") + setup_core(config={CONF_OTA: {}, CONF_API: {}}, address="192.168.1.100") result = choose_upload_log_host( default="OTA", check_default=None, - show_ota=True, - show_mqtt=False, - show_api=True, + purpose=Purpose.UPLOADING, ) assert result == ["192.168.1.100"] +@pytest.mark.usefixtures("mock_serial_ports") +def test_choose_upload_log_host_ota_ip_all_options() -> None: + """Test OTA device when both static IP, OTA, API and MQTT are configured and enabled but MDNS not.""" + setup_core( + config={ + CONF_OTA: {}, + CONF_API: {}, + CONF_MQTT: { + CONF_BROKER: "mqtt.local", + }, + CONF_MDNS: { + CONF_DISABLED: True, + }, + }, + address="192.168.1.100", + ) + + result = choose_upload_log_host( + default="OTA", + check_default=None, + purpose=Purpose.UPLOADING, + ) + assert result == ["192.168.1.100", "MQTTIP"] + + +@pytest.mark.usefixtures("mock_serial_ports") +def test_choose_upload_log_host_ota_local_all_options() -> None: + """Test OTA device when both static IP, OTA, API and MQTT are configured and enabled but MDNS not.""" + setup_core( + config={ + CONF_OTA: {}, + CONF_API: {}, + CONF_MQTT: { + CONF_BROKER: "mqtt.local", + }, + CONF_MDNS: { + CONF_DISABLED: True, + }, + }, + address="test.local", + ) + + result = choose_upload_log_host( + default="OTA", + check_default=None, + purpose=Purpose.UPLOADING, + ) + assert result == ["MQTTIP", "test.local"] + + +@pytest.mark.usefixtures("mock_serial_ports") +def test_choose_upload_log_host_ota_ip_all_options_logging() -> None: + """Test OTA device when both static IP, OTA, API and MQTT are configured and enabled but MDNS not.""" + setup_core( + config={ + CONF_OTA: {}, + CONF_API: {}, + CONF_MQTT: { + CONF_BROKER: "mqtt.local", + }, + CONF_MDNS: { + CONF_DISABLED: True, + }, + }, + address="192.168.1.100", + ) + + result = choose_upload_log_host( + default="OTA", + check_default=None, + purpose=Purpose.LOGGING, + ) + assert result == ["192.168.1.100", "MQTTIP", "MQTT"] + + +@pytest.mark.usefixtures("mock_serial_ports") +def test_choose_upload_log_host_ota_local_all_options_logging() -> None: + """Test OTA device when both static IP, OTA, API and MQTT are configured and enabled but MDNS not.""" + setup_core( + config={ + CONF_OTA: {}, + CONF_API: {}, + CONF_MQTT: { + CONF_BROKER: "mqtt.local", + }, + CONF_MDNS: { + CONF_DISABLED: True, + }, + }, + address="test.local", + ) + + result = choose_upload_log_host( + default="OTA", + check_default=None, + purpose=Purpose.LOGGING, + ) + assert result == ["MQTTIP", "MQTT", "test.local"] + + @pytest.mark.usefixtures("mock_no_mqtt_logging") def test_choose_upload_log_host_no_address_with_ota_config() -> None: """Test OTA device when OTA is configured but no address is set.""" - setup_core(config={"ota": {}}) + setup_core(config={CONF_OTA: {}}) result = choose_upload_log_host( default="OTA", check_default=None, - show_ota=True, - show_mqtt=False, - show_api=False, + purpose=Purpose.UPLOADING, ) assert result == [] + + +@dataclass +class MockArgs: + """Mock args for testing.""" + + file: str | None = None + upload_speed: int = 460800 + username: str | None = None + password: str | None = None + client_id: str | None = None + topic: str | None = None + configuration: str | None = None + name: str | None = None + dashboard: bool = False + + +def test_upload_program_serial_esp32( + mock_upload_using_esptool: Mock, + mock_get_port_type: Mock, + mock_check_permissions: Mock, +) -> None: + """Test upload_program with serial port for ESP32.""" + setup_core(platform=PLATFORM_ESP32) + mock_get_port_type.return_value = "SERIAL" + mock_upload_using_esptool.return_value = 0 + + config = {} + args = MockArgs() + devices = ["/dev/ttyUSB0"] + + exit_code, host = upload_program(config, args, devices) + + assert exit_code == 0 + assert host == "/dev/ttyUSB0" + mock_check_permissions.assert_called_once_with("/dev/ttyUSB0") + mock_upload_using_esptool.assert_called_once() + + +def test_upload_program_serial_esp8266_with_file( + mock_upload_using_esptool: Mock, + mock_get_port_type: Mock, + mock_check_permissions: Mock, +) -> None: + """Test upload_program with serial port for ESP8266 with custom file.""" + setup_core(platform=PLATFORM_ESP8266) + mock_get_port_type.return_value = "SERIAL" + mock_upload_using_esptool.return_value = 0 + + config = {} + args = MockArgs(file="firmware.bin") + devices = ["/dev/ttyUSB0"] + + exit_code, host = upload_program(config, args, devices) + + assert exit_code == 0 + assert host == "/dev/ttyUSB0" + mock_check_permissions.assert_called_once_with("/dev/ttyUSB0") + mock_upload_using_esptool.assert_called_once_with( + config, "/dev/ttyUSB0", "firmware.bin", 460800 + ) + + +@pytest.mark.parametrize( + "platform,device", + [ + (PLATFORM_RP2040, "/dev/ttyACM0"), + (PLATFORM_BK72XX, "/dev/ttyUSB0"), # LibreTiny platform + ], +) +def test_upload_program_serial_platformio_platforms( + mock_upload_using_platformio: Mock, + mock_get_port_type: Mock, + mock_check_permissions: Mock, + platform: str, + device: str, +) -> None: + """Test upload_program with serial port for platformio platforms (RP2040/LibreTiny).""" + setup_core(platform=platform) + mock_get_port_type.return_value = "SERIAL" + mock_upload_using_platformio.return_value = 0 + + config = {} + args = MockArgs() + devices = [device] + + exit_code, host = upload_program(config, args, devices) + + assert exit_code == 0 + assert host == device + mock_check_permissions.assert_called_once_with(device) + mock_upload_using_platformio.assert_called_once_with(config, device) + + +def test_upload_program_serial_upload_failed( + mock_upload_using_esptool: Mock, + mock_get_port_type: Mock, + mock_check_permissions: Mock, +) -> None: + """Test upload_program when serial upload fails.""" + setup_core(platform=PLATFORM_ESP32) + mock_get_port_type.return_value = "SERIAL" + mock_upload_using_esptool.return_value = 1 # Failed + + config = {} + args = MockArgs() + devices = ["/dev/ttyUSB0"] + + exit_code, host = upload_program(config, args, devices) + + assert exit_code == 1 + assert host is None + mock_check_permissions.assert_called_once_with("/dev/ttyUSB0") + mock_upload_using_esptool.assert_called_once() + + +def test_upload_program_ota_success( + mock_run_ota: Mock, + mock_get_port_type: Mock, + tmp_path: Path, +) -> None: + """Test upload_program with OTA.""" + setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path) + + mock_get_port_type.return_value = "NETWORK" + mock_run_ota.return_value = (0, "192.168.1.100") + + config = { + CONF_OTA: [ + { + CONF_PLATFORM: CONF_ESPHOME, + CONF_PORT: 3232, + CONF_PASSWORD: "secret", + } + ] + } + args = MockArgs() + devices = ["192.168.1.100"] + + exit_code, host = upload_program(config, args, devices) + + assert exit_code == 0 + assert host == "192.168.1.100" + expected_firmware = str( + tmp_path / ".esphome" / "build" / "test" / ".pioenvs" / "test" / "firmware.bin" + ) + mock_run_ota.assert_called_once_with( + ["192.168.1.100"], 3232, "secret", expected_firmware + ) + + +def test_upload_program_ota_with_file_arg( + mock_run_ota: Mock, + mock_get_port_type: Mock, + tmp_path: Path, +) -> None: + """Test upload_program with OTA and custom file.""" + setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path) + + mock_get_port_type.return_value = "NETWORK" + mock_run_ota.return_value = (0, "192.168.1.100") + + config = { + CONF_OTA: [ + { + CONF_PLATFORM: CONF_ESPHOME, + CONF_PORT: 3232, + } + ] + } + args = MockArgs(file="custom.bin") + devices = ["192.168.1.100"] + + exit_code, host = upload_program(config, args, devices) + + assert exit_code == 0 + assert host == "192.168.1.100" + mock_run_ota.assert_called_once_with(["192.168.1.100"], 3232, "", "custom.bin") + + +def test_upload_program_ota_no_config( + mock_get_port_type: Mock, +) -> None: + """Test upload_program with OTA but no OTA config.""" + setup_core(platform=PLATFORM_ESP32) + mock_get_port_type.return_value = "NETWORK" + + config = {} # No OTA config + args = MockArgs() + devices = ["192.168.1.100"] + + with pytest.raises(EsphomeError, match="Cannot upload Over the Air"): + upload_program(config, args, devices) + + +def test_upload_program_ota_with_mqtt_resolution( + mock_mqtt_get_ip: Mock, + mock_is_ip_address: Mock, + mock_run_ota: Mock, + tmp_path: Path, +) -> None: + """Test upload_program with OTA using MQTT for address resolution.""" + setup_core(address="device.local", platform=PLATFORM_ESP32, tmp_path=tmp_path) + + mock_is_ip_address.return_value = False + mock_mqtt_get_ip.return_value = ["192.168.1.100"] + mock_run_ota.return_value = (0, "192.168.1.100") + + config = { + CONF_OTA: [ + { + CONF_PLATFORM: CONF_ESPHOME, + CONF_PORT: 3232, + } + ], + CONF_MQTT: { + CONF_BROKER: "mqtt.local", + }, + CONF_MDNS: { + CONF_DISABLED: True, + }, + } + args = MockArgs(username="user", password="pass", client_id="client") + devices = ["MQTT"] + + exit_code, host = upload_program(config, args, devices) + + assert exit_code == 0 + assert host == "192.168.1.100" + mock_mqtt_get_ip.assert_called_once_with(config, "user", "pass", "client") + expected_firmware = str( + tmp_path / ".esphome" / "build" / "test" / ".pioenvs" / "test" / "firmware.bin" + ) + mock_run_ota.assert_called_once_with(["192.168.1.100"], 3232, "", expected_firmware) + + +@patch("esphome.__main__.importlib.import_module") +def test_upload_program_platform_specific_handler( + mock_import: Mock, + mock_get_port_type: Mock, +) -> None: + """Test upload_program with platform-specific upload handler.""" + setup_core(platform="custom_platform") + mock_get_port_type.return_value = "CUSTOM" + + mock_module = MagicMock() + mock_module.upload_program.return_value = True + mock_import.return_value = mock_module + + config = {} + args = MockArgs() + devices = ["custom_device"] + + exit_code, host = upload_program(config, args, devices) + + assert exit_code == 0 + assert host == "custom_device" + mock_import.assert_called_once_with("esphome.components.custom_platform") + mock_module.upload_program.assert_called_once_with(config, args, "custom_device") + + +def test_show_logs_serial( + mock_get_port_type: Mock, + mock_check_permissions: Mock, + mock_run_miniterm: Mock, +) -> None: + """Test show_logs with serial port.""" + setup_core(config={"logger": {}}, platform=PLATFORM_ESP32) + mock_get_port_type.return_value = "SERIAL" + mock_run_miniterm.return_value = 0 + + args = MockArgs() + devices = ["/dev/ttyUSB0"] + + result = show_logs(CORE.config, args, devices) + + assert result == 0 + mock_check_permissions.assert_called_once_with("/dev/ttyUSB0") + mock_run_miniterm.assert_called_once_with(CORE.config, "/dev/ttyUSB0", args) + + +def test_show_logs_no_logger() -> None: + """Test show_logs when logger is not configured.""" + setup_core(config={}, platform=PLATFORM_ESP32) # No logger config + args = MockArgs() + devices = ["/dev/ttyUSB0"] + + with pytest.raises(EsphomeError, match="Logger is not configured"): + show_logs(CORE.config, args, devices) + + +@patch("esphome.components.api.client.run_logs") +def test_show_logs_api( + mock_run_logs: Mock, +) -> None: + """Test show_logs with API.""" + setup_core( + config={ + "logger": {}, + CONF_API: {}, + CONF_MDNS: {CONF_DISABLED: False}, + }, + platform=PLATFORM_ESP32, + ) + mock_run_logs.return_value = 0 + + args = MockArgs() + devices = ["192.168.1.100", "192.168.1.101"] + + result = show_logs(CORE.config, args, devices) + + assert result == 0 + mock_run_logs.assert_called_once_with( + CORE.config, ["192.168.1.100", "192.168.1.101"] + ) + + +@patch("esphome.components.api.client.run_logs") +def test_show_logs_api_with_mqtt_fallback( + mock_run_logs: Mock, + mock_mqtt_get_ip: Mock, +) -> None: + """Test show_logs with API using MQTT for address resolution.""" + setup_core( + config={ + "logger": {}, + CONF_API: {}, + CONF_MDNS: {CONF_DISABLED: True}, + CONF_MQTT: {CONF_BROKER: "mqtt.local"}, + }, + platform=PLATFORM_ESP32, + ) + mock_run_logs.return_value = 0 + mock_mqtt_get_ip.return_value = ["192.168.1.200"] + + args = MockArgs(username="user", password="pass", client_id="client") + devices = ["device.local"] + + result = show_logs(CORE.config, args, devices) + + assert result == 0 + mock_mqtt_get_ip.assert_called_once_with(CORE.config, "user", "pass", "client") + mock_run_logs.assert_called_once_with(CORE.config, ["192.168.1.200"]) + + +@patch("esphome.mqtt.show_logs") +def test_show_logs_mqtt( + mock_mqtt_show_logs: Mock, +) -> None: + """Test show_logs with MQTT.""" + setup_core( + config={ + "logger": {}, + "mqtt": {CONF_BROKER: "mqtt.local"}, + }, + platform=PLATFORM_ESP32, + ) + mock_mqtt_show_logs.return_value = 0 + + args = MockArgs( + topic="esphome/logs", + username="user", + password="pass", + client_id="client", + ) + devices = ["MQTT"] + + result = show_logs(CORE.config, args, devices) + + assert result == 0 + mock_mqtt_show_logs.assert_called_once_with( + CORE.config, "esphome/logs", "user", "pass", "client" + ) + + +@patch("esphome.mqtt.show_logs") +def test_show_logs_network_with_mqtt_only( + mock_mqtt_show_logs: Mock, +) -> None: + """Test show_logs with network port but only MQTT configured.""" + setup_core( + config={ + "logger": {}, + "mqtt": {CONF_BROKER: "mqtt.local"}, + # No API configured + }, + platform=PLATFORM_ESP32, + ) + mock_mqtt_show_logs.return_value = 0 + + args = MockArgs( + topic="esphome/logs", + username="user", + password="pass", + client_id="client", + ) + devices = ["192.168.1.100"] + + result = show_logs(CORE.config, args, devices) + + assert result == 0 + mock_mqtt_show_logs.assert_called_once_with( + CORE.config, "esphome/logs", "user", "pass", "client" + ) + + +def test_show_logs_no_method_configured() -> None: + """Test show_logs when no remote logging method is configured.""" + setup_core( + config={ + "logger": {}, + # No API or MQTT configured + }, + platform=PLATFORM_ESP32, + ) + + args = MockArgs() + devices = ["192.168.1.100"] + + with pytest.raises( + EsphomeError, match="No remote or local logging method configured" + ): + show_logs(CORE.config, args, devices) + + +@patch("esphome.__main__.importlib.import_module") +def test_show_logs_platform_specific_handler( + mock_import: Mock, +) -> None: + """Test show_logs with platform-specific logs handler.""" + setup_core(platform="custom_platform", config={"logger": {}}) + + mock_module = MagicMock() + mock_module.show_logs.return_value = True + mock_import.return_value = mock_module + + config = {"logger": {}} + args = MockArgs() + devices = ["custom_device"] + + result = show_logs(config, args, devices) + + assert result == 0 + mock_import.assert_called_once_with("esphome.components.custom_platform") + mock_module.show_logs.assert_called_once_with(config, args, devices) + + +def test_has_mqtt_logging_no_log_topic() -> None: + """Test has_mqtt_logging returns True when CONF_LOG_TOPIC is not in mqtt_config.""" + + # Setup MQTT config without CONF_LOG_TOPIC (defaults to enabled - this is the missing test case) + setup_core(config={CONF_MQTT: {CONF_BROKER: "mqtt.local"}}) + assert has_mqtt_logging() is True + + # Setup MQTT config with CONF_LOG_TOPIC set to None (explicitly disabled) + setup_core(config={CONF_MQTT: {CONF_BROKER: "mqtt.local", CONF_LOG_TOPIC: None}}) + assert has_mqtt_logging() is False + + # Setup MQTT config with CONF_LOG_TOPIC set with topic and level (explicitly enabled) + setup_core( + config={ + CONF_MQTT: { + CONF_BROKER: "mqtt.local", + CONF_LOG_TOPIC: {CONF_TOPIC: "esphome/logs", CONF_LEVEL: "DEBUG"}, + } + } + ) + assert has_mqtt_logging() is True + + # Setup MQTT config with CONF_LOG_TOPIC set but level is NONE (disabled) + setup_core( + config={ + CONF_MQTT: { + CONF_BROKER: "mqtt.local", + CONF_LOG_TOPIC: {CONF_TOPIC: "esphome/logs", CONF_LEVEL: "NONE"}, + } + } + ) + assert has_mqtt_logging() is False + + # Setup without MQTT config at all + setup_core(config={}) + assert has_mqtt_logging() is False + + +def test_has_mqtt() -> None: + """Test has_mqtt function.""" + + # Test with MQTT configured + setup_core(config={CONF_MQTT: {CONF_BROKER: "mqtt.local"}}) + assert has_mqtt() is True + + # Test without MQTT configured + setup_core(config={}) + assert has_mqtt() is False + + # Test with other components but no MQTT + setup_core(config={CONF_API: {}, CONF_OTA: {}}) + assert has_mqtt() is False + + +def test_get_port_type() -> None: + """Test get_port_type function.""" + + assert get_port_type("/dev/ttyUSB0") == "SERIAL" + assert get_port_type("/dev/ttyACM0") == "SERIAL" + assert get_port_type("COM1") == "SERIAL" + assert get_port_type("COM10") == "SERIAL" + + assert get_port_type("MQTT") == "MQTT" + assert get_port_type("MQTTIP") == "MQTTIP" + + assert get_port_type("192.168.1.100") == "NETWORK" + assert get_port_type("esphome-device.local") == "NETWORK" + assert get_port_type("10.0.0.1") == "NETWORK" + + +def test_has_mqtt_ip_lookup() -> None: + """Test has_mqtt_ip_lookup function.""" + + CONF_DISCOVER_IP = "discover_ip" + + setup_core(config={}) + assert has_mqtt_ip_lookup() is False + + setup_core(config={CONF_MQTT: {CONF_BROKER: "mqtt.local"}}) + assert has_mqtt_ip_lookup() is True + + setup_core(config={CONF_MQTT: {CONF_BROKER: "mqtt.local", CONF_DISCOVER_IP: True}}) + assert has_mqtt_ip_lookup() is True + + setup_core(config={CONF_MQTT: {CONF_BROKER: "mqtt.local", CONF_DISCOVER_IP: False}}) + assert has_mqtt_ip_lookup() is False + + +def test_has_non_ip_address() -> None: + """Test has_non_ip_address function.""" + + setup_core(address=None) + assert has_non_ip_address() is False + + setup_core(address="192.168.1.100") + assert has_non_ip_address() is False + + setup_core(address="10.0.0.1") + assert has_non_ip_address() is False + + setup_core(address="esphome-device.local") + assert has_non_ip_address() is True + + setup_core(address="my-device") + assert has_non_ip_address() is True + + +def test_has_ip_address() -> None: + """Test has_ip_address function.""" + + setup_core(address=None) + assert has_ip_address() is False + + setup_core(address="192.168.1.100") + assert has_ip_address() is True + + setup_core(address="10.0.0.1") + assert has_ip_address() is True + + setup_core(address="esphome-device.local") + assert has_ip_address() is False + + setup_core(address="my-device") + assert has_ip_address() is False + + +def test_mqtt_get_ip() -> None: + """Test mqtt_get_ip function.""" + config = {CONF_MQTT: {CONF_BROKER: "mqtt.local"}} + + with patch("esphome.mqtt.get_esphome_device_ip") as mock_get_ip: + mock_get_ip.return_value = ["192.168.1.100", "192.168.1.101"] + + result = mqtt_get_ip(config, "user", "pass", "client-id") + + assert result == ["192.168.1.100", "192.168.1.101"] + mock_get_ip.assert_called_once_with(config, "user", "pass", "client-id") + + +def test_has_resolvable_address() -> None: + """Test has_resolvable_address function.""" + + # Test with mDNS enabled and hostname address + setup_core(config={}, address="esphome-device.local") + assert has_resolvable_address() is True + + # Test with mDNS disabled and hostname address + setup_core( + config={CONF_MDNS: {CONF_DISABLED: True}}, address="esphome-device.local" + ) + assert has_resolvable_address() is False + + # Test with IP address (mDNS doesn't matter) + setup_core(config={}, address="192.168.1.100") + assert has_resolvable_address() is True + + # Test with IP address and mDNS disabled + setup_core(config={CONF_MDNS: {CONF_DISABLED: True}}, address="192.168.1.100") + assert has_resolvable_address() is True + + # Test with no address but mDNS enabled (can still resolve mDNS names) + setup_core(config={}, address=None) + assert has_resolvable_address() is True + + # Test with no address and mDNS disabled + setup_core(config={CONF_MDNS: {CONF_DISABLED: True}}, address=None) + assert has_resolvable_address() is False + + +def test_command_wizard(tmp_path: Path) -> None: + """Test command_wizard function.""" + config_file = tmp_path / "test.yaml" + + # Mock wizard.wizard to avoid interactive prompts + with patch("esphome.wizard.wizard") as mock_wizard: + mock_wizard.return_value = 0 + + args = MockArgs(configuration=str(config_file)) + result = command_wizard(args) + + assert result == 0 + mock_wizard.assert_called_once_with(str(config_file)) + + +def test_command_rename_invalid_characters( + tmp_path: Path, capfd: CaptureFixture[str] +) -> None: + """Test command_rename with invalid characters in name.""" + setup_core(tmp_path=tmp_path) + + # Test with invalid character (space) + args = MockArgs(name="invalid name") + result = command_rename(args, {}) + + assert result == 1 + captured = capfd.readouterr() + assert "invalid character" in captured.out.lower() + + +def test_command_rename_complex_yaml( + tmp_path: Path, capfd: CaptureFixture[str] +) -> None: + """Test command_rename with complex YAML that cannot be renamed.""" + config_file = tmp_path / "test.yaml" + config_file.write_text("# Complex YAML without esphome section\nsome_key: value\n") + setup_core(tmp_path=tmp_path) + CORE.config_path = str(config_file) + + args = MockArgs(name="newname") + result = command_rename(args, {}) + + assert result == 1 + captured = capfd.readouterr() + assert "complex yaml" in captured.out.lower() + + +def test_command_rename_success( + tmp_path: Path, + capfd: CaptureFixture[str], + mock_run_external_process: Mock, +) -> None: + """Test successful rename of a simple configuration.""" + config_file = tmp_path / "oldname.yaml" + config_file.write_text(""" +esphome: + name: oldname + +esp32: + board: nodemcu-32s + +wifi: + ssid: "test" + password: "test1234" +""") + setup_core(tmp_path=tmp_path) + CORE.config_path = str(config_file) + + # Set up CORE.config to avoid ValueError when accessing CORE.address + CORE.config = {CONF_ESPHOME: {CONF_NAME: "oldname"}} + + args = MockArgs(name="newname", dashboard=False) + + # Simulate successful validation and upload + mock_run_external_process.return_value = 0 + + result = command_rename(args, {}) + + assert result == 0 + + # Verify new file was created + new_file = tmp_path / "newname.yaml" + assert new_file.exists() + + # Verify old file was removed + assert not config_file.exists() + + # Verify content was updated + content = new_file.read_text() + assert ( + 'name: "newname"' in content + or "name: 'newname'" in content + or "name: newname" in content + ) + + captured = capfd.readouterr() + assert "SUCCESS" in captured.out + + +def test_command_rename_with_substitutions( + tmp_path: Path, + mock_run_external_process: Mock, +) -> None: + """Test rename with substitutions in YAML.""" + config_file = tmp_path / "oldname.yaml" + config_file.write_text(""" +substitutions: + device_name: oldname + +esphome: + name: ${device_name} + +esp32: + board: nodemcu-32s +""") + setup_core(tmp_path=tmp_path) + CORE.config_path = str(config_file) + + # Set up CORE.config to avoid ValueError when accessing CORE.address + CORE.config = { + CONF_ESPHOME: {CONF_NAME: "oldname"}, + CONF_SUBSTITUTIONS: {"device_name": "oldname"}, + } + + args = MockArgs(name="newname", dashboard=False) + + mock_run_external_process.return_value = 0 + + result = command_rename(args, {}) + + assert result == 0 + + # Verify substitution was updated + new_file = tmp_path / "newname.yaml" + content = new_file.read_text() + assert 'device_name: "newname"' in content + + +def test_command_rename_validation_failure( + tmp_path: Path, + capfd: CaptureFixture[str], + mock_run_external_process: Mock, +) -> None: + """Test rename when validation fails.""" + config_file = tmp_path / "oldname.yaml" + config_file.write_text(""" +esphome: + name: oldname + +esp32: + board: nodemcu-32s +""") + setup_core(tmp_path=tmp_path) + CORE.config_path = str(config_file) + + args = MockArgs(name="newname", dashboard=False) + + # First call for validation fails + mock_run_external_process.return_value = 1 + + result = command_rename(args, {}) + + assert result == 1 + + # Verify new file was created but then removed due to failure + new_file = tmp_path / "newname.yaml" + assert not new_file.exists() + + # Verify old file still exists (not removed on failure) + assert config_file.exists() + + captured = capfd.readouterr() + assert "Rename failed" in captured.out diff --git a/tests/unit_tests/test_util.py b/tests/unit_tests/test_util.py index 74d6a74709..34f40a651f 100644 --- a/tests/unit_tests/test_util.py +++ b/tests/unit_tests/test_util.py @@ -141,3 +141,170 @@ def test_list_yaml_files_mixed_extensions(tmp_path: Path) -> None: str(yaml_file), str(yml_file), } + + +def test_list_yaml_files_does_not_recurse_into_subdirectories(tmp_path: Path) -> None: + """Test that list_yaml_files only finds files in specified directory, not subdirectories.""" + # Create directory structure with YAML files at different depths + root = tmp_path / "configs" + root.mkdir() + + # Create YAML files in the root directory + (root / "config1.yaml").write_text("test: 1") + (root / "config2.yml").write_text("test: 2") + (root / "device.yaml").write_text("test: device") + + # Create subdirectory with YAML files (should NOT be found) + subdir = root / "subdir" + subdir.mkdir() + (subdir / "nested1.yaml").write_text("test: nested1") + (subdir / "nested2.yml").write_text("test: nested2") + + # Create deeper subdirectory (should NOT be found) + deep_subdir = subdir / "deeper" + deep_subdir.mkdir() + (deep_subdir / "very_nested.yaml").write_text("test: very_nested") + + # Test listing files from the root directory + result = util.list_yaml_files([str(root)]) + + # Should only find the 3 files in root, not the 3 in subdirectories + assert len(result) == 3 + + # Check that only root-level files are found + assert str(root / "config1.yaml") in result + assert str(root / "config2.yml") in result + assert str(root / "device.yaml") in result + + # Ensure nested files are NOT found + for r in result: + assert "subdir" not in r + assert "deeper" not in r + assert "nested1.yaml" not in r + assert "nested2.yml" not in r + assert "very_nested.yaml" not in r + + +def test_list_yaml_files_excludes_secrets(tmp_path: Path) -> None: + """Test that secrets.yaml and secrets.yml are excluded.""" + root = tmp_path / "configs" + root.mkdir() + + # Create various YAML files including secrets + (root / "config.yaml").write_text("test: config") + (root / "secrets.yaml").write_text("wifi_password: secret123") + (root / "secrets.yml").write_text("api_key: secret456") + (root / "device.yaml").write_text("test: device") + + result = util.list_yaml_files([str(root)]) + + # Should find 2 files (config.yaml and device.yaml), not secrets + assert len(result) == 2 + assert str(root / "config.yaml") in result + assert str(root / "device.yaml") in result + assert str(root / "secrets.yaml") not in result + assert str(root / "secrets.yml") not in result + + +def test_list_yaml_files_excludes_hidden_files(tmp_path: Path) -> None: + """Test that hidden files (starting with .) are excluded.""" + root = tmp_path / "configs" + root.mkdir() + + # Create regular and hidden YAML files + (root / "config.yaml").write_text("test: config") + (root / ".hidden.yaml").write_text("test: hidden") + (root / ".backup.yml").write_text("test: backup") + (root / "device.yaml").write_text("test: device") + + result = util.list_yaml_files([str(root)]) + + # Should find only non-hidden files + assert len(result) == 2 + assert str(root / "config.yaml") in result + assert str(root / "device.yaml") in result + assert str(root / ".hidden.yaml") not in result + assert str(root / ".backup.yml") not in result + + +def test_filter_yaml_files_basic() -> None: + """Test filter_yaml_files function.""" + files = [ + "/path/to/config.yaml", + "/path/to/device.yml", + "/path/to/readme.txt", + "/path/to/script.py", + "/path/to/data.json", + "/path/to/another.yaml", + ] + + result = util.filter_yaml_files(files) + + assert len(result) == 3 + assert "/path/to/config.yaml" in result + assert "/path/to/device.yml" in result + assert "/path/to/another.yaml" in result + assert "/path/to/readme.txt" not in result + assert "/path/to/script.py" not in result + assert "/path/to/data.json" not in result + + +def test_filter_yaml_files_excludes_secrets() -> None: + """Test that filter_yaml_files excludes secrets files.""" + files = [ + "/path/to/config.yaml", + "/path/to/secrets.yaml", + "/path/to/secrets.yml", + "/path/to/device.yaml", + "/some/dir/secrets.yaml", + ] + + result = util.filter_yaml_files(files) + + assert len(result) == 2 + assert "/path/to/config.yaml" in result + assert "/path/to/device.yaml" in result + assert "/path/to/secrets.yaml" not in result + assert "/path/to/secrets.yml" not in result + assert "/some/dir/secrets.yaml" not in result + + +def test_filter_yaml_files_excludes_hidden() -> None: + """Test that filter_yaml_files excludes hidden files.""" + files = [ + "/path/to/config.yaml", + "/path/to/.hidden.yaml", + "/path/to/.backup.yml", + "/path/to/device.yaml", + "/some/dir/.config.yaml", + ] + + result = util.filter_yaml_files(files) + + assert len(result) == 2 + assert "/path/to/config.yaml" in result + assert "/path/to/device.yaml" in result + assert "/path/to/.hidden.yaml" not in result + assert "/path/to/.backup.yml" not in result + assert "/some/dir/.config.yaml" not in result + + +def test_filter_yaml_files_case_sensitive() -> None: + """Test that filter_yaml_files is case-sensitive for extensions.""" + files = [ + "/path/to/config.yaml", + "/path/to/config.YAML", + "/path/to/config.YML", + "/path/to/config.Yaml", + "/path/to/config.yml", + ] + + result = util.filter_yaml_files(files) + + # Should only match lowercase .yaml and .yml + assert len(result) == 2 + assert "/path/to/config.yaml" in result + assert "/path/to/config.yml" in result + assert "/path/to/config.YAML" not in result + assert "/path/to/config.YML" not in result + assert "/path/to/config.Yaml" not in result diff --git a/tests/unit_tests/test_writer.py b/tests/unit_tests/test_writer.py index f47947ff37..f1f86a322e 100644 --- a/tests/unit_tests/test_writer.py +++ b/tests/unit_tests/test_writer.py @@ -1,13 +1,34 @@ """Test writer module functionality.""" from collections.abc import Callable +from pathlib import Path from typing import Any from unittest.mock import MagicMock, patch import pytest +from esphome.core import EsphomeError from esphome.storage_json import StorageJSON -from esphome.writer import storage_should_clean, update_storage_json +from esphome.writer import ( + CPP_AUTO_GENERATE_BEGIN, + CPP_AUTO_GENERATE_END, + CPP_INCLUDE_BEGIN, + CPP_INCLUDE_END, + GITIGNORE_CONTENT, + clean_build, + clean_cmake_cache, + storage_should_clean, + update_storage_json, + write_cpp, + write_gitignore, +) + + +@pytest.fixture +def mock_copy_src_tree(): + """Mock copy_src_tree to avoid side effects during tests.""" + with patch("esphome.writer.copy_src_tree"): + yield @pytest.fixture @@ -218,3 +239,396 @@ def test_update_storage_json_logging_components_removed( # Verify save was called new_storage.save.assert_called_once_with("/test/path") + + +@patch("esphome.writer.CORE") +def test_clean_cmake_cache( + mock_core: MagicMock, + tmp_path: Path, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test clean_cmake_cache removes CMakeCache.txt file.""" + # Create directory structure + pioenvs_dir = tmp_path / ".pioenvs" + pioenvs_dir.mkdir() + device_dir = pioenvs_dir / "test_device" + device_dir.mkdir() + cmake_cache_file = device_dir / "CMakeCache.txt" + cmake_cache_file.write_text("# CMake cache file") + + # Setup mocks + mock_core.relative_pioenvs_path.side_effect = [ + str(pioenvs_dir), # First call for directory check + str(cmake_cache_file), # Second call for file path + ] + mock_core.name = "test_device" + + # Verify file exists before + assert cmake_cache_file.exists() + + # Call the function + with caplog.at_level("INFO"): + clean_cmake_cache() + + # Verify file was removed + assert not cmake_cache_file.exists() + + # Verify logging + assert "Deleting" in caplog.text + assert "CMakeCache.txt" in caplog.text + + +@patch("esphome.writer.CORE") +def test_clean_cmake_cache_no_pioenvs_dir( + mock_core: MagicMock, + tmp_path: Path, +) -> None: + """Test clean_cmake_cache when pioenvs directory doesn't exist.""" + # Setup non-existent directory path + pioenvs_dir = tmp_path / ".pioenvs" + + # Setup mocks + mock_core.relative_pioenvs_path.return_value = str(pioenvs_dir) + + # Verify directory doesn't exist + assert not pioenvs_dir.exists() + + # Call the function - should not crash + clean_cmake_cache() + + # Verify directory still doesn't exist + assert not pioenvs_dir.exists() + + +@patch("esphome.writer.CORE") +def test_clean_cmake_cache_no_cmake_file( + mock_core: MagicMock, + tmp_path: Path, +) -> None: + """Test clean_cmake_cache when CMakeCache.txt doesn't exist.""" + # Create directory structure without CMakeCache.txt + pioenvs_dir = tmp_path / ".pioenvs" + pioenvs_dir.mkdir() + device_dir = pioenvs_dir / "test_device" + device_dir.mkdir() + cmake_cache_file = device_dir / "CMakeCache.txt" + + # Setup mocks + mock_core.relative_pioenvs_path.side_effect = [ + str(pioenvs_dir), # First call for directory check + str(cmake_cache_file), # Second call for file path + ] + mock_core.name = "test_device" + + # Verify file doesn't exist + assert not cmake_cache_file.exists() + + # Call the function - should not crash + clean_cmake_cache() + + # Verify file still doesn't exist + assert not cmake_cache_file.exists() + + +@patch("esphome.writer.CORE") +def test_clean_build( + mock_core: MagicMock, + tmp_path: Path, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test clean_build removes all build artifacts.""" + # Create directory structure and files + pioenvs_dir = tmp_path / ".pioenvs" + pioenvs_dir.mkdir() + (pioenvs_dir / "test_file.o").write_text("object file") + + piolibdeps_dir = tmp_path / ".piolibdeps" + piolibdeps_dir.mkdir() + (piolibdeps_dir / "library").mkdir() + + dependencies_lock = tmp_path / "dependencies.lock" + dependencies_lock.write_text("lock file") + + # Setup mocks + mock_core.relative_pioenvs_path.return_value = str(pioenvs_dir) + mock_core.relative_piolibdeps_path.return_value = str(piolibdeps_dir) + mock_core.relative_build_path.return_value = str(dependencies_lock) + + # Verify all exist before + assert pioenvs_dir.exists() + assert piolibdeps_dir.exists() + assert dependencies_lock.exists() + + # Call the function + with caplog.at_level("INFO"): + clean_build() + + # Verify all were removed + assert not pioenvs_dir.exists() + assert not piolibdeps_dir.exists() + assert not dependencies_lock.exists() + + # Verify logging + assert "Deleting" in caplog.text + assert ".pioenvs" in caplog.text + assert ".piolibdeps" in caplog.text + assert "dependencies.lock" in caplog.text + + +@patch("esphome.writer.CORE") +def test_clean_build_partial_exists( + mock_core: MagicMock, + tmp_path: Path, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test clean_build when only some paths exist.""" + # Create only pioenvs directory + pioenvs_dir = tmp_path / ".pioenvs" + pioenvs_dir.mkdir() + (pioenvs_dir / "test_file.o").write_text("object file") + + piolibdeps_dir = tmp_path / ".piolibdeps" + dependencies_lock = tmp_path / "dependencies.lock" + + # Setup mocks + mock_core.relative_pioenvs_path.return_value = str(pioenvs_dir) + mock_core.relative_piolibdeps_path.return_value = str(piolibdeps_dir) + mock_core.relative_build_path.return_value = str(dependencies_lock) + + # Verify only pioenvs exists + assert pioenvs_dir.exists() + assert not piolibdeps_dir.exists() + assert not dependencies_lock.exists() + + # Call the function + with caplog.at_level("INFO"): + clean_build() + + # Verify only existing path was removed + assert not pioenvs_dir.exists() + assert not piolibdeps_dir.exists() + assert not dependencies_lock.exists() + + # Verify logging - only pioenvs should be logged + assert "Deleting" in caplog.text + assert ".pioenvs" in caplog.text + assert ".piolibdeps" not in caplog.text + assert "dependencies.lock" not in caplog.text + + +@patch("esphome.writer.CORE") +def test_clean_build_nothing_exists( + mock_core: MagicMock, + tmp_path: Path, +) -> None: + """Test clean_build when no build artifacts exist.""" + # Setup paths that don't exist + pioenvs_dir = tmp_path / ".pioenvs" + piolibdeps_dir = tmp_path / ".piolibdeps" + dependencies_lock = tmp_path / "dependencies.lock" + + # Setup mocks + mock_core.relative_pioenvs_path.return_value = str(pioenvs_dir) + mock_core.relative_piolibdeps_path.return_value = str(piolibdeps_dir) + mock_core.relative_build_path.return_value = str(dependencies_lock) + + # Verify nothing exists + assert not pioenvs_dir.exists() + assert not piolibdeps_dir.exists() + assert not dependencies_lock.exists() + + # Call the function - should not crash + clean_build() + + # Verify nothing was created + assert not pioenvs_dir.exists() + assert not piolibdeps_dir.exists() + assert not dependencies_lock.exists() + + +@patch("esphome.writer.CORE") +def test_write_gitignore_creates_new_file( + mock_core: MagicMock, + tmp_path: Path, +) -> None: + """Test write_gitignore creates a new .gitignore file when it doesn't exist.""" + gitignore_path = tmp_path / ".gitignore" + + # Setup mocks + mock_core.relative_config_path.return_value = str(gitignore_path) + + # Verify file doesn't exist + assert not gitignore_path.exists() + + # Call the function + write_gitignore() + + # Verify file was created with correct content + assert gitignore_path.exists() + assert gitignore_path.read_text() == GITIGNORE_CONTENT + + +@patch("esphome.writer.CORE") +def test_write_gitignore_skips_existing_file( + mock_core: MagicMock, + tmp_path: Path, +) -> None: + """Test write_gitignore doesn't overwrite existing .gitignore file.""" + gitignore_path = tmp_path / ".gitignore" + existing_content = "# Custom gitignore\n/custom_dir/\n" + gitignore_path.write_text(existing_content) + + # Setup mocks + mock_core.relative_config_path.return_value = str(gitignore_path) + + # Verify file exists with custom content + assert gitignore_path.exists() + assert gitignore_path.read_text() == existing_content + + # Call the function + write_gitignore() + + # Verify file was not modified + assert gitignore_path.exists() + assert gitignore_path.read_text() == existing_content + + +@patch("esphome.writer.write_file_if_changed") # Mock to capture output +@patch("esphome.writer.copy_src_tree") # Keep this mock as it's complex +@patch("esphome.writer.CORE") +def test_write_cpp_with_existing_file( + mock_core: MagicMock, + mock_copy_src_tree: MagicMock, + mock_write_file: MagicMock, + tmp_path: Path, +) -> None: + """Test write_cpp when main.cpp already exists.""" + # Create a real file with markers + main_cpp = tmp_path / "main.cpp" + existing_content = f"""#include "esphome.h" +{CPP_INCLUDE_BEGIN} +// Old includes +{CPP_INCLUDE_END} +void setup() {{ +{CPP_AUTO_GENERATE_BEGIN} +// Old code +{CPP_AUTO_GENERATE_END} +}} +void loop() {{}}""" + main_cpp.write_text(existing_content) + + # Setup mocks + mock_core.relative_src_path.return_value = str(main_cpp) + mock_core.cpp_global_section = "// Global section" + + # Call the function + test_code = " // New generated code" + write_cpp(test_code) + + # Verify copy_src_tree was called + mock_copy_src_tree.assert_called_once() + + # Get the content that would be written + mock_write_file.assert_called_once() + written_path, written_content = mock_write_file.call_args[0] + + # Check that markers are preserved and content is updated + assert CPP_INCLUDE_BEGIN in written_content + assert CPP_INCLUDE_END in written_content + assert CPP_AUTO_GENERATE_BEGIN in written_content + assert CPP_AUTO_GENERATE_END in written_content + assert test_code in written_content + assert "// Global section" in written_content + + +@patch("esphome.writer.write_file_if_changed") # Mock to capture output +@patch("esphome.writer.copy_src_tree") # Keep this mock as it's complex +@patch("esphome.writer.CORE") +def test_write_cpp_creates_new_file( + mock_core: MagicMock, + mock_copy_src_tree: MagicMock, + mock_write_file: MagicMock, + tmp_path: Path, +) -> None: + """Test write_cpp when main.cpp doesn't exist.""" + # Setup path for new file + main_cpp = tmp_path / "main.cpp" + + # Setup mocks + mock_core.relative_src_path.return_value = str(main_cpp) + mock_core.cpp_global_section = "// Global section" + + # Verify file doesn't exist + assert not main_cpp.exists() + + # Call the function + test_code = " // Generated code" + write_cpp(test_code) + + # Verify copy_src_tree was called + mock_copy_src_tree.assert_called_once() + + # Get the content that would be written + mock_write_file.assert_called_once() + written_path, written_content = mock_write_file.call_args[0] + assert written_path == str(main_cpp) + + # Check that all necessary parts are in the new file + assert '#include "esphome.h"' in written_content + assert CPP_INCLUDE_BEGIN in written_content + assert CPP_INCLUDE_END in written_content + assert CPP_AUTO_GENERATE_BEGIN in written_content + assert CPP_AUTO_GENERATE_END in written_content + assert test_code in written_content + assert "void setup()" in written_content + assert "void loop()" in written_content + assert "App.setup();" in written_content + assert "App.loop();" in written_content + + +@pytest.mark.usefixtures("mock_copy_src_tree") +@patch("esphome.writer.CORE") +def test_write_cpp_with_missing_end_marker( + mock_core: MagicMock, + tmp_path: Path, +) -> None: + """Test write_cpp raises error when end marker is missing.""" + # Create a file with begin marker but no end marker + main_cpp = tmp_path / "main.cpp" + existing_content = f"""#include "esphome.h" +{CPP_AUTO_GENERATE_BEGIN} +// Code without end marker""" + main_cpp.write_text(existing_content) + + # Setup mocks + mock_core.relative_src_path.return_value = str(main_cpp) + + # Call should raise an error + with pytest.raises(EsphomeError, match="Could not find auto generated code end"): + write_cpp("// New code") + + +@pytest.mark.usefixtures("mock_copy_src_tree") +@patch("esphome.writer.CORE") +def test_write_cpp_with_duplicate_markers( + mock_core: MagicMock, + tmp_path: Path, +) -> None: + """Test write_cpp raises error when duplicate markers exist.""" + # Create a file with duplicate begin markers + main_cpp = tmp_path / "main.cpp" + existing_content = f"""#include "esphome.h" +{CPP_AUTO_GENERATE_BEGIN} +// First section +{CPP_AUTO_GENERATE_END} +{CPP_AUTO_GENERATE_BEGIN} +// Duplicate section +{CPP_AUTO_GENERATE_END}""" + main_cpp.write_text(existing_content) + + # Setup mocks + mock_core.relative_src_path.return_value = str(main_cpp) + + # Call should raise an error + with pytest.raises(EsphomeError, match="Found multiple auto generate code begins"): + write_cpp("// New code")