1
0
mirror of https://github.com/esphome/esphome.git synced 2025-09-08 06:12:20 +01:00

Fix DNS resolution inconsistency between logs and OTA operations

This commit is contained in:
J. Nick Koston
2025-09-04 19:58:44 -05:00
parent 365a427b57
commit ccbe629f8d
4 changed files with 155 additions and 108 deletions

View File

@@ -396,25 +396,27 @@ def check_permissions(port: str):
) )
def upload_program(config: ConfigType, args: ArgsProtocol, host: str) -> int | str: def upload_program(
config: ConfigType, args: ArgsProtocol, devices: list[str]
) -> int | str:
try: try:
module = importlib.import_module("esphome.components." + CORE.target_platform) module = importlib.import_module("esphome.components." + CORE.target_platform)
if getattr(module, "upload_program")(config, args, host): if getattr(module, "upload_program")(config, args, devices[0]):
return 0 return 0
except AttributeError: except AttributeError:
pass pass
if get_port_type(host) == "SERIAL": if get_port_type(devices[0]) == "SERIAL":
check_permissions(host) check_permissions(devices[0])
if CORE.target_platform in (PLATFORM_ESP32, PLATFORM_ESP8266): if CORE.target_platform in (PLATFORM_ESP32, PLATFORM_ESP8266):
file = getattr(args, "file", None) file = getattr(args, "file", None)
return upload_using_esptool(config, host, file, args.upload_speed) return upload_using_esptool(config, devices[0], file, args.upload_speed)
if CORE.target_platform in (PLATFORM_RP2040): if CORE.target_platform in (PLATFORM_RP2040):
return upload_using_platformio(config, host) return upload_using_platformio(config, devices[0])
if CORE.is_libretiny: if CORE.is_libretiny:
return upload_using_platformio(config, host) return upload_using_platformio(config, devices[0])
return 1 # Unknown target platform return 1 # Unknown target platform
@@ -433,28 +435,27 @@ def upload_program(config: ConfigType, args: ArgsProtocol, host: str) -> int | s
remote_port = int(ota_conf[CONF_PORT]) remote_port = int(ota_conf[CONF_PORT])
password = ota_conf.get(CONF_PASSWORD, "") password = ota_conf.get(CONF_PASSWORD, "")
binary = args.file if getattr(args, "file", None) is not None else CORE.firmware_bin
# Check if we should use MQTT for address resolution # Check if we should use MQTT for address resolution
# This happens when no device was specified, or the current host is "MQTT"/"OTA" # This happens when no device was specified, or the current host is "MQTT"/"OTA"
devices: list[str] = args.device or []
if ( if (
CONF_MQTT in config # pylint: disable=too-many-boolean-expressions CONF_MQTT in config # pylint: disable=too-many-boolean-expressions
and (not devices or host in ("MQTT", "OTA")) and (not devices or devices[0] in ("MQTT", "OTA"))
and ( and (
((config[CONF_MDNS][CONF_DISABLED]) and not is_ip_address(CORE.address)) ((config[CONF_MDNS][CONF_DISABLED]) and not is_ip_address(CORE.address))
or get_port_type(host) == "MQTT" or get_port_type(devices[0]) == "MQTT"
) )
): ):
from esphome import mqtt from esphome import mqtt
host = mqtt.get_esphome_device_ip( devices = [
config, args.username, args.password, args.client_id mqtt.get_esphome_device_ip(
) config, args.username, args.password, args.client_id
)
]
if getattr(args, "file", None) is not None: return espota2.run_ota(devices, remote_port, password, binary)
return espota2.run_ota(host, remote_port, password, args.file)
return espota2.run_ota(host, remote_port, password, CORE.firmware_bin)
def show_logs(config: ConfigType, args: ArgsProtocol, devices: list[str]) -> int | None: def show_logs(config: ConfigType, args: ArgsProtocol, devices: list[str]) -> int | None:
@@ -551,17 +552,11 @@ def command_upload(args: ArgsProtocol, config: ConfigType) -> int | None:
purpose="uploading", purpose="uploading",
) )
# Try each device until one succeeds exit_code = upload_program(config, args, devices)
exit_code = 1 if exit_code == 0:
for device in devices: _LOGGER.info("Successfully uploaded program.")
_LOGGER.info("Uploading to %s", device) else:
exit_code = upload_program(config, args, device) _LOGGER.warning("Failed to upload to %s", devices)
if exit_code == 0:
_LOGGER.info("Successfully uploaded program.")
return 0
if len(devices) > 1:
_LOGGER.warning("Failed to upload to %s", device)
return exit_code return exit_code

View File

@@ -308,8 +308,12 @@ def perform_ota(
time.sleep(1) time.sleep(1)
def run_ota_impl_(remote_host, remote_port, password, filename): def run_ota_impl_(
remote_host: str | list[str], remote_port: int, password: str, filename: str
) -> int:
# Handle both single host and list of hosts
try: try:
# Resolve all hosts at once for parallel DNS resolution
res = resolve_ip_address(remote_host, remote_port) res = resolve_ip_address(remote_host, remote_port)
except EsphomeError as err: except EsphomeError as err:
_LOGGER.error( _LOGGER.error(
@@ -350,7 +354,9 @@ def run_ota_impl_(remote_host, remote_port, password, filename):
return 1 return 1
def run_ota(remote_host, remote_port, password, filename): def run_ota(
remote_host: str | list[str], remote_port: int, password: str, filename: str
) -> int:
try: try:
return run_ota_impl_(remote_host, remote_port, password, filename) return run_ota_impl_(remote_host, remote_port, password, filename)
except OTAError as err: except OTAError as err:

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import codecs import codecs
from contextlib import suppress from contextlib import suppress
import ipaddress import ipaddress
@@ -11,6 +13,18 @@ from urllib.parse import urlparse
from esphome.const import __version__ as ESPHOME_VERSION from esphome.const import __version__ as ESPHOME_VERSION
# Type aliases for socket address information
AddrInfo = tuple[
int, # family (AF_INET, AF_INET6, etc.)
int, # type (SOCK_STREAM, SOCK_DGRAM, etc.)
int, # proto (IPPROTO_TCP, etc.)
str, # canonname
tuple[str, int] | tuple[str, int, int, int], # sockaddr (IPv4 or IPv6)
]
IPv4SockAddr = tuple[str, int] # (host, port)
IPv6SockAddr = tuple[str, int, int, int] # (host, port, flowinfo, scope_id)
SockAddr = IPv4SockAddr | IPv6SockAddr
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
IS_MACOS = platform.system() == "Darwin" IS_MACOS = platform.system() == "Darwin"
@@ -147,32 +161,7 @@ def is_ip_address(host):
return False return False
def _resolve_with_zeroconf(host): def addr_preference_(res: AddrInfo) -> int:
from esphome.core import EsphomeError
from esphome.zeroconf import EsphomeZeroconf
try:
zc = EsphomeZeroconf()
except Exception as err:
raise EsphomeError(
"Cannot start mDNS sockets, is this a docker container without "
"host network mode?"
) from err
try:
info = zc.resolve_host(f"{host}.")
except Exception as err:
raise EsphomeError(f"Error resolving mDNS hostname: {err}") from err
finally:
zc.close()
if info is None:
raise EsphomeError(
"Error resolving address with mDNS: Did not respond. "
"Maybe the device is offline."
)
return info
def addr_preference_(res):
# Trivial alternative to RFC6724 sorting. Put sane IPv6 first, then # Trivial alternative to RFC6724 sorting. Put sane IPv6 first, then
# Legacy IP, then IPv6 link-local addresses without an actual link. # Legacy IP, then IPv6 link-local addresses without an actual link.
sa = res[4] sa = res[4]
@@ -184,66 +173,70 @@ def addr_preference_(res):
return 1 return 1
def resolve_ip_address(host, port): def resolve_ip_address(host: str | list[str], port: int) -> list[AddrInfo]:
import socket import socket
from esphome.core import EsphomeError
# There are five cases here. The host argument could be one of: # There are five cases here. The host argument could be one of:
# • a *list* of IP addresses discovered by MQTT, # • a *list* of IP addresses discovered by MQTT,
# • a single IP address specified by the user, # • a single IP address specified by the user,
# • a .local hostname to be resolved by mDNS, # • a .local hostname to be resolved by mDNS,
# • a normal hostname to be resolved in DNS, or # • a normal hostname to be resolved in DNS, or
# • A URL from which we should extract the hostname. # • A URL from which we should extract the hostname.
#
# In each of the first three cases, we end up with IP addresses in
# string form which need to be converted to a 5-tuple to be used
# for the socket connection attempt. The easiest way to construct
# those is to pass the IP address string to getaddrinfo(). Which,
# coincidentally, is how we do hostname lookups in the other cases
# too. So first build a list which contains either IP addresses or
# a single hostname, then call getaddrinfo() on each element of
# that list.
errs = [] hosts: list[str]
if isinstance(host, list): if isinstance(host, list):
addr_list = host hosts = host
elif is_ip_address(host):
addr_list = [host]
else: else:
url = urlparse(host) if not is_ip_address(host):
if url.scheme != "": url = urlparse(host)
host = url.hostname if url.scheme != "":
host = url.hostname
hosts = [host]
addr_list = [] res: list[AddrInfo] = []
if host.endswith(".local"): if all(is_ip_address(h) for h in hosts):
# Fast path: all are IP addresses, use socket.getaddrinfo with AI_NUMERICHOST
for addr in hosts:
try: try:
_LOGGER.info("Resolving IP address of %s in mDNS", host) res += socket.getaddrinfo(
addr_list = _resolve_with_zeroconf(host) addr, port, proto=socket.IPPROTO_TCP, flags=socket.AI_NUMERICHOST
except EsphomeError as err: )
errs.append(str(err)) except OSError:
_LOGGER.debug("Failed to parse IP address '%s'", addr)
# Sort by preference
res.sort(key=addr_preference_)
return res
# If not mDNS, or if mDNS failed, use normal DNS from esphome.resolver import AsyncResolver
if not addr_list:
addr_list = [host]
# Now we have a list containing either IP addresses or a hostname resolver = AsyncResolver()
res = [] addr_infos = resolver.run(hosts, port)
for addr in addr_list: # Convert aioesphomeapi AddrInfo to our format
if not is_ip_address(addr): for addr_info in addr_infos:
_LOGGER.info("Resolving IP address of %s", host) sockaddr = addr_info.sockaddr
try: if addr_info.family == socket.AF_INET6:
r = socket.getaddrinfo(addr, port, proto=socket.IPPROTO_TCP) # IPv6
except OSError as err: sockaddr_tuple = (
errs.append(str(err)) sockaddr.address,
raise EsphomeError( sockaddr.port,
f"Error resolving IP address: {', '.join(errs)}" sockaddr.flowinfo,
) from err sockaddr.scope_id,
)
else:
# IPv4
sockaddr_tuple = (sockaddr.address, sockaddr.port)
res = res + r res.append(
(
addr_info.family,
addr_info.type,
addr_info.proto,
"", # canonname
sockaddr_tuple,
)
)
# Zeroconf tends to give us link-local IPv6 addresses without specifying # Sort by preference
# the link. Put those last in the list to be attempted.
res.sort(key=addr_preference_) res.sort(key=addr_preference_)
return res return res
@@ -262,15 +255,7 @@ def sort_ip_addresses(address_list: list[str]) -> list[str]:
# First "resolve" all the IP addresses to getaddrinfo() tuples of the form # First "resolve" all the IP addresses to getaddrinfo() tuples of the form
# (family, type, proto, canonname, sockaddr) # (family, type, proto, canonname, sockaddr)
res: list[ res: list[AddrInfo] = []
tuple[
int,
int,
int,
str | None,
tuple[str, int] | tuple[str, int, int, int],
]
] = []
for addr in address_list: for addr in address_list:
# This should always work as these are supposed to be IP addresses # This should always work as these are supposed to be IP addresses
try: try:

61
esphome/resolver.py Normal file
View File

@@ -0,0 +1,61 @@
"""DNS resolver for ESPHome using aioesphomeapi."""
from __future__ import annotations
import asyncio
import threading
from aioesphomeapi.core import ResolveAPIError, ResolveTimeoutAPIError
import aioesphomeapi.host_resolver as hr
from esphome.core import EsphomeError
RESOLVE_TIMEOUT = 10.0 # seconds
class AsyncResolver:
"""Resolver using aioesphomeapi that runs in a thread for faster results.
This resolver uses aioesphomeapi's async_resolve_host to handle DNS resolution,
including proper .local domain fallback. Running in a thread allows us to get
the result immediately without waiting for asyncio.run() to complete its
cleanup cycle, which can take significant time.
"""
def __init__(self) -> None:
"""Initialize the resolver."""
self.result: list[hr.AddrInfo] | None = None
self.exception: Exception | None = None
self.event = threading.Event()
async def _resolve(self, hosts: list[str], port: int) -> None:
"""Resolve hostnames to IP addresses."""
try:
self.result = await hr.async_resolve_host(
hosts, port, timeout=RESOLVE_TIMEOUT
)
except Exception as e:
self.exception = e
finally:
self.event.set()
def run(self, hosts: list[str], port: int) -> list[hr.AddrInfo]:
"""Run the DNS resolution in a separate thread."""
thread = threading.Thread(
target=lambda: asyncio.run(self._resolve(hosts, port)), daemon=True
)
thread.start()
if not self.event.wait(
timeout=RESOLVE_TIMEOUT + 1.0
): # Give it 1 second more than the resolver timeout
raise EsphomeError("Timeout resolving IP address")
if exc := self.exception:
if isinstance(exc, ResolveAPIError):
raise EsphomeError(f"Error resolving IP address: {exc}") from exc
if isinstance(exc, ResolveTimeoutAPIError):
raise EsphomeError(f"Timeout resolving IP address: {exc}") from exc
raise exc
return self.result