diff --git a/esphome/__main__.py b/esphome/__main__.py index aab3035a5e..e3182ea55f 100644 --- a/esphome/__main__.py +++ b/esphome/__main__.py @@ -396,25 +396,27 @@ def check_permissions(port: str): ) -def upload_program(config: ConfigType, args: ArgsProtocol, host: str) -> int | str: +def upload_program( + config: ConfigType, args: ArgsProtocol, devices: list[str] +) -> int | str: try: module = importlib.import_module("esphome.components." + CORE.target_platform) - if getattr(module, "upload_program")(config, args, host): + if getattr(module, "upload_program")(config, args, devices[0]): return 0 except AttributeError: pass - if get_port_type(host) == "SERIAL": - check_permissions(host) + if get_port_type(devices[0]) == "SERIAL": + check_permissions(devices[0]) if CORE.target_platform in (PLATFORM_ESP32, PLATFORM_ESP8266): file = getattr(args, "file", None) - return upload_using_esptool(config, host, file, args.upload_speed) + return upload_using_esptool(config, devices[0], file, args.upload_speed) if CORE.target_platform in (PLATFORM_RP2040): - return upload_using_platformio(config, host) + return upload_using_platformio(config, devices[0]) if CORE.is_libretiny: - return upload_using_platformio(config, host) + return upload_using_platformio(config, devices[0]) return 1 # Unknown target platform @@ -433,28 +435,27 @@ def upload_program(config: ConfigType, args: ArgsProtocol, host: str) -> int | s remote_port = int(ota_conf[CONF_PORT]) password = ota_conf.get(CONF_PASSWORD, "") + binary = args.file if getattr(args, "file", None) is not None else CORE.firmware_bin # Check if we should use MQTT for address resolution # This happens when no device was specified, or the current host is "MQTT"/"OTA" - devices: list[str] = args.device or [] if ( CONF_MQTT in config # pylint: disable=too-many-boolean-expressions - and (not devices or host in ("MQTT", "OTA")) + and (not devices or devices[0] in ("MQTT", "OTA")) and ( ((config[CONF_MDNS][CONF_DISABLED]) and not is_ip_address(CORE.address)) - or get_port_type(host) == "MQTT" + or get_port_type(devices[0]) == "MQTT" ) ): from esphome import mqtt - host = mqtt.get_esphome_device_ip( - config, args.username, args.password, args.client_id - ) + devices = [ + mqtt.get_esphome_device_ip( + config, args.username, args.password, args.client_id + ) + ] - if getattr(args, "file", None) is not None: - return espota2.run_ota(host, remote_port, password, args.file) - - return espota2.run_ota(host, remote_port, password, CORE.firmware_bin) + return espota2.run_ota(devices, remote_port, password, binary) def show_logs(config: ConfigType, args: ArgsProtocol, devices: list[str]) -> int | None: @@ -551,17 +552,11 @@ def command_upload(args: ArgsProtocol, config: ConfigType) -> int | None: purpose="uploading", ) - # Try each device until one succeeds - exit_code = 1 - for device in devices: - _LOGGER.info("Uploading to %s", device) - exit_code = upload_program(config, args, device) - if exit_code == 0: - _LOGGER.info("Successfully uploaded program.") - return 0 - if len(devices) > 1: - _LOGGER.warning("Failed to upload to %s", device) - + exit_code = upload_program(config, args, devices) + if exit_code == 0: + _LOGGER.info("Successfully uploaded program.") + else: + _LOGGER.warning("Failed to upload to %s", devices) return exit_code diff --git a/esphome/espota2.py b/esphome/espota2.py index 279bafee8e..d83f25a303 100644 --- a/esphome/espota2.py +++ b/esphome/espota2.py @@ -308,8 +308,12 @@ def perform_ota( time.sleep(1) -def run_ota_impl_(remote_host, remote_port, password, filename): +def run_ota_impl_( + remote_host: str | list[str], remote_port: int, password: str, filename: str +) -> int: + # Handle both single host and list of hosts try: + # Resolve all hosts at once for parallel DNS resolution res = resolve_ip_address(remote_host, remote_port) except EsphomeError as err: _LOGGER.error( @@ -350,7 +354,9 @@ def run_ota_impl_(remote_host, remote_port, password, filename): return 1 -def run_ota(remote_host, remote_port, password, filename): +def run_ota( + remote_host: str | list[str], remote_port: int, password: str, filename: str +) -> int: try: return run_ota_impl_(remote_host, remote_port, password, filename) except OTAError as err: diff --git a/esphome/helpers.py b/esphome/helpers.py index 377a4e1717..b00c97ff73 100644 --- a/esphome/helpers.py +++ b/esphome/helpers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import codecs from contextlib import suppress import ipaddress @@ -11,6 +13,18 @@ from urllib.parse import urlparse from esphome.const import __version__ as ESPHOME_VERSION +# Type aliases for socket address information +AddrInfo = tuple[ + int, # family (AF_INET, AF_INET6, etc.) + int, # type (SOCK_STREAM, SOCK_DGRAM, etc.) + int, # proto (IPPROTO_TCP, etc.) + str, # canonname + tuple[str, int] | tuple[str, int, int, int], # sockaddr (IPv4 or IPv6) +] +IPv4SockAddr = tuple[str, int] # (host, port) +IPv6SockAddr = tuple[str, int, int, int] # (host, port, flowinfo, scope_id) +SockAddr = IPv4SockAddr | IPv6SockAddr + _LOGGER = logging.getLogger(__name__) IS_MACOS = platform.system() == "Darwin" @@ -147,32 +161,7 @@ def is_ip_address(host): return False -def _resolve_with_zeroconf(host): - from esphome.core import EsphomeError - from esphome.zeroconf import EsphomeZeroconf - - try: - zc = EsphomeZeroconf() - except Exception as err: - raise EsphomeError( - "Cannot start mDNS sockets, is this a docker container without " - "host network mode?" - ) from err - try: - info = zc.resolve_host(f"{host}.") - except Exception as err: - raise EsphomeError(f"Error resolving mDNS hostname: {err}") from err - finally: - zc.close() - if info is None: - raise EsphomeError( - "Error resolving address with mDNS: Did not respond. " - "Maybe the device is offline." - ) - return info - - -def addr_preference_(res): +def addr_preference_(res: AddrInfo) -> int: # Trivial alternative to RFC6724 sorting. Put sane IPv6 first, then # Legacy IP, then IPv6 link-local addresses without an actual link. sa = res[4] @@ -184,66 +173,70 @@ def addr_preference_(res): return 1 -def resolve_ip_address(host, port): +def resolve_ip_address(host: str | list[str], port: int) -> list[AddrInfo]: import socket - from esphome.core import EsphomeError - # There are five cases here. The host argument could be one of: # • a *list* of IP addresses discovered by MQTT, # • a single IP address specified by the user, # • a .local hostname to be resolved by mDNS, # • a normal hostname to be resolved in DNS, or # • A URL from which we should extract the hostname. - # - # In each of the first three cases, we end up with IP addresses in - # string form which need to be converted to a 5-tuple to be used - # for the socket connection attempt. The easiest way to construct - # those is to pass the IP address string to getaddrinfo(). Which, - # coincidentally, is how we do hostname lookups in the other cases - # too. So first build a list which contains either IP addresses or - # a single hostname, then call getaddrinfo() on each element of - # that list. - errs = [] + hosts: list[str] if isinstance(host, list): - addr_list = host - elif is_ip_address(host): - addr_list = [host] + hosts = host else: - url = urlparse(host) - if url.scheme != "": - host = url.hostname + if not is_ip_address(host): + url = urlparse(host) + if url.scheme != "": + host = url.hostname + hosts = [host] - addr_list = [] - if host.endswith(".local"): + res: list[AddrInfo] = [] + if all(is_ip_address(h) for h in hosts): + # Fast path: all are IP addresses, use socket.getaddrinfo with AI_NUMERICHOST + for addr in hosts: try: - _LOGGER.info("Resolving IP address of %s in mDNS", host) - addr_list = _resolve_with_zeroconf(host) - except EsphomeError as err: - errs.append(str(err)) + res += socket.getaddrinfo( + addr, port, proto=socket.IPPROTO_TCP, flags=socket.AI_NUMERICHOST + ) + except OSError: + _LOGGER.debug("Failed to parse IP address '%s'", addr) + # Sort by preference + res.sort(key=addr_preference_) + return res - # If not mDNS, or if mDNS failed, use normal DNS - if not addr_list: - addr_list = [host] + from esphome.resolver import AsyncResolver - # Now we have a list containing either IP addresses or a hostname - res = [] - for addr in addr_list: - if not is_ip_address(addr): - _LOGGER.info("Resolving IP address of %s", host) - try: - r = socket.getaddrinfo(addr, port, proto=socket.IPPROTO_TCP) - except OSError as err: - errs.append(str(err)) - raise EsphomeError( - f"Error resolving IP address: {', '.join(errs)}" - ) from err + resolver = AsyncResolver() + addr_infos = resolver.run(hosts, port) + # Convert aioesphomeapi AddrInfo to our format + for addr_info in addr_infos: + sockaddr = addr_info.sockaddr + if addr_info.family == socket.AF_INET6: + # IPv6 + sockaddr_tuple = ( + sockaddr.address, + sockaddr.port, + sockaddr.flowinfo, + sockaddr.scope_id, + ) + else: + # IPv4 + sockaddr_tuple = (sockaddr.address, sockaddr.port) - res = res + r + res.append( + ( + addr_info.family, + addr_info.type, + addr_info.proto, + "", # canonname + sockaddr_tuple, + ) + ) - # Zeroconf tends to give us link-local IPv6 addresses without specifying - # the link. Put those last in the list to be attempted. + # Sort by preference res.sort(key=addr_preference_) return res @@ -262,15 +255,7 @@ def sort_ip_addresses(address_list: list[str]) -> list[str]: # First "resolve" all the IP addresses to getaddrinfo() tuples of the form # (family, type, proto, canonname, sockaddr) - res: list[ - tuple[ - int, - int, - int, - str | None, - tuple[str, int] | tuple[str, int, int, int], - ] - ] = [] + res: list[AddrInfo] = [] for addr in address_list: # This should always work as these are supposed to be IP addresses try: diff --git a/esphome/resolver.py b/esphome/resolver.py new file mode 100644 index 0000000000..a245737962 --- /dev/null +++ b/esphome/resolver.py @@ -0,0 +1,61 @@ +"""DNS resolver for ESPHome using aioesphomeapi.""" + +from __future__ import annotations + +import asyncio +import threading + +from aioesphomeapi.core import ResolveAPIError, ResolveTimeoutAPIError +import aioesphomeapi.host_resolver as hr + +from esphome.core import EsphomeError + +RESOLVE_TIMEOUT = 10.0 # seconds + + +class AsyncResolver: + """Resolver using aioesphomeapi that runs in a thread for faster results. + + This resolver uses aioesphomeapi's async_resolve_host to handle DNS resolution, + including proper .local domain fallback. Running in a thread allows us to get + the result immediately without waiting for asyncio.run() to complete its + cleanup cycle, which can take significant time. + """ + + def __init__(self) -> None: + """Initialize the resolver.""" + self.result: list[hr.AddrInfo] | None = None + self.exception: Exception | None = None + self.event = threading.Event() + + async def _resolve(self, hosts: list[str], port: int) -> None: + """Resolve hostnames to IP addresses.""" + try: + self.result = await hr.async_resolve_host( + hosts, port, timeout=RESOLVE_TIMEOUT + ) + except Exception as e: + self.exception = e + finally: + self.event.set() + + def run(self, hosts: list[str], port: int) -> list[hr.AddrInfo]: + """Run the DNS resolution in a separate thread.""" + thread = threading.Thread( + target=lambda: asyncio.run(self._resolve(hosts, port)), daemon=True + ) + thread.start() + + if not self.event.wait( + timeout=RESOLVE_TIMEOUT + 1.0 + ): # Give it 1 second more than the resolver timeout + raise EsphomeError("Timeout resolving IP address") + + if exc := self.exception: + if isinstance(exc, ResolveAPIError): + raise EsphomeError(f"Error resolving IP address: {exc}") from exc + if isinstance(exc, ResolveTimeoutAPIError): + raise EsphomeError(f"Timeout resolving IP address: {exc}") from exc + raise exc + + return self.result