1
0
mirror of https://github.com/esphome/esphome.git synced 2025-09-18 11:12:20 +01:00

[wizard] Fix KeyError when running wizard with empty OTA password (#10753)

This commit is contained in:
J. Nick Koston
2025-09-16 14:56:54 -05:00
committed by GitHub
parent 1f4b10f523
commit 22989592f0
2 changed files with 91 additions and 32 deletions

View File

@@ -1,6 +1,7 @@
import os import os
import random import random
import string import string
from typing import Literal, NotRequired, TypedDict, Unpack
import unicodedata import unicodedata
import voluptuous as vol import voluptuous as vol
@@ -103,11 +104,25 @@ HARDWARE_BASE_CONFIGS = {
} }
def sanitize_double_quotes(value): def sanitize_double_quotes(value: str) -> str:
return value.replace("\\", "\\\\").replace('"', '\\"') return value.replace("\\", "\\\\").replace('"', '\\"')
def wizard_file(**kwargs): class WizardFileKwargs(TypedDict):
"""Keyword arguments for wizard_file function."""
name: str
platform: Literal["ESP8266", "ESP32", "RP2040", "BK72XX", "LN882X", "RTL87XX"]
board: str
ssid: NotRequired[str]
psk: NotRequired[str]
password: NotRequired[str]
ota_password: NotRequired[str]
api_encryption_key: NotRequired[str]
friendly_name: NotRequired[str]
def wizard_file(**kwargs: Unpack[WizardFileKwargs]) -> str:
letters = string.ascii_letters + string.digits letters = string.ascii_letters + string.digits
ap_name_base = kwargs["name"].replace("_", " ").title() ap_name_base = kwargs["name"].replace("_", " ").title()
ap_name = f"{ap_name_base} Fallback Hotspot" ap_name = f"{ap_name_base} Fallback Hotspot"
@@ -180,7 +195,25 @@ captive_portal:
return config return config
def wizard_write(path, **kwargs): class WizardWriteKwargs(TypedDict):
"""Keyword arguments for wizard_write function."""
name: str
type: Literal["basic", "empty", "upload"]
# Required for "basic" type
board: NotRequired[str]
platform: NotRequired[str]
ssid: NotRequired[str]
psk: NotRequired[str]
password: NotRequired[str]
ota_password: NotRequired[str]
api_encryption_key: NotRequired[str]
friendly_name: NotRequired[str]
# Required for "upload" type
file_text: NotRequired[str]
def wizard_write(path: str, **kwargs: Unpack[WizardWriteKwargs]) -> bool:
from esphome.components.bk72xx import boards as bk72xx_boards from esphome.components.bk72xx import boards as bk72xx_boards
from esphome.components.esp32 import boards as esp32_boards from esphome.components.esp32 import boards as esp32_boards
from esphome.components.esp8266 import boards as esp8266_boards from esphome.components.esp8266 import boards as esp8266_boards
@@ -237,14 +270,14 @@ def wizard_write(path, **kwargs):
if get_bool_env(ENV_QUICKWIZARD): if get_bool_env(ENV_QUICKWIZARD):
def sleep(time): def sleep(time: float) -> None:
pass pass
else: else:
from time import sleep from time import sleep
def safe_print_step(step, big): def safe_print_step(step: int, big: str) -> None:
safe_print() safe_print()
safe_print() safe_print()
safe_print(f"============= STEP {step} =============") safe_print(f"============= STEP {step} =============")
@@ -253,14 +286,14 @@ def safe_print_step(step, big):
sleep(0.25) sleep(0.25)
def default_input(text, default): def default_input(text: str, default: str) -> str:
safe_print() safe_print()
safe_print(f"Press ENTER for default ({default})") safe_print(f"Press ENTER for default ({default})")
return safe_input(text.format(default)) or default return safe_input(text.format(default)) or default
# From https://stackoverflow.com/a/518232/8924614 # From https://stackoverflow.com/a/518232/8924614
def strip_accents(value): def strip_accents(value: str) -> str:
return "".join( return "".join(
c c
for c in unicodedata.normalize("NFD", str(value)) for c in unicodedata.normalize("NFD", str(value))
@@ -268,7 +301,7 @@ def strip_accents(value):
) )
def wizard(path): def wizard(path: str) -> int:
from esphome.components.bk72xx import boards as bk72xx_boards from esphome.components.bk72xx import boards as bk72xx_boards
from esphome.components.esp32 import boards as esp32_boards from esphome.components.esp32 import boards as esp32_boards
from esphome.components.esp8266 import boards as esp8266_boards from esphome.components.esp8266 import boards as esp8266_boards
@@ -509,6 +542,7 @@ def wizard(path):
ssid=ssid, ssid=ssid,
psk=psk, psk=psk,
password=password, password=password,
type="basic",
): ):
return 1 return 1

View File

@@ -1,9 +1,12 @@
"""Tests for the wizard.py file.""" """Tests for the wizard.py file."""
import os import os
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
from pytest import MonkeyPatch
from esphome.components.bk72xx.boards import BK72XX_BOARD_PINS from esphome.components.bk72xx.boards import BK72XX_BOARD_PINS
from esphome.components.esp32.boards import ESP32_BOARD_PINS from esphome.components.esp32.boards import ESP32_BOARD_PINS
@@ -15,7 +18,7 @@ import esphome.wizard as wz
@pytest.fixture @pytest.fixture
def default_config(): def default_config() -> dict[str, Any]:
return { return {
"type": "basic", "type": "basic",
"name": "test-name", "name": "test-name",
@@ -28,7 +31,7 @@ def default_config():
@pytest.fixture @pytest.fixture
def wizard_answers(): def wizard_answers() -> list[str]:
return [ return [
"test-node", # Name of the node "test-node", # Name of the node
"ESP8266", # platform "ESP8266", # platform
@@ -53,7 +56,9 @@ def test_sanitize_quotes_replaces_with_escaped_char():
assert output_str == '\\"key\\": \\"value\\"' assert output_str == '\\"key\\": \\"value\\"'
def test_config_file_fallback_ap_includes_descriptive_name(default_config): def test_config_file_fallback_ap_includes_descriptive_name(
default_config: dict[str, Any],
):
""" """
The fallback AP should include the node and a descriptive name The fallback AP should include the node and a descriptive name
""" """
@@ -67,7 +72,9 @@ def test_config_file_fallback_ap_includes_descriptive_name(default_config):
assert 'ssid: "Test Node Fallback Hotspot"' in config assert 'ssid: "Test Node Fallback Hotspot"' in config
def test_config_file_fallback_ap_name_less_than_32_chars(default_config): def test_config_file_fallback_ap_name_less_than_32_chars(
default_config: dict[str, Any],
):
""" """
The fallback AP name must be less than 32 chars. The fallback AP name must be less than 32 chars.
Since it is composed of the node name and "Fallback Hotspot" this can be too long and needs truncating Since it is composed of the node name and "Fallback Hotspot" this can be too long and needs truncating
@@ -82,7 +89,7 @@ def test_config_file_fallback_ap_name_less_than_32_chars(default_config):
assert 'ssid: "A Very Long Name For This Node"' in config assert 'ssid: "A Very Long Name For This Node"' in config
def test_config_file_should_include_ota(default_config): def test_config_file_should_include_ota(default_config: dict[str, Any]):
""" """
The Over-The-Air update should be enabled by default The Over-The-Air update should be enabled by default
""" """
@@ -95,7 +102,9 @@ def test_config_file_should_include_ota(default_config):
assert "ota:" in config assert "ota:" in config
def test_config_file_should_include_ota_when_password_set(default_config): def test_config_file_should_include_ota_when_password_set(
default_config: dict[str, Any],
):
""" """
The Over-The-Air update should be enabled when a password is set The Over-The-Air update should be enabled when a password is set
""" """
@@ -109,7 +118,9 @@ def test_config_file_should_include_ota_when_password_set(default_config):
assert "ota:" in config assert "ota:" in config
def test_wizard_write_sets_platform(default_config, tmp_path, monkeypatch): def test_wizard_write_sets_platform(
default_config: dict[str, Any], tmp_path: Path, monkeypatch: MonkeyPatch
):
""" """
If the platform is not explicitly set, use "ESP8266" if the board is one of the ESP8266 boards If the platform is not explicitly set, use "ESP8266" if the board is one of the ESP8266 boards
""" """
@@ -126,7 +137,7 @@ def test_wizard_write_sets_platform(default_config, tmp_path, monkeypatch):
assert "esp8266:" in generated_config assert "esp8266:" in generated_config
def test_wizard_empty_config(tmp_path, monkeypatch): def test_wizard_empty_config(tmp_path: Path, monkeypatch: MonkeyPatch):
""" """
The wizard should be able to create an empty configuration The wizard should be able to create an empty configuration
""" """
@@ -146,7 +157,7 @@ def test_wizard_empty_config(tmp_path, monkeypatch):
assert generated_config == "" assert generated_config == ""
def test_wizard_upload_config(tmp_path, monkeypatch): def test_wizard_upload_config(tmp_path: Path, monkeypatch: MonkeyPatch):
""" """
The wizard should be able to import an base64 encoded configuration The wizard should be able to import an base64 encoded configuration
""" """
@@ -168,7 +179,7 @@ def test_wizard_upload_config(tmp_path, monkeypatch):
def test_wizard_write_defaults_platform_from_board_esp8266( def test_wizard_write_defaults_platform_from_board_esp8266(
default_config, tmp_path, monkeypatch default_config: dict[str, Any], tmp_path: Path, monkeypatch: MonkeyPatch
): ):
""" """
If the platform is not explicitly set, use "ESP8266" if the board is one of the ESP8266 boards If the platform is not explicitly set, use "ESP8266" if the board is one of the ESP8266 boards
@@ -189,7 +200,7 @@ def test_wizard_write_defaults_platform_from_board_esp8266(
def test_wizard_write_defaults_platform_from_board_esp32( def test_wizard_write_defaults_platform_from_board_esp32(
default_config, tmp_path, monkeypatch default_config: dict[str, Any], tmp_path: Path, monkeypatch: MonkeyPatch
): ):
""" """
If the platform is not explicitly set, use "ESP32" if the board is one of the ESP32 boards If the platform is not explicitly set, use "ESP32" if the board is one of the ESP32 boards
@@ -210,7 +221,7 @@ def test_wizard_write_defaults_platform_from_board_esp32(
def test_wizard_write_defaults_platform_from_board_bk72xx( def test_wizard_write_defaults_platform_from_board_bk72xx(
default_config, tmp_path, monkeypatch default_config: dict[str, Any], tmp_path: Path, monkeypatch: MonkeyPatch
): ):
""" """
If the platform is not explicitly set, use "BK72XX" if the board is one of BK72XX boards If the platform is not explicitly set, use "BK72XX" if the board is one of BK72XX boards
@@ -231,7 +242,7 @@ def test_wizard_write_defaults_platform_from_board_bk72xx(
def test_wizard_write_defaults_platform_from_board_ln882x( def test_wizard_write_defaults_platform_from_board_ln882x(
default_config, tmp_path, monkeypatch default_config: dict[str, Any], tmp_path: Path, monkeypatch: MonkeyPatch
): ):
""" """
If the platform is not explicitly set, use "LN882X" if the board is one of LN882X boards If the platform is not explicitly set, use "LN882X" if the board is one of LN882X boards
@@ -252,7 +263,7 @@ def test_wizard_write_defaults_platform_from_board_ln882x(
def test_wizard_write_defaults_platform_from_board_rtl87xx( def test_wizard_write_defaults_platform_from_board_rtl87xx(
default_config, tmp_path, monkeypatch default_config: dict[str, Any], tmp_path: Path, monkeypatch: MonkeyPatch
): ):
""" """
If the platform is not explicitly set, use "RTL87XX" if the board is one of RTL87XX boards If the platform is not explicitly set, use "RTL87XX" if the board is one of RTL87XX boards
@@ -272,7 +283,7 @@ def test_wizard_write_defaults_platform_from_board_rtl87xx(
assert "rtl87xx:" in generated_config assert "rtl87xx:" in generated_config
def test_safe_print_step_prints_step_number_and_description(monkeypatch): def test_safe_print_step_prints_step_number_and_description(monkeypatch: MonkeyPatch):
""" """
The safe_print_step function prints the step number and the passed description The safe_print_step function prints the step number and the passed description
""" """
@@ -296,7 +307,7 @@ def test_safe_print_step_prints_step_number_and_description(monkeypatch):
assert any(f"STEP {step_num}" in arg for arg in all_args) assert any(f"STEP {step_num}" in arg for arg in all_args)
def test_default_input_uses_default_if_no_input_supplied(monkeypatch): def test_default_input_uses_default_if_no_input_supplied(monkeypatch: MonkeyPatch):
""" """
The default_input() function should return the supplied default value if the user doesn't enter anything The default_input() function should return the supplied default value if the user doesn't enter anything
""" """
@@ -312,7 +323,7 @@ def test_default_input_uses_default_if_no_input_supplied(monkeypatch):
assert retval == default_string assert retval == default_string
def test_default_input_uses_user_supplied_value(monkeypatch): def test_default_input_uses_user_supplied_value(monkeypatch: MonkeyPatch):
""" """
The default_input() function should return the value that the user enters The default_input() function should return the value that the user enters
""" """
@@ -376,7 +387,9 @@ def test_wizard_rejects_existing_files(tmpdir):
assert retval == 2 assert retval == 2
def test_wizard_accepts_default_answers_esp8266(tmpdir, monkeypatch, wizard_answers): def test_wizard_accepts_default_answers_esp8266(
tmpdir, monkeypatch: MonkeyPatch, wizard_answers: list[str]
):
""" """
The wizard should accept the given default answers for esp8266 The wizard should accept the given default answers for esp8266
""" """
@@ -396,7 +409,9 @@ def test_wizard_accepts_default_answers_esp8266(tmpdir, monkeypatch, wizard_answ
assert retval == 0 assert retval == 0
def test_wizard_accepts_default_answers_esp32(tmpdir, monkeypatch, wizard_answers): def test_wizard_accepts_default_answers_esp32(
tmpdir, monkeypatch: MonkeyPatch, wizard_answers: list[str]
):
""" """
The wizard should accept the given default answers for esp32 The wizard should accept the given default answers for esp32
""" """
@@ -418,7 +433,9 @@ def test_wizard_accepts_default_answers_esp32(tmpdir, monkeypatch, wizard_answer
assert retval == 0 assert retval == 0
def test_wizard_offers_better_node_name(tmpdir, monkeypatch, wizard_answers): def test_wizard_offers_better_node_name(
tmpdir, monkeypatch: MonkeyPatch, wizard_answers: list[str]
):
""" """
When the node name does not conform, a better alternative is offered When the node name does not conform, a better alternative is offered
* Removes special chars * Removes special chars
@@ -449,7 +466,9 @@ def test_wizard_offers_better_node_name(tmpdir, monkeypatch, wizard_answers):
assert wz.default_input.call_args.args[1] == expected_name assert wz.default_input.call_args.args[1] == expected_name
def test_wizard_requires_correct_platform(tmpdir, monkeypatch, wizard_answers): def test_wizard_requires_correct_platform(
tmpdir, monkeypatch: MonkeyPatch, wizard_answers: list[str]
):
""" """
When the platform is not either esp32 or esp8266, the wizard should reject it When the platform is not either esp32 or esp8266, the wizard should reject it
""" """
@@ -471,7 +490,9 @@ def test_wizard_requires_correct_platform(tmpdir, monkeypatch, wizard_answers):
assert retval == 0 assert retval == 0
def test_wizard_requires_correct_board(tmpdir, monkeypatch, wizard_answers): def test_wizard_requires_correct_board(
tmpdir, monkeypatch: MonkeyPatch, wizard_answers: list[str]
):
""" """
When the board is not a valid esp8266 board, the wizard should reject it When the board is not a valid esp8266 board, the wizard should reject it
""" """
@@ -493,7 +514,9 @@ def test_wizard_requires_correct_board(tmpdir, monkeypatch, wizard_answers):
assert retval == 0 assert retval == 0
def test_wizard_requires_valid_ssid(tmpdir, monkeypatch, wizard_answers): def test_wizard_requires_valid_ssid(
tmpdir, monkeypatch: MonkeyPatch, wizard_answers: list[str]
):
""" """
When the board is not a valid esp8266 board, the wizard should reject it When the board is not a valid esp8266 board, the wizard should reject it
""" """
@@ -515,7 +538,9 @@ def test_wizard_requires_valid_ssid(tmpdir, monkeypatch, wizard_answers):
assert retval == 0 assert retval == 0
def test_wizard_write_protects_existing_config(tmpdir, default_config, monkeypatch): def test_wizard_write_protects_existing_config(
tmpdir, default_config: dict[str, Any], monkeypatch: MonkeyPatch
):
""" """
The wizard_write function should not overwrite existing config files and return False The wizard_write function should not overwrite existing config files and return False
""" """