mirror of
				https://github.com/esphome/esphome.git
				synced 2025-10-31 07:03:55 +00:00 
			
		
		
		
	Merge branch 'esp8266_api_progmem' into integration
This commit is contained in:
		| @@ -112,7 +112,7 @@ void APIConnection::start() { | |||||||
|   APIError err = this->helper_->init(); |   APIError err = this->helper_->init(); | ||||||
|   if (err != APIError::OK) { |   if (err != APIError::OK) { | ||||||
|     on_fatal_error(); |     on_fatal_error(); | ||||||
|     this->log_warning_("Helper init failed", err); |     this->log_warning_(LOG_STR("Helper init failed"), err); | ||||||
|     return; |     return; | ||||||
|   } |   } | ||||||
|   this->client_info_.peername = helper_->getpeername(); |   this->client_info_.peername = helper_->getpeername(); | ||||||
| @@ -159,7 +159,7 @@ void APIConnection::loop() { | |||||||
|         break; |         break; | ||||||
|       } else if (err != APIError::OK) { |       } else if (err != APIError::OK) { | ||||||
|         on_fatal_error(); |         on_fatal_error(); | ||||||
|         this->log_warning_("Reading failed", err); |         this->log_warning_(LOG_STR("Reading failed"), err); | ||||||
|         return; |         return; | ||||||
|       } else { |       } else { | ||||||
|         this->last_traffic_ = now; |         this->last_traffic_ = now; | ||||||
| @@ -1565,7 +1565,7 @@ bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint8_t message_type) { | |||||||
|     return false; |     return false; | ||||||
|   if (err != APIError::OK) { |   if (err != APIError::OK) { | ||||||
|     on_fatal_error(); |     on_fatal_error(); | ||||||
|     this->log_warning_("Packet write failed", err); |     this->log_warning_(LOG_STR("Packet write failed"), err); | ||||||
|     return false; |     return false; | ||||||
|   } |   } | ||||||
|   // Do not set last_traffic_ on send |   // Do not set last_traffic_ on send | ||||||
| @@ -1752,7 +1752,7 @@ void APIConnection::process_batch_() { | |||||||
|                                                        std::span<const PacketInfo>(packet_info, packet_count)); |                                                        std::span<const PacketInfo>(packet_info, packet_count)); | ||||||
|   if (err != APIError::OK && err != APIError::WOULD_BLOCK) { |   if (err != APIError::OK && err != APIError::WOULD_BLOCK) { | ||||||
|     on_fatal_error(); |     on_fatal_error(); | ||||||
|     this->log_warning_("Batch write failed", err); |     this->log_warning_(LOG_STR("Batch write failed"), err); | ||||||
|   } |   } | ||||||
|  |  | ||||||
| #ifdef HAS_PROTO_MESSAGE_DUMP | #ifdef HAS_PROTO_MESSAGE_DUMP | ||||||
| @@ -1830,11 +1830,14 @@ void APIConnection::process_state_subscriptions_() { | |||||||
| } | } | ||||||
| #endif  // USE_API_HOMEASSISTANT_STATES | #endif  // USE_API_HOMEASSISTANT_STATES | ||||||
|  |  | ||||||
| void APIConnection::log_warning_(const char *message, APIError err) { | void APIConnection::log_warning_(const LogString *message, APIError err) { | ||||||
|   ESP_LOGW(TAG, "%s: %s %s errno=%d", this->get_client_combined_info().c_str(), message, api_error_to_str(err), errno); |   ESP_LOGW(TAG, "%s: %s %s errno=%d", this->get_client_combined_info().c_str(), LOG_STR_ARG(message), | ||||||
|  |            LOG_STR_ARG(api_error_to_logstr(err)), errno); | ||||||
| } | } | ||||||
|  |  | ||||||
| void APIConnection::log_socket_operation_failed_(APIError err) { this->log_warning_("Socket operation failed", err); } | void APIConnection::log_socket_operation_failed_(APIError err) { | ||||||
|  |   this->log_warning_(LOG_STR("Socket operation failed"), err); | ||||||
|  | } | ||||||
|  |  | ||||||
| }  // namespace esphome::api | }  // namespace esphome::api | ||||||
| #endif | #endif | ||||||
|   | |||||||
| @@ -732,7 +732,7 @@ class APIConnection final : public APIServerConnection { | |||||||
|   } |   } | ||||||
|  |  | ||||||
|   // Helper function to log API errors with errno |   // Helper function to log API errors with errno | ||||||
|   void log_warning_(const char *message, APIError err); |   void log_warning_(const LogString *message, APIError err); | ||||||
|   // Specific helper for duplicated error message |   // Specific helper for duplicated error message | ||||||
|   void log_socket_operation_failed_(APIError err); |   void log_socket_operation_failed_(APIError err); | ||||||
| }; | }; | ||||||
|   | |||||||
| @@ -23,59 +23,59 @@ static const char *const TAG = "api.frame_helper"; | |||||||
| #define LOG_PACKET_SENDING(data, len) ((void) 0) | #define LOG_PACKET_SENDING(data, len) ((void) 0) | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| const char *api_error_to_str(APIError err) { | const LogString *api_error_to_logstr(APIError err) { | ||||||
|   // not using switch to ensure compiler doesn't try to build a big table out of it |   // not using switch to ensure compiler doesn't try to build a big table out of it | ||||||
|   if (err == APIError::OK) { |   if (err == APIError::OK) { | ||||||
|     return "OK"; |     return LOG_STR("OK"); | ||||||
|   } else if (err == APIError::WOULD_BLOCK) { |   } else if (err == APIError::WOULD_BLOCK) { | ||||||
|     return "WOULD_BLOCK"; |     return LOG_STR("WOULD_BLOCK"); | ||||||
|   } else if (err == APIError::BAD_INDICATOR) { |   } else if (err == APIError::BAD_INDICATOR) { | ||||||
|     return "BAD_INDICATOR"; |     return LOG_STR("BAD_INDICATOR"); | ||||||
|   } else if (err == APIError::BAD_DATA_PACKET) { |   } else if (err == APIError::BAD_DATA_PACKET) { | ||||||
|     return "BAD_DATA_PACKET"; |     return LOG_STR("BAD_DATA_PACKET"); | ||||||
|   } else if (err == APIError::TCP_NODELAY_FAILED) { |   } else if (err == APIError::TCP_NODELAY_FAILED) { | ||||||
|     return "TCP_NODELAY_FAILED"; |     return LOG_STR("TCP_NODELAY_FAILED"); | ||||||
|   } else if (err == APIError::TCP_NONBLOCKING_FAILED) { |   } else if (err == APIError::TCP_NONBLOCKING_FAILED) { | ||||||
|     return "TCP_NONBLOCKING_FAILED"; |     return LOG_STR("TCP_NONBLOCKING_FAILED"); | ||||||
|   } else if (err == APIError::CLOSE_FAILED) { |   } else if (err == APIError::CLOSE_FAILED) { | ||||||
|     return "CLOSE_FAILED"; |     return LOG_STR("CLOSE_FAILED"); | ||||||
|   } else if (err == APIError::SHUTDOWN_FAILED) { |   } else if (err == APIError::SHUTDOWN_FAILED) { | ||||||
|     return "SHUTDOWN_FAILED"; |     return LOG_STR("SHUTDOWN_FAILED"); | ||||||
|   } else if (err == APIError::BAD_STATE) { |   } else if (err == APIError::BAD_STATE) { | ||||||
|     return "BAD_STATE"; |     return LOG_STR("BAD_STATE"); | ||||||
|   } else if (err == APIError::BAD_ARG) { |   } else if (err == APIError::BAD_ARG) { | ||||||
|     return "BAD_ARG"; |     return LOG_STR("BAD_ARG"); | ||||||
|   } else if (err == APIError::SOCKET_READ_FAILED) { |   } else if (err == APIError::SOCKET_READ_FAILED) { | ||||||
|     return "SOCKET_READ_FAILED"; |     return LOG_STR("SOCKET_READ_FAILED"); | ||||||
|   } else if (err == APIError::SOCKET_WRITE_FAILED) { |   } else if (err == APIError::SOCKET_WRITE_FAILED) { | ||||||
|     return "SOCKET_WRITE_FAILED"; |     return LOG_STR("SOCKET_WRITE_FAILED"); | ||||||
|   } else if (err == APIError::OUT_OF_MEMORY) { |   } else if (err == APIError::OUT_OF_MEMORY) { | ||||||
|     return "OUT_OF_MEMORY"; |     return LOG_STR("OUT_OF_MEMORY"); | ||||||
|   } else if (err == APIError::CONNECTION_CLOSED) { |   } else if (err == APIError::CONNECTION_CLOSED) { | ||||||
|     return "CONNECTION_CLOSED"; |     return LOG_STR("CONNECTION_CLOSED"); | ||||||
|   } |   } | ||||||
| #ifdef USE_API_NOISE | #ifdef USE_API_NOISE | ||||||
|   else if (err == APIError::BAD_HANDSHAKE_PACKET_LEN) { |   else if (err == APIError::BAD_HANDSHAKE_PACKET_LEN) { | ||||||
|     return "BAD_HANDSHAKE_PACKET_LEN"; |     return LOG_STR("BAD_HANDSHAKE_PACKET_LEN"); | ||||||
|   } else if (err == APIError::HANDSHAKESTATE_READ_FAILED) { |   } else if (err == APIError::HANDSHAKESTATE_READ_FAILED) { | ||||||
|     return "HANDSHAKESTATE_READ_FAILED"; |     return LOG_STR("HANDSHAKESTATE_READ_FAILED"); | ||||||
|   } else if (err == APIError::HANDSHAKESTATE_WRITE_FAILED) { |   } else if (err == APIError::HANDSHAKESTATE_WRITE_FAILED) { | ||||||
|     return "HANDSHAKESTATE_WRITE_FAILED"; |     return LOG_STR("HANDSHAKESTATE_WRITE_FAILED"); | ||||||
|   } else if (err == APIError::HANDSHAKESTATE_BAD_STATE) { |   } else if (err == APIError::HANDSHAKESTATE_BAD_STATE) { | ||||||
|     return "HANDSHAKESTATE_BAD_STATE"; |     return LOG_STR("HANDSHAKESTATE_BAD_STATE"); | ||||||
|   } else if (err == APIError::CIPHERSTATE_DECRYPT_FAILED) { |   } else if (err == APIError::CIPHERSTATE_DECRYPT_FAILED) { | ||||||
|     return "CIPHERSTATE_DECRYPT_FAILED"; |     return LOG_STR("CIPHERSTATE_DECRYPT_FAILED"); | ||||||
|   } else if (err == APIError::CIPHERSTATE_ENCRYPT_FAILED) { |   } else if (err == APIError::CIPHERSTATE_ENCRYPT_FAILED) { | ||||||
|     return "CIPHERSTATE_ENCRYPT_FAILED"; |     return LOG_STR("CIPHERSTATE_ENCRYPT_FAILED"); | ||||||
|   } else if (err == APIError::HANDSHAKESTATE_SETUP_FAILED) { |   } else if (err == APIError::HANDSHAKESTATE_SETUP_FAILED) { | ||||||
|     return "HANDSHAKESTATE_SETUP_FAILED"; |     return LOG_STR("HANDSHAKESTATE_SETUP_FAILED"); | ||||||
|   } else if (err == APIError::HANDSHAKESTATE_SPLIT_FAILED) { |   } else if (err == APIError::HANDSHAKESTATE_SPLIT_FAILED) { | ||||||
|     return "HANDSHAKESTATE_SPLIT_FAILED"; |     return LOG_STR("HANDSHAKESTATE_SPLIT_FAILED"); | ||||||
|   } else if (err == APIError::BAD_HANDSHAKE_ERROR_BYTE) { |   } else if (err == APIError::BAD_HANDSHAKE_ERROR_BYTE) { | ||||||
|     return "BAD_HANDSHAKE_ERROR_BYTE"; |     return LOG_STR("BAD_HANDSHAKE_ERROR_BYTE"); | ||||||
|   } |   } | ||||||
| #endif | #endif | ||||||
|   return "UNKNOWN"; |   return LOG_STR("UNKNOWN"); | ||||||
| } | } | ||||||
|  |  | ||||||
| // Default implementation for loop - handles sending buffered data | // Default implementation for loop - handles sending buffered data | ||||||
|   | |||||||
| @@ -66,7 +66,7 @@ enum class APIError : uint16_t { | |||||||
| #endif | #endif | ||||||
| }; | }; | ||||||
|  |  | ||||||
| const char *api_error_to_str(APIError err); | const LogString *api_error_to_logstr(APIError err); | ||||||
|  |  | ||||||
| class APIFrameHelper { | class APIFrameHelper { | ||||||
|  public: |  public: | ||||||
|   | |||||||
| @@ -27,42 +27,42 @@ static constexpr size_t PROLOGUE_INIT_LEN = 12;  // strlen("NoiseAPIInit") | |||||||
| #endif | #endif | ||||||
|  |  | ||||||
| /// Convert a noise error code to a readable error | /// Convert a noise error code to a readable error | ||||||
| std::string noise_err_to_str(int err) { | const LogString *noise_err_to_logstr(int err) { | ||||||
|   if (err == NOISE_ERROR_NO_MEMORY) |   if (err == NOISE_ERROR_NO_MEMORY) | ||||||
|     return "NO_MEMORY"; |     return LOG_STR("NO_MEMORY"); | ||||||
|   if (err == NOISE_ERROR_UNKNOWN_ID) |   if (err == NOISE_ERROR_UNKNOWN_ID) | ||||||
|     return "UNKNOWN_ID"; |     return LOG_STR("UNKNOWN_ID"); | ||||||
|   if (err == NOISE_ERROR_UNKNOWN_NAME) |   if (err == NOISE_ERROR_UNKNOWN_NAME) | ||||||
|     return "UNKNOWN_NAME"; |     return LOG_STR("UNKNOWN_NAME"); | ||||||
|   if (err == NOISE_ERROR_MAC_FAILURE) |   if (err == NOISE_ERROR_MAC_FAILURE) | ||||||
|     return "MAC_FAILURE"; |     return LOG_STR("MAC_FAILURE"); | ||||||
|   if (err == NOISE_ERROR_NOT_APPLICABLE) |   if (err == NOISE_ERROR_NOT_APPLICABLE) | ||||||
|     return "NOT_APPLICABLE"; |     return LOG_STR("NOT_APPLICABLE"); | ||||||
|   if (err == NOISE_ERROR_SYSTEM) |   if (err == NOISE_ERROR_SYSTEM) | ||||||
|     return "SYSTEM"; |     return LOG_STR("SYSTEM"); | ||||||
|   if (err == NOISE_ERROR_REMOTE_KEY_REQUIRED) |   if (err == NOISE_ERROR_REMOTE_KEY_REQUIRED) | ||||||
|     return "REMOTE_KEY_REQUIRED"; |     return LOG_STR("REMOTE_KEY_REQUIRED"); | ||||||
|   if (err == NOISE_ERROR_LOCAL_KEY_REQUIRED) |   if (err == NOISE_ERROR_LOCAL_KEY_REQUIRED) | ||||||
|     return "LOCAL_KEY_REQUIRED"; |     return LOG_STR("LOCAL_KEY_REQUIRED"); | ||||||
|   if (err == NOISE_ERROR_PSK_REQUIRED) |   if (err == NOISE_ERROR_PSK_REQUIRED) | ||||||
|     return "PSK_REQUIRED"; |     return LOG_STR("PSK_REQUIRED"); | ||||||
|   if (err == NOISE_ERROR_INVALID_LENGTH) |   if (err == NOISE_ERROR_INVALID_LENGTH) | ||||||
|     return "INVALID_LENGTH"; |     return LOG_STR("INVALID_LENGTH"); | ||||||
|   if (err == NOISE_ERROR_INVALID_PARAM) |   if (err == NOISE_ERROR_INVALID_PARAM) | ||||||
|     return "INVALID_PARAM"; |     return LOG_STR("INVALID_PARAM"); | ||||||
|   if (err == NOISE_ERROR_INVALID_STATE) |   if (err == NOISE_ERROR_INVALID_STATE) | ||||||
|     return "INVALID_STATE"; |     return LOG_STR("INVALID_STATE"); | ||||||
|   if (err == NOISE_ERROR_INVALID_NONCE) |   if (err == NOISE_ERROR_INVALID_NONCE) | ||||||
|     return "INVALID_NONCE"; |     return LOG_STR("INVALID_NONCE"); | ||||||
|   if (err == NOISE_ERROR_INVALID_PRIVATE_KEY) |   if (err == NOISE_ERROR_INVALID_PRIVATE_KEY) | ||||||
|     return "INVALID_PRIVATE_KEY"; |     return LOG_STR("INVALID_PRIVATE_KEY"); | ||||||
|   if (err == NOISE_ERROR_INVALID_PUBLIC_KEY) |   if (err == NOISE_ERROR_INVALID_PUBLIC_KEY) | ||||||
|     return "INVALID_PUBLIC_KEY"; |     return LOG_STR("INVALID_PUBLIC_KEY"); | ||||||
|   if (err == NOISE_ERROR_INVALID_FORMAT) |   if (err == NOISE_ERROR_INVALID_FORMAT) | ||||||
|     return "INVALID_FORMAT"; |     return LOG_STR("INVALID_FORMAT"); | ||||||
|   if (err == NOISE_ERROR_INVALID_SIGNATURE) |   if (err == NOISE_ERROR_INVALID_SIGNATURE) | ||||||
|     return "INVALID_SIGNATURE"; |     return LOG_STR("INVALID_SIGNATURE"); | ||||||
|   return to_string(err); |   return LOG_STR("UNKNOWN"); | ||||||
| } | } | ||||||
|  |  | ||||||
| /// Initialize the frame helper, returns OK if successful. | /// Initialize the frame helper, returns OK if successful. | ||||||
| @@ -83,18 +83,18 @@ APIError APINoiseFrameHelper::init() { | |||||||
| // Helper for handling handshake frame errors | // Helper for handling handshake frame errors | ||||||
| APIError APINoiseFrameHelper::handle_handshake_frame_error_(APIError aerr) { | APIError APINoiseFrameHelper::handle_handshake_frame_error_(APIError aerr) { | ||||||
|   if (aerr == APIError::BAD_INDICATOR) { |   if (aerr == APIError::BAD_INDICATOR) { | ||||||
|     send_explicit_handshake_reject_("Bad indicator byte"); |     send_explicit_handshake_reject_(LOG_STR("Bad indicator byte")); | ||||||
|   } else if (aerr == APIError::BAD_HANDSHAKE_PACKET_LEN) { |   } else if (aerr == APIError::BAD_HANDSHAKE_PACKET_LEN) { | ||||||
|     send_explicit_handshake_reject_("Bad handshake packet len"); |     send_explicit_handshake_reject_(LOG_STR("Bad handshake packet len")); | ||||||
|   } |   } | ||||||
|   return aerr; |   return aerr; | ||||||
| } | } | ||||||
|  |  | ||||||
| // Helper for handling noise library errors | // Helper for handling noise library errors | ||||||
| APIError APINoiseFrameHelper::handle_noise_error_(int err, const char *func_name, APIError api_err) { | APIError APINoiseFrameHelper::handle_noise_error_(int err, const LogString *func_name, APIError api_err) { | ||||||
|   if (err != 0) { |   if (err != 0) { | ||||||
|     state_ = State::FAILED; |     state_ = State::FAILED; | ||||||
|     HELPER_LOG("%s failed: %s", func_name, noise_err_to_str(err).c_str()); |     HELPER_LOG("%s failed: %s", LOG_STR_ARG(func_name), LOG_STR_ARG(noise_err_to_logstr(err))); | ||||||
|     return api_err; |     return api_err; | ||||||
|   } |   } | ||||||
|   return APIError::OK; |   return APIError::OK; | ||||||
| @@ -279,11 +279,11 @@ APIError APINoiseFrameHelper::state_action_() { | |||||||
|       } |       } | ||||||
|  |  | ||||||
|       if (frame.empty()) { |       if (frame.empty()) { | ||||||
|         send_explicit_handshake_reject_("Empty handshake message"); |         send_explicit_handshake_reject_(LOG_STR("Empty handshake message")); | ||||||
|         return APIError::BAD_HANDSHAKE_ERROR_BYTE; |         return APIError::BAD_HANDSHAKE_ERROR_BYTE; | ||||||
|       } else if (frame[0] != 0x00) { |       } else if (frame[0] != 0x00) { | ||||||
|         HELPER_LOG("Bad handshake error byte: %u", frame[0]); |         HELPER_LOG("Bad handshake error byte: %u", frame[0]); | ||||||
|         send_explicit_handshake_reject_("Bad handshake error byte"); |         send_explicit_handshake_reject_(LOG_STR("Bad handshake error byte")); | ||||||
|         return APIError::BAD_HANDSHAKE_ERROR_BYTE; |         return APIError::BAD_HANDSHAKE_ERROR_BYTE; | ||||||
|       } |       } | ||||||
|  |  | ||||||
| @@ -293,8 +293,10 @@ APIError APINoiseFrameHelper::state_action_() { | |||||||
|       err = noise_handshakestate_read_message(handshake_, &mbuf, nullptr); |       err = noise_handshakestate_read_message(handshake_, &mbuf, nullptr); | ||||||
|       if (err != 0) { |       if (err != 0) { | ||||||
|         // Special handling for MAC failure |         // Special handling for MAC failure | ||||||
|         send_explicit_handshake_reject_(err == NOISE_ERROR_MAC_FAILURE ? "Handshake MAC failure" : "Handshake error"); |         send_explicit_handshake_reject_(err == NOISE_ERROR_MAC_FAILURE ? LOG_STR("Handshake MAC failure") | ||||||
|         return handle_noise_error_(err, "noise_handshakestate_read_message", APIError::HANDSHAKESTATE_READ_FAILED); |                                                                        : LOG_STR("Handshake error")); | ||||||
|  |         return handle_noise_error_(err, LOG_STR("noise_handshakestate_read_message"), | ||||||
|  |                                    APIError::HANDSHAKESTATE_READ_FAILED); | ||||||
|       } |       } | ||||||
|  |  | ||||||
|       aerr = check_handshake_finished_(); |       aerr = check_handshake_finished_(); | ||||||
| @@ -307,8 +309,8 @@ APIError APINoiseFrameHelper::state_action_() { | |||||||
|       noise_buffer_set_output(mbuf, buffer + 1, sizeof(buffer) - 1); |       noise_buffer_set_output(mbuf, buffer + 1, sizeof(buffer) - 1); | ||||||
|  |  | ||||||
|       err = noise_handshakestate_write_message(handshake_, &mbuf, nullptr); |       err = noise_handshakestate_write_message(handshake_, &mbuf, nullptr); | ||||||
|       APIError aerr_write = |       APIError aerr_write = handle_noise_error_(err, LOG_STR("noise_handshakestate_write_message"), | ||||||
|           handle_noise_error_(err, "noise_handshakestate_write_message", APIError::HANDSHAKESTATE_WRITE_FAILED); |                                                 APIError::HANDSHAKESTATE_WRITE_FAILED); | ||||||
|       if (aerr_write != APIError::OK) |       if (aerr_write != APIError::OK) | ||||||
|         return aerr_write; |         return aerr_write; | ||||||
|       buffer[0] = 0x00;  // success |       buffer[0] = 0x00;  // success | ||||||
| @@ -331,15 +333,31 @@ APIError APINoiseFrameHelper::state_action_() { | |||||||
|   } |   } | ||||||
|   return APIError::OK; |   return APIError::OK; | ||||||
| } | } | ||||||
| void APINoiseFrameHelper::send_explicit_handshake_reject_(const std::string &reason) { | void APINoiseFrameHelper::send_explicit_handshake_reject_(const LogString *reason) { | ||||||
|  | #ifdef USE_STORE_LOG_STR_IN_FLASH | ||||||
|  |   // On ESP8266 with flash strings, we need to use PROGMEM-aware functions | ||||||
|  |   size_t reason_len = strlen_P(reinterpret_cast<PGM_P>(reason)); | ||||||
|   std::vector<uint8_t> data; |   std::vector<uint8_t> data; | ||||||
|   data.resize(reason.length() + 1); |   data.resize(reason_len + 1); | ||||||
|  |   data[0] = 0x01;  // failure | ||||||
|  |  | ||||||
|  |   // Copy error message from PROGMEM | ||||||
|  |   if (reason_len > 0) { | ||||||
|  |     memcpy_P(data.data() + 1, reinterpret_cast<PGM_P>(reason), reason_len); | ||||||
|  |   } | ||||||
|  | #else | ||||||
|  |   // Normal memory access | ||||||
|  |   const char *reason_str = LOG_STR_ARG(reason); | ||||||
|  |   size_t reason_len = strlen(reason_str); | ||||||
|  |   std::vector<uint8_t> data; | ||||||
|  |   data.resize(reason_len + 1); | ||||||
|   data[0] = 0x01;  // failure |   data[0] = 0x01;  // failure | ||||||
|  |  | ||||||
|   // Copy error message in bulk |   // Copy error message in bulk | ||||||
|   if (!reason.empty()) { |   if (reason_len > 0) { | ||||||
|     std::memcpy(data.data() + 1, reason.c_str(), reason.length()); |     std::memcpy(data.data() + 1, reason_str, reason_len); | ||||||
|   } |   } | ||||||
|  | #endif | ||||||
|  |  | ||||||
|   // temporarily remove failed state |   // temporarily remove failed state | ||||||
|   auto orig_state = state_; |   auto orig_state = state_; | ||||||
| @@ -368,7 +386,8 @@ APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) { | |||||||
|   noise_buffer_init(mbuf); |   noise_buffer_init(mbuf); | ||||||
|   noise_buffer_set_inout(mbuf, frame.data(), frame.size(), frame.size()); |   noise_buffer_set_inout(mbuf, frame.data(), frame.size(), frame.size()); | ||||||
|   err = noise_cipherstate_decrypt(recv_cipher_, &mbuf); |   err = noise_cipherstate_decrypt(recv_cipher_, &mbuf); | ||||||
|   APIError decrypt_err = handle_noise_error_(err, "noise_cipherstate_decrypt", APIError::CIPHERSTATE_DECRYPT_FAILED); |   APIError decrypt_err = | ||||||
|  |       handle_noise_error_(err, LOG_STR("noise_cipherstate_decrypt"), APIError::CIPHERSTATE_DECRYPT_FAILED); | ||||||
|   if (decrypt_err != APIError::OK) |   if (decrypt_err != APIError::OK) | ||||||
|     return decrypt_err; |     return decrypt_err; | ||||||
|  |  | ||||||
| @@ -450,7 +469,8 @@ APIError APINoiseFrameHelper::write_protobuf_packets(ProtoWriteBuffer buffer, st | |||||||
|                            4 + packet.payload_size + frame_footer_size_); |                            4 + packet.payload_size + frame_footer_size_); | ||||||
|  |  | ||||||
|     int err = noise_cipherstate_encrypt(send_cipher_, &mbuf); |     int err = noise_cipherstate_encrypt(send_cipher_, &mbuf); | ||||||
|     APIError aerr = handle_noise_error_(err, "noise_cipherstate_encrypt", APIError::CIPHERSTATE_ENCRYPT_FAILED); |     APIError aerr = | ||||||
|  |         handle_noise_error_(err, LOG_STR("noise_cipherstate_encrypt"), APIError::CIPHERSTATE_ENCRYPT_FAILED); | ||||||
|     if (aerr != APIError::OK) |     if (aerr != APIError::OK) | ||||||
|       return aerr; |       return aerr; | ||||||
|  |  | ||||||
| @@ -504,25 +524,27 @@ APIError APINoiseFrameHelper::init_handshake_() { | |||||||
|   nid_.modifier_ids[0] = NOISE_MODIFIER_PSK0; |   nid_.modifier_ids[0] = NOISE_MODIFIER_PSK0; | ||||||
|  |  | ||||||
|   err = noise_handshakestate_new_by_id(&handshake_, &nid_, NOISE_ROLE_RESPONDER); |   err = noise_handshakestate_new_by_id(&handshake_, &nid_, NOISE_ROLE_RESPONDER); | ||||||
|   APIError aerr = handle_noise_error_(err, "noise_handshakestate_new_by_id", APIError::HANDSHAKESTATE_SETUP_FAILED); |   APIError aerr = | ||||||
|  |       handle_noise_error_(err, LOG_STR("noise_handshakestate_new_by_id"), APIError::HANDSHAKESTATE_SETUP_FAILED); | ||||||
|   if (aerr != APIError::OK) |   if (aerr != APIError::OK) | ||||||
|     return aerr; |     return aerr; | ||||||
|  |  | ||||||
|   const auto &psk = ctx_->get_psk(); |   const auto &psk = ctx_->get_psk(); | ||||||
|   err = noise_handshakestate_set_pre_shared_key(handshake_, psk.data(), psk.size()); |   err = noise_handshakestate_set_pre_shared_key(handshake_, psk.data(), psk.size()); | ||||||
|   aerr = handle_noise_error_(err, "noise_handshakestate_set_pre_shared_key", APIError::HANDSHAKESTATE_SETUP_FAILED); |   aerr = handle_noise_error_(err, LOG_STR("noise_handshakestate_set_pre_shared_key"), | ||||||
|  |                              APIError::HANDSHAKESTATE_SETUP_FAILED); | ||||||
|   if (aerr != APIError::OK) |   if (aerr != APIError::OK) | ||||||
|     return aerr; |     return aerr; | ||||||
|  |  | ||||||
|   err = noise_handshakestate_set_prologue(handshake_, prologue_.data(), prologue_.size()); |   err = noise_handshakestate_set_prologue(handshake_, prologue_.data(), prologue_.size()); | ||||||
|   aerr = handle_noise_error_(err, "noise_handshakestate_set_prologue", APIError::HANDSHAKESTATE_SETUP_FAILED); |   aerr = handle_noise_error_(err, LOG_STR("noise_handshakestate_set_prologue"), APIError::HANDSHAKESTATE_SETUP_FAILED); | ||||||
|   if (aerr != APIError::OK) |   if (aerr != APIError::OK) | ||||||
|     return aerr; |     return aerr; | ||||||
|   // set_prologue copies it into handshakestate, so we can get rid of it now |   // set_prologue copies it into handshakestate, so we can get rid of it now | ||||||
|   prologue_ = {}; |   prologue_ = {}; | ||||||
|  |  | ||||||
|   err = noise_handshakestate_start(handshake_); |   err = noise_handshakestate_start(handshake_); | ||||||
|   aerr = handle_noise_error_(err, "noise_handshakestate_start", APIError::HANDSHAKESTATE_SETUP_FAILED); |   aerr = handle_noise_error_(err, LOG_STR("noise_handshakestate_start"), APIError::HANDSHAKESTATE_SETUP_FAILED); | ||||||
|   if (aerr != APIError::OK) |   if (aerr != APIError::OK) | ||||||
|     return aerr; |     return aerr; | ||||||
|   return APIError::OK; |   return APIError::OK; | ||||||
| @@ -540,7 +562,8 @@ APIError APINoiseFrameHelper::check_handshake_finished_() { | |||||||
|     return APIError::HANDSHAKESTATE_BAD_STATE; |     return APIError::HANDSHAKESTATE_BAD_STATE; | ||||||
|   } |   } | ||||||
|   int err = noise_handshakestate_split(handshake_, &send_cipher_, &recv_cipher_); |   int err = noise_handshakestate_split(handshake_, &send_cipher_, &recv_cipher_); | ||||||
|   APIError aerr = handle_noise_error_(err, "noise_handshakestate_split", APIError::HANDSHAKESTATE_SPLIT_FAILED); |   APIError aerr = | ||||||
|  |       handle_noise_error_(err, LOG_STR("noise_handshakestate_split"), APIError::HANDSHAKESTATE_SPLIT_FAILED); | ||||||
|   if (aerr != APIError::OK) |   if (aerr != APIError::OK) | ||||||
|     return aerr; |     return aerr; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -32,9 +32,9 @@ class APINoiseFrameHelper final : public APIFrameHelper { | |||||||
|   APIError write_frame_(const uint8_t *data, uint16_t len); |   APIError write_frame_(const uint8_t *data, uint16_t len); | ||||||
|   APIError init_handshake_(); |   APIError init_handshake_(); | ||||||
|   APIError check_handshake_finished_(); |   APIError check_handshake_finished_(); | ||||||
|   void send_explicit_handshake_reject_(const std::string &reason); |   void send_explicit_handshake_reject_(const LogString *reason); | ||||||
|   APIError handle_handshake_frame_error_(APIError aerr); |   APIError handle_handshake_frame_error_(APIError aerr); | ||||||
|   APIError handle_noise_error_(int err, const char *func_name, APIError api_err); |   APIError handle_noise_error_(int err, const LogString *func_name, APIError api_err); | ||||||
|  |  | ||||||
|   // Pointers first (4 bytes each) |   // Pointers first (4 bytes each) | ||||||
|   NoiseHandshakeState *handshake_{nullptr}; |   NoiseHandshakeState *handshake_{nullptr}; | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ from __future__ import annotations | |||||||
|  |  | ||||||
| import asyncio | import asyncio | ||||||
| import base64 | import base64 | ||||||
|  | import binascii | ||||||
| from collections.abc import Callable, Iterable | from collections.abc import Callable, Iterable | ||||||
| import datetime | import datetime | ||||||
| import functools | import functools | ||||||
| @@ -490,7 +491,17 @@ class WizardRequestHandler(BaseHandler): | |||||||
|         kwargs = { |         kwargs = { | ||||||
|             k: v |             k: v | ||||||
|             for k, v in json.loads(self.request.body.decode()).items() |             for k, v in json.loads(self.request.body.decode()).items() | ||||||
|             if k in ("name", "platform", "board", "ssid", "psk", "password") |             if k | ||||||
|  |             in ( | ||||||
|  |                 "type", | ||||||
|  |                 "name", | ||||||
|  |                 "platform", | ||||||
|  |                 "board", | ||||||
|  |                 "ssid", | ||||||
|  |                 "psk", | ||||||
|  |                 "password", | ||||||
|  |                 "file_content", | ||||||
|  |             ) | ||||||
|         } |         } | ||||||
|         if not kwargs["name"]: |         if not kwargs["name"]: | ||||||
|             self.set_status(422) |             self.set_status(422) | ||||||
| @@ -498,19 +509,65 @@ class WizardRequestHandler(BaseHandler): | |||||||
|             self.write(json.dumps({"error": "Name is required"})) |             self.write(json.dumps({"error": "Name is required"})) | ||||||
|             return |             return | ||||||
|  |  | ||||||
|  |         if "type" not in kwargs: | ||||||
|  |             # Default to basic wizard type for backwards compatibility | ||||||
|  |             kwargs["type"] = "basic" | ||||||
|  |  | ||||||
|         kwargs["friendly_name"] = kwargs["name"] |         kwargs["friendly_name"] = kwargs["name"] | ||||||
|         kwargs["name"] = friendly_name_slugify(kwargs["friendly_name"]) |         kwargs["name"] = friendly_name_slugify(kwargs["friendly_name"]) | ||||||
|  |         if kwargs["type"] == "basic": | ||||||
|         kwargs["ota_password"] = secrets.token_hex(16) |             kwargs["ota_password"] = secrets.token_hex(16) | ||||||
|         noise_psk = secrets.token_bytes(32) |             noise_psk = secrets.token_bytes(32) | ||||||
|         kwargs["api_encryption_key"] = base64.b64encode(noise_psk).decode() |             kwargs["api_encryption_key"] = base64.b64encode(noise_psk).decode() | ||||||
|  |         elif kwargs["type"] == "upload": | ||||||
|  |             try: | ||||||
|  |                 kwargs["file_text"] = base64.b64decode(kwargs["file_content"]).decode( | ||||||
|  |                     "utf-8" | ||||||
|  |                 ) | ||||||
|  |             except (binascii.Error, UnicodeDecodeError): | ||||||
|  |                 self.set_status(422) | ||||||
|  |                 self.set_header("content-type", "application/json") | ||||||
|  |                 self.write( | ||||||
|  |                     json.dumps({"error": "The uploaded file is not correctly encoded."}) | ||||||
|  |                 ) | ||||||
|  |                 return | ||||||
|  |         elif kwargs["type"] != "empty": | ||||||
|  |             self.set_status(422) | ||||||
|  |             self.set_header("content-type", "application/json") | ||||||
|  |             self.write( | ||||||
|  |                 json.dumps( | ||||||
|  |                     {"error": f"Invalid wizard type specified: {kwargs['type']}"} | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |             return | ||||||
|         filename = f"{kwargs['name']}.yaml" |         filename = f"{kwargs['name']}.yaml" | ||||||
|         destination = settings.rel_path(filename) |         destination = settings.rel_path(filename) | ||||||
|         wizard.wizard_write(path=destination, **kwargs) |  | ||||||
|         self.set_status(200) |         # Check if destination file already exists | ||||||
|         self.set_header("content-type", "application/json") |         if os.path.exists(destination): | ||||||
|         self.write(json.dumps({"configuration": filename})) |             self.set_status(409)  # Conflict status code | ||||||
|         self.finish() |             self.set_header("content-type", "application/json") | ||||||
|  |             self.write( | ||||||
|  |                 json.dumps({"error": f"Configuration file '{filename}' already exists"}) | ||||||
|  |             ) | ||||||
|  |             self.finish() | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |         success = wizard.wizard_write(path=destination, **kwargs) | ||||||
|  |         if success: | ||||||
|  |             self.set_status(200) | ||||||
|  |             self.set_header("content-type", "application/json") | ||||||
|  |             self.write(json.dumps({"configuration": filename})) | ||||||
|  |             self.finish() | ||||||
|  |         else: | ||||||
|  |             self.set_status(500) | ||||||
|  |             self.set_header("content-type", "application/json") | ||||||
|  |             self.write( | ||||||
|  |                 json.dumps( | ||||||
|  |                     {"error": "Failed to write configuration, see logs for details"} | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |             self.finish() | ||||||
|  |  | ||||||
|  |  | ||||||
| class ImportRequestHandler(BaseHandler): | class ImportRequestHandler(BaseHandler): | ||||||
|   | |||||||
| @@ -189,32 +189,45 @@ def wizard_write(path, **kwargs): | |||||||
|     from esphome.components.rtl87xx import boards as rtl87xx_boards |     from esphome.components.rtl87xx import boards as rtl87xx_boards | ||||||
|  |  | ||||||
|     name = kwargs["name"] |     name = kwargs["name"] | ||||||
|     board = kwargs["board"] |     if kwargs["type"] == "empty": | ||||||
|  |         file_text = "" | ||||||
|  |         # Will be updated later after editing the file | ||||||
|  |         hardware = "UNKNOWN" | ||||||
|  |     elif kwargs["type"] == "upload": | ||||||
|  |         file_text = kwargs["file_text"] | ||||||
|  |         hardware = "UNKNOWN" | ||||||
|  |     else:  # "basic" | ||||||
|  |         board = kwargs["board"] | ||||||
|  |  | ||||||
|     for key in ("ssid", "psk", "password", "ota_password"): |         for key in ("ssid", "psk", "password", "ota_password"): | ||||||
|         if key in kwargs: |             if key in kwargs: | ||||||
|             kwargs[key] = sanitize_double_quotes(kwargs[key]) |                 kwargs[key] = sanitize_double_quotes(kwargs[key]) | ||||||
|  |         if "platform" not in kwargs: | ||||||
|  |             if board in esp8266_boards.BOARDS: | ||||||
|  |                 platform = "ESP8266" | ||||||
|  |             elif board in esp32_boards.BOARDS: | ||||||
|  |                 platform = "ESP32" | ||||||
|  |             elif board in rp2040_boards.BOARDS: | ||||||
|  |                 platform = "RP2040" | ||||||
|  |             elif board in bk72xx_boards.BOARDS: | ||||||
|  |                 platform = "BK72XX" | ||||||
|  |             elif board in ln882x_boards.BOARDS: | ||||||
|  |                 platform = "LN882X" | ||||||
|  |             elif board in rtl87xx_boards.BOARDS: | ||||||
|  |                 platform = "RTL87XX" | ||||||
|  |             else: | ||||||
|  |                 safe_print(color(AnsiFore.RED, f'The board "{board}" is unknown.')) | ||||||
|  |                 return False | ||||||
|  |             kwargs["platform"] = platform | ||||||
|  |         hardware = kwargs["platform"] | ||||||
|  |         file_text = wizard_file(**kwargs) | ||||||
|  |  | ||||||
|     if "platform" not in kwargs: |     # Check if file already exists to prevent overwriting | ||||||
|         if board in esp8266_boards.BOARDS: |     if os.path.exists(path) and os.path.isfile(path): | ||||||
|             platform = "ESP8266" |         safe_print(color(AnsiFore.RED, f'The file "{path}" already exists.')) | ||||||
|         elif board in esp32_boards.BOARDS: |         return False | ||||||
|             platform = "ESP32" |  | ||||||
|         elif board in rp2040_boards.BOARDS: |  | ||||||
|             platform = "RP2040" |  | ||||||
|         elif board in bk72xx_boards.BOARDS: |  | ||||||
|             platform = "BK72XX" |  | ||||||
|         elif board in ln882x_boards.BOARDS: |  | ||||||
|             platform = "LN882X" |  | ||||||
|         elif board in rtl87xx_boards.BOARDS: |  | ||||||
|             platform = "RTL87XX" |  | ||||||
|         else: |  | ||||||
|             safe_print(color(AnsiFore.RED, f'The board "{board}" is unknown.')) |  | ||||||
|             return False |  | ||||||
|         kwargs["platform"] = platform |  | ||||||
|     hardware = kwargs["platform"] |  | ||||||
|  |  | ||||||
|     write_file(path, wizard_file(**kwargs)) |     write_file(path, file_text) | ||||||
|     storage = StorageJSON.from_wizard(name, name, f"{name}.local", hardware) |     storage = StorageJSON.from_wizard(name, name, f"{name}.local", hardware) | ||||||
|     storage_path = ext_storage_path(os.path.basename(path)) |     storage_path = ext_storage_path(os.path.basename(path)) | ||||||
|     storage.save(storage_path) |     storage.save(storage_path) | ||||||
|   | |||||||
| @@ -17,6 +17,7 @@ import esphome.wizard as wz | |||||||
| @pytest.fixture | @pytest.fixture | ||||||
| def default_config(): | def default_config(): | ||||||
|     return { |     return { | ||||||
|  |         "type": "basic", | ||||||
|         "name": "test-name", |         "name": "test-name", | ||||||
|         "platform": "ESP8266", |         "platform": "ESP8266", | ||||||
|         "board": "esp01_1m", |         "board": "esp01_1m", | ||||||
| @@ -125,6 +126,47 @@ def test_wizard_write_sets_platform(default_config, tmp_path, monkeypatch): | |||||||
|     assert "esp8266:" in generated_config |     assert "esp8266:" in generated_config | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_wizard_empty_config(tmp_path, monkeypatch): | ||||||
|  |     """ | ||||||
|  |     The wizard should be able to create an empty configuration | ||||||
|  |     """ | ||||||
|  |     # Given | ||||||
|  |     empty_config = { | ||||||
|  |         "type": "empty", | ||||||
|  |         "name": "test-empty", | ||||||
|  |     } | ||||||
|  |     monkeypatch.setattr(wz, "write_file", MagicMock()) | ||||||
|  |     monkeypatch.setattr(CORE, "config_path", os.path.dirname(tmp_path)) | ||||||
|  |  | ||||||
|  |     # When | ||||||
|  |     wz.wizard_write(tmp_path, **empty_config) | ||||||
|  |  | ||||||
|  |     # Then | ||||||
|  |     generated_config = wz.write_file.call_args.args[1] | ||||||
|  |     assert generated_config == "" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_wizard_upload_config(tmp_path, monkeypatch): | ||||||
|  |     """ | ||||||
|  |     The wizard should be able to import an base64 encoded configuration | ||||||
|  |     """ | ||||||
|  |     # Given | ||||||
|  |     empty_config = { | ||||||
|  |         "type": "upload", | ||||||
|  |         "name": "test-upload", | ||||||
|  |         "file_text": "# imported file 📁\n\n", | ||||||
|  |     } | ||||||
|  |     monkeypatch.setattr(wz, "write_file", MagicMock()) | ||||||
|  |     monkeypatch.setattr(CORE, "config_path", os.path.dirname(tmp_path)) | ||||||
|  |  | ||||||
|  |     # When | ||||||
|  |     wz.wizard_write(tmp_path, **empty_config) | ||||||
|  |  | ||||||
|  |     # Then | ||||||
|  |     generated_config = wz.write_file.call_args.args[1] | ||||||
|  |     assert generated_config == "# imported file 📁\n\n" | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_wizard_write_defaults_platform_from_board_esp8266( | def test_wizard_write_defaults_platform_from_board_esp8266( | ||||||
|     default_config, tmp_path, monkeypatch |     default_config, tmp_path, monkeypatch | ||||||
| ): | ): | ||||||
| @@ -471,3 +513,22 @@ def test_wizard_requires_valid_ssid(tmpdir, monkeypatch, wizard_answers): | |||||||
|  |  | ||||||
|     # Then |     # Then | ||||||
|     assert retval == 0 |     assert retval == 0 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_wizard_write_protects_existing_config(tmpdir, default_config, monkeypatch): | ||||||
|  |     """ | ||||||
|  |     The wizard_write function should not overwrite existing config files and return False | ||||||
|  |     """ | ||||||
|  |     # Given | ||||||
|  |     config_file = tmpdir.join("test.yaml") | ||||||
|  |     original_content = "# Original config content\n" | ||||||
|  |     config_file.write(original_content) | ||||||
|  |  | ||||||
|  |     monkeypatch.setattr(CORE, "config_path", str(tmpdir)) | ||||||
|  |  | ||||||
|  |     # When | ||||||
|  |     result = wz.wizard_write(str(config_file), **default_config) | ||||||
|  |  | ||||||
|  |     # Then | ||||||
|  |     assert result is False  # Should return False when file exists | ||||||
|  |     assert config_file.read() == original_content | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user