1
0
mirror of https://github.com/esphome/esphome.git synced 2025-09-11 07:42:26 +01:00

Merge branch 'esp8266_api_progmem' into integration

This commit is contained in:
J. Nick Koston
2025-09-04 08:37:35 -05:00
9 changed files with 266 additions and 109 deletions

View File

@@ -112,7 +112,7 @@ void APIConnection::start() {
APIError err = this->helper_->init(); APIError err = this->helper_->init();
if (err != APIError::OK) { if (err != APIError::OK) {
on_fatal_error(); on_fatal_error();
this->log_warning_("Helper init failed", err); this->log_warning_(LOG_STR("Helper init failed"), err);
return; return;
} }
this->client_info_.peername = helper_->getpeername(); this->client_info_.peername = helper_->getpeername();
@@ -159,7 +159,7 @@ void APIConnection::loop() {
break; break;
} else if (err != APIError::OK) { } else if (err != APIError::OK) {
on_fatal_error(); on_fatal_error();
this->log_warning_("Reading failed", err); this->log_warning_(LOG_STR("Reading failed"), err);
return; return;
} else { } else {
this->last_traffic_ = now; this->last_traffic_ = now;
@@ -1565,7 +1565,7 @@ bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint8_t message_type) {
return false; return false;
if (err != APIError::OK) { if (err != APIError::OK) {
on_fatal_error(); on_fatal_error();
this->log_warning_("Packet write failed", err); this->log_warning_(LOG_STR("Packet write failed"), err);
return false; return false;
} }
// Do not set last_traffic_ on send // Do not set last_traffic_ on send
@@ -1752,7 +1752,7 @@ void APIConnection::process_batch_() {
std::span<const PacketInfo>(packet_info, packet_count)); std::span<const PacketInfo>(packet_info, packet_count));
if (err != APIError::OK && err != APIError::WOULD_BLOCK) { if (err != APIError::OK && err != APIError::WOULD_BLOCK) {
on_fatal_error(); on_fatal_error();
this->log_warning_("Batch write failed", err); this->log_warning_(LOG_STR("Batch write failed"), err);
} }
#ifdef HAS_PROTO_MESSAGE_DUMP #ifdef HAS_PROTO_MESSAGE_DUMP
@@ -1830,11 +1830,14 @@ void APIConnection::process_state_subscriptions_() {
} }
#endif // USE_API_HOMEASSISTANT_STATES #endif // USE_API_HOMEASSISTANT_STATES
void APIConnection::log_warning_(const char *message, APIError err) { void APIConnection::log_warning_(const LogString *message, APIError err) {
ESP_LOGW(TAG, "%s: %s %s errno=%d", this->get_client_combined_info().c_str(), message, api_error_to_str(err), errno); ESP_LOGW(TAG, "%s: %s %s errno=%d", this->get_client_combined_info().c_str(), LOG_STR_ARG(message),
LOG_STR_ARG(api_error_to_logstr(err)), errno);
} }
void APIConnection::log_socket_operation_failed_(APIError err) { this->log_warning_("Socket operation failed", err); } void APIConnection::log_socket_operation_failed_(APIError err) {
this->log_warning_(LOG_STR("Socket operation failed"), err);
}
} // namespace esphome::api } // namespace esphome::api
#endif #endif

View File

@@ -732,7 +732,7 @@ class APIConnection final : public APIServerConnection {
} }
// Helper function to log API errors with errno // Helper function to log API errors with errno
void log_warning_(const char *message, APIError err); void log_warning_(const LogString *message, APIError err);
// Specific helper for duplicated error message // Specific helper for duplicated error message
void log_socket_operation_failed_(APIError err); void log_socket_operation_failed_(APIError err);
}; };

View File

@@ -23,59 +23,59 @@ static const char *const TAG = "api.frame_helper";
#define LOG_PACKET_SENDING(data, len) ((void) 0) #define LOG_PACKET_SENDING(data, len) ((void) 0)
#endif #endif
const char *api_error_to_str(APIError err) { const LogString *api_error_to_logstr(APIError err) {
// not using switch to ensure compiler doesn't try to build a big table out of it // not using switch to ensure compiler doesn't try to build a big table out of it
if (err == APIError::OK) { if (err == APIError::OK) {
return "OK"; return LOG_STR("OK");
} else if (err == APIError::WOULD_BLOCK) { } else if (err == APIError::WOULD_BLOCK) {
return "WOULD_BLOCK"; return LOG_STR("WOULD_BLOCK");
} else if (err == APIError::BAD_INDICATOR) { } else if (err == APIError::BAD_INDICATOR) {
return "BAD_INDICATOR"; return LOG_STR("BAD_INDICATOR");
} else if (err == APIError::BAD_DATA_PACKET) { } else if (err == APIError::BAD_DATA_PACKET) {
return "BAD_DATA_PACKET"; return LOG_STR("BAD_DATA_PACKET");
} else if (err == APIError::TCP_NODELAY_FAILED) { } else if (err == APIError::TCP_NODELAY_FAILED) {
return "TCP_NODELAY_FAILED"; return LOG_STR("TCP_NODELAY_FAILED");
} else if (err == APIError::TCP_NONBLOCKING_FAILED) { } else if (err == APIError::TCP_NONBLOCKING_FAILED) {
return "TCP_NONBLOCKING_FAILED"; return LOG_STR("TCP_NONBLOCKING_FAILED");
} else if (err == APIError::CLOSE_FAILED) { } else if (err == APIError::CLOSE_FAILED) {
return "CLOSE_FAILED"; return LOG_STR("CLOSE_FAILED");
} else if (err == APIError::SHUTDOWN_FAILED) { } else if (err == APIError::SHUTDOWN_FAILED) {
return "SHUTDOWN_FAILED"; return LOG_STR("SHUTDOWN_FAILED");
} else if (err == APIError::BAD_STATE) { } else if (err == APIError::BAD_STATE) {
return "BAD_STATE"; return LOG_STR("BAD_STATE");
} else if (err == APIError::BAD_ARG) { } else if (err == APIError::BAD_ARG) {
return "BAD_ARG"; return LOG_STR("BAD_ARG");
} else if (err == APIError::SOCKET_READ_FAILED) { } else if (err == APIError::SOCKET_READ_FAILED) {
return "SOCKET_READ_FAILED"; return LOG_STR("SOCKET_READ_FAILED");
} else if (err == APIError::SOCKET_WRITE_FAILED) { } else if (err == APIError::SOCKET_WRITE_FAILED) {
return "SOCKET_WRITE_FAILED"; return LOG_STR("SOCKET_WRITE_FAILED");
} else if (err == APIError::OUT_OF_MEMORY) { } else if (err == APIError::OUT_OF_MEMORY) {
return "OUT_OF_MEMORY"; return LOG_STR("OUT_OF_MEMORY");
} else if (err == APIError::CONNECTION_CLOSED) { } else if (err == APIError::CONNECTION_CLOSED) {
return "CONNECTION_CLOSED"; return LOG_STR("CONNECTION_CLOSED");
} }
#ifdef USE_API_NOISE #ifdef USE_API_NOISE
else if (err == APIError::BAD_HANDSHAKE_PACKET_LEN) { else if (err == APIError::BAD_HANDSHAKE_PACKET_LEN) {
return "BAD_HANDSHAKE_PACKET_LEN"; return LOG_STR("BAD_HANDSHAKE_PACKET_LEN");
} else if (err == APIError::HANDSHAKESTATE_READ_FAILED) { } else if (err == APIError::HANDSHAKESTATE_READ_FAILED) {
return "HANDSHAKESTATE_READ_FAILED"; return LOG_STR("HANDSHAKESTATE_READ_FAILED");
} else if (err == APIError::HANDSHAKESTATE_WRITE_FAILED) { } else if (err == APIError::HANDSHAKESTATE_WRITE_FAILED) {
return "HANDSHAKESTATE_WRITE_FAILED"; return LOG_STR("HANDSHAKESTATE_WRITE_FAILED");
} else if (err == APIError::HANDSHAKESTATE_BAD_STATE) { } else if (err == APIError::HANDSHAKESTATE_BAD_STATE) {
return "HANDSHAKESTATE_BAD_STATE"; return LOG_STR("HANDSHAKESTATE_BAD_STATE");
} else if (err == APIError::CIPHERSTATE_DECRYPT_FAILED) { } else if (err == APIError::CIPHERSTATE_DECRYPT_FAILED) {
return "CIPHERSTATE_DECRYPT_FAILED"; return LOG_STR("CIPHERSTATE_DECRYPT_FAILED");
} else if (err == APIError::CIPHERSTATE_ENCRYPT_FAILED) { } else if (err == APIError::CIPHERSTATE_ENCRYPT_FAILED) {
return "CIPHERSTATE_ENCRYPT_FAILED"; return LOG_STR("CIPHERSTATE_ENCRYPT_FAILED");
} else if (err == APIError::HANDSHAKESTATE_SETUP_FAILED) { } else if (err == APIError::HANDSHAKESTATE_SETUP_FAILED) {
return "HANDSHAKESTATE_SETUP_FAILED"; return LOG_STR("HANDSHAKESTATE_SETUP_FAILED");
} else if (err == APIError::HANDSHAKESTATE_SPLIT_FAILED) { } else if (err == APIError::HANDSHAKESTATE_SPLIT_FAILED) {
return "HANDSHAKESTATE_SPLIT_FAILED"; return LOG_STR("HANDSHAKESTATE_SPLIT_FAILED");
} else if (err == APIError::BAD_HANDSHAKE_ERROR_BYTE) { } else if (err == APIError::BAD_HANDSHAKE_ERROR_BYTE) {
return "BAD_HANDSHAKE_ERROR_BYTE"; return LOG_STR("BAD_HANDSHAKE_ERROR_BYTE");
} }
#endif #endif
return "UNKNOWN"; return LOG_STR("UNKNOWN");
} }
// Default implementation for loop - handles sending buffered data // Default implementation for loop - handles sending buffered data

View File

@@ -66,7 +66,7 @@ enum class APIError : uint16_t {
#endif #endif
}; };
const char *api_error_to_str(APIError err); const LogString *api_error_to_logstr(APIError err);
class APIFrameHelper { class APIFrameHelper {
public: public:

View File

@@ -27,42 +27,42 @@ static constexpr size_t PROLOGUE_INIT_LEN = 12; // strlen("NoiseAPIInit")
#endif #endif
/// Convert a noise error code to a readable error /// Convert a noise error code to a readable error
std::string noise_err_to_str(int err) { const LogString *noise_err_to_logstr(int err) {
if (err == NOISE_ERROR_NO_MEMORY) if (err == NOISE_ERROR_NO_MEMORY)
return "NO_MEMORY"; return LOG_STR("NO_MEMORY");
if (err == NOISE_ERROR_UNKNOWN_ID) if (err == NOISE_ERROR_UNKNOWN_ID)
return "UNKNOWN_ID"; return LOG_STR("UNKNOWN_ID");
if (err == NOISE_ERROR_UNKNOWN_NAME) if (err == NOISE_ERROR_UNKNOWN_NAME)
return "UNKNOWN_NAME"; return LOG_STR("UNKNOWN_NAME");
if (err == NOISE_ERROR_MAC_FAILURE) if (err == NOISE_ERROR_MAC_FAILURE)
return "MAC_FAILURE"; return LOG_STR("MAC_FAILURE");
if (err == NOISE_ERROR_NOT_APPLICABLE) if (err == NOISE_ERROR_NOT_APPLICABLE)
return "NOT_APPLICABLE"; return LOG_STR("NOT_APPLICABLE");
if (err == NOISE_ERROR_SYSTEM) if (err == NOISE_ERROR_SYSTEM)
return "SYSTEM"; return LOG_STR("SYSTEM");
if (err == NOISE_ERROR_REMOTE_KEY_REQUIRED) if (err == NOISE_ERROR_REMOTE_KEY_REQUIRED)
return "REMOTE_KEY_REQUIRED"; return LOG_STR("REMOTE_KEY_REQUIRED");
if (err == NOISE_ERROR_LOCAL_KEY_REQUIRED) if (err == NOISE_ERROR_LOCAL_KEY_REQUIRED)
return "LOCAL_KEY_REQUIRED"; return LOG_STR("LOCAL_KEY_REQUIRED");
if (err == NOISE_ERROR_PSK_REQUIRED) if (err == NOISE_ERROR_PSK_REQUIRED)
return "PSK_REQUIRED"; return LOG_STR("PSK_REQUIRED");
if (err == NOISE_ERROR_INVALID_LENGTH) if (err == NOISE_ERROR_INVALID_LENGTH)
return "INVALID_LENGTH"; return LOG_STR("INVALID_LENGTH");
if (err == NOISE_ERROR_INVALID_PARAM) if (err == NOISE_ERROR_INVALID_PARAM)
return "INVALID_PARAM"; return LOG_STR("INVALID_PARAM");
if (err == NOISE_ERROR_INVALID_STATE) if (err == NOISE_ERROR_INVALID_STATE)
return "INVALID_STATE"; return LOG_STR("INVALID_STATE");
if (err == NOISE_ERROR_INVALID_NONCE) if (err == NOISE_ERROR_INVALID_NONCE)
return "INVALID_NONCE"; return LOG_STR("INVALID_NONCE");
if (err == NOISE_ERROR_INVALID_PRIVATE_KEY) if (err == NOISE_ERROR_INVALID_PRIVATE_KEY)
return "INVALID_PRIVATE_KEY"; return LOG_STR("INVALID_PRIVATE_KEY");
if (err == NOISE_ERROR_INVALID_PUBLIC_KEY) if (err == NOISE_ERROR_INVALID_PUBLIC_KEY)
return "INVALID_PUBLIC_KEY"; return LOG_STR("INVALID_PUBLIC_KEY");
if (err == NOISE_ERROR_INVALID_FORMAT) if (err == NOISE_ERROR_INVALID_FORMAT)
return "INVALID_FORMAT"; return LOG_STR("INVALID_FORMAT");
if (err == NOISE_ERROR_INVALID_SIGNATURE) if (err == NOISE_ERROR_INVALID_SIGNATURE)
return "INVALID_SIGNATURE"; return LOG_STR("INVALID_SIGNATURE");
return to_string(err); return LOG_STR("UNKNOWN");
} }
/// Initialize the frame helper, returns OK if successful. /// Initialize the frame helper, returns OK if successful.
@@ -83,18 +83,18 @@ APIError APINoiseFrameHelper::init() {
// Helper for handling handshake frame errors // Helper for handling handshake frame errors
APIError APINoiseFrameHelper::handle_handshake_frame_error_(APIError aerr) { APIError APINoiseFrameHelper::handle_handshake_frame_error_(APIError aerr) {
if (aerr == APIError::BAD_INDICATOR) { if (aerr == APIError::BAD_INDICATOR) {
send_explicit_handshake_reject_("Bad indicator byte"); send_explicit_handshake_reject_(LOG_STR("Bad indicator byte"));
} else if (aerr == APIError::BAD_HANDSHAKE_PACKET_LEN) { } else if (aerr == APIError::BAD_HANDSHAKE_PACKET_LEN) {
send_explicit_handshake_reject_("Bad handshake packet len"); send_explicit_handshake_reject_(LOG_STR("Bad handshake packet len"));
} }
return aerr; return aerr;
} }
// Helper for handling noise library errors // Helper for handling noise library errors
APIError APINoiseFrameHelper::handle_noise_error_(int err, const char *func_name, APIError api_err) { APIError APINoiseFrameHelper::handle_noise_error_(int err, const LogString *func_name, APIError api_err) {
if (err != 0) { if (err != 0) {
state_ = State::FAILED; state_ = State::FAILED;
HELPER_LOG("%s failed: %s", func_name, noise_err_to_str(err).c_str()); HELPER_LOG("%s failed: %s", LOG_STR_ARG(func_name), LOG_STR_ARG(noise_err_to_logstr(err)));
return api_err; return api_err;
} }
return APIError::OK; return APIError::OK;
@@ -279,11 +279,11 @@ APIError APINoiseFrameHelper::state_action_() {
} }
if (frame.empty()) { if (frame.empty()) {
send_explicit_handshake_reject_("Empty handshake message"); send_explicit_handshake_reject_(LOG_STR("Empty handshake message"));
return APIError::BAD_HANDSHAKE_ERROR_BYTE; return APIError::BAD_HANDSHAKE_ERROR_BYTE;
} else if (frame[0] != 0x00) { } else if (frame[0] != 0x00) {
HELPER_LOG("Bad handshake error byte: %u", frame[0]); HELPER_LOG("Bad handshake error byte: %u", frame[0]);
send_explicit_handshake_reject_("Bad handshake error byte"); send_explicit_handshake_reject_(LOG_STR("Bad handshake error byte"));
return APIError::BAD_HANDSHAKE_ERROR_BYTE; return APIError::BAD_HANDSHAKE_ERROR_BYTE;
} }
@@ -293,8 +293,10 @@ APIError APINoiseFrameHelper::state_action_() {
err = noise_handshakestate_read_message(handshake_, &mbuf, nullptr); err = noise_handshakestate_read_message(handshake_, &mbuf, nullptr);
if (err != 0) { if (err != 0) {
// Special handling for MAC failure // Special handling for MAC failure
send_explicit_handshake_reject_(err == NOISE_ERROR_MAC_FAILURE ? "Handshake MAC failure" : "Handshake error"); send_explicit_handshake_reject_(err == NOISE_ERROR_MAC_FAILURE ? LOG_STR("Handshake MAC failure")
return handle_noise_error_(err, "noise_handshakestate_read_message", APIError::HANDSHAKESTATE_READ_FAILED); : LOG_STR("Handshake error"));
return handle_noise_error_(err, LOG_STR("noise_handshakestate_read_message"),
APIError::HANDSHAKESTATE_READ_FAILED);
} }
aerr = check_handshake_finished_(); aerr = check_handshake_finished_();
@@ -307,8 +309,8 @@ APIError APINoiseFrameHelper::state_action_() {
noise_buffer_set_output(mbuf, buffer + 1, sizeof(buffer) - 1); noise_buffer_set_output(mbuf, buffer + 1, sizeof(buffer) - 1);
err = noise_handshakestate_write_message(handshake_, &mbuf, nullptr); err = noise_handshakestate_write_message(handshake_, &mbuf, nullptr);
APIError aerr_write = APIError aerr_write = handle_noise_error_(err, LOG_STR("noise_handshakestate_write_message"),
handle_noise_error_(err, "noise_handshakestate_write_message", APIError::HANDSHAKESTATE_WRITE_FAILED); APIError::HANDSHAKESTATE_WRITE_FAILED);
if (aerr_write != APIError::OK) if (aerr_write != APIError::OK)
return aerr_write; return aerr_write;
buffer[0] = 0x00; // success buffer[0] = 0x00; // success
@@ -331,15 +333,31 @@ APIError APINoiseFrameHelper::state_action_() {
} }
return APIError::OK; return APIError::OK;
} }
void APINoiseFrameHelper::send_explicit_handshake_reject_(const std::string &reason) { void APINoiseFrameHelper::send_explicit_handshake_reject_(const LogString *reason) {
#ifdef USE_STORE_LOG_STR_IN_FLASH
// On ESP8266 with flash strings, we need to use PROGMEM-aware functions
size_t reason_len = strlen_P(reinterpret_cast<PGM_P>(reason));
std::vector<uint8_t> data; std::vector<uint8_t> data;
data.resize(reason.length() + 1); data.resize(reason_len + 1);
data[0] = 0x01; // failure
// Copy error message from PROGMEM
if (reason_len > 0) {
memcpy_P(data.data() + 1, reinterpret_cast<PGM_P>(reason), reason_len);
}
#else
// Normal memory access
const char *reason_str = LOG_STR_ARG(reason);
size_t reason_len = strlen(reason_str);
std::vector<uint8_t> data;
data.resize(reason_len + 1);
data[0] = 0x01; // failure data[0] = 0x01; // failure
// Copy error message in bulk // Copy error message in bulk
if (!reason.empty()) { if (reason_len > 0) {
std::memcpy(data.data() + 1, reason.c_str(), reason.length()); std::memcpy(data.data() + 1, reason_str, reason_len);
} }
#endif
// temporarily remove failed state // temporarily remove failed state
auto orig_state = state_; auto orig_state = state_;
@@ -368,7 +386,8 @@ APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) {
noise_buffer_init(mbuf); noise_buffer_init(mbuf);
noise_buffer_set_inout(mbuf, frame.data(), frame.size(), frame.size()); noise_buffer_set_inout(mbuf, frame.data(), frame.size(), frame.size());
err = noise_cipherstate_decrypt(recv_cipher_, &mbuf); err = noise_cipherstate_decrypt(recv_cipher_, &mbuf);
APIError decrypt_err = handle_noise_error_(err, "noise_cipherstate_decrypt", APIError::CIPHERSTATE_DECRYPT_FAILED); APIError decrypt_err =
handle_noise_error_(err, LOG_STR("noise_cipherstate_decrypt"), APIError::CIPHERSTATE_DECRYPT_FAILED);
if (decrypt_err != APIError::OK) if (decrypt_err != APIError::OK)
return decrypt_err; return decrypt_err;
@@ -450,7 +469,8 @@ APIError APINoiseFrameHelper::write_protobuf_packets(ProtoWriteBuffer buffer, st
4 + packet.payload_size + frame_footer_size_); 4 + packet.payload_size + frame_footer_size_);
int err = noise_cipherstate_encrypt(send_cipher_, &mbuf); int err = noise_cipherstate_encrypt(send_cipher_, &mbuf);
APIError aerr = handle_noise_error_(err, "noise_cipherstate_encrypt", APIError::CIPHERSTATE_ENCRYPT_FAILED); APIError aerr =
handle_noise_error_(err, LOG_STR("noise_cipherstate_encrypt"), APIError::CIPHERSTATE_ENCRYPT_FAILED);
if (aerr != APIError::OK) if (aerr != APIError::OK)
return aerr; return aerr;
@@ -504,25 +524,27 @@ APIError APINoiseFrameHelper::init_handshake_() {
nid_.modifier_ids[0] = NOISE_MODIFIER_PSK0; nid_.modifier_ids[0] = NOISE_MODIFIER_PSK0;
err = noise_handshakestate_new_by_id(&handshake_, &nid_, NOISE_ROLE_RESPONDER); err = noise_handshakestate_new_by_id(&handshake_, &nid_, NOISE_ROLE_RESPONDER);
APIError aerr = handle_noise_error_(err, "noise_handshakestate_new_by_id", APIError::HANDSHAKESTATE_SETUP_FAILED); APIError aerr =
handle_noise_error_(err, LOG_STR("noise_handshakestate_new_by_id"), APIError::HANDSHAKESTATE_SETUP_FAILED);
if (aerr != APIError::OK) if (aerr != APIError::OK)
return aerr; return aerr;
const auto &psk = ctx_->get_psk(); const auto &psk = ctx_->get_psk();
err = noise_handshakestate_set_pre_shared_key(handshake_, psk.data(), psk.size()); err = noise_handshakestate_set_pre_shared_key(handshake_, psk.data(), psk.size());
aerr = handle_noise_error_(err, "noise_handshakestate_set_pre_shared_key", APIError::HANDSHAKESTATE_SETUP_FAILED); aerr = handle_noise_error_(err, LOG_STR("noise_handshakestate_set_pre_shared_key"),
APIError::HANDSHAKESTATE_SETUP_FAILED);
if (aerr != APIError::OK) if (aerr != APIError::OK)
return aerr; return aerr;
err = noise_handshakestate_set_prologue(handshake_, prologue_.data(), prologue_.size()); err = noise_handshakestate_set_prologue(handshake_, prologue_.data(), prologue_.size());
aerr = handle_noise_error_(err, "noise_handshakestate_set_prologue", APIError::HANDSHAKESTATE_SETUP_FAILED); aerr = handle_noise_error_(err, LOG_STR("noise_handshakestate_set_prologue"), APIError::HANDSHAKESTATE_SETUP_FAILED);
if (aerr != APIError::OK) if (aerr != APIError::OK)
return aerr; return aerr;
// set_prologue copies it into handshakestate, so we can get rid of it now // set_prologue copies it into handshakestate, so we can get rid of it now
prologue_ = {}; prologue_ = {};
err = noise_handshakestate_start(handshake_); err = noise_handshakestate_start(handshake_);
aerr = handle_noise_error_(err, "noise_handshakestate_start", APIError::HANDSHAKESTATE_SETUP_FAILED); aerr = handle_noise_error_(err, LOG_STR("noise_handshakestate_start"), APIError::HANDSHAKESTATE_SETUP_FAILED);
if (aerr != APIError::OK) if (aerr != APIError::OK)
return aerr; return aerr;
return APIError::OK; return APIError::OK;
@@ -540,7 +562,8 @@ APIError APINoiseFrameHelper::check_handshake_finished_() {
return APIError::HANDSHAKESTATE_BAD_STATE; return APIError::HANDSHAKESTATE_BAD_STATE;
} }
int err = noise_handshakestate_split(handshake_, &send_cipher_, &recv_cipher_); int err = noise_handshakestate_split(handshake_, &send_cipher_, &recv_cipher_);
APIError aerr = handle_noise_error_(err, "noise_handshakestate_split", APIError::HANDSHAKESTATE_SPLIT_FAILED); APIError aerr =
handle_noise_error_(err, LOG_STR("noise_handshakestate_split"), APIError::HANDSHAKESTATE_SPLIT_FAILED);
if (aerr != APIError::OK) if (aerr != APIError::OK)
return aerr; return aerr;

View File

@@ -32,9 +32,9 @@ class APINoiseFrameHelper final : public APIFrameHelper {
APIError write_frame_(const uint8_t *data, uint16_t len); APIError write_frame_(const uint8_t *data, uint16_t len);
APIError init_handshake_(); APIError init_handshake_();
APIError check_handshake_finished_(); APIError check_handshake_finished_();
void send_explicit_handshake_reject_(const std::string &reason); void send_explicit_handshake_reject_(const LogString *reason);
APIError handle_handshake_frame_error_(APIError aerr); APIError handle_handshake_frame_error_(APIError aerr);
APIError handle_noise_error_(int err, const char *func_name, APIError api_err); APIError handle_noise_error_(int err, const LogString *func_name, APIError api_err);
// Pointers first (4 bytes each) // Pointers first (4 bytes each)
NoiseHandshakeState *handshake_{nullptr}; NoiseHandshakeState *handshake_{nullptr};

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio import asyncio
import base64 import base64
import binascii
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
import datetime import datetime
import functools import functools
@@ -490,7 +491,17 @@ class WizardRequestHandler(BaseHandler):
kwargs = { kwargs = {
k: v k: v
for k, v in json.loads(self.request.body.decode()).items() for k, v in json.loads(self.request.body.decode()).items()
if k in ("name", "platform", "board", "ssid", "psk", "password") if k
in (
"type",
"name",
"platform",
"board",
"ssid",
"psk",
"password",
"file_content",
)
} }
if not kwargs["name"]: if not kwargs["name"]:
self.set_status(422) self.set_status(422)
@@ -498,19 +509,65 @@ class WizardRequestHandler(BaseHandler):
self.write(json.dumps({"error": "Name is required"})) self.write(json.dumps({"error": "Name is required"}))
return return
if "type" not in kwargs:
# Default to basic wizard type for backwards compatibility
kwargs["type"] = "basic"
kwargs["friendly_name"] = kwargs["name"] kwargs["friendly_name"] = kwargs["name"]
kwargs["name"] = friendly_name_slugify(kwargs["friendly_name"]) kwargs["name"] = friendly_name_slugify(kwargs["friendly_name"])
if kwargs["type"] == "basic":
kwargs["ota_password"] = secrets.token_hex(16) kwargs["ota_password"] = secrets.token_hex(16)
noise_psk = secrets.token_bytes(32) noise_psk = secrets.token_bytes(32)
kwargs["api_encryption_key"] = base64.b64encode(noise_psk).decode() kwargs["api_encryption_key"] = base64.b64encode(noise_psk).decode()
elif kwargs["type"] == "upload":
try:
kwargs["file_text"] = base64.b64decode(kwargs["file_content"]).decode(
"utf-8"
)
except (binascii.Error, UnicodeDecodeError):
self.set_status(422)
self.set_header("content-type", "application/json")
self.write(
json.dumps({"error": "The uploaded file is not correctly encoded."})
)
return
elif kwargs["type"] != "empty":
self.set_status(422)
self.set_header("content-type", "application/json")
self.write(
json.dumps(
{"error": f"Invalid wizard type specified: {kwargs['type']}"}
)
)
return
filename = f"{kwargs['name']}.yaml" filename = f"{kwargs['name']}.yaml"
destination = settings.rel_path(filename) destination = settings.rel_path(filename)
wizard.wizard_write(path=destination, **kwargs)
self.set_status(200) # Check if destination file already exists
self.set_header("content-type", "application/json") if os.path.exists(destination):
self.write(json.dumps({"configuration": filename})) self.set_status(409) # Conflict status code
self.finish() self.set_header("content-type", "application/json")
self.write(
json.dumps({"error": f"Configuration file '{filename}' already exists"})
)
self.finish()
return
success = wizard.wizard_write(path=destination, **kwargs)
if success:
self.set_status(200)
self.set_header("content-type", "application/json")
self.write(json.dumps({"configuration": filename}))
self.finish()
else:
self.set_status(500)
self.set_header("content-type", "application/json")
self.write(
json.dumps(
{"error": "Failed to write configuration, see logs for details"}
)
)
self.finish()
class ImportRequestHandler(BaseHandler): class ImportRequestHandler(BaseHandler):

View File

@@ -189,32 +189,45 @@ def wizard_write(path, **kwargs):
from esphome.components.rtl87xx import boards as rtl87xx_boards from esphome.components.rtl87xx import boards as rtl87xx_boards
name = kwargs["name"] name = kwargs["name"]
board = kwargs["board"] if kwargs["type"] == "empty":
file_text = ""
# Will be updated later after editing the file
hardware = "UNKNOWN"
elif kwargs["type"] == "upload":
file_text = kwargs["file_text"]
hardware = "UNKNOWN"
else: # "basic"
board = kwargs["board"]
for key in ("ssid", "psk", "password", "ota_password"): for key in ("ssid", "psk", "password", "ota_password"):
if key in kwargs: if key in kwargs:
kwargs[key] = sanitize_double_quotes(kwargs[key]) kwargs[key] = sanitize_double_quotes(kwargs[key])
if "platform" not in kwargs:
if board in esp8266_boards.BOARDS:
platform = "ESP8266"
elif board in esp32_boards.BOARDS:
platform = "ESP32"
elif board in rp2040_boards.BOARDS:
platform = "RP2040"
elif board in bk72xx_boards.BOARDS:
platform = "BK72XX"
elif board in ln882x_boards.BOARDS:
platform = "LN882X"
elif board in rtl87xx_boards.BOARDS:
platform = "RTL87XX"
else:
safe_print(color(AnsiFore.RED, f'The board "{board}" is unknown.'))
return False
kwargs["platform"] = platform
hardware = kwargs["platform"]
file_text = wizard_file(**kwargs)
if "platform" not in kwargs: # Check if file already exists to prevent overwriting
if board in esp8266_boards.BOARDS: if os.path.exists(path) and os.path.isfile(path):
platform = "ESP8266" safe_print(color(AnsiFore.RED, f'The file "{path}" already exists.'))
elif board in esp32_boards.BOARDS: return False
platform = "ESP32"
elif board in rp2040_boards.BOARDS:
platform = "RP2040"
elif board in bk72xx_boards.BOARDS:
platform = "BK72XX"
elif board in ln882x_boards.BOARDS:
platform = "LN882X"
elif board in rtl87xx_boards.BOARDS:
platform = "RTL87XX"
else:
safe_print(color(AnsiFore.RED, f'The board "{board}" is unknown.'))
return False
kwargs["platform"] = platform
hardware = kwargs["platform"]
write_file(path, wizard_file(**kwargs)) write_file(path, file_text)
storage = StorageJSON.from_wizard(name, name, f"{name}.local", hardware) storage = StorageJSON.from_wizard(name, name, f"{name}.local", hardware)
storage_path = ext_storage_path(os.path.basename(path)) storage_path = ext_storage_path(os.path.basename(path))
storage.save(storage_path) storage.save(storage_path)

View File

@@ -17,6 +17,7 @@ import esphome.wizard as wz
@pytest.fixture @pytest.fixture
def default_config(): def default_config():
return { return {
"type": "basic",
"name": "test-name", "name": "test-name",
"platform": "ESP8266", "platform": "ESP8266",
"board": "esp01_1m", "board": "esp01_1m",
@@ -125,6 +126,47 @@ def test_wizard_write_sets_platform(default_config, tmp_path, monkeypatch):
assert "esp8266:" in generated_config assert "esp8266:" in generated_config
def test_wizard_empty_config(tmp_path, monkeypatch):
"""
The wizard should be able to create an empty configuration
"""
# Given
empty_config = {
"type": "empty",
"name": "test-empty",
}
monkeypatch.setattr(wz, "write_file", MagicMock())
monkeypatch.setattr(CORE, "config_path", os.path.dirname(tmp_path))
# When
wz.wizard_write(tmp_path, **empty_config)
# Then
generated_config = wz.write_file.call_args.args[1]
assert generated_config == ""
def test_wizard_upload_config(tmp_path, monkeypatch):
"""
The wizard should be able to import an base64 encoded configuration
"""
# Given
empty_config = {
"type": "upload",
"name": "test-upload",
"file_text": "# imported file 📁\n\n",
}
monkeypatch.setattr(wz, "write_file", MagicMock())
monkeypatch.setattr(CORE, "config_path", os.path.dirname(tmp_path))
# When
wz.wizard_write(tmp_path, **empty_config)
# Then
generated_config = wz.write_file.call_args.args[1]
assert generated_config == "# imported file 📁\n\n"
def test_wizard_write_defaults_platform_from_board_esp8266( def test_wizard_write_defaults_platform_from_board_esp8266(
default_config, tmp_path, monkeypatch default_config, tmp_path, monkeypatch
): ):
@@ -471,3 +513,22 @@ def test_wizard_requires_valid_ssid(tmpdir, monkeypatch, wizard_answers):
# Then # Then
assert retval == 0 assert retval == 0
def test_wizard_write_protects_existing_config(tmpdir, default_config, monkeypatch):
"""
The wizard_write function should not overwrite existing config files and return False
"""
# Given
config_file = tmpdir.join("test.yaml")
original_content = "# Original config content\n"
config_file.write(original_content)
monkeypatch.setattr(CORE, "config_path", str(tmpdir))
# When
result = wz.wizard_write(str(config_file), **default_config)
# Then
assert result is False # Should return False when file exists
assert config_file.read() == original_content