1
0
mirror of https://github.com/esphome/esphome.git synced 2025-09-15 09:42:19 +01:00
This commit is contained in:
J. Nick Koston
2025-09-04 21:39:50 -05:00
parent 6ab0581c93
commit 830b9a881a
4 changed files with 52 additions and 51 deletions

View File

@@ -209,8 +209,8 @@ def resolve_ip_address(host: str | list[str], port: int) -> list[AddrInfo]:
from esphome.resolver import AsyncResolver from esphome.resolver import AsyncResolver
resolver = AsyncResolver() resolver = AsyncResolver(hosts, port)
addr_infos = resolver.run(hosts, port) addr_infos = resolver.resolve()
# Convert aioesphomeapi AddrInfo to our format # Convert aioesphomeapi AddrInfo to our format
for addr_info in addr_infos: for addr_info in addr_infos:
sockaddr = addr_info.sockaddr sockaddr = addr_info.sockaddr

View File

@@ -13,7 +13,7 @@ from esphome.core import EsphomeError
RESOLVE_TIMEOUT = 10.0 # seconds RESOLVE_TIMEOUT = 10.0 # seconds
class AsyncResolver: class AsyncResolver(threading.Thread):
"""Resolver using aioesphomeapi that runs in a thread for faster results. """Resolver using aioesphomeapi that runs in a thread for faster results.
This resolver uses aioesphomeapi's async_resolve_host to handle DNS resolution, This resolver uses aioesphomeapi's async_resolve_host to handle DNS resolution,
@@ -22,17 +22,20 @@ class AsyncResolver:
cleanup cycle, which can take significant time. cleanup cycle, which can take significant time.
""" """
def __init__(self) -> None: def __init__(self, hosts: list[str], port: int) -> None:
"""Initialize the resolver.""" """Initialize the resolver."""
super().__init__(daemon=True)
self.hosts = hosts
self.port = port
self.result: list[hr.AddrInfo] | None = None self.result: list[hr.AddrInfo] | None = None
self.exception: Exception | None = None self.exception: Exception | None = None
self.event = threading.Event() self.event = threading.Event()
async def _resolve(self, hosts: list[str], port: int) -> None: async def _resolve(self) -> None:
"""Resolve hostnames to IP addresses.""" """Resolve hostnames to IP addresses."""
try: try:
self.result = await hr.async_resolve_host( self.result = await hr.async_resolve_host(
hosts, port, timeout=RESOLVE_TIMEOUT self.hosts, self.port, timeout=RESOLVE_TIMEOUT
) )
except Exception as e: # pylint: disable=broad-except except Exception as e: # pylint: disable=broad-except
# We need to catch all exceptions to ensure the event is set # We need to catch all exceptions to ensure the event is set
@@ -41,12 +44,13 @@ class AsyncResolver:
finally: finally:
self.event.set() self.event.set()
def run(self, hosts: list[str], port: int) -> list[hr.AddrInfo]: def run(self) -> None:
"""Run the DNS resolution in a separate thread.""" """Run the DNS resolution."""
thread = threading.Thread( asyncio.run(self._resolve())
target=lambda: asyncio.run(self._resolve(hosts, port)), daemon=True
) def resolve(self) -> list[hr.AddrInfo]:
thread.start() """Start the thread and wait for the result."""
self.start()
if not self.event.wait( if not self.event.wait(
timeout=RESOLVE_TIMEOUT + 1.0 timeout=RESOLVE_TIMEOUT + 1.0

View File

@@ -423,14 +423,15 @@ def test_resolve_ip_address_hostname() -> None:
with patch("esphome.resolver.AsyncResolver") as MockResolver: with patch("esphome.resolver.AsyncResolver") as MockResolver:
mock_resolver = MockResolver.return_value mock_resolver = MockResolver.return_value
mock_resolver.run.return_value = [mock_addr_info] mock_resolver.resolve.return_value = [mock_addr_info]
result = helpers.resolve_ip_address("test.local", 6053) result = helpers.resolve_ip_address("test.local", 6053)
assert len(result) == 1 assert len(result) == 1
assert result[0][0] == socket.AF_INET assert result[0][0] == socket.AF_INET
assert result[0][4] == ("192.168.1.100", 6053) assert result[0][4] == ("192.168.1.100", 6053)
mock_resolver.run.assert_called_once_with(["test.local"], 6053) MockResolver.assert_called_once_with(["test.local"], 6053)
mock_resolver.resolve.assert_called_once()
def test_resolve_ip_address_mixed_list() -> None: def test_resolve_ip_address_mixed_list() -> None:
@@ -444,14 +445,15 @@ def test_resolve_ip_address_mixed_list() -> None:
with patch("esphome.resolver.AsyncResolver") as MockResolver: with patch("esphome.resolver.AsyncResolver") as MockResolver:
mock_resolver = MockResolver.return_value mock_resolver = MockResolver.return_value
mock_resolver.run.return_value = [mock_addr_info] mock_resolver.resolve.return_value = [mock_addr_info]
# Mix of IP and hostname - should use async resolver # Mix of IP and hostname - should use async resolver
result = helpers.resolve_ip_address(["192.168.1.100", "test.local"], 6053) result = helpers.resolve_ip_address(["192.168.1.100", "test.local"], 6053)
assert len(result) == 1 assert len(result) == 1
assert result[0][4][0] == "192.168.1.200" assert result[0][4][0] == "192.168.1.200"
mock_resolver.run.assert_called_once_with(["192.168.1.100", "test.local"], 6053) MockResolver.assert_called_once_with(["192.168.1.100", "test.local"], 6053)
mock_resolver.resolve.assert_called_once()
def test_resolve_ip_address_url() -> None: def test_resolve_ip_address_url() -> None:
@@ -465,12 +467,13 @@ def test_resolve_ip_address_url() -> None:
with patch("esphome.resolver.AsyncResolver") as MockResolver: with patch("esphome.resolver.AsyncResolver") as MockResolver:
mock_resolver = MockResolver.return_value mock_resolver = MockResolver.return_value
mock_resolver.run.return_value = [mock_addr_info] mock_resolver.resolve.return_value = [mock_addr_info]
result = helpers.resolve_ip_address("http://test.local", 6053) result = helpers.resolve_ip_address("http://test.local", 6053)
assert len(result) == 1 assert len(result) == 1
mock_resolver.run.assert_called_once_with(["test.local"], 6053) MockResolver.assert_called_once_with(["test.local"], 6053)
mock_resolver.resolve.assert_called_once()
def test_resolve_ip_address_ipv6_conversion() -> None: def test_resolve_ip_address_ipv6_conversion() -> None:
@@ -484,7 +487,7 @@ def test_resolve_ip_address_ipv6_conversion() -> None:
with patch("esphome.resolver.AsyncResolver") as MockResolver: with patch("esphome.resolver.AsyncResolver") as MockResolver:
mock_resolver = MockResolver.return_value mock_resolver = MockResolver.return_value
mock_resolver.run.return_value = [mock_addr_info] mock_resolver.resolve.return_value = [mock_addr_info]
result = helpers.resolve_ip_address("test.local", 6053) result = helpers.resolve_ip_address("test.local", 6053)
@@ -497,7 +500,7 @@ def test_resolve_ip_address_error_handling() -> None:
"""Test error handling from AsyncResolver.""" """Test error handling from AsyncResolver."""
with patch("esphome.resolver.AsyncResolver") as MockResolver: with patch("esphome.resolver.AsyncResolver") as MockResolver:
mock_resolver = MockResolver.return_value mock_resolver = MockResolver.return_value
mock_resolver.run.side_effect = EsphomeError("Resolution failed") mock_resolver.resolve.side_effect = EsphomeError("Resolution failed")
with pytest.raises(EsphomeError, match="Resolution failed"): with pytest.raises(EsphomeError, match="Resolution failed"):
helpers.resolve_ip_address("test.local", 6053) helpers.resolve_ip_address("test.local", 6053)
@@ -583,7 +586,7 @@ def test_resolve_ip_address_sorting() -> None:
with patch("esphome.resolver.AsyncResolver") as MockResolver: with patch("esphome.resolver.AsyncResolver") as MockResolver:
mock_resolver = MockResolver.return_value mock_resolver = MockResolver.return_value
mock_resolver.run.return_value = mock_addr_infos mock_resolver.resolve.return_value = mock_addr_infos
result = helpers.resolve_ip_address("test.local", 6053) result = helpers.resolve_ip_address("test.local", 6053)

View File

@@ -2,10 +2,8 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import re import re
import socket import socket
import threading
from unittest.mock import patch from unittest.mock import patch
from aioesphomeapi.core import ResolveAPIError, ResolveTimeoutAPIError from aioesphomeapi.core import ResolveAPIError, ResolveTimeoutAPIError
@@ -44,8 +42,8 @@ def test_async_resolver_successful_resolution(mock_addr_info_ipv4: AddrInfo) ->
"esphome.resolver.hr.async_resolve_host", "esphome.resolver.hr.async_resolve_host",
return_value=[mock_addr_info_ipv4], return_value=[mock_addr_info_ipv4],
) as mock_resolve: ) as mock_resolve:
resolver = AsyncResolver() resolver = AsyncResolver(["test.local"], 6053)
result = resolver.run(["test.local"], 6053) result = resolver.resolve()
assert result == [mock_addr_info_ipv4] assert result == [mock_addr_info_ipv4]
mock_resolve.assert_called_once_with( mock_resolve.assert_called_once_with(
@@ -63,8 +61,8 @@ def test_async_resolver_multiple_hosts(
"esphome.resolver.hr.async_resolve_host", "esphome.resolver.hr.async_resolve_host",
return_value=mock_results, return_value=mock_results,
) as mock_resolve: ) as mock_resolve:
resolver = AsyncResolver() resolver = AsyncResolver(["test1.local", "test2.local"], 6053)
result = resolver.run(["test1.local", "test2.local"], 6053) result = resolver.resolve()
assert result == mock_results assert result == mock_results
mock_resolve.assert_called_once_with( mock_resolve.assert_called_once_with(
@@ -79,11 +77,11 @@ def test_async_resolver_resolve_api_error() -> None:
"esphome.resolver.hr.async_resolve_host", "esphome.resolver.hr.async_resolve_host",
side_effect=ResolveAPIError(error_msg), side_effect=ResolveAPIError(error_msg),
): ):
resolver = AsyncResolver() resolver = AsyncResolver(["test.local"], 6053)
with pytest.raises( with pytest.raises(
EsphomeError, match=re.escape(f"Error resolving IP address: {error_msg}") EsphomeError, match=re.escape(f"Error resolving IP address: {error_msg}")
): ):
resolver.run(["test.local"], 6053) resolver.resolve()
def test_async_resolver_timeout_error() -> None: def test_async_resolver_timeout_error() -> None:
@@ -94,14 +92,14 @@ def test_async_resolver_timeout_error() -> None:
"esphome.resolver.hr.async_resolve_host", "esphome.resolver.hr.async_resolve_host",
side_effect=ResolveTimeoutAPIError(error_msg), side_effect=ResolveTimeoutAPIError(error_msg),
): ):
resolver = AsyncResolver() resolver = AsyncResolver(["test.local"], 6053)
# Match either "Timeout" or "Error" since ResolveTimeoutAPIError is a subclass of ResolveAPIError # Match either "Timeout" or "Error" since ResolveTimeoutAPIError is a subclass of ResolveAPIError
# and depending on import order/test execution context, it might be caught as either # and depending on import order/test execution context, it might be caught as either
with pytest.raises( with pytest.raises(
EsphomeError, EsphomeError,
match=f"(Timeout|Error) resolving IP address: {re.escape(error_msg)}", match=f"(Timeout|Error) resolving IP address: {re.escape(error_msg)}",
): ):
resolver.run(["test.local"], 6053) resolver.resolve()
def test_async_resolver_generic_exception() -> None: def test_async_resolver_generic_exception() -> None:
@@ -111,34 +109,30 @@ def test_async_resolver_generic_exception() -> None:
"esphome.resolver.hr.async_resolve_host", "esphome.resolver.hr.async_resolve_host",
side_effect=error, side_effect=error,
): ):
resolver = AsyncResolver() resolver = AsyncResolver(["test.local"], 6053)
with pytest.raises(RuntimeError, match="Unexpected error"): with pytest.raises(RuntimeError, match="Unexpected error"):
resolver.run(["test.local"], 6053) resolver.resolve()
def test_async_resolver_thread_timeout() -> None: def test_async_resolver_thread_timeout() -> None:
"""Test timeout when thread doesn't complete in time.""" """Test timeout when thread doesn't complete in time."""
# Use an event to control when the async function completes # Mock the start method to prevent actual thread execution
test_event = threading.Event() with (
patch.object(AsyncResolver, "start"),
async def slow_resolve(hosts, port, timeout): patch("esphome.resolver.hr.async_resolve_host"),
# Wait for the test to signal completion ):
await asyncio.get_event_loop().run_in_executor(None, test_event.wait, 0.5) resolver = AsyncResolver(["test.local"], 6053)
return [] # Override event.wait to simulate timeout (return False = timeout occurred)
with patch("esphome.resolver.hr.async_resolve_host", slow_resolve):
resolver = AsyncResolver()
# Override event.wait to simulate timeout
with ( with (
patch.object(resolver.event, "wait", return_value=False), patch.object(resolver.event, "wait", return_value=False),
pytest.raises( pytest.raises(
EsphomeError, match=re.escape("Timeout resolving IP address") EsphomeError, match=re.escape("Timeout resolving IP address")
), ),
): ):
resolver.run(["test.local"], 6053) resolver.resolve()
# Signal the async function to complete and give it time to clean up # Verify thread start was called
test_event.set() resolver.start.assert_called_once()
def test_async_resolver_ip_addresses(mock_addr_info_ipv4: AddrInfo) -> None: def test_async_resolver_ip_addresses(mock_addr_info_ipv4: AddrInfo) -> None:
@@ -147,8 +141,8 @@ def test_async_resolver_ip_addresses(mock_addr_info_ipv4: AddrInfo) -> None:
"esphome.resolver.hr.async_resolve_host", "esphome.resolver.hr.async_resolve_host",
return_value=[mock_addr_info_ipv4], return_value=[mock_addr_info_ipv4],
) as mock_resolve: ) as mock_resolve:
resolver = AsyncResolver() resolver = AsyncResolver(["192.168.1.100"], 6053)
result = resolver.run(["192.168.1.100"], 6053) result = resolver.resolve()
assert result == [mock_addr_info_ipv4] assert result == [mock_addr_info_ipv4]
mock_resolve.assert_called_once_with( mock_resolve.assert_called_once_with(
@@ -166,8 +160,8 @@ def test_async_resolver_mixed_addresses(
"esphome.resolver.hr.async_resolve_host", "esphome.resolver.hr.async_resolve_host",
return_value=mock_results, return_value=mock_results,
) as mock_resolve: ) as mock_resolve:
resolver = AsyncResolver() resolver = AsyncResolver(["test.local", "192.168.1.100", "::1"], 6053)
result = resolver.run(["test.local", "192.168.1.100", "::1"], 6053) result = resolver.resolve()
assert result == mock_results assert result == mock_results
mock_resolve.assert_called_once_with( mock_resolve.assert_called_once_with(