diff --git a/esphome/wizard.py b/esphome/wizard.py index cb599df59a..3edf519816 100644 --- a/esphome/wizard.py +++ b/esphome/wizard.py @@ -1,6 +1,7 @@ import os import random import string +from typing import Literal, NotRequired, TypedDict, Unpack import unicodedata import voluptuous as vol @@ -103,11 +104,25 @@ HARDWARE_BASE_CONFIGS = { } -def sanitize_double_quotes(value): +def sanitize_double_quotes(value: str) -> str: return value.replace("\\", "\\\\").replace('"', '\\"') -def wizard_file(**kwargs): +class WizardFileKwargs(TypedDict): + """Keyword arguments for wizard_file function.""" + + name: str + platform: Literal["ESP8266", "ESP32", "RP2040", "BK72XX", "LN882X", "RTL87XX"] + board: str + ssid: NotRequired[str] + psk: NotRequired[str] + password: NotRequired[str] + ota_password: NotRequired[str] + api_encryption_key: NotRequired[str] + friendly_name: NotRequired[str] + + +def wizard_file(**kwargs: Unpack[WizardFileKwargs]) -> str: letters = string.ascii_letters + string.digits ap_name_base = kwargs["name"].replace("_", " ").title() ap_name = f"{ap_name_base} Fallback Hotspot" @@ -180,7 +195,25 @@ captive_portal: return config -def wizard_write(path, **kwargs): +class WizardWriteKwargs(TypedDict): + """Keyword arguments for wizard_write function.""" + + name: str + type: Literal["basic", "empty", "upload"] + # Required for "basic" type + board: NotRequired[str] + platform: NotRequired[str] + ssid: NotRequired[str] + psk: NotRequired[str] + password: NotRequired[str] + ota_password: NotRequired[str] + api_encryption_key: NotRequired[str] + friendly_name: NotRequired[str] + # Required for "upload" type + file_text: NotRequired[str] + + +def wizard_write(path: str, **kwargs: Unpack[WizardWriteKwargs]) -> bool: from esphome.components.bk72xx import boards as bk72xx_boards from esphome.components.esp32 import boards as esp32_boards from esphome.components.esp8266 import boards as esp8266_boards @@ -237,14 +270,14 @@ def wizard_write(path, **kwargs): if get_bool_env(ENV_QUICKWIZARD): - def sleep(time): + def sleep(time: float) -> None: pass else: from time import sleep -def safe_print_step(step, big): +def safe_print_step(step: int, big: str) -> None: safe_print() safe_print() safe_print(f"============= STEP {step} =============") @@ -253,14 +286,14 @@ def safe_print_step(step, big): sleep(0.25) -def default_input(text, default): +def default_input(text: str, default: str) -> str: safe_print() safe_print(f"Press ENTER for default ({default})") return safe_input(text.format(default)) or default # From https://stackoverflow.com/a/518232/8924614 -def strip_accents(value): +def strip_accents(value: str) -> str: return "".join( c for c in unicodedata.normalize("NFD", str(value)) @@ -268,7 +301,7 @@ def strip_accents(value): ) -def wizard(path): +def wizard(path: str) -> int: from esphome.components.bk72xx import boards as bk72xx_boards from esphome.components.esp32 import boards as esp32_boards from esphome.components.esp8266 import boards as esp8266_boards @@ -509,6 +542,7 @@ def wizard(path): ssid=ssid, psk=psk, password=password, + type="basic", ): return 1 diff --git a/tests/unit_tests/test_wizard.py b/tests/unit_tests/test_wizard.py index fea2fb5558..7af4db813a 100644 --- a/tests/unit_tests/test_wizard.py +++ b/tests/unit_tests/test_wizard.py @@ -1,9 +1,12 @@ """Tests for the wizard.py file.""" import os +from pathlib import Path +from typing import Any from unittest.mock import MagicMock import pytest +from pytest import MonkeyPatch from esphome.components.bk72xx.boards import BK72XX_BOARD_PINS from esphome.components.esp32.boards import ESP32_BOARD_PINS @@ -15,7 +18,7 @@ import esphome.wizard as wz @pytest.fixture -def default_config(): +def default_config() -> dict[str, Any]: return { "type": "basic", "name": "test-name", @@ -28,7 +31,7 @@ def default_config(): @pytest.fixture -def wizard_answers(): +def wizard_answers() -> list[str]: return [ "test-node", # Name of the node "ESP8266", # platform @@ -53,7 +56,9 @@ def test_sanitize_quotes_replaces_with_escaped_char(): assert output_str == '\\"key\\": \\"value\\"' -def test_config_file_fallback_ap_includes_descriptive_name(default_config): +def test_config_file_fallback_ap_includes_descriptive_name( + default_config: dict[str, Any], +): """ The fallback AP should include the node and a descriptive name """ @@ -67,7 +72,9 @@ def test_config_file_fallback_ap_includes_descriptive_name(default_config): assert 'ssid: "Test Node Fallback Hotspot"' in config -def test_config_file_fallback_ap_name_less_than_32_chars(default_config): +def test_config_file_fallback_ap_name_less_than_32_chars( + default_config: dict[str, Any], +): """ The fallback AP name must be less than 32 chars. Since it is composed of the node name and "Fallback Hotspot" this can be too long and needs truncating @@ -82,7 +89,7 @@ def test_config_file_fallback_ap_name_less_than_32_chars(default_config): assert 'ssid: "A Very Long Name For This Node"' in config -def test_config_file_should_include_ota(default_config): +def test_config_file_should_include_ota(default_config: dict[str, Any]): """ The Over-The-Air update should be enabled by default """ @@ -95,7 +102,9 @@ def test_config_file_should_include_ota(default_config): assert "ota:" in config -def test_config_file_should_include_ota_when_password_set(default_config): +def test_config_file_should_include_ota_when_password_set( + default_config: dict[str, Any], +): """ The Over-The-Air update should be enabled when a password is set """ @@ -109,7 +118,9 @@ def test_config_file_should_include_ota_when_password_set(default_config): assert "ota:" in config -def test_wizard_write_sets_platform(default_config, tmp_path, monkeypatch): +def test_wizard_write_sets_platform( + default_config: dict[str, Any], tmp_path: Path, monkeypatch: MonkeyPatch +): """ If the platform is not explicitly set, use "ESP8266" if the board is one of the ESP8266 boards """ @@ -126,7 +137,7 @@ def test_wizard_write_sets_platform(default_config, tmp_path, monkeypatch): assert "esp8266:" in generated_config -def test_wizard_empty_config(tmp_path, monkeypatch): +def test_wizard_empty_config(tmp_path: Path, monkeypatch: MonkeyPatch): """ The wizard should be able to create an empty configuration """ @@ -146,7 +157,7 @@ def test_wizard_empty_config(tmp_path, monkeypatch): assert generated_config == "" -def test_wizard_upload_config(tmp_path, monkeypatch): +def test_wizard_upload_config(tmp_path: Path, monkeypatch: MonkeyPatch): """ The wizard should be able to import an base64 encoded configuration """ @@ -168,7 +179,7 @@ def test_wizard_upload_config(tmp_path, monkeypatch): def test_wizard_write_defaults_platform_from_board_esp8266( - default_config, tmp_path, monkeypatch + default_config: dict[str, Any], tmp_path: Path, monkeypatch: MonkeyPatch ): """ If the platform is not explicitly set, use "ESP8266" if the board is one of the ESP8266 boards @@ -189,7 +200,7 @@ def test_wizard_write_defaults_platform_from_board_esp8266( def test_wizard_write_defaults_platform_from_board_esp32( - default_config, tmp_path, monkeypatch + default_config: dict[str, Any], tmp_path: Path, monkeypatch: MonkeyPatch ): """ If the platform is not explicitly set, use "ESP32" if the board is one of the ESP32 boards @@ -210,7 +221,7 @@ def test_wizard_write_defaults_platform_from_board_esp32( def test_wizard_write_defaults_platform_from_board_bk72xx( - default_config, tmp_path, monkeypatch + default_config: dict[str, Any], tmp_path: Path, monkeypatch: MonkeyPatch ): """ If the platform is not explicitly set, use "BK72XX" if the board is one of BK72XX boards @@ -231,7 +242,7 @@ def test_wizard_write_defaults_platform_from_board_bk72xx( def test_wizard_write_defaults_platform_from_board_ln882x( - default_config, tmp_path, monkeypatch + default_config: dict[str, Any], tmp_path: Path, monkeypatch: MonkeyPatch ): """ If the platform is not explicitly set, use "LN882X" if the board is one of LN882X boards @@ -252,7 +263,7 @@ def test_wizard_write_defaults_platform_from_board_ln882x( def test_wizard_write_defaults_platform_from_board_rtl87xx( - default_config, tmp_path, monkeypatch + default_config: dict[str, Any], tmp_path: Path, monkeypatch: MonkeyPatch ): """ If the platform is not explicitly set, use "RTL87XX" if the board is one of RTL87XX boards @@ -272,7 +283,7 @@ def test_wizard_write_defaults_platform_from_board_rtl87xx( assert "rtl87xx:" in generated_config -def test_safe_print_step_prints_step_number_and_description(monkeypatch): +def test_safe_print_step_prints_step_number_and_description(monkeypatch: MonkeyPatch): """ The safe_print_step function prints the step number and the passed description """ @@ -296,7 +307,7 @@ def test_safe_print_step_prints_step_number_and_description(monkeypatch): assert any(f"STEP {step_num}" in arg for arg in all_args) -def test_default_input_uses_default_if_no_input_supplied(monkeypatch): +def test_default_input_uses_default_if_no_input_supplied(monkeypatch: MonkeyPatch): """ The default_input() function should return the supplied default value if the user doesn't enter anything """ @@ -312,7 +323,7 @@ def test_default_input_uses_default_if_no_input_supplied(monkeypatch): assert retval == default_string -def test_default_input_uses_user_supplied_value(monkeypatch): +def test_default_input_uses_user_supplied_value(monkeypatch: MonkeyPatch): """ The default_input() function should return the value that the user enters """ @@ -376,7 +387,9 @@ def test_wizard_rejects_existing_files(tmpdir): assert retval == 2 -def test_wizard_accepts_default_answers_esp8266(tmpdir, monkeypatch, wizard_answers): +def test_wizard_accepts_default_answers_esp8266( + tmpdir, monkeypatch: MonkeyPatch, wizard_answers: list[str] +): """ The wizard should accept the given default answers for esp8266 """ @@ -396,7 +409,9 @@ def test_wizard_accepts_default_answers_esp8266(tmpdir, monkeypatch, wizard_answ assert retval == 0 -def test_wizard_accepts_default_answers_esp32(tmpdir, monkeypatch, wizard_answers): +def test_wizard_accepts_default_answers_esp32( + tmpdir, monkeypatch: MonkeyPatch, wizard_answers: list[str] +): """ The wizard should accept the given default answers for esp32 """ @@ -418,7 +433,9 @@ def test_wizard_accepts_default_answers_esp32(tmpdir, monkeypatch, wizard_answer assert retval == 0 -def test_wizard_offers_better_node_name(tmpdir, monkeypatch, wizard_answers): +def test_wizard_offers_better_node_name( + tmpdir, monkeypatch: MonkeyPatch, wizard_answers: list[str] +): """ When the node name does not conform, a better alternative is offered * Removes special chars @@ -449,7 +466,9 @@ def test_wizard_offers_better_node_name(tmpdir, monkeypatch, wizard_answers): assert wz.default_input.call_args.args[1] == expected_name -def test_wizard_requires_correct_platform(tmpdir, monkeypatch, wizard_answers): +def test_wizard_requires_correct_platform( + tmpdir, monkeypatch: MonkeyPatch, wizard_answers: list[str] +): """ When the platform is not either esp32 or esp8266, the wizard should reject it """ @@ -471,7 +490,9 @@ def test_wizard_requires_correct_platform(tmpdir, monkeypatch, wizard_answers): assert retval == 0 -def test_wizard_requires_correct_board(tmpdir, monkeypatch, wizard_answers): +def test_wizard_requires_correct_board( + tmpdir, monkeypatch: MonkeyPatch, wizard_answers: list[str] +): """ When the board is not a valid esp8266 board, the wizard should reject it """ @@ -493,7 +514,9 @@ def test_wizard_requires_correct_board(tmpdir, monkeypatch, wizard_answers): assert retval == 0 -def test_wizard_requires_valid_ssid(tmpdir, monkeypatch, wizard_answers): +def test_wizard_requires_valid_ssid( + tmpdir, monkeypatch: MonkeyPatch, wizard_answers: list[str] +): """ When the board is not a valid esp8266 board, the wizard should reject it """ @@ -515,7 +538,9 @@ def test_wizard_requires_valid_ssid(tmpdir, monkeypatch, wizard_answers): assert retval == 0 -def test_wizard_write_protects_existing_config(tmpdir, default_config, monkeypatch): +def test_wizard_write_protects_existing_config( + tmpdir, default_config: dict[str, Any], monkeypatch: MonkeyPatch +): """ The wizard_write function should not overwrite existing config files and return False """