diff --git a/esphome/dashboard/dashboard.py b/esphome/dashboard/dashboard.py index dbd0a6629d..9b39aca973 100644 --- a/esphome/dashboard/dashboard.py +++ b/esphome/dashboard/dashboard.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import base64 import binascii import codecs @@ -47,10 +48,10 @@ from esphome.storage_json import ( from esphome.util import get_serial_ports, shlex_quote from esphome.zeroconf import ( ESPHOME_SERVICE_TYPE, + AsyncEsphomeZeroconf, DashboardBrowser, DashboardImportDiscovery, DashboardStatus, - EsphomeZeroconf, ) from .util import friendly_name_slugify, password_hash @@ -975,13 +976,13 @@ class BoardsRequestHandler(BaseHandler): self.write(json.dumps(output)) -class MDNSStatusThread(threading.Thread): - """Thread that updates the mdns status.""" +class MDNSStatus: + """Class that updates the mdns status.""" def __init__(self) -> None: - """Initialize the MDNSStatusThread.""" + """Initialize the MDNSStatus class.""" super().__init__() - self.zeroconf: EsphomeZeroconf | None = None + self.aiozc: AsyncEsphomeZeroconf | None = None # This is the current mdns state for each host (True, False, None) self.host_mdns_state: dict[str, bool | None] = {} # This is the hostnames to filenames mapping @@ -997,10 +998,10 @@ class MDNSStatusThread(threading.Thread): async def async_resolve_host(self, host_name: str) -> str | None: """Resolve a host name to an address in a thread-safe manner.""" - if zc := self.zeroconf: + if aiozc := self.aiozc: # Currently we do not do any I/O and only # return the cached result (timeout=0) - return await zc.async_resolve_host(host_name) + return await aiozc.async_resolve_host(host_name) return None def _refresh_hosts(self): @@ -1034,11 +1035,11 @@ class MDNSStatusThread(threading.Thread): host_name_to_filename[name] = filename filename_to_host_name[filename] = name - def run(self): + async def async_run(self) -> None: global IMPORT_RESULT - zc = EsphomeZeroconf() - self.zeroconf = zc + aiozc = AsyncEsphomeZeroconf() + self.aiozc = aiozc host_mdns_state = self.host_mdns_state host_name_to_filename = self.host_name_to_filename host_name_with_mdns_enabled = self.host_name_with_mdns_enabled @@ -1055,18 +1056,20 @@ class MDNSStatusThread(threading.Thread): stat = DashboardStatus(on_update) imports = DashboardImportDiscovery() browser = DashboardBrowser( - zc, ESPHOME_SERVICE_TYPE, [stat.browser_callback, imports.browser_callback] + aiozc.zeroconf, + ESPHOME_SERVICE_TYPE, + [stat.browser_callback, imports.browser_callback], ) while not STOP_EVENT.is_set(): self._refresh_hosts() IMPORT_RESULT = imports.import_state - PING_REQUEST.wait() - PING_REQUEST.clear() + await PING_REQUEST.async_wait() + PING_REQUEST.async_clear() - browser.cancel() - zc.close() - self.zeroconf = None + await browser.async_cancel() + await aiozc.async_close() + self.aiozc = None class PingStatusThread(threading.Thread): @@ -1246,22 +1249,69 @@ class UndoDeleteRequestHandler(BaseHandler): class MDNSContainer: def __init__(self) -> None: """Initialize the MDNSContainer.""" - self._mdns: MDNSStatusThread | None = None + self._mdns: MDNSStatus | None = None - def set_mdns(self, mdns: MDNSStatusThread) -> None: - """Set the MDNSStatusThread instance.""" + def set_mdns(self, mdns: MDNSStatus) -> None: + """Set the MDNSStatus instance.""" self._mdns = mdns - def get_mdns(self) -> MDNSStatusThread | None: - """Return the MDNSStatusThread instance.""" + def get_mdns(self) -> MDNSStatus | None: + """Return the MDNSStatus instance.""" return self._mdns +class ThreadedAsyncEvent: + def __init__(self) -> None: + """Initialize the ThreadedAsyncEvent.""" + self.event = threading.Event() + self.async_event: asyncio.Event | None = None + self.loop: asyncio.AbstractEventLoop | None = None + + def async_setup( + self, loop: asyncio.AbstractEventLoop, async_event: asyncio.Event + ) -> None: + """Set the asyncio.Event instance.""" + self.loop = loop + self.async_event = async_event + + def async_set(self) -> None: + """Set the asyncio.Event instance.""" + self.async_event.set() + self.event.set() + + def set(self) -> None: + """Set the event.""" + self.loop.call_soon_threadsafe(self.async_event.set) + self.event.set() + + def wait(self) -> None: + """Wait for the event.""" + self.event.wait() + + async def async_wait(self) -> None: + """Wait the event async.""" + await self.async_event.wait() + + def clear(self) -> None: + """Clear the event.""" + self.loop.call_soon_threadsafe(self.async_event.clear) + self.event.clear() + + def async_clear(self) -> None: + """Clear the event async.""" + self.async_event.clear() + self.event.clear() + + def is_set(self) -> bool: + """Return if the event is set.""" + return self.event.is_set() + + PING_RESULT: dict = {} IMPORT_RESULT = {} -STOP_EVENT = threading.Event() -PING_REQUEST = threading.Event() -MQTT_PING_REQUEST = threading.Event() +STOP_EVENT = ThreadedAsyncEvent() +PING_REQUEST = ThreadedAsyncEvent() +MQTT_PING_REQUEST = ThreadedAsyncEvent() MDNS_CONTAINER = MDNSContainer() @@ -1525,6 +1575,18 @@ def start_web_server(args): storage.save(path) settings.cookie_secret = storage.cookie_secret + try: + asyncio.run(async_start_web_server(args)) + except KeyboardInterrupt: + pass + + +async def async_start_web_server(args): + loop = asyncio.get_event_loop() + PING_REQUEST.async_setup(loop, asyncio.Event()) + MQTT_PING_REQUEST.async_setup(loop, asyncio.Event()) + STOP_EVENT.async_setup(loop, asyncio.Event()) + app = make_app(args.verbose) if args.socket is not None: _LOGGER.info( @@ -1549,27 +1611,36 @@ def start_web_server(args): webbrowser.open(f"http://{args.address}:{args.port}") + mdns_task: asyncio.Task | None = None + ping_status_thread: PingStatusThread | None = None if settings.status_use_ping: - status_thread = PingStatusThread() + ping_status_thread = PingStatusThread() + ping_status_thread.start() else: - status_thread = MDNSStatusThread() - MDNS_CONTAINER.set_mdns(status_thread) - status_thread.start() + mdns_status = MDNSStatus() + MDNS_CONTAINER.set_mdns(mdns_status) + mdns_task = asyncio.create_task(mdns_status.async_run()) if settings.status_use_mqtt: status_thread_mqtt = MqttStatusThread() status_thread_mqtt.start() + shutdown_event = asyncio.Event() try: - tornado.ioloop.IOLoop.current().start() + await shutdown_event.wait() except KeyboardInterrupt: + raise + finally: _LOGGER.info("Shutting down...") STOP_EVENT.set() PING_REQUEST.set() - status_thread.join() + if ping_status_thread: + ping_status_thread.join() MDNS_CONTAINER.set_mdns(None) + mdns_task.cancel() if settings.status_use_mqtt: status_thread_mqtt.join() MQTT_PING_REQUEST.set() if args.socket is not None: os.remove(args.socket) + await asyncio.sleep(0) diff --git a/esphome/zeroconf.py b/esphome/zeroconf.py index 5e4fc70d00..956e348e07 100644 --- a/esphome/zeroconf.py +++ b/esphome/zeroconf.py @@ -1,22 +1,21 @@ from __future__ import annotations +import asyncio import logging from dataclasses import dataclass from typing import Callable -from zeroconf import ( - IPVersion, - ServiceBrowser, - ServiceInfo, - ServiceStateChange, - Zeroconf, -) +from zeroconf import IPVersion, ServiceInfo, ServiceStateChange, Zeroconf +from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf from esphome.storage_json import StorageJSON, ext_storage_path _LOGGER = logging.getLogger(__name__) +_BACKGROUND_TASKS: set[asyncio.Task] = set() + + class HostResolver(ServiceInfo): """Resolve a host name to an IP address.""" @@ -65,7 +64,7 @@ class DiscoveredImport: network: str -class DashboardBrowser(ServiceBrowser): +class DashboardBrowser(AsyncServiceBrowser): """A class to browse for ESPHome nodes.""" @@ -94,7 +93,28 @@ class DashboardImportDiscovery: # Ignore updates for devices that are not in the import state return - info = zeroconf.get_service_info(service_type, name) + info = AsyncServiceInfo( + service_type, + name, + ) + if info.load_from_cache(zeroconf): + self._process_service_info(name, info) + return + task = asyncio.create_task( + self._async_process_service_info(zeroconf, info, service_type, name) + ) + _BACKGROUND_TASKS.add(task) + task.add_done_callback(_BACKGROUND_TASKS.discard) + + async def _async_process_service_info( + self, zeroconf: Zeroconf, info: AsyncServiceInfo, service_type: str, name: str + ) -> None: + """Process a service info.""" + if await info.async_request(zeroconf): + self._process_service_info(name, info) + + def _process_service_info(self, name: str, info: ServiceInfo) -> None: + """Process a service info.""" _LOGGER.debug("-> resolved info: %s", info) if info is None: return @@ -146,16 +166,17 @@ class DashboardImportDiscovery: ) -class EsphomeZeroconf(Zeroconf): - def _make_host_resolver(self, host: str) -> HostResolver: - """Create a new HostResolver for the given host name.""" - name = host.partition(".")[0] - info = HostResolver(ESPHOME_SERVICE_TYPE, f"{name}.{ESPHOME_SERVICE_TYPE}") - return info +def _make_host_resolver(host: str) -> HostResolver: + """Create a new HostResolver for the given host name.""" + name = host.partition(".")[0] + info = HostResolver(ESPHOME_SERVICE_TYPE, f"{name}.{ESPHOME_SERVICE_TYPE}") + return info + +class EsphomeZeroconf(Zeroconf): def resolve_host(self, host: str, timeout: float = 3.0) -> str | None: """Resolve a host name to an IP address.""" - info = self._make_host_resolver(host) + info = _make_host_resolver(host) if ( info.load_from_cache(self) or (timeout and info.request(self, timeout * 1000)) @@ -163,12 +184,14 @@ class EsphomeZeroconf(Zeroconf): return str(addresses[0]) return None + +class AsyncEsphomeZeroconf(AsyncZeroconf): async def async_resolve_host(self, host: str, timeout: float = 3.0) -> str | None: """Resolve a host name to an IP address.""" - info = self._make_host_resolver(host) + info = _make_host_resolver(host) if ( - info.load_from_cache(self) - or (timeout and await info.async_request(self, timeout * 1000)) + info.load_from_cache(self.zeroconf) + or (timeout and await info.async_request(self.zeroconf, timeout * 1000)) ) and (addresses := info.ip_addresses_by_version(IPVersion.V4Only)): return str(addresses[0]) return None