1
0
mirror of https://github.com/esphome/esphome.git synced 2025-10-15 08:13:51 +01:00

Merge branch 'dev' into jesserockz-2025-457

This commit is contained in:
Jesse Hills
2025-10-01 11:12:55 +13:00
committed by GitHub
53 changed files with 1896 additions and 414 deletions

View File

@@ -102,12 +102,12 @@ jobs:
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1
- name: Log in to docker hub
uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_PASSWORD }}
- name: Log in to the GitHub container registry
uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -182,13 +182,13 @@ jobs:
- name: Log in to docker hub
if: matrix.registry == 'dockerhub'
uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_PASSWORD }}
- name: Log in to the GitHub container registry
if: matrix.registry == 'ghcr'
uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
with:
registry: ghcr.io
username: ${{ github.actor }}

View File

@@ -66,6 +66,8 @@ CONF_BATCH_DELAY = "batch_delay"
CONF_CUSTOM_SERVICES = "custom_services"
CONF_HOMEASSISTANT_SERVICES = "homeassistant_services"
CONF_HOMEASSISTANT_STATES = "homeassistant_states"
CONF_LISTEN_BACKLOG = "listen_backlog"
CONF_MAX_CONNECTIONS = "max_connections"
def validate_encryption_key(value):
@@ -165,6 +167,29 @@ CONFIG_SCHEMA = cv.All(
cv.Optional(CONF_ON_CLIENT_DISCONNECTED): automation.validate_automation(
single=True
),
# Connection limits to prevent memory exhaustion on resource-constrained devices
# Each connection uses ~500-1000 bytes of RAM plus system resources
# Platform defaults based on available RAM and network stack implementation:
cv.SplitDefault(
CONF_LISTEN_BACKLOG,
esp8266=1, # Limited RAM (~40KB free), LWIP raw sockets
esp32=4, # More RAM (520KB), BSD sockets
rp2040=1, # Limited RAM (264KB), LWIP raw sockets like ESP8266
bk72xx=4, # Moderate RAM, BSD-style sockets
rtl87xx=4, # Moderate RAM, BSD-style sockets
host=4, # Abundant resources
ln882x=4, # Moderate RAM
): cv.int_range(min=1, max=10),
cv.SplitDefault(
CONF_MAX_CONNECTIONS,
esp8266=4, # ~40KB free RAM, each connection uses ~500-1000 bytes
esp32=8, # 520KB RAM available
rp2040=4, # 264KB RAM but LWIP constraints
bk72xx=8, # Moderate RAM
rtl87xx=8, # Moderate RAM
host=8, # Abundant resources
ln882x=8, # Moderate RAM
): cv.int_range(min=1, max=20),
}
).extend(cv.COMPONENT_SCHEMA),
cv.rename_key(CONF_SERVICES, CONF_ACTIONS),
@@ -183,6 +208,10 @@ async def to_code(config):
cg.add(var.set_password(config[CONF_PASSWORD]))
cg.add(var.set_reboot_timeout(config[CONF_REBOOT_TIMEOUT]))
cg.add(var.set_batch_delay(config[CONF_BATCH_DELAY]))
if CONF_LISTEN_BACKLOG in config:
cg.add(var.set_listen_backlog(config[CONF_LISTEN_BACKLOG]))
if CONF_MAX_CONNECTIONS in config:
cg.add(var.set_max_connections(config[CONF_MAX_CONNECTIONS]))
# Set USE_API_SERVICES if any services are enabled
if config.get(CONF_ACTIONS) or config[CONF_CUSTOM_SERVICES]:

View File

@@ -90,7 +90,7 @@ void APIServer::setup() {
return;
}
err = this->socket_->listen(4);
err = this->socket_->listen(this->listen_backlog_);
if (err != 0) {
ESP_LOGW(TAG, "Socket unable to listen: errno %d", errno);
this->mark_failed();
@@ -143,9 +143,19 @@ void APIServer::loop() {
while (true) {
struct sockaddr_storage source_addr;
socklen_t addr_len = sizeof(source_addr);
auto sock = this->socket_->accept_loop_monitored((struct sockaddr *) &source_addr, &addr_len);
if (!sock)
break;
// Check if we're at the connection limit
if (this->clients_.size() >= this->max_connections_) {
ESP_LOGW(TAG, "Max connections (%d), rejecting %s", this->max_connections_, sock->getpeername().c_str());
// Immediately close - socket destructor will handle cleanup
sock.reset();
continue;
}
ESP_LOGD(TAG, "Accept %s", sock->getpeername().c_str());
auto *conn = new APIConnection(std::move(sock), this);
@@ -209,8 +219,10 @@ void APIServer::loop() {
void APIServer::dump_config() {
ESP_LOGCONFIG(TAG,
"Server:\n"
" Address: %s:%u",
network::get_use_address().c_str(), this->port_);
" Address: %s:%u\n"
" Listen backlog: %u\n"
" Max connections: %u",
network::get_use_address().c_str(), this->port_, this->listen_backlog_, this->max_connections_);
#ifdef USE_API_NOISE
ESP_LOGCONFIG(TAG, " Noise encryption: %s", YESNO(this->noise_ctx_->has_psk()));
if (!this->noise_ctx_->has_psk()) {

View File

@@ -45,6 +45,8 @@ class APIServer : public Component, public Controller {
void set_reboot_timeout(uint32_t reboot_timeout);
void set_batch_delay(uint16_t batch_delay);
uint16_t get_batch_delay() const { return batch_delay_; }
void set_listen_backlog(uint8_t listen_backlog) { this->listen_backlog_ = listen_backlog; }
void set_max_connections(uint8_t max_connections) { this->max_connections_ = max_connections; }
// Get reference to shared buffer for API connections
std::vector<uint8_t> &get_shared_buffer_ref() { return shared_write_buffer_; }
@@ -198,8 +200,12 @@ class APIServer : public Component, public Controller {
// Group smaller types together
uint16_t port_{6053};
uint16_t batch_delay_{100};
// Connection limits - these defaults will be overridden by config values
// from cv.SplitDefault in __init__.py which sets platform-specific defaults
uint8_t listen_backlog_{4};
uint8_t max_connections_{8};
bool shutting_down_ = false;
// 5 bytes used, 3 bytes padding
// 7 bytes used, 1 byte padding
#ifdef USE_API_NOISE
std::shared_ptr<APINoiseContext> noise_ctx_ = std::make_shared<APINoiseContext>();

View File

@@ -21,8 +21,8 @@ void Canbus::dump_config() {
}
}
void Canbus::send_data(uint32_t can_id, bool use_extended_id, bool remote_transmission_request,
const std::vector<uint8_t> &data) {
canbus::Error Canbus::send_data(uint32_t can_id, bool use_extended_id, bool remote_transmission_request,
const std::vector<uint8_t> &data) {
struct CanFrame can_message;
uint8_t size = static_cast<uint8_t>(data.size());
@@ -45,13 +45,15 @@ void Canbus::send_data(uint32_t can_id, bool use_extended_id, bool remote_transm
ESP_LOGVV(TAG, " data[%d]=%02x", i, can_message.data[i]);
}
if (this->send_message(&can_message) != canbus::ERROR_OK) {
canbus::Error error = this->send_message(&can_message);
if (error != canbus::ERROR_OK) {
if (use_extended_id) {
ESP_LOGW(TAG, "send to extended id=0x%08" PRIx32 " failed!", can_id);
ESP_LOGW(TAG, "send to extended id=0x%08" PRIx32 " failed with error %d!", can_id, error);
} else {
ESP_LOGW(TAG, "send to standard id=0x%03" PRIx32 " failed!", can_id);
ESP_LOGW(TAG, "send to standard id=0x%03" PRIx32 " failed with error %d!", can_id, error);
}
}
return error;
}
void Canbus::add_trigger(CanbusTrigger *trigger) {

View File

@@ -70,11 +70,11 @@ class Canbus : public Component {
float get_setup_priority() const override { return setup_priority::HARDWARE; }
void loop() override;
void send_data(uint32_t can_id, bool use_extended_id, bool remote_transmission_request,
const std::vector<uint8_t> &data);
void send_data(uint32_t can_id, bool use_extended_id, const std::vector<uint8_t> &data) {
canbus::Error send_data(uint32_t can_id, bool use_extended_id, bool remote_transmission_request,
const std::vector<uint8_t> &data);
canbus::Error send_data(uint32_t can_id, bool use_extended_id, const std::vector<uint8_t> &data) {
// for backwards compatibility only
this->send_data(can_id, use_extended_id, false, data);
return this->send_data(can_id, use_extended_id, false, data);
}
void set_can_id(uint32_t can_id) { this->can_id_ = can_id; }
void set_use_extended_id(bool use_extended_id) { this->use_extended_id_ = use_extended_id; }

View File

@@ -197,7 +197,8 @@ CONFIG_SCHEMA = cv.All(
cv.Optional(CONF_ESP32_EXT1_WAKEUP): cv.All(
cv.only_on_esp32,
esp32.only_on_variant(
unsupported=[VARIANT_ESP32C3], msg_prefix="Wakeup from ext1"
unsupported=[VARIANT_ESP32C2, VARIANT_ESP32C3],
msg_prefix="Wakeup from ext1",
),
cv.Schema(
{
@@ -214,7 +215,13 @@ CONFIG_SCHEMA = cv.All(
cv.Optional(CONF_TOUCH_WAKEUP): cv.All(
cv.only_on_esp32,
esp32.only_on_variant(
unsupported=[VARIANT_ESP32C3], msg_prefix="Wakeup from touch"
unsupported=[
VARIANT_ESP32C2,
VARIANT_ESP32C3,
VARIANT_ESP32C6,
VARIANT_ESP32H2,
],
msg_prefix="Wakeup from touch",
),
cv.boolean,
),

View File

@@ -34,7 +34,7 @@ enum WakeupPinMode {
WAKEUP_PIN_MODE_INVERT_WAKEUP,
};
#if defined(USE_ESP32) && !defined(USE_ESP32_VARIANT_ESP32C3)
#if defined(USE_ESP32) && !defined(USE_ESP32_VARIANT_ESP32C2) && !defined(USE_ESP32_VARIANT_ESP32C3)
struct Ext1Wakeup {
uint64_t mask;
esp_sleep_ext1_wakeup_mode_t wakeup_mode;
@@ -50,7 +50,7 @@ struct WakeupCauseToRunDuration {
uint32_t gpio_cause;
};
#endif
#endif // USE_ESP32
template<typename... Ts> class EnterDeepSleepAction;
@@ -73,20 +73,22 @@ class DeepSleepComponent : public Component {
void set_wakeup_pin(InternalGPIOPin *pin) { this->wakeup_pin_ = pin; }
void set_wakeup_pin_mode(WakeupPinMode wakeup_pin_mode);
#endif
#endif // USE_ESP32
#if defined(USE_ESP32)
#if !defined(USE_ESP32_VARIANT_ESP32C3)
#if !defined(USE_ESP32_VARIANT_ESP32C2) && !defined(USE_ESP32_VARIANT_ESP32C3)
void set_ext1_wakeup(Ext1Wakeup ext1_wakeup);
void set_touch_wakeup(bool touch_wakeup);
#endif
#if !defined(USE_ESP32_VARIANT_ESP32C2) && !defined(USE_ESP32_VARIANT_ESP32C3) && \
!defined(USE_ESP32_VARIANT_ESP32C6) && !defined(USE_ESP32_VARIANT_ESP32H2)
void set_touch_wakeup(bool touch_wakeup);
#endif
// Set the duration in ms for how long the code should run before entering
// deep sleep mode, according to the cause the ESP32 has woken.
void set_run_duration(WakeupCauseToRunDuration wakeup_cause_to_run_duration);
#endif
#endif // USE_ESP32
/// Set a duration in ms for how long the code should run before entering deep sleep mode.
void set_run_duration(uint32_t time_ms);
@@ -117,13 +119,13 @@ class DeepSleepComponent : public Component {
InternalGPIOPin *wakeup_pin_;
WakeupPinMode wakeup_pin_mode_{WAKEUP_PIN_MODE_IGNORE};
#if !defined(USE_ESP32_VARIANT_ESP32C3)
#if !defined(USE_ESP32_VARIANT_ESP32C2) && !defined(USE_ESP32_VARIANT_ESP32C3)
optional<Ext1Wakeup> ext1_wakeup_;
#endif
optional<bool> touch_wakeup_;
optional<WakeupCauseToRunDuration> wakeup_cause_to_run_duration_;
#endif
#endif // USE_ESP32
optional<uint32_t> run_duration_;
bool next_enter_deep_sleep_{false};
bool prevent_{false};

View File

@@ -7,6 +7,26 @@
namespace esphome {
namespace deep_sleep {
// Deep Sleep feature support matrix for ESP32 variants:
//
// | Variant | ext0 | ext1 | Touch | GPIO wakeup |
// |-----------|------|------|-------|-------------|
// | ESP32 | ✓ | ✓ | ✓ | |
// | ESP32-S2 | ✓ | ✓ | ✓ | |
// | ESP32-S3 | ✓ | ✓ | ✓ | |
// | ESP32-C2 | | | | ✓ |
// | ESP32-C3 | | | | ✓ |
// | ESP32-C5 | | (✓) | | (✓) |
// | ESP32-C6 | | ✓ | | ✓ |
// | ESP32-H2 | | ✓ | | |
//
// Notes:
// - (✓) = Supported by hardware but not yet implemented in ESPHome
// - ext0: Single pin wakeup using RTC GPIO (esp_sleep_enable_ext0_wakeup)
// - ext1: Multiple pin wakeup (esp_sleep_enable_ext1_wakeup)
// - Touch: Touch pad wakeup (esp_sleep_enable_touchpad_wakeup)
// - GPIO wakeup: GPIO wakeup for non-RTC pins (esp_deep_sleep_enable_gpio_wakeup)
static const char *const TAG = "deep_sleep";
optional<uint32_t> DeepSleepComponent::get_run_duration_() const {
@@ -30,13 +50,13 @@ void DeepSleepComponent::set_wakeup_pin_mode(WakeupPinMode wakeup_pin_mode) {
this->wakeup_pin_mode_ = wakeup_pin_mode;
}
#if !defined(USE_ESP32_VARIANT_ESP32C3) && !defined(USE_ESP32_VARIANT_ESP32C6)
#if !defined(USE_ESP32_VARIANT_ESP32C2) && !defined(USE_ESP32_VARIANT_ESP32C3)
void DeepSleepComponent::set_ext1_wakeup(Ext1Wakeup ext1_wakeup) { this->ext1_wakeup_ = ext1_wakeup; }
#if !defined(USE_ESP32_VARIANT_ESP32H2)
void DeepSleepComponent::set_touch_wakeup(bool touch_wakeup) { this->touch_wakeup_ = touch_wakeup; }
#endif
#if !defined(USE_ESP32_VARIANT_ESP32C2) && !defined(USE_ESP32_VARIANT_ESP32C3) && \
!defined(USE_ESP32_VARIANT_ESP32C6) && !defined(USE_ESP32_VARIANT_ESP32H2)
void DeepSleepComponent::set_touch_wakeup(bool touch_wakeup) { this->touch_wakeup_ = touch_wakeup; }
#endif
void DeepSleepComponent::set_run_duration(WakeupCauseToRunDuration wakeup_cause_to_run_duration) {
@@ -72,9 +92,13 @@ bool DeepSleepComponent::prepare_to_sleep_() {
}
void DeepSleepComponent::deep_sleep_() {
#if !defined(USE_ESP32_VARIANT_ESP32C3) && !defined(USE_ESP32_VARIANT_ESP32C6) && !defined(USE_ESP32_VARIANT_ESP32H2)
// Timer wakeup - all variants support this
if (this->sleep_duration_.has_value())
esp_sleep_enable_timer_wakeup(*this->sleep_duration_);
// Single pin wakeup (ext0) - ESP32, S2, S3 only
#if !defined(USE_ESP32_VARIANT_ESP32C2) && !defined(USE_ESP32_VARIANT_ESP32C3) && \
!defined(USE_ESP32_VARIANT_ESP32C6) && !defined(USE_ESP32_VARIANT_ESP32H2)
if (this->wakeup_pin_ != nullptr) {
const auto gpio_pin = gpio_num_t(this->wakeup_pin_->get_pin());
if (this->wakeup_pin_->get_flags() & gpio::FLAG_PULLUP) {
@@ -95,32 +119,15 @@ void DeepSleepComponent::deep_sleep_() {
}
esp_sleep_enable_ext0_wakeup(gpio_pin, level);
}
if (this->ext1_wakeup_.has_value()) {
esp_sleep_enable_ext1_wakeup(this->ext1_wakeup_->mask, this->ext1_wakeup_->wakeup_mode);
}
if (this->touch_wakeup_.has_value() && *(this->touch_wakeup_)) {
esp_sleep_enable_touchpad_wakeup();
esp_sleep_pd_config(ESP_PD_DOMAIN_RTC_PERIPH, ESP_PD_OPTION_ON);
}
#endif
#if defined(USE_ESP32_VARIANT_ESP32H2)
if (this->sleep_duration_.has_value())
esp_sleep_enable_timer_wakeup(*this->sleep_duration_);
if (this->ext1_wakeup_.has_value()) {
esp_sleep_enable_ext1_wakeup(this->ext1_wakeup_->mask, this->ext1_wakeup_->wakeup_mode);
}
#endif
#if defined(USE_ESP32_VARIANT_ESP32C3) || defined(USE_ESP32_VARIANT_ESP32C6)
if (this->sleep_duration_.has_value())
esp_sleep_enable_timer_wakeup(*this->sleep_duration_);
// GPIO wakeup - C2, C3, C6 only
#if defined(USE_ESP32_VARIANT_ESP32C2) || defined(USE_ESP32_VARIANT_ESP32C3) || defined(USE_ESP32_VARIANT_ESP32C6)
if (this->wakeup_pin_ != nullptr) {
const auto gpio_pin = gpio_num_t(this->wakeup_pin_->get_pin());
if (this->wakeup_pin_->get_flags() && gpio::FLAG_PULLUP) {
if (this->wakeup_pin_->get_flags() & gpio::FLAG_PULLUP) {
gpio_sleep_set_pull_mode(gpio_pin, GPIO_PULLUP_ONLY);
} else if (this->wakeup_pin_->get_flags() && gpio::FLAG_PULLDOWN) {
} else if (this->wakeup_pin_->get_flags() & gpio::FLAG_PULLDOWN) {
gpio_sleep_set_pull_mode(gpio_pin, GPIO_PULLDOWN_ONLY);
}
gpio_sleep_set_direction(gpio_pin, GPIO_MODE_INPUT);
@@ -138,9 +145,26 @@ void DeepSleepComponent::deep_sleep_() {
static_cast<esp_deepsleep_gpio_wake_up_mode_t>(level));
}
#endif
// Multiple pin wakeup (ext1) - All except C2, C3
#if !defined(USE_ESP32_VARIANT_ESP32C2) && !defined(USE_ESP32_VARIANT_ESP32C3)
if (this->ext1_wakeup_.has_value()) {
esp_sleep_enable_ext1_wakeup(this->ext1_wakeup_->mask, this->ext1_wakeup_->wakeup_mode);
}
#endif
// Touch wakeup - ESP32, S2, S3 only
#if !defined(USE_ESP32_VARIANT_ESP32C2) && !defined(USE_ESP32_VARIANT_ESP32C3) && \
!defined(USE_ESP32_VARIANT_ESP32C6) && !defined(USE_ESP32_VARIANT_ESP32H2)
if (this->touch_wakeup_.has_value() && *(this->touch_wakeup_)) {
esp_sleep_enable_touchpad_wakeup();
esp_sleep_pd_config(ESP_PD_DOMAIN_RTC_PERIPH, ESP_PD_OPTION_ON);
}
#endif
esp_deep_sleep_start();
}
} // namespace deep_sleep
} // namespace esphome
#endif
#endif // USE_ESP32

View File

@@ -1,11 +1,13 @@
#include "ota_esphome.h"
#ifdef USE_OTA
#ifdef USE_OTA_PASSWORD
#ifdef USE_OTA_MD5
#include "esphome/components/md5/md5.h"
#endif
#ifdef USE_OTA_SHA256
#include "esphome/components/sha256/sha256.h"
#endif
#endif
#include "esphome/components/network/util.h"
#include "esphome/components/ota/ota_backend.h"
#include "esphome/components/ota/ota_backend_arduino_esp32.h"
@@ -26,9 +28,19 @@ namespace esphome {
static const char *const TAG = "esphome.ota";
static constexpr uint16_t OTA_BLOCK_SIZE = 8192;
static constexpr size_t OTA_BUFFER_SIZE = 1024; // buffer size for OTA data transfer
static constexpr uint32_t OTA_SOCKET_TIMEOUT_HANDSHAKE = 10000; // milliseconds for initial handshake
static constexpr uint32_t OTA_SOCKET_TIMEOUT_DATA = 90000; // milliseconds for data transfer
#ifdef USE_OTA_PASSWORD
#ifdef USE_OTA_MD5
static constexpr size_t MD5_HEX_SIZE = 32; // MD5 hash as hex string (16 bytes * 2)
#endif
#ifdef USE_OTA_SHA256
static constexpr size_t SHA256_HEX_SIZE = 64; // SHA256 hash as hex string (32 bytes * 2)
#endif
#endif // USE_OTA_PASSWORD
void ESPHomeOTAComponent::setup() {
#ifdef USE_OTA_STATE_CALLBACK
ota::register_ota_platform(this);
@@ -69,7 +81,7 @@ void ESPHomeOTAComponent::setup() {
return;
}
err = this->server_->listen(4);
err = this->server_->listen(1); // Only one client at a time
if (err != 0) {
this->log_socket_error_(LOG_STR("listen"));
this->mark_failed();
@@ -112,11 +124,11 @@ static const uint8_t FEATURE_SUPPORTS_SHA256_AUTH = 0x02;
#define ALLOW_OTA_DOWNGRADE_MD5
void ESPHomeOTAComponent::handle_handshake_() {
/// Handle the initial OTA handshake.
/// Handle the OTA handshake and authentication.
///
/// This method is non-blocking and will return immediately if no data is available.
/// It reads all 5 magic bytes (0x6C, 0x26, 0xF7, 0x5C, 0x45) non-blocking
/// before proceeding to handle_data_(). A 10-second timeout is enforced from initial connection.
/// It manages the state machine through connection, magic bytes validation, feature
/// negotiation, and authentication before entering the blocking data transfer phase.
if (this->client_ == nullptr) {
// We already checked server_->ready() in loop(), so we can accept directly
@@ -141,7 +153,8 @@ void ESPHomeOTAComponent::handle_handshake_() {
}
this->log_start_(LOG_STR("handshake"));
this->client_connect_time_ = App.get_loop_component_start_time();
this->magic_buf_pos_ = 0; // Reset magic buffer position
this->handshake_buf_pos_ = 0; // Reset handshake buffer position
this->ota_state_ = OTAState::MAGIC_READ;
}
// Check for handshake timeout
@@ -152,46 +165,99 @@ void ESPHomeOTAComponent::handle_handshake_() {
return;
}
// Try to read remaining magic bytes
if (this->magic_buf_pos_ < 5) {
// Read as many bytes as available
uint8_t bytes_to_read = 5 - this->magic_buf_pos_;
ssize_t read = this->client_->read(this->magic_buf_ + this->magic_buf_pos_, bytes_to_read);
if (read == -1 && (errno == EAGAIN || errno == EWOULDBLOCK)) {
return; // No data yet, try again next loop
}
if (read <= 0) {
// Error or connection closed
if (read == -1) {
this->log_socket_error_(LOG_STR("reading magic bytes"));
} else {
ESP_LOGW(TAG, "Remote closed during handshake");
switch (this->ota_state_) {
case OTAState::MAGIC_READ: {
// Try to read remaining magic bytes (5 total)
if (!this->try_read_(5, LOG_STR("read magic"))) {
return;
}
this->cleanup_connection_();
return;
// Validate magic bytes
static const uint8_t MAGIC_BYTES[5] = {0x6C, 0x26, 0xF7, 0x5C, 0x45};
if (memcmp(this->handshake_buf_, MAGIC_BYTES, 5) != 0) {
ESP_LOGW(TAG, "Magic bytes mismatch! 0x%02X-0x%02X-0x%02X-0x%02X-0x%02X", this->handshake_buf_[0],
this->handshake_buf_[1], this->handshake_buf_[2], this->handshake_buf_[3], this->handshake_buf_[4]);
this->send_error_and_cleanup_(ota::OTA_RESPONSE_ERROR_MAGIC);
return;
}
// Magic bytes valid, move to next state
this->transition_ota_state_(OTAState::MAGIC_ACK);
this->handshake_buf_[0] = ota::OTA_RESPONSE_OK;
this->handshake_buf_[1] = USE_OTA_VERSION;
[[fallthrough]];
}
this->magic_buf_pos_ += read;
}
// Check if we have all 5 magic bytes
if (this->magic_buf_pos_ == 5) {
// Validate magic bytes
static const uint8_t MAGIC_BYTES[5] = {0x6C, 0x26, 0xF7, 0x5C, 0x45};
if (memcmp(this->magic_buf_, MAGIC_BYTES, 5) != 0) {
ESP_LOGW(TAG, "Magic bytes mismatch! 0x%02X-0x%02X-0x%02X-0x%02X-0x%02X", this->magic_buf_[0],
this->magic_buf_[1], this->magic_buf_[2], this->magic_buf_[3], this->magic_buf_[4]);
// Send error response (non-blocking, best effort)
uint8_t error = static_cast<uint8_t>(ota::OTA_RESPONSE_ERROR_MAGIC);
this->client_->write(&error, 1);
this->cleanup_connection_();
return;
case OTAState::MAGIC_ACK: {
// Send OK and version - 2 bytes
if (!this->try_write_(2, LOG_STR("ack magic"))) {
return;
}
// All bytes sent, create backend and move to next state
this->backend_ = ota::make_ota_backend();
this->transition_ota_state_(OTAState::FEATURE_READ);
[[fallthrough]];
}
// All 5 magic bytes are valid, continue with data handling
this->handle_data_();
case OTAState::FEATURE_READ: {
// Read features - 1 byte
if (!this->try_read_(1, LOG_STR("read feature"))) {
return;
}
this->ota_features_ = this->handshake_buf_[0];
ESP_LOGV(TAG, "Features: 0x%02X", this->ota_features_);
this->transition_ota_state_(OTAState::FEATURE_ACK);
this->handshake_buf_[0] =
((this->ota_features_ & FEATURE_SUPPORTS_COMPRESSION) != 0 && this->backend_->supports_compression())
? ota::OTA_RESPONSE_SUPPORTS_COMPRESSION
: ota::OTA_RESPONSE_HEADER_OK;
[[fallthrough]];
}
case OTAState::FEATURE_ACK: {
// Acknowledge header - 1 byte
if (!this->try_write_(1, LOG_STR("ack feature"))) {
return;
}
#ifdef USE_OTA_PASSWORD
// If password is set, move to auth phase
if (!this->password_.empty()) {
this->transition_ota_state_(OTAState::AUTH_SEND);
} else
#endif
{
// No password, move directly to data phase
this->transition_ota_state_(OTAState::DATA);
}
[[fallthrough]];
}
#ifdef USE_OTA_PASSWORD
case OTAState::AUTH_SEND: {
// Non-blocking authentication send
if (!this->handle_auth_send_()) {
return;
}
this->transition_ota_state_(OTAState::AUTH_READ);
[[fallthrough]];
}
case OTAState::AUTH_READ: {
// Non-blocking authentication read & verify
if (!this->handle_auth_read_()) {
return;
}
this->transition_ota_state_(OTAState::DATA);
[[fallthrough]];
}
#endif
case OTAState::DATA:
this->handle_data_();
return;
default:
break;
}
}
@@ -199,114 +265,21 @@ void ESPHomeOTAComponent::handle_data_() {
/// Handle the OTA data transfer and update process.
///
/// This method is blocking and will not return until the OTA update completes,
/// fails, or times out. It handles authentication, receives the firmware data,
/// writes it to flash, and reboots on success.
/// fails, or times out. It receives the firmware data, writes it to flash,
/// and reboots on success.
///
/// Authentication has already been handled in the non-blocking states AUTH_SEND/AUTH_READ.
ota::OTAResponseTypes error_code = ota::OTA_RESPONSE_ERROR_UNKNOWN;
bool update_started = false;
size_t total = 0;
uint32_t last_progress = 0;
uint8_t buf[1024];
uint8_t buf[OTA_BUFFER_SIZE];
char *sbuf = reinterpret_cast<char *>(buf);
size_t ota_size;
uint8_t ota_features;
std::unique_ptr<ota::OTABackend> backend;
(void) ota_features;
#if USE_OTA_VERSION == 2
size_t size_acknowledged = 0;
#endif
// Send OK and version - 2 bytes
buf[0] = ota::OTA_RESPONSE_OK;
buf[1] = USE_OTA_VERSION;
this->writeall_(buf, 2);
backend = ota::make_ota_backend();
// Read features - 1 byte
if (!this->readall_(buf, 1)) {
this->log_read_error_(LOG_STR("features"));
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
}
ota_features = buf[0]; // NOLINT
ESP_LOGV(TAG, "Features: 0x%02X", ota_features);
// Acknowledge header - 1 byte
buf[0] = ota::OTA_RESPONSE_HEADER_OK;
if ((ota_features & FEATURE_SUPPORTS_COMPRESSION) != 0 && backend->supports_compression()) {
buf[0] = ota::OTA_RESPONSE_SUPPORTS_COMPRESSION;
}
this->writeall_(buf, 1);
#ifdef USE_OTA_PASSWORD
if (!this->password_.empty()) {
bool auth_success = false;
#ifdef USE_OTA_SHA256
// SECURITY HARDENING: Prefer SHA256 authentication on platforms that support it.
//
// This is a hardening measure to prevent future downgrade attacks where an attacker
// could force the use of MD5 authentication by manipulating the feature flags.
//
// While MD5 is currently still acceptable for our OTA authentication use case
// (where the password is a shared secret and we're only authenticating, not
// encrypting), at some point in the future MD5 will likely become so weak that
// it could be practically attacked.
//
// We enforce SHA256 now on capable platforms because:
// 1. We can't retroactively update device firmware in the field
// 2. Clients (like esphome CLI) can always be updated to support SHA256
// 3. This prevents any possibility of downgrade attacks in the future
//
// Devices that don't support SHA256 (due to platform limitations) will
// continue to use MD5 as their only option (see #else branch below).
bool client_supports_sha256 = (ota_features & FEATURE_SUPPORTS_SHA256_AUTH) != 0;
#ifdef ALLOW_OTA_DOWNGRADE_MD5
// Temporary compatibility mode: Allow MD5 for ~3 versions to enable OTA downgrades
// This prevents users from being locked out if they need to downgrade after updating
// TODO: Remove this entire ifdef block in 2026.1.0
if (client_supports_sha256) {
sha256::SHA256 sha_hasher;
auth_success = this->perform_hash_auth_(&sha_hasher, this->password_, ota::OTA_RESPONSE_REQUEST_SHA256_AUTH,
LOG_STR("SHA256"), sbuf);
} else {
#ifdef USE_OTA_MD5
ESP_LOGW(TAG, "Using MD5 auth for compatibility (deprecated)");
md5::MD5Digest md5_hasher;
auth_success =
this->perform_hash_auth_(&md5_hasher, this->password_, ota::OTA_RESPONSE_REQUEST_AUTH, LOG_STR("MD5"), sbuf);
#endif // USE_OTA_MD5
}
#else
// Strict mode: SHA256 required on capable platforms (future default)
if (!client_supports_sha256) {
ESP_LOGW(TAG, "Client requires SHA256");
error_code = ota::OTA_RESPONSE_ERROR_AUTH_INVALID;
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
}
sha256::SHA256 sha_hasher;
auth_success = this->perform_hash_auth_(&sha_hasher, this->password_, ota::OTA_RESPONSE_REQUEST_SHA256_AUTH,
LOG_STR("SHA256"), sbuf);
#endif // ALLOW_OTA_DOWNGRADE_MD5
#else
// Platform only supports MD5 - use it as the only available option
// This is not a security downgrade as the platform cannot support SHA256
#ifdef USE_OTA_MD5
md5::MD5Digest md5_hasher;
auth_success =
this->perform_hash_auth_(&md5_hasher, this->password_, ota::OTA_RESPONSE_REQUEST_AUTH, LOG_STR("MD5"), sbuf);
#endif // USE_OTA_MD5
#endif // USE_OTA_SHA256
if (!auth_success) {
error_code = ota::OTA_RESPONSE_ERROR_AUTH_INVALID;
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
}
}
#endif // USE_OTA_PASSWORD
// Acknowledge auth OK - 1 byte
buf[0] = ota::OTA_RESPONSE_AUTH_OK;
this->writeall_(buf, 1);
@@ -334,7 +307,7 @@ void ESPHomeOTAComponent::handle_data_() {
#endif
// This will block for a few seconds as it locks flash
error_code = backend->begin(ota_size);
error_code = this->backend_->begin(ota_size);
if (error_code != ota::OTA_RESPONSE_OK)
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
update_started = true;
@@ -350,7 +323,7 @@ void ESPHomeOTAComponent::handle_data_() {
}
sbuf[32] = '\0';
ESP_LOGV(TAG, "Update: Binary MD5 is %s", sbuf);
backend->set_update_md5(sbuf);
this->backend_->set_update_md5(sbuf);
// Acknowledge MD5 OK - 1 byte
buf[0] = ota::OTA_RESPONSE_BIN_MD5_OK;
@@ -358,26 +331,24 @@ void ESPHomeOTAComponent::handle_data_() {
while (total < ota_size) {
// TODO: timeout check
size_t requested = std::min(sizeof(buf), ota_size - total);
size_t remaining = ota_size - total;
size_t requested = remaining < OTA_BUFFER_SIZE ? remaining : OTA_BUFFER_SIZE;
ssize_t read = this->client_->read(buf, requested);
if (read == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
if (this->would_block_(errno)) {
this->yield_and_feed_watchdog_();
continue;
}
ESP_LOGW(TAG, "Read error, errno %d", errno);
ESP_LOGW(TAG, "Read err %d", errno);
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
} else if (read == 0) {
// $ man recv
// "When a stream socket peer has performed an orderly shutdown, the return value will
// be 0 (the traditional "end-of-file" return)."
ESP_LOGW(TAG, "Remote closed connection");
ESP_LOGW(TAG, "Remote closed");
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
}
error_code = backend->write(buf, read);
error_code = this->backend_->write(buf, read);
if (error_code != ota::OTA_RESPONSE_OK) {
ESP_LOGW(TAG, "Flash write error, code: %d", error_code);
ESP_LOGW(TAG, "Flash write err %d", error_code);
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
}
total += read;
@@ -406,9 +377,9 @@ void ESPHomeOTAComponent::handle_data_() {
buf[0] = ota::OTA_RESPONSE_RECEIVE_OK;
this->writeall_(buf, 1);
error_code = backend->end();
error_code = this->backend_->end();
if (error_code != ota::OTA_RESPONSE_OK) {
ESP_LOGW(TAG, "Error ending update! code: %d", error_code);
ESP_LOGW(TAG, "End update err %d", error_code);
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
}
@@ -437,8 +408,8 @@ error:
this->writeall_(buf, 1);
this->cleanup_connection_();
if (backend != nullptr && update_started) {
backend->abort();
if (this->backend_ != nullptr && update_started) {
this->backend_->abort();
}
this->status_momentary_error("onerror", 5000);
@@ -459,12 +430,12 @@ bool ESPHomeOTAComponent::readall_(uint8_t *buf, size_t len) {
ssize_t read = this->client_->read(buf + at, len - at);
if (read == -1) {
if (errno != EAGAIN && errno != EWOULDBLOCK) {
ESP_LOGW(TAG, "Error reading %d bytes, errno %d", len, errno);
if (!this->would_block_(errno)) {
ESP_LOGW(TAG, "Read err %d bytes, errno %d", len, errno);
return false;
}
} else if (read == 0) {
ESP_LOGW(TAG, "Remote closed connection");
ESP_LOGW(TAG, "Remote closed");
return false;
} else {
at += read;
@@ -486,8 +457,8 @@ bool ESPHomeOTAComponent::writeall_(const uint8_t *buf, size_t len) {
ssize_t written = this->client_->write(buf + at, len - at);
if (written == -1) {
if (errno != EAGAIN && errno != EWOULDBLOCK) {
ESP_LOGW(TAG, "Error writing %d bytes, errno %d", len, errno);
if (!this->would_block_(errno)) {
ESP_LOGW(TAG, "Write err %d bytes, errno %d", len, errno);
return false;
}
} else {
@@ -512,11 +483,74 @@ void ESPHomeOTAComponent::log_start_(const LogString *phase) {
ESP_LOGD(TAG, "Starting %s from %s", LOG_STR_ARG(phase), this->client_->getpeername().c_str());
}
void ESPHomeOTAComponent::log_remote_closed_(const LogString *during) {
ESP_LOGW(TAG, "Remote closed at %s", LOG_STR_ARG(during));
}
bool ESPHomeOTAComponent::handle_read_error_(ssize_t read, const LogString *desc) {
if (read == -1 && this->would_block_(errno)) {
return false; // No data yet, try again next loop
}
if (read <= 0) {
read == 0 ? this->log_remote_closed_(desc) : this->log_socket_error_(desc);
this->cleanup_connection_();
return false;
}
return true;
}
bool ESPHomeOTAComponent::handle_write_error_(ssize_t written, const LogString *desc) {
if (written == -1) {
if (this->would_block_(errno)) {
return false; // Try again next loop
}
this->log_socket_error_(desc);
this->cleanup_connection_();
return false;
}
return true;
}
bool ESPHomeOTAComponent::try_read_(size_t to_read, const LogString *desc) {
// Read bytes into handshake buffer, starting at handshake_buf_pos_
size_t bytes_to_read = to_read - this->handshake_buf_pos_;
ssize_t read = this->client_->read(this->handshake_buf_ + this->handshake_buf_pos_, bytes_to_read);
if (!this->handle_read_error_(read, desc)) {
return false;
}
this->handshake_buf_pos_ += read;
// Return true only if we have all the requested bytes
return this->handshake_buf_pos_ >= to_read;
}
bool ESPHomeOTAComponent::try_write_(size_t to_write, const LogString *desc) {
// Write bytes from handshake buffer, starting at handshake_buf_pos_
size_t bytes_to_write = to_write - this->handshake_buf_pos_;
ssize_t written = this->client_->write(this->handshake_buf_ + this->handshake_buf_pos_, bytes_to_write);
if (!this->handle_write_error_(written, desc)) {
return false;
}
this->handshake_buf_pos_ += written;
// Return true only if we have written all the requested bytes
return this->handshake_buf_pos_ >= to_write;
}
void ESPHomeOTAComponent::cleanup_connection_() {
this->client_->close();
this->client_ = nullptr;
this->client_connect_time_ = 0;
this->magic_buf_pos_ = 0;
this->handshake_buf_pos_ = 0;
this->ota_state_ = OTAState::IDLE;
this->ota_features_ = 0;
this->backend_ = nullptr;
#ifdef USE_OTA_PASSWORD
this->cleanup_auth_();
#endif
}
void ESPHomeOTAComponent::yield_and_feed_watchdog_() {
@@ -525,82 +559,247 @@ void ESPHomeOTAComponent::yield_and_feed_watchdog_() {
}
#ifdef USE_OTA_PASSWORD
void ESPHomeOTAComponent::log_auth_warning_(const LogString *action, const LogString *hash_name) {
ESP_LOGW(TAG, "Auth: %s %s failed", LOG_STR_ARG(action), LOG_STR_ARG(hash_name));
void ESPHomeOTAComponent::log_auth_warning_(const LogString *msg) { ESP_LOGW(TAG, "Auth: %s", LOG_STR_ARG(msg)); }
bool ESPHomeOTAComponent::select_auth_type_() {
#ifdef USE_OTA_SHA256
bool client_supports_sha256 = (this->ota_features_ & FEATURE_SUPPORTS_SHA256_AUTH) != 0;
#ifdef ALLOW_OTA_DOWNGRADE_MD5
// Allow fallback to MD5 if client doesn't support SHA256
if (client_supports_sha256) {
this->auth_type_ = ota::OTA_RESPONSE_REQUEST_SHA256_AUTH;
return true;
}
#ifdef USE_OTA_MD5
this->log_auth_warning_(LOG_STR("Using deprecated MD5"));
this->auth_type_ = ota::OTA_RESPONSE_REQUEST_AUTH;
return true;
#else
this->log_auth_warning_(LOG_STR("SHA256 required"));
this->send_error_and_cleanup_(ota::OTA_RESPONSE_ERROR_AUTH_INVALID);
return false;
#endif // USE_OTA_MD5
#else // !ALLOW_OTA_DOWNGRADE_MD5
// Require SHA256
if (!client_supports_sha256) {
this->log_auth_warning_(LOG_STR("SHA256 required"));
this->send_error_and_cleanup_(ota::OTA_RESPONSE_ERROR_AUTH_INVALID);
return false;
}
this->auth_type_ = ota::OTA_RESPONSE_REQUEST_SHA256_AUTH;
return true;
#endif // ALLOW_OTA_DOWNGRADE_MD5
#else // !USE_OTA_SHA256
#ifdef USE_OTA_MD5
// Only MD5 available
this->auth_type_ = ota::OTA_RESPONSE_REQUEST_AUTH;
return true;
#else
// No auth methods available
this->log_auth_warning_(LOG_STR("No auth methods available"));
this->send_error_and_cleanup_(ota::OTA_RESPONSE_ERROR_AUTH_INVALID);
return false;
#endif // USE_OTA_MD5
#endif // USE_OTA_SHA256
}
// Non-template function definition to reduce binary size
bool ESPHomeOTAComponent::perform_hash_auth_(HashBase *hasher, const std::string &password, uint8_t auth_request,
const LogString *name, char *buf) {
// Get sizes from the hasher
const size_t hex_size = hasher->get_size() * 2; // Hex is twice the byte size
const size_t nonce_len = hasher->get_size() / 4; // Nonce is 1/4 of hash size in bytes
bool ESPHomeOTAComponent::handle_auth_send_() {
// Initialize auth buffer if not already done
if (!this->auth_buf_) {
// Select auth type based on client capabilities and configuration
if (!this->select_auth_type_()) {
return false;
}
// Use the provided buffer for all hex operations
// Generate nonce with appropriate hasher
bool success = false;
#ifdef USE_OTA_SHA256
if (this->auth_type_ == ota::OTA_RESPONSE_REQUEST_SHA256_AUTH) {
sha256::SHA256 sha_hasher;
success = this->prepare_auth_nonce_(&sha_hasher);
}
#endif
#ifdef USE_OTA_MD5
if (this->auth_type_ == ota::OTA_RESPONSE_REQUEST_AUTH) {
md5::MD5Digest md5_hasher;
success = this->prepare_auth_nonce_(&md5_hasher);
}
#endif
// Small stack buffer for nonce seed bytes
uint8_t nonce_bytes[8]; // Max 8 bytes (2 x uint32_t for SHA256)
// Send auth request type
this->writeall_(&auth_request, 1);
hasher->init();
// Generate nonce seed bytes using random_bytes
if (!random_bytes(nonce_bytes, nonce_len)) {
this->log_auth_warning_(LOG_STR("Random bytes generation failed"), name);
return false;
if (!success) {
return false;
}
}
hasher->add(nonce_bytes, nonce_len);
hasher->calculate();
// Generate and send nonce
hasher->get_hex(buf);
buf[hex_size] = '\0';
ESP_LOGV(TAG, "Auth: %s Nonce is %s", LOG_STR_ARG(name), buf);
// Try to write auth_type + nonce
size_t hex_size = this->get_auth_hex_size_();
const size_t to_write = 1 + hex_size;
size_t remaining = to_write - this->auth_buf_pos_;
if (!this->writeall_(reinterpret_cast<uint8_t *>(buf), hex_size)) {
this->log_auth_warning_(LOG_STR("Writing nonce"), name);
ssize_t written = this->client_->write(this->auth_buf_.get() + this->auth_buf_pos_, remaining);
if (!this->handle_write_error_(written, LOG_STR("ack auth"))) {
return false;
}
// Start challenge: password + nonce
hasher->init();
hasher->add(password.c_str(), password.length());
hasher->add(buf, hex_size);
this->auth_buf_pos_ += written;
// Read cnonce and add to hash
if (!this->readall_(reinterpret_cast<uint8_t *>(buf), hex_size)) {
this->log_auth_warning_(LOG_STR("Reading cnonce"), name);
// Check if we still have more to write
if (this->auth_buf_pos_ < to_write) {
return false; // More to write, try again next loop
}
// All written, prepare for reading phase
this->auth_buf_pos_ = 0;
return true;
}
bool ESPHomeOTAComponent::handle_auth_read_() {
size_t hex_size = this->get_auth_hex_size_();
const size_t to_read = hex_size * 2; // CNonce + Response
// Try to read remaining bytes (CNonce + Response)
// We read cnonce+response starting at offset 1+hex_size (after auth_type and our nonce)
size_t cnonce_offset = 1 + hex_size; // Offset where cnonce should be stored in buffer
size_t remaining = to_read - this->auth_buf_pos_;
ssize_t read = this->client_->read(this->auth_buf_.get() + cnonce_offset + this->auth_buf_pos_, remaining);
if (!this->handle_read_error_(read, LOG_STR("read auth"))) {
return false;
}
buf[hex_size] = '\0';
ESP_LOGV(TAG, "Auth: %s CNonce is %s", LOG_STR_ARG(name), buf);
hasher->add(buf, hex_size);
hasher->calculate();
this->auth_buf_pos_ += read;
// Log expected result (digest is already in hasher)
hasher->get_hex(buf);
buf[hex_size] = '\0';
ESP_LOGV(TAG, "Auth: %s Result is %s", LOG_STR_ARG(name), buf);
// Read response into the buffer
if (!this->readall_(reinterpret_cast<uint8_t *>(buf), hex_size)) {
this->log_auth_warning_(LOG_STR("Reading response"), name);
return false;
// Check if we still need more data
if (this->auth_buf_pos_ < to_read) {
return false; // More to read, try again next loop
}
buf[hex_size] = '\0';
ESP_LOGV(TAG, "Auth: %s Response is %s", LOG_STR_ARG(name), buf);
// Compare response directly with digest in hasher
bool matches = hasher->equals_hex(buf);
// We have all the data, verify it
bool matches = false;
#ifdef USE_OTA_SHA256
if (this->auth_type_ == ota::OTA_RESPONSE_REQUEST_SHA256_AUTH) {
sha256::SHA256 sha_hasher;
matches = this->verify_hash_auth_(&sha_hasher, hex_size);
}
#endif
#ifdef USE_OTA_MD5
if (this->auth_type_ == ota::OTA_RESPONSE_REQUEST_AUTH) {
md5::MD5Digest md5_hasher;
matches = this->verify_hash_auth_(&md5_hasher, hex_size);
}
#endif
if (!matches) {
this->log_auth_warning_(LOG_STR("Password mismatch"), name);
this->log_auth_warning_(LOG_STR("Password mismatch"));
this->send_error_and_cleanup_(ota::OTA_RESPONSE_ERROR_AUTH_INVALID);
return false;
}
return matches;
// Authentication successful - clean up auth state
this->cleanup_auth_();
return true;
}
bool ESPHomeOTAComponent::prepare_auth_nonce_(HashBase *hasher) {
// Calculate required buffer size using the hasher
const size_t hex_size = hasher->get_size() * 2;
const size_t nonce_len = hasher->get_size() / 4;
// Buffer layout after AUTH_READ completes:
// [0]: auth_type (1 byte)
// [1...hex_size]: nonce (hex_size bytes) - our random nonce sent in AUTH_SEND
// [1+hex_size...1+2*hex_size-1]: cnonce (hex_size bytes) - client's nonce
// [1+2*hex_size...1+3*hex_size-1]: response (hex_size bytes) - client's hash
// Total: 1 + 3*hex_size
const size_t auth_buf_size = 1 + 3 * hex_size;
this->auth_buf_ = std::make_unique<uint8_t[]>(auth_buf_size);
this->auth_buf_pos_ = 0;
// Generate nonce
char *buf = reinterpret_cast<char *>(this->auth_buf_.get() + 1);
if (!random_bytes(reinterpret_cast<uint8_t *>(buf), nonce_len)) {
this->log_auth_warning_(LOG_STR("Random failed"));
this->send_error_and_cleanup_(ota::OTA_RESPONSE_ERROR_UNKNOWN);
return false;
}
hasher->init();
hasher->add(buf, nonce_len);
hasher->calculate();
// Prepare buffer: auth_type (1 byte) + nonce (hex_size bytes)
this->auth_buf_[0] = this->auth_type_;
hasher->get_hex(buf);
#if ESPHOME_LOG_LEVEL >= ESPHOME_LOG_LEVEL_VERBOSE
char log_buf[hex_size + 1];
// Log nonce for debugging
memcpy(log_buf, buf, hex_size);
log_buf[hex_size] = '\0';
ESP_LOGV(TAG, "Auth: Nonce is %s", log_buf);
#endif
return true;
}
bool ESPHomeOTAComponent::verify_hash_auth_(HashBase *hasher, size_t hex_size) {
// Get pointers to the data in the buffer (see prepare_auth_nonce_ for buffer layout)
const char *nonce = reinterpret_cast<char *>(this->auth_buf_.get() + 1); // Skip auth_type byte
const char *cnonce = nonce + hex_size; // CNonce immediately follows nonce
const char *response = cnonce + hex_size; // Response immediately follows cnonce
// Calculate expected hash: password + nonce + cnonce
hasher->init();
hasher->add(this->password_.c_str(), this->password_.length());
hasher->add(nonce, hex_size * 2); // Add both nonce and cnonce (contiguous in buffer)
hasher->calculate();
#if ESPHOME_LOG_LEVEL >= ESPHOME_LOG_LEVEL_VERBOSE
char log_buf[hex_size + 1];
// Log CNonce
memcpy(log_buf, cnonce, hex_size);
log_buf[hex_size] = '\0';
ESP_LOGV(TAG, "Auth: CNonce is %s", log_buf);
// Log computed hash
hasher->get_hex(log_buf);
log_buf[hex_size] = '\0';
ESP_LOGV(TAG, "Auth: Result is %s", log_buf);
// Log received response
memcpy(log_buf, response, hex_size);
log_buf[hex_size] = '\0';
ESP_LOGV(TAG, "Auth: Response is %s", log_buf);
#endif
// Compare response
return hasher->equals_hex(response);
}
size_t ESPHomeOTAComponent::get_auth_hex_size_() const {
#ifdef USE_OTA_SHA256
if (this->auth_type_ == ota::OTA_RESPONSE_REQUEST_SHA256_AUTH) {
return SHA256_HEX_SIZE;
}
#endif
#ifdef USE_OTA_MD5
return MD5_HEX_SIZE;
#else
#ifndef USE_OTA_SHA256
#error "Either USE_OTA_MD5 or USE_OTA_SHA256 must be defined when USE_OTA_PASSWORD is enabled"
#endif
#endif
}
void ESPHomeOTAComponent::cleanup_auth_() {
this->auth_buf_ = nullptr;
this->auth_buf_pos_ = 0;
this->auth_type_ = 0;
}
#endif // USE_OTA_PASSWORD

View File

@@ -14,6 +14,18 @@ namespace esphome {
/// ESPHomeOTAComponent provides a simple way to integrate Over-the-Air updates into your app using ArduinoOTA.
class ESPHomeOTAComponent : public ota::OTAComponent {
public:
enum class OTAState : uint8_t {
IDLE,
MAGIC_READ, // Reading magic bytes
MAGIC_ACK, // Sending OK and version after magic bytes
FEATURE_READ, // Reading feature flags from client
FEATURE_ACK, // Sending feature acknowledgment
#ifdef USE_OTA_PASSWORD
AUTH_SEND, // Sending authentication request
AUTH_READ, // Reading authentication data
#endif // USE_OTA_PASSWORD
DATA, // BLOCKING! Processing OTA data (update, etc.)
};
#ifdef USE_OTA_PASSWORD
void set_auth_password(const std::string &password) { password_ = password; }
#endif // USE_OTA_PASSWORD
@@ -32,16 +44,39 @@ class ESPHomeOTAComponent : public ota::OTAComponent {
void handle_handshake_();
void handle_data_();
#ifdef USE_OTA_PASSWORD
bool perform_hash_auth_(HashBase *hasher, const std::string &password, uint8_t auth_request, const LogString *name,
char *buf);
void log_auth_warning_(const LogString *action, const LogString *hash_name);
bool handle_auth_send_();
bool handle_auth_read_();
bool select_auth_type_();
bool prepare_auth_nonce_(HashBase *hasher);
bool verify_hash_auth_(HashBase *hasher, size_t hex_size);
size_t get_auth_hex_size_() const;
void cleanup_auth_();
void log_auth_warning_(const LogString *msg);
#endif // USE_OTA_PASSWORD
bool readall_(uint8_t *buf, size_t len);
bool writeall_(const uint8_t *buf, size_t len);
bool try_read_(size_t to_read, const LogString *desc);
bool try_write_(size_t to_write, const LogString *desc);
inline bool would_block_(int error_code) const { return error_code == EAGAIN || error_code == EWOULDBLOCK; }
bool handle_read_error_(ssize_t read, const LogString *desc);
bool handle_write_error_(ssize_t written, const LogString *desc);
inline void transition_ota_state_(OTAState next_state) {
this->ota_state_ = next_state;
this->handshake_buf_pos_ = 0; // Reset buffer position for next state
}
void log_socket_error_(const LogString *msg);
void log_read_error_(const LogString *what);
void log_start_(const LogString *phase);
void log_remote_closed_(const LogString *during);
void cleanup_connection_();
inline void send_error_and_cleanup_(ota::OTAResponseTypes error) {
uint8_t error_byte = static_cast<uint8_t>(error);
this->client_->write(&error_byte, 1); // Best effort, non-blocking
this->cleanup_connection_();
}
void yield_and_feed_watchdog_();
#ifdef USE_OTA_PASSWORD
@@ -50,11 +85,19 @@ class ESPHomeOTAComponent : public ota::OTAComponent {
std::unique_ptr<socket::Socket> server_;
std::unique_ptr<socket::Socket> client_;
std::unique_ptr<ota::OTABackend> backend_;
uint32_t client_connect_time_{0};
uint16_t port_;
uint8_t magic_buf_[5];
uint8_t magic_buf_pos_{0};
uint8_t handshake_buf_[5];
OTAState ota_state_{OTAState::IDLE};
uint8_t handshake_buf_pos_{0};
uint8_t ota_features_{0};
#ifdef USE_OTA_PASSWORD
std::unique_ptr<uint8_t[]> auth_buf_;
uint8_t auth_buf_pos_{0};
uint8_t auth_type_{0}; // Store auth type to know which hasher to use
#endif // USE_OTA_PASSWORD
};
} // namespace esphome

View File

@@ -27,6 +27,7 @@ from esphome.const import (
CONF_GATEWAY,
CONF_ID,
CONF_INTERRUPT_PIN,
CONF_MAC_ADDRESS,
CONF_MANUAL_IP,
CONF_MISO_PIN,
CONF_MODE,
@@ -197,6 +198,7 @@ BASE_SCHEMA = cv.Schema(
"This option has been removed. Please use the [disabled] option under the "
"new mdns component instead."
),
cv.Optional(CONF_MAC_ADDRESS): cv.mac_address,
}
).extend(cv.COMPONENT_SCHEMA)
@@ -365,6 +367,9 @@ async def to_code(config):
if phy_define := _PHY_TYPE_TO_DEFINE.get(config[CONF_TYPE]):
cg.add_define(phy_define)
if mac_address := config.get(CONF_MAC_ADDRESS):
cg.add(var.set_fixed_mac(mac_address.parts))
cg.add_define("USE_ETHERNET")
# Disable WiFi when using Ethernet to save memory

View File

@@ -253,7 +253,11 @@ void EthernetComponent::setup() {
// use ESP internal eth mac
uint8_t mac_addr[6];
esp_read_mac(mac_addr, ESP_MAC_ETH);
if (this->fixed_mac_.has_value()) {
memcpy(mac_addr, this->fixed_mac_->data(), 6);
} else {
esp_read_mac(mac_addr, ESP_MAC_ETH);
}
err = esp_eth_ioctl(this->eth_handle_, ETH_CMD_S_MAC_ADDR, mac_addr);
ESPHL_ERROR_CHECK(err, "set mac address error");

View File

@@ -84,6 +84,7 @@ class EthernetComponent : public Component {
#endif
void set_type(EthernetType type);
void set_manual_ip(const ManualIP &manual_ip);
void set_fixed_mac(const std::array<uint8_t, 6> &mac) { this->fixed_mac_ = mac; }
network::IPAddresses get_ip_addresses();
network::IPAddress get_dns_address(uint8_t num);
@@ -155,6 +156,7 @@ class EthernetComponent : public Component {
esp_netif_t *eth_netif_{nullptr};
esp_eth_handle_t eth_handle_;
esp_eth_phy_t *phy_{nullptr};
optional<std::array<uint8_t, 6>> fixed_mac_;
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)

View File

@@ -9,6 +9,7 @@ from esphome.const import (
CONF_ID,
CONF_METHOD,
CONF_ON_ERROR,
CONF_ON_RESPONSE,
CONF_TIMEOUT,
CONF_TRIGGER_ID,
CONF_URL,
@@ -52,7 +53,6 @@ CONF_BUFFER_SIZE_TX = "buffer_size_tx"
CONF_CA_CERTIFICATE_PATH = "ca_certificate_path"
CONF_MAX_RESPONSE_BUFFER_SIZE = "max_response_buffer_size"
CONF_ON_RESPONSE = "on_response"
CONF_HEADERS = "headers"
CONF_COLLECT_HEADERS = "collect_headers"
CONF_BODY = "body"

View File

@@ -155,7 +155,7 @@ void MCP2515::prepare_id_(uint8_t *buffer, const bool extended, const uint32_t i
canid = (uint16_t) (id >> 16);
buffer[MCP_SIDL] = (uint8_t) (canid & 0x03);
buffer[MCP_SIDL] += (uint8_t) ((canid & 0x1C) << 3);
buffer[MCP_SIDL] |= TXB_EXIDE_MASK;
buffer[MCP_SIDL] |= SIDL_EXIDE_MASK;
buffer[MCP_SIDH] = (uint8_t) (canid >> 5);
} else {
buffer[MCP_SIDH] = (uint8_t) (canid >> 3);
@@ -258,7 +258,7 @@ canbus::Error MCP2515::send_message(struct canbus::CanFrame *frame) {
}
}
return canbus::ERROR_FAILTX;
return canbus::ERROR_ALLTXBUSY;
}
canbus::Error MCP2515::read_message_(RXBn rxbn, struct canbus::CanFrame *frame) {
@@ -272,7 +272,7 @@ canbus::Error MCP2515::read_message_(RXBn rxbn, struct canbus::CanFrame *frame)
bool use_extended_id = false;
bool remote_transmission_request = false;
if ((tbufdata[MCP_SIDL] & TXB_EXIDE_MASK) == TXB_EXIDE_MASK) {
if ((tbufdata[MCP_SIDL] & SIDL_EXIDE_MASK) == SIDL_EXIDE_MASK) {
id = (id << 2) + (tbufdata[MCP_SIDL] & 0x03);
id = (id << 8) + tbufdata[MCP_EID8];
id = (id << 8) + tbufdata[MCP_EID0];
@@ -315,6 +315,17 @@ canbus::Error MCP2515::read_message(struct canbus::CanFrame *frame) {
rc = canbus::ERROR_NOMSG;
}
#ifdef ESPHOME_LOG_HAS_DEBUG
uint8_t err = get_error_flags_();
// The receive flowchart in the datasheet says that if rollover is set (BUKT), RX1OVR flag will be set
// once both buffers are full. However, the RX0OVR flag is actually set instead.
// We can just check for both though because it doesn't break anything.
if (err & (EFLG_RX0OVR | EFLG_RX1OVR)) {
ESP_LOGD(TAG, "receive buffer overrun");
clear_rx_n_ovr_flags_();
}
#endif
return rc;
}

View File

@@ -130,7 +130,9 @@ static const uint8_t CANSTAT_ICOD = 0x0E;
static const uint8_t CNF3_SOF = 0x80;
static const uint8_t TXB_EXIDE_MASK = 0x08;
// applies to RXBn_SIDL, TXBn_SIDL and RXFn_SIDL
static const uint8_t SIDL_EXIDE_MASK = 0x08;
static const uint8_t DLC_MASK = 0x0F;
static const uint8_t RTR_MASK = 0x40;

View File

@@ -9,7 +9,7 @@
#include "lwip/tcp.h"
#include <cerrno>
#include <cstring>
#include <queue>
#include <array>
#include "esphome/core/helpers.h"
#include "esphome/core/log.h"
@@ -50,12 +50,18 @@ class LWIPRawImpl : public Socket {
errno = EBADF;
return nullptr;
}
if (accepted_sockets_.empty()) {
if (this->accepted_socket_count_ == 0) {
errno = EWOULDBLOCK;
return nullptr;
}
std::unique_ptr<LWIPRawImpl> sock = std::move(accepted_sockets_.front());
accepted_sockets_.pop();
// Take from front for FIFO ordering
std::unique_ptr<LWIPRawImpl> sock = std::move(this->accepted_sockets_[0]);
// Shift remaining sockets forward
for (uint8_t i = 1; i < this->accepted_socket_count_; i++) {
this->accepted_sockets_[i - 1] = std::move(this->accepted_sockets_[i]);
}
this->accepted_socket_count_--;
LWIP_LOG("Connection accepted by application, queue size: %d", this->accepted_socket_count_);
if (addr != nullptr) {
sock->getpeername(addr, addrlen);
}
@@ -494,9 +500,18 @@ class LWIPRawImpl : public Socket {
// nothing to do here, we just don't push it to the queue
return ERR_OK;
}
// Check if we've reached the maximum accept queue size
if (this->accepted_socket_count_ >= MAX_ACCEPTED_SOCKETS) {
LWIP_LOG("Rejecting connection, queue full (%d)", this->accepted_socket_count_);
// Abort the connection when queue is full
tcp_abort(newpcb);
// Must return ERR_ABRT since we called tcp_abort()
return ERR_ABRT;
}
auto sock = make_unique<LWIPRawImpl>(family_, newpcb);
sock->init();
accepted_sockets_.push(std::move(sock));
this->accepted_sockets_[this->accepted_socket_count_++] = std::move(sock);
LWIP_LOG("Accepted connection, queue size: %d", this->accepted_socket_count_);
return ERR_OK;
}
void err_fn(err_t err) {
@@ -587,7 +602,20 @@ class LWIPRawImpl : public Socket {
}
struct tcp_pcb *pcb_;
std::queue<std::unique_ptr<LWIPRawImpl>> accepted_sockets_;
// Accept queue - holds incoming connections briefly until the event loop calls accept()
// This is NOT a connection pool - just a temporary queue between LWIP callbacks and the main loop
// 3 slots is plenty since connections are pulled out quickly by the event loop
//
// Memory analysis: std::array<3> vs original std::queue implementation:
// - std::queue uses std::deque internally which on 32-bit systems needs:
// 24 bytes (deque object) + 32+ bytes (map array) + heap allocations
// Total: ~56+ bytes minimum, plus heap fragmentation
// - std::array<3>: 12 bytes fixed (3 pointers × 4 bytes)
// Saves ~44+ bytes RAM per listening socket + avoids ALL heap allocations
// Used on ESP8266 and RP2040 (platforms using LWIP_TCP implementation)
static constexpr size_t MAX_ACCEPTED_SOCKETS = 3;
std::array<std::unique_ptr<LWIPRawImpl>, MAX_ACCEPTED_SOCKETS> accepted_sockets_;
uint8_t accepted_socket_count_ = 0; // Number of sockets currently in queue
bool rx_closed_ = false;
pbuf *rx_buf_ = nullptr;
size_t rx_buf_offset_ = 0;

View File

@@ -15,6 +15,10 @@ CONF_BANDWIDTH = "bandwidth"
CONF_BITRATE = "bitrate"
CONF_CODING_RATE = "coding_rate"
CONF_CRC_ENABLE = "crc_enable"
CONF_CRC_INVERTED = "crc_inverted"
CONF_CRC_SIZE = "crc_size"
CONF_CRC_POLYNOMIAL = "crc_polynomial"
CONF_CRC_INITIAL = "crc_initial"
CONF_DEVIATION = "deviation"
CONF_DIO1_PIN = "dio1_pin"
CONF_HW_VERSION = "hw_version"
@@ -188,6 +192,14 @@ CONFIG_SCHEMA = (
cv.Required(CONF_BUSY_PIN): pins.internal_gpio_input_pin_schema,
cv.Optional(CONF_CODING_RATE, default="CR_4_5"): cv.enum(CODING_RATE),
cv.Optional(CONF_CRC_ENABLE, default=False): cv.boolean,
cv.Optional(CONF_CRC_INVERTED, default=True): cv.boolean,
cv.Optional(CONF_CRC_SIZE, default=2): cv.int_range(min=1, max=2),
cv.Optional(CONF_CRC_POLYNOMIAL, default=0x1021): cv.All(
cv.hex_int, cv.Range(min=0, max=0xFFFF)
),
cv.Optional(CONF_CRC_INITIAL, default=0x1D0F): cv.All(
cv.hex_int, cv.Range(min=0, max=0xFFFF)
),
cv.Optional(CONF_DEVIATION, default=5000): cv.int_range(min=0, max=100000),
cv.Required(CONF_DIO1_PIN): pins.internal_gpio_input_pin_schema,
cv.Required(CONF_FREQUENCY): cv.int_range(min=137000000, max=1020000000),
@@ -251,6 +263,10 @@ async def to_code(config):
cg.add(var.set_shaping(config[CONF_SHAPING]))
cg.add(var.set_bitrate(config[CONF_BITRATE]))
cg.add(var.set_crc_enable(config[CONF_CRC_ENABLE]))
cg.add(var.set_crc_inverted(config[CONF_CRC_INVERTED]))
cg.add(var.set_crc_size(config[CONF_CRC_SIZE]))
cg.add(var.set_crc_polynomial(config[CONF_CRC_POLYNOMIAL]))
cg.add(var.set_crc_initial(config[CONF_CRC_INITIAL]))
cg.add(var.set_payload_length(config[CONF_PAYLOAD_LENGTH]))
cg.add(var.set_preamble_size(config[CONF_PREAMBLE_SIZE]))
cg.add(var.set_preamble_detect(config[CONF_PREAMBLE_DETECT]))

View File

@@ -235,6 +235,16 @@ void SX126x::configure() {
buf[7] = (fdev >> 0) & 0xFF;
this->write_opcode_(RADIO_SET_MODULATIONPARAMS, buf, 8);
// set crc params
if (this->crc_enable_) {
buf[0] = this->crc_initial_ >> 8;
buf[1] = this->crc_initial_ & 0xFF;
this->write_register_(REG_CRC_INITIAL, buf, 2);
buf[0] = this->crc_polynomial_ >> 8;
buf[1] = this->crc_polynomial_ & 0xFF;
this->write_register_(REG_CRC_POLYNOMIAL, buf, 2);
}
// set packet params and sync word
this->set_packet_params_(this->get_max_packet_size());
if (!this->sync_value_.empty()) {
@@ -276,7 +286,11 @@ void SX126x::set_packet_params_(uint8_t payload_length) {
buf[4] = 0x00;
buf[5] = (this->payload_length_ > 0) ? 0x00 : 0x01;
buf[6] = payload_length;
buf[7] = this->crc_enable_ ? 0x06 : 0x01;
if (this->crc_enable_) {
buf[7] = (this->crc_inverted_ ? 0x04 : 0x00) + (this->crc_size_ & 0x02);
} else {
buf[7] = 0x01;
}
buf[8] = 0x00;
this->write_opcode_(RADIO_SET_PACKETPARAMS, buf, 9);
}

View File

@@ -67,6 +67,10 @@ class SX126x : public Component,
void set_busy_pin(InternalGPIOPin *busy_pin) { this->busy_pin_ = busy_pin; }
void set_coding_rate(uint8_t coding_rate) { this->coding_rate_ = coding_rate; }
void set_crc_enable(bool crc_enable) { this->crc_enable_ = crc_enable; }
void set_crc_inverted(bool crc_inverted) { this->crc_inverted_ = crc_inverted; }
void set_crc_size(uint8_t crc_size) { this->crc_size_ = crc_size; }
void set_crc_polynomial(uint16_t crc_polynomial) { this->crc_polynomial_ = crc_polynomial; }
void set_crc_initial(uint16_t crc_initial) { this->crc_initial_ = crc_initial; }
void set_deviation(uint32_t deviation) { this->deviation_ = deviation; }
void set_dio1_pin(InternalGPIOPin *dio1_pin) { this->dio1_pin_ = dio1_pin; }
void set_frequency(uint32_t frequency) { this->frequency_ = frequency; }
@@ -118,6 +122,11 @@ class SX126x : public Component,
char version_[16];
SX126xBw bandwidth_{SX126X_BW_125000};
uint32_t bitrate_{0};
bool crc_enable_{false};
bool crc_inverted_{false};
uint8_t crc_size_{0};
uint16_t crc_polynomial_{0};
uint16_t crc_initial_{0};
uint32_t deviation_{0};
uint32_t frequency_{0};
uint32_t payload_length_{0};
@@ -131,7 +140,6 @@ class SX126x : public Component,
uint8_t shaping_{0};
uint8_t spreading_factor_{0};
int8_t pa_power_{0};
bool crc_enable_{false};
bool rx_start_{false};
bool rf_switch_{false};
};

View File

@@ -53,6 +53,8 @@ enum SX126xOpCode : uint8_t {
enum SX126xRegister : uint16_t {
REG_VERSION_STRING = 0x0320,
REG_CRC_INITIAL = 0x06BC,
REG_CRC_POLYNOMIAL = 0x06BE,
REG_GFSK_SYNCWORD = 0x06C0,
REG_LORA_SYNCWORD = 0x0740,
REG_OCP = 0x08E7,

View File

@@ -242,7 +242,6 @@ void VoiceAssistant::loop() {
msg.flags = flags;
msg.audio_settings = audio_settings;
msg.set_wake_word_phrase(StringRef(this->wake_word_));
this->wake_word_ = "";
// Reset media player state tracking
#ifdef USE_MEDIA_PLAYER

View File

@@ -27,6 +27,10 @@
#include "dhcpserver/dhcpserver.h"
#endif // USE_WIFI_AP
#ifdef USE_CAPTIVE_PORTAL
#include "esphome/components/captive_portal/captive_portal.h"
#endif
#include "lwip/apps/sntp.h"
#include "lwip/dns.h"
#include "lwip/err.h"
@@ -918,6 +922,22 @@ bool WiFiComponent::wifi_ap_ip_config_(optional<ManualIP> manual_ip) {
return false;
}
#if defined(USE_CAPTIVE_PORTAL) && ESP_IDF_VERSION >= ESP_IDF_VERSION_VAL(5, 4, 0)
// Configure DHCP Option 114 (Captive Portal URI) if captive portal is enabled
// This provides a standards-compliant way for clients to discover the captive portal
if (captive_portal::global_captive_portal != nullptr) {
static char captive_portal_uri[32];
snprintf(captive_portal_uri, sizeof(captive_portal_uri), "http://%s", network::IPAddress(&info.ip).str().c_str());
err = esp_netif_dhcps_option(s_ap_netif, ESP_NETIF_OP_SET, ESP_NETIF_CAPTIVEPORTAL_URI, captive_portal_uri,
strlen(captive_portal_uri));
if (err != ESP_OK) {
ESP_LOGV(TAG, "Failed to set DHCP captive portal URI: %s", esp_err_to_name(err));
} else {
ESP_LOGV(TAG, "DHCP Captive Portal URI set to: %s", captive_portal_uri);
}
}
#endif
err = esp_netif_dhcps_start(s_ap_netif);
if (err != ESP_OK) {

View File

@@ -126,6 +126,7 @@
#define USE_OTA_MD5
#define USE_OTA_PASSWORD
#define USE_OTA_SHA256
#define ALLOW_OTA_DOWNGRADE_MD5
#define USE_OTA_STATE_CALLBACK
#define USE_OTA_VERSION 2
#define USE_TIME_TIMEZONE

View File

@@ -1,9 +1,26 @@
from __future__ import annotations
EVENT_ENTRY_ADDED = "entry_added"
EVENT_ENTRY_REMOVED = "entry_removed"
EVENT_ENTRY_UPDATED = "entry_updated"
EVENT_ENTRY_STATE_CHANGED = "entry_state_changed"
from esphome.enum import StrEnum
class DashboardEvent(StrEnum):
"""Dashboard WebSocket event types."""
# Server -> Client events (backend sends to frontend)
ENTRY_ADDED = "entry_added"
ENTRY_REMOVED = "entry_removed"
ENTRY_UPDATED = "entry_updated"
ENTRY_STATE_CHANGED = "entry_state_changed"
IMPORTABLE_DEVICE_ADDED = "importable_device_added"
IMPORTABLE_DEVICE_REMOVED = "importable_device_removed"
INITIAL_STATE = "initial_state" # Sent on WebSocket connection
PONG = "pong" # Response to client ping
# Client -> Server events (frontend sends to backend)
PING = "ping" # WebSocket keepalive from client
REFRESH = "refresh" # Force backend to poll for changes
MAX_EXECUTOR_WORKERS = 48

View File

@@ -13,6 +13,7 @@ from typing import Any
from esphome.storage_json import ignored_devices_storage_path
from ..zeroconf import DiscoveredImport
from .const import DashboardEvent
from .dns import DNSCache
from .entries import DashboardEntries
from .settings import DashboardSettings
@@ -30,7 +31,7 @@ MDNS_BOOTSTRAP_TIME = 7.5
class Event:
"""Dashboard Event."""
event_type: str
event_type: DashboardEvent
data: dict[str, Any]
@@ -39,22 +40,24 @@ class EventBus:
def __init__(self) -> None:
"""Initialize the Dashboard event bus."""
self._listeners: dict[str, set[Callable[[Event], None]]] = {}
self._listeners: dict[DashboardEvent, set[Callable[[Event], None]]] = {}
def async_add_listener(
self, event_type: str, listener: Callable[[Event], None]
self, event_type: DashboardEvent, listener: Callable[[Event], None]
) -> Callable[[], None]:
"""Add a listener to the event bus."""
self._listeners.setdefault(event_type, set()).add(listener)
return partial(self._async_remove_listener, event_type, listener)
def _async_remove_listener(
self, event_type: str, listener: Callable[[Event], None]
self, event_type: DashboardEvent, listener: Callable[[Event], None]
) -> None:
"""Remove a listener from the event bus."""
self._listeners[event_type].discard(listener)
def async_fire(self, event_type: str, event_data: dict[str, Any]) -> None:
def async_fire(
self, event_type: DashboardEvent, event_data: dict[str, Any]
) -> None:
"""Fire an event."""
event = Event(event_type, event_data)

View File

@@ -12,13 +12,7 @@ from esphome import const, util
from esphome.enum import StrEnum
from esphome.storage_json import StorageJSON, ext_storage_path
from .const import (
DASHBOARD_COMMAND,
EVENT_ENTRY_ADDED,
EVENT_ENTRY_REMOVED,
EVENT_ENTRY_STATE_CHANGED,
EVENT_ENTRY_UPDATED,
)
from .const import DASHBOARD_COMMAND, DashboardEvent
from .util.subprocess import async_run_system_command
if TYPE_CHECKING:
@@ -102,12 +96,12 @@ class DashboardEntries:
# "path/to/file.yaml": DashboardEntry,
# ...
# }
self._entries: dict[str, DashboardEntry] = {}
self._entries: dict[Path, DashboardEntry] = {}
self._loaded_entries = False
self._update_lock = asyncio.Lock()
self._name_to_entry: dict[str, set[DashboardEntry]] = defaultdict(set)
def get(self, path: str) -> DashboardEntry | None:
def get(self, path: Path) -> DashboardEntry | None:
"""Get an entry by path."""
return self._entries.get(path)
@@ -192,7 +186,7 @@ class DashboardEntries:
return
entry.state = state
self._dashboard.bus.async_fire(
EVENT_ENTRY_STATE_CHANGED, {"entry": entry, "state": state}
DashboardEvent.ENTRY_STATE_CHANGED, {"entry": entry, "state": state}
)
async def async_request_update_entries(self) -> None:
@@ -260,22 +254,22 @@ class DashboardEntries:
for entry in added:
entries[entry.path] = entry
name_to_entry[entry.name].add(entry)
bus.async_fire(EVENT_ENTRY_ADDED, {"entry": entry})
bus.async_fire(DashboardEvent.ENTRY_ADDED, {"entry": entry})
for entry in removed:
del entries[entry.path]
name_to_entry[entry.name].discard(entry)
bus.async_fire(EVENT_ENTRY_REMOVED, {"entry": entry})
bus.async_fire(DashboardEvent.ENTRY_REMOVED, {"entry": entry})
for entry in updated:
if (original_name := original_names[entry]) != (current_name := entry.name):
name_to_entry[original_name].discard(entry)
name_to_entry[current_name].add(entry)
bus.async_fire(EVENT_ENTRY_UPDATED, {"entry": entry})
bus.async_fire(DashboardEvent.ENTRY_UPDATED, {"entry": entry})
def _get_path_to_cache_key(self) -> dict[str, DashboardCacheKeyType]:
def _get_path_to_cache_key(self) -> dict[Path, DashboardCacheKeyType]:
"""Return a dict of path to cache key."""
path_to_cache_key: dict[str, DashboardCacheKeyType] = {}
path_to_cache_key: dict[Path, DashboardCacheKeyType] = {}
#
# The cache key is (inode, device, mtime, size)
# which allows us to avoid locking since it ensures

View File

@@ -0,0 +1,76 @@
"""Data models and builders for the dashboard."""
from __future__ import annotations
from typing import TYPE_CHECKING, TypedDict
if TYPE_CHECKING:
from esphome.zeroconf import DiscoveredImport
from .core import ESPHomeDashboard
from .entries import DashboardEntry
class ImportableDeviceDict(TypedDict):
"""Dictionary representation of an importable device."""
name: str
friendly_name: str | None
package_import_url: str
project_name: str
project_version: str
network: str
ignored: bool
class ConfiguredDeviceDict(TypedDict, total=False):
"""Dictionary representation of a configured device."""
name: str
friendly_name: str | None
configuration: str
loaded_integrations: list[str] | None
deployed_version: str | None
current_version: str | None
path: str
comment: str | None
address: str | None
web_port: int | None
target_platform: str | None
class DeviceListResponse(TypedDict):
"""Response for device list API."""
configured: list[ConfiguredDeviceDict]
importable: list[ImportableDeviceDict]
def build_importable_device_dict(
dashboard: ESPHomeDashboard, discovered: DiscoveredImport
) -> ImportableDeviceDict:
"""Build the importable device dictionary."""
return ImportableDeviceDict(
name=discovered.device_name,
friendly_name=discovered.friendly_name,
package_import_url=discovered.package_import_url,
project_name=discovered.project_name,
project_version=discovered.project_version,
network=discovered.network,
ignored=discovered.device_name in dashboard.ignored_devices,
)
def build_device_list_response(
dashboard: ESPHomeDashboard, entries: list[DashboardEntry]
) -> DeviceListResponse:
"""Build the device list response data."""
configured = {entry.name for entry in entries}
return DeviceListResponse(
configured=[entry.to_dict() for entry in entries],
importable=[
build_importable_device_dict(dashboard, res)
for res in dashboard.import_result.values()
if res.device_name not in configured
],
)

View File

@@ -13,10 +13,12 @@ from esphome.zeroconf import (
DashboardBrowser,
DashboardImportDiscovery,
DashboardStatus,
DiscoveredImport,
)
from ..const import SENTINEL
from ..const import SENTINEL, DashboardEvent
from ..entries import DashboardEntry, EntryStateSource, bool_to_entry_state
from ..models import build_importable_device_dict
if typing.TYPE_CHECKING:
from ..core import ESPHomeDashboard
@@ -77,6 +79,20 @@ class MDNSStatus:
_LOGGER.debug("Not found in zeroconf cache: %s", resolver_name)
return None
def _on_import_update(self, name: str, discovered: DiscoveredImport | None) -> None:
"""Handle importable device updates."""
if discovered is None:
# Device removed
self.dashboard.bus.async_fire(
DashboardEvent.IMPORTABLE_DEVICE_REMOVED, {"name": name}
)
else:
# Device added
self.dashboard.bus.async_fire(
DashboardEvent.IMPORTABLE_DEVICE_ADDED,
{"device": build_importable_device_dict(self.dashboard, discovered)},
)
async def async_refresh_hosts(self) -> None:
"""Refresh the hosts to track."""
dashboard = self.dashboard
@@ -133,7 +149,8 @@ class MDNSStatus:
self._async_set_state(entry, result)
stat = DashboardStatus(on_update)
imports = DashboardImportDiscovery()
imports = DashboardImportDiscovery(self._on_import_update)
dashboard.import_result = imports.import_state
browser = DashboardBrowser(

View File

@@ -4,8 +4,10 @@ import asyncio
import base64
import binascii
from collections.abc import Callable, Iterable
import contextlib
import datetime
import functools
from functools import partial
import gzip
import hashlib
import importlib
@@ -50,9 +52,10 @@ from esphome.util import get_serial_ports, shlex_quote
from esphome.yaml_util import FastestAvailableSafeLoader
from ..helpers import write_file
from .const import DASHBOARD_COMMAND
from .core import DASHBOARD, ESPHomeDashboard
from .const import DASHBOARD_COMMAND, DashboardEvent
from .core import DASHBOARD, ESPHomeDashboard, Event
from .entries import UNKNOWN_STATE, DashboardEntry, entry_state_to_bool
from .models import build_device_list_response
from .util.subprocess import async_run_system_command
from .util.text import friendly_name_slugify
@@ -520,6 +523,243 @@ class EsphomeUpdateAllHandler(EsphomeCommandWebSocket):
return [*DASHBOARD_COMMAND, "update-all", settings.config_dir]
# Dashboard polling constants
DASHBOARD_POLL_INTERVAL = 2 # seconds
DASHBOARD_ENTRIES_UPDATE_INTERVAL = 10 # seconds
DASHBOARD_ENTRIES_UPDATE_ITERATIONS = (
DASHBOARD_ENTRIES_UPDATE_INTERVAL // DASHBOARD_POLL_INTERVAL
)
class DashboardSubscriber:
"""Manages dashboard event polling task lifecycle based on active subscribers."""
def __init__(self) -> None:
"""Initialize the dashboard subscriber."""
self._subscribers: set[DashboardEventsWebSocket] = set()
self._event_loop_task: asyncio.Task | None = None
self._refresh_event: asyncio.Event = asyncio.Event()
def subscribe(self, subscriber: DashboardEventsWebSocket) -> Callable[[], None]:
"""Subscribe to dashboard updates and start event loop if needed."""
self._subscribers.add(subscriber)
if not self._event_loop_task or self._event_loop_task.done():
self._event_loop_task = asyncio.create_task(self._event_loop())
_LOGGER.info("Started dashboard event loop")
return partial(self._unsubscribe, subscriber)
def _unsubscribe(self, subscriber: DashboardEventsWebSocket) -> None:
"""Unsubscribe from dashboard updates and stop event loop if no subscribers."""
self._subscribers.discard(subscriber)
if (
not self._subscribers
and self._event_loop_task
and not self._event_loop_task.done()
):
self._event_loop_task.cancel()
self._event_loop_task = None
_LOGGER.info("Stopped dashboard event loop - no subscribers")
def request_refresh(self) -> None:
"""Signal the polling loop to refresh immediately."""
self._refresh_event.set()
async def _event_loop(self) -> None:
"""Run the event polling loop while there are subscribers."""
dashboard = DASHBOARD
entries_update_counter = 0
while self._subscribers:
# Signal that we need ping updates (non-blocking)
dashboard.ping_request.set()
if settings.status_use_mqtt:
dashboard.mqtt_ping_request.set()
# Check if it's time to update entries or if refresh was requested
entries_update_counter += 1
if (
entries_update_counter >= DASHBOARD_ENTRIES_UPDATE_ITERATIONS
or self._refresh_event.is_set()
):
entries_update_counter = 0
await dashboard.entries.async_request_update_entries()
# Clear the refresh event if it was set
self._refresh_event.clear()
# Wait for either timeout or refresh event
try:
async with asyncio.timeout(DASHBOARD_POLL_INTERVAL):
await self._refresh_event.wait()
# If we get here, refresh was requested - continue loop immediately
except TimeoutError:
# Normal timeout - continue with regular polling
pass
# Global dashboard subscriber instance
DASHBOARD_SUBSCRIBER = DashboardSubscriber()
@websocket_class
class DashboardEventsWebSocket(tornado.websocket.WebSocketHandler):
"""WebSocket handler for real-time dashboard events."""
_event_listeners: list[Callable[[], None]] | None = None
_dashboard_unsubscribe: Callable[[], None] | None = None
async def get(self, *args: str, **kwargs: str) -> None:
"""Handle WebSocket upgrade request."""
if not is_authenticated(self):
self.set_status(401)
self.finish("Unauthorized")
return
await super().get(*args, **kwargs)
async def open(self, *args: str, **kwargs: str) -> None: # pylint: disable=invalid-overridden-method
"""Handle new WebSocket connection."""
# Ensure messages are sent immediately to avoid
# a 200-500ms delay when nodelay is not set.
self.set_nodelay(True)
# Update entries first
await DASHBOARD.entries.async_request_update_entries()
# Send initial state
self._send_initial_state()
# Subscribe to events
self._subscribe_to_events()
# Subscribe to dashboard updates
self._dashboard_unsubscribe = DASHBOARD_SUBSCRIBER.subscribe(self)
_LOGGER.debug("Dashboard status WebSocket opened")
def _send_initial_state(self) -> None:
"""Send initial device list and ping status."""
entries = DASHBOARD.entries.async_all()
# Send initial state
self._safe_send_message(
{
"event": DashboardEvent.INITIAL_STATE,
"data": {
"devices": build_device_list_response(DASHBOARD, entries),
"ping": {
entry.filename: entry_state_to_bool(entry.state)
for entry in entries
},
},
}
)
def _subscribe_to_events(self) -> None:
"""Subscribe to dashboard events."""
async_add_listener = DASHBOARD.bus.async_add_listener
# Subscribe to all events
self._event_listeners = [
async_add_listener(
DashboardEvent.ENTRY_STATE_CHANGED, self._on_entry_state_changed
),
async_add_listener(
DashboardEvent.ENTRY_ADDED,
self._make_entry_handler(DashboardEvent.ENTRY_ADDED),
),
async_add_listener(
DashboardEvent.ENTRY_REMOVED,
self._make_entry_handler(DashboardEvent.ENTRY_REMOVED),
),
async_add_listener(
DashboardEvent.ENTRY_UPDATED,
self._make_entry_handler(DashboardEvent.ENTRY_UPDATED),
),
async_add_listener(
DashboardEvent.IMPORTABLE_DEVICE_ADDED, self._on_importable_added
),
async_add_listener(
DashboardEvent.IMPORTABLE_DEVICE_REMOVED,
self._on_importable_removed,
),
]
def _on_entry_state_changed(self, event: Event) -> None:
"""Handle entry state change event."""
entry = event.data["entry"]
state = event.data["state"]
self._safe_send_message(
{
"event": DashboardEvent.ENTRY_STATE_CHANGED,
"data": {
"filename": entry.filename,
"name": entry.name,
"state": entry_state_to_bool(state),
},
}
)
def _make_entry_handler(
self, event_type: DashboardEvent
) -> Callable[[Event], None]:
"""Create an entry event handler."""
def handler(event: Event) -> None:
self._safe_send_message(
{"event": event_type, "data": {"device": event.data["entry"].to_dict()}}
)
return handler
def _on_importable_added(self, event: Event) -> None:
"""Handle importable device added event."""
# Don't send if device is already configured
device_name = event.data.get("device", {}).get("name")
if device_name and DASHBOARD.entries.get_by_name(device_name):
return
self._safe_send_message(
{"event": DashboardEvent.IMPORTABLE_DEVICE_ADDED, "data": event.data}
)
def _on_importable_removed(self, event: Event) -> None:
"""Handle importable device removed event."""
self._safe_send_message(
{"event": DashboardEvent.IMPORTABLE_DEVICE_REMOVED, "data": event.data}
)
def _safe_send_message(self, message: dict[str, Any]) -> None:
"""Send a message to the WebSocket client, ignoring closed errors."""
with contextlib.suppress(tornado.websocket.WebSocketClosedError):
self.write_message(json.dumps(message))
def on_message(self, message: str) -> None:
"""Handle incoming WebSocket messages."""
_LOGGER.debug("WebSocket received message: %s", message)
try:
data = json.loads(message)
except json.JSONDecodeError as err:
_LOGGER.debug("Failed to parse WebSocket message: %s", err)
return
event = data.get("event")
_LOGGER.debug("WebSocket message event: %s", event)
if event == DashboardEvent.PING:
# Send pong response for client ping
_LOGGER.debug("Received client ping, sending pong")
self._safe_send_message({"event": DashboardEvent.PONG})
elif event == DashboardEvent.REFRESH:
# Signal the polling loop to refresh immediately
_LOGGER.debug("Received refresh request, signaling polling loop")
DASHBOARD_SUBSCRIBER.request_refresh()
def on_close(self) -> None:
"""Handle WebSocket close."""
# Unsubscribe from dashboard updates
if self._dashboard_unsubscribe:
self._dashboard_unsubscribe()
self._dashboard_unsubscribe = None
# Unsubscribe from events
for remove_listener in self._event_listeners or []:
remove_listener()
_LOGGER.debug("Dashboard status WebSocket closed")
class SerialPortRequestHandler(BaseHandler):
@authenticated
async def get(self) -> None:
@@ -874,28 +1114,7 @@ class ListDevicesHandler(BaseHandler):
await dashboard.entries.async_request_update_entries()
entries = dashboard.entries.async_all()
self.set_header("content-type", "application/json")
configured = {entry.name for entry in entries}
self.write(
json.dumps(
{
"configured": [entry.to_dict() for entry in entries],
"importable": [
{
"name": res.device_name,
"friendly_name": res.friendly_name,
"package_import_url": res.package_import_url,
"project_name": res.project_name,
"project_version": res.project_version,
"network": res.network,
"ignored": res.device_name in dashboard.ignored_devices,
}
for res in dashboard.import_result.values()
if res.device_name not in configured
],
}
)
)
self.write(json.dumps(build_device_list_response(dashboard, entries)))
class MainRequestHandler(BaseHandler):
@@ -1351,6 +1570,7 @@ def make_app(debug=get_bool_env(ENV_DEV)) -> tornado.web.Application:
(f"{rel}wizard", WizardRequestHandler),
(f"{rel}static/(.*)", StaticFileHandler, {"path": get_static_path()}),
(f"{rel}devices", ListDevicesHandler),
(f"{rel}events", DashboardEventsWebSocket),
(f"{rel}import", ImportRequestHandler),
(f"{rel}secret_keys", SecretKeysRequestHandler),
(f"{rel}json-config", JsonConfigRequestHandler),

View File

@@ -68,8 +68,11 @@ class DashboardBrowser(AsyncServiceBrowser):
class DashboardImportDiscovery:
def __init__(self) -> None:
def __init__(
self, on_update: Callable[[str, DiscoveredImport | None], None] | None = None
) -> None:
self.import_state: dict[str, DiscoveredImport] = {}
self.on_update = on_update
def browser_callback(
self,
@@ -85,7 +88,9 @@ class DashboardImportDiscovery:
state_change,
)
if state_change == ServiceStateChange.Removed:
self.import_state.pop(name, None)
removed = self.import_state.pop(name, None)
if removed and self.on_update:
self.on_update(name, None)
return
if state_change == ServiceStateChange.Updated and name not in self.import_state:
@@ -139,7 +144,7 @@ class DashboardImportDiscovery:
if friendly_name is not None:
friendly_name = friendly_name.decode()
self.import_state[name] = DiscoveredImport(
discovered = DiscoveredImport(
friendly_name=friendly_name,
device_name=node_name,
package_import_url=import_url,
@@ -147,6 +152,10 @@ class DashboardImportDiscovery:
project_version=project_version,
network=network,
)
is_new = name not in self.import_state
self.import_state[name] = discovered
if is_new and self.on_update:
self.on_update(name, discovered)
def update_device_mdns(self, node_name: str, version: str):
storage_path = ext_storage_path(node_name + ".yaml")

View File

@@ -12,7 +12,7 @@ platformio==6.1.18 # When updating platformio, also update /docker/Dockerfile
esptool==5.1.0
click==8.1.7
esphome-dashboard==20250904.0
aioesphomeapi==41.10.0
aioesphomeapi==41.11.0
zeroconf==0.147.2
puremagic==1.30
ruamel.yaml==0.18.15 # dashboard_import

View File

@@ -0,0 +1,14 @@
deep_sleep:
run_duration:
default: 10s
gpio_wakeup_reason: 30s
touch_wakeup_reason: 15s
sleep_duration: 50s
wakeup_pin: ${wakeup_pin}
wakeup_pin_mode: INVERT_WAKEUP
esp32_ext1_wakeup:
pins:
- number: GPIO2
- number: GPIO13
mode: ANY_HIGH
touch_wakeup: true

View File

@@ -0,0 +1,12 @@
deep_sleep:
run_duration:
default: 10s
gpio_wakeup_reason: 30s
sleep_duration: 50s
wakeup_pin: ${wakeup_pin}
wakeup_pin_mode: INVERT_WAKEUP
esp32_ext1_wakeup:
pins:
- number: GPIO2
- number: GPIO5
mode: ANY_HIGH

View File

@@ -2,4 +2,4 @@ substitutions:
wakeup_pin: GPIO4
<<: !include common.yaml
<<: !include common-esp32.yaml
<<: !include common-esp32-ext1.yaml

View File

@@ -2,4 +2,4 @@ substitutions:
wakeup_pin: GPIO4
<<: !include common.yaml
<<: !include common-esp32.yaml
<<: !include common-esp32-all.yaml

View File

@@ -2,4 +2,4 @@ substitutions:
wakeup_pin: GPIO4
<<: !include common.yaml
<<: !include common-esp32.yaml
<<: !include common-esp32-all.yaml

View File

@@ -2,4 +2,4 @@ substitutions:
wakeup_pin: GPIO4
<<: !include common.yaml
<<: !include common-esp32.yaml
<<: !include common-esp32-all.yaml

View File

@@ -12,3 +12,4 @@ ethernet:
gateway: 192.168.178.1
subnet: 255.255.255.0
domain: .local
mac_address: "02:AA:BB:CC:DD:01"

View File

@@ -12,3 +12,4 @@ ethernet:
gateway: 192.168.178.1
subnet: 255.255.255.0
domain: .local
mac_address: "02:AA:BB:CC:DD:01"

View File

@@ -12,3 +12,4 @@ ethernet:
gateway: 192.168.178.1
subnet: 255.255.255.0
domain: .local
mac_address: "02:AA:BB:CC:DD:01"

View File

@@ -12,3 +12,4 @@ ethernet:
gateway: 192.168.178.1
subnet: 255.255.255.0
domain: .local
mac_address: "02:AA:BB:CC:DD:01"

View File

@@ -12,3 +12,4 @@ ethernet:
gateway: 192.168.178.1
subnet: 255.255.255.0
domain: .local
mac_address: "02:AA:BB:CC:DD:01"

View File

@@ -12,3 +12,4 @@ ethernet:
gateway: 192.168.178.1
subnet: 255.255.255.0
domain: .local
mac_address: "02:AA:BB:CC:DD:01"

View File

@@ -12,3 +12,4 @@ ethernet:
gateway: 192.168.178.1
subnet: 255.255.255.0
domain: .local
mac_address: "02:AA:BB:CC:DD:01"

View File

@@ -12,3 +12,4 @@ ethernet:
gateway: 192.168.178.1
subnet: 255.255.255.0
domain: .local
mac_address: "02:AA:BB:CC:DD:01"

View File

@@ -12,3 +12,4 @@ ethernet:
gateway: 192.168.178.1
subnet: 255.255.255.0
domain: .local
mac_address: "02:AA:BB:CC:DD:01"

View File

@@ -11,6 +11,10 @@ sx126x:
pa_power: 3
bandwidth: 125_0kHz
crc_enable: true
crc_initial: 0x1D0F
crc_polynomial: 0x1021
crc_size: 2
crc_inverted: true
frequency: 433920000
modulation: LORA
rx_start: true

View File

@@ -2,20 +2,42 @@
from __future__ import annotations
from unittest.mock import Mock
from pathlib import Path
from unittest.mock import MagicMock, Mock
import pytest
import pytest_asyncio
from esphome.dashboard.core import ESPHomeDashboard
from esphome.dashboard.entries import DashboardEntries
@pytest.fixture
def mock_dashboard() -> Mock:
def mock_settings(tmp_path: Path) -> MagicMock:
"""Create mock dashboard settings."""
settings = MagicMock()
settings.config_dir = str(tmp_path)
settings.absolute_config_dir = tmp_path
return settings
@pytest.fixture
def mock_dashboard(mock_settings: MagicMock) -> Mock:
"""Create a mock dashboard."""
dashboard = Mock(spec=ESPHomeDashboard)
dashboard.settings = mock_settings
dashboard.entries = Mock()
dashboard.entries.async_all.return_value = []
dashboard.stop_event = Mock()
dashboard.stop_event.is_set.return_value = True
dashboard.ping_request = Mock()
dashboard.ignored_devices = set()
dashboard.bus = Mock()
dashboard.bus.async_fire = Mock()
return dashboard
@pytest_asyncio.fixture
async def dashboard_entries(mock_dashboard: Mock) -> DashboardEntries:
"""Create a DashboardEntries instance for testing."""
return DashboardEntries(mock_dashboard)

View File

@@ -8,7 +8,9 @@ import pytest
import pytest_asyncio
from zeroconf import AddressResolver, IPVersion
from esphome.dashboard.const import DashboardEvent
from esphome.dashboard.status.mdns import MDNSStatus
from esphome.zeroconf import DiscoveredImport
@pytest_asyncio.fixture
@@ -166,3 +168,73 @@ async def test_async_setup_failure(mock_dashboard: Mock) -> None:
result = mdns_status.async_setup()
assert result is False
assert mdns_status.aiozc is None
@pytest.mark.asyncio
async def test_on_import_update_device_added(mdns_status: MDNSStatus) -> None:
"""Test _on_import_update when a device is added."""
# Create a DiscoveredImport object
discovered = DiscoveredImport(
device_name="test_device",
friendly_name="Test Device",
package_import_url="https://example.com/package",
project_name="test_project",
project_version="1.0.0",
network="wifi",
)
# Call _on_import_update with a device
mdns_status._on_import_update("test_device", discovered)
# Should fire IMPORTABLE_DEVICE_ADDED event
mock_dashboard = mdns_status.dashboard
mock_dashboard.bus.async_fire.assert_called_once()
call_args = mock_dashboard.bus.async_fire.call_args
assert call_args[0][0] == DashboardEvent.IMPORTABLE_DEVICE_ADDED
assert "device" in call_args[0][1]
device_data = call_args[0][1]["device"]
assert device_data["name"] == "test_device"
assert device_data["friendly_name"] == "Test Device"
assert device_data["project_name"] == "test_project"
assert device_data["ignored"] is False
@pytest.mark.asyncio
async def test_on_import_update_device_ignored(mdns_status: MDNSStatus) -> None:
"""Test _on_import_update when a device is ignored."""
# Add device to ignored list
mdns_status.dashboard.ignored_devices.add("ignored_device")
# Create a DiscoveredImport object for ignored device
discovered = DiscoveredImport(
device_name="ignored_device",
friendly_name="Ignored Device",
package_import_url="https://example.com/package",
project_name="test_project",
project_version="1.0.0",
network="ethernet",
)
# Call _on_import_update with an ignored device
mdns_status._on_import_update("ignored_device", discovered)
# Should fire IMPORTABLE_DEVICE_ADDED event with ignored=True
mock_dashboard = mdns_status.dashboard
mock_dashboard.bus.async_fire.assert_called_once()
call_args = mock_dashboard.bus.async_fire.call_args
assert call_args[0][0] == DashboardEvent.IMPORTABLE_DEVICE_ADDED
device_data = call_args[0][1]["device"]
assert device_data["name"] == "ignored_device"
assert device_data["ignored"] is True
@pytest.mark.asyncio
async def test_on_import_update_device_removed(mdns_status: MDNSStatus) -> None:
"""Test _on_import_update when a device is removed."""
# Call _on_import_update with None (device removed)
mdns_status._on_import_update("removed_device", None)
# Should fire IMPORTABLE_DEVICE_REMOVED event
mdns_status.dashboard.bus.async_fire.assert_called_once_with(
DashboardEvent.IMPORTABLE_DEVICE_REMOVED, {"name": "removed_device"}
)

View File

@@ -2,14 +2,15 @@
from __future__ import annotations
import os
from pathlib import Path
import tempfile
from unittest.mock import MagicMock
from unittest.mock import Mock
import pytest
import pytest_asyncio
from esphome.core import CORE
from esphome.dashboard.const import DashboardEvent
from esphome.dashboard.entries import DashboardEntries, DashboardEntry
@@ -27,21 +28,6 @@ def setup_core():
CORE.reset()
@pytest.fixture
def mock_settings() -> MagicMock:
"""Create mock dashboard settings."""
settings = MagicMock()
settings.config_dir = "/test/config"
settings.absolute_config_dir = Path("/test/config")
return settings
@pytest_asyncio.fixture
async def dashboard_entries(mock_settings: MagicMock) -> DashboardEntries:
"""Create a DashboardEntries instance for testing."""
return DashboardEntries(mock_settings)
def test_dashboard_entry_path_initialization() -> None:
"""Test DashboardEntry initializes with path correctly."""
test_path = Path("/test/config/device.yaml")
@@ -78,15 +64,24 @@ def test_dashboard_entry_path_with_relative_path() -> None:
@pytest.mark.asyncio
async def test_dashboard_entries_get_by_path(
dashboard_entries: DashboardEntries,
dashboard_entries: DashboardEntries, tmp_path: Path
) -> None:
"""Test getting entry by path."""
test_path = Path("/test/config/device.yaml")
entry = DashboardEntry(test_path, create_cache_key())
# Create a test file
test_file = tmp_path / "device.yaml"
test_file.write_text("test config")
dashboard_entries._entries[str(test_path)] = entry
# Update entries to load the file
await dashboard_entries.async_update_entries()
result = dashboard_entries.get(str(test_path))
# Verify the entry was loaded
all_entries = dashboard_entries.async_all()
assert len(all_entries) == 1
entry = all_entries[0]
assert entry.path == test_file
# Also verify get() works with Path
result = dashboard_entries.get(test_file)
assert result == entry
@@ -101,45 +96,54 @@ async def test_dashboard_entries_get_nonexistent_path(
@pytest.mark.asyncio
async def test_dashboard_entries_path_normalization(
dashboard_entries: DashboardEntries,
dashboard_entries: DashboardEntries, tmp_path: Path
) -> None:
"""Test that paths are handled consistently."""
path1 = Path("/test/config/device.yaml")
# Create a test file
test_file = tmp_path / "device.yaml"
test_file.write_text("test config")
entry = DashboardEntry(path1, create_cache_key())
dashboard_entries._entries[str(path1)] = entry
# Update entries to load the file
await dashboard_entries.async_update_entries()
result = dashboard_entries.get(str(path1))
assert result == entry
# Get the entry by path
result = dashboard_entries.get(test_file)
assert result is not None
@pytest.mark.asyncio
async def test_dashboard_entries_path_with_spaces(
dashboard_entries: DashboardEntries,
dashboard_entries: DashboardEntries, tmp_path: Path
) -> None:
"""Test handling paths with spaces."""
test_path = Path("/test/config/my device.yaml")
entry = DashboardEntry(test_path, create_cache_key())
# Create a test file with spaces in name
test_file = tmp_path / "my device.yaml"
test_file.write_text("test config")
dashboard_entries._entries[str(test_path)] = entry
# Update entries to load the file
await dashboard_entries.async_update_entries()
result = dashboard_entries.get(str(test_path))
assert result == entry
assert result.path == test_path
# Get the entry by path
result = dashboard_entries.get(test_file)
assert result is not None
assert result.path == test_file
@pytest.mark.asyncio
async def test_dashboard_entries_path_with_special_chars(
dashboard_entries: DashboardEntries,
dashboard_entries: DashboardEntries, tmp_path: Path
) -> None:
"""Test handling paths with special characters."""
test_path = Path("/test/config/device-01_test.yaml")
entry = DashboardEntry(test_path, create_cache_key())
# Create a test file with special characters
test_file = tmp_path / "device-01_test.yaml"
test_file.write_text("test config")
dashboard_entries._entries[str(test_path)] = entry
# Update entries to load the file
await dashboard_entries.async_update_entries()
result = dashboard_entries.get(str(test_path))
assert result == entry
# Get the entry by path
result = dashboard_entries.get(test_file)
assert result is not None
def test_dashboard_entries_windows_path() -> None:
@@ -154,22 +158,25 @@ def test_dashboard_entries_windows_path() -> None:
@pytest.mark.asyncio
async def test_dashboard_entries_path_to_cache_key_mapping(
dashboard_entries: DashboardEntries,
dashboard_entries: DashboardEntries, tmp_path: Path
) -> None:
"""Test internal entries storage with paths and cache keys."""
path1 = Path("/test/config/device1.yaml")
path2 = Path("/test/config/device2.yaml")
# Create test files
file1 = tmp_path / "device1.yaml"
file2 = tmp_path / "device2.yaml"
file1.write_text("test config 1")
file2.write_text("test config 2")
entry1 = DashboardEntry(path1, create_cache_key())
entry2 = DashboardEntry(path2, (1, 1, 1.0, 1))
# Update entries to load the files
await dashboard_entries.async_update_entries()
dashboard_entries._entries[str(path1)] = entry1
dashboard_entries._entries[str(path2)] = entry2
# Get entries and verify they have different cache keys
entry1 = dashboard_entries.get(file1)
entry2 = dashboard_entries.get(file2)
assert str(path1) in dashboard_entries._entries
assert str(path2) in dashboard_entries._entries
assert dashboard_entries._entries[str(path1)].cache_key == create_cache_key()
assert dashboard_entries._entries[str(path2)].cache_key == (1, 1, 1.0, 1)
assert entry1 is not None
assert entry2 is not None
assert entry1.cache_key != entry2.cache_key
def test_dashboard_entry_path_property() -> None:
@@ -183,21 +190,99 @@ def test_dashboard_entry_path_property() -> None:
@pytest.mark.asyncio
async def test_dashboard_entries_all_returns_entries_with_paths(
dashboard_entries: DashboardEntries,
dashboard_entries: DashboardEntries, tmp_path: Path
) -> None:
"""Test that all() returns entries with their paths intact."""
paths = [
Path("/test/config/device1.yaml"),
Path("/test/config/device2.yaml"),
Path("/test/config/subfolder/device3.yaml"),
# Create test files
files = [
tmp_path / "device1.yaml",
tmp_path / "device2.yaml",
tmp_path / "device3.yaml",
]
for path in paths:
entry = DashboardEntry(path, create_cache_key())
dashboard_entries._entries[str(path)] = entry
for file in files:
file.write_text("test config")
# Update entries to load the files
await dashboard_entries.async_update_entries()
all_entries = dashboard_entries.async_all()
assert len(all_entries) == len(paths)
assert len(all_entries) == len(files)
retrieved_paths = [entry.path for entry in all_entries]
assert set(retrieved_paths) == set(paths)
assert set(retrieved_paths) == set(files)
@pytest.mark.asyncio
async def test_async_update_entries_removed_path(
dashboard_entries: DashboardEntries, mock_dashboard: Mock, tmp_path: Path
) -> None:
"""Test that removed files trigger ENTRY_REMOVED event."""
# Create a test file
test_file = tmp_path / "device.yaml"
test_file.write_text("test config")
# First update to add the entry
await dashboard_entries.async_update_entries()
# Verify entry was added
all_entries = dashboard_entries.async_all()
assert len(all_entries) == 1
entry = all_entries[0]
# Delete the file
test_file.unlink()
# Second update to detect removal
await dashboard_entries.async_update_entries()
# Verify entry was removed
all_entries = dashboard_entries.async_all()
assert len(all_entries) == 0
# Verify ENTRY_REMOVED event was fired
mock_dashboard.bus.async_fire.assert_any_call(
DashboardEvent.ENTRY_REMOVED, {"entry": entry}
)
@pytest.mark.asyncio
async def test_async_update_entries_updated_path(
dashboard_entries: DashboardEntries, mock_dashboard: Mock, tmp_path: Path
) -> None:
"""Test that modified files trigger ENTRY_UPDATED event."""
# Create a test file
test_file = tmp_path / "device.yaml"
test_file.write_text("test config")
# First update to add the entry
await dashboard_entries.async_update_entries()
# Verify entry was added
all_entries = dashboard_entries.async_all()
assert len(all_entries) == 1
entry = all_entries[0]
original_cache_key = entry.cache_key
# Modify the file to change its mtime
test_file.write_text("updated config")
# Explicitly change the mtime to ensure it's different
stat = test_file.stat()
os.utime(test_file, (stat.st_atime, stat.st_mtime + 1))
# Second update to detect modification
await dashboard_entries.async_update_entries()
# Verify entry is still there with updated cache key
all_entries = dashboard_entries.async_all()
assert len(all_entries) == 1
updated_entry = all_entries[0]
assert updated_entry == entry # Same entry object
assert updated_entry.cache_key != original_cache_key # But cache key updated
# Verify ENTRY_UPDATED event was fired
mock_dashboard.bus.async_fire.assert_any_call(
DashboardEvent.ENTRY_UPDATED, {"entry": entry}
)

View File

@@ -2,11 +2,12 @@ from __future__ import annotations
import asyncio
from collections.abc import Generator
from contextlib import asynccontextmanager
import gzip
import json
import os
from pathlib import Path
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
import pytest_asyncio
@@ -14,9 +15,19 @@ from tornado.httpclient import AsyncHTTPClient, HTTPClientError, HTTPResponse
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
from tornado.testing import bind_unused_port
from tornado.websocket import WebSocketClientConnection, websocket_connect
from esphome.dashboard import web_server
from esphome.dashboard.const import DashboardEvent
from esphome.dashboard.core import DASHBOARD
from esphome.dashboard.entries import (
DashboardEntry,
EntryStateSource,
bool_to_entry_state,
)
from esphome.dashboard.models import build_importable_device_dict
from esphome.dashboard.web_server import DashboardSubscriber
from esphome.zeroconf import DiscoveredImport
from .common import get_fixture_path
@@ -126,6 +137,33 @@ async def dashboard() -> DashboardTestHelper:
io_loop.close()
@asynccontextmanager
async def websocket_connection(dashboard: DashboardTestHelper):
"""Async context manager for WebSocket connections."""
url = f"ws://127.0.0.1:{dashboard.port}/events"
ws = await websocket_connect(url)
try:
yield ws
finally:
if ws:
ws.close()
@pytest_asyncio.fixture
async def websocket_client(dashboard: DashboardTestHelper) -> WebSocketClientConnection:
"""Create a WebSocket connection for testing."""
url = f"ws://127.0.0.1:{dashboard.port}/events"
ws = await websocket_connect(url)
# Read and discard initial state message
await ws.read_message()
yield ws
if ws:
ws.close()
@pytest.mark.asyncio
async def test_main_page(dashboard: DashboardTestHelper) -> None:
response = await dashboard.fetch("/")
@@ -810,3 +848,457 @@ def test_build_cache_arguments_name_without_address(mock_dashboard: Mock) -> Non
mock_dashboard.mdns_status.get_cached_addresses.assert_called_once_with(
"my-device.local"
)
@pytest.mark.asyncio
async def test_websocket_connection_initial_state(
dashboard: DashboardTestHelper,
) -> None:
"""Test WebSocket connection and initial state."""
async with websocket_connection(dashboard) as ws:
# Should receive initial state with configured and importable devices
msg = await ws.read_message()
assert msg is not None
data = json.loads(msg)
assert data["event"] == "initial_state"
assert "devices" in data["data"]
assert "configured" in data["data"]["devices"]
assert "importable" in data["data"]["devices"]
# Check configured devices
configured = data["data"]["devices"]["configured"]
assert len(configured) > 0
assert configured[0]["name"] == "pico" # From test fixtures
@pytest.mark.asyncio
async def test_websocket_ping_pong(
dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection
) -> None:
"""Test WebSocket ping/pong mechanism."""
# Send ping
await websocket_client.write_message(json.dumps({"event": "ping"}))
# Should receive pong
msg = await websocket_client.read_message()
assert msg is not None
data = json.loads(msg)
assert data["event"] == "pong"
@pytest.mark.asyncio
async def test_websocket_invalid_json(
dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection
) -> None:
"""Test WebSocket handling of invalid JSON."""
# Send invalid JSON
await websocket_client.write_message("not valid json {]")
# Send a valid ping to verify connection is still alive
await websocket_client.write_message(json.dumps({"event": "ping"}))
# Should receive pong, confirming the connection wasn't closed by invalid JSON
msg = await websocket_client.read_message()
assert msg is not None
data = json.loads(msg)
assert data["event"] == "pong"
@pytest.mark.asyncio
async def test_websocket_authentication_required(
dashboard: DashboardTestHelper,
) -> None:
"""Test WebSocket authentication when auth is required."""
with patch(
"esphome.dashboard.web_server.is_authenticated"
) as mock_is_authenticated:
mock_is_authenticated.return_value = False
# Try to connect - should be rejected with 401
url = f"ws://127.0.0.1:{dashboard.port}/events"
with pytest.raises(HTTPClientError) as exc_info:
await websocket_connect(url)
# Should get HTTP 401 Unauthorized
assert exc_info.value.code == 401
@pytest.mark.asyncio
async def test_websocket_authentication_not_required(
dashboard: DashboardTestHelper,
) -> None:
"""Test WebSocket connection when no auth is required."""
with patch(
"esphome.dashboard.web_server.is_authenticated"
) as mock_is_authenticated:
mock_is_authenticated.return_value = True
# Should be able to connect successfully
async with websocket_connection(dashboard) as ws:
msg = await ws.read_message()
assert msg is not None
data = json.loads(msg)
assert data["event"] == "initial_state"
@pytest.mark.asyncio
async def test_websocket_entry_state_changed(
dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection
) -> None:
"""Test WebSocket entry state changed event."""
# Simulate entry state change
entry = DASHBOARD.entries.async_all()[0]
state = bool_to_entry_state(True, EntryStateSource.MDNS)
DASHBOARD.bus.async_fire(
DashboardEvent.ENTRY_STATE_CHANGED, {"entry": entry, "state": state}
)
# Should receive state change event
msg = await websocket_client.read_message()
assert msg is not None
data = json.loads(msg)
assert data["event"] == "entry_state_changed"
assert data["data"]["filename"] == entry.filename
assert data["data"]["name"] == entry.name
assert data["data"]["state"] is True
@pytest.mark.asyncio
async def test_websocket_entry_added(
dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection
) -> None:
"""Test WebSocket entry added event."""
# Create a mock entry
mock_entry = Mock(spec=DashboardEntry)
mock_entry.filename = "test.yaml"
mock_entry.name = "test_device"
mock_entry.to_dict.return_value = {
"name": "test_device",
"filename": "test.yaml",
"configuration": "test.yaml",
}
# Simulate entry added
DASHBOARD.bus.async_fire(DashboardEvent.ENTRY_ADDED, {"entry": mock_entry})
# Should receive entry added event
msg = await websocket_client.read_message()
assert msg is not None
data = json.loads(msg)
assert data["event"] == "entry_added"
assert data["data"]["device"]["name"] == "test_device"
assert data["data"]["device"]["filename"] == "test.yaml"
@pytest.mark.asyncio
async def test_websocket_entry_removed(
dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection
) -> None:
"""Test WebSocket entry removed event."""
# Create a mock entry
mock_entry = Mock(spec=DashboardEntry)
mock_entry.filename = "removed.yaml"
mock_entry.name = "removed_device"
mock_entry.to_dict.return_value = {
"name": "removed_device",
"filename": "removed.yaml",
"configuration": "removed.yaml",
}
# Simulate entry removed
DASHBOARD.bus.async_fire(DashboardEvent.ENTRY_REMOVED, {"entry": mock_entry})
# Should receive entry removed event
msg = await websocket_client.read_message()
assert msg is not None
data = json.loads(msg)
assert data["event"] == "entry_removed"
assert data["data"]["device"]["name"] == "removed_device"
assert data["data"]["device"]["filename"] == "removed.yaml"
@pytest.mark.asyncio
async def test_websocket_importable_device_added(
dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection
) -> None:
"""Test WebSocket importable device added event with real DiscoveredImport."""
# Create a real DiscoveredImport object
discovered = DiscoveredImport(
device_name="new_import_device",
friendly_name="New Import Device",
package_import_url="https://example.com/package",
project_name="test_project",
project_version="1.0.0",
network="wifi",
)
# Directly fire the event as the mDNS system would
device_dict = build_importable_device_dict(DASHBOARD, discovered)
DASHBOARD.bus.async_fire(
DashboardEvent.IMPORTABLE_DEVICE_ADDED, {"device": device_dict}
)
# Should receive importable device added event
msg = await websocket_client.read_message()
assert msg is not None
data = json.loads(msg)
assert data["event"] == "importable_device_added"
assert data["data"]["device"]["name"] == "new_import_device"
assert data["data"]["device"]["friendly_name"] == "New Import Device"
assert data["data"]["device"]["project_name"] == "test_project"
assert data["data"]["device"]["network"] == "wifi"
assert data["data"]["device"]["ignored"] is False
@pytest.mark.asyncio
async def test_websocket_importable_device_added_ignored(
dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection
) -> None:
"""Test WebSocket importable device added event for ignored device."""
# Add device to ignored list
DASHBOARD.ignored_devices.add("ignored_device")
# Create a real DiscoveredImport object
discovered = DiscoveredImport(
device_name="ignored_device",
friendly_name="Ignored Device",
package_import_url="https://example.com/package",
project_name="test_project",
project_version="1.0.0",
network="ethernet",
)
# Directly fire the event as the mDNS system would
device_dict = build_importable_device_dict(DASHBOARD, discovered)
DASHBOARD.bus.async_fire(
DashboardEvent.IMPORTABLE_DEVICE_ADDED, {"device": device_dict}
)
# Should receive importable device added event with ignored=True
msg = await websocket_client.read_message()
assert msg is not None
data = json.loads(msg)
assert data["event"] == "importable_device_added"
assert data["data"]["device"]["name"] == "ignored_device"
assert data["data"]["device"]["friendly_name"] == "Ignored Device"
assert data["data"]["device"]["network"] == "ethernet"
assert data["data"]["device"]["ignored"] is True
@pytest.mark.asyncio
async def test_websocket_importable_device_removed(
dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection
) -> None:
"""Test WebSocket importable device removed event."""
# Simulate importable device removed
DASHBOARD.bus.async_fire(
DashboardEvent.IMPORTABLE_DEVICE_REMOVED,
{"name": "removed_import_device"},
)
# Should receive importable device removed event
msg = await websocket_client.read_message()
assert msg is not None
data = json.loads(msg)
assert data["event"] == "importable_device_removed"
assert data["data"]["name"] == "removed_import_device"
@pytest.mark.asyncio
async def test_websocket_importable_device_already_configured(
dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection
) -> None:
"""Test that importable device event is not sent if device is already configured."""
# Get an existing configured device name
existing_entry = DASHBOARD.entries.async_all()[0]
# Simulate importable device added with same name as configured device
DASHBOARD.bus.async_fire(
DashboardEvent.IMPORTABLE_DEVICE_ADDED,
{
"device": {
"name": existing_entry.name,
"friendly_name": "Should Not Be Sent",
"package_import_url": "https://example.com/package",
"project_name": "test_project",
"project_version": "1.0.0",
"network": "wifi",
}
},
)
# Send a ping to ensure connection is still alive
await websocket_client.write_message(json.dumps({"event": "ping"}))
# Should only receive pong, not the importable device event
msg = await websocket_client.read_message()
assert msg is not None
data = json.loads(msg)
assert data["event"] == "pong"
@pytest.mark.asyncio
async def test_websocket_multiple_connections(dashboard: DashboardTestHelper) -> None:
"""Test multiple WebSocket connections."""
async with (
websocket_connection(dashboard) as ws1,
websocket_connection(dashboard) as ws2,
):
# Both should receive initial state
msg1 = await ws1.read_message()
assert msg1 is not None
data1 = json.loads(msg1)
assert data1["event"] == "initial_state"
msg2 = await ws2.read_message()
assert msg2 is not None
data2 = json.loads(msg2)
assert data2["event"] == "initial_state"
# Fire an event - both should receive it
entry = DASHBOARD.entries.async_all()[0]
state = bool_to_entry_state(False, EntryStateSource.MDNS)
DASHBOARD.bus.async_fire(
DashboardEvent.ENTRY_STATE_CHANGED, {"entry": entry, "state": state}
)
msg1 = await ws1.read_message()
assert msg1 is not None
data1 = json.loads(msg1)
assert data1["event"] == "entry_state_changed"
msg2 = await ws2.read_message()
assert msg2 is not None
data2 = json.loads(msg2)
assert data2["event"] == "entry_state_changed"
@pytest.mark.asyncio
async def test_dashboard_subscriber_lifecycle(dashboard: DashboardTestHelper) -> None:
"""Test DashboardSubscriber lifecycle."""
subscriber = DashboardSubscriber()
# Initially no subscribers
assert len(subscriber._subscribers) == 0
assert subscriber._event_loop_task is None
# Add a subscriber
mock_websocket = Mock()
unsubscribe = subscriber.subscribe(mock_websocket)
# Should have started the event loop task
assert len(subscriber._subscribers) == 1
assert subscriber._event_loop_task is not None
# Unsubscribe
unsubscribe()
# Should have stopped the task
assert len(subscriber._subscribers) == 0
@pytest.mark.asyncio
async def test_dashboard_subscriber_entries_update_interval(
dashboard: DashboardTestHelper,
) -> None:
"""Test DashboardSubscriber entries update interval."""
# Patch the constants to make the test run faster
with (
patch("esphome.dashboard.web_server.DASHBOARD_POLL_INTERVAL", 0.01),
patch("esphome.dashboard.web_server.DASHBOARD_ENTRIES_UPDATE_ITERATIONS", 2),
patch("esphome.dashboard.web_server.settings") as mock_settings,
patch("esphome.dashboard.web_server.DASHBOARD") as mock_dashboard,
):
mock_settings.status_use_mqtt = False
# Mock dashboard dependencies
mock_dashboard.ping_request = Mock()
mock_dashboard.ping_request.set = Mock()
mock_dashboard.entries = Mock()
mock_dashboard.entries.async_request_update_entries = Mock()
subscriber = DashboardSubscriber()
mock_websocket = Mock()
# Subscribe to start the event loop
unsubscribe = subscriber.subscribe(mock_websocket)
# Wait for a few iterations to ensure entries update is called
await asyncio.sleep(0.05) # Should be enough for 2+ iterations
# Unsubscribe to stop the task
unsubscribe()
# Verify entries update was called
assert mock_dashboard.entries.async_request_update_entries.call_count >= 1
# Verify ping request was set multiple times
assert mock_dashboard.ping_request.set.call_count >= 2
@pytest.mark.asyncio
async def test_websocket_refresh_command(
dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection
) -> None:
"""Test WebSocket refresh command triggers dashboard update."""
with patch("esphome.dashboard.web_server.DASHBOARD_SUBSCRIBER") as mock_subscriber:
mock_subscriber.request_refresh = Mock()
# Send refresh command
await websocket_client.write_message(json.dumps({"event": "refresh"}))
# Give it a moment to process
await asyncio.sleep(0.01)
# Verify request_refresh was called
mock_subscriber.request_refresh.assert_called_once()
@pytest.mark.asyncio
async def test_dashboard_subscriber_refresh_event(
dashboard: DashboardTestHelper,
) -> None:
"""Test DashboardSubscriber refresh event triggers immediate update."""
# Patch the constants to make the test run faster
with (
patch(
"esphome.dashboard.web_server.DASHBOARD_POLL_INTERVAL", 1.0
), # Long timeout
patch(
"esphome.dashboard.web_server.DASHBOARD_ENTRIES_UPDATE_ITERATIONS", 100
), # Won't reach naturally
patch("esphome.dashboard.web_server.settings") as mock_settings,
patch("esphome.dashboard.web_server.DASHBOARD") as mock_dashboard,
):
mock_settings.status_use_mqtt = False
# Mock dashboard dependencies
mock_dashboard.ping_request = Mock()
mock_dashboard.ping_request.set = Mock()
mock_dashboard.entries = Mock()
mock_dashboard.entries.async_request_update_entries = AsyncMock()
subscriber = DashboardSubscriber()
mock_websocket = Mock()
# Subscribe to start the event loop
unsubscribe = subscriber.subscribe(mock_websocket)
# Wait a bit to ensure loop is running
await asyncio.sleep(0.01)
# Verify entries update hasn't been called yet (iterations not reached)
assert mock_dashboard.entries.async_request_update_entries.call_count == 0
# Request refresh
subscriber.request_refresh()
# Wait for the refresh to be processed
await asyncio.sleep(0.01)
# Now entries update should have been called
assert mock_dashboard.entries.async_request_update_entries.call_count == 1
# Unsubscribe to stop the task
unsubscribe()
# Give it a moment to clean up
await asyncio.sleep(0.01)