1
0
mirror of https://github.com/esphome/esphome.git synced 2025-02-27 23:38:17 +00:00

dashboard: Implement automatic ping fallback (#8263)

This commit is contained in:
J. Nick Koston 2025-02-27 15:17:07 +00:00 committed by GitHub
parent 63a7234767
commit 3048f303c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 254 additions and 106 deletions

View File

@ -23,10 +23,6 @@ if bashio::config.true 'streamer_mode'; then
export ESPHOME_STREAMER_MODE=true
fi
if bashio::config.true 'status_use_ping'; then
export ESPHOME_DASHBOARD_USE_PING=true
fi
if bashio::config.has_value 'relative_url'; then
export ESPHOME_DASHBOARD_RELATIVE_URL=$(bashio::config 'relative_url')
fi

View File

@ -9,7 +9,7 @@ import json
import logging
from pathlib import Path
import threading
from typing import TYPE_CHECKING, Any, Callable
from typing import Any, Callable
from esphome.storage_json import ignored_devices_storage_path
@ -17,15 +17,15 @@ from ..zeroconf import DiscoveredImport
from .dns import DNSCache
from .entries import DashboardEntries
from .settings import DashboardSettings
if TYPE_CHECKING:
from .status.mdns import MDNSStatus
from .status.mdns import MDNSStatus
from .status.ping import PingStatus
_LOGGER = logging.getLogger(__name__)
IGNORED_DEVICES_STORAGE_PATH = "ignored-devices.json"
MDNS_BOOTSTRAP_TIME = 7.5
@dataclass
class Event:
@ -81,6 +81,7 @@ class ESPHomeDashboard:
"dns_cache",
"_background_tasks",
"ignored_devices",
"_ping_status_task",
)
def __init__(self) -> None:
@ -97,6 +98,7 @@ class ESPHomeDashboard:
self.dns_cache = DNSCache()
self._background_tasks: set[asyncio.Task] = set()
self.ignored_devices: set[str] = set()
self._ping_status_task: asyncio.Task | None = None
async def async_setup(self) -> None:
"""Setup the dashboard."""
@ -121,41 +123,48 @@ class ESPHomeDashboard:
{"ignored_devices": sorted(self.ignored_devices)}, indent=2, fp=f_handle
)
def _async_start_ping_status(self, ping_status: PingStatus) -> None:
self._ping_status_task = asyncio.create_task(ping_status.async_run())
async def async_run(self) -> None:
"""Run the dashboard."""
settings = self.settings
mdns_task: asyncio.Task | None = None
ping_status_task: asyncio.Task | None = None
await self.entries.async_update_entries()
if settings.status_use_ping:
from .status.ping import PingStatus
mdns_status = MDNSStatus(self)
ping_status = PingStatus(self)
start_ping_timer: asyncio.TimerHandle | None = None
ping_status = PingStatus()
ping_status_task = asyncio.create_task(ping_status.async_run())
else:
from .status.mdns import MDNSStatus
mdns_status = MDNSStatus()
await mdns_status.async_refresh_hosts()
self.mdns_status = mdns_status
self.mdns_status = mdns_status
if mdns_status.async_setup():
mdns_task = asyncio.create_task(mdns_status.async_run())
# Start ping MDNS_BOOTSTRAP_TIME seconds after startup to ensure
# MDNS has had a chance to resolve the devices
start_ping_timer = self.loop.call_later(
MDNS_BOOTSTRAP_TIME, self._async_start_ping_status, ping_status
)
else:
# If mDNS is not available, start the ping status immediately
self._async_start_ping_status(ping_status)
if settings.status_use_mqtt:
from .status.mqtt import MqttStatusThread
status_thread_mqtt = MqttStatusThread()
status_thread_mqtt = MqttStatusThread(self)
status_thread_mqtt.start()
shutdown_event = asyncio.Event()
try:
await shutdown_event.wait()
await asyncio.Event().wait()
finally:
_LOGGER.info("Shutting down...")
self.stop_event.set()
self.ping_request.set()
if ping_status_task:
ping_status_task.cancel()
if start_ping_timer:
start_ping_timer.cancel()
if self._ping_status_task:
self._ping_status_task.cancel()
self._ping_status_task = None
if mdns_task:
mdns_task.cancel()
if settings.status_use_mqtt:

View File

@ -1,6 +1,8 @@
from __future__ import annotations
import asyncio
from contextlib import suppress
from ipaddress import ip_address
import sys
from icmplib import NameLookupError, async_resolve
@ -10,11 +12,15 @@ if sys.version_info >= (3, 11):
else:
from async_timeout import timeout as async_timeout
RESOLVE_TIMEOUT = 3.0
async def _async_resolve_wrapper(hostname: str) -> list[str] | Exception:
"""Wrap the icmplib async_resolve function."""
with suppress(ValueError):
return [str(ip_address(hostname))]
try:
async with async_timeout(2):
async with async_timeout(RESOLVE_TIMEOUT):
return await async_resolve(hostname)
except (asyncio.TimeoutError, NameLookupError, UnicodeError) as ex:
return ex

View File

@ -2,6 +2,8 @@ from __future__ import annotations
import asyncio
from collections import defaultdict
from dataclasses import dataclass
from functools import lru_cache
import logging
import os
from typing import TYPE_CHECKING, Any
@ -27,37 +29,53 @@ _LOGGER = logging.getLogger(__name__)
DashboardCacheKeyType = tuple[int, int, float, int]
# Currently EntryState is a simple
# online/offline/unknown enum, but in the future
# it may be expanded to include more states
@dataclass(frozen=True)
class EntryState:
"""Represents the state of an entry."""
reachable: ReachableState
source: EntryStateSource
class EntryState(StrEnum):
ONLINE = "online"
OFFLINE = "offline"
class EntryStateSource(StrEnum):
MDNS = "mdns"
PING = "ping"
MQTT = "mqtt"
UNKNOWN = "unknown"
_BOOL_TO_ENTRY_STATE = {
True: EntryState.ONLINE,
False: EntryState.OFFLINE,
None: EntryState.UNKNOWN,
}
_ENTRY_STATE_TO_BOOL = {
EntryState.ONLINE: True,
EntryState.OFFLINE: False,
EntryState.UNKNOWN: None,
}
class ReachableState(StrEnum):
ONLINE = "online"
OFFLINE = "offline"
DNS_FAILURE = "dns_failure"
UNKNOWN = "unknown"
def bool_to_entry_state(value: bool) -> EntryState:
_BOOL_TO_REACHABLE_STATE = {
True: ReachableState.ONLINE,
False: ReachableState.OFFLINE,
None: ReachableState.UNKNOWN,
}
_REACHABLE_STATE_TO_BOOL = {
ReachableState.ONLINE: True,
ReachableState.OFFLINE: False,
ReachableState.DNS_FAILURE: False,
ReachableState.UNKNOWN: None,
}
UNKNOWN_STATE = EntryState(ReachableState.UNKNOWN, EntryStateSource.UNKNOWN)
@lru_cache # creating frozen dataclass instances is expensive, so we cache them
def bool_to_entry_state(value: bool | None, source: EntryStateSource) -> EntryState:
"""Convert a bool to an entry state."""
return _BOOL_TO_ENTRY_STATE[value]
return EntryState(_BOOL_TO_REACHABLE_STATE[value], source)
def entry_state_to_bool(value: EntryState) -> bool | None:
"""Convert an entry state to a bool."""
return _ENTRY_STATE_TO_BOOL[value]
return _REACHABLE_STATE_TO_BOOL[value.reachable]
class DashboardEntries:
@ -119,6 +137,55 @@ class DashboardEntries:
"""Set the state for an entry."""
self.async_set_state(entry, state)
def set_state_if_online_or_source(
self, entry: DashboardEntry, state: EntryState
) -> None:
"""Set the state for an entry if its online or provided by the source or unknown."""
asyncio.run_coroutine_threadsafe(
self._async_set_state_if_online_or_source(entry, state), self._loop
).result()
async def _async_set_state_if_online_or_source(
self, entry: DashboardEntry, state: EntryState
) -> None:
"""Set the state for an entry if its online or provided by the source or unknown."""
self.async_set_state_if_online_or_source(entry, state)
def async_set_state_if_online_or_source(
self, entry: DashboardEntry, state: EntryState
) -> None:
"""Set the state for an entry if its online or provided by the source or unknown."""
if (
state.reachable is ReachableState.ONLINE
and entry.state.reachable is not ReachableState.ONLINE
) or entry.state.source in (
EntryStateSource.UNKNOWN,
state.source,
):
self.async_set_state(entry, state)
def set_state_if_source(self, entry: DashboardEntry, state: EntryState) -> None:
"""Set the state for an entry if provided by the source or unknown."""
asyncio.run_coroutine_threadsafe(
self._async_set_state_if_source(entry, state), self._loop
).result()
async def _async_set_state_if_source(
self, entry: DashboardEntry, state: EntryState
) -> None:
"""Set the state for an entry if rovided by the source or unknown."""
self.async_set_state_if_source(entry, state)
def async_set_state_if_source(
self, entry: DashboardEntry, state: EntryState
) -> None:
"""Set the state for an entry if provided by the source or unknown."""
if entry.state.source in (
EntryStateSource.UNKNOWN,
state.source,
):
self.async_set_state(entry, state)
def async_set_state(self, entry: DashboardEntry, state: EntryState) -> None:
"""Set the state for an entry."""
if entry.state == state:
@ -269,7 +336,7 @@ class DashboardEntry:
self._storage_path = ext_storage_path(self.filename)
self.cache_key = cache_key
self.storage: StorageJSON | None = None
self.state = EntryState.UNKNOWN
self.state = UNKNOWN_STATE
self._to_dict: dict[str, Any] | None = None
def __repr__(self) -> str:

View File

@ -54,10 +54,6 @@ class DashboardSettings:
def relative_url(self) -> str:
return os.getenv("ESPHOME_DASHBOARD_RELATIVE_URL") or "/"
@property
def status_use_ping(self):
return get_bool_env("ESPHOME_DASHBOARD_USE_PING")
@property
def status_use_mqtt(self) -> bool:
return get_bool_env("ESPHOME_DASHBOARD_USE_MQTT")

View File

@ -1,6 +1,8 @@
from __future__ import annotations
import asyncio
import logging
import typing
from esphome.zeroconf import (
ESPHOME_SERVICE_TYPE,
@ -11,20 +13,36 @@ from esphome.zeroconf import (
)
from ..const import SENTINEL
from ..core import DASHBOARD
from ..entries import DashboardEntry, bool_to_entry_state
from ..entries import DashboardEntry, EntryStateSource, bool_to_entry_state
if typing.TYPE_CHECKING:
from ..core import ESPHomeDashboard
_LOGGER = logging.getLogger(__name__)
class MDNSStatus:
"""Class that updates the mdns status."""
def __init__(self) -> None:
def __init__(self, dashboard: ESPHomeDashboard) -> None:
"""Initialize the MDNSStatus class."""
super().__init__()
self.aiozc: AsyncEsphomeZeroconf | None = None
# This is the current mdns state for each host (True, False, None)
self.host_mdns_state: dict[str, bool | None] = {}
self._loop = asyncio.get_running_loop()
self.dashboard = dashboard
def async_setup(self) -> bool:
"""Set up the MDNSStatus class."""
try:
self.aiozc = AsyncEsphomeZeroconf()
except OSError as e:
_LOGGER.warning(
"Failed to initialize zeroconf, will fallback to ping: %s", e
)
return False
return True
async def async_resolve_host(self, host_name: str) -> list[str] | None:
"""Resolve a host name to an address in a thread-safe manner."""
@ -32,9 +50,9 @@ class MDNSStatus:
return await aiozc.async_resolve_host(host_name)
return None
async def async_refresh_hosts(self):
async def async_refresh_hosts(self) -> None:
"""Refresh the hosts to track."""
dashboard = DASHBOARD
dashboard = self.dashboard
host_mdns_state = self.host_mdns_state
entries = dashboard.entries
poll_names: dict[str, set[DashboardEntry]] = {}
@ -49,7 +67,7 @@ class MDNSStatus:
# the device won't respond to a request to ._esphomelib._tcp.local.
poll_names.setdefault(entry.name, set()).add(entry)
elif (online := host_mdns_state.get(entry.name, SENTINEL)) != SENTINEL:
entries.async_set_state(entry, bool_to_entry_state(online))
self._async_set_state(entry, online)
if poll_names and self.aiozc:
results = await asyncio.gather(
*(self.aiozc.async_resolve_host(name) for name in poll_names)
@ -58,13 +76,25 @@ class MDNSStatus:
result = bool(address_list)
host_mdns_state[name] = result
for entry in poll_names[name]:
entries.async_set_state(entry, bool_to_entry_state(result))
self._async_set_state(entry, result)
def _async_set_state(self, entry: DashboardEntry, result: bool | None) -> None:
"""Set the state of an entry."""
state = bool_to_entry_state(result, EntryStateSource.MDNS)
if result:
# If we can reach it via mDNS, we always set it online
# since its the fastest source if its working
self.dashboard.entries.async_set_state(entry, state)
else:
# However if we can't reach it via mDNS
# we only set it to offline if the state is unknown
# or from mDNS
self.dashboard.entries.async_set_state_if_source(entry, state)
async def async_run(self) -> None:
dashboard = DASHBOARD
"""Run the mdns status."""
dashboard = self.dashboard
entries = dashboard.entries
aiozc = AsyncEsphomeZeroconf()
self.aiozc = aiozc
host_mdns_state = self.host_mdns_state
def on_update(dat: dict[str, bool | None]) -> None:
@ -73,15 +103,14 @@ class MDNSStatus:
host_mdns_state[name] = result
if matching_entries := entries.get_by_name(name):
for entry in matching_entries:
if not entry.no_mdns:
entries.async_set_state(entry, bool_to_entry_state(result))
self._async_set_state(entry, result)
stat = DashboardStatus(on_update)
imports = DashboardImportDiscovery()
dashboard.import_result = imports.import_state
browser = DashboardBrowser(
aiozc.zeroconf,
self.aiozc.zeroconf,
ESPHOME_SERVICE_TYPE,
[stat.browser_callback, imports.browser_callback],
)
@ -93,5 +122,5 @@ class MDNSStatus:
ping_request.clear()
await browser.async_cancel()
await aiozc.async_close()
await self.aiozc.async_close()
self.aiozc = None

View File

@ -4,19 +4,27 @@ import binascii
import json
import os
import threading
import typing
from esphome import mqtt
from ..core import DASHBOARD
from ..entries import EntryState
from ..entries import EntryStateSource, bool_to_entry_state
if typing.TYPE_CHECKING:
from ..core import ESPHomeDashboard
class MqttStatusThread(threading.Thread):
"""Status thread to get the status of the devices via MQTT."""
def __init__(self, dashboard: ESPHomeDashboard) -> None:
"""Initialize the status thread."""
super().__init__()
self.dashboard = dashboard
def run(self) -> None:
"""Run the status thread."""
dashboard = DASHBOARD
dashboard = self.dashboard
entries = dashboard.entries
current_entries = entries.all()
@ -31,10 +39,13 @@ class MqttStatusThread(threading.Thread):
data = json.loads(payload)
if "name" not in data:
return
for entry in current_entries:
if entry.name == data["name"]:
entries.set_state(entry, EntryState.ONLINE)
return
if matching_entries := entries.get_by_name(data["name"]):
for entry in matching_entries:
# Only override state if we don't have a state from another source
# or we have a state from MQTT and the device is reachable
entries.set_state_if_online_or_source(
entry, bool_to_entry_state(True, EntryStateSource.MQTT)
)
def on_connect(client, userdata, flags, return_code):
client.publish("esphome/discover", None, retain=False)
@ -56,8 +67,10 @@ class MqttStatusThread(threading.Thread):
current_entries = entries.all()
# will be set to true on on_message
for entry in current_entries:
if entry.no_mdns:
entries.set_state(entry, EntryState.OFFLINE)
# Only override state if we don't have a state from another source
entries.set_state_if_source(
entry, bool_to_entry_state(False, EntryStateSource.MQTT)
)
client.publish("esphome/discover", None, retain=False)
dashboard.mqtt_ping_request.wait()

View File

@ -3,29 +3,44 @@ from __future__ import annotations
import asyncio
import logging
import time
import typing
from typing import cast
from icmplib import Host, SocketPermissionError, async_ping
from ..const import MAX_EXECUTOR_WORKERS
from ..core import DASHBOARD
from ..entries import DashboardEntry, EntryState, bool_to_entry_state
from ..entries import (
DashboardEntry,
EntryState,
EntryStateSource,
ReachableState,
bool_to_entry_state,
)
from ..util.itertools import chunked
if typing.TYPE_CHECKING:
from ..core import ESPHomeDashboard
_LOGGER = logging.getLogger(__name__)
GROUP_SIZE = int(MAX_EXECUTOR_WORKERS / 2)
DNS_FAILURE_STATE = EntryState(ReachableState.DNS_FAILURE, EntryStateSource.PING)
MIN_PING_INTERVAL = 5 # ensure we don't ping too often
class PingStatus:
def __init__(self) -> None:
def __init__(self, dashboard: ESPHomeDashboard) -> None:
"""Initialize the PingStatus class."""
super().__init__()
self._loop = asyncio.get_running_loop()
self.dashboard = dashboard
async def async_run(self) -> None:
"""Run the ping status."""
dashboard = DASHBOARD
dashboard = self.dashboard
entries = dashboard.entries
privileged = await _can_use_icmp_lib_with_privilege()
if privileged is None:
@ -36,10 +51,24 @@ class PingStatus:
# Only ping if the dashboard is open
await dashboard.ping_request.wait()
dashboard.ping_request.clear()
iteration_start = time.monotonic()
current_entries = dashboard.entries.async_all()
to_ping: list[DashboardEntry] = [
entry for entry in current_entries if entry.address is not None
]
to_ping: list[DashboardEntry] = []
for entry in current_entries:
if entry.address is None:
# No address or we already have a state from another source
# so no need to ping
continue
if (
entry.state.reachable is ReachableState.ONLINE
and entry.state.source
not in (EntryStateSource.PING, EntryStateSource.UNKNOWN)
):
# If we already have a state from another source and
# it's online, we don't need to ping
continue
to_ping.append(entry)
# Resolve DNS for all entries
entries_with_addresses: dict[DashboardEntry, list[str]] = {}
@ -56,7 +85,10 @@ class PingStatus:
for entry, result in zip(ping_group, dns_results):
if isinstance(result, Exception):
entries.async_set_state(entry, EntryState.UNKNOWN)
# Only update state if its unknown or from ping
# so we don't mark it as offline if we have a state
# from mDNS or MQTT
entries.async_set_state_if_source(entry, DNS_FAILURE_STATE)
continue
if isinstance(result, BaseException):
raise result
@ -82,8 +114,20 @@ class PingStatus:
else:
host: Host = result
ping_result = host.is_alive
entry, _ = entry_addresses
entries.async_set_state(entry, bool_to_entry_state(ping_result))
entry: DashboardEntry = entry_addresses[0]
# If we can reach it via ping, we always set it
# online, however if we can't reach it via ping
# we only set it to offline if the state is unknown
# or from ping
entries.async_set_state_if_online_or_source(
entry,
bool_to_entry_state(ping_result, EntryStateSource.PING),
)
if not dashboard.stop_event.is_set():
iteration_duration = time.monotonic() - iteration_start
if iteration_duration < MIN_PING_INTERVAL:
await asyncio.sleep(MIN_PING_INTERVAL - iteration_duration)
async def _can_use_icmp_lib_with_privilege() -> None | bool:

View File

@ -45,7 +45,7 @@ from esphome.yaml_util import FastestAvailableSafeLoader
from .const import DASHBOARD_COMMAND
from .core import DASHBOARD
from .entries import EntryState, entry_state_to_bool
from .entries import UNKNOWN_STATE, entry_state_to_bool
from .util.file import write_file
from .util.subprocess import async_run_system_command
from .util.text import friendly_name_slugify
@ -381,7 +381,7 @@ class EsphomeRenameHandler(EsphomeCommandWebSocket):
# Remove the old ping result from the cache
entries = DASHBOARD.entries
if entry := entries.get(self.old_name):
entries.async_set_state(entry, EntryState.UNKNOWN)
entries.async_set_state(entry, UNKNOWN_STATE)
class EsphomeUploadHandler(EsphomePortCommandWebSocket):

View File

@ -5,7 +5,13 @@ from dataclasses import dataclass
import logging
from typing import Callable
from zeroconf import IPVersion, ServiceInfo, ServiceStateChange, Zeroconf
from zeroconf import (
AddressResolver,
IPVersion,
ServiceInfo,
ServiceStateChange,
Zeroconf,
)
from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf
from esphome.storage_json import StorageJSON, ext_storage_path
@ -16,15 +22,6 @@ _LOGGER = logging.getLogger(__name__)
_BACKGROUND_TASKS: set[asyncio.Task] = set()
class HostResolver(ServiceInfo):
"""Resolve a host name to an IP address."""
@property
def _is_complete(self) -> bool:
"""The ServiceInfo has all expected properties."""
return bool(self._ipv4_addresses)
class DashboardStatus:
def __init__(self, on_update: Callable[[dict[str, bool | None], []]]) -> None:
"""Initialize the dashboard status."""
@ -166,19 +163,10 @@ class DashboardImportDiscovery:
)
def _make_host_resolver(host: str) -> HostResolver:
"""Create a new HostResolver for the given host name."""
name = host.partition(".")[0]
info = HostResolver(
ESPHOME_SERVICE_TYPE, f"{name}.{ESPHOME_SERVICE_TYPE}", server=f"{name}.local."
)
return info
class EsphomeZeroconf(Zeroconf):
def resolve_host(self, host: str, timeout: float = 3.0) -> list[str] | None:
"""Resolve a host name to an IP address."""
info = _make_host_resolver(host)
info = AddressResolver(f'{host.partition(".")[0]}.local.')
if (
info.load_from_cache(self)
or (timeout and info.request(self, timeout * 1000))
@ -192,7 +180,7 @@ class AsyncEsphomeZeroconf(AsyncZeroconf):
self, host: str, timeout: float = 3.0
) -> list[str] | None:
"""Resolve a host name to an IP address."""
info = _make_host_resolver(host)
info = AddressResolver(f'{host.partition(".")[0]}.local.')
if (
info.load_from_cache(self.zeroconf)
or (timeout and await info.async_request(self.zeroconf, timeout * 1000))