From cc4c059429e5e2e5884915aeb149319af7ec970e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 26 Sep 2025 21:52:00 -0500 Subject: [PATCH] optimize --- .../components/esphome/ota/ota_esphome.cpp | 221 ++++++++++++------ esphome/components/esphome/ota/ota_esphome.h | 15 +- 2 files changed, 164 insertions(+), 72 deletions(-) diff --git a/esphome/components/esphome/ota/ota_esphome.cpp b/esphome/components/esphome/ota/ota_esphome.cpp index 11795aaf2f..445167f13e 100644 --- a/esphome/components/esphome/ota/ota_esphome.cpp +++ b/esphome/components/esphome/ota/ota_esphome.cpp @@ -141,7 +141,8 @@ void ESPHomeOTAComponent::handle_handshake_() { } this->log_start_(LOG_STR("handshake")); this->client_connect_time_ = App.get_loop_component_start_time(); - this->magic_buf_pos_ = 0; // Reset magic buffer position + this->handshake_buf_pos_ = 0; // Reset handshake buffer position + this->ota_state_ = OTAState::MAGIC_READ; } // Check for handshake timeout @@ -152,46 +153,143 @@ void ESPHomeOTAComponent::handle_handshake_() { return; } - // Try to read remaining magic bytes - if (this->magic_buf_pos_ < 5) { - // Read as many bytes as available - uint8_t bytes_to_read = 5 - this->magic_buf_pos_; - ssize_t read = this->client_->read(this->magic_buf_ + this->magic_buf_pos_, bytes_to_read); + while (true) { + switch (this->ota_state_) { + case OTAState::MAGIC_READ: { + // Try to read remaining magic bytes + if (this->handshake_buf_pos_ < 5) { + // Read as many bytes as available + uint8_t bytes_to_read = 5 - this->handshake_buf_pos_; + ssize_t read = this->client_->read(this->handshake_buf_ + this->handshake_buf_pos_, bytes_to_read); - if (read == -1 && (errno == EAGAIN || errno == EWOULDBLOCK)) { - return; // No data yet, try again next loop - } + if (read == -1 && (errno == EAGAIN || errno == EWOULDBLOCK)) { + return; // No data yet, try again next loop + } - if (read <= 0) { - // Error or connection closed - if (read == -1) { - this->log_socket_error_(LOG_STR("reading magic bytes")); - } else { - ESP_LOGW(TAG, "Remote closed during handshake"); + if (read <= 0) { + // Error or connection closed + if (read == -1) { + this->log_socket_error_(LOG_STR("reading magic bytes")); + } else { + ESP_LOGW(TAG, "Remote closed during handshake"); + } + this->cleanup_connection_(); + return; + } + + this->handshake_buf_pos_ += read; + } + + // Check if we have all 5 magic bytes + if (this->handshake_buf_pos_ != 5) { + break; + } + + // Validate magic bytes + static const uint8_t MAGIC_BYTES[5] = {0x6C, 0x26, 0xF7, 0x5C, 0x45}; + if (memcmp(this->handshake_buf_, MAGIC_BYTES, 5) != 0) { + ESP_LOGW(TAG, "Magic bytes mismatch! 0x%02X-0x%02X-0x%02X-0x%02X-0x%02X", this->handshake_buf_[0], + this->handshake_buf_[1], this->handshake_buf_[2], this->handshake_buf_[3], this->handshake_buf_[4]); + // Send error response (non-blocking, best effort) + uint8_t error = static_cast(ota::OTA_RESPONSE_ERROR_MAGIC); + this->client_->write(&error, 1); + this->cleanup_connection_(); + return; + } + + // Magic bytes valid, move to next state + this->ota_state_ = OTAState::MAGIC_ACK; + this->handshake_buf_pos_ = 0; // Reset for reuse + continue; } - this->cleanup_connection_(); - return; + + case OTAState::MAGIC_ACK: { + // Send OK and version - 2 bytes + // Prepare response in handshake buffer if not already done + if (this->handshake_buf_pos_ == 0) { + this->handshake_buf_[0] = ota::OTA_RESPONSE_OK; + this->handshake_buf_[1] = USE_OTA_VERSION; + } + + // Write remaining bytes (2 total) + size_t bytes_to_write = 2 - this->handshake_buf_pos_; + ssize_t written = this->client_->write(this->handshake_buf_ + this->handshake_buf_pos_, bytes_to_write); + + if (written == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return; // Try again next loop + } + this->log_socket_error_(LOG_STR("writing magic ack")); + this->cleanup_connection_(); + return; + } + + this->handshake_buf_pos_ += written; + if (this->handshake_buf_pos_ != 2) { + return; + } + // All bytes sent, create backend and move to next state + this->backend_ = ota::make_ota_backend(); + this->ota_state_ = OTAState::FEATURE_READ; + this->handshake_buf_pos_ = 0; // Reset for reuse + continue; + } + + case OTAState::FEATURE_READ: { + // Read features - 1 byte + ssize_t read = this->client_->read(this->handshake_buf_, 1); + + if (read == -1 && (errno == EAGAIN || errno == EWOULDBLOCK)) { + return; // No data yet, try again next loop + } + + if (read <= 0) { + if (read == -1) { + this->log_socket_error_(LOG_STR("reading features")); + } else { + ESP_LOGW(TAG, "Remote closed during feature read"); + } + this->cleanup_connection_(); + return; + } + + this->ota_features_ = this->handshake_buf_[0]; + ESP_LOGV(TAG, "Features: 0x%02X", this->ota_features_); + this->ota_state_ = OTAState::FEATURE_ACK; + this->handshake_buf_pos_ = 0; // Reset for reuse + continue; + } + + case OTAState::FEATURE_ACK: { + // Acknowledge header - 1 byte + uint8_t ack = + ((this->ota_features_ & FEATURE_SUPPORTS_COMPRESSION) != 0 && this->backend_->supports_compression()) + ? ota::OTA_RESPONSE_SUPPORTS_COMPRESSION + : ota::OTA_RESPONSE_HEADER_OK; + + ssize_t written = this->client_->write(&ack, 1); + if (written == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return; // Try again next loop + } + this->log_socket_error_(LOG_STR("writing feature ack")); + this->cleanup_connection_(); + return; + } + + // Handshake complete, move to data phase + this->ota_state_ = OTAState::DATA; + continue; + } + + case OTAState::DATA: + this->handle_data_(); + return; + + case OTAState::IDLE: + // This shouldn't happen + return; } - - this->magic_buf_pos_ += read; - } - - // Check if we have all 5 magic bytes - if (this->magic_buf_pos_ == 5) { - // Validate magic bytes - static const uint8_t MAGIC_BYTES[5] = {0x6C, 0x26, 0xF7, 0x5C, 0x45}; - if (memcmp(this->magic_buf_, MAGIC_BYTES, 5) != 0) { - ESP_LOGW(TAG, "Magic bytes mismatch! 0x%02X-0x%02X-0x%02X-0x%02X-0x%02X", this->magic_buf_[0], - this->magic_buf_[1], this->magic_buf_[2], this->magic_buf_[3], this->magic_buf_[4]); - // Send error response (non-blocking, best effort) - uint8_t error = static_cast(ota::OTA_RESPONSE_ERROR_MAGIC); - this->client_->write(&error, 1); - this->cleanup_connection_(); - return; - } - - // All 5 magic bytes are valid, continue with data handling - this->handle_data_(); } } @@ -208,35 +306,15 @@ void ESPHomeOTAComponent::handle_data_() { uint8_t buf[1024]; char *sbuf = reinterpret_cast(buf); size_t ota_size; - uint8_t ota_features; - std::unique_ptr backend; - (void) ota_features; #if USE_OTA_VERSION == 2 size_t size_acknowledged = 0; #endif - // Send OK and version - 2 bytes - buf[0] = ota::OTA_RESPONSE_OK; - buf[1] = USE_OTA_VERSION; - this->writeall_(buf, 2); - - backend = ota::make_ota_backend(); - - // Read features - 1 byte - if (!this->readall_(buf, 1)) { - this->log_read_error_(LOG_STR("features")); - goto error; // NOLINT(cppcoreguidelines-avoid-goto) - } - ota_features = buf[0]; // NOLINT - ESP_LOGV(TAG, "Features: 0x%02X", ota_features); - - // Acknowledge header - 1 byte - buf[0] = ota::OTA_RESPONSE_HEADER_OK; - if ((ota_features & FEATURE_SUPPORTS_COMPRESSION) != 0 && backend->supports_compression()) { - buf[0] = ota::OTA_RESPONSE_SUPPORTS_COMPRESSION; - } - - this->writeall_(buf, 1); + // The handshake has already been completed in handle_handshake_() + // We already have: + // - this->backend_ created + // - this->ota_features_ set + // - Feature acknowledgment sent #ifdef USE_OTA_PASSWORD if (!this->password_.empty()) { @@ -261,7 +339,7 @@ void ESPHomeOTAComponent::handle_data_() { // Devices that don't support SHA256 (due to platform limitations) will // continue to use MD5 as their only option (see #else branch below). - bool client_supports_sha256 = (ota_features & FEATURE_SUPPORTS_SHA256_AUTH) != 0; + bool client_supports_sha256 = (this->ota_features_ & FEATURE_SUPPORTS_SHA256_AUTH) != 0; #ifdef ALLOW_OTA_DOWNGRADE_MD5 // Temporary compatibility mode: Allow MD5 for ~3 versions to enable OTA downgrades @@ -334,7 +412,7 @@ void ESPHomeOTAComponent::handle_data_() { #endif // This will block for a few seconds as it locks flash - error_code = backend->begin(ota_size); + error_code = this->backend_->begin(ota_size); if (error_code != ota::OTA_RESPONSE_OK) goto error; // NOLINT(cppcoreguidelines-avoid-goto) update_started = true; @@ -350,7 +428,7 @@ void ESPHomeOTAComponent::handle_data_() { } sbuf[32] = '\0'; ESP_LOGV(TAG, "Update: Binary MD5 is %s", sbuf); - backend->set_update_md5(sbuf); + this->backend_->set_update_md5(sbuf); // Acknowledge MD5 OK - 1 byte buf[0] = ota::OTA_RESPONSE_BIN_MD5_OK; @@ -375,7 +453,7 @@ void ESPHomeOTAComponent::handle_data_() { goto error; // NOLINT(cppcoreguidelines-avoid-goto) } - error_code = backend->write(buf, read); + error_code = this->backend_->write(buf, read); if (error_code != ota::OTA_RESPONSE_OK) { ESP_LOGW(TAG, "Flash write error, code: %d", error_code); goto error; // NOLINT(cppcoreguidelines-avoid-goto) @@ -406,7 +484,7 @@ void ESPHomeOTAComponent::handle_data_() { buf[0] = ota::OTA_RESPONSE_RECEIVE_OK; this->writeall_(buf, 1); - error_code = backend->end(); + error_code = this->backend_->end(); if (error_code != ota::OTA_RESPONSE_OK) { ESP_LOGW(TAG, "Error ending update! code: %d", error_code); goto error; // NOLINT(cppcoreguidelines-avoid-goto) @@ -437,8 +515,8 @@ error: this->writeall_(buf, 1); this->cleanup_connection_(); - if (backend != nullptr && update_started) { - backend->abort(); + if (this->backend_ != nullptr && update_started) { + this->backend_->abort(); } this->status_momentary_error("onerror", 5000); @@ -516,7 +594,10 @@ void ESPHomeOTAComponent::cleanup_connection_() { this->client_->close(); this->client_ = nullptr; this->client_connect_time_ = 0; - this->magic_buf_pos_ = 0; + this->handshake_buf_pos_ = 0; + this->ota_state_ = OTAState::IDLE; + this->ota_features_ = 0; + this->backend_ = nullptr; } void ESPHomeOTAComponent::yield_and_feed_watchdog_() { diff --git a/esphome/components/esphome/ota/ota_esphome.h b/esphome/components/esphome/ota/ota_esphome.h index 5bacb60706..02e759c2ba 100644 --- a/esphome/components/esphome/ota/ota_esphome.h +++ b/esphome/components/esphome/ota/ota_esphome.h @@ -14,6 +14,14 @@ namespace esphome { /// ESPHomeOTAComponent provides a simple way to integrate Over-the-Air updates into your app using ArduinoOTA. class ESPHomeOTAComponent : public ota::OTAComponent { public: + enum class OTAState : uint8_t { + IDLE, + MAGIC_READ, // Reading magic bytes + MAGIC_ACK, // Sending OK and version after magic bytes + FEATURE_READ, // Reading feature flags from client + FEATURE_ACK, // Sending feature acknowledgment + DATA, // Processing OTA data (authentication, update, etc.) + }; #ifdef USE_OTA_PASSWORD void set_auth_password(const std::string &password) { password_ = password; } #endif // USE_OTA_PASSWORD @@ -51,10 +59,13 @@ class ESPHomeOTAComponent : public ota::OTAComponent { std::unique_ptr server_; std::unique_ptr client_; + OTAState ota_state_{OTAState::IDLE}; uint32_t client_connect_time_{0}; uint16_t port_; - uint8_t magic_buf_[5]; - uint8_t magic_buf_pos_{0}; + uint8_t handshake_buf_[5]; + uint8_t handshake_buf_pos_{0}; + uint8_t ota_features_{0}; + std::unique_ptr backend_; }; } // namespace esphome