1
0
mirror of https://github.com/esphome/esphome.git synced 2025-10-03 18:42:23 +01:00

[dashboard] Replace polling with WebSocket for real-time updates (#10893)

This commit is contained in:
J. Nick Koston
2025-09-30 13:03:52 -05:00
committed by GitHub
parent d75b7708a5
commit c69603d916
11 changed files with 1125 additions and 118 deletions

View File

@@ -1,9 +1,26 @@
from __future__ import annotations from __future__ import annotations
EVENT_ENTRY_ADDED = "entry_added" from esphome.enum import StrEnum
EVENT_ENTRY_REMOVED = "entry_removed"
EVENT_ENTRY_UPDATED = "entry_updated"
EVENT_ENTRY_STATE_CHANGED = "entry_state_changed" class DashboardEvent(StrEnum):
"""Dashboard WebSocket event types."""
# Server -> Client events (backend sends to frontend)
ENTRY_ADDED = "entry_added"
ENTRY_REMOVED = "entry_removed"
ENTRY_UPDATED = "entry_updated"
ENTRY_STATE_CHANGED = "entry_state_changed"
IMPORTABLE_DEVICE_ADDED = "importable_device_added"
IMPORTABLE_DEVICE_REMOVED = "importable_device_removed"
INITIAL_STATE = "initial_state" # Sent on WebSocket connection
PONG = "pong" # Response to client ping
# Client -> Server events (frontend sends to backend)
PING = "ping" # WebSocket keepalive from client
REFRESH = "refresh" # Force backend to poll for changes
MAX_EXECUTOR_WORKERS = 48 MAX_EXECUTOR_WORKERS = 48

View File

@@ -13,6 +13,7 @@ from typing import Any
from esphome.storage_json import ignored_devices_storage_path from esphome.storage_json import ignored_devices_storage_path
from ..zeroconf import DiscoveredImport from ..zeroconf import DiscoveredImport
from .const import DashboardEvent
from .dns import DNSCache from .dns import DNSCache
from .entries import DashboardEntries from .entries import DashboardEntries
from .settings import DashboardSettings from .settings import DashboardSettings
@@ -30,7 +31,7 @@ MDNS_BOOTSTRAP_TIME = 7.5
class Event: class Event:
"""Dashboard Event.""" """Dashboard Event."""
event_type: str event_type: DashboardEvent
data: dict[str, Any] data: dict[str, Any]
@@ -39,22 +40,24 @@ class EventBus:
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize the Dashboard event bus.""" """Initialize the Dashboard event bus."""
self._listeners: dict[str, set[Callable[[Event], None]]] = {} self._listeners: dict[DashboardEvent, set[Callable[[Event], None]]] = {}
def async_add_listener( def async_add_listener(
self, event_type: str, listener: Callable[[Event], None] self, event_type: DashboardEvent, listener: Callable[[Event], None]
) -> Callable[[], None]: ) -> Callable[[], None]:
"""Add a listener to the event bus.""" """Add a listener to the event bus."""
self._listeners.setdefault(event_type, set()).add(listener) self._listeners.setdefault(event_type, set()).add(listener)
return partial(self._async_remove_listener, event_type, listener) return partial(self._async_remove_listener, event_type, listener)
def _async_remove_listener( def _async_remove_listener(
self, event_type: str, listener: Callable[[Event], None] self, event_type: DashboardEvent, listener: Callable[[Event], None]
) -> None: ) -> None:
"""Remove a listener from the event bus.""" """Remove a listener from the event bus."""
self._listeners[event_type].discard(listener) self._listeners[event_type].discard(listener)
def async_fire(self, event_type: str, event_data: dict[str, Any]) -> None: def async_fire(
self, event_type: DashboardEvent, event_data: dict[str, Any]
) -> None:
"""Fire an event.""" """Fire an event."""
event = Event(event_type, event_data) event = Event(event_type, event_data)

View File

@@ -12,13 +12,7 @@ from esphome import const, util
from esphome.enum import StrEnum from esphome.enum import StrEnum
from esphome.storage_json import StorageJSON, ext_storage_path from esphome.storage_json import StorageJSON, ext_storage_path
from .const import ( from .const import DASHBOARD_COMMAND, DashboardEvent
DASHBOARD_COMMAND,
EVENT_ENTRY_ADDED,
EVENT_ENTRY_REMOVED,
EVENT_ENTRY_STATE_CHANGED,
EVENT_ENTRY_UPDATED,
)
from .util.subprocess import async_run_system_command from .util.subprocess import async_run_system_command
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -102,12 +96,12 @@ class DashboardEntries:
# "path/to/file.yaml": DashboardEntry, # "path/to/file.yaml": DashboardEntry,
# ... # ...
# } # }
self._entries: dict[str, DashboardEntry] = {} self._entries: dict[Path, DashboardEntry] = {}
self._loaded_entries = False self._loaded_entries = False
self._update_lock = asyncio.Lock() self._update_lock = asyncio.Lock()
self._name_to_entry: dict[str, set[DashboardEntry]] = defaultdict(set) self._name_to_entry: dict[str, set[DashboardEntry]] = defaultdict(set)
def get(self, path: str) -> DashboardEntry | None: def get(self, path: Path) -> DashboardEntry | None:
"""Get an entry by path.""" """Get an entry by path."""
return self._entries.get(path) return self._entries.get(path)
@@ -192,7 +186,7 @@ class DashboardEntries:
return return
entry.state = state entry.state = state
self._dashboard.bus.async_fire( self._dashboard.bus.async_fire(
EVENT_ENTRY_STATE_CHANGED, {"entry": entry, "state": state} DashboardEvent.ENTRY_STATE_CHANGED, {"entry": entry, "state": state}
) )
async def async_request_update_entries(self) -> None: async def async_request_update_entries(self) -> None:
@@ -260,22 +254,22 @@ class DashboardEntries:
for entry in added: for entry in added:
entries[entry.path] = entry entries[entry.path] = entry
name_to_entry[entry.name].add(entry) name_to_entry[entry.name].add(entry)
bus.async_fire(EVENT_ENTRY_ADDED, {"entry": entry}) bus.async_fire(DashboardEvent.ENTRY_ADDED, {"entry": entry})
for entry in removed: for entry in removed:
del entries[entry.path] del entries[entry.path]
name_to_entry[entry.name].discard(entry) name_to_entry[entry.name].discard(entry)
bus.async_fire(EVENT_ENTRY_REMOVED, {"entry": entry}) bus.async_fire(DashboardEvent.ENTRY_REMOVED, {"entry": entry})
for entry in updated: for entry in updated:
if (original_name := original_names[entry]) != (current_name := entry.name): if (original_name := original_names[entry]) != (current_name := entry.name):
name_to_entry[original_name].discard(entry) name_to_entry[original_name].discard(entry)
name_to_entry[current_name].add(entry) name_to_entry[current_name].add(entry)
bus.async_fire(EVENT_ENTRY_UPDATED, {"entry": entry}) bus.async_fire(DashboardEvent.ENTRY_UPDATED, {"entry": entry})
def _get_path_to_cache_key(self) -> dict[str, DashboardCacheKeyType]: def _get_path_to_cache_key(self) -> dict[Path, DashboardCacheKeyType]:
"""Return a dict of path to cache key.""" """Return a dict of path to cache key."""
path_to_cache_key: dict[str, DashboardCacheKeyType] = {} path_to_cache_key: dict[Path, DashboardCacheKeyType] = {}
# #
# The cache key is (inode, device, mtime, size) # The cache key is (inode, device, mtime, size)
# which allows us to avoid locking since it ensures # which allows us to avoid locking since it ensures

View File

@@ -0,0 +1,76 @@
"""Data models and builders for the dashboard."""
from __future__ import annotations
from typing import TYPE_CHECKING, TypedDict
if TYPE_CHECKING:
from esphome.zeroconf import DiscoveredImport
from .core import ESPHomeDashboard
from .entries import DashboardEntry
class ImportableDeviceDict(TypedDict):
"""Dictionary representation of an importable device."""
name: str
friendly_name: str | None
package_import_url: str
project_name: str
project_version: str
network: str
ignored: bool
class ConfiguredDeviceDict(TypedDict, total=False):
"""Dictionary representation of a configured device."""
name: str
friendly_name: str | None
configuration: str
loaded_integrations: list[str] | None
deployed_version: str | None
current_version: str | None
path: str
comment: str | None
address: str | None
web_port: int | None
target_platform: str | None
class DeviceListResponse(TypedDict):
"""Response for device list API."""
configured: list[ConfiguredDeviceDict]
importable: list[ImportableDeviceDict]
def build_importable_device_dict(
dashboard: ESPHomeDashboard, discovered: DiscoveredImport
) -> ImportableDeviceDict:
"""Build the importable device dictionary."""
return ImportableDeviceDict(
name=discovered.device_name,
friendly_name=discovered.friendly_name,
package_import_url=discovered.package_import_url,
project_name=discovered.project_name,
project_version=discovered.project_version,
network=discovered.network,
ignored=discovered.device_name in dashboard.ignored_devices,
)
def build_device_list_response(
dashboard: ESPHomeDashboard, entries: list[DashboardEntry]
) -> DeviceListResponse:
"""Build the device list response data."""
configured = {entry.name for entry in entries}
return DeviceListResponse(
configured=[entry.to_dict() for entry in entries],
importable=[
build_importable_device_dict(dashboard, res)
for res in dashboard.import_result.values()
if res.device_name not in configured
],
)

View File

@@ -13,10 +13,12 @@ from esphome.zeroconf import (
DashboardBrowser, DashboardBrowser,
DashboardImportDiscovery, DashboardImportDiscovery,
DashboardStatus, DashboardStatus,
DiscoveredImport,
) )
from ..const import SENTINEL from ..const import SENTINEL, DashboardEvent
from ..entries import DashboardEntry, EntryStateSource, bool_to_entry_state from ..entries import DashboardEntry, EntryStateSource, bool_to_entry_state
from ..models import build_importable_device_dict
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from ..core import ESPHomeDashboard from ..core import ESPHomeDashboard
@@ -77,6 +79,20 @@ class MDNSStatus:
_LOGGER.debug("Not found in zeroconf cache: %s", resolver_name) _LOGGER.debug("Not found in zeroconf cache: %s", resolver_name)
return None return None
def _on_import_update(self, name: str, discovered: DiscoveredImport | None) -> None:
"""Handle importable device updates."""
if discovered is None:
# Device removed
self.dashboard.bus.async_fire(
DashboardEvent.IMPORTABLE_DEVICE_REMOVED, {"name": name}
)
else:
# Device added
self.dashboard.bus.async_fire(
DashboardEvent.IMPORTABLE_DEVICE_ADDED,
{"device": build_importable_device_dict(self.dashboard, discovered)},
)
async def async_refresh_hosts(self) -> None: async def async_refresh_hosts(self) -> None:
"""Refresh the hosts to track.""" """Refresh the hosts to track."""
dashboard = self.dashboard dashboard = self.dashboard
@@ -133,7 +149,8 @@ class MDNSStatus:
self._async_set_state(entry, result) self._async_set_state(entry, result)
stat = DashboardStatus(on_update) stat = DashboardStatus(on_update)
imports = DashboardImportDiscovery()
imports = DashboardImportDiscovery(self._on_import_update)
dashboard.import_result = imports.import_state dashboard.import_result = imports.import_state
browser = DashboardBrowser( browser = DashboardBrowser(

View File

@@ -4,8 +4,10 @@ import asyncio
import base64 import base64
import binascii import binascii
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
import contextlib
import datetime import datetime
import functools import functools
from functools import partial
import gzip import gzip
import hashlib import hashlib
import importlib import importlib
@@ -50,9 +52,10 @@ from esphome.util import get_serial_ports, shlex_quote
from esphome.yaml_util import FastestAvailableSafeLoader from esphome.yaml_util import FastestAvailableSafeLoader
from ..helpers import write_file from ..helpers import write_file
from .const import DASHBOARD_COMMAND from .const import DASHBOARD_COMMAND, DashboardEvent
from .core import DASHBOARD, ESPHomeDashboard from .core import DASHBOARD, ESPHomeDashboard, Event
from .entries import UNKNOWN_STATE, DashboardEntry, entry_state_to_bool from .entries import UNKNOWN_STATE, DashboardEntry, entry_state_to_bool
from .models import build_device_list_response
from .util.subprocess import async_run_system_command from .util.subprocess import async_run_system_command
from .util.text import friendly_name_slugify from .util.text import friendly_name_slugify
@@ -520,6 +523,243 @@ class EsphomeUpdateAllHandler(EsphomeCommandWebSocket):
return [*DASHBOARD_COMMAND, "update-all", settings.config_dir] return [*DASHBOARD_COMMAND, "update-all", settings.config_dir]
# Dashboard polling constants
DASHBOARD_POLL_INTERVAL = 2 # seconds
DASHBOARD_ENTRIES_UPDATE_INTERVAL = 10 # seconds
DASHBOARD_ENTRIES_UPDATE_ITERATIONS = (
DASHBOARD_ENTRIES_UPDATE_INTERVAL // DASHBOARD_POLL_INTERVAL
)
class DashboardSubscriber:
"""Manages dashboard event polling task lifecycle based on active subscribers."""
def __init__(self) -> None:
"""Initialize the dashboard subscriber."""
self._subscribers: set[DashboardEventsWebSocket] = set()
self._event_loop_task: asyncio.Task | None = None
self._refresh_event: asyncio.Event = asyncio.Event()
def subscribe(self, subscriber: DashboardEventsWebSocket) -> Callable[[], None]:
"""Subscribe to dashboard updates and start event loop if needed."""
self._subscribers.add(subscriber)
if not self._event_loop_task or self._event_loop_task.done():
self._event_loop_task = asyncio.create_task(self._event_loop())
_LOGGER.info("Started dashboard event loop")
return partial(self._unsubscribe, subscriber)
def _unsubscribe(self, subscriber: DashboardEventsWebSocket) -> None:
"""Unsubscribe from dashboard updates and stop event loop if no subscribers."""
self._subscribers.discard(subscriber)
if (
not self._subscribers
and self._event_loop_task
and not self._event_loop_task.done()
):
self._event_loop_task.cancel()
self._event_loop_task = None
_LOGGER.info("Stopped dashboard event loop - no subscribers")
def request_refresh(self) -> None:
"""Signal the polling loop to refresh immediately."""
self._refresh_event.set()
async def _event_loop(self) -> None:
"""Run the event polling loop while there are subscribers."""
dashboard = DASHBOARD
entries_update_counter = 0
while self._subscribers:
# Signal that we need ping updates (non-blocking)
dashboard.ping_request.set()
if settings.status_use_mqtt:
dashboard.mqtt_ping_request.set()
# Check if it's time to update entries or if refresh was requested
entries_update_counter += 1
if (
entries_update_counter >= DASHBOARD_ENTRIES_UPDATE_ITERATIONS
or self._refresh_event.is_set()
):
entries_update_counter = 0
await dashboard.entries.async_request_update_entries()
# Clear the refresh event if it was set
self._refresh_event.clear()
# Wait for either timeout or refresh event
try:
async with asyncio.timeout(DASHBOARD_POLL_INTERVAL):
await self._refresh_event.wait()
# If we get here, refresh was requested - continue loop immediately
except TimeoutError:
# Normal timeout - continue with regular polling
pass
# Global dashboard subscriber instance
DASHBOARD_SUBSCRIBER = DashboardSubscriber()
@websocket_class
class DashboardEventsWebSocket(tornado.websocket.WebSocketHandler):
"""WebSocket handler for real-time dashboard events."""
_event_listeners: list[Callable[[], None]] | None = None
_dashboard_unsubscribe: Callable[[], None] | None = None
async def get(self, *args: str, **kwargs: str) -> None:
"""Handle WebSocket upgrade request."""
if not is_authenticated(self):
self.set_status(401)
self.finish("Unauthorized")
return
await super().get(*args, **kwargs)
async def open(self, *args: str, **kwargs: str) -> None: # pylint: disable=invalid-overridden-method
"""Handle new WebSocket connection."""
# Ensure messages are sent immediately to avoid
# a 200-500ms delay when nodelay is not set.
self.set_nodelay(True)
# Update entries first
await DASHBOARD.entries.async_request_update_entries()
# Send initial state
self._send_initial_state()
# Subscribe to events
self._subscribe_to_events()
# Subscribe to dashboard updates
self._dashboard_unsubscribe = DASHBOARD_SUBSCRIBER.subscribe(self)
_LOGGER.debug("Dashboard status WebSocket opened")
def _send_initial_state(self) -> None:
"""Send initial device list and ping status."""
entries = DASHBOARD.entries.async_all()
# Send initial state
self._safe_send_message(
{
"event": DashboardEvent.INITIAL_STATE,
"data": {
"devices": build_device_list_response(DASHBOARD, entries),
"ping": {
entry.filename: entry_state_to_bool(entry.state)
for entry in entries
},
},
}
)
def _subscribe_to_events(self) -> None:
"""Subscribe to dashboard events."""
async_add_listener = DASHBOARD.bus.async_add_listener
# Subscribe to all events
self._event_listeners = [
async_add_listener(
DashboardEvent.ENTRY_STATE_CHANGED, self._on_entry_state_changed
),
async_add_listener(
DashboardEvent.ENTRY_ADDED,
self._make_entry_handler(DashboardEvent.ENTRY_ADDED),
),
async_add_listener(
DashboardEvent.ENTRY_REMOVED,
self._make_entry_handler(DashboardEvent.ENTRY_REMOVED),
),
async_add_listener(
DashboardEvent.ENTRY_UPDATED,
self._make_entry_handler(DashboardEvent.ENTRY_UPDATED),
),
async_add_listener(
DashboardEvent.IMPORTABLE_DEVICE_ADDED, self._on_importable_added
),
async_add_listener(
DashboardEvent.IMPORTABLE_DEVICE_REMOVED,
self._on_importable_removed,
),
]
def _on_entry_state_changed(self, event: Event) -> None:
"""Handle entry state change event."""
entry = event.data["entry"]
state = event.data["state"]
self._safe_send_message(
{
"event": DashboardEvent.ENTRY_STATE_CHANGED,
"data": {
"filename": entry.filename,
"name": entry.name,
"state": entry_state_to_bool(state),
},
}
)
def _make_entry_handler(
self, event_type: DashboardEvent
) -> Callable[[Event], None]:
"""Create an entry event handler."""
def handler(event: Event) -> None:
self._safe_send_message(
{"event": event_type, "data": {"device": event.data["entry"].to_dict()}}
)
return handler
def _on_importable_added(self, event: Event) -> None:
"""Handle importable device added event."""
# Don't send if device is already configured
device_name = event.data.get("device", {}).get("name")
if device_name and DASHBOARD.entries.get_by_name(device_name):
return
self._safe_send_message(
{"event": DashboardEvent.IMPORTABLE_DEVICE_ADDED, "data": event.data}
)
def _on_importable_removed(self, event: Event) -> None:
"""Handle importable device removed event."""
self._safe_send_message(
{"event": DashboardEvent.IMPORTABLE_DEVICE_REMOVED, "data": event.data}
)
def _safe_send_message(self, message: dict[str, Any]) -> None:
"""Send a message to the WebSocket client, ignoring closed errors."""
with contextlib.suppress(tornado.websocket.WebSocketClosedError):
self.write_message(json.dumps(message))
def on_message(self, message: str) -> None:
"""Handle incoming WebSocket messages."""
_LOGGER.debug("WebSocket received message: %s", message)
try:
data = json.loads(message)
except json.JSONDecodeError as err:
_LOGGER.debug("Failed to parse WebSocket message: %s", err)
return
event = data.get("event")
_LOGGER.debug("WebSocket message event: %s", event)
if event == DashboardEvent.PING:
# Send pong response for client ping
_LOGGER.debug("Received client ping, sending pong")
self._safe_send_message({"event": DashboardEvent.PONG})
elif event == DashboardEvent.REFRESH:
# Signal the polling loop to refresh immediately
_LOGGER.debug("Received refresh request, signaling polling loop")
DASHBOARD_SUBSCRIBER.request_refresh()
def on_close(self) -> None:
"""Handle WebSocket close."""
# Unsubscribe from dashboard updates
if self._dashboard_unsubscribe:
self._dashboard_unsubscribe()
self._dashboard_unsubscribe = None
# Unsubscribe from events
for remove_listener in self._event_listeners or []:
remove_listener()
_LOGGER.debug("Dashboard status WebSocket closed")
class SerialPortRequestHandler(BaseHandler): class SerialPortRequestHandler(BaseHandler):
@authenticated @authenticated
async def get(self) -> None: async def get(self) -> None:
@@ -874,28 +1114,7 @@ class ListDevicesHandler(BaseHandler):
await dashboard.entries.async_request_update_entries() await dashboard.entries.async_request_update_entries()
entries = dashboard.entries.async_all() entries = dashboard.entries.async_all()
self.set_header("content-type", "application/json") self.set_header("content-type", "application/json")
configured = {entry.name for entry in entries} self.write(json.dumps(build_device_list_response(dashboard, entries)))
self.write(
json.dumps(
{
"configured": [entry.to_dict() for entry in entries],
"importable": [
{
"name": res.device_name,
"friendly_name": res.friendly_name,
"package_import_url": res.package_import_url,
"project_name": res.project_name,
"project_version": res.project_version,
"network": res.network,
"ignored": res.device_name in dashboard.ignored_devices,
}
for res in dashboard.import_result.values()
if res.device_name not in configured
],
}
)
)
class MainRequestHandler(BaseHandler): class MainRequestHandler(BaseHandler):
@@ -1351,6 +1570,7 @@ def make_app(debug=get_bool_env(ENV_DEV)) -> tornado.web.Application:
(f"{rel}wizard", WizardRequestHandler), (f"{rel}wizard", WizardRequestHandler),
(f"{rel}static/(.*)", StaticFileHandler, {"path": get_static_path()}), (f"{rel}static/(.*)", StaticFileHandler, {"path": get_static_path()}),
(f"{rel}devices", ListDevicesHandler), (f"{rel}devices", ListDevicesHandler),
(f"{rel}events", DashboardEventsWebSocket),
(f"{rel}import", ImportRequestHandler), (f"{rel}import", ImportRequestHandler),
(f"{rel}secret_keys", SecretKeysRequestHandler), (f"{rel}secret_keys", SecretKeysRequestHandler),
(f"{rel}json-config", JsonConfigRequestHandler), (f"{rel}json-config", JsonConfigRequestHandler),

View File

@@ -68,8 +68,11 @@ class DashboardBrowser(AsyncServiceBrowser):
class DashboardImportDiscovery: class DashboardImportDiscovery:
def __init__(self) -> None: def __init__(
self, on_update: Callable[[str, DiscoveredImport | None], None] | None = None
) -> None:
self.import_state: dict[str, DiscoveredImport] = {} self.import_state: dict[str, DiscoveredImport] = {}
self.on_update = on_update
def browser_callback( def browser_callback(
self, self,
@@ -85,7 +88,9 @@ class DashboardImportDiscovery:
state_change, state_change,
) )
if state_change == ServiceStateChange.Removed: if state_change == ServiceStateChange.Removed:
self.import_state.pop(name, None) removed = self.import_state.pop(name, None)
if removed and self.on_update:
self.on_update(name, None)
return return
if state_change == ServiceStateChange.Updated and name not in self.import_state: if state_change == ServiceStateChange.Updated and name not in self.import_state:
@@ -139,7 +144,7 @@ class DashboardImportDiscovery:
if friendly_name is not None: if friendly_name is not None:
friendly_name = friendly_name.decode() friendly_name = friendly_name.decode()
self.import_state[name] = DiscoveredImport( discovered = DiscoveredImport(
friendly_name=friendly_name, friendly_name=friendly_name,
device_name=node_name, device_name=node_name,
package_import_url=import_url, package_import_url=import_url,
@@ -147,6 +152,10 @@ class DashboardImportDiscovery:
project_version=project_version, project_version=project_version,
network=network, network=network,
) )
is_new = name not in self.import_state
self.import_state[name] = discovered
if is_new and self.on_update:
self.on_update(name, discovered)
def update_device_mdns(self, node_name: str, version: str): def update_device_mdns(self, node_name: str, version: str):
storage_path = ext_storage_path(node_name + ".yaml") storage_path = ext_storage_path(node_name + ".yaml")

View File

@@ -2,20 +2,42 @@
from __future__ import annotations from __future__ import annotations
from unittest.mock import Mock from pathlib import Path
from unittest.mock import MagicMock, Mock
import pytest import pytest
import pytest_asyncio
from esphome.dashboard.core import ESPHomeDashboard from esphome.dashboard.core import ESPHomeDashboard
from esphome.dashboard.entries import DashboardEntries
@pytest.fixture @pytest.fixture
def mock_dashboard() -> Mock: def mock_settings(tmp_path: Path) -> MagicMock:
"""Create mock dashboard settings."""
settings = MagicMock()
settings.config_dir = str(tmp_path)
settings.absolute_config_dir = tmp_path
return settings
@pytest.fixture
def mock_dashboard(mock_settings: MagicMock) -> Mock:
"""Create a mock dashboard.""" """Create a mock dashboard."""
dashboard = Mock(spec=ESPHomeDashboard) dashboard = Mock(spec=ESPHomeDashboard)
dashboard.settings = mock_settings
dashboard.entries = Mock() dashboard.entries = Mock()
dashboard.entries.async_all.return_value = [] dashboard.entries.async_all.return_value = []
dashboard.stop_event = Mock() dashboard.stop_event = Mock()
dashboard.stop_event.is_set.return_value = True dashboard.stop_event.is_set.return_value = True
dashboard.ping_request = Mock() dashboard.ping_request = Mock()
dashboard.ignored_devices = set()
dashboard.bus = Mock()
dashboard.bus.async_fire = Mock()
return dashboard return dashboard
@pytest_asyncio.fixture
async def dashboard_entries(mock_dashboard: Mock) -> DashboardEntries:
"""Create a DashboardEntries instance for testing."""
return DashboardEntries(mock_dashboard)

View File

@@ -8,7 +8,9 @@ import pytest
import pytest_asyncio import pytest_asyncio
from zeroconf import AddressResolver, IPVersion from zeroconf import AddressResolver, IPVersion
from esphome.dashboard.const import DashboardEvent
from esphome.dashboard.status.mdns import MDNSStatus from esphome.dashboard.status.mdns import MDNSStatus
from esphome.zeroconf import DiscoveredImport
@pytest_asyncio.fixture @pytest_asyncio.fixture
@@ -166,3 +168,73 @@ async def test_async_setup_failure(mock_dashboard: Mock) -> None:
result = mdns_status.async_setup() result = mdns_status.async_setup()
assert result is False assert result is False
assert mdns_status.aiozc is None assert mdns_status.aiozc is None
@pytest.mark.asyncio
async def test_on_import_update_device_added(mdns_status: MDNSStatus) -> None:
"""Test _on_import_update when a device is added."""
# Create a DiscoveredImport object
discovered = DiscoveredImport(
device_name="test_device",
friendly_name="Test Device",
package_import_url="https://example.com/package",
project_name="test_project",
project_version="1.0.0",
network="wifi",
)
# Call _on_import_update with a device
mdns_status._on_import_update("test_device", discovered)
# Should fire IMPORTABLE_DEVICE_ADDED event
mock_dashboard = mdns_status.dashboard
mock_dashboard.bus.async_fire.assert_called_once()
call_args = mock_dashboard.bus.async_fire.call_args
assert call_args[0][0] == DashboardEvent.IMPORTABLE_DEVICE_ADDED
assert "device" in call_args[0][1]
device_data = call_args[0][1]["device"]
assert device_data["name"] == "test_device"
assert device_data["friendly_name"] == "Test Device"
assert device_data["project_name"] == "test_project"
assert device_data["ignored"] is False
@pytest.mark.asyncio
async def test_on_import_update_device_ignored(mdns_status: MDNSStatus) -> None:
"""Test _on_import_update when a device is ignored."""
# Add device to ignored list
mdns_status.dashboard.ignored_devices.add("ignored_device")
# Create a DiscoveredImport object for ignored device
discovered = DiscoveredImport(
device_name="ignored_device",
friendly_name="Ignored Device",
package_import_url="https://example.com/package",
project_name="test_project",
project_version="1.0.0",
network="ethernet",
)
# Call _on_import_update with an ignored device
mdns_status._on_import_update("ignored_device", discovered)
# Should fire IMPORTABLE_DEVICE_ADDED event with ignored=True
mock_dashboard = mdns_status.dashboard
mock_dashboard.bus.async_fire.assert_called_once()
call_args = mock_dashboard.bus.async_fire.call_args
assert call_args[0][0] == DashboardEvent.IMPORTABLE_DEVICE_ADDED
device_data = call_args[0][1]["device"]
assert device_data["name"] == "ignored_device"
assert device_data["ignored"] is True
@pytest.mark.asyncio
async def test_on_import_update_device_removed(mdns_status: MDNSStatus) -> None:
"""Test _on_import_update when a device is removed."""
# Call _on_import_update with None (device removed)
mdns_status._on_import_update("removed_device", None)
# Should fire IMPORTABLE_DEVICE_REMOVED event
mdns_status.dashboard.bus.async_fire.assert_called_once_with(
DashboardEvent.IMPORTABLE_DEVICE_REMOVED, {"name": "removed_device"}
)

View File

@@ -2,14 +2,15 @@
from __future__ import annotations from __future__ import annotations
import os
from pathlib import Path from pathlib import Path
import tempfile import tempfile
from unittest.mock import MagicMock from unittest.mock import Mock
import pytest import pytest
import pytest_asyncio
from esphome.core import CORE from esphome.core import CORE
from esphome.dashboard.const import DashboardEvent
from esphome.dashboard.entries import DashboardEntries, DashboardEntry from esphome.dashboard.entries import DashboardEntries, DashboardEntry
@@ -27,21 +28,6 @@ def setup_core():
CORE.reset() CORE.reset()
@pytest.fixture
def mock_settings() -> MagicMock:
"""Create mock dashboard settings."""
settings = MagicMock()
settings.config_dir = "/test/config"
settings.absolute_config_dir = Path("/test/config")
return settings
@pytest_asyncio.fixture
async def dashboard_entries(mock_settings: MagicMock) -> DashboardEntries:
"""Create a DashboardEntries instance for testing."""
return DashboardEntries(mock_settings)
def test_dashboard_entry_path_initialization() -> None: def test_dashboard_entry_path_initialization() -> None:
"""Test DashboardEntry initializes with path correctly.""" """Test DashboardEntry initializes with path correctly."""
test_path = Path("/test/config/device.yaml") test_path = Path("/test/config/device.yaml")
@@ -78,15 +64,24 @@ def test_dashboard_entry_path_with_relative_path() -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_dashboard_entries_get_by_path( async def test_dashboard_entries_get_by_path(
dashboard_entries: DashboardEntries, dashboard_entries: DashboardEntries, tmp_path: Path
) -> None: ) -> None:
"""Test getting entry by path.""" """Test getting entry by path."""
test_path = Path("/test/config/device.yaml") # Create a test file
entry = DashboardEntry(test_path, create_cache_key()) test_file = tmp_path / "device.yaml"
test_file.write_text("test config")
dashboard_entries._entries[str(test_path)] = entry # Update entries to load the file
await dashboard_entries.async_update_entries()
result = dashboard_entries.get(str(test_path)) # Verify the entry was loaded
all_entries = dashboard_entries.async_all()
assert len(all_entries) == 1
entry = all_entries[0]
assert entry.path == test_file
# Also verify get() works with Path
result = dashboard_entries.get(test_file)
assert result == entry assert result == entry
@@ -101,45 +96,54 @@ async def test_dashboard_entries_get_nonexistent_path(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_dashboard_entries_path_normalization( async def test_dashboard_entries_path_normalization(
dashboard_entries: DashboardEntries, dashboard_entries: DashboardEntries, tmp_path: Path
) -> None: ) -> None:
"""Test that paths are handled consistently.""" """Test that paths are handled consistently."""
path1 = Path("/test/config/device.yaml") # Create a test file
test_file = tmp_path / "device.yaml"
test_file.write_text("test config")
entry = DashboardEntry(path1, create_cache_key()) # Update entries to load the file
dashboard_entries._entries[str(path1)] = entry await dashboard_entries.async_update_entries()
result = dashboard_entries.get(str(path1)) # Get the entry by path
assert result == entry result = dashboard_entries.get(test_file)
assert result is not None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_dashboard_entries_path_with_spaces( async def test_dashboard_entries_path_with_spaces(
dashboard_entries: DashboardEntries, dashboard_entries: DashboardEntries, tmp_path: Path
) -> None: ) -> None:
"""Test handling paths with spaces.""" """Test handling paths with spaces."""
test_path = Path("/test/config/my device.yaml") # Create a test file with spaces in name
entry = DashboardEntry(test_path, create_cache_key()) test_file = tmp_path / "my device.yaml"
test_file.write_text("test config")
dashboard_entries._entries[str(test_path)] = entry # Update entries to load the file
await dashboard_entries.async_update_entries()
result = dashboard_entries.get(str(test_path)) # Get the entry by path
assert result == entry result = dashboard_entries.get(test_file)
assert result.path == test_path assert result is not None
assert result.path == test_file
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_dashboard_entries_path_with_special_chars( async def test_dashboard_entries_path_with_special_chars(
dashboard_entries: DashboardEntries, dashboard_entries: DashboardEntries, tmp_path: Path
) -> None: ) -> None:
"""Test handling paths with special characters.""" """Test handling paths with special characters."""
test_path = Path("/test/config/device-01_test.yaml") # Create a test file with special characters
entry = DashboardEntry(test_path, create_cache_key()) test_file = tmp_path / "device-01_test.yaml"
test_file.write_text("test config")
dashboard_entries._entries[str(test_path)] = entry # Update entries to load the file
await dashboard_entries.async_update_entries()
result = dashboard_entries.get(str(test_path)) # Get the entry by path
assert result == entry result = dashboard_entries.get(test_file)
assert result is not None
def test_dashboard_entries_windows_path() -> None: def test_dashboard_entries_windows_path() -> None:
@@ -154,22 +158,25 @@ def test_dashboard_entries_windows_path() -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_dashboard_entries_path_to_cache_key_mapping( async def test_dashboard_entries_path_to_cache_key_mapping(
dashboard_entries: DashboardEntries, dashboard_entries: DashboardEntries, tmp_path: Path
) -> None: ) -> None:
"""Test internal entries storage with paths and cache keys.""" """Test internal entries storage with paths and cache keys."""
path1 = Path("/test/config/device1.yaml") # Create test files
path2 = Path("/test/config/device2.yaml") file1 = tmp_path / "device1.yaml"
file2 = tmp_path / "device2.yaml"
file1.write_text("test config 1")
file2.write_text("test config 2")
entry1 = DashboardEntry(path1, create_cache_key()) # Update entries to load the files
entry2 = DashboardEntry(path2, (1, 1, 1.0, 1)) await dashboard_entries.async_update_entries()
dashboard_entries._entries[str(path1)] = entry1 # Get entries and verify they have different cache keys
dashboard_entries._entries[str(path2)] = entry2 entry1 = dashboard_entries.get(file1)
entry2 = dashboard_entries.get(file2)
assert str(path1) in dashboard_entries._entries assert entry1 is not None
assert str(path2) in dashboard_entries._entries assert entry2 is not None
assert dashboard_entries._entries[str(path1)].cache_key == create_cache_key() assert entry1.cache_key != entry2.cache_key
assert dashboard_entries._entries[str(path2)].cache_key == (1, 1, 1.0, 1)
def test_dashboard_entry_path_property() -> None: def test_dashboard_entry_path_property() -> None:
@@ -183,21 +190,99 @@ def test_dashboard_entry_path_property() -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_dashboard_entries_all_returns_entries_with_paths( async def test_dashboard_entries_all_returns_entries_with_paths(
dashboard_entries: DashboardEntries, dashboard_entries: DashboardEntries, tmp_path: Path
) -> None: ) -> None:
"""Test that all() returns entries with their paths intact.""" """Test that all() returns entries with their paths intact."""
paths = [ # Create test files
Path("/test/config/device1.yaml"), files = [
Path("/test/config/device2.yaml"), tmp_path / "device1.yaml",
Path("/test/config/subfolder/device3.yaml"), tmp_path / "device2.yaml",
tmp_path / "device3.yaml",
] ]
for path in paths: for file in files:
entry = DashboardEntry(path, create_cache_key()) file.write_text("test config")
dashboard_entries._entries[str(path)] = entry
# Update entries to load the files
await dashboard_entries.async_update_entries()
all_entries = dashboard_entries.async_all() all_entries = dashboard_entries.async_all()
assert len(all_entries) == len(paths) assert len(all_entries) == len(files)
retrieved_paths = [entry.path for entry in all_entries] retrieved_paths = [entry.path for entry in all_entries]
assert set(retrieved_paths) == set(paths) assert set(retrieved_paths) == set(files)
@pytest.mark.asyncio
async def test_async_update_entries_removed_path(
dashboard_entries: DashboardEntries, mock_dashboard: Mock, tmp_path: Path
) -> None:
"""Test that removed files trigger ENTRY_REMOVED event."""
# Create a test file
test_file = tmp_path / "device.yaml"
test_file.write_text("test config")
# First update to add the entry
await dashboard_entries.async_update_entries()
# Verify entry was added
all_entries = dashboard_entries.async_all()
assert len(all_entries) == 1
entry = all_entries[0]
# Delete the file
test_file.unlink()
# Second update to detect removal
await dashboard_entries.async_update_entries()
# Verify entry was removed
all_entries = dashboard_entries.async_all()
assert len(all_entries) == 0
# Verify ENTRY_REMOVED event was fired
mock_dashboard.bus.async_fire.assert_any_call(
DashboardEvent.ENTRY_REMOVED, {"entry": entry}
)
@pytest.mark.asyncio
async def test_async_update_entries_updated_path(
dashboard_entries: DashboardEntries, mock_dashboard: Mock, tmp_path: Path
) -> None:
"""Test that modified files trigger ENTRY_UPDATED event."""
# Create a test file
test_file = tmp_path / "device.yaml"
test_file.write_text("test config")
# First update to add the entry
await dashboard_entries.async_update_entries()
# Verify entry was added
all_entries = dashboard_entries.async_all()
assert len(all_entries) == 1
entry = all_entries[0]
original_cache_key = entry.cache_key
# Modify the file to change its mtime
test_file.write_text("updated config")
# Explicitly change the mtime to ensure it's different
stat = test_file.stat()
os.utime(test_file, (stat.st_atime, stat.st_mtime + 1))
# Second update to detect modification
await dashboard_entries.async_update_entries()
# Verify entry is still there with updated cache key
all_entries = dashboard_entries.async_all()
assert len(all_entries) == 1
updated_entry = all_entries[0]
assert updated_entry == entry # Same entry object
assert updated_entry.cache_key != original_cache_key # But cache key updated
# Verify ENTRY_UPDATED event was fired
mock_dashboard.bus.async_fire.assert_any_call(
DashboardEvent.ENTRY_UPDATED, {"entry": entry}
)

View File

@@ -2,11 +2,12 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Generator from collections.abc import Generator
from contextlib import asynccontextmanager
import gzip import gzip
import json import json
import os import os
from pathlib import Path from pathlib import Path
from unittest.mock import MagicMock, Mock, patch from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest import pytest
import pytest_asyncio import pytest_asyncio
@@ -14,9 +15,19 @@ from tornado.httpclient import AsyncHTTPClient, HTTPClientError, HTTPResponse
from tornado.httpserver import HTTPServer from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop from tornado.ioloop import IOLoop
from tornado.testing import bind_unused_port from tornado.testing import bind_unused_port
from tornado.websocket import WebSocketClientConnection, websocket_connect
from esphome.dashboard import web_server from esphome.dashboard import web_server
from esphome.dashboard.const import DashboardEvent
from esphome.dashboard.core import DASHBOARD from esphome.dashboard.core import DASHBOARD
from esphome.dashboard.entries import (
DashboardEntry,
EntryStateSource,
bool_to_entry_state,
)
from esphome.dashboard.models import build_importable_device_dict
from esphome.dashboard.web_server import DashboardSubscriber
from esphome.zeroconf import DiscoveredImport
from .common import get_fixture_path from .common import get_fixture_path
@@ -126,6 +137,33 @@ async def dashboard() -> DashboardTestHelper:
io_loop.close() io_loop.close()
@asynccontextmanager
async def websocket_connection(dashboard: DashboardTestHelper):
"""Async context manager for WebSocket connections."""
url = f"ws://127.0.0.1:{dashboard.port}/events"
ws = await websocket_connect(url)
try:
yield ws
finally:
if ws:
ws.close()
@pytest_asyncio.fixture
async def websocket_client(dashboard: DashboardTestHelper) -> WebSocketClientConnection:
"""Create a WebSocket connection for testing."""
url = f"ws://127.0.0.1:{dashboard.port}/events"
ws = await websocket_connect(url)
# Read and discard initial state message
await ws.read_message()
yield ws
if ws:
ws.close()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_main_page(dashboard: DashboardTestHelper) -> None: async def test_main_page(dashboard: DashboardTestHelper) -> None:
response = await dashboard.fetch("/") response = await dashboard.fetch("/")
@@ -810,3 +848,457 @@ def test_build_cache_arguments_name_without_address(mock_dashboard: Mock) -> Non
mock_dashboard.mdns_status.get_cached_addresses.assert_called_once_with( mock_dashboard.mdns_status.get_cached_addresses.assert_called_once_with(
"my-device.local" "my-device.local"
) )
@pytest.mark.asyncio
async def test_websocket_connection_initial_state(
dashboard: DashboardTestHelper,
) -> None:
"""Test WebSocket connection and initial state."""
async with websocket_connection(dashboard) as ws:
# Should receive initial state with configured and importable devices
msg = await ws.read_message()
assert msg is not None
data = json.loads(msg)
assert data["event"] == "initial_state"
assert "devices" in data["data"]
assert "configured" in data["data"]["devices"]
assert "importable" in data["data"]["devices"]
# Check configured devices
configured = data["data"]["devices"]["configured"]
assert len(configured) > 0
assert configured[0]["name"] == "pico" # From test fixtures
@pytest.mark.asyncio
async def test_websocket_ping_pong(
dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection
) -> None:
"""Test WebSocket ping/pong mechanism."""
# Send ping
await websocket_client.write_message(json.dumps({"event": "ping"}))
# Should receive pong
msg = await websocket_client.read_message()
assert msg is not None
data = json.loads(msg)
assert data["event"] == "pong"
@pytest.mark.asyncio
async def test_websocket_invalid_json(
dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection
) -> None:
"""Test WebSocket handling of invalid JSON."""
# Send invalid JSON
await websocket_client.write_message("not valid json {]")
# Send a valid ping to verify connection is still alive
await websocket_client.write_message(json.dumps({"event": "ping"}))
# Should receive pong, confirming the connection wasn't closed by invalid JSON
msg = await websocket_client.read_message()
assert msg is not None
data = json.loads(msg)
assert data["event"] == "pong"
@pytest.mark.asyncio
async def test_websocket_authentication_required(
dashboard: DashboardTestHelper,
) -> None:
"""Test WebSocket authentication when auth is required."""
with patch(
"esphome.dashboard.web_server.is_authenticated"
) as mock_is_authenticated:
mock_is_authenticated.return_value = False
# Try to connect - should be rejected with 401
url = f"ws://127.0.0.1:{dashboard.port}/events"
with pytest.raises(HTTPClientError) as exc_info:
await websocket_connect(url)
# Should get HTTP 401 Unauthorized
assert exc_info.value.code == 401
@pytest.mark.asyncio
async def test_websocket_authentication_not_required(
dashboard: DashboardTestHelper,
) -> None:
"""Test WebSocket connection when no auth is required."""
with patch(
"esphome.dashboard.web_server.is_authenticated"
) as mock_is_authenticated:
mock_is_authenticated.return_value = True
# Should be able to connect successfully
async with websocket_connection(dashboard) as ws:
msg = await ws.read_message()
assert msg is not None
data = json.loads(msg)
assert data["event"] == "initial_state"
@pytest.mark.asyncio
async def test_websocket_entry_state_changed(
dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection
) -> None:
"""Test WebSocket entry state changed event."""
# Simulate entry state change
entry = DASHBOARD.entries.async_all()[0]
state = bool_to_entry_state(True, EntryStateSource.MDNS)
DASHBOARD.bus.async_fire(
DashboardEvent.ENTRY_STATE_CHANGED, {"entry": entry, "state": state}
)
# Should receive state change event
msg = await websocket_client.read_message()
assert msg is not None
data = json.loads(msg)
assert data["event"] == "entry_state_changed"
assert data["data"]["filename"] == entry.filename
assert data["data"]["name"] == entry.name
assert data["data"]["state"] is True
@pytest.mark.asyncio
async def test_websocket_entry_added(
dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection
) -> None:
"""Test WebSocket entry added event."""
# Create a mock entry
mock_entry = Mock(spec=DashboardEntry)
mock_entry.filename = "test.yaml"
mock_entry.name = "test_device"
mock_entry.to_dict.return_value = {
"name": "test_device",
"filename": "test.yaml",
"configuration": "test.yaml",
}
# Simulate entry added
DASHBOARD.bus.async_fire(DashboardEvent.ENTRY_ADDED, {"entry": mock_entry})
# Should receive entry added event
msg = await websocket_client.read_message()
assert msg is not None
data = json.loads(msg)
assert data["event"] == "entry_added"
assert data["data"]["device"]["name"] == "test_device"
assert data["data"]["device"]["filename"] == "test.yaml"
@pytest.mark.asyncio
async def test_websocket_entry_removed(
dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection
) -> None:
"""Test WebSocket entry removed event."""
# Create a mock entry
mock_entry = Mock(spec=DashboardEntry)
mock_entry.filename = "removed.yaml"
mock_entry.name = "removed_device"
mock_entry.to_dict.return_value = {
"name": "removed_device",
"filename": "removed.yaml",
"configuration": "removed.yaml",
}
# Simulate entry removed
DASHBOARD.bus.async_fire(DashboardEvent.ENTRY_REMOVED, {"entry": mock_entry})
# Should receive entry removed event
msg = await websocket_client.read_message()
assert msg is not None
data = json.loads(msg)
assert data["event"] == "entry_removed"
assert data["data"]["device"]["name"] == "removed_device"
assert data["data"]["device"]["filename"] == "removed.yaml"
@pytest.mark.asyncio
async def test_websocket_importable_device_added(
dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection
) -> None:
"""Test WebSocket importable device added event with real DiscoveredImport."""
# Create a real DiscoveredImport object
discovered = DiscoveredImport(
device_name="new_import_device",
friendly_name="New Import Device",
package_import_url="https://example.com/package",
project_name="test_project",
project_version="1.0.0",
network="wifi",
)
# Directly fire the event as the mDNS system would
device_dict = build_importable_device_dict(DASHBOARD, discovered)
DASHBOARD.bus.async_fire(
DashboardEvent.IMPORTABLE_DEVICE_ADDED, {"device": device_dict}
)
# Should receive importable device added event
msg = await websocket_client.read_message()
assert msg is not None
data = json.loads(msg)
assert data["event"] == "importable_device_added"
assert data["data"]["device"]["name"] == "new_import_device"
assert data["data"]["device"]["friendly_name"] == "New Import Device"
assert data["data"]["device"]["project_name"] == "test_project"
assert data["data"]["device"]["network"] == "wifi"
assert data["data"]["device"]["ignored"] is False
@pytest.mark.asyncio
async def test_websocket_importable_device_added_ignored(
dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection
) -> None:
"""Test WebSocket importable device added event for ignored device."""
# Add device to ignored list
DASHBOARD.ignored_devices.add("ignored_device")
# Create a real DiscoveredImport object
discovered = DiscoveredImport(
device_name="ignored_device",
friendly_name="Ignored Device",
package_import_url="https://example.com/package",
project_name="test_project",
project_version="1.0.0",
network="ethernet",
)
# Directly fire the event as the mDNS system would
device_dict = build_importable_device_dict(DASHBOARD, discovered)
DASHBOARD.bus.async_fire(
DashboardEvent.IMPORTABLE_DEVICE_ADDED, {"device": device_dict}
)
# Should receive importable device added event with ignored=True
msg = await websocket_client.read_message()
assert msg is not None
data = json.loads(msg)
assert data["event"] == "importable_device_added"
assert data["data"]["device"]["name"] == "ignored_device"
assert data["data"]["device"]["friendly_name"] == "Ignored Device"
assert data["data"]["device"]["network"] == "ethernet"
assert data["data"]["device"]["ignored"] is True
@pytest.mark.asyncio
async def test_websocket_importable_device_removed(
dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection
) -> None:
"""Test WebSocket importable device removed event."""
# Simulate importable device removed
DASHBOARD.bus.async_fire(
DashboardEvent.IMPORTABLE_DEVICE_REMOVED,
{"name": "removed_import_device"},
)
# Should receive importable device removed event
msg = await websocket_client.read_message()
assert msg is not None
data = json.loads(msg)
assert data["event"] == "importable_device_removed"
assert data["data"]["name"] == "removed_import_device"
@pytest.mark.asyncio
async def test_websocket_importable_device_already_configured(
dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection
) -> None:
"""Test that importable device event is not sent if device is already configured."""
# Get an existing configured device name
existing_entry = DASHBOARD.entries.async_all()[0]
# Simulate importable device added with same name as configured device
DASHBOARD.bus.async_fire(
DashboardEvent.IMPORTABLE_DEVICE_ADDED,
{
"device": {
"name": existing_entry.name,
"friendly_name": "Should Not Be Sent",
"package_import_url": "https://example.com/package",
"project_name": "test_project",
"project_version": "1.0.0",
"network": "wifi",
}
},
)
# Send a ping to ensure connection is still alive
await websocket_client.write_message(json.dumps({"event": "ping"}))
# Should only receive pong, not the importable device event
msg = await websocket_client.read_message()
assert msg is not None
data = json.loads(msg)
assert data["event"] == "pong"
@pytest.mark.asyncio
async def test_websocket_multiple_connections(dashboard: DashboardTestHelper) -> None:
"""Test multiple WebSocket connections."""
async with (
websocket_connection(dashboard) as ws1,
websocket_connection(dashboard) as ws2,
):
# Both should receive initial state
msg1 = await ws1.read_message()
assert msg1 is not None
data1 = json.loads(msg1)
assert data1["event"] == "initial_state"
msg2 = await ws2.read_message()
assert msg2 is not None
data2 = json.loads(msg2)
assert data2["event"] == "initial_state"
# Fire an event - both should receive it
entry = DASHBOARD.entries.async_all()[0]
state = bool_to_entry_state(False, EntryStateSource.MDNS)
DASHBOARD.bus.async_fire(
DashboardEvent.ENTRY_STATE_CHANGED, {"entry": entry, "state": state}
)
msg1 = await ws1.read_message()
assert msg1 is not None
data1 = json.loads(msg1)
assert data1["event"] == "entry_state_changed"
msg2 = await ws2.read_message()
assert msg2 is not None
data2 = json.loads(msg2)
assert data2["event"] == "entry_state_changed"
@pytest.mark.asyncio
async def test_dashboard_subscriber_lifecycle(dashboard: DashboardTestHelper) -> None:
"""Test DashboardSubscriber lifecycle."""
subscriber = DashboardSubscriber()
# Initially no subscribers
assert len(subscriber._subscribers) == 0
assert subscriber._event_loop_task is None
# Add a subscriber
mock_websocket = Mock()
unsubscribe = subscriber.subscribe(mock_websocket)
# Should have started the event loop task
assert len(subscriber._subscribers) == 1
assert subscriber._event_loop_task is not None
# Unsubscribe
unsubscribe()
# Should have stopped the task
assert len(subscriber._subscribers) == 0
@pytest.mark.asyncio
async def test_dashboard_subscriber_entries_update_interval(
dashboard: DashboardTestHelper,
) -> None:
"""Test DashboardSubscriber entries update interval."""
# Patch the constants to make the test run faster
with (
patch("esphome.dashboard.web_server.DASHBOARD_POLL_INTERVAL", 0.01),
patch("esphome.dashboard.web_server.DASHBOARD_ENTRIES_UPDATE_ITERATIONS", 2),
patch("esphome.dashboard.web_server.settings") as mock_settings,
patch("esphome.dashboard.web_server.DASHBOARD") as mock_dashboard,
):
mock_settings.status_use_mqtt = False
# Mock dashboard dependencies
mock_dashboard.ping_request = Mock()
mock_dashboard.ping_request.set = Mock()
mock_dashboard.entries = Mock()
mock_dashboard.entries.async_request_update_entries = Mock()
subscriber = DashboardSubscriber()
mock_websocket = Mock()
# Subscribe to start the event loop
unsubscribe = subscriber.subscribe(mock_websocket)
# Wait for a few iterations to ensure entries update is called
await asyncio.sleep(0.05) # Should be enough for 2+ iterations
# Unsubscribe to stop the task
unsubscribe()
# Verify entries update was called
assert mock_dashboard.entries.async_request_update_entries.call_count >= 1
# Verify ping request was set multiple times
assert mock_dashboard.ping_request.set.call_count >= 2
@pytest.mark.asyncio
async def test_websocket_refresh_command(
dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection
) -> None:
"""Test WebSocket refresh command triggers dashboard update."""
with patch("esphome.dashboard.web_server.DASHBOARD_SUBSCRIBER") as mock_subscriber:
mock_subscriber.request_refresh = Mock()
# Send refresh command
await websocket_client.write_message(json.dumps({"event": "refresh"}))
# Give it a moment to process
await asyncio.sleep(0.01)
# Verify request_refresh was called
mock_subscriber.request_refresh.assert_called_once()
@pytest.mark.asyncio
async def test_dashboard_subscriber_refresh_event(
dashboard: DashboardTestHelper,
) -> None:
"""Test DashboardSubscriber refresh event triggers immediate update."""
# Patch the constants to make the test run faster
with (
patch(
"esphome.dashboard.web_server.DASHBOARD_POLL_INTERVAL", 1.0
), # Long timeout
patch(
"esphome.dashboard.web_server.DASHBOARD_ENTRIES_UPDATE_ITERATIONS", 100
), # Won't reach naturally
patch("esphome.dashboard.web_server.settings") as mock_settings,
patch("esphome.dashboard.web_server.DASHBOARD") as mock_dashboard,
):
mock_settings.status_use_mqtt = False
# Mock dashboard dependencies
mock_dashboard.ping_request = Mock()
mock_dashboard.ping_request.set = Mock()
mock_dashboard.entries = Mock()
mock_dashboard.entries.async_request_update_entries = AsyncMock()
subscriber = DashboardSubscriber()
mock_websocket = Mock()
# Subscribe to start the event loop
unsubscribe = subscriber.subscribe(mock_websocket)
# Wait a bit to ensure loop is running
await asyncio.sleep(0.01)
# Verify entries update hasn't been called yet (iterations not reached)
assert mock_dashboard.entries.async_request_update_entries.call_count == 0
# Request refresh
subscriber.request_refresh()
# Wait for the refresh to be processed
await asyncio.sleep(0.01)
# Now entries update should have been called
assert mock_dashboard.entries.async_request_update_entries.call_count == 1
# Unsubscribe to stop the task
unsubscribe()
# Give it a moment to clean up
await asyncio.sleep(0.01)