mirror of
				https://github.com/esphome/esphome.git
				synced 2025-10-31 15:12:06 +00:00 
			
		
		
		
	dashboard: refactor ping implementation to be more efficient (#6002)
This commit is contained in:
		| @@ -4,5 +4,7 @@ EVENT_ENTRY_ADDED = "entry_added" | |||||||
| EVENT_ENTRY_REMOVED = "entry_removed" | EVENT_ENTRY_REMOVED = "entry_removed" | ||||||
| EVENT_ENTRY_UPDATED = "entry_updated" | EVENT_ENTRY_UPDATED = "entry_updated" | ||||||
| EVENT_ENTRY_STATE_CHANGED = "entry_state_changed" | EVENT_ENTRY_STATE_CHANGED = "entry_state_changed" | ||||||
|  | MAX_EXECUTOR_WORKERS = 48 | ||||||
|  |  | ||||||
|  |  | ||||||
| SENTINEL = object() | SENTINEL = object() | ||||||
|   | |||||||
| @@ -8,6 +8,7 @@ from functools import partial | |||||||
| from typing import TYPE_CHECKING, Any, Callable | from typing import TYPE_CHECKING, Any, Callable | ||||||
|  |  | ||||||
| from ..zeroconf import DiscoveredImport | from ..zeroconf import DiscoveredImport | ||||||
|  | from .dns import DNSCache | ||||||
| from .entries import DashboardEntries | from .entries import DashboardEntries | ||||||
| from .settings import DashboardSettings | from .settings import DashboardSettings | ||||||
|  |  | ||||||
| @@ -69,6 +70,7 @@ class ESPHomeDashboard: | |||||||
|         "mqtt_ping_request", |         "mqtt_ping_request", | ||||||
|         "mdns_status", |         "mdns_status", | ||||||
|         "settings", |         "settings", | ||||||
|  |         "dns_cache", | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     def __init__(self) -> None: |     def __init__(self) -> None: | ||||||
| @@ -81,7 +83,8 @@ class ESPHomeDashboard: | |||||||
|         self.ping_request: asyncio.Event | None = None |         self.ping_request: asyncio.Event | None = None | ||||||
|         self.mqtt_ping_request = threading.Event() |         self.mqtt_ping_request = threading.Event() | ||||||
|         self.mdns_status: MDNSStatus | None = None |         self.mdns_status: MDNSStatus | None = None | ||||||
|         self.settings: DashboardSettings = DashboardSettings() |         self.settings = DashboardSettings() | ||||||
|  |         self.dns_cache = DNSCache() | ||||||
|  |  | ||||||
|     async def async_setup(self) -> None: |     async def async_setup(self) -> None: | ||||||
|         """Setup the dashboard.""" |         """Setup the dashboard.""" | ||||||
|   | |||||||
| @@ -1,11 +1,19 @@ | |||||||
| from __future__ import annotations | from __future__ import annotations | ||||||
|  |  | ||||||
| import asyncio | import asyncio | ||||||
|  | import logging | ||||||
| import os | import os | ||||||
| import socket | import socket | ||||||
|  | import threading | ||||||
|  | import traceback | ||||||
|  | from asyncio import events | ||||||
|  | from concurrent.futures import ThreadPoolExecutor | ||||||
|  | from time import monotonic | ||||||
|  | from typing import Any | ||||||
|  |  | ||||||
| from esphome.storage_json import EsphomeStorageJSON, esphome_storage_path | from esphome.storage_json import EsphomeStorageJSON, esphome_storage_path | ||||||
|  |  | ||||||
|  | from .const import MAX_EXECUTOR_WORKERS | ||||||
| from .core import DASHBOARD | from .core import DASHBOARD | ||||||
| from .web_server import make_app, start_web_server | from .web_server import make_app, start_web_server | ||||||
|  |  | ||||||
| @@ -14,6 +22,95 @@ ENV_DEV = "ESPHOME_DASHBOARD_DEV" | |||||||
| settings = DASHBOARD.settings | settings = DASHBOARD.settings | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def can_use_pidfd() -> bool: | ||||||
|  |     """Check if pidfd_open is available. | ||||||
|  |  | ||||||
|  |     Back ported from cpython 3.12 | ||||||
|  |     """ | ||||||
|  |     if not hasattr(os, "pidfd_open"): | ||||||
|  |         return False | ||||||
|  |     try: | ||||||
|  |         pid = os.getpid() | ||||||
|  |         os.close(os.pidfd_open(pid, 0)) | ||||||
|  |     except OSError: | ||||||
|  |         # blocked by security policy like SECCOMP | ||||||
|  |         return False | ||||||
|  |     return True | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DashboardEventLoopPolicy(asyncio.DefaultEventLoopPolicy): | ||||||
|  |     """Event loop policy for Home Assistant.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, debug: bool) -> None: | ||||||
|  |         """Init the event loop policy.""" | ||||||
|  |         super().__init__() | ||||||
|  |         self.debug = debug | ||||||
|  |         self._watcher: asyncio.AbstractChildWatcher | None = None | ||||||
|  |  | ||||||
|  |     def _init_watcher(self) -> None: | ||||||
|  |         """Initialize the watcher for child processes. | ||||||
|  |  | ||||||
|  |         Back ported from cpython 3.12 | ||||||
|  |         """ | ||||||
|  |         with events._lock:  # type: ignore[attr-defined] # pylint: disable=protected-access | ||||||
|  |             if self._watcher is None:  # pragma: no branch | ||||||
|  |                 if can_use_pidfd(): | ||||||
|  |                     self._watcher = asyncio.PidfdChildWatcher() | ||||||
|  |                 else: | ||||||
|  |                     self._watcher = asyncio.ThreadedChildWatcher() | ||||||
|  |                 if threading.current_thread() is threading.main_thread(): | ||||||
|  |                     self._watcher.attach_loop( | ||||||
|  |                         self._local._loop  # type: ignore[attr-defined] # pylint: disable=protected-access | ||||||
|  |                     ) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def loop_name(self) -> str: | ||||||
|  |         """Return name of the loop.""" | ||||||
|  |         return self._loop_factory.__name__  # type: ignore[no-any-return,attr-defined] | ||||||
|  |  | ||||||
|  |     def new_event_loop(self) -> asyncio.AbstractEventLoop: | ||||||
|  |         """Get the event loop.""" | ||||||
|  |         loop: asyncio.AbstractEventLoop = super().new_event_loop() | ||||||
|  |         loop.set_exception_handler(_async_loop_exception_handler) | ||||||
|  |  | ||||||
|  |         if self.debug: | ||||||
|  |             loop.set_debug(True) | ||||||
|  |  | ||||||
|  |         executor = ThreadPoolExecutor( | ||||||
|  |             thread_name_prefix="SyncWorker", max_workers=MAX_EXECUTOR_WORKERS | ||||||
|  |         ) | ||||||
|  |         loop.set_default_executor(executor) | ||||||
|  |         # bind the built-in time.monotonic directly as loop.time to avoid the | ||||||
|  |         # overhead of the additional method call since its the most called loop | ||||||
|  |         # method and its roughly 10%+ of all the call time in base_events.py | ||||||
|  |         loop.time = monotonic  # type: ignore[method-assign] | ||||||
|  |         return loop | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _async_loop_exception_handler(_: Any, context: dict[str, Any]) -> None: | ||||||
|  |     """Handle all exception inside the core loop.""" | ||||||
|  |     kwargs = {} | ||||||
|  |     if exception := context.get("exception"): | ||||||
|  |         kwargs["exc_info"] = (type(exception), exception, exception.__traceback__) | ||||||
|  |  | ||||||
|  |     logger = logging.getLogger(__package__) | ||||||
|  |     if source_traceback := context.get("source_traceback"): | ||||||
|  |         stack_summary = "".join(traceback.format_list(source_traceback)) | ||||||
|  |         logger.error( | ||||||
|  |             "Error doing job: %s: %s", | ||||||
|  |             context["message"], | ||||||
|  |             stack_summary, | ||||||
|  |             **kwargs,  # type: ignore[arg-type] | ||||||
|  |         ) | ||||||
|  |         return | ||||||
|  |  | ||||||
|  |     logger.error( | ||||||
|  |         "Error doing job: %s", | ||||||
|  |         context["message"], | ||||||
|  |         **kwargs,  # type: ignore[arg-type] | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| def start_dashboard(args) -> None: | def start_dashboard(args) -> None: | ||||||
|     """Start the dashboard.""" |     """Start the dashboard.""" | ||||||
|     settings.parse_args(args) |     settings.parse_args(args) | ||||||
| @@ -26,6 +123,8 @@ def start_dashboard(args) -> None: | |||||||
|             storage.save(path) |             storage.save(path) | ||||||
|         settings.cookie_secret = storage.cookie_secret |         settings.cookie_secret = storage.cookie_secret | ||||||
|  |  | ||||||
|  |     asyncio.set_event_loop_policy(DashboardEventLoopPolicy(settings.verbose)) | ||||||
|  |  | ||||||
|     try: |     try: | ||||||
|         asyncio.run(async_start(args)) |         asyncio.run(async_start(args)) | ||||||
|     except KeyboardInterrupt: |     except KeyboardInterrupt: | ||||||
|   | |||||||
							
								
								
									
										43
									
								
								esphome/dashboard/dns.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								esphome/dashboard/dns.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,43 @@ | |||||||
|  | from __future__ import annotations | ||||||
|  |  | ||||||
|  | import asyncio | ||||||
|  | import sys | ||||||
|  |  | ||||||
|  | from icmplib import NameLookupError, async_resolve | ||||||
|  |  | ||||||
|  | if sys.version_info >= (3, 11): | ||||||
|  |     from asyncio import timeout as async_timeout | ||||||
|  | else: | ||||||
|  |     from async_timeout import timeout as async_timeout | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def _async_resolve_wrapper(hostname: str) -> list[str] | Exception: | ||||||
|  |     """Wrap the icmplib async_resolve function.""" | ||||||
|  |     try: | ||||||
|  |         async with async_timeout(2): | ||||||
|  |             return await async_resolve(hostname) | ||||||
|  |     except (asyncio.TimeoutError, NameLookupError, UnicodeError) as ex: | ||||||
|  |         return ex | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DNSCache: | ||||||
|  |     """DNS cache for the dashboard.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, ttl: int | None = 120) -> None: | ||||||
|  |         """Initialize the DNSCache.""" | ||||||
|  |         self._cache: dict[str, tuple[float, list[str] | Exception]] = {} | ||||||
|  |         self._ttl = ttl | ||||||
|  |  | ||||||
|  |     async def async_resolve( | ||||||
|  |         self, hostname: str, now_monotonic: float | ||||||
|  |     ) -> list[str] | Exception: | ||||||
|  |         """Resolve a hostname to a list of IP address.""" | ||||||
|  |         if expire_time_addresses := self._cache.get(hostname): | ||||||
|  |             expire_time, addresses = expire_time_addresses | ||||||
|  |             if expire_time > now_monotonic: | ||||||
|  |                 return addresses | ||||||
|  |  | ||||||
|  |         expires = now_monotonic + self._ttl | ||||||
|  |         addresses = await _async_resolve_wrapper(hostname) | ||||||
|  |         self._cache[hostname] = (expires, addresses) | ||||||
|  |         return addresses | ||||||
| @@ -14,7 +14,19 @@ from .util.password import password_hash | |||||||
| class DashboardSettings: | class DashboardSettings: | ||||||
|     """Settings for the dashboard.""" |     """Settings for the dashboard.""" | ||||||
|  |  | ||||||
|  |     __slots__ = ( | ||||||
|  |         "config_dir", | ||||||
|  |         "password_hash", | ||||||
|  |         "username", | ||||||
|  |         "using_password", | ||||||
|  |         "on_ha_addon", | ||||||
|  |         "cookie_secret", | ||||||
|  |         "absolute_config_dir", | ||||||
|  |         "verbose", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     def __init__(self) -> None: |     def __init__(self) -> None: | ||||||
|  |         """Initialize the dashboard settings.""" | ||||||
|         self.config_dir: str = "" |         self.config_dir: str = "" | ||||||
|         self.password_hash: str = "" |         self.password_hash: str = "" | ||||||
|         self.username: str = "" |         self.username: str = "" | ||||||
| @@ -22,8 +34,10 @@ class DashboardSettings: | |||||||
|         self.on_ha_addon: bool = False |         self.on_ha_addon: bool = False | ||||||
|         self.cookie_secret: str | None = None |         self.cookie_secret: str | None = None | ||||||
|         self.absolute_config_dir: Path | None = None |         self.absolute_config_dir: Path | None = None | ||||||
|  |         self.verbose: bool = False | ||||||
|  |  | ||||||
|     def parse_args(self, args: Any) -> None: |     def parse_args(self, args: Any) -> None: | ||||||
|  |         """Parse the arguments.""" | ||||||
|         self.on_ha_addon: bool = args.ha_addon |         self.on_ha_addon: bool = args.ha_addon | ||||||
|         password = args.password or os.getenv("PASSWORD") or "" |         password = args.password or os.getenv("PASSWORD") or "" | ||||||
|         if not self.on_ha_addon: |         if not self.on_ha_addon: | ||||||
| @@ -33,6 +47,7 @@ class DashboardSettings: | |||||||
|             self.password_hash = password_hash(password) |             self.password_hash = password_hash(password) | ||||||
|         self.config_dir = args.configuration |         self.config_dir = args.configuration | ||||||
|         self.absolute_config_dir = Path(self.config_dir).resolve() |         self.absolute_config_dir = Path(self.config_dir).resolve() | ||||||
|  |         self.verbose = args.verbose | ||||||
|         CORE.config_path = os.path.join(self.config_dir, ".") |         CORE.config_path = os.path.join(self.config_dir, ".") | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|   | |||||||
| @@ -1,20 +1,20 @@ | |||||||
| from __future__ import annotations | from __future__ import annotations | ||||||
|  |  | ||||||
| import asyncio | import asyncio | ||||||
| import os | import logging | ||||||
|  | import time | ||||||
| from typing import cast | from typing import cast | ||||||
|  |  | ||||||
|  | from icmplib import Host, SocketPermissionError, async_ping | ||||||
|  |  | ||||||
|  | from ..const import MAX_EXECUTOR_WORKERS | ||||||
| from ..core import DASHBOARD | from ..core import DASHBOARD | ||||||
| from ..entries import DashboardEntry, bool_to_entry_state | from ..entries import DashboardEntry, EntryState, bool_to_entry_state | ||||||
| from ..util.itertools import chunked | from ..util.itertools import chunked | ||||||
| from ..util.subprocess import async_system_command_status |  | ||||||
|  |  | ||||||
|  | _LOGGER = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| async def _async_ping_host(host: str) -> bool: | GROUP_SIZE = int(MAX_EXECUTOR_WORKERS / 2) | ||||||
|     """Ping a host.""" |  | ||||||
|     return await async_system_command_status( |  | ||||||
|         ["ping", "-n" if os.name == "nt" else "-c", "1", host] |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class PingStatus: | class PingStatus: | ||||||
| @@ -27,6 +27,10 @@ class PingStatus: | |||||||
|         """Run the ping status.""" |         """Run the ping status.""" | ||||||
|         dashboard = DASHBOARD |         dashboard = DASHBOARD | ||||||
|         entries = dashboard.entries |         entries = dashboard.entries | ||||||
|  |         privileged = await _can_use_icmp_lib_with_privilege() | ||||||
|  |         if privileged is None: | ||||||
|  |             _LOGGER.warning("Cannot use icmplib because privileges are insufficient") | ||||||
|  |             return | ||||||
|  |  | ||||||
|         while not dashboard.stop_event.is_set(): |         while not dashboard.stop_event.is_set(): | ||||||
|             # Only ping if the dashboard is open |             # Only ping if the dashboard is open | ||||||
| @@ -36,15 +40,68 @@ class PingStatus: | |||||||
|             to_ping: list[DashboardEntry] = [ |             to_ping: list[DashboardEntry] = [ | ||||||
|                 entry for entry in current_entries if entry.address is not None |                 entry for entry in current_entries if entry.address is not None | ||||||
|             ] |             ] | ||||||
|             for ping_group in chunked(to_ping, 16): |  | ||||||
|  |             # Resolve DNS for all entries | ||||||
|  |             entries_with_addresses: dict[DashboardEntry, list[str]] = {} | ||||||
|  |             for ping_group in chunked(to_ping, GROUP_SIZE): | ||||||
|                 ping_group = cast(list[DashboardEntry], ping_group) |                 ping_group = cast(list[DashboardEntry], ping_group) | ||||||
|                 results = await asyncio.gather( |                 now_monotonic = time.monotonic() | ||||||
|                     *(_async_ping_host(entry.address) for entry in ping_group), |                 dns_results = await asyncio.gather( | ||||||
|  |                     *( | ||||||
|  |                         dashboard.dns_cache.async_resolve(entry.address, now_monotonic) | ||||||
|  |                         for entry in ping_group | ||||||
|  |                     ), | ||||||
|                     return_exceptions=True, |                     return_exceptions=True, | ||||||
|                 ) |                 ) | ||||||
|                 for entry, result in zip(ping_group, results): |  | ||||||
|  |                 for entry, result in zip(ping_group, dns_results): | ||||||
|                     if isinstance(result, Exception): |                     if isinstance(result, Exception): | ||||||
|                         result = False |                         entries.async_set_state(entry, EntryState.UNKNOWN) | ||||||
|  |                         continue | ||||||
|  |                     if isinstance(result, BaseException): | ||||||
|  |                         raise result | ||||||
|  |                     entries_with_addresses[entry] = result | ||||||
|  |  | ||||||
|  |             # Ping all entries with valid addresses | ||||||
|  |             for ping_group in chunked(entries_with_addresses.items(), GROUP_SIZE): | ||||||
|  |                 entry_addresses = cast(tuple[DashboardEntry, list[str]], ping_group) | ||||||
|  |  | ||||||
|  |                 results = await asyncio.gather( | ||||||
|  |                     *( | ||||||
|  |                         async_ping(addresses[0], privileged=privileged) | ||||||
|  |                         for _, addresses in entry_addresses | ||||||
|  |                     ), | ||||||
|  |                     return_exceptions=True, | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |                 for entry_addresses, result in zip(entry_addresses, results): | ||||||
|  |                     if isinstance(result, Exception): | ||||||
|  |                         ping_result = False | ||||||
|                     elif isinstance(result, BaseException): |                     elif isinstance(result, BaseException): | ||||||
|                         raise result |                         raise result | ||||||
|                     entries.async_set_state(entry, bool_to_entry_state(result)) |                     else: | ||||||
|  |                         host: Host = result | ||||||
|  |                         ping_result = host.is_alive | ||||||
|  |                     entry, _ = entry_addresses | ||||||
|  |                     entries.async_set_state(entry, bool_to_entry_state(ping_result)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def _can_use_icmp_lib_with_privilege() -> None | bool: | ||||||
|  |     """Verify we can create a raw socket.""" | ||||||
|  |     try: | ||||||
|  |         await async_ping("127.0.0.1", count=0, timeout=0, privileged=True) | ||||||
|  |     except SocketPermissionError: | ||||||
|  |         try: | ||||||
|  |             await async_ping("127.0.0.1", count=0, timeout=0, privileged=False) | ||||||
|  |         except SocketPermissionError: | ||||||
|  |             _LOGGER.debug( | ||||||
|  |                 "Cannot use icmplib because privileges are insufficient to create the" | ||||||
|  |                 " socket" | ||||||
|  |             ) | ||||||
|  |             return None | ||||||
|  |  | ||||||
|  |         _LOGGER.debug("Using icmplib in privileged=False mode") | ||||||
|  |         return False | ||||||
|  |  | ||||||
|  |     _LOGGER.debug("Using icmplib in privileged=True mode") | ||||||
|  |     return True | ||||||
|   | |||||||
| @@ -9,6 +9,7 @@ import hashlib | |||||||
| import json | import json | ||||||
| import logging | import logging | ||||||
| import os | import os | ||||||
|  | import time | ||||||
| import secrets | import secrets | ||||||
| import shutil | import shutil | ||||||
| import subprocess | import subprocess | ||||||
| @@ -302,16 +303,28 @@ class EsphomePortCommandWebSocket(EsphomeCommandWebSocket): | |||||||
|         port = json_message["port"] |         port = json_message["port"] | ||||||
|         if ( |         if ( | ||||||
|             port == "OTA"  # pylint: disable=too-many-boolean-expressions |             port == "OTA"  # pylint: disable=too-many-boolean-expressions | ||||||
|             and (mdns := dashboard.mdns_status) |  | ||||||
|             and (entry := entries.get(config_file)) |             and (entry := entries.get(config_file)) | ||||||
|             and entry.loaded_integrations |             and entry.loaded_integrations | ||||||
|             and "api" in entry.loaded_integrations |             and "api" in entry.loaded_integrations | ||||||
|             and (address := await mdns.async_resolve_host(entry.name)) |  | ||||||
|         ): |         ): | ||||||
|             # Use the IP address if available but only |             if (mdns := dashboard.mdns_status) and ( | ||||||
|             # if the API is loaded and the device is online |                 address := await mdns.async_resolve_host(entry.name) | ||||||
|             # since MQTT logging will not work otherwise |             ): | ||||||
|             port = address |                 # Use the IP address if available but only | ||||||
|  |                 # if the API is loaded and the device is online | ||||||
|  |                 # since MQTT logging will not work otherwise | ||||||
|  |                 port = address | ||||||
|  |             elif ( | ||||||
|  |                 entry.address | ||||||
|  |                 and ( | ||||||
|  |                     address_list := await dashboard.dns_cache.async_resolve( | ||||||
|  |                         entry.address, time.monotonic() | ||||||
|  |                     ) | ||||||
|  |                 ) | ||||||
|  |                 and not isinstance(address_list, Exception) | ||||||
|  |             ): | ||||||
|  |                 # If mdns is not available, try to use the DNS cache | ||||||
|  |                 port = address_list[0] | ||||||
|  |  | ||||||
|         return [ |         return [ | ||||||
|             *DASHBOARD_COMMAND, |             *DASHBOARD_COMMAND, | ||||||
|   | |||||||
| @@ -1,7 +1,9 @@ | |||||||
|  | async_timeout==4.0.3; python_version <= "3.10" | ||||||
| voluptuous==0.14.1 | voluptuous==0.14.1 | ||||||
| PyYAML==6.0.1 | PyYAML==6.0.1 | ||||||
| paho-mqtt==1.6.1 | paho-mqtt==1.6.1 | ||||||
| colorama==0.4.6 | colorama==0.4.6 | ||||||
|  | icmplib==3.0.4 | ||||||
| tornado==6.4 | tornado==6.4 | ||||||
| tzlocal==5.2    # from time | tzlocal==5.2    # from time | ||||||
| tzdata>=2021.1  # from time | tzdata>=2021.1  # from time | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user