1
0
mirror of https://github.com/esphome/esphome.git synced 2025-09-15 17:52:19 +01:00

Merge remote-tracking branch 'upstream/dev' into zwave_proxy

This commit is contained in:
kbx81
2025-09-13 00:09:57 -05:00
8 changed files with 1914 additions and 173 deletions

View File

@@ -15,9 +15,11 @@ import argcomplete
from esphome import const, writer, yaml_util from esphome import const, writer, yaml_util
import esphome.codegen as cg import esphome.codegen as cg
from esphome.components.mqtt import CONF_DISCOVER_IP
from esphome.config import iter_component_configs, read_config, strip_default_ids from esphome.config import iter_component_configs, read_config, strip_default_ids
from esphome.const import ( from esphome.const import (
ALLOWED_NAME_CHARS, ALLOWED_NAME_CHARS,
CONF_API,
CONF_BAUD_RATE, CONF_BAUD_RATE,
CONF_BROKER, CONF_BROKER,
CONF_DEASSERT_RTS_DTR, CONF_DEASSERT_RTS_DTR,
@@ -43,6 +45,7 @@ from esphome.const import (
SECRETS_FILES, SECRETS_FILES,
) )
from esphome.core import CORE, EsphomeError, coroutine from esphome.core import CORE, EsphomeError, coroutine
from esphome.enum import StrEnum
from esphome.helpers import get_bool_env, indent, is_ip_address from esphome.helpers import get_bool_env, indent, is_ip_address
from esphome.log import AnsiFore, color, setup_log from esphome.log import AnsiFore, color, setup_log
from esphome.types import ConfigType from esphome.types import ConfigType
@@ -106,13 +109,15 @@ def choose_prompt(options, purpose: str = None):
return options[opt - 1][1] return options[opt - 1][1]
class Purpose(StrEnum):
UPLOADING = "uploading"
LOGGING = "logging"
def choose_upload_log_host( def choose_upload_log_host(
default: list[str] | str | None, default: list[str] | str | None,
check_default: str | None, check_default: str | None,
show_ota: bool, purpose: Purpose,
show_mqtt: bool,
show_api: bool,
purpose: str | None = None,
) -> list[str]: ) -> list[str]:
# Convert to list for uniform handling # Convert to list for uniform handling
defaults = [default] if isinstance(default, str) else default or [] defaults = [default] if isinstance(default, str) else default or []
@@ -132,13 +137,30 @@ def choose_upload_log_host(
] ]
resolved.append(choose_prompt(options, purpose=purpose)) resolved.append(choose_prompt(options, purpose=purpose))
elif device == "OTA": elif device == "OTA":
if CORE.address and ( # ensure IP adresses are used first
(show_ota and "ota" in CORE.config) if is_ip_address(CORE.address) and (
or (show_api and "api" in CORE.config) (purpose == Purpose.LOGGING and has_api())
or (purpose == Purpose.UPLOADING and has_ota())
): ):
resolved.append(CORE.address) resolved.append(CORE.address)
elif show_mqtt and has_mqtt_logging():
if purpose == Purpose.LOGGING:
if has_api() and has_mqtt_ip_lookup():
resolved.append("MQTTIP")
if has_mqtt_logging():
resolved.append("MQTT") resolved.append("MQTT")
if has_api() and has_non_ip_address():
resolved.append(CORE.address)
elif purpose == Purpose.UPLOADING:
if has_ota() and has_mqtt_ip_lookup():
resolved.append("MQTTIP")
if has_ota() and has_non_ip_address():
resolved.append(CORE.address)
else: else:
resolved.append(device) resolved.append(device)
if not resolved: if not resolved:
@@ -149,39 +171,111 @@ def choose_upload_log_host(
options = [ options = [
(f"{port.path} ({port.description})", port.path) for port in get_serial_ports() (f"{port.path} ({port.description})", port.path) for port in get_serial_ports()
] ]
if (show_ota and "ota" in CORE.config) or (show_api and "api" in CORE.config):
options.append((f"Over The Air ({CORE.address})", CORE.address)) if purpose == Purpose.LOGGING:
if show_mqtt and has_mqtt_logging(): if has_mqtt_logging():
mqtt_config = CORE.config[CONF_MQTT] mqtt_config = CORE.config[CONF_MQTT]
options.append((f"MQTT ({mqtt_config[CONF_BROKER]})", "MQTT")) options.append((f"MQTT ({mqtt_config[CONF_BROKER]})", "MQTT"))
if has_api():
if has_resolvable_address():
options.append((f"Over The Air ({CORE.address})", CORE.address))
if has_mqtt_ip_lookup():
options.append(("Over The Air (MQTT IP lookup)", "MQTTIP"))
elif purpose == Purpose.UPLOADING and has_ota():
if has_resolvable_address():
options.append((f"Over The Air ({CORE.address})", CORE.address))
if has_mqtt_ip_lookup():
options.append(("Over The Air (MQTT IP lookup)", "MQTTIP"))
if check_default is not None and check_default in [opt[1] for opt in options]: if check_default is not None and check_default in [opt[1] for opt in options]:
return [check_default] return [check_default]
return [choose_prompt(options, purpose=purpose)] return [choose_prompt(options, purpose=purpose)]
def mqtt_logging_enabled(mqtt_config): def has_mqtt_logging() -> bool:
"""Check if MQTT logging is available."""
if CONF_MQTT not in CORE.config:
return False
mqtt_config = CORE.config[CONF_MQTT]
# enabled by default
if CONF_LOG_TOPIC not in mqtt_config:
return True
log_topic = mqtt_config[CONF_LOG_TOPIC] log_topic = mqtt_config[CONF_LOG_TOPIC]
if log_topic is None: if log_topic is None:
return False return False
if CONF_TOPIC not in log_topic: if CONF_TOPIC not in log_topic:
return False return False
return log_topic.get(CONF_LEVEL, None) != "NONE"
return log_topic[CONF_LEVEL] != "NONE"
def has_mqtt_logging() -> bool: def has_mqtt() -> bool:
"""Check if MQTT logging is available.""" """Check if MQTT is available."""
return (mqtt_config := CORE.config.get(CONF_MQTT)) and mqtt_logging_enabled( return CONF_MQTT in CORE.config
mqtt_config
)
def has_api() -> bool:
"""Check if API is available."""
return CONF_API in CORE.config
def has_ota() -> bool:
"""Check if OTA is available."""
return CONF_OTA in CORE.config
def has_mqtt_ip_lookup() -> bool:
"""Check if MQTT is available and IP lookup is supported."""
if CONF_MQTT not in CORE.config:
return False
# Default Enabled
if CONF_DISCOVER_IP not in CORE.config[CONF_MQTT]:
return True
return CORE.config[CONF_MQTT][CONF_DISCOVER_IP]
def has_mdns() -> bool:
"""Check if MDNS is available."""
return CONF_MDNS not in CORE.config or not CORE.config[CONF_MDNS][CONF_DISABLED]
def has_non_ip_address() -> bool:
"""Check if CORE.address is set and is not an IP address."""
return CORE.address is not None and not is_ip_address(CORE.address)
def has_ip_address() -> bool:
"""Check if CORE.address is a valid IP address."""
return CORE.address is not None and is_ip_address(CORE.address)
def has_resolvable_address() -> bool:
"""Check if CORE.address is resolvable (via mDNS or is an IP address)."""
return has_mdns() or has_ip_address()
def mqtt_get_ip(config: ConfigType, username: str, password: str, client_id: str):
from esphome import mqtt
return mqtt.get_esphome_device_ip(config, username, password, client_id)
_PORT_TO_PORT_TYPE = {
"MQTT": "MQTT",
"MQTTIP": "MQTTIP",
}
def get_port_type(port: str) -> str: def get_port_type(port: str) -> str:
if port.startswith("/") or port.startswith("COM"): if port.startswith("/") or port.startswith("COM"):
return "SERIAL" return "SERIAL"
if port == "MQTT": return _PORT_TO_PORT_TYPE.get(port, "NETWORK")
return "MQTT"
return "NETWORK"
def run_miniterm(config: ConfigType, port: str, args) -> int: def run_miniterm(config: ConfigType, port: str, args) -> int:
@@ -439,23 +533,9 @@ def upload_program(
password = ota_conf.get(CONF_PASSWORD, "") password = ota_conf.get(CONF_PASSWORD, "")
binary = args.file if getattr(args, "file", None) is not None else CORE.firmware_bin binary = args.file if getattr(args, "file", None) is not None else CORE.firmware_bin
# Check if we should use MQTT for address resolution # MQTT address resolution
# This happens when no device was specified, or the current host is "MQTT"/"OTA" if get_port_type(host) in ("MQTT", "MQTTIP"):
if ( devices = mqtt_get_ip(config, args.username, args.password, args.client_id)
CONF_MQTT in config # pylint: disable=too-many-boolean-expressions
and (not devices or host in ("MQTT", "OTA"))
and (
((config[CONF_MDNS][CONF_DISABLED]) and not is_ip_address(CORE.address))
or get_port_type(host) == "MQTT"
)
):
from esphome import mqtt
devices = [
mqtt.get_esphome_device_ip(
config, args.username, args.password, args.client_id
)
]
return espota2.run_ota(devices, remote_port, password, binary) return espota2.run_ota(devices, remote_port, password, binary)
@@ -476,20 +556,28 @@ def show_logs(config: ConfigType, args: ArgsProtocol, devices: list[str]) -> int
if get_port_type(port) == "SERIAL": if get_port_type(port) == "SERIAL":
check_permissions(port) check_permissions(port)
return run_miniterm(config, port, args) return run_miniterm(config, port, args)
if get_port_type(port) == "NETWORK" and "api" in config:
port_type = get_port_type(port)
# Check if we should use API for logging
if has_api():
addresses_to_use: list[str] | None = None
if port_type == "NETWORK" and (has_mdns() or is_ip_address(port)):
addresses_to_use = devices addresses_to_use = devices
if config[CONF_MDNS][CONF_DISABLED] and CONF_MQTT in config: elif port_type in ("NETWORK", "MQTT", "MQTTIP") and has_mqtt_ip_lookup():
from esphome import mqtt # Only use MQTT IP lookup if the first condition didn't match
# (for MQTT/MQTTIP types, or for NETWORK when mdns/ip check fails)
mqtt_address = mqtt.get_esphome_device_ip( addresses_to_use = mqtt_get_ip(
config, args.username, args.password, args.client_id config, args.username, args.password, args.client_id
)[0] )
addresses_to_use = [mqtt_address]
if addresses_to_use is not None:
from esphome.components.api.client import run_logs from esphome.components.api.client import run_logs
return run_logs(config, addresses_to_use) return run_logs(config, addresses_to_use)
if get_port_type(port) in ("NETWORK", "MQTT") and "mqtt" in config:
if port_type in ("NETWORK", "MQTT") and has_mqtt_logging():
from esphome import mqtt from esphome import mqtt
return mqtt.show_logs( return mqtt.show_logs(
@@ -555,10 +643,7 @@ def command_upload(args: ArgsProtocol, config: ConfigType) -> int | None:
devices = choose_upload_log_host( devices = choose_upload_log_host(
default=args.device, default=args.device,
check_default=None, check_default=None,
show_ota=True, purpose=Purpose.UPLOADING,
show_mqtt=False,
show_api=False,
purpose="uploading",
) )
exit_code, _ = upload_program(config, args, devices) exit_code, _ = upload_program(config, args, devices)
@@ -583,10 +668,7 @@ def command_logs(args: ArgsProtocol, config: ConfigType) -> int | None:
devices = choose_upload_log_host( devices = choose_upload_log_host(
default=args.device, default=args.device,
check_default=None, check_default=None,
show_ota=False, purpose=Purpose.LOGGING,
show_mqtt=True,
show_api=True,
purpose="logging",
) )
return show_logs(config, args, devices) return show_logs(config, args, devices)
@@ -612,10 +694,7 @@ def command_run(args: ArgsProtocol, config: ConfigType) -> int | None:
devices = choose_upload_log_host( devices = choose_upload_log_host(
default=args.device, default=args.device,
check_default=None, check_default=None,
show_ota=True, purpose=Purpose.UPLOADING,
show_mqtt=False,
show_api=True,
purpose="uploading",
) )
exit_code, successful_device = upload_program(config, args, devices) exit_code, successful_device = upload_program(config, args, devices)
@@ -632,10 +711,7 @@ def command_run(args: ArgsProtocol, config: ConfigType) -> int | None:
devices = choose_upload_log_host( devices = choose_upload_log_host(
default=successful_device, default=successful_device,
check_default=successful_device, check_default=successful_device,
show_ota=False, purpose=Purpose.LOGGING,
show_mqtt=True,
show_api=True,
purpose="logging",
) )
return show_logs(config, args, devices) return show_logs(config, args, devices)

View File

@@ -270,6 +270,7 @@ void PacketTransport::add_binary_data_(uint8_t key, const char *id, bool data) {
auto len = 1 + 1 + 1 + strlen(id); auto len = 1 + 1 + 1 + strlen(id);
if (len + this->header_.size() + this->data_.size() > this->get_max_packet_size()) { if (len + this->header_.size() + this->data_.size() > this->get_max_packet_size()) {
this->flush_(); this->flush_();
this->init_data_();
} }
add(this->data_, key); add(this->data_, key);
add(this->data_, (uint8_t) data); add(this->data_, (uint8_t) data);
@@ -284,6 +285,7 @@ void PacketTransport::add_data_(uint8_t key, const char *id, uint32_t data) {
auto len = 4 + 1 + 1 + strlen(id); auto len = 4 + 1 + 1 + strlen(id);
if (len + this->header_.size() + this->data_.size() > this->get_max_packet_size()) { if (len + this->header_.size() + this->data_.size() > this->get_max_packet_size()) {
this->flush_(); this->flush_();
this->init_data_();
} }
add(this->data_, key); add(this->data_, key);
add(this->data_, data); add(this->data_, data);

View File

@@ -114,6 +114,7 @@ CONF_AND = "and"
CONF_ANGLE = "angle" CONF_ANGLE = "angle"
CONF_ANY = "any" CONF_ANY = "any"
CONF_AP = "ap" CONF_AP = "ap"
CONF_API = "api"
CONF_APPARENT_POWER = "apparent_power" CONF_APPARENT_POWER = "apparent_power"
CONF_ARDUINO_VERSION = "arduino_version" CONF_ARDUINO_VERSION = "arduino_version"
CONF_AREA = "area" CONF_AREA = "area"

View File

@@ -8,7 +8,7 @@ pre-commit
pytest==8.4.2 pytest==8.4.2
pytest-cov==7.0.0 pytest-cov==7.0.0
pytest-mock==3.15.0 pytest-mock==3.15.0
pytest-asyncio==1.1.0 pytest-asyncio==1.2.0
pytest-xdist==3.8.0 pytest-xdist==3.8.0
asyncmock==0.4.2 asyncmock==0.4.2
hypothesis==6.92.1 hypothesis==6.92.1

View File

@@ -556,6 +556,66 @@ def test_start_web_server_with_address_port(
assert (archive_dir / "old.yaml").exists() assert (archive_dir / "old.yaml").exists()
@pytest.mark.asyncio
async def test_edit_request_handler_get(dashboard: DashboardTestHelper) -> None:
"""Test EditRequestHandler.get method."""
# Test getting a valid yaml file
response = await dashboard.fetch("/edit?configuration=pico.yaml")
assert response.code == 200
assert response.headers["content-type"] == "application/yaml"
content = response.body.decode()
assert "esphome:" in content # Verify it's a valid ESPHome config
# Test getting a non-existent file
with pytest.raises(HTTPClientError) as exc_info:
await dashboard.fetch("/edit?configuration=nonexistent.yaml")
assert exc_info.value.code == 404
# Test getting a non-yaml file
with pytest.raises(HTTPClientError) as exc_info:
await dashboard.fetch("/edit?configuration=test.txt")
assert exc_info.value.code == 404
# Test path traversal attempt
with pytest.raises(HTTPClientError) as exc_info:
await dashboard.fetch("/edit?configuration=../../../etc/passwd")
assert exc_info.value.code == 404
@pytest.mark.asyncio
async def test_archive_request_handler_post(
dashboard: DashboardTestHelper,
mock_archive_storage_path: MagicMock,
mock_ext_storage_path: MagicMock,
tmp_path: Path,
) -> None:
"""Test ArchiveRequestHandler.post method."""
# Set up temp directories
config_dir = Path(get_fixture_path("conf"))
archive_dir = tmp_path / "archive"
# Create a test configuration file
test_config = config_dir / "test_archive.yaml"
test_config.write_text("esphome:\n name: test_archive\n")
# Archive the configuration
response = await dashboard.fetch(
"/archive",
method="POST",
body="configuration=test_archive.yaml",
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
assert response.code == 200
# Verify file was moved to archive
assert not test_config.exists()
assert (archive_dir / "test_archive.yaml").exists()
assert (
archive_dir / "test_archive.yaml"
).read_text() == "esphome:\n name: test_archive\n"
@pytest.mark.skipif(os.name == "nt", reason="Unix sockets are not supported on Windows") @pytest.mark.skipif(os.name == "nt", reason="Unix sockets are not supported on Windows")
@pytest.mark.usefixtures("mock_trash_storage_path", "mock_archive_storage_path") @pytest.mark.usefixtures("mock_trash_storage_path", "mock_archive_storage_path")
def test_start_web_server_with_unix_socket(tmp_path: Path) -> None: def test_start_web_server_with_unix_socket(tmp_path: Path) -> None:

File diff suppressed because it is too large Load Diff

View File

@@ -141,3 +141,170 @@ def test_list_yaml_files_mixed_extensions(tmp_path: Path) -> None:
str(yaml_file), str(yaml_file),
str(yml_file), str(yml_file),
} }
def test_list_yaml_files_does_not_recurse_into_subdirectories(tmp_path: Path) -> None:
"""Test that list_yaml_files only finds files in specified directory, not subdirectories."""
# Create directory structure with YAML files at different depths
root = tmp_path / "configs"
root.mkdir()
# Create YAML files in the root directory
(root / "config1.yaml").write_text("test: 1")
(root / "config2.yml").write_text("test: 2")
(root / "device.yaml").write_text("test: device")
# Create subdirectory with YAML files (should NOT be found)
subdir = root / "subdir"
subdir.mkdir()
(subdir / "nested1.yaml").write_text("test: nested1")
(subdir / "nested2.yml").write_text("test: nested2")
# Create deeper subdirectory (should NOT be found)
deep_subdir = subdir / "deeper"
deep_subdir.mkdir()
(deep_subdir / "very_nested.yaml").write_text("test: very_nested")
# Test listing files from the root directory
result = util.list_yaml_files([str(root)])
# Should only find the 3 files in root, not the 3 in subdirectories
assert len(result) == 3
# Check that only root-level files are found
assert str(root / "config1.yaml") in result
assert str(root / "config2.yml") in result
assert str(root / "device.yaml") in result
# Ensure nested files are NOT found
for r in result:
assert "subdir" not in r
assert "deeper" not in r
assert "nested1.yaml" not in r
assert "nested2.yml" not in r
assert "very_nested.yaml" not in r
def test_list_yaml_files_excludes_secrets(tmp_path: Path) -> None:
"""Test that secrets.yaml and secrets.yml are excluded."""
root = tmp_path / "configs"
root.mkdir()
# Create various YAML files including secrets
(root / "config.yaml").write_text("test: config")
(root / "secrets.yaml").write_text("wifi_password: secret123")
(root / "secrets.yml").write_text("api_key: secret456")
(root / "device.yaml").write_text("test: device")
result = util.list_yaml_files([str(root)])
# Should find 2 files (config.yaml and device.yaml), not secrets
assert len(result) == 2
assert str(root / "config.yaml") in result
assert str(root / "device.yaml") in result
assert str(root / "secrets.yaml") not in result
assert str(root / "secrets.yml") not in result
def test_list_yaml_files_excludes_hidden_files(tmp_path: Path) -> None:
"""Test that hidden files (starting with .) are excluded."""
root = tmp_path / "configs"
root.mkdir()
# Create regular and hidden YAML files
(root / "config.yaml").write_text("test: config")
(root / ".hidden.yaml").write_text("test: hidden")
(root / ".backup.yml").write_text("test: backup")
(root / "device.yaml").write_text("test: device")
result = util.list_yaml_files([str(root)])
# Should find only non-hidden files
assert len(result) == 2
assert str(root / "config.yaml") in result
assert str(root / "device.yaml") in result
assert str(root / ".hidden.yaml") not in result
assert str(root / ".backup.yml") not in result
def test_filter_yaml_files_basic() -> None:
"""Test filter_yaml_files function."""
files = [
"/path/to/config.yaml",
"/path/to/device.yml",
"/path/to/readme.txt",
"/path/to/script.py",
"/path/to/data.json",
"/path/to/another.yaml",
]
result = util.filter_yaml_files(files)
assert len(result) == 3
assert "/path/to/config.yaml" in result
assert "/path/to/device.yml" in result
assert "/path/to/another.yaml" in result
assert "/path/to/readme.txt" not in result
assert "/path/to/script.py" not in result
assert "/path/to/data.json" not in result
def test_filter_yaml_files_excludes_secrets() -> None:
"""Test that filter_yaml_files excludes secrets files."""
files = [
"/path/to/config.yaml",
"/path/to/secrets.yaml",
"/path/to/secrets.yml",
"/path/to/device.yaml",
"/some/dir/secrets.yaml",
]
result = util.filter_yaml_files(files)
assert len(result) == 2
assert "/path/to/config.yaml" in result
assert "/path/to/device.yaml" in result
assert "/path/to/secrets.yaml" not in result
assert "/path/to/secrets.yml" not in result
assert "/some/dir/secrets.yaml" not in result
def test_filter_yaml_files_excludes_hidden() -> None:
"""Test that filter_yaml_files excludes hidden files."""
files = [
"/path/to/config.yaml",
"/path/to/.hidden.yaml",
"/path/to/.backup.yml",
"/path/to/device.yaml",
"/some/dir/.config.yaml",
]
result = util.filter_yaml_files(files)
assert len(result) == 2
assert "/path/to/config.yaml" in result
assert "/path/to/device.yaml" in result
assert "/path/to/.hidden.yaml" not in result
assert "/path/to/.backup.yml" not in result
assert "/some/dir/.config.yaml" not in result
def test_filter_yaml_files_case_sensitive() -> None:
"""Test that filter_yaml_files is case-sensitive for extensions."""
files = [
"/path/to/config.yaml",
"/path/to/config.YAML",
"/path/to/config.YML",
"/path/to/config.Yaml",
"/path/to/config.yml",
]
result = util.filter_yaml_files(files)
# Should only match lowercase .yaml and .yml
assert len(result) == 2
assert "/path/to/config.yaml" in result
assert "/path/to/config.yml" in result
assert "/path/to/config.YAML" not in result
assert "/path/to/config.YML" not in result
assert "/path/to/config.Yaml" not in result

View File

@@ -1,13 +1,34 @@
"""Test writer module functionality.""" """Test writer module functionality."""
from collections.abc import Callable from collections.abc import Callable
from pathlib import Path
from typing import Any from typing import Any
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
from esphome.core import EsphomeError
from esphome.storage_json import StorageJSON from esphome.storage_json import StorageJSON
from esphome.writer import storage_should_clean, update_storage_json from esphome.writer import (
CPP_AUTO_GENERATE_BEGIN,
CPP_AUTO_GENERATE_END,
CPP_INCLUDE_BEGIN,
CPP_INCLUDE_END,
GITIGNORE_CONTENT,
clean_build,
clean_cmake_cache,
storage_should_clean,
update_storage_json,
write_cpp,
write_gitignore,
)
@pytest.fixture
def mock_copy_src_tree():
"""Mock copy_src_tree to avoid side effects during tests."""
with patch("esphome.writer.copy_src_tree"):
yield
@pytest.fixture @pytest.fixture
@@ -218,3 +239,396 @@ def test_update_storage_json_logging_components_removed(
# Verify save was called # Verify save was called
new_storage.save.assert_called_once_with("/test/path") new_storage.save.assert_called_once_with("/test/path")
@patch("esphome.writer.CORE")
def test_clean_cmake_cache(
mock_core: MagicMock,
tmp_path: Path,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test clean_cmake_cache removes CMakeCache.txt file."""
# Create directory structure
pioenvs_dir = tmp_path / ".pioenvs"
pioenvs_dir.mkdir()
device_dir = pioenvs_dir / "test_device"
device_dir.mkdir()
cmake_cache_file = device_dir / "CMakeCache.txt"
cmake_cache_file.write_text("# CMake cache file")
# Setup mocks
mock_core.relative_pioenvs_path.side_effect = [
str(pioenvs_dir), # First call for directory check
str(cmake_cache_file), # Second call for file path
]
mock_core.name = "test_device"
# Verify file exists before
assert cmake_cache_file.exists()
# Call the function
with caplog.at_level("INFO"):
clean_cmake_cache()
# Verify file was removed
assert not cmake_cache_file.exists()
# Verify logging
assert "Deleting" in caplog.text
assert "CMakeCache.txt" in caplog.text
@patch("esphome.writer.CORE")
def test_clean_cmake_cache_no_pioenvs_dir(
mock_core: MagicMock,
tmp_path: Path,
) -> None:
"""Test clean_cmake_cache when pioenvs directory doesn't exist."""
# Setup non-existent directory path
pioenvs_dir = tmp_path / ".pioenvs"
# Setup mocks
mock_core.relative_pioenvs_path.return_value = str(pioenvs_dir)
# Verify directory doesn't exist
assert not pioenvs_dir.exists()
# Call the function - should not crash
clean_cmake_cache()
# Verify directory still doesn't exist
assert not pioenvs_dir.exists()
@patch("esphome.writer.CORE")
def test_clean_cmake_cache_no_cmake_file(
mock_core: MagicMock,
tmp_path: Path,
) -> None:
"""Test clean_cmake_cache when CMakeCache.txt doesn't exist."""
# Create directory structure without CMakeCache.txt
pioenvs_dir = tmp_path / ".pioenvs"
pioenvs_dir.mkdir()
device_dir = pioenvs_dir / "test_device"
device_dir.mkdir()
cmake_cache_file = device_dir / "CMakeCache.txt"
# Setup mocks
mock_core.relative_pioenvs_path.side_effect = [
str(pioenvs_dir), # First call for directory check
str(cmake_cache_file), # Second call for file path
]
mock_core.name = "test_device"
# Verify file doesn't exist
assert not cmake_cache_file.exists()
# Call the function - should not crash
clean_cmake_cache()
# Verify file still doesn't exist
assert not cmake_cache_file.exists()
@patch("esphome.writer.CORE")
def test_clean_build(
mock_core: MagicMock,
tmp_path: Path,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test clean_build removes all build artifacts."""
# Create directory structure and files
pioenvs_dir = tmp_path / ".pioenvs"
pioenvs_dir.mkdir()
(pioenvs_dir / "test_file.o").write_text("object file")
piolibdeps_dir = tmp_path / ".piolibdeps"
piolibdeps_dir.mkdir()
(piolibdeps_dir / "library").mkdir()
dependencies_lock = tmp_path / "dependencies.lock"
dependencies_lock.write_text("lock file")
# Setup mocks
mock_core.relative_pioenvs_path.return_value = str(pioenvs_dir)
mock_core.relative_piolibdeps_path.return_value = str(piolibdeps_dir)
mock_core.relative_build_path.return_value = str(dependencies_lock)
# Verify all exist before
assert pioenvs_dir.exists()
assert piolibdeps_dir.exists()
assert dependencies_lock.exists()
# Call the function
with caplog.at_level("INFO"):
clean_build()
# Verify all were removed
assert not pioenvs_dir.exists()
assert not piolibdeps_dir.exists()
assert not dependencies_lock.exists()
# Verify logging
assert "Deleting" in caplog.text
assert ".pioenvs" in caplog.text
assert ".piolibdeps" in caplog.text
assert "dependencies.lock" in caplog.text
@patch("esphome.writer.CORE")
def test_clean_build_partial_exists(
mock_core: MagicMock,
tmp_path: Path,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test clean_build when only some paths exist."""
# Create only pioenvs directory
pioenvs_dir = tmp_path / ".pioenvs"
pioenvs_dir.mkdir()
(pioenvs_dir / "test_file.o").write_text("object file")
piolibdeps_dir = tmp_path / ".piolibdeps"
dependencies_lock = tmp_path / "dependencies.lock"
# Setup mocks
mock_core.relative_pioenvs_path.return_value = str(pioenvs_dir)
mock_core.relative_piolibdeps_path.return_value = str(piolibdeps_dir)
mock_core.relative_build_path.return_value = str(dependencies_lock)
# Verify only pioenvs exists
assert pioenvs_dir.exists()
assert not piolibdeps_dir.exists()
assert not dependencies_lock.exists()
# Call the function
with caplog.at_level("INFO"):
clean_build()
# Verify only existing path was removed
assert not pioenvs_dir.exists()
assert not piolibdeps_dir.exists()
assert not dependencies_lock.exists()
# Verify logging - only pioenvs should be logged
assert "Deleting" in caplog.text
assert ".pioenvs" in caplog.text
assert ".piolibdeps" not in caplog.text
assert "dependencies.lock" not in caplog.text
@patch("esphome.writer.CORE")
def test_clean_build_nothing_exists(
mock_core: MagicMock,
tmp_path: Path,
) -> None:
"""Test clean_build when no build artifacts exist."""
# Setup paths that don't exist
pioenvs_dir = tmp_path / ".pioenvs"
piolibdeps_dir = tmp_path / ".piolibdeps"
dependencies_lock = tmp_path / "dependencies.lock"
# Setup mocks
mock_core.relative_pioenvs_path.return_value = str(pioenvs_dir)
mock_core.relative_piolibdeps_path.return_value = str(piolibdeps_dir)
mock_core.relative_build_path.return_value = str(dependencies_lock)
# Verify nothing exists
assert not pioenvs_dir.exists()
assert not piolibdeps_dir.exists()
assert not dependencies_lock.exists()
# Call the function - should not crash
clean_build()
# Verify nothing was created
assert not pioenvs_dir.exists()
assert not piolibdeps_dir.exists()
assert not dependencies_lock.exists()
@patch("esphome.writer.CORE")
def test_write_gitignore_creates_new_file(
mock_core: MagicMock,
tmp_path: Path,
) -> None:
"""Test write_gitignore creates a new .gitignore file when it doesn't exist."""
gitignore_path = tmp_path / ".gitignore"
# Setup mocks
mock_core.relative_config_path.return_value = str(gitignore_path)
# Verify file doesn't exist
assert not gitignore_path.exists()
# Call the function
write_gitignore()
# Verify file was created with correct content
assert gitignore_path.exists()
assert gitignore_path.read_text() == GITIGNORE_CONTENT
@patch("esphome.writer.CORE")
def test_write_gitignore_skips_existing_file(
mock_core: MagicMock,
tmp_path: Path,
) -> None:
"""Test write_gitignore doesn't overwrite existing .gitignore file."""
gitignore_path = tmp_path / ".gitignore"
existing_content = "# Custom gitignore\n/custom_dir/\n"
gitignore_path.write_text(existing_content)
# Setup mocks
mock_core.relative_config_path.return_value = str(gitignore_path)
# Verify file exists with custom content
assert gitignore_path.exists()
assert gitignore_path.read_text() == existing_content
# Call the function
write_gitignore()
# Verify file was not modified
assert gitignore_path.exists()
assert gitignore_path.read_text() == existing_content
@patch("esphome.writer.write_file_if_changed") # Mock to capture output
@patch("esphome.writer.copy_src_tree") # Keep this mock as it's complex
@patch("esphome.writer.CORE")
def test_write_cpp_with_existing_file(
mock_core: MagicMock,
mock_copy_src_tree: MagicMock,
mock_write_file: MagicMock,
tmp_path: Path,
) -> None:
"""Test write_cpp when main.cpp already exists."""
# Create a real file with markers
main_cpp = tmp_path / "main.cpp"
existing_content = f"""#include "esphome.h"
{CPP_INCLUDE_BEGIN}
// Old includes
{CPP_INCLUDE_END}
void setup() {{
{CPP_AUTO_GENERATE_BEGIN}
// Old code
{CPP_AUTO_GENERATE_END}
}}
void loop() {{}}"""
main_cpp.write_text(existing_content)
# Setup mocks
mock_core.relative_src_path.return_value = str(main_cpp)
mock_core.cpp_global_section = "// Global section"
# Call the function
test_code = " // New generated code"
write_cpp(test_code)
# Verify copy_src_tree was called
mock_copy_src_tree.assert_called_once()
# Get the content that would be written
mock_write_file.assert_called_once()
written_path, written_content = mock_write_file.call_args[0]
# Check that markers are preserved and content is updated
assert CPP_INCLUDE_BEGIN in written_content
assert CPP_INCLUDE_END in written_content
assert CPP_AUTO_GENERATE_BEGIN in written_content
assert CPP_AUTO_GENERATE_END in written_content
assert test_code in written_content
assert "// Global section" in written_content
@patch("esphome.writer.write_file_if_changed") # Mock to capture output
@patch("esphome.writer.copy_src_tree") # Keep this mock as it's complex
@patch("esphome.writer.CORE")
def test_write_cpp_creates_new_file(
mock_core: MagicMock,
mock_copy_src_tree: MagicMock,
mock_write_file: MagicMock,
tmp_path: Path,
) -> None:
"""Test write_cpp when main.cpp doesn't exist."""
# Setup path for new file
main_cpp = tmp_path / "main.cpp"
# Setup mocks
mock_core.relative_src_path.return_value = str(main_cpp)
mock_core.cpp_global_section = "// Global section"
# Verify file doesn't exist
assert not main_cpp.exists()
# Call the function
test_code = " // Generated code"
write_cpp(test_code)
# Verify copy_src_tree was called
mock_copy_src_tree.assert_called_once()
# Get the content that would be written
mock_write_file.assert_called_once()
written_path, written_content = mock_write_file.call_args[0]
assert written_path == str(main_cpp)
# Check that all necessary parts are in the new file
assert '#include "esphome.h"' in written_content
assert CPP_INCLUDE_BEGIN in written_content
assert CPP_INCLUDE_END in written_content
assert CPP_AUTO_GENERATE_BEGIN in written_content
assert CPP_AUTO_GENERATE_END in written_content
assert test_code in written_content
assert "void setup()" in written_content
assert "void loop()" in written_content
assert "App.setup();" in written_content
assert "App.loop();" in written_content
@pytest.mark.usefixtures("mock_copy_src_tree")
@patch("esphome.writer.CORE")
def test_write_cpp_with_missing_end_marker(
mock_core: MagicMock,
tmp_path: Path,
) -> None:
"""Test write_cpp raises error when end marker is missing."""
# Create a file with begin marker but no end marker
main_cpp = tmp_path / "main.cpp"
existing_content = f"""#include "esphome.h"
{CPP_AUTO_GENERATE_BEGIN}
// Code without end marker"""
main_cpp.write_text(existing_content)
# Setup mocks
mock_core.relative_src_path.return_value = str(main_cpp)
# Call should raise an error
with pytest.raises(EsphomeError, match="Could not find auto generated code end"):
write_cpp("// New code")
@pytest.mark.usefixtures("mock_copy_src_tree")
@patch("esphome.writer.CORE")
def test_write_cpp_with_duplicate_markers(
mock_core: MagicMock,
tmp_path: Path,
) -> None:
"""Test write_cpp raises error when duplicate markers exist."""
# Create a file with duplicate begin markers
main_cpp = tmp_path / "main.cpp"
existing_content = f"""#include "esphome.h"
{CPP_AUTO_GENERATE_BEGIN}
// First section
{CPP_AUTO_GENERATE_END}
{CPP_AUTO_GENERATE_BEGIN}
// Duplicate section
{CPP_AUTO_GENERATE_END}"""
main_cpp.write_text(existing_content)
# Setup mocks
mock_core.relative_src_path.return_value = str(main_cpp)
# Call should raise an error
with pytest.raises(EsphomeError, match="Found multiple auto generate code begins"):
write_cpp("// New code")