mirror of
				https://github.com/esphome/esphome.git
				synced 2025-11-04 09:01:49 +00:00 
			
		
		
		
	dashboard: convert ping thread to use asyncio (#5749)
This commit is contained in:
		@@ -1,18 +1,13 @@
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import asyncio
 | 
			
		||||
import threading
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ThreadedAsyncEvent:
 | 
			
		||||
    """This is a shim to allow the asyncio event to be used in a threaded context.
 | 
			
		||||
 | 
			
		||||
    When more of the code is moved to asyncio, this can be removed.
 | 
			
		||||
    """
 | 
			
		||||
class AsyncEvent:
 | 
			
		||||
    """This is a shim around asyncio.Event."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self) -> None:
 | 
			
		||||
        """Initialize the ThreadedAsyncEvent."""
 | 
			
		||||
        self.event = threading.Event()
 | 
			
		||||
        self.async_event: asyncio.Event | None = None
 | 
			
		||||
        self.loop: asyncio.AbstractEventLoop | None = None
 | 
			
		||||
 | 
			
		||||
@@ -26,31 +21,11 @@ class ThreadedAsyncEvent:
 | 
			
		||||
    def async_set(self) -> None:
 | 
			
		||||
        """Set the asyncio.Event instance."""
 | 
			
		||||
        self.async_event.set()
 | 
			
		||||
        self.event.set()
 | 
			
		||||
 | 
			
		||||
    def set(self) -> None:
 | 
			
		||||
        """Set the event."""
 | 
			
		||||
        self.loop.call_soon_threadsafe(self.async_event.set)
 | 
			
		||||
        self.event.set()
 | 
			
		||||
 | 
			
		||||
    def wait(self) -> None:
 | 
			
		||||
        """Wait for the event."""
 | 
			
		||||
        self.event.wait()
 | 
			
		||||
 | 
			
		||||
    async def async_wait(self) -> None:
 | 
			
		||||
        """Wait the event async."""
 | 
			
		||||
        await self.async_event.wait()
 | 
			
		||||
 | 
			
		||||
    def clear(self) -> None:
 | 
			
		||||
        """Clear the event."""
 | 
			
		||||
        self.loop.call_soon_threadsafe(self.async_event.clear)
 | 
			
		||||
        self.event.clear()
 | 
			
		||||
 | 
			
		||||
    def async_clear(self) -> None:
 | 
			
		||||
        """Clear the event async."""
 | 
			
		||||
        self.async_event.clear()
 | 
			
		||||
        self.event.clear()
 | 
			
		||||
 | 
			
		||||
    def is_set(self) -> bool:
 | 
			
		||||
        """Return if the event is set."""
 | 
			
		||||
        return self.event.is_set()
 | 
			
		||||
 
 | 
			
		||||
@@ -3,7 +3,6 @@ from __future__ import annotations
 | 
			
		||||
import asyncio
 | 
			
		||||
import base64
 | 
			
		||||
import binascii
 | 
			
		||||
import collections
 | 
			
		||||
import datetime
 | 
			
		||||
import functools
 | 
			
		||||
import gzip
 | 
			
		||||
@@ -11,14 +10,13 @@ import hashlib
 | 
			
		||||
import hmac
 | 
			
		||||
import json
 | 
			
		||||
import logging
 | 
			
		||||
import multiprocessing
 | 
			
		||||
import os
 | 
			
		||||
import secrets
 | 
			
		||||
import shutil
 | 
			
		||||
import subprocess
 | 
			
		||||
import threading
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Any
 | 
			
		||||
from typing import Any, cast
 | 
			
		||||
 | 
			
		||||
import tornado
 | 
			
		||||
import tornado.concurrent
 | 
			
		||||
@@ -52,9 +50,9 @@ from esphome.zeroconf import (
 | 
			
		||||
    DashboardImportDiscovery,
 | 
			
		||||
    DashboardStatus,
 | 
			
		||||
)
 | 
			
		||||
from .async_adapter import ThreadedAsyncEvent
 | 
			
		||||
 | 
			
		||||
from .util import friendly_name_slugify, password_hash
 | 
			
		||||
from .async_adapter import AsyncEvent
 | 
			
		||||
from .util import chunked, friendly_name_slugify, password_hash
 | 
			
		||||
 | 
			
		||||
_LOGGER = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
@@ -603,7 +601,7 @@ class ImportRequestHandler(BaseHandler):
 | 
			
		||||
                encryption,
 | 
			
		||||
            )
 | 
			
		||||
            # Make sure the device gets marked online right away
 | 
			
		||||
            PING_REQUEST.set()
 | 
			
		||||
            PING_REQUEST.async_set()
 | 
			
		||||
        except FileExistsError:
 | 
			
		||||
            self.set_status(500)
 | 
			
		||||
            self.write("File already exists")
 | 
			
		||||
@@ -905,15 +903,6 @@ class MainRequestHandler(BaseHandler):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _ping_func(filename, address):
 | 
			
		||||
    if os.name == "nt":
 | 
			
		||||
        command = ["ping", "-n", "1", address]
 | 
			
		||||
    else:
 | 
			
		||||
        command = ["ping", "-c", "1", address]
 | 
			
		||||
    rc, _, _ = run_system_command(*command)
 | 
			
		||||
    return filename, rc == 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PrometheusServiceDiscoveryHandler(BaseHandler):
 | 
			
		||||
    @authenticated
 | 
			
		||||
    def get(self):
 | 
			
		||||
@@ -1070,47 +1059,48 @@ class MDNSStatus:
 | 
			
		||||
        self.aiozc = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PingStatusThread(threading.Thread):
 | 
			
		||||
    def run(self):
 | 
			
		||||
        with multiprocessing.Pool(processes=8) as pool:
 | 
			
		||||
            while not STOP_EVENT.wait(2):
 | 
			
		||||
                # Only do pings if somebody has the dashboard open
 | 
			
		||||
async def _async_ping_host(host: str) -> bool:
 | 
			
		||||
    """Ping a host."""
 | 
			
		||||
    ping_command = ["ping", "-n" if os.name == "nt" else "-c", "1"]
 | 
			
		||||
    process = await asyncio.create_subprocess_exec(
 | 
			
		||||
        *ping_command,
 | 
			
		||||
        host,
 | 
			
		||||
        stdin=asyncio.subprocess.DEVNULL,
 | 
			
		||||
        stdout=asyncio.subprocess.DEVNULL,
 | 
			
		||||
        stderr=asyncio.subprocess.DEVNULL,
 | 
			
		||||
    )
 | 
			
		||||
    await process.wait()
 | 
			
		||||
    return process.returncode == 0
 | 
			
		||||
 | 
			
		||||
                def callback(ret):
 | 
			
		||||
                    PING_RESULT[ret[0]] = ret[1]
 | 
			
		||||
 | 
			
		||||
                entries = _list_dashboard_entries()
 | 
			
		||||
                queue = collections.deque()
 | 
			
		||||
                for entry in entries:
 | 
			
		||||
                    if entry.address is None:
 | 
			
		||||
                        PING_RESULT[entry.filename] = None
 | 
			
		||||
                        continue
 | 
			
		||||
class PingStatus:
 | 
			
		||||
    def __init__(self) -> None:
 | 
			
		||||
        """Initialize the PingStatus class."""
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self._loop = asyncio.get_running_loop()
 | 
			
		||||
 | 
			
		||||
                    result = pool.apply_async(
 | 
			
		||||
                        _ping_func, (entry.filename, entry.address), callback=callback
 | 
			
		||||
                    )
 | 
			
		||||
                    queue.append(result)
 | 
			
		||||
 | 
			
		||||
                while queue:
 | 
			
		||||
                    item = queue[0]
 | 
			
		||||
                    if item.ready():
 | 
			
		||||
                        queue.popleft()
 | 
			
		||||
                        continue
 | 
			
		||||
 | 
			
		||||
                    try:
 | 
			
		||||
                        item.get(0.1)
 | 
			
		||||
                    except OSError:
 | 
			
		||||
                        # ping not installed
 | 
			
		||||
                        pass
 | 
			
		||||
                    except multiprocessing.TimeoutError:
 | 
			
		||||
                        pass
 | 
			
		||||
 | 
			
		||||
                    if STOP_EVENT.is_set():
 | 
			
		||||
                        pool.terminate()
 | 
			
		||||
                        return
 | 
			
		||||
 | 
			
		||||
                PING_REQUEST.wait()
 | 
			
		||||
                PING_REQUEST.clear()
 | 
			
		||||
    async def async_run(self) -> None:
 | 
			
		||||
        """Run the ping status."""
 | 
			
		||||
        while not STOP_EVENT.is_set():
 | 
			
		||||
            # Only ping if the dashboard is open
 | 
			
		||||
            await PING_REQUEST.async_wait()
 | 
			
		||||
            PING_REQUEST.async_clear()
 | 
			
		||||
            entries = await self._loop.run_in_executor(None, _list_dashboard_entries)
 | 
			
		||||
            to_ping: list[DashboardEntry] = [
 | 
			
		||||
                entry for entry in entries if entry.address is not None
 | 
			
		||||
            ]
 | 
			
		||||
            for ping_group in chunked(to_ping, 16):
 | 
			
		||||
                ping_group = cast(list[DashboardEntry], ping_group)
 | 
			
		||||
                results = await asyncio.gather(
 | 
			
		||||
                    *(_async_ping_host(entry.address) for entry in ping_group),
 | 
			
		||||
                    return_exceptions=True,
 | 
			
		||||
                )
 | 
			
		||||
                for entry, result in zip(ping_group, results):
 | 
			
		||||
                    if isinstance(result, Exception):
 | 
			
		||||
                        result = False
 | 
			
		||||
                    elif isinstance(result, BaseException):
 | 
			
		||||
                        raise result
 | 
			
		||||
                    PING_RESULT[entry.filename] = result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MqttStatusThread(threading.Thread):
 | 
			
		||||
@@ -1171,7 +1161,7 @@ class MqttStatusThread(threading.Thread):
 | 
			
		||||
class PingRequestHandler(BaseHandler):
 | 
			
		||||
    @authenticated
 | 
			
		||||
    def get(self):
 | 
			
		||||
        PING_REQUEST.set()
 | 
			
		||||
        PING_REQUEST.async_set()
 | 
			
		||||
        if settings.status_use_mqtt:
 | 
			
		||||
            MQTT_PING_REQUEST.set()
 | 
			
		||||
        self.set_header("content-type", "application/json")
 | 
			
		||||
@@ -1261,7 +1251,7 @@ class MDNSContainer:
 | 
			
		||||
PING_RESULT: dict = {}
 | 
			
		||||
IMPORT_RESULT = {}
 | 
			
		||||
STOP_EVENT = threading.Event()
 | 
			
		||||
PING_REQUEST = ThreadedAsyncEvent()
 | 
			
		||||
PING_REQUEST = AsyncEvent()
 | 
			
		||||
MQTT_PING_REQUEST = threading.Event()
 | 
			
		||||
MDNS_CONTAINER = MDNSContainer()
 | 
			
		||||
 | 
			
		||||
@@ -1561,10 +1551,10 @@ async def async_start_web_server(args):
 | 
			
		||||
            webbrowser.open(f"http://{args.address}:{args.port}")
 | 
			
		||||
 | 
			
		||||
    mdns_task: asyncio.Task | None = None
 | 
			
		||||
    ping_status_thread: PingStatusThread | None = None
 | 
			
		||||
    ping_status_task: asyncio.Task | None = None
 | 
			
		||||
    if settings.status_use_ping:
 | 
			
		||||
        ping_status_thread = PingStatusThread()
 | 
			
		||||
        ping_status_thread.start()
 | 
			
		||||
        ping_status = PingStatus()
 | 
			
		||||
        ping_status_task = asyncio.create_task(ping_status.async_run())
 | 
			
		||||
    else:
 | 
			
		||||
        mdns_status = MDNSStatus()
 | 
			
		||||
        await mdns_status.async_refresh_hosts()
 | 
			
		||||
@@ -1581,9 +1571,9 @@ async def async_start_web_server(args):
 | 
			
		||||
    finally:
 | 
			
		||||
        _LOGGER.info("Shutting down...")
 | 
			
		||||
        STOP_EVENT.set()
 | 
			
		||||
        PING_REQUEST.set()
 | 
			
		||||
        if ping_status_thread:
 | 
			
		||||
            ping_status_thread.join()
 | 
			
		||||
        PING_REQUEST.async_set()
 | 
			
		||||
        if ping_status_task:
 | 
			
		||||
            ping_status_task.cancel()
 | 
			
		||||
        MDNS_CONTAINER.set_mdns(None)
 | 
			
		||||
        if mdns_task:
 | 
			
		||||
            mdns_task.cancel()
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,9 @@
 | 
			
		||||
import hashlib
 | 
			
		||||
import unicodedata
 | 
			
		||||
from collections.abc import Iterable
 | 
			
		||||
from functools import partial
 | 
			
		||||
from itertools import islice
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
from esphome.const import ALLOWED_NAME_CHARS
 | 
			
		||||
 | 
			
		||||
@@ -30,3 +34,19 @@ def friendly_name_slugify(value):
 | 
			
		||||
        .strip("-")
 | 
			
		||||
    )
 | 
			
		||||
    return "".join(c for c in value if c in ALLOWED_NAME_CHARS)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def take(take_num: int, iterable: Iterable) -> list[Any]:
 | 
			
		||||
    """Return first n items of the iterable as a list.
 | 
			
		||||
 | 
			
		||||
    From itertools recipes
 | 
			
		||||
    """
 | 
			
		||||
    return list(islice(iterable, take_num))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chunked(iterable: Iterable, chunked_num: int) -> Iterable[Any]:
 | 
			
		||||
    """Break *iterable* into lists of length *n*.
 | 
			
		||||
 | 
			
		||||
    From more-itertools
 | 
			
		||||
    """
 | 
			
		||||
    return iter(partial(take, chunked_num, iter(iterable)), [])
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user