diff --git a/esphome/helpers.py b/esphome/helpers.py index b00c97ff73..6beaa24a96 100644 --- a/esphome/helpers.py +++ b/esphome/helpers.py @@ -209,8 +209,8 @@ def resolve_ip_address(host: str | list[str], port: int) -> list[AddrInfo]: from esphome.resolver import AsyncResolver - resolver = AsyncResolver() - addr_infos = resolver.run(hosts, port) + resolver = AsyncResolver(hosts, port) + addr_infos = resolver.resolve() # Convert aioesphomeapi AddrInfo to our format for addr_info in addr_infos: sockaddr = addr_info.sockaddr diff --git a/esphome/resolver.py b/esphome/resolver.py index 24972a456f..99482aa20e 100644 --- a/esphome/resolver.py +++ b/esphome/resolver.py @@ -13,7 +13,7 @@ from esphome.core import EsphomeError RESOLVE_TIMEOUT = 10.0 # seconds -class AsyncResolver: +class AsyncResolver(threading.Thread): """Resolver using aioesphomeapi that runs in a thread for faster results. This resolver uses aioesphomeapi's async_resolve_host to handle DNS resolution, @@ -22,17 +22,20 @@ class AsyncResolver: cleanup cycle, which can take significant time. """ - def __init__(self) -> None: + def __init__(self, hosts: list[str], port: int) -> None: """Initialize the resolver.""" + super().__init__(daemon=True) + self.hosts = hosts + self.port = port self.result: list[hr.AddrInfo] | None = None self.exception: Exception | None = None self.event = threading.Event() - async def _resolve(self, hosts: list[str], port: int) -> None: + async def _resolve(self) -> None: """Resolve hostnames to IP addresses.""" try: self.result = await hr.async_resolve_host( - hosts, port, timeout=RESOLVE_TIMEOUT + self.hosts, self.port, timeout=RESOLVE_TIMEOUT ) except Exception as e: # pylint: disable=broad-except # We need to catch all exceptions to ensure the event is set @@ -41,12 +44,13 @@ class AsyncResolver: finally: self.event.set() - def run(self, hosts: list[str], port: int) -> list[hr.AddrInfo]: - """Run the DNS resolution in a separate thread.""" - thread = threading.Thread( - target=lambda: asyncio.run(self._resolve(hosts, port)), daemon=True - ) - thread.start() + def run(self) -> None: + """Run the DNS resolution.""" + asyncio.run(self._resolve()) + + def resolve(self) -> list[hr.AddrInfo]: + """Start the thread and wait for the result.""" + self.start() if not self.event.wait( timeout=RESOLVE_TIMEOUT + 1.0 diff --git a/tests/unit_tests/test_helpers.py b/tests/unit_tests/test_helpers.py index 9a052ad9c2..9f51206ff9 100644 --- a/tests/unit_tests/test_helpers.py +++ b/tests/unit_tests/test_helpers.py @@ -423,14 +423,15 @@ def test_resolve_ip_address_hostname() -> None: with patch("esphome.resolver.AsyncResolver") as MockResolver: mock_resolver = MockResolver.return_value - mock_resolver.run.return_value = [mock_addr_info] + mock_resolver.resolve.return_value = [mock_addr_info] result = helpers.resolve_ip_address("test.local", 6053) assert len(result) == 1 assert result[0][0] == socket.AF_INET assert result[0][4] == ("192.168.1.100", 6053) - mock_resolver.run.assert_called_once_with(["test.local"], 6053) + MockResolver.assert_called_once_with(["test.local"], 6053) + mock_resolver.resolve.assert_called_once() def test_resolve_ip_address_mixed_list() -> None: @@ -444,14 +445,15 @@ def test_resolve_ip_address_mixed_list() -> None: with patch("esphome.resolver.AsyncResolver") as MockResolver: mock_resolver = MockResolver.return_value - mock_resolver.run.return_value = [mock_addr_info] + mock_resolver.resolve.return_value = [mock_addr_info] # Mix of IP and hostname - should use async resolver result = helpers.resolve_ip_address(["192.168.1.100", "test.local"], 6053) assert len(result) == 1 assert result[0][4][0] == "192.168.1.200" - mock_resolver.run.assert_called_once_with(["192.168.1.100", "test.local"], 6053) + MockResolver.assert_called_once_with(["192.168.1.100", "test.local"], 6053) + mock_resolver.resolve.assert_called_once() def test_resolve_ip_address_url() -> None: @@ -465,12 +467,13 @@ def test_resolve_ip_address_url() -> None: with patch("esphome.resolver.AsyncResolver") as MockResolver: mock_resolver = MockResolver.return_value - mock_resolver.run.return_value = [mock_addr_info] + mock_resolver.resolve.return_value = [mock_addr_info] result = helpers.resolve_ip_address("http://test.local", 6053) assert len(result) == 1 - mock_resolver.run.assert_called_once_with(["test.local"], 6053) + MockResolver.assert_called_once_with(["test.local"], 6053) + mock_resolver.resolve.assert_called_once() def test_resolve_ip_address_ipv6_conversion() -> None: @@ -484,7 +487,7 @@ def test_resolve_ip_address_ipv6_conversion() -> None: with patch("esphome.resolver.AsyncResolver") as MockResolver: mock_resolver = MockResolver.return_value - mock_resolver.run.return_value = [mock_addr_info] + mock_resolver.resolve.return_value = [mock_addr_info] result = helpers.resolve_ip_address("test.local", 6053) @@ -497,7 +500,7 @@ def test_resolve_ip_address_error_handling() -> None: """Test error handling from AsyncResolver.""" with patch("esphome.resolver.AsyncResolver") as MockResolver: mock_resolver = MockResolver.return_value - mock_resolver.run.side_effect = EsphomeError("Resolution failed") + mock_resolver.resolve.side_effect = EsphomeError("Resolution failed") with pytest.raises(EsphomeError, match="Resolution failed"): helpers.resolve_ip_address("test.local", 6053) @@ -583,7 +586,7 @@ def test_resolve_ip_address_sorting() -> None: with patch("esphome.resolver.AsyncResolver") as MockResolver: mock_resolver = MockResolver.return_value - mock_resolver.run.return_value = mock_addr_infos + mock_resolver.resolve.return_value = mock_addr_infos result = helpers.resolve_ip_address("test.local", 6053) diff --git a/tests/unit_tests/test_resolver.py b/tests/unit_tests/test_resolver.py index 0dbe89b206..b4cca05d9f 100644 --- a/tests/unit_tests/test_resolver.py +++ b/tests/unit_tests/test_resolver.py @@ -2,10 +2,8 @@ from __future__ import annotations -import asyncio import re import socket -import threading from unittest.mock import patch from aioesphomeapi.core import ResolveAPIError, ResolveTimeoutAPIError @@ -44,8 +42,8 @@ def test_async_resolver_successful_resolution(mock_addr_info_ipv4: AddrInfo) -> "esphome.resolver.hr.async_resolve_host", return_value=[mock_addr_info_ipv4], ) as mock_resolve: - resolver = AsyncResolver() - result = resolver.run(["test.local"], 6053) + resolver = AsyncResolver(["test.local"], 6053) + result = resolver.resolve() assert result == [mock_addr_info_ipv4] mock_resolve.assert_called_once_with( @@ -63,8 +61,8 @@ def test_async_resolver_multiple_hosts( "esphome.resolver.hr.async_resolve_host", return_value=mock_results, ) as mock_resolve: - resolver = AsyncResolver() - result = resolver.run(["test1.local", "test2.local"], 6053) + resolver = AsyncResolver(["test1.local", "test2.local"], 6053) + result = resolver.resolve() assert result == mock_results mock_resolve.assert_called_once_with( @@ -79,11 +77,11 @@ def test_async_resolver_resolve_api_error() -> None: "esphome.resolver.hr.async_resolve_host", side_effect=ResolveAPIError(error_msg), ): - resolver = AsyncResolver() + resolver = AsyncResolver(["test.local"], 6053) with pytest.raises( EsphomeError, match=re.escape(f"Error resolving IP address: {error_msg}") ): - resolver.run(["test.local"], 6053) + resolver.resolve() def test_async_resolver_timeout_error() -> None: @@ -94,14 +92,14 @@ def test_async_resolver_timeout_error() -> None: "esphome.resolver.hr.async_resolve_host", side_effect=ResolveTimeoutAPIError(error_msg), ): - resolver = AsyncResolver() + resolver = AsyncResolver(["test.local"], 6053) # Match either "Timeout" or "Error" since ResolveTimeoutAPIError is a subclass of ResolveAPIError # and depending on import order/test execution context, it might be caught as either with pytest.raises( EsphomeError, match=f"(Timeout|Error) resolving IP address: {re.escape(error_msg)}", ): - resolver.run(["test.local"], 6053) + resolver.resolve() def test_async_resolver_generic_exception() -> None: @@ -111,34 +109,30 @@ def test_async_resolver_generic_exception() -> None: "esphome.resolver.hr.async_resolve_host", side_effect=error, ): - resolver = AsyncResolver() + resolver = AsyncResolver(["test.local"], 6053) with pytest.raises(RuntimeError, match="Unexpected error"): - resolver.run(["test.local"], 6053) + resolver.resolve() def test_async_resolver_thread_timeout() -> None: """Test timeout when thread doesn't complete in time.""" - # Use an event to control when the async function completes - test_event = threading.Event() - - async def slow_resolve(hosts, port, timeout): - # Wait for the test to signal completion - await asyncio.get_event_loop().run_in_executor(None, test_event.wait, 0.5) - return [] - - with patch("esphome.resolver.hr.async_resolve_host", slow_resolve): - resolver = AsyncResolver() - # Override event.wait to simulate timeout + # Mock the start method to prevent actual thread execution + with ( + patch.object(AsyncResolver, "start"), + patch("esphome.resolver.hr.async_resolve_host"), + ): + resolver = AsyncResolver(["test.local"], 6053) + # Override event.wait to simulate timeout (return False = timeout occurred) with ( patch.object(resolver.event, "wait", return_value=False), pytest.raises( EsphomeError, match=re.escape("Timeout resolving IP address") ), ): - resolver.run(["test.local"], 6053) + resolver.resolve() - # Signal the async function to complete and give it time to clean up - test_event.set() + # Verify thread start was called + resolver.start.assert_called_once() def test_async_resolver_ip_addresses(mock_addr_info_ipv4: AddrInfo) -> None: @@ -147,8 +141,8 @@ def test_async_resolver_ip_addresses(mock_addr_info_ipv4: AddrInfo) -> None: "esphome.resolver.hr.async_resolve_host", return_value=[mock_addr_info_ipv4], ) as mock_resolve: - resolver = AsyncResolver() - result = resolver.run(["192.168.1.100"], 6053) + resolver = AsyncResolver(["192.168.1.100"], 6053) + result = resolver.resolve() assert result == [mock_addr_info_ipv4] mock_resolve.assert_called_once_with( @@ -166,8 +160,8 @@ def test_async_resolver_mixed_addresses( "esphome.resolver.hr.async_resolve_host", return_value=mock_results, ) as mock_resolve: - resolver = AsyncResolver() - result = resolver.run(["test.local", "192.168.1.100", "::1"], 6053) + resolver = AsyncResolver(["test.local", "192.168.1.100", "::1"], 6053) + result = resolver.resolve() assert result == mock_results mock_resolve.assert_called_once_with(