1
0
mirror of https://github.com/esphome/esphome.git synced 2025-10-30 14:43:51 +00:00
This commit is contained in:
J. Nick Koston
2025-09-27 01:08:10 -05:00
parent 9cdd4bc555
commit 5abde23432
2 changed files with 50 additions and 78 deletions

View File

@@ -124,11 +124,11 @@ static const uint8_t FEATURE_SUPPORTS_SHA256_AUTH = 0x02;
#define ALLOW_OTA_DOWNGRADE_MD5
void ESPHomeOTAComponent::handle_handshake_() {
/// Handle the initial OTA handshake.
/// Handle the OTA handshake and authentication.
///
/// This method is non-blocking and will return immediately if no data is available.
/// It reads all 5 magic bytes (0x6C, 0x26, 0xF7, 0x5C, 0x45) non-blocking
/// before proceeding to handle_data_(). A 10-second timeout is enforced from initial connection.
/// It manages the state machine through connection, magic bytes validation, feature
/// negotiation, and authentication before entering the blocking data transfer phase.
if (this->client_ == nullptr) {
// We already checked server_->ready() in loop(), so we can accept directly
@@ -168,7 +168,7 @@ void ESPHomeOTAComponent::handle_handshake_() {
switch (this->ota_state_) {
case OTAState::MAGIC_READ: {
// Try to read remaining magic bytes (5 total)
if (!this->try_read_(5, LOG_STR("reading magic bytes"), LOG_STR("handshake"))) {
if (!this->try_read_(5, LOG_STR("read magic"))) {
return;
}
@@ -183,21 +183,16 @@ void ESPHomeOTAComponent::handle_handshake_() {
// Magic bytes valid, move to next state
this->transition_ota_state_(OTAState::MAGIC_ACK);
this->handshake_buf_[0] = ota::OTA_RESPONSE_OK;
this->handshake_buf_[1] = USE_OTA_VERSION;
[[fallthrough]];
}
case OTAState::MAGIC_ACK: {
// Send OK and version - 2 bytes
// Prepare response in handshake buffer if not already done
if (this->handshake_buf_pos_ == 0) {
this->handshake_buf_[0] = ota::OTA_RESPONSE_OK;
this->handshake_buf_[1] = USE_OTA_VERSION;
}
if (!this->try_write_(2, LOG_STR("writing magic ack"))) {
if (!this->try_write_(2, LOG_STR("ack magic"))) {
return;
}
// All bytes sent, create backend and move to next state
this->backend_ = ota::make_ota_backend();
this->transition_ota_state_(OTAState::FEATURE_READ);
@@ -206,30 +201,24 @@ void ESPHomeOTAComponent::handle_handshake_() {
case OTAState::FEATURE_READ: {
// Read features - 1 byte
if (!this->try_read_(1, LOG_STR("reading features"), LOG_STR("feature read"))) {
if (!this->try_read_(1, LOG_STR("read feature"))) {
return;
}
this->ota_features_ = this->handshake_buf_[0];
ESP_LOGV(TAG, "Features: 0x%02X", this->ota_features_);
this->transition_ota_state_(OTAState::FEATURE_ACK);
this->handshake_buf_[0] =
((this->ota_features_ & FEATURE_SUPPORTS_COMPRESSION) != 0 && this->backend_->supports_compression())
? ota::OTA_RESPONSE_SUPPORTS_COMPRESSION
: ota::OTA_RESPONSE_HEADER_OK;
[[fallthrough]];
}
case OTAState::FEATURE_ACK: {
// Acknowledge header - 1 byte
// Prepare response in handshake buffer if not already done
if (this->handshake_buf_pos_ == 0) {
this->handshake_buf_[0] =
((this->ota_features_ & FEATURE_SUPPORTS_COMPRESSION) != 0 && this->backend_->supports_compression())
? ota::OTA_RESPONSE_SUPPORTS_COMPRESSION
: ota::OTA_RESPONSE_HEADER_OK;
}
if (!this->try_write_(1, LOG_STR("writing feature ack"))) {
if (!this->try_write_(1, LOG_STR("ack feature"))) {
return;
}
#ifdef USE_OTA_PASSWORD
// If password is set, move to auth phase
if (!this->password_.empty()) {
@@ -266,11 +255,10 @@ void ESPHomeOTAComponent::handle_handshake_() {
case OTAState::DATA:
this->handle_data_();
return;
[[fallthrough]];
case OTAState::IDLE:
// This shouldn't happen
return;
default:
break;
}
}
@@ -287,20 +275,12 @@ void ESPHomeOTAComponent::handle_data_() {
size_t total = 0;
uint32_t last_progress = 0;
uint8_t buf[OTA_BUFFER_SIZE];
const size_t buf_size = sizeof(buf);
char *sbuf = reinterpret_cast<char *>(buf);
size_t ota_size;
#if USE_OTA_VERSION == 2
size_t size_acknowledged = 0;
#endif
// The handshake and auth have already been completed
// We already have:
// - this->backend_ created
// - this->ota_features_ set
// - Feature acknowledgment sent
// - Authentication completed (if password was set)
// Acknowledge auth OK - 1 byte
buf[0] = ota::OTA_RESPONSE_AUTH_OK;
this->writeall_(buf, 1);
@@ -353,26 +333,23 @@ void ESPHomeOTAComponent::handle_data_() {
while (total < ota_size) {
// TODO: timeout check
size_t remaining = ota_size - total;
size_t requested = remaining < buf_size ? remaining : buf_size;
size_t requested = remaining < OTA_BUFFER_SIZE ? remaining : OTA_BUFFER_SIZE;
ssize_t read = this->client_->read(buf, requested);
if (read == -1) {
if (this->would_block_(errno)) {
this->yield_and_feed_watchdog_();
continue;
}
ESP_LOGW(TAG, "Read error, errno %d", errno);
ESP_LOGW(TAG, "Read err %d", errno);
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
} else if (read == 0) {
// $ man recv
// "When a stream socket peer has performed an orderly shutdown, the return value will
// be 0 (the traditional "end-of-file" return)."
ESP_LOGW(TAG, "Remote closed connection");
ESP_LOGW(TAG, "Remote closed");
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
}
error_code = this->backend_->write(buf, read);
if (error_code != ota::OTA_RESPONSE_OK) {
ESP_LOGW(TAG, "Flash write error, code: %d", error_code);
ESP_LOGW(TAG, "Flash write err %d", error_code);
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
}
total += read;
@@ -403,7 +380,7 @@ void ESPHomeOTAComponent::handle_data_() {
error_code = this->backend_->end();
if (error_code != ota::OTA_RESPONSE_OK) {
ESP_LOGW(TAG, "Error ending update! code: %d", error_code);
ESP_LOGW(TAG, "End update err %d", error_code);
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
}
@@ -455,11 +432,11 @@ bool ESPHomeOTAComponent::readall_(uint8_t *buf, size_t len) {
ssize_t read = this->client_->read(buf + at, len - at);
if (read == -1) {
if (!this->would_block_(errno)) {
ESP_LOGW(TAG, "Error reading %d bytes, errno %d", len, errno);
ESP_LOGW(TAG, "Read err %d bytes, errno %d", len, errno);
return false;
}
} else if (read == 0) {
ESP_LOGW(TAG, "Remote closed connection");
ESP_LOGW(TAG, "Remote closed");
return false;
} else {
at += read;
@@ -482,7 +459,7 @@ bool ESPHomeOTAComponent::writeall_(const uint8_t *buf, size_t len) {
ssize_t written = this->client_->write(buf + at, len - at);
if (written == -1) {
if (!this->would_block_(errno)) {
ESP_LOGW(TAG, "Error writing %d bytes, errno %d", len, errno);
ESP_LOGW(TAG, "Write err %d bytes, errno %d", len, errno);
return false;
}
} else {
@@ -508,40 +485,40 @@ void ESPHomeOTAComponent::log_start_(const LogString *phase) {
}
void ESPHomeOTAComponent::log_remote_closed_(const LogString *during) {
ESP_LOGW(TAG, "Remote closed during %s", LOG_STR_ARG(during));
ESP_LOGW(TAG, "Remote closed at %s", LOG_STR_ARG(during));
}
bool ESPHomeOTAComponent::handle_read_error_(ssize_t read, const LogString *error_desc, const LogString *close_desc) {
bool ESPHomeOTAComponent::handle_read_error_(ssize_t read, const LogString *desc) {
if (read == -1 && this->would_block_(errno)) {
return false; // No data yet, try again next loop
}
if (read <= 0) {
read == 0 ? this->log_remote_closed_(close_desc) : this->log_socket_error_(error_desc);
read == 0 ? this->log_remote_closed_(desc) : this->log_socket_error_(desc);
this->cleanup_connection_();
return false;
}
return true;
}
bool ESPHomeOTAComponent::handle_write_error_(ssize_t written, const LogString *error_desc) {
bool ESPHomeOTAComponent::handle_write_error_(ssize_t written, const LogString *desc) {
if (written == -1) {
if (this->would_block_(errno)) {
return false; // Try again next loop
}
this->log_socket_error_(error_desc);
this->log_socket_error_(desc);
this->cleanup_connection_();
return false;
}
return true;
}
bool ESPHomeOTAComponent::try_read_(size_t to_read, const LogString *error_desc, const LogString *close_desc) {
bool ESPHomeOTAComponent::try_read_(size_t to_read, const LogString *desc) {
// Read bytes into handshake buffer, starting at handshake_buf_pos_
size_t bytes_to_read = to_read - this->handshake_buf_pos_;
ssize_t read = this->client_->read(this->handshake_buf_ + this->handshake_buf_pos_, bytes_to_read);
if (!this->handle_read_error_(read, error_desc, close_desc)) {
if (!this->handle_read_error_(read, desc)) {
return false;
}
@@ -550,12 +527,12 @@ bool ESPHomeOTAComponent::try_read_(size_t to_read, const LogString *error_desc,
return this->handshake_buf_pos_ >= to_read;
}
bool ESPHomeOTAComponent::try_write_(size_t to_write, const LogString *error_desc) {
bool ESPHomeOTAComponent::try_write_(size_t to_write, const LogString *desc) {
// Write bytes from handshake buffer, starting at handshake_buf_pos_
size_t bytes_to_write = to_write - this->handshake_buf_pos_;
ssize_t written = this->client_->write(this->handshake_buf_ + this->handshake_buf_pos_, bytes_to_write);
if (!this->handle_write_error_(written, error_desc)) {
if (!this->handle_write_error_(written, desc)) {
return false;
}
@@ -564,11 +541,6 @@ bool ESPHomeOTAComponent::try_write_(size_t to_write, const LogString *error_des
return this->handshake_buf_pos_ >= to_write;
}
void ESPHomeOTAComponent::transition_ota_state_(OTAState next_state) {
this->ota_state_ = next_state;
this->handshake_buf_pos_ = 0; // Reset buffer position for next state
}
void ESPHomeOTAComponent::cleanup_connection_() {
this->client_->close();
this->client_ = nullptr;
@@ -582,12 +554,6 @@ void ESPHomeOTAComponent::cleanup_connection_() {
#endif
}
void ESPHomeOTAComponent::send_error_and_cleanup_(ota::OTAResponseTypes error) {
uint8_t error_byte = static_cast<uint8_t>(error);
this->client_->write(&error_byte, 1); // Best effort, non-blocking
this->cleanup_connection_();
}
void ESPHomeOTAComponent::yield_and_feed_watchdog_() {
App.feed_wdt();
delay(1);
@@ -675,7 +641,7 @@ bool ESPHomeOTAComponent::handle_auth_send_() {
size_t remaining = to_write - this->auth_buf_pos_;
ssize_t written = this->client_->write(this->auth_buf_.get() + this->auth_buf_pos_, remaining);
if (!this->handle_write_error_(written, LOG_STR("auth write"))) {
if (!this->handle_write_error_(written, LOG_STR("ack auth"))) {
return false;
}
@@ -701,8 +667,7 @@ bool ESPHomeOTAComponent::handle_auth_read_() {
size_t remaining = to_read - this->auth_buf_pos_;
ssize_t read = this->client_->read(this->auth_buf_.get() + cnonce_offset + this->auth_buf_pos_, remaining);
auto *auth_read_desc = LOG_STR("auth read");
if (!this->handle_read_error_(read, auth_read_desc, auth_read_desc)) {
if (!this->handle_read_error_(read, LOG_STR("read auth"))) {
return false;
}
@@ -760,7 +725,7 @@ bool ESPHomeOTAComponent::prepare_auth_nonce_(HashBase *hasher) {
char *buf = reinterpret_cast<char *>(this->auth_buf_.get() + 1);
if (!random_bytes(reinterpret_cast<uint8_t *>(buf), nonce_len)) {
this->log_auth_warning_(LOG_STR("Random failed"));
this->cleanup_connection_();
this->send_error_and_cleanup_(ota::OTA_RESPONSE_ERROR_UNKNOWN);
return false;
}

View File

@@ -56,20 +56,27 @@ class ESPHomeOTAComponent : public ota::OTAComponent {
bool readall_(uint8_t *buf, size_t len);
bool writeall_(const uint8_t *buf, size_t len);
bool try_read_(size_t to_read, const LogString *error_desc, const LogString *close_desc);
bool try_write_(size_t to_write, const LogString *error_desc);
bool try_read_(size_t to_read, const LogString *desc);
bool try_write_(size_t to_write, const LogString *desc);
bool would_block_(int error_code) const { return error_code == EAGAIN || error_code == EWOULDBLOCK; }
bool handle_read_error_(ssize_t read, const LogString *error_desc, const LogString *close_desc);
bool handle_write_error_(ssize_t written, const LogString *error_desc);
void transition_ota_state_(OTAState next_state);
inline bool would_block_(int error_code) const { return error_code == EAGAIN || error_code == EWOULDBLOCK; }
bool handle_read_error_(ssize_t read, const LogString *desc);
bool handle_write_error_(ssize_t written, const LogString *desc);
inline void transition_ota_state_(OTAState next_state) {
this->ota_state_ = next_state;
this->handshake_buf_pos_ = 0; // Reset buffer position for next state
}
void log_socket_error_(const LogString *msg);
void log_read_error_(const LogString *what);
void log_start_(const LogString *phase);
void log_remote_closed_(const LogString *during);
void cleanup_connection_();
void send_error_and_cleanup_(ota::OTAResponseTypes error);
inline void send_error_and_cleanup_(ota::OTAResponseTypes error) {
uint8_t error_byte = static_cast<uint8_t>(error);
this->client_->write(&error_byte, 1); // Best effort, non-blocking
this->cleanup_connection_();
}
void yield_and_feed_watchdog_();
#ifdef USE_OTA_PASSWORD