mirror of
				https://github.com/esphome/esphome.git
				synced 2025-10-31 15:12:06 +00:00 
			
		
		
		
	dashboard: refactor ping implementation to be more efficient (#6002)
This commit is contained in:
		| @@ -4,5 +4,7 @@ EVENT_ENTRY_ADDED = "entry_added" | ||||
| EVENT_ENTRY_REMOVED = "entry_removed" | ||||
| EVENT_ENTRY_UPDATED = "entry_updated" | ||||
| EVENT_ENTRY_STATE_CHANGED = "entry_state_changed" | ||||
| MAX_EXECUTOR_WORKERS = 48 | ||||
|  | ||||
|  | ||||
| SENTINEL = object() | ||||
|   | ||||
| @@ -8,6 +8,7 @@ from functools import partial | ||||
| from typing import TYPE_CHECKING, Any, Callable | ||||
|  | ||||
| from ..zeroconf import DiscoveredImport | ||||
| from .dns import DNSCache | ||||
| from .entries import DashboardEntries | ||||
| from .settings import DashboardSettings | ||||
|  | ||||
| @@ -69,6 +70,7 @@ class ESPHomeDashboard: | ||||
|         "mqtt_ping_request", | ||||
|         "mdns_status", | ||||
|         "settings", | ||||
|         "dns_cache", | ||||
|     ) | ||||
|  | ||||
|     def __init__(self) -> None: | ||||
| @@ -81,7 +83,8 @@ class ESPHomeDashboard: | ||||
|         self.ping_request: asyncio.Event | None = None | ||||
|         self.mqtt_ping_request = threading.Event() | ||||
|         self.mdns_status: MDNSStatus | None = None | ||||
|         self.settings: DashboardSettings = DashboardSettings() | ||||
|         self.settings = DashboardSettings() | ||||
|         self.dns_cache = DNSCache() | ||||
|  | ||||
|     async def async_setup(self) -> None: | ||||
|         """Setup the dashboard.""" | ||||
|   | ||||
| @@ -1,11 +1,19 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| import asyncio | ||||
| import logging | ||||
| import os | ||||
| import socket | ||||
| import threading | ||||
| import traceback | ||||
| from asyncio import events | ||||
| from concurrent.futures import ThreadPoolExecutor | ||||
| from time import monotonic | ||||
| from typing import Any | ||||
|  | ||||
| from esphome.storage_json import EsphomeStorageJSON, esphome_storage_path | ||||
|  | ||||
| from .const import MAX_EXECUTOR_WORKERS | ||||
| from .core import DASHBOARD | ||||
| from .web_server import make_app, start_web_server | ||||
|  | ||||
| @@ -14,6 +22,95 @@ ENV_DEV = "ESPHOME_DASHBOARD_DEV" | ||||
| settings = DASHBOARD.settings | ||||
|  | ||||
|  | ||||
| def can_use_pidfd() -> bool: | ||||
|     """Check if pidfd_open is available. | ||||
|  | ||||
|     Back ported from cpython 3.12 | ||||
|     """ | ||||
|     if not hasattr(os, "pidfd_open"): | ||||
|         return False | ||||
|     try: | ||||
|         pid = os.getpid() | ||||
|         os.close(os.pidfd_open(pid, 0)) | ||||
|     except OSError: | ||||
|         # blocked by security policy like SECCOMP | ||||
|         return False | ||||
|     return True | ||||
|  | ||||
|  | ||||
| class DashboardEventLoopPolicy(asyncio.DefaultEventLoopPolicy): | ||||
|     """Event loop policy for Home Assistant.""" | ||||
|  | ||||
|     def __init__(self, debug: bool) -> None: | ||||
|         """Init the event loop policy.""" | ||||
|         super().__init__() | ||||
|         self.debug = debug | ||||
|         self._watcher: asyncio.AbstractChildWatcher | None = None | ||||
|  | ||||
|     def _init_watcher(self) -> None: | ||||
|         """Initialize the watcher for child processes. | ||||
|  | ||||
|         Back ported from cpython 3.12 | ||||
|         """ | ||||
|         with events._lock:  # type: ignore[attr-defined] # pylint: disable=protected-access | ||||
|             if self._watcher is None:  # pragma: no branch | ||||
|                 if can_use_pidfd(): | ||||
|                     self._watcher = asyncio.PidfdChildWatcher() | ||||
|                 else: | ||||
|                     self._watcher = asyncio.ThreadedChildWatcher() | ||||
|                 if threading.current_thread() is threading.main_thread(): | ||||
|                     self._watcher.attach_loop( | ||||
|                         self._local._loop  # type: ignore[attr-defined] # pylint: disable=protected-access | ||||
|                     ) | ||||
|  | ||||
|     @property | ||||
|     def loop_name(self) -> str: | ||||
|         """Return name of the loop.""" | ||||
|         return self._loop_factory.__name__  # type: ignore[no-any-return,attr-defined] | ||||
|  | ||||
|     def new_event_loop(self) -> asyncio.AbstractEventLoop: | ||||
|         """Get the event loop.""" | ||||
|         loop: asyncio.AbstractEventLoop = super().new_event_loop() | ||||
|         loop.set_exception_handler(_async_loop_exception_handler) | ||||
|  | ||||
|         if self.debug: | ||||
|             loop.set_debug(True) | ||||
|  | ||||
|         executor = ThreadPoolExecutor( | ||||
|             thread_name_prefix="SyncWorker", max_workers=MAX_EXECUTOR_WORKERS | ||||
|         ) | ||||
|         loop.set_default_executor(executor) | ||||
|         # bind the built-in time.monotonic directly as loop.time to avoid the | ||||
|         # overhead of the additional method call since its the most called loop | ||||
|         # method and its roughly 10%+ of all the call time in base_events.py | ||||
|         loop.time = monotonic  # type: ignore[method-assign] | ||||
|         return loop | ||||
|  | ||||
|  | ||||
| def _async_loop_exception_handler(_: Any, context: dict[str, Any]) -> None: | ||||
|     """Handle all exception inside the core loop.""" | ||||
|     kwargs = {} | ||||
|     if exception := context.get("exception"): | ||||
|         kwargs["exc_info"] = (type(exception), exception, exception.__traceback__) | ||||
|  | ||||
|     logger = logging.getLogger(__package__) | ||||
|     if source_traceback := context.get("source_traceback"): | ||||
|         stack_summary = "".join(traceback.format_list(source_traceback)) | ||||
|         logger.error( | ||||
|             "Error doing job: %s: %s", | ||||
|             context["message"], | ||||
|             stack_summary, | ||||
|             **kwargs,  # type: ignore[arg-type] | ||||
|         ) | ||||
|         return | ||||
|  | ||||
|     logger.error( | ||||
|         "Error doing job: %s", | ||||
|         context["message"], | ||||
|         **kwargs,  # type: ignore[arg-type] | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def start_dashboard(args) -> None: | ||||
|     """Start the dashboard.""" | ||||
|     settings.parse_args(args) | ||||
| @@ -26,6 +123,8 @@ def start_dashboard(args) -> None: | ||||
|             storage.save(path) | ||||
|         settings.cookie_secret = storage.cookie_secret | ||||
|  | ||||
|     asyncio.set_event_loop_policy(DashboardEventLoopPolicy(settings.verbose)) | ||||
|  | ||||
|     try: | ||||
|         asyncio.run(async_start(args)) | ||||
|     except KeyboardInterrupt: | ||||
|   | ||||
							
								
								
									
										43
									
								
								esphome/dashboard/dns.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								esphome/dashboard/dns.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,43 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| import asyncio | ||||
| import sys | ||||
|  | ||||
| from icmplib import NameLookupError, async_resolve | ||||
|  | ||||
| if sys.version_info >= (3, 11): | ||||
|     from asyncio import timeout as async_timeout | ||||
| else: | ||||
|     from async_timeout import timeout as async_timeout | ||||
|  | ||||
|  | ||||
| async def _async_resolve_wrapper(hostname: str) -> list[str] | Exception: | ||||
|     """Wrap the icmplib async_resolve function.""" | ||||
|     try: | ||||
|         async with async_timeout(2): | ||||
|             return await async_resolve(hostname) | ||||
|     except (asyncio.TimeoutError, NameLookupError, UnicodeError) as ex: | ||||
|         return ex | ||||
|  | ||||
|  | ||||
| class DNSCache: | ||||
|     """DNS cache for the dashboard.""" | ||||
|  | ||||
|     def __init__(self, ttl: int | None = 120) -> None: | ||||
|         """Initialize the DNSCache.""" | ||||
|         self._cache: dict[str, tuple[float, list[str] | Exception]] = {} | ||||
|         self._ttl = ttl | ||||
|  | ||||
|     async def async_resolve( | ||||
|         self, hostname: str, now_monotonic: float | ||||
|     ) -> list[str] | Exception: | ||||
|         """Resolve a hostname to a list of IP address.""" | ||||
|         if expire_time_addresses := self._cache.get(hostname): | ||||
|             expire_time, addresses = expire_time_addresses | ||||
|             if expire_time > now_monotonic: | ||||
|                 return addresses | ||||
|  | ||||
|         expires = now_monotonic + self._ttl | ||||
|         addresses = await _async_resolve_wrapper(hostname) | ||||
|         self._cache[hostname] = (expires, addresses) | ||||
|         return addresses | ||||
| @@ -14,7 +14,19 @@ from .util.password import password_hash | ||||
| class DashboardSettings: | ||||
|     """Settings for the dashboard.""" | ||||
|  | ||||
|     __slots__ = ( | ||||
|         "config_dir", | ||||
|         "password_hash", | ||||
|         "username", | ||||
|         "using_password", | ||||
|         "on_ha_addon", | ||||
|         "cookie_secret", | ||||
|         "absolute_config_dir", | ||||
|         "verbose", | ||||
|     ) | ||||
|  | ||||
|     def __init__(self) -> None: | ||||
|         """Initialize the dashboard settings.""" | ||||
|         self.config_dir: str = "" | ||||
|         self.password_hash: str = "" | ||||
|         self.username: str = "" | ||||
| @@ -22,8 +34,10 @@ class DashboardSettings: | ||||
|         self.on_ha_addon: bool = False | ||||
|         self.cookie_secret: str | None = None | ||||
|         self.absolute_config_dir: Path | None = None | ||||
|         self.verbose: bool = False | ||||
|  | ||||
|     def parse_args(self, args: Any) -> None: | ||||
|         """Parse the arguments.""" | ||||
|         self.on_ha_addon: bool = args.ha_addon | ||||
|         password = args.password or os.getenv("PASSWORD") or "" | ||||
|         if not self.on_ha_addon: | ||||
| @@ -33,6 +47,7 @@ class DashboardSettings: | ||||
|             self.password_hash = password_hash(password) | ||||
|         self.config_dir = args.configuration | ||||
|         self.absolute_config_dir = Path(self.config_dir).resolve() | ||||
|         self.verbose = args.verbose | ||||
|         CORE.config_path = os.path.join(self.config_dir, ".") | ||||
|  | ||||
|     @property | ||||
|   | ||||
| @@ -1,20 +1,20 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| import asyncio | ||||
| import os | ||||
| import logging | ||||
| import time | ||||
| from typing import cast | ||||
|  | ||||
| from icmplib import Host, SocketPermissionError, async_ping | ||||
|  | ||||
| from ..const import MAX_EXECUTOR_WORKERS | ||||
| from ..core import DASHBOARD | ||||
| from ..entries import DashboardEntry, bool_to_entry_state | ||||
| from ..entries import DashboardEntry, EntryState, bool_to_entry_state | ||||
| from ..util.itertools import chunked | ||||
| from ..util.subprocess import async_system_command_status | ||||
|  | ||||
| _LOGGER = logging.getLogger(__name__) | ||||
|  | ||||
| async def _async_ping_host(host: str) -> bool: | ||||
|     """Ping a host.""" | ||||
|     return await async_system_command_status( | ||||
|         ["ping", "-n" if os.name == "nt" else "-c", "1", host] | ||||
|     ) | ||||
| GROUP_SIZE = int(MAX_EXECUTOR_WORKERS / 2) | ||||
|  | ||||
|  | ||||
| class PingStatus: | ||||
| @@ -27,6 +27,10 @@ class PingStatus: | ||||
|         """Run the ping status.""" | ||||
|         dashboard = DASHBOARD | ||||
|         entries = dashboard.entries | ||||
|         privileged = await _can_use_icmp_lib_with_privilege() | ||||
|         if privileged is None: | ||||
|             _LOGGER.warning("Cannot use icmplib because privileges are insufficient") | ||||
|             return | ||||
|  | ||||
|         while not dashboard.stop_event.is_set(): | ||||
|             # Only ping if the dashboard is open | ||||
| @@ -36,15 +40,68 @@ class PingStatus: | ||||
|             to_ping: list[DashboardEntry] = [ | ||||
|                 entry for entry in current_entries if entry.address is not None | ||||
|             ] | ||||
|             for ping_group in chunked(to_ping, 16): | ||||
|  | ||||
|             # Resolve DNS for all entries | ||||
|             entries_with_addresses: dict[DashboardEntry, list[str]] = {} | ||||
|             for ping_group in chunked(to_ping, GROUP_SIZE): | ||||
|                 ping_group = cast(list[DashboardEntry], ping_group) | ||||
|                 results = await asyncio.gather( | ||||
|                     *(_async_ping_host(entry.address) for entry in ping_group), | ||||
|                 now_monotonic = time.monotonic() | ||||
|                 dns_results = await asyncio.gather( | ||||
|                     *( | ||||
|                         dashboard.dns_cache.async_resolve(entry.address, now_monotonic) | ||||
|                         for entry in ping_group | ||||
|                     ), | ||||
|                     return_exceptions=True, | ||||
|                 ) | ||||
|                 for entry, result in zip(ping_group, results): | ||||
|  | ||||
|                 for entry, result in zip(ping_group, dns_results): | ||||
|                     if isinstance(result, Exception): | ||||
|                         result = False | ||||
|                         entries.async_set_state(entry, EntryState.UNKNOWN) | ||||
|                         continue | ||||
|                     if isinstance(result, BaseException): | ||||
|                         raise result | ||||
|                     entries_with_addresses[entry] = result | ||||
|  | ||||
|             # Ping all entries with valid addresses | ||||
|             for ping_group in chunked(entries_with_addresses.items(), GROUP_SIZE): | ||||
|                 entry_addresses = cast(tuple[DashboardEntry, list[str]], ping_group) | ||||
|  | ||||
|                 results = await asyncio.gather( | ||||
|                     *( | ||||
|                         async_ping(addresses[0], privileged=privileged) | ||||
|                         for _, addresses in entry_addresses | ||||
|                     ), | ||||
|                     return_exceptions=True, | ||||
|                 ) | ||||
|  | ||||
|                 for entry_addresses, result in zip(entry_addresses, results): | ||||
|                     if isinstance(result, Exception): | ||||
|                         ping_result = False | ||||
|                     elif isinstance(result, BaseException): | ||||
|                         raise result | ||||
|                     entries.async_set_state(entry, bool_to_entry_state(result)) | ||||
|                     else: | ||||
|                         host: Host = result | ||||
|                         ping_result = host.is_alive | ||||
|                     entry, _ = entry_addresses | ||||
|                     entries.async_set_state(entry, bool_to_entry_state(ping_result)) | ||||
|  | ||||
|  | ||||
| async def _can_use_icmp_lib_with_privilege() -> None | bool: | ||||
|     """Verify we can create a raw socket.""" | ||||
|     try: | ||||
|         await async_ping("127.0.0.1", count=0, timeout=0, privileged=True) | ||||
|     except SocketPermissionError: | ||||
|         try: | ||||
|             await async_ping("127.0.0.1", count=0, timeout=0, privileged=False) | ||||
|         except SocketPermissionError: | ||||
|             _LOGGER.debug( | ||||
|                 "Cannot use icmplib because privileges are insufficient to create the" | ||||
|                 " socket" | ||||
|             ) | ||||
|             return None | ||||
|  | ||||
|         _LOGGER.debug("Using icmplib in privileged=False mode") | ||||
|         return False | ||||
|  | ||||
|     _LOGGER.debug("Using icmplib in privileged=True mode") | ||||
|     return True | ||||
|   | ||||
| @@ -9,6 +9,7 @@ import hashlib | ||||
| import json | ||||
| import logging | ||||
| import os | ||||
| import time | ||||
| import secrets | ||||
| import shutil | ||||
| import subprocess | ||||
| @@ -302,16 +303,28 @@ class EsphomePortCommandWebSocket(EsphomeCommandWebSocket): | ||||
|         port = json_message["port"] | ||||
|         if ( | ||||
|             port == "OTA"  # pylint: disable=too-many-boolean-expressions | ||||
|             and (mdns := dashboard.mdns_status) | ||||
|             and (entry := entries.get(config_file)) | ||||
|             and entry.loaded_integrations | ||||
|             and "api" in entry.loaded_integrations | ||||
|             and (address := await mdns.async_resolve_host(entry.name)) | ||||
|         ): | ||||
|             # Use the IP address if available but only | ||||
|             # if the API is loaded and the device is online | ||||
|             # since MQTT logging will not work otherwise | ||||
|             port = address | ||||
|             if (mdns := dashboard.mdns_status) and ( | ||||
|                 address := await mdns.async_resolve_host(entry.name) | ||||
|             ): | ||||
|                 # Use the IP address if available but only | ||||
|                 # if the API is loaded and the device is online | ||||
|                 # since MQTT logging will not work otherwise | ||||
|                 port = address | ||||
|             elif ( | ||||
|                 entry.address | ||||
|                 and ( | ||||
|                     address_list := await dashboard.dns_cache.async_resolve( | ||||
|                         entry.address, time.monotonic() | ||||
|                     ) | ||||
|                 ) | ||||
|                 and not isinstance(address_list, Exception) | ||||
|             ): | ||||
|                 # If mdns is not available, try to use the DNS cache | ||||
|                 port = address_list[0] | ||||
|  | ||||
|         return [ | ||||
|             *DASHBOARD_COMMAND, | ||||
|   | ||||
| @@ -1,7 +1,9 @@ | ||||
| async_timeout==4.0.3; python_version <= "3.10" | ||||
| voluptuous==0.14.1 | ||||
| PyYAML==6.0.1 | ||||
| paho-mqtt==1.6.1 | ||||
| colorama==0.4.6 | ||||
| icmplib==3.0.4 | ||||
| tornado==6.4 | ||||
| tzlocal==5.2    # from time | ||||
| tzdata>=2021.1  # from time | ||||
|   | ||||
		Reference in New Issue
	
	Block a user