mirror of
https://github.com/esphome/esphome.git
synced 2025-09-02 19:32:19 +01:00
Refactor dashboard zeroconf support (#5681)
This commit is contained in:
@@ -1,130 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Callable
|
||||
|
||||
from zeroconf import (
|
||||
DNSAddress,
|
||||
DNSOutgoing,
|
||||
DNSQuestion,
|
||||
RecordUpdate,
|
||||
RecordUpdateListener,
|
||||
IPVersion,
|
||||
ServiceBrowser,
|
||||
ServiceInfo,
|
||||
ServiceStateChange,
|
||||
Zeroconf,
|
||||
current_time_millis,
|
||||
)
|
||||
|
||||
from esphome.storage_json import StorageJSON, ext_storage_path
|
||||
|
||||
_CLASS_IN = 1
|
||||
_FLAGS_QR_QUERY = 0x0000 # query
|
||||
_TYPE_A = 1
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HostResolver(RecordUpdateListener):
|
||||
class HostResolver(ServiceInfo):
|
||||
"""Resolve a host name to an IP address."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.address: Optional[bytes] = None
|
||||
|
||||
def async_update_records(
|
||||
self, zc: Zeroconf, now: float, records: list[RecordUpdate]
|
||||
) -> None:
|
||||
"""Update multiple records in one shot.
|
||||
|
||||
This will run in zeroconf's event loop thread so it
|
||||
must be thread-safe.
|
||||
"""
|
||||
for record_update in records:
|
||||
record, _ = record_update
|
||||
if record is None:
|
||||
continue
|
||||
if record.type == _TYPE_A:
|
||||
assert isinstance(record, DNSAddress)
|
||||
if record.name == self.name:
|
||||
self.address = record.address
|
||||
|
||||
def request(self, zc: Zeroconf, timeout: float) -> bool:
|
||||
now = time.time()
|
||||
delay = 0.2
|
||||
next_ = now + delay
|
||||
last = now + timeout
|
||||
|
||||
try:
|
||||
zc.add_listener(self, None)
|
||||
while self.address is None:
|
||||
if last <= now:
|
||||
# Timeout
|
||||
return False
|
||||
if next_ <= now:
|
||||
out = DNSOutgoing(_FLAGS_QR_QUERY)
|
||||
out.add_question(DNSQuestion(self.name, _TYPE_A, _CLASS_IN))
|
||||
zc.send(out)
|
||||
next_ = now + delay
|
||||
delay *= 2
|
||||
|
||||
time.sleep(min(next_, last) - now)
|
||||
now = time.time()
|
||||
finally:
|
||||
zc.remove_listener(self)
|
||||
|
||||
return True
|
||||
@property
|
||||
def _is_complete(self) -> bool:
|
||||
"""The ServiceInfo has all expected properties."""
|
||||
return bool(self._ipv4_addresses)
|
||||
|
||||
|
||||
class DashboardStatus(threading.Thread):
|
||||
PING_AFTER = 15 * 1000 # Send new mDNS request after 15 seconds
|
||||
OFFLINE_AFTER = PING_AFTER * 2 # Offline if no mDNS response after 30 seconds
|
||||
|
||||
def __init__(self, zc: Zeroconf, on_update) -> None:
|
||||
threading.Thread.__init__(self)
|
||||
self.zc = zc
|
||||
self.query_hosts: set[str] = set()
|
||||
self.key_to_host: dict[str, str] = {}
|
||||
self.stop_event = threading.Event()
|
||||
self.query_event = threading.Event()
|
||||
class DashboardStatus:
|
||||
def __init__(self, on_update: Callable[[dict[str, bool | None], []]]) -> None:
|
||||
"""Initialize the dashboard status."""
|
||||
self.on_update = on_update
|
||||
|
||||
def request_query(self, hosts: dict[str, str]) -> None:
|
||||
self.query_hosts = set(hosts.values())
|
||||
self.key_to_host = hosts
|
||||
self.query_event.set()
|
||||
|
||||
def stop(self) -> None:
|
||||
self.stop_event.set()
|
||||
self.query_event.set()
|
||||
|
||||
def host_status(self, key: str) -> bool:
|
||||
entries = self.zc.cache.entries_with_name(key)
|
||||
if not entries:
|
||||
return False
|
||||
now = current_time_millis()
|
||||
|
||||
return any(
|
||||
(entry.created + DashboardStatus.OFFLINE_AFTER) >= now for entry in entries
|
||||
)
|
||||
|
||||
def run(self) -> None:
|
||||
while not self.stop_event.is_set():
|
||||
self.on_update(
|
||||
{key: self.host_status(host) for key, host in self.key_to_host.items()}
|
||||
)
|
||||
now = current_time_millis()
|
||||
for host in self.query_hosts:
|
||||
entries = self.zc.cache.entries_with_name(host)
|
||||
if not entries or all(
|
||||
(entry.created + DashboardStatus.PING_AFTER) <= now
|
||||
for entry in entries
|
||||
):
|
||||
out = DNSOutgoing(_FLAGS_QR_QUERY)
|
||||
out.add_question(DNSQuestion(host, _TYPE_A, _CLASS_IN))
|
||||
self.zc.send(out)
|
||||
self.query_event.wait()
|
||||
self.query_event.clear()
|
||||
def browser_callback(
|
||||
self,
|
||||
zeroconf: Zeroconf,
|
||||
service_type: str,
|
||||
name: str,
|
||||
state_change: ServiceStateChange,
|
||||
) -> None:
|
||||
"""Handle a service update."""
|
||||
short_name = name.partition(".")[0]
|
||||
if state_change == ServiceStateChange.Removed:
|
||||
self.on_update({short_name: False})
|
||||
elif state_change in (ServiceStateChange.Updated, ServiceStateChange.Added):
|
||||
self.on_update({short_name: True})
|
||||
|
||||
|
||||
ESPHOME_SERVICE_TYPE = "_esphomelib._tcp.local."
|
||||
@@ -138,7 +57,7 @@ TXT_RECORD_VERSION = b"version"
|
||||
|
||||
@dataclass
|
||||
class DiscoveredImport:
|
||||
friendly_name: Optional[str]
|
||||
friendly_name: str | None
|
||||
device_name: str
|
||||
package_import_url: str
|
||||
project_name: str
|
||||
@@ -146,15 +65,15 @@ class DiscoveredImport:
|
||||
network: str
|
||||
|
||||
|
||||
class DashboardBrowser(ServiceBrowser):
|
||||
"""A class to browse for ESPHome nodes."""
|
||||
|
||||
|
||||
class DashboardImportDiscovery:
|
||||
def __init__(self, zc: Zeroconf) -> None:
|
||||
self.zc = zc
|
||||
self.service_browser = ServiceBrowser(
|
||||
self.zc, ESPHOME_SERVICE_TYPE, [self._on_update]
|
||||
)
|
||||
def __init__(self) -> None:
|
||||
self.import_state: dict[str, DiscoveredImport] = {}
|
||||
|
||||
def _on_update(
|
||||
def browser_callback(
|
||||
self,
|
||||
zeroconf: Zeroconf,
|
||||
service_type: str,
|
||||
@@ -167,8 +86,6 @@ class DashboardImportDiscovery:
|
||||
name,
|
||||
state_change,
|
||||
)
|
||||
if service_type != ESPHOME_SERVICE_TYPE:
|
||||
return
|
||||
if state_change == ServiceStateChange.Removed:
|
||||
self.import_state.pop(name, None)
|
||||
return
|
||||
@@ -212,9 +129,6 @@ class DashboardImportDiscovery:
|
||||
network=network,
|
||||
)
|
||||
|
||||
def cancel(self) -> None:
|
||||
self.service_browser.cancel()
|
||||
|
||||
def update_device_mdns(self, node_name: str, version: str):
|
||||
storage_path = ext_storage_path(node_name + ".yaml")
|
||||
storage_json = StorageJSON.load(storage_path)
|
||||
@@ -234,7 +148,11 @@ class DashboardImportDiscovery:
|
||||
|
||||
class EsphomeZeroconf(Zeroconf):
|
||||
def resolve_host(self, host: str, timeout=3.0):
|
||||
info = HostResolver(host)
|
||||
if info.request(self, timeout):
|
||||
return socket.inet_ntoa(info.address)
|
||||
"""Resolve a host name to an IP address."""
|
||||
name = host.partition(".")[0]
|
||||
info = HostResolver(f"{name}.{ESPHOME_SERVICE_TYPE}", ESPHOME_SERVICE_TYPE)
|
||||
if (info.load_from_cache(self) or info.request(self, timeout * 1000)) and (
|
||||
addresses := info.ip_addresses_by_version(IPVersion.V4Only)
|
||||
):
|
||||
return str(addresses[0])
|
||||
return None
|
||||
|
Reference in New Issue
Block a user