1
0
mirror of https://github.com/esphome/esphome.git synced 2025-09-07 05:42:20 +01:00

[esphome] Fix OTA watchdog reset when port scanning

This commit is contained in:
J. Nick Koston
2025-08-09 14:01:49 -05:00
parent ff9ddb9d68
commit ea74a9ec8f
2 changed files with 67 additions and 40 deletions

View File

@@ -20,6 +20,7 @@ namespace esphome {
static const char *const TAG = "esphome.ota"; static const char *const TAG = "esphome.ota";
static constexpr u_int16_t OTA_BLOCK_SIZE = 8192; static constexpr u_int16_t OTA_BLOCK_SIZE = 8192;
static constexpr uint16_t OTA_SOCKET_TIMEOUT_HANDSHAKE = 1000; // milliseconds for initial handshake
void ESPHomeOTAComponent::setup() { void ESPHomeOTAComponent::setup() {
#ifdef USE_OTA_STATE_CALLBACK #ifdef USE_OTA_STATE_CALLBACK
@@ -28,19 +29,19 @@ void ESPHomeOTAComponent::setup() {
this->server_ = socket::socket_ip_loop_monitored(SOCK_STREAM, 0); // monitored for incoming connections this->server_ = socket::socket_ip_loop_monitored(SOCK_STREAM, 0); // monitored for incoming connections
if (this->server_ == nullptr) { if (this->server_ == nullptr) {
ESP_LOGW(TAG, "Could not create socket"); this->log_socket_error_("creation");
this->mark_failed(); this->mark_failed();
return; return;
} }
int enable = 1; int enable = 1;
int err = this->server_->setsockopt(SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)); int err = this->server_->setsockopt(SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
if (err != 0) { if (err != 0) {
ESP_LOGW(TAG, "Socket unable to set reuseaddr: errno %d", err); this->log_socket_error_("reuseaddr");
// we can still continue // we can still continue
} }
err = this->server_->setblocking(false); err = this->server_->setblocking(false);
if (err != 0) { if (err != 0) {
ESP_LOGW(TAG, "Socket unable to set nonblocking mode: errno %d", err); this->log_socket_error_("non-blocking");
this->mark_failed(); this->mark_failed();
return; return;
} }
@@ -49,21 +50,21 @@ void ESPHomeOTAComponent::setup() {
socklen_t sl = socket::set_sockaddr_any((struct sockaddr *) &server, sizeof(server), this->port_); socklen_t sl = socket::set_sockaddr_any((struct sockaddr *) &server, sizeof(server), this->port_);
if (sl == 0) { if (sl == 0) {
ESP_LOGW(TAG, "Socket unable to set sockaddr: errno %d", errno); this->log_socket_error_("set sockaddr");
this->mark_failed(); this->mark_failed();
return; return;
} }
err = this->server_->bind((struct sockaddr *) &server, sizeof(server)); err = this->server_->bind((struct sockaddr *) &server, sizeof(server));
if (err != 0) { if (err != 0) {
ESP_LOGW(TAG, "Socket unable to bind: errno %d", errno); this->log_socket_error_("bind");
this->mark_failed(); this->mark_failed();
return; return;
} }
err = this->server_->listen(4); err = this->server_->listen(4);
if (err != 0) { if (err != 0) {
ESP_LOGW(TAG, "Socket unable to listen: errno %d", errno); this->log_socket_error_("listen");
this->mark_failed(); this->mark_failed();
return; return;
} }
@@ -120,26 +121,26 @@ void ESPHomeOTAComponent::handle_() {
int enable = 1; int enable = 1;
int err = this->client_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int)); int err = this->client_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int));
if (err != 0) { if (err != 0) {
ESP_LOGW(TAG, "Socket could not enable TCP nodelay, errno %d", errno); this->log_socket_error_("nodelay");
this->client_->close(); this->cleanup_connection_();
this->client_ = nullptr; return;
}
err = this->client_->setblocking(false);
if (err != 0) {
this->log_socket_error_("non-blocking");
this->cleanup_connection_();
return; return;
} }
ESP_LOGD(TAG, "Starting update from %s", this->client_->getpeername().c_str()); this->log_start_("handshake");
this->status_set_warning();
#ifdef USE_OTA_STATE_CALLBACK
this->state_callback_.call(ota::OTA_STARTED, 0.0f, 0);
#endif
if (!this->readall_(buf, 5)) { if (!this->readall_(buf, 5, OTA_SOCKET_TIMEOUT_HANDSHAKE)) {
ESP_LOGW(TAG, "Reading magic bytes failed"); ESP_LOGW(TAG, "Read magic bytes failed");
goto error; // NOLINT(cppcoreguidelines-avoid-goto) goto error; // NOLINT(cppcoreguidelines-avoid-goto)
} }
// 0x6C, 0x26, 0xF7, 0x5C, 0x45 // 0x6C, 0x26, 0xF7, 0x5C, 0x45
if (buf[0] != 0x6C || buf[1] != 0x26 || buf[2] != 0xF7 || buf[3] != 0x5C || buf[4] != 0x45) { if (buf[0] != 0x6C || buf[1] != 0x26 || buf[2] != 0xF7 || buf[3] != 0x5C || buf[4] != 0x45) {
ESP_LOGW(TAG, "Magic bytes do not match! 0x%02X-0x%02X-0x%02X-0x%02X-0x%02X", buf[0], buf[1], buf[2], buf[3], ESP_LOGW(TAG, "Magic bytes mismatch! 0x%02X-0x%02X-0x%02X-0x%02X-0x%02X", buf[0], buf[1], buf[2], buf[3], buf[4]);
buf[4]);
error_code = ota::OTA_RESPONSE_ERROR_MAGIC; error_code = ota::OTA_RESPONSE_ERROR_MAGIC;
goto error; // NOLINT(cppcoreguidelines-avoid-goto) goto error; // NOLINT(cppcoreguidelines-avoid-goto)
} }
@@ -153,7 +154,7 @@ void ESPHomeOTAComponent::handle_() {
// Read features - 1 byte // Read features - 1 byte
if (!this->readall_(buf, 1)) { if (!this->readall_(buf, 1)) {
ESP_LOGW(TAG, "Reading features failed"); ESP_LOGW(TAG, "Read features failed");
goto error; // NOLINT(cppcoreguidelines-avoid-goto) goto error; // NOLINT(cppcoreguidelines-avoid-goto)
} }
ota_features = buf[0]; // NOLINT ota_features = buf[0]; // NOLINT
@@ -232,7 +233,7 @@ void ESPHomeOTAComponent::handle_() {
// Read size, 4 bytes MSB first // Read size, 4 bytes MSB first
if (!this->readall_(buf, 4)) { if (!this->readall_(buf, 4)) {
ESP_LOGW(TAG, "Reading size failed"); ESP_LOGW(TAG, "Read size failed");
goto error; // NOLINT(cppcoreguidelines-avoid-goto) goto error; // NOLINT(cppcoreguidelines-avoid-goto)
} }
ota_size = 0; ota_size = 0;
@@ -242,6 +243,17 @@ void ESPHomeOTAComponent::handle_() {
} }
ESP_LOGV(TAG, "Size is %u bytes", ota_size); ESP_LOGV(TAG, "Size is %u bytes", ota_size);
// Now that we've passed authentication and are actually
// starting the update, set the warning status and notify
// listeners. This ensures that port scanners do not
// accidentally trigger the update process.
this->log_start_("update");
this->status_set_warning();
#ifdef USE_OTA_STATE_CALLBACK
this->state_callback_.call(ota::OTA_STARTED, 0.0f, 0);
#endif
// This will block for a few seconds as it locks flash
error_code = backend->begin(ota_size); error_code = backend->begin(ota_size);
if (error_code != ota::OTA_RESPONSE_OK) if (error_code != ota::OTA_RESPONSE_OK)
goto error; // NOLINT(cppcoreguidelines-avoid-goto) goto error; // NOLINT(cppcoreguidelines-avoid-goto)
@@ -253,7 +265,7 @@ void ESPHomeOTAComponent::handle_() {
// Read binary MD5, 32 bytes // Read binary MD5, 32 bytes
if (!this->readall_(buf, 32)) { if (!this->readall_(buf, 32)) {
ESP_LOGW(TAG, "Reading binary MD5 checksum failed"); ESP_LOGW(TAG, "Read MD5 checksum failed");
goto error; // NOLINT(cppcoreguidelines-avoid-goto) goto error; // NOLINT(cppcoreguidelines-avoid-goto)
} }
sbuf[32] = '\0'; sbuf[32] = '\0';
@@ -274,19 +286,19 @@ void ESPHomeOTAComponent::handle_() {
delay(1); delay(1);
continue; continue;
} }
ESP_LOGW(TAG, "Error receiving data for update, errno %d", errno); ESP_LOGW(TAG, "Read error, errno %d", errno);
goto error; // NOLINT(cppcoreguidelines-avoid-goto) goto error; // NOLINT(cppcoreguidelines-avoid-goto)
} else if (read == 0) { } else if (read == 0) {
// $ man recv // $ man recv
// "When a stream socket peer has performed an orderly shutdown, the return value will // "When a stream socket peer has performed an orderly shutdown, the return value will
// be 0 (the traditional "end-of-file" return)." // be 0 (the traditional "end-of-file" return)."
ESP_LOGW(TAG, "Remote end closed connection"); ESP_LOGW(TAG, "Remote closed connection");
goto error; // NOLINT(cppcoreguidelines-avoid-goto) goto error; // NOLINT(cppcoreguidelines-avoid-goto)
} }
error_code = backend->write(buf, read); error_code = backend->write(buf, read);
if (error_code != ota::OTA_RESPONSE_OK) { if (error_code != ota::OTA_RESPONSE_OK) {
ESP_LOGW(TAG, "Error writing binary data to flash!, error_code: %d", error_code); ESP_LOGW(TAG, "Flash write error, code: %d", error_code);
goto error; // NOLINT(cppcoreguidelines-avoid-goto) goto error; // NOLINT(cppcoreguidelines-avoid-goto)
} }
total += read; total += read;
@@ -318,7 +330,7 @@ void ESPHomeOTAComponent::handle_() {
error_code = backend->end(); error_code = backend->end();
if (error_code != ota::OTA_RESPONSE_OK) { if (error_code != ota::OTA_RESPONSE_OK) {
ESP_LOGW(TAG, "Error ending update! error_code: %d", error_code); ESP_LOGW(TAG, "Error ending update! code: %d", error_code);
goto error; // NOLINT(cppcoreguidelines-avoid-goto) goto error; // NOLINT(cppcoreguidelines-avoid-goto)
} }
@@ -328,12 +340,11 @@ void ESPHomeOTAComponent::handle_() {
// Read ACK // Read ACK
if (!this->readall_(buf, 1) || buf[0] != ota::OTA_RESPONSE_OK) { if (!this->readall_(buf, 1) || buf[0] != ota::OTA_RESPONSE_OK) {
ESP_LOGW(TAG, "Reading back acknowledgement failed"); ESP_LOGW(TAG, "Read ack failed");
// do not go to error, this is not fatal // do not go to error, this is not fatal
} }
this->client_->close(); this->cleanup_connection_();
this->client_ = nullptr;
delay(10); delay(10);
ESP_LOGI(TAG, "Update complete"); ESP_LOGI(TAG, "Update complete");
this->status_clear_warning(); this->status_clear_warning();
@@ -346,8 +357,7 @@ void ESPHomeOTAComponent::handle_() {
error: error:
buf[0] = static_cast<uint8_t>(error_code); buf[0] = static_cast<uint8_t>(error_code);
this->writeall_(buf, 1); this->writeall_(buf, 1);
this->client_->close(); this->cleanup_connection_();
this->client_ = nullptr;
if (backend != nullptr && update_started) { if (backend != nullptr && update_started) {
backend->abort(); backend->abort();
@@ -359,13 +369,13 @@ error:
#endif #endif
} }
bool ESPHomeOTAComponent::readall_(uint8_t *buf, size_t len) { bool ESPHomeOTAComponent::readall_(uint8_t *buf, size_t len, uint16_t timeout) {
uint32_t start = millis(); uint32_t start = millis();
uint32_t at = 0; uint32_t at = 0;
while (len - at > 0) { while (len - at > 0) {
uint32_t now = millis(); uint32_t now = millis();
if (now - start > 1000) { if (now - start > timeout) {
ESP_LOGW(TAG, "Timed out reading %d bytes of data", len); ESP_LOGW(TAG, "Timeout reading %d bytes", len);
return false; return false;
} }
@@ -376,7 +386,7 @@ bool ESPHomeOTAComponent::readall_(uint8_t *buf, size_t len) {
delay(1); delay(1);
continue; continue;
} }
ESP_LOGW(TAG, "Failed to read %d bytes of data, errno %d", len, errno); ESP_LOGW(TAG, "Error reading %d bytes, errno %d", len, errno);
return false; return false;
} else if (read == 0) { } else if (read == 0) {
ESP_LOGW(TAG, "Remote closed connection"); ESP_LOGW(TAG, "Remote closed connection");
@@ -390,13 +400,13 @@ bool ESPHomeOTAComponent::readall_(uint8_t *buf, size_t len) {
return true; return true;
} }
bool ESPHomeOTAComponent::writeall_(const uint8_t *buf, size_t len) { bool ESPHomeOTAComponent::writeall_(const uint8_t *buf, size_t len, uint16_t timeout) {
uint32_t start = millis(); uint32_t start = millis();
uint32_t at = 0; uint32_t at = 0;
while (len - at > 0) { while (len - at > 0) {
uint32_t now = millis(); uint32_t now = millis();
if (now - start > 1000) { if (now - start > timeout) {
ESP_LOGW(TAG, "Timed out writing %d bytes of data", len); ESP_LOGW(TAG, "Timeout writing %d bytes", len);
return false; return false;
} }
@@ -407,7 +417,7 @@ bool ESPHomeOTAComponent::writeall_(const uint8_t *buf, size_t len) {
delay(1); delay(1);
continue; continue;
} }
ESP_LOGW(TAG, "Failed to write %d bytes of data, errno %d", len, errno); ESP_LOGW(TAG, "Error writing %d bytes, errno %d", len, errno);
return false; return false;
} else { } else {
at += written; at += written;
@@ -421,5 +431,17 @@ bool ESPHomeOTAComponent::writeall_(const uint8_t *buf, size_t len) {
float ESPHomeOTAComponent::get_setup_priority() const { return setup_priority::AFTER_WIFI; } float ESPHomeOTAComponent::get_setup_priority() const { return setup_priority::AFTER_WIFI; }
uint16_t ESPHomeOTAComponent::get_port() const { return this->port_; } uint16_t ESPHomeOTAComponent::get_port() const { return this->port_; }
void ESPHomeOTAComponent::set_port(uint16_t port) { this->port_ = port; } void ESPHomeOTAComponent::set_port(uint16_t port) { this->port_ = port; }
void ESPHomeOTAComponent::log_socket_error_(const char *msg) { ESP_LOGW(TAG, "Socket %s: errno %d", msg, errno); }
void ESPHomeOTAComponent::log_start_(const char *phase) {
ESP_LOGD(TAG, "Starting %s from %s", phase, this->client_->getpeername().c_str());
}
void ESPHomeOTAComponent::cleanup_connection_() {
this->client_->close();
this->client_ = nullptr;
}
} // namespace esphome } // namespace esphome
#endif #endif

View File

@@ -9,6 +9,8 @@
namespace esphome { namespace esphome {
static constexpr uint16_t OTA_SOCKET_TIMEOUT_DATA = 2500; // milliseconds for data transfer
/// ESPHomeOTAComponent provides a simple way to integrate Over-the-Air updates into your app using ArduinoOTA. /// ESPHomeOTAComponent provides a simple way to integrate Over-the-Air updates into your app using ArduinoOTA.
class ESPHomeOTAComponent : public ota::OTAComponent { class ESPHomeOTAComponent : public ota::OTAComponent {
public: public:
@@ -28,8 +30,11 @@ class ESPHomeOTAComponent : public ota::OTAComponent {
protected: protected:
void handle_(); void handle_();
bool readall_(uint8_t *buf, size_t len); bool readall_(uint8_t *buf, size_t len, uint16_t timeout = OTA_SOCKET_TIMEOUT_DATA);
bool writeall_(const uint8_t *buf, size_t len); bool writeall_(const uint8_t *buf, size_t len, uint16_t timeout = OTA_SOCKET_TIMEOUT_DATA);
void log_socket_error_(const char *msg);
void log_start_(const char *phase);
void cleanup_connection_();
#ifdef USE_OTA_PASSWORD #ifdef USE_OTA_PASSWORD
std::string password_; std::string password_;