From f46c499c4e4598b95d24f546a2d2cbaa13206850 Mon Sep 17 00:00:00 2001 From: Keith Burzinski Date: Wed, 15 May 2024 21:01:09 -0500 Subject: [PATCH] Separate `OTABackend` from OTA component (#6459) Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com> --- esphome/__main__.py | 25 +- .../esp32_ble_tracker/esp32_ble_tracker.cpp | 13 +- esphome/components/esphome/ota/__init__.py | 72 ++++++ .../ota/ota_esphome.cpp} | 215 ++++++++---------- .../ota/ota_esphome.h} | 62 +---- esphome/components/ota/__init__.py | 106 ++++----- esphome/components/ota/automation.h | 19 +- esphome/components/ota/ota_backend.cpp | 20 ++ esphome/components/ota/ota_backend.h | 79 ++++++- .../ota/ota_backend_arduino_esp32.cpp | 5 +- .../ota/ota_backend_arduino_esp32.h | 6 +- .../ota/ota_backend_arduino_esp8266.cpp | 9 +- .../ota/ota_backend_arduino_esp8266.h | 5 +- .../ota/ota_backend_arduino_libretiny.cpp | 9 +- .../ota/ota_backend_arduino_libretiny.h | 5 +- .../ota/ota_backend_arduino_rp2040.cpp | 9 +- .../ota/ota_backend_arduino_rp2040.h | 7 +- .../components/ota/ota_backend_esp_idf.cpp | 13 +- esphome/components/ota/ota_backend_esp_idf.h | 8 +- .../components/safe_mode/button/__init__.py | 16 +- .../safe_mode/button/safe_mode_button.cpp | 2 +- .../safe_mode/button/safe_mode_button.h | 8 +- .../components/safe_mode/switch/__init__.py | 14 +- .../safe_mode/switch/safe_mode_switch.cpp | 4 +- .../safe_mode/switch/safe_mode_switch.h | 8 +- esphome/cpp_helpers.py | 21 +- esphome/wizard.py | 5 +- tests/components/ota/common.yaml | 51 +++-- tests/components/safe_mode/common.yaml | 2 + tests/dummy_main.cpp | 4 +- tests/test1.yaml | 47 ++-- tests/test11.5.yaml | 1 + tests/test2.yaml | 7 +- tests/test3.1.yaml | 3 +- tests/test3.yaml | 7 +- tests/test4.yaml | 5 +- tests/test5.yaml | 1 + tests/test6.yaml | 1 + tests/test9.1.yaml | 1 + tests/test9.yaml | 1 + 40 files changed, 505 insertions(+), 391 deletions(-) create mode 100644 esphome/components/esphome/ota/__init__.py rename esphome/components/{ota/ota_component.cpp => esphome/ota/ota_esphome.cpp} (66%) rename esphome/components/{ota/ota_component.h => esphome/ota/ota_esphome.h} (50%) create mode 100644 esphome/components/ota/ota_backend.cpp diff --git a/esphome/__main__.py b/esphome/__main__.py index daf74eebb0..9930119c86 100644 --- a/esphome/__main__.py +++ b/esphome/__main__.py @@ -18,22 +18,23 @@ from esphome.const import ( CONF_BAUD_RATE, CONF_BROKER, CONF_DEASSERT_RTS_DTR, + CONF_DISABLED, + CONF_ESPHOME, CONF_LOGGER, + CONF_MDNS, + CONF_MQTT, CONF_NAME, CONF_OTA, - CONF_MQTT, - CONF_MDNS, - CONF_DISABLED, CONF_PASSWORD, - CONF_PORT, - CONF_ESPHOME, + CONF_PLATFORM, CONF_PLATFORMIO_OPTIONS, + CONF_PORT, CONF_SUBSTITUTIONS, PLATFORM_BK72XX, - PLATFORM_RTL87XX, PLATFORM_ESP32, PLATFORM_ESP8266, PLATFORM_RP2040, + PLATFORM_RTL87XX, SECRETS_FILES, ) from esphome.core import CORE, EsphomeError, coroutine @@ -330,15 +331,19 @@ def upload_program(config, args, host): return 1 # Unknown target platform - if CONF_OTA not in config: + ota_conf = {} + for ota_item in config.get(CONF_OTA, []): + if ota_item[CONF_PLATFORM] == CONF_ESPHOME: + ota_conf = ota_item + break + + if not ota_conf: raise EsphomeError( - "Cannot upload Over the Air as the config does not include the ota: " - "component" + f"Cannot upload Over the Air as the {CONF_OTA} configuration is not present or does not include {CONF_PLATFORM}: {CONF_ESPHOME}" ) from esphome import espota2 - ota_conf = config[CONF_OTA] remote_port = ota_conf[CONF_PORT] password = ota_conf.get(CONF_PASSWORD, "") diff --git a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp index 4ae7929ded..d154d4e519 100644 --- a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp +++ b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp @@ -18,7 +18,7 @@ #include #ifdef USE_OTA -#include "esphome/components/ota/ota_component.h" +#include "esphome/components/ota/ota_backend.h" #endif #ifdef USE_ARDUINO @@ -61,11 +61,12 @@ void ESP32BLETracker::setup() { this->scanner_idle_ = true; #ifdef USE_OTA - ota::global_ota_component->add_on_state_callback([this](ota::OTAState state, float progress, uint8_t error) { - if (state == ota::OTA_STARTED) { - this->stop_scan(); - } - }); + ota::get_global_ota_callback()->add_on_state_callback( + [this](ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) { + if (state == ota::OTA_STARTED) { + this->stop_scan(); + } + }); #endif } diff --git a/esphome/components/esphome/ota/__init__.py b/esphome/components/esphome/ota/__init__.py new file mode 100644 index 0000000000..abe9323b53 --- /dev/null +++ b/esphome/components/esphome/ota/__init__.py @@ -0,0 +1,72 @@ +from esphome.cpp_generator import RawExpression +import esphome.codegen as cg +import esphome.config_validation as cv +from esphome.components.ota import BASE_OTA_SCHEMA, ota_to_code, OTAComponent +from esphome.const import ( + CONF_ID, + CONF_NUM_ATTEMPTS, + CONF_OTA, + CONF_PASSWORD, + CONF_PORT, + CONF_REBOOT_TIMEOUT, + CONF_SAFE_MODE, + CONF_VERSION, + KEY_PAST_SAFE_MODE, +) +from esphome.core import CORE, coroutine_with_priority + + +CODEOWNERS = ["@esphome/core"] +AUTO_LOAD = ["md5", "socket"] +DEPENDENCIES = ["network"] + +esphome = cg.esphome_ns.namespace("esphome") +ESPHomeOTAComponent = esphome.class_("ESPHomeOTAComponent", OTAComponent) + + +CONFIG_SCHEMA = ( + cv.Schema( + { + cv.GenerateID(): cv.declare_id(ESPHomeOTAComponent), + cv.Optional(CONF_SAFE_MODE, default=True): cv.boolean, + cv.Optional(CONF_VERSION, default=2): cv.one_of(1, 2, int=True), + cv.SplitDefault( + CONF_PORT, + esp8266=8266, + esp32=3232, + rp2040=2040, + bk72xx=8892, + rtl87xx=8892, + ): cv.port, + cv.Optional(CONF_PASSWORD): cv.string, + cv.Optional( + CONF_REBOOT_TIMEOUT, default="5min" + ): cv.positive_time_period_milliseconds, + cv.Optional(CONF_NUM_ATTEMPTS, default="10"): cv.positive_not_null_int, + } + ) + .extend(BASE_OTA_SCHEMA) + .extend(cv.COMPONENT_SCHEMA) +) + + +@coroutine_with_priority(50.0) +async def to_code(config): + CORE.data[CONF_OTA] = {} + + var = cg.new_Pvariable(config[CONF_ID]) + await ota_to_code(var, config) + cg.add(var.set_port(config[CONF_PORT])) + if CONF_PASSWORD in config: + cg.add(var.set_auth_password(config[CONF_PASSWORD])) + cg.add_define("USE_OTA_PASSWORD") + cg.add_define("USE_OTA_VERSION", config[CONF_VERSION]) + + await cg.register_component(var, config) + + if config[CONF_SAFE_MODE]: + condition = var.should_enter_safe_mode( + config[CONF_NUM_ATTEMPTS], config[CONF_REBOOT_TIMEOUT] + ) + cg.add(RawExpression(f"if ({condition}) return")) + CORE.data[CONF_OTA][KEY_PAST_SAFE_MODE] = True diff --git a/esphome/components/ota/ota_component.cpp b/esphome/components/esphome/ota/ota_esphome.cpp similarity index 66% rename from esphome/components/ota/ota_component.cpp rename to esphome/components/esphome/ota/ota_esphome.cpp index 15af14ff1a..f2f1cfc6a8 100644 --- a/esphome/components/ota/ota_component.cpp +++ b/esphome/components/esphome/ota/ota_esphome.cpp @@ -1,55 +1,34 @@ -#include "ota_component.h" -#include "ota_backend.h" -#include "ota_backend_arduino_esp32.h" -#include "ota_backend_arduino_esp8266.h" -#include "ota_backend_arduino_rp2040.h" -#include "ota_backend_arduino_libretiny.h" -#include "ota_backend_esp_idf.h" +#include "ota_esphome.h" -#include "esphome/core/log.h" -#include "esphome/core/application.h" -#include "esphome/core/hal.h" -#include "esphome/core/util.h" #include "esphome/components/md5/md5.h" #include "esphome/components/network/util.h" +#include "esphome/components/ota/ota_backend.h" +#include "esphome/components/ota/ota_backend_arduino_esp32.h" +#include "esphome/components/ota/ota_backend_arduino_esp8266.h" +#include "esphome/components/ota/ota_backend_arduino_libretiny.h" +#include "esphome/components/ota/ota_backend_arduino_rp2040.h" +#include "esphome/components/ota/ota_backend_esp_idf.h" +#include "esphome/core/application.h" +#include "esphome/core/hal.h" +#include "esphome/core/log.h" +#include "esphome/core/util.h" #include #include namespace esphome { -namespace ota { -static const char *const TAG = "ota"; +static const char *const TAG = "esphome.ota"; static constexpr u_int16_t OTA_BLOCK_SIZE = 8192; -OTAComponent *global_ota_component = nullptr; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) - -std::unique_ptr make_ota_backend() { -#ifdef USE_ARDUINO -#ifdef USE_ESP8266 - return make_unique(); -#endif // USE_ESP8266 -#ifdef USE_ESP32 - return make_unique(); -#endif // USE_ESP32 -#endif // USE_ARDUINO -#ifdef USE_ESP_IDF - return make_unique(); -#endif // USE_ESP_IDF -#ifdef USE_RP2040 - return make_unique(); -#endif // USE_RP2040 -#ifdef USE_LIBRETINY - return make_unique(); +void ESPHomeOTAComponent::setup() { +#ifdef USE_OTA_STATE_CALLBACK + ota::register_ota_platform(this); #endif -} -OTAComponent::OTAComponent() { global_ota_component = this; } - -void OTAComponent::setup() { server_ = socket::socket_ip(SOCK_STREAM, 0); if (server_ == nullptr) { - ESP_LOGW(TAG, "Could not create socket."); + ESP_LOGW(TAG, "Could not create socket"); this->mark_failed(); return; } @@ -88,41 +67,39 @@ void OTAComponent::setup() { this->mark_failed(); return; } - - this->dump_config(); } -void OTAComponent::dump_config() { - ESP_LOGCONFIG(TAG, "Over-The-Air Updates:"); +void ESPHomeOTAComponent::dump_config() { + ESP_LOGCONFIG(TAG, "Over-The-Air updates:"); ESP_LOGCONFIG(TAG, " Address: %s:%u", network::get_use_address().c_str(), this->port_); + ESP_LOGCONFIG(TAG, " Version: %d", USE_OTA_VERSION); #ifdef USE_OTA_PASSWORD if (!this->password_.empty()) { - ESP_LOGCONFIG(TAG, " Using Password."); + ESP_LOGCONFIG(TAG, " Password configured"); } #endif - ESP_LOGCONFIG(TAG, " OTA version: %d.", USE_OTA_VERSION); if (this->has_safe_mode_ && this->safe_mode_rtc_value_ > 1 && - this->safe_mode_rtc_value_ != esphome::ota::OTAComponent::ENTER_SAFE_MODE_MAGIC) { - ESP_LOGW(TAG, "Last Boot was an unhandled reset, will proceed to safe mode in %" PRIu32 " restarts", + this->safe_mode_rtc_value_ != ESPHomeOTAComponent::ENTER_SAFE_MODE_MAGIC) { + ESP_LOGW(TAG, "Last reset occurred too quickly; safe mode will be invoked in %" PRIu32 " restarts", this->safe_mode_num_attempts_ - this->safe_mode_rtc_value_); } } -void OTAComponent::loop() { +void ESPHomeOTAComponent::loop() { this->handle_(); if (this->has_safe_mode_ && (millis() - this->safe_mode_start_time_) > this->safe_mode_enable_time_) { this->has_safe_mode_ = false; // successful boot, reset counter - ESP_LOGI(TAG, "Boot seems successful, resetting boot loop counter."); + ESP_LOGI(TAG, "Boot seems successful; resetting boot loop counter"); this->clean_rtc(); } } static const uint8_t FEATURE_SUPPORTS_COMPRESSION = 0x01; -void OTAComponent::handle_() { - OTAResponseTypes error_code = OTA_RESPONSE_ERROR_UNKNOWN; +void ESPHomeOTAComponent::handle_() { + ota::OTAResponseTypes error_code = ota::OTA_RESPONSE_ERROR_UNKNOWN; bool update_started = false; size_t total = 0; uint32_t last_progress = 0; @@ -130,7 +107,7 @@ void OTAComponent::handle_() { char *sbuf = reinterpret_cast(buf); size_t ota_size; uint8_t ota_features; - std::unique_ptr backend; + std::unique_ptr backend; (void) ota_features; #if USE_OTA_VERSION == 2 size_t size_acknowledged = 0; @@ -147,54 +124,54 @@ void OTAComponent::handle_() { int enable = 1; int err = client_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int)); if (err != 0) { - ESP_LOGW(TAG, "Socket could not enable tcp nodelay, errno: %d", errno); + ESP_LOGW(TAG, "Socket could not enable TCP nodelay, errno %d", errno); return; } - ESP_LOGD(TAG, "Starting OTA Update from %s...", this->client_->getpeername().c_str()); + ESP_LOGD(TAG, "Starting update from %s...", this->client_->getpeername().c_str()); this->status_set_warning(); #ifdef USE_OTA_STATE_CALLBACK - this->state_callback_.call(OTA_STARTED, 0.0f, 0); + this->state_callback_.call(ota::OTA_STARTED, 0.0f, 0); #endif if (!this->readall_(buf, 5)) { - ESP_LOGW(TAG, "Reading magic bytes failed!"); + ESP_LOGW(TAG, "Reading magic bytes failed"); goto error; // NOLINT(cppcoreguidelines-avoid-goto) } // 0x6C, 0x26, 0xF7, 0x5C, 0x45 if (buf[0] != 0x6C || buf[1] != 0x26 || buf[2] != 0xF7 || buf[3] != 0x5C || buf[4] != 0x45) { ESP_LOGW(TAG, "Magic bytes do not match! 0x%02X-0x%02X-0x%02X-0x%02X-0x%02X", buf[0], buf[1], buf[2], buf[3], buf[4]); - error_code = OTA_RESPONSE_ERROR_MAGIC; + error_code = ota::OTA_RESPONSE_ERROR_MAGIC; goto error; // NOLINT(cppcoreguidelines-avoid-goto) } // Send OK and version - 2 bytes - buf[0] = OTA_RESPONSE_OK; + buf[0] = ota::OTA_RESPONSE_OK; buf[1] = USE_OTA_VERSION; this->writeall_(buf, 2); - backend = make_ota_backend(); + backend = ota::make_ota_backend(); // Read features - 1 byte if (!this->readall_(buf, 1)) { - ESP_LOGW(TAG, "Reading features failed!"); + ESP_LOGW(TAG, "Reading features failed"); goto error; // NOLINT(cppcoreguidelines-avoid-goto) } ota_features = buf[0]; // NOLINT - ESP_LOGV(TAG, "OTA features is 0x%02X", ota_features); + ESP_LOGV(TAG, "Features: 0x%02X", ota_features); // Acknowledge header - 1 byte - buf[0] = OTA_RESPONSE_HEADER_OK; + buf[0] = ota::OTA_RESPONSE_HEADER_OK; if ((ota_features & FEATURE_SUPPORTS_COMPRESSION) != 0 && backend->supports_compression()) { - buf[0] = OTA_RESPONSE_SUPPORTS_COMPRESSION; + buf[0] = ota::OTA_RESPONSE_SUPPORTS_COMPRESSION; } this->writeall_(buf, 1); #ifdef USE_OTA_PASSWORD if (!this->password_.empty()) { - buf[0] = OTA_RESPONSE_REQUEST_AUTH; + buf[0] = ota::OTA_RESPONSE_REQUEST_AUTH; this->writeall_(buf, 1); md5::MD5Digest md5{}; md5.init(); @@ -206,7 +183,7 @@ void OTAComponent::handle_() { // Send nonce, 32 bytes hex MD5 if (!this->writeall_(reinterpret_cast(sbuf), 32)) { - ESP_LOGW(TAG, "Auth: Writing nonce failed!"); + ESP_LOGW(TAG, "Auth: Writing nonce failed"); goto error; // NOLINT(cppcoreguidelines-avoid-goto) } @@ -218,7 +195,7 @@ void OTAComponent::handle_() { // Receive cnonce, 32 bytes hex MD5 if (!this->readall_(buf, 32)) { - ESP_LOGW(TAG, "Auth: Reading cnonce failed!"); + ESP_LOGW(TAG, "Auth: Reading cnonce failed"); goto error; // NOLINT(cppcoreguidelines-avoid-goto) } sbuf[32] = '\0'; @@ -233,7 +210,7 @@ void OTAComponent::handle_() { // Receive result, 32 bytes hex MD5 if (!this->readall_(buf + 64, 32)) { - ESP_LOGW(TAG, "Auth: Reading response failed!"); + ESP_LOGW(TAG, "Auth: Reading response failed"); goto error; // NOLINT(cppcoreguidelines-avoid-goto) } sbuf[64 + 32] = '\0'; @@ -244,20 +221,20 @@ void OTAComponent::handle_() { matches = matches && buf[i] == buf[64 + i]; if (!matches) { - ESP_LOGW(TAG, "Auth failed! Passwords do not match!"); - error_code = OTA_RESPONSE_ERROR_AUTH_INVALID; + ESP_LOGW(TAG, "Auth failed! Passwords do not match"); + error_code = ota::OTA_RESPONSE_ERROR_AUTH_INVALID; goto error; // NOLINT(cppcoreguidelines-avoid-goto) } } #endif // USE_OTA_PASSWORD // Acknowledge auth OK - 1 byte - buf[0] = OTA_RESPONSE_AUTH_OK; + buf[0] = ota::OTA_RESPONSE_AUTH_OK; this->writeall_(buf, 1); // Read size, 4 bytes MSB first if (!this->readall_(buf, 4)) { - ESP_LOGW(TAG, "Reading size failed!"); + ESP_LOGW(TAG, "Reading size failed"); goto error; // NOLINT(cppcoreguidelines-avoid-goto) } ota_size = 0; @@ -265,20 +242,20 @@ void OTAComponent::handle_() { ota_size <<= 8; ota_size |= buf[i]; } - ESP_LOGV(TAG, "OTA size is %u bytes", ota_size); + ESP_LOGV(TAG, "Size is %u bytes", ota_size); error_code = backend->begin(ota_size); - if (error_code != OTA_RESPONSE_OK) + if (error_code != ota::OTA_RESPONSE_OK) goto error; // NOLINT(cppcoreguidelines-avoid-goto) update_started = true; // Acknowledge prepare OK - 1 byte - buf[0] = OTA_RESPONSE_UPDATE_PREPARE_OK; + buf[0] = ota::OTA_RESPONSE_UPDATE_PREPARE_OK; this->writeall_(buf, 1); // Read binary MD5, 32 bytes if (!this->readall_(buf, 32)) { - ESP_LOGW(TAG, "Reading binary MD5 checksum failed!"); + ESP_LOGW(TAG, "Reading binary MD5 checksum failed"); goto error; // NOLINT(cppcoreguidelines-avoid-goto) } sbuf[32] = '\0'; @@ -286,7 +263,7 @@ void OTAComponent::handle_() { backend->set_update_md5(sbuf); // Acknowledge MD5 OK - 1 byte - buf[0] = OTA_RESPONSE_BIN_MD5_OK; + buf[0] = ota::OTA_RESPONSE_BIN_MD5_OK; this->writeall_(buf, 1); while (total < ota_size) { @@ -299,7 +276,7 @@ void OTAComponent::handle_() { delay(1); continue; } - ESP_LOGW(TAG, "Error receiving data for update, errno: %d", errno); + ESP_LOGW(TAG, "Error receiving data for update, errno %d", errno); goto error; // NOLINT(cppcoreguidelines-avoid-goto) } else if (read == 0) { // $ man recv @@ -310,14 +287,14 @@ void OTAComponent::handle_() { } error_code = backend->write(buf, read); - if (error_code != OTA_RESPONSE_OK) { + if (error_code != ota::OTA_RESPONSE_OK) { ESP_LOGW(TAG, "Error writing binary data to flash!, error_code: %d", error_code); goto error; // NOLINT(cppcoreguidelines-avoid-goto) } total += read; #if USE_OTA_VERSION == 2 while (size_acknowledged + OTA_BLOCK_SIZE <= total || (total == ota_size && size_acknowledged < ota_size)) { - buf[0] = OTA_RESPONSE_CHUNK_OK; + buf[0] = ota::OTA_RESPONSE_CHUNK_OK; this->writeall_(buf, 1); size_acknowledged += OTA_BLOCK_SIZE; } @@ -327,9 +304,9 @@ void OTAComponent::handle_() { if (now - last_progress > 1000) { last_progress = now; float percentage = (total * 100.0f) / ota_size; - ESP_LOGD(TAG, "OTA in progress: %0.1f%%", percentage); + ESP_LOGD(TAG, "Progress: %0.1f%%", percentage); #ifdef USE_OTA_STATE_CALLBACK - this->state_callback_.call(OTA_IN_PROGRESS, percentage, 0); + this->state_callback_.call(ota::OTA_IN_PROGRESS, percentage, 0); #endif // feed watchdog and give other tasks a chance to run App.feed_wdt(); @@ -338,32 +315,32 @@ void OTAComponent::handle_() { } // Acknowledge receive OK - 1 byte - buf[0] = OTA_RESPONSE_RECEIVE_OK; + buf[0] = ota::OTA_RESPONSE_RECEIVE_OK; this->writeall_(buf, 1); error_code = backend->end(); - if (error_code != OTA_RESPONSE_OK) { - ESP_LOGW(TAG, "Error ending OTA!, error_code: %d", error_code); + if (error_code != ota::OTA_RESPONSE_OK) { + ESP_LOGW(TAG, "Error ending update! error_code: %d", error_code); goto error; // NOLINT(cppcoreguidelines-avoid-goto) } // Acknowledge Update end OK - 1 byte - buf[0] = OTA_RESPONSE_UPDATE_END_OK; + buf[0] = ota::OTA_RESPONSE_UPDATE_END_OK; this->writeall_(buf, 1); // Read ACK - if (!this->readall_(buf, 1) || buf[0] != OTA_RESPONSE_OK) { - ESP_LOGW(TAG, "Reading back acknowledgement failed!"); + if (!this->readall_(buf, 1) || buf[0] != ota::OTA_RESPONSE_OK) { + ESP_LOGW(TAG, "Reading back acknowledgement failed"); // do not go to error, this is not fatal } this->client_->close(); this->client_ = nullptr; delay(10); - ESP_LOGI(TAG, "OTA update finished!"); + ESP_LOGI(TAG, "Update complete"); this->status_clear_warning(); #ifdef USE_OTA_STATE_CALLBACK - this->state_callback_.call(OTA_COMPLETED, 100.0f, 0); + this->state_callback_.call(ota::OTA_COMPLETED, 100.0f, 0); #endif delay(100); // NOLINT App.safe_reboot(); @@ -380,11 +357,11 @@ error: this->status_momentary_error("onerror", 5000); #ifdef USE_OTA_STATE_CALLBACK - this->state_callback_.call(OTA_ERROR, 0.0f, static_cast(error_code)); + this->state_callback_.call(ota::OTA_ERROR, 0.0f, static_cast(error_code)); #endif } -bool OTAComponent::readall_(uint8_t *buf, size_t len) { +bool ESPHomeOTAComponent::readall_(uint8_t *buf, size_t len) { uint32_t start = millis(); uint32_t at = 0; while (len - at > 0) { @@ -401,7 +378,7 @@ bool OTAComponent::readall_(uint8_t *buf, size_t len) { delay(1); continue; } - ESP_LOGW(TAG, "Failed to read %d bytes of data, errno: %d", len, errno); + ESP_LOGW(TAG, "Failed to read %d bytes of data, errno %d", len, errno); return false; } else if (read == 0) { ESP_LOGW(TAG, "Remote closed connection"); @@ -415,7 +392,7 @@ bool OTAComponent::readall_(uint8_t *buf, size_t len) { return true; } -bool OTAComponent::writeall_(const uint8_t *buf, size_t len) { +bool ESPHomeOTAComponent::writeall_(const uint8_t *buf, size_t len) { uint32_t start = millis(); uint32_t at = 0; while (len - at > 0) { @@ -432,7 +409,7 @@ bool OTAComponent::writeall_(const uint8_t *buf, size_t len) { delay(1); continue; } - ESP_LOGW(TAG, "Failed to write %d bytes of data, errno: %d", len, errno); + ESP_LOGW(TAG, "Failed to write %d bytes of data, errno %d", len, errno); return false; } else { at += written; @@ -443,31 +420,31 @@ bool OTAComponent::writeall_(const uint8_t *buf, size_t len) { return true; } -float OTAComponent::get_setup_priority() const { return setup_priority::AFTER_WIFI; } -uint16_t OTAComponent::get_port() const { return this->port_; } -void OTAComponent::set_port(uint16_t port) { this->port_ = port; } +float ESPHomeOTAComponent::get_setup_priority() const { return setup_priority::AFTER_WIFI; } +uint16_t ESPHomeOTAComponent::get_port() const { return this->port_; } +void ESPHomeOTAComponent::set_port(uint16_t port) { this->port_ = port; } -void OTAComponent::set_safe_mode_pending(const bool &pending) { +void ESPHomeOTAComponent::set_safe_mode_pending(const bool &pending) { if (!this->has_safe_mode_) return; uint32_t current_rtc = this->read_rtc_(); - if (pending && current_rtc != esphome::ota::OTAComponent::ENTER_SAFE_MODE_MAGIC) { - ESP_LOGI(TAG, "Device will enter safe mode on next boot."); - this->write_rtc_(esphome::ota::OTAComponent::ENTER_SAFE_MODE_MAGIC); + if (pending && current_rtc != ESPHomeOTAComponent::ENTER_SAFE_MODE_MAGIC) { + ESP_LOGI(TAG, "Device will enter safe mode on next boot"); + this->write_rtc_(ESPHomeOTAComponent::ENTER_SAFE_MODE_MAGIC); } - if (!pending && current_rtc == esphome::ota::OTAComponent::ENTER_SAFE_MODE_MAGIC) { + if (!pending && current_rtc == ESPHomeOTAComponent::ENTER_SAFE_MODE_MAGIC) { ESP_LOGI(TAG, "Safe mode pending has been cleared"); this->clean_rtc(); } } -bool OTAComponent::get_safe_mode_pending() { - return this->has_safe_mode_ && this->read_rtc_() == esphome::ota::OTAComponent::ENTER_SAFE_MODE_MAGIC; +bool ESPHomeOTAComponent::get_safe_mode_pending() { + return this->has_safe_mode_ && this->read_rtc_() == ESPHomeOTAComponent::ENTER_SAFE_MODE_MAGIC; } -bool OTAComponent::should_enter_safe_mode(uint8_t num_attempts, uint32_t enable_time) { +bool ESPHomeOTAComponent::should_enter_safe_mode(uint8_t num_attempts, uint32_t enable_time) { this->has_safe_mode_ = true; this->safe_mode_start_time_ = millis(); this->safe_mode_enable_time_ = enable_time; @@ -475,24 +452,24 @@ bool OTAComponent::should_enter_safe_mode(uint8_t num_attempts, uint32_t enable_ this->rtc_ = global_preferences->make_preference(233825507UL, false); this->safe_mode_rtc_value_ = this->read_rtc_(); - bool is_manual_safe_mode = this->safe_mode_rtc_value_ == esphome::ota::OTAComponent::ENTER_SAFE_MODE_MAGIC; + bool is_manual_safe_mode = this->safe_mode_rtc_value_ == ESPHomeOTAComponent::ENTER_SAFE_MODE_MAGIC; if (is_manual_safe_mode) { ESP_LOGI(TAG, "Safe mode has been entered manually"); } else { - ESP_LOGCONFIG(TAG, "There have been %" PRIu32 " suspected unsuccessful boot attempts.", this->safe_mode_rtc_value_); + ESP_LOGCONFIG(TAG, "There have been %" PRIu32 " suspected unsuccessful boot attempts", this->safe_mode_rtc_value_); } if (this->safe_mode_rtc_value_ >= num_attempts || is_manual_safe_mode) { this->clean_rtc(); if (!is_manual_safe_mode) { - ESP_LOGE(TAG, "Boot loop detected. Proceeding to safe mode."); + ESP_LOGE(TAG, "Boot loop detected. Proceeding to safe mode"); } this->status_set_error(); this->set_timeout(enable_time, []() { - ESP_LOGE(TAG, "No OTA attempt made, restarting."); + ESP_LOGE(TAG, "No OTA attempt made, restarting"); App.reboot(); }); @@ -500,7 +477,7 @@ bool OTAComponent::should_enter_safe_mode(uint8_t num_attempts, uint32_t enable_ delay(300); // NOLINT App.setup(); - ESP_LOGI(TAG, "Waiting for OTA attempt."); + ESP_LOGI(TAG, "Waiting for OTA attempt"); return true; } else { @@ -509,27 +486,23 @@ bool OTAComponent::should_enter_safe_mode(uint8_t num_attempts, uint32_t enable_ return false; } } -void OTAComponent::write_rtc_(uint32_t val) { + +void ESPHomeOTAComponent::write_rtc_(uint32_t val) { this->rtc_.save(&val); global_preferences->sync(); } -uint32_t OTAComponent::read_rtc_() { + +uint32_t ESPHomeOTAComponent::read_rtc_() { uint32_t val; if (!this->rtc_.load(&val)) return 0; return val; } -void OTAComponent::clean_rtc() { this->write_rtc_(0); } -void OTAComponent::on_safe_shutdown() { - if (this->has_safe_mode_ && this->read_rtc_() != esphome::ota::OTAComponent::ENTER_SAFE_MODE_MAGIC) + +void ESPHomeOTAComponent::clean_rtc() { this->write_rtc_(0); } + +void ESPHomeOTAComponent::on_safe_shutdown() { + if (this->has_safe_mode_ && this->read_rtc_() != ESPHomeOTAComponent::ENTER_SAFE_MODE_MAGIC) this->clean_rtc(); } - -#ifdef USE_OTA_STATE_CALLBACK -void OTAComponent::add_on_state_callback(std::function &&callback) { - this->state_callback_.add(std::move(callback)); -} -#endif - -} // namespace ota } // namespace esphome diff --git a/esphome/components/ota/ota_component.h b/esphome/components/esphome/ota/ota_esphome.h similarity index 50% rename from esphome/components/ota/ota_component.h rename to esphome/components/esphome/ota/ota_esphome.h index c20f4f0709..e8f36f05ca 100644 --- a/esphome/components/ota/ota_component.h +++ b/esphome/components/esphome/ota/ota_esphome.h @@ -1,49 +1,16 @@ #pragma once -#include "esphome/components/socket/socket.h" -#include "esphome/core/component.h" -#include "esphome/core/preferences.h" -#include "esphome/core/helpers.h" #include "esphome/core/defines.h" +#include "esphome/core/helpers.h" +#include "esphome/core/preferences.h" +#include "esphome/components/ota/ota_backend.h" +#include "esphome/components/socket/socket.h" namespace esphome { -namespace ota { -enum OTAResponseTypes { - OTA_RESPONSE_OK = 0x00, - OTA_RESPONSE_REQUEST_AUTH = 0x01, - - OTA_RESPONSE_HEADER_OK = 0x40, - OTA_RESPONSE_AUTH_OK = 0x41, - OTA_RESPONSE_UPDATE_PREPARE_OK = 0x42, - OTA_RESPONSE_BIN_MD5_OK = 0x43, - OTA_RESPONSE_RECEIVE_OK = 0x44, - OTA_RESPONSE_UPDATE_END_OK = 0x45, - OTA_RESPONSE_SUPPORTS_COMPRESSION = 0x46, - OTA_RESPONSE_CHUNK_OK = 0x47, - - OTA_RESPONSE_ERROR_MAGIC = 0x80, - OTA_RESPONSE_ERROR_UPDATE_PREPARE = 0x81, - OTA_RESPONSE_ERROR_AUTH_INVALID = 0x82, - OTA_RESPONSE_ERROR_WRITING_FLASH = 0x83, - OTA_RESPONSE_ERROR_UPDATE_END = 0x84, - OTA_RESPONSE_ERROR_INVALID_BOOTSTRAPPING = 0x85, - OTA_RESPONSE_ERROR_WRONG_CURRENT_FLASH_CONFIG = 0x86, - OTA_RESPONSE_ERROR_WRONG_NEW_FLASH_CONFIG = 0x87, - OTA_RESPONSE_ERROR_ESP8266_NOT_ENOUGH_SPACE = 0x88, - OTA_RESPONSE_ERROR_ESP32_NOT_ENOUGH_SPACE = 0x89, - OTA_RESPONSE_ERROR_NO_UPDATE_PARTITION = 0x8A, - OTA_RESPONSE_ERROR_MD5_MISMATCH = 0x8B, - OTA_RESPONSE_ERROR_RP2040_NOT_ENOUGH_SPACE = 0x8C, - OTA_RESPONSE_ERROR_UNKNOWN = 0xFF, -}; - -enum OTAState { OTA_COMPLETED = 0, OTA_STARTED, OTA_IN_PROGRESS, OTA_ERROR }; - -/// OTAComponent provides a simple way to integrate Over-the-Air updates into your app using ArduinoOTA. -class OTAComponent : public Component { +/// ESPHomeOTAComponent provides a simple way to integrate Over-the-Air updates into your app using ArduinoOTA. +class ESPHomeOTAComponent : public ota::OTAComponent { public: - OTAComponent(); #ifdef USE_OTA_PASSWORD void set_auth_password(const std::string &password) { password_ = password; } #endif // USE_OTA_PASSWORD @@ -57,10 +24,6 @@ class OTAComponent : public Component { void set_safe_mode_pending(const bool &pending); bool get_safe_mode_pending(); -#ifdef USE_OTA_STATE_CALLBACK - void add_on_state_callback(std::function &&callback); -#endif - // ========== INTERNAL METHODS ========== // (In most use cases you won't need these) void setup() override; @@ -91,22 +54,15 @@ class OTAComponent : public Component { std::unique_ptr server_; std::unique_ptr client_; - bool has_safe_mode_{false}; ///< stores whether safe mode can be enabled. - uint32_t safe_mode_start_time_; ///< stores when safe mode was enabled. - uint32_t safe_mode_enable_time_{60000}; ///< The time safe mode should be on for. + bool has_safe_mode_{false}; ///< stores whether safe mode can be enabled + uint32_t safe_mode_start_time_; ///< stores when safe mode was enabled + uint32_t safe_mode_enable_time_{60000}; ///< The time safe mode should be on for uint32_t safe_mode_rtc_value_; uint8_t safe_mode_num_attempts_; ESPPreferenceObject rtc_; static const uint32_t ENTER_SAFE_MODE_MAGIC = 0x5afe5afe; ///< a magic number to indicate that safe mode should be entered on next boot - -#ifdef USE_OTA_STATE_CALLBACK - CallbackManager state_callback_{}; -#endif }; -extern OTAComponent *global_ota_component; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) - -} // namespace ota } // namespace esphome diff --git a/esphome/components/ota/__init__.py b/esphome/components/ota/__init__.py index 3c845490dc..728d36f3fa 100644 --- a/esphome/components/ota/__init__.py +++ b/esphome/components/ota/__init__.py @@ -1,71 +1,67 @@ -from esphome.cpp_generator import RawExpression import esphome.codegen as cg import esphome.config_validation as cv from esphome import automation -from esphome.const import ( - CONF_ID, - CONF_NUM_ATTEMPTS, - CONF_PASSWORD, - CONF_PORT, - CONF_REBOOT_TIMEOUT, - CONF_SAFE_MODE, - CONF_TRIGGER_ID, - CONF_OTA, - KEY_PAST_SAFE_MODE, - CONF_VERSION, -) from esphome.core import CORE, coroutine_with_priority -CODEOWNERS = ["@esphome/core"] -DEPENDENCIES = ["network"] -AUTO_LOAD = ["socket", "md5"] +from esphome.const import CONF_ESPHOME, CONF_OTA, CONF_PLATFORM, CONF_TRIGGER_ID -CONF_ON_STATE_CHANGE = "on_state_change" +CODEOWNERS = ["@esphome/core"] +AUTO_LOAD = ["md5"] + +IS_PLATFORM_COMPONENT = True + +CONF_ON_ABORT = "on_abort" CONF_ON_BEGIN = "on_begin" -CONF_ON_PROGRESS = "on_progress" CONF_ON_END = "on_end" CONF_ON_ERROR = "on_error" +CONF_ON_PROGRESS = "on_progress" +CONF_ON_STATE_CHANGE = "on_state_change" + ota_ns = cg.esphome_ns.namespace("ota") -OTAState = ota_ns.enum("OTAState") OTAComponent = ota_ns.class_("OTAComponent", cg.Component) +OTAState = ota_ns.enum("OTAState") +OTAAbortTrigger = ota_ns.class_("OTAAbortTrigger", automation.Trigger.template()) +OTAEndTrigger = ota_ns.class_("OTAEndTrigger", automation.Trigger.template()) +OTAErrorTrigger = ota_ns.class_("OTAErrorTrigger", automation.Trigger.template()) +OTAProgressTrigger = ota_ns.class_("OTAProgressTrigger", automation.Trigger.template()) +OTAStartTrigger = ota_ns.class_("OTAStartTrigger", automation.Trigger.template()) OTAStateChangeTrigger = ota_ns.class_( "OTAStateChangeTrigger", automation.Trigger.template() ) -OTAStartTrigger = ota_ns.class_("OTAStartTrigger", automation.Trigger.template()) -OTAProgressTrigger = ota_ns.class_("OTAProgressTrigger", automation.Trigger.template()) -OTAEndTrigger = ota_ns.class_("OTAEndTrigger", automation.Trigger.template()) -OTAErrorTrigger = ota_ns.class_("OTAErrorTrigger", automation.Trigger.template()) -CONFIG_SCHEMA = cv.Schema( +def _ota_final_validate(config): + if len(config) < 1: + raise cv.Invalid( + f"At least one platform must be specified for '{CONF_OTA}'; add '{CONF_PLATFORM}: {CONF_ESPHOME}' for original OTA functionality" + ) + + +FINAL_VALIDATE_SCHEMA = _ota_final_validate + +BASE_OTA_SCHEMA = cv.Schema( { - cv.GenerateID(): cv.declare_id(OTAComponent), - cv.Optional(CONF_SAFE_MODE, default=True): cv.boolean, - cv.Optional(CONF_VERSION, default=2): cv.one_of(1, 2, int=True), - cv.SplitDefault( - CONF_PORT, - esp8266=8266, - esp32=3232, - rp2040=2040, - bk72xx=8892, - rtl87xx=8892, - ): cv.port, - cv.Optional(CONF_PASSWORD): cv.string, - cv.Optional( - CONF_REBOOT_TIMEOUT, default="5min" - ): cv.positive_time_period_milliseconds, - cv.Optional(CONF_NUM_ATTEMPTS, default="10"): cv.positive_not_null_int, cv.Optional(CONF_ON_STATE_CHANGE): automation.validate_automation( { cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(OTAStateChangeTrigger), } ), + cv.Optional(CONF_ON_ABORT): automation.validate_automation( + { + cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(OTAAbortTrigger), + } + ), cv.Optional(CONF_ON_BEGIN): automation.validate_automation( { cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(OTAStartTrigger), } ), + cv.Optional(CONF_ON_END): automation.validate_automation( + { + cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(OTAEndTrigger), + } + ), cv.Optional(CONF_ON_ERROR): automation.validate_automation( { cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(OTAErrorTrigger), @@ -76,35 +72,13 @@ CONFIG_SCHEMA = cv.Schema( cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(OTAProgressTrigger), } ), - cv.Optional(CONF_ON_END): automation.validate_automation( - { - cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(OTAEndTrigger), - } - ), } -).extend(cv.COMPONENT_SCHEMA) +) @coroutine_with_priority(50.0) async def to_code(config): - CORE.data[CONF_OTA] = {} - - var = cg.new_Pvariable(config[CONF_ID]) - cg.add(var.set_port(config[CONF_PORT])) cg.add_define("USE_OTA") - if CONF_PASSWORD in config: - cg.add(var.set_auth_password(config[CONF_PASSWORD])) - cg.add_define("USE_OTA_PASSWORD") - cg.add_define("USE_OTA_VERSION", config[CONF_VERSION]) - - await cg.register_component(var, config) - - if config[CONF_SAFE_MODE]: - condition = var.should_enter_safe_mode( - config[CONF_NUM_ATTEMPTS], config[CONF_REBOOT_TIMEOUT] - ) - cg.add(RawExpression(f"if ({condition}) return")) - CORE.data[CONF_OTA][KEY_PAST_SAFE_MODE] = True if CORE.is_esp32 and CORE.using_arduino: cg.add_library("Update", None) @@ -112,11 +86,17 @@ async def to_code(config): if CORE.is_rp2040 and CORE.using_arduino: cg.add_library("Updater", None) + +async def ota_to_code(var, config): use_state_callback = False for conf in config.get(CONF_ON_STATE_CHANGE, []): trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) await automation.build_automation(trigger, [(OTAState, "state")], conf) use_state_callback = True + for conf in config.get(CONF_ON_ABORT, []): + trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) + await automation.build_automation(trigger, [], conf) + use_state_callback = True for conf in config.get(CONF_ON_BEGIN, []): trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) await automation.build_automation(trigger, [], conf) diff --git a/esphome/components/ota/automation.h b/esphome/components/ota/automation.h index 0c77a18ce1..4605193480 100644 --- a/esphome/components/ota/automation.h +++ b/esphome/components/ota/automation.h @@ -1,11 +1,8 @@ #pragma once - -#include "esphome/core/defines.h" #ifdef USE_OTA_STATE_CALLBACK +#include "ota_backend.h" -#include "esphome/core/component.h" #include "esphome/core/automation.h" -#include "esphome/components/ota/ota_component.h" namespace esphome { namespace ota { @@ -54,6 +51,17 @@ class OTAEndTrigger : public Trigger<> { } }; +class OTAAbortTrigger : public Trigger<> { + public: + explicit OTAAbortTrigger(OTAComponent *parent) { + parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) { + if (state == OTA_ABORT && !parent->is_failed()) { + trigger(); + } + }); + } +}; + class OTAErrorTrigger : public Trigger { public: explicit OTAErrorTrigger(OTAComponent *parent) { @@ -67,5 +75,4 @@ class OTAErrorTrigger : public Trigger { } // namespace ota } // namespace esphome - -#endif // USE_OTA_STATE_CALLBACK +#endif diff --git a/esphome/components/ota/ota_backend.cpp b/esphome/components/ota/ota_backend.cpp new file mode 100644 index 0000000000..30de4ec4b3 --- /dev/null +++ b/esphome/components/ota/ota_backend.cpp @@ -0,0 +1,20 @@ +#include "ota_backend.h" + +namespace esphome { +namespace ota { + +#ifdef USE_OTA_STATE_CALLBACK +OTAGlobalCallback *global_ota_callback{nullptr}; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) + +OTAGlobalCallback *get_global_ota_callback() { + if (global_ota_callback == nullptr) { + global_ota_callback = new OTAGlobalCallback(); // NOLINT(cppcoreguidelines-owning-memory) + } + return global_ota_callback; +} + +void register_ota_platform(OTAComponent *ota_caller) { get_global_ota_callback()->register_ota(ota_caller); } +#endif + +} // namespace ota +} // namespace esphome diff --git a/esphome/components/ota/ota_backend.h b/esphome/components/ota/ota_backend.h index 5c5b61a278..bc8ab46643 100644 --- a/esphome/components/ota/ota_backend.h +++ b/esphome/components/ota/ota_backend.h @@ -1,9 +1,53 @@ #pragma once -#include "ota_component.h" + +#include "esphome/core/component.h" +#include "esphome/core/defines.h" +#include "esphome/core/helpers.h" + +#ifdef USE_OTA_STATE_CALLBACK +#include "esphome/core/automation.h" +#endif namespace esphome { namespace ota { +enum OTAResponseTypes { + OTA_RESPONSE_OK = 0x00, + OTA_RESPONSE_REQUEST_AUTH = 0x01, + + OTA_RESPONSE_HEADER_OK = 0x40, + OTA_RESPONSE_AUTH_OK = 0x41, + OTA_RESPONSE_UPDATE_PREPARE_OK = 0x42, + OTA_RESPONSE_BIN_MD5_OK = 0x43, + OTA_RESPONSE_RECEIVE_OK = 0x44, + OTA_RESPONSE_UPDATE_END_OK = 0x45, + OTA_RESPONSE_SUPPORTS_COMPRESSION = 0x46, + OTA_RESPONSE_CHUNK_OK = 0x47, + + OTA_RESPONSE_ERROR_MAGIC = 0x80, + OTA_RESPONSE_ERROR_UPDATE_PREPARE = 0x81, + OTA_RESPONSE_ERROR_AUTH_INVALID = 0x82, + OTA_RESPONSE_ERROR_WRITING_FLASH = 0x83, + OTA_RESPONSE_ERROR_UPDATE_END = 0x84, + OTA_RESPONSE_ERROR_INVALID_BOOTSTRAPPING = 0x85, + OTA_RESPONSE_ERROR_WRONG_CURRENT_FLASH_CONFIG = 0x86, + OTA_RESPONSE_ERROR_WRONG_NEW_FLASH_CONFIG = 0x87, + OTA_RESPONSE_ERROR_ESP8266_NOT_ENOUGH_SPACE = 0x88, + OTA_RESPONSE_ERROR_ESP32_NOT_ENOUGH_SPACE = 0x89, + OTA_RESPONSE_ERROR_NO_UPDATE_PARTITION = 0x8A, + OTA_RESPONSE_ERROR_MD5_MISMATCH = 0x8B, + OTA_RESPONSE_ERROR_RP2040_NOT_ENOUGH_SPACE = 0x8C, + OTA_RESPONSE_ERROR_UNKNOWN = 0xFF, +}; + +enum OTAState { + OTA_COMPLETED = 0, + OTA_STARTED, + OTA_IN_PROGRESS, + OTA_ABORT, + OTA_ERROR, +}; + class OTABackend { public: virtual ~OTABackend() = default; @@ -15,5 +59,38 @@ class OTABackend { virtual bool supports_compression() = 0; }; +class OTAComponent : public Component { +#ifdef USE_OTA_STATE_CALLBACK + public: + void add_on_state_callback(std::function &&callback) { + this->state_callback_.add(std::move(callback)); + } + + protected: + CallbackManager state_callback_{}; +#endif +}; + +#ifdef USE_OTA_STATE_CALLBACK +class OTAGlobalCallback { + public: + void register_ota(OTAComponent *ota_caller) { + ota_caller->add_on_state_callback([this, ota_caller](OTAState state, float progress, uint8_t error) { + this->state_callback_.call(state, progress, error, ota_caller); + }); + } + void add_on_state_callback(std::function &&callback) { + this->state_callback_.add(std::move(callback)); + } + + protected: + CallbackManager state_callback_{}; +}; + +OTAGlobalCallback *get_global_ota_callback(); +void register_ota_platform(OTAComponent *ota_caller); +#endif +std::unique_ptr make_ota_backend(); + } // namespace ota } // namespace esphome diff --git a/esphome/components/ota/ota_backend_arduino_esp32.cpp b/esphome/components/ota/ota_backend_arduino_esp32.cpp index 4759737dbd..62c6a72388 100644 --- a/esphome/components/ota/ota_backend_arduino_esp32.cpp +++ b/esphome/components/ota/ota_backend_arduino_esp32.cpp @@ -1,8 +1,7 @@ -#include "esphome/core/defines.h" #ifdef USE_ESP32_FRAMEWORK_ARDUINO +#include "esphome/core/defines.h" #include "ota_backend_arduino_esp32.h" -#include "ota_component.h" #include "ota_backend.h" #include @@ -10,6 +9,8 @@ namespace esphome { namespace ota { +std::unique_ptr make_ota_backend() { return make_unique(); } + OTAResponseTypes ArduinoESP32OTABackend::begin(size_t image_size) { bool ret = Update.begin(image_size, U_FLASH); if (ret) { diff --git a/esphome/components/ota/ota_backend_arduino_esp32.h b/esphome/components/ota/ota_backend_arduino_esp32.h index f86a70d678..ac7fe9f14f 100644 --- a/esphome/components/ota/ota_backend_arduino_esp32.h +++ b/esphome/components/ota/ota_backend_arduino_esp32.h @@ -1,10 +1,10 @@ #pragma once -#include "esphome/core/defines.h" #ifdef USE_ESP32_FRAMEWORK_ARDUINO - -#include "ota_component.h" #include "ota_backend.h" +#include "esphome/core/defines.h" +#include "esphome/core/helpers.h" + namespace esphome { namespace ota { diff --git a/esphome/components/ota/ota_backend_arduino_esp8266.cpp b/esphome/components/ota/ota_backend_arduino_esp8266.cpp index 23dc0d4e21..b317075bd0 100644 --- a/esphome/components/ota/ota_backend_arduino_esp8266.cpp +++ b/esphome/components/ota/ota_backend_arduino_esp8266.cpp @@ -1,10 +1,9 @@ -#include "esphome/core/defines.h" #ifdef USE_ARDUINO #ifdef USE_ESP8266 - -#include "ota_backend_arduino_esp8266.h" -#include "ota_component.h" #include "ota_backend.h" +#include "ota_backend_arduino_esp8266.h" + +#include "esphome/core/defines.h" #include "esphome/components/esp8266/preferences.h" #include @@ -12,6 +11,8 @@ namespace esphome { namespace ota { +std::unique_ptr make_ota_backend() { return make_unique(); } + OTAResponseTypes ArduinoESP8266OTABackend::begin(size_t image_size) { bool ret = Update.begin(image_size, U_FLASH); if (ret) { diff --git a/esphome/components/ota/ota_backend_arduino_esp8266.h b/esphome/components/ota/ota_backend_arduino_esp8266.h index 7937c665b0..7f44d7c965 100644 --- a/esphome/components/ota/ota_backend_arduino_esp8266.h +++ b/esphome/components/ota/ota_backend_arduino_esp8266.h @@ -1,10 +1,9 @@ #pragma once -#include "esphome/core/defines.h" #ifdef USE_ARDUINO #ifdef USE_ESP8266 - -#include "ota_component.h" #include "ota_backend.h" + +#include "esphome/core/defines.h" #include "esphome/core/macros.h" namespace esphome { diff --git a/esphome/components/ota/ota_backend_arduino_libretiny.cpp b/esphome/components/ota/ota_backend_arduino_libretiny.cpp index dbf6c97988..df4e774ebc 100644 --- a/esphome/components/ota/ota_backend_arduino_libretiny.cpp +++ b/esphome/components/ota/ota_backend_arduino_libretiny.cpp @@ -1,15 +1,16 @@ -#include "esphome/core/defines.h" #ifdef USE_LIBRETINY - -#include "ota_backend_arduino_libretiny.h" -#include "ota_component.h" #include "ota_backend.h" +#include "ota_backend_arduino_libretiny.h" + +#include "esphome/core/defines.h" #include namespace esphome { namespace ota { +std::unique_ptr make_ota_backend() { return make_unique(); } + OTAResponseTypes ArduinoLibreTinyOTABackend::begin(size_t image_size) { bool ret = Update.begin(image_size, U_FLASH); if (ret) { diff --git a/esphome/components/ota/ota_backend_arduino_libretiny.h b/esphome/components/ota/ota_backend_arduino_libretiny.h index 79656bb353..11deb6e2f2 100644 --- a/esphome/components/ota/ota_backend_arduino_libretiny.h +++ b/esphome/components/ota/ota_backend_arduino_libretiny.h @@ -1,10 +1,9 @@ #pragma once -#include "esphome/core/defines.h" #ifdef USE_LIBRETINY - -#include "ota_component.h" #include "ota_backend.h" +#include "esphome/core/defines.h" + namespace esphome { namespace ota { diff --git a/esphome/components/ota/ota_backend_arduino_rp2040.cpp b/esphome/components/ota/ota_backend_arduino_rp2040.cpp index 260387cec1..4448b0c95e 100644 --- a/esphome/components/ota/ota_backend_arduino_rp2040.cpp +++ b/esphome/components/ota/ota_backend_arduino_rp2040.cpp @@ -1,17 +1,18 @@ -#include "esphome/core/defines.h" #ifdef USE_ARDUINO #ifdef USE_RP2040 - -#include "esphome/components/rp2040/preferences.h" #include "ota_backend.h" #include "ota_backend_arduino_rp2040.h" -#include "ota_component.h" + +#include "esphome/components/rp2040/preferences.h" +#include "esphome/core/defines.h" #include namespace esphome { namespace ota { +std::unique_ptr make_ota_backend() { return make_unique(); } + OTAResponseTypes ArduinoRP2040OTABackend::begin(size_t image_size) { bool ret = Update.begin(image_size, U_FLASH); if (ret) { diff --git a/esphome/components/ota/ota_backend_arduino_rp2040.h b/esphome/components/ota/ota_backend_arduino_rp2040.h index 5aa2ec9435..b189964ab3 100644 --- a/esphome/components/ota/ota_backend_arduino_rp2040.h +++ b/esphome/components/ota/ota_backend_arduino_rp2040.h @@ -1,11 +1,10 @@ #pragma once -#include "esphome/core/defines.h" #ifdef USE_ARDUINO #ifdef USE_RP2040 - -#include "esphome/core/macros.h" #include "ota_backend.h" -#include "ota_component.h" + +#include "esphome/core/defines.h" +#include "esphome/core/macros.h" namespace esphome { namespace ota { diff --git a/esphome/components/ota/ota_backend_esp_idf.cpp b/esphome/components/ota/ota_backend_esp_idf.cpp index 319a1482f1..6f45fb75e4 100644 --- a/esphome/components/ota/ota_backend_esp_idf.cpp +++ b/esphome/components/ota/ota_backend_esp_idf.cpp @@ -1,12 +1,11 @@ -#include "esphome/core/defines.h" #ifdef USE_ESP_IDF - -#include - #include "ota_backend_esp_idf.h" -#include "ota_component.h" -#include + #include "esphome/components/md5/md5.h" +#include "esphome/core/defines.h" + +#include +#include #if ESP_IDF_VERSION_MAJOR >= 5 #include @@ -15,6 +14,8 @@ namespace esphome { namespace ota { +std::unique_ptr make_ota_backend() { return make_unique(); } + OTAResponseTypes IDFOTABackend::begin(size_t image_size) { this->partition_ = esp_ota_get_next_update_partition(nullptr); if (this->partition_ == nullptr) { diff --git a/esphome/components/ota/ota_backend_esp_idf.h b/esphome/components/ota/ota_backend_esp_idf.h index af09d0d693..ed66d9b970 100644 --- a/esphome/components/ota/ota_backend_esp_idf.h +++ b/esphome/components/ota/ota_backend_esp_idf.h @@ -1,11 +1,11 @@ #pragma once -#include "esphome/core/defines.h" #ifdef USE_ESP_IDF - -#include "ota_component.h" #include "ota_backend.h" -#include + #include "esphome/components/md5/md5.h" +#include "esphome/core/defines.h" + +#include namespace esphome { namespace ota { diff --git a/esphome/components/safe_mode/button/__init__.py b/esphome/components/safe_mode/button/__init__.py index 307e4e372e..bd51d2e038 100644 --- a/esphome/components/safe_mode/button/__init__.py +++ b/esphome/components/safe_mode/button/__init__.py @@ -1,18 +1,17 @@ import esphome.codegen as cg import esphome.config_validation as cv from esphome.components import button -from esphome.components.ota import OTAComponent +from esphome.components.esphome.ota import ESPHomeOTAComponent from esphome.const import ( - CONF_ID, - CONF_OTA, + CONF_ESPHOME, DEVICE_CLASS_RESTART, ENTITY_CATEGORY_CONFIG, ICON_RESTART_ALERT, ) +from .. import safe_mode_ns -DEPENDENCIES = ["ota"] +DEPENDENCIES = ["ota.esphome"] -safe_mode_ns = cg.esphome_ns.namespace("safe_mode") SafeModeButton = safe_mode_ns.class_("SafeModeButton", button.Button, cg.Component) CONFIG_SCHEMA = ( @@ -22,15 +21,14 @@ CONFIG_SCHEMA = ( entity_category=ENTITY_CATEGORY_CONFIG, icon=ICON_RESTART_ALERT, ) - .extend({cv.GenerateID(CONF_OTA): cv.use_id(OTAComponent)}) + .extend({cv.GenerateID(CONF_ESPHOME): cv.use_id(ESPHomeOTAComponent)}) .extend(cv.COMPONENT_SCHEMA) ) async def to_code(config): - var = cg.new_Pvariable(config[CONF_ID]) + var = await button.new_button(config) await cg.register_component(var, config) - await button.register_button(var, config) - ota = await cg.get_variable(config[CONF_OTA]) + ota = await cg.get_variable(config[CONF_ESPHOME]) cg.add(var.set_ota(ota)) diff --git a/esphome/components/safe_mode/button/safe_mode_button.cpp b/esphome/components/safe_mode/button/safe_mode_button.cpp index 2b8654de46..d513b79c12 100644 --- a/esphome/components/safe_mode/button/safe_mode_button.cpp +++ b/esphome/components/safe_mode/button/safe_mode_button.cpp @@ -8,7 +8,7 @@ namespace safe_mode { static const char *const TAG = "safe_mode.button"; -void SafeModeButton::set_ota(ota::OTAComponent *ota) { this->ota_ = ota; } +void SafeModeButton::set_ota(esphome::ESPHomeOTAComponent *ota) { this->ota_ = ota; } void SafeModeButton::press_action() { ESP_LOGI(TAG, "Restarting device in safe mode..."); diff --git a/esphome/components/safe_mode/button/safe_mode_button.h b/esphome/components/safe_mode/button/safe_mode_button.h index 63e0d1755e..a306735b7f 100644 --- a/esphome/components/safe_mode/button/safe_mode_button.h +++ b/esphome/components/safe_mode/button/safe_mode_button.h @@ -1,8 +1,8 @@ #pragma once -#include "esphome/core/component.h" -#include "esphome/components/ota/ota_component.h" #include "esphome/components/button/button.h" +#include "esphome/components/esphome/ota/ota_esphome.h" +#include "esphome/core/component.h" namespace esphome { namespace safe_mode { @@ -10,10 +10,10 @@ namespace safe_mode { class SafeModeButton : public button::Button, public Component { public: void dump_config() override; - void set_ota(ota::OTAComponent *ota); + void set_ota(esphome::ESPHomeOTAComponent *ota); protected: - ota::OTAComponent *ota_; + esphome::ESPHomeOTAComponent *ota_; void press_action() override; }; diff --git a/esphome/components/safe_mode/switch/__init__.py b/esphome/components/safe_mode/switch/__init__.py index a6fcdfbece..0f8e500482 100644 --- a/esphome/components/safe_mode/switch/__init__.py +++ b/esphome/components/safe_mode/switch/__init__.py @@ -1,26 +1,26 @@ import esphome.codegen as cg import esphome.config_validation as cv from esphome.components import switch -from esphome.components.ota import OTAComponent +from esphome.components.esphome.ota import ESPHomeOTAComponent from esphome.const import ( - CONF_OTA, + CONF_ESPHOME, ENTITY_CATEGORY_CONFIG, ICON_RESTART_ALERT, ) from .. import safe_mode_ns -DEPENDENCIES = ["ota"] +DEPENDENCIES = ["ota.esphome"] SafeModeSwitch = safe_mode_ns.class_("SafeModeSwitch", switch.Switch, cg.Component) CONFIG_SCHEMA = ( switch.switch_schema( SafeModeSwitch, - icon=ICON_RESTART_ALERT, - entity_category=ENTITY_CATEGORY_CONFIG, block_inverted=True, + entity_category=ENTITY_CATEGORY_CONFIG, + icon=ICON_RESTART_ALERT, ) - .extend({cv.GenerateID(CONF_OTA): cv.use_id(OTAComponent)}) + .extend({cv.GenerateID(CONF_ESPHOME): cv.use_id(ESPHomeOTAComponent)}) .extend(cv.COMPONENT_SCHEMA) ) @@ -29,5 +29,5 @@ async def to_code(config): var = await switch.new_switch(config) await cg.register_component(var, config) - ota = await cg.get_variable(config[CONF_OTA]) + ota = await cg.get_variable(config[CONF_ESPHOME]) cg.add(var.set_ota(ota)) diff --git a/esphome/components/safe_mode/switch/safe_mode_switch.cpp b/esphome/components/safe_mode/switch/safe_mode_switch.cpp index a3979eec06..71408df140 100644 --- a/esphome/components/safe_mode/switch/safe_mode_switch.cpp +++ b/esphome/components/safe_mode/switch/safe_mode_switch.cpp @@ -1,14 +1,14 @@ #include "safe_mode_switch.h" +#include "esphome/core/application.h" #include "esphome/core/hal.h" #include "esphome/core/log.h" -#include "esphome/core/application.h" namespace esphome { namespace safe_mode { static const char *const TAG = "safe_mode_switch"; -void SafeModeSwitch::set_ota(ota::OTAComponent *ota) { this->ota_ = ota; } +void SafeModeSwitch::set_ota(esphome::ESPHomeOTAComponent *ota) { this->ota_ = ota; } void SafeModeSwitch::write_state(bool state) { // Acknowledge diff --git a/esphome/components/safe_mode/switch/safe_mode_switch.h b/esphome/components/safe_mode/switch/safe_mode_switch.h index 2772db3d84..5bd15a44de 100644 --- a/esphome/components/safe_mode/switch/safe_mode_switch.h +++ b/esphome/components/safe_mode/switch/safe_mode_switch.h @@ -1,8 +1,8 @@ #pragma once -#include "esphome/core/component.h" -#include "esphome/components/ota/ota_component.h" +#include "esphome/components/esphome/ota/ota_esphome.h" #include "esphome/components/switch/switch.h" +#include "esphome/core/component.h" namespace esphome { namespace safe_mode { @@ -10,10 +10,10 @@ namespace safe_mode { class SafeModeSwitch : public switch_::Switch, public Component { public: void dump_config() override; - void set_ota(ota::OTAComponent *ota); + void set_ota(esphome::ESPHomeOTAComponent *ota); protected: - ota::OTAComponent *ota_; + esphome::ESPHomeOTAComponent *ota_; void write_state(bool state) override; }; diff --git a/esphome/cpp_helpers.py b/esphome/cpp_helpers.py index 4b3716e223..ce494e5d9d 100644 --- a/esphome/cpp_helpers.py +++ b/esphome/cpp_helpers.py @@ -3,14 +3,16 @@ import logging from esphome.const import ( CONF_DISABLED_BY_DEFAULT, CONF_ENTITY_CATEGORY, + CONF_ESPHOME, CONF_ICON, CONF_INTERNAL, CONF_NAME, - CONF_SETUP_PRIORITY, - CONF_UPDATE_INTERVAL, - CONF_TYPE_ID, CONF_OTA, + CONF_PLATFORM, CONF_SAFE_MODE, + CONF_SETUP_PRIORITY, + CONF_TYPE_ID, + CONF_UPDATE_INTERVAL, KEY_PAST_SAFE_MODE, ) @@ -139,9 +141,16 @@ async def build_registry_list(registry, config): async def past_safe_mode(): - safe_mode_enabled = ( - CONF_OTA in CORE.config and CORE.config[CONF_OTA][CONF_SAFE_MODE] - ) + ota_conf = {} + for ota_item in CORE.config.get(CONF_OTA, []): + if ota_item[CONF_PLATFORM] == CONF_ESPHOME: + ota_conf = ota_item + break + + if not ota_conf: + return + + safe_mode_enabled = ota_conf[CONF_SAFE_MODE] if not safe_mode_enabled: return diff --git a/esphome/wizard.py b/esphome/wizard.py index 4ec366bbb9..9680ade044 100644 --- a/esphome/wizard.py +++ b/esphome/wizard.py @@ -153,10 +153,11 @@ def wizard_file(**kwargs): # Configure OTA config += "\nota:\n" + config += " - platform: esphome\n" if "ota_password" in kwargs: - config += f" password: \"{kwargs['ota_password']}\"" + config += f" password: \"{kwargs['ota_password']}\"" elif "password" in kwargs: - config += f" password: \"{kwargs['password']}\"" + config += f" password: \"{kwargs['password']}\"" # Configuring wifi config += "\n\nwifi:\n" diff --git a/tests/components/ota/common.yaml b/tests/components/ota/common.yaml index 367454995f..4910e2d891 100644 --- a/tests/components/ota/common.yaml +++ b/tests/components/ota/common.yaml @@ -3,28 +3,29 @@ wifi: password: password1 ota: - safe_mode: true - password: "superlongpasswordthatnoonewillknow" - port: 3286 - reboot_timeout: 2min - num_attempts: 5 - on_begin: - then: - - logger.log: "OTA start" - on_progress: - then: - - logger.log: - format: "OTA progress %0.1f%%" - args: ["x"] - on_end: - then: - - logger.log: "OTA end" - on_error: - then: - - logger.log: - format: "OTA update error %d" - args: ["x"] - on_state_change: - then: - lambda: >- - ESP_LOGD("ota", "State %d", state); + - platform: esphome + safe_mode: true + password: "superlongpasswordthatnoonewillknow" + port: 3286 + reboot_timeout: 2min + num_attempts: 5 + on_begin: + then: + - logger.log: "OTA start" + on_progress: + then: + - logger.log: + format: "OTA progress %0.1f%%" + args: ["x"] + on_end: + then: + - logger.log: "OTA end" + on_error: + then: + - logger.log: + format: "OTA update error %d" + args: ["x"] + on_state_change: + then: + lambda: >- + ESP_LOGD("ota", "State %d", state); diff --git a/tests/components/safe_mode/common.yaml b/tests/components/safe_mode/common.yaml index df0abd9aec..1dfc516254 100644 --- a/tests/components/safe_mode/common.yaml +++ b/tests/components/safe_mode/common.yaml @@ -3,6 +3,8 @@ wifi: password: password1 ota: + - platform: esphome + safe_mode: true button: - platform: safe_mode diff --git a/tests/dummy_main.cpp b/tests/dummy_main.cpp index da5c6d10d0..3ba4c8bd07 100644 --- a/tests/dummy_main.cpp +++ b/tests/dummy_main.cpp @@ -5,7 +5,7 @@ #include #include -#include +#include #include #include @@ -25,7 +25,7 @@ void setup() { ap.set_password("password1"); wifi->add_sta(ap); - auto *ota = new ota::OTAComponent(); // NOLINT + auto *ota = new esphome::ESPHomeOTAComponent(); // NOLINT ota->set_port(8266); App.setup(); diff --git a/tests/test1.yaml b/tests/test1.yaml index 79b836da4a..dc46b55c44 100644 --- a/tests/test1.yaml +++ b/tests/test1.yaml @@ -265,29 +265,30 @@ uart: baud_rate: 9600 ota: - safe_mode: true - password: "superlongpasswordthatnoonewillknow" - port: 3286 - reboot_timeout: 2min - num_attempts: 5 - on_state_change: - then: - lambda: >- - ESP_LOGD("ota", "State %d", state); - on_begin: - then: - logger.log: OTA begin - on_progress: - then: - lambda: >- - ESP_LOGD("ota", "Got progress %f", x); - on_end: - then: - logger.log: OTA end - on_error: - then: - lambda: >- - ESP_LOGD("ota", "Got error code %d", x); + - platform: esphome + safe_mode: true + password: "superlongpasswordthatnoonewillknow" + port: 3286 + reboot_timeout: 2min + num_attempts: 5 + on_state_change: + then: + lambda: >- + ESP_LOGD("ota", "State %d", state); + on_begin: + then: + logger.log: OTA begin + on_progress: + then: + lambda: >- + ESP_LOGD("ota", "Got progress %f", x); + on_end: + then: + logger.log: OTA end + on_error: + then: + lambda: >- + ESP_LOGD("ota", "Got error code %d", x); logger: baud_rate: 0 diff --git a/tests/test11.5.yaml b/tests/test11.5.yaml index 13de7f1929..758f295a6c 100644 --- a/tests/test11.5.yaml +++ b/tests/test11.5.yaml @@ -31,6 +31,7 @@ network: api: ota: + - platform: esphome logger: diff --git a/tests/test2.yaml b/tests/test2.yaml index 970076e78b..54ff4807a3 100644 --- a/tests/test2.yaml +++ b/tests/test2.yaml @@ -80,9 +80,10 @@ uart: - lambda: UARTDebug::log_hex(direction, bytes, ':'); ota: - safe_mode: true - port: 3286 - num_attempts: 15 + - platform: esphome + safe_mode: true + port: 3286 + num_attempts: 15 logger: level: DEBUG diff --git a/tests/test3.1.yaml b/tests/test3.1.yaml index 2bddd6f4d7..18d92289cd 100644 --- a/tests/test3.1.yaml +++ b/tests/test3.1.yaml @@ -49,7 +49,8 @@ spi: number: GPIO14 ota: - version: 2 + - platform: esphome + version: 2 logger: diff --git a/tests/test3.yaml b/tests/test3.yaml index 61d814385b..7554d4bcb2 100644 --- a/tests/test3.yaml +++ b/tests/test3.yaml @@ -328,9 +328,10 @@ vbus: uart_id: uart_4 ota: - safe_mode: true - port: 3286 - reboot_timeout: 15min + - platform: esphome + safe_mode: true + port: 3286 + reboot_timeout: 15min logger: hardware_uart: UART1 diff --git a/tests/test4.yaml b/tests/test4.yaml index 993ce126a8..86beee81c6 100644 --- a/tests/test4.yaml +++ b/tests/test4.yaml @@ -103,8 +103,9 @@ uart: parity: EVEN ota: - safe_mode: true - port: 3286 + - platform: esphome + safe_mode: true + port: 3286 logger: level: DEBUG diff --git a/tests/test5.yaml b/tests/test5.yaml index afd3359098..f7a34d5a1b 100644 --- a/tests/test5.yaml +++ b/tests/test5.yaml @@ -28,6 +28,7 @@ network: api: ota: + - platform: esphome logger: diff --git a/tests/test6.yaml b/tests/test6.yaml index 2c5aa30aad..b1103eb126 100644 --- a/tests/test6.yaml +++ b/tests/test6.yaml @@ -22,6 +22,7 @@ network: api: ota: + - platform: esphome logger: diff --git a/tests/test9.1.yaml b/tests/test9.1.yaml index f7455b7668..2d205ef4e6 100644 --- a/tests/test9.1.yaml +++ b/tests/test9.1.yaml @@ -12,6 +12,7 @@ esphome: logger: ota: + - platform: esphome captive_portal: diff --git a/tests/test9.yaml b/tests/test9.yaml index d660b4f24a..5017ccc5ed 100644 --- a/tests/test9.yaml +++ b/tests/test9.yaml @@ -12,6 +12,7 @@ esphome: logger: ota: + - platform: esphome captive_portal: