1
0
mirror of https://github.com/esphome/esphome.git synced 2025-01-18 20:10:55 +00:00

dashboard: Small cleanups to dashboard (#5841)

This commit is contained in:
J. Nick Koston 2023-11-27 16:39:24 -06:00 committed by GitHub
parent 460362b11f
commit 4e6d3729e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 87 additions and 74 deletions

View File

@ -262,7 +262,7 @@ class DashboardEntry:
self.state = EntryState.UNKNOWN self.state = EntryState.UNKNOWN
self._to_dict: dict[str, Any] | None = None self._to_dict: dict[str, Any] | None = None
def __repr__(self): def __repr__(self) -> str:
"""Return the representation of this entry.""" """Return the representation of this entry."""
return ( return (
f"DashboardEntry(path={self.path} " f"DashboardEntry(path={self.path} "

View File

@ -23,45 +23,45 @@ class DashboardSettings:
self.cookie_secret: str | None = None self.cookie_secret: str | None = None
self.absolute_config_dir: Path | None = None self.absolute_config_dir: Path | None = None
def parse_args(self, args): def parse_args(self, args: Any) -> None:
self.on_ha_addon: bool = args.ha_addon self.on_ha_addon: bool = args.ha_addon
password: str = args.password or os.getenv("PASSWORD", "") password = args.password or os.getenv("PASSWORD") or ""
if not self.on_ha_addon: if not self.on_ha_addon:
self.username: str = args.username or os.getenv("USERNAME", "") self.username = args.username or os.getenv("USERNAME") or ""
self.using_password = bool(password) self.using_password = bool(password)
if self.using_password: if self.using_password:
self.password_hash = password_hash(password) self.password_hash = password_hash(password)
self.config_dir: str = args.configuration self.config_dir = args.configuration
self.absolute_config_dir: Path = Path(self.config_dir).resolve() self.absolute_config_dir = Path(self.config_dir).resolve()
CORE.config_path = os.path.join(self.config_dir, ".") CORE.config_path = os.path.join(self.config_dir, ".")
@property @property
def relative_url(self): def relative_url(self) -> str:
return os.getenv("ESPHOME_DASHBOARD_RELATIVE_URL", "/") return os.getenv("ESPHOME_DASHBOARD_RELATIVE_URL") or "/"
@property @property
def status_use_ping(self): def status_use_ping(self):
return get_bool_env("ESPHOME_DASHBOARD_USE_PING") return get_bool_env("ESPHOME_DASHBOARD_USE_PING")
@property @property
def status_use_mqtt(self): def status_use_mqtt(self) -> bool:
return get_bool_env("ESPHOME_DASHBOARD_USE_MQTT") return get_bool_env("ESPHOME_DASHBOARD_USE_MQTT")
@property @property
def using_ha_addon_auth(self): def using_ha_addon_auth(self) -> bool:
if not self.on_ha_addon: if not self.on_ha_addon:
return False return False
return not get_bool_env("DISABLE_HA_AUTHENTICATION") return not get_bool_env("DISABLE_HA_AUTHENTICATION")
@property @property
def using_auth(self): def using_auth(self) -> bool:
return self.using_password or self.using_ha_addon_auth return self.using_password or self.using_ha_addon_auth
@property @property
def streamer_mode(self): def streamer_mode(self) -> bool:
return get_bool_env("ESPHOME_STREAMER_MODE") return get_bool_env("ESPHOME_STREAMER_MODE")
def check_password(self, username, password): def check_password(self, username: str, password: str) -> bool:
if not self.using_auth: if not self.using_auth:
return True return True
if username != self.username: if username != self.username:

View File

@ -14,12 +14,14 @@ import shutil
import subprocess import subprocess
import threading import threading
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, Callable, TypeVar
from collections.abc import Iterable
import tornado import tornado
import tornado.concurrent import tornado.concurrent
import tornado.gen import tornado.gen
import tornado.httpserver import tornado.httpserver
import tornado.httputil
import tornado.ioloop import tornado.ioloop
import tornado.iostream import tornado.iostream
import tornado.netutil import tornado.netutil
@ -27,9 +29,9 @@ import tornado.process
import tornado.queues import tornado.queues
import tornado.web import tornado.web
import tornado.websocket import tornado.websocket
import tornado.httputil
import yaml import yaml
from tornado.log import access_log from tornado.log import access_log
from yaml.nodes import Node
from esphome import const, platformio_api, yaml_util from esphome import const, platformio_api, yaml_util
from esphome.helpers import get_bool_env, mkdir_p from esphome.helpers import get_bool_env, mkdir_p
@ -54,7 +56,7 @@ cookie_authenticated_yes = b"yes"
settings = DASHBOARD.settings settings = DASHBOARD.settings
def template_args(): def template_args() -> dict[str, Any]:
version = const.__version__ version = const.__version__
if "b" in version: if "b" in version:
docs_link = "https://beta.esphome.io/" docs_link = "https://beta.esphome.io/"
@ -73,9 +75,12 @@ def template_args():
} }
def authenticated(func): T = TypeVar("T", bound=Callable[..., Any])
def authenticated(func: T) -> T:
@functools.wraps(func) @functools.wraps(func)
def decorator(self, *args, **kwargs): def decorator(self, *args: Any, **kwargs: Any):
if not is_authenticated(self): if not is_authenticated(self):
self.redirect("./login") self.redirect("./login")
return None return None
@ -209,7 +214,7 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler):
tornado.ioloop.IOLoop.current().spawn_callback(self._redirect_stdout) tornado.ioloop.IOLoop.current().spawn_callback(self._redirect_stdout)
@property @property
def is_process_active(self): def is_process_active(self) -> bool:
return self._proc is not None and self._proc.returncode is None return self._proc is not None and self._proc.returncode is None
@websocket_method("stdin") @websocket_method("stdin")
@ -398,7 +403,7 @@ class EsphomeUpdateAllHandler(EsphomeCommandWebSocket):
class SerialPortRequestHandler(BaseHandler): class SerialPortRequestHandler(BaseHandler):
@authenticated @authenticated
async def get(self): async def get(self) -> None:
ports = await asyncio.get_running_loop().run_in_executor(None, get_serial_ports) ports = await asyncio.get_running_loop().run_in_executor(None, get_serial_ports)
data = [] data = []
for port in ports: for port in ports:
@ -418,7 +423,7 @@ class SerialPortRequestHandler(BaseHandler):
class WizardRequestHandler(BaseHandler): class WizardRequestHandler(BaseHandler):
@authenticated @authenticated
def post(self): def post(self) -> None:
from esphome import wizard from esphome import wizard
kwargs = { kwargs = {
@ -449,7 +454,7 @@ class WizardRequestHandler(BaseHandler):
class ImportRequestHandler(BaseHandler): class ImportRequestHandler(BaseHandler):
@authenticated @authenticated
def post(self): def post(self) -> None:
from esphome.components.dashboard_import import import_config from esphome.components.dashboard_import import import_config
dashboard = DASHBOARD dashboard = DASHBOARD
@ -504,7 +509,7 @@ class ImportRequestHandler(BaseHandler):
class DownloadListRequestHandler(BaseHandler): class DownloadListRequestHandler(BaseHandler):
@authenticated @authenticated
@bind_config @bind_config
def get(self, configuration=None): def get(self, configuration: str | None = None) -> None:
storage_path = ext_storage_path(configuration) storage_path = ext_storage_path(configuration)
storage_json = StorageJSON.load(storage_path) storage_json = StorageJSON.load(storage_path)
if storage_json is None: if storage_json is None:
@ -512,26 +517,29 @@ class DownloadListRequestHandler(BaseHandler):
return return
from esphome.components.esp32 import VARIANTS as ESP32_VARIANTS from esphome.components.esp32 import VARIANTS as ESP32_VARIANTS
from esphome.components.esp32 import get_download_types as esp32_types
from esphome.components.esp8266 import get_download_types as esp8266_types
from esphome.components.libretiny import get_download_types as libretiny_types
from esphome.components.rp2040 import get_download_types as rp2040_types
downloads = [] downloads = []
platform = storage_json.target_platform.lower() platform: str = storage_json.target_platform.lower()
if platform == const.PLATFORM_RP2040: if platform == const.PLATFORM_RP2040:
from esphome.components.rp2040 import get_download_types as rp2040_types
downloads = rp2040_types(storage_json) downloads = rp2040_types(storage_json)
elif platform == const.PLATFORM_ESP8266: elif platform == const.PLATFORM_ESP8266:
from esphome.components.esp8266 import get_download_types as esp8266_types
downloads = esp8266_types(storage_json) downloads = esp8266_types(storage_json)
elif platform.upper() in ESP32_VARIANTS: elif platform.upper() in ESP32_VARIANTS:
from esphome.components.esp32 import get_download_types as esp32_types
downloads = esp32_types(storage_json) downloads = esp32_types(storage_json)
elif platform == const.PLATFORM_BK72XX: elif platform in (const.PLATFORM_RTL87XX, const.PLATFORM_BK72XX):
downloads = libretiny_types(storage_json) from esphome.components.libretiny import (
elif platform == const.PLATFORM_RTL87XX: get_download_types as libretiny_types,
)
downloads = libretiny_types(storage_json) downloads = libretiny_types(storage_json)
else: else:
self.send_error(418) raise ValueError(f"Unknown platform {platform}")
return
self.set_status(200) self.set_status(200)
self.set_header("content-type", "application/json") self.set_header("content-type", "application/json")
@ -551,7 +559,7 @@ class DownloadBinaryRequestHandler(BaseHandler):
@authenticated @authenticated
@bind_config @bind_config
async def get(self, configuration: str | None = None): async def get(self, configuration: str | None = None) -> None:
"""Download a binary file.""" """Download a binary file."""
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
compressed = self.get_argument("compressed", "0") == "1" compressed = self.get_argument("compressed", "0") == "1"
@ -618,7 +626,7 @@ class DownloadBinaryRequestHandler(BaseHandler):
class EsphomeVersionHandler(BaseHandler): class EsphomeVersionHandler(BaseHandler):
@authenticated @authenticated
def get(self): def get(self) -> None:
self.set_header("Content-Type", "application/json") self.set_header("Content-Type", "application/json")
self.write(json.dumps({"version": const.__version__})) self.write(json.dumps({"version": const.__version__}))
self.finish() self.finish()
@ -626,7 +634,7 @@ class EsphomeVersionHandler(BaseHandler):
class ListDevicesHandler(BaseHandler): class ListDevicesHandler(BaseHandler):
@authenticated @authenticated
async def get(self): async def get(self) -> None:
dashboard = DASHBOARD dashboard = DASHBOARD
await dashboard.entries.async_request_update_entries() await dashboard.entries.async_request_update_entries()
entries = dashboard.entries.async_all() entries = dashboard.entries.async_all()
@ -656,7 +664,7 @@ class ListDevicesHandler(BaseHandler):
class MainRequestHandler(BaseHandler): class MainRequestHandler(BaseHandler):
@authenticated @authenticated
def get(self): def get(self) -> None:
begin = bool(self.get_argument("begin", False)) begin = bool(self.get_argument("begin", False))
self.render( self.render(
@ -669,7 +677,7 @@ class MainRequestHandler(BaseHandler):
class PrometheusServiceDiscoveryHandler(BaseHandler): class PrometheusServiceDiscoveryHandler(BaseHandler):
@authenticated @authenticated
async def get(self): async def get(self) -> None:
dashboard = DASHBOARD dashboard = DASHBOARD
await dashboard.entries.async_request_update_entries() await dashboard.entries.async_request_update_entries()
entries = dashboard.entries.async_all() entries = dashboard.entries.async_all()
@ -698,29 +706,34 @@ class PrometheusServiceDiscoveryHandler(BaseHandler):
class BoardsRequestHandler(BaseHandler): class BoardsRequestHandler(BaseHandler):
@authenticated @authenticated
def get(self, platform: str): def get(self, platform: str) -> None:
from esphome.components.bk72xx.boards import BOARDS as BK72XX_BOARDS
from esphome.components.esp32.boards import BOARDS as ESP32_BOARDS
from esphome.components.esp8266.boards import BOARDS as ESP8266_BOARDS
from esphome.components.rp2040.boards import BOARDS as RP2040_BOARDS
from esphome.components.rtl87xx.boards import BOARDS as RTL87XX_BOARDS
platform_to_boards = {
const.PLATFORM_ESP32: ESP32_BOARDS,
const.PLATFORM_ESP8266: ESP8266_BOARDS,
const.PLATFORM_RP2040: RP2040_BOARDS,
const.PLATFORM_BK72XX: BK72XX_BOARDS,
const.PLATFORM_RTL87XX: RTL87XX_BOARDS,
}
# filter all ESP32 variants by requested platform # filter all ESP32 variants by requested platform
if platform.startswith("esp32"): if platform.startswith("esp32"):
from esphome.components.esp32.boards import BOARDS as ESP32_BOARDS
boards = { boards = {
k: v k: v
for k, v in platform_to_boards[const.PLATFORM_ESP32].items() for k, v in ESP32_BOARDS.items()
if v[const.KEY_VARIANT] == platform.upper() if v[const.KEY_VARIANT] == platform.upper()
} }
elif platform == const.PLATFORM_ESP8266:
from esphome.components.esp8266.boards import BOARDS as ESP8266_BOARDS
boards = ESP8266_BOARDS
elif platform == const.PLATFORM_RP2040:
from esphome.components.rp2040.boards import BOARDS as RP2040_BOARDS
boards = RP2040_BOARDS
elif platform == const.PLATFORM_BK72XX:
from esphome.components.bk72xx.boards import BOARDS as BK72XX_BOARDS
boards = BK72XX_BOARDS
elif platform == const.PLATFORM_RTL87XX:
from esphome.components.rtl87xx.boards import BOARDS as RTL87XX_BOARDS
boards = RTL87XX_BOARDS
else: else:
boards = platform_to_boards[platform] raise ValueError(f"Unknown platform {platform}")
# map to a {board_name: board_title} dict # map to a {board_name: board_title} dict
platform_boards = {key: val[const.KEY_NAME] for key, val in boards.items()} platform_boards = {key: val[const.KEY_NAME] for key, val in boards.items()}
@ -734,7 +747,7 @@ class BoardsRequestHandler(BaseHandler):
class PingRequestHandler(BaseHandler): class PingRequestHandler(BaseHandler):
@authenticated @authenticated
def get(self): def get(self) -> None:
dashboard = DASHBOARD dashboard = DASHBOARD
dashboard.ping_request.set() dashboard.ping_request.set()
if settings.status_use_mqtt: if settings.status_use_mqtt:
@ -754,7 +767,7 @@ class PingRequestHandler(BaseHandler):
class InfoRequestHandler(BaseHandler): class InfoRequestHandler(BaseHandler):
@authenticated @authenticated
@bind_config @bind_config
async def get(self, configuration=None): async def get(self, configuration: str | None = None) -> None:
yaml_path = settings.rel_path(configuration) yaml_path = settings.rel_path(configuration)
dashboard = DASHBOARD dashboard = DASHBOARD
entry = dashboard.entries.get(yaml_path) entry = dashboard.entries.get(yaml_path)
@ -770,7 +783,7 @@ class InfoRequestHandler(BaseHandler):
class EditRequestHandler(BaseHandler): class EditRequestHandler(BaseHandler):
@authenticated @authenticated
@bind_config @bind_config
async def get(self, configuration: str | None = None): async def get(self, configuration: str | None = None) -> None:
"""Get the content of a file.""" """Get the content of a file."""
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
filename = settings.rel_path(configuration) filename = settings.rel_path(configuration)
@ -788,7 +801,7 @@ class EditRequestHandler(BaseHandler):
@authenticated @authenticated
@bind_config @bind_config
async def post(self, configuration: str | None = None): async def post(self, configuration: str | None = None) -> None:
"""Write the content of a file.""" """Write the content of a file."""
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
config_file = settings.rel_path(configuration) config_file = settings.rel_path(configuration)
@ -805,7 +818,7 @@ class EditRequestHandler(BaseHandler):
class DeleteRequestHandler(BaseHandler): class DeleteRequestHandler(BaseHandler):
@authenticated @authenticated
@bind_config @bind_config
def post(self, configuration=None): def post(self, configuration: str | None = None) -> None:
config_file = settings.rel_path(configuration) config_file = settings.rel_path(configuration)
storage_path = ext_storage_path(configuration) storage_path = ext_storage_path(configuration)
@ -825,20 +838,20 @@ class DeleteRequestHandler(BaseHandler):
class UndoDeleteRequestHandler(BaseHandler): class UndoDeleteRequestHandler(BaseHandler):
@authenticated @authenticated
@bind_config @bind_config
def post(self, configuration=None): def post(self, configuration: str | None = None) -> None:
config_file = settings.rel_path(configuration) config_file = settings.rel_path(configuration)
trash_path = trash_storage_path() trash_path = trash_storage_path()
shutil.move(os.path.join(trash_path, configuration), config_file) shutil.move(os.path.join(trash_path, configuration), config_file)
class LoginHandler(BaseHandler): class LoginHandler(BaseHandler):
def get(self): def get(self) -> None:
if is_authenticated(self): if is_authenticated(self):
self.redirect("./") self.redirect("./")
else: else:
self.render_login_page() self.render_login_page()
def render_login_page(self, error=None): def render_login_page(self, error: str | None = None) -> None:
self.render( self.render(
"login.template.html", "login.template.html",
error=error, error=error,
@ -847,7 +860,7 @@ class LoginHandler(BaseHandler):
**template_args(), **template_args(),
) )
def post_ha_addon_login(self): def post_ha_addon_login(self) -> None:
import requests import requests
headers = { headers = {
@ -874,7 +887,7 @@ class LoginHandler(BaseHandler):
self.set_status(401) self.set_status(401)
self.render_login_page(error="Invalid username or password") self.render_login_page(error="Invalid username or password")
def post_native_login(self): def post_native_login(self) -> None:
username = self.get_argument("username", "") username = self.get_argument("username", "")
password = self.get_argument("password", "") password = self.get_argument("password", "")
if settings.check_password(username, password): if settings.check_password(username, password):
@ -887,7 +900,7 @@ class LoginHandler(BaseHandler):
self.set_status(401) self.set_status(401)
self.render_login_page(error=error_str) self.render_login_page(error=error_str)
def post(self): def post(self) -> None:
if settings.using_ha_addon_auth: if settings.using_ha_addon_auth:
self.post_ha_addon_login() self.post_ha_addon_login()
else: else:
@ -896,14 +909,14 @@ class LoginHandler(BaseHandler):
class LogoutHandler(BaseHandler): class LogoutHandler(BaseHandler):
@authenticated @authenticated
def get(self): def get(self) -> None:
self.clear_cookie("authenticated") self.clear_cookie("authenticated")
self.redirect("./login") self.redirect("./login")
class SecretKeysRequestHandler(BaseHandler): class SecretKeysRequestHandler(BaseHandler):
@authenticated @authenticated
def get(self): def get(self) -> None:
filename = None filename = None
for secret_filename in const.SECRETS_FILES: for secret_filename in const.SECRETS_FILES:
@ -923,10 +936,10 @@ class SecretKeysRequestHandler(BaseHandler):
class SafeLoaderIgnoreUnknown(FastestAvailableSafeLoader): class SafeLoaderIgnoreUnknown(FastestAvailableSafeLoader):
def ignore_unknown(self, node): def ignore_unknown(self, node: Node) -> str:
return f"{node.tag} {node.value}" return f"{node.tag} {node.value}"
def construct_yaml_binary(self, node) -> str: def construct_yaml_binary(self, node: Node) -> str:
return super().construct_yaml_binary(node).decode("ascii") return super().construct_yaml_binary(node).decode("ascii")
@ -939,7 +952,7 @@ SafeLoaderIgnoreUnknown.add_constructor(
class JsonConfigRequestHandler(BaseHandler): class JsonConfigRequestHandler(BaseHandler):
@authenticated @authenticated
@bind_config @bind_config
async def get(self, configuration=None): async def get(self, configuration: str | None = None) -> None:
filename = settings.rel_path(configuration) filename = settings.rel_path(configuration)
if not os.path.isfile(filename): if not os.path.isfile(filename):
self.send_error(404) self.send_error(404)
@ -959,7 +972,7 @@ class JsonConfigRequestHandler(BaseHandler):
self.finish() self.finish()
def get_base_frontend_path(): def get_base_frontend_path() -> str:
if ENV_DEV not in os.environ: if ENV_DEV not in os.environ:
import esphome_dashboard import esphome_dashboard
@ -973,12 +986,12 @@ def get_base_frontend_path():
return os.path.abspath(os.path.join(os.getcwd(), static_path, "esphome_dashboard")) return os.path.abspath(os.path.join(os.getcwd(), static_path, "esphome_dashboard"))
def get_static_path(*args): def get_static_path(*args: Iterable[str]) -> str:
return os.path.join(get_base_frontend_path(), "static", *args) return os.path.join(get_base_frontend_path(), "static", *args)
@functools.cache @functools.cache
def get_static_file_url(name): def get_static_file_url(name: str) -> str:
base = f"./static/{name}" base = f"./static/{name}"
if ENV_DEV in os.environ: if ENV_DEV in os.environ:
@ -997,7 +1010,7 @@ def get_static_file_url(name):
def make_app(debug=get_bool_env(ENV_DEV)) -> tornado.web.Application: def make_app(debug=get_bool_env(ENV_DEV)) -> tornado.web.Application:
def log_function(handler): def log_function(handler: tornado.web.RequestHandler) -> None:
if handler.get_status() < 400: if handler.get_status() < 400:
log_method = access_log.info log_method = access_log.info