mirror of
https://github.com/esphome/esphome.git
synced 2025-02-12 07:58:17 +00:00
use new tornado start methods
This commit is contained in:
parent
0c06abd960
commit
5d8253ddf9
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import binascii
|
||||
import codecs
|
||||
@ -47,10 +48,10 @@ from esphome.storage_json import (
|
||||
from esphome.util import get_serial_ports, shlex_quote
|
||||
from esphome.zeroconf import (
|
||||
ESPHOME_SERVICE_TYPE,
|
||||
AsyncEsphomeZeroconf,
|
||||
DashboardBrowser,
|
||||
DashboardImportDiscovery,
|
||||
DashboardStatus,
|
||||
EsphomeZeroconf,
|
||||
)
|
||||
|
||||
from .util import friendly_name_slugify, password_hash
|
||||
@ -975,13 +976,13 @@ class BoardsRequestHandler(BaseHandler):
|
||||
self.write(json.dumps(output))
|
||||
|
||||
|
||||
class MDNSStatusThread(threading.Thread):
|
||||
"""Thread that updates the mdns status."""
|
||||
class MDNSStatus:
|
||||
"""Class that updates the mdns status."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the MDNSStatusThread."""
|
||||
"""Initialize the MDNSStatus class."""
|
||||
super().__init__()
|
||||
self.zeroconf: EsphomeZeroconf | None = None
|
||||
self.aiozc: AsyncEsphomeZeroconf | None = None
|
||||
# This is the current mdns state for each host (True, False, None)
|
||||
self.host_mdns_state: dict[str, bool | None] = {}
|
||||
# This is the hostnames to filenames mapping
|
||||
@ -997,10 +998,10 @@ class MDNSStatusThread(threading.Thread):
|
||||
|
||||
async def async_resolve_host(self, host_name: str) -> str | None:
|
||||
"""Resolve a host name to an address in a thread-safe manner."""
|
||||
if zc := self.zeroconf:
|
||||
if aiozc := self.aiozc:
|
||||
# Currently we do not do any I/O and only
|
||||
# return the cached result (timeout=0)
|
||||
return await zc.async_resolve_host(host_name)
|
||||
return await aiozc.async_resolve_host(host_name)
|
||||
return None
|
||||
|
||||
def _refresh_hosts(self):
|
||||
@ -1034,11 +1035,11 @@ class MDNSStatusThread(threading.Thread):
|
||||
host_name_to_filename[name] = filename
|
||||
filename_to_host_name[filename] = name
|
||||
|
||||
def run(self):
|
||||
async def async_run(self) -> None:
|
||||
global IMPORT_RESULT
|
||||
|
||||
zc = EsphomeZeroconf()
|
||||
self.zeroconf = zc
|
||||
aiozc = AsyncEsphomeZeroconf()
|
||||
self.aiozc = aiozc
|
||||
host_mdns_state = self.host_mdns_state
|
||||
host_name_to_filename = self.host_name_to_filename
|
||||
host_name_with_mdns_enabled = self.host_name_with_mdns_enabled
|
||||
@ -1055,18 +1056,20 @@ class MDNSStatusThread(threading.Thread):
|
||||
stat = DashboardStatus(on_update)
|
||||
imports = DashboardImportDiscovery()
|
||||
browser = DashboardBrowser(
|
||||
zc, ESPHOME_SERVICE_TYPE, [stat.browser_callback, imports.browser_callback]
|
||||
aiozc.zeroconf,
|
||||
ESPHOME_SERVICE_TYPE,
|
||||
[stat.browser_callback, imports.browser_callback],
|
||||
)
|
||||
|
||||
while not STOP_EVENT.is_set():
|
||||
self._refresh_hosts()
|
||||
IMPORT_RESULT = imports.import_state
|
||||
PING_REQUEST.wait()
|
||||
PING_REQUEST.clear()
|
||||
await PING_REQUEST.async_wait()
|
||||
PING_REQUEST.async_clear()
|
||||
|
||||
browser.cancel()
|
||||
zc.close()
|
||||
self.zeroconf = None
|
||||
await browser.async_cancel()
|
||||
await aiozc.async_close()
|
||||
self.aiozc = None
|
||||
|
||||
|
||||
class PingStatusThread(threading.Thread):
|
||||
@ -1246,22 +1249,69 @@ class UndoDeleteRequestHandler(BaseHandler):
|
||||
class MDNSContainer:
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the MDNSContainer."""
|
||||
self._mdns: MDNSStatusThread | None = None
|
||||
self._mdns: MDNSStatus | None = None
|
||||
|
||||
def set_mdns(self, mdns: MDNSStatusThread) -> None:
|
||||
"""Set the MDNSStatusThread instance."""
|
||||
def set_mdns(self, mdns: MDNSStatus) -> None:
|
||||
"""Set the MDNSStatus instance."""
|
||||
self._mdns = mdns
|
||||
|
||||
def get_mdns(self) -> MDNSStatusThread | None:
|
||||
"""Return the MDNSStatusThread instance."""
|
||||
def get_mdns(self) -> MDNSStatus | None:
|
||||
"""Return the MDNSStatus instance."""
|
||||
return self._mdns
|
||||
|
||||
|
||||
class ThreadedAsyncEvent:
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the ThreadedAsyncEvent."""
|
||||
self.event = threading.Event()
|
||||
self.async_event: asyncio.Event | None = None
|
||||
self.loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
def async_setup(
|
||||
self, loop: asyncio.AbstractEventLoop, async_event: asyncio.Event
|
||||
) -> None:
|
||||
"""Set the asyncio.Event instance."""
|
||||
self.loop = loop
|
||||
self.async_event = async_event
|
||||
|
||||
def async_set(self) -> None:
|
||||
"""Set the asyncio.Event instance."""
|
||||
self.async_event.set()
|
||||
self.event.set()
|
||||
|
||||
def set(self) -> None:
|
||||
"""Set the event."""
|
||||
self.loop.call_soon_threadsafe(self.async_event.set)
|
||||
self.event.set()
|
||||
|
||||
def wait(self) -> None:
|
||||
"""Wait for the event."""
|
||||
self.event.wait()
|
||||
|
||||
async def async_wait(self) -> None:
|
||||
"""Wait the event async."""
|
||||
await self.async_event.wait()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the event."""
|
||||
self.loop.call_soon_threadsafe(self.async_event.clear)
|
||||
self.event.clear()
|
||||
|
||||
def async_clear(self) -> None:
|
||||
"""Clear the event async."""
|
||||
self.async_event.clear()
|
||||
self.event.clear()
|
||||
|
||||
def is_set(self) -> bool:
|
||||
"""Return if the event is set."""
|
||||
return self.event.is_set()
|
||||
|
||||
|
||||
PING_RESULT: dict = {}
|
||||
IMPORT_RESULT = {}
|
||||
STOP_EVENT = threading.Event()
|
||||
PING_REQUEST = threading.Event()
|
||||
MQTT_PING_REQUEST = threading.Event()
|
||||
STOP_EVENT = ThreadedAsyncEvent()
|
||||
PING_REQUEST = ThreadedAsyncEvent()
|
||||
MQTT_PING_REQUEST = ThreadedAsyncEvent()
|
||||
MDNS_CONTAINER = MDNSContainer()
|
||||
|
||||
|
||||
@ -1525,6 +1575,18 @@ def start_web_server(args):
|
||||
storage.save(path)
|
||||
settings.cookie_secret = storage.cookie_secret
|
||||
|
||||
try:
|
||||
asyncio.run(async_start_web_server(args))
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
|
||||
async def async_start_web_server(args):
|
||||
loop = asyncio.get_event_loop()
|
||||
PING_REQUEST.async_setup(loop, asyncio.Event())
|
||||
MQTT_PING_REQUEST.async_setup(loop, asyncio.Event())
|
||||
STOP_EVENT.async_setup(loop, asyncio.Event())
|
||||
|
||||
app = make_app(args.verbose)
|
||||
if args.socket is not None:
|
||||
_LOGGER.info(
|
||||
@ -1549,27 +1611,36 @@ def start_web_server(args):
|
||||
|
||||
webbrowser.open(f"http://{args.address}:{args.port}")
|
||||
|
||||
mdns_task: asyncio.Task | None = None
|
||||
ping_status_thread: PingStatusThread | None = None
|
||||
if settings.status_use_ping:
|
||||
status_thread = PingStatusThread()
|
||||
ping_status_thread = PingStatusThread()
|
||||
ping_status_thread.start()
|
||||
else:
|
||||
status_thread = MDNSStatusThread()
|
||||
MDNS_CONTAINER.set_mdns(status_thread)
|
||||
status_thread.start()
|
||||
mdns_status = MDNSStatus()
|
||||
MDNS_CONTAINER.set_mdns(mdns_status)
|
||||
mdns_task = asyncio.create_task(mdns_status.async_run())
|
||||
|
||||
if settings.status_use_mqtt:
|
||||
status_thread_mqtt = MqttStatusThread()
|
||||
status_thread_mqtt.start()
|
||||
|
||||
shutdown_event = asyncio.Event()
|
||||
try:
|
||||
tornado.ioloop.IOLoop.current().start()
|
||||
await shutdown_event.wait()
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
finally:
|
||||
_LOGGER.info("Shutting down...")
|
||||
STOP_EVENT.set()
|
||||
PING_REQUEST.set()
|
||||
status_thread.join()
|
||||
if ping_status_thread:
|
||||
ping_status_thread.join()
|
||||
MDNS_CONTAINER.set_mdns(None)
|
||||
mdns_task.cancel()
|
||||
if settings.status_use_mqtt:
|
||||
status_thread_mqtt.join()
|
||||
MQTT_PING_REQUEST.set()
|
||||
if args.socket is not None:
|
||||
os.remove(args.socket)
|
||||
await asyncio.sleep(0)
|
||||
|
@ -1,22 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
from zeroconf import (
|
||||
IPVersion,
|
||||
ServiceBrowser,
|
||||
ServiceInfo,
|
||||
ServiceStateChange,
|
||||
Zeroconf,
|
||||
)
|
||||
from zeroconf import IPVersion, ServiceInfo, ServiceStateChange, Zeroconf
|
||||
from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf
|
||||
|
||||
from esphome.storage_json import StorageJSON, ext_storage_path
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_BACKGROUND_TASKS: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
class HostResolver(ServiceInfo):
|
||||
"""Resolve a host name to an IP address."""
|
||||
|
||||
@ -65,7 +64,7 @@ class DiscoveredImport:
|
||||
network: str
|
||||
|
||||
|
||||
class DashboardBrowser(ServiceBrowser):
|
||||
class DashboardBrowser(AsyncServiceBrowser):
|
||||
"""A class to browse for ESPHome nodes."""
|
||||
|
||||
|
||||
@ -94,7 +93,28 @@ class DashboardImportDiscovery:
|
||||
# Ignore updates for devices that are not in the import state
|
||||
return
|
||||
|
||||
info = zeroconf.get_service_info(service_type, name)
|
||||
info = AsyncServiceInfo(
|
||||
service_type,
|
||||
name,
|
||||
)
|
||||
if info.load_from_cache(zeroconf):
|
||||
self._process_service_info(name, info)
|
||||
return
|
||||
task = asyncio.create_task(
|
||||
self._async_process_service_info(zeroconf, info, service_type, name)
|
||||
)
|
||||
_BACKGROUND_TASKS.add(task)
|
||||
task.add_done_callback(_BACKGROUND_TASKS.discard)
|
||||
|
||||
async def _async_process_service_info(
|
||||
self, zeroconf: Zeroconf, info: AsyncServiceInfo, service_type: str, name: str
|
||||
) -> None:
|
||||
"""Process a service info."""
|
||||
if await info.async_request(zeroconf):
|
||||
self._process_service_info(name, info)
|
||||
|
||||
def _process_service_info(self, name: str, info: ServiceInfo) -> None:
|
||||
"""Process a service info."""
|
||||
_LOGGER.debug("-> resolved info: %s", info)
|
||||
if info is None:
|
||||
return
|
||||
@ -146,16 +166,17 @@ class DashboardImportDiscovery:
|
||||
)
|
||||
|
||||
|
||||
class EsphomeZeroconf(Zeroconf):
|
||||
def _make_host_resolver(self, host: str) -> HostResolver:
|
||||
"""Create a new HostResolver for the given host name."""
|
||||
name = host.partition(".")[0]
|
||||
info = HostResolver(ESPHOME_SERVICE_TYPE, f"{name}.{ESPHOME_SERVICE_TYPE}")
|
||||
return info
|
||||
def _make_host_resolver(host: str) -> HostResolver:
|
||||
"""Create a new HostResolver for the given host name."""
|
||||
name = host.partition(".")[0]
|
||||
info = HostResolver(ESPHOME_SERVICE_TYPE, f"{name}.{ESPHOME_SERVICE_TYPE}")
|
||||
return info
|
||||
|
||||
|
||||
class EsphomeZeroconf(Zeroconf):
|
||||
def resolve_host(self, host: str, timeout: float = 3.0) -> str | None:
|
||||
"""Resolve a host name to an IP address."""
|
||||
info = self._make_host_resolver(host)
|
||||
info = _make_host_resolver(host)
|
||||
if (
|
||||
info.load_from_cache(self)
|
||||
or (timeout and info.request(self, timeout * 1000))
|
||||
@ -163,12 +184,14 @@ class EsphomeZeroconf(Zeroconf):
|
||||
return str(addresses[0])
|
||||
return None
|
||||
|
||||
|
||||
class AsyncEsphomeZeroconf(AsyncZeroconf):
|
||||
async def async_resolve_host(self, host: str, timeout: float = 3.0) -> str | None:
|
||||
"""Resolve a host name to an IP address."""
|
||||
info = self._make_host_resolver(host)
|
||||
info = _make_host_resolver(host)
|
||||
if (
|
||||
info.load_from_cache(self)
|
||||
or (timeout and await info.async_request(self, timeout * 1000))
|
||||
info.load_from_cache(self.zeroconf)
|
||||
or (timeout and await info.async_request(self.zeroconf, timeout * 1000))
|
||||
) and (addresses := info.ip_addresses_by_version(IPVersion.V4Only)):
|
||||
return str(addresses[0])
|
||||
return None
|
||||
|
Loading…
x
Reference in New Issue
Block a user