From 0475dd8af4efefb0b6382c25f1e3fa03341f2b09 Mon Sep 17 00:00:00 2001 From: Jesse Hills <3060199+jesserockz@users.noreply.github.com> Date: Mon, 16 Dec 2024 15:28:30 +1300 Subject: [PATCH] Multi source ota update --- .../http_request/update/__init__.py | 14 ++++----- .../update/http_request_update.cpp | 30 ++++++++++++++----- .../http_request/update/http_request_update.h | 6 ++-- 3 files changed, 31 insertions(+), 19 deletions(-) diff --git a/esphome/components/http_request/update/__init__.py b/esphome/components/http_request/update/__init__.py index 356afa1432..74b1ca2803 100644 --- a/esphome/components/http_request/update/__init__.py +++ b/esphome/components/http_request/update/__init__.py @@ -1,15 +1,11 @@ -import esphome.config_validation as cv import esphome.codegen as cg - from esphome.components import update -from esphome.const import ( - CONF_SOURCE, -) +import esphome.config_validation as cv +from esphome.const import CONF_SOURCE -from .. import http_request_ns, CONF_HTTP_REQUEST_ID, HttpRequestComponent +from .. import CONF_HTTP_REQUEST_ID, HttpRequestComponent, http_request_ns from ..ota import OtaHttpRequestComponent - AUTO_LOAD = ["json"] CODEOWNERS = ["@jesserockz"] DEPENDENCIES = ["ota.http_request"] @@ -25,7 +21,7 @@ CONFIG_SCHEMA = update.UPDATE_SCHEMA.extend( cv.GenerateID(): cv.declare_id(HttpRequestUpdate), cv.GenerateID(CONF_OTA_ID): cv.use_id(OtaHttpRequestComponent), cv.GenerateID(CONF_HTTP_REQUEST_ID): cv.use_id(HttpRequestComponent), - cv.Required(CONF_SOURCE): cv.url, + cv.Required(CONF_SOURCE): cv.ensure_list(cv.url), } ).extend(cv.polling_component_schema("6h")) @@ -37,7 +33,7 @@ async def to_code(config): request_parent = await cg.get_variable(config[CONF_HTTP_REQUEST_ID]) cg.add(var.set_request_parent(request_parent)) - cg.add(var.set_source_url(config[CONF_SOURCE])) + cg.add(var.set_source_urls(config[CONF_SOURCE])) cg.add_define("USE_OTA_STATE_CALLBACK") diff --git a/esphome/components/http_request/update/http_request_update.cpp b/esphome/components/http_request/update/http_request_update.cpp index 0e0966c22b..17a38e143b 100644 --- a/esphome/components/http_request/update/http_request_update.cpp +++ b/esphome/components/http_request/update/http_request_update.cpp @@ -21,19 +21,33 @@ void HttpRequestUpdate::setup() { this->update_info_.progress = progress; this->publish_state(); } else if (state == ota::OTAState::OTA_ABORT || state == ota::OTAState::OTA_ERROR) { - this->state_ = update::UPDATE_STATE_AVAILABLE; - this->status_set_error("Failed to install firmware"); - this->publish_state(); + if (this->current_source_ + 1 < this->source_urls_.size()) { + this->current_source_++; + this->defer("update", [this]() { + this->update(); + this->perform(true); + }); + } else { + this->current_source_ = 0; + this->state_ = update::UPDATE_STATE_AVAILABLE; + this->status_set_error("Failed to install firmware"); + this->publish_state(); + } } }); } void HttpRequestUpdate::update() { - auto container = this->request_parent_->get(this->source_url_); + std::string current_source = this->source_urls_[this->current_source_]; + auto container = this->request_parent_->get(current_source); if (container == nullptr || container->status_code != HTTP_STATUS_OK) { - std::string msg = str_sprintf("Failed to fetch manifest from %s", this->source_url_.c_str()); + std::string msg = str_sprintf("Failed to fetch manifest from %s", current_source.c_str()); this->status_set_error(msg.c_str()); + if (this->current_source_ + 1 < this->source_urls_.size()) { + this->current_source_++; + this->defer("update", [this]() { this->update(); }); + } return; } @@ -99,7 +113,7 @@ void HttpRequestUpdate::update() { }); if (!valid) { - std::string msg = str_sprintf("Failed to parse JSON from %s", this->source_url_.c_str()); + std::string msg = str_sprintf("Failed to parse JSON from %s", current_source.c_str()); this->status_set_error(msg.c_str()); return; } @@ -108,10 +122,10 @@ void HttpRequestUpdate::update() { if (this->update_info_.firmware_url.find("http") == std::string::npos) { std::string path = this->update_info_.firmware_url; if (path[0] == '/') { - std::string domain = this->source_url_.substr(0, this->source_url_.find('/', 8)); + std::string domain = current_source.substr(0, current_source.find('/', 8)); this->update_info_.firmware_url = domain + path; } else { - std::string domain = this->source_url_.substr(0, this->source_url_.rfind('/') + 1); + std::string domain = current_source.substr(0, current_source.rfind('/') + 1); this->update_info_.firmware_url = domain + path; } } diff --git a/esphome/components/http_request/update/http_request_update.h b/esphome/components/http_request/update/http_request_update.h index 45c7e6a447..30e3db545d 100644 --- a/esphome/components/http_request/update/http_request_update.h +++ b/esphome/components/http_request/update/http_request_update.h @@ -18,7 +18,8 @@ class HttpRequestUpdate : public update::UpdateEntity, public PollingComponent { void perform(bool force) override; void check() override { this->update(); } - void set_source_url(const std::string &source_url) { this->source_url_ = source_url; } + void set_source_url(const std::string &source_urls) { this->source_urls_ = {source_urls}; } + void set_source_urls(const std::vector &source_urls) { this->source_urls_ = source_urls; } void set_request_parent(HttpRequestComponent *request_parent) { this->request_parent_ = request_parent; } void set_ota_parent(OtaHttpRequestComponent *ota_parent) { this->ota_parent_ = ota_parent; } @@ -28,7 +29,8 @@ class HttpRequestUpdate : public update::UpdateEntity, public PollingComponent { protected: HttpRequestComponent *request_parent_; OtaHttpRequestComponent *ota_parent_; - std::string source_url_; + std::vector source_urls_; + size_t current_source_{0}; }; } // namespace http_request