1
0
mirror of https://github.com/esphome/esphome.git synced 2025-09-23 21:52:23 +01:00

Merge branch 'integration' into memory_api

This commit is contained in:
J. Nick Koston
2025-09-14 21:56:20 -05:00
34 changed files with 3424 additions and 136 deletions

View File

@@ -114,6 +114,14 @@ class Purpose(StrEnum):
LOGGING = "logging" LOGGING = "logging"
def _resolve_with_cache(address: str, purpose: Purpose) -> list[str]:
"""Resolve an address using cache if available, otherwise return the address itself."""
if CORE.address_cache and (cached := CORE.address_cache.get_addresses(address)):
_LOGGER.debug("Using cached addresses for %s: %s", purpose.value, cached)
return cached
return [address]
def choose_upload_log_host( def choose_upload_log_host(
default: list[str] | str | None, default: list[str] | str | None,
check_default: str | None, check_default: str | None,
@@ -142,7 +150,7 @@ def choose_upload_log_host(
(purpose == Purpose.LOGGING and has_api()) (purpose == Purpose.LOGGING and has_api())
or (purpose == Purpose.UPLOADING and has_ota()) or (purpose == Purpose.UPLOADING and has_ota())
): ):
resolved.append(CORE.address) resolved.extend(_resolve_with_cache(CORE.address, purpose))
if purpose == Purpose.LOGGING: if purpose == Purpose.LOGGING:
if has_api() and has_mqtt_ip_lookup(): if has_api() and has_mqtt_ip_lookup():
@@ -152,15 +160,14 @@ def choose_upload_log_host(
resolved.append("MQTT") resolved.append("MQTT")
if has_api() and has_non_ip_address(): if has_api() and has_non_ip_address():
resolved.append(CORE.address) resolved.extend(_resolve_with_cache(CORE.address, purpose))
elif purpose == Purpose.UPLOADING: elif purpose == Purpose.UPLOADING:
if has_ota() and has_mqtt_ip_lookup(): if has_ota() and has_mqtt_ip_lookup():
resolved.append("MQTTIP") resolved.append("MQTTIP")
if has_ota() and has_non_ip_address(): if has_ota() and has_non_ip_address():
resolved.append(CORE.address) resolved.extend(_resolve_with_cache(CORE.address, purpose))
else: else:
resolved.append(device) resolved.append(device)
if not resolved: if not resolved:
@@ -972,6 +979,18 @@ def parse_args(argv):
help="Add a substitution", help="Add a substitution",
metavar=("key", "value"), metavar=("key", "value"),
) )
options_parser.add_argument(
"--mdns-address-cache",
help="mDNS address cache mapping in format 'hostname=ip1,ip2'",
action="append",
default=[],
)
options_parser.add_argument(
"--dns-address-cache",
help="DNS address cache mapping in format 'hostname=ip1,ip2'",
action="append",
default=[],
)
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description=f"ESPHome {const.__version__}", parents=[options_parser] description=f"ESPHome {const.__version__}", parents=[options_parser]
@@ -1230,9 +1249,15 @@ def parse_args(argv):
def run_esphome(argv): def run_esphome(argv):
from esphome.address_cache import AddressCache
args = parse_args(argv) args = parse_args(argv)
CORE.dashboard = args.dashboard CORE.dashboard = args.dashboard
# Create address cache from command-line arguments
CORE.address_cache = AddressCache.from_cli_args(
args.mdns_address_cache, args.dns_address_cache
)
# Override log level if verbose is set # Override log level if verbose is set
if args.verbose: if args.verbose:
args.log_level = "DEBUG" args.log_level = "DEBUG"

142
esphome/address_cache.py Normal file
View File

@@ -0,0 +1,142 @@
"""Address cache for DNS and mDNS lookups."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Iterable
_LOGGER = logging.getLogger(__name__)
def normalize_hostname(hostname: str) -> str:
"""Normalize hostname for cache lookups.
Removes trailing dots and converts to lowercase.
"""
return hostname.rstrip(".").lower()
class AddressCache:
"""Cache for DNS and mDNS address lookups.
This cache stores pre-resolved addresses from command-line arguments
to avoid slow DNS/mDNS lookups during builds.
"""
def __init__(
self,
mdns_cache: dict[str, list[str]] | None = None,
dns_cache: dict[str, list[str]] | None = None,
) -> None:
"""Initialize the address cache.
Args:
mdns_cache: Pre-populated mDNS addresses (hostname -> IPs)
dns_cache: Pre-populated DNS addresses (hostname -> IPs)
"""
self.mdns_cache = mdns_cache or {}
self.dns_cache = dns_cache or {}
def _get_cached_addresses(
self, hostname: str, cache: dict[str, list[str]], cache_type: str
) -> list[str] | None:
"""Get cached addresses from a specific cache.
Args:
hostname: The hostname to look up
cache: The cache dictionary to check
cache_type: Type of cache for logging ("mDNS" or "DNS")
Returns:
List of IP addresses if found in cache, None otherwise
"""
normalized = normalize_hostname(hostname)
if addresses := cache.get(normalized):
_LOGGER.debug("Using %s cache for %s: %s", cache_type, hostname, addresses)
return addresses
return None
def get_mdns_addresses(self, hostname: str) -> list[str] | None:
"""Get cached mDNS addresses for a hostname.
Args:
hostname: The hostname to look up (should end with .local)
Returns:
List of IP addresses if found in cache, None otherwise
"""
return self._get_cached_addresses(hostname, self.mdns_cache, "mDNS")
def get_dns_addresses(self, hostname: str) -> list[str] | None:
"""Get cached DNS addresses for a hostname.
Args:
hostname: The hostname to look up
Returns:
List of IP addresses if found in cache, None otherwise
"""
return self._get_cached_addresses(hostname, self.dns_cache, "DNS")
def get_addresses(self, hostname: str) -> list[str] | None:
"""Get cached addresses for a hostname.
Checks mDNS cache for .local domains, DNS cache otherwise.
Args:
hostname: The hostname to look up
Returns:
List of IP addresses if found in cache, None otherwise
"""
normalized = normalize_hostname(hostname)
if normalized.endswith(".local"):
return self.get_mdns_addresses(hostname)
return self.get_dns_addresses(hostname)
def has_cache(self) -> bool:
"""Check if any cache entries exist."""
return bool(self.mdns_cache or self.dns_cache)
@classmethod
def from_cli_args(
cls, mdns_args: Iterable[str], dns_args: Iterable[str]
) -> AddressCache:
"""Create cache from command-line arguments.
Args:
mdns_args: List of mDNS cache entries like ['host=ip1,ip2']
dns_args: List of DNS cache entries like ['host=ip1,ip2']
Returns:
Configured AddressCache instance
"""
mdns_cache = cls._parse_cache_args(mdns_args)
dns_cache = cls._parse_cache_args(dns_args)
return cls(mdns_cache=mdns_cache, dns_cache=dns_cache)
@staticmethod
def _parse_cache_args(cache_args: Iterable[str]) -> dict[str, list[str]]:
"""Parse cache arguments into a dictionary.
Args:
cache_args: List of cache mappings like ['host1=ip1,ip2', 'host2=ip3']
Returns:
Dictionary mapping normalized hostnames to list of IP addresses
"""
cache: dict[str, list[str]] = {}
for arg in cache_args:
if "=" not in arg:
_LOGGER.warning(
"Invalid cache format: %s (expected 'hostname=ip1,ip2')", arg
)
continue
hostname, ips = arg.split("=", 1)
# Normalize hostname for consistent lookups
normalized = normalize_hostname(hostname)
cache[normalized] = [ip.strip() for ip in ips.split(",")]
return cache

View File

@@ -7,7 +7,7 @@ service APIConnection {
option (needs_setup_connection) = false; option (needs_setup_connection) = false;
option (needs_authentication) = false; option (needs_authentication) = false;
} }
rpc connect (ConnectRequest) returns (ConnectResponse) { rpc authenticate (AuthenticationRequest) returns (AuthenticationResponse) {
option (needs_setup_connection) = false; option (needs_setup_connection) = false;
option (needs_authentication) = false; option (needs_authentication) = false;
} }
@@ -129,7 +129,7 @@ message HelloResponse {
// Message sent at the beginning of each connection to authenticate the client // Message sent at the beginning of each connection to authenticate the client
// Can only be sent by the client and only at the beginning of the connection // Can only be sent by the client and only at the beginning of the connection
message ConnectRequest { message AuthenticationRequest {
option (id) = 3; option (id) = 3;
option (source) = SOURCE_CLIENT; option (source) = SOURCE_CLIENT;
option (no_delay) = true; option (no_delay) = true;
@@ -141,7 +141,7 @@ message ConnectRequest {
// Confirmation of successful connection. After this the connection is available for all traffic. // Confirmation of successful connection. After this the connection is available for all traffic.
// Can only be sent by the server and only at the beginning of the connection // Can only be sent by the server and only at the beginning of the connection
message ConnectResponse { message AuthenticationResponse {
option (id) = 4; option (id) = 4;
option (source) = SOURCE_SERVER; option (source) = SOURCE_SERVER;
option (no_delay) = true; option (no_delay) = true;

View File

@@ -1387,14 +1387,14 @@ bool APIConnection::send_hello_response(const HelloRequest &msg) {
return this->send_message(resp, HelloResponse::MESSAGE_TYPE); return this->send_message(resp, HelloResponse::MESSAGE_TYPE);
} }
#ifdef USE_API_PASSWORD #ifdef USE_API_PASSWORD
bool APIConnection::send_connect_response(const ConnectRequest &msg) { bool APIConnection::send_authenticate_response(const AuthenticationRequest &msg) {
ConnectResponse resp; AuthenticationResponse resp;
// bool invalid_password = 1; // bool invalid_password = 1;
resp.invalid_password = !this->parent_->check_password(msg.password); resp.invalid_password = !this->parent_->check_password(msg.password);
if (!resp.invalid_password) { if (!resp.invalid_password) {
this->complete_authentication_(); this->complete_authentication_();
} }
return this->send_message(resp, ConnectResponse::MESSAGE_TYPE); return this->send_message(resp, AuthenticationResponse::MESSAGE_TYPE);
} }
#endif // USE_API_PASSWORD #endif // USE_API_PASSWORD

View File

@@ -198,7 +198,7 @@ class APIConnection final : public APIServerConnection {
#endif #endif
bool send_hello_response(const HelloRequest &msg) override; bool send_hello_response(const HelloRequest &msg) override;
#ifdef USE_API_PASSWORD #ifdef USE_API_PASSWORD
bool send_connect_response(const ConnectRequest &msg) override; bool send_authenticate_response(const AuthenticationRequest &msg) override;
#endif #endif
bool send_disconnect_response(const DisconnectRequest &msg) override; bool send_disconnect_response(const DisconnectRequest &msg) override;
bool send_ping_response(const PingRequest &msg) override; bool send_ping_response(const PingRequest &msg) override;

View File

@@ -43,7 +43,7 @@ void HelloResponse::calculate_size(ProtoSize &size) const {
size.add_length(1, this->name_ref_.size()); size.add_length(1, this->name_ref_.size());
} }
#ifdef USE_API_PASSWORD #ifdef USE_API_PASSWORD
bool ConnectRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value) { bool AuthenticationRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value) {
switch (field_id) { switch (field_id) {
case 1: case 1:
this->password = value.as_string(); this->password = value.as_string();
@@ -53,8 +53,8 @@ bool ConnectRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value
} }
return true; return true;
} }
void ConnectResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(1, this->invalid_password); } void AuthenticationResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(1, this->invalid_password); }
void ConnectResponse::calculate_size(ProtoSize &size) const { size.add_bool(1, this->invalid_password); } void AuthenticationResponse::calculate_size(ProtoSize &size) const { size.add_bool(1, this->invalid_password); }
#endif #endif
#ifdef USE_AREAS #ifdef USE_AREAS
void AreaInfo::encode(ProtoWriteBuffer buffer) const { void AreaInfo::encode(ProtoWriteBuffer buffer) const {

View File

@@ -361,12 +361,12 @@ class HelloResponse final : public ProtoMessage {
protected: protected:
}; };
#ifdef USE_API_PASSWORD #ifdef USE_API_PASSWORD
class ConnectRequest final : public ProtoDecodableMessage { class AuthenticationRequest final : public ProtoDecodableMessage {
public: public:
static constexpr uint8_t MESSAGE_TYPE = 3; static constexpr uint8_t MESSAGE_TYPE = 3;
static constexpr uint8_t ESTIMATED_SIZE = 9; static constexpr uint8_t ESTIMATED_SIZE = 9;
#ifdef HAS_PROTO_MESSAGE_DUMP #ifdef HAS_PROTO_MESSAGE_DUMP
const char *message_name() const override { return "connect_request"; } const char *message_name() const override { return "authentication_request"; }
#endif #endif
std::string password{}; std::string password{};
#ifdef HAS_PROTO_MESSAGE_DUMP #ifdef HAS_PROTO_MESSAGE_DUMP
@@ -376,12 +376,12 @@ class ConnectRequest final : public ProtoDecodableMessage {
protected: protected:
bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;
}; };
class ConnectResponse final : public ProtoMessage { class AuthenticationResponse final : public ProtoMessage {
public: public:
static constexpr uint8_t MESSAGE_TYPE = 4; static constexpr uint8_t MESSAGE_TYPE = 4;
static constexpr uint8_t ESTIMATED_SIZE = 2; static constexpr uint8_t ESTIMATED_SIZE = 2;
#ifdef HAS_PROTO_MESSAGE_DUMP #ifdef HAS_PROTO_MESSAGE_DUMP
const char *message_name() const override { return "connect_response"; } const char *message_name() const override { return "authentication_response"; }
#endif #endif
bool invalid_password{false}; bool invalid_password{false};
void encode(ProtoWriteBuffer buffer) const override; void encode(ProtoWriteBuffer buffer) const override;

View File

@@ -670,8 +670,11 @@ void HelloResponse::dump_to(std::string &out) const {
dump_field(out, "name", this->name_ref_); dump_field(out, "name", this->name_ref_);
} }
#ifdef USE_API_PASSWORD #ifdef USE_API_PASSWORD
void ConnectRequest::dump_to(std::string &out) const { dump_field(out, "password", this->password); } void AuthenticationRequest::dump_to(std::string &out) const { dump_field(out, "password", this->password); }
void ConnectResponse::dump_to(std::string &out) const { dump_field(out, "invalid_password", this->invalid_password); } void AuthenticationResponse::dump_to(std::string &out) const {
MessageDumpHelper helper(out, "AuthenticationResponse");
dump_field(out, "invalid_password", this->invalid_password);
}
#endif #endif
void DisconnectRequest::dump_to(std::string &out) const { out.append("DisconnectRequest {}"); } void DisconnectRequest::dump_to(std::string &out) const { out.append("DisconnectRequest {}"); }
void DisconnectResponse::dump_to(std::string &out) const { out.append("DisconnectResponse {}"); } void DisconnectResponse::dump_to(std::string &out) const { out.append("DisconnectResponse {}"); }

View File

@@ -25,13 +25,13 @@ void APIServerConnectionBase::read_message(uint32_t msg_size, uint32_t msg_type,
break; break;
} }
#ifdef USE_API_PASSWORD #ifdef USE_API_PASSWORD
case ConnectRequest::MESSAGE_TYPE: { case AuthenticationRequest::MESSAGE_TYPE: {
ConnectRequest msg; AuthenticationRequest msg;
msg.decode(msg_data, msg_size); msg.decode(msg_data, msg_size);
#ifdef HAS_PROTO_MESSAGE_DUMP #ifdef HAS_PROTO_MESSAGE_DUMP
ESP_LOGVV(TAG, "on_connect_request: %s", msg.dump().c_str()); ESP_LOGVV(TAG, "on_authentication_request: %s", msg.dump().c_str());
#endif #endif
this->on_connect_request(msg); this->on_authentication_request(msg);
break; break;
} }
#endif #endif
@@ -600,8 +600,8 @@ void APIServerConnection::on_hello_request(const HelloRequest &msg) {
} }
} }
#ifdef USE_API_PASSWORD #ifdef USE_API_PASSWORD
void APIServerConnection::on_connect_request(const ConnectRequest &msg) { void APIServerConnection::on_authentication_request(const AuthenticationRequest &msg) {
if (!this->send_connect_response(msg)) { if (!this->send_authenticate_response(msg)) {
this->on_fatal_error(); this->on_fatal_error();
} }
} }

View File

@@ -27,7 +27,7 @@ class APIServerConnectionBase : public ProtoService {
virtual void on_hello_request(const HelloRequest &value){}; virtual void on_hello_request(const HelloRequest &value){};
#ifdef USE_API_PASSWORD #ifdef USE_API_PASSWORD
virtual void on_connect_request(const ConnectRequest &value){}; virtual void on_authentication_request(const AuthenticationRequest &value){};
#endif #endif
virtual void on_disconnect_request(const DisconnectRequest &value){}; virtual void on_disconnect_request(const DisconnectRequest &value){};
@@ -216,7 +216,7 @@ class APIServerConnection : public APIServerConnectionBase {
public: public:
virtual bool send_hello_response(const HelloRequest &msg) = 0; virtual bool send_hello_response(const HelloRequest &msg) = 0;
#ifdef USE_API_PASSWORD #ifdef USE_API_PASSWORD
virtual bool send_connect_response(const ConnectRequest &msg) = 0; virtual bool send_authenticate_response(const AuthenticationRequest &msg) = 0;
#endif #endif
virtual bool send_disconnect_response(const DisconnectRequest &msg) = 0; virtual bool send_disconnect_response(const DisconnectRequest &msg) = 0;
virtual bool send_ping_response(const PingRequest &msg) = 0; virtual bool send_ping_response(const PingRequest &msg) = 0;
@@ -339,7 +339,7 @@ class APIServerConnection : public APIServerConnectionBase {
protected: protected:
void on_hello_request(const HelloRequest &msg) override; void on_hello_request(const HelloRequest &msg) override;
#ifdef USE_API_PASSWORD #ifdef USE_API_PASSWORD
void on_connect_request(const ConnectRequest &msg) override; void on_authentication_request(const AuthenticationRequest &msg) override;
#endif #endif
void on_disconnect_request(const DisconnectRequest &msg) override; void on_disconnect_request(const DisconnectRequest &msg) override;
void on_ping_request(const PingRequest &msg) override; void on_ping_request(const PingRequest &msg) override;

View File

@@ -18,6 +18,7 @@ from esphome.const import (
DEVICE_CLASS_TEMPERATURE, DEVICE_CLASS_TEMPERATURE,
DEVICE_CLASS_VOLTAGE, DEVICE_CLASS_VOLTAGE,
STATE_CLASS_MEASUREMENT, STATE_CLASS_MEASUREMENT,
STATE_CLASS_TOTAL_INCREASING,
UNIT_AMPERE, UNIT_AMPERE,
UNIT_CELSIUS, UNIT_CELSIUS,
UNIT_VOLT, UNIT_VOLT,
@@ -162,7 +163,7 @@ INA2XX_SCHEMA = cv.Schema(
unit_of_measurement=UNIT_WATT_HOURS, unit_of_measurement=UNIT_WATT_HOURS,
accuracy_decimals=8, accuracy_decimals=8,
device_class=DEVICE_CLASS_ENERGY, device_class=DEVICE_CLASS_ENERGY,
state_class=STATE_CLASS_MEASUREMENT, state_class=STATE_CLASS_TOTAL_INCREASING,
), ),
key=CONF_NAME, key=CONF_NAME,
), ),
@@ -170,7 +171,8 @@ INA2XX_SCHEMA = cv.Schema(
sensor.sensor_schema( sensor.sensor_schema(
unit_of_measurement=UNIT_JOULE, unit_of_measurement=UNIT_JOULE,
accuracy_decimals=8, accuracy_decimals=8,
state_class=STATE_CLASS_MEASUREMENT, device_class=DEVICE_CLASS_ENERGY,
state_class=STATE_CLASS_TOTAL_INCREASING,
), ),
key=CONF_NAME, key=CONF_NAME,
), ),

View File

@@ -39,6 +39,8 @@ from esphome.helpers import ensure_unique_string, get_str_env, is_ha_addon
from esphome.util import OrderedDict from esphome.util import OrderedDict
if TYPE_CHECKING: if TYPE_CHECKING:
from esphome.address_cache import AddressCache
from ..cpp_generator import MockObj, MockObjClass, Statement from ..cpp_generator import MockObj, MockObjClass, Statement
from ..types import ConfigType, EntityMetadata from ..types import ConfigType, EntityMetadata
@@ -583,6 +585,8 @@ class EsphomeCore:
self.id_classes = {} self.id_classes = {}
# The current component being processed during validation # The current component being processed during validation
self.current_component: str | None = None self.current_component: str | None = None
# Address cache for DNS and mDNS lookups from command line arguments
self.address_cache: AddressCache | None = None
def reset(self): def reset(self):
from esphome.pins import PIN_SCHEMA_REGISTRY from esphome.pins import PIN_SCHEMA_REGISTRY
@@ -610,6 +614,7 @@ class EsphomeCore:
self.platform_counts = defaultdict(int) self.platform_counts = defaultdict(int)
self.unique_ids = {} self.unique_ids = {}
self.current_component = None self.current_component = None
self.address_cache = None
PIN_SCHEMA_REGISTRY.reset() PIN_SCHEMA_REGISTRY.reset()
@contextmanager @contextmanager

View File

@@ -28,6 +28,21 @@ class DNSCache:
self._cache: dict[str, tuple[float, list[str] | Exception]] = {} self._cache: dict[str, tuple[float, list[str] | Exception]] = {}
self._ttl = ttl self._ttl = ttl
def get_cached_addresses(
self, hostname: str, now_monotonic: float
) -> list[str] | None:
"""Get cached addresses without triggering resolution.
Returns None if not in cache, list of addresses if found.
"""
# Normalize hostname for consistent lookups
normalized = hostname.rstrip(".").lower()
if expire_time_addresses := self._cache.get(normalized):
expire_time, addresses = expire_time_addresses
if expire_time > now_monotonic and not isinstance(addresses, Exception):
return addresses
return None
async def async_resolve( async def async_resolve(
self, hostname: str, now_monotonic: float self, hostname: str, now_monotonic: float
) -> list[str] | Exception: ) -> list[str] | Exception:

View File

@@ -4,6 +4,9 @@ import asyncio
import logging import logging
import typing import typing
from zeroconf import AddressResolver, IPVersion
from esphome.address_cache import normalize_hostname
from esphome.zeroconf import ( from esphome.zeroconf import (
ESPHOME_SERVICE_TYPE, ESPHOME_SERVICE_TYPE,
AsyncEsphomeZeroconf, AsyncEsphomeZeroconf,
@@ -50,6 +53,30 @@ class MDNSStatus:
return await aiozc.async_resolve_host(host_name) return await aiozc.async_resolve_host(host_name)
return None return None
def get_cached_addresses(self, host_name: str) -> list[str] | None:
"""Get cached addresses for a host without triggering resolution.
Returns None if not in cache or no zeroconf available.
"""
if not self.aiozc:
_LOGGER.debug("No zeroconf instance available for %s", host_name)
return None
# Normalize hostname and get the base name
normalized = normalize_hostname(host_name)
base_name = normalized.partition(".")[0]
# Try to load from zeroconf cache without triggering resolution
resolver_name = f"{base_name}.local."
info = AddressResolver(resolver_name)
# Let zeroconf use its own current time for cache checking
if info.load_from_cache(self.aiozc.zeroconf):
addresses = info.parsed_scoped_addresses(IPVersion.All)
_LOGGER.debug("Found %s in zeroconf cache: %s", resolver_name, addresses)
return addresses
_LOGGER.debug("Not found in zeroconf cache: %s", resolver_name)
return None
async def async_refresh_hosts(self) -> None: async def async_refresh_hosts(self) -> None:
"""Refresh the hosts to track.""" """Refresh the hosts to track."""
dashboard = self.dashboard dashboard = self.dashboard

View File

@@ -50,8 +50,8 @@ from esphome.util import get_serial_ports, shlex_quote
from esphome.yaml_util import FastestAvailableSafeLoader from esphome.yaml_util import FastestAvailableSafeLoader
from .const import DASHBOARD_COMMAND from .const import DASHBOARD_COMMAND
from .core import DASHBOARD from .core import DASHBOARD, ESPHomeDashboard
from .entries import UNKNOWN_STATE, entry_state_to_bool from .entries import UNKNOWN_STATE, DashboardEntry, entry_state_to_bool
from .util.file import write_file from .util.file import write_file
from .util.subprocess import async_run_system_command from .util.subprocess import async_run_system_command
from .util.text import friendly_name_slugify from .util.text import friendly_name_slugify
@@ -314,6 +314,73 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler):
raise NotImplementedError raise NotImplementedError
def build_cache_arguments(
entry: DashboardEntry | None,
dashboard: ESPHomeDashboard,
now: float,
) -> list[str]:
"""Build cache arguments for passing to CLI.
Args:
entry: Dashboard entry for the configuration
dashboard: Dashboard instance with cache access
now: Current monotonic time for DNS cache expiry checks
Returns:
List of cache arguments to pass to CLI
"""
cache_args: list[str] = []
if not entry:
return cache_args
_LOGGER.debug(
"Building cache for entry (address=%s, name=%s)",
entry.address,
entry.name,
)
def add_cache_entry(hostname: str, addresses: list[str], cache_type: str) -> None:
"""Add a cache entry to the command arguments."""
if not addresses:
return
normalized = hostname.rstrip(".").lower()
cache_args.extend(
[
f"--{cache_type}-address-cache",
f"{normalized}={','.join(sort_ip_addresses(addresses))}",
]
)
# Check entry.address for cached addresses
if use_address := entry.address:
if use_address.endswith(".local"):
# mDNS cache for .local addresses
if (mdns := dashboard.mdns_status) and (
cached := mdns.get_cached_addresses(use_address)
):
_LOGGER.debug("mDNS cache hit for %s: %s", use_address, cached)
add_cache_entry(use_address, cached, "mdns")
# DNS cache for non-.local addresses
elif cached := dashboard.dns_cache.get_cached_addresses(use_address, now):
_LOGGER.debug("DNS cache hit for %s: %s", use_address, cached)
add_cache_entry(use_address, cached, "dns")
# Check entry.name if we haven't already cached via address
# For mDNS devices, entry.name typically doesn't have .local suffix
if entry.name and not use_address:
mdns_name = (
f"{entry.name}.local" if not entry.name.endswith(".local") else entry.name
)
if (mdns := dashboard.mdns_status) and (
cached := mdns.get_cached_addresses(mdns_name)
):
_LOGGER.debug("mDNS cache hit for %s: %s", mdns_name, cached)
add_cache_entry(mdns_name, cached, "mdns")
return cache_args
class EsphomePortCommandWebSocket(EsphomeCommandWebSocket): class EsphomePortCommandWebSocket(EsphomeCommandWebSocket):
"""Base class for commands that require a port.""" """Base class for commands that require a port."""
@@ -326,52 +393,22 @@ class EsphomePortCommandWebSocket(EsphomeCommandWebSocket):
configuration = json_message["configuration"] configuration = json_message["configuration"]
config_file = settings.rel_path(configuration) config_file = settings.rel_path(configuration)
port = json_message["port"] port = json_message["port"]
addresses: list[str] = []
# Build cache arguments to pass to CLI
cache_args: list[str] = []
if ( if (
port == "OTA" # pylint: disable=too-many-boolean-expressions port == "OTA" # pylint: disable=too-many-boolean-expressions
and (entry := entries.get(config_file)) and (entry := entries.get(config_file))
and entry.loaded_integrations and entry.loaded_integrations
and "api" in entry.loaded_integrations and "api" in entry.loaded_integrations
): ):
# First priority: entry.address AKA use_address cache_args = build_cache_arguments(entry, dashboard, time.monotonic())
if (
(use_address := entry.address)
and (
address_list := await dashboard.dns_cache.async_resolve(
use_address, time.monotonic()
)
)
and not isinstance(address_list, Exception)
):
addresses.extend(sort_ip_addresses(address_list))
# Second priority: mDNS # Cache arguments must come before the subcommand
if ( cmd = [*DASHBOARD_COMMAND, *cache_args, *args, config_file, "--device", port]
(mdns := dashboard.mdns_status) _LOGGER.debug("Built command: %s", cmd)
and (address_list := await mdns.async_resolve_host(entry.name)) return cmd
and (
new_addresses := [
addr for addr in address_list if addr not in addresses
]
)
):
# Use the IP address if available but only
# if the API is loaded and the device is online
# since MQTT logging will not work otherwise
addresses.extend(sort_ip_addresses(new_addresses))
if not addresses:
# If no address was found, use the port directly
# as otherwise they will get the chooser which
# does not work with the dashboard as there is no
# interactive way to get keyboard input
addresses = [port]
device_args: list[str] = [
arg for address in addresses for arg in ("--device", address)
]
return [*DASHBOARD_COMMAND, *args, config_file, *device_args]
class EsphomeLogsHandler(EsphomePortCommandWebSocket): class EsphomeLogsHandler(EsphomePortCommandWebSocket):

View File

@@ -311,10 +311,14 @@ def perform_ota(
def run_ota_impl_( def run_ota_impl_(
remote_host: str | list[str], remote_port: int, password: str, filename: str remote_host: str | list[str], remote_port: int, password: str, filename: str
) -> tuple[int, str | None]: ) -> tuple[int, str | None]:
from esphome.core import CORE
# Handle both single host and list of hosts # Handle both single host and list of hosts
try: try:
# Resolve all hosts at once for parallel DNS resolution # Resolve all hosts at once for parallel DNS resolution
res = resolve_ip_address(remote_host, remote_port) res = resolve_ip_address(
remote_host, remote_port, address_cache=CORE.address_cache
)
except EsphomeError as err: except EsphomeError as err:
_LOGGER.error( _LOGGER.error(
"Error resolving IP address of %s. Is it connected to WiFi?", "Error resolving IP address of %s. Is it connected to WiFi?",

View File

@@ -9,10 +9,14 @@ from pathlib import Path
import platform import platform
import re import re
import tempfile import tempfile
from typing import TYPE_CHECKING
from urllib.parse import urlparse from urllib.parse import urlparse
from esphome.const import __version__ as ESPHOME_VERSION from esphome.const import __version__ as ESPHOME_VERSION
if TYPE_CHECKING:
from esphome.address_cache import AddressCache
# Type aliases for socket address information # Type aliases for socket address information
AddrInfo = tuple[ AddrInfo = tuple[
int, # family (AF_INET, AF_INET6, etc.) int, # family (AF_INET, AF_INET6, etc.)
@@ -173,7 +177,24 @@ def addr_preference_(res: AddrInfo) -> int:
return 1 return 1
def resolve_ip_address(host: str | list[str], port: int) -> list[AddrInfo]: def _add_ip_addresses_to_addrinfo(
addresses: list[str], port: int, res: list[AddrInfo]
) -> None:
"""Helper to add IP addresses to addrinfo results with error handling."""
import socket
for addr in addresses:
try:
res += socket.getaddrinfo(
addr, port, proto=socket.IPPROTO_TCP, flags=socket.AI_NUMERICHOST
)
except OSError:
_LOGGER.debug("Failed to parse IP address '%s'", addr)
def resolve_ip_address(
host: str | list[str], port: int, address_cache: AddressCache | None = None
) -> list[AddrInfo]:
import socket import socket
# There are five cases here. The host argument could be one of: # There are five cases here. The host argument could be one of:
@@ -194,47 +215,69 @@ def resolve_ip_address(host: str | list[str], port: int) -> list[AddrInfo]:
hosts = [host] hosts = [host]
res: list[AddrInfo] = [] res: list[AddrInfo] = []
# Fast path: if all hosts are already IP addresses
if all(is_ip_address(h) for h in hosts): if all(is_ip_address(h) for h in hosts):
# Fast path: all are IP addresses, use socket.getaddrinfo with AI_NUMERICHOST _add_ip_addresses_to_addrinfo(hosts, port, res)
for addr in hosts:
try:
res += socket.getaddrinfo(
addr, port, proto=socket.IPPROTO_TCP, flags=socket.AI_NUMERICHOST
)
except OSError:
_LOGGER.debug("Failed to parse IP address '%s'", addr)
# Sort by preference # Sort by preference
res.sort(key=addr_preference_) res.sort(key=addr_preference_)
return res return res
from esphome.resolver import AsyncResolver # Process hosts
cached_addresses: list[str] = []
uncached_hosts: list[str] = []
has_cache = address_cache is not None
resolver = AsyncResolver(hosts, port) for h in hosts:
addr_infos = resolver.resolve() if is_ip_address(h):
# Convert aioesphomeapi AddrInfo to our format if has_cache:
for addr_info in addr_infos: # If we have a cache, treat IPs as cached
sockaddr = addr_info.sockaddr cached_addresses.append(h)
if addr_info.family == socket.AF_INET6: else:
# IPv6 # If no cache, pass IPs through to resolver with hostnames
sockaddr_tuple = ( uncached_hosts.append(h)
sockaddr.address, elif address_cache and (cached := address_cache.get_addresses(h)):
sockaddr.port, # Found in cache
sockaddr.flowinfo, cached_addresses.extend(cached)
sockaddr.scope_id,
)
else: else:
# IPv4 # Not cached, need to resolve
sockaddr_tuple = (sockaddr.address, sockaddr.port) if address_cache and address_cache.has_cache():
_LOGGER.info("Host %s not in cache, will need to resolve", h)
uncached_hosts.append(h)
res.append( # Process cached addresses (includes direct IPs and cached lookups)
( _add_ip_addresses_to_addrinfo(cached_addresses, port, res)
addr_info.family,
addr_info.type, # If we have uncached hosts (only non-IP hostnames), resolve them
addr_info.proto, if uncached_hosts:
"", # canonname from esphome.resolver import AsyncResolver
sockaddr_tuple,
resolver = AsyncResolver(uncached_hosts, port)
addr_infos = resolver.resolve()
# Convert aioesphomeapi AddrInfo to our format
for addr_info in addr_infos:
sockaddr = addr_info.sockaddr
if addr_info.family == socket.AF_INET6:
# IPv6
sockaddr_tuple = (
sockaddr.address,
sockaddr.port,
sockaddr.flowinfo,
sockaddr.scope_id,
)
else:
# IPv4
sockaddr_tuple = (sockaddr.address, sockaddr.port)
res.append(
(
addr_info.family,
addr_info.type,
addr_info.proto,
"", # canonname
sockaddr_tuple,
)
) )
)
# Sort by preference # Sort by preference
res.sort(key=addr_preference_) res.sort(key=addr_preference_)
@@ -256,14 +299,7 @@ def sort_ip_addresses(address_list: list[str]) -> list[str]:
# First "resolve" all the IP addresses to getaddrinfo() tuples of the form # First "resolve" all the IP addresses to getaddrinfo() tuples of the form
# (family, type, proto, canonname, sockaddr) # (family, type, proto, canonname, sockaddr)
res: list[AddrInfo] = [] res: list[AddrInfo] = []
for addr in address_list: _add_ip_addresses_to_addrinfo(address_list, 0, res)
# This should always work as these are supposed to be IP addresses
try:
res += socket.getaddrinfo(
addr, 0, proto=socket.IPPROTO_TCP, flags=socket.AI_NUMERICHOST
)
except OSError:
_LOGGER.info("Failed to parse IP address '%s'", addr)
# Now use that information to sort them. # Now use that information to sort them.
res.sort(key=addr_preference_) res.sort(key=addr_preference_)

View File

@@ -12,7 +12,7 @@ platformio==6.1.18 # When updating platformio, also update /docker/Dockerfile
esptool==5.0.2 esptool==5.0.2
click==8.1.7 click==8.1.7
esphome-dashboard==20250904.0 esphome-dashboard==20250904.0
aioesphomeapi==40.2.0 aioesphomeapi==41.0.0
zeroconf==0.147.2 zeroconf==0.147.2
puremagic==1.30 puremagic==1.30
ruamel.yaml==0.18.15 # dashboard_import ruamel.yaml==0.18.15 # dashboard_import

View File

@@ -1,8 +1,4 @@
substitutions: substitutions:
network_enable_ipv6: "true" network_enable_ipv6: "true"
bk72xx:
framework:
version: 1.7.0
<<: !include common.yaml <<: !include common.yaml

View File

@@ -0,0 +1 @@
<<: !include common.yaml

View File

@@ -0,0 +1,21 @@
"""Common fixtures for dashboard tests."""
from __future__ import annotations
from unittest.mock import Mock
import pytest
from esphome.dashboard.core import ESPHomeDashboard
@pytest.fixture
def mock_dashboard() -> Mock:
"""Create a mock dashboard."""
dashboard = Mock(spec=ESPHomeDashboard)
dashboard.entries = Mock()
dashboard.entries.async_all.return_value = []
dashboard.stop_event = Mock()
dashboard.stop_event.is_set.return_value = True
dashboard.ping_request = Mock()
return dashboard

View File

View File

@@ -0,0 +1,121 @@
"""Unit tests for esphome.dashboard.dns module."""
from __future__ import annotations
import time
from unittest.mock import patch
import pytest
from esphome.dashboard.dns import DNSCache
@pytest.fixture
def dns_cache_fixture() -> DNSCache:
"""Create a DNSCache instance."""
return DNSCache()
def test_get_cached_addresses_not_in_cache(dns_cache_fixture: DNSCache) -> None:
"""Test get_cached_addresses when hostname is not in cache."""
now = time.monotonic()
result = dns_cache_fixture.get_cached_addresses("unknown.example.com", now)
assert result is None
def test_get_cached_addresses_expired(dns_cache_fixture: DNSCache) -> None:
"""Test get_cached_addresses when cache entry is expired."""
now = time.monotonic()
# Add entry that's already expired
dns_cache_fixture._cache["example.com"] = (now - 1, ["192.168.1.10"])
result = dns_cache_fixture.get_cached_addresses("example.com", now)
assert result is None
# Expired entry should still be in cache (not removed by get_cached_addresses)
assert "example.com" in dns_cache_fixture._cache
def test_get_cached_addresses_valid(dns_cache_fixture: DNSCache) -> None:
"""Test get_cached_addresses with valid cache entry."""
now = time.monotonic()
# Add entry that expires in 60 seconds
dns_cache_fixture._cache["example.com"] = (
now + 60,
["192.168.1.10", "192.168.1.11"],
)
result = dns_cache_fixture.get_cached_addresses("example.com", now)
assert result == ["192.168.1.10", "192.168.1.11"]
# Entry should still be in cache
assert "example.com" in dns_cache_fixture._cache
def test_get_cached_addresses_hostname_normalization(
dns_cache_fixture: DNSCache,
) -> None:
"""Test get_cached_addresses normalizes hostname."""
now = time.monotonic()
# Add entry with lowercase hostname
dns_cache_fixture._cache["example.com"] = (now + 60, ["192.168.1.10"])
# Test with various forms
assert dns_cache_fixture.get_cached_addresses("EXAMPLE.COM", now) == [
"192.168.1.10"
]
assert dns_cache_fixture.get_cached_addresses("example.com.", now) == [
"192.168.1.10"
]
assert dns_cache_fixture.get_cached_addresses("EXAMPLE.COM.", now) == [
"192.168.1.10"
]
def test_get_cached_addresses_ipv6(dns_cache_fixture: DNSCache) -> None:
"""Test get_cached_addresses with IPv6 addresses."""
now = time.monotonic()
dns_cache_fixture._cache["example.com"] = (now + 60, ["2001:db8::1", "fe80::1"])
result = dns_cache_fixture.get_cached_addresses("example.com", now)
assert result == ["2001:db8::1", "fe80::1"]
def test_get_cached_addresses_empty_list(dns_cache_fixture: DNSCache) -> None:
"""Test get_cached_addresses with empty address list."""
now = time.monotonic()
dns_cache_fixture._cache["example.com"] = (now + 60, [])
result = dns_cache_fixture.get_cached_addresses("example.com", now)
assert result == []
def test_get_cached_addresses_exception_in_cache(dns_cache_fixture: DNSCache) -> None:
"""Test get_cached_addresses when cache contains an exception."""
now = time.monotonic()
# Store an exception (from failed resolution)
dns_cache_fixture._cache["example.com"] = (now + 60, OSError("Resolution failed"))
result = dns_cache_fixture.get_cached_addresses("example.com", now)
assert result is None # Should return None for exceptions
def test_async_resolve_not_called(dns_cache_fixture: DNSCache) -> None:
"""Test that get_cached_addresses never calls async_resolve."""
now = time.monotonic()
with patch.object(dns_cache_fixture, "async_resolve") as mock_resolve:
# Test non-cached
result = dns_cache_fixture.get_cached_addresses("uncached.com", now)
assert result is None
mock_resolve.assert_not_called()
# Test expired
dns_cache_fixture._cache["expired.com"] = (now - 1, ["192.168.1.10"])
result = dns_cache_fixture.get_cached_addresses("expired.com", now)
assert result is None
mock_resolve.assert_not_called()
# Test valid
dns_cache_fixture._cache["valid.com"] = (now + 60, ["192.168.1.10"])
result = dns_cache_fixture.get_cached_addresses("valid.com", now)
assert result == ["192.168.1.10"]
mock_resolve.assert_not_called()

View File

@@ -0,0 +1,168 @@
"""Unit tests for esphome.dashboard.status.mdns module."""
from __future__ import annotations
from unittest.mock import Mock, patch
import pytest
import pytest_asyncio
from zeroconf import AddressResolver, IPVersion
from esphome.dashboard.status.mdns import MDNSStatus
@pytest_asyncio.fixture
async def mdns_status(mock_dashboard: Mock) -> MDNSStatus:
"""Create an MDNSStatus instance in async context."""
# We're in an async context so get_running_loop will work
return MDNSStatus(mock_dashboard)
@pytest.mark.asyncio
async def test_get_cached_addresses_no_zeroconf(mdns_status: MDNSStatus) -> None:
"""Test get_cached_addresses when no zeroconf instance is available."""
mdns_status.aiozc = None
result = mdns_status.get_cached_addresses("device.local")
assert result is None
@pytest.mark.asyncio
async def test_get_cached_addresses_not_in_cache(mdns_status: MDNSStatus) -> None:
"""Test get_cached_addresses when address is not in cache."""
mdns_status.aiozc = Mock()
mdns_status.aiozc.zeroconf = Mock()
with patch("esphome.dashboard.status.mdns.AddressResolver") as mock_resolver:
mock_info = Mock(spec=AddressResolver)
mock_info.load_from_cache.return_value = False
mock_resolver.return_value = mock_info
result = mdns_status.get_cached_addresses("device.local")
assert result is None
mock_info.load_from_cache.assert_called_once_with(mdns_status.aiozc.zeroconf)
@pytest.mark.asyncio
async def test_get_cached_addresses_found_in_cache(mdns_status: MDNSStatus) -> None:
"""Test get_cached_addresses when address is found in cache."""
mdns_status.aiozc = Mock()
mdns_status.aiozc.zeroconf = Mock()
with patch("esphome.dashboard.status.mdns.AddressResolver") as mock_resolver:
mock_info = Mock(spec=AddressResolver)
mock_info.load_from_cache.return_value = True
mock_info.parsed_scoped_addresses.return_value = ["192.168.1.10", "fe80::1"]
mock_resolver.return_value = mock_info
result = mdns_status.get_cached_addresses("device.local")
assert result == ["192.168.1.10", "fe80::1"]
mock_info.load_from_cache.assert_called_once_with(mdns_status.aiozc.zeroconf)
mock_info.parsed_scoped_addresses.assert_called_once_with(IPVersion.All)
@pytest.mark.asyncio
async def test_get_cached_addresses_with_trailing_dot(mdns_status: MDNSStatus) -> None:
"""Test get_cached_addresses with hostname having trailing dot."""
mdns_status.aiozc = Mock()
mdns_status.aiozc.zeroconf = Mock()
with patch("esphome.dashboard.status.mdns.AddressResolver") as mock_resolver:
mock_info = Mock(spec=AddressResolver)
mock_info.load_from_cache.return_value = True
mock_info.parsed_scoped_addresses.return_value = ["192.168.1.10"]
mock_resolver.return_value = mock_info
result = mdns_status.get_cached_addresses("device.local.")
assert result == ["192.168.1.10"]
# Should normalize to device.local. for zeroconf
mock_resolver.assert_called_once_with("device.local.")
@pytest.mark.asyncio
async def test_get_cached_addresses_uppercase_hostname(mdns_status: MDNSStatus) -> None:
"""Test get_cached_addresses with uppercase hostname."""
mdns_status.aiozc = Mock()
mdns_status.aiozc.zeroconf = Mock()
with patch("esphome.dashboard.status.mdns.AddressResolver") as mock_resolver:
mock_info = Mock(spec=AddressResolver)
mock_info.load_from_cache.return_value = True
mock_info.parsed_scoped_addresses.return_value = ["192.168.1.10"]
mock_resolver.return_value = mock_info
result = mdns_status.get_cached_addresses("DEVICE.LOCAL")
assert result == ["192.168.1.10"]
# Should normalize to device.local. for zeroconf
mock_resolver.assert_called_once_with("device.local.")
@pytest.mark.asyncio
async def test_get_cached_addresses_simple_hostname(mdns_status: MDNSStatus) -> None:
"""Test get_cached_addresses with simple hostname (no domain)."""
mdns_status.aiozc = Mock()
mdns_status.aiozc.zeroconf = Mock()
with patch("esphome.dashboard.status.mdns.AddressResolver") as mock_resolver:
mock_info = Mock(spec=AddressResolver)
mock_info.load_from_cache.return_value = True
mock_info.parsed_scoped_addresses.return_value = ["192.168.1.10"]
mock_resolver.return_value = mock_info
result = mdns_status.get_cached_addresses("device")
assert result == ["192.168.1.10"]
# Should append .local. for zeroconf
mock_resolver.assert_called_once_with("device.local.")
@pytest.mark.asyncio
async def test_get_cached_addresses_ipv6_only(mdns_status: MDNSStatus) -> None:
"""Test get_cached_addresses returning only IPv6 addresses."""
mdns_status.aiozc = Mock()
mdns_status.aiozc.zeroconf = Mock()
with patch("esphome.dashboard.status.mdns.AddressResolver") as mock_resolver:
mock_info = Mock(spec=AddressResolver)
mock_info.load_from_cache.return_value = True
mock_info.parsed_scoped_addresses.return_value = ["fe80::1", "2001:db8::1"]
mock_resolver.return_value = mock_info
result = mdns_status.get_cached_addresses("device.local")
assert result == ["fe80::1", "2001:db8::1"]
@pytest.mark.asyncio
async def test_get_cached_addresses_empty_list(mdns_status: MDNSStatus) -> None:
"""Test get_cached_addresses returning empty list from cache."""
mdns_status.aiozc = Mock()
mdns_status.aiozc.zeroconf = Mock()
with patch("esphome.dashboard.status.mdns.AddressResolver") as mock_resolver:
mock_info = Mock(spec=AddressResolver)
mock_info.load_from_cache.return_value = True
mock_info.parsed_scoped_addresses.return_value = []
mock_resolver.return_value = mock_info
result = mdns_status.get_cached_addresses("device.local")
assert result == []
@pytest.mark.asyncio
async def test_async_setup_success(mock_dashboard: Mock) -> None:
"""Test successful async_setup."""
mdns_status = MDNSStatus(mock_dashboard)
with patch("esphome.dashboard.status.mdns.AsyncEsphomeZeroconf") as mock_zc:
mock_zc.return_value = Mock()
result = mdns_status.async_setup()
assert result is True
assert mdns_status.aiozc is not None
@pytest.mark.asyncio
async def test_async_setup_failure(mock_dashboard: Mock) -> None:
"""Test async_setup with OSError."""
mdns_status = MDNSStatus(mock_dashboard)
with patch("esphome.dashboard.status.mdns.AsyncEsphomeZeroconf") as mock_zc:
mock_zc.side_effect = OSError("Network error")
result = mdns_status.async_setup()
assert result is False
assert mdns_status.aiozc is None

View File

@@ -639,3 +639,83 @@ def test_start_web_server_with_unix_socket(tmp_path: Path) -> None:
mock_server_class.assert_called_once_with(app) mock_server_class.assert_called_once_with(app)
mock_bind.assert_called_once_with(str(socket_path), mode=0o666) mock_bind.assert_called_once_with(str(socket_path), mode=0o666)
server.add_socket.assert_called_once() server.add_socket.assert_called_once()
def test_build_cache_arguments_no_entry(mock_dashboard: Mock) -> None:
"""Test with no entry returns empty list."""
result = web_server.build_cache_arguments(None, mock_dashboard, 0.0)
assert result == []
def test_build_cache_arguments_no_address_no_name(mock_dashboard: Mock) -> None:
"""Test with entry but no address or name."""
entry = Mock(spec=web_server.DashboardEntry)
entry.address = None
entry.name = None
result = web_server.build_cache_arguments(entry, mock_dashboard, 0.0)
assert result == []
def test_build_cache_arguments_mdns_address_cached(mock_dashboard: Mock) -> None:
"""Test with .local address that has cached mDNS results."""
entry = Mock(spec=web_server.DashboardEntry)
entry.address = "device.local"
entry.name = None
mock_dashboard.mdns_status = Mock()
mock_dashboard.mdns_status.get_cached_addresses.return_value = [
"192.168.1.10",
"fe80::1",
]
result = web_server.build_cache_arguments(entry, mock_dashboard, 0.0)
assert result == [
"--mdns-address-cache",
"device.local=192.168.1.10,fe80::1",
]
mock_dashboard.mdns_status.get_cached_addresses.assert_called_once_with(
"device.local"
)
def test_build_cache_arguments_dns_address_cached(mock_dashboard: Mock) -> None:
"""Test with non-.local address that has cached DNS results."""
entry = Mock(spec=web_server.DashboardEntry)
entry.address = "example.com"
entry.name = None
mock_dashboard.dns_cache = Mock()
mock_dashboard.dns_cache.get_cached_addresses.return_value = [
"93.184.216.34",
"2606:2800:220:1:248:1893:25c8:1946",
]
now = 100.0
result = web_server.build_cache_arguments(entry, mock_dashboard, now)
# IPv6 addresses are sorted before IPv4
assert result == [
"--dns-address-cache",
"example.com=2606:2800:220:1:248:1893:25c8:1946,93.184.216.34",
]
mock_dashboard.dns_cache.get_cached_addresses.assert_called_once_with(
"example.com", now
)
def test_build_cache_arguments_name_without_address(mock_dashboard: Mock) -> None:
"""Test with name but no address - should check mDNS with .local suffix."""
entry = Mock(spec=web_server.DashboardEntry)
entry.name = "my-device"
entry.address = None
mock_dashboard.mdns_status = Mock()
mock_dashboard.mdns_status.get_cached_addresses.return_value = ["192.168.1.20"]
result = web_server.build_cache_arguments(entry, mock_dashboard, 0.0)
assert result == [
"--mdns-address-cache",
"my-device.local=192.168.1.20",
]
mock_dashboard.mdns_status.get_cached_addresses.assert_called_once_with(
"my-device.local"
)

View File

@@ -0,0 +1,188 @@
"""Tests for esphome.build_gen.platformio module."""
from __future__ import annotations
from collections.abc import Generator
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from esphome.build_gen import platformio
from esphome.core import CORE
@pytest.fixture
def mock_update_storage_json() -> Generator[MagicMock]:
"""Mock update_storage_json for all tests."""
with patch("esphome.build_gen.platformio.update_storage_json") as mock:
yield mock
@pytest.fixture
def mock_write_file_if_changed() -> Generator[MagicMock]:
"""Mock write_file_if_changed for tests."""
with patch("esphome.build_gen.platformio.write_file_if_changed") as mock:
yield mock
def test_write_ini_creates_new_file(
tmp_path: Path, mock_update_storage_json: MagicMock
) -> None:
"""Test write_ini creates a new platformio.ini file."""
CORE.build_path = str(tmp_path)
content = """
[env:test]
platform = espressif32
board = esp32dev
framework = arduino
"""
platformio.write_ini(content)
ini_file = tmp_path / "platformio.ini"
assert ini_file.exists()
file_content = ini_file.read_text()
assert content in file_content
assert platformio.INI_AUTO_GENERATE_BEGIN in file_content
assert platformio.INI_AUTO_GENERATE_END in file_content
def test_write_ini_updates_existing_file(
tmp_path: Path, mock_update_storage_json: MagicMock
) -> None:
"""Test write_ini updates existing platformio.ini file."""
CORE.build_path = str(tmp_path)
# Create existing file with custom content
ini_file = tmp_path / "platformio.ini"
existing_content = f"""
; Custom header
[platformio]
default_envs = test
{platformio.INI_AUTO_GENERATE_BEGIN}
; Old auto-generated content
[env:old]
platform = old
{platformio.INI_AUTO_GENERATE_END}
; Custom footer
"""
ini_file.write_text(existing_content)
# New content to write
new_content = """
[env:test]
platform = espressif32
board = esp32dev
framework = arduino
"""
platformio.write_ini(new_content)
file_content = ini_file.read_text()
# Check that custom parts are preserved
assert "; Custom header" in file_content
assert "[platformio]" in file_content
assert "default_envs = test" in file_content
assert "; Custom footer" in file_content
# Check that new content replaced old auto-generated content
assert new_content in file_content
assert "[env:old]" not in file_content
assert "platform = old" not in file_content
def test_write_ini_preserves_custom_sections(
tmp_path: Path, mock_update_storage_json: MagicMock
) -> None:
"""Test write_ini preserves custom sections outside auto-generate markers."""
CORE.build_path = str(tmp_path)
# Create existing file with multiple custom sections
ini_file = tmp_path / "platformio.ini"
existing_content = f"""
[platformio]
src_dir = .
include_dir = .
[common]
lib_deps =
Wire
SPI
{platformio.INI_AUTO_GENERATE_BEGIN}
[env:old]
platform = old
{platformio.INI_AUTO_GENERATE_END}
[env:custom]
upload_speed = 921600
monitor_speed = 115200
"""
ini_file.write_text(existing_content)
new_content = "[env:auto]\nplatform = new"
platformio.write_ini(new_content)
file_content = ini_file.read_text()
# All custom sections should be preserved
assert "[platformio]" in file_content
assert "src_dir = ." in file_content
assert "[common]" in file_content
assert "lib_deps" in file_content
assert "[env:custom]" in file_content
assert "upload_speed = 921600" in file_content
# New auto-generated content should replace old
assert "[env:auto]" in file_content
assert "platform = new" in file_content
assert "[env:old]" not in file_content
def test_write_ini_no_change_when_content_same(
tmp_path: Path,
mock_update_storage_json: MagicMock,
mock_write_file_if_changed: MagicMock,
) -> None:
"""Test write_ini doesn't rewrite file when content is unchanged."""
CORE.build_path = str(tmp_path)
content = "[env:test]\nplatform = esp32"
full_content = (
f"{platformio.INI_BASE_FORMAT[0]}"
f"{platformio.INI_AUTO_GENERATE_BEGIN}\n"
f"{content}"
f"{platformio.INI_AUTO_GENERATE_END}"
f"{platformio.INI_BASE_FORMAT[1]}"
)
ini_file = tmp_path / "platformio.ini"
ini_file.write_text(full_content)
mock_write_file_if_changed.return_value = False # Indicate no change
platformio.write_ini(content)
# write_file_if_changed should be called with the same content
mock_write_file_if_changed.assert_called_once()
call_args = mock_write_file_if_changed.call_args[0]
assert call_args[0] == str(ini_file)
assert content in call_args[1]
def test_write_ini_calls_update_storage_json(
tmp_path: Path, mock_update_storage_json: MagicMock
) -> None:
"""Test write_ini calls update_storage_json."""
CORE.build_path = str(tmp_path)
content = "[env:test]\nplatform = esp32"
platformio.write_ini(content)
mock_update_storage_json.assert_called_once()

View File

@@ -9,8 +9,10 @@ not be part of a unit test suite.
""" """
from collections.abc import Generator
from pathlib import Path from pathlib import Path
import sys import sys
from unittest.mock import Mock, patch
import pytest import pytest
@@ -43,3 +45,45 @@ def setup_core(tmp_path: Path) -> Path:
"""Set up CORE with test paths.""" """Set up CORE with test paths."""
CORE.config_path = str(tmp_path / "test.yaml") CORE.config_path = str(tmp_path / "test.yaml")
return tmp_path return tmp_path
@pytest.fixture
def mock_write_file_if_changed() -> Generator[Mock, None, None]:
"""Mock write_file_if_changed for storage_json."""
with patch("esphome.storage_json.write_file_if_changed") as mock:
yield mock
@pytest.fixture
def mock_copy_file_if_changed() -> Generator[Mock, None, None]:
"""Mock copy_file_if_changed for core.config."""
with patch("esphome.core.config.copy_file_if_changed") as mock:
yield mock
@pytest.fixture
def mock_run_platformio_cli() -> Generator[Mock, None, None]:
"""Mock run_platformio_cli for platformio_api."""
with patch("esphome.platformio_api.run_platformio_cli") as mock:
yield mock
@pytest.fixture
def mock_run_platformio_cli_run() -> Generator[Mock, None, None]:
"""Mock run_platformio_cli_run for platformio_api."""
with patch("esphome.platformio_api.run_platformio_cli_run") as mock:
yield mock
@pytest.fixture
def mock_decode_pc() -> Generator[Mock, None, None]:
"""Mock _decode_pc for platformio_api."""
with patch("esphome.platformio_api._decode_pc") as mock:
yield mock
@pytest.fixture
def mock_run_external_command() -> Generator[Mock, None, None]:
"""Mock run_external_command for platformio_api."""
with patch("esphome.platformio_api.run_external_command") as mock:
yield mock

View File

@@ -1,21 +1,56 @@
"""Unit tests for core config functionality including areas and devices.""" """Unit tests for core config functionality including areas and devices."""
from collections.abc import Callable from collections.abc import Callable
import os
from pathlib import Path from pathlib import Path
import types
from typing import Any from typing import Any
from unittest.mock import MagicMock, Mock, patch
import pytest import pytest
from esphome import config_validation as cv, core from esphome import config_validation as cv, core
from esphome.const import CONF_AREA, CONF_AREAS, CONF_DEVICES from esphome.const import (
from esphome.core import config CONF_AREA,
from esphome.core.config import Area, validate_area_config CONF_AREAS,
CONF_BUILD_PATH,
CONF_DEVICES,
CONF_ESPHOME,
CONF_NAME,
CONF_NAME_ADD_MAC_SUFFIX,
KEY_CORE,
)
from esphome.core import CORE, config
from esphome.core.config import (
Area,
preload_core_config,
valid_include,
valid_project_name,
validate_area_config,
validate_hostname,
)
from .common import load_config_from_fixture from .common import load_config_from_fixture
FIXTURES_DIR = Path(__file__).parent.parent / "fixtures" / "core" / "config" FIXTURES_DIR = Path(__file__).parent.parent / "fixtures" / "core" / "config"
@pytest.fixture
def mock_cg_with_include_capture() -> tuple[Mock, list[str]]:
"""Mock code generation with include capture."""
includes_added: list[str] = []
with patch("esphome.core.config.cg") as mock_cg:
mock_raw_statement = MagicMock()
def capture_include(text: str) -> MagicMock:
includes_added.append(text)
return mock_raw_statement
mock_cg.RawStatement.side_effect = capture_include
yield mock_cg, includes_added
def test_validate_area_config_with_string() -> None: def test_validate_area_config_with_string() -> None:
"""Test that string area config is converted to structured format.""" """Test that string area config is converted to structured format."""
result = validate_area_config("Living Room") result = validate_area_config("Living Room")
@@ -245,3 +280,566 @@ def test_add_platform_defines_priority() -> None:
f"_add_platform_defines priority ({config._add_platform_defines.priority}) must be lower than " f"_add_platform_defines priority ({config._add_platform_defines.priority}) must be lower than "
f"globals priority ({globals_to_code.priority}) to fix issue #10431 (sensor count bug with lambdas)" f"globals priority ({globals_to_code.priority}) to fix issue #10431 (sensor count bug with lambdas)"
) )
def test_valid_include_with_angle_brackets() -> None:
"""Test valid_include accepts angle bracket includes."""
assert valid_include("<ArduinoJson.h>") == "<ArduinoJson.h>"
def test_valid_include_with_valid_file(tmp_path: Path) -> None:
"""Test valid_include accepts valid include files."""
CORE.config_path = str(tmp_path / "test.yaml")
include_file = tmp_path / "include.h"
include_file.touch()
assert valid_include(str(include_file)) == str(include_file)
def test_valid_include_with_valid_directory(tmp_path: Path) -> None:
"""Test valid_include accepts valid directories."""
CORE.config_path = str(tmp_path / "test.yaml")
include_dir = tmp_path / "includes"
include_dir.mkdir()
assert valid_include(str(include_dir)) == str(include_dir)
def test_valid_include_invalid_extension(tmp_path: Path) -> None:
"""Test valid_include rejects files with invalid extensions."""
CORE.config_path = str(tmp_path / "test.yaml")
invalid_file = tmp_path / "file.txt"
invalid_file.touch()
with pytest.raises(cv.Invalid, match="Include has invalid file extension"):
valid_include(str(invalid_file))
def test_valid_project_name_valid() -> None:
"""Test valid_project_name accepts valid project names."""
assert valid_project_name("esphome.my_project") == "esphome.my_project"
def test_valid_project_name_no_namespace() -> None:
"""Test valid_project_name rejects names without namespace."""
with pytest.raises(cv.Invalid, match="project name needs to have a namespace"):
valid_project_name("my_project")
def test_valid_project_name_multiple_dots() -> None:
"""Test valid_project_name rejects names with multiple dots."""
with pytest.raises(cv.Invalid, match="project name needs to have a namespace"):
valid_project_name("esphome.my.project")
def test_validate_hostname_valid() -> None:
"""Test validate_hostname accepts valid hostnames."""
config = {CONF_NAME: "my-device", CONF_NAME_ADD_MAC_SUFFIX: False}
assert validate_hostname(config) == config
def test_validate_hostname_too_long() -> None:
"""Test validate_hostname rejects hostnames that are too long."""
config = {
CONF_NAME: "a" * 32, # 32 chars, max is 31
CONF_NAME_ADD_MAC_SUFFIX: False,
}
with pytest.raises(cv.Invalid, match="Hostnames can only be 31 characters long"):
validate_hostname(config)
def test_validate_hostname_too_long_with_mac_suffix() -> None:
"""Test validate_hostname accounts for MAC suffix length."""
config = {
CONF_NAME: "a" * 25, # 25 chars, max is 24 with MAC suffix
CONF_NAME_ADD_MAC_SUFFIX: True,
}
with pytest.raises(cv.Invalid, match="Hostnames can only be 24 characters long"):
validate_hostname(config)
def test_validate_hostname_with_underscore(caplog) -> None:
"""Test validate_hostname warns about underscores."""
config = {CONF_NAME: "my_device", CONF_NAME_ADD_MAC_SUFFIX: False}
assert validate_hostname(config) == config
assert (
"Using the '_' (underscore) character in the hostname is discouraged"
in caplog.text
)
def test_preload_core_config_basic(setup_core: Path) -> None:
"""Test preload_core_config sets basic CORE attributes."""
config = {
CONF_ESPHOME: {
CONF_NAME: "test_device",
},
"esp32": {},
}
result = {}
platform = preload_core_config(config, result)
assert CORE.name == "test_device"
assert platform == "esp32"
assert KEY_CORE in CORE.data
assert CONF_BUILD_PATH in config[CONF_ESPHOME]
def test_preload_core_config_with_build_path(setup_core: Path) -> None:
"""Test preload_core_config uses provided build path."""
config = {
CONF_ESPHOME: {
CONF_NAME: "test_device",
CONF_BUILD_PATH: "/custom/build/path",
},
"esp8266": {},
}
result = {}
platform = preload_core_config(config, result)
assert config[CONF_ESPHOME][CONF_BUILD_PATH] == "/custom/build/path"
assert platform == "esp8266"
def test_preload_core_config_env_build_path(setup_core: Path) -> None:
"""Test preload_core_config uses ESPHOME_BUILD_PATH env var."""
config = {
CONF_ESPHOME: {
CONF_NAME: "test_device",
},
"rp2040": {},
}
result = {}
with patch.dict(os.environ, {"ESPHOME_BUILD_PATH": "/env/build"}):
platform = preload_core_config(config, result)
assert CONF_BUILD_PATH in config[CONF_ESPHOME]
assert "test_device" in config[CONF_ESPHOME][CONF_BUILD_PATH]
assert platform == "rp2040"
def test_preload_core_config_no_platform(setup_core: Path) -> None:
"""Test preload_core_config raises when no platform is specified."""
config = {
CONF_ESPHOME: {
CONF_NAME: "test_device",
},
}
result = {}
# Mock _is_target_platform to avoid expensive component loading
with patch("esphome.core.config._is_target_platform") as mock_is_platform:
# Return True for known platforms
mock_is_platform.side_effect = lambda name: name in [
"esp32",
"esp8266",
"rp2040",
]
with pytest.raises(cv.Invalid, match="Platform missing"):
preload_core_config(config, result)
def test_preload_core_config_multiple_platforms(setup_core: Path) -> None:
"""Test preload_core_config raises when multiple platforms are specified."""
config = {
CONF_ESPHOME: {
CONF_NAME: "test_device",
},
"esp32": {},
"esp8266": {},
}
result = {}
# Mock _is_target_platform to avoid expensive component loading
with patch("esphome.core.config._is_target_platform") as mock_is_platform:
# Return True for known platforms
mock_is_platform.side_effect = lambda name: name in [
"esp32",
"esp8266",
"rp2040",
]
with pytest.raises(cv.Invalid, match="Found multiple target platform blocks"):
preload_core_config(config, result)
def test_include_file_header(tmp_path: Path, mock_copy_file_if_changed: Mock) -> None:
"""Test include_file adds include statement for header files."""
src_file = tmp_path / "source.h"
src_file.write_text("// Header content")
CORE.build_path = str(tmp_path / "build")
with patch("esphome.core.config.cg") as mock_cg:
# Mock RawStatement to capture the text
mock_raw_statement = MagicMock()
mock_raw_statement.text = ""
def raw_statement_side_effect(text):
mock_raw_statement.text = text
return mock_raw_statement
mock_cg.RawStatement.side_effect = raw_statement_side_effect
config.include_file(str(src_file), "test.h")
mock_copy_file_if_changed.assert_called_once()
mock_cg.add_global.assert_called_once()
# Check that include statement was added
assert '#include "test.h"' in mock_raw_statement.text
def test_include_file_cpp(tmp_path: Path, mock_copy_file_if_changed: Mock) -> None:
"""Test include_file does not add include for cpp files."""
src_file = tmp_path / "source.cpp"
src_file.write_text("// CPP content")
CORE.build_path = str(tmp_path / "build")
with patch("esphome.core.config.cg") as mock_cg:
config.include_file(str(src_file), "test.cpp")
mock_copy_file_if_changed.assert_called_once()
# Should not add include statement for .cpp files
mock_cg.add_global.assert_not_called()
def test_get_usable_cpu_count() -> None:
"""Test get_usable_cpu_count returns CPU count."""
count = config.get_usable_cpu_count()
assert isinstance(count, int)
assert count > 0
def test_get_usable_cpu_count_with_process_cpu_count() -> None:
"""Test get_usable_cpu_count uses process_cpu_count when available."""
# Test with process_cpu_count (Python 3.13+)
# Create a mock os module with process_cpu_count
mock_os = types.SimpleNamespace(process_cpu_count=lambda: 8, cpu_count=lambda: 4)
with patch("esphome.core.config.os", mock_os):
# When process_cpu_count exists, it should be used
count = config.get_usable_cpu_count()
assert count == 8
# Test fallback to cpu_count when process_cpu_count not available
mock_os_no_process = types.SimpleNamespace(cpu_count=lambda: 4)
with patch("esphome.core.config.os", mock_os_no_process):
count = config.get_usable_cpu_count()
assert count == 4
def test_list_target_platforms(tmp_path: Path) -> None:
"""Test _list_target_platforms returns available platforms."""
# Create mock components directory structure
components_dir = tmp_path / "components"
components_dir.mkdir()
# Create platform and non-platform directories with __init__.py
platforms = ["esp32", "esp8266", "rp2040", "libretiny", "host"]
non_platforms = ["sensor"]
for component in platforms + non_platforms:
component_dir = components_dir / component
component_dir.mkdir()
(component_dir / "__init__.py").touch()
# Create a file (not a directory)
(components_dir / "README.md").touch()
# Create a directory without __init__.py
(components_dir / "no_init").mkdir()
# Mock Path(__file__).parents[1] to return our tmp_path
with patch("esphome.core.config.Path") as mock_path:
mock_file_path = MagicMock()
mock_file_path.parents = [MagicMock(), tmp_path]
mock_path.return_value = mock_file_path
platforms = config._list_target_platforms()
assert isinstance(platforms, list)
# Should include platform components
assert "esp32" in platforms
assert "esp8266" in platforms
assert "rp2040" in platforms
assert "libretiny" in platforms
assert "host" in platforms
# Should not include non-platform components
assert "sensor" not in platforms
assert "README.md" not in platforms
assert "no_init" not in platforms
def test_is_target_platform() -> None:
"""Test _is_target_platform identifies valid platforms."""
assert config._is_target_platform("esp32") is True
assert config._is_target_platform("esp8266") is True
assert config._is_target_platform("rp2040") is True
assert config._is_target_platform("invalid_platform") is False
assert config._is_target_platform("api") is False # Component but not platform
@pytest.mark.asyncio
async def test_add_includes_with_single_file(
tmp_path: Path,
mock_copy_file_if_changed: Mock,
mock_cg_with_include_capture: tuple[Mock, list[str]],
) -> None:
"""Test add_includes copies a single header file to build directory."""
CORE.config_path = str(tmp_path / "config.yaml")
CORE.build_path = str(tmp_path / "build")
os.makedirs(CORE.build_path, exist_ok=True)
# Create include file
include_file = tmp_path / "my_header.h"
include_file.write_text("#define MY_CONSTANT 42")
mock_cg, includes_added = mock_cg_with_include_capture
await config.add_includes([str(include_file)])
# Verify copy_file_if_changed was called to copy the file
# Note: add_includes adds files to a src/ subdirectory
mock_copy_file_if_changed.assert_called_once_with(
str(include_file), str(Path(CORE.build_path) / "src" / "my_header.h")
)
# Verify include statement was added
assert any('#include "my_header.h"' in inc for inc in includes_added)
@pytest.mark.asyncio
@pytest.mark.skipif(os.name == "nt", reason="Unix-specific test")
async def test_add_includes_with_directory_unix(
tmp_path: Path,
mock_copy_file_if_changed: Mock,
mock_cg_with_include_capture: tuple[Mock, list[str]],
) -> None:
"""Test add_includes copies all files from a directory on Unix."""
CORE.config_path = str(tmp_path / "config.yaml")
CORE.build_path = str(tmp_path / "build")
os.makedirs(CORE.build_path, exist_ok=True)
# Create include directory with files
include_dir = tmp_path / "includes"
include_dir.mkdir()
(include_dir / "header1.h").write_text("#define HEADER1")
(include_dir / "header2.hpp").write_text("#define HEADER2")
(include_dir / "source.cpp").write_text("// Implementation")
(include_dir / "README.md").write_text(
"# Documentation"
) # Should be copied but not included
# Create subdirectory with files
subdir = include_dir / "subdir"
subdir.mkdir()
(subdir / "nested.h").write_text("#define NESTED")
mock_cg, includes_added = mock_cg_with_include_capture
await config.add_includes([str(include_dir)])
# Verify copy_file_if_changed was called for all files
assert mock_copy_file_if_changed.call_count == 5 # 4 code files + 1 README
# Verify include statements were added for valid extensions
include_strings = " ".join(includes_added)
assert "includes/header1.h" in include_strings
assert "includes/header2.hpp" in include_strings
assert "includes/subdir/nested.h" in include_strings
# CPP files are copied but not included
assert "source.cpp" not in include_strings or "#include" not in include_strings
# README.md should not have an include statement
assert "README.md" not in include_strings
@pytest.mark.asyncio
@pytest.mark.skipif(os.name != "nt", reason="Windows-specific test")
async def test_add_includes_with_directory_windows(
tmp_path: Path,
mock_copy_file_if_changed: Mock,
mock_cg_with_include_capture: tuple[Mock, list[str]],
) -> None:
"""Test add_includes copies all files from a directory on Windows."""
CORE.config_path = str(tmp_path / "config.yaml")
CORE.build_path = str(tmp_path / "build")
os.makedirs(CORE.build_path, exist_ok=True)
# Create include directory with files
include_dir = tmp_path / "includes"
include_dir.mkdir()
(include_dir / "header1.h").write_text("#define HEADER1")
(include_dir / "header2.hpp").write_text("#define HEADER2")
(include_dir / "source.cpp").write_text("// Implementation")
(include_dir / "README.md").write_text(
"# Documentation"
) # Should be copied but not included
# Create subdirectory with files
subdir = include_dir / "subdir"
subdir.mkdir()
(subdir / "nested.h").write_text("#define NESTED")
mock_cg, includes_added = mock_cg_with_include_capture
await config.add_includes([str(include_dir)])
# Verify copy_file_if_changed was called for all files
assert mock_copy_file_if_changed.call_count == 5 # 4 code files + 1 README
# Verify include statements were added for valid extensions
include_strings = " ".join(includes_added)
assert "includes\\header1.h" in include_strings
assert "includes\\header2.hpp" in include_strings
assert "includes\\subdir\\nested.h" in include_strings
# CPP files are copied but not included
assert "source.cpp" not in include_strings or "#include" not in include_strings
# README.md should not have an include statement
assert "README.md" not in include_strings
@pytest.mark.asyncio
async def test_add_includes_with_multiple_sources(
tmp_path: Path, mock_copy_file_if_changed: Mock
) -> None:
"""Test add_includes with multiple files and directories."""
CORE.config_path = str(tmp_path / "config.yaml")
CORE.build_path = str(tmp_path / "build")
os.makedirs(CORE.build_path, exist_ok=True)
# Create various include sources
single_file = tmp_path / "single.h"
single_file.write_text("#define SINGLE")
dir1 = tmp_path / "dir1"
dir1.mkdir()
(dir1 / "file1.h").write_text("#define FILE1")
dir2 = tmp_path / "dir2"
dir2.mkdir()
(dir2 / "file2.cpp").write_text("// File2")
with patch("esphome.core.config.cg"):
await config.add_includes([str(single_file), str(dir1), str(dir2)])
# Verify copy_file_if_changed was called for all files
assert mock_copy_file_if_changed.call_count == 3 # 3 files total
@pytest.mark.asyncio
async def test_add_includes_empty_directory(
tmp_path: Path, mock_copy_file_if_changed: Mock
) -> None:
"""Test add_includes with an empty directory doesn't fail."""
CORE.config_path = str(tmp_path / "config.yaml")
CORE.build_path = str(tmp_path / "build")
os.makedirs(CORE.build_path, exist_ok=True)
# Create empty directory
empty_dir = tmp_path / "empty"
empty_dir.mkdir()
with patch("esphome.core.config.cg"):
# Should not raise any errors
await config.add_includes([str(empty_dir)])
# No files to copy from empty directory
mock_copy_file_if_changed.assert_not_called()
@pytest.mark.asyncio
@pytest.mark.skipif(os.name == "nt", reason="Unix-specific test")
async def test_add_includes_preserves_directory_structure_unix(
tmp_path: Path, mock_copy_file_if_changed: Mock
) -> None:
"""Test that add_includes preserves relative directory structure on Unix."""
CORE.config_path = str(tmp_path / "config.yaml")
CORE.build_path = str(tmp_path / "build")
os.makedirs(CORE.build_path, exist_ok=True)
# Create nested directory structure
lib_dir = tmp_path / "lib"
lib_dir.mkdir()
src_dir = lib_dir / "src"
src_dir.mkdir()
(src_dir / "core.h").write_text("#define CORE")
utils_dir = lib_dir / "utils"
utils_dir.mkdir()
(utils_dir / "helper.h").write_text("#define HELPER")
with patch("esphome.core.config.cg"):
await config.add_includes([str(lib_dir)])
# Verify copy_file_if_changed was called with correct paths
calls = mock_copy_file_if_changed.call_args_list
dest_paths = [call[0][1] for call in calls]
# Check that relative paths are preserved
assert any("lib/src/core.h" in path for path in dest_paths)
assert any("lib/utils/helper.h" in path for path in dest_paths)
@pytest.mark.asyncio
@pytest.mark.skipif(os.name != "nt", reason="Windows-specific test")
async def test_add_includes_preserves_directory_structure_windows(
tmp_path: Path, mock_copy_file_if_changed: Mock
) -> None:
"""Test that add_includes preserves relative directory structure on Windows."""
CORE.config_path = str(tmp_path / "config.yaml")
CORE.build_path = str(tmp_path / "build")
os.makedirs(CORE.build_path, exist_ok=True)
# Create nested directory structure
lib_dir = tmp_path / "lib"
lib_dir.mkdir()
src_dir = lib_dir / "src"
src_dir.mkdir()
(src_dir / "core.h").write_text("#define CORE")
utils_dir = lib_dir / "utils"
utils_dir.mkdir()
(utils_dir / "helper.h").write_text("#define HELPER")
with patch("esphome.core.config.cg"):
await config.add_includes([str(lib_dir)])
# Verify copy_file_if_changed was called with correct paths
calls = mock_copy_file_if_changed.call_args_list
dest_paths = [call[0][1] for call in calls]
# Check that relative paths are preserved
assert any("lib\\src\\core.h" in path for path in dest_paths)
assert any("lib\\utils\\helper.h" in path for path in dest_paths)
@pytest.mark.asyncio
async def test_add_includes_overwrites_existing_files(
tmp_path: Path, mock_copy_file_if_changed: Mock
) -> None:
"""Test that add_includes overwrites existing files in build directory."""
CORE.config_path = str(tmp_path / "config.yaml")
CORE.build_path = str(tmp_path / "build")
os.makedirs(CORE.build_path, exist_ok=True)
# Create include file
include_file = tmp_path / "header.h"
include_file.write_text("#define NEW_VALUE 42")
with patch("esphome.core.config.cg"):
await config.add_includes([str(include_file)])
# Verify copy_file_if_changed was called (it handles overwriting)
# Note: add_includes adds files to a src/ subdirectory
mock_copy_file_if_changed.assert_called_once_with(
str(include_file), str(Path(CORE.build_path) / "src" / "header.h")
)

View File

@@ -0,0 +1,305 @@
"""Tests for the address_cache module."""
from __future__ import annotations
import logging
import pytest
from pytest import LogCaptureFixture
from esphome.address_cache import AddressCache, normalize_hostname
def test_normalize_simple_hostname() -> None:
"""Test normalizing a simple hostname."""
assert normalize_hostname("device") == "device"
assert normalize_hostname("device.local") == "device.local"
assert normalize_hostname("server.example.com") == "server.example.com"
def test_normalize_removes_trailing_dots() -> None:
"""Test that trailing dots are removed."""
assert normalize_hostname("device.") == "device"
assert normalize_hostname("device.local.") == "device.local"
assert normalize_hostname("server.example.com.") == "server.example.com"
assert normalize_hostname("device...") == "device"
def test_normalize_converts_to_lowercase() -> None:
"""Test that hostnames are converted to lowercase."""
assert normalize_hostname("DEVICE") == "device"
assert normalize_hostname("Device.Local") == "device.local"
assert normalize_hostname("Server.Example.COM") == "server.example.com"
def test_normalize_combined() -> None:
"""Test combination of trailing dots and case conversion."""
assert normalize_hostname("DEVICE.LOCAL.") == "device.local"
assert normalize_hostname("Server.Example.COM...") == "server.example.com"
def test_init_empty() -> None:
"""Test initialization with empty caches."""
cache = AddressCache()
assert cache.mdns_cache == {}
assert cache.dns_cache == {}
assert not cache.has_cache()
def test_init_with_caches() -> None:
"""Test initialization with provided caches."""
mdns_cache: dict[str, list[str]] = {"device.local": ["192.168.1.10"]}
dns_cache: dict[str, list[str]] = {"server.com": ["10.0.0.1"]}
cache = AddressCache(mdns_cache=mdns_cache, dns_cache=dns_cache)
assert cache.mdns_cache == mdns_cache
assert cache.dns_cache == dns_cache
assert cache.has_cache()
def test_get_mdns_addresses() -> None:
"""Test getting mDNS addresses."""
cache = AddressCache(mdns_cache={"device.local": ["192.168.1.10", "192.168.1.11"]})
# Direct lookup
assert cache.get_mdns_addresses("device.local") == [
"192.168.1.10",
"192.168.1.11",
]
# Case insensitive lookup
assert cache.get_mdns_addresses("Device.Local") == [
"192.168.1.10",
"192.168.1.11",
]
# With trailing dot
assert cache.get_mdns_addresses("device.local.") == [
"192.168.1.10",
"192.168.1.11",
]
# Not found
assert cache.get_mdns_addresses("unknown.local") is None
def test_get_dns_addresses() -> None:
"""Test getting DNS addresses."""
cache = AddressCache(dns_cache={"server.com": ["10.0.0.1", "10.0.0.2"]})
# Direct lookup
assert cache.get_dns_addresses("server.com") == ["10.0.0.1", "10.0.0.2"]
# Case insensitive lookup
assert cache.get_dns_addresses("Server.COM") == ["10.0.0.1", "10.0.0.2"]
# With trailing dot
assert cache.get_dns_addresses("server.com.") == ["10.0.0.1", "10.0.0.2"]
# Not found
assert cache.get_dns_addresses("unknown.com") is None
def test_get_addresses_auto_detection() -> None:
"""Test automatic cache selection based on hostname."""
cache = AddressCache(
mdns_cache={"device.local": ["192.168.1.10"]},
dns_cache={"server.com": ["10.0.0.1"]},
)
# Should use mDNS cache for .local domains
assert cache.get_addresses("device.local") == ["192.168.1.10"]
assert cache.get_addresses("device.local.") == ["192.168.1.10"]
assert cache.get_addresses("Device.Local") == ["192.168.1.10"]
# Should use DNS cache for non-.local domains
assert cache.get_addresses("server.com") == ["10.0.0.1"]
assert cache.get_addresses("server.com.") == ["10.0.0.1"]
assert cache.get_addresses("Server.COM") == ["10.0.0.1"]
# Not found
assert cache.get_addresses("unknown.local") is None
assert cache.get_addresses("unknown.com") is None
def test_has_cache() -> None:
"""Test checking if cache has entries."""
# Empty cache
cache = AddressCache()
assert not cache.has_cache()
# Only mDNS cache
cache = AddressCache(mdns_cache={"device.local": ["192.168.1.10"]})
assert cache.has_cache()
# Only DNS cache
cache = AddressCache(dns_cache={"server.com": ["10.0.0.1"]})
assert cache.has_cache()
# Both caches
cache = AddressCache(
mdns_cache={"device.local": ["192.168.1.10"]},
dns_cache={"server.com": ["10.0.0.1"]},
)
assert cache.has_cache()
def test_from_cli_args_empty() -> None:
"""Test creating cache from empty CLI arguments."""
cache = AddressCache.from_cli_args([], [])
assert cache.mdns_cache == {}
assert cache.dns_cache == {}
def test_from_cli_args_single_entry() -> None:
"""Test creating cache from single CLI argument."""
mdns_args: list[str] = ["device.local=192.168.1.10"]
dns_args: list[str] = ["server.com=10.0.0.1"]
cache = AddressCache.from_cli_args(mdns_args, dns_args)
assert cache.mdns_cache == {"device.local": ["192.168.1.10"]}
assert cache.dns_cache == {"server.com": ["10.0.0.1"]}
def test_from_cli_args_multiple_ips() -> None:
"""Test creating cache with multiple IPs per host."""
mdns_args: list[str] = ["device.local=192.168.1.10,192.168.1.11"]
dns_args: list[str] = ["server.com=10.0.0.1,10.0.0.2,10.0.0.3"]
cache = AddressCache.from_cli_args(mdns_args, dns_args)
assert cache.mdns_cache == {"device.local": ["192.168.1.10", "192.168.1.11"]}
assert cache.dns_cache == {"server.com": ["10.0.0.1", "10.0.0.2", "10.0.0.3"]}
def test_from_cli_args_multiple_entries() -> None:
"""Test creating cache with multiple host entries."""
mdns_args: list[str] = [
"device1.local=192.168.1.10",
"device2.local=192.168.1.20,192.168.1.21",
]
dns_args: list[str] = ["server1.com=10.0.0.1", "server2.com=10.0.0.2"]
cache = AddressCache.from_cli_args(mdns_args, dns_args)
assert cache.mdns_cache == {
"device1.local": ["192.168.1.10"],
"device2.local": ["192.168.1.20", "192.168.1.21"],
}
assert cache.dns_cache == {
"server1.com": ["10.0.0.1"],
"server2.com": ["10.0.0.2"],
}
def test_from_cli_args_normalization() -> None:
"""Test that CLI arguments are normalized."""
mdns_args: list[str] = ["Device1.Local.=192.168.1.10", "DEVICE2.LOCAL=192.168.1.20"]
dns_args: list[str] = ["Server1.COM.=10.0.0.1", "SERVER2.com=10.0.0.2"]
cache = AddressCache.from_cli_args(mdns_args, dns_args)
# Hostnames should be normalized (lowercase, no trailing dots)
assert cache.mdns_cache == {
"device1.local": ["192.168.1.10"],
"device2.local": ["192.168.1.20"],
}
assert cache.dns_cache == {
"server1.com": ["10.0.0.1"],
"server2.com": ["10.0.0.2"],
}
def test_from_cli_args_whitespace_handling() -> None:
"""Test that whitespace in IPs is handled."""
mdns_args: list[str] = ["device.local= 192.168.1.10 , 192.168.1.11 "]
dns_args: list[str] = ["server.com= 10.0.0.1 , 10.0.0.2 "]
cache = AddressCache.from_cli_args(mdns_args, dns_args)
assert cache.mdns_cache == {"device.local": ["192.168.1.10", "192.168.1.11"]}
assert cache.dns_cache == {"server.com": ["10.0.0.1", "10.0.0.2"]}
def test_from_cli_args_invalid_format(caplog: LogCaptureFixture) -> None:
"""Test handling of invalid argument format."""
mdns_args: list[str] = ["invalid_format", "device.local=192.168.1.10"]
dns_args: list[str] = ["server.com=10.0.0.1", "also_invalid"]
cache = AddressCache.from_cli_args(mdns_args, dns_args)
# Valid entries should still be processed
assert cache.mdns_cache == {"device.local": ["192.168.1.10"]}
assert cache.dns_cache == {"server.com": ["10.0.0.1"]}
# Check that warnings were logged for invalid entries
assert "Invalid cache format: invalid_format" in caplog.text
assert "Invalid cache format: also_invalid" in caplog.text
def test_from_cli_args_ipv6() -> None:
"""Test handling of IPv6 addresses."""
mdns_args: list[str] = ["device.local=fe80::1,2001:db8::1"]
dns_args: list[str] = ["server.com=2001:db8::2,::1"]
cache = AddressCache.from_cli_args(mdns_args, dns_args)
assert cache.mdns_cache == {"device.local": ["fe80::1", "2001:db8::1"]}
assert cache.dns_cache == {"server.com": ["2001:db8::2", "::1"]}
def test_logging_output(caplog: LogCaptureFixture) -> None:
"""Test that appropriate debug logging occurs."""
caplog.set_level(logging.DEBUG)
cache = AddressCache(
mdns_cache={"device.local": ["192.168.1.10"]},
dns_cache={"server.com": ["10.0.0.1"]},
)
# Test successful lookups log at debug level
result: list[str] | None = cache.get_mdns_addresses("device.local")
assert result == ["192.168.1.10"]
assert "Using mDNS cache for device.local" in caplog.text
caplog.clear()
result = cache.get_dns_addresses("server.com")
assert result == ["10.0.0.1"]
assert "Using DNS cache for server.com" in caplog.text
# Test that failed lookups don't log
caplog.clear()
result = cache.get_mdns_addresses("unknown.local")
assert result is None
assert "Using mDNS cache" not in caplog.text
@pytest.mark.parametrize(
"hostname,expected",
[
("test.local", "test.local"),
("Test.Local.", "test.local"),
("TEST.LOCAL...", "test.local"),
("example.com", "example.com"),
("EXAMPLE.COM.", "example.com"),
],
)
def test_normalize_hostname_parametrized(hostname: str, expected: str) -> None:
"""Test hostname normalization with various inputs."""
assert normalize_hostname(hostname) == expected
@pytest.mark.parametrize(
"mdns_arg,expected",
[
("host=1.2.3.4", {"host": ["1.2.3.4"]}),
("Host.Local=1.2.3.4,5.6.7.8", {"host.local": ["1.2.3.4", "5.6.7.8"]}),
("HOST.LOCAL.=::1", {"host.local": ["::1"]}),
],
)
def test_parse_cache_args_parametrized(
mdns_arg: str, expected: dict[str, list[str]]
) -> None:
"""Test parsing of cache arguments with various formats."""
cache = AddressCache.from_cli_args([mdns_arg], [])
assert cache.mdns_cache == expected

View File

@@ -1,3 +1,6 @@
import os
from unittest.mock import patch
from hypothesis import given from hypothesis import given
import pytest import pytest
from strategies import mac_addr_strings from strategies import mac_addr_strings
@@ -577,3 +580,83 @@ class TestEsphomeCore:
assert target.is_esp32 is False assert target.is_esp32 is False
assert target.is_esp8266 is True assert target.is_esp8266 is True
@pytest.mark.skipif(os.name == "nt", reason="Unix-specific test")
def test_data_dir_default_unix(self, target):
"""Test data_dir returns .esphome in config directory by default on Unix."""
target.config_path = "/home/user/config.yaml"
assert target.data_dir == "/home/user/.esphome"
@pytest.mark.skipif(os.name != "nt", reason="Windows-specific test")
def test_data_dir_default_windows(self, target):
"""Test data_dir returns .esphome in config directory by default on Windows."""
target.config_path = "D:\\home\\user\\config.yaml"
assert target.data_dir == "D:\\home\\user\\.esphome"
def test_data_dir_ha_addon(self, target):
"""Test data_dir returns /data when running as Home Assistant addon."""
target.config_path = "/config/test.yaml"
with patch.dict(os.environ, {"ESPHOME_IS_HA_ADDON": "true"}):
assert target.data_dir == "/data"
def test_data_dir_env_override(self, target):
"""Test data_dir uses ESPHOME_DATA_DIR environment variable when set."""
target.config_path = "/home/user/config.yaml"
with patch.dict(os.environ, {"ESPHOME_DATA_DIR": "/custom/data/path"}):
assert target.data_dir == "/custom/data/path"
@pytest.mark.skipif(os.name == "nt", reason="Unix-specific test")
def test_data_dir_priority_unix(self, target):
"""Test data_dir priority on Unix: HA addon > env var > default."""
target.config_path = "/config/test.yaml"
expected_default = "/config/.esphome"
# Test HA addon takes priority over env var
with patch.dict(
os.environ,
{"ESPHOME_IS_HA_ADDON": "true", "ESPHOME_DATA_DIR": "/custom/path"},
):
assert target.data_dir == "/data"
# Test env var is used when not HA addon
with patch.dict(
os.environ,
{"ESPHOME_IS_HA_ADDON": "false", "ESPHOME_DATA_DIR": "/custom/path"},
):
assert target.data_dir == "/custom/path"
# Test default when neither is set
with patch.dict(os.environ, {}, clear=True):
# Ensure these env vars are not set
os.environ.pop("ESPHOME_IS_HA_ADDON", None)
os.environ.pop("ESPHOME_DATA_DIR", None)
assert target.data_dir == expected_default
@pytest.mark.skipif(os.name != "nt", reason="Windows-specific test")
def test_data_dir_priority_windows(self, target):
"""Test data_dir priority on Windows: HA addon > env var > default."""
target.config_path = "D:\\config\\test.yaml"
expected_default = "D:\\config\\.esphome"
# Test HA addon takes priority over env var
with patch.dict(
os.environ,
{"ESPHOME_IS_HA_ADDON": "true", "ESPHOME_DATA_DIR": "/custom/path"},
):
assert target.data_dir == "/data"
# Test env var is used when not HA addon
with patch.dict(
os.environ,
{"ESPHOME_IS_HA_ADDON": "false", "ESPHOME_DATA_DIR": "/custom/path"},
):
assert target.data_dir == "/custom/path"
# Test default when neither is set
with patch.dict(os.environ, {}, clear=True):
# Ensure these env vars are not set
os.environ.pop("ESPHOME_IS_HA_ADDON", None)
os.environ.pop("ESPHOME_DATA_DIR", None)
assert target.data_dir == expected_default

View File

@@ -1,5 +1,8 @@
import logging import logging
import os
from pathlib import Path
import socket import socket
import stat
from unittest.mock import patch from unittest.mock import patch
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr, IPv6Sockaddr from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr, IPv6Sockaddr
@@ -8,6 +11,7 @@ from hypothesis.strategies import ip_addresses
import pytest import pytest
from esphome import helpers from esphome import helpers
from esphome.address_cache import AddressCache
from esphome.core import EsphomeError from esphome.core import EsphomeError
@@ -554,6 +558,239 @@ def test_addr_preference_ipv6_link_local_with_scope() -> None:
assert helpers.addr_preference_(addr_info) == 1 # Has scope, so it's usable assert helpers.addr_preference_(addr_info) == 1 # Has scope, so it's usable
def test_mkdir_p(tmp_path: Path) -> None:
"""Test mkdir_p creates directories recursively."""
# Test creating nested directories
nested_path = tmp_path / "level1" / "level2" / "level3"
helpers.mkdir_p(nested_path)
assert nested_path.exists()
assert nested_path.is_dir()
# Test that mkdir_p is idempotent (doesn't fail if directory exists)
helpers.mkdir_p(nested_path)
assert nested_path.exists()
# Test with empty path (should do nothing)
helpers.mkdir_p("")
# Test with existing directory
existing_dir = tmp_path / "existing"
existing_dir.mkdir()
helpers.mkdir_p(existing_dir)
assert existing_dir.exists()
def test_mkdir_p_file_exists_error(tmp_path: Path) -> None:
"""Test mkdir_p raises error when path is a file."""
# Create a file
file_path = tmp_path / "test_file.txt"
file_path.write_text("test content")
# Try to create directory with same name as existing file
with pytest.raises(EsphomeError, match=r"Error creating directories"):
helpers.mkdir_p(file_path)
def test_mkdir_p_with_existing_file_raises_error(tmp_path: Path) -> None:
"""Test mkdir_p raises error when trying to create dir over existing file."""
# Create a file where we want to create a directory
file_path = tmp_path / "existing_file"
file_path.write_text("content")
# Try to create a directory with a path that goes through the file
dir_path = file_path / "subdir"
with pytest.raises(EsphomeError, match=r"Error creating directories"):
helpers.mkdir_p(dir_path)
@pytest.mark.skipif(os.name == "nt", reason="Unix-specific test")
def test_read_file_unix(tmp_path: Path) -> None:
"""Test read_file reads file content correctly on Unix."""
# Test reading regular file
test_file = tmp_path / "test.txt"
expected_content = "Test content\nLine 2\n"
test_file.write_text(expected_content)
content = helpers.read_file(test_file)
assert content == expected_content
# Test reading file with UTF-8 characters
utf8_file = tmp_path / "utf8.txt"
utf8_content = "Hello 世界 🌍"
utf8_file.write_text(utf8_content, encoding="utf-8")
content = helpers.read_file(utf8_file)
assert content == utf8_content
@pytest.mark.skipif(os.name != "nt", reason="Windows-specific test")
def test_read_file_windows(tmp_path: Path) -> None:
"""Test read_file reads file content correctly on Windows."""
# Test reading regular file
test_file = tmp_path / "test.txt"
expected_content = "Test content\nLine 2\n"
test_file.write_text(expected_content)
content = helpers.read_file(test_file)
# On Windows, text mode reading converts \n to \r\n
assert content == expected_content.replace("\n", "\r\n")
# Test reading file with UTF-8 characters
utf8_file = tmp_path / "utf8.txt"
utf8_content = "Hello 世界 🌍"
utf8_file.write_text(utf8_content, encoding="utf-8")
content = helpers.read_file(utf8_file)
assert content == utf8_content
def test_read_file_not_found() -> None:
"""Test read_file raises error for non-existent file."""
with pytest.raises(EsphomeError, match=r"Error reading file"):
helpers.read_file("/nonexistent/file.txt")
def test_read_file_unicode_decode_error(tmp_path: Path) -> None:
"""Test read_file raises error for invalid UTF-8."""
test_file = tmp_path / "invalid.txt"
# Write invalid UTF-8 bytes
test_file.write_bytes(b"\xff\xfe")
with pytest.raises(EsphomeError, match=r"Error reading file"):
helpers.read_file(test_file)
@pytest.mark.skipif(os.name == "nt", reason="Unix-specific test")
def test_write_file_unix(tmp_path: Path) -> None:
"""Test write_file writes content correctly on Unix."""
# Test writing string content
test_file = tmp_path / "test.txt"
content = "Test content\nLine 2"
helpers.write_file(test_file, content)
assert test_file.read_text() == content
# Check file permissions
assert oct(test_file.stat().st_mode)[-3:] == "644"
# Test overwriting existing file
new_content = "New content"
helpers.write_file(test_file, new_content)
assert test_file.read_text() == new_content
# Test writing to nested directories (should create them)
nested_file = tmp_path / "dir1" / "dir2" / "file.txt"
helpers.write_file(nested_file, content)
assert nested_file.read_text() == content
@pytest.mark.skipif(os.name != "nt", reason="Windows-specific test")
def test_write_file_windows(tmp_path: Path) -> None:
"""Test write_file writes content correctly on Windows."""
# Test writing string content
test_file = tmp_path / "test.txt"
content = "Test content\nLine 2"
helpers.write_file(test_file, content)
assert test_file.read_text() == content
# Windows doesn't have Unix-style 644 permissions
# Test overwriting existing file
new_content = "New content"
helpers.write_file(test_file, new_content)
assert test_file.read_text() == new_content
# Test writing to nested directories (should create them)
nested_file = tmp_path / "dir1" / "dir2" / "file.txt"
helpers.write_file(nested_file, content)
assert nested_file.read_text() == content
@pytest.mark.skipif(os.name == "nt", reason="Unix-specific permission test")
def test_write_file_to_non_writable_directory_unix(tmp_path: Path) -> None:
"""Test write_file raises error when directory is not writable on Unix."""
# Create a directory and make it read-only
read_only_dir = tmp_path / "readonly"
read_only_dir.mkdir()
test_file = read_only_dir / "test.txt"
# Make directory read-only (no write permission)
read_only_dir.chmod(0o555)
try:
with pytest.raises(EsphomeError, match=r"Could not write file"):
helpers.write_file(test_file, "content")
finally:
# Restore write permissions for cleanup
read_only_dir.chmod(0o755)
@pytest.mark.skipif(os.name != "nt", reason="Windows-specific test")
def test_write_file_to_non_writable_directory_windows(tmp_path: Path) -> None:
"""Test write_file error handling on Windows."""
# Windows handles permissions differently - test a different error case
# Try to write to a file path that contains an existing file as a directory component
existing_file = tmp_path / "file.txt"
existing_file.write_text("content")
# Try to write to a path that treats the file as a directory
invalid_path = existing_file / "subdir" / "test.txt"
with pytest.raises(EsphomeError, match=r"Could not write file"):
helpers.write_file(invalid_path, "content")
@pytest.mark.skipif(os.name == "nt", reason="Unix-specific permission test")
def test_write_file_with_permission_bits_unix(tmp_path: Path) -> None:
"""Test that write_file sets correct permissions on Unix."""
test_file = tmp_path / "test.txt"
helpers.write_file(test_file, "content")
# Check that file has 644 permissions
file_mode = test_file.stat().st_mode
assert stat.S_IMODE(file_mode) == 0o644
@pytest.mark.skipif(os.name == "nt", reason="Unix-specific permission test")
def test_copy_file_if_changed_permission_recovery_unix(tmp_path: Path) -> None:
"""Test copy_file_if_changed handles permission errors correctly on Unix."""
# Test with read-only destination file
src = tmp_path / "source.txt"
dst = tmp_path / "dest.txt"
src.write_text("new content")
dst.write_text("old content")
dst.chmod(0o444) # Make destination read-only
try:
# Should handle permission error by deleting and retrying
helpers.copy_file_if_changed(src, dst)
assert dst.read_text() == "new content"
finally:
# Restore write permissions for cleanup
if dst.exists():
dst.chmod(0o644)
def test_copy_file_if_changed_creates_directories(tmp_path: Path) -> None:
"""Test copy_file_if_changed creates missing directories."""
src = tmp_path / "source.txt"
dst = tmp_path / "subdir" / "nested" / "dest.txt"
src.write_text("content")
helpers.copy_file_if_changed(src, dst)
assert dst.exists()
assert dst.read_text() == "content"
def test_copy_file_if_changed_nonexistent_source(tmp_path: Path) -> None:
"""Test copy_file_if_changed with non-existent source."""
src = tmp_path / "nonexistent.txt"
dst = tmp_path / "dest.txt"
with pytest.raises(EsphomeError, match=r"Error copying file"):
helpers.copy_file_if_changed(src, dst)
def test_resolve_ip_address_sorting() -> None: def test_resolve_ip_address_sorting() -> None:
"""Test that results are sorted by preference.""" """Test that results are sorted by preference."""
# Create multiple address infos with different preferences # Create multiple address infos with different preferences
@@ -594,3 +831,84 @@ def test_resolve_ip_address_sorting() -> None:
assert result[0][4][0] == "2001:db8::1" # IPv6 (preference 1) assert result[0][4][0] == "2001:db8::1" # IPv6 (preference 1)
assert result[1][4][0] == "192.168.1.100" # IPv4 (preference 2) assert result[1][4][0] == "192.168.1.100" # IPv4 (preference 2)
assert result[2][4][0] == "fe80::1" # Link-local no scope (preference 3) assert result[2][4][0] == "fe80::1" # Link-local no scope (preference 3)
def test_resolve_ip_address_with_cache() -> None:
"""Test that the cache is used when provided."""
cache = AddressCache(
mdns_cache={"test.local": ["192.168.1.100", "192.168.1.101"]},
dns_cache={
"example.com": ["93.184.216.34", "2606:2800:220:1:248:1893:25c8:1946"]
},
)
# Test mDNS cache hit
result = helpers.resolve_ip_address("test.local", 6053, address_cache=cache)
# Should return cached addresses without calling resolver
assert len(result) == 2
assert result[0][4][0] == "192.168.1.100"
assert result[1][4][0] == "192.168.1.101"
# Test DNS cache hit
result = helpers.resolve_ip_address("example.com", 6053, address_cache=cache)
# Should return cached addresses with IPv6 first due to preference
assert len(result) == 2
assert result[0][4][0] == "2606:2800:220:1:248:1893:25c8:1946" # IPv6 first
assert result[1][4][0] == "93.184.216.34" # IPv4 second
def test_resolve_ip_address_cache_miss() -> None:
"""Test that resolver is called when not in cache."""
cache = AddressCache(mdns_cache={"other.local": ["192.168.1.200"]})
mock_addr_info = AddrInfo(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
sockaddr=IPv4Sockaddr(address="192.168.1.100", port=6053),
)
with patch("esphome.resolver.AsyncResolver") as MockResolver:
mock_resolver = MockResolver.return_value
mock_resolver.resolve.return_value = [mock_addr_info]
result = helpers.resolve_ip_address("test.local", 6053, address_cache=cache)
# Should call resolver since test.local is not in cache
MockResolver.assert_called_once_with(["test.local"], 6053)
assert len(result) == 1
assert result[0][4][0] == "192.168.1.100"
def test_resolve_ip_address_mixed_cached_uncached() -> None:
"""Test resolution with mix of cached and uncached hosts."""
cache = AddressCache(mdns_cache={"cached.local": ["192.168.1.50"]})
mock_addr_info = AddrInfo(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
sockaddr=IPv4Sockaddr(address="192.168.1.100", port=6053),
)
with patch("esphome.resolver.AsyncResolver") as MockResolver:
mock_resolver = MockResolver.return_value
mock_resolver.resolve.return_value = [mock_addr_info]
# Pass a list with cached IP, cached hostname, and uncached hostname
result = helpers.resolve_ip_address(
["192.168.1.10", "cached.local", "uncached.local"],
6053,
address_cache=cache,
)
# Should only resolve uncached.local
MockResolver.assert_called_once_with(["uncached.local"], 6053)
# Results should include all addresses
addresses = [r[4][0] for r in result]
assert "192.168.1.10" in addresses # Direct IP
assert "192.168.1.50" in addresses # From cache
assert "192.168.1.100" in addresses # From resolver

View File

@@ -1,10 +1,16 @@
"""Tests for platformio_api.py path functions.""" """Tests for platformio_api.py path functions."""
import json
import os
from pathlib import Path from pathlib import Path
from unittest.mock import patch import shutil
from types import SimpleNamespace
from unittest.mock import MagicMock, Mock, patch
import pytest
from esphome import platformio_api from esphome import platformio_api
from esphome.core import CORE from esphome.core import CORE, EsphomeError
def test_idedata_firmware_elf_path(setup_core: Path) -> None: def test_idedata_firmware_elf_path(setup_core: Path) -> None:
@@ -104,7 +110,9 @@ def test_flash_image_dataclass() -> None:
assert image.offset == "0x10000" assert image.offset == "0x10000"
def test_load_idedata_returns_dict(setup_core: Path) -> None: def test_load_idedata_returns_dict(
setup_core: Path, mock_run_platformio_cli_run
) -> None:
"""Test _load_idedata returns parsed idedata dict when successful.""" """Test _load_idedata returns parsed idedata dict when successful."""
CORE.build_path = str(setup_core / "build" / "test") CORE.build_path = str(setup_core / "build" / "test")
CORE.name = "test" CORE.name = "test"
@@ -118,12 +126,511 @@ def test_load_idedata_returns_dict(setup_core: Path) -> None:
idedata_path.parent.mkdir(parents=True, exist_ok=True) idedata_path.parent.mkdir(parents=True, exist_ok=True)
idedata_path.write_text('{"prog_path": "/test/firmware.elf"}') idedata_path.write_text('{"prog_path": "/test/firmware.elf"}')
with patch("esphome.platformio_api.run_platformio_cli_run") as mock_run: mock_run_platformio_cli_run.return_value = '{"prog_path": "/test/firmware.elf"}'
mock_run.return_value = '{"prog_path": "/test/firmware.elf"}'
config = {"name": "test"} config = {"name": "test"}
result = platformio_api._load_idedata(config) result = platformio_api._load_idedata(config)
assert result is not None assert result is not None
assert isinstance(result, dict) assert isinstance(result, dict)
assert result["prog_path"] == "/test/firmware.elf" assert result["prog_path"] == "/test/firmware.elf"
def test_load_idedata_uses_cache_when_valid(
setup_core: Path, mock_run_platformio_cli_run: Mock
) -> None:
"""Test _load_idedata uses cached data when unchanged."""
CORE.build_path = str(setup_core / "build" / "test")
CORE.name = "test"
# Create platformio.ini
platformio_ini = setup_core / "build" / "test" / "platformio.ini"
platformio_ini.parent.mkdir(parents=True, exist_ok=True)
platformio_ini.write_text("content")
# Create idedata cache file that's newer
idedata_path = setup_core / ".esphome" / "idedata" / "test.json"
idedata_path.parent.mkdir(parents=True, exist_ok=True)
idedata_path.write_text('{"prog_path": "/cached/firmware.elf"}')
# Make idedata newer than platformio.ini
platformio_ini_mtime = platformio_ini.stat().st_mtime
os.utime(idedata_path, (platformio_ini_mtime + 1, platformio_ini_mtime + 1))
config = {"name": "test"}
result = platformio_api._load_idedata(config)
# Should not call _run_idedata since cache is valid
mock_run_platformio_cli_run.assert_not_called()
assert result["prog_path"] == "/cached/firmware.elf"
def test_load_idedata_regenerates_when_platformio_ini_newer(
setup_core: Path, mock_run_platformio_cli_run: Mock
) -> None:
"""Test _load_idedata regenerates when platformio.ini is newer."""
CORE.build_path = str(setup_core / "build" / "test")
CORE.name = "test"
# Create idedata cache file first
idedata_path = setup_core / ".esphome" / "idedata" / "test.json"
idedata_path.parent.mkdir(parents=True, exist_ok=True)
idedata_path.write_text('{"prog_path": "/old/firmware.elf"}')
# Create platformio.ini that's newer
idedata_mtime = idedata_path.stat().st_mtime
platformio_ini = setup_core / "build" / "test" / "platformio.ini"
platformio_ini.parent.mkdir(parents=True, exist_ok=True)
platformio_ini.write_text("content")
# Make platformio.ini newer than idedata
os.utime(platformio_ini, (idedata_mtime + 1, idedata_mtime + 1))
# Mock platformio to return new data
new_data = {"prog_path": "/new/firmware.elf"}
mock_run_platformio_cli_run.return_value = json.dumps(new_data)
config = {"name": "test"}
result = platformio_api._load_idedata(config)
# Should call _run_idedata since platformio.ini is newer
mock_run_platformio_cli_run.assert_called_once()
assert result["prog_path"] == "/new/firmware.elf"
def test_load_idedata_regenerates_on_corrupted_cache(
setup_core: Path, mock_run_platformio_cli_run: Mock
) -> None:
"""Test _load_idedata regenerates when cache file is corrupted."""
CORE.build_path = str(setup_core / "build" / "test")
CORE.name = "test"
# Create platformio.ini
platformio_ini = setup_core / "build" / "test" / "platformio.ini"
platformio_ini.parent.mkdir(parents=True, exist_ok=True)
platformio_ini.write_text("content")
# Create corrupted idedata cache file
idedata_path = setup_core / ".esphome" / "idedata" / "test.json"
idedata_path.parent.mkdir(parents=True, exist_ok=True)
idedata_path.write_text('{"prog_path": invalid json')
# Make idedata newer so it would be used if valid
platformio_ini_mtime = platformio_ini.stat().st_mtime
os.utime(idedata_path, (platformio_ini_mtime + 1, platformio_ini_mtime + 1))
# Mock platformio to return new data
new_data = {"prog_path": "/new/firmware.elf"}
mock_run_platformio_cli_run.return_value = json.dumps(new_data)
config = {"name": "test"}
result = platformio_api._load_idedata(config)
# Should call _run_idedata since cache is corrupted
mock_run_platformio_cli_run.assert_called_once()
assert result["prog_path"] == "/new/firmware.elf"
def test_run_idedata_parses_json_from_output(
setup_core: Path, mock_run_platformio_cli_run: Mock
) -> None:
"""Test _run_idedata extracts JSON from platformio output."""
config = {"name": "test"}
expected_data = {
"prog_path": "/path/to/firmware.elf",
"cc_path": "/path/to/gcc",
"extra": {"flash_images": []},
}
# Simulate platformio output with JSON embedded
mock_run_platformio_cli_run.return_value = (
f"Some preamble\n{json.dumps(expected_data)}\nSome postamble"
)
result = platformio_api._run_idedata(config)
assert result == expected_data
def test_run_idedata_raises_on_no_json(
setup_core: Path, mock_run_platformio_cli_run: Mock
) -> None:
"""Test _run_idedata raises EsphomeError when no JSON found."""
config = {"name": "test"}
mock_run_platformio_cli_run.return_value = "No JSON in this output"
with pytest.raises(EsphomeError):
platformio_api._run_idedata(config)
def test_run_idedata_raises_on_invalid_json(
setup_core: Path, mock_run_platformio_cli_run: Mock
) -> None:
"""Test _run_idedata raises on malformed JSON."""
config = {"name": "test"}
mock_run_platformio_cli_run.return_value = '{"invalid": json"}'
# The ValueError from json.loads is re-raised
with pytest.raises(ValueError):
platformio_api._run_idedata(config)
def test_run_platformio_cli_sets_environment_variables(
setup_core: Path, mock_run_external_command: Mock
) -> None:
"""Test run_platformio_cli sets correct environment variables."""
CORE.build_path = str(setup_core / "build" / "test")
with patch.dict(os.environ, {}, clear=False):
mock_run_external_command.return_value = 0
platformio_api.run_platformio_cli("test", "arg")
# Check environment variables were set
assert os.environ["PLATFORMIO_FORCE_COLOR"] == "true"
assert (
setup_core / "build" / "test"
in Path(os.environ["PLATFORMIO_BUILD_DIR"]).parents
or Path(os.environ["PLATFORMIO_BUILD_DIR"]) == setup_core / "build" / "test"
)
assert "PLATFORMIO_LIBDEPS_DIR" in os.environ
assert "PYTHONWARNINGS" in os.environ
# Check command was called correctly
mock_run_external_command.assert_called_once()
args = mock_run_external_command.call_args[0]
assert "platformio" in args
assert "test" in args
assert "arg" in args
def test_run_platformio_cli_run_builds_command(
setup_core: Path, mock_run_platformio_cli: Mock
) -> None:
"""Test run_platformio_cli_run builds correct command."""
CORE.build_path = str(setup_core / "build" / "test")
mock_run_platformio_cli.return_value = 0
config = {"name": "test"}
platformio_api.run_platformio_cli_run(config, True, "extra", "args")
mock_run_platformio_cli.assert_called_once_with(
"run", "-d", CORE.build_path, "-v", "extra", "args"
)
def test_run_compile(setup_core: Path, mock_run_platformio_cli_run: Mock) -> None:
"""Test run_compile with process limit."""
from esphome.const import CONF_COMPILE_PROCESS_LIMIT, CONF_ESPHOME
CORE.build_path = str(setup_core / "build" / "test")
config = {CONF_ESPHOME: {CONF_COMPILE_PROCESS_LIMIT: 4}}
mock_run_platformio_cli_run.return_value = 0
platformio_api.run_compile(config, verbose=True)
mock_run_platformio_cli_run.assert_called_once_with(config, True, "-j4")
def test_get_idedata_caches_result(
setup_core: Path, mock_run_platformio_cli_run: Mock
) -> None:
"""Test get_idedata caches result in CORE.data."""
from esphome.const import KEY_CORE
CORE.build_path = str(setup_core / "build" / "test")
CORE.name = "test"
CORE.data[KEY_CORE] = {}
# Create platformio.ini to avoid regeneration
platformio_ini = setup_core / "build" / "test" / "platformio.ini"
platformio_ini.parent.mkdir(parents=True, exist_ok=True)
platformio_ini.write_text("content")
# Mock platformio to return data
idedata = {"prog_path": "/test/firmware.elf"}
mock_run_platformio_cli_run.return_value = json.dumps(idedata)
config = {"name": "test"}
# First call should load and cache
result1 = platformio_api.get_idedata(config)
mock_run_platformio_cli_run.assert_called_once()
# Second call should use cache from CORE.data
result2 = platformio_api.get_idedata(config)
mock_run_platformio_cli_run.assert_called_once() # Still only called once
assert result1 is result2
assert isinstance(result1, platformio_api.IDEData)
assert result1.firmware_elf_path == "/test/firmware.elf"
def test_idedata_addr2line_path_windows(setup_core: Path) -> None:
"""Test IDEData.addr2line_path on Windows."""
raw_data = {"prog_path": "/path/to/firmware.elf", "cc_path": "C:\\tools\\gcc.exe"}
idedata = platformio_api.IDEData(raw_data)
result = idedata.addr2line_path
assert result == "C:\\tools\\addr2line.exe"
def test_idedata_addr2line_path_unix(setup_core: Path) -> None:
"""Test IDEData.addr2line_path on Unix."""
raw_data = {"prog_path": "/path/to/firmware.elf", "cc_path": "/usr/bin/gcc"}
idedata = platformio_api.IDEData(raw_data)
result = idedata.addr2line_path
assert result == "/usr/bin/addr2line"
def test_patch_structhash(setup_core: Path) -> None:
"""Test patch_structhash monkey patches platformio functions."""
# Create simple namespace objects to act as modules
mock_cli = SimpleNamespace()
mock_helpers = SimpleNamespace()
mock_run = SimpleNamespace(cli=mock_cli, helpers=mock_helpers)
# Mock platformio modules
with patch.dict(
"sys.modules",
{
"platformio.run.cli": mock_cli,
"platformio.run.helpers": mock_helpers,
"platformio.run": mock_run,
"platformio.project.helpers": MagicMock(),
"platformio.fs": MagicMock(),
"platformio": MagicMock(),
},
):
# Call patch_structhash
platformio_api.patch_structhash()
# Verify both modules had clean_build_dir patched
# Check that clean_build_dir was set on both modules
assert hasattr(mock_cli, "clean_build_dir")
assert hasattr(mock_helpers, "clean_build_dir")
# Verify they got the same function assigned
assert mock_cli.clean_build_dir is mock_helpers.clean_build_dir
# Verify it's a real function (not a Mock)
assert callable(mock_cli.clean_build_dir)
assert mock_cli.clean_build_dir.__name__ == "patched_clean_build_dir"
def test_patched_clean_build_dir_removes_outdated(setup_core: Path) -> None:
"""Test patched_clean_build_dir removes build dir when platformio.ini is newer."""
build_dir = setup_core / "build"
build_dir.mkdir()
platformio_ini = setup_core / "platformio.ini"
platformio_ini.write_text("config")
# Make platformio.ini newer than build_dir
build_mtime = build_dir.stat().st_mtime
os.utime(platformio_ini, (build_mtime + 1, build_mtime + 1))
# Track if directory was removed
removed_paths: list[str] = []
def track_rmtree(path: str) -> None:
removed_paths.append(path)
shutil.rmtree(path)
# Create mock modules that patch_structhash expects
mock_cli = SimpleNamespace()
mock_helpers = SimpleNamespace()
mock_project_helpers = MagicMock()
mock_project_helpers.get_project_dir.return_value = str(setup_core)
mock_fs = SimpleNamespace(rmtree=track_rmtree)
with patch.dict(
"sys.modules",
{
"platformio": SimpleNamespace(fs=mock_fs),
"platformio.fs": mock_fs,
"platformio.project.helpers": mock_project_helpers,
"platformio.run": SimpleNamespace(cli=mock_cli, helpers=mock_helpers),
"platformio.run.cli": mock_cli,
"platformio.run.helpers": mock_helpers,
},
):
# Call patch_structhash to install the patched function
platformio_api.patch_structhash()
# Call the patched function
mock_helpers.clean_build_dir(str(build_dir), [])
# Verify directory was removed and recreated
assert len(removed_paths) == 1
assert removed_paths[0] == str(build_dir)
assert build_dir.exists() # makedirs recreated it
def test_patched_clean_build_dir_keeps_updated(setup_core: Path) -> None:
"""Test patched_clean_build_dir keeps build dir when it's up to date."""
build_dir = setup_core / "build"
build_dir.mkdir()
test_file = build_dir / "test.txt"
test_file.write_text("test content")
platformio_ini = setup_core / "platformio.ini"
platformio_ini.write_text("config")
# Make build_dir newer than platformio.ini
ini_mtime = platformio_ini.stat().st_mtime
os.utime(build_dir, (ini_mtime + 1, ini_mtime + 1))
# Track if rmtree is called
removed_paths: list[str] = []
def track_rmtree(path: str) -> None:
removed_paths.append(path)
# Create mock modules
mock_cli = SimpleNamespace()
mock_helpers = SimpleNamespace()
mock_project_helpers = MagicMock()
mock_project_helpers.get_project_dir.return_value = str(setup_core)
mock_fs = SimpleNamespace(rmtree=track_rmtree)
with patch.dict(
"sys.modules",
{
"platformio": SimpleNamespace(fs=mock_fs),
"platformio.fs": mock_fs,
"platformio.project.helpers": mock_project_helpers,
"platformio.run": SimpleNamespace(cli=mock_cli, helpers=mock_helpers),
"platformio.run.cli": mock_cli,
"platformio.run.helpers": mock_helpers,
},
):
# Call patch_structhash to install the patched function
platformio_api.patch_structhash()
# Call the patched function
mock_helpers.clean_build_dir(str(build_dir), [])
# Verify rmtree was NOT called
assert len(removed_paths) == 0
# Verify directory and file still exist
assert build_dir.exists()
assert test_file.exists()
assert test_file.read_text() == "test content"
def test_patched_clean_build_dir_creates_missing(setup_core: Path) -> None:
"""Test patched_clean_build_dir creates build dir when it doesn't exist."""
build_dir = setup_core / "build"
platformio_ini = setup_core / "platformio.ini"
platformio_ini.write_text("config")
# Ensure build_dir doesn't exist
assert not build_dir.exists()
# Track if rmtree is called
removed_paths: list[str] = []
def track_rmtree(path: str) -> None:
removed_paths.append(path)
# Create mock modules
mock_cli = SimpleNamespace()
mock_helpers = SimpleNamespace()
mock_project_helpers = MagicMock()
mock_project_helpers.get_project_dir.return_value = str(setup_core)
mock_fs = SimpleNamespace(rmtree=track_rmtree)
with patch.dict(
"sys.modules",
{
"platformio": SimpleNamespace(fs=mock_fs),
"platformio.fs": mock_fs,
"platformio.project.helpers": mock_project_helpers,
"platformio.run": SimpleNamespace(cli=mock_cli, helpers=mock_helpers),
"platformio.run.cli": mock_cli,
"platformio.run.helpers": mock_helpers,
},
):
# Call patch_structhash to install the patched function
platformio_api.patch_structhash()
# Call the patched function
mock_helpers.clean_build_dir(str(build_dir), [])
# Verify rmtree was NOT called
assert len(removed_paths) == 0
# Verify directory was created
assert build_dir.exists()
def test_process_stacktrace_esp8266_exception(setup_core: Path, caplog) -> None:
"""Test process_stacktrace handles ESP8266 exceptions."""
config = {"name": "test"}
# Test exception type parsing
line = "Exception (28):"
backtrace_state = False
result = platformio_api.process_stacktrace(config, line, backtrace_state)
assert "Access to invalid address: LOAD (wild pointer?)" in caplog.text
assert result is False
def test_process_stacktrace_esp8266_backtrace(
setup_core: Path, mock_decode_pc: Mock
) -> None:
"""Test process_stacktrace handles ESP8266 multi-line backtrace."""
config = {"name": "test"}
# Start of backtrace
line1 = ">>>stack>>>"
state = platformio_api.process_stacktrace(config, line1, False)
assert state is True
# Backtrace content with addresses
line2 = "40201234 40205678"
state = platformio_api.process_stacktrace(config, line2, state)
assert state is True
assert mock_decode_pc.call_count == 2
# End of backtrace
line3 = "<<<stack<<<"
state = platformio_api.process_stacktrace(config, line3, state)
assert state is False
def test_process_stacktrace_esp32_backtrace(
setup_core: Path, mock_decode_pc: Mock
) -> None:
"""Test process_stacktrace handles ESP32 single-line backtrace."""
config = {"name": "test"}
line = "Backtrace: 0x40081234:0x3ffb1234 0x40085678:0x3ffb5678"
state = platformio_api.process_stacktrace(config, line, False)
# Should decode both addresses
assert mock_decode_pc.call_count == 2
mock_decode_pc.assert_any_call(config, "40081234")
mock_decode_pc.assert_any_call(config, "40085678")
assert state is False
def test_process_stacktrace_bad_alloc(
setup_core: Path, mock_decode_pc: Mock, caplog
) -> None:
"""Test process_stacktrace handles bad alloc messages."""
config = {"name": "test"}
line = "last failed alloc call: 40201234(512)"
state = platformio_api.process_stacktrace(config, line, False)
assert "Memory allocation of 512 bytes failed at 40201234" in caplog.text
mock_decode_pc.assert_called_once_with(config, "40201234")
assert state is False

View File

@@ -1,12 +1,15 @@
"""Tests for storage_json.py path functions.""" """Tests for storage_json.py path functions."""
from datetime import datetime
import json
from pathlib import Path from pathlib import Path
import sys import sys
from unittest.mock import patch from unittest.mock import MagicMock, Mock, patch
import pytest import pytest
from esphome import storage_json from esphome import storage_json
from esphome.const import CONF_DISABLED, CONF_MDNS
from esphome.core import CORE from esphome.core import CORE
@@ -115,7 +118,9 @@ def test_storage_json_firmware_bin_path_property(setup_core: Path) -> None:
assert storage.firmware_bin_path == "/path/to/firmware.bin" assert storage.firmware_bin_path == "/path/to/firmware.bin"
def test_storage_json_save_creates_directory(setup_core: Path, tmp_path: Path) -> None: def test_storage_json_save_creates_directory(
setup_core: Path, tmp_path: Path, mock_write_file_if_changed: Mock
) -> None:
"""Test StorageJSON.save creates storage directory if it doesn't exist.""" """Test StorageJSON.save creates storage directory if it doesn't exist."""
storage_dir = tmp_path / "new_data" / "storage" storage_dir = tmp_path / "new_data" / "storage"
storage_file = storage_dir / "test.json" storage_file = storage_dir / "test.json"
@@ -139,11 +144,10 @@ def test_storage_json_save_creates_directory(setup_core: Path, tmp_path: Path) -
no_mdns=False, no_mdns=False,
) )
with patch("esphome.storage_json.write_file_if_changed") as mock_write: storage.save(str(storage_file))
storage.save(str(storage_file)) mock_write_file_if_changed.assert_called_once()
mock_write.assert_called_once() call_args = mock_write_file_if_changed.call_args[0]
call_args = mock_write.call_args[0] assert call_args[0] == str(storage_file)
assert call_args[0] == str(storage_file)
def test_storage_json_from_wizard(setup_core: Path) -> None: def test_storage_json_from_wizard(setup_core: Path) -> None:
@@ -180,3 +184,477 @@ def test_storage_paths_with_ha_addon(mock_is_ha_addon: bool, tmp_path: Path) ->
result = storage_json.esphome_storage_path() result = storage_json.esphome_storage_path()
expected = str(Path("/data") / "esphome.json") expected = str(Path("/data") / "esphome.json")
assert result == expected assert result == expected
def test_storage_json_as_dict() -> None:
"""Test StorageJSON.as_dict returns correct dictionary."""
storage = storage_json.StorageJSON(
storage_version=1,
name="test_device",
friendly_name="Test Device",
comment="Test comment",
esphome_version="2024.1.0",
src_version=1,
address="192.168.1.100",
web_port=80,
target_platform="ESP32",
build_path="/path/to/build",
firmware_bin_path="/path/to/firmware.bin",
loaded_integrations={"wifi", "api", "ota"},
loaded_platforms={"sensor", "binary_sensor"},
no_mdns=True,
framework="arduino",
core_platform="esp32",
)
result = storage.as_dict()
assert result["storage_version"] == 1
assert result["name"] == "test_device"
assert result["friendly_name"] == "Test Device"
assert result["comment"] == "Test comment"
assert result["esphome_version"] == "2024.1.0"
assert result["src_version"] == 1
assert result["address"] == "192.168.1.100"
assert result["web_port"] == 80
assert result["esp_platform"] == "ESP32"
assert result["build_path"] == "/path/to/build"
assert result["firmware_bin_path"] == "/path/to/firmware.bin"
assert "api" in result["loaded_integrations"]
assert "wifi" in result["loaded_integrations"]
assert "ota" in result["loaded_integrations"]
assert result["loaded_integrations"] == sorted(
["wifi", "api", "ota"]
) # Should be sorted
assert "sensor" in result["loaded_platforms"]
assert result["loaded_platforms"] == sorted(
["sensor", "binary_sensor"]
) # Should be sorted
assert result["no_mdns"] is True
assert result["framework"] == "arduino"
assert result["core_platform"] == "esp32"
def test_storage_json_to_json() -> None:
"""Test StorageJSON.to_json returns valid JSON string."""
storage = storage_json.StorageJSON(
storage_version=1,
name="test",
friendly_name="Test",
comment=None,
esphome_version="2024.1.0",
src_version=None,
address="test.local",
web_port=None,
target_platform="ESP8266",
build_path=None,
firmware_bin_path=None,
loaded_integrations=set(),
loaded_platforms=set(),
no_mdns=False,
)
json_str = storage.to_json()
# Should be valid JSON
parsed = json.loads(json_str)
assert parsed["name"] == "test"
assert parsed["storage_version"] == 1
# Should end with newline
assert json_str.endswith("\n")
def test_storage_json_save(tmp_path: Path) -> None:
"""Test StorageJSON.save writes file correctly."""
storage = storage_json.StorageJSON(
storage_version=1,
name="test",
friendly_name="Test",
comment=None,
esphome_version="2024.1.0",
src_version=None,
address="test.local",
web_port=None,
target_platform="ESP32",
build_path=None,
firmware_bin_path=None,
loaded_integrations=set(),
loaded_platforms=set(),
no_mdns=False,
)
save_path = tmp_path / "test.json"
with patch("esphome.storage_json.write_file_if_changed") as mock_write:
storage.save(str(save_path))
mock_write.assert_called_once_with(str(save_path), storage.to_json())
def test_storage_json_from_esphome_core(setup_core: Path) -> None:
"""Test StorageJSON.from_esphome_core creates correct storage object."""
# Mock CORE object
mock_core = MagicMock()
mock_core.name = "my_device"
mock_core.friendly_name = "My Device"
mock_core.comment = "A test device"
mock_core.address = "192.168.1.50"
mock_core.web_port = 8080
mock_core.target_platform = "esp32"
mock_core.is_esp32 = True
mock_core.build_path = "/build/my_device"
mock_core.firmware_bin = "/build/my_device/firmware.bin"
mock_core.loaded_integrations = {"wifi", "api"}
mock_core.loaded_platforms = {"sensor"}
mock_core.config = {CONF_MDNS: {CONF_DISABLED: True}}
mock_core.target_framework = "esp-idf"
with patch("esphome.components.esp32.get_esp32_variant") as mock_variant:
mock_variant.return_value = "ESP32-C3"
result = storage_json.StorageJSON.from_esphome_core(mock_core, old=None)
assert result.name == "my_device"
assert result.friendly_name == "My Device"
assert result.comment == "A test device"
assert result.address == "192.168.1.50"
assert result.web_port == 8080
assert result.target_platform == "ESP32-C3"
assert result.build_path == "/build/my_device"
assert result.firmware_bin_path == "/build/my_device/firmware.bin"
assert result.loaded_integrations == {"wifi", "api"}
assert result.loaded_platforms == {"sensor"}
assert result.no_mdns is True
assert result.framework == "esp-idf"
assert result.core_platform == "esp32"
def test_storage_json_from_esphome_core_mdns_enabled(setup_core: Path) -> None:
"""Test from_esphome_core with mDNS enabled."""
mock_core = MagicMock()
mock_core.name = "test"
mock_core.friendly_name = "Test"
mock_core.comment = None
mock_core.address = "test.local"
mock_core.web_port = None
mock_core.target_platform = "esp8266"
mock_core.is_esp32 = False
mock_core.build_path = "/build"
mock_core.firmware_bin = "/build/firmware.bin"
mock_core.loaded_integrations = set()
mock_core.loaded_platforms = set()
mock_core.config = {} # No MDNS config means enabled
mock_core.target_framework = "arduino"
result = storage_json.StorageJSON.from_esphome_core(mock_core, old=None)
assert result.no_mdns is False
def test_storage_json_load_valid_file(tmp_path: Path) -> None:
"""Test StorageJSON.load with valid JSON file."""
storage_data = {
"storage_version": 1,
"name": "loaded_device",
"friendly_name": "Loaded Device",
"comment": "Loaded from file",
"esphome_version": "2024.1.0",
"src_version": 2,
"address": "10.0.0.1",
"web_port": 8080,
"esp_platform": "ESP32",
"build_path": "/loaded/build",
"firmware_bin_path": "/loaded/firmware.bin",
"loaded_integrations": ["wifi", "api"],
"loaded_platforms": ["sensor"],
"no_mdns": True,
"framework": "arduino",
"core_platform": "esp32",
}
file_path = tmp_path / "storage.json"
file_path.write_text(json.dumps(storage_data))
result = storage_json.StorageJSON.load(str(file_path))
assert result is not None
assert result.name == "loaded_device"
assert result.friendly_name == "Loaded Device"
assert result.comment == "Loaded from file"
assert result.esphome_version == "2024.1.0"
assert result.src_version == 2
assert result.address == "10.0.0.1"
assert result.web_port == 8080
assert result.target_platform == "ESP32"
assert result.build_path == "/loaded/build"
assert result.firmware_bin_path == "/loaded/firmware.bin"
assert result.loaded_integrations == {"wifi", "api"}
assert result.loaded_platforms == {"sensor"}
assert result.no_mdns is True
assert result.framework == "arduino"
assert result.core_platform == "esp32"
def test_storage_json_load_invalid_file(tmp_path: Path) -> None:
"""Test StorageJSON.load with invalid JSON file."""
file_path = tmp_path / "invalid.json"
file_path.write_text("not valid json{")
result = storage_json.StorageJSON.load(str(file_path))
assert result is None
def test_storage_json_load_nonexistent_file() -> None:
"""Test StorageJSON.load with non-existent file."""
result = storage_json.StorageJSON.load("/nonexistent/file.json")
assert result is None
def test_storage_json_equality() -> None:
"""Test StorageJSON equality comparison."""
storage1 = storage_json.StorageJSON(
storage_version=1,
name="test",
friendly_name="Test",
comment=None,
esphome_version="2024.1.0",
src_version=1,
address="test.local",
web_port=80,
target_platform="ESP32",
build_path="/build",
firmware_bin_path="/firmware.bin",
loaded_integrations={"wifi"},
loaded_platforms=set(),
no_mdns=False,
)
storage2 = storage_json.StorageJSON(
storage_version=1,
name="test",
friendly_name="Test",
comment=None,
esphome_version="2024.1.0",
src_version=1,
address="test.local",
web_port=80,
target_platform="ESP32",
build_path="/build",
firmware_bin_path="/firmware.bin",
loaded_integrations={"wifi"},
loaded_platforms=set(),
no_mdns=False,
)
storage3 = storage_json.StorageJSON(
storage_version=1,
name="different", # Different name
friendly_name="Test",
comment=None,
esphome_version="2024.1.0",
src_version=1,
address="test.local",
web_port=80,
target_platform="ESP32",
build_path="/build",
firmware_bin_path="/firmware.bin",
loaded_integrations={"wifi"},
loaded_platforms=set(),
no_mdns=False,
)
assert storage1 == storage2
assert storage1 != storage3
assert storage1 != "not a storage object"
def test_esphome_storage_json_as_dict() -> None:
"""Test EsphomeStorageJSON.as_dict returns correct dictionary."""
storage = storage_json.EsphomeStorageJSON(
storage_version=1,
cookie_secret="secret123",
last_update_check="2024-01-15T10:30:00",
remote_version="2024.1.1",
)
result = storage.as_dict()
assert result["storage_version"] == 1
assert result["cookie_secret"] == "secret123"
assert result["last_update_check"] == "2024-01-15T10:30:00"
assert result["remote_version"] == "2024.1.1"
def test_esphome_storage_json_last_update_check_property() -> None:
"""Test EsphomeStorageJSON.last_update_check property."""
storage = storage_json.EsphomeStorageJSON(
storage_version=1,
cookie_secret="secret",
last_update_check="2024-01-15T10:30:00",
remote_version=None,
)
# Test getter
result = storage.last_update_check
assert isinstance(result, datetime)
assert result.year == 2024
assert result.month == 1
assert result.day == 15
assert result.hour == 10
assert result.minute == 30
# Test setter
new_date = datetime(2024, 2, 20, 15, 45, 30)
storage.last_update_check = new_date
assert storage.last_update_check_str == "2024-02-20T15:45:30"
def test_esphome_storage_json_last_update_check_invalid() -> None:
"""Test EsphomeStorageJSON.last_update_check with invalid date."""
storage = storage_json.EsphomeStorageJSON(
storage_version=1,
cookie_secret="secret",
last_update_check="invalid date",
remote_version=None,
)
result = storage.last_update_check
assert result is None
def test_esphome_storage_json_to_json() -> None:
"""Test EsphomeStorageJSON.to_json returns valid JSON string."""
storage = storage_json.EsphomeStorageJSON(
storage_version=1,
cookie_secret="mysecret",
last_update_check="2024-01-15T10:30:00",
remote_version="2024.1.1",
)
json_str = storage.to_json()
# Should be valid JSON
parsed = json.loads(json_str)
assert parsed["cookie_secret"] == "mysecret"
assert parsed["storage_version"] == 1
# Should end with newline
assert json_str.endswith("\n")
def test_esphome_storage_json_save(tmp_path: Path) -> None:
"""Test EsphomeStorageJSON.save writes file correctly."""
storage = storage_json.EsphomeStorageJSON(
storage_version=1,
cookie_secret="secret",
last_update_check=None,
remote_version=None,
)
save_path = tmp_path / "esphome.json"
with patch("esphome.storage_json.write_file_if_changed") as mock_write:
storage.save(str(save_path))
mock_write.assert_called_once_with(str(save_path), storage.to_json())
def test_esphome_storage_json_load_valid_file(tmp_path: Path) -> None:
"""Test EsphomeStorageJSON.load with valid JSON file."""
storage_data = {
"storage_version": 1,
"cookie_secret": "loaded_secret",
"last_update_check": "2024-01-20T14:30:00",
"remote_version": "2024.1.2",
}
file_path = tmp_path / "esphome.json"
file_path.write_text(json.dumps(storage_data))
result = storage_json.EsphomeStorageJSON.load(str(file_path))
assert result is not None
assert result.storage_version == 1
assert result.cookie_secret == "loaded_secret"
assert result.last_update_check_str == "2024-01-20T14:30:00"
assert result.remote_version == "2024.1.2"
def test_esphome_storage_json_load_invalid_file(tmp_path: Path) -> None:
"""Test EsphomeStorageJSON.load with invalid JSON file."""
file_path = tmp_path / "invalid.json"
file_path.write_text("not valid json{")
result = storage_json.EsphomeStorageJSON.load(str(file_path))
assert result is None
def test_esphome_storage_json_load_nonexistent_file() -> None:
"""Test EsphomeStorageJSON.load with non-existent file."""
result = storage_json.EsphomeStorageJSON.load("/nonexistent/file.json")
assert result is None
def test_esphome_storage_json_get_default() -> None:
"""Test EsphomeStorageJSON.get_default creates default storage."""
with patch("esphome.storage_json.os.urandom") as mock_urandom:
# Mock urandom to return predictable bytes
mock_urandom.return_value = b"test" * 16 # 64 bytes
result = storage_json.EsphomeStorageJSON.get_default()
assert result.storage_version == 1
assert len(result.cookie_secret) == 128 # 64 bytes hex = 128 chars
assert result.last_update_check is None
assert result.remote_version is None
def test_esphome_storage_json_equality() -> None:
"""Test EsphomeStorageJSON equality comparison."""
storage1 = storage_json.EsphomeStorageJSON(
storage_version=1,
cookie_secret="secret",
last_update_check="2024-01-15T10:30:00",
remote_version="2024.1.1",
)
storage2 = storage_json.EsphomeStorageJSON(
storage_version=1,
cookie_secret="secret",
last_update_check="2024-01-15T10:30:00",
remote_version="2024.1.1",
)
storage3 = storage_json.EsphomeStorageJSON(
storage_version=1,
cookie_secret="different", # Different secret
last_update_check="2024-01-15T10:30:00",
remote_version="2024.1.1",
)
assert storage1 == storage2
assert storage1 != storage3
assert storage1 != "not a storage object"
def test_storage_json_load_legacy_esphomeyaml_version(tmp_path: Path) -> None:
"""Test loading storage with legacy esphomeyaml_version field."""
storage_data = {
"storage_version": 1,
"name": "legacy_device",
"friendly_name": "Legacy Device",
"esphomeyaml_version": "1.14.0", # Legacy field name
"address": "legacy.local",
"esp_platform": "ESP8266",
}
file_path = tmp_path / "legacy.json"
file_path.write_text(json.dumps(storage_data))
result = storage_json.StorageJSON.load(str(file_path))
assert result is not None
assert result.esphome_version == "1.14.0" # Should map to esphome_version

View File

@@ -1,5 +1,7 @@
"""Tests for esphome.util module.""" """Tests for esphome.util module."""
from __future__ import annotations
from pathlib import Path from pathlib import Path
import pytest import pytest
@@ -308,3 +310,85 @@ def test_filter_yaml_files_case_sensitive() -> None:
assert "/path/to/config.YAML" not in result assert "/path/to/config.YAML" not in result
assert "/path/to/config.YML" not in result assert "/path/to/config.YML" not in result
assert "/path/to/config.Yaml" not in result assert "/path/to/config.Yaml" not in result
@pytest.mark.parametrize(
("input_str", "expected"),
[
# Empty string
("", "''"),
# Simple strings that don't need quoting
("hello", "hello"),
("test123", "test123"),
("file.txt", "file.txt"),
("/path/to/file", "/path/to/file"),
("user@host", "user@host"),
("value:123", "value:123"),
("item,list", "item,list"),
("path-with-dash", "path-with-dash"),
# Strings that need quoting
("hello world", "'hello world'"),
("test\ttab", "'test\ttab'"),
("line\nbreak", "'line\nbreak'"),
("semicolon;here", "'semicolon;here'"),
("pipe|symbol", "'pipe|symbol'"),
("redirect>file", "'redirect>file'"),
("redirect<file", "'redirect<file'"),
("background&", "'background&'"),
("dollar$sign", "'dollar$sign'"),
("backtick`cmd", "'backtick`cmd'"),
('double"quote', "'double\"quote'"),
("backslash\\path", "'backslash\\path'"),
("question?mark", "'question?mark'"),
("asterisk*wild", "'asterisk*wild'"),
("bracket[test]", "'bracket[test]'"),
("paren(test)", "'paren(test)'"),
("curly{brace}", "'curly{brace}'"),
# Single quotes in string (special escaping)
("it's", "'it'\"'\"'s'"),
("don't", "'don'\"'\"'t'"),
("'quoted'", "''\"'\"'quoted'\"'\"''"),
# Complex combinations
("test 'with' quotes", "'test '\"'\"'with'\"'\"' quotes'"),
("path/to/file's.txt", "'path/to/file'\"'\"'s.txt'"),
],
)
def test_shlex_quote(input_str: str, expected: str) -> None:
"""Test shlex_quote properly escapes shell arguments."""
assert util.shlex_quote(input_str) == expected
def test_shlex_quote_safe_characters() -> None:
"""Test that safe characters are not quoted."""
# These characters are considered safe and shouldn't be quoted
safe_chars = (
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789@%+=:,./-_"
)
for char in safe_chars:
assert util.shlex_quote(char) == char
assert util.shlex_quote(f"test{char}test") == f"test{char}test"
def test_shlex_quote_unsafe_characters() -> None:
"""Test that unsafe characters trigger quoting."""
# These characters should trigger quoting
unsafe_chars = ' \t\n;|>&<$`"\\?*[](){}!#~^'
for char in unsafe_chars:
result = util.shlex_quote(f"test{char}test")
assert result.startswith("'")
assert result.endswith("'")
def test_shlex_quote_edge_cases() -> None:
"""Test edge cases for shlex_quote."""
# Multiple single quotes
assert util.shlex_quote("'''") == "''\"'\"''\"'\"''\"'\"''"
# Mixed quotes
assert util.shlex_quote('"\'"') == "'\"'\"'\"'\"'"
# Only whitespace
assert util.shlex_quote(" ") == "' '"
assert util.shlex_quote("\t") == "'\t'"
assert util.shlex_quote("\n") == "'\n'"
assert util.shlex_quote(" ") == "' '"