1
0
mirror of https://github.com/esphome/esphome.git synced 2025-04-05 18:30:28 +01:00

Refactor more bits into the base ota component

This commit is contained in:
Keith Burzinski 2024-04-21 04:14:07 -05:00
parent af6dffd150
commit 8426084ebf
No known key found for this signature in database
GPG Key ID: 802564C5F0EEFFBE
6 changed files with 175 additions and 174 deletions

View File

@ -2,6 +2,7 @@ from esphome.cpp_generator import RawExpression
import esphome.codegen as cg
import esphome.config_validation as cv
from esphome import automation
from esphome.components import ota
from esphome.const import (
CONF_ID,
CONF_NUM_ATTEMPTS,
@ -21,83 +22,34 @@ CODEOWNERS = ["@esphome/core"]
AUTO_LOAD = ["md5", "socket"]
DEPENDENCIES = ["network"]
CONF_ON_BEGIN = "on_begin"
CONF_ON_END = "on_end"
CONF_ON_ERROR = "on_error"
CONF_ON_PROGRESS = "on_progress"
CONF_ON_STATE_CHANGE = "on_state_change"
esphome = cg.esphome_ns.namespace("esphome")
ESPHomeOTAComponent = esphome.class_("ESPHomeOTAComponent", cg.Component)
ESPHomeOTAEndTrigger = esphome.class_(
"ESPHomeOTAEndTrigger", automation.Trigger.template()
)
ESPHomeOTAErrorTrigger = esphome.class_(
"ESPHomeOTAErrorTrigger", automation.Trigger.template()
)
ESPHomeOTAProgressTrigger = esphome.class_(
"ESPHomeOTAProgressTrigger", automation.Trigger.template()
)
ESPHomeOTAStartTrigger = esphome.class_(
"ESPHomeOTAStartTrigger", automation.Trigger.template()
)
ESPHomeOTAStateChangeTrigger = esphome.class_(
"ESPHomeOTAStateChangeTrigger", automation.Trigger.template()
)
ESPHomeOTAState = esphome.enum("ESPHomeOTAState")
ESPHomeOTAComponent = esphome.class_("ESPHomeOTAComponent", ota.OTAComponent)
CONFIG_SCHEMA = cv.Schema(
{
cv.GenerateID(): cv.declare_id(ESPHomeOTAComponent),
cv.Optional(CONF_SAFE_MODE, default=True): cv.boolean,
cv.Optional(CONF_VERSION, default=2): cv.one_of(1, 2, int=True),
cv.SplitDefault(
CONF_PORT,
esp8266=8266,
esp32=3232,
rp2040=2040,
bk72xx=8892,
rtl87xx=8892,
): cv.port,
cv.Optional(CONF_PASSWORD): cv.string,
cv.Optional(
CONF_REBOOT_TIMEOUT, default="5min"
): cv.positive_time_period_milliseconds,
cv.Optional(CONF_NUM_ATTEMPTS, default="10"): cv.positive_not_null_int,
cv.Optional(CONF_ON_STATE_CHANGE): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(
ESPHomeOTAStateChangeTrigger
),
}
),
cv.Optional(CONF_ON_BEGIN): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(ESPHomeOTAStartTrigger),
}
),
cv.Optional(CONF_ON_END): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(ESPHomeOTAEndTrigger),
}
),
cv.Optional(CONF_ON_ERROR): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(ESPHomeOTAErrorTrigger),
}
),
cv.Optional(CONF_ON_PROGRESS): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(
ESPHomeOTAProgressTrigger
),
}
),
}
).extend(cv.COMPONENT_SCHEMA)
CONFIG_SCHEMA = (
cv.Schema(
{
cv.GenerateID(): cv.declare_id(ESPHomeOTAComponent),
cv.Optional(CONF_SAFE_MODE, default=True): cv.boolean,
cv.Optional(CONF_VERSION, default=2): cv.one_of(1, 2, int=True),
cv.SplitDefault(
CONF_PORT,
esp8266=8266,
esp32=3232,
rp2040=2040,
bk72xx=8892,
rtl87xx=8892,
): cv.port,
cv.Optional(CONF_PASSWORD): cv.string,
cv.Optional(
CONF_REBOOT_TIMEOUT, default="5min"
): cv.positive_time_period_milliseconds,
cv.Optional(CONF_NUM_ATTEMPTS, default="10"): cv.positive_not_null_int,
}
)
.extend(ota.BASE_OTA_SCHEMA)
.extend(cv.COMPONENT_SCHEMA)
)
@coroutine_with_priority(50.0)
@ -128,23 +80,23 @@ async def to_code(config):
cg.add_library("Updater", None)
use_state_callback = False
for conf in config.get(CONF_ON_STATE_CHANGE, []):
for conf in config.get(ota.CONF_ON_STATE_CHANGE, []):
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
await automation.build_automation(trigger, [(ESPHomeOTAState, "state")], conf)
await automation.build_automation(trigger, [(ota.OTAState, "state")], conf)
use_state_callback = True
for conf in config.get(CONF_ON_BEGIN, []):
for conf in config.get(ota.CONF_ON_BEGIN, []):
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
await automation.build_automation(trigger, [], conf)
use_state_callback = True
for conf in config.get(CONF_ON_PROGRESS, []):
for conf in config.get(ota.CONF_ON_PROGRESS, []):
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
await automation.build_automation(trigger, [(float, "x")], conf)
use_state_callback = True
for conf in config.get(CONF_ON_END, []):
for conf in config.get(ota.CONF_ON_END, []):
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
await automation.build_automation(trigger, [], conf)
use_state_callback = True
for conf in config.get(CONF_ON_ERROR, []):
for conf in config.get(ota.CONF_ON_ERROR, []):
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
await automation.build_automation(trigger, [(cg.uint8, "x")], conf)
use_state_callback = True

View File

@ -1,69 +0,0 @@
#pragma once
#ifdef USE_OTA_STATE_CALLBACK
#include "ota_esphome.h"
#include "esphome/core/component.h"
#include "esphome/core/defines.h"
#include "esphome/core/automation.h"
namespace esphome {
class OTAESPHomeStateChangeTrigger : public Trigger<OTAESPHomeState> {
public:
explicit OTAESPHomeStateChangeTrigger(OTAESPHomeComponent *parent) {
parent->add_on_state_callback([this, parent](OTAESPHomeState state, float progress, uint8_t error) {
if (!parent->is_failed()) {
return trigger(state);
}
});
}
};
class OTAESPHomeStartTrigger : public Trigger<> {
public:
explicit OTAESPHomeStartTrigger(OTAESPHomeComponent *parent) {
parent->add_on_state_callback([this, parent](OTAESPHomeState state, float progress, uint8_t error) {
if (state == OTA_STARTED && !parent->is_failed()) {
trigger();
}
});
}
};
class OTAESPHomeProgressTrigger : public Trigger<float> {
public:
explicit OTAESPHomeProgressTrigger(OTAESPHomeComponent *parent) {
parent->add_on_state_callback([this, parent](OTAESPHomeState state, float progress, uint8_t error) {
if (state == OTA_IN_PROGRESS && !parent->is_failed()) {
trigger(progress);
}
});
}
};
class OTAESPHomeEndTrigger : public Trigger<> {
public:
explicit OTAESPHomeEndTrigger(OTAESPHomeComponent *parent) {
parent->add_on_state_callback([this, parent](OTAESPHomeState state, float progress, uint8_t error) {
if (state == OTA_COMPLETED && !parent->is_failed()) {
trigger();
}
});
}
};
class OTAESPHomeErrorTrigger : public Trigger<uint8_t> {
public:
explicit OTAESPHomeErrorTrigger(OTAESPHomeComponent *parent) {
parent->add_on_state_callback([this, parent](OTAESPHomeState state, float progress, uint8_t error) {
if (state == OTA_ERROR && !parent->is_failed()) {
trigger(error);
}
});
}
};
} // namespace esphome
#endif // USE_OTA_STATE_CALLBACK

View File

@ -131,7 +131,7 @@ void ESPHomeOTAComponent::handle_() {
ESP_LOGD(TAG, "Starting OTA update from %s...", this->client_->getpeername().c_str());
this->status_set_warning();
#ifdef USE_OTA_STATE_CALLBACK
this->state_callback_.call(OTA_STARTED, 0.0f, 0);
this->state_callback_.call(ota::OTA_STARTED, 0.0f, 0);
#endif
if (!this->readall_(buf, 5)) {
@ -306,7 +306,7 @@ void ESPHomeOTAComponent::handle_() {
float percentage = (total * 100.0f) / ota_size;
ESP_LOGD(TAG, "OTA in progress: %0.1f%%", percentage);
#ifdef USE_OTA_STATE_CALLBACK
this->state_callback_.call(OTA_IN_PROGRESS, percentage, 0);
this->state_callback_.call(ota::OTA_IN_PROGRESS, percentage, 0);
#endif
// feed watchdog and give other tasks a chance to run
App.feed_wdt();
@ -340,7 +340,7 @@ void ESPHomeOTAComponent::handle_() {
ESP_LOGI(TAG, "OTA update finished");
this->status_clear_warning();
#ifdef USE_OTA_STATE_CALLBACK
this->state_callback_.call(OTA_COMPLETED, 100.0f, 0);
this->state_callback_.call(ota::OTA_COMPLETED, 100.0f, 0);
#endif
delay(100); // NOLINT
App.safe_reboot();
@ -357,7 +357,7 @@ error:
this->status_momentary_error("onerror", 5000);
#ifdef USE_OTA_STATE_CALLBACK
this->state_callback_.call(OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code));
this->state_callback_.call(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code));
#endif
}
@ -486,26 +486,23 @@ bool ESPHomeOTAComponent::should_enter_safe_mode(uint8_t num_attempts, uint32_t
return false;
}
}
void ESPHomeOTAComponent::write_rtc_(uint32_t val) {
this->rtc_.save(&val);
global_preferences->sync();
}
uint32_t ESPHomeOTAComponent::read_rtc_() {
uint32_t val;
if (!this->rtc_.load(&val))
return 0;
return val;
}
void ESPHomeOTAComponent::clean_rtc() { this->write_rtc_(0); }
void ESPHomeOTAComponent::on_safe_shutdown() {
if (this->has_safe_mode_ && this->read_rtc_() != ESPHomeOTAComponent::ENTER_SAFE_MODE_MAGIC)
this->clean_rtc();
}
#ifdef USE_OTA_STATE_CALLBACK
void ESPHomeOTAComponent::add_on_state_callback(std::function<void(OTAESPHomeState, float, uint8_t)> &&callback) {
this->state_callback_.add(std::move(callback));
}
#endif
} // namespace esphome

View File

@ -1,17 +1,15 @@
#pragma once
#include "esphome/components/socket/socket.h"
#include "esphome/core/component.h"
#include "esphome/core/defines.h"
#include "esphome/core/helpers.h"
#include "esphome/core/preferences.h"
#include "esphome/components/ota/ota_backend.h"
#include "esphome/components/socket/socket.h"
namespace esphome {
enum OTAESPHomeState { OTA_COMPLETED = 0, OTA_STARTED, OTA_IN_PROGRESS, OTA_ERROR };
/// ESPHomeOTAComponent provides a simple way to integrate Over-the-Air updates into your app using ArduinoOTA.
class ESPHomeOTAComponent : public Component {
class ESPHomeOTAComponent : public ota::OTAComponent {
public:
ESPHomeOTAComponent();
#ifdef USE_OTA_PASSWORD
@ -27,10 +25,6 @@ class ESPHomeOTAComponent : public Component {
void set_safe_mode_pending(const bool &pending);
bool get_safe_mode_pending();
#ifdef USE_OTA_STATE_CALLBACK
void add_on_state_callback(std::function<void(OTAESPHomeState, float, uint8_t)> &&callback);
#endif
// ========== INTERNAL METHODS ==========
// (In most use cases you won't need these)
void setup() override;
@ -70,10 +64,6 @@ class ESPHomeOTAComponent : public Component {
static const uint32_t ENTER_SAFE_MODE_MAGIC =
0x5afe5afe; ///< a magic number to indicate that safe mode should be entered on next boot
#ifdef USE_OTA_STATE_CALLBACK
CallbackManager<void(OTAESPHomeState, float, uint8_t)> state_callback_{};
#endif
};
extern ESPHomeOTAComponent *global_ota_component; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)

View File

@ -1,6 +1,8 @@
import esphome.codegen as cg
import esphome.config_validation as cv
from esphome import automation
from esphome.const import CONF_ESPHOME, CONF_OTA, CONF_PLATFORM
from esphome.const import CONF_ESPHOME, CONF_OTA, CONF_PLATFORM, CONF_TRIGGER_ID
CODEOWNERS = ["@esphome/core"]
AUTO_LOAD = ["md5"]
@ -8,6 +10,24 @@ DEPENDENCIES = ["network"]
IS_PLATFORM_COMPONENT = True
CONF_ON_BEGIN = "on_begin"
CONF_ON_END = "on_end"
CONF_ON_ERROR = "on_error"
CONF_ON_PROGRESS = "on_progress"
CONF_ON_STATE_CHANGE = "on_state_change"
ota = cg.esphome_ns.namespace("ota")
OTAComponent = ota.class_("OTAComponent", cg.Component)
OTAState = ota.enum("OTAState")
OTAEndTrigger = ota.class_("OTAEndTrigger", automation.Trigger.template())
OTAErrorTrigger = ota.class_("OTAErrorTrigger", automation.Trigger.template())
OTAProgressTrigger = ota.class_("OTAProgressTrigger", automation.Trigger.template())
OTAStartTrigger = ota.class_("OTAStartTrigger", automation.Trigger.template())
OTAStateChangeTrigger = ota.class_(
"OTAStateChangeTrigger", automation.Trigger.template()
)
def _ota_final_validate(config):
if len(config) < 1:
@ -17,3 +37,33 @@ def _ota_final_validate(config):
FINAL_VALIDATE_SCHEMA = _ota_final_validate
BASE_OTA_SCHEMA = cv.Schema(
{
cv.Optional(CONF_ON_STATE_CHANGE): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(OTAStateChangeTrigger),
}
),
cv.Optional(CONF_ON_BEGIN): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(OTAStartTrigger),
}
),
cv.Optional(CONF_ON_END): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(OTAEndTrigger),
}
),
cv.Optional(CONF_ON_ERROR): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(OTAErrorTrigger),
}
),
cv.Optional(CONF_ON_PROGRESS): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(OTAProgressTrigger),
}
),
}
)

View File

@ -1,4 +1,10 @@
#pragma once
#ifdef USE_OTA_STATE_CALLBACK
#include "esphome/core/automation.h"
#include "esphome/core/defines.h"
#endif
#include "esphome/core/component.h"
#include "esphome/core/helpers.h"
namespace esphome {
@ -33,6 +39,8 @@ enum OTAResponseTypes {
OTA_RESPONSE_ERROR_UNKNOWN = 0xFF,
};
enum OTAState { OTA_COMPLETED = 0, OTA_STARTED, OTA_IN_PROGRESS, OTA_ERROR };
class OTABackend {
public:
virtual ~OTABackend() = default;
@ -44,7 +52,80 @@ class OTABackend {
virtual bool supports_compression() = 0;
};
class OTAComponent : public Component {
#ifdef USE_OTA_STATE_CALLBACK
public:
void add_on_state_callback(std::function<void(ota::OTAState, float, uint8_t)> &&callback) {
this->state_callback_.add(std::move(callback));
}
protected:
CallbackManager<void(ota::OTAState, float, uint8_t)> state_callback_{};
#endif
};
std::unique_ptr<ota::OTABackend> make_ota_backend();
///
/// Automations
///
#ifdef USE_OTA_STATE_CALLBACK
class OTAStateChangeTrigger : public Trigger<OTAState> {
public:
explicit OTAStateChangeTrigger(OTAComponent *parent) {
parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) {
if (!parent->is_failed()) {
return trigger(state);
}
});
}
};
class OTAStartTrigger : public Trigger<> {
public:
explicit OTAStartTrigger(OTAComponent *parent) {
parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) {
if (state == OTA_STARTED && !parent->is_failed()) {
trigger();
}
});
}
};
class OTAProgressTrigger : public Trigger<float> {
public:
explicit OTAProgressTrigger(OTAComponent *parent) {
parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) {
if (state == OTA_IN_PROGRESS && !parent->is_failed()) {
trigger(progress);
}
});
}
};
class OTAEndTrigger : public Trigger<> {
public:
explicit OTAEndTrigger(OTAComponent *parent) {
parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) {
if (state == OTA_COMPLETED && !parent->is_failed()) {
trigger();
}
});
}
};
class OTAErrorTrigger : public Trigger<uint8_t> {
public:
explicit OTAErrorTrigger(OTAComponent *parent) {
parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) {
if (state == OTA_ERROR && !parent->is_failed()) {
trigger(error);
}
});
}
};
#endif
} // namespace ota
} // namespace esphome