From 6dfdcff66caf3f62de6442f7ecb2f194d1232c11 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 8 Jan 2024 15:35:43 -1000 Subject: [PATCH] dashboard: refactor ping implementation to be more efficient (#6002) --- esphome/dashboard/const.py | 2 + esphome/dashboard/core.py | 5 +- esphome/dashboard/dashboard.py | 99 ++++++++++++++++++++++++++++++++ esphome/dashboard/dns.py | 43 ++++++++++++++ esphome/dashboard/settings.py | 15 +++++ esphome/dashboard/status/ping.py | 85 ++++++++++++++++++++++----- esphome/dashboard/web_server.py | 25 ++++++-- requirements.txt | 2 + 8 files changed, 255 insertions(+), 21 deletions(-) create mode 100644 esphome/dashboard/dns.py diff --git a/esphome/dashboard/const.py b/esphome/dashboard/const.py index ed2b81d3e8..190d6c4a9a 100644 --- a/esphome/dashboard/const.py +++ b/esphome/dashboard/const.py @@ -4,5 +4,7 @@ EVENT_ENTRY_ADDED = "entry_added" EVENT_ENTRY_REMOVED = "entry_removed" EVENT_ENTRY_UPDATED = "entry_updated" EVENT_ENTRY_STATE_CHANGED = "entry_state_changed" +MAX_EXECUTOR_WORKERS = 48 + SENTINEL = object() diff --git a/esphome/dashboard/core.py b/esphome/dashboard/core.py index ffec9784e8..e22d95fba9 100644 --- a/esphome/dashboard/core.py +++ b/esphome/dashboard/core.py @@ -8,6 +8,7 @@ from functools import partial from typing import TYPE_CHECKING, Any, Callable from ..zeroconf import DiscoveredImport +from .dns import DNSCache from .entries import DashboardEntries from .settings import DashboardSettings @@ -69,6 +70,7 @@ class ESPHomeDashboard: "mqtt_ping_request", "mdns_status", "settings", + "dns_cache", ) def __init__(self) -> None: @@ -81,7 +83,8 @@ class ESPHomeDashboard: self.ping_request: asyncio.Event | None = None self.mqtt_ping_request = threading.Event() self.mdns_status: MDNSStatus | None = None - self.settings: DashboardSettings = DashboardSettings() + self.settings = DashboardSettings() + self.dns_cache = DNSCache() async def async_setup(self) -> None: """Setup the dashboard.""" diff --git a/esphome/dashboard/dashboard.py b/esphome/dashboard/dashboard.py index 789b14653c..2be98ab3e4 100644 --- a/esphome/dashboard/dashboard.py +++ b/esphome/dashboard/dashboard.py @@ -1,11 +1,19 @@ from __future__ import annotations import asyncio +import logging import os import socket +import threading +import traceback +from asyncio import events +from concurrent.futures import ThreadPoolExecutor +from time import monotonic +from typing import Any from esphome.storage_json import EsphomeStorageJSON, esphome_storage_path +from .const import MAX_EXECUTOR_WORKERS from .core import DASHBOARD from .web_server import make_app, start_web_server @@ -14,6 +22,95 @@ ENV_DEV = "ESPHOME_DASHBOARD_DEV" settings = DASHBOARD.settings +def can_use_pidfd() -> bool: + """Check if pidfd_open is available. + + Back ported from cpython 3.12 + """ + if not hasattr(os, "pidfd_open"): + return False + try: + pid = os.getpid() + os.close(os.pidfd_open(pid, 0)) + except OSError: + # blocked by security policy like SECCOMP + return False + return True + + +class DashboardEventLoopPolicy(asyncio.DefaultEventLoopPolicy): + """Event loop policy for Home Assistant.""" + + def __init__(self, debug: bool) -> None: + """Init the event loop policy.""" + super().__init__() + self.debug = debug + self._watcher: asyncio.AbstractChildWatcher | None = None + + def _init_watcher(self) -> None: + """Initialize the watcher for child processes. + + Back ported from cpython 3.12 + """ + with events._lock: # type: ignore[attr-defined] # pylint: disable=protected-access + if self._watcher is None: # pragma: no branch + if can_use_pidfd(): + self._watcher = asyncio.PidfdChildWatcher() + else: + self._watcher = asyncio.ThreadedChildWatcher() + if threading.current_thread() is threading.main_thread(): + self._watcher.attach_loop( + self._local._loop # type: ignore[attr-defined] # pylint: disable=protected-access + ) + + @property + def loop_name(self) -> str: + """Return name of the loop.""" + return self._loop_factory.__name__ # type: ignore[no-any-return,attr-defined] + + def new_event_loop(self) -> asyncio.AbstractEventLoop: + """Get the event loop.""" + loop: asyncio.AbstractEventLoop = super().new_event_loop() + loop.set_exception_handler(_async_loop_exception_handler) + + if self.debug: + loop.set_debug(True) + + executor = ThreadPoolExecutor( + thread_name_prefix="SyncWorker", max_workers=MAX_EXECUTOR_WORKERS + ) + loop.set_default_executor(executor) + # bind the built-in time.monotonic directly as loop.time to avoid the + # overhead of the additional method call since its the most called loop + # method and its roughly 10%+ of all the call time in base_events.py + loop.time = monotonic # type: ignore[method-assign] + return loop + + +def _async_loop_exception_handler(_: Any, context: dict[str, Any]) -> None: + """Handle all exception inside the core loop.""" + kwargs = {} + if exception := context.get("exception"): + kwargs["exc_info"] = (type(exception), exception, exception.__traceback__) + + logger = logging.getLogger(__package__) + if source_traceback := context.get("source_traceback"): + stack_summary = "".join(traceback.format_list(source_traceback)) + logger.error( + "Error doing job: %s: %s", + context["message"], + stack_summary, + **kwargs, # type: ignore[arg-type] + ) + return + + logger.error( + "Error doing job: %s", + context["message"], + **kwargs, # type: ignore[arg-type] + ) + + def start_dashboard(args) -> None: """Start the dashboard.""" settings.parse_args(args) @@ -26,6 +123,8 @@ def start_dashboard(args) -> None: storage.save(path) settings.cookie_secret = storage.cookie_secret + asyncio.set_event_loop_policy(DashboardEventLoopPolicy(settings.verbose)) + try: asyncio.run(async_start(args)) except KeyboardInterrupt: diff --git a/esphome/dashboard/dns.py b/esphome/dashboard/dns.py new file mode 100644 index 0000000000..b78a909220 --- /dev/null +++ b/esphome/dashboard/dns.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import asyncio +import sys + +from icmplib import NameLookupError, async_resolve + +if sys.version_info >= (3, 11): + from asyncio import timeout as async_timeout +else: + from async_timeout import timeout as async_timeout + + +async def _async_resolve_wrapper(hostname: str) -> list[str] | Exception: + """Wrap the icmplib async_resolve function.""" + try: + async with async_timeout(2): + return await async_resolve(hostname) + except (asyncio.TimeoutError, NameLookupError, UnicodeError) as ex: + return ex + + +class DNSCache: + """DNS cache for the dashboard.""" + + def __init__(self, ttl: int | None = 120) -> None: + """Initialize the DNSCache.""" + self._cache: dict[str, tuple[float, list[str] | Exception]] = {} + self._ttl = ttl + + async def async_resolve( + self, hostname: str, now_monotonic: float + ) -> list[str] | Exception: + """Resolve a hostname to a list of IP address.""" + if expire_time_addresses := self._cache.get(hostname): + expire_time, addresses = expire_time_addresses + if expire_time > now_monotonic: + return addresses + + expires = now_monotonic + self._ttl + addresses = await _async_resolve_wrapper(hostname) + self._cache[hostname] = (expires, addresses) + return addresses diff --git a/esphome/dashboard/settings.py b/esphome/dashboard/settings.py index 1a5b1620e8..1f05abab4c 100644 --- a/esphome/dashboard/settings.py +++ b/esphome/dashboard/settings.py @@ -14,7 +14,19 @@ from .util.password import password_hash class DashboardSettings: """Settings for the dashboard.""" + __slots__ = ( + "config_dir", + "password_hash", + "username", + "using_password", + "on_ha_addon", + "cookie_secret", + "absolute_config_dir", + "verbose", + ) + def __init__(self) -> None: + """Initialize the dashboard settings.""" self.config_dir: str = "" self.password_hash: str = "" self.username: str = "" @@ -22,8 +34,10 @@ class DashboardSettings: self.on_ha_addon: bool = False self.cookie_secret: str | None = None self.absolute_config_dir: Path | None = None + self.verbose: bool = False def parse_args(self, args: Any) -> None: + """Parse the arguments.""" self.on_ha_addon: bool = args.ha_addon password = args.password or os.getenv("PASSWORD") or "" if not self.on_ha_addon: @@ -33,6 +47,7 @@ class DashboardSettings: self.password_hash = password_hash(password) self.config_dir = args.configuration self.absolute_config_dir = Path(self.config_dir).resolve() + self.verbose = args.verbose CORE.config_path = os.path.join(self.config_dir, ".") @property diff --git a/esphome/dashboard/status/ping.py b/esphome/dashboard/status/ping.py index 989cd1570f..6630f03c9d 100644 --- a/esphome/dashboard/status/ping.py +++ b/esphome/dashboard/status/ping.py @@ -1,20 +1,20 @@ from __future__ import annotations import asyncio -import os +import logging +import time from typing import cast +from icmplib import Host, SocketPermissionError, async_ping + +from ..const import MAX_EXECUTOR_WORKERS from ..core import DASHBOARD -from ..entries import DashboardEntry, bool_to_entry_state +from ..entries import DashboardEntry, EntryState, bool_to_entry_state from ..util.itertools import chunked -from ..util.subprocess import async_system_command_status +_LOGGER = logging.getLogger(__name__) -async def _async_ping_host(host: str) -> bool: - """Ping a host.""" - return await async_system_command_status( - ["ping", "-n" if os.name == "nt" else "-c", "1", host] - ) +GROUP_SIZE = int(MAX_EXECUTOR_WORKERS / 2) class PingStatus: @@ -27,6 +27,10 @@ class PingStatus: """Run the ping status.""" dashboard = DASHBOARD entries = dashboard.entries + privileged = await _can_use_icmp_lib_with_privilege() + if privileged is None: + _LOGGER.warning("Cannot use icmplib because privileges are insufficient") + return while not dashboard.stop_event.is_set(): # Only ping if the dashboard is open @@ -36,15 +40,68 @@ class PingStatus: to_ping: list[DashboardEntry] = [ entry for entry in current_entries if entry.address is not None ] - for ping_group in chunked(to_ping, 16): + + # Resolve DNS for all entries + entries_with_addresses: dict[DashboardEntry, list[str]] = {} + for ping_group in chunked(to_ping, GROUP_SIZE): ping_group = cast(list[DashboardEntry], ping_group) - results = await asyncio.gather( - *(_async_ping_host(entry.address) for entry in ping_group), + now_monotonic = time.monotonic() + dns_results = await asyncio.gather( + *( + dashboard.dns_cache.async_resolve(entry.address, now_monotonic) + for entry in ping_group + ), return_exceptions=True, ) - for entry, result in zip(ping_group, results): + + for entry, result in zip(ping_group, dns_results): if isinstance(result, Exception): - result = False + entries.async_set_state(entry, EntryState.UNKNOWN) + continue + if isinstance(result, BaseException): + raise result + entries_with_addresses[entry] = result + + # Ping all entries with valid addresses + for ping_group in chunked(entries_with_addresses.items(), GROUP_SIZE): + entry_addresses = cast(tuple[DashboardEntry, list[str]], ping_group) + + results = await asyncio.gather( + *( + async_ping(addresses[0], privileged=privileged) + for _, addresses in entry_addresses + ), + return_exceptions=True, + ) + + for entry_addresses, result in zip(entry_addresses, results): + if isinstance(result, Exception): + ping_result = False elif isinstance(result, BaseException): raise result - entries.async_set_state(entry, bool_to_entry_state(result)) + else: + host: Host = result + ping_result = host.is_alive + entry, _ = entry_addresses + entries.async_set_state(entry, bool_to_entry_state(ping_result)) + + +async def _can_use_icmp_lib_with_privilege() -> None | bool: + """Verify we can create a raw socket.""" + try: + await async_ping("127.0.0.1", count=0, timeout=0, privileged=True) + except SocketPermissionError: + try: + await async_ping("127.0.0.1", count=0, timeout=0, privileged=False) + except SocketPermissionError: + _LOGGER.debug( + "Cannot use icmplib because privileges are insufficient to create the" + " socket" + ) + return None + + _LOGGER.debug("Using icmplib in privileged=False mode") + return False + + _LOGGER.debug("Using icmplib in privileged=True mode") + return True diff --git a/esphome/dashboard/web_server.py b/esphome/dashboard/web_server.py index 6a80865906..c16461d174 100644 --- a/esphome/dashboard/web_server.py +++ b/esphome/dashboard/web_server.py @@ -9,6 +9,7 @@ import hashlib import json import logging import os +import time import secrets import shutil import subprocess @@ -302,16 +303,28 @@ class EsphomePortCommandWebSocket(EsphomeCommandWebSocket): port = json_message["port"] if ( port == "OTA" # pylint: disable=too-many-boolean-expressions - and (mdns := dashboard.mdns_status) and (entry := entries.get(config_file)) and entry.loaded_integrations and "api" in entry.loaded_integrations - and (address := await mdns.async_resolve_host(entry.name)) ): - # Use the IP address if available but only - # if the API is loaded and the device is online - # since MQTT logging will not work otherwise - port = address + if (mdns := dashboard.mdns_status) and ( + address := await mdns.async_resolve_host(entry.name) + ): + # Use the IP address if available but only + # if the API is loaded and the device is online + # since MQTT logging will not work otherwise + port = address + elif ( + entry.address + and ( + address_list := await dashboard.dns_cache.async_resolve( + entry.address, time.monotonic() + ) + ) + and not isinstance(address_list, Exception) + ): + # If mdns is not available, try to use the DNS cache + port = address_list[0] return [ *DASHBOARD_COMMAND, diff --git a/requirements.txt b/requirements.txt index 115f85de3e..5281b64e66 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,9 @@ +async_timeout==4.0.3; python_version <= "3.10" voluptuous==0.14.1 PyYAML==6.0.1 paho-mqtt==1.6.1 colorama==0.4.6 +icmplib==3.0.4 tornado==6.4 tzlocal==5.2 # from time tzdata>=2021.1 # from time