diff --git a/esphome/components/api/api_connection.cpp b/esphome/components/api/api_connection.cpp index 4b3a3e2fc8..02b1d61368 100644 --- a/esphome/components/api/api_connection.cpp +++ b/esphome/components/api/api_connection.cpp @@ -112,7 +112,7 @@ void APIConnection::start() { APIError err = this->helper_->init(); if (err != APIError::OK) { on_fatal_error(); - this->log_warning_("Helper init failed", err); + this->log_warning_(LOG_STR("Helper init failed"), err); return; } this->client_info_.peername = helper_->getpeername(); @@ -159,7 +159,7 @@ void APIConnection::loop() { break; } else if (err != APIError::OK) { on_fatal_error(); - this->log_warning_("Reading failed", err); + this->log_warning_(LOG_STR("Reading failed"), err); return; } else { this->last_traffic_ = now; @@ -1565,7 +1565,7 @@ bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint8_t message_type) { return false; if (err != APIError::OK) { on_fatal_error(); - this->log_warning_("Packet write failed", err); + this->log_warning_(LOG_STR("Packet write failed"), err); return false; } // Do not set last_traffic_ on send @@ -1752,7 +1752,7 @@ void APIConnection::process_batch_() { std::span(packet_info, packet_count)); if (err != APIError::OK && err != APIError::WOULD_BLOCK) { on_fatal_error(); - this->log_warning_("Batch write failed", err); + this->log_warning_(LOG_STR("Batch write failed"), err); } #ifdef HAS_PROTO_MESSAGE_DUMP @@ -1830,11 +1830,14 @@ void APIConnection::process_state_subscriptions_() { } #endif // USE_API_HOMEASSISTANT_STATES -void APIConnection::log_warning_(const char *message, APIError err) { - ESP_LOGW(TAG, "%s: %s %s errno=%d", this->get_client_combined_info().c_str(), message, api_error_to_str(err), errno); +void APIConnection::log_warning_(const LogString *message, APIError err) { + ESP_LOGW(TAG, "%s: %s %s errno=%d", this->get_client_combined_info().c_str(), LOG_STR_ARG(message), + LOG_STR_ARG(api_error_to_logstr(err)), errno); } -void APIConnection::log_socket_operation_failed_(APIError err) { this->log_warning_("Socket operation failed", err); } +void APIConnection::log_socket_operation_failed_(APIError err) { + this->log_warning_(LOG_STR("Socket operation failed"), err); +} } // namespace esphome::api #endif diff --git a/esphome/components/api/api_connection.h b/esphome/components/api/api_connection.h index 72254d1536..7ee82e0c68 100644 --- a/esphome/components/api/api_connection.h +++ b/esphome/components/api/api_connection.h @@ -732,7 +732,7 @@ class APIConnection final : public APIServerConnection { } // Helper function to log API errors with errno - void log_warning_(const char *message, APIError err); + void log_warning_(const LogString *message, APIError err); // Specific helper for duplicated error message void log_socket_operation_failed_(APIError err); }; diff --git a/esphome/components/api/api_frame_helper.cpp b/esphome/components/api/api_frame_helper.cpp index dee3af2ac3..a284e09c4a 100644 --- a/esphome/components/api/api_frame_helper.cpp +++ b/esphome/components/api/api_frame_helper.cpp @@ -23,59 +23,59 @@ static const char *const TAG = "api.frame_helper"; #define LOG_PACKET_SENDING(data, len) ((void) 0) #endif -const char *api_error_to_str(APIError err) { +const LogString *api_error_to_logstr(APIError err) { // not using switch to ensure compiler doesn't try to build a big table out of it if (err == APIError::OK) { - return "OK"; + return LOG_STR("OK"); } else if (err == APIError::WOULD_BLOCK) { - return "WOULD_BLOCK"; + return LOG_STR("WOULD_BLOCK"); } else if (err == APIError::BAD_INDICATOR) { - return "BAD_INDICATOR"; + return LOG_STR("BAD_INDICATOR"); } else if (err == APIError::BAD_DATA_PACKET) { - return "BAD_DATA_PACKET"; + return LOG_STR("BAD_DATA_PACKET"); } else if (err == APIError::TCP_NODELAY_FAILED) { - return "TCP_NODELAY_FAILED"; + return LOG_STR("TCP_NODELAY_FAILED"); } else if (err == APIError::TCP_NONBLOCKING_FAILED) { - return "TCP_NONBLOCKING_FAILED"; + return LOG_STR("TCP_NONBLOCKING_FAILED"); } else if (err == APIError::CLOSE_FAILED) { - return "CLOSE_FAILED"; + return LOG_STR("CLOSE_FAILED"); } else if (err == APIError::SHUTDOWN_FAILED) { - return "SHUTDOWN_FAILED"; + return LOG_STR("SHUTDOWN_FAILED"); } else if (err == APIError::BAD_STATE) { - return "BAD_STATE"; + return LOG_STR("BAD_STATE"); } else if (err == APIError::BAD_ARG) { - return "BAD_ARG"; + return LOG_STR("BAD_ARG"); } else if (err == APIError::SOCKET_READ_FAILED) { - return "SOCKET_READ_FAILED"; + return LOG_STR("SOCKET_READ_FAILED"); } else if (err == APIError::SOCKET_WRITE_FAILED) { - return "SOCKET_WRITE_FAILED"; + return LOG_STR("SOCKET_WRITE_FAILED"); } else if (err == APIError::OUT_OF_MEMORY) { - return "OUT_OF_MEMORY"; + return LOG_STR("OUT_OF_MEMORY"); } else if (err == APIError::CONNECTION_CLOSED) { - return "CONNECTION_CLOSED"; + return LOG_STR("CONNECTION_CLOSED"); } #ifdef USE_API_NOISE else if (err == APIError::BAD_HANDSHAKE_PACKET_LEN) { - return "BAD_HANDSHAKE_PACKET_LEN"; + return LOG_STR("BAD_HANDSHAKE_PACKET_LEN"); } else if (err == APIError::HANDSHAKESTATE_READ_FAILED) { - return "HANDSHAKESTATE_READ_FAILED"; + return LOG_STR("HANDSHAKESTATE_READ_FAILED"); } else if (err == APIError::HANDSHAKESTATE_WRITE_FAILED) { - return "HANDSHAKESTATE_WRITE_FAILED"; + return LOG_STR("HANDSHAKESTATE_WRITE_FAILED"); } else if (err == APIError::HANDSHAKESTATE_BAD_STATE) { - return "HANDSHAKESTATE_BAD_STATE"; + return LOG_STR("HANDSHAKESTATE_BAD_STATE"); } else if (err == APIError::CIPHERSTATE_DECRYPT_FAILED) { - return "CIPHERSTATE_DECRYPT_FAILED"; + return LOG_STR("CIPHERSTATE_DECRYPT_FAILED"); } else if (err == APIError::CIPHERSTATE_ENCRYPT_FAILED) { - return "CIPHERSTATE_ENCRYPT_FAILED"; + return LOG_STR("CIPHERSTATE_ENCRYPT_FAILED"); } else if (err == APIError::HANDSHAKESTATE_SETUP_FAILED) { - return "HANDSHAKESTATE_SETUP_FAILED"; + return LOG_STR("HANDSHAKESTATE_SETUP_FAILED"); } else if (err == APIError::HANDSHAKESTATE_SPLIT_FAILED) { - return "HANDSHAKESTATE_SPLIT_FAILED"; + return LOG_STR("HANDSHAKESTATE_SPLIT_FAILED"); } else if (err == APIError::BAD_HANDSHAKE_ERROR_BYTE) { - return "BAD_HANDSHAKE_ERROR_BYTE"; + return LOG_STR("BAD_HANDSHAKE_ERROR_BYTE"); } #endif - return "UNKNOWN"; + return LOG_STR("UNKNOWN"); } // Default implementation for loop - handles sending buffered data diff --git a/esphome/components/api/api_frame_helper.h b/esphome/components/api/api_frame_helper.h index 43e9d95fbe..c11d701ffe 100644 --- a/esphome/components/api/api_frame_helper.h +++ b/esphome/components/api/api_frame_helper.h @@ -66,7 +66,7 @@ enum class APIError : uint16_t { #endif }; -const char *api_error_to_str(APIError err); +const LogString *api_error_to_logstr(APIError err); class APIFrameHelper { public: diff --git a/esphome/components/api/api_frame_helper_noise.cpp b/esphome/components/api/api_frame_helper_noise.cpp index 35d1715931..394a8baa35 100644 --- a/esphome/components/api/api_frame_helper_noise.cpp +++ b/esphome/components/api/api_frame_helper_noise.cpp @@ -27,42 +27,42 @@ static constexpr size_t PROLOGUE_INIT_LEN = 12; // strlen("NoiseAPIInit") #endif /// Convert a noise error code to a readable error -std::string noise_err_to_str(int err) { +const LogString *noise_err_to_logstr(int err) { if (err == NOISE_ERROR_NO_MEMORY) - return "NO_MEMORY"; + return LOG_STR("NO_MEMORY"); if (err == NOISE_ERROR_UNKNOWN_ID) - return "UNKNOWN_ID"; + return LOG_STR("UNKNOWN_ID"); if (err == NOISE_ERROR_UNKNOWN_NAME) - return "UNKNOWN_NAME"; + return LOG_STR("UNKNOWN_NAME"); if (err == NOISE_ERROR_MAC_FAILURE) - return "MAC_FAILURE"; + return LOG_STR("MAC_FAILURE"); if (err == NOISE_ERROR_NOT_APPLICABLE) - return "NOT_APPLICABLE"; + return LOG_STR("NOT_APPLICABLE"); if (err == NOISE_ERROR_SYSTEM) - return "SYSTEM"; + return LOG_STR("SYSTEM"); if (err == NOISE_ERROR_REMOTE_KEY_REQUIRED) - return "REMOTE_KEY_REQUIRED"; + return LOG_STR("REMOTE_KEY_REQUIRED"); if (err == NOISE_ERROR_LOCAL_KEY_REQUIRED) - return "LOCAL_KEY_REQUIRED"; + return LOG_STR("LOCAL_KEY_REQUIRED"); if (err == NOISE_ERROR_PSK_REQUIRED) - return "PSK_REQUIRED"; + return LOG_STR("PSK_REQUIRED"); if (err == NOISE_ERROR_INVALID_LENGTH) - return "INVALID_LENGTH"; + return LOG_STR("INVALID_LENGTH"); if (err == NOISE_ERROR_INVALID_PARAM) - return "INVALID_PARAM"; + return LOG_STR("INVALID_PARAM"); if (err == NOISE_ERROR_INVALID_STATE) - return "INVALID_STATE"; + return LOG_STR("INVALID_STATE"); if (err == NOISE_ERROR_INVALID_NONCE) - return "INVALID_NONCE"; + return LOG_STR("INVALID_NONCE"); if (err == NOISE_ERROR_INVALID_PRIVATE_KEY) - return "INVALID_PRIVATE_KEY"; + return LOG_STR("INVALID_PRIVATE_KEY"); if (err == NOISE_ERROR_INVALID_PUBLIC_KEY) - return "INVALID_PUBLIC_KEY"; + return LOG_STR("INVALID_PUBLIC_KEY"); if (err == NOISE_ERROR_INVALID_FORMAT) - return "INVALID_FORMAT"; + return LOG_STR("INVALID_FORMAT"); if (err == NOISE_ERROR_INVALID_SIGNATURE) - return "INVALID_SIGNATURE"; - return to_string(err); + return LOG_STR("INVALID_SIGNATURE"); + return LOG_STR("UNKNOWN"); } /// Initialize the frame helper, returns OK if successful. @@ -83,18 +83,18 @@ APIError APINoiseFrameHelper::init() { // Helper for handling handshake frame errors APIError APINoiseFrameHelper::handle_handshake_frame_error_(APIError aerr) { if (aerr == APIError::BAD_INDICATOR) { - send_explicit_handshake_reject_("Bad indicator byte"); + send_explicit_handshake_reject_(LOG_STR("Bad indicator byte")); } else if (aerr == APIError::BAD_HANDSHAKE_PACKET_LEN) { - send_explicit_handshake_reject_("Bad handshake packet len"); + send_explicit_handshake_reject_(LOG_STR("Bad handshake packet len")); } return aerr; } // Helper for handling noise library errors -APIError APINoiseFrameHelper::handle_noise_error_(int err, const char *func_name, APIError api_err) { +APIError APINoiseFrameHelper::handle_noise_error_(int err, const LogString *func_name, APIError api_err) { if (err != 0) { state_ = State::FAILED; - HELPER_LOG("%s failed: %s", func_name, noise_err_to_str(err).c_str()); + HELPER_LOG("%s failed: %s", LOG_STR_ARG(func_name), LOG_STR_ARG(noise_err_to_logstr(err))); return api_err; } return APIError::OK; @@ -279,11 +279,11 @@ APIError APINoiseFrameHelper::state_action_() { } if (frame.empty()) { - send_explicit_handshake_reject_("Empty handshake message"); + send_explicit_handshake_reject_(LOG_STR("Empty handshake message")); return APIError::BAD_HANDSHAKE_ERROR_BYTE; } else if (frame[0] != 0x00) { HELPER_LOG("Bad handshake error byte: %u", frame[0]); - send_explicit_handshake_reject_("Bad handshake error byte"); + send_explicit_handshake_reject_(LOG_STR("Bad handshake error byte")); return APIError::BAD_HANDSHAKE_ERROR_BYTE; } @@ -293,8 +293,10 @@ APIError APINoiseFrameHelper::state_action_() { err = noise_handshakestate_read_message(handshake_, &mbuf, nullptr); if (err != 0) { // Special handling for MAC failure - send_explicit_handshake_reject_(err == NOISE_ERROR_MAC_FAILURE ? "Handshake MAC failure" : "Handshake error"); - return handle_noise_error_(err, "noise_handshakestate_read_message", APIError::HANDSHAKESTATE_READ_FAILED); + send_explicit_handshake_reject_(err == NOISE_ERROR_MAC_FAILURE ? LOG_STR("Handshake MAC failure") + : LOG_STR("Handshake error")); + return handle_noise_error_(err, LOG_STR("noise_handshakestate_read_message"), + APIError::HANDSHAKESTATE_READ_FAILED); } aerr = check_handshake_finished_(); @@ -307,8 +309,8 @@ APIError APINoiseFrameHelper::state_action_() { noise_buffer_set_output(mbuf, buffer + 1, sizeof(buffer) - 1); err = noise_handshakestate_write_message(handshake_, &mbuf, nullptr); - APIError aerr_write = - handle_noise_error_(err, "noise_handshakestate_write_message", APIError::HANDSHAKESTATE_WRITE_FAILED); + APIError aerr_write = handle_noise_error_(err, LOG_STR("noise_handshakestate_write_message"), + APIError::HANDSHAKESTATE_WRITE_FAILED); if (aerr_write != APIError::OK) return aerr_write; buffer[0] = 0x00; // success @@ -331,15 +333,31 @@ APIError APINoiseFrameHelper::state_action_() { } return APIError::OK; } -void APINoiseFrameHelper::send_explicit_handshake_reject_(const std::string &reason) { +void APINoiseFrameHelper::send_explicit_handshake_reject_(const LogString *reason) { +#ifdef USE_STORE_LOG_STR_IN_FLASH + // On ESP8266 with flash strings, we need to use PROGMEM-aware functions + size_t reason_len = strlen_P(reinterpret_cast(reason)); std::vector data; - data.resize(reason.length() + 1); + data.resize(reason_len + 1); + data[0] = 0x01; // failure + + // Copy error message from PROGMEM + if (reason_len > 0) { + memcpy_P(data.data() + 1, reinterpret_cast(reason), reason_len); + } +#else + // Normal memory access + const char *reason_str = LOG_STR_ARG(reason); + size_t reason_len = strlen(reason_str); + std::vector data; + data.resize(reason_len + 1); data[0] = 0x01; // failure // Copy error message in bulk - if (!reason.empty()) { - std::memcpy(data.data() + 1, reason.c_str(), reason.length()); + if (reason_len > 0) { + std::memcpy(data.data() + 1, reason_str, reason_len); } +#endif // temporarily remove failed state auto orig_state = state_; @@ -368,7 +386,8 @@ APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) { noise_buffer_init(mbuf); noise_buffer_set_inout(mbuf, frame.data(), frame.size(), frame.size()); err = noise_cipherstate_decrypt(recv_cipher_, &mbuf); - APIError decrypt_err = handle_noise_error_(err, "noise_cipherstate_decrypt", APIError::CIPHERSTATE_DECRYPT_FAILED); + APIError decrypt_err = + handle_noise_error_(err, LOG_STR("noise_cipherstate_decrypt"), APIError::CIPHERSTATE_DECRYPT_FAILED); if (decrypt_err != APIError::OK) return decrypt_err; @@ -450,7 +469,8 @@ APIError APINoiseFrameHelper::write_protobuf_packets(ProtoWriteBuffer buffer, st 4 + packet.payload_size + frame_footer_size_); int err = noise_cipherstate_encrypt(send_cipher_, &mbuf); - APIError aerr = handle_noise_error_(err, "noise_cipherstate_encrypt", APIError::CIPHERSTATE_ENCRYPT_FAILED); + APIError aerr = + handle_noise_error_(err, LOG_STR("noise_cipherstate_encrypt"), APIError::CIPHERSTATE_ENCRYPT_FAILED); if (aerr != APIError::OK) return aerr; @@ -504,25 +524,27 @@ APIError APINoiseFrameHelper::init_handshake_() { nid_.modifier_ids[0] = NOISE_MODIFIER_PSK0; err = noise_handshakestate_new_by_id(&handshake_, &nid_, NOISE_ROLE_RESPONDER); - APIError aerr = handle_noise_error_(err, "noise_handshakestate_new_by_id", APIError::HANDSHAKESTATE_SETUP_FAILED); + APIError aerr = + handle_noise_error_(err, LOG_STR("noise_handshakestate_new_by_id"), APIError::HANDSHAKESTATE_SETUP_FAILED); if (aerr != APIError::OK) return aerr; const auto &psk = ctx_->get_psk(); err = noise_handshakestate_set_pre_shared_key(handshake_, psk.data(), psk.size()); - aerr = handle_noise_error_(err, "noise_handshakestate_set_pre_shared_key", APIError::HANDSHAKESTATE_SETUP_FAILED); + aerr = handle_noise_error_(err, LOG_STR("noise_handshakestate_set_pre_shared_key"), + APIError::HANDSHAKESTATE_SETUP_FAILED); if (aerr != APIError::OK) return aerr; err = noise_handshakestate_set_prologue(handshake_, prologue_.data(), prologue_.size()); - aerr = handle_noise_error_(err, "noise_handshakestate_set_prologue", APIError::HANDSHAKESTATE_SETUP_FAILED); + aerr = handle_noise_error_(err, LOG_STR("noise_handshakestate_set_prologue"), APIError::HANDSHAKESTATE_SETUP_FAILED); if (aerr != APIError::OK) return aerr; // set_prologue copies it into handshakestate, so we can get rid of it now prologue_ = {}; err = noise_handshakestate_start(handshake_); - aerr = handle_noise_error_(err, "noise_handshakestate_start", APIError::HANDSHAKESTATE_SETUP_FAILED); + aerr = handle_noise_error_(err, LOG_STR("noise_handshakestate_start"), APIError::HANDSHAKESTATE_SETUP_FAILED); if (aerr != APIError::OK) return aerr; return APIError::OK; @@ -540,7 +562,8 @@ APIError APINoiseFrameHelper::check_handshake_finished_() { return APIError::HANDSHAKESTATE_BAD_STATE; } int err = noise_handshakestate_split(handshake_, &send_cipher_, &recv_cipher_); - APIError aerr = handle_noise_error_(err, "noise_handshakestate_split", APIError::HANDSHAKESTATE_SPLIT_FAILED); + APIError aerr = + handle_noise_error_(err, LOG_STR("noise_handshakestate_split"), APIError::HANDSHAKESTATE_SPLIT_FAILED); if (aerr != APIError::OK) return aerr; diff --git a/esphome/components/api/api_frame_helper_noise.h b/esphome/components/api/api_frame_helper_noise.h index 49bc6f8854..71a217c4ca 100644 --- a/esphome/components/api/api_frame_helper_noise.h +++ b/esphome/components/api/api_frame_helper_noise.h @@ -32,9 +32,9 @@ class APINoiseFrameHelper final : public APIFrameHelper { APIError write_frame_(const uint8_t *data, uint16_t len); APIError init_handshake_(); APIError check_handshake_finished_(); - void send_explicit_handshake_reject_(const std::string &reason); + void send_explicit_handshake_reject_(const LogString *reason); APIError handle_handshake_frame_error_(APIError aerr); - APIError handle_noise_error_(int err, const char *func_name, APIError api_err); + APIError handle_noise_error_(int err, const LogString *func_name, APIError api_err); // Pointers first (4 bytes each) NoiseHandshakeState *handshake_{nullptr}; diff --git a/esphome/dashboard/web_server.py b/esphome/dashboard/web_server.py index fd16667d8a..294a180794 100644 --- a/esphome/dashboard/web_server.py +++ b/esphome/dashboard/web_server.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import base64 +import binascii from collections.abc import Callable, Iterable import datetime import functools @@ -490,7 +491,17 @@ class WizardRequestHandler(BaseHandler): kwargs = { k: v for k, v in json.loads(self.request.body.decode()).items() - if k in ("name", "platform", "board", "ssid", "psk", "password") + if k + in ( + "type", + "name", + "platform", + "board", + "ssid", + "psk", + "password", + "file_content", + ) } if not kwargs["name"]: self.set_status(422) @@ -498,19 +509,65 @@ class WizardRequestHandler(BaseHandler): self.write(json.dumps({"error": "Name is required"})) return + if "type" not in kwargs: + # Default to basic wizard type for backwards compatibility + kwargs["type"] = "basic" + kwargs["friendly_name"] = kwargs["name"] kwargs["name"] = friendly_name_slugify(kwargs["friendly_name"]) - - kwargs["ota_password"] = secrets.token_hex(16) - noise_psk = secrets.token_bytes(32) - kwargs["api_encryption_key"] = base64.b64encode(noise_psk).decode() + if kwargs["type"] == "basic": + kwargs["ota_password"] = secrets.token_hex(16) + noise_psk = secrets.token_bytes(32) + kwargs["api_encryption_key"] = base64.b64encode(noise_psk).decode() + elif kwargs["type"] == "upload": + try: + kwargs["file_text"] = base64.b64decode(kwargs["file_content"]).decode( + "utf-8" + ) + except (binascii.Error, UnicodeDecodeError): + self.set_status(422) + self.set_header("content-type", "application/json") + self.write( + json.dumps({"error": "The uploaded file is not correctly encoded."}) + ) + return + elif kwargs["type"] != "empty": + self.set_status(422) + self.set_header("content-type", "application/json") + self.write( + json.dumps( + {"error": f"Invalid wizard type specified: {kwargs['type']}"} + ) + ) + return filename = f"{kwargs['name']}.yaml" destination = settings.rel_path(filename) - wizard.wizard_write(path=destination, **kwargs) - self.set_status(200) - self.set_header("content-type", "application/json") - self.write(json.dumps({"configuration": filename})) - self.finish() + + # Check if destination file already exists + if os.path.exists(destination): + self.set_status(409) # Conflict status code + self.set_header("content-type", "application/json") + self.write( + json.dumps({"error": f"Configuration file '{filename}' already exists"}) + ) + self.finish() + return + + success = wizard.wizard_write(path=destination, **kwargs) + if success: + self.set_status(200) + self.set_header("content-type", "application/json") + self.write(json.dumps({"configuration": filename})) + self.finish() + else: + self.set_status(500) + self.set_header("content-type", "application/json") + self.write( + json.dumps( + {"error": "Failed to write configuration, see logs for details"} + ) + ) + self.finish() class ImportRequestHandler(BaseHandler): diff --git a/esphome/wizard.py b/esphome/wizard.py index 8602e90222..cb599df59a 100644 --- a/esphome/wizard.py +++ b/esphome/wizard.py @@ -189,32 +189,45 @@ def wizard_write(path, **kwargs): from esphome.components.rtl87xx import boards as rtl87xx_boards name = kwargs["name"] - board = kwargs["board"] + if kwargs["type"] == "empty": + file_text = "" + # Will be updated later after editing the file + hardware = "UNKNOWN" + elif kwargs["type"] == "upload": + file_text = kwargs["file_text"] + hardware = "UNKNOWN" + else: # "basic" + board = kwargs["board"] - for key in ("ssid", "psk", "password", "ota_password"): - if key in kwargs: - kwargs[key] = sanitize_double_quotes(kwargs[key]) + for key in ("ssid", "psk", "password", "ota_password"): + if key in kwargs: + kwargs[key] = sanitize_double_quotes(kwargs[key]) + if "platform" not in kwargs: + if board in esp8266_boards.BOARDS: + platform = "ESP8266" + elif board in esp32_boards.BOARDS: + platform = "ESP32" + elif board in rp2040_boards.BOARDS: + platform = "RP2040" + elif board in bk72xx_boards.BOARDS: + platform = "BK72XX" + elif board in ln882x_boards.BOARDS: + platform = "LN882X" + elif board in rtl87xx_boards.BOARDS: + platform = "RTL87XX" + else: + safe_print(color(AnsiFore.RED, f'The board "{board}" is unknown.')) + return False + kwargs["platform"] = platform + hardware = kwargs["platform"] + file_text = wizard_file(**kwargs) - if "platform" not in kwargs: - if board in esp8266_boards.BOARDS: - platform = "ESP8266" - elif board in esp32_boards.BOARDS: - platform = "ESP32" - elif board in rp2040_boards.BOARDS: - platform = "RP2040" - elif board in bk72xx_boards.BOARDS: - platform = "BK72XX" - elif board in ln882x_boards.BOARDS: - platform = "LN882X" - elif board in rtl87xx_boards.BOARDS: - platform = "RTL87XX" - else: - safe_print(color(AnsiFore.RED, f'The board "{board}" is unknown.')) - return False - kwargs["platform"] = platform - hardware = kwargs["platform"] + # Check if file already exists to prevent overwriting + if os.path.exists(path) and os.path.isfile(path): + safe_print(color(AnsiFore.RED, f'The file "{path}" already exists.')) + return False - write_file(path, wizard_file(**kwargs)) + write_file(path, file_text) storage = StorageJSON.from_wizard(name, name, f"{name}.local", hardware) storage_path = ext_storage_path(os.path.basename(path)) storage.save(storage_path) diff --git a/tests/unit_tests/test_wizard.py b/tests/unit_tests/test_wizard.py index ab20b2abb5..fea2fb5558 100644 --- a/tests/unit_tests/test_wizard.py +++ b/tests/unit_tests/test_wizard.py @@ -17,6 +17,7 @@ import esphome.wizard as wz @pytest.fixture def default_config(): return { + "type": "basic", "name": "test-name", "platform": "ESP8266", "board": "esp01_1m", @@ -125,6 +126,47 @@ def test_wizard_write_sets_platform(default_config, tmp_path, monkeypatch): assert "esp8266:" in generated_config +def test_wizard_empty_config(tmp_path, monkeypatch): + """ + The wizard should be able to create an empty configuration + """ + # Given + empty_config = { + "type": "empty", + "name": "test-empty", + } + monkeypatch.setattr(wz, "write_file", MagicMock()) + monkeypatch.setattr(CORE, "config_path", os.path.dirname(tmp_path)) + + # When + wz.wizard_write(tmp_path, **empty_config) + + # Then + generated_config = wz.write_file.call_args.args[1] + assert generated_config == "" + + +def test_wizard_upload_config(tmp_path, monkeypatch): + """ + The wizard should be able to import an base64 encoded configuration + """ + # Given + empty_config = { + "type": "upload", + "name": "test-upload", + "file_text": "# imported file 📁\n\n", + } + monkeypatch.setattr(wz, "write_file", MagicMock()) + monkeypatch.setattr(CORE, "config_path", os.path.dirname(tmp_path)) + + # When + wz.wizard_write(tmp_path, **empty_config) + + # Then + generated_config = wz.write_file.call_args.args[1] + assert generated_config == "# imported file 📁\n\n" + + def test_wizard_write_defaults_platform_from_board_esp8266( default_config, tmp_path, monkeypatch ): @@ -471,3 +513,22 @@ def test_wizard_requires_valid_ssid(tmpdir, monkeypatch, wizard_answers): # Then assert retval == 0 + + +def test_wizard_write_protects_existing_config(tmpdir, default_config, monkeypatch): + """ + The wizard_write function should not overwrite existing config files and return False + """ + # Given + config_file = tmpdir.join("test.yaml") + original_content = "# Original config content\n" + config_file.write(original_content) + + monkeypatch.setattr(CORE, "config_path", str(tmpdir)) + + # When + result = wz.wizard_write(str(config_file), **default_config) + + # Then + assert result is False # Should return False when file exists + assert config_file.read() == original_content