From 4d3405340d4abfd4e3e0341c668aef8a3278d332 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 11 Sep 2025 18:21:01 -0500 Subject: [PATCH] Fix dashboard dns lookup delay --- esphome/__main__.py | 22 ++++++ esphome/address_cache.py | 131 +++++++++++++++++++++++++++++++ esphome/core/__init__.py | 3 + esphome/dashboard/dns.py | 11 +++ esphome/dashboard/status/mdns.py | 16 ++++ esphome/dashboard/web_server.py | 70 ++++++++++------- esphome/espota2.py | 6 +- esphome/helpers.py | 88 ++++++++++++++------- tests/unit_tests/test_helpers.py | 87 ++++++++++++++++++++ tests/unit_tests/test_main.py | 71 +++++++++++++++++ 10 files changed, 448 insertions(+), 57 deletions(-) create mode 100644 esphome/address_cache.py diff --git a/esphome/__main__.py b/esphome/__main__.py index bba254436e..15c29e6cdf 100644 --- a/esphome/__main__.py +++ b/esphome/__main__.py @@ -889,6 +889,18 @@ def parse_args(argv): help="Add a substitution", metavar=("key", "value"), ) + options_parser.add_argument( + "--mdns-lookup-cache", + help="mDNS lookup cache mapping in format 'hostname=ip1,ip2'", + action="append", + default=[], + ) + options_parser.add_argument( + "--dns-lookup-cache", + help="DNS lookup cache mapping in format 'hostname=ip1,ip2'", + action="append", + default=[], + ) parser = argparse.ArgumentParser( description=f"ESPHome {const.__version__}", parents=[options_parser] @@ -1136,9 +1148,19 @@ def parse_args(argv): def run_esphome(argv): + from esphome.address_cache import AddressCache + args = parse_args(argv) CORE.dashboard = args.dashboard + # Create address cache from command-line arguments + address_cache = AddressCache.from_cli_args( + args.mdns_lookup_cache, args.dns_lookup_cache + ) + + # Store cache in CORE for access throughout the application + CORE.address_cache = address_cache + # Override log level if verbose is set if args.verbose: args.log_level = "DEBUG" diff --git a/esphome/address_cache.py b/esphome/address_cache.py new file mode 100644 index 0000000000..6e5881716d --- /dev/null +++ b/esphome/address_cache.py @@ -0,0 +1,131 @@ +"""Address cache for DNS and mDNS lookups.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Iterable + +_LOGGER = logging.getLogger(__name__) + + +def normalize_hostname(hostname: str) -> str: + """Normalize hostname for cache lookups. + + Removes trailing dots and converts to lowercase. + """ + return hostname.rstrip(".").lower() + + +class AddressCache: + """Cache for DNS and mDNS address lookups. + + This cache stores pre-resolved addresses from command-line arguments + to avoid slow DNS/mDNS lookups during builds. + """ + + def __init__( + self, + mdns_cache: dict[str, list[str]] | None = None, + dns_cache: dict[str, list[str]] | None = None, + ) -> None: + """Initialize the address cache. + + Args: + mdns_cache: Pre-populated mDNS addresses (hostname -> IPs) + dns_cache: Pre-populated DNS addresses (hostname -> IPs) + """ + self.mdns_cache = mdns_cache or {} + self.dns_cache = dns_cache or {} + + def get_mdns_addresses(self, hostname: str) -> list[str] | None: + """Get cached mDNS addresses for a hostname. + + Args: + hostname: The hostname to look up (should end with .local) + + Returns: + List of IP addresses if found in cache, None otherwise + """ + normalized = normalize_hostname(hostname) + if addresses := self.mdns_cache.get(normalized): + _LOGGER.debug("Using mDNS cache for %s: %s", hostname, addresses) + return addresses + return None + + def get_dns_addresses(self, hostname: str) -> list[str] | None: + """Get cached DNS addresses for a hostname. + + Args: + hostname: The hostname to look up + + Returns: + List of IP addresses if found in cache, None otherwise + """ + normalized = normalize_hostname(hostname) + if addresses := self.dns_cache.get(normalized): + _LOGGER.debug("Using DNS cache for %s: %s", hostname, addresses) + return addresses + return None + + def get_addresses(self, hostname: str) -> list[str] | None: + """Get cached addresses for a hostname. + + Checks mDNS cache for .local domains, DNS cache otherwise. + + Args: + hostname: The hostname to look up + + Returns: + List of IP addresses if found in cache, None otherwise + """ + normalized = normalize_hostname(hostname) + if normalized.endswith(".local"): + return self.get_mdns_addresses(hostname) + return self.get_dns_addresses(hostname) + + def has_cache(self) -> bool: + """Check if any cache entries exist.""" + return bool(self.mdns_cache or self.dns_cache) + + @classmethod + def from_cli_args( + cls, mdns_args: Iterable[str], dns_args: Iterable[str] + ) -> AddressCache: + """Create cache from command-line arguments. + + Args: + mdns_args: List of mDNS cache entries like ['host=ip1,ip2'] + dns_args: List of DNS cache entries like ['host=ip1,ip2'] + + Returns: + Configured AddressCache instance + """ + mdns_cache = cls._parse_cache_args(mdns_args) + dns_cache = cls._parse_cache_args(dns_args) + return cls(mdns_cache=mdns_cache, dns_cache=dns_cache) + + @staticmethod + def _parse_cache_args(cache_args: Iterable[str]) -> dict[str, list[str]]: + """Parse cache arguments into a dictionary. + + Args: + cache_args: List of cache mappings like ['host1=ip1,ip2', 'host2=ip3'] + + Returns: + Dictionary mapping normalized hostnames to list of IP addresses + """ + cache: dict[str, list[str]] = {} + for arg in cache_args: + if "=" not in arg: + _LOGGER.warning( + "Invalid cache format: %s (expected 'hostname=ip1,ip2')", arg + ) + continue + hostname, ips = arg.split("=", 1) + # Normalize hostname for consistent lookups + normalized = normalize_hostname(hostname) + cache[normalized] = [ip.strip() for ip in ips.split(",")] + return cache diff --git a/esphome/core/__init__.py b/esphome/core/__init__.py index 89e3eff7d8..0d4ddf56d4 100644 --- a/esphome/core/__init__.py +++ b/esphome/core/__init__.py @@ -583,6 +583,8 @@ class EsphomeCore: self.id_classes = {} # The current component being processed during validation self.current_component: str | None = None + # Address cache for DNS and mDNS lookups from command line arguments + self.address_cache: object | None = None def reset(self): from esphome.pins import PIN_SCHEMA_REGISTRY @@ -610,6 +612,7 @@ class EsphomeCore: self.platform_counts = defaultdict(int) self.unique_ids = {} self.current_component = None + self.address_cache = None PIN_SCHEMA_REGISTRY.reset() @contextmanager diff --git a/esphome/dashboard/dns.py b/esphome/dashboard/dns.py index 98134062f4..4f1ef71dd0 100644 --- a/esphome/dashboard/dns.py +++ b/esphome/dashboard/dns.py @@ -28,6 +28,17 @@ class DNSCache: self._cache: dict[str, tuple[float, list[str] | Exception]] = {} self._ttl = ttl + def get_cached(self, hostname: str, now_monotonic: float) -> list[str] | None: + """Get cached address without triggering resolution. + + Returns None if not in cache, list of addresses if found. + """ + if expire_time_addresses := self._cache.get(hostname): + expire_time, addresses = expire_time_addresses + if expire_time > now_monotonic and not isinstance(addresses, Exception): + return addresses + return None + async def async_resolve( self, hostname: str, now_monotonic: float ) -> list[str] | Exception: diff --git a/esphome/dashboard/status/mdns.py b/esphome/dashboard/status/mdns.py index f9ac7b4289..0977a89c3a 100644 --- a/esphome/dashboard/status/mdns.py +++ b/esphome/dashboard/status/mdns.py @@ -50,6 +50,22 @@ class MDNSStatus: return await aiozc.async_resolve_host(host_name) return None + def get_cached_addresses(self, host_name: str) -> list[str] | None: + """Get cached addresses for a host without triggering resolution. + + Returns None if not in cache or no zeroconf available. + """ + if not self.aiozc: + return None + + from zeroconf import AddressResolver, IPVersion + + # Try to load from zeroconf cache without triggering resolution + info = AddressResolver(f"{host_name.partition('.')[0]}.local.") + if info.load_from_cache(self.aiozc.zeroconf): + return info.parsed_scoped_addresses(IPVersion.All) + return None + async def async_refresh_hosts(self) -> None: """Refresh the hosts to track.""" dashboard = self.dashboard diff --git a/esphome/dashboard/web_server.py b/esphome/dashboard/web_server.py index 294a180794..767144fd19 100644 --- a/esphome/dashboard/web_server.py +++ b/esphome/dashboard/web_server.py @@ -326,52 +326,64 @@ class EsphomePortCommandWebSocket(EsphomeCommandWebSocket): configuration = json_message["configuration"] config_file = settings.rel_path(configuration) port = json_message["port"] + + # Only get cached addresses - no async resolution addresses: list[str] = [] + cache_args: list[str] = [] + if ( port == "OTA" # pylint: disable=too-many-boolean-expressions and (entry := entries.get(config_file)) and entry.loaded_integrations and "api" in entry.loaded_integrations ): - # First priority: entry.address AKA use_address - if ( - (use_address := entry.address) - and ( - address_list := await dashboard.dns_cache.async_resolve( - use_address, time.monotonic() - ) - ) - and not isinstance(address_list, Exception) - ): - addresses.extend(sort_ip_addresses(address_list)) + now = time.monotonic() - # Second priority: mDNS - if ( - (mdns := dashboard.mdns_status) - and (address_list := await mdns.async_resolve_host(entry.name)) - and ( - new_addresses := [ - addr for addr in address_list if addr not in addresses - ] - ) + # Collect all cached addresses for this device + dns_cache_entries: dict[str, set[str]] = {} + mdns_cache_entries: dict[str, set[str]] = {} + + # First priority: entry.address AKA use_address (from DNS cache only) + if (use_address := entry.address) and ( + cached := dashboard.dns_cache.get_cached(use_address, now) ): - # Use the IP address if available but only - # if the API is loaded and the device is online - # since MQTT logging will not work otherwise - addresses.extend(sort_ip_addresses(new_addresses)) + addresses.extend(sort_ip_addresses(cached)) + dns_cache_entries[use_address] = set(cached) + + # Second priority: mDNS cache for device name + if entry.name and not addresses: # Only if we don't have addresses yet + if entry.name.endswith(".local"): + # Check mDNS cache (zeroconf) + if (mdns := dashboard.mdns_status) and ( + cached := mdns.get_cached_addresses(entry.name) + ): + addresses.extend(sort_ip_addresses(cached)) + mdns_cache_entries[entry.name] = set(cached) + # Check DNS cache for non-.local names + elif cached := dashboard.dns_cache.get_cached(entry.name, now): + addresses.extend(sort_ip_addresses(cached)) + dns_cache_entries[entry.name] = set(cached) + + # Build cache arguments to pass to CLI + for hostname, addrs in dns_cache_entries.items(): + cache_args.extend( + ["--dns-lookup-cache", f"{hostname}={','.join(sorted(addrs))}"] + ) + for hostname, addrs in mdns_cache_entries.items(): + cache_args.extend( + ["--mdns-lookup-cache", f"{hostname}={','.join(sorted(addrs))}"] + ) if not addresses: - # If no address was found, use the port directly - # as otherwise they will get the chooser which - # does not work with the dashboard as there is no - # interactive way to get keyboard input + # If no cached address was found, use the port directly + # The CLI will do the resolution with the cache hints we provide addresses = [port] device_args: list[str] = [ arg for address in addresses for arg in ("--device", address) ] - return [*DASHBOARD_COMMAND, *args, config_file, *device_args] + return [*DASHBOARD_COMMAND, *args, config_file, *device_args, *cache_args] class EsphomeLogsHandler(EsphomePortCommandWebSocket): diff --git a/esphome/espota2.py b/esphome/espota2.py index 3d25af985b..f808d558d7 100644 --- a/esphome/espota2.py +++ b/esphome/espota2.py @@ -311,10 +311,14 @@ def perform_ota( def run_ota_impl_( remote_host: str | list[str], remote_port: int, password: str, filename: str ) -> tuple[int, str | None]: + from esphome.core import CORE + # Handle both single host and list of hosts try: # Resolve all hosts at once for parallel DNS resolution - res = resolve_ip_address(remote_host, remote_port) + res = resolve_ip_address( + remote_host, remote_port, address_cache=getattr(CORE, "address_cache", None) + ) except EsphomeError as err: _LOGGER.error( "Error resolving IP address of %s. Is it connected to WiFi?", diff --git a/esphome/helpers.py b/esphome/helpers.py index 6beaa24a96..f4b321b26f 100644 --- a/esphome/helpers.py +++ b/esphome/helpers.py @@ -173,7 +173,9 @@ def addr_preference_(res: AddrInfo) -> int: return 1 -def resolve_ip_address(host: str | list[str], port: int) -> list[AddrInfo]: +def resolve_ip_address( + host: str | list[str], port: int, address_cache: object | None = None +) -> list[AddrInfo]: import socket # There are five cases here. The host argument could be one of: @@ -194,8 +196,9 @@ def resolve_ip_address(host: str | list[str], port: int) -> list[AddrInfo]: hosts = [host] res: list[AddrInfo] = [] + + # Fast path: if all hosts are already IP addresses if all(is_ip_address(h) for h in hosts): - # Fast path: all are IP addresses, use socket.getaddrinfo with AI_NUMERICHOST for addr in hosts: try: res += socket.getaddrinfo( @@ -207,34 +210,65 @@ def resolve_ip_address(host: str | list[str], port: int) -> list[AddrInfo]: res.sort(key=addr_preference_) return res - from esphome.resolver import AsyncResolver + # Check if we have cached addresses for these hosts + cached_hosts: list[str] = [] + uncached_hosts: list[str] = [] - resolver = AsyncResolver(hosts, port) - addr_infos = resolver.resolve() - # Convert aioesphomeapi AddrInfo to our format - for addr_info in addr_infos: - sockaddr = addr_info.sockaddr - if addr_info.family == socket.AF_INET6: - # IPv6 - sockaddr_tuple = ( - sockaddr.address, - sockaddr.port, - sockaddr.flowinfo, - sockaddr.scope_id, - ) - else: - # IPv4 - sockaddr_tuple = (sockaddr.address, sockaddr.port) + for h in hosts: + # Check if it's already an IP address + if is_ip_address(h): + cached_hosts.append(h) + continue - res.append( - ( - addr_info.family, - addr_info.type, - addr_info.proto, - "", # canonname - sockaddr_tuple, + # Check cache if provided + if address_cache and (cached_addresses := address_cache.get_addresses(h)): + cached_hosts.extend(cached_addresses) + continue + + # Not in cache, need to resolve + if address_cache and address_cache.has_cache(): + _LOGGER.info("Host %s not in cache, will need to resolve", h) + uncached_hosts.append(h) + + # Process cached addresses (all should be IP addresses) + for addr in cached_hosts: + try: + res += socket.getaddrinfo( + addr, port, proto=socket.IPPROTO_TCP, flags=socket.AI_NUMERICHOST + ) + except OSError: + _LOGGER.debug("Failed to parse IP address '%s'", addr) + + # If we have uncached hosts, resolve them + if uncached_hosts: + from esphome.resolver import AsyncResolver + + resolver = AsyncResolver(uncached_hosts, port) + addr_infos = resolver.resolve() + # Convert aioesphomeapi AddrInfo to our format + for addr_info in addr_infos: + sockaddr = addr_info.sockaddr + if addr_info.family == socket.AF_INET6: + # IPv6 + sockaddr_tuple = ( + sockaddr.address, + sockaddr.port, + sockaddr.flowinfo, + sockaddr.scope_id, + ) + else: + # IPv4 + sockaddr_tuple = (sockaddr.address, sockaddr.port) + + res.append( + ( + addr_info.family, + addr_info.type, + addr_info.proto, + "", # canonname + sockaddr_tuple, + ) ) - ) # Sort by preference res.sort(key=addr_preference_) diff --git a/tests/unit_tests/test_helpers.py b/tests/unit_tests/test_helpers.py index 9f51206ff9..631d6a878e 100644 --- a/tests/unit_tests/test_helpers.py +++ b/tests/unit_tests/test_helpers.py @@ -594,3 +594,90 @@ def test_resolve_ip_address_sorting() -> None: assert result[0][4][0] == "2001:db8::1" # IPv6 (preference 1) assert result[1][4][0] == "192.168.1.100" # IPv4 (preference 2) assert result[2][4][0] == "fe80::1" # Link-local no scope (preference 3) + + +def test_resolve_ip_address_with_cache() -> None: + """Test that the cache is used when provided.""" + from esphome.address_cache import AddressCache + + cache = AddressCache( + mdns_cache={"test.local": ["192.168.1.100", "192.168.1.101"]}, + dns_cache={ + "example.com": ["93.184.216.34", "2606:2800:220:1:248:1893:25c8:1946"] + }, + ) + + # Test mDNS cache hit + result = helpers.resolve_ip_address("test.local", 6053, address_cache=cache) + + # Should return cached addresses without calling resolver + assert len(result) == 2 + assert result[0][4][0] == "192.168.1.100" + assert result[1][4][0] == "192.168.1.101" + + # Test DNS cache hit + result = helpers.resolve_ip_address("example.com", 6053, address_cache=cache) + + # Should return cached addresses with IPv6 first due to preference + assert len(result) == 2 + assert result[0][4][0] == "2606:2800:220:1:248:1893:25c8:1946" # IPv6 first + assert result[1][4][0] == "93.184.216.34" # IPv4 second + + +def test_resolve_ip_address_cache_miss() -> None: + """Test that resolver is called when not in cache.""" + from esphome.address_cache import AddressCache + + cache = AddressCache(mdns_cache={"other.local": ["192.168.1.200"]}) + + mock_addr_info = AddrInfo( + family=socket.AF_INET, + type=socket.SOCK_STREAM, + proto=socket.IPPROTO_TCP, + sockaddr=IPv4Sockaddr(address="192.168.1.100", port=6053), + ) + + with patch("esphome.resolver.AsyncResolver") as MockResolver: + mock_resolver = MockResolver.return_value + mock_resolver.resolve.return_value = [mock_addr_info] + + result = helpers.resolve_ip_address("test.local", 6053, address_cache=cache) + + # Should call resolver since test.local is not in cache + MockResolver.assert_called_once_with(["test.local"], 6053) + assert len(result) == 1 + assert result[0][4][0] == "192.168.1.100" + + +def test_resolve_ip_address_mixed_cached_uncached() -> None: + """Test resolution with mix of cached and uncached hosts.""" + from esphome.address_cache import AddressCache + + cache = AddressCache(mdns_cache={"cached.local": ["192.168.1.50"]}) + + mock_addr_info = AddrInfo( + family=socket.AF_INET, + type=socket.SOCK_STREAM, + proto=socket.IPPROTO_TCP, + sockaddr=IPv4Sockaddr(address="192.168.1.100", port=6053), + ) + + with patch("esphome.resolver.AsyncResolver") as MockResolver: + mock_resolver = MockResolver.return_value + mock_resolver.resolve.return_value = [mock_addr_info] + + # Pass a list with cached IP, cached hostname, and uncached hostname + result = helpers.resolve_ip_address( + ["192.168.1.10", "cached.local", "uncached.local"], + 6053, + address_cache=cache, + ) + + # Should only resolve uncached.local + MockResolver.assert_called_once_with(["uncached.local"], 6053) + + # Results should include all addresses + addresses = [r[4][0] for r in result] + assert "192.168.1.10" in addresses # Direct IP + assert "192.168.1.50" in addresses # From cache + assert "192.168.1.100" in addresses # From resolver diff --git a/tests/unit_tests/test_main.py b/tests/unit_tests/test_main.py index 2c7236c7f8..ce19f18a1f 100644 --- a/tests/unit_tests/test_main.py +++ b/tests/unit_tests/test_main.py @@ -10,6 +10,7 @@ from unittest.mock import Mock, patch import pytest from esphome.__main__ import choose_upload_log_host +from esphome.address_cache import AddressCache from esphome.const import CONF_BROKER, CONF_MQTT, CONF_USE_ADDRESS, CONF_WIFI from esphome.core import CORE @@ -510,3 +511,73 @@ def test_choose_upload_log_host_no_address_with_ota_config() -> None: show_api=False, ) assert result == [] + + +def test_address_cache_from_cli_args() -> None: + """Test parsing address cache from CLI arguments.""" + # Test empty lists + cache = AddressCache.from_cli_args([], []) + assert cache.mdns_cache == {} + assert cache.dns_cache == {} + + # Test single entry with single IP + cache = AddressCache.from_cli_args( + ["host.local=192.168.1.1"], ["example.com=10.0.0.1"] + ) + assert cache.mdns_cache == {"host.local": ["192.168.1.1"]} + assert cache.dns_cache == {"example.com": ["10.0.0.1"]} + + # Test multiple IPs + cache = AddressCache.from_cli_args(["host.local=192.168.1.1,192.168.1.2"], []) + assert cache.mdns_cache == {"host.local": ["192.168.1.1", "192.168.1.2"]} + + # Test multiple entries + cache = AddressCache.from_cli_args( + ["host1.local=192.168.1.1", "host2.local=192.168.1.2"], + ["example.com=10.0.0.1", "test.org=10.0.0.2,10.0.0.3"], + ) + assert cache.mdns_cache == { + "host1.local": ["192.168.1.1"], + "host2.local": ["192.168.1.2"], + } + assert cache.dns_cache == { + "example.com": ["10.0.0.1"], + "test.org": ["10.0.0.2", "10.0.0.3"], + } + + # Test with IPv6 + cache = AddressCache.from_cli_args(["host.local=2001:db8::1,fe80::1"], []) + assert cache.mdns_cache == {"host.local": ["2001:db8::1", "fe80::1"]} + + # Test invalid format (should be skipped with warning) + with patch("esphome.address_cache._LOGGER") as mock_logger: + cache = AddressCache.from_cli_args(["invalid_format"], []) + assert cache.mdns_cache == {} + mock_logger.warning.assert_called_once() + + +def test_address_cache_get_methods() -> None: + """Test the AddressCache get methods.""" + cache = AddressCache( + mdns_cache={"test.local": ["192.168.1.1"]}, + dns_cache={"example.com": ["10.0.0.1"]}, + ) + + # Test mDNS lookup + assert cache.get_mdns_addresses("test.local") == ["192.168.1.1"] + assert cache.get_mdns_addresses("other.local") is None + + # Test DNS lookup + assert cache.get_dns_addresses("example.com") == ["10.0.0.1"] + assert cache.get_dns_addresses("other.com") is None + + # Test automatic selection based on domain + assert cache.get_addresses("test.local") == ["192.168.1.1"] + assert cache.get_addresses("example.com") == ["10.0.0.1"] + assert cache.get_addresses("unknown.local") is None + assert cache.get_addresses("unknown.com") is None + + # Test has_cache + assert cache.has_cache() is True + empty_cache = AddressCache() + assert empty_cache.has_cache() is False