From 2bb64a189dc4a04eb90505298ee9281f4131e48d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 18 Sep 2025 20:13:13 -0500 Subject: [PATCH] [dashboard] Transfer DNS/mDNS cache from dashboard to CLI to avoid blocking (#10685) --- esphome/__main__.py | 33 ++- esphome/address_cache.py | 142 ++++++++++++ esphome/core/__init__.py | 5 + esphome/dashboard/dns.py | 15 ++ esphome/dashboard/status/mdns.py | 27 +++ esphome/dashboard/web_server.py | 119 ++++++---- esphome/espota2.py | 6 +- esphome/helpers.py | 118 ++++++---- tests/dashboard/conftest.py | 21 ++ tests/dashboard/status/__init__.py | 0 tests/dashboard/status/test_dns.py | 121 ++++++++++ tests/dashboard/status/test_mdns.py | 168 ++++++++++++++ tests/dashboard/test_web_server.py | 80 +++++++ tests/unit_tests/test_address_cache.py | 305 +++++++++++++++++++++++++ tests/unit_tests/test_helpers.py | 82 +++++++ 15 files changed, 1155 insertions(+), 87 deletions(-) create mode 100644 esphome/address_cache.py create mode 100644 tests/dashboard/conftest.py create mode 100644 tests/dashboard/status/__init__.py create mode 100644 tests/dashboard/status/test_dns.py create mode 100644 tests/dashboard/status/test_mdns.py create mode 100644 tests/unit_tests/test_address_cache.py diff --git a/esphome/__main__.py b/esphome/__main__.py index f54fa8e3c6..07cd267c96 100644 --- a/esphome/__main__.py +++ b/esphome/__main__.py @@ -114,6 +114,14 @@ class Purpose(StrEnum): LOGGING = "logging" +def _resolve_with_cache(address: str, purpose: Purpose) -> list[str]: + """Resolve an address using cache if available, otherwise return the address itself.""" + if CORE.address_cache and (cached := CORE.address_cache.get_addresses(address)): + _LOGGER.debug("Using cached addresses for %s: %s", purpose.value, cached) + return cached + return [address] + + def choose_upload_log_host( default: list[str] | str | None, check_default: str | None, @@ -142,7 +150,7 @@ def choose_upload_log_host( (purpose == Purpose.LOGGING and has_api()) or (purpose == Purpose.UPLOADING and has_ota()) ): - resolved.append(CORE.address) + resolved.extend(_resolve_with_cache(CORE.address, purpose)) if purpose == Purpose.LOGGING: if has_api() and has_mqtt_ip_lookup(): @@ -152,15 +160,14 @@ def choose_upload_log_host( resolved.append("MQTT") if has_api() and has_non_ip_address(): - resolved.append(CORE.address) + resolved.extend(_resolve_with_cache(CORE.address, purpose)) elif purpose == Purpose.UPLOADING: if has_ota() and has_mqtt_ip_lookup(): resolved.append("MQTTIP") if has_ota() and has_non_ip_address(): - resolved.append(CORE.address) - + resolved.extend(_resolve_with_cache(CORE.address, purpose)) else: resolved.append(device) if not resolved: @@ -965,6 +972,18 @@ def parse_args(argv): help="Add a substitution", metavar=("key", "value"), ) + options_parser.add_argument( + "--mdns-address-cache", + help="mDNS address cache mapping in format 'hostname=ip1,ip2'", + action="append", + default=[], + ) + options_parser.add_argument( + "--dns-address-cache", + help="DNS address cache mapping in format 'hostname=ip1,ip2'", + action="append", + default=[], + ) parser = argparse.ArgumentParser( description=f"ESPHome {const.__version__}", parents=[options_parser] @@ -1212,9 +1231,15 @@ def parse_args(argv): def run_esphome(argv): + from esphome.address_cache import AddressCache + args = parse_args(argv) CORE.dashboard = args.dashboard + # Create address cache from command-line arguments + CORE.address_cache = AddressCache.from_cli_args( + args.mdns_address_cache, args.dns_address_cache + ) # Override log level if verbose is set if args.verbose: args.log_level = "DEBUG" diff --git a/esphome/address_cache.py b/esphome/address_cache.py new file mode 100644 index 0000000000..7c20be90f0 --- /dev/null +++ b/esphome/address_cache.py @@ -0,0 +1,142 @@ +"""Address cache for DNS and mDNS lookups.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Iterable + +_LOGGER = logging.getLogger(__name__) + + +def normalize_hostname(hostname: str) -> str: + """Normalize hostname for cache lookups. + + Removes trailing dots and converts to lowercase. + """ + return hostname.rstrip(".").lower() + + +class AddressCache: + """Cache for DNS and mDNS address lookups. + + This cache stores pre-resolved addresses from command-line arguments + to avoid slow DNS/mDNS lookups during builds. + """ + + def __init__( + self, + mdns_cache: dict[str, list[str]] | None = None, + dns_cache: dict[str, list[str]] | None = None, + ) -> None: + """Initialize the address cache. + + Args: + mdns_cache: Pre-populated mDNS addresses (hostname -> IPs) + dns_cache: Pre-populated DNS addresses (hostname -> IPs) + """ + self.mdns_cache = mdns_cache or {} + self.dns_cache = dns_cache or {} + + def _get_cached_addresses( + self, hostname: str, cache: dict[str, list[str]], cache_type: str + ) -> list[str] | None: + """Get cached addresses from a specific cache. + + Args: + hostname: The hostname to look up + cache: The cache dictionary to check + cache_type: Type of cache for logging ("mDNS" or "DNS") + + Returns: + List of IP addresses if found in cache, None otherwise + """ + normalized = normalize_hostname(hostname) + if addresses := cache.get(normalized): + _LOGGER.debug("Using %s cache for %s: %s", cache_type, hostname, addresses) + return addresses + return None + + def get_mdns_addresses(self, hostname: str) -> list[str] | None: + """Get cached mDNS addresses for a hostname. + + Args: + hostname: The hostname to look up (should end with .local) + + Returns: + List of IP addresses if found in cache, None otherwise + """ + return self._get_cached_addresses(hostname, self.mdns_cache, "mDNS") + + def get_dns_addresses(self, hostname: str) -> list[str] | None: + """Get cached DNS addresses for a hostname. + + Args: + hostname: The hostname to look up + + Returns: + List of IP addresses if found in cache, None otherwise + """ + return self._get_cached_addresses(hostname, self.dns_cache, "DNS") + + def get_addresses(self, hostname: str) -> list[str] | None: + """Get cached addresses for a hostname. + + Checks mDNS cache for .local domains, DNS cache otherwise. + + Args: + hostname: The hostname to look up + + Returns: + List of IP addresses if found in cache, None otherwise + """ + normalized = normalize_hostname(hostname) + if normalized.endswith(".local"): + return self.get_mdns_addresses(hostname) + return self.get_dns_addresses(hostname) + + def has_cache(self) -> bool: + """Check if any cache entries exist.""" + return bool(self.mdns_cache or self.dns_cache) + + @classmethod + def from_cli_args( + cls, mdns_args: Iterable[str], dns_args: Iterable[str] + ) -> AddressCache: + """Create cache from command-line arguments. + + Args: + mdns_args: List of mDNS cache entries like ['host=ip1,ip2'] + dns_args: List of DNS cache entries like ['host=ip1,ip2'] + + Returns: + Configured AddressCache instance + """ + mdns_cache = cls._parse_cache_args(mdns_args) + dns_cache = cls._parse_cache_args(dns_args) + return cls(mdns_cache=mdns_cache, dns_cache=dns_cache) + + @staticmethod + def _parse_cache_args(cache_args: Iterable[str]) -> dict[str, list[str]]: + """Parse cache arguments into a dictionary. + + Args: + cache_args: List of cache mappings like ['host1=ip1,ip2', 'host2=ip3'] + + Returns: + Dictionary mapping normalized hostnames to list of IP addresses + """ + cache: dict[str, list[str]] = {} + for arg in cache_args: + if "=" not in arg: + _LOGGER.warning( + "Invalid cache format: %s (expected 'hostname=ip1,ip2')", arg + ) + continue + hostname, ips = arg.split("=", 1) + # Normalize hostname for consistent lookups + normalized = normalize_hostname(hostname) + cache[normalized] = [ip.strip() for ip in ips.split(",")] + return cache diff --git a/esphome/core/__init__.py b/esphome/core/__init__.py index 89e3eff7d8..242a6854df 100644 --- a/esphome/core/__init__.py +++ b/esphome/core/__init__.py @@ -39,6 +39,8 @@ from esphome.helpers import ensure_unique_string, get_str_env, is_ha_addon from esphome.util import OrderedDict if TYPE_CHECKING: + from esphome.address_cache import AddressCache + from ..cpp_generator import MockObj, MockObjClass, Statement from ..types import ConfigType, EntityMetadata @@ -583,6 +585,8 @@ class EsphomeCore: self.id_classes = {} # The current component being processed during validation self.current_component: str | None = None + # Address cache for DNS and mDNS lookups from command line arguments + self.address_cache: AddressCache | None = None def reset(self): from esphome.pins import PIN_SCHEMA_REGISTRY @@ -610,6 +614,7 @@ class EsphomeCore: self.platform_counts = defaultdict(int) self.unique_ids = {} self.current_component = None + self.address_cache = None PIN_SCHEMA_REGISTRY.reset() @contextmanager diff --git a/esphome/dashboard/dns.py b/esphome/dashboard/dns.py index 98134062f4..58867f7bc1 100644 --- a/esphome/dashboard/dns.py +++ b/esphome/dashboard/dns.py @@ -28,6 +28,21 @@ class DNSCache: self._cache: dict[str, tuple[float, list[str] | Exception]] = {} self._ttl = ttl + def get_cached_addresses( + self, hostname: str, now_monotonic: float + ) -> list[str] | None: + """Get cached addresses without triggering resolution. + + Returns None if not in cache, list of addresses if found. + """ + # Normalize hostname for consistent lookups + normalized = hostname.rstrip(".").lower() + if expire_time_addresses := self._cache.get(normalized): + expire_time, addresses = expire_time_addresses + if expire_time > now_monotonic and not isinstance(addresses, Exception): + return addresses + return None + async def async_resolve( self, hostname: str, now_monotonic: float ) -> list[str] | Exception: diff --git a/esphome/dashboard/status/mdns.py b/esphome/dashboard/status/mdns.py index f9ac7b4289..989517e1c3 100644 --- a/esphome/dashboard/status/mdns.py +++ b/esphome/dashboard/status/mdns.py @@ -4,6 +4,9 @@ import asyncio import logging import typing +from zeroconf import AddressResolver, IPVersion + +from esphome.address_cache import normalize_hostname from esphome.zeroconf import ( ESPHOME_SERVICE_TYPE, AsyncEsphomeZeroconf, @@ -50,6 +53,30 @@ class MDNSStatus: return await aiozc.async_resolve_host(host_name) return None + def get_cached_addresses(self, host_name: str) -> list[str] | None: + """Get cached addresses for a host without triggering resolution. + + Returns None if not in cache or no zeroconf available. + """ + if not self.aiozc: + _LOGGER.debug("No zeroconf instance available for %s", host_name) + return None + + # Normalize hostname and get the base name + normalized = normalize_hostname(host_name) + base_name = normalized.partition(".")[0] + + # Try to load from zeroconf cache without triggering resolution + resolver_name = f"{base_name}.local." + info = AddressResolver(resolver_name) + # Let zeroconf use its own current time for cache checking + if info.load_from_cache(self.aiozc.zeroconf): + addresses = info.parsed_scoped_addresses(IPVersion.All) + _LOGGER.debug("Found %s in zeroconf cache: %s", resolver_name, addresses) + return addresses + _LOGGER.debug("Not found in zeroconf cache: %s", resolver_name) + return None + async def async_refresh_hosts(self) -> None: """Refresh the hosts to track.""" dashboard = self.dashboard diff --git a/esphome/dashboard/web_server.py b/esphome/dashboard/web_server.py index e6c5fd3d84..24595eb942 100644 --- a/esphome/dashboard/web_server.py +++ b/esphome/dashboard/web_server.py @@ -50,8 +50,8 @@ from esphome.util import get_serial_ports, shlex_quote from esphome.yaml_util import FastestAvailableSafeLoader from .const import DASHBOARD_COMMAND -from .core import DASHBOARD -from .entries import UNKNOWN_STATE, entry_state_to_bool +from .core import DASHBOARD, ESPHomeDashboard +from .entries import UNKNOWN_STATE, DashboardEntry, entry_state_to_bool from .util.file import write_file from .util.subprocess import async_run_system_command from .util.text import friendly_name_slugify @@ -314,6 +314,73 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): raise NotImplementedError +def build_cache_arguments( + entry: DashboardEntry | None, + dashboard: ESPHomeDashboard, + now: float, +) -> list[str]: + """Build cache arguments for passing to CLI. + + Args: + entry: Dashboard entry for the configuration + dashboard: Dashboard instance with cache access + now: Current monotonic time for DNS cache expiry checks + + Returns: + List of cache arguments to pass to CLI + """ + cache_args: list[str] = [] + + if not entry: + return cache_args + + _LOGGER.debug( + "Building cache for entry (address=%s, name=%s)", + entry.address, + entry.name, + ) + + def add_cache_entry(hostname: str, addresses: list[str], cache_type: str) -> None: + """Add a cache entry to the command arguments.""" + if not addresses: + return + normalized = hostname.rstrip(".").lower() + cache_args.extend( + [ + f"--{cache_type}-address-cache", + f"{normalized}={','.join(sort_ip_addresses(addresses))}", + ] + ) + + # Check entry.address for cached addresses + if use_address := entry.address: + if use_address.endswith(".local"): + # mDNS cache for .local addresses + if (mdns := dashboard.mdns_status) and ( + cached := mdns.get_cached_addresses(use_address) + ): + _LOGGER.debug("mDNS cache hit for %s: %s", use_address, cached) + add_cache_entry(use_address, cached, "mdns") + # DNS cache for non-.local addresses + elif cached := dashboard.dns_cache.get_cached_addresses(use_address, now): + _LOGGER.debug("DNS cache hit for %s: %s", use_address, cached) + add_cache_entry(use_address, cached, "dns") + + # Check entry.name if we haven't already cached via address + # For mDNS devices, entry.name typically doesn't have .local suffix + if entry.name and not use_address: + mdns_name = ( + f"{entry.name}.local" if not entry.name.endswith(".local") else entry.name + ) + if (mdns := dashboard.mdns_status) and ( + cached := mdns.get_cached_addresses(mdns_name) + ): + _LOGGER.debug("mDNS cache hit for %s: %s", mdns_name, cached) + add_cache_entry(mdns_name, cached, "mdns") + + return cache_args + + class EsphomePortCommandWebSocket(EsphomeCommandWebSocket): """Base class for commands that require a port.""" @@ -326,52 +393,22 @@ class EsphomePortCommandWebSocket(EsphomeCommandWebSocket): configuration = json_message["configuration"] config_file = settings.rel_path(configuration) port = json_message["port"] - addresses: list[str] = [] + + # Build cache arguments to pass to CLI + cache_args: list[str] = [] + if ( port == "OTA" # pylint: disable=too-many-boolean-expressions and (entry := entries.get(config_file)) and entry.loaded_integrations and "api" in entry.loaded_integrations ): - # First priority: entry.address AKA use_address - if ( - (use_address := entry.address) - and ( - address_list := await dashboard.dns_cache.async_resolve( - use_address, time.monotonic() - ) - ) - and not isinstance(address_list, Exception) - ): - addresses.extend(sort_ip_addresses(address_list)) + cache_args = build_cache_arguments(entry, dashboard, time.monotonic()) - # Second priority: mDNS - if ( - (mdns := dashboard.mdns_status) - and (address_list := await mdns.async_resolve_host(entry.name)) - and ( - new_addresses := [ - addr for addr in address_list if addr not in addresses - ] - ) - ): - # Use the IP address if available but only - # if the API is loaded and the device is online - # since MQTT logging will not work otherwise - addresses.extend(sort_ip_addresses(new_addresses)) - - if not addresses: - # If no address was found, use the port directly - # as otherwise they will get the chooser which - # does not work with the dashboard as there is no - # interactive way to get keyboard input - addresses = [port] - - device_args: list[str] = [ - arg for address in addresses for arg in ("--device", address) - ] - - return [*DASHBOARD_COMMAND, *args, config_file, *device_args] + # Cache arguments must come before the subcommand + cmd = [*DASHBOARD_COMMAND, *cache_args, *args, config_file, "--device", port] + _LOGGER.debug("Built command: %s", cmd) + return cmd class EsphomeLogsHandler(EsphomePortCommandWebSocket): diff --git a/esphome/espota2.py b/esphome/espota2.py index 3d25af985b..99c91d94e2 100644 --- a/esphome/espota2.py +++ b/esphome/espota2.py @@ -311,10 +311,14 @@ def perform_ota( def run_ota_impl_( remote_host: str | list[str], remote_port: int, password: str, filename: str ) -> tuple[int, str | None]: + from esphome.core import CORE + # Handle both single host and list of hosts try: # Resolve all hosts at once for parallel DNS resolution - res = resolve_ip_address(remote_host, remote_port) + res = resolve_ip_address( + remote_host, remote_port, address_cache=CORE.address_cache + ) except EsphomeError as err: _LOGGER.error( "Error resolving IP address of %s. Is it connected to WiFi?", diff --git a/esphome/helpers.py b/esphome/helpers.py index 6beaa24a96..2b7221355c 100644 --- a/esphome/helpers.py +++ b/esphome/helpers.py @@ -9,10 +9,14 @@ from pathlib import Path import platform import re import tempfile +from typing import TYPE_CHECKING from urllib.parse import urlparse from esphome.const import __version__ as ESPHOME_VERSION +if TYPE_CHECKING: + from esphome.address_cache import AddressCache + # Type aliases for socket address information AddrInfo = tuple[ int, # family (AF_INET, AF_INET6, etc.) @@ -173,7 +177,24 @@ def addr_preference_(res: AddrInfo) -> int: return 1 -def resolve_ip_address(host: str | list[str], port: int) -> list[AddrInfo]: +def _add_ip_addresses_to_addrinfo( + addresses: list[str], port: int, res: list[AddrInfo] +) -> None: + """Helper to add IP addresses to addrinfo results with error handling.""" + import socket + + for addr in addresses: + try: + res += socket.getaddrinfo( + addr, port, proto=socket.IPPROTO_TCP, flags=socket.AI_NUMERICHOST + ) + except OSError: + _LOGGER.debug("Failed to parse IP address '%s'", addr) + + +def resolve_ip_address( + host: str | list[str], port: int, address_cache: AddressCache | None = None +) -> list[AddrInfo]: import socket # There are five cases here. The host argument could be one of: @@ -194,47 +215,69 @@ def resolve_ip_address(host: str | list[str], port: int) -> list[AddrInfo]: hosts = [host] res: list[AddrInfo] = [] + + # Fast path: if all hosts are already IP addresses if all(is_ip_address(h) for h in hosts): - # Fast path: all are IP addresses, use socket.getaddrinfo with AI_NUMERICHOST - for addr in hosts: - try: - res += socket.getaddrinfo( - addr, port, proto=socket.IPPROTO_TCP, flags=socket.AI_NUMERICHOST - ) - except OSError: - _LOGGER.debug("Failed to parse IP address '%s'", addr) + _add_ip_addresses_to_addrinfo(hosts, port, res) # Sort by preference res.sort(key=addr_preference_) return res - from esphome.resolver import AsyncResolver + # Process hosts + cached_addresses: list[str] = [] + uncached_hosts: list[str] = [] + has_cache = address_cache is not None - resolver = AsyncResolver(hosts, port) - addr_infos = resolver.resolve() - # Convert aioesphomeapi AddrInfo to our format - for addr_info in addr_infos: - sockaddr = addr_info.sockaddr - if addr_info.family == socket.AF_INET6: - # IPv6 - sockaddr_tuple = ( - sockaddr.address, - sockaddr.port, - sockaddr.flowinfo, - sockaddr.scope_id, - ) + for h in hosts: + if is_ip_address(h): + if has_cache: + # If we have a cache, treat IPs as cached + cached_addresses.append(h) + else: + # If no cache, pass IPs through to resolver with hostnames + uncached_hosts.append(h) + elif address_cache and (cached := address_cache.get_addresses(h)): + # Found in cache + cached_addresses.extend(cached) else: - # IPv4 - sockaddr_tuple = (sockaddr.address, sockaddr.port) + # Not cached, need to resolve + if address_cache and address_cache.has_cache(): + _LOGGER.info("Host %s not in cache, will need to resolve", h) + uncached_hosts.append(h) - res.append( - ( - addr_info.family, - addr_info.type, - addr_info.proto, - "", # canonname - sockaddr_tuple, + # Process cached addresses (includes direct IPs and cached lookups) + _add_ip_addresses_to_addrinfo(cached_addresses, port, res) + + # If we have uncached hosts (only non-IP hostnames), resolve them + if uncached_hosts: + from esphome.resolver import AsyncResolver + + resolver = AsyncResolver(uncached_hosts, port) + addr_infos = resolver.resolve() + # Convert aioesphomeapi AddrInfo to our format + for addr_info in addr_infos: + sockaddr = addr_info.sockaddr + if addr_info.family == socket.AF_INET6: + # IPv6 + sockaddr_tuple = ( + sockaddr.address, + sockaddr.port, + sockaddr.flowinfo, + sockaddr.scope_id, + ) + else: + # IPv4 + sockaddr_tuple = (sockaddr.address, sockaddr.port) + + res.append( + ( + addr_info.family, + addr_info.type, + addr_info.proto, + "", # canonname + sockaddr_tuple, + ) ) - ) # Sort by preference res.sort(key=addr_preference_) @@ -256,14 +299,7 @@ def sort_ip_addresses(address_list: list[str]) -> list[str]: # First "resolve" all the IP addresses to getaddrinfo() tuples of the form # (family, type, proto, canonname, sockaddr) res: list[AddrInfo] = [] - for addr in address_list: - # This should always work as these are supposed to be IP addresses - try: - res += socket.getaddrinfo( - addr, 0, proto=socket.IPPROTO_TCP, flags=socket.AI_NUMERICHOST - ) - except OSError: - _LOGGER.info("Failed to parse IP address '%s'", addr) + _add_ip_addresses_to_addrinfo(address_list, 0, res) # Now use that information to sort them. res.sort(key=addr_preference_) diff --git a/tests/dashboard/conftest.py b/tests/dashboard/conftest.py new file mode 100644 index 0000000000..358be1bf5d --- /dev/null +++ b/tests/dashboard/conftest.py @@ -0,0 +1,21 @@ +"""Common fixtures for dashboard tests.""" + +from __future__ import annotations + +from unittest.mock import Mock + +import pytest + +from esphome.dashboard.core import ESPHomeDashboard + + +@pytest.fixture +def mock_dashboard() -> Mock: + """Create a mock dashboard.""" + dashboard = Mock(spec=ESPHomeDashboard) + dashboard.entries = Mock() + dashboard.entries.async_all.return_value = [] + dashboard.stop_event = Mock() + dashboard.stop_event.is_set.return_value = True + dashboard.ping_request = Mock() + return dashboard diff --git a/tests/dashboard/status/__init__.py b/tests/dashboard/status/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/dashboard/status/test_dns.py b/tests/dashboard/status/test_dns.py new file mode 100644 index 0000000000..9ca48ba2d8 --- /dev/null +++ b/tests/dashboard/status/test_dns.py @@ -0,0 +1,121 @@ +"""Unit tests for esphome.dashboard.dns module.""" + +from __future__ import annotations + +import time +from unittest.mock import patch + +import pytest + +from esphome.dashboard.dns import DNSCache + + +@pytest.fixture +def dns_cache_fixture() -> DNSCache: + """Create a DNSCache instance.""" + return DNSCache() + + +def test_get_cached_addresses_not_in_cache(dns_cache_fixture: DNSCache) -> None: + """Test get_cached_addresses when hostname is not in cache.""" + now = time.monotonic() + result = dns_cache_fixture.get_cached_addresses("unknown.example.com", now) + assert result is None + + +def test_get_cached_addresses_expired(dns_cache_fixture: DNSCache) -> None: + """Test get_cached_addresses when cache entry is expired.""" + now = time.monotonic() + # Add entry that's already expired + dns_cache_fixture._cache["example.com"] = (now - 1, ["192.168.1.10"]) + + result = dns_cache_fixture.get_cached_addresses("example.com", now) + assert result is None + # Expired entry should still be in cache (not removed by get_cached_addresses) + assert "example.com" in dns_cache_fixture._cache + + +def test_get_cached_addresses_valid(dns_cache_fixture: DNSCache) -> None: + """Test get_cached_addresses with valid cache entry.""" + now = time.monotonic() + # Add entry that expires in 60 seconds + dns_cache_fixture._cache["example.com"] = ( + now + 60, + ["192.168.1.10", "192.168.1.11"], + ) + + result = dns_cache_fixture.get_cached_addresses("example.com", now) + assert result == ["192.168.1.10", "192.168.1.11"] + # Entry should still be in cache + assert "example.com" in dns_cache_fixture._cache + + +def test_get_cached_addresses_hostname_normalization( + dns_cache_fixture: DNSCache, +) -> None: + """Test get_cached_addresses normalizes hostname.""" + now = time.monotonic() + # Add entry with lowercase hostname + dns_cache_fixture._cache["example.com"] = (now + 60, ["192.168.1.10"]) + + # Test with various forms + assert dns_cache_fixture.get_cached_addresses("EXAMPLE.COM", now) == [ + "192.168.1.10" + ] + assert dns_cache_fixture.get_cached_addresses("example.com.", now) == [ + "192.168.1.10" + ] + assert dns_cache_fixture.get_cached_addresses("EXAMPLE.COM.", now) == [ + "192.168.1.10" + ] + + +def test_get_cached_addresses_ipv6(dns_cache_fixture: DNSCache) -> None: + """Test get_cached_addresses with IPv6 addresses.""" + now = time.monotonic() + dns_cache_fixture._cache["example.com"] = (now + 60, ["2001:db8::1", "fe80::1"]) + + result = dns_cache_fixture.get_cached_addresses("example.com", now) + assert result == ["2001:db8::1", "fe80::1"] + + +def test_get_cached_addresses_empty_list(dns_cache_fixture: DNSCache) -> None: + """Test get_cached_addresses with empty address list.""" + now = time.monotonic() + dns_cache_fixture._cache["example.com"] = (now + 60, []) + + result = dns_cache_fixture.get_cached_addresses("example.com", now) + assert result == [] + + +def test_get_cached_addresses_exception_in_cache(dns_cache_fixture: DNSCache) -> None: + """Test get_cached_addresses when cache contains an exception.""" + now = time.monotonic() + # Store an exception (from failed resolution) + dns_cache_fixture._cache["example.com"] = (now + 60, OSError("Resolution failed")) + + result = dns_cache_fixture.get_cached_addresses("example.com", now) + assert result is None # Should return None for exceptions + + +def test_async_resolve_not_called(dns_cache_fixture: DNSCache) -> None: + """Test that get_cached_addresses never calls async_resolve.""" + now = time.monotonic() + + with patch.object(dns_cache_fixture, "async_resolve") as mock_resolve: + # Test non-cached + result = dns_cache_fixture.get_cached_addresses("uncached.com", now) + assert result is None + mock_resolve.assert_not_called() + + # Test expired + dns_cache_fixture._cache["expired.com"] = (now - 1, ["192.168.1.10"]) + result = dns_cache_fixture.get_cached_addresses("expired.com", now) + assert result is None + mock_resolve.assert_not_called() + + # Test valid + dns_cache_fixture._cache["valid.com"] = (now + 60, ["192.168.1.10"]) + result = dns_cache_fixture.get_cached_addresses("valid.com", now) + assert result == ["192.168.1.10"] + mock_resolve.assert_not_called() diff --git a/tests/dashboard/status/test_mdns.py b/tests/dashboard/status/test_mdns.py new file mode 100644 index 0000000000..7130c2c73a --- /dev/null +++ b/tests/dashboard/status/test_mdns.py @@ -0,0 +1,168 @@ +"""Unit tests for esphome.dashboard.status.mdns module.""" + +from __future__ import annotations + +from unittest.mock import Mock, patch + +import pytest +import pytest_asyncio +from zeroconf import AddressResolver, IPVersion + +from esphome.dashboard.status.mdns import MDNSStatus + + +@pytest_asyncio.fixture +async def mdns_status(mock_dashboard: Mock) -> MDNSStatus: + """Create an MDNSStatus instance in async context.""" + # We're in an async context so get_running_loop will work + return MDNSStatus(mock_dashboard) + + +@pytest.mark.asyncio +async def test_get_cached_addresses_no_zeroconf(mdns_status: MDNSStatus) -> None: + """Test get_cached_addresses when no zeroconf instance is available.""" + mdns_status.aiozc = None + result = mdns_status.get_cached_addresses("device.local") + assert result is None + + +@pytest.mark.asyncio +async def test_get_cached_addresses_not_in_cache(mdns_status: MDNSStatus) -> None: + """Test get_cached_addresses when address is not in cache.""" + mdns_status.aiozc = Mock() + mdns_status.aiozc.zeroconf = Mock() + + with patch("esphome.dashboard.status.mdns.AddressResolver") as mock_resolver: + mock_info = Mock(spec=AddressResolver) + mock_info.load_from_cache.return_value = False + mock_resolver.return_value = mock_info + + result = mdns_status.get_cached_addresses("device.local") + assert result is None + mock_info.load_from_cache.assert_called_once_with(mdns_status.aiozc.zeroconf) + + +@pytest.mark.asyncio +async def test_get_cached_addresses_found_in_cache(mdns_status: MDNSStatus) -> None: + """Test get_cached_addresses when address is found in cache.""" + mdns_status.aiozc = Mock() + mdns_status.aiozc.zeroconf = Mock() + + with patch("esphome.dashboard.status.mdns.AddressResolver") as mock_resolver: + mock_info = Mock(spec=AddressResolver) + mock_info.load_from_cache.return_value = True + mock_info.parsed_scoped_addresses.return_value = ["192.168.1.10", "fe80::1"] + mock_resolver.return_value = mock_info + + result = mdns_status.get_cached_addresses("device.local") + assert result == ["192.168.1.10", "fe80::1"] + mock_info.load_from_cache.assert_called_once_with(mdns_status.aiozc.zeroconf) + mock_info.parsed_scoped_addresses.assert_called_once_with(IPVersion.All) + + +@pytest.mark.asyncio +async def test_get_cached_addresses_with_trailing_dot(mdns_status: MDNSStatus) -> None: + """Test get_cached_addresses with hostname having trailing dot.""" + mdns_status.aiozc = Mock() + mdns_status.aiozc.zeroconf = Mock() + + with patch("esphome.dashboard.status.mdns.AddressResolver") as mock_resolver: + mock_info = Mock(spec=AddressResolver) + mock_info.load_from_cache.return_value = True + mock_info.parsed_scoped_addresses.return_value = ["192.168.1.10"] + mock_resolver.return_value = mock_info + + result = mdns_status.get_cached_addresses("device.local.") + assert result == ["192.168.1.10"] + # Should normalize to device.local. for zeroconf + mock_resolver.assert_called_once_with("device.local.") + + +@pytest.mark.asyncio +async def test_get_cached_addresses_uppercase_hostname(mdns_status: MDNSStatus) -> None: + """Test get_cached_addresses with uppercase hostname.""" + mdns_status.aiozc = Mock() + mdns_status.aiozc.zeroconf = Mock() + + with patch("esphome.dashboard.status.mdns.AddressResolver") as mock_resolver: + mock_info = Mock(spec=AddressResolver) + mock_info.load_from_cache.return_value = True + mock_info.parsed_scoped_addresses.return_value = ["192.168.1.10"] + mock_resolver.return_value = mock_info + + result = mdns_status.get_cached_addresses("DEVICE.LOCAL") + assert result == ["192.168.1.10"] + # Should normalize to device.local. for zeroconf + mock_resolver.assert_called_once_with("device.local.") + + +@pytest.mark.asyncio +async def test_get_cached_addresses_simple_hostname(mdns_status: MDNSStatus) -> None: + """Test get_cached_addresses with simple hostname (no domain).""" + mdns_status.aiozc = Mock() + mdns_status.aiozc.zeroconf = Mock() + + with patch("esphome.dashboard.status.mdns.AddressResolver") as mock_resolver: + mock_info = Mock(spec=AddressResolver) + mock_info.load_from_cache.return_value = True + mock_info.parsed_scoped_addresses.return_value = ["192.168.1.10"] + mock_resolver.return_value = mock_info + + result = mdns_status.get_cached_addresses("device") + assert result == ["192.168.1.10"] + # Should append .local. for zeroconf + mock_resolver.assert_called_once_with("device.local.") + + +@pytest.mark.asyncio +async def test_get_cached_addresses_ipv6_only(mdns_status: MDNSStatus) -> None: + """Test get_cached_addresses returning only IPv6 addresses.""" + mdns_status.aiozc = Mock() + mdns_status.aiozc.zeroconf = Mock() + + with patch("esphome.dashboard.status.mdns.AddressResolver") as mock_resolver: + mock_info = Mock(spec=AddressResolver) + mock_info.load_from_cache.return_value = True + mock_info.parsed_scoped_addresses.return_value = ["fe80::1", "2001:db8::1"] + mock_resolver.return_value = mock_info + + result = mdns_status.get_cached_addresses("device.local") + assert result == ["fe80::1", "2001:db8::1"] + + +@pytest.mark.asyncio +async def test_get_cached_addresses_empty_list(mdns_status: MDNSStatus) -> None: + """Test get_cached_addresses returning empty list from cache.""" + mdns_status.aiozc = Mock() + mdns_status.aiozc.zeroconf = Mock() + + with patch("esphome.dashboard.status.mdns.AddressResolver") as mock_resolver: + mock_info = Mock(spec=AddressResolver) + mock_info.load_from_cache.return_value = True + mock_info.parsed_scoped_addresses.return_value = [] + mock_resolver.return_value = mock_info + + result = mdns_status.get_cached_addresses("device.local") + assert result == [] + + +@pytest.mark.asyncio +async def test_async_setup_success(mock_dashboard: Mock) -> None: + """Test successful async_setup.""" + mdns_status = MDNSStatus(mock_dashboard) + with patch("esphome.dashboard.status.mdns.AsyncEsphomeZeroconf") as mock_zc: + mock_zc.return_value = Mock() + result = mdns_status.async_setup() + assert result is True + assert mdns_status.aiozc is not None + + +@pytest.mark.asyncio +async def test_async_setup_failure(mock_dashboard: Mock) -> None: + """Test async_setup with OSError.""" + mdns_status = MDNSStatus(mock_dashboard) + with patch("esphome.dashboard.status.mdns.AsyncEsphomeZeroconf") as mock_zc: + mock_zc.side_effect = OSError("Network error") + result = mdns_status.async_setup() + assert result is False + assert mdns_status.aiozc is None diff --git a/tests/dashboard/test_web_server.py b/tests/dashboard/test_web_server.py index 1938617f20..605df4e02c 100644 --- a/tests/dashboard/test_web_server.py +++ b/tests/dashboard/test_web_server.py @@ -730,3 +730,83 @@ def test_start_web_server_with_unix_socket(tmp_path: Path) -> None: mock_server_class.assert_called_once_with(app) mock_bind.assert_called_once_with(str(socket_path), mode=0o666) server.add_socket.assert_called_once() + + +def test_build_cache_arguments_no_entry(mock_dashboard: Mock) -> None: + """Test with no entry returns empty list.""" + result = web_server.build_cache_arguments(None, mock_dashboard, 0.0) + assert result == [] + + +def test_build_cache_arguments_no_address_no_name(mock_dashboard: Mock) -> None: + """Test with entry but no address or name.""" + entry = Mock(spec=web_server.DashboardEntry) + entry.address = None + entry.name = None + result = web_server.build_cache_arguments(entry, mock_dashboard, 0.0) + assert result == [] + + +def test_build_cache_arguments_mdns_address_cached(mock_dashboard: Mock) -> None: + """Test with .local address that has cached mDNS results.""" + entry = Mock(spec=web_server.DashboardEntry) + entry.address = "device.local" + entry.name = None + mock_dashboard.mdns_status = Mock() + mock_dashboard.mdns_status.get_cached_addresses.return_value = [ + "192.168.1.10", + "fe80::1", + ] + + result = web_server.build_cache_arguments(entry, mock_dashboard, 0.0) + + assert result == [ + "--mdns-address-cache", + "device.local=192.168.1.10,fe80::1", + ] + mock_dashboard.mdns_status.get_cached_addresses.assert_called_once_with( + "device.local" + ) + + +def test_build_cache_arguments_dns_address_cached(mock_dashboard: Mock) -> None: + """Test with non-.local address that has cached DNS results.""" + entry = Mock(spec=web_server.DashboardEntry) + entry.address = "example.com" + entry.name = None + mock_dashboard.dns_cache = Mock() + mock_dashboard.dns_cache.get_cached_addresses.return_value = [ + "93.184.216.34", + "2606:2800:220:1:248:1893:25c8:1946", + ] + + now = 100.0 + result = web_server.build_cache_arguments(entry, mock_dashboard, now) + + # IPv6 addresses are sorted before IPv4 + assert result == [ + "--dns-address-cache", + "example.com=2606:2800:220:1:248:1893:25c8:1946,93.184.216.34", + ] + mock_dashboard.dns_cache.get_cached_addresses.assert_called_once_with( + "example.com", now + ) + + +def test_build_cache_arguments_name_without_address(mock_dashboard: Mock) -> None: + """Test with name but no address - should check mDNS with .local suffix.""" + entry = Mock(spec=web_server.DashboardEntry) + entry.name = "my-device" + entry.address = None + mock_dashboard.mdns_status = Mock() + mock_dashboard.mdns_status.get_cached_addresses.return_value = ["192.168.1.20"] + + result = web_server.build_cache_arguments(entry, mock_dashboard, 0.0) + + assert result == [ + "--mdns-address-cache", + "my-device.local=192.168.1.20", + ] + mock_dashboard.mdns_status.get_cached_addresses.assert_called_once_with( + "my-device.local" + ) diff --git a/tests/unit_tests/test_address_cache.py b/tests/unit_tests/test_address_cache.py new file mode 100644 index 0000000000..de43830d53 --- /dev/null +++ b/tests/unit_tests/test_address_cache.py @@ -0,0 +1,305 @@ +"""Tests for the address_cache module.""" + +from __future__ import annotations + +import logging + +import pytest +from pytest import LogCaptureFixture + +from esphome.address_cache import AddressCache, normalize_hostname + + +def test_normalize_simple_hostname() -> None: + """Test normalizing a simple hostname.""" + assert normalize_hostname("device") == "device" + assert normalize_hostname("device.local") == "device.local" + assert normalize_hostname("server.example.com") == "server.example.com" + + +def test_normalize_removes_trailing_dots() -> None: + """Test that trailing dots are removed.""" + assert normalize_hostname("device.") == "device" + assert normalize_hostname("device.local.") == "device.local" + assert normalize_hostname("server.example.com.") == "server.example.com" + assert normalize_hostname("device...") == "device" + + +def test_normalize_converts_to_lowercase() -> None: + """Test that hostnames are converted to lowercase.""" + assert normalize_hostname("DEVICE") == "device" + assert normalize_hostname("Device.Local") == "device.local" + assert normalize_hostname("Server.Example.COM") == "server.example.com" + + +def test_normalize_combined() -> None: + """Test combination of trailing dots and case conversion.""" + assert normalize_hostname("DEVICE.LOCAL.") == "device.local" + assert normalize_hostname("Server.Example.COM...") == "server.example.com" + + +def test_init_empty() -> None: + """Test initialization with empty caches.""" + cache = AddressCache() + assert cache.mdns_cache == {} + assert cache.dns_cache == {} + assert not cache.has_cache() + + +def test_init_with_caches() -> None: + """Test initialization with provided caches.""" + mdns_cache: dict[str, list[str]] = {"device.local": ["192.168.1.10"]} + dns_cache: dict[str, list[str]] = {"server.com": ["10.0.0.1"]} + cache = AddressCache(mdns_cache=mdns_cache, dns_cache=dns_cache) + assert cache.mdns_cache == mdns_cache + assert cache.dns_cache == dns_cache + assert cache.has_cache() + + +def test_get_mdns_addresses() -> None: + """Test getting mDNS addresses.""" + cache = AddressCache(mdns_cache={"device.local": ["192.168.1.10", "192.168.1.11"]}) + + # Direct lookup + assert cache.get_mdns_addresses("device.local") == [ + "192.168.1.10", + "192.168.1.11", + ] + + # Case insensitive lookup + assert cache.get_mdns_addresses("Device.Local") == [ + "192.168.1.10", + "192.168.1.11", + ] + + # With trailing dot + assert cache.get_mdns_addresses("device.local.") == [ + "192.168.1.10", + "192.168.1.11", + ] + + # Not found + assert cache.get_mdns_addresses("unknown.local") is None + + +def test_get_dns_addresses() -> None: + """Test getting DNS addresses.""" + cache = AddressCache(dns_cache={"server.com": ["10.0.0.1", "10.0.0.2"]}) + + # Direct lookup + assert cache.get_dns_addresses("server.com") == ["10.0.0.1", "10.0.0.2"] + + # Case insensitive lookup + assert cache.get_dns_addresses("Server.COM") == ["10.0.0.1", "10.0.0.2"] + + # With trailing dot + assert cache.get_dns_addresses("server.com.") == ["10.0.0.1", "10.0.0.2"] + + # Not found + assert cache.get_dns_addresses("unknown.com") is None + + +def test_get_addresses_auto_detection() -> None: + """Test automatic cache selection based on hostname.""" + cache = AddressCache( + mdns_cache={"device.local": ["192.168.1.10"]}, + dns_cache={"server.com": ["10.0.0.1"]}, + ) + + # Should use mDNS cache for .local domains + assert cache.get_addresses("device.local") == ["192.168.1.10"] + assert cache.get_addresses("device.local.") == ["192.168.1.10"] + assert cache.get_addresses("Device.Local") == ["192.168.1.10"] + + # Should use DNS cache for non-.local domains + assert cache.get_addresses("server.com") == ["10.0.0.1"] + assert cache.get_addresses("server.com.") == ["10.0.0.1"] + assert cache.get_addresses("Server.COM") == ["10.0.0.1"] + + # Not found + assert cache.get_addresses("unknown.local") is None + assert cache.get_addresses("unknown.com") is None + + +def test_has_cache() -> None: + """Test checking if cache has entries.""" + # Empty cache + cache = AddressCache() + assert not cache.has_cache() + + # Only mDNS cache + cache = AddressCache(mdns_cache={"device.local": ["192.168.1.10"]}) + assert cache.has_cache() + + # Only DNS cache + cache = AddressCache(dns_cache={"server.com": ["10.0.0.1"]}) + assert cache.has_cache() + + # Both caches + cache = AddressCache( + mdns_cache={"device.local": ["192.168.1.10"]}, + dns_cache={"server.com": ["10.0.0.1"]}, + ) + assert cache.has_cache() + + +def test_from_cli_args_empty() -> None: + """Test creating cache from empty CLI arguments.""" + cache = AddressCache.from_cli_args([], []) + assert cache.mdns_cache == {} + assert cache.dns_cache == {} + + +def test_from_cli_args_single_entry() -> None: + """Test creating cache from single CLI argument.""" + mdns_args: list[str] = ["device.local=192.168.1.10"] + dns_args: list[str] = ["server.com=10.0.0.1"] + + cache = AddressCache.from_cli_args(mdns_args, dns_args) + + assert cache.mdns_cache == {"device.local": ["192.168.1.10"]} + assert cache.dns_cache == {"server.com": ["10.0.0.1"]} + + +def test_from_cli_args_multiple_ips() -> None: + """Test creating cache with multiple IPs per host.""" + mdns_args: list[str] = ["device.local=192.168.1.10,192.168.1.11"] + dns_args: list[str] = ["server.com=10.0.0.1,10.0.0.2,10.0.0.3"] + + cache = AddressCache.from_cli_args(mdns_args, dns_args) + + assert cache.mdns_cache == {"device.local": ["192.168.1.10", "192.168.1.11"]} + assert cache.dns_cache == {"server.com": ["10.0.0.1", "10.0.0.2", "10.0.0.3"]} + + +def test_from_cli_args_multiple_entries() -> None: + """Test creating cache with multiple host entries.""" + mdns_args: list[str] = [ + "device1.local=192.168.1.10", + "device2.local=192.168.1.20,192.168.1.21", + ] + dns_args: list[str] = ["server1.com=10.0.0.1", "server2.com=10.0.0.2"] + + cache = AddressCache.from_cli_args(mdns_args, dns_args) + + assert cache.mdns_cache == { + "device1.local": ["192.168.1.10"], + "device2.local": ["192.168.1.20", "192.168.1.21"], + } + assert cache.dns_cache == { + "server1.com": ["10.0.0.1"], + "server2.com": ["10.0.0.2"], + } + + +def test_from_cli_args_normalization() -> None: + """Test that CLI arguments are normalized.""" + mdns_args: list[str] = ["Device1.Local.=192.168.1.10", "DEVICE2.LOCAL=192.168.1.20"] + dns_args: list[str] = ["Server1.COM.=10.0.0.1", "SERVER2.com=10.0.0.2"] + + cache = AddressCache.from_cli_args(mdns_args, dns_args) + + # Hostnames should be normalized (lowercase, no trailing dots) + assert cache.mdns_cache == { + "device1.local": ["192.168.1.10"], + "device2.local": ["192.168.1.20"], + } + assert cache.dns_cache == { + "server1.com": ["10.0.0.1"], + "server2.com": ["10.0.0.2"], + } + + +def test_from_cli_args_whitespace_handling() -> None: + """Test that whitespace in IPs is handled.""" + mdns_args: list[str] = ["device.local= 192.168.1.10 , 192.168.1.11 "] + dns_args: list[str] = ["server.com= 10.0.0.1 , 10.0.0.2 "] + + cache = AddressCache.from_cli_args(mdns_args, dns_args) + + assert cache.mdns_cache == {"device.local": ["192.168.1.10", "192.168.1.11"]} + assert cache.dns_cache == {"server.com": ["10.0.0.1", "10.0.0.2"]} + + +def test_from_cli_args_invalid_format(caplog: LogCaptureFixture) -> None: + """Test handling of invalid argument format.""" + mdns_args: list[str] = ["invalid_format", "device.local=192.168.1.10"] + dns_args: list[str] = ["server.com=10.0.0.1", "also_invalid"] + + cache = AddressCache.from_cli_args(mdns_args, dns_args) + + # Valid entries should still be processed + assert cache.mdns_cache == {"device.local": ["192.168.1.10"]} + assert cache.dns_cache == {"server.com": ["10.0.0.1"]} + + # Check that warnings were logged for invalid entries + assert "Invalid cache format: invalid_format" in caplog.text + assert "Invalid cache format: also_invalid" in caplog.text + + +def test_from_cli_args_ipv6() -> None: + """Test handling of IPv6 addresses.""" + mdns_args: list[str] = ["device.local=fe80::1,2001:db8::1"] + dns_args: list[str] = ["server.com=2001:db8::2,::1"] + + cache = AddressCache.from_cli_args(mdns_args, dns_args) + + assert cache.mdns_cache == {"device.local": ["fe80::1", "2001:db8::1"]} + assert cache.dns_cache == {"server.com": ["2001:db8::2", "::1"]} + + +def test_logging_output(caplog: LogCaptureFixture) -> None: + """Test that appropriate debug logging occurs.""" + caplog.set_level(logging.DEBUG) + + cache = AddressCache( + mdns_cache={"device.local": ["192.168.1.10"]}, + dns_cache={"server.com": ["10.0.0.1"]}, + ) + + # Test successful lookups log at debug level + result: list[str] | None = cache.get_mdns_addresses("device.local") + assert result == ["192.168.1.10"] + assert "Using mDNS cache for device.local" in caplog.text + + caplog.clear() + result = cache.get_dns_addresses("server.com") + assert result == ["10.0.0.1"] + assert "Using DNS cache for server.com" in caplog.text + + # Test that failed lookups don't log + caplog.clear() + result = cache.get_mdns_addresses("unknown.local") + assert result is None + assert "Using mDNS cache" not in caplog.text + + +@pytest.mark.parametrize( + "hostname,expected", + [ + ("test.local", "test.local"), + ("Test.Local.", "test.local"), + ("TEST.LOCAL...", "test.local"), + ("example.com", "example.com"), + ("EXAMPLE.COM.", "example.com"), + ], +) +def test_normalize_hostname_parametrized(hostname: str, expected: str) -> None: + """Test hostname normalization with various inputs.""" + assert normalize_hostname(hostname) == expected + + +@pytest.mark.parametrize( + "mdns_arg,expected", + [ + ("host=1.2.3.4", {"host": ["1.2.3.4"]}), + ("Host.Local=1.2.3.4,5.6.7.8", {"host.local": ["1.2.3.4", "5.6.7.8"]}), + ("HOST.LOCAL.=::1", {"host.local": ["::1"]}), + ], +) +def test_parse_cache_args_parametrized( + mdns_arg: str, expected: dict[str, list[str]] +) -> None: + """Test parsing of cache arguments with various formats.""" + cache = AddressCache.from_cli_args([mdns_arg], []) + assert cache.mdns_cache == expected diff --git a/tests/unit_tests/test_helpers.py b/tests/unit_tests/test_helpers.py index cc65d9747e..b49e5797c1 100644 --- a/tests/unit_tests/test_helpers.py +++ b/tests/unit_tests/test_helpers.py @@ -11,6 +11,7 @@ from hypothesis.strategies import ip_addresses import pytest from esphome import helpers +from esphome.address_cache import AddressCache from esphome.core import EsphomeError @@ -830,3 +831,84 @@ def test_resolve_ip_address_sorting() -> None: assert result[0][4][0] == "2001:db8::1" # IPv6 (preference 1) assert result[1][4][0] == "192.168.1.100" # IPv4 (preference 2) assert result[2][4][0] == "fe80::1" # Link-local no scope (preference 3) + + +def test_resolve_ip_address_with_cache() -> None: + """Test that the cache is used when provided.""" + cache = AddressCache( + mdns_cache={"test.local": ["192.168.1.100", "192.168.1.101"]}, + dns_cache={ + "example.com": ["93.184.216.34", "2606:2800:220:1:248:1893:25c8:1946"] + }, + ) + + # Test mDNS cache hit + result = helpers.resolve_ip_address("test.local", 6053, address_cache=cache) + + # Should return cached addresses without calling resolver + assert len(result) == 2 + assert result[0][4][0] == "192.168.1.100" + assert result[1][4][0] == "192.168.1.101" + + # Test DNS cache hit + result = helpers.resolve_ip_address("example.com", 6053, address_cache=cache) + + # Should return cached addresses with IPv6 first due to preference + assert len(result) == 2 + assert result[0][4][0] == "2606:2800:220:1:248:1893:25c8:1946" # IPv6 first + assert result[1][4][0] == "93.184.216.34" # IPv4 second + + +def test_resolve_ip_address_cache_miss() -> None: + """Test that resolver is called when not in cache.""" + cache = AddressCache(mdns_cache={"other.local": ["192.168.1.200"]}) + + mock_addr_info = AddrInfo( + family=socket.AF_INET, + type=socket.SOCK_STREAM, + proto=socket.IPPROTO_TCP, + sockaddr=IPv4Sockaddr(address="192.168.1.100", port=6053), + ) + + with patch("esphome.resolver.AsyncResolver") as MockResolver: + mock_resolver = MockResolver.return_value + mock_resolver.resolve.return_value = [mock_addr_info] + + result = helpers.resolve_ip_address("test.local", 6053, address_cache=cache) + + # Should call resolver since test.local is not in cache + MockResolver.assert_called_once_with(["test.local"], 6053) + assert len(result) == 1 + assert result[0][4][0] == "192.168.1.100" + + +def test_resolve_ip_address_mixed_cached_uncached() -> None: + """Test resolution with mix of cached and uncached hosts.""" + cache = AddressCache(mdns_cache={"cached.local": ["192.168.1.50"]}) + + mock_addr_info = AddrInfo( + family=socket.AF_INET, + type=socket.SOCK_STREAM, + proto=socket.IPPROTO_TCP, + sockaddr=IPv4Sockaddr(address="192.168.1.100", port=6053), + ) + + with patch("esphome.resolver.AsyncResolver") as MockResolver: + mock_resolver = MockResolver.return_value + mock_resolver.resolve.return_value = [mock_addr_info] + + # Pass a list with cached IP, cached hostname, and uncached hostname + result = helpers.resolve_ip_address( + ["192.168.1.10", "cached.local", "uncached.local"], + 6053, + address_cache=cache, + ) + + # Should only resolve uncached.local + MockResolver.assert_called_once_with(["uncached.local"], 6053) + + # Results should include all addresses + addresses = [r[4][0] for r in result] + assert "192.168.1.10" in addresses # Direct IP + assert "192.168.1.50" in addresses # From cache + assert "192.168.1.100" in addresses # From resolver