diff --git a/esphome/wizard.py b/esphome/wizard.py index d77450b04d..f5e8a1e462 100644 --- a/esphome/wizard.py +++ b/esphome/wizard.py @@ -1,5 +1,7 @@ +import base64 from pathlib import Path import random +import secrets import string from typing import Literal, NotRequired, TypedDict, Unpack import unicodedata @@ -116,7 +118,6 @@ class WizardFileKwargs(TypedDict): board: str ssid: NotRequired[str] psk: NotRequired[str] - password: NotRequired[str] ota_password: NotRequired[str] api_encryption_key: NotRequired[str] friendly_name: NotRequired[str] @@ -144,9 +145,7 @@ def wizard_file(**kwargs: Unpack[WizardFileKwargs]) -> str: config += API_CONFIG - # Configure API - if "password" in kwargs: - config += f' password: "{kwargs["password"]}"\n' + # Configure API encryption if "api_encryption_key" in kwargs: config += f' encryption:\n key: "{kwargs["api_encryption_key"]}"\n' @@ -155,8 +154,6 @@ def wizard_file(**kwargs: Unpack[WizardFileKwargs]) -> str: config += " - platform: esphome\n" if "ota_password" in kwargs: config += f' password: "{kwargs["ota_password"]}"' - elif "password" in kwargs: - config += f' password: "{kwargs["password"]}"' # Configuring wifi config += "\n\nwifi:\n" @@ -205,7 +202,6 @@ class WizardWriteKwargs(TypedDict): platform: NotRequired[str] ssid: NotRequired[str] psk: NotRequired[str] - password: NotRequired[str] ota_password: NotRequired[str] api_encryption_key: NotRequired[str] friendly_name: NotRequired[str] @@ -232,7 +228,7 @@ def wizard_write(path: Path, **kwargs: Unpack[WizardWriteKwargs]) -> bool: else: # "basic" board = kwargs["board"] - for key in ("ssid", "psk", "password", "ota_password"): + for key in ("ssid", "psk", "ota_password"): if key in kwargs: kwargs[key] = sanitize_double_quotes(kwargs[key]) if "platform" not in kwargs: @@ -522,26 +518,54 @@ def wizard(path: Path) -> int: "Almost there! ESPHome can automatically upload custom firmwares over WiFi " "(over the air) and integrates into Home Assistant with a native API." ) + safe_print() + sleep(0.5) + + # Generate encryption key (32 bytes, base64 encoded) for secure API communication + noise_psk = secrets.token_bytes(32) + api_encryption_key = base64.b64encode(noise_psk).decode() + safe_print( - f"This can be insecure if you do not trust the WiFi network. Do you want to set a {color(AnsiFore.GREEN, 'password')} for connecting to this ESP?" + "For secure API communication, I've generated a random encryption key." + ) + safe_print() + safe_print( + f"Your {color(AnsiFore.GREEN, 'API encryption key')} is: " + f"{color(AnsiFore.BOLD_WHITE, api_encryption_key)}" + ) + safe_print() + safe_print("You'll need this key when adding the device to Home Assistant.") + sleep(1) + + safe_print() + safe_print( + f"Do you want to set a {color(AnsiFore.GREEN, 'password')} for OTA updates? " + "This can be insecure if you do not trust the WiFi network." ) safe_print() sleep(0.25) safe_print("Press ENTER for no password") - password = safe_input(color(AnsiFore.BOLD_WHITE, "(password): ")) + ota_password = safe_input(color(AnsiFore.BOLD_WHITE, "(password): ")) else: - ssid, password, psk = "", "", "" + ssid, psk = "", "" + api_encryption_key = None + ota_password = "" - if not wizard_write( - path=path, - name=name, - platform=platform, - board=board, - ssid=ssid, - psk=psk, - password=password, - type="basic", - ): + kwargs = { + "path": path, + "name": name, + "platform": platform, + "board": board, + "ssid": ssid, + "psk": psk, + "type": "basic", + } + if api_encryption_key: + kwargs["api_encryption_key"] = api_encryption_key + if ota_password: + kwargs["ota_password"] = ota_password + + if not wizard_write(**kwargs): return 1 safe_print() diff --git a/tests/unit_tests/test_wizard.py b/tests/unit_tests/test_wizard.py index fd53a0b0b7..eb44c1c20f 100644 --- a/tests/unit_tests/test_wizard.py +++ b/tests/unit_tests/test_wizard.py @@ -25,7 +25,6 @@ def default_config() -> dict[str, Any]: "board": "esp01_1m", "ssid": "test_ssid", "psk": "test_psk", - "password": "", } @@ -37,7 +36,7 @@ def wizard_answers() -> list[str]: "nodemcuv2", # board "SSID", # ssid "psk", # wifi password - "ota_pass", # ota password + "", # ota password (empty for no password) ] @@ -105,16 +104,35 @@ def test_config_file_should_include_ota_when_password_set( default_config: dict[str, Any], ): """ - The Over-The-Air update should be enabled when a password is set + The Over-The-Air update should be enabled when an OTA password is set """ # Given - default_config["password"] = "foo" + default_config["ota_password"] = "foo" # When config = wz.wizard_file(**default_config) # Then assert "ota:" in config + assert 'password: "foo"' in config + + +def test_config_file_should_include_api_encryption_key( + default_config: dict[str, Any], +): + """ + The API encryption key should be included when set + """ + # Given + default_config["api_encryption_key"] = "test_encryption_key_base64==" + + # When + config = wz.wizard_file(**default_config) + + # Then + assert "api:" in config + assert "encryption:" in config + assert 'key: "test_encryption_key_base64=="' in config def test_wizard_write_sets_platform( @@ -556,3 +574,61 @@ def test_wizard_write_protects_existing_config( # Then assert result is False # Should return False when file exists assert config_file.read_text() == original_content + + +def test_wizard_accepts_ota_password( + tmp_path: Path, monkeypatch: MonkeyPatch, wizard_answers: list[str] +): + """ + The wizard should pass ota_password to wizard_write when the user provides one + """ + + # Given + wizard_answers[5] = "my_ota_password" # Set OTA password + config_file = tmp_path / "test.yaml" + input_mock = MagicMock(side_effect=wizard_answers) + monkeypatch.setattr("builtins.input", input_mock) + monkeypatch.setattr(wz, "safe_print", lambda t=None, end=None: 0) + monkeypatch.setattr(wz, "sleep", lambda _: 0) + wizard_write_mock = MagicMock(return_value=True) + monkeypatch.setattr(wz, "wizard_write", wizard_write_mock) + + # When + retval = wz.wizard(config_file) + + # Then + assert retval == 0 + call_kwargs = wizard_write_mock.call_args.kwargs + assert "ota_password" in call_kwargs + assert call_kwargs["ota_password"] == "my_ota_password" + + +def test_wizard_accepts_rpipico_board(tmp_path: Path, monkeypatch: MonkeyPatch): + """ + The wizard should handle rpipico board which doesn't support WiFi. + This tests the branch where api_encryption_key is None. + """ + + # Given + wizard_answers_rp2040 = [ + "test-node", # Name of the node + "RP2040", # platform + "rpipico", # board (no WiFi support) + ] + config_file = tmp_path / "test.yaml" + input_mock = MagicMock(side_effect=wizard_answers_rp2040) + monkeypatch.setattr("builtins.input", input_mock) + monkeypatch.setattr(wz, "safe_print", lambda t=None, end=None: 0) + monkeypatch.setattr(wz, "sleep", lambda _: 0) + wizard_write_mock = MagicMock(return_value=True) + monkeypatch.setattr(wz, "wizard_write", wizard_write_mock) + + # When + retval = wz.wizard(config_file) + + # Then + assert retval == 0 + call_kwargs = wizard_write_mock.call_args.kwargs + # rpipico doesn't support WiFi, so no api_encryption_key or ota_password + assert "api_encryption_key" not in call_kwargs + assert "ota_password" not in call_kwargs