From 9194c8221341f8eeb933544049c5d931565a86f3 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 11 Nov 2023 17:33:13 -0600 Subject: [PATCH 01/11] tornado support native coros --- esphome/dashboard/dashboard.py | 46 ++++++++++++++++++---------------- esphome/zeroconf.py | 19 ++++++++++++-- 2 files changed, 41 insertions(+), 24 deletions(-) diff --git a/esphome/dashboard/dashboard.py b/esphome/dashboard/dashboard.py index 330cc06b3b..b56cddbe92 100644 --- a/esphome/dashboard/dashboard.py +++ b/esphome/dashboard/dashboard.py @@ -289,7 +289,7 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): self._use_popen = os.name == "nt" @authenticated - def on_message(self, message): + async def on_message(self, message): # Messages are always JSON, 500 when not json_message = json.loads(message) type_ = json_message["type"] @@ -299,14 +299,14 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): _LOGGER.warning("Requested unknown message type %s", type_) return - handlers[type_](self, json_message) + await handlers[type_](self, json_message) @websocket_method("spawn") - def handle_spawn(self, json_message): + async def handle_spawn(self, json_message): if self._proc is not None: # spawn can only be called once return - command = self.build_command(json_message) + command = await self.build_command(json_message) _LOGGER.info("Running command '%s'", " ".join(shlex_quote(x) for x in command)) if self._use_popen: @@ -337,7 +337,7 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): return self._proc is not None and self._proc.returncode is None @websocket_method("stdin") - def handle_stdin(self, json_message): + async def handle_stdin(self, json_message): if not self.is_process_active: return data = json_message["data"] @@ -395,7 +395,7 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): # Shutdown proc on WS close self._is_closed = True - def build_command(self, json_message): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: raise NotImplementedError @@ -405,7 +405,9 @@ DASHBOARD_COMMAND = ["esphome", "--dashboard"] class EsphomePortCommandWebSocket(EsphomeCommandWebSocket): """Base class for commands that require a port.""" - def run_command(self, args: list[str], json_message: dict[str, Any]) -> list[str]: + async def run_command( + self, args: list[str], json_message: dict[str, Any] + ) -> list[str]: """Build the command to run.""" configuration = json_message["configuration"] config_file = settings.rel_path(configuration) @@ -414,7 +416,7 @@ class EsphomePortCommandWebSocket(EsphomeCommandWebSocket): port == "OTA" and (mdns := MDNS_CONTAINER.get_mdns()) and (host_name := mdns.filename_to_host_name_thread_safe(configuration)) - and (address := mdns.resolve_host_thread_safe(host_name)) + and (address := await mdns.async_resolve_host(host_name)) ): port = address @@ -428,15 +430,15 @@ class EsphomePortCommandWebSocket(EsphomeCommandWebSocket): class EsphomeLogsHandler(EsphomePortCommandWebSocket): - def build_command(self, json_message: dict[str, Any]) -> list[str]: + async def build_command(self, json_message: dict[str, Any]) -> list[str]: """Build the command to run.""" - return self.run_command(["logs"], json_message) + return await self.run_command(["logs"], json_message) class EsphomeRenameHandler(EsphomeCommandWebSocket): old_name: str - def build_command(self, json_message): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: config_file = settings.rel_path(json_message["configuration"]) self.old_name = json_message["configuration"] return [ @@ -457,19 +459,19 @@ class EsphomeRenameHandler(EsphomeCommandWebSocket): class EsphomeUploadHandler(EsphomePortCommandWebSocket): - def build_command(self, json_message: dict[str, Any]) -> list[str]: + async def build_command(self, json_message: dict[str, Any]) -> list[str]: """Build the command to run.""" return self.run_command(["upload"], json_message) class EsphomeRunHandler(EsphomePortCommandWebSocket): - def build_command(self, json_message: dict[str, Any]) -> list[str]: + async def build_command(self, json_message: dict[str, Any]) -> list[str]: """Build the command to run.""" return self.run_command(["run"], json_message) class EsphomeCompileHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: config_file = settings.rel_path(json_message["configuration"]) command = [*DASHBOARD_COMMAND, "compile"] if json_message.get("only_generate", False): @@ -479,7 +481,7 @@ class EsphomeCompileHandler(EsphomeCommandWebSocket): class EsphomeValidateHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: config_file = settings.rel_path(json_message["configuration"]) command = [*DASHBOARD_COMMAND, "config", config_file] if not settings.streamer_mode: @@ -488,29 +490,29 @@ class EsphomeValidateHandler(EsphomeCommandWebSocket): class EsphomeCleanMqttHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: config_file = settings.rel_path(json_message["configuration"]) return [*DASHBOARD_COMMAND, "clean-mqtt", config_file] class EsphomeCleanHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: config_file = settings.rel_path(json_message["configuration"]) return [*DASHBOARD_COMMAND, "clean", config_file] class EsphomeVscodeHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: return [*DASHBOARD_COMMAND, "-q", "vscode", "dummy"] class EsphomeAceEditorHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: return [*DASHBOARD_COMMAND, "-q", "vscode", "--ace", settings.config_dir] class EsphomeUpdateAllHandler(EsphomeCommandWebSocket): - def build_command(self, json_message): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: return [*DASHBOARD_COMMAND, "update-all", settings.config_dir] @@ -990,12 +992,12 @@ class MDNSStatusThread(threading.Thread): """Resolve a filename to an address in a thread-safe manner.""" return self.filename_to_host_name.get(filename) - def resolve_host_thread_safe(self, host_name: str) -> str | None: + async def async_resolve_host(self, host_name: str) -> str | None: """Resolve a host name to an address in a thread-safe manner.""" if zc := self.zeroconf: # Currently we do not do any I/O and only # return the cached result (timeout=0) - return zc.resolve_host(host_name, 0) + return await zc.async_resolve_host(host_name) return None def _refresh_hosts(self): diff --git a/esphome/zeroconf.py b/esphome/zeroconf.py index f4cb7f080b..5e4fc70d00 100644 --- a/esphome/zeroconf.py +++ b/esphome/zeroconf.py @@ -147,13 +147,28 @@ class DashboardImportDiscovery: class EsphomeZeroconf(Zeroconf): - def resolve_host(self, host: str, timeout: float = 3.0) -> str | None: - """Resolve a host name to an IP address.""" + def _make_host_resolver(self, host: str) -> HostResolver: + """Create a new HostResolver for the given host name.""" name = host.partition(".")[0] info = HostResolver(ESPHOME_SERVICE_TYPE, f"{name}.{ESPHOME_SERVICE_TYPE}") + return info + + def resolve_host(self, host: str, timeout: float = 3.0) -> str | None: + """Resolve a host name to an IP address.""" + info = self._make_host_resolver(host) if ( info.load_from_cache(self) or (timeout and info.request(self, timeout * 1000)) ) and (addresses := info.ip_addresses_by_version(IPVersion.V4Only)): return str(addresses[0]) return None + + async def async_resolve_host(self, host: str, timeout: float = 3.0) -> str | None: + """Resolve a host name to an IP address.""" + info = self._make_host_resolver(host) + if ( + info.load_from_cache(self) + or (timeout and await info.async_request(self, timeout * 1000)) + ) and (addresses := info.ip_addresses_by_version(IPVersion.V4Only)): + return str(addresses[0]) + return None From 0c06abd960475c30b7ac4e7b90f8cf322365ee41 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 11 Nov 2023 17:39:09 -0600 Subject: [PATCH 02/11] lint --- esphome/dashboard/dashboard.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/esphome/dashboard/dashboard.py b/esphome/dashboard/dashboard.py index b56cddbe92..dbd0a6629d 100644 --- a/esphome/dashboard/dashboard.py +++ b/esphome/dashboard/dashboard.py @@ -289,7 +289,10 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): self._use_popen = os.name == "nt" @authenticated - async def on_message(self, message): + async def on_message( # pylint: disable=invalid-overridden-method + self, message: str + ) -> None: + # Since tornado 4.5, on_message is allowed to be a coroutine # Messages are always JSON, 500 when not json_message = json.loads(message) type_ = json_message["type"] From 5d8253ddf9cde203db3e7d6c78b161fbd69084c9 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 11 Nov 2023 18:19:44 -0600 Subject: [PATCH 03/11] use new tornado start methods --- esphome/dashboard/dashboard.py | 131 +++++++++++++++++++++++++-------- esphome/zeroconf.py | 61 ++++++++++----- 2 files changed, 143 insertions(+), 49 deletions(-) diff --git a/esphome/dashboard/dashboard.py b/esphome/dashboard/dashboard.py index dbd0a6629d..9b39aca973 100644 --- a/esphome/dashboard/dashboard.py +++ b/esphome/dashboard/dashboard.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import base64 import binascii import codecs @@ -47,10 +48,10 @@ from esphome.storage_json import ( from esphome.util import get_serial_ports, shlex_quote from esphome.zeroconf import ( ESPHOME_SERVICE_TYPE, + AsyncEsphomeZeroconf, DashboardBrowser, DashboardImportDiscovery, DashboardStatus, - EsphomeZeroconf, ) from .util import friendly_name_slugify, password_hash @@ -975,13 +976,13 @@ class BoardsRequestHandler(BaseHandler): self.write(json.dumps(output)) -class MDNSStatusThread(threading.Thread): - """Thread that updates the mdns status.""" +class MDNSStatus: + """Class that updates the mdns status.""" def __init__(self) -> None: - """Initialize the MDNSStatusThread.""" + """Initialize the MDNSStatus class.""" super().__init__() - self.zeroconf: EsphomeZeroconf | None = None + self.aiozc: AsyncEsphomeZeroconf | None = None # This is the current mdns state for each host (True, False, None) self.host_mdns_state: dict[str, bool | None] = {} # This is the hostnames to filenames mapping @@ -997,10 +998,10 @@ class MDNSStatusThread(threading.Thread): async def async_resolve_host(self, host_name: str) -> str | None: """Resolve a host name to an address in a thread-safe manner.""" - if zc := self.zeroconf: + if aiozc := self.aiozc: # Currently we do not do any I/O and only # return the cached result (timeout=0) - return await zc.async_resolve_host(host_name) + return await aiozc.async_resolve_host(host_name) return None def _refresh_hosts(self): @@ -1034,11 +1035,11 @@ class MDNSStatusThread(threading.Thread): host_name_to_filename[name] = filename filename_to_host_name[filename] = name - def run(self): + async def async_run(self) -> None: global IMPORT_RESULT - zc = EsphomeZeroconf() - self.zeroconf = zc + aiozc = AsyncEsphomeZeroconf() + self.aiozc = aiozc host_mdns_state = self.host_mdns_state host_name_to_filename = self.host_name_to_filename host_name_with_mdns_enabled = self.host_name_with_mdns_enabled @@ -1055,18 +1056,20 @@ class MDNSStatusThread(threading.Thread): stat = DashboardStatus(on_update) imports = DashboardImportDiscovery() browser = DashboardBrowser( - zc, ESPHOME_SERVICE_TYPE, [stat.browser_callback, imports.browser_callback] + aiozc.zeroconf, + ESPHOME_SERVICE_TYPE, + [stat.browser_callback, imports.browser_callback], ) while not STOP_EVENT.is_set(): self._refresh_hosts() IMPORT_RESULT = imports.import_state - PING_REQUEST.wait() - PING_REQUEST.clear() + await PING_REQUEST.async_wait() + PING_REQUEST.async_clear() - browser.cancel() - zc.close() - self.zeroconf = None + await browser.async_cancel() + await aiozc.async_close() + self.aiozc = None class PingStatusThread(threading.Thread): @@ -1246,22 +1249,69 @@ class UndoDeleteRequestHandler(BaseHandler): class MDNSContainer: def __init__(self) -> None: """Initialize the MDNSContainer.""" - self._mdns: MDNSStatusThread | None = None + self._mdns: MDNSStatus | None = None - def set_mdns(self, mdns: MDNSStatusThread) -> None: - """Set the MDNSStatusThread instance.""" + def set_mdns(self, mdns: MDNSStatus) -> None: + """Set the MDNSStatus instance.""" self._mdns = mdns - def get_mdns(self) -> MDNSStatusThread | None: - """Return the MDNSStatusThread instance.""" + def get_mdns(self) -> MDNSStatus | None: + """Return the MDNSStatus instance.""" return self._mdns +class ThreadedAsyncEvent: + def __init__(self) -> None: + """Initialize the ThreadedAsyncEvent.""" + self.event = threading.Event() + self.async_event: asyncio.Event | None = None + self.loop: asyncio.AbstractEventLoop | None = None + + def async_setup( + self, loop: asyncio.AbstractEventLoop, async_event: asyncio.Event + ) -> None: + """Set the asyncio.Event instance.""" + self.loop = loop + self.async_event = async_event + + def async_set(self) -> None: + """Set the asyncio.Event instance.""" + self.async_event.set() + self.event.set() + + def set(self) -> None: + """Set the event.""" + self.loop.call_soon_threadsafe(self.async_event.set) + self.event.set() + + def wait(self) -> None: + """Wait for the event.""" + self.event.wait() + + async def async_wait(self) -> None: + """Wait the event async.""" + await self.async_event.wait() + + def clear(self) -> None: + """Clear the event.""" + self.loop.call_soon_threadsafe(self.async_event.clear) + self.event.clear() + + def async_clear(self) -> None: + """Clear the event async.""" + self.async_event.clear() + self.event.clear() + + def is_set(self) -> bool: + """Return if the event is set.""" + return self.event.is_set() + + PING_RESULT: dict = {} IMPORT_RESULT = {} -STOP_EVENT = threading.Event() -PING_REQUEST = threading.Event() -MQTT_PING_REQUEST = threading.Event() +STOP_EVENT = ThreadedAsyncEvent() +PING_REQUEST = ThreadedAsyncEvent() +MQTT_PING_REQUEST = ThreadedAsyncEvent() MDNS_CONTAINER = MDNSContainer() @@ -1525,6 +1575,18 @@ def start_web_server(args): storage.save(path) settings.cookie_secret = storage.cookie_secret + try: + asyncio.run(async_start_web_server(args)) + except KeyboardInterrupt: + pass + + +async def async_start_web_server(args): + loop = asyncio.get_event_loop() + PING_REQUEST.async_setup(loop, asyncio.Event()) + MQTT_PING_REQUEST.async_setup(loop, asyncio.Event()) + STOP_EVENT.async_setup(loop, asyncio.Event()) + app = make_app(args.verbose) if args.socket is not None: _LOGGER.info( @@ -1549,27 +1611,36 @@ def start_web_server(args): webbrowser.open(f"http://{args.address}:{args.port}") + mdns_task: asyncio.Task | None = None + ping_status_thread: PingStatusThread | None = None if settings.status_use_ping: - status_thread = PingStatusThread() + ping_status_thread = PingStatusThread() + ping_status_thread.start() else: - status_thread = MDNSStatusThread() - MDNS_CONTAINER.set_mdns(status_thread) - status_thread.start() + mdns_status = MDNSStatus() + MDNS_CONTAINER.set_mdns(mdns_status) + mdns_task = asyncio.create_task(mdns_status.async_run()) if settings.status_use_mqtt: status_thread_mqtt = MqttStatusThread() status_thread_mqtt.start() + shutdown_event = asyncio.Event() try: - tornado.ioloop.IOLoop.current().start() + await shutdown_event.wait() except KeyboardInterrupt: + raise + finally: _LOGGER.info("Shutting down...") STOP_EVENT.set() PING_REQUEST.set() - status_thread.join() + if ping_status_thread: + ping_status_thread.join() MDNS_CONTAINER.set_mdns(None) + mdns_task.cancel() if settings.status_use_mqtt: status_thread_mqtt.join() MQTT_PING_REQUEST.set() if args.socket is not None: os.remove(args.socket) + await asyncio.sleep(0) diff --git a/esphome/zeroconf.py b/esphome/zeroconf.py index 5e4fc70d00..956e348e07 100644 --- a/esphome/zeroconf.py +++ b/esphome/zeroconf.py @@ -1,22 +1,21 @@ from __future__ import annotations +import asyncio import logging from dataclasses import dataclass from typing import Callable -from zeroconf import ( - IPVersion, - ServiceBrowser, - ServiceInfo, - ServiceStateChange, - Zeroconf, -) +from zeroconf import IPVersion, ServiceInfo, ServiceStateChange, Zeroconf +from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf from esphome.storage_json import StorageJSON, ext_storage_path _LOGGER = logging.getLogger(__name__) +_BACKGROUND_TASKS: set[asyncio.Task] = set() + + class HostResolver(ServiceInfo): """Resolve a host name to an IP address.""" @@ -65,7 +64,7 @@ class DiscoveredImport: network: str -class DashboardBrowser(ServiceBrowser): +class DashboardBrowser(AsyncServiceBrowser): """A class to browse for ESPHome nodes.""" @@ -94,7 +93,28 @@ class DashboardImportDiscovery: # Ignore updates for devices that are not in the import state return - info = zeroconf.get_service_info(service_type, name) + info = AsyncServiceInfo( + service_type, + name, + ) + if info.load_from_cache(zeroconf): + self._process_service_info(name, info) + return + task = asyncio.create_task( + self._async_process_service_info(zeroconf, info, service_type, name) + ) + _BACKGROUND_TASKS.add(task) + task.add_done_callback(_BACKGROUND_TASKS.discard) + + async def _async_process_service_info( + self, zeroconf: Zeroconf, info: AsyncServiceInfo, service_type: str, name: str + ) -> None: + """Process a service info.""" + if await info.async_request(zeroconf): + self._process_service_info(name, info) + + def _process_service_info(self, name: str, info: ServiceInfo) -> None: + """Process a service info.""" _LOGGER.debug("-> resolved info: %s", info) if info is None: return @@ -146,16 +166,17 @@ class DashboardImportDiscovery: ) -class EsphomeZeroconf(Zeroconf): - def _make_host_resolver(self, host: str) -> HostResolver: - """Create a new HostResolver for the given host name.""" - name = host.partition(".")[0] - info = HostResolver(ESPHOME_SERVICE_TYPE, f"{name}.{ESPHOME_SERVICE_TYPE}") - return info +def _make_host_resolver(host: str) -> HostResolver: + """Create a new HostResolver for the given host name.""" + name = host.partition(".")[0] + info = HostResolver(ESPHOME_SERVICE_TYPE, f"{name}.{ESPHOME_SERVICE_TYPE}") + return info + +class EsphomeZeroconf(Zeroconf): def resolve_host(self, host: str, timeout: float = 3.0) -> str | None: """Resolve a host name to an IP address.""" - info = self._make_host_resolver(host) + info = _make_host_resolver(host) if ( info.load_from_cache(self) or (timeout and info.request(self, timeout * 1000)) @@ -163,12 +184,14 @@ class EsphomeZeroconf(Zeroconf): return str(addresses[0]) return None + +class AsyncEsphomeZeroconf(AsyncZeroconf): async def async_resolve_host(self, host: str, timeout: float = 3.0) -> str | None: """Resolve a host name to an IP address.""" - info = self._make_host_resolver(host) + info = _make_host_resolver(host) if ( - info.load_from_cache(self) - or (timeout and await info.async_request(self, timeout * 1000)) + info.load_from_cache(self.zeroconf) + or (timeout and await info.async_request(self.zeroconf, timeout * 1000)) ) and (addresses := info.ip_addresses_by_version(IPVersion.V4Only)): return str(addresses[0]) return None From 8ad691cbde09ee15af2b34ad69c53109d5e73479 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 11 Nov 2023 18:21:08 -0600 Subject: [PATCH 04/11] use new tornado start methods --- esphome/dashboard/dashboard.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/esphome/dashboard/dashboard.py b/esphome/dashboard/dashboard.py index 9b39aca973..608202f85f 100644 --- a/esphome/dashboard/dashboard.py +++ b/esphome/dashboard/dashboard.py @@ -1261,6 +1261,11 @@ class MDNSContainer: class ThreadedAsyncEvent: + """This is a shim to allow the asyncio event to be used in a threaded context. + + When more of the code is moved to asyncio, this can be removed. + """ + def __init__(self) -> None: """Initialize the ThreadedAsyncEvent.""" self.event = threading.Event() From e388fa5a35ae07ca12a76ce6c96a1fde339f918c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 11 Nov 2023 18:22:14 -0600 Subject: [PATCH 05/11] use new tornado start methods --- esphome/dashboard/dashboard.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/esphome/dashboard/dashboard.py b/esphome/dashboard/dashboard.py index 608202f85f..358af5b753 100644 --- a/esphome/dashboard/dashboard.py +++ b/esphome/dashboard/dashboard.py @@ -1314,9 +1314,9 @@ class ThreadedAsyncEvent: PING_RESULT: dict = {} IMPORT_RESULT = {} -STOP_EVENT = ThreadedAsyncEvent() +STOP_EVENT = threading.Event() PING_REQUEST = ThreadedAsyncEvent() -MQTT_PING_REQUEST = ThreadedAsyncEvent() +MQTT_PING_REQUEST = threading.Event() MDNS_CONTAINER = MDNSContainer() @@ -1589,8 +1589,6 @@ def start_web_server(args): async def async_start_web_server(args): loop = asyncio.get_event_loop() PING_REQUEST.async_setup(loop, asyncio.Event()) - MQTT_PING_REQUEST.async_setup(loop, asyncio.Event()) - STOP_EVENT.async_setup(loop, asyncio.Event()) app = make_app(args.verbose) if args.socket is not None: From 4d984ce4129d368fbae798454e9669347256bb3b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 11 Nov 2023 18:26:01 -0600 Subject: [PATCH 06/11] break --- esphome/dashboard/async_adapter.py | 56 ++++++++++++++++++++++++++++++ esphome/dashboard/dashboard.py | 53 +--------------------------- 2 files changed, 57 insertions(+), 52 deletions(-) create mode 100644 esphome/dashboard/async_adapter.py diff --git a/esphome/dashboard/async_adapter.py b/esphome/dashboard/async_adapter.py new file mode 100644 index 0000000000..d6f4f6e1ff --- /dev/null +++ b/esphome/dashboard/async_adapter.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import asyncio +import threading + + +class ThreadedAsyncEvent: + """This is a shim to allow the asyncio event to be used in a threaded context. + + When more of the code is moved to asyncio, this can be removed. + """ + + def __init__(self) -> None: + """Initialize the ThreadedAsyncEvent.""" + self.event = threading.Event() + self.async_event: asyncio.Event | None = None + self.loop: asyncio.AbstractEventLoop | None = None + + def async_setup( + self, loop: asyncio.AbstractEventLoop, async_event: asyncio.Event + ) -> None: + """Set the asyncio.Event instance.""" + self.loop = loop + self.async_event = async_event + + def async_set(self) -> None: + """Set the asyncio.Event instance.""" + self.async_event.set() + self.event.set() + + def set(self) -> None: + """Set the event.""" + self.loop.call_soon_threadsafe(self.async_event.set) + self.event.set() + + def wait(self) -> None: + """Wait for the event.""" + self.event.wait() + + async def async_wait(self) -> None: + """Wait the event async.""" + await self.async_event.wait() + + def clear(self) -> None: + """Clear the event.""" + self.loop.call_soon_threadsafe(self.async_event.clear) + self.event.clear() + + def async_clear(self) -> None: + """Clear the event async.""" + self.async_event.clear() + self.event.clear() + + def is_set(self) -> bool: + """Return if the event is set.""" + return self.event.is_set() diff --git a/esphome/dashboard/dashboard.py b/esphome/dashboard/dashboard.py index 358af5b753..43870d0f34 100644 --- a/esphome/dashboard/dashboard.py +++ b/esphome/dashboard/dashboard.py @@ -53,6 +53,7 @@ from esphome.zeroconf import ( DashboardImportDiscovery, DashboardStatus, ) +from .async_adapter import ThreadedAsyncEvent from .util import friendly_name_slugify, password_hash @@ -1260,58 +1261,6 @@ class MDNSContainer: return self._mdns -class ThreadedAsyncEvent: - """This is a shim to allow the asyncio event to be used in a threaded context. - - When more of the code is moved to asyncio, this can be removed. - """ - - def __init__(self) -> None: - """Initialize the ThreadedAsyncEvent.""" - self.event = threading.Event() - self.async_event: asyncio.Event | None = None - self.loop: asyncio.AbstractEventLoop | None = None - - def async_setup( - self, loop: asyncio.AbstractEventLoop, async_event: asyncio.Event - ) -> None: - """Set the asyncio.Event instance.""" - self.loop = loop - self.async_event = async_event - - def async_set(self) -> None: - """Set the asyncio.Event instance.""" - self.async_event.set() - self.event.set() - - def set(self) -> None: - """Set the event.""" - self.loop.call_soon_threadsafe(self.async_event.set) - self.event.set() - - def wait(self) -> None: - """Wait for the event.""" - self.event.wait() - - async def async_wait(self) -> None: - """Wait the event async.""" - await self.async_event.wait() - - def clear(self) -> None: - """Clear the event.""" - self.loop.call_soon_threadsafe(self.async_event.clear) - self.event.clear() - - def async_clear(self) -> None: - """Clear the event async.""" - self.async_event.clear() - self.event.clear() - - def is_set(self) -> bool: - """Return if the event is set.""" - return self.event.is_set() - - PING_RESULT: dict = {} IMPORT_RESULT = {} STOP_EVENT = threading.Event() From 250d82e0c88f3488fa6df4f1573efd9207caefd1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 11 Nov 2023 18:29:39 -0600 Subject: [PATCH 07/11] lint --- esphome/dashboard/dashboard.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/esphome/dashboard/dashboard.py b/esphome/dashboard/dashboard.py index 43870d0f34..5fb0a06bf3 100644 --- a/esphome/dashboard/dashboard.py +++ b/esphome/dashboard/dashboard.py @@ -1580,8 +1580,6 @@ async def async_start_web_server(args): shutdown_event = asyncio.Event() try: await shutdown_event.wait() - except KeyboardInterrupt: - raise finally: _LOGGER.info("Shutting down...") STOP_EVENT.set() From b047038adbf9f455312095fb10092573861c2478 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 11 Nov 2023 18:33:28 -0600 Subject: [PATCH 08/11] lint --- esphome/dashboard/dashboard.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/esphome/dashboard/dashboard.py b/esphome/dashboard/dashboard.py index 5fb0a06bf3..d5bcee4156 100644 --- a/esphome/dashboard/dashboard.py +++ b/esphome/dashboard/dashboard.py @@ -307,7 +307,7 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): await handlers[type_](self, json_message) @websocket_method("spawn") - async def handle_spawn(self, json_message): + async def handle_spawn(self, json_message: dict[str, Any]) -> None: if self._proc is not None: # spawn can only be called once return From 26d0f0ebeac04e55471c6f40e6a508e04f718543 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 11 Nov 2023 18:38:48 -0600 Subject: [PATCH 09/11] typing, missing awaits --- esphome/dashboard/dashboard.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/esphome/dashboard/dashboard.py b/esphome/dashboard/dashboard.py index d5bcee4156..d9239da229 100644 --- a/esphome/dashboard/dashboard.py +++ b/esphome/dashboard/dashboard.py @@ -342,7 +342,7 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): return self._proc is not None and self._proc.returncode is None @websocket_method("stdin") - async def handle_stdin(self, json_message): + async def handle_stdin(self, json_message: dict[str, Any]) -> None: if not self.is_process_active: return data = json_message["data"] @@ -351,7 +351,7 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): self._proc.stdin.write(data) @tornado.gen.coroutine - def _redirect_stdout(self): + def _redirect_stdout(self) -> None: reg = b"[\n\r]" while True: @@ -370,7 +370,7 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): _LOGGER.debug("> stdout: %s", data) self.write_message({"event": "line", "data": data}) - def _stdout_thread(self): + def _stdout_thread(self) -> None: if not self._use_popen: return while True: @@ -383,13 +383,13 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): self._proc.wait(1.0) self._queue.put_nowait(None) - def _proc_on_exit(self, returncode): + def _proc_on_exit(self, returncode: int) -> None: if not self._is_closed: # Check if the proc was not forcibly closed _LOGGER.info("Process exited with return code %s", returncode) self.write_message({"event": "exit", "code": returncode}) - def on_close(self): + def on_close(self) -> None: # Check if proc exists (if 'start' has been run) if self.is_process_active: _LOGGER.debug("Terminating process") @@ -466,13 +466,13 @@ class EsphomeRenameHandler(EsphomeCommandWebSocket): class EsphomeUploadHandler(EsphomePortCommandWebSocket): async def build_command(self, json_message: dict[str, Any]) -> list[str]: """Build the command to run.""" - return self.run_command(["upload"], json_message) + return await self.run_command(["upload"], json_message) class EsphomeRunHandler(EsphomePortCommandWebSocket): async def build_command(self, json_message: dict[str, Any]) -> list[str]: """Build the command to run.""" - return self.run_command(["run"], json_message) + return await self.run_command(["run"], json_message) class EsphomeCompileHandler(EsphomeCommandWebSocket): From 53f3385c496f22dce6a2bcfbcad990cc217b0efa Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 11 Nov 2023 19:36:56 -0600 Subject: [PATCH 10/11] Migrate to using aioesphomeapi for the log runner to fix multiple issues (#5733) --- esphome/components/api/client.py | 76 +++++++++++++++----------------- 1 file changed, 35 insertions(+), 41 deletions(-) diff --git a/esphome/components/api/client.py b/esphome/components/api/client.py index 819055ccf4..2c43eca70c 100644 --- a/esphome/components/api/client.py +++ b/esphome/components/api/client.py @@ -1,71 +1,65 @@ +from __future__ import annotations + import asyncio import logging from datetime import datetime -from typing import Optional +from typing import Any -from aioesphomeapi import APIClient, ReconnectLogic, APIConnectionError, LogLevel -import zeroconf +from aioesphomeapi import APIClient +from aioesphomeapi.api_pb2 import SubscribeLogsResponse +from aioesphomeapi.log_runner import async_run +from zeroconf.asyncio import AsyncZeroconf + +from esphome.const import CONF_KEY, CONF_PASSWORD, CONF_PORT, __version__ +from esphome.core import CORE -from esphome.const import CONF_KEY, CONF_PORT, CONF_PASSWORD, __version__ -from esphome.util import safe_print from . import CONF_ENCRYPTION _LOGGER = logging.getLogger(__name__) async def async_run_logs(config, address): + """Run the logs command in the event loop.""" conf = config["api"] port: int = int(conf[CONF_PORT]) password: str = conf[CONF_PASSWORD] - noise_psk: Optional[str] = None + noise_psk: str | None = None if CONF_ENCRYPTION in conf: noise_psk = conf[CONF_ENCRYPTION][CONF_KEY] _LOGGER.info("Starting log output from %s using esphome API", address) + aiozc = AsyncZeroconf() + cli = APIClient( address, port, password, client_info=f"ESPHome Logs {__version__}", noise_psk=noise_psk, + zeroconf_instance=aiozc.zeroconf, ) - first_connect = True + dashboard = CORE.dashboard - def on_log(msg): - time_ = datetime.now().time().strftime("[%H:%M:%S]") - text = msg.message.decode("utf8", "backslashreplace") - safe_print(time_ + text) - - async def on_connect(): - nonlocal first_connect - try: - await cli.subscribe_logs( - on_log, - log_level=LogLevel.LOG_LEVEL_VERY_VERBOSE, - dump_config=first_connect, - ) - first_connect = False - except APIConnectionError: - cli.disconnect() - - async def on_disconnect(expected_disconnect: bool) -> None: - _LOGGER.warning("Disconnected from API") - - zc = zeroconf.Zeroconf() - reconnect = ReconnectLogic( - client=cli, - on_connect=on_connect, - on_disconnect=on_disconnect, - zeroconf_instance=zc, - ) - await reconnect.start() + def on_log(msg: SubscribeLogsResponse) -> None: + """Handle a new log message.""" + time_ = datetime.now() + message: bytes = msg.message + text = message.decode("utf8", "backslashreplace") + if dashboard: + text = text.replace("\033", "\\033") + print(f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}]{text}") + stop = await async_run(cli, on_log, aio_zeroconf_instance=aiozc) try: while True: await asyncio.sleep(60) + finally: + await aiozc.async_close() + await stop() + + +def run_logs(config: dict[str, Any], address: str) -> None: + """Run the logs command.""" + try: + asyncio.run(async_run_logs(config, address)) except KeyboardInterrupt: - await reconnect.stop() - zc.close() - - -def run_logs(config, address): - asyncio.run(async_run_logs(config, address)) + pass From 1bd90e4f515897fd89fd0eab01c77efe6ff4731d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 11 Nov 2023 22:50:41 -0600 Subject: [PATCH 11/11] io in executor --- esphome/dashboard/dashboard.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/esphome/dashboard/dashboard.py b/esphome/dashboard/dashboard.py index d9239da229..d0e9a82523 100644 --- a/esphome/dashboard/dashboard.py +++ b/esphome/dashboard/dashboard.py @@ -991,7 +991,7 @@ class MDNSStatus: self.filename_to_host_name: dict[str, str] = {} # This is a set of host names to track (i.e no_mdns = false) self.host_name_with_mdns_enabled: set[set] = set() - self._refresh_hosts() + self._loop = asyncio.get_running_loop() def filename_to_host_name_thread_safe(self, filename: str) -> str | None: """Resolve a filename to an address in a thread-safe manner.""" @@ -1005,9 +1005,9 @@ class MDNSStatus: return await aiozc.async_resolve_host(host_name) return None - def _refresh_hosts(self): + async def async_refresh_hosts(self): """Refresh the hosts to track.""" - entries = _list_dashboard_entries() + entries = await self._loop.run_in_executor(None, _list_dashboard_entries) host_name_with_mdns_enabled = self.host_name_with_mdns_enabled host_mdns_state = self.host_mdns_state host_name_to_filename = self.host_name_to_filename @@ -1053,7 +1053,6 @@ class MDNSStatus: filename = host_name_to_filename[name] PING_RESULT[filename] = result - self._refresh_hosts() stat = DashboardStatus(on_update) imports = DashboardImportDiscovery() browser = DashboardBrowser( @@ -1063,7 +1062,7 @@ class MDNSStatus: ) while not STOP_EVENT.is_set(): - self._refresh_hosts() + await self.async_refresh_hosts() IMPORT_RESULT = imports.import_state await PING_REQUEST.async_wait() PING_REQUEST.async_clear() @@ -1570,6 +1569,7 @@ async def async_start_web_server(args): ping_status_thread.start() else: mdns_status = MDNSStatus() + await mdns_status.async_refresh_hosts() MDNS_CONTAINER.set_mdns(mdns_status) mdns_task = asyncio.create_task(mdns_status.async_run())