1
0
mirror of https://github.com/esphome/esphome.git synced 2025-10-30 06:33:51 +00:00

Fix dashboard dns lookup delay

This commit is contained in:
J. Nick Koston
2025-09-11 18:21:01 -05:00
parent bbef0e173e
commit 4d3405340d
10 changed files with 448 additions and 57 deletions

View File

@@ -889,6 +889,18 @@ def parse_args(argv):
help="Add a substitution",
metavar=("key", "value"),
)
options_parser.add_argument(
"--mdns-lookup-cache",
help="mDNS lookup cache mapping in format 'hostname=ip1,ip2'",
action="append",
default=[],
)
options_parser.add_argument(
"--dns-lookup-cache",
help="DNS lookup cache mapping in format 'hostname=ip1,ip2'",
action="append",
default=[],
)
parser = argparse.ArgumentParser(
description=f"ESPHome {const.__version__}", parents=[options_parser]
@@ -1136,9 +1148,19 @@ def parse_args(argv):
def run_esphome(argv):
from esphome.address_cache import AddressCache
args = parse_args(argv)
CORE.dashboard = args.dashboard
# Create address cache from command-line arguments
address_cache = AddressCache.from_cli_args(
args.mdns_lookup_cache, args.dns_lookup_cache
)
# Store cache in CORE for access throughout the application
CORE.address_cache = address_cache
# Override log level if verbose is set
if args.verbose:
args.log_level = "DEBUG"

131
esphome/address_cache.py Normal file
View File

@@ -0,0 +1,131 @@
"""Address cache for DNS and mDNS lookups."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Iterable
_LOGGER = logging.getLogger(__name__)
def normalize_hostname(hostname: str) -> str:
"""Normalize hostname for cache lookups.
Removes trailing dots and converts to lowercase.
"""
return hostname.rstrip(".").lower()
class AddressCache:
"""Cache for DNS and mDNS address lookups.
This cache stores pre-resolved addresses from command-line arguments
to avoid slow DNS/mDNS lookups during builds.
"""
def __init__(
self,
mdns_cache: dict[str, list[str]] | None = None,
dns_cache: dict[str, list[str]] | None = None,
) -> None:
"""Initialize the address cache.
Args:
mdns_cache: Pre-populated mDNS addresses (hostname -> IPs)
dns_cache: Pre-populated DNS addresses (hostname -> IPs)
"""
self.mdns_cache = mdns_cache or {}
self.dns_cache = dns_cache or {}
def get_mdns_addresses(self, hostname: str) -> list[str] | None:
"""Get cached mDNS addresses for a hostname.
Args:
hostname: The hostname to look up (should end with .local)
Returns:
List of IP addresses if found in cache, None otherwise
"""
normalized = normalize_hostname(hostname)
if addresses := self.mdns_cache.get(normalized):
_LOGGER.debug("Using mDNS cache for %s: %s", hostname, addresses)
return addresses
return None
def get_dns_addresses(self, hostname: str) -> list[str] | None:
"""Get cached DNS addresses for a hostname.
Args:
hostname: The hostname to look up
Returns:
List of IP addresses if found in cache, None otherwise
"""
normalized = normalize_hostname(hostname)
if addresses := self.dns_cache.get(normalized):
_LOGGER.debug("Using DNS cache for %s: %s", hostname, addresses)
return addresses
return None
def get_addresses(self, hostname: str) -> list[str] | None:
"""Get cached addresses for a hostname.
Checks mDNS cache for .local domains, DNS cache otherwise.
Args:
hostname: The hostname to look up
Returns:
List of IP addresses if found in cache, None otherwise
"""
normalized = normalize_hostname(hostname)
if normalized.endswith(".local"):
return self.get_mdns_addresses(hostname)
return self.get_dns_addresses(hostname)
def has_cache(self) -> bool:
"""Check if any cache entries exist."""
return bool(self.mdns_cache or self.dns_cache)
@classmethod
def from_cli_args(
cls, mdns_args: Iterable[str], dns_args: Iterable[str]
) -> AddressCache:
"""Create cache from command-line arguments.
Args:
mdns_args: List of mDNS cache entries like ['host=ip1,ip2']
dns_args: List of DNS cache entries like ['host=ip1,ip2']
Returns:
Configured AddressCache instance
"""
mdns_cache = cls._parse_cache_args(mdns_args)
dns_cache = cls._parse_cache_args(dns_args)
return cls(mdns_cache=mdns_cache, dns_cache=dns_cache)
@staticmethod
def _parse_cache_args(cache_args: Iterable[str]) -> dict[str, list[str]]:
"""Parse cache arguments into a dictionary.
Args:
cache_args: List of cache mappings like ['host1=ip1,ip2', 'host2=ip3']
Returns:
Dictionary mapping normalized hostnames to list of IP addresses
"""
cache: dict[str, list[str]] = {}
for arg in cache_args:
if "=" not in arg:
_LOGGER.warning(
"Invalid cache format: %s (expected 'hostname=ip1,ip2')", arg
)
continue
hostname, ips = arg.split("=", 1)
# Normalize hostname for consistent lookups
normalized = normalize_hostname(hostname)
cache[normalized] = [ip.strip() for ip in ips.split(",")]
return cache

View File

@@ -583,6 +583,8 @@ class EsphomeCore:
self.id_classes = {}
# The current component being processed during validation
self.current_component: str | None = None
# Address cache for DNS and mDNS lookups from command line arguments
self.address_cache: object | None = None
def reset(self):
from esphome.pins import PIN_SCHEMA_REGISTRY
@@ -610,6 +612,7 @@ class EsphomeCore:
self.platform_counts = defaultdict(int)
self.unique_ids = {}
self.current_component = None
self.address_cache = None
PIN_SCHEMA_REGISTRY.reset()
@contextmanager

View File

@@ -28,6 +28,17 @@ class DNSCache:
self._cache: dict[str, tuple[float, list[str] | Exception]] = {}
self._ttl = ttl
def get_cached(self, hostname: str, now_monotonic: float) -> list[str] | None:
"""Get cached address without triggering resolution.
Returns None if not in cache, list of addresses if found.
"""
if expire_time_addresses := self._cache.get(hostname):
expire_time, addresses = expire_time_addresses
if expire_time > now_monotonic and not isinstance(addresses, Exception):
return addresses
return None
async def async_resolve(
self, hostname: str, now_monotonic: float
) -> list[str] | Exception:

View File

@@ -50,6 +50,22 @@ class MDNSStatus:
return await aiozc.async_resolve_host(host_name)
return None
def get_cached_addresses(self, host_name: str) -> list[str] | None:
"""Get cached addresses for a host without triggering resolution.
Returns None if not in cache or no zeroconf available.
"""
if not self.aiozc:
return None
from zeroconf import AddressResolver, IPVersion
# Try to load from zeroconf cache without triggering resolution
info = AddressResolver(f"{host_name.partition('.')[0]}.local.")
if info.load_from_cache(self.aiozc.zeroconf):
return info.parsed_scoped_addresses(IPVersion.All)
return None
async def async_refresh_hosts(self) -> None:
"""Refresh the hosts to track."""
dashboard = self.dashboard

View File

@@ -326,52 +326,64 @@ class EsphomePortCommandWebSocket(EsphomeCommandWebSocket):
configuration = json_message["configuration"]
config_file = settings.rel_path(configuration)
port = json_message["port"]
# Only get cached addresses - no async resolution
addresses: list[str] = []
cache_args: list[str] = []
if (
port == "OTA" # pylint: disable=too-many-boolean-expressions
and (entry := entries.get(config_file))
and entry.loaded_integrations
and "api" in entry.loaded_integrations
):
# First priority: entry.address AKA use_address
if (
(use_address := entry.address)
and (
address_list := await dashboard.dns_cache.async_resolve(
use_address, time.monotonic()
)
)
and not isinstance(address_list, Exception)
):
addresses.extend(sort_ip_addresses(address_list))
now = time.monotonic()
# Second priority: mDNS
if (
(mdns := dashboard.mdns_status)
and (address_list := await mdns.async_resolve_host(entry.name))
and (
new_addresses := [
addr for addr in address_list if addr not in addresses
]
)
# Collect all cached addresses for this device
dns_cache_entries: dict[str, set[str]] = {}
mdns_cache_entries: dict[str, set[str]] = {}
# First priority: entry.address AKA use_address (from DNS cache only)
if (use_address := entry.address) and (
cached := dashboard.dns_cache.get_cached(use_address, now)
):
# Use the IP address if available but only
# if the API is loaded and the device is online
# since MQTT logging will not work otherwise
addresses.extend(sort_ip_addresses(new_addresses))
addresses.extend(sort_ip_addresses(cached))
dns_cache_entries[use_address] = set(cached)
# Second priority: mDNS cache for device name
if entry.name and not addresses: # Only if we don't have addresses yet
if entry.name.endswith(".local"):
# Check mDNS cache (zeroconf)
if (mdns := dashboard.mdns_status) and (
cached := mdns.get_cached_addresses(entry.name)
):
addresses.extend(sort_ip_addresses(cached))
mdns_cache_entries[entry.name] = set(cached)
# Check DNS cache for non-.local names
elif cached := dashboard.dns_cache.get_cached(entry.name, now):
addresses.extend(sort_ip_addresses(cached))
dns_cache_entries[entry.name] = set(cached)
# Build cache arguments to pass to CLI
for hostname, addrs in dns_cache_entries.items():
cache_args.extend(
["--dns-lookup-cache", f"{hostname}={','.join(sorted(addrs))}"]
)
for hostname, addrs in mdns_cache_entries.items():
cache_args.extend(
["--mdns-lookup-cache", f"{hostname}={','.join(sorted(addrs))}"]
)
if not addresses:
# If no address was found, use the port directly
# as otherwise they will get the chooser which
# does not work with the dashboard as there is no
# interactive way to get keyboard input
# If no cached address was found, use the port directly
# The CLI will do the resolution with the cache hints we provide
addresses = [port]
device_args: list[str] = [
arg for address in addresses for arg in ("--device", address)
]
return [*DASHBOARD_COMMAND, *args, config_file, *device_args]
return [*DASHBOARD_COMMAND, *args, config_file, *device_args, *cache_args]
class EsphomeLogsHandler(EsphomePortCommandWebSocket):

View File

@@ -311,10 +311,14 @@ def perform_ota(
def run_ota_impl_(
remote_host: str | list[str], remote_port: int, password: str, filename: str
) -> tuple[int, str | None]:
from esphome.core import CORE
# Handle both single host and list of hosts
try:
# Resolve all hosts at once for parallel DNS resolution
res = resolve_ip_address(remote_host, remote_port)
res = resolve_ip_address(
remote_host, remote_port, address_cache=getattr(CORE, "address_cache", None)
)
except EsphomeError as err:
_LOGGER.error(
"Error resolving IP address of %s. Is it connected to WiFi?",

View File

@@ -173,7 +173,9 @@ def addr_preference_(res: AddrInfo) -> int:
return 1
def resolve_ip_address(host: str | list[str], port: int) -> list[AddrInfo]:
def resolve_ip_address(
host: str | list[str], port: int, address_cache: object | None = None
) -> list[AddrInfo]:
import socket
# There are five cases here. The host argument could be one of:
@@ -194,8 +196,9 @@ def resolve_ip_address(host: str | list[str], port: int) -> list[AddrInfo]:
hosts = [host]
res: list[AddrInfo] = []
# Fast path: if all hosts are already IP addresses
if all(is_ip_address(h) for h in hosts):
# Fast path: all are IP addresses, use socket.getaddrinfo with AI_NUMERICHOST
for addr in hosts:
try:
res += socket.getaddrinfo(
@@ -207,34 +210,65 @@ def resolve_ip_address(host: str | list[str], port: int) -> list[AddrInfo]:
res.sort(key=addr_preference_)
return res
from esphome.resolver import AsyncResolver
# Check if we have cached addresses for these hosts
cached_hosts: list[str] = []
uncached_hosts: list[str] = []
resolver = AsyncResolver(hosts, port)
addr_infos = resolver.resolve()
# Convert aioesphomeapi AddrInfo to our format
for addr_info in addr_infos:
sockaddr = addr_info.sockaddr
if addr_info.family == socket.AF_INET6:
# IPv6
sockaddr_tuple = (
sockaddr.address,
sockaddr.port,
sockaddr.flowinfo,
sockaddr.scope_id,
)
else:
# IPv4
sockaddr_tuple = (sockaddr.address, sockaddr.port)
for h in hosts:
# Check if it's already an IP address
if is_ip_address(h):
cached_hosts.append(h)
continue
res.append(
(
addr_info.family,
addr_info.type,
addr_info.proto,
"", # canonname
sockaddr_tuple,
# Check cache if provided
if address_cache and (cached_addresses := address_cache.get_addresses(h)):
cached_hosts.extend(cached_addresses)
continue
# Not in cache, need to resolve
if address_cache and address_cache.has_cache():
_LOGGER.info("Host %s not in cache, will need to resolve", h)
uncached_hosts.append(h)
# Process cached addresses (all should be IP addresses)
for addr in cached_hosts:
try:
res += socket.getaddrinfo(
addr, port, proto=socket.IPPROTO_TCP, flags=socket.AI_NUMERICHOST
)
except OSError:
_LOGGER.debug("Failed to parse IP address '%s'", addr)
# If we have uncached hosts, resolve them
if uncached_hosts:
from esphome.resolver import AsyncResolver
resolver = AsyncResolver(uncached_hosts, port)
addr_infos = resolver.resolve()
# Convert aioesphomeapi AddrInfo to our format
for addr_info in addr_infos:
sockaddr = addr_info.sockaddr
if addr_info.family == socket.AF_INET6:
# IPv6
sockaddr_tuple = (
sockaddr.address,
sockaddr.port,
sockaddr.flowinfo,
sockaddr.scope_id,
)
else:
# IPv4
sockaddr_tuple = (sockaddr.address, sockaddr.port)
res.append(
(
addr_info.family,
addr_info.type,
addr_info.proto,
"", # canonname
sockaddr_tuple,
)
)
)
# Sort by preference
res.sort(key=addr_preference_)