mirror of
				https://github.com/esphome/esphome.git
				synced 2025-10-31 15:12:06 +00:00 
			
		
		
		
	use new tornado start methods
This commit is contained in:
		| @@ -1,5 +1,6 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| import asyncio | ||||
| import base64 | ||||
| import binascii | ||||
| import codecs | ||||
| @@ -47,10 +48,10 @@ from esphome.storage_json import ( | ||||
| from esphome.util import get_serial_ports, shlex_quote | ||||
| from esphome.zeroconf import ( | ||||
|     ESPHOME_SERVICE_TYPE, | ||||
|     AsyncEsphomeZeroconf, | ||||
|     DashboardBrowser, | ||||
|     DashboardImportDiscovery, | ||||
|     DashboardStatus, | ||||
|     EsphomeZeroconf, | ||||
| ) | ||||
|  | ||||
| from .util import friendly_name_slugify, password_hash | ||||
| @@ -975,13 +976,13 @@ class BoardsRequestHandler(BaseHandler): | ||||
|         self.write(json.dumps(output)) | ||||
|  | ||||
|  | ||||
| class MDNSStatusThread(threading.Thread): | ||||
|     """Thread that updates the mdns status.""" | ||||
| class MDNSStatus: | ||||
|     """Class that updates the mdns status.""" | ||||
|  | ||||
|     def __init__(self) -> None: | ||||
|         """Initialize the MDNSStatusThread.""" | ||||
|         """Initialize the MDNSStatus class.""" | ||||
|         super().__init__() | ||||
|         self.zeroconf: EsphomeZeroconf | None = None | ||||
|         self.aiozc: AsyncEsphomeZeroconf | None = None | ||||
|         # This is the current mdns state for each host (True, False, None) | ||||
|         self.host_mdns_state: dict[str, bool | None] = {} | ||||
|         # This is the hostnames to filenames mapping | ||||
| @@ -997,10 +998,10 @@ class MDNSStatusThread(threading.Thread): | ||||
|  | ||||
|     async def async_resolve_host(self, host_name: str) -> str | None: | ||||
|         """Resolve a host name to an address in a thread-safe manner.""" | ||||
|         if zc := self.zeroconf: | ||||
|         if aiozc := self.aiozc: | ||||
|             # Currently we do not do any I/O and only | ||||
|             # return the cached result (timeout=0) | ||||
|             return await zc.async_resolve_host(host_name) | ||||
|             return await aiozc.async_resolve_host(host_name) | ||||
|         return None | ||||
|  | ||||
|     def _refresh_hosts(self): | ||||
| @@ -1034,11 +1035,11 @@ class MDNSStatusThread(threading.Thread): | ||||
|             host_name_to_filename[name] = filename | ||||
|             filename_to_host_name[filename] = name | ||||
|  | ||||
|     def run(self): | ||||
|     async def async_run(self) -> None: | ||||
|         global IMPORT_RESULT | ||||
|  | ||||
|         zc = EsphomeZeroconf() | ||||
|         self.zeroconf = zc | ||||
|         aiozc = AsyncEsphomeZeroconf() | ||||
|         self.aiozc = aiozc | ||||
|         host_mdns_state = self.host_mdns_state | ||||
|         host_name_to_filename = self.host_name_to_filename | ||||
|         host_name_with_mdns_enabled = self.host_name_with_mdns_enabled | ||||
| @@ -1055,18 +1056,20 @@ class MDNSStatusThread(threading.Thread): | ||||
|         stat = DashboardStatus(on_update) | ||||
|         imports = DashboardImportDiscovery() | ||||
|         browser = DashboardBrowser( | ||||
|             zc, ESPHOME_SERVICE_TYPE, [stat.browser_callback, imports.browser_callback] | ||||
|             aiozc.zeroconf, | ||||
|             ESPHOME_SERVICE_TYPE, | ||||
|             [stat.browser_callback, imports.browser_callback], | ||||
|         ) | ||||
|  | ||||
|         while not STOP_EVENT.is_set(): | ||||
|             self._refresh_hosts() | ||||
|             IMPORT_RESULT = imports.import_state | ||||
|             PING_REQUEST.wait() | ||||
|             PING_REQUEST.clear() | ||||
|             await PING_REQUEST.async_wait() | ||||
|             PING_REQUEST.async_clear() | ||||
|  | ||||
|         browser.cancel() | ||||
|         zc.close() | ||||
|         self.zeroconf = None | ||||
|         await browser.async_cancel() | ||||
|         await aiozc.async_close() | ||||
|         self.aiozc = None | ||||
|  | ||||
|  | ||||
| class PingStatusThread(threading.Thread): | ||||
| @@ -1246,22 +1249,69 @@ class UndoDeleteRequestHandler(BaseHandler): | ||||
| class MDNSContainer: | ||||
|     def __init__(self) -> None: | ||||
|         """Initialize the MDNSContainer.""" | ||||
|         self._mdns: MDNSStatusThread | None = None | ||||
|         self._mdns: MDNSStatus | None = None | ||||
|  | ||||
|     def set_mdns(self, mdns: MDNSStatusThread) -> None: | ||||
|         """Set the MDNSStatusThread instance.""" | ||||
|     def set_mdns(self, mdns: MDNSStatus) -> None: | ||||
|         """Set the MDNSStatus instance.""" | ||||
|         self._mdns = mdns | ||||
|  | ||||
|     def get_mdns(self) -> MDNSStatusThread | None: | ||||
|         """Return the MDNSStatusThread instance.""" | ||||
|     def get_mdns(self) -> MDNSStatus | None: | ||||
|         """Return the MDNSStatus instance.""" | ||||
|         return self._mdns | ||||
|  | ||||
|  | ||||
| class ThreadedAsyncEvent: | ||||
|     def __init__(self) -> None: | ||||
|         """Initialize the ThreadedAsyncEvent.""" | ||||
|         self.event = threading.Event() | ||||
|         self.async_event: asyncio.Event | None = None | ||||
|         self.loop: asyncio.AbstractEventLoop | None = None | ||||
|  | ||||
|     def async_setup( | ||||
|         self, loop: asyncio.AbstractEventLoop, async_event: asyncio.Event | ||||
|     ) -> None: | ||||
|         """Set the asyncio.Event instance.""" | ||||
|         self.loop = loop | ||||
|         self.async_event = async_event | ||||
|  | ||||
|     def async_set(self) -> None: | ||||
|         """Set the asyncio.Event instance.""" | ||||
|         self.async_event.set() | ||||
|         self.event.set() | ||||
|  | ||||
|     def set(self) -> None: | ||||
|         """Set the event.""" | ||||
|         self.loop.call_soon_threadsafe(self.async_event.set) | ||||
|         self.event.set() | ||||
|  | ||||
|     def wait(self) -> None: | ||||
|         """Wait for the event.""" | ||||
|         self.event.wait() | ||||
|  | ||||
|     async def async_wait(self) -> None: | ||||
|         """Wait the event async.""" | ||||
|         await self.async_event.wait() | ||||
|  | ||||
|     def clear(self) -> None: | ||||
|         """Clear the event.""" | ||||
|         self.loop.call_soon_threadsafe(self.async_event.clear) | ||||
|         self.event.clear() | ||||
|  | ||||
|     def async_clear(self) -> None: | ||||
|         """Clear the event async.""" | ||||
|         self.async_event.clear() | ||||
|         self.event.clear() | ||||
|  | ||||
|     def is_set(self) -> bool: | ||||
|         """Return if the event is set.""" | ||||
|         return self.event.is_set() | ||||
|  | ||||
|  | ||||
| PING_RESULT: dict = {} | ||||
| IMPORT_RESULT = {} | ||||
| STOP_EVENT = threading.Event() | ||||
| PING_REQUEST = threading.Event() | ||||
| MQTT_PING_REQUEST = threading.Event() | ||||
| STOP_EVENT = ThreadedAsyncEvent() | ||||
| PING_REQUEST = ThreadedAsyncEvent() | ||||
| MQTT_PING_REQUEST = ThreadedAsyncEvent() | ||||
| MDNS_CONTAINER = MDNSContainer() | ||||
|  | ||||
|  | ||||
| @@ -1525,6 +1575,18 @@ def start_web_server(args): | ||||
|             storage.save(path) | ||||
|         settings.cookie_secret = storage.cookie_secret | ||||
|  | ||||
|     try: | ||||
|         asyncio.run(async_start_web_server(args)) | ||||
|     except KeyboardInterrupt: | ||||
|         pass | ||||
|  | ||||
|  | ||||
| async def async_start_web_server(args): | ||||
|     loop = asyncio.get_event_loop() | ||||
|     PING_REQUEST.async_setup(loop, asyncio.Event()) | ||||
|     MQTT_PING_REQUEST.async_setup(loop, asyncio.Event()) | ||||
|     STOP_EVENT.async_setup(loop, asyncio.Event()) | ||||
|  | ||||
|     app = make_app(args.verbose) | ||||
|     if args.socket is not None: | ||||
|         _LOGGER.info( | ||||
| @@ -1549,27 +1611,36 @@ def start_web_server(args): | ||||
|  | ||||
|             webbrowser.open(f"http://{args.address}:{args.port}") | ||||
|  | ||||
|     mdns_task: asyncio.Task | None = None | ||||
|     ping_status_thread: PingStatusThread | None = None | ||||
|     if settings.status_use_ping: | ||||
|         status_thread = PingStatusThread() | ||||
|         ping_status_thread = PingStatusThread() | ||||
|         ping_status_thread.start() | ||||
|     else: | ||||
|         status_thread = MDNSStatusThread() | ||||
|         MDNS_CONTAINER.set_mdns(status_thread) | ||||
|     status_thread.start() | ||||
|         mdns_status = MDNSStatus() | ||||
|         MDNS_CONTAINER.set_mdns(mdns_status) | ||||
|         mdns_task = asyncio.create_task(mdns_status.async_run()) | ||||
|  | ||||
|     if settings.status_use_mqtt: | ||||
|         status_thread_mqtt = MqttStatusThread() | ||||
|         status_thread_mqtt.start() | ||||
|  | ||||
|     shutdown_event = asyncio.Event() | ||||
|     try: | ||||
|         tornado.ioloop.IOLoop.current().start() | ||||
|         await shutdown_event.wait() | ||||
|     except KeyboardInterrupt: | ||||
|         raise | ||||
|     finally: | ||||
|         _LOGGER.info("Shutting down...") | ||||
|         STOP_EVENT.set() | ||||
|         PING_REQUEST.set() | ||||
|         status_thread.join() | ||||
|         if ping_status_thread: | ||||
|             ping_status_thread.join() | ||||
|         MDNS_CONTAINER.set_mdns(None) | ||||
|         mdns_task.cancel() | ||||
|         if settings.status_use_mqtt: | ||||
|             status_thread_mqtt.join() | ||||
|             MQTT_PING_REQUEST.set() | ||||
|         if args.socket is not None: | ||||
|             os.remove(args.socket) | ||||
|         await asyncio.sleep(0) | ||||
|   | ||||
| @@ -1,22 +1,21 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| import asyncio | ||||
| import logging | ||||
| from dataclasses import dataclass | ||||
| from typing import Callable | ||||
|  | ||||
| from zeroconf import ( | ||||
|     IPVersion, | ||||
|     ServiceBrowser, | ||||
|     ServiceInfo, | ||||
|     ServiceStateChange, | ||||
|     Zeroconf, | ||||
| ) | ||||
| from zeroconf import IPVersion, ServiceInfo, ServiceStateChange, Zeroconf | ||||
| from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf | ||||
|  | ||||
| from esphome.storage_json import StorageJSON, ext_storage_path | ||||
|  | ||||
| _LOGGER = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| _BACKGROUND_TASKS: set[asyncio.Task] = set() | ||||
|  | ||||
|  | ||||
| class HostResolver(ServiceInfo): | ||||
|     """Resolve a host name to an IP address.""" | ||||
|  | ||||
| @@ -65,7 +64,7 @@ class DiscoveredImport: | ||||
|     network: str | ||||
|  | ||||
|  | ||||
| class DashboardBrowser(ServiceBrowser): | ||||
| class DashboardBrowser(AsyncServiceBrowser): | ||||
|     """A class to browse for ESPHome nodes.""" | ||||
|  | ||||
|  | ||||
| @@ -94,7 +93,28 @@ class DashboardImportDiscovery: | ||||
|             # Ignore updates for devices that are not in the import state | ||||
|             return | ||||
|  | ||||
|         info = zeroconf.get_service_info(service_type, name) | ||||
|         info = AsyncServiceInfo( | ||||
|             service_type, | ||||
|             name, | ||||
|         ) | ||||
|         if info.load_from_cache(zeroconf): | ||||
|             self._process_service_info(name, info) | ||||
|             return | ||||
|         task = asyncio.create_task( | ||||
|             self._async_process_service_info(zeroconf, info, service_type, name) | ||||
|         ) | ||||
|         _BACKGROUND_TASKS.add(task) | ||||
|         task.add_done_callback(_BACKGROUND_TASKS.discard) | ||||
|  | ||||
|     async def _async_process_service_info( | ||||
|         self, zeroconf: Zeroconf, info: AsyncServiceInfo, service_type: str, name: str | ||||
|     ) -> None: | ||||
|         """Process a service info.""" | ||||
|         if await info.async_request(zeroconf): | ||||
|             self._process_service_info(name, info) | ||||
|  | ||||
|     def _process_service_info(self, name: str, info: ServiceInfo) -> None: | ||||
|         """Process a service info.""" | ||||
|         _LOGGER.debug("-> resolved info: %s", info) | ||||
|         if info is None: | ||||
|             return | ||||
| @@ -146,16 +166,17 @@ class DashboardImportDiscovery: | ||||
|                 ) | ||||
|  | ||||
|  | ||||
| class EsphomeZeroconf(Zeroconf): | ||||
|     def _make_host_resolver(self, host: str) -> HostResolver: | ||||
|         """Create a new HostResolver for the given host name.""" | ||||
|         name = host.partition(".")[0] | ||||
|         info = HostResolver(ESPHOME_SERVICE_TYPE, f"{name}.{ESPHOME_SERVICE_TYPE}") | ||||
|         return info | ||||
| def _make_host_resolver(host: str) -> HostResolver: | ||||
|     """Create a new HostResolver for the given host name.""" | ||||
|     name = host.partition(".")[0] | ||||
|     info = HostResolver(ESPHOME_SERVICE_TYPE, f"{name}.{ESPHOME_SERVICE_TYPE}") | ||||
|     return info | ||||
|  | ||||
|  | ||||
| class EsphomeZeroconf(Zeroconf): | ||||
|     def resolve_host(self, host: str, timeout: float = 3.0) -> str | None: | ||||
|         """Resolve a host name to an IP address.""" | ||||
|         info = self._make_host_resolver(host) | ||||
|         info = _make_host_resolver(host) | ||||
|         if ( | ||||
|             info.load_from_cache(self) | ||||
|             or (timeout and info.request(self, timeout * 1000)) | ||||
| @@ -163,12 +184,14 @@ class EsphomeZeroconf(Zeroconf): | ||||
|             return str(addresses[0]) | ||||
|         return None | ||||
|  | ||||
|  | ||||
| class AsyncEsphomeZeroconf(AsyncZeroconf): | ||||
|     async def async_resolve_host(self, host: str, timeout: float = 3.0) -> str | None: | ||||
|         """Resolve a host name to an IP address.""" | ||||
|         info = self._make_host_resolver(host) | ||||
|         info = _make_host_resolver(host) | ||||
|         if ( | ||||
|             info.load_from_cache(self) | ||||
|             or (timeout and await info.async_request(self, timeout * 1000)) | ||||
|             info.load_from_cache(self.zeroconf) | ||||
|             or (timeout and await info.async_request(self.zeroconf, timeout * 1000)) | ||||
|         ) and (addresses := info.ip_addresses_by_version(IPVersion.V4Only)): | ||||
|             return str(addresses[0]) | ||||
|         return None | ||||
|   | ||||
		Reference in New Issue
	
	Block a user