mirror of
https://github.com/esphome/esphome.git
synced 2025-09-15 09:42:19 +01:00
redesign
This commit is contained in:
@@ -209,8 +209,8 @@ def resolve_ip_address(host: str | list[str], port: int) -> list[AddrInfo]:
|
||||
|
||||
from esphome.resolver import AsyncResolver
|
||||
|
||||
resolver = AsyncResolver()
|
||||
addr_infos = resolver.run(hosts, port)
|
||||
resolver = AsyncResolver(hosts, port)
|
||||
addr_infos = resolver.resolve()
|
||||
# Convert aioesphomeapi AddrInfo to our format
|
||||
for addr_info in addr_infos:
|
||||
sockaddr = addr_info.sockaddr
|
||||
|
@@ -13,7 +13,7 @@ from esphome.core import EsphomeError
|
||||
RESOLVE_TIMEOUT = 10.0 # seconds
|
||||
|
||||
|
||||
class AsyncResolver:
|
||||
class AsyncResolver(threading.Thread):
|
||||
"""Resolver using aioesphomeapi that runs in a thread for faster results.
|
||||
|
||||
This resolver uses aioesphomeapi's async_resolve_host to handle DNS resolution,
|
||||
@@ -22,17 +22,20 @@ class AsyncResolver:
|
||||
cleanup cycle, which can take significant time.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, hosts: list[str], port: int) -> None:
|
||||
"""Initialize the resolver."""
|
||||
super().__init__(daemon=True)
|
||||
self.hosts = hosts
|
||||
self.port = port
|
||||
self.result: list[hr.AddrInfo] | None = None
|
||||
self.exception: Exception | None = None
|
||||
self.event = threading.Event()
|
||||
|
||||
async def _resolve(self, hosts: list[str], port: int) -> None:
|
||||
async def _resolve(self) -> None:
|
||||
"""Resolve hostnames to IP addresses."""
|
||||
try:
|
||||
self.result = await hr.async_resolve_host(
|
||||
hosts, port, timeout=RESOLVE_TIMEOUT
|
||||
self.hosts, self.port, timeout=RESOLVE_TIMEOUT
|
||||
)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
# We need to catch all exceptions to ensure the event is set
|
||||
@@ -41,12 +44,13 @@ class AsyncResolver:
|
||||
finally:
|
||||
self.event.set()
|
||||
|
||||
def run(self, hosts: list[str], port: int) -> list[hr.AddrInfo]:
|
||||
"""Run the DNS resolution in a separate thread."""
|
||||
thread = threading.Thread(
|
||||
target=lambda: asyncio.run(self._resolve(hosts, port)), daemon=True
|
||||
)
|
||||
thread.start()
|
||||
def run(self) -> None:
|
||||
"""Run the DNS resolution."""
|
||||
asyncio.run(self._resolve())
|
||||
|
||||
def resolve(self) -> list[hr.AddrInfo]:
|
||||
"""Start the thread and wait for the result."""
|
||||
self.start()
|
||||
|
||||
if not self.event.wait(
|
||||
timeout=RESOLVE_TIMEOUT + 1.0
|
||||
|
@@ -423,14 +423,15 @@ def test_resolve_ip_address_hostname() -> None:
|
||||
|
||||
with patch("esphome.resolver.AsyncResolver") as MockResolver:
|
||||
mock_resolver = MockResolver.return_value
|
||||
mock_resolver.run.return_value = [mock_addr_info]
|
||||
mock_resolver.resolve.return_value = [mock_addr_info]
|
||||
|
||||
result = helpers.resolve_ip_address("test.local", 6053)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == socket.AF_INET
|
||||
assert result[0][4] == ("192.168.1.100", 6053)
|
||||
mock_resolver.run.assert_called_once_with(["test.local"], 6053)
|
||||
MockResolver.assert_called_once_with(["test.local"], 6053)
|
||||
mock_resolver.resolve.assert_called_once()
|
||||
|
||||
|
||||
def test_resolve_ip_address_mixed_list() -> None:
|
||||
@@ -444,14 +445,15 @@ def test_resolve_ip_address_mixed_list() -> None:
|
||||
|
||||
with patch("esphome.resolver.AsyncResolver") as MockResolver:
|
||||
mock_resolver = MockResolver.return_value
|
||||
mock_resolver.run.return_value = [mock_addr_info]
|
||||
mock_resolver.resolve.return_value = [mock_addr_info]
|
||||
|
||||
# Mix of IP and hostname - should use async resolver
|
||||
result = helpers.resolve_ip_address(["192.168.1.100", "test.local"], 6053)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0][4][0] == "192.168.1.200"
|
||||
mock_resolver.run.assert_called_once_with(["192.168.1.100", "test.local"], 6053)
|
||||
MockResolver.assert_called_once_with(["192.168.1.100", "test.local"], 6053)
|
||||
mock_resolver.resolve.assert_called_once()
|
||||
|
||||
|
||||
def test_resolve_ip_address_url() -> None:
|
||||
@@ -465,12 +467,13 @@ def test_resolve_ip_address_url() -> None:
|
||||
|
||||
with patch("esphome.resolver.AsyncResolver") as MockResolver:
|
||||
mock_resolver = MockResolver.return_value
|
||||
mock_resolver.run.return_value = [mock_addr_info]
|
||||
mock_resolver.resolve.return_value = [mock_addr_info]
|
||||
|
||||
result = helpers.resolve_ip_address("http://test.local", 6053)
|
||||
|
||||
assert len(result) == 1
|
||||
mock_resolver.run.assert_called_once_with(["test.local"], 6053)
|
||||
MockResolver.assert_called_once_with(["test.local"], 6053)
|
||||
mock_resolver.resolve.assert_called_once()
|
||||
|
||||
|
||||
def test_resolve_ip_address_ipv6_conversion() -> None:
|
||||
@@ -484,7 +487,7 @@ def test_resolve_ip_address_ipv6_conversion() -> None:
|
||||
|
||||
with patch("esphome.resolver.AsyncResolver") as MockResolver:
|
||||
mock_resolver = MockResolver.return_value
|
||||
mock_resolver.run.return_value = [mock_addr_info]
|
||||
mock_resolver.resolve.return_value = [mock_addr_info]
|
||||
|
||||
result = helpers.resolve_ip_address("test.local", 6053)
|
||||
|
||||
@@ -497,7 +500,7 @@ def test_resolve_ip_address_error_handling() -> None:
|
||||
"""Test error handling from AsyncResolver."""
|
||||
with patch("esphome.resolver.AsyncResolver") as MockResolver:
|
||||
mock_resolver = MockResolver.return_value
|
||||
mock_resolver.run.side_effect = EsphomeError("Resolution failed")
|
||||
mock_resolver.resolve.side_effect = EsphomeError("Resolution failed")
|
||||
|
||||
with pytest.raises(EsphomeError, match="Resolution failed"):
|
||||
helpers.resolve_ip_address("test.local", 6053)
|
||||
@@ -583,7 +586,7 @@ def test_resolve_ip_address_sorting() -> None:
|
||||
|
||||
with patch("esphome.resolver.AsyncResolver") as MockResolver:
|
||||
mock_resolver = MockResolver.return_value
|
||||
mock_resolver.run.return_value = mock_addr_infos
|
||||
mock_resolver.resolve.return_value = mock_addr_infos
|
||||
|
||||
result = helpers.resolve_ip_address("test.local", 6053)
|
||||
|
||||
|
@@ -2,10 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import socket
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
from aioesphomeapi.core import ResolveAPIError, ResolveTimeoutAPIError
|
||||
@@ -44,8 +42,8 @@ def test_async_resolver_successful_resolution(mock_addr_info_ipv4: AddrInfo) ->
|
||||
"esphome.resolver.hr.async_resolve_host",
|
||||
return_value=[mock_addr_info_ipv4],
|
||||
) as mock_resolve:
|
||||
resolver = AsyncResolver()
|
||||
result = resolver.run(["test.local"], 6053)
|
||||
resolver = AsyncResolver(["test.local"], 6053)
|
||||
result = resolver.resolve()
|
||||
|
||||
assert result == [mock_addr_info_ipv4]
|
||||
mock_resolve.assert_called_once_with(
|
||||
@@ -63,8 +61,8 @@ def test_async_resolver_multiple_hosts(
|
||||
"esphome.resolver.hr.async_resolve_host",
|
||||
return_value=mock_results,
|
||||
) as mock_resolve:
|
||||
resolver = AsyncResolver()
|
||||
result = resolver.run(["test1.local", "test2.local"], 6053)
|
||||
resolver = AsyncResolver(["test1.local", "test2.local"], 6053)
|
||||
result = resolver.resolve()
|
||||
|
||||
assert result == mock_results
|
||||
mock_resolve.assert_called_once_with(
|
||||
@@ -79,11 +77,11 @@ def test_async_resolver_resolve_api_error() -> None:
|
||||
"esphome.resolver.hr.async_resolve_host",
|
||||
side_effect=ResolveAPIError(error_msg),
|
||||
):
|
||||
resolver = AsyncResolver()
|
||||
resolver = AsyncResolver(["test.local"], 6053)
|
||||
with pytest.raises(
|
||||
EsphomeError, match=re.escape(f"Error resolving IP address: {error_msg}")
|
||||
):
|
||||
resolver.run(["test.local"], 6053)
|
||||
resolver.resolve()
|
||||
|
||||
|
||||
def test_async_resolver_timeout_error() -> None:
|
||||
@@ -94,14 +92,14 @@ def test_async_resolver_timeout_error() -> None:
|
||||
"esphome.resolver.hr.async_resolve_host",
|
||||
side_effect=ResolveTimeoutAPIError(error_msg),
|
||||
):
|
||||
resolver = AsyncResolver()
|
||||
resolver = AsyncResolver(["test.local"], 6053)
|
||||
# Match either "Timeout" or "Error" since ResolveTimeoutAPIError is a subclass of ResolveAPIError
|
||||
# and depending on import order/test execution context, it might be caught as either
|
||||
with pytest.raises(
|
||||
EsphomeError,
|
||||
match=f"(Timeout|Error) resolving IP address: {re.escape(error_msg)}",
|
||||
):
|
||||
resolver.run(["test.local"], 6053)
|
||||
resolver.resolve()
|
||||
|
||||
|
||||
def test_async_resolver_generic_exception() -> None:
|
||||
@@ -111,34 +109,30 @@ def test_async_resolver_generic_exception() -> None:
|
||||
"esphome.resolver.hr.async_resolve_host",
|
||||
side_effect=error,
|
||||
):
|
||||
resolver = AsyncResolver()
|
||||
resolver = AsyncResolver(["test.local"], 6053)
|
||||
with pytest.raises(RuntimeError, match="Unexpected error"):
|
||||
resolver.run(["test.local"], 6053)
|
||||
resolver.resolve()
|
||||
|
||||
|
||||
def test_async_resolver_thread_timeout() -> None:
|
||||
"""Test timeout when thread doesn't complete in time."""
|
||||
# Use an event to control when the async function completes
|
||||
test_event = threading.Event()
|
||||
|
||||
async def slow_resolve(hosts, port, timeout):
|
||||
# Wait for the test to signal completion
|
||||
await asyncio.get_event_loop().run_in_executor(None, test_event.wait, 0.5)
|
||||
return []
|
||||
|
||||
with patch("esphome.resolver.hr.async_resolve_host", slow_resolve):
|
||||
resolver = AsyncResolver()
|
||||
# Override event.wait to simulate timeout
|
||||
# Mock the start method to prevent actual thread execution
|
||||
with (
|
||||
patch.object(AsyncResolver, "start"),
|
||||
patch("esphome.resolver.hr.async_resolve_host"),
|
||||
):
|
||||
resolver = AsyncResolver(["test.local"], 6053)
|
||||
# Override event.wait to simulate timeout (return False = timeout occurred)
|
||||
with (
|
||||
patch.object(resolver.event, "wait", return_value=False),
|
||||
pytest.raises(
|
||||
EsphomeError, match=re.escape("Timeout resolving IP address")
|
||||
),
|
||||
):
|
||||
resolver.run(["test.local"], 6053)
|
||||
resolver.resolve()
|
||||
|
||||
# Signal the async function to complete and give it time to clean up
|
||||
test_event.set()
|
||||
# Verify thread start was called
|
||||
resolver.start.assert_called_once()
|
||||
|
||||
|
||||
def test_async_resolver_ip_addresses(mock_addr_info_ipv4: AddrInfo) -> None:
|
||||
@@ -147,8 +141,8 @@ def test_async_resolver_ip_addresses(mock_addr_info_ipv4: AddrInfo) -> None:
|
||||
"esphome.resolver.hr.async_resolve_host",
|
||||
return_value=[mock_addr_info_ipv4],
|
||||
) as mock_resolve:
|
||||
resolver = AsyncResolver()
|
||||
result = resolver.run(["192.168.1.100"], 6053)
|
||||
resolver = AsyncResolver(["192.168.1.100"], 6053)
|
||||
result = resolver.resolve()
|
||||
|
||||
assert result == [mock_addr_info_ipv4]
|
||||
mock_resolve.assert_called_once_with(
|
||||
@@ -166,8 +160,8 @@ def test_async_resolver_mixed_addresses(
|
||||
"esphome.resolver.hr.async_resolve_host",
|
||||
return_value=mock_results,
|
||||
) as mock_resolve:
|
||||
resolver = AsyncResolver()
|
||||
result = resolver.run(["test.local", "192.168.1.100", "::1"], 6053)
|
||||
resolver = AsyncResolver(["test.local", "192.168.1.100", "::1"], 6053)
|
||||
result = resolver.resolve()
|
||||
|
||||
assert result == mock_results
|
||||
mock_resolve.assert_called_once_with(
|
||||
|
Reference in New Issue
Block a user