mirror of
				https://github.com/esphome/esphome.git
				synced 2025-11-04 00:51:49 +00:00 
			
		
		
		
	[wizard] Fix KeyError when running wizard with empty OTA password (#10753)
This commit is contained in:
		@@ -1,6 +1,7 @@
 | 
			
		||||
import os
 | 
			
		||||
import random
 | 
			
		||||
import string
 | 
			
		||||
from typing import Literal, NotRequired, TypedDict, Unpack
 | 
			
		||||
import unicodedata
 | 
			
		||||
 | 
			
		||||
import voluptuous as vol
 | 
			
		||||
@@ -103,11 +104,25 @@ HARDWARE_BASE_CONFIGS = {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def sanitize_double_quotes(value):
 | 
			
		||||
def sanitize_double_quotes(value: str) -> str:
 | 
			
		||||
    return value.replace("\\", "\\\\").replace('"', '\\"')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def wizard_file(**kwargs):
 | 
			
		||||
class WizardFileKwargs(TypedDict):
 | 
			
		||||
    """Keyword arguments for wizard_file function."""
 | 
			
		||||
 | 
			
		||||
    name: str
 | 
			
		||||
    platform: Literal["ESP8266", "ESP32", "RP2040", "BK72XX", "LN882X", "RTL87XX"]
 | 
			
		||||
    board: str
 | 
			
		||||
    ssid: NotRequired[str]
 | 
			
		||||
    psk: NotRequired[str]
 | 
			
		||||
    password: NotRequired[str]
 | 
			
		||||
    ota_password: NotRequired[str]
 | 
			
		||||
    api_encryption_key: NotRequired[str]
 | 
			
		||||
    friendly_name: NotRequired[str]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def wizard_file(**kwargs: Unpack[WizardFileKwargs]) -> str:
 | 
			
		||||
    letters = string.ascii_letters + string.digits
 | 
			
		||||
    ap_name_base = kwargs["name"].replace("_", " ").title()
 | 
			
		||||
    ap_name = f"{ap_name_base} Fallback Hotspot"
 | 
			
		||||
@@ -180,7 +195,25 @@ captive_portal:
 | 
			
		||||
    return config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def wizard_write(path, **kwargs):
 | 
			
		||||
class WizardWriteKwargs(TypedDict):
 | 
			
		||||
    """Keyword arguments for wizard_write function."""
 | 
			
		||||
 | 
			
		||||
    name: str
 | 
			
		||||
    type: Literal["basic", "empty", "upload"]
 | 
			
		||||
    # Required for "basic" type
 | 
			
		||||
    board: NotRequired[str]
 | 
			
		||||
    platform: NotRequired[str]
 | 
			
		||||
    ssid: NotRequired[str]
 | 
			
		||||
    psk: NotRequired[str]
 | 
			
		||||
    password: NotRequired[str]
 | 
			
		||||
    ota_password: NotRequired[str]
 | 
			
		||||
    api_encryption_key: NotRequired[str]
 | 
			
		||||
    friendly_name: NotRequired[str]
 | 
			
		||||
    # Required for "upload" type
 | 
			
		||||
    file_text: NotRequired[str]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def wizard_write(path: str, **kwargs: Unpack[WizardWriteKwargs]) -> bool:
 | 
			
		||||
    from esphome.components.bk72xx import boards as bk72xx_boards
 | 
			
		||||
    from esphome.components.esp32 import boards as esp32_boards
 | 
			
		||||
    from esphome.components.esp8266 import boards as esp8266_boards
 | 
			
		||||
@@ -237,14 +270,14 @@ def wizard_write(path, **kwargs):
 | 
			
		||||
 | 
			
		||||
if get_bool_env(ENV_QUICKWIZARD):
 | 
			
		||||
 | 
			
		||||
    def sleep(time):
 | 
			
		||||
    def sleep(time: float) -> None:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
else:
 | 
			
		||||
    from time import sleep
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def safe_print_step(step, big):
 | 
			
		||||
def safe_print_step(step: int, big: str) -> None:
 | 
			
		||||
    safe_print()
 | 
			
		||||
    safe_print()
 | 
			
		||||
    safe_print(f"============= STEP {step} =============")
 | 
			
		||||
@@ -253,14 +286,14 @@ def safe_print_step(step, big):
 | 
			
		||||
    sleep(0.25)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def default_input(text, default):
 | 
			
		||||
def default_input(text: str, default: str) -> str:
 | 
			
		||||
    safe_print()
 | 
			
		||||
    safe_print(f"Press ENTER for default ({default})")
 | 
			
		||||
    return safe_input(text.format(default)) or default
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# From https://stackoverflow.com/a/518232/8924614
 | 
			
		||||
def strip_accents(value):
 | 
			
		||||
def strip_accents(value: str) -> str:
 | 
			
		||||
    return "".join(
 | 
			
		||||
        c
 | 
			
		||||
        for c in unicodedata.normalize("NFD", str(value))
 | 
			
		||||
@@ -268,7 +301,7 @@ def strip_accents(value):
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def wizard(path):
 | 
			
		||||
def wizard(path: str) -> int:
 | 
			
		||||
    from esphome.components.bk72xx import boards as bk72xx_boards
 | 
			
		||||
    from esphome.components.esp32 import boards as esp32_boards
 | 
			
		||||
    from esphome.components.esp8266 import boards as esp8266_boards
 | 
			
		||||
@@ -509,6 +542,7 @@ def wizard(path):
 | 
			
		||||
        ssid=ssid,
 | 
			
		||||
        psk=psk,
 | 
			
		||||
        password=password,
 | 
			
		||||
        type="basic",
 | 
			
		||||
    ):
 | 
			
		||||
        return 1
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,9 +1,12 @@
 | 
			
		||||
"""Tests for the wizard.py file."""
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Any
 | 
			
		||||
from unittest.mock import MagicMock
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
from pytest import MonkeyPatch
 | 
			
		||||
 | 
			
		||||
from esphome.components.bk72xx.boards import BK72XX_BOARD_PINS
 | 
			
		||||
from esphome.components.esp32.boards import ESP32_BOARD_PINS
 | 
			
		||||
@@ -15,7 +18,7 @@ import esphome.wizard as wz
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def default_config():
 | 
			
		||||
def default_config() -> dict[str, Any]:
 | 
			
		||||
    return {
 | 
			
		||||
        "type": "basic",
 | 
			
		||||
        "name": "test-name",
 | 
			
		||||
@@ -28,7 +31,7 @@ def default_config():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def wizard_answers():
 | 
			
		||||
def wizard_answers() -> list[str]:
 | 
			
		||||
    return [
 | 
			
		||||
        "test-node",  # Name of the node
 | 
			
		||||
        "ESP8266",  # platform
 | 
			
		||||
@@ -53,7 +56,9 @@ def test_sanitize_quotes_replaces_with_escaped_char():
 | 
			
		||||
    assert output_str == '\\"key\\": \\"value\\"'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_config_file_fallback_ap_includes_descriptive_name(default_config):
 | 
			
		||||
def test_config_file_fallback_ap_includes_descriptive_name(
 | 
			
		||||
    default_config: dict[str, Any],
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    The fallback AP should include the node and a descriptive name
 | 
			
		||||
    """
 | 
			
		||||
@@ -67,7 +72,9 @@ def test_config_file_fallback_ap_includes_descriptive_name(default_config):
 | 
			
		||||
    assert 'ssid: "Test Node Fallback Hotspot"' in config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_config_file_fallback_ap_name_less_than_32_chars(default_config):
 | 
			
		||||
def test_config_file_fallback_ap_name_less_than_32_chars(
 | 
			
		||||
    default_config: dict[str, Any],
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    The fallback AP name must be less than 32 chars.
 | 
			
		||||
    Since it is composed of the node name and "Fallback Hotspot" this can be too long and needs truncating
 | 
			
		||||
@@ -82,7 +89,7 @@ def test_config_file_fallback_ap_name_less_than_32_chars(default_config):
 | 
			
		||||
    assert 'ssid: "A Very Long Name For This Node"' in config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_config_file_should_include_ota(default_config):
 | 
			
		||||
def test_config_file_should_include_ota(default_config: dict[str, Any]):
 | 
			
		||||
    """
 | 
			
		||||
    The Over-The-Air update should be enabled by default
 | 
			
		||||
    """
 | 
			
		||||
@@ -95,7 +102,9 @@ def test_config_file_should_include_ota(default_config):
 | 
			
		||||
    assert "ota:" in config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_config_file_should_include_ota_when_password_set(default_config):
 | 
			
		||||
def test_config_file_should_include_ota_when_password_set(
 | 
			
		||||
    default_config: dict[str, Any],
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    The Over-The-Air update should be enabled when a password is set
 | 
			
		||||
    """
 | 
			
		||||
@@ -109,7 +118,9 @@ def test_config_file_should_include_ota_when_password_set(default_config):
 | 
			
		||||
    assert "ota:" in config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_wizard_write_sets_platform(default_config, tmp_path, monkeypatch):
 | 
			
		||||
def test_wizard_write_sets_platform(
 | 
			
		||||
    default_config: dict[str, Any], tmp_path: Path, monkeypatch: MonkeyPatch
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    If the platform is not explicitly set, use "ESP8266" if the board is one of the ESP8266 boards
 | 
			
		||||
    """
 | 
			
		||||
@@ -126,7 +137,7 @@ def test_wizard_write_sets_platform(default_config, tmp_path, monkeypatch):
 | 
			
		||||
    assert "esp8266:" in generated_config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_wizard_empty_config(tmp_path, monkeypatch):
 | 
			
		||||
def test_wizard_empty_config(tmp_path: Path, monkeypatch: MonkeyPatch):
 | 
			
		||||
    """
 | 
			
		||||
    The wizard should be able to create an empty configuration
 | 
			
		||||
    """
 | 
			
		||||
@@ -146,7 +157,7 @@ def test_wizard_empty_config(tmp_path, monkeypatch):
 | 
			
		||||
    assert generated_config == ""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_wizard_upload_config(tmp_path, monkeypatch):
 | 
			
		||||
def test_wizard_upload_config(tmp_path: Path, monkeypatch: MonkeyPatch):
 | 
			
		||||
    """
 | 
			
		||||
    The wizard should be able to import an base64 encoded configuration
 | 
			
		||||
    """
 | 
			
		||||
@@ -168,7 +179,7 @@ def test_wizard_upload_config(tmp_path, monkeypatch):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_wizard_write_defaults_platform_from_board_esp8266(
 | 
			
		||||
    default_config, tmp_path, monkeypatch
 | 
			
		||||
    default_config: dict[str, Any], tmp_path: Path, monkeypatch: MonkeyPatch
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    If the platform is not explicitly set, use "ESP8266" if the board is one of the ESP8266 boards
 | 
			
		||||
@@ -189,7 +200,7 @@ def test_wizard_write_defaults_platform_from_board_esp8266(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_wizard_write_defaults_platform_from_board_esp32(
 | 
			
		||||
    default_config, tmp_path, monkeypatch
 | 
			
		||||
    default_config: dict[str, Any], tmp_path: Path, monkeypatch: MonkeyPatch
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    If the platform is not explicitly set, use "ESP32" if the board is one of the ESP32 boards
 | 
			
		||||
@@ -210,7 +221,7 @@ def test_wizard_write_defaults_platform_from_board_esp32(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_wizard_write_defaults_platform_from_board_bk72xx(
 | 
			
		||||
    default_config, tmp_path, monkeypatch
 | 
			
		||||
    default_config: dict[str, Any], tmp_path: Path, monkeypatch: MonkeyPatch
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    If the platform is not explicitly set, use "BK72XX" if the board is one of BK72XX boards
 | 
			
		||||
@@ -231,7 +242,7 @@ def test_wizard_write_defaults_platform_from_board_bk72xx(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_wizard_write_defaults_platform_from_board_ln882x(
 | 
			
		||||
    default_config, tmp_path, monkeypatch
 | 
			
		||||
    default_config: dict[str, Any], tmp_path: Path, monkeypatch: MonkeyPatch
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    If the platform is not explicitly set, use "LN882X" if the board is one of LN882X boards
 | 
			
		||||
@@ -252,7 +263,7 @@ def test_wizard_write_defaults_platform_from_board_ln882x(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_wizard_write_defaults_platform_from_board_rtl87xx(
 | 
			
		||||
    default_config, tmp_path, monkeypatch
 | 
			
		||||
    default_config: dict[str, Any], tmp_path: Path, monkeypatch: MonkeyPatch
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    If the platform is not explicitly set, use "RTL87XX" if the board is one of RTL87XX boards
 | 
			
		||||
@@ -272,7 +283,7 @@ def test_wizard_write_defaults_platform_from_board_rtl87xx(
 | 
			
		||||
    assert "rtl87xx:" in generated_config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_safe_print_step_prints_step_number_and_description(monkeypatch):
 | 
			
		||||
def test_safe_print_step_prints_step_number_and_description(monkeypatch: MonkeyPatch):
 | 
			
		||||
    """
 | 
			
		||||
    The safe_print_step function prints the step number and the passed description
 | 
			
		||||
    """
 | 
			
		||||
@@ -296,7 +307,7 @@ def test_safe_print_step_prints_step_number_and_description(monkeypatch):
 | 
			
		||||
    assert any(f"STEP {step_num}" in arg for arg in all_args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_default_input_uses_default_if_no_input_supplied(monkeypatch):
 | 
			
		||||
def test_default_input_uses_default_if_no_input_supplied(monkeypatch: MonkeyPatch):
 | 
			
		||||
    """
 | 
			
		||||
    The default_input() function should return the supplied default value if the user doesn't enter anything
 | 
			
		||||
    """
 | 
			
		||||
@@ -312,7 +323,7 @@ def test_default_input_uses_default_if_no_input_supplied(monkeypatch):
 | 
			
		||||
    assert retval == default_string
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_default_input_uses_user_supplied_value(monkeypatch):
 | 
			
		||||
def test_default_input_uses_user_supplied_value(monkeypatch: MonkeyPatch):
 | 
			
		||||
    """
 | 
			
		||||
    The default_input() function should return the value that the user enters
 | 
			
		||||
    """
 | 
			
		||||
@@ -376,7 +387,9 @@ def test_wizard_rejects_existing_files(tmpdir):
 | 
			
		||||
    assert retval == 2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_wizard_accepts_default_answers_esp8266(tmpdir, monkeypatch, wizard_answers):
 | 
			
		||||
def test_wizard_accepts_default_answers_esp8266(
 | 
			
		||||
    tmpdir, monkeypatch: MonkeyPatch, wizard_answers: list[str]
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    The wizard should accept the given default answers for esp8266
 | 
			
		||||
    """
 | 
			
		||||
@@ -396,7 +409,9 @@ def test_wizard_accepts_default_answers_esp8266(tmpdir, monkeypatch, wizard_answ
 | 
			
		||||
    assert retval == 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_wizard_accepts_default_answers_esp32(tmpdir, monkeypatch, wizard_answers):
 | 
			
		||||
def test_wizard_accepts_default_answers_esp32(
 | 
			
		||||
    tmpdir, monkeypatch: MonkeyPatch, wizard_answers: list[str]
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    The wizard should accept the given default answers for esp32
 | 
			
		||||
    """
 | 
			
		||||
@@ -418,7 +433,9 @@ def test_wizard_accepts_default_answers_esp32(tmpdir, monkeypatch, wizard_answer
 | 
			
		||||
    assert retval == 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_wizard_offers_better_node_name(tmpdir, monkeypatch, wizard_answers):
 | 
			
		||||
def test_wizard_offers_better_node_name(
 | 
			
		||||
    tmpdir, monkeypatch: MonkeyPatch, wizard_answers: list[str]
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    When the node name does not conform, a better alternative is offered
 | 
			
		||||
    * Removes special chars
 | 
			
		||||
@@ -449,7 +466,9 @@ def test_wizard_offers_better_node_name(tmpdir, monkeypatch, wizard_answers):
 | 
			
		||||
    assert wz.default_input.call_args.args[1] == expected_name
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_wizard_requires_correct_platform(tmpdir, monkeypatch, wizard_answers):
 | 
			
		||||
def test_wizard_requires_correct_platform(
 | 
			
		||||
    tmpdir, monkeypatch: MonkeyPatch, wizard_answers: list[str]
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    When the platform is not either esp32 or esp8266, the wizard should reject it
 | 
			
		||||
    """
 | 
			
		||||
@@ -471,7 +490,9 @@ def test_wizard_requires_correct_platform(tmpdir, monkeypatch, wizard_answers):
 | 
			
		||||
    assert retval == 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_wizard_requires_correct_board(tmpdir, monkeypatch, wizard_answers):
 | 
			
		||||
def test_wizard_requires_correct_board(
 | 
			
		||||
    tmpdir, monkeypatch: MonkeyPatch, wizard_answers: list[str]
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    When the board is not a valid esp8266 board, the wizard should reject it
 | 
			
		||||
    """
 | 
			
		||||
@@ -493,7 +514,9 @@ def test_wizard_requires_correct_board(tmpdir, monkeypatch, wizard_answers):
 | 
			
		||||
    assert retval == 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_wizard_requires_valid_ssid(tmpdir, monkeypatch, wizard_answers):
 | 
			
		||||
def test_wizard_requires_valid_ssid(
 | 
			
		||||
    tmpdir, monkeypatch: MonkeyPatch, wizard_answers: list[str]
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    When the board is not a valid esp8266 board, the wizard should reject it
 | 
			
		||||
    """
 | 
			
		||||
@@ -515,7 +538,9 @@ def test_wizard_requires_valid_ssid(tmpdir, monkeypatch, wizard_answers):
 | 
			
		||||
    assert retval == 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_wizard_write_protects_existing_config(tmpdir, default_config, monkeypatch):
 | 
			
		||||
def test_wizard_write_protects_existing_config(
 | 
			
		||||
    tmpdir, default_config: dict[str, Any], monkeypatch: MonkeyPatch
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    The wizard_write function should not overwrite existing config files and return False
 | 
			
		||||
    """
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user