From 8426084ebfe523066855c20dcd636b0dfb04bc5d Mon Sep 17 00:00:00 2001 From: Keith Burzinski Date: Sun, 21 Apr 2024 04:14:07 -0500 Subject: [PATCH] Refactor more bits into the base `ota` component --- esphome/components/esphome/ota/__init__.py | 112 +++++------------- esphome/components/esphome/ota/automation.h | 69 ----------- .../components/esphome/ota/ota_esphome.cpp | 19 ++- esphome/components/esphome/ota/ota_esphome.h | 16 +-- esphome/components/ota/__init__.py | 52 +++++++- esphome/components/ota/ota_backend.h | 81 +++++++++++++ 6 files changed, 175 insertions(+), 174 deletions(-) delete mode 100644 esphome/components/esphome/ota/automation.h diff --git a/esphome/components/esphome/ota/__init__.py b/esphome/components/esphome/ota/__init__.py index d94e2ab921..e6f99ba18a 100644 --- a/esphome/components/esphome/ota/__init__.py +++ b/esphome/components/esphome/ota/__init__.py @@ -2,6 +2,7 @@ from esphome.cpp_generator import RawExpression import esphome.codegen as cg import esphome.config_validation as cv from esphome import automation +from esphome.components import ota from esphome.const import ( CONF_ID, CONF_NUM_ATTEMPTS, @@ -21,83 +22,34 @@ CODEOWNERS = ["@esphome/core"] AUTO_LOAD = ["md5", "socket"] DEPENDENCIES = ["network"] -CONF_ON_BEGIN = "on_begin" -CONF_ON_END = "on_end" -CONF_ON_ERROR = "on_error" -CONF_ON_PROGRESS = "on_progress" -CONF_ON_STATE_CHANGE = "on_state_change" - esphome = cg.esphome_ns.namespace("esphome") - -ESPHomeOTAComponent = esphome.class_("ESPHomeOTAComponent", cg.Component) -ESPHomeOTAEndTrigger = esphome.class_( - "ESPHomeOTAEndTrigger", automation.Trigger.template() -) -ESPHomeOTAErrorTrigger = esphome.class_( - "ESPHomeOTAErrorTrigger", automation.Trigger.template() -) -ESPHomeOTAProgressTrigger = esphome.class_( - "ESPHomeOTAProgressTrigger", automation.Trigger.template() -) -ESPHomeOTAStartTrigger = esphome.class_( - "ESPHomeOTAStartTrigger", automation.Trigger.template() -) -ESPHomeOTAStateChangeTrigger = esphome.class_( - "ESPHomeOTAStateChangeTrigger", automation.Trigger.template() -) - -ESPHomeOTAState = esphome.enum("ESPHomeOTAState") +ESPHomeOTAComponent = esphome.class_("ESPHomeOTAComponent", ota.OTAComponent) -CONFIG_SCHEMA = cv.Schema( - { - cv.GenerateID(): cv.declare_id(ESPHomeOTAComponent), - cv.Optional(CONF_SAFE_MODE, default=True): cv.boolean, - cv.Optional(CONF_VERSION, default=2): cv.one_of(1, 2, int=True), - cv.SplitDefault( - CONF_PORT, - esp8266=8266, - esp32=3232, - rp2040=2040, - bk72xx=8892, - rtl87xx=8892, - ): cv.port, - cv.Optional(CONF_PASSWORD): cv.string, - cv.Optional( - CONF_REBOOT_TIMEOUT, default="5min" - ): cv.positive_time_period_milliseconds, - cv.Optional(CONF_NUM_ATTEMPTS, default="10"): cv.positive_not_null_int, - cv.Optional(CONF_ON_STATE_CHANGE): automation.validate_automation( - { - cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id( - ESPHomeOTAStateChangeTrigger - ), - } - ), - cv.Optional(CONF_ON_BEGIN): automation.validate_automation( - { - cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(ESPHomeOTAStartTrigger), - } - ), - cv.Optional(CONF_ON_END): automation.validate_automation( - { - cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(ESPHomeOTAEndTrigger), - } - ), - cv.Optional(CONF_ON_ERROR): automation.validate_automation( - { - cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(ESPHomeOTAErrorTrigger), - } - ), - cv.Optional(CONF_ON_PROGRESS): automation.validate_automation( - { - cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id( - ESPHomeOTAProgressTrigger - ), - } - ), - } -).extend(cv.COMPONENT_SCHEMA) +CONFIG_SCHEMA = ( + cv.Schema( + { + cv.GenerateID(): cv.declare_id(ESPHomeOTAComponent), + cv.Optional(CONF_SAFE_MODE, default=True): cv.boolean, + cv.Optional(CONF_VERSION, default=2): cv.one_of(1, 2, int=True), + cv.SplitDefault( + CONF_PORT, + esp8266=8266, + esp32=3232, + rp2040=2040, + bk72xx=8892, + rtl87xx=8892, + ): cv.port, + cv.Optional(CONF_PASSWORD): cv.string, + cv.Optional( + CONF_REBOOT_TIMEOUT, default="5min" + ): cv.positive_time_period_milliseconds, + cv.Optional(CONF_NUM_ATTEMPTS, default="10"): cv.positive_not_null_int, + } + ) + .extend(ota.BASE_OTA_SCHEMA) + .extend(cv.COMPONENT_SCHEMA) +) @coroutine_with_priority(50.0) @@ -128,23 +80,23 @@ async def to_code(config): cg.add_library("Updater", None) use_state_callback = False - for conf in config.get(CONF_ON_STATE_CHANGE, []): + for conf in config.get(ota.CONF_ON_STATE_CHANGE, []): trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) - await automation.build_automation(trigger, [(ESPHomeOTAState, "state")], conf) + await automation.build_automation(trigger, [(ota.OTAState, "state")], conf) use_state_callback = True - for conf in config.get(CONF_ON_BEGIN, []): + for conf in config.get(ota.CONF_ON_BEGIN, []): trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) await automation.build_automation(trigger, [], conf) use_state_callback = True - for conf in config.get(CONF_ON_PROGRESS, []): + for conf in config.get(ota.CONF_ON_PROGRESS, []): trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) await automation.build_automation(trigger, [(float, "x")], conf) use_state_callback = True - for conf in config.get(CONF_ON_END, []): + for conf in config.get(ota.CONF_ON_END, []): trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) await automation.build_automation(trigger, [], conf) use_state_callback = True - for conf in config.get(CONF_ON_ERROR, []): + for conf in config.get(ota.CONF_ON_ERROR, []): trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) await automation.build_automation(trigger, [(cg.uint8, "x")], conf) use_state_callback = True diff --git a/esphome/components/esphome/ota/automation.h b/esphome/components/esphome/ota/automation.h deleted file mode 100644 index b922b6d77b..0000000000 --- a/esphome/components/esphome/ota/automation.h +++ /dev/null @@ -1,69 +0,0 @@ -#pragma once - -#ifdef USE_OTA_STATE_CALLBACK -#include "ota_esphome.h" - -#include "esphome/core/component.h" -#include "esphome/core/defines.h" -#include "esphome/core/automation.h" - -namespace esphome { - -class OTAESPHomeStateChangeTrigger : public Trigger { - public: - explicit OTAESPHomeStateChangeTrigger(OTAESPHomeComponent *parent) { - parent->add_on_state_callback([this, parent](OTAESPHomeState state, float progress, uint8_t error) { - if (!parent->is_failed()) { - return trigger(state); - } - }); - } -}; - -class OTAESPHomeStartTrigger : public Trigger<> { - public: - explicit OTAESPHomeStartTrigger(OTAESPHomeComponent *parent) { - parent->add_on_state_callback([this, parent](OTAESPHomeState state, float progress, uint8_t error) { - if (state == OTA_STARTED && !parent->is_failed()) { - trigger(); - } - }); - } -}; - -class OTAESPHomeProgressTrigger : public Trigger { - public: - explicit OTAESPHomeProgressTrigger(OTAESPHomeComponent *parent) { - parent->add_on_state_callback([this, parent](OTAESPHomeState state, float progress, uint8_t error) { - if (state == OTA_IN_PROGRESS && !parent->is_failed()) { - trigger(progress); - } - }); - } -}; - -class OTAESPHomeEndTrigger : public Trigger<> { - public: - explicit OTAESPHomeEndTrigger(OTAESPHomeComponent *parent) { - parent->add_on_state_callback([this, parent](OTAESPHomeState state, float progress, uint8_t error) { - if (state == OTA_COMPLETED && !parent->is_failed()) { - trigger(); - } - }); - } -}; - -class OTAESPHomeErrorTrigger : public Trigger { - public: - explicit OTAESPHomeErrorTrigger(OTAESPHomeComponent *parent) { - parent->add_on_state_callback([this, parent](OTAESPHomeState state, float progress, uint8_t error) { - if (state == OTA_ERROR && !parent->is_failed()) { - trigger(error); - } - }); - } -}; - -} // namespace esphome - -#endif // USE_OTA_STATE_CALLBACK diff --git a/esphome/components/esphome/ota/ota_esphome.cpp b/esphome/components/esphome/ota/ota_esphome.cpp index 1fa3052e78..7b76380fb4 100644 --- a/esphome/components/esphome/ota/ota_esphome.cpp +++ b/esphome/components/esphome/ota/ota_esphome.cpp @@ -131,7 +131,7 @@ void ESPHomeOTAComponent::handle_() { ESP_LOGD(TAG, "Starting OTA update from %s...", this->client_->getpeername().c_str()); this->status_set_warning(); #ifdef USE_OTA_STATE_CALLBACK - this->state_callback_.call(OTA_STARTED, 0.0f, 0); + this->state_callback_.call(ota::OTA_STARTED, 0.0f, 0); #endif if (!this->readall_(buf, 5)) { @@ -306,7 +306,7 @@ void ESPHomeOTAComponent::handle_() { float percentage = (total * 100.0f) / ota_size; ESP_LOGD(TAG, "OTA in progress: %0.1f%%", percentage); #ifdef USE_OTA_STATE_CALLBACK - this->state_callback_.call(OTA_IN_PROGRESS, percentage, 0); + this->state_callback_.call(ota::OTA_IN_PROGRESS, percentage, 0); #endif // feed watchdog and give other tasks a chance to run App.feed_wdt(); @@ -340,7 +340,7 @@ void ESPHomeOTAComponent::handle_() { ESP_LOGI(TAG, "OTA update finished"); this->status_clear_warning(); #ifdef USE_OTA_STATE_CALLBACK - this->state_callback_.call(OTA_COMPLETED, 100.0f, 0); + this->state_callback_.call(ota::OTA_COMPLETED, 100.0f, 0); #endif delay(100); // NOLINT App.safe_reboot(); @@ -357,7 +357,7 @@ error: this->status_momentary_error("onerror", 5000); #ifdef USE_OTA_STATE_CALLBACK - this->state_callback_.call(OTA_ERROR, 0.0f, static_cast(error_code)); + this->state_callback_.call(ota::OTA_ERROR, 0.0f, static_cast(error_code)); #endif } @@ -486,26 +486,23 @@ bool ESPHomeOTAComponent::should_enter_safe_mode(uint8_t num_attempts, uint32_t return false; } } + void ESPHomeOTAComponent::write_rtc_(uint32_t val) { this->rtc_.save(&val); global_preferences->sync(); } + uint32_t ESPHomeOTAComponent::read_rtc_() { uint32_t val; if (!this->rtc_.load(&val)) return 0; return val; } + void ESPHomeOTAComponent::clean_rtc() { this->write_rtc_(0); } + void ESPHomeOTAComponent::on_safe_shutdown() { if (this->has_safe_mode_ && this->read_rtc_() != ESPHomeOTAComponent::ENTER_SAFE_MODE_MAGIC) this->clean_rtc(); } - -#ifdef USE_OTA_STATE_CALLBACK -void ESPHomeOTAComponent::add_on_state_callback(std::function &&callback) { - this->state_callback_.add(std::move(callback)); -} -#endif - } // namespace esphome diff --git a/esphome/components/esphome/ota/ota_esphome.h b/esphome/components/esphome/ota/ota_esphome.h index bbe3173ade..f230f2b465 100644 --- a/esphome/components/esphome/ota/ota_esphome.h +++ b/esphome/components/esphome/ota/ota_esphome.h @@ -1,17 +1,15 @@ #pragma once -#include "esphome/components/socket/socket.h" -#include "esphome/core/component.h" #include "esphome/core/defines.h" #include "esphome/core/helpers.h" #include "esphome/core/preferences.h" +#include "esphome/components/ota/ota_backend.h" +#include "esphome/components/socket/socket.h" namespace esphome { -enum OTAESPHomeState { OTA_COMPLETED = 0, OTA_STARTED, OTA_IN_PROGRESS, OTA_ERROR }; - /// ESPHomeOTAComponent provides a simple way to integrate Over-the-Air updates into your app using ArduinoOTA. -class ESPHomeOTAComponent : public Component { +class ESPHomeOTAComponent : public ota::OTAComponent { public: ESPHomeOTAComponent(); #ifdef USE_OTA_PASSWORD @@ -27,10 +25,6 @@ class ESPHomeOTAComponent : public Component { void set_safe_mode_pending(const bool &pending); bool get_safe_mode_pending(); -#ifdef USE_OTA_STATE_CALLBACK - void add_on_state_callback(std::function &&callback); -#endif - // ========== INTERNAL METHODS ========== // (In most use cases you won't need these) void setup() override; @@ -70,10 +64,6 @@ class ESPHomeOTAComponent : public Component { static const uint32_t ENTER_SAFE_MODE_MAGIC = 0x5afe5afe; ///< a magic number to indicate that safe mode should be entered on next boot - -#ifdef USE_OTA_STATE_CALLBACK - CallbackManager state_callback_{}; -#endif }; extern ESPHomeOTAComponent *global_ota_component; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) diff --git a/esphome/components/ota/__init__.py b/esphome/components/ota/__init__.py index aed0b87f85..719f25082f 100644 --- a/esphome/components/ota/__init__.py +++ b/esphome/components/ota/__init__.py @@ -1,6 +1,8 @@ +import esphome.codegen as cg import esphome.config_validation as cv +from esphome import automation -from esphome.const import CONF_ESPHOME, CONF_OTA, CONF_PLATFORM +from esphome.const import CONF_ESPHOME, CONF_OTA, CONF_PLATFORM, CONF_TRIGGER_ID CODEOWNERS = ["@esphome/core"] AUTO_LOAD = ["md5"] @@ -8,6 +10,24 @@ DEPENDENCIES = ["network"] IS_PLATFORM_COMPONENT = True +CONF_ON_BEGIN = "on_begin" +CONF_ON_END = "on_end" +CONF_ON_ERROR = "on_error" +CONF_ON_PROGRESS = "on_progress" +CONF_ON_STATE_CHANGE = "on_state_change" + + +ota = cg.esphome_ns.namespace("ota") +OTAComponent = ota.class_("OTAComponent", cg.Component) +OTAState = ota.enum("OTAState") +OTAEndTrigger = ota.class_("OTAEndTrigger", automation.Trigger.template()) +OTAErrorTrigger = ota.class_("OTAErrorTrigger", automation.Trigger.template()) +OTAProgressTrigger = ota.class_("OTAProgressTrigger", automation.Trigger.template()) +OTAStartTrigger = ota.class_("OTAStartTrigger", automation.Trigger.template()) +OTAStateChangeTrigger = ota.class_( + "OTAStateChangeTrigger", automation.Trigger.template() +) + def _ota_final_validate(config): if len(config) < 1: @@ -17,3 +37,33 @@ def _ota_final_validate(config): FINAL_VALIDATE_SCHEMA = _ota_final_validate + +BASE_OTA_SCHEMA = cv.Schema( + { + cv.Optional(CONF_ON_STATE_CHANGE): automation.validate_automation( + { + cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(OTAStateChangeTrigger), + } + ), + cv.Optional(CONF_ON_BEGIN): automation.validate_automation( + { + cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(OTAStartTrigger), + } + ), + cv.Optional(CONF_ON_END): automation.validate_automation( + { + cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(OTAEndTrigger), + } + ), + cv.Optional(CONF_ON_ERROR): automation.validate_automation( + { + cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(OTAErrorTrigger), + } + ), + cv.Optional(CONF_ON_PROGRESS): automation.validate_automation( + { + cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(OTAProgressTrigger), + } + ), + } +) diff --git a/esphome/components/ota/ota_backend.h b/esphome/components/ota/ota_backend.h index 471eebb390..9f2a851f5a 100644 --- a/esphome/components/ota/ota_backend.h +++ b/esphome/components/ota/ota_backend.h @@ -1,4 +1,10 @@ #pragma once +#ifdef USE_OTA_STATE_CALLBACK +#include "esphome/core/automation.h" +#include "esphome/core/defines.h" +#endif + +#include "esphome/core/component.h" #include "esphome/core/helpers.h" namespace esphome { @@ -33,6 +39,8 @@ enum OTAResponseTypes { OTA_RESPONSE_ERROR_UNKNOWN = 0xFF, }; +enum OTAState { OTA_COMPLETED = 0, OTA_STARTED, OTA_IN_PROGRESS, OTA_ERROR }; + class OTABackend { public: virtual ~OTABackend() = default; @@ -44,7 +52,80 @@ class OTABackend { virtual bool supports_compression() = 0; }; +class OTAComponent : public Component { +#ifdef USE_OTA_STATE_CALLBACK + public: + void add_on_state_callback(std::function &&callback) { + this->state_callback_.add(std::move(callback)); + } + + protected: + CallbackManager state_callback_{}; +#endif +}; + std::unique_ptr make_ota_backend(); +/// +/// Automations +/// + +#ifdef USE_OTA_STATE_CALLBACK +class OTAStateChangeTrigger : public Trigger { + public: + explicit OTAStateChangeTrigger(OTAComponent *parent) { + parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) { + if (!parent->is_failed()) { + return trigger(state); + } + }); + } +}; + +class OTAStartTrigger : public Trigger<> { + public: + explicit OTAStartTrigger(OTAComponent *parent) { + parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) { + if (state == OTA_STARTED && !parent->is_failed()) { + trigger(); + } + }); + } +}; + +class OTAProgressTrigger : public Trigger { + public: + explicit OTAProgressTrigger(OTAComponent *parent) { + parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) { + if (state == OTA_IN_PROGRESS && !parent->is_failed()) { + trigger(progress); + } + }); + } +}; + +class OTAEndTrigger : public Trigger<> { + public: + explicit OTAEndTrigger(OTAComponent *parent) { + parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) { + if (state == OTA_COMPLETED && !parent->is_failed()) { + trigger(); + } + }); + } +}; + +class OTAErrorTrigger : public Trigger { + public: + explicit OTAErrorTrigger(OTAComponent *parent) { + parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) { + if (state == OTA_ERROR && !parent->is_failed()) { + trigger(error); + } + }); + } +}; +#endif + } // namespace ota } // namespace esphome