diff --git a/esphome/components/esphome/ota/ota_esphome.cpp b/esphome/components/esphome/ota/ota_esphome.cpp index 4cc82b9094..58cfbfbcc3 100644 --- a/esphome/components/esphome/ota/ota_esphome.cpp +++ b/esphome/components/esphome/ota/ota_esphome.cpp @@ -20,6 +20,7 @@ namespace esphome { static const char *const TAG = "esphome.ota"; static constexpr u_int16_t OTA_BLOCK_SIZE = 8192; +static constexpr uint16_t OTA_SOCKET_TIMEOUT_HANDSHAKE = 1000; // milliseconds for initial handshake void ESPHomeOTAComponent::setup() { #ifdef USE_OTA_STATE_CALLBACK @@ -28,19 +29,19 @@ void ESPHomeOTAComponent::setup() { this->server_ = socket::socket_ip_loop_monitored(SOCK_STREAM, 0); // monitored for incoming connections if (this->server_ == nullptr) { - ESP_LOGW(TAG, "Could not create socket"); + this->log_socket_error_("creation"); this->mark_failed(); return; } int enable = 1; int err = this->server_->setsockopt(SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)); if (err != 0) { - ESP_LOGW(TAG, "Socket unable to set reuseaddr: errno %d", err); + this->log_socket_error_("reuseaddr"); // we can still continue } err = this->server_->setblocking(false); if (err != 0) { - ESP_LOGW(TAG, "Socket unable to set nonblocking mode: errno %d", err); + this->log_socket_error_("non-blocking"); this->mark_failed(); return; } @@ -49,21 +50,21 @@ void ESPHomeOTAComponent::setup() { socklen_t sl = socket::set_sockaddr_any((struct sockaddr *) &server, sizeof(server), this->port_); if (sl == 0) { - ESP_LOGW(TAG, "Socket unable to set sockaddr: errno %d", errno); + this->log_socket_error_("set sockaddr"); this->mark_failed(); return; } err = this->server_->bind((struct sockaddr *) &server, sizeof(server)); if (err != 0) { - ESP_LOGW(TAG, "Socket unable to bind: errno %d", errno); + this->log_socket_error_("bind"); this->mark_failed(); return; } err = this->server_->listen(4); if (err != 0) { - ESP_LOGW(TAG, "Socket unable to listen: errno %d", errno); + this->log_socket_error_("listen"); this->mark_failed(); return; } @@ -120,26 +121,26 @@ void ESPHomeOTAComponent::handle_() { int enable = 1; int err = this->client_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int)); if (err != 0) { - ESP_LOGW(TAG, "Socket could not enable TCP nodelay, errno %d", errno); - this->client_->close(); - this->client_ = nullptr; + this->log_socket_error_("nodelay"); + this->cleanup_connection_(); + return; + } + err = this->client_->setblocking(false); + if (err != 0) { + this->log_socket_error_("non-blocking"); + this->cleanup_connection_(); return; } - ESP_LOGD(TAG, "Starting update from %s", this->client_->getpeername().c_str()); - this->status_set_warning(); -#ifdef USE_OTA_STATE_CALLBACK - this->state_callback_.call(ota::OTA_STARTED, 0.0f, 0); -#endif + this->log_start_("handshake"); - if (!this->readall_(buf, 5)) { - ESP_LOGW(TAG, "Reading magic bytes failed"); + if (!this->readall_(buf, 5, OTA_SOCKET_TIMEOUT_HANDSHAKE)) { + ESP_LOGW(TAG, "Read magic bytes failed"); goto error; // NOLINT(cppcoreguidelines-avoid-goto) } // 0x6C, 0x26, 0xF7, 0x5C, 0x45 if (buf[0] != 0x6C || buf[1] != 0x26 || buf[2] != 0xF7 || buf[3] != 0x5C || buf[4] != 0x45) { - ESP_LOGW(TAG, "Magic bytes do not match! 0x%02X-0x%02X-0x%02X-0x%02X-0x%02X", buf[0], buf[1], buf[2], buf[3], - buf[4]); + ESP_LOGW(TAG, "Magic bytes mismatch! 0x%02X-0x%02X-0x%02X-0x%02X-0x%02X", buf[0], buf[1], buf[2], buf[3], buf[4]); error_code = ota::OTA_RESPONSE_ERROR_MAGIC; goto error; // NOLINT(cppcoreguidelines-avoid-goto) } @@ -153,7 +154,7 @@ void ESPHomeOTAComponent::handle_() { // Read features - 1 byte if (!this->readall_(buf, 1)) { - ESP_LOGW(TAG, "Reading features failed"); + ESP_LOGW(TAG, "Read features failed"); goto error; // NOLINT(cppcoreguidelines-avoid-goto) } ota_features = buf[0]; // NOLINT @@ -232,7 +233,7 @@ void ESPHomeOTAComponent::handle_() { // Read size, 4 bytes MSB first if (!this->readall_(buf, 4)) { - ESP_LOGW(TAG, "Reading size failed"); + ESP_LOGW(TAG, "Read size failed"); goto error; // NOLINT(cppcoreguidelines-avoid-goto) } ota_size = 0; @@ -242,6 +243,17 @@ void ESPHomeOTAComponent::handle_() { } ESP_LOGV(TAG, "Size is %u bytes", ota_size); + // Now that we've passed authentication and are actually + // starting the update, set the warning status and notify + // listeners. This ensures that port scanners do not + // accidentally trigger the update process. + this->log_start_("update"); + this->status_set_warning(); +#ifdef USE_OTA_STATE_CALLBACK + this->state_callback_.call(ota::OTA_STARTED, 0.0f, 0); +#endif + + // This will block for a few seconds as it locks flash error_code = backend->begin(ota_size); if (error_code != ota::OTA_RESPONSE_OK) goto error; // NOLINT(cppcoreguidelines-avoid-goto) @@ -253,7 +265,7 @@ void ESPHomeOTAComponent::handle_() { // Read binary MD5, 32 bytes if (!this->readall_(buf, 32)) { - ESP_LOGW(TAG, "Reading binary MD5 checksum failed"); + ESP_LOGW(TAG, "Read MD5 checksum failed"); goto error; // NOLINT(cppcoreguidelines-avoid-goto) } sbuf[32] = '\0'; @@ -274,19 +286,19 @@ void ESPHomeOTAComponent::handle_() { delay(1); continue; } - ESP_LOGW(TAG, "Error receiving data for update, errno %d", errno); + ESP_LOGW(TAG, "Read error, errno %d", errno); goto error; // NOLINT(cppcoreguidelines-avoid-goto) } else if (read == 0) { // $ man recv // "When a stream socket peer has performed an orderly shutdown, the return value will // be 0 (the traditional "end-of-file" return)." - ESP_LOGW(TAG, "Remote end closed connection"); + ESP_LOGW(TAG, "Remote closed connection"); goto error; // NOLINT(cppcoreguidelines-avoid-goto) } error_code = backend->write(buf, read); if (error_code != ota::OTA_RESPONSE_OK) { - ESP_LOGW(TAG, "Error writing binary data to flash!, error_code: %d", error_code); + ESP_LOGW(TAG, "Flash write error, code: %d", error_code); goto error; // NOLINT(cppcoreguidelines-avoid-goto) } total += read; @@ -318,7 +330,7 @@ void ESPHomeOTAComponent::handle_() { error_code = backend->end(); if (error_code != ota::OTA_RESPONSE_OK) { - ESP_LOGW(TAG, "Error ending update! error_code: %d", error_code); + ESP_LOGW(TAG, "Error ending update! code: %d", error_code); goto error; // NOLINT(cppcoreguidelines-avoid-goto) } @@ -328,12 +340,11 @@ void ESPHomeOTAComponent::handle_() { // Read ACK if (!this->readall_(buf, 1) || buf[0] != ota::OTA_RESPONSE_OK) { - ESP_LOGW(TAG, "Reading back acknowledgement failed"); + ESP_LOGW(TAG, "Read ack failed"); // do not go to error, this is not fatal } - this->client_->close(); - this->client_ = nullptr; + this->cleanup_connection_(); delay(10); ESP_LOGI(TAG, "Update complete"); this->status_clear_warning(); @@ -346,8 +357,7 @@ void ESPHomeOTAComponent::handle_() { error: buf[0] = static_cast(error_code); this->writeall_(buf, 1); - this->client_->close(); - this->client_ = nullptr; + this->cleanup_connection_(); if (backend != nullptr && update_started) { backend->abort(); @@ -359,13 +369,13 @@ error: #endif } -bool ESPHomeOTAComponent::readall_(uint8_t *buf, size_t len) { +bool ESPHomeOTAComponent::readall_(uint8_t *buf, size_t len, uint16_t timeout) { uint32_t start = millis(); uint32_t at = 0; while (len - at > 0) { uint32_t now = millis(); - if (now - start > 1000) { - ESP_LOGW(TAG, "Timed out reading %d bytes of data", len); + if (now - start > timeout) { + ESP_LOGW(TAG, "Timeout reading %d bytes", len); return false; } @@ -376,7 +386,7 @@ bool ESPHomeOTAComponent::readall_(uint8_t *buf, size_t len) { delay(1); continue; } - ESP_LOGW(TAG, "Failed to read %d bytes of data, errno %d", len, errno); + ESP_LOGW(TAG, "Error reading %d bytes, errno %d", len, errno); return false; } else if (read == 0) { ESP_LOGW(TAG, "Remote closed connection"); @@ -390,13 +400,13 @@ bool ESPHomeOTAComponent::readall_(uint8_t *buf, size_t len) { return true; } -bool ESPHomeOTAComponent::writeall_(const uint8_t *buf, size_t len) { +bool ESPHomeOTAComponent::writeall_(const uint8_t *buf, size_t len, uint16_t timeout) { uint32_t start = millis(); uint32_t at = 0; while (len - at > 0) { uint32_t now = millis(); - if (now - start > 1000) { - ESP_LOGW(TAG, "Timed out writing %d bytes of data", len); + if (now - start > timeout) { + ESP_LOGW(TAG, "Timeout writing %d bytes", len); return false; } @@ -407,7 +417,7 @@ bool ESPHomeOTAComponent::writeall_(const uint8_t *buf, size_t len) { delay(1); continue; } - ESP_LOGW(TAG, "Failed to write %d bytes of data, errno %d", len, errno); + ESP_LOGW(TAG, "Error writing %d bytes, errno %d", len, errno); return false; } else { at += written; @@ -421,5 +431,17 @@ bool ESPHomeOTAComponent::writeall_(const uint8_t *buf, size_t len) { float ESPHomeOTAComponent::get_setup_priority() const { return setup_priority::AFTER_WIFI; } uint16_t ESPHomeOTAComponent::get_port() const { return this->port_; } void ESPHomeOTAComponent::set_port(uint16_t port) { this->port_ = port; } + +void ESPHomeOTAComponent::log_socket_error_(const char *msg) { ESP_LOGW(TAG, "Socket %s: errno %d", msg, errno); } + +void ESPHomeOTAComponent::log_start_(const char *phase) { + ESP_LOGD(TAG, "Starting %s from %s", phase, this->client_->getpeername().c_str()); +} + +void ESPHomeOTAComponent::cleanup_connection_() { + this->client_->close(); + this->client_ = nullptr; +} + } // namespace esphome #endif diff --git a/esphome/components/esphome/ota/ota_esphome.h b/esphome/components/esphome/ota/ota_esphome.h index e0d09ff37e..5d58eefd2f 100644 --- a/esphome/components/esphome/ota/ota_esphome.h +++ b/esphome/components/esphome/ota/ota_esphome.h @@ -9,6 +9,8 @@ namespace esphome { +static constexpr uint16_t OTA_SOCKET_TIMEOUT_DATA = 2500; // milliseconds for data transfer + /// ESPHomeOTAComponent provides a simple way to integrate Over-the-Air updates into your app using ArduinoOTA. class ESPHomeOTAComponent : public ota::OTAComponent { public: @@ -28,8 +30,11 @@ class ESPHomeOTAComponent : public ota::OTAComponent { protected: void handle_(); - bool readall_(uint8_t *buf, size_t len); - bool writeall_(const uint8_t *buf, size_t len); + bool readall_(uint8_t *buf, size_t len, uint16_t timeout = OTA_SOCKET_TIMEOUT_DATA); + bool writeall_(const uint8_t *buf, size_t len, uint16_t timeout = OTA_SOCKET_TIMEOUT_DATA); + void log_socket_error_(const char *msg); + void log_start_(const char *phase); + void cleanup_connection_(); #ifdef USE_OTA_PASSWORD std::string password_;