mirror of
https://github.com/esphome/esphome.git
synced 2025-09-15 17:52:19 +01:00
Merge remote-tracking branch 'upstream/dev' into dashboard_dns_lookup_delay
This commit is contained in:
@@ -15,9 +15,11 @@ import argcomplete
|
|||||||
|
|
||||||
from esphome import const, writer, yaml_util
|
from esphome import const, writer, yaml_util
|
||||||
import esphome.codegen as cg
|
import esphome.codegen as cg
|
||||||
|
from esphome.components.mqtt import CONF_DISCOVER_IP
|
||||||
from esphome.config import iter_component_configs, read_config, strip_default_ids
|
from esphome.config import iter_component_configs, read_config, strip_default_ids
|
||||||
from esphome.const import (
|
from esphome.const import (
|
||||||
ALLOWED_NAME_CHARS,
|
ALLOWED_NAME_CHARS,
|
||||||
|
CONF_API,
|
||||||
CONF_BAUD_RATE,
|
CONF_BAUD_RATE,
|
||||||
CONF_BROKER,
|
CONF_BROKER,
|
||||||
CONF_DEASSERT_RTS_DTR,
|
CONF_DEASSERT_RTS_DTR,
|
||||||
@@ -43,6 +45,7 @@ from esphome.const import (
|
|||||||
SECRETS_FILES,
|
SECRETS_FILES,
|
||||||
)
|
)
|
||||||
from esphome.core import CORE, EsphomeError, coroutine
|
from esphome.core import CORE, EsphomeError, coroutine
|
||||||
|
from esphome.enum import StrEnum
|
||||||
from esphome.helpers import get_bool_env, indent, is_ip_address
|
from esphome.helpers import get_bool_env, indent, is_ip_address
|
||||||
from esphome.log import AnsiFore, color, setup_log
|
from esphome.log import AnsiFore, color, setup_log
|
||||||
from esphome.types import ConfigType
|
from esphome.types import ConfigType
|
||||||
@@ -106,13 +109,15 @@ def choose_prompt(options, purpose: str = None):
|
|||||||
return options[opt - 1][1]
|
return options[opt - 1][1]
|
||||||
|
|
||||||
|
|
||||||
|
class Purpose(StrEnum):
|
||||||
|
UPLOADING = "uploading"
|
||||||
|
LOGGING = "logging"
|
||||||
|
|
||||||
|
|
||||||
def choose_upload_log_host(
|
def choose_upload_log_host(
|
||||||
default: list[str] | str | None,
|
default: list[str] | str | None,
|
||||||
check_default: str | None,
|
check_default: str | None,
|
||||||
show_ota: bool,
|
purpose: Purpose,
|
||||||
show_mqtt: bool,
|
|
||||||
show_api: bool,
|
|
||||||
purpose: str | None = None,
|
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
# Convert to list for uniform handling
|
# Convert to list for uniform handling
|
||||||
defaults = [default] if isinstance(default, str) else default or []
|
defaults = [default] if isinstance(default, str) else default or []
|
||||||
@@ -132,9 +137,10 @@ def choose_upload_log_host(
|
|||||||
]
|
]
|
||||||
resolved.append(choose_prompt(options, purpose=purpose))
|
resolved.append(choose_prompt(options, purpose=purpose))
|
||||||
elif device == "OTA":
|
elif device == "OTA":
|
||||||
if CORE.address and (
|
# ensure IP adresses are used first
|
||||||
(show_ota and "ota" in CORE.config)
|
if is_ip_address(CORE.address) and (
|
||||||
or (show_api and "api" in CORE.config)
|
(purpose == Purpose.LOGGING and has_api())
|
||||||
|
or (purpose == Purpose.UPLOADING and has_ota())
|
||||||
):
|
):
|
||||||
# Check if we have cached addresses for CORE.address
|
# Check if we have cached addresses for CORE.address
|
||||||
if CORE.address_cache and (
|
if CORE.address_cache and (
|
||||||
@@ -144,8 +150,41 @@ def choose_upload_log_host(
|
|||||||
resolved.extend(cached)
|
resolved.extend(cached)
|
||||||
else:
|
else:
|
||||||
resolved.append(CORE.address)
|
resolved.append(CORE.address)
|
||||||
elif show_mqtt and has_mqtt_logging():
|
|
||||||
resolved.append("MQTT")
|
if purpose == Purpose.LOGGING:
|
||||||
|
if has_api() and has_mqtt_ip_lookup():
|
||||||
|
resolved.append("MQTTIP")
|
||||||
|
|
||||||
|
if has_mqtt_logging():
|
||||||
|
resolved.append("MQTT")
|
||||||
|
|
||||||
|
if has_api() and has_non_ip_address():
|
||||||
|
# Check if we have cached addresses for CORE.address
|
||||||
|
if CORE.address_cache and (
|
||||||
|
cached := CORE.address_cache.get_addresses(CORE.address)
|
||||||
|
):
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Using cached addresses for logging: %s", cached
|
||||||
|
)
|
||||||
|
resolved.extend(cached)
|
||||||
|
else:
|
||||||
|
resolved.append(CORE.address)
|
||||||
|
|
||||||
|
elif purpose == Purpose.UPLOADING:
|
||||||
|
if has_ota() and has_mqtt_ip_lookup():
|
||||||
|
resolved.append("MQTTIP")
|
||||||
|
|
||||||
|
if has_ota() and has_non_ip_address():
|
||||||
|
# Check if we have cached addresses for CORE.address
|
||||||
|
if CORE.address_cache and (
|
||||||
|
cached := CORE.address_cache.get_addresses(CORE.address)
|
||||||
|
):
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Using cached addresses for uploading: %s", cached
|
||||||
|
)
|
||||||
|
resolved.extend(cached)
|
||||||
|
else:
|
||||||
|
resolved.append(CORE.address)
|
||||||
else:
|
else:
|
||||||
resolved.append(device)
|
resolved.append(device)
|
||||||
if not resolved:
|
if not resolved:
|
||||||
@@ -156,39 +195,111 @@ def choose_upload_log_host(
|
|||||||
options = [
|
options = [
|
||||||
(f"{port.path} ({port.description})", port.path) for port in get_serial_ports()
|
(f"{port.path} ({port.description})", port.path) for port in get_serial_ports()
|
||||||
]
|
]
|
||||||
if (show_ota and "ota" in CORE.config) or (show_api and "api" in CORE.config):
|
|
||||||
options.append((f"Over The Air ({CORE.address})", CORE.address))
|
if purpose == Purpose.LOGGING:
|
||||||
if show_mqtt and has_mqtt_logging():
|
if has_mqtt_logging():
|
||||||
mqtt_config = CORE.config[CONF_MQTT]
|
mqtt_config = CORE.config[CONF_MQTT]
|
||||||
options.append((f"MQTT ({mqtt_config[CONF_BROKER]})", "MQTT"))
|
options.append((f"MQTT ({mqtt_config[CONF_BROKER]})", "MQTT"))
|
||||||
|
|
||||||
|
if has_api():
|
||||||
|
if has_resolvable_address():
|
||||||
|
options.append((f"Over The Air ({CORE.address})", CORE.address))
|
||||||
|
if has_mqtt_ip_lookup():
|
||||||
|
options.append(("Over The Air (MQTT IP lookup)", "MQTTIP"))
|
||||||
|
|
||||||
|
elif purpose == Purpose.UPLOADING and has_ota():
|
||||||
|
if has_resolvable_address():
|
||||||
|
options.append((f"Over The Air ({CORE.address})", CORE.address))
|
||||||
|
if has_mqtt_ip_lookup():
|
||||||
|
options.append(("Over The Air (MQTT IP lookup)", "MQTTIP"))
|
||||||
|
|
||||||
if check_default is not None and check_default in [opt[1] for opt in options]:
|
if check_default is not None and check_default in [opt[1] for opt in options]:
|
||||||
return [check_default]
|
return [check_default]
|
||||||
return [choose_prompt(options, purpose=purpose)]
|
return [choose_prompt(options, purpose=purpose)]
|
||||||
|
|
||||||
|
|
||||||
def mqtt_logging_enabled(mqtt_config):
|
def has_mqtt_logging() -> bool:
|
||||||
|
"""Check if MQTT logging is available."""
|
||||||
|
if CONF_MQTT not in CORE.config:
|
||||||
|
return False
|
||||||
|
|
||||||
|
mqtt_config = CORE.config[CONF_MQTT]
|
||||||
|
|
||||||
|
# enabled by default
|
||||||
|
if CONF_LOG_TOPIC not in mqtt_config:
|
||||||
|
return True
|
||||||
|
|
||||||
log_topic = mqtt_config[CONF_LOG_TOPIC]
|
log_topic = mqtt_config[CONF_LOG_TOPIC]
|
||||||
if log_topic is None:
|
if log_topic is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if CONF_TOPIC not in log_topic:
|
if CONF_TOPIC not in log_topic:
|
||||||
return False
|
return False
|
||||||
return log_topic.get(CONF_LEVEL, None) != "NONE"
|
|
||||||
|
return log_topic[CONF_LEVEL] != "NONE"
|
||||||
|
|
||||||
|
|
||||||
def has_mqtt_logging() -> bool:
|
def has_mqtt() -> bool:
|
||||||
"""Check if MQTT logging is available."""
|
"""Check if MQTT is available."""
|
||||||
return (mqtt_config := CORE.config.get(CONF_MQTT)) and mqtt_logging_enabled(
|
return CONF_MQTT in CORE.config
|
||||||
mqtt_config
|
|
||||||
)
|
|
||||||
|
def has_api() -> bool:
|
||||||
|
"""Check if API is available."""
|
||||||
|
return CONF_API in CORE.config
|
||||||
|
|
||||||
|
|
||||||
|
def has_ota() -> bool:
|
||||||
|
"""Check if OTA is available."""
|
||||||
|
return CONF_OTA in CORE.config
|
||||||
|
|
||||||
|
|
||||||
|
def has_mqtt_ip_lookup() -> bool:
|
||||||
|
"""Check if MQTT is available and IP lookup is supported."""
|
||||||
|
if CONF_MQTT not in CORE.config:
|
||||||
|
return False
|
||||||
|
# Default Enabled
|
||||||
|
if CONF_DISCOVER_IP not in CORE.config[CONF_MQTT]:
|
||||||
|
return True
|
||||||
|
return CORE.config[CONF_MQTT][CONF_DISCOVER_IP]
|
||||||
|
|
||||||
|
|
||||||
|
def has_mdns() -> bool:
|
||||||
|
"""Check if MDNS is available."""
|
||||||
|
return CONF_MDNS not in CORE.config or not CORE.config[CONF_MDNS][CONF_DISABLED]
|
||||||
|
|
||||||
|
|
||||||
|
def has_non_ip_address() -> bool:
|
||||||
|
"""Check if CORE.address is set and is not an IP address."""
|
||||||
|
return CORE.address is not None and not is_ip_address(CORE.address)
|
||||||
|
|
||||||
|
|
||||||
|
def has_ip_address() -> bool:
|
||||||
|
"""Check if CORE.address is a valid IP address."""
|
||||||
|
return CORE.address is not None and is_ip_address(CORE.address)
|
||||||
|
|
||||||
|
|
||||||
|
def has_resolvable_address() -> bool:
|
||||||
|
"""Check if CORE.address is resolvable (via mDNS or is an IP address)."""
|
||||||
|
return has_mdns() or has_ip_address()
|
||||||
|
|
||||||
|
|
||||||
|
def mqtt_get_ip(config: ConfigType, username: str, password: str, client_id: str):
|
||||||
|
from esphome import mqtt
|
||||||
|
|
||||||
|
return mqtt.get_esphome_device_ip(config, username, password, client_id)
|
||||||
|
|
||||||
|
|
||||||
|
_PORT_TO_PORT_TYPE = {
|
||||||
|
"MQTT": "MQTT",
|
||||||
|
"MQTTIP": "MQTTIP",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_port_type(port: str) -> str:
|
def get_port_type(port: str) -> str:
|
||||||
if port.startswith("/") or port.startswith("COM"):
|
if port.startswith("/") or port.startswith("COM"):
|
||||||
return "SERIAL"
|
return "SERIAL"
|
||||||
if port == "MQTT":
|
return _PORT_TO_PORT_TYPE.get(port, "NETWORK")
|
||||||
return "MQTT"
|
|
||||||
return "NETWORK"
|
|
||||||
|
|
||||||
|
|
||||||
def run_miniterm(config: ConfigType, port: str, args) -> int:
|
def run_miniterm(config: ConfigType, port: str, args) -> int:
|
||||||
@@ -446,23 +557,9 @@ def upload_program(
|
|||||||
password = ota_conf.get(CONF_PASSWORD, "")
|
password = ota_conf.get(CONF_PASSWORD, "")
|
||||||
binary = args.file if getattr(args, "file", None) is not None else CORE.firmware_bin
|
binary = args.file if getattr(args, "file", None) is not None else CORE.firmware_bin
|
||||||
|
|
||||||
# Check if we should use MQTT for address resolution
|
# MQTT address resolution
|
||||||
# This happens when no device was specified, or the current host is "MQTT"/"OTA"
|
if get_port_type(host) in ("MQTT", "MQTTIP"):
|
||||||
if (
|
devices = mqtt_get_ip(config, args.username, args.password, args.client_id)
|
||||||
CONF_MQTT in config # pylint: disable=too-many-boolean-expressions
|
|
||||||
and (not devices or host in ("MQTT", "OTA"))
|
|
||||||
and (
|
|
||||||
((config[CONF_MDNS][CONF_DISABLED]) and not is_ip_address(CORE.address))
|
|
||||||
or get_port_type(host) == "MQTT"
|
|
||||||
)
|
|
||||||
):
|
|
||||||
from esphome import mqtt
|
|
||||||
|
|
||||||
devices = [
|
|
||||||
mqtt.get_esphome_device_ip(
|
|
||||||
config, args.username, args.password, args.client_id
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
return espota2.run_ota(devices, remote_port, password, binary)
|
return espota2.run_ota(devices, remote_port, password, binary)
|
||||||
|
|
||||||
@@ -483,20 +580,28 @@ def show_logs(config: ConfigType, args: ArgsProtocol, devices: list[str]) -> int
|
|||||||
if get_port_type(port) == "SERIAL":
|
if get_port_type(port) == "SERIAL":
|
||||||
check_permissions(port)
|
check_permissions(port)
|
||||||
return run_miniterm(config, port, args)
|
return run_miniterm(config, port, args)
|
||||||
if get_port_type(port) == "NETWORK" and "api" in config:
|
|
||||||
addresses_to_use = devices
|
|
||||||
if config[CONF_MDNS][CONF_DISABLED] and CONF_MQTT in config:
|
|
||||||
from esphome import mqtt
|
|
||||||
|
|
||||||
mqtt_address = mqtt.get_esphome_device_ip(
|
port_type = get_port_type(port)
|
||||||
|
|
||||||
|
# Check if we should use API for logging
|
||||||
|
if has_api():
|
||||||
|
addresses_to_use: list[str] | None = None
|
||||||
|
|
||||||
|
if port_type == "NETWORK" and (has_mdns() or is_ip_address(port)):
|
||||||
|
addresses_to_use = devices
|
||||||
|
elif port_type in ("NETWORK", "MQTT", "MQTTIP") and has_mqtt_ip_lookup():
|
||||||
|
# Only use MQTT IP lookup if the first condition didn't match
|
||||||
|
# (for MQTT/MQTTIP types, or for NETWORK when mdns/ip check fails)
|
||||||
|
addresses_to_use = mqtt_get_ip(
|
||||||
config, args.username, args.password, args.client_id
|
config, args.username, args.password, args.client_id
|
||||||
)[0]
|
)
|
||||||
addresses_to_use = [mqtt_address]
|
|
||||||
|
|
||||||
from esphome.components.api.client import run_logs
|
if addresses_to_use is not None:
|
||||||
|
from esphome.components.api.client import run_logs
|
||||||
|
|
||||||
return run_logs(config, addresses_to_use)
|
return run_logs(config, addresses_to_use)
|
||||||
if get_port_type(port) in ("NETWORK", "MQTT") and "mqtt" in config:
|
|
||||||
|
if port_type in ("NETWORK", "MQTT") and has_mqtt_logging():
|
||||||
from esphome import mqtt
|
from esphome import mqtt
|
||||||
|
|
||||||
return mqtt.show_logs(
|
return mqtt.show_logs(
|
||||||
@@ -562,10 +667,7 @@ def command_upload(args: ArgsProtocol, config: ConfigType) -> int | None:
|
|||||||
devices = choose_upload_log_host(
|
devices = choose_upload_log_host(
|
||||||
default=args.device,
|
default=args.device,
|
||||||
check_default=None,
|
check_default=None,
|
||||||
show_ota=True,
|
purpose=Purpose.UPLOADING,
|
||||||
show_mqtt=False,
|
|
||||||
show_api=False,
|
|
||||||
purpose="uploading",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
exit_code, _ = upload_program(config, args, devices)
|
exit_code, _ = upload_program(config, args, devices)
|
||||||
@@ -590,10 +692,7 @@ def command_logs(args: ArgsProtocol, config: ConfigType) -> int | None:
|
|||||||
devices = choose_upload_log_host(
|
devices = choose_upload_log_host(
|
||||||
default=args.device,
|
default=args.device,
|
||||||
check_default=None,
|
check_default=None,
|
||||||
show_ota=False,
|
purpose=Purpose.LOGGING,
|
||||||
show_mqtt=True,
|
|
||||||
show_api=True,
|
|
||||||
purpose="logging",
|
|
||||||
)
|
)
|
||||||
return show_logs(config, args, devices)
|
return show_logs(config, args, devices)
|
||||||
|
|
||||||
@@ -619,10 +718,7 @@ def command_run(args: ArgsProtocol, config: ConfigType) -> int | None:
|
|||||||
devices = choose_upload_log_host(
|
devices = choose_upload_log_host(
|
||||||
default=args.device,
|
default=args.device,
|
||||||
check_default=None,
|
check_default=None,
|
||||||
show_ota=True,
|
purpose=Purpose.UPLOADING,
|
||||||
show_mqtt=False,
|
|
||||||
show_api=True,
|
|
||||||
purpose="uploading",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
exit_code, successful_device = upload_program(config, args, devices)
|
exit_code, successful_device = upload_program(config, args, devices)
|
||||||
@@ -639,10 +735,7 @@ def command_run(args: ArgsProtocol, config: ConfigType) -> int | None:
|
|||||||
devices = choose_upload_log_host(
|
devices = choose_upload_log_host(
|
||||||
default=successful_device,
|
default=successful_device,
|
||||||
check_default=successful_device,
|
check_default=successful_device,
|
||||||
show_ota=False,
|
purpose=Purpose.LOGGING,
|
||||||
show_mqtt=True,
|
|
||||||
show_api=True,
|
|
||||||
purpose="logging",
|
|
||||||
)
|
)
|
||||||
return show_logs(config, args, devices)
|
return show_logs(config, args, devices)
|
||||||
|
|
||||||
|
@@ -270,6 +270,7 @@ void PacketTransport::add_binary_data_(uint8_t key, const char *id, bool data) {
|
|||||||
auto len = 1 + 1 + 1 + strlen(id);
|
auto len = 1 + 1 + 1 + strlen(id);
|
||||||
if (len + this->header_.size() + this->data_.size() > this->get_max_packet_size()) {
|
if (len + this->header_.size() + this->data_.size() > this->get_max_packet_size()) {
|
||||||
this->flush_();
|
this->flush_();
|
||||||
|
this->init_data_();
|
||||||
}
|
}
|
||||||
add(this->data_, key);
|
add(this->data_, key);
|
||||||
add(this->data_, (uint8_t) data);
|
add(this->data_, (uint8_t) data);
|
||||||
@@ -284,6 +285,7 @@ void PacketTransport::add_data_(uint8_t key, const char *id, uint32_t data) {
|
|||||||
auto len = 4 + 1 + 1 + strlen(id);
|
auto len = 4 + 1 + 1 + strlen(id);
|
||||||
if (len + this->header_.size() + this->data_.size() > this->get_max_packet_size()) {
|
if (len + this->header_.size() + this->data_.size() > this->get_max_packet_size()) {
|
||||||
this->flush_();
|
this->flush_();
|
||||||
|
this->init_data_();
|
||||||
}
|
}
|
||||||
add(this->data_, key);
|
add(this->data_, key);
|
||||||
add(this->data_, data);
|
add(this->data_, data);
|
||||||
|
@@ -114,6 +114,7 @@ CONF_AND = "and"
|
|||||||
CONF_ANGLE = "angle"
|
CONF_ANGLE = "angle"
|
||||||
CONF_ANY = "any"
|
CONF_ANY = "any"
|
||||||
CONF_AP = "ap"
|
CONF_AP = "ap"
|
||||||
|
CONF_API = "api"
|
||||||
CONF_APPARENT_POWER = "apparent_power"
|
CONF_APPARENT_POWER = "apparent_power"
|
||||||
CONF_ARDUINO_VERSION = "arduino_version"
|
CONF_ARDUINO_VERSION = "arduino_version"
|
||||||
CONF_AREA = "area"
|
CONF_AREA = "area"
|
||||||
|
@@ -8,7 +8,7 @@ pre-commit
|
|||||||
pytest==8.4.2
|
pytest==8.4.2
|
||||||
pytest-cov==7.0.0
|
pytest-cov==7.0.0
|
||||||
pytest-mock==3.15.0
|
pytest-mock==3.15.0
|
||||||
pytest-asyncio==1.1.0
|
pytest-asyncio==1.2.0
|
||||||
pytest-xdist==3.8.0
|
pytest-xdist==3.8.0
|
||||||
asyncmock==0.4.2
|
asyncmock==0.4.2
|
||||||
hypothesis==6.92.1
|
hypothesis==6.92.1
|
||||||
|
@@ -556,6 +556,66 @@ def test_start_web_server_with_address_port(
|
|||||||
assert (archive_dir / "old.yaml").exists()
|
assert (archive_dir / "old.yaml").exists()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_edit_request_handler_get(dashboard: DashboardTestHelper) -> None:
|
||||||
|
"""Test EditRequestHandler.get method."""
|
||||||
|
# Test getting a valid yaml file
|
||||||
|
response = await dashboard.fetch("/edit?configuration=pico.yaml")
|
||||||
|
assert response.code == 200
|
||||||
|
assert response.headers["content-type"] == "application/yaml"
|
||||||
|
content = response.body.decode()
|
||||||
|
assert "esphome:" in content # Verify it's a valid ESPHome config
|
||||||
|
|
||||||
|
# Test getting a non-existent file
|
||||||
|
with pytest.raises(HTTPClientError) as exc_info:
|
||||||
|
await dashboard.fetch("/edit?configuration=nonexistent.yaml")
|
||||||
|
assert exc_info.value.code == 404
|
||||||
|
|
||||||
|
# Test getting a non-yaml file
|
||||||
|
with pytest.raises(HTTPClientError) as exc_info:
|
||||||
|
await dashboard.fetch("/edit?configuration=test.txt")
|
||||||
|
assert exc_info.value.code == 404
|
||||||
|
|
||||||
|
# Test path traversal attempt
|
||||||
|
with pytest.raises(HTTPClientError) as exc_info:
|
||||||
|
await dashboard.fetch("/edit?configuration=../../../etc/passwd")
|
||||||
|
assert exc_info.value.code == 404
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_archive_request_handler_post(
|
||||||
|
dashboard: DashboardTestHelper,
|
||||||
|
mock_archive_storage_path: MagicMock,
|
||||||
|
mock_ext_storage_path: MagicMock,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test ArchiveRequestHandler.post method."""
|
||||||
|
|
||||||
|
# Set up temp directories
|
||||||
|
config_dir = Path(get_fixture_path("conf"))
|
||||||
|
archive_dir = tmp_path / "archive"
|
||||||
|
|
||||||
|
# Create a test configuration file
|
||||||
|
test_config = config_dir / "test_archive.yaml"
|
||||||
|
test_config.write_text("esphome:\n name: test_archive\n")
|
||||||
|
|
||||||
|
# Archive the configuration
|
||||||
|
response = await dashboard.fetch(
|
||||||
|
"/archive",
|
||||||
|
method="POST",
|
||||||
|
body="configuration=test_archive.yaml",
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
)
|
||||||
|
assert response.code == 200
|
||||||
|
|
||||||
|
# Verify file was moved to archive
|
||||||
|
assert not test_config.exists()
|
||||||
|
assert (archive_dir / "test_archive.yaml").exists()
|
||||||
|
assert (
|
||||||
|
archive_dir / "test_archive.yaml"
|
||||||
|
).read_text() == "esphome:\n name: test_archive\n"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(os.name == "nt", reason="Unix sockets are not supported on Windows")
|
@pytest.mark.skipif(os.name == "nt", reason="Unix sockets are not supported on Windows")
|
||||||
@pytest.mark.usefixtures("mock_trash_storage_path", "mock_archive_storage_path")
|
@pytest.mark.usefixtures("mock_trash_storage_path", "mock_archive_storage_path")
|
||||||
def test_start_web_server_with_unix_socket(tmp_path: Path) -> None:
|
def test_start_web_server_with_unix_socket(tmp_path: Path) -> None:
|
||||||
|
File diff suppressed because it is too large
Load Diff
@@ -141,3 +141,170 @@ def test_list_yaml_files_mixed_extensions(tmp_path: Path) -> None:
|
|||||||
str(yaml_file),
|
str(yaml_file),
|
||||||
str(yml_file),
|
str(yml_file),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_yaml_files_does_not_recurse_into_subdirectories(tmp_path: Path) -> None:
|
||||||
|
"""Test that list_yaml_files only finds files in specified directory, not subdirectories."""
|
||||||
|
# Create directory structure with YAML files at different depths
|
||||||
|
root = tmp_path / "configs"
|
||||||
|
root.mkdir()
|
||||||
|
|
||||||
|
# Create YAML files in the root directory
|
||||||
|
(root / "config1.yaml").write_text("test: 1")
|
||||||
|
(root / "config2.yml").write_text("test: 2")
|
||||||
|
(root / "device.yaml").write_text("test: device")
|
||||||
|
|
||||||
|
# Create subdirectory with YAML files (should NOT be found)
|
||||||
|
subdir = root / "subdir"
|
||||||
|
subdir.mkdir()
|
||||||
|
(subdir / "nested1.yaml").write_text("test: nested1")
|
||||||
|
(subdir / "nested2.yml").write_text("test: nested2")
|
||||||
|
|
||||||
|
# Create deeper subdirectory (should NOT be found)
|
||||||
|
deep_subdir = subdir / "deeper"
|
||||||
|
deep_subdir.mkdir()
|
||||||
|
(deep_subdir / "very_nested.yaml").write_text("test: very_nested")
|
||||||
|
|
||||||
|
# Test listing files from the root directory
|
||||||
|
result = util.list_yaml_files([str(root)])
|
||||||
|
|
||||||
|
# Should only find the 3 files in root, not the 3 in subdirectories
|
||||||
|
assert len(result) == 3
|
||||||
|
|
||||||
|
# Check that only root-level files are found
|
||||||
|
assert str(root / "config1.yaml") in result
|
||||||
|
assert str(root / "config2.yml") in result
|
||||||
|
assert str(root / "device.yaml") in result
|
||||||
|
|
||||||
|
# Ensure nested files are NOT found
|
||||||
|
for r in result:
|
||||||
|
assert "subdir" not in r
|
||||||
|
assert "deeper" not in r
|
||||||
|
assert "nested1.yaml" not in r
|
||||||
|
assert "nested2.yml" not in r
|
||||||
|
assert "very_nested.yaml" not in r
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_yaml_files_excludes_secrets(tmp_path: Path) -> None:
|
||||||
|
"""Test that secrets.yaml and secrets.yml are excluded."""
|
||||||
|
root = tmp_path / "configs"
|
||||||
|
root.mkdir()
|
||||||
|
|
||||||
|
# Create various YAML files including secrets
|
||||||
|
(root / "config.yaml").write_text("test: config")
|
||||||
|
(root / "secrets.yaml").write_text("wifi_password: secret123")
|
||||||
|
(root / "secrets.yml").write_text("api_key: secret456")
|
||||||
|
(root / "device.yaml").write_text("test: device")
|
||||||
|
|
||||||
|
result = util.list_yaml_files([str(root)])
|
||||||
|
|
||||||
|
# Should find 2 files (config.yaml and device.yaml), not secrets
|
||||||
|
assert len(result) == 2
|
||||||
|
assert str(root / "config.yaml") in result
|
||||||
|
assert str(root / "device.yaml") in result
|
||||||
|
assert str(root / "secrets.yaml") not in result
|
||||||
|
assert str(root / "secrets.yml") not in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_yaml_files_excludes_hidden_files(tmp_path: Path) -> None:
|
||||||
|
"""Test that hidden files (starting with .) are excluded."""
|
||||||
|
root = tmp_path / "configs"
|
||||||
|
root.mkdir()
|
||||||
|
|
||||||
|
# Create regular and hidden YAML files
|
||||||
|
(root / "config.yaml").write_text("test: config")
|
||||||
|
(root / ".hidden.yaml").write_text("test: hidden")
|
||||||
|
(root / ".backup.yml").write_text("test: backup")
|
||||||
|
(root / "device.yaml").write_text("test: device")
|
||||||
|
|
||||||
|
result = util.list_yaml_files([str(root)])
|
||||||
|
|
||||||
|
# Should find only non-hidden files
|
||||||
|
assert len(result) == 2
|
||||||
|
assert str(root / "config.yaml") in result
|
||||||
|
assert str(root / "device.yaml") in result
|
||||||
|
assert str(root / ".hidden.yaml") not in result
|
||||||
|
assert str(root / ".backup.yml") not in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_yaml_files_basic() -> None:
|
||||||
|
"""Test filter_yaml_files function."""
|
||||||
|
files = [
|
||||||
|
"/path/to/config.yaml",
|
||||||
|
"/path/to/device.yml",
|
||||||
|
"/path/to/readme.txt",
|
||||||
|
"/path/to/script.py",
|
||||||
|
"/path/to/data.json",
|
||||||
|
"/path/to/another.yaml",
|
||||||
|
]
|
||||||
|
|
||||||
|
result = util.filter_yaml_files(files)
|
||||||
|
|
||||||
|
assert len(result) == 3
|
||||||
|
assert "/path/to/config.yaml" in result
|
||||||
|
assert "/path/to/device.yml" in result
|
||||||
|
assert "/path/to/another.yaml" in result
|
||||||
|
assert "/path/to/readme.txt" not in result
|
||||||
|
assert "/path/to/script.py" not in result
|
||||||
|
assert "/path/to/data.json" not in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_yaml_files_excludes_secrets() -> None:
|
||||||
|
"""Test that filter_yaml_files excludes secrets files."""
|
||||||
|
files = [
|
||||||
|
"/path/to/config.yaml",
|
||||||
|
"/path/to/secrets.yaml",
|
||||||
|
"/path/to/secrets.yml",
|
||||||
|
"/path/to/device.yaml",
|
||||||
|
"/some/dir/secrets.yaml",
|
||||||
|
]
|
||||||
|
|
||||||
|
result = util.filter_yaml_files(files)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert "/path/to/config.yaml" in result
|
||||||
|
assert "/path/to/device.yaml" in result
|
||||||
|
assert "/path/to/secrets.yaml" not in result
|
||||||
|
assert "/path/to/secrets.yml" not in result
|
||||||
|
assert "/some/dir/secrets.yaml" not in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_yaml_files_excludes_hidden() -> None:
|
||||||
|
"""Test that filter_yaml_files excludes hidden files."""
|
||||||
|
files = [
|
||||||
|
"/path/to/config.yaml",
|
||||||
|
"/path/to/.hidden.yaml",
|
||||||
|
"/path/to/.backup.yml",
|
||||||
|
"/path/to/device.yaml",
|
||||||
|
"/some/dir/.config.yaml",
|
||||||
|
]
|
||||||
|
|
||||||
|
result = util.filter_yaml_files(files)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert "/path/to/config.yaml" in result
|
||||||
|
assert "/path/to/device.yaml" in result
|
||||||
|
assert "/path/to/.hidden.yaml" not in result
|
||||||
|
assert "/path/to/.backup.yml" not in result
|
||||||
|
assert "/some/dir/.config.yaml" not in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_yaml_files_case_sensitive() -> None:
|
||||||
|
"""Test that filter_yaml_files is case-sensitive for extensions."""
|
||||||
|
files = [
|
||||||
|
"/path/to/config.yaml",
|
||||||
|
"/path/to/config.YAML",
|
||||||
|
"/path/to/config.YML",
|
||||||
|
"/path/to/config.Yaml",
|
||||||
|
"/path/to/config.yml",
|
||||||
|
]
|
||||||
|
|
||||||
|
result = util.filter_yaml_files(files)
|
||||||
|
|
||||||
|
# Should only match lowercase .yaml and .yml
|
||||||
|
assert len(result) == 2
|
||||||
|
assert "/path/to/config.yaml" in result
|
||||||
|
assert "/path/to/config.yml" in result
|
||||||
|
assert "/path/to/config.YAML" not in result
|
||||||
|
assert "/path/to/config.YML" not in result
|
||||||
|
assert "/path/to/config.Yaml" not in result
|
||||||
|
@@ -1,13 +1,34 @@
|
|||||||
"""Test writer module functionality."""
|
"""Test writer module functionality."""
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from esphome.core import EsphomeError
|
||||||
from esphome.storage_json import StorageJSON
|
from esphome.storage_json import StorageJSON
|
||||||
from esphome.writer import storage_should_clean, update_storage_json
|
from esphome.writer import (
|
||||||
|
CPP_AUTO_GENERATE_BEGIN,
|
||||||
|
CPP_AUTO_GENERATE_END,
|
||||||
|
CPP_INCLUDE_BEGIN,
|
||||||
|
CPP_INCLUDE_END,
|
||||||
|
GITIGNORE_CONTENT,
|
||||||
|
clean_build,
|
||||||
|
clean_cmake_cache,
|
||||||
|
storage_should_clean,
|
||||||
|
update_storage_json,
|
||||||
|
write_cpp,
|
||||||
|
write_gitignore,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_copy_src_tree():
|
||||||
|
"""Mock copy_src_tree to avoid side effects during tests."""
|
||||||
|
with patch("esphome.writer.copy_src_tree"):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -218,3 +239,396 @@ def test_update_storage_json_logging_components_removed(
|
|||||||
|
|
||||||
# Verify save was called
|
# Verify save was called
|
||||||
new_storage.save.assert_called_once_with("/test/path")
|
new_storage.save.assert_called_once_with("/test/path")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("esphome.writer.CORE")
|
||||||
|
def test_clean_cmake_cache(
|
||||||
|
mock_core: MagicMock,
|
||||||
|
tmp_path: Path,
|
||||||
|
caplog: pytest.LogCaptureFixture,
|
||||||
|
) -> None:
|
||||||
|
"""Test clean_cmake_cache removes CMakeCache.txt file."""
|
||||||
|
# Create directory structure
|
||||||
|
pioenvs_dir = tmp_path / ".pioenvs"
|
||||||
|
pioenvs_dir.mkdir()
|
||||||
|
device_dir = pioenvs_dir / "test_device"
|
||||||
|
device_dir.mkdir()
|
||||||
|
cmake_cache_file = device_dir / "CMakeCache.txt"
|
||||||
|
cmake_cache_file.write_text("# CMake cache file")
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
mock_core.relative_pioenvs_path.side_effect = [
|
||||||
|
str(pioenvs_dir), # First call for directory check
|
||||||
|
str(cmake_cache_file), # Second call for file path
|
||||||
|
]
|
||||||
|
mock_core.name = "test_device"
|
||||||
|
|
||||||
|
# Verify file exists before
|
||||||
|
assert cmake_cache_file.exists()
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
with caplog.at_level("INFO"):
|
||||||
|
clean_cmake_cache()
|
||||||
|
|
||||||
|
# Verify file was removed
|
||||||
|
assert not cmake_cache_file.exists()
|
||||||
|
|
||||||
|
# Verify logging
|
||||||
|
assert "Deleting" in caplog.text
|
||||||
|
assert "CMakeCache.txt" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
@patch("esphome.writer.CORE")
|
||||||
|
def test_clean_cmake_cache_no_pioenvs_dir(
|
||||||
|
mock_core: MagicMock,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test clean_cmake_cache when pioenvs directory doesn't exist."""
|
||||||
|
# Setup non-existent directory path
|
||||||
|
pioenvs_dir = tmp_path / ".pioenvs"
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
mock_core.relative_pioenvs_path.return_value = str(pioenvs_dir)
|
||||||
|
|
||||||
|
# Verify directory doesn't exist
|
||||||
|
assert not pioenvs_dir.exists()
|
||||||
|
|
||||||
|
# Call the function - should not crash
|
||||||
|
clean_cmake_cache()
|
||||||
|
|
||||||
|
# Verify directory still doesn't exist
|
||||||
|
assert not pioenvs_dir.exists()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("esphome.writer.CORE")
|
||||||
|
def test_clean_cmake_cache_no_cmake_file(
|
||||||
|
mock_core: MagicMock,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test clean_cmake_cache when CMakeCache.txt doesn't exist."""
|
||||||
|
# Create directory structure without CMakeCache.txt
|
||||||
|
pioenvs_dir = tmp_path / ".pioenvs"
|
||||||
|
pioenvs_dir.mkdir()
|
||||||
|
device_dir = pioenvs_dir / "test_device"
|
||||||
|
device_dir.mkdir()
|
||||||
|
cmake_cache_file = device_dir / "CMakeCache.txt"
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
mock_core.relative_pioenvs_path.side_effect = [
|
||||||
|
str(pioenvs_dir), # First call for directory check
|
||||||
|
str(cmake_cache_file), # Second call for file path
|
||||||
|
]
|
||||||
|
mock_core.name = "test_device"
|
||||||
|
|
||||||
|
# Verify file doesn't exist
|
||||||
|
assert not cmake_cache_file.exists()
|
||||||
|
|
||||||
|
# Call the function - should not crash
|
||||||
|
clean_cmake_cache()
|
||||||
|
|
||||||
|
# Verify file still doesn't exist
|
||||||
|
assert not cmake_cache_file.exists()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("esphome.writer.CORE")
|
||||||
|
def test_clean_build(
|
||||||
|
mock_core: MagicMock,
|
||||||
|
tmp_path: Path,
|
||||||
|
caplog: pytest.LogCaptureFixture,
|
||||||
|
) -> None:
|
||||||
|
"""Test clean_build removes all build artifacts."""
|
||||||
|
# Create directory structure and files
|
||||||
|
pioenvs_dir = tmp_path / ".pioenvs"
|
||||||
|
pioenvs_dir.mkdir()
|
||||||
|
(pioenvs_dir / "test_file.o").write_text("object file")
|
||||||
|
|
||||||
|
piolibdeps_dir = tmp_path / ".piolibdeps"
|
||||||
|
piolibdeps_dir.mkdir()
|
||||||
|
(piolibdeps_dir / "library").mkdir()
|
||||||
|
|
||||||
|
dependencies_lock = tmp_path / "dependencies.lock"
|
||||||
|
dependencies_lock.write_text("lock file")
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
mock_core.relative_pioenvs_path.return_value = str(pioenvs_dir)
|
||||||
|
mock_core.relative_piolibdeps_path.return_value = str(piolibdeps_dir)
|
||||||
|
mock_core.relative_build_path.return_value = str(dependencies_lock)
|
||||||
|
|
||||||
|
# Verify all exist before
|
||||||
|
assert pioenvs_dir.exists()
|
||||||
|
assert piolibdeps_dir.exists()
|
||||||
|
assert dependencies_lock.exists()
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
with caplog.at_level("INFO"):
|
||||||
|
clean_build()
|
||||||
|
|
||||||
|
# Verify all were removed
|
||||||
|
assert not pioenvs_dir.exists()
|
||||||
|
assert not piolibdeps_dir.exists()
|
||||||
|
assert not dependencies_lock.exists()
|
||||||
|
|
||||||
|
# Verify logging
|
||||||
|
assert "Deleting" in caplog.text
|
||||||
|
assert ".pioenvs" in caplog.text
|
||||||
|
assert ".piolibdeps" in caplog.text
|
||||||
|
assert "dependencies.lock" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
@patch("esphome.writer.CORE")
|
||||||
|
def test_clean_build_partial_exists(
|
||||||
|
mock_core: MagicMock,
|
||||||
|
tmp_path: Path,
|
||||||
|
caplog: pytest.LogCaptureFixture,
|
||||||
|
) -> None:
|
||||||
|
"""Test clean_build when only some paths exist."""
|
||||||
|
# Create only pioenvs directory
|
||||||
|
pioenvs_dir = tmp_path / ".pioenvs"
|
||||||
|
pioenvs_dir.mkdir()
|
||||||
|
(pioenvs_dir / "test_file.o").write_text("object file")
|
||||||
|
|
||||||
|
piolibdeps_dir = tmp_path / ".piolibdeps"
|
||||||
|
dependencies_lock = tmp_path / "dependencies.lock"
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
mock_core.relative_pioenvs_path.return_value = str(pioenvs_dir)
|
||||||
|
mock_core.relative_piolibdeps_path.return_value = str(piolibdeps_dir)
|
||||||
|
mock_core.relative_build_path.return_value = str(dependencies_lock)
|
||||||
|
|
||||||
|
# Verify only pioenvs exists
|
||||||
|
assert pioenvs_dir.exists()
|
||||||
|
assert not piolibdeps_dir.exists()
|
||||||
|
assert not dependencies_lock.exists()
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
with caplog.at_level("INFO"):
|
||||||
|
clean_build()
|
||||||
|
|
||||||
|
# Verify only existing path was removed
|
||||||
|
assert not pioenvs_dir.exists()
|
||||||
|
assert not piolibdeps_dir.exists()
|
||||||
|
assert not dependencies_lock.exists()
|
||||||
|
|
||||||
|
# Verify logging - only pioenvs should be logged
|
||||||
|
assert "Deleting" in caplog.text
|
||||||
|
assert ".pioenvs" in caplog.text
|
||||||
|
assert ".piolibdeps" not in caplog.text
|
||||||
|
assert "dependencies.lock" not in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
@patch("esphome.writer.CORE")
|
||||||
|
def test_clean_build_nothing_exists(
|
||||||
|
mock_core: MagicMock,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test clean_build when no build artifacts exist."""
|
||||||
|
# Setup paths that don't exist
|
||||||
|
pioenvs_dir = tmp_path / ".pioenvs"
|
||||||
|
piolibdeps_dir = tmp_path / ".piolibdeps"
|
||||||
|
dependencies_lock = tmp_path / "dependencies.lock"
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
mock_core.relative_pioenvs_path.return_value = str(pioenvs_dir)
|
||||||
|
mock_core.relative_piolibdeps_path.return_value = str(piolibdeps_dir)
|
||||||
|
mock_core.relative_build_path.return_value = str(dependencies_lock)
|
||||||
|
|
||||||
|
# Verify nothing exists
|
||||||
|
assert not pioenvs_dir.exists()
|
||||||
|
assert not piolibdeps_dir.exists()
|
||||||
|
assert not dependencies_lock.exists()
|
||||||
|
|
||||||
|
# Call the function - should not crash
|
||||||
|
clean_build()
|
||||||
|
|
||||||
|
# Verify nothing was created
|
||||||
|
assert not pioenvs_dir.exists()
|
||||||
|
assert not piolibdeps_dir.exists()
|
||||||
|
assert not dependencies_lock.exists()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("esphome.writer.CORE")
|
||||||
|
def test_write_gitignore_creates_new_file(
|
||||||
|
mock_core: MagicMock,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test write_gitignore creates a new .gitignore file when it doesn't exist."""
|
||||||
|
gitignore_path = tmp_path / ".gitignore"
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
mock_core.relative_config_path.return_value = str(gitignore_path)
|
||||||
|
|
||||||
|
# Verify file doesn't exist
|
||||||
|
assert not gitignore_path.exists()
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
write_gitignore()
|
||||||
|
|
||||||
|
# Verify file was created with correct content
|
||||||
|
assert gitignore_path.exists()
|
||||||
|
assert gitignore_path.read_text() == GITIGNORE_CONTENT
|
||||||
|
|
||||||
|
|
||||||
|
@patch("esphome.writer.CORE")
|
||||||
|
def test_write_gitignore_skips_existing_file(
|
||||||
|
mock_core: MagicMock,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test write_gitignore doesn't overwrite existing .gitignore file."""
|
||||||
|
gitignore_path = tmp_path / ".gitignore"
|
||||||
|
existing_content = "# Custom gitignore\n/custom_dir/\n"
|
||||||
|
gitignore_path.write_text(existing_content)
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
mock_core.relative_config_path.return_value = str(gitignore_path)
|
||||||
|
|
||||||
|
# Verify file exists with custom content
|
||||||
|
assert gitignore_path.exists()
|
||||||
|
assert gitignore_path.read_text() == existing_content
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
write_gitignore()
|
||||||
|
|
||||||
|
# Verify file was not modified
|
||||||
|
assert gitignore_path.exists()
|
||||||
|
assert gitignore_path.read_text() == existing_content
|
||||||
|
|
||||||
|
|
||||||
|
@patch("esphome.writer.write_file_if_changed") # Mock to capture output
|
||||||
|
@patch("esphome.writer.copy_src_tree") # Keep this mock as it's complex
|
||||||
|
@patch("esphome.writer.CORE")
|
||||||
|
def test_write_cpp_with_existing_file(
|
||||||
|
mock_core: MagicMock,
|
||||||
|
mock_copy_src_tree: MagicMock,
|
||||||
|
mock_write_file: MagicMock,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test write_cpp when main.cpp already exists."""
|
||||||
|
# Create a real file with markers
|
||||||
|
main_cpp = tmp_path / "main.cpp"
|
||||||
|
existing_content = f"""#include "esphome.h"
|
||||||
|
{CPP_INCLUDE_BEGIN}
|
||||||
|
// Old includes
|
||||||
|
{CPP_INCLUDE_END}
|
||||||
|
void setup() {{
|
||||||
|
{CPP_AUTO_GENERATE_BEGIN}
|
||||||
|
// Old code
|
||||||
|
{CPP_AUTO_GENERATE_END}
|
||||||
|
}}
|
||||||
|
void loop() {{}}"""
|
||||||
|
main_cpp.write_text(existing_content)
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
mock_core.relative_src_path.return_value = str(main_cpp)
|
||||||
|
mock_core.cpp_global_section = "// Global section"
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
test_code = " // New generated code"
|
||||||
|
write_cpp(test_code)
|
||||||
|
|
||||||
|
# Verify copy_src_tree was called
|
||||||
|
mock_copy_src_tree.assert_called_once()
|
||||||
|
|
||||||
|
# Get the content that would be written
|
||||||
|
mock_write_file.assert_called_once()
|
||||||
|
written_path, written_content = mock_write_file.call_args[0]
|
||||||
|
|
||||||
|
# Check that markers are preserved and content is updated
|
||||||
|
assert CPP_INCLUDE_BEGIN in written_content
|
||||||
|
assert CPP_INCLUDE_END in written_content
|
||||||
|
assert CPP_AUTO_GENERATE_BEGIN in written_content
|
||||||
|
assert CPP_AUTO_GENERATE_END in written_content
|
||||||
|
assert test_code in written_content
|
||||||
|
assert "// Global section" in written_content
|
||||||
|
|
||||||
|
|
||||||
|
@patch("esphome.writer.write_file_if_changed") # Mock to capture output
|
||||||
|
@patch("esphome.writer.copy_src_tree") # Keep this mock as it's complex
|
||||||
|
@patch("esphome.writer.CORE")
|
||||||
|
def test_write_cpp_creates_new_file(
|
||||||
|
mock_core: MagicMock,
|
||||||
|
mock_copy_src_tree: MagicMock,
|
||||||
|
mock_write_file: MagicMock,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test write_cpp when main.cpp doesn't exist."""
|
||||||
|
# Setup path for new file
|
||||||
|
main_cpp = tmp_path / "main.cpp"
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
mock_core.relative_src_path.return_value = str(main_cpp)
|
||||||
|
mock_core.cpp_global_section = "// Global section"
|
||||||
|
|
||||||
|
# Verify file doesn't exist
|
||||||
|
assert not main_cpp.exists()
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
test_code = " // Generated code"
|
||||||
|
write_cpp(test_code)
|
||||||
|
|
||||||
|
# Verify copy_src_tree was called
|
||||||
|
mock_copy_src_tree.assert_called_once()
|
||||||
|
|
||||||
|
# Get the content that would be written
|
||||||
|
mock_write_file.assert_called_once()
|
||||||
|
written_path, written_content = mock_write_file.call_args[0]
|
||||||
|
assert written_path == str(main_cpp)
|
||||||
|
|
||||||
|
# Check that all necessary parts are in the new file
|
||||||
|
assert '#include "esphome.h"' in written_content
|
||||||
|
assert CPP_INCLUDE_BEGIN in written_content
|
||||||
|
assert CPP_INCLUDE_END in written_content
|
||||||
|
assert CPP_AUTO_GENERATE_BEGIN in written_content
|
||||||
|
assert CPP_AUTO_GENERATE_END in written_content
|
||||||
|
assert test_code in written_content
|
||||||
|
assert "void setup()" in written_content
|
||||||
|
assert "void loop()" in written_content
|
||||||
|
assert "App.setup();" in written_content
|
||||||
|
assert "App.loop();" in written_content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_copy_src_tree")
|
||||||
|
@patch("esphome.writer.CORE")
|
||||||
|
def test_write_cpp_with_missing_end_marker(
|
||||||
|
mock_core: MagicMock,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test write_cpp raises error when end marker is missing."""
|
||||||
|
# Create a file with begin marker but no end marker
|
||||||
|
main_cpp = tmp_path / "main.cpp"
|
||||||
|
existing_content = f"""#include "esphome.h"
|
||||||
|
{CPP_AUTO_GENERATE_BEGIN}
|
||||||
|
// Code without end marker"""
|
||||||
|
main_cpp.write_text(existing_content)
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
mock_core.relative_src_path.return_value = str(main_cpp)
|
||||||
|
|
||||||
|
# Call should raise an error
|
||||||
|
with pytest.raises(EsphomeError, match="Could not find auto generated code end"):
|
||||||
|
write_cpp("// New code")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_copy_src_tree")
|
||||||
|
@patch("esphome.writer.CORE")
|
||||||
|
def test_write_cpp_with_duplicate_markers(
|
||||||
|
mock_core: MagicMock,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test write_cpp raises error when duplicate markers exist."""
|
||||||
|
# Create a file with duplicate begin markers
|
||||||
|
main_cpp = tmp_path / "main.cpp"
|
||||||
|
existing_content = f"""#include "esphome.h"
|
||||||
|
{CPP_AUTO_GENERATE_BEGIN}
|
||||||
|
// First section
|
||||||
|
{CPP_AUTO_GENERATE_END}
|
||||||
|
{CPP_AUTO_GENERATE_BEGIN}
|
||||||
|
// Duplicate section
|
||||||
|
{CPP_AUTO_GENERATE_END}"""
|
||||||
|
main_cpp.write_text(existing_content)
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
mock_core.relative_src_path.return_value = str(main_cpp)
|
||||||
|
|
||||||
|
# Call should raise an error
|
||||||
|
with pytest.raises(EsphomeError, match="Found multiple auto generate code begins"):
|
||||||
|
write_cpp("// New code")
|
||||||
|
Reference in New Issue
Block a user