diff --git a/docker/ha-addon-rootfs/etc/s6-overlay/s6-rc.d/esphome/run b/docker/ha-addon-rootfs/etc/s6-overlay/s6-rc.d/esphome/run index f973dfcaf8..cdbaff6c04 100755 --- a/docker/ha-addon-rootfs/etc/s6-overlay/s6-rc.d/esphome/run +++ b/docker/ha-addon-rootfs/etc/s6-overlay/s6-rc.d/esphome/run @@ -23,10 +23,6 @@ if bashio::config.true 'streamer_mode'; then export ESPHOME_STREAMER_MODE=true fi -if bashio::config.true 'status_use_ping'; then - export ESPHOME_DASHBOARD_USE_PING=true -fi - if bashio::config.has_value 'relative_url'; then export ESPHOME_DASHBOARD_RELATIVE_URL=$(bashio::config 'relative_url') fi diff --git a/esphome/dashboard/core.py b/esphome/dashboard/core.py index f53cb7ffb1..416442c426 100644 --- a/esphome/dashboard/core.py +++ b/esphome/dashboard/core.py @@ -9,7 +9,7 @@ import json import logging from pathlib import Path import threading -from typing import TYPE_CHECKING, Any, Callable +from typing import Any, Callable from esphome.storage_json import ignored_devices_storage_path @@ -17,15 +17,15 @@ from ..zeroconf import DiscoveredImport from .dns import DNSCache from .entries import DashboardEntries from .settings import DashboardSettings - -if TYPE_CHECKING: - from .status.mdns import MDNSStatus - +from .status.mdns import MDNSStatus +from .status.ping import PingStatus _LOGGER = logging.getLogger(__name__) IGNORED_DEVICES_STORAGE_PATH = "ignored-devices.json" +MDNS_BOOTSTRAP_TIME = 7.5 + @dataclass class Event: @@ -81,6 +81,7 @@ class ESPHomeDashboard: "dns_cache", "_background_tasks", "ignored_devices", + "_ping_status_task", ) def __init__(self) -> None: @@ -97,6 +98,7 @@ class ESPHomeDashboard: self.dns_cache = DNSCache() self._background_tasks: set[asyncio.Task] = set() self.ignored_devices: set[str] = set() + self._ping_status_task: asyncio.Task | None = None async def async_setup(self) -> None: """Setup the dashboard.""" @@ -121,41 +123,48 @@ class ESPHomeDashboard: {"ignored_devices": sorted(self.ignored_devices)}, indent=2, fp=f_handle ) + def _async_start_ping_status(self, ping_status: PingStatus) -> None: + self._ping_status_task = asyncio.create_task(ping_status.async_run()) + async def async_run(self) -> None: """Run the dashboard.""" settings = self.settings mdns_task: asyncio.Task | None = None - ping_status_task: asyncio.Task | None = None await self.entries.async_update_entries() - if settings.status_use_ping: - from .status.ping import PingStatus + mdns_status = MDNSStatus(self) + ping_status = PingStatus(self) + start_ping_timer: asyncio.TimerHandle | None = None - ping_status = PingStatus() - ping_status_task = asyncio.create_task(ping_status.async_run()) - else: - from .status.mdns import MDNSStatus - - mdns_status = MDNSStatus() - await mdns_status.async_refresh_hosts() - self.mdns_status = mdns_status + self.mdns_status = mdns_status + if mdns_status.async_setup(): mdns_task = asyncio.create_task(mdns_status.async_run()) + # Start ping MDNS_BOOTSTRAP_TIME seconds after startup to ensure + # MDNS has had a chance to resolve the devices + start_ping_timer = self.loop.call_later( + MDNS_BOOTSTRAP_TIME, self._async_start_ping_status, ping_status + ) + else: + # If mDNS is not available, start the ping status immediately + self._async_start_ping_status(ping_status) if settings.status_use_mqtt: from .status.mqtt import MqttStatusThread - status_thread_mqtt = MqttStatusThread() + status_thread_mqtt = MqttStatusThread(self) status_thread_mqtt.start() - shutdown_event = asyncio.Event() try: - await shutdown_event.wait() + await asyncio.Event().wait() finally: _LOGGER.info("Shutting down...") self.stop_event.set() self.ping_request.set() - if ping_status_task: - ping_status_task.cancel() + if start_ping_timer: + start_ping_timer.cancel() + if self._ping_status_task: + self._ping_status_task.cancel() + self._ping_status_task = None if mdns_task: mdns_task.cancel() if settings.status_use_mqtt: diff --git a/esphome/dashboard/dns.py b/esphome/dashboard/dns.py index b78a909220..ea85d338bf 100644 --- a/esphome/dashboard/dns.py +++ b/esphome/dashboard/dns.py @@ -1,6 +1,8 @@ from __future__ import annotations import asyncio +from contextlib import suppress +from ipaddress import ip_address import sys from icmplib import NameLookupError, async_resolve @@ -10,11 +12,15 @@ if sys.version_info >= (3, 11): else: from async_timeout import timeout as async_timeout +RESOLVE_TIMEOUT = 3.0 + async def _async_resolve_wrapper(hostname: str) -> list[str] | Exception: """Wrap the icmplib async_resolve function.""" + with suppress(ValueError): + return [str(ip_address(hostname))] try: - async with async_timeout(2): + async with async_timeout(RESOLVE_TIMEOUT): return await async_resolve(hostname) except (asyncio.TimeoutError, NameLookupError, UnicodeError) as ex: return ex diff --git a/esphome/dashboard/entries.py b/esphome/dashboard/entries.py index cb0d4a3772..e4825298f7 100644 --- a/esphome/dashboard/entries.py +++ b/esphome/dashboard/entries.py @@ -2,6 +2,8 @@ from __future__ import annotations import asyncio from collections import defaultdict +from dataclasses import dataclass +from functools import lru_cache import logging import os from typing import TYPE_CHECKING, Any @@ -27,37 +29,53 @@ _LOGGER = logging.getLogger(__name__) DashboardCacheKeyType = tuple[int, int, float, int] -# Currently EntryState is a simple -# online/offline/unknown enum, but in the future -# it may be expanded to include more states + +@dataclass(frozen=True) +class EntryState: + """Represents the state of an entry.""" + + reachable: ReachableState + source: EntryStateSource -class EntryState(StrEnum): - ONLINE = "online" - OFFLINE = "offline" +class EntryStateSource(StrEnum): + MDNS = "mdns" + PING = "ping" + MQTT = "mqtt" UNKNOWN = "unknown" -_BOOL_TO_ENTRY_STATE = { - True: EntryState.ONLINE, - False: EntryState.OFFLINE, - None: EntryState.UNKNOWN, -} -_ENTRY_STATE_TO_BOOL = { - EntryState.ONLINE: True, - EntryState.OFFLINE: False, - EntryState.UNKNOWN: None, -} +class ReachableState(StrEnum): + ONLINE = "online" + OFFLINE = "offline" + DNS_FAILURE = "dns_failure" + UNKNOWN = "unknown" -def bool_to_entry_state(value: bool) -> EntryState: +_BOOL_TO_REACHABLE_STATE = { + True: ReachableState.ONLINE, + False: ReachableState.OFFLINE, + None: ReachableState.UNKNOWN, +} +_REACHABLE_STATE_TO_BOOL = { + ReachableState.ONLINE: True, + ReachableState.OFFLINE: False, + ReachableState.DNS_FAILURE: False, + ReachableState.UNKNOWN: None, +} + +UNKNOWN_STATE = EntryState(ReachableState.UNKNOWN, EntryStateSource.UNKNOWN) + + +@lru_cache # creating frozen dataclass instances is expensive, so we cache them +def bool_to_entry_state(value: bool | None, source: EntryStateSource) -> EntryState: """Convert a bool to an entry state.""" - return _BOOL_TO_ENTRY_STATE[value] + return EntryState(_BOOL_TO_REACHABLE_STATE[value], source) def entry_state_to_bool(value: EntryState) -> bool | None: """Convert an entry state to a bool.""" - return _ENTRY_STATE_TO_BOOL[value] + return _REACHABLE_STATE_TO_BOOL[value.reachable] class DashboardEntries: @@ -119,6 +137,55 @@ class DashboardEntries: """Set the state for an entry.""" self.async_set_state(entry, state) + def set_state_if_online_or_source( + self, entry: DashboardEntry, state: EntryState + ) -> None: + """Set the state for an entry if its online or provided by the source or unknown.""" + asyncio.run_coroutine_threadsafe( + self._async_set_state_if_online_or_source(entry, state), self._loop + ).result() + + async def _async_set_state_if_online_or_source( + self, entry: DashboardEntry, state: EntryState + ) -> None: + """Set the state for an entry if its online or provided by the source or unknown.""" + self.async_set_state_if_online_or_source(entry, state) + + def async_set_state_if_online_or_source( + self, entry: DashboardEntry, state: EntryState + ) -> None: + """Set the state for an entry if its online or provided by the source or unknown.""" + if ( + state.reachable is ReachableState.ONLINE + and entry.state.reachable is not ReachableState.ONLINE + ) or entry.state.source in ( + EntryStateSource.UNKNOWN, + state.source, + ): + self.async_set_state(entry, state) + + def set_state_if_source(self, entry: DashboardEntry, state: EntryState) -> None: + """Set the state for an entry if provided by the source or unknown.""" + asyncio.run_coroutine_threadsafe( + self._async_set_state_if_source(entry, state), self._loop + ).result() + + async def _async_set_state_if_source( + self, entry: DashboardEntry, state: EntryState + ) -> None: + """Set the state for an entry if rovided by the source or unknown.""" + self.async_set_state_if_source(entry, state) + + def async_set_state_if_source( + self, entry: DashboardEntry, state: EntryState + ) -> None: + """Set the state for an entry if provided by the source or unknown.""" + if entry.state.source in ( + EntryStateSource.UNKNOWN, + state.source, + ): + self.async_set_state(entry, state) + def async_set_state(self, entry: DashboardEntry, state: EntryState) -> None: """Set the state for an entry.""" if entry.state == state: @@ -269,7 +336,7 @@ class DashboardEntry: self._storage_path = ext_storage_path(self.filename) self.cache_key = cache_key self.storage: StorageJSON | None = None - self.state = EntryState.UNKNOWN + self.state = UNKNOWN_STATE self._to_dict: dict[str, Any] | None = None def __repr__(self) -> str: diff --git a/esphome/dashboard/settings.py b/esphome/dashboard/settings.py index 1f05abab4c..fa39b55016 100644 --- a/esphome/dashboard/settings.py +++ b/esphome/dashboard/settings.py @@ -54,10 +54,6 @@ class DashboardSettings: def relative_url(self) -> str: return os.getenv("ESPHOME_DASHBOARD_RELATIVE_URL") or "/" - @property - def status_use_ping(self): - return get_bool_env("ESPHOME_DASHBOARD_USE_PING") - @property def status_use_mqtt(self) -> bool: return get_bool_env("ESPHOME_DASHBOARD_USE_MQTT") diff --git a/esphome/dashboard/status/mdns.py b/esphome/dashboard/status/mdns.py index 9f6399ca8b..f9ac7b4289 100644 --- a/esphome/dashboard/status/mdns.py +++ b/esphome/dashboard/status/mdns.py @@ -1,6 +1,8 @@ from __future__ import annotations import asyncio +import logging +import typing from esphome.zeroconf import ( ESPHOME_SERVICE_TYPE, @@ -11,20 +13,36 @@ from esphome.zeroconf import ( ) from ..const import SENTINEL -from ..core import DASHBOARD -from ..entries import DashboardEntry, bool_to_entry_state +from ..entries import DashboardEntry, EntryStateSource, bool_to_entry_state + +if typing.TYPE_CHECKING: + from ..core import ESPHomeDashboard + +_LOGGER = logging.getLogger(__name__) class MDNSStatus: """Class that updates the mdns status.""" - def __init__(self) -> None: + def __init__(self, dashboard: ESPHomeDashboard) -> None: """Initialize the MDNSStatus class.""" super().__init__() self.aiozc: AsyncEsphomeZeroconf | None = None # This is the current mdns state for each host (True, False, None) self.host_mdns_state: dict[str, bool | None] = {} self._loop = asyncio.get_running_loop() + self.dashboard = dashboard + + def async_setup(self) -> bool: + """Set up the MDNSStatus class.""" + try: + self.aiozc = AsyncEsphomeZeroconf() + except OSError as e: + _LOGGER.warning( + "Failed to initialize zeroconf, will fallback to ping: %s", e + ) + return False + return True async def async_resolve_host(self, host_name: str) -> list[str] | None: """Resolve a host name to an address in a thread-safe manner.""" @@ -32,9 +50,9 @@ class MDNSStatus: return await aiozc.async_resolve_host(host_name) return None - async def async_refresh_hosts(self): + async def async_refresh_hosts(self) -> None: """Refresh the hosts to track.""" - dashboard = DASHBOARD + dashboard = self.dashboard host_mdns_state = self.host_mdns_state entries = dashboard.entries poll_names: dict[str, set[DashboardEntry]] = {} @@ -49,7 +67,7 @@ class MDNSStatus: # the device won't respond to a request to ._esphomelib._tcp.local. poll_names.setdefault(entry.name, set()).add(entry) elif (online := host_mdns_state.get(entry.name, SENTINEL)) != SENTINEL: - entries.async_set_state(entry, bool_to_entry_state(online)) + self._async_set_state(entry, online) if poll_names and self.aiozc: results = await asyncio.gather( *(self.aiozc.async_resolve_host(name) for name in poll_names) @@ -58,13 +76,25 @@ class MDNSStatus: result = bool(address_list) host_mdns_state[name] = result for entry in poll_names[name]: - entries.async_set_state(entry, bool_to_entry_state(result)) + self._async_set_state(entry, result) + + def _async_set_state(self, entry: DashboardEntry, result: bool | None) -> None: + """Set the state of an entry.""" + state = bool_to_entry_state(result, EntryStateSource.MDNS) + if result: + # If we can reach it via mDNS, we always set it online + # since its the fastest source if its working + self.dashboard.entries.async_set_state(entry, state) + else: + # However if we can't reach it via mDNS + # we only set it to offline if the state is unknown + # or from mDNS + self.dashboard.entries.async_set_state_if_source(entry, state) async def async_run(self) -> None: - dashboard = DASHBOARD + """Run the mdns status.""" + dashboard = self.dashboard entries = dashboard.entries - aiozc = AsyncEsphomeZeroconf() - self.aiozc = aiozc host_mdns_state = self.host_mdns_state def on_update(dat: dict[str, bool | None]) -> None: @@ -73,15 +103,14 @@ class MDNSStatus: host_mdns_state[name] = result if matching_entries := entries.get_by_name(name): for entry in matching_entries: - if not entry.no_mdns: - entries.async_set_state(entry, bool_to_entry_state(result)) + self._async_set_state(entry, result) stat = DashboardStatus(on_update) imports = DashboardImportDiscovery() dashboard.import_result = imports.import_state browser = DashboardBrowser( - aiozc.zeroconf, + self.aiozc.zeroconf, ESPHOME_SERVICE_TYPE, [stat.browser_callback, imports.browser_callback], ) @@ -93,5 +122,5 @@ class MDNSStatus: ping_request.clear() await browser.async_cancel() - await aiozc.async_close() + await self.aiozc.async_close() self.aiozc = None diff --git a/esphome/dashboard/status/mqtt.py b/esphome/dashboard/status/mqtt.py index 8c35dd2535..70eb0b58b5 100644 --- a/esphome/dashboard/status/mqtt.py +++ b/esphome/dashboard/status/mqtt.py @@ -4,19 +4,27 @@ import binascii import json import os import threading +import typing from esphome import mqtt -from ..core import DASHBOARD -from ..entries import EntryState +from ..entries import EntryStateSource, bool_to_entry_state + +if typing.TYPE_CHECKING: + from ..core import ESPHomeDashboard class MqttStatusThread(threading.Thread): """Status thread to get the status of the devices via MQTT.""" + def __init__(self, dashboard: ESPHomeDashboard) -> None: + """Initialize the status thread.""" + super().__init__() + self.dashboard = dashboard + def run(self) -> None: """Run the status thread.""" - dashboard = DASHBOARD + dashboard = self.dashboard entries = dashboard.entries current_entries = entries.all() @@ -31,10 +39,13 @@ class MqttStatusThread(threading.Thread): data = json.loads(payload) if "name" not in data: return - for entry in current_entries: - if entry.name == data["name"]: - entries.set_state(entry, EntryState.ONLINE) - return + if matching_entries := entries.get_by_name(data["name"]): + for entry in matching_entries: + # Only override state if we don't have a state from another source + # or we have a state from MQTT and the device is reachable + entries.set_state_if_online_or_source( + entry, bool_to_entry_state(True, EntryStateSource.MQTT) + ) def on_connect(client, userdata, flags, return_code): client.publish("esphome/discover", None, retain=False) @@ -56,8 +67,10 @@ class MqttStatusThread(threading.Thread): current_entries = entries.all() # will be set to true on on_message for entry in current_entries: - if entry.no_mdns: - entries.set_state(entry, EntryState.OFFLINE) + # Only override state if we don't have a state from another source + entries.set_state_if_source( + entry, bool_to_entry_state(False, EntryStateSource.MQTT) + ) client.publish("esphome/discover", None, retain=False) dashboard.mqtt_ping_request.wait() diff --git a/esphome/dashboard/status/ping.py b/esphome/dashboard/status/ping.py index 6630f03c9d..b4f106d21a 100644 --- a/esphome/dashboard/status/ping.py +++ b/esphome/dashboard/status/ping.py @@ -3,29 +3,44 @@ from __future__ import annotations import asyncio import logging import time +import typing from typing import cast from icmplib import Host, SocketPermissionError, async_ping from ..const import MAX_EXECUTOR_WORKERS -from ..core import DASHBOARD -from ..entries import DashboardEntry, EntryState, bool_to_entry_state +from ..entries import ( + DashboardEntry, + EntryState, + EntryStateSource, + ReachableState, + bool_to_entry_state, +) from ..util.itertools import chunked +if typing.TYPE_CHECKING: + from ..core import ESPHomeDashboard + + _LOGGER = logging.getLogger(__name__) GROUP_SIZE = int(MAX_EXECUTOR_WORKERS / 2) +DNS_FAILURE_STATE = EntryState(ReachableState.DNS_FAILURE, EntryStateSource.PING) + +MIN_PING_INTERVAL = 5 # ensure we don't ping too often + class PingStatus: - def __init__(self) -> None: + def __init__(self, dashboard: ESPHomeDashboard) -> None: """Initialize the PingStatus class.""" super().__init__() self._loop = asyncio.get_running_loop() + self.dashboard = dashboard async def async_run(self) -> None: """Run the ping status.""" - dashboard = DASHBOARD + dashboard = self.dashboard entries = dashboard.entries privileged = await _can_use_icmp_lib_with_privilege() if privileged is None: @@ -36,10 +51,24 @@ class PingStatus: # Only ping if the dashboard is open await dashboard.ping_request.wait() dashboard.ping_request.clear() + iteration_start = time.monotonic() current_entries = dashboard.entries.async_all() - to_ping: list[DashboardEntry] = [ - entry for entry in current_entries if entry.address is not None - ] + to_ping: list[DashboardEntry] = [] + + for entry in current_entries: + if entry.address is None: + # No address or we already have a state from another source + # so no need to ping + continue + if ( + entry.state.reachable is ReachableState.ONLINE + and entry.state.source + not in (EntryStateSource.PING, EntryStateSource.UNKNOWN) + ): + # If we already have a state from another source and + # it's online, we don't need to ping + continue + to_ping.append(entry) # Resolve DNS for all entries entries_with_addresses: dict[DashboardEntry, list[str]] = {} @@ -56,7 +85,10 @@ class PingStatus: for entry, result in zip(ping_group, dns_results): if isinstance(result, Exception): - entries.async_set_state(entry, EntryState.UNKNOWN) + # Only update state if its unknown or from ping + # so we don't mark it as offline if we have a state + # from mDNS or MQTT + entries.async_set_state_if_source(entry, DNS_FAILURE_STATE) continue if isinstance(result, BaseException): raise result @@ -82,8 +114,20 @@ class PingStatus: else: host: Host = result ping_result = host.is_alive - entry, _ = entry_addresses - entries.async_set_state(entry, bool_to_entry_state(ping_result)) + entry: DashboardEntry = entry_addresses[0] + # If we can reach it via ping, we always set it + # online, however if we can't reach it via ping + # we only set it to offline if the state is unknown + # or from ping + entries.async_set_state_if_online_or_source( + entry, + bool_to_entry_state(ping_result, EntryStateSource.PING), + ) + + if not dashboard.stop_event.is_set(): + iteration_duration = time.monotonic() - iteration_start + if iteration_duration < MIN_PING_INTERVAL: + await asyncio.sleep(MIN_PING_INTERVAL - iteration_duration) async def _can_use_icmp_lib_with_privilege() -> None | bool: diff --git a/esphome/dashboard/web_server.py b/esphome/dashboard/web_server.py index f78f17b093..f7888ce6ed 100644 --- a/esphome/dashboard/web_server.py +++ b/esphome/dashboard/web_server.py @@ -45,7 +45,7 @@ from esphome.yaml_util import FastestAvailableSafeLoader from .const import DASHBOARD_COMMAND from .core import DASHBOARD -from .entries import EntryState, entry_state_to_bool +from .entries import UNKNOWN_STATE, entry_state_to_bool from .util.file import write_file from .util.subprocess import async_run_system_command from .util.text import friendly_name_slugify @@ -381,7 +381,7 @@ class EsphomeRenameHandler(EsphomeCommandWebSocket): # Remove the old ping result from the cache entries = DASHBOARD.entries if entry := entries.get(self.old_name): - entries.async_set_state(entry, EntryState.UNKNOWN) + entries.async_set_state(entry, UNKNOWN_STATE) class EsphomeUploadHandler(EsphomePortCommandWebSocket): diff --git a/esphome/zeroconf.py b/esphome/zeroconf.py index 5a92a4ed7c..0e2c431d4b 100644 --- a/esphome/zeroconf.py +++ b/esphome/zeroconf.py @@ -5,7 +5,13 @@ from dataclasses import dataclass import logging from typing import Callable -from zeroconf import IPVersion, ServiceInfo, ServiceStateChange, Zeroconf +from zeroconf import ( + AddressResolver, + IPVersion, + ServiceInfo, + ServiceStateChange, + Zeroconf, +) from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf from esphome.storage_json import StorageJSON, ext_storage_path @@ -16,15 +22,6 @@ _LOGGER = logging.getLogger(__name__) _BACKGROUND_TASKS: set[asyncio.Task] = set() -class HostResolver(ServiceInfo): - """Resolve a host name to an IP address.""" - - @property - def _is_complete(self) -> bool: - """The ServiceInfo has all expected properties.""" - return bool(self._ipv4_addresses) - - class DashboardStatus: def __init__(self, on_update: Callable[[dict[str, bool | None], []]]) -> None: """Initialize the dashboard status.""" @@ -166,19 +163,10 @@ class DashboardImportDiscovery: ) -def _make_host_resolver(host: str) -> HostResolver: - """Create a new HostResolver for the given host name.""" - name = host.partition(".")[0] - info = HostResolver( - ESPHOME_SERVICE_TYPE, f"{name}.{ESPHOME_SERVICE_TYPE}", server=f"{name}.local." - ) - return info - - class EsphomeZeroconf(Zeroconf): def resolve_host(self, host: str, timeout: float = 3.0) -> list[str] | None: """Resolve a host name to an IP address.""" - info = _make_host_resolver(host) + info = AddressResolver(f'{host.partition(".")[0]}.local.') if ( info.load_from_cache(self) or (timeout and info.request(self, timeout * 1000)) @@ -192,7 +180,7 @@ class AsyncEsphomeZeroconf(AsyncZeroconf): self, host: str, timeout: float = 3.0 ) -> list[str] | None: """Resolve a host name to an IP address.""" - info = _make_host_resolver(host) + info = AddressResolver(f'{host.partition(".")[0]}.local.') if ( info.load_from_cache(self.zeroconf) or (timeout and await info.async_request(self.zeroconf, timeout * 1000))