From 61a558a062e9df0c9d412ab1fa7a5c4f0c34bba5 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 15 Jun 2025 15:53:45 -0500 Subject: [PATCH 01/14] Implement a lock free ring buffer for BLEScanResult to avoid drops (#9087) --- .../esp32_ble_tracker/esp32_ble_tracker.cpp | 88 ++++++++++++------- .../esp32_ble_tracker/esp32_ble_tracker.h | 14 ++- 2 files changed, 69 insertions(+), 33 deletions(-) diff --git a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp index ab3efc3ad3..c5906779f1 100644 --- a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp +++ b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp @@ -51,15 +51,14 @@ void ESP32BLETracker::setup() { return; } RAMAllocator allocator; - this->scan_result_buffer_ = allocator.allocate(SCAN_RESULT_BUFFER_SIZE); + this->scan_ring_buffer_ = allocator.allocate(SCAN_RESULT_BUFFER_SIZE); - if (this->scan_result_buffer_ == nullptr) { - ESP_LOGE(TAG, "Could not allocate buffer for BLE Tracker!"); + if (this->scan_ring_buffer_ == nullptr) { + ESP_LOGE(TAG, "Could not allocate ring buffer for BLE Tracker!"); this->mark_failed(); } global_esp32_ble_tracker = this; - this->scan_result_lock_ = xSemaphoreCreateMutex(); #ifdef USE_OTA ota::get_global_ota_callback()->add_on_state_callback( @@ -119,27 +118,31 @@ void ESP32BLETracker::loop() { } bool promote_to_connecting = discovered && !searching && !connecting; - if (this->scanner_state_ == ScannerState::RUNNING && - this->scan_result_index_ && // if it looks like we have a scan result we will take the lock - xSemaphoreTake(this->scan_result_lock_, 0)) { - uint32_t index = this->scan_result_index_; - if (index >= SCAN_RESULT_BUFFER_SIZE) { - ESP_LOGW(TAG, "Too many BLE events to process. Some devices may not show up."); - } + // Process scan results from lock-free SPSC ring buffer + // Consumer side: This runs in the main loop thread + if (this->scanner_state_ == ScannerState::RUNNING) { + // Load our own index with relaxed ordering (we're the only writer) + size_t read_idx = this->ring_read_index_.load(std::memory_order_relaxed); - if (this->raw_advertisements_) { - for (auto *listener : this->listeners_) { - listener->parse_devices(this->scan_result_buffer_, this->scan_result_index_); - } - for (auto *client : this->clients_) { - client->parse_devices(this->scan_result_buffer_, this->scan_result_index_); - } - } + // Load producer's index with acquire to see their latest writes + size_t write_idx = this->ring_write_index_.load(std::memory_order_acquire); - if (this->parse_advertisements_) { - for (size_t i = 0; i < index; i++) { + while (read_idx != write_idx) { + // Process one result at a time directly from ring buffer + BLEScanResult &scan_result = this->scan_ring_buffer_[read_idx]; + + if (this->raw_advertisements_) { + for (auto *listener : this->listeners_) { + listener->parse_devices(&scan_result, 1); + } + for (auto *client : this->clients_) { + client->parse_devices(&scan_result, 1); + } + } + + if (this->parse_advertisements_) { ESPBTDevice device; - device.parse_scan_rst(this->scan_result_buffer_[i]); + device.parse_scan_rst(scan_result); bool found = false; for (auto *listener : this->listeners_) { @@ -160,9 +163,19 @@ void ESP32BLETracker::loop() { this->print_bt_device_info(device); } } + + // Move to next entry in ring buffer + read_idx = (read_idx + 1) % SCAN_RESULT_BUFFER_SIZE; + + // Store with release to ensure reads complete before index update + this->ring_read_index_.store(read_idx, std::memory_order_release); + } + + // Log dropped results periodically + size_t dropped = this->scan_results_dropped_.exchange(0, std::memory_order_relaxed); + if (dropped > 0) { + ESP_LOGW(TAG, "Dropped %zu BLE scan results due to buffer overflow", dropped); } - this->scan_result_index_ = 0; - xSemaphoreGive(this->scan_result_lock_); } if (this->scanner_state_ == ScannerState::STOPPED) { this->end_of_scan_(); // Change state to IDLE @@ -391,12 +404,27 @@ void ESP32BLETracker::gap_scan_event_handler(const BLEScanResult &scan_result) { ESP_LOGV(TAG, "gap_scan_result - event %d", scan_result.search_evt); if (scan_result.search_evt == ESP_GAP_SEARCH_INQ_RES_EVT) { - if (xSemaphoreTake(this->scan_result_lock_, 0)) { - if (this->scan_result_index_ < SCAN_RESULT_BUFFER_SIZE) { - // Store BLEScanResult directly in our buffer - this->scan_result_buffer_[this->scan_result_index_++] = scan_result; - } - xSemaphoreGive(this->scan_result_lock_); + // Lock-free SPSC ring buffer write (Producer side) + // This runs in the ESP-IDF Bluetooth stack callback thread + // IMPORTANT: Only this thread writes to ring_write_index_ + + // Load our own index with relaxed ordering (we're the only writer) + size_t write_idx = this->ring_write_index_.load(std::memory_order_relaxed); + size_t next_write_idx = (write_idx + 1) % SCAN_RESULT_BUFFER_SIZE; + + // Load consumer's index with acquire to see their latest updates + size_t read_idx = this->ring_read_index_.load(std::memory_order_acquire); + + // Check if buffer is full + if (next_write_idx != read_idx) { + // Write to ring buffer + this->scan_ring_buffer_[write_idx] = scan_result; + + // Store with release to ensure the write is visible before index update + this->ring_write_index_.store(next_write_idx, std::memory_order_release); + } else { + // Buffer full, track dropped results + this->scan_results_dropped_.fetch_add(1, std::memory_order_relaxed); } } else if (scan_result.search_evt == ESP_GAP_SEARCH_INQ_CMPL_EVT) { // Scan finished on its own diff --git a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h index 33c0caaa87..16a100fb47 100644 --- a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h +++ b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h @@ -6,6 +6,7 @@ #include "esphome/core/helpers.h" #include +#include #include #include @@ -282,9 +283,16 @@ class ESP32BLETracker : public Component, bool ble_was_disabled_{true}; bool raw_advertisements_{false}; bool parse_advertisements_{false}; - SemaphoreHandle_t scan_result_lock_; - size_t scan_result_index_{0}; - BLEScanResult *scan_result_buffer_; + + // Lock-free Single-Producer Single-Consumer (SPSC) ring buffer for scan results + // Producer: ESP-IDF Bluetooth stack callback (gap_scan_event_handler) + // Consumer: ESPHome main loop (loop() method) + // This design ensures zero blocking in the BT callback and prevents scan result loss + BLEScanResult *scan_ring_buffer_; + std::atomic ring_write_index_{0}; // Written only by BT callback (producer) + std::atomic ring_read_index_{0}; // Written only by main loop (consumer) + std::atomic scan_results_dropped_{0}; // Tracks buffer overflow events + esp_bt_status_t scan_start_failed_{ESP_BT_STATUS_SUCCESS}; esp_bt_status_t scan_set_param_failed_{ESP_BT_STATUS_SUCCESS}; int connecting_{0}; From fcce4a8be6dbfd01a84c1827a45ecda3a339ce15 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 15 Jun 2025 16:16:46 -0500 Subject: [PATCH 02/14] Make BLE queue lock free (#9088) --- esphome/components/esp32_ble/ble.cpp | 26 +++++++-- esphome/components/esp32_ble/ble.h | 5 +- esphome/components/esp32_ble/queue.h | 83 ++++++++++++++++------------ 3 files changed, 73 insertions(+), 41 deletions(-) diff --git a/esphome/components/esp32_ble/ble.cpp b/esphome/components/esp32_ble/ble.cpp index ed74d59ef2..8adef79d2f 100644 --- a/esphome/components/esp32_ble/ble.cpp +++ b/esphome/components/esp32_ble/ble.cpp @@ -23,9 +23,6 @@ namespace esp32_ble { static const char *const TAG = "esp32_ble"; -// Maximum size of the BLE event queue -static constexpr size_t MAX_BLE_QUEUE_SIZE = SCAN_RESULT_BUFFER_SIZE * 2; - static RAMAllocator EVENT_ALLOCATOR( // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) RAMAllocator::ALLOW_FAILURE | RAMAllocator::ALLOC_INTERNAL); @@ -360,21 +357,38 @@ void ESP32BLE::loop() { if (this->advertising_ != nullptr) { this->advertising_->loop(); } + + // Log dropped events periodically + size_t dropped = this->ble_events_.get_and_reset_dropped_count(); + if (dropped > 0) { + ESP_LOGW(TAG, "Dropped %zu BLE events due to buffer overflow", dropped); + } } template void enqueue_ble_event(Args... args) { - if (global_ble->ble_events_.size() >= MAX_BLE_QUEUE_SIZE) { - ESP_LOGD(TAG, "Event queue full (%zu), dropping event", MAX_BLE_QUEUE_SIZE); + // Check if queue is full before allocating + if (global_ble->ble_events_.full()) { + // Queue is full, drop the event + global_ble->ble_events_.increment_dropped_count(); return; } BLEEvent *new_event = EVENT_ALLOCATOR.allocate(1); if (new_event == nullptr) { // Memory too fragmented to allocate new event. Can only drop it until memory comes back + global_ble->ble_events_.increment_dropped_count(); return; } new (new_event) BLEEvent(args...); - global_ble->ble_events_.push(new_event); + + // Push the event - since we're the only producer and we checked full() above, + // this should always succeed unless we have a bug + if (!global_ble->ble_events_.push(new_event)) { + // This should not happen in SPSC queue with single producer + ESP_LOGE(TAG, "BLE queue push failed unexpectedly"); + new_event->~BLEEvent(); + EVENT_ALLOCATOR.deallocate(new_event, 1); + } } // NOLINT(clang-analyzer-unix.Malloc) // Explicit template instantiations for the friend function diff --git a/esphome/components/esp32_ble/ble.h b/esphome/components/esp32_ble/ble.h index 6508db1a00..58c064a2ef 100644 --- a/esphome/components/esp32_ble/ble.h +++ b/esphome/components/esp32_ble/ble.h @@ -30,6 +30,9 @@ static constexpr uint8_t SCAN_RESULT_BUFFER_SIZE = 32; static constexpr uint8_t SCAN_RESULT_BUFFER_SIZE = 20; #endif +// Maximum size of the BLE event queue - must be power of 2 for lock-free queue +static constexpr size_t MAX_BLE_QUEUE_SIZE = 64; + uint64_t ble_addr_to_uint64(const esp_bd_addr_t address); // NOLINTNEXTLINE(modernize-use-using) @@ -144,7 +147,7 @@ class ESP32BLE : public Component { std::vector ble_status_event_handlers_; BLEComponentState state_{BLE_COMPONENT_STATE_OFF}; - Queue ble_events_; + LockFreeQueue ble_events_; BLEAdvertising *advertising_{}; esp_ble_io_cap_t io_cap_{ESP_IO_CAP_NONE}; uint32_t advertising_cycle_time_{}; diff --git a/esphome/components/esp32_ble/queue.h b/esphome/components/esp32_ble/queue.h index f69878bf6e..56d2efd18b 100644 --- a/esphome/components/esp32_ble/queue.h +++ b/esphome/components/esp32_ble/queue.h @@ -2,63 +2,78 @@ #ifdef USE_ESP32 -#include -#include - -#include -#include +#include +#include /* * BLE events come in from a separate Task (thread) in the ESP32 stack. Rather - * than trying to deal with various locking strategies, all incoming GAP and GATT - * events will simply be placed on a semaphore guarded queue. The next time the - * component runs loop(), these events are popped off the queue and handed at - * this safer time. + * than using mutex-based locking, this lock-free queue allows the BLE + * task to enqueue events without blocking. The main loop() then processes + * these events at a safer time. + * + * This is a Single-Producer Single-Consumer (SPSC) lock-free ring buffer. + * The BLE task is the only producer, and the main loop() is the only consumer. */ namespace esphome { namespace esp32_ble { -template class Queue { +template class LockFreeQueue { public: - Queue() { m_ = xSemaphoreCreateMutex(); } + LockFreeQueue() : head_(0), tail_(0), dropped_count_(0) {} - void push(T *element) { + bool push(T *element) { if (element == nullptr) - return; - // It is not called from main loop. Thus it won't block main thread. - xSemaphoreTake(m_, portMAX_DELAY); - q_.push(element); - xSemaphoreGive(m_); + return false; + + size_t current_tail = tail_.load(std::memory_order_relaxed); + size_t next_tail = (current_tail + 1) % SIZE; + + if (next_tail == head_.load(std::memory_order_acquire)) { + // Buffer full + dropped_count_.fetch_add(1, std::memory_order_relaxed); + return false; + } + + buffer_[current_tail] = element; + tail_.store(next_tail, std::memory_order_release); + return true; } T *pop() { - T *element = nullptr; + size_t current_head = head_.load(std::memory_order_relaxed); - if (xSemaphoreTake(m_, 5L / portTICK_PERIOD_MS)) { - if (!q_.empty()) { - element = q_.front(); - q_.pop(); - } - xSemaphoreGive(m_); + if (current_head == tail_.load(std::memory_order_acquire)) { + return nullptr; // Empty } + + T *element = buffer_[current_head]; + head_.store((current_head + 1) % SIZE, std::memory_order_release); return element; } size_t size() const { - // Lock-free size check. While std::queue::size() is not thread-safe, we intentionally - // avoid locking here to prevent blocking the BLE callback thread. The size is only - // used to decide whether to drop incoming events when the queue is near capacity. - // With a queue limit of 40-64 events and normal processing, dropping events should - // be extremely rare. When it does approach capacity, being off by 1-2 events is - // acceptable to avoid blocking the BLE stack's time-sensitive callbacks. - // Trade-off: We prefer occasional dropped events over potential BLE stack delays. - return q_.size(); + size_t tail = tail_.load(std::memory_order_acquire); + size_t head = head_.load(std::memory_order_acquire); + return (tail - head + SIZE) % SIZE; + } + + size_t get_and_reset_dropped_count() { return dropped_count_.exchange(0, std::memory_order_relaxed); } + + void increment_dropped_count() { dropped_count_.fetch_add(1, std::memory_order_relaxed); } + + bool empty() const { return head_.load(std::memory_order_acquire) == tail_.load(std::memory_order_acquire); } + + bool full() const { + size_t next_tail = (tail_.load(std::memory_order_relaxed) + 1) % SIZE; + return next_tail == head_.load(std::memory_order_acquire); } protected: - std::queue q_; - SemaphoreHandle_t m_; + T *buffer_[SIZE]; + std::atomic head_; + std::atomic tail_; + std::atomic dropped_count_; }; } // namespace esp32_ble From be58cdda3ba00a8b0943271d0935294e209ba1c3 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 15 Jun 2025 16:19:04 -0500 Subject: [PATCH 03/14] Fix protobuf encoding size mismatch by passing force parameter in encode_string (#9074) --- esphome/components/api/proto.h | 2 +- .../host_mode_empty_string_options.yaml | 58 +++++++++ .../test_host_mode_empty_string_options.py | 110 ++++++++++++++++++ 3 files changed, 169 insertions(+), 1 deletion(-) create mode 100644 tests/integration/fixtures/host_mode_empty_string_options.yaml create mode 100644 tests/integration/test_host_mode_empty_string_options.py diff --git a/esphome/components/api/proto.h b/esphome/components/api/proto.h index 5265c4520d..eb0dbc151b 100644 --- a/esphome/components/api/proto.h +++ b/esphome/components/api/proto.h @@ -216,7 +216,7 @@ class ProtoWriteBuffer { this->buffer_->insert(this->buffer_->end(), data, data + len); } void encode_string(uint32_t field_id, const std::string &value, bool force = false) { - this->encode_string(field_id, value.data(), value.size()); + this->encode_string(field_id, value.data(), value.size(), force); } void encode_bytes(uint32_t field_id, const uint8_t *data, size_t len, bool force = false) { this->encode_string(field_id, reinterpret_cast(data), len, force); diff --git a/tests/integration/fixtures/host_mode_empty_string_options.yaml b/tests/integration/fixtures/host_mode_empty_string_options.yaml new file mode 100644 index 0000000000..ab8e6cd005 --- /dev/null +++ b/tests/integration/fixtures/host_mode_empty_string_options.yaml @@ -0,0 +1,58 @@ +esphome: + name: host-empty-string-test + +host: + +api: + batch_delay: 50ms + +select: + - platform: template + name: "Select Empty First" + id: select_empty_first + optimistic: true + options: + - "" # Empty string at the beginning + - "Option A" + - "Option B" + - "Option C" + initial_option: "Option A" + + - platform: template + name: "Select Empty Middle" + id: select_empty_middle + optimistic: true + options: + - "Option 1" + - "Option 2" + - "" # Empty string in the middle + - "Option 3" + - "Option 4" + initial_option: "Option 1" + + - platform: template + name: "Select Empty Last" + id: select_empty_last + optimistic: true + options: + - "Choice X" + - "Choice Y" + - "Choice Z" + - "" # Empty string at the end + initial_option: "Choice X" + +# Add a sensor to ensure we have other entities in the list +sensor: + - platform: template + name: "Test Sensor" + id: test_sensor + lambda: |- + return 42.0; + update_interval: 60s + +binary_sensor: + - platform: template + name: "Test Binary Sensor" + id: test_binary_sensor + lambda: |- + return true; diff --git a/tests/integration/test_host_mode_empty_string_options.py b/tests/integration/test_host_mode_empty_string_options.py new file mode 100644 index 0000000000..d2df839a75 --- /dev/null +++ b/tests/integration/test_host_mode_empty_string_options.py @@ -0,0 +1,110 @@ +"""Integration test for protobuf encoding of empty string options in select entities.""" + +from __future__ import annotations + +import asyncio + +from aioesphomeapi import EntityState, SelectInfo +import pytest + +from .types import APIClientConnectedFactory, RunCompiledFunction + + +@pytest.mark.asyncio +async def test_host_mode_empty_string_options( + yaml_config: str, + run_compiled: RunCompiledFunction, + api_client_connected: APIClientConnectedFactory, +) -> None: + """Test that select entities with empty string options are correctly encoded in protobuf messages. + + This tests the fix for the bug where the force parameter was not passed in encode_string, + causing empty strings in repeated fields to be skipped during encoding but included in + size calculation, leading to protobuf decoding errors. + """ + # Write, compile and run the ESPHome device, then connect to API + loop = asyncio.get_running_loop() + async with run_compiled(yaml_config), api_client_connected() as client: + # Verify we can get device info + device_info = await client.device_info() + assert device_info is not None + assert device_info.name == "host-empty-string-test" + + # Get list of entities - this will encode ListEntitiesSelectResponse messages + # with empty string options that would trigger the bug + entity_info, services = await client.list_entities_services() + + # Find our select entities + select_entities = [e for e in entity_info if isinstance(e, SelectInfo)] + assert len(select_entities) == 3, ( + f"Expected 3 select entities, got {len(select_entities)}" + ) + + # Verify each select entity by name and check their options + selects_by_name = {e.name: e for e in select_entities} + + # Check "Select Empty First" - empty string at beginning + assert "Select Empty First" in selects_by_name + empty_first = selects_by_name["Select Empty First"] + assert len(empty_first.options) == 4 + assert empty_first.options[0] == "" # Empty string at beginning + assert empty_first.options[1] == "Option A" + assert empty_first.options[2] == "Option B" + assert empty_first.options[3] == "Option C" + + # Check "Select Empty Middle" - empty string in middle + assert "Select Empty Middle" in selects_by_name + empty_middle = selects_by_name["Select Empty Middle"] + assert len(empty_middle.options) == 5 + assert empty_middle.options[0] == "Option 1" + assert empty_middle.options[1] == "Option 2" + assert empty_middle.options[2] == "" # Empty string in middle + assert empty_middle.options[3] == "Option 3" + assert empty_middle.options[4] == "Option 4" + + # Check "Select Empty Last" - empty string at end + assert "Select Empty Last" in selects_by_name + empty_last = selects_by_name["Select Empty Last"] + assert len(empty_last.options) == 4 + assert empty_last.options[0] == "Choice X" + assert empty_last.options[1] == "Choice Y" + assert empty_last.options[2] == "Choice Z" + assert empty_last.options[3] == "" # Empty string at end + + # If we got here without protobuf decoding errors, the fix is working + # The bug would have caused "Invalid protobuf message" errors with trailing bytes + + # Also verify we can interact with the select entities + # Subscribe to state changes + states: dict[int, EntityState] = {} + state_change_future: asyncio.Future[None] = loop.create_future() + + def on_state(state: EntityState) -> None: + """Track state changes.""" + states[state.key] = state + # When we receive the state change for our select, resolve the future + if state.key == empty_first.key and not state_change_future.done(): + state_change_future.set_result(None) + + client.subscribe_states(on_state) + + # Try setting a select to an empty string option + # This further tests that empty strings are handled correctly + client.select_command(empty_first.key, "") + + # Wait for state update with timeout + try: + await asyncio.wait_for(state_change_future, timeout=5.0) + except asyncio.TimeoutError: + pytest.fail( + "Did not receive state update after setting select to empty string" + ) + + # Verify the state was set to empty string + assert empty_first.key in states + select_state = states[empty_first.key] + assert hasattr(select_state, "state") + assert select_state.state == "" + + # The test passes if no protobuf decoding errors occurred + # With the bug, we would have gotten "Invalid protobuf message" errors From bd85ba9b6a07d27ced7ed8c52215c2092d05f698 Mon Sep 17 00:00:00 2001 From: Kevin Ahrendt Date: Sun, 15 Jun 2025 22:19:50 +0100 Subject: [PATCH 04/14] [i2s_audio] Check for a nullptr before disabling and deleting channel (#9062) --- .../microphone/i2s_audio_microphone.cpp | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/esphome/components/i2s_audio/microphone/i2s_audio_microphone.cpp b/esphome/components/i2s_audio/microphone/i2s_audio_microphone.cpp index 1ce98d51d3..52d0ae34fb 100644 --- a/esphome/components/i2s_audio/microphone/i2s_audio_microphone.cpp +++ b/esphome/components/i2s_audio/microphone/i2s_audio_microphone.cpp @@ -317,15 +317,18 @@ void I2SAudioMicrophone::stop_driver_() { ESP_LOGW(TAG, "Error uninstalling I2S driver - it may not have started: %s", esp_err_to_name(err)); } #else - /* Have to stop the channel before deleting it */ - err = i2s_channel_disable(this->rx_handle_); - if (err != ESP_OK) { - ESP_LOGW(TAG, "Error stopping I2S microphone - it may not have started: %s", esp_err_to_name(err)); - } - /* If the handle is not needed any more, delete it to release the channel resources */ - err = i2s_del_channel(this->rx_handle_); - if (err != ESP_OK) { - ESP_LOGW(TAG, "Error deleting I2S channel - it may not have started: %s", esp_err_to_name(err)); + if (this->rx_handle_ != nullptr) { + /* Have to stop the channel before deleting it */ + err = i2s_channel_disable(this->rx_handle_); + if (err != ESP_OK) { + ESP_LOGW(TAG, "Error stopping I2S microphone - it may not have started: %s", esp_err_to_name(err)); + } + /* If the handle is not needed any more, delete it to release the channel resources */ + err = i2s_del_channel(this->rx_handle_); + if (err != ESP_OK) { + ESP_LOGW(TAG, "Error deleting I2S channel - it may not have started: %s", esp_err_to_name(err)); + } + this->rx_handle_ = nullptr; } #endif this->parent_->unlock(); From 06810e8e6a960221b484972e733ddb845c878222 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 15 Jun 2025 16:22:14 -0500 Subject: [PATCH 05/14] Ensure we can send batches where the first message exceeds MAX_PACKET_SIZE (#9068) --- esphome/components/api/api_connection.cpp | 8 +- tests/integration/conftest.py | 16 +- .../fixtures/api_message_size_batching.yaml | 161 +++++++++++++++ .../fixtures/large_message_batching.yaml | 137 +++++++++++++ .../test_api_message_size_batching.py | 194 ++++++++++++++++++ .../test_large_message_batching.py | 59 ++++++ 6 files changed, 570 insertions(+), 5 deletions(-) create mode 100644 tests/integration/fixtures/api_message_size_batching.yaml create mode 100644 tests/integration/fixtures/large_message_batching.yaml create mode 100644 tests/integration/test_api_message_size_batching.py create mode 100644 tests/integration/test_large_message_batching.py diff --git a/esphome/components/api/api_connection.cpp b/esphome/components/api/api_connection.cpp index ca6e2a2d56..8328f5d2cd 100644 --- a/esphome/components/api/api_connection.cpp +++ b/esphome/components/api/api_connection.cpp @@ -1807,7 +1807,7 @@ void APIConnection::process_batch_() { this->batch_first_message_ = true; size_t items_processed = 0; - uint32_t remaining_size = MAX_PACKET_SIZE; + uint16_t remaining_size = std::numeric_limits::max(); // Track where each message's header padding begins in the buffer // For plaintext: this is where the 6-byte header padding starts @@ -1832,11 +1832,15 @@ void APIConnection::process_batch_() { packet_info.emplace_back(item.message_type, current_offset, proto_payload_size); // Update tracking variables + items_processed++; + // After first message, set remaining size to MAX_PACKET_SIZE to avoid fragmentation + if (items_processed == 1) { + remaining_size = MAX_PACKET_SIZE; + } remaining_size -= payload_size; // Calculate where the next message's header padding will start // Current buffer size + footer space (that prepare_message_buffer will add for this message) current_offset = this->parent_->get_shared_buffer_ref().size() + footer_size; - items_processed++; } if (items_processed == 0) { diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 4eb1584c27..90377300a6 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -15,7 +15,7 @@ import sys import tempfile from typing import TextIO -from aioesphomeapi import APIClient, APIConnectionError, ReconnectLogic +from aioesphomeapi import APIClient, APIConnectionError, LogParser, ReconnectLogic import pytest import pytest_asyncio @@ -365,11 +365,21 @@ async def _read_stream_lines( stream: asyncio.StreamReader, lines: list[str], output_stream: TextIO ) -> None: """Read lines from a stream, append to list, and echo to output stream.""" + log_parser = LogParser() while line := await stream.readline(): - decoded_line = line.decode("utf-8", errors="replace") + decoded_line = ( + line.replace(b"\r", b"") + .replace(b"\n", b"") + .decode("utf8", "backslashreplace") + ) lines.append(decoded_line.rstrip()) # Echo to stdout/stderr in real-time - print(decoded_line.rstrip(), file=output_stream, flush=True) + # Print without newline to avoid double newlines + print( + log_parser.parse_line(decoded_line, timestamp=""), + file=output_stream, + flush=True, + ) @asynccontextmanager diff --git a/tests/integration/fixtures/api_message_size_batching.yaml b/tests/integration/fixtures/api_message_size_batching.yaml new file mode 100644 index 0000000000..c730dc1aa3 --- /dev/null +++ b/tests/integration/fixtures/api_message_size_batching.yaml @@ -0,0 +1,161 @@ +esphome: + name: message-size-batching-test +host: +api: +# Default batch_delay to test batching +logger: + +# Create entities that will produce different protobuf header sizes +# Header size depends on: 1 byte indicator + varint(payload_size) + varint(message_type) +# 4-byte header: type < 128, payload < 128 +# 5-byte header: type < 128, payload 128-16383 OR type 128+, payload < 128 +# 6-byte header: type 128+, payload 128-16383 + +# Small select with few options - produces small message +select: + - platform: template + name: "Small Select" + id: small_select + optimistic: true + options: + - "Option A" + - "Option B" + initial_option: "Option A" + update_interval: 5.0s + + # Medium select with more options - produces medium message + - platform: template + name: "Medium Select" + id: medium_select + optimistic: true + options: + - "Option 001" + - "Option 002" + - "Option 003" + - "Option 004" + - "Option 005" + - "Option 006" + - "Option 007" + - "Option 008" + - "Option 009" + - "Option 010" + - "Option 011" + - "Option 012" + - "Option 013" + - "Option 014" + - "Option 015" + - "Option 016" + - "Option 017" + - "Option 018" + - "Option 019" + - "Option 020" + initial_option: "Option 001" + update_interval: 5.0s + + # Large select with many options - produces larger message + - platform: template + name: "Large Select with Many Options to Create Larger Payload" + id: large_select + optimistic: true + options: + - "Long Option Name 001 - This is a longer option name to increase message size" + - "Long Option Name 002 - This is a longer option name to increase message size" + - "Long Option Name 003 - This is a longer option name to increase message size" + - "Long Option Name 004 - This is a longer option name to increase message size" + - "Long Option Name 005 - This is a longer option name to increase message size" + - "Long Option Name 006 - This is a longer option name to increase message size" + - "Long Option Name 007 - This is a longer option name to increase message size" + - "Long Option Name 008 - This is a longer option name to increase message size" + - "Long Option Name 009 - This is a longer option name to increase message size" + - "Long Option Name 010 - This is a longer option name to increase message size" + - "Long Option Name 011 - This is a longer option name to increase message size" + - "Long Option Name 012 - This is a longer option name to increase message size" + - "Long Option Name 013 - This is a longer option name to increase message size" + - "Long Option Name 014 - This is a longer option name to increase message size" + - "Long Option Name 015 - This is a longer option name to increase message size" + - "Long Option Name 016 - This is a longer option name to increase message size" + - "Long Option Name 017 - This is a longer option name to increase message size" + - "Long Option Name 018 - This is a longer option name to increase message size" + - "Long Option Name 019 - This is a longer option name to increase message size" + - "Long Option Name 020 - This is a longer option name to increase message size" + - "Long Option Name 021 - This is a longer option name to increase message size" + - "Long Option Name 022 - This is a longer option name to increase message size" + - "Long Option Name 023 - This is a longer option name to increase message size" + - "Long Option Name 024 - This is a longer option name to increase message size" + - "Long Option Name 025 - This is a longer option name to increase message size" + - "Long Option Name 026 - This is a longer option name to increase message size" + - "Long Option Name 027 - This is a longer option name to increase message size" + - "Long Option Name 028 - This is a longer option name to increase message size" + - "Long Option Name 029 - This is a longer option name to increase message size" + - "Long Option Name 030 - This is a longer option name to increase message size" + - "Long Option Name 031 - This is a longer option name to increase message size" + - "Long Option Name 032 - This is a longer option name to increase message size" + - "Long Option Name 033 - This is a longer option name to increase message size" + - "Long Option Name 034 - This is a longer option name to increase message size" + - "Long Option Name 035 - This is a longer option name to increase message size" + - "Long Option Name 036 - This is a longer option name to increase message size" + - "Long Option Name 037 - This is a longer option name to increase message size" + - "Long Option Name 038 - This is a longer option name to increase message size" + - "Long Option Name 039 - This is a longer option name to increase message size" + - "Long Option Name 040 - This is a longer option name to increase message size" + - "Long Option Name 041 - This is a longer option name to increase message size" + - "Long Option Name 042 - This is a longer option name to increase message size" + - "Long Option Name 043 - This is a longer option name to increase message size" + - "Long Option Name 044 - This is a longer option name to increase message size" + - "Long Option Name 045 - This is a longer option name to increase message size" + - "Long Option Name 046 - This is a longer option name to increase message size" + - "Long Option Name 047 - This is a longer option name to increase message size" + - "Long Option Name 048 - This is a longer option name to increase message size" + - "Long Option Name 049 - This is a longer option name to increase message size" + - "Long Option Name 050 - This is a longer option name to increase message size" + initial_option: "Long Option Name 001 - This is a longer option name to increase message size" + update_interval: 5.0s + +# Text sensors with different value lengths +text_sensor: + - platform: template + name: "Short Text Sensor" + id: short_text_sensor + lambda: |- + return {"OK"}; + update_interval: 5.0s + + - platform: template + name: "Medium Text Sensor" + id: medium_text_sensor + lambda: |- + return {"This is a medium length text sensor value that should produce a medium sized message"}; + update_interval: 5.0s + + - platform: template + name: "Long Text Sensor with Very Long Value" + id: long_text_sensor + lambda: |- + return {"This is a very long text sensor value that contains a lot of text to ensure we get a larger protobuf message. The message should be long enough to require a 2-byte varint for the payload size, which happens when the payload exceeds 127 bytes. Let's add even more text here to make sure we exceed that threshold and test the batching of messages with different header sizes properly."}; + update_interval: 5.0s + +# Text input which can have various lengths +text: + - platform: template + name: "Test Text Input" + id: test_text_input + optimistic: true + mode: text + min_length: 0 + max_length: 255 + initial_value: "Initial value" + update_interval: 5.0s + +# Number entity to add variety (different message type number) +# The ListEntitiesNumberResponse has message type 49 +# The NumberStateResponse has message type 50 +number: + - platform: template + name: "Test Number with Long Name to Increase Message Size" + id: test_number + optimistic: true + min_value: 0 + max_value: 1000 + step: 0.1 + initial_value: 42.0 + update_interval: 5.0s diff --git a/tests/integration/fixtures/large_message_batching.yaml b/tests/integration/fixtures/large_message_batching.yaml new file mode 100644 index 0000000000..1b2d817cd4 --- /dev/null +++ b/tests/integration/fixtures/large_message_batching.yaml @@ -0,0 +1,137 @@ +esphome: + name: large-message-test +host: +api: +logger: + +# Create a select entity with many options to exceed 1390 bytes +select: + - platform: template + name: "Large Select" + id: large_select + optimistic: true + options: + - "Option 000 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 001 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 002 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 003 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 004 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 005 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 006 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 007 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 008 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 009 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 010 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 011 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 012 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 013 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 014 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 015 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 016 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 017 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 018 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 019 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 020 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 021 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 022 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 023 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 024 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 025 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 026 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 027 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 028 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 029 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 030 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 031 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 032 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 033 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 034 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 035 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 036 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 037 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 038 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 039 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 040 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 041 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 042 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 043 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 044 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 045 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 046 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 047 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 048 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 049 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 050 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 051 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 052 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 053 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 054 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 055 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 056 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 057 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 058 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 059 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 060 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 061 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 062 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 063 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 064 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 065 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 066 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 067 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 068 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 069 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 070 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 071 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 072 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 073 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 074 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 075 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 076 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 077 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 078 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 079 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 080 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 081 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 082 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 083 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 084 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 085 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 086 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 087 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 088 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 089 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 090 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 091 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 092 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 093 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 094 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 095 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 096 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 097 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 098 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + - "Option 099 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + initial_option: "Option 000 - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + +# Add some other entities to test batching with the large select +sensor: + - platform: template + name: "Test Sensor" + id: test_sensor + lambda: |- + return 42.0; + update_interval: 1s + +binary_sensor: + - platform: template + name: "Test Binary Sensor" + id: test_binary_sensor + lambda: |- + return true; + +switch: + - platform: template + name: "Test Switch" + id: test_switch + optimistic: true + diff --git a/tests/integration/test_api_message_size_batching.py b/tests/integration/test_api_message_size_batching.py new file mode 100644 index 0000000000..631e64825e --- /dev/null +++ b/tests/integration/test_api_message_size_batching.py @@ -0,0 +1,194 @@ +"""Integration test for API batching with various message sizes.""" + +from __future__ import annotations + +import asyncio + +from aioesphomeapi import EntityState, NumberInfo, SelectInfo, TextInfo, TextSensorInfo +import pytest + +from .types import APIClientConnectedFactory, RunCompiledFunction + + +@pytest.mark.asyncio +async def test_api_message_size_batching( + yaml_config: str, + run_compiled: RunCompiledFunction, + api_client_connected: APIClientConnectedFactory, +) -> None: + """Test API can batch messages of various sizes correctly.""" + # Write, compile and run the ESPHome device, then connect to API + loop = asyncio.get_running_loop() + async with run_compiled(yaml_config), api_client_connected() as client: + # Verify we can get device info + device_info = await client.device_info() + assert device_info is not None + assert device_info.name == "message-size-batching-test" + + # List entities - this will batch various sized messages together + entity_info, services = await asyncio.wait_for( + client.list_entities_services(), timeout=5.0 + ) + + # Count different entity types + selects = [] + text_sensors = [] + text_inputs = [] + numbers = [] + other_entities = [] + + for entity in entity_info: + if isinstance(entity, SelectInfo): + selects.append(entity) + elif isinstance(entity, TextSensorInfo): + text_sensors.append(entity) + elif isinstance(entity, TextInfo): + text_inputs.append(entity) + elif isinstance(entity, NumberInfo): + numbers.append(entity) + else: + other_entities.append(entity) + + # Verify we have our test entities - exact counts + assert len(selects) == 3, ( + f"Expected exactly 3 select entities, got {len(selects)}" + ) + assert len(text_sensors) == 3, ( + f"Expected exactly 3 text sensor entities, got {len(text_sensors)}" + ) + assert len(text_inputs) == 1, ( + f"Expected exactly 1 text input entity, got {len(text_inputs)}" + ) + + # Collect all select entity object_ids for error messages + select_ids = [s.object_id for s in selects] + + # Find our specific test entities + small_select = None + medium_select = None + large_select = None + + for select in selects: + if select.object_id == "small_select": + small_select = select + elif select.object_id == "medium_select": + medium_select = select + elif ( + select.object_id + == "large_select_with_many_options_to_create_larger_payload" + ): + large_select = select + + assert small_select is not None, ( + f"Could not find small_select entity. Found: {select_ids}" + ) + assert medium_select is not None, ( + f"Could not find medium_select entity. Found: {select_ids}" + ) + assert large_select is not None, ( + f"Could not find large_select entity. Found: {select_ids}" + ) + + # Verify the selects have the expected number of options + assert len(small_select.options) == 2, ( + f"Expected 2 options for small_select, got {len(small_select.options)}" + ) + assert len(medium_select.options) == 20, ( + f"Expected 20 options for medium_select, got {len(medium_select.options)}" + ) + assert len(large_select.options) == 50, ( + f"Expected 50 options for large_select, got {len(large_select.options)}" + ) + + # Collect all text sensor object_ids for error messages + text_sensor_ids = [t.object_id for t in text_sensors] + + # Verify text sensors with different value lengths + short_text_sensor = None + medium_text_sensor = None + long_text_sensor = None + + for text_sensor in text_sensors: + if text_sensor.object_id == "short_text_sensor": + short_text_sensor = text_sensor + elif text_sensor.object_id == "medium_text_sensor": + medium_text_sensor = text_sensor + elif text_sensor.object_id == "long_text_sensor_with_very_long_value": + long_text_sensor = text_sensor + + assert short_text_sensor is not None, ( + f"Could not find short_text_sensor. Found: {text_sensor_ids}" + ) + assert medium_text_sensor is not None, ( + f"Could not find medium_text_sensor. Found: {text_sensor_ids}" + ) + assert long_text_sensor is not None, ( + f"Could not find long_text_sensor. Found: {text_sensor_ids}" + ) + + # Check text input which can have a long max_length + text_input = None + text_input_ids = [t.object_id for t in text_inputs] + + for ti in text_inputs: + if ti.object_id == "test_text_input": + text_input = ti + break + + assert text_input is not None, ( + f"Could not find test_text_input. Found: {text_input_ids}" + ) + assert text_input.max_length == 255, ( + f"Expected max_length 255, got {text_input.max_length}" + ) + + # Verify total entity count - messages of various sizes were batched successfully + # We have: 3 selects + 3 text sensors + 1 text input + 1 number = 8 total + total_entities = len(entity_info) + assert total_entities == 8, f"Expected exactly 8 entities, got {total_entities}" + + # Check we have the expected entity types + assert len(numbers) == 1, ( + f"Expected exactly 1 number entity, got {len(numbers)}" + ) + assert len(other_entities) == 0, ( + f"Unexpected entity types found: {[type(e).__name__ for e in other_entities]}" + ) + + # Subscribe to state changes to verify batching works + # Collect keys from entity info to know what states to expect + expected_keys = {entity.key for entity in entity_info} + assert len(expected_keys) == 8, ( + f"Expected 8 unique entity keys, got {len(expected_keys)}" + ) + + received_keys: set[int] = set() + states_future: asyncio.Future[None] = loop.create_future() + + def on_state(state: EntityState) -> None: + """Track when states are received.""" + received_keys.add(state.key) + # Check if we've received states from all expected entities + if expected_keys.issubset(received_keys) and not states_future.done(): + states_future.set_result(None) + + client.subscribe_states(on_state) + + # Wait for states with timeout + try: + await asyncio.wait_for(states_future, timeout=5.0) + except asyncio.TimeoutError: + missing_keys = expected_keys - received_keys + pytest.fail( + f"Did not receive states from all entities within 5 seconds. " + f"Missing keys: {missing_keys}, " + f"Received {len(received_keys)} of {len(expected_keys)} expected states" + ) + + # Verify we received states from all entities + assert expected_keys.issubset(received_keys) + + # Check that various message sizes were handled correctly + # Small messages (4-byte header): type < 128, payload < 128 + # Medium messages (5-byte header): type < 128, payload 128-16383 OR type 128+, payload < 128 + # Large messages (6-byte header): type 128+, payload 128-16383 diff --git a/tests/integration/test_large_message_batching.py b/tests/integration/test_large_message_batching.py new file mode 100644 index 0000000000..399fd39dd3 --- /dev/null +++ b/tests/integration/test_large_message_batching.py @@ -0,0 +1,59 @@ +"""Integration test for API handling of large messages exceeding batch size.""" + +from __future__ import annotations + +from aioesphomeapi import SelectInfo +import pytest + +from .types import APIClientConnectedFactory, RunCompiledFunction + + +@pytest.mark.asyncio +async def test_large_message_batching( + yaml_config: str, + run_compiled: RunCompiledFunction, + api_client_connected: APIClientConnectedFactory, +) -> None: + """Test API can handle large messages (>1390 bytes) in batches.""" + # Write, compile and run the ESPHome device, then connect to API + async with run_compiled(yaml_config), api_client_connected() as client: + # Verify we can get device info + device_info = await client.device_info() + assert device_info is not None + assert device_info.name == "large-message-test" + + # List entities - this will include our select with many options + entity_info, services = await client.list_entities_services() + + # Find our large select entity + large_select = None + for entity in entity_info: + if isinstance(entity, SelectInfo) and entity.object_id == "large_select": + large_select = entity + break + + assert large_select is not None, "Could not find large_select entity" + + # Verify the select has all its options + # We created 100 options with long names + assert len(large_select.options) == 100, ( + f"Expected 100 options, got {len(large_select.options)}" + ) + + # Verify all options are present and correct + for i in range(100): + expected_option = f"Option {i:03d} - This is a very long option name to make the message larger than the typical batch size of 1390 bytes" + assert expected_option in large_select.options, ( + f"Missing option: {expected_option}" + ) + + # Also verify we can still receive other entities in the same batch + # Count total entities - should have at least our select plus some sensors + entity_count = len(entity_info) + assert entity_count >= 4, f"Expected at least 4 entities, got {entity_count}" + + # Verify we have different entity types (not just selects) + entity_types = {type(entity).__name__ for entity in entity_info} + assert len(entity_types) >= 2, ( + f"Expected multiple entity types, got {entity_types}" + ) From 1dbebe90bac2783618365948c9159933b5ee1e19 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 15 Jun 2025 16:29:25 -0500 Subject: [PATCH 06/14] Add common base classes for entity protobuf messages to reduce duplicate code (#9090) --- esphome/components/api/api.proto | 44 ++++ esphome/components/api/api_connection.cpp | 44 ++-- esphome/components/api/api_connection.h | 9 +- esphome/components/api/api_options.proto | 1 + esphome/components/api/api_pb2.cpp | 1 + esphome/components/api/api_pb2.h | 291 +++++----------------- script/api_protobuf/api_protobuf.py | 171 ++++++++++++- 7 files changed, 306 insertions(+), 255 deletions(-) diff --git a/esphome/components/api/api.proto b/esphome/components/api/api.proto index c5c63b8dfc..843b72795a 100644 --- a/esphome/components/api/api.proto +++ b/esphome/components/api/api.proto @@ -266,6 +266,7 @@ enum EntityCategory { // ==================== BINARY SENSOR ==================== message ListEntitiesBinarySensorResponse { option (id) = 12; + option (base_class) = "InfoResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_BINARY_SENSOR"; @@ -282,6 +283,7 @@ message ListEntitiesBinarySensorResponse { } message BinarySensorStateResponse { option (id) = 21; + option (base_class) = "StateResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_BINARY_SENSOR"; option (no_delay) = true; @@ -296,6 +298,7 @@ message BinarySensorStateResponse { // ==================== COVER ==================== message ListEntitiesCoverResponse { option (id) = 13; + option (base_class) = "InfoResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_COVER"; @@ -325,6 +328,7 @@ enum CoverOperation { } message CoverStateResponse { option (id) = 22; + option (base_class) = "StateResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_COVER"; option (no_delay) = true; @@ -367,6 +371,7 @@ message CoverCommandRequest { // ==================== FAN ==================== message ListEntitiesFanResponse { option (id) = 14; + option (base_class) = "InfoResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_FAN"; @@ -395,6 +400,7 @@ enum FanDirection { } message FanStateResponse { option (id) = 23; + option (base_class) = "StateResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_FAN"; option (no_delay) = true; @@ -444,6 +450,7 @@ enum ColorMode { } message ListEntitiesLightResponse { option (id) = 15; + option (base_class) = "InfoResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_LIGHT"; @@ -467,6 +474,7 @@ message ListEntitiesLightResponse { } message LightStateResponse { option (id) = 24; + option (base_class) = "StateResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_LIGHT"; option (no_delay) = true; @@ -536,6 +544,7 @@ enum SensorLastResetType { message ListEntitiesSensorResponse { option (id) = 16; + option (base_class) = "InfoResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_SENSOR"; @@ -557,6 +566,7 @@ message ListEntitiesSensorResponse { } message SensorStateResponse { option (id) = 25; + option (base_class) = "StateResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_SENSOR"; option (no_delay) = true; @@ -571,6 +581,7 @@ message SensorStateResponse { // ==================== SWITCH ==================== message ListEntitiesSwitchResponse { option (id) = 17; + option (base_class) = "InfoResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_SWITCH"; @@ -587,6 +598,7 @@ message ListEntitiesSwitchResponse { } message SwitchStateResponse { option (id) = 26; + option (base_class) = "StateResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_SWITCH"; option (no_delay) = true; @@ -607,6 +619,7 @@ message SwitchCommandRequest { // ==================== TEXT SENSOR ==================== message ListEntitiesTextSensorResponse { option (id) = 18; + option (base_class) = "InfoResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_TEXT_SENSOR"; @@ -622,6 +635,7 @@ message ListEntitiesTextSensorResponse { } message TextSensorStateResponse { option (id) = 27; + option (base_class) = "StateResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_TEXT_SENSOR"; option (no_delay) = true; @@ -789,6 +803,7 @@ message ExecuteServiceRequest { // ==================== CAMERA ==================== message ListEntitiesCameraResponse { option (id) = 43; + option (base_class) = "InfoResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_ESP32_CAMERA"; @@ -869,6 +884,7 @@ enum ClimatePreset { } message ListEntitiesClimateResponse { option (id) = 46; + option (base_class) = "InfoResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_CLIMATE"; @@ -903,6 +919,7 @@ message ListEntitiesClimateResponse { } message ClimateStateResponse { option (id) = 47; + option (base_class) = "StateResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_CLIMATE"; option (no_delay) = true; @@ -964,6 +981,7 @@ enum NumberMode { } message ListEntitiesNumberResponse { option (id) = 49; + option (base_class) = "InfoResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_NUMBER"; @@ -984,6 +1002,7 @@ message ListEntitiesNumberResponse { } message NumberStateResponse { option (id) = 50; + option (base_class) = "StateResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_NUMBER"; option (no_delay) = true; @@ -1007,6 +1026,7 @@ message NumberCommandRequest { // ==================== SELECT ==================== message ListEntitiesSelectResponse { option (id) = 52; + option (base_class) = "InfoResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_SELECT"; @@ -1022,6 +1042,7 @@ message ListEntitiesSelectResponse { } message SelectStateResponse { option (id) = 53; + option (base_class) = "StateResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_SELECT"; option (no_delay) = true; @@ -1045,6 +1066,7 @@ message SelectCommandRequest { // ==================== SIREN ==================== message ListEntitiesSirenResponse { option (id) = 55; + option (base_class) = "InfoResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_SIREN"; @@ -1062,6 +1084,7 @@ message ListEntitiesSirenResponse { } message SirenStateResponse { option (id) = 56; + option (base_class) = "StateResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_SIREN"; option (no_delay) = true; @@ -1102,6 +1125,7 @@ enum LockCommand { } message ListEntitiesLockResponse { option (id) = 58; + option (base_class) = "InfoResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_LOCK"; @@ -1123,6 +1147,7 @@ message ListEntitiesLockResponse { } message LockStateResponse { option (id) = 59; + option (base_class) = "StateResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_LOCK"; option (no_delay) = true; @@ -1145,6 +1170,7 @@ message LockCommandRequest { // ==================== BUTTON ==================== message ListEntitiesButtonResponse { option (id) = 61; + option (base_class) = "InfoResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_BUTTON"; @@ -1196,6 +1222,7 @@ message MediaPlayerSupportedFormat { } message ListEntitiesMediaPlayerResponse { option (id) = 63; + option (base_class) = "InfoResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_MEDIA_PLAYER"; @@ -1214,6 +1241,7 @@ message ListEntitiesMediaPlayerResponse { } message MediaPlayerStateResponse { option (id) = 64; + option (base_class) = "StateResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_MEDIA_PLAYER"; option (no_delay) = true; @@ -1735,6 +1763,7 @@ enum AlarmControlPanelStateCommand { message ListEntitiesAlarmControlPanelResponse { option (id) = 94; + option (base_class) = "InfoResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_ALARM_CONTROL_PANEL"; @@ -1752,6 +1781,7 @@ message ListEntitiesAlarmControlPanelResponse { message AlarmControlPanelStateResponse { option (id) = 95; + option (base_class) = "StateResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_ALARM_CONTROL_PANEL"; option (no_delay) = true; @@ -1776,6 +1806,7 @@ enum TextMode { } message ListEntitiesTextResponse { option (id) = 97; + option (base_class) = "InfoResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_TEXT"; @@ -1794,6 +1825,7 @@ message ListEntitiesTextResponse { } message TextStateResponse { option (id) = 98; + option (base_class) = "StateResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_TEXT"; option (no_delay) = true; @@ -1818,6 +1850,7 @@ message TextCommandRequest { // ==================== DATETIME DATE ==================== message ListEntitiesDateResponse { option (id) = 100; + option (base_class) = "InfoResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_DATETIME_DATE"; @@ -1832,6 +1865,7 @@ message ListEntitiesDateResponse { } message DateStateResponse { option (id) = 101; + option (base_class) = "StateResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_DATETIME_DATE"; option (no_delay) = true; @@ -1859,6 +1893,7 @@ message DateCommandRequest { // ==================== DATETIME TIME ==================== message ListEntitiesTimeResponse { option (id) = 103; + option (base_class) = "InfoResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_DATETIME_TIME"; @@ -1873,6 +1908,7 @@ message ListEntitiesTimeResponse { } message TimeStateResponse { option (id) = 104; + option (base_class) = "StateResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_DATETIME_TIME"; option (no_delay) = true; @@ -1900,6 +1936,7 @@ message TimeCommandRequest { // ==================== EVENT ==================== message ListEntitiesEventResponse { option (id) = 107; + option (base_class) = "InfoResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_EVENT"; @@ -1917,6 +1954,7 @@ message ListEntitiesEventResponse { } message EventResponse { option (id) = 108; + option (base_class) = "StateResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_EVENT"; @@ -1927,6 +1965,7 @@ message EventResponse { // ==================== VALVE ==================== message ListEntitiesValveResponse { option (id) = 109; + option (base_class) = "InfoResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_VALVE"; @@ -1952,6 +1991,7 @@ enum ValveOperation { } message ValveStateResponse { option (id) = 110; + option (base_class) = "StateResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_VALVE"; option (no_delay) = true; @@ -1976,6 +2016,7 @@ message ValveCommandRequest { // ==================== DATETIME DATETIME ==================== message ListEntitiesDateTimeResponse { option (id) = 112; + option (base_class) = "InfoResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_DATETIME_DATETIME"; @@ -1990,6 +2031,7 @@ message ListEntitiesDateTimeResponse { } message DateTimeStateResponse { option (id) = 113; + option (base_class) = "StateResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_DATETIME_DATETIME"; option (no_delay) = true; @@ -2013,6 +2055,7 @@ message DateTimeCommandRequest { // ==================== UPDATE ==================== message ListEntitiesUpdateResponse { option (id) = 116; + option (base_class) = "InfoResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_UPDATE"; @@ -2028,6 +2071,7 @@ message ListEntitiesUpdateResponse { } message UpdateStateResponse { option (id) = 117; + option (base_class) = "StateResponseProtoMessage"; option (source) = SOURCE_SERVER; option (ifdef) = "USE_UPDATE"; option (no_delay) = true; diff --git a/esphome/components/api/api_connection.cpp b/esphome/components/api/api_connection.cpp index 8328f5d2cd..3e2b7c0154 100644 --- a/esphome/components/api/api_connection.cpp +++ b/esphome/components/api/api_connection.cpp @@ -301,7 +301,7 @@ uint16_t APIConnection::try_send_binary_sensor_state(EntityBase *entity, APIConn BinarySensorStateResponse resp; resp.state = binary_sensor->state; resp.missing_state = !binary_sensor->has_state(); - resp.key = binary_sensor->get_object_id_hash(); + fill_entity_state_base(binary_sensor, resp); return encode_message_to_buffer(resp, BinarySensorStateResponse::MESSAGE_TYPE, conn, remaining_size, is_single); } @@ -335,7 +335,7 @@ uint16_t APIConnection::try_send_cover_state(EntityBase *entity, APIConnection * if (traits.get_supports_tilt()) msg.tilt = cover->tilt; msg.current_operation = static_cast(cover->current_operation); - msg.key = cover->get_object_id_hash(); + fill_entity_state_base(cover, msg); return encode_message_to_buffer(msg, CoverStateResponse::MESSAGE_TYPE, conn, remaining_size, is_single); } uint16_t APIConnection::try_send_cover_info(EntityBase *entity, APIConnection *conn, uint32_t remaining_size, @@ -403,7 +403,7 @@ uint16_t APIConnection::try_send_fan_state(EntityBase *entity, APIConnection *co msg.direction = static_cast(fan->direction); if (traits.supports_preset_modes()) msg.preset_mode = fan->preset_mode; - msg.key = fan->get_object_id_hash(); + fill_entity_state_base(fan, msg); return encode_message_to_buffer(msg, FanStateResponse::MESSAGE_TYPE, conn, remaining_size, is_single); } uint16_t APIConnection::try_send_fan_info(EntityBase *entity, APIConnection *conn, uint32_t remaining_size, @@ -470,7 +470,7 @@ uint16_t APIConnection::try_send_light_state(EntityBase *entity, APIConnection * resp.warm_white = values.get_warm_white(); if (light->supports_effects()) resp.effect = light->get_effect_name(); - resp.key = light->get_object_id_hash(); + fill_entity_state_base(light, resp); return encode_message_to_buffer(resp, LightStateResponse::MESSAGE_TYPE, conn, remaining_size, is_single); } uint16_t APIConnection::try_send_light_info(EntityBase *entity, APIConnection *conn, uint32_t remaining_size, @@ -552,7 +552,7 @@ uint16_t APIConnection::try_send_sensor_state(EntityBase *entity, APIConnection SensorStateResponse resp; resp.state = sensor->state; resp.missing_state = !sensor->has_state(); - resp.key = sensor->get_object_id_hash(); + fill_entity_state_base(sensor, resp); return encode_message_to_buffer(resp, SensorStateResponse::MESSAGE_TYPE, conn, remaining_size, is_single); } @@ -586,7 +586,7 @@ uint16_t APIConnection::try_send_switch_state(EntityBase *entity, APIConnection auto *a_switch = static_cast(entity); SwitchStateResponse resp; resp.state = a_switch->state; - resp.key = a_switch->get_object_id_hash(); + fill_entity_state_base(a_switch, resp); return encode_message_to_buffer(resp, SwitchStateResponse::MESSAGE_TYPE, conn, remaining_size, is_single); } @@ -629,7 +629,7 @@ uint16_t APIConnection::try_send_text_sensor_state(EntityBase *entity, APIConnec TextSensorStateResponse resp; resp.state = text_sensor->state; resp.missing_state = !text_sensor->has_state(); - resp.key = text_sensor->get_object_id_hash(); + fill_entity_state_base(text_sensor, resp); return encode_message_to_buffer(resp, TextSensorStateResponse::MESSAGE_TYPE, conn, remaining_size, is_single); } uint16_t APIConnection::try_send_text_sensor_info(EntityBase *entity, APIConnection *conn, uint32_t remaining_size, @@ -653,7 +653,7 @@ uint16_t APIConnection::try_send_climate_state(EntityBase *entity, APIConnection bool is_single) { auto *climate = static_cast(entity); ClimateStateResponse resp; - resp.key = climate->get_object_id_hash(); + fill_entity_state_base(climate, resp); auto traits = climate->get_traits(); resp.mode = static_cast(climate->mode); resp.action = static_cast(climate->action); @@ -762,7 +762,7 @@ uint16_t APIConnection::try_send_number_state(EntityBase *entity, APIConnection NumberStateResponse resp; resp.state = number->state; resp.missing_state = !number->has_state(); - resp.key = number->get_object_id_hash(); + fill_entity_state_base(number, resp); return encode_message_to_buffer(resp, NumberStateResponse::MESSAGE_TYPE, conn, remaining_size, is_single); } @@ -803,7 +803,7 @@ uint16_t APIConnection::try_send_date_state(EntityBase *entity, APIConnection *c resp.year = date->year; resp.month = date->month; resp.day = date->day; - resp.key = date->get_object_id_hash(); + fill_entity_state_base(date, resp); return encode_message_to_buffer(resp, DateStateResponse::MESSAGE_TYPE, conn, remaining_size, is_single); } void APIConnection::send_date_info(datetime::DateEntity *date) { @@ -840,7 +840,7 @@ uint16_t APIConnection::try_send_time_state(EntityBase *entity, APIConnection *c resp.hour = time->hour; resp.minute = time->minute; resp.second = time->second; - resp.key = time->get_object_id_hash(); + fill_entity_state_base(time, resp); return encode_message_to_buffer(resp, TimeStateResponse::MESSAGE_TYPE, conn, remaining_size, is_single); } void APIConnection::send_time_info(datetime::TimeEntity *time) { @@ -879,7 +879,7 @@ uint16_t APIConnection::try_send_datetime_state(EntityBase *entity, APIConnectio ESPTime state = datetime->state_as_esptime(); resp.epoch_seconds = state.timestamp; } - resp.key = datetime->get_object_id_hash(); + fill_entity_state_base(datetime, resp); return encode_message_to_buffer(resp, DateTimeStateResponse::MESSAGE_TYPE, conn, remaining_size, is_single); } void APIConnection::send_datetime_info(datetime::DateTimeEntity *datetime) { @@ -918,7 +918,7 @@ uint16_t APIConnection::try_send_text_state(EntityBase *entity, APIConnection *c TextStateResponse resp; resp.state = text->state; resp.missing_state = !text->has_state(); - resp.key = text->get_object_id_hash(); + fill_entity_state_base(text, resp); return encode_message_to_buffer(resp, TextStateResponse::MESSAGE_TYPE, conn, remaining_size, is_single); } @@ -959,7 +959,7 @@ uint16_t APIConnection::try_send_select_state(EntityBase *entity, APIConnection SelectStateResponse resp; resp.state = select->state; resp.missing_state = !select->has_state(); - resp.key = select->get_object_id_hash(); + fill_entity_state_base(select, resp); return encode_message_to_buffer(resp, SelectStateResponse::MESSAGE_TYPE, conn, remaining_size, is_single); } @@ -1019,7 +1019,7 @@ uint16_t APIConnection::try_send_lock_state(EntityBase *entity, APIConnection *c auto *a_lock = static_cast(entity); LockStateResponse resp; resp.state = static_cast(a_lock->state); - resp.key = a_lock->get_object_id_hash(); + fill_entity_state_base(a_lock, resp); return encode_message_to_buffer(resp, LockStateResponse::MESSAGE_TYPE, conn, remaining_size, is_single); } @@ -1063,7 +1063,7 @@ uint16_t APIConnection::try_send_valve_state(EntityBase *entity, APIConnection * ValveStateResponse resp; resp.position = valve->position; resp.current_operation = static_cast(valve->current_operation); - resp.key = valve->get_object_id_hash(); + fill_entity_state_base(valve, resp); return encode_message_to_buffer(resp, ValveStateResponse::MESSAGE_TYPE, conn, remaining_size, is_single); } void APIConnection::send_valve_info(valve::Valve *valve) { @@ -1111,7 +1111,7 @@ uint16_t APIConnection::try_send_media_player_state(EntityBase *entity, APIConne resp.state = static_cast(report_state); resp.volume = media_player->volume; resp.muted = media_player->is_muted(); - resp.key = media_player->get_object_id_hash(); + fill_entity_state_base(media_player, resp); return encode_message_to_buffer(resp, MediaPlayerStateResponse::MESSAGE_TYPE, conn, remaining_size, is_single); } void APIConnection::send_media_player_info(media_player::MediaPlayer *media_player) { @@ -1375,7 +1375,7 @@ uint16_t APIConnection::try_send_alarm_control_panel_state(EntityBase *entity, A auto *a_alarm_control_panel = static_cast(entity); AlarmControlPanelStateResponse resp; resp.state = static_cast(a_alarm_control_panel->get_state()); - resp.key = a_alarm_control_panel->get_object_id_hash(); + fill_entity_state_base(a_alarm_control_panel, resp); return encode_message_to_buffer(resp, AlarmControlPanelStateResponse::MESSAGE_TYPE, conn, remaining_size, is_single); } void APIConnection::send_alarm_control_panel_info(alarm_control_panel::AlarmControlPanel *a_alarm_control_panel) { @@ -1439,7 +1439,7 @@ uint16_t APIConnection::try_send_event_response(event::Event *event, const std:: uint32_t remaining_size, bool is_single) { EventResponse resp; resp.event_type = event_type; - resp.key = event->get_object_id_hash(); + fill_entity_state_base(event, resp); return encode_message_to_buffer(resp, EventResponse::MESSAGE_TYPE, conn, remaining_size, is_single); } @@ -1477,7 +1477,7 @@ uint16_t APIConnection::try_send_update_state(EntityBase *entity, APIConnection resp.release_summary = update->update_info.summary; resp.release_url = update->update_info.release_url; } - resp.key = update->get_object_id_hash(); + fill_entity_state_base(update, resp); return encode_message_to_buffer(resp, UpdateStateResponse::MESSAGE_TYPE, conn, remaining_size, is_single); } void APIConnection::send_update_info(update::UpdateEntity *update) { @@ -1538,7 +1538,7 @@ bool APIConnection::try_send_log_message(int level, const char *tag, const char buffer.encode_string(3, line, line_length); // string message = 3 // SubscribeLogsResponse - 29 - return this->send_buffer(buffer, 29); + return this->send_buffer(buffer, SubscribeLogsResponse::MESSAGE_TYPE); } HelloResponse APIConnection::hello(const HelloRequest &msg) { @@ -1685,7 +1685,7 @@ bool APIConnection::try_to_clear_buffer(bool log_out_of_space) { return false; } bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint16_t message_type) { - if (!this->try_to_clear_buffer(message_type != 29)) { // SubscribeLogsResponse + if (!this->try_to_clear_buffer(message_type != SubscribeLogsResponse::MESSAGE_TYPE)) { // SubscribeLogsResponse return false; } diff --git a/esphome/components/api/api_connection.h b/esphome/components/api/api_connection.h index 13e6066788..7cd41561d4 100644 --- a/esphome/components/api/api_connection.h +++ b/esphome/components/api/api_connection.h @@ -282,8 +282,8 @@ class APIConnection : public APIServerConnection { ProtoWriteBuffer allocate_batch_message_buffer(uint16_t size); protected: - // Helper function to fill common entity fields - template static void fill_entity_info_base(esphome::EntityBase *entity, ResponseT &response) { + // Helper function to fill common entity info fields + static void fill_entity_info_base(esphome::EntityBase *entity, InfoResponseProtoMessage &response) { // Set common fields that are shared by all entity types response.key = entity->get_object_id_hash(); response.object_id = entity->get_object_id(); @@ -297,6 +297,11 @@ class APIConnection : public APIServerConnection { response.entity_category = static_cast(entity->get_entity_category()); } + // Helper function to fill common entity state fields + static void fill_entity_state_base(esphome::EntityBase *entity, StateResponseProtoMessage &response) { + response.key = entity->get_object_id_hash(); + } + // Non-template helper to encode any ProtoMessage static uint16_t encode_message_to_buffer(ProtoMessage &msg, uint16_t message_type, APIConnection *conn, uint32_t remaining_size, bool is_single); diff --git a/esphome/components/api/api_options.proto b/esphome/components/api/api_options.proto index feaf39ba15..3a547b8688 100644 --- a/esphome/components/api/api_options.proto +++ b/esphome/components/api/api_options.proto @@ -21,4 +21,5 @@ extend google.protobuf.MessageOptions { optional string ifdef = 1038; optional bool log = 1039 [default=true]; optional bool no_delay = 1040 [default=false]; + optional string base_class = 1041; } diff --git a/esphome/components/api/api_pb2.cpp b/esphome/components/api/api_pb2.cpp index 2d609f6dd6..415409f880 100644 --- a/esphome/components/api/api_pb2.cpp +++ b/esphome/components/api/api_pb2.cpp @@ -628,6 +628,7 @@ template<> const char *proto_enum_to_string(enums::UpdateC } } #endif + bool HelloRequest::decode_varint(uint32_t field_id, ProtoVarInt value) { switch (field_id) { case 2: { diff --git a/esphome/components/api/api_pb2.h b/esphome/components/api/api_pb2.h index 8b3f7a7b2a..14a1f3f353 100644 --- a/esphome/components/api/api_pb2.h +++ b/esphome/components/api/api_pb2.h @@ -253,6 +253,27 @@ enum UpdateCommand : uint32_t { } // namespace enums +class InfoResponseProtoMessage : public ProtoMessage { + public: + ~InfoResponseProtoMessage() override = default; + std::string object_id{}; + uint32_t key{0}; + std::string name{}; + std::string unique_id{}; + bool disabled_by_default{false}; + std::string icon{}; + enums::EntityCategory entity_category{}; + + protected: +}; + +class StateResponseProtoMessage : public ProtoMessage { + public: + ~StateResponseProtoMessage() override = default; + uint32_t key{0}; + + protected: +}; class HelloRequest : public ProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 1; @@ -484,22 +505,15 @@ class SubscribeStatesRequest : public ProtoMessage { protected: }; -class ListEntitiesBinarySensorResponse : public ProtoMessage { +class ListEntitiesBinarySensorResponse : public InfoResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 12; static constexpr uint16_t ESTIMATED_SIZE = 56; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "list_entities_binary_sensor_response"; } #endif - std::string object_id{}; - uint32_t key{0}; - std::string name{}; - std::string unique_id{}; std::string device_class{}; bool is_status_binary_sensor{false}; - bool disabled_by_default{false}; - std::string icon{}; - enums::EntityCategory entity_category{}; void encode(ProtoWriteBuffer buffer) const override; void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP @@ -511,14 +525,13 @@ class ListEntitiesBinarySensorResponse : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class BinarySensorStateResponse : public ProtoMessage { +class BinarySensorStateResponse : public StateResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 21; static constexpr uint16_t ESTIMATED_SIZE = 9; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "binary_sensor_state_response"; } #endif - uint32_t key{0}; bool state{false}; bool missing_state{false}; void encode(ProtoWriteBuffer buffer) const override; @@ -531,24 +544,17 @@ class BinarySensorStateResponse : public ProtoMessage { bool decode_32bit(uint32_t field_id, Proto32Bit value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class ListEntitiesCoverResponse : public ProtoMessage { +class ListEntitiesCoverResponse : public InfoResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 13; static constexpr uint16_t ESTIMATED_SIZE = 62; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "list_entities_cover_response"; } #endif - std::string object_id{}; - uint32_t key{0}; - std::string name{}; - std::string unique_id{}; bool assumed_state{false}; bool supports_position{false}; bool supports_tilt{false}; std::string device_class{}; - bool disabled_by_default{false}; - std::string icon{}; - enums::EntityCategory entity_category{}; bool supports_stop{false}; void encode(ProtoWriteBuffer buffer) const override; void calculate_size(uint32_t &total_size) const override; @@ -561,14 +567,13 @@ class ListEntitiesCoverResponse : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class CoverStateResponse : public ProtoMessage { +class CoverStateResponse : public StateResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 22; static constexpr uint16_t ESTIMATED_SIZE = 19; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "cover_state_response"; } #endif - uint32_t key{0}; enums::LegacyCoverState legacy_state{}; float position{0.0f}; float tilt{0.0f}; @@ -608,24 +613,17 @@ class CoverCommandRequest : public ProtoMessage { bool decode_32bit(uint32_t field_id, Proto32Bit value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class ListEntitiesFanResponse : public ProtoMessage { +class ListEntitiesFanResponse : public InfoResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 14; static constexpr uint16_t ESTIMATED_SIZE = 73; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "list_entities_fan_response"; } #endif - std::string object_id{}; - uint32_t key{0}; - std::string name{}; - std::string unique_id{}; bool supports_oscillation{false}; bool supports_speed{false}; bool supports_direction{false}; int32_t supported_speed_count{0}; - bool disabled_by_default{false}; - std::string icon{}; - enums::EntityCategory entity_category{}; std::vector supported_preset_modes{}; void encode(ProtoWriteBuffer buffer) const override; void calculate_size(uint32_t &total_size) const override; @@ -638,14 +636,13 @@ class ListEntitiesFanResponse : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class FanStateResponse : public ProtoMessage { +class FanStateResponse : public StateResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 23; static constexpr uint16_t ESTIMATED_SIZE = 26; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "fan_state_response"; } #endif - uint32_t key{0}; bool state{false}; bool oscillating{false}; enums::FanSpeed speed{}; @@ -694,17 +691,13 @@ class FanCommandRequest : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class ListEntitiesLightResponse : public ProtoMessage { +class ListEntitiesLightResponse : public InfoResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 15; static constexpr uint16_t ESTIMATED_SIZE = 85; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "list_entities_light_response"; } #endif - std::string object_id{}; - uint32_t key{0}; - std::string name{}; - std::string unique_id{}; std::vector supported_color_modes{}; bool legacy_supports_brightness{false}; bool legacy_supports_rgb{false}; @@ -713,9 +706,6 @@ class ListEntitiesLightResponse : public ProtoMessage { float min_mireds{0.0f}; float max_mireds{0.0f}; std::vector effects{}; - bool disabled_by_default{false}; - std::string icon{}; - enums::EntityCategory entity_category{}; void encode(ProtoWriteBuffer buffer) const override; void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP @@ -727,14 +717,13 @@ class ListEntitiesLightResponse : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class LightStateResponse : public ProtoMessage { +class LightStateResponse : public StateResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 24; static constexpr uint16_t ESTIMATED_SIZE = 63; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "light_state_response"; } #endif - uint32_t key{0}; bool state{false}; float brightness{0.0f}; enums::ColorMode color_mode{}; @@ -803,26 +792,19 @@ class LightCommandRequest : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class ListEntitiesSensorResponse : public ProtoMessage { +class ListEntitiesSensorResponse : public InfoResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 16; static constexpr uint16_t ESTIMATED_SIZE = 73; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "list_entities_sensor_response"; } #endif - std::string object_id{}; - uint32_t key{0}; - std::string name{}; - std::string unique_id{}; - std::string icon{}; std::string unit_of_measurement{}; int32_t accuracy_decimals{0}; bool force_update{false}; std::string device_class{}; enums::SensorStateClass state_class{}; enums::SensorLastResetType legacy_last_reset_type{}; - bool disabled_by_default{false}; - enums::EntityCategory entity_category{}; void encode(ProtoWriteBuffer buffer) const override; void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP @@ -834,14 +816,13 @@ class ListEntitiesSensorResponse : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class SensorStateResponse : public ProtoMessage { +class SensorStateResponse : public StateResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 25; static constexpr uint16_t ESTIMATED_SIZE = 12; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "sensor_state_response"; } #endif - uint32_t key{0}; float state{0.0f}; bool missing_state{false}; void encode(ProtoWriteBuffer buffer) const override; @@ -854,21 +835,14 @@ class SensorStateResponse : public ProtoMessage { bool decode_32bit(uint32_t field_id, Proto32Bit value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class ListEntitiesSwitchResponse : public ProtoMessage { +class ListEntitiesSwitchResponse : public InfoResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 17; static constexpr uint16_t ESTIMATED_SIZE = 56; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "list_entities_switch_response"; } #endif - std::string object_id{}; - uint32_t key{0}; - std::string name{}; - std::string unique_id{}; - std::string icon{}; bool assumed_state{false}; - bool disabled_by_default{false}; - enums::EntityCategory entity_category{}; std::string device_class{}; void encode(ProtoWriteBuffer buffer) const override; void calculate_size(uint32_t &total_size) const override; @@ -881,14 +855,13 @@ class ListEntitiesSwitchResponse : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class SwitchStateResponse : public ProtoMessage { +class SwitchStateResponse : public StateResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 26; static constexpr uint16_t ESTIMATED_SIZE = 7; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "switch_state_response"; } #endif - uint32_t key{0}; bool state{false}; void encode(ProtoWriteBuffer buffer) const override; void calculate_size(uint32_t &total_size) const override; @@ -919,20 +892,13 @@ class SwitchCommandRequest : public ProtoMessage { bool decode_32bit(uint32_t field_id, Proto32Bit value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class ListEntitiesTextSensorResponse : public ProtoMessage { +class ListEntitiesTextSensorResponse : public InfoResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 18; static constexpr uint16_t ESTIMATED_SIZE = 54; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "list_entities_text_sensor_response"; } #endif - std::string object_id{}; - uint32_t key{0}; - std::string name{}; - std::string unique_id{}; - std::string icon{}; - bool disabled_by_default{false}; - enums::EntityCategory entity_category{}; std::string device_class{}; void encode(ProtoWriteBuffer buffer) const override; void calculate_size(uint32_t &total_size) const override; @@ -945,14 +911,13 @@ class ListEntitiesTextSensorResponse : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class TextSensorStateResponse : public ProtoMessage { +class TextSensorStateResponse : public StateResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 27; static constexpr uint16_t ESTIMATED_SIZE = 16; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "text_sensor_state_response"; } #endif - uint32_t key{0}; std::string state{}; bool missing_state{false}; void encode(ProtoWriteBuffer buffer) const override; @@ -1249,20 +1214,13 @@ class ExecuteServiceRequest : public ProtoMessage { bool decode_32bit(uint32_t field_id, Proto32Bit value) override; bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; }; -class ListEntitiesCameraResponse : public ProtoMessage { +class ListEntitiesCameraResponse : public InfoResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 43; static constexpr uint16_t ESTIMATED_SIZE = 45; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "list_entities_camera_response"; } #endif - std::string object_id{}; - uint32_t key{0}; - std::string name{}; - std::string unique_id{}; - bool disabled_by_default{false}; - std::string icon{}; - enums::EntityCategory entity_category{}; void encode(ProtoWriteBuffer buffer) const override; void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP @@ -1313,17 +1271,13 @@ class CameraImageRequest : public ProtoMessage { protected: bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class ListEntitiesClimateResponse : public ProtoMessage { +class ListEntitiesClimateResponse : public InfoResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 46; static constexpr uint16_t ESTIMATED_SIZE = 151; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "list_entities_climate_response"; } #endif - std::string object_id{}; - uint32_t key{0}; - std::string name{}; - std::string unique_id{}; bool supports_current_temperature{false}; bool supports_two_point_target_temperature{false}; std::vector supported_modes{}; @@ -1337,9 +1291,6 @@ class ListEntitiesClimateResponse : public ProtoMessage { std::vector supported_custom_fan_modes{}; std::vector supported_presets{}; std::vector supported_custom_presets{}; - bool disabled_by_default{false}; - std::string icon{}; - enums::EntityCategory entity_category{}; float visual_current_temperature_step{0.0f}; bool supports_current_humidity{false}; bool supports_target_humidity{false}; @@ -1356,14 +1307,13 @@ class ListEntitiesClimateResponse : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class ClimateStateResponse : public ProtoMessage { +class ClimateStateResponse : public StateResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 47; static constexpr uint16_t ESTIMATED_SIZE = 65; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "climate_state_response"; } #endif - uint32_t key{0}; enums::ClimateMode mode{}; float current_temperature{0.0f}; float target_temperature{0.0f}; @@ -1430,23 +1380,16 @@ class ClimateCommandRequest : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class ListEntitiesNumberResponse : public ProtoMessage { +class ListEntitiesNumberResponse : public InfoResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 49; static constexpr uint16_t ESTIMATED_SIZE = 80; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "list_entities_number_response"; } #endif - std::string object_id{}; - uint32_t key{0}; - std::string name{}; - std::string unique_id{}; - std::string icon{}; float min_value{0.0f}; float max_value{0.0f}; float step{0.0f}; - bool disabled_by_default{false}; - enums::EntityCategory entity_category{}; std::string unit_of_measurement{}; enums::NumberMode mode{}; std::string device_class{}; @@ -1461,14 +1404,13 @@ class ListEntitiesNumberResponse : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class NumberStateResponse : public ProtoMessage { +class NumberStateResponse : public StateResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 50; static constexpr uint16_t ESTIMATED_SIZE = 12; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "number_state_response"; } #endif - uint32_t key{0}; float state{0.0f}; bool missing_state{false}; void encode(ProtoWriteBuffer buffer) const override; @@ -1499,21 +1441,14 @@ class NumberCommandRequest : public ProtoMessage { protected: bool decode_32bit(uint32_t field_id, Proto32Bit value) override; }; -class ListEntitiesSelectResponse : public ProtoMessage { +class ListEntitiesSelectResponse : public InfoResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 52; static constexpr uint16_t ESTIMATED_SIZE = 63; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "list_entities_select_response"; } #endif - std::string object_id{}; - uint32_t key{0}; - std::string name{}; - std::string unique_id{}; - std::string icon{}; std::vector options{}; - bool disabled_by_default{false}; - enums::EntityCategory entity_category{}; void encode(ProtoWriteBuffer buffer) const override; void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP @@ -1525,14 +1460,13 @@ class ListEntitiesSelectResponse : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class SelectStateResponse : public ProtoMessage { +class SelectStateResponse : public StateResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 53; static constexpr uint16_t ESTIMATED_SIZE = 16; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "select_state_response"; } #endif - uint32_t key{0}; std::string state{}; bool missing_state{false}; void encode(ProtoWriteBuffer buffer) const override; @@ -1565,23 +1499,16 @@ class SelectCommandRequest : public ProtoMessage { bool decode_32bit(uint32_t field_id, Proto32Bit value) override; bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; }; -class ListEntitiesSirenResponse : public ProtoMessage { +class ListEntitiesSirenResponse : public InfoResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 55; static constexpr uint16_t ESTIMATED_SIZE = 67; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "list_entities_siren_response"; } #endif - std::string object_id{}; - uint32_t key{0}; - std::string name{}; - std::string unique_id{}; - std::string icon{}; - bool disabled_by_default{false}; std::vector tones{}; bool supports_duration{false}; bool supports_volume{false}; - enums::EntityCategory entity_category{}; void encode(ProtoWriteBuffer buffer) const override; void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP @@ -1593,14 +1520,13 @@ class ListEntitiesSirenResponse : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class SirenStateResponse : public ProtoMessage { +class SirenStateResponse : public StateResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 56; static constexpr uint16_t ESTIMATED_SIZE = 7; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "siren_state_response"; } #endif - uint32_t key{0}; bool state{false}; void encode(ProtoWriteBuffer buffer) const override; void calculate_size(uint32_t &total_size) const override; @@ -1639,20 +1565,13 @@ class SirenCommandRequest : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class ListEntitiesLockResponse : public ProtoMessage { +class ListEntitiesLockResponse : public InfoResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 58; static constexpr uint16_t ESTIMATED_SIZE = 60; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "list_entities_lock_response"; } #endif - std::string object_id{}; - uint32_t key{0}; - std::string name{}; - std::string unique_id{}; - std::string icon{}; - bool disabled_by_default{false}; - enums::EntityCategory entity_category{}; bool assumed_state{false}; bool supports_open{false}; bool requires_code{false}; @@ -1668,14 +1587,13 @@ class ListEntitiesLockResponse : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class LockStateResponse : public ProtoMessage { +class LockStateResponse : public StateResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 59; static constexpr uint16_t ESTIMATED_SIZE = 7; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "lock_state_response"; } #endif - uint32_t key{0}; enums::LockState state{}; void encode(ProtoWriteBuffer buffer) const override; void calculate_size(uint32_t &total_size) const override; @@ -1709,20 +1627,13 @@ class LockCommandRequest : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class ListEntitiesButtonResponse : public ProtoMessage { +class ListEntitiesButtonResponse : public InfoResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 61; static constexpr uint16_t ESTIMATED_SIZE = 54; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "list_entities_button_response"; } #endif - std::string object_id{}; - uint32_t key{0}; - std::string name{}; - std::string unique_id{}; - std::string icon{}; - bool disabled_by_default{false}; - enums::EntityCategory entity_category{}; std::string device_class{}; void encode(ProtoWriteBuffer buffer) const override; void calculate_size(uint32_t &total_size) const override; @@ -1769,20 +1680,13 @@ class MediaPlayerSupportedFormat : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class ListEntitiesMediaPlayerResponse : public ProtoMessage { +class ListEntitiesMediaPlayerResponse : public InfoResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 63; static constexpr uint16_t ESTIMATED_SIZE = 81; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "list_entities_media_player_response"; } #endif - std::string object_id{}; - uint32_t key{0}; - std::string name{}; - std::string unique_id{}; - std::string icon{}; - bool disabled_by_default{false}; - enums::EntityCategory entity_category{}; bool supports_pause{false}; std::vector supported_formats{}; void encode(ProtoWriteBuffer buffer) const override; @@ -1796,14 +1700,13 @@ class ListEntitiesMediaPlayerResponse : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class MediaPlayerStateResponse : public ProtoMessage { +class MediaPlayerStateResponse : public StateResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 64; static constexpr uint16_t ESTIMATED_SIZE = 14; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "media_player_state_response"; } #endif - uint32_t key{0}; enums::MediaPlayerState state{}; float volume{0.0f}; bool muted{false}; @@ -2653,20 +2556,13 @@ class VoiceAssistantSetConfiguration : public ProtoMessage { protected: bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; }; -class ListEntitiesAlarmControlPanelResponse : public ProtoMessage { +class ListEntitiesAlarmControlPanelResponse : public InfoResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 94; static constexpr uint16_t ESTIMATED_SIZE = 53; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "list_entities_alarm_control_panel_response"; } #endif - std::string object_id{}; - uint32_t key{0}; - std::string name{}; - std::string unique_id{}; - std::string icon{}; - bool disabled_by_default{false}; - enums::EntityCategory entity_category{}; uint32_t supported_features{0}; bool requires_code{false}; bool requires_code_to_arm{false}; @@ -2681,14 +2577,13 @@ class ListEntitiesAlarmControlPanelResponse : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class AlarmControlPanelStateResponse : public ProtoMessage { +class AlarmControlPanelStateResponse : public StateResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 95; static constexpr uint16_t ESTIMATED_SIZE = 7; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "alarm_control_panel_state_response"; } #endif - uint32_t key{0}; enums::AlarmControlPanelState state{}; void encode(ProtoWriteBuffer buffer) const override; void calculate_size(uint32_t &total_size) const override; @@ -2721,20 +2616,13 @@ class AlarmControlPanelCommandRequest : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class ListEntitiesTextResponse : public ProtoMessage { +class ListEntitiesTextResponse : public InfoResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 97; static constexpr uint16_t ESTIMATED_SIZE = 64; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "list_entities_text_response"; } #endif - std::string object_id{}; - uint32_t key{0}; - std::string name{}; - std::string unique_id{}; - std::string icon{}; - bool disabled_by_default{false}; - enums::EntityCategory entity_category{}; uint32_t min_length{0}; uint32_t max_length{0}; std::string pattern{}; @@ -2750,14 +2638,13 @@ class ListEntitiesTextResponse : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class TextStateResponse : public ProtoMessage { +class TextStateResponse : public StateResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 98; static constexpr uint16_t ESTIMATED_SIZE = 16; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "text_state_response"; } #endif - uint32_t key{0}; std::string state{}; bool missing_state{false}; void encode(ProtoWriteBuffer buffer) const override; @@ -2790,20 +2677,13 @@ class TextCommandRequest : public ProtoMessage { bool decode_32bit(uint32_t field_id, Proto32Bit value) override; bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; }; -class ListEntitiesDateResponse : public ProtoMessage { +class ListEntitiesDateResponse : public InfoResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 100; static constexpr uint16_t ESTIMATED_SIZE = 45; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "list_entities_date_response"; } #endif - std::string object_id{}; - uint32_t key{0}; - std::string name{}; - std::string unique_id{}; - std::string icon{}; - bool disabled_by_default{false}; - enums::EntityCategory entity_category{}; void encode(ProtoWriteBuffer buffer) const override; void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP @@ -2815,14 +2695,13 @@ class ListEntitiesDateResponse : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class DateStateResponse : public ProtoMessage { +class DateStateResponse : public StateResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 101; static constexpr uint16_t ESTIMATED_SIZE = 19; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "date_state_response"; } #endif - uint32_t key{0}; bool missing_state{false}; uint32_t year{0}; uint32_t month{0}; @@ -2858,20 +2737,13 @@ class DateCommandRequest : public ProtoMessage { bool decode_32bit(uint32_t field_id, Proto32Bit value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class ListEntitiesTimeResponse : public ProtoMessage { +class ListEntitiesTimeResponse : public InfoResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 103; static constexpr uint16_t ESTIMATED_SIZE = 45; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "list_entities_time_response"; } #endif - std::string object_id{}; - uint32_t key{0}; - std::string name{}; - std::string unique_id{}; - std::string icon{}; - bool disabled_by_default{false}; - enums::EntityCategory entity_category{}; void encode(ProtoWriteBuffer buffer) const override; void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP @@ -2883,14 +2755,13 @@ class ListEntitiesTimeResponse : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class TimeStateResponse : public ProtoMessage { +class TimeStateResponse : public StateResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 104; static constexpr uint16_t ESTIMATED_SIZE = 19; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "time_state_response"; } #endif - uint32_t key{0}; bool missing_state{false}; uint32_t hour{0}; uint32_t minute{0}; @@ -2926,20 +2797,13 @@ class TimeCommandRequest : public ProtoMessage { bool decode_32bit(uint32_t field_id, Proto32Bit value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class ListEntitiesEventResponse : public ProtoMessage { +class ListEntitiesEventResponse : public InfoResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 107; static constexpr uint16_t ESTIMATED_SIZE = 72; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "list_entities_event_response"; } #endif - std::string object_id{}; - uint32_t key{0}; - std::string name{}; - std::string unique_id{}; - std::string icon{}; - bool disabled_by_default{false}; - enums::EntityCategory entity_category{}; std::string device_class{}; std::vector event_types{}; void encode(ProtoWriteBuffer buffer) const override; @@ -2953,14 +2817,13 @@ class ListEntitiesEventResponse : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class EventResponse : public ProtoMessage { +class EventResponse : public StateResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 108; static constexpr uint16_t ESTIMATED_SIZE = 14; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "event_response"; } #endif - uint32_t key{0}; std::string event_type{}; void encode(ProtoWriteBuffer buffer) const override; void calculate_size(uint32_t &total_size) const override; @@ -2972,20 +2835,13 @@ class EventResponse : public ProtoMessage { bool decode_32bit(uint32_t field_id, Proto32Bit value) override; bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; }; -class ListEntitiesValveResponse : public ProtoMessage { +class ListEntitiesValveResponse : public InfoResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 109; static constexpr uint16_t ESTIMATED_SIZE = 60; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "list_entities_valve_response"; } #endif - std::string object_id{}; - uint32_t key{0}; - std::string name{}; - std::string unique_id{}; - std::string icon{}; - bool disabled_by_default{false}; - enums::EntityCategory entity_category{}; std::string device_class{}; bool assumed_state{false}; bool supports_position{false}; @@ -3001,14 +2857,13 @@ class ListEntitiesValveResponse : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class ValveStateResponse : public ProtoMessage { +class ValveStateResponse : public StateResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 110; static constexpr uint16_t ESTIMATED_SIZE = 12; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "valve_state_response"; } #endif - uint32_t key{0}; float position{0.0f}; enums::ValveOperation current_operation{}; void encode(ProtoWriteBuffer buffer) const override; @@ -3042,20 +2897,13 @@ class ValveCommandRequest : public ProtoMessage { bool decode_32bit(uint32_t field_id, Proto32Bit value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class ListEntitiesDateTimeResponse : public ProtoMessage { +class ListEntitiesDateTimeResponse : public InfoResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 112; static constexpr uint16_t ESTIMATED_SIZE = 45; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "list_entities_date_time_response"; } #endif - std::string object_id{}; - uint32_t key{0}; - std::string name{}; - std::string unique_id{}; - std::string icon{}; - bool disabled_by_default{false}; - enums::EntityCategory entity_category{}; void encode(ProtoWriteBuffer buffer) const override; void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP @@ -3067,14 +2915,13 @@ class ListEntitiesDateTimeResponse : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class DateTimeStateResponse : public ProtoMessage { +class DateTimeStateResponse : public StateResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 113; static constexpr uint16_t ESTIMATED_SIZE = 12; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "date_time_state_response"; } #endif - uint32_t key{0}; bool missing_state{false}; uint32_t epoch_seconds{0}; void encode(ProtoWriteBuffer buffer) const override; @@ -3105,20 +2952,13 @@ class DateTimeCommandRequest : public ProtoMessage { protected: bool decode_32bit(uint32_t field_id, Proto32Bit value) override; }; -class ListEntitiesUpdateResponse : public ProtoMessage { +class ListEntitiesUpdateResponse : public InfoResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 116; static constexpr uint16_t ESTIMATED_SIZE = 54; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "list_entities_update_response"; } #endif - std::string object_id{}; - uint32_t key{0}; - std::string name{}; - std::string unique_id{}; - std::string icon{}; - bool disabled_by_default{false}; - enums::EntityCategory entity_category{}; std::string device_class{}; void encode(ProtoWriteBuffer buffer) const override; void calculate_size(uint32_t &total_size) const override; @@ -3131,14 +2971,13 @@ class ListEntitiesUpdateResponse : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; -class UpdateStateResponse : public ProtoMessage { +class UpdateStateResponse : public StateResponseProtoMessage { public: static constexpr uint16_t MESSAGE_TYPE = 117; static constexpr uint16_t ESTIMATED_SIZE = 61; #ifdef HAS_PROTO_MESSAGE_DUMP static constexpr const char *message_name() { return "update_state_response"; } #endif - uint32_t key{0}; bool missing_state{false}; bool in_progress{false}; bool has_progress{false}; diff --git a/script/api_protobuf/api_protobuf.py b/script/api_protobuf/api_protobuf.py index d634be98c4..24b6bef843 100755 --- a/script/api_protobuf/api_protobuf.py +++ b/script/api_protobuf/api_protobuf.py @@ -848,7 +848,10 @@ def calculate_message_estimated_size(desc: descriptor.DescriptorProto) -> int: return total_size -def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]: +def build_message_type( + desc: descriptor.DescriptorProto, + base_class_fields: dict[str, list[descriptor.FieldDescriptorProto]] = None, +) -> tuple[str, str]: public_content: list[str] = [] protected_content: list[str] = [] decode_varint: list[str] = [] @@ -859,6 +862,12 @@ def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]: dump: list[str] = [] size_calc: list[str] = [] + # Check if this message has a base class + base_class = get_base_class(desc) + common_field_names = set() + if base_class and base_class_fields and base_class in base_class_fields: + common_field_names = {f.name for f in base_class_fields[base_class]} + # Get message ID if it's a service message message_id: int | None = get_opt(desc, pb.id) @@ -886,8 +895,14 @@ def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]: ti = RepeatedTypeInfo(field) else: ti = TYPE_INFO[field.type](field) - protected_content.extend(ti.protected_content) - public_content.extend(ti.public_content) + + # Skip field declarations for fields that are in the base class + # but include their encode/decode logic + if field.name not in common_field_names: + protected_content.extend(ti.protected_content) + public_content.extend(ti.public_content) + + # Always include encode/decode logic for all fields encode.append(ti.encode_content) size_calc.append(ti.get_size_calculation(f"this->{ti.field_name}")) @@ -1001,7 +1016,10 @@ def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]: prot += "#endif\n" public_content.append(prot) - out = f"class {desc.name} : public ProtoMessage {{\n" + if base_class: + out = f"class {desc.name} : public {base_class} {{\n" + else: + out = f"class {desc.name} : public ProtoMessage {{\n" out += " public:\n" out += indent("\n".join(public_content)) + "\n" out += "\n" @@ -1033,6 +1051,132 @@ def get_opt( return desc.options.Extensions[opt] +def get_base_class(desc: descriptor.DescriptorProto) -> str | None: + """Get the base_class option from a message descriptor.""" + if not desc.options.HasExtension(pb.base_class): + return None + return desc.options.Extensions[pb.base_class] + + +def collect_messages_by_base_class( + messages: list[descriptor.DescriptorProto], +) -> dict[str, list[descriptor.DescriptorProto]]: + """Group messages by their base_class option.""" + base_class_groups = {} + + for msg in messages: + base_class = get_base_class(msg) + if base_class: + if base_class not in base_class_groups: + base_class_groups[base_class] = [] + base_class_groups[base_class].append(msg) + + return base_class_groups + + +def find_common_fields( + messages: list[descriptor.DescriptorProto], +) -> list[descriptor.FieldDescriptorProto]: + """Find fields that are common to all messages in the list.""" + if not messages: + return [] + + # Start with fields from the first message + first_msg_fields = {field.name: field for field in messages[0].field} + common_fields = [] + + # Check each field to see if it exists in all messages with same type + # Field numbers can vary between messages - derived classes handle the mapping + for field_name, field in first_msg_fields.items(): + is_common = True + + for msg in messages[1:]: + found = False + for other_field in msg.field: + if ( + other_field.name == field_name + and other_field.type == field.type + and other_field.label == field.label + ): + found = True + break + + if not found: + is_common = False + break + + if is_common: + common_fields.append(field) + + # Sort by field number to maintain order + common_fields.sort(key=lambda f: f.number) + return common_fields + + +def build_base_class( + base_class_name: str, + common_fields: list[descriptor.FieldDescriptorProto], +) -> tuple[str, str]: + """Build the base class definition and implementation.""" + public_content = [] + protected_content = [] + + # For base classes, we only declare the fields but don't handle encode/decode + # The derived classes will handle encoding/decoding with their specific field numbers + for field in common_fields: + if field.label == 3: # repeated + ti = RepeatedTypeInfo(field) + else: + ti = TYPE_INFO[field.type](field) + + # Only add field declarations, not encode/decode logic + protected_content.extend(ti.protected_content) + public_content.extend(ti.public_content) + + # Build header + out = f"class {base_class_name} : public ProtoMessage {{\n" + out += " public:\n" + + # Add destructor with override + public_content.insert(0, f"~{base_class_name}() override = default;") + + # Base classes don't implement encode/decode/calculate_size + # Derived classes handle these with their specific field numbers + cpp = "" + + out += indent("\n".join(public_content)) + "\n" + out += "\n" + out += " protected:\n" + out += indent("\n".join(protected_content)) + if protected_content: + out += "\n" + out += "};\n" + + # No implementation needed for base classes + + return out, cpp + + +def generate_base_classes( + base_class_groups: dict[str, list[descriptor.DescriptorProto]], +) -> tuple[str, str]: + """Generate all base classes.""" + all_headers = [] + all_cpp = [] + + for base_class_name, messages in base_class_groups.items(): + # Find common fields + common_fields = find_common_fields(messages) + + if common_fields: + # Generate base class + header, cpp = build_base_class(base_class_name, common_fields) + all_headers.append(header) + all_cpp.append(cpp) + + return "\n".join(all_headers), "\n".join(all_cpp) + + def build_service_message_type( mt: descriptor.DescriptorProto, ) -> tuple[str, str] | None: @@ -1134,8 +1278,25 @@ def main() -> None: mt = file.message_type + # Collect messages by base class + base_class_groups = collect_messages_by_base_class(mt) + + # Find common fields for each base class + base_class_fields = {} + for base_class_name, messages in base_class_groups.items(): + common_fields = find_common_fields(messages) + if common_fields: + base_class_fields[base_class_name] = common_fields + + # Generate base classes + if base_class_fields: + base_headers, base_cpp = generate_base_classes(base_class_groups) + content += base_headers + cpp += base_cpp + + # Generate message types with base class information for m in mt: - s, c = build_message_type(m) + s, c = build_message_type(m, base_class_fields) content += s cpp += c From 28d11553e045854e05e2a0b61e35105e4f8ffddd Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 15 Jun 2025 16:33:38 -0500 Subject: [PATCH 07/14] Reduce Component blocking threshold memory usage by 2 bytes per component (#9081) --- esphome/core/component.cpp | 13 ++++++++++--- esphome/core/component.h | 4 ++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/esphome/core/component.cpp b/esphome/core/component.cpp index dae99a0d22..03c44599e2 100644 --- a/esphome/core/component.cpp +++ b/esphome/core/component.cpp @@ -1,6 +1,7 @@ #include "esphome/core/component.h" #include +#include #include #include "esphome/core/application.h" #include "esphome/core/hal.h" @@ -41,8 +42,8 @@ const uint8_t STATUS_LED_OK = 0x00; const uint8_t STATUS_LED_WARNING = 0x04; // Bit 2 const uint8_t STATUS_LED_ERROR = 0x08; // Bit 3 -const uint32_t WARN_IF_BLOCKING_OVER_MS = 50U; ///< Initial blocking time allowed without warning -const uint32_t WARN_IF_BLOCKING_INCREMENT_MS = 10U; ///< How long the blocking time must be larger to warn again +const uint16_t WARN_IF_BLOCKING_OVER_MS = 50U; ///< Initial blocking time allowed without warning +const uint16_t WARN_IF_BLOCKING_INCREMENT_MS = 10U; ///< How long the blocking time must be larger to warn again uint32_t global_state = 0; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) @@ -122,7 +123,13 @@ const char *Component::get_component_source() const { } bool Component::should_warn_of_blocking(uint32_t blocking_time) { if (blocking_time > this->warn_if_blocking_over_) { - this->warn_if_blocking_over_ = blocking_time + WARN_IF_BLOCKING_INCREMENT_MS; + // Prevent overflow when adding increment - if we're about to overflow, just max out + if (blocking_time + WARN_IF_BLOCKING_INCREMENT_MS < blocking_time || + blocking_time + WARN_IF_BLOCKING_INCREMENT_MS > std::numeric_limits::max()) { + this->warn_if_blocking_over_ = std::numeric_limits::max(); + } else { + this->warn_if_blocking_over_ = static_cast(blocking_time + WARN_IF_BLOCKING_INCREMENT_MS); + } return true; } return false; diff --git a/esphome/core/component.h b/esphome/core/component.h index 7ad4a5e496..d05a965034 100644 --- a/esphome/core/component.h +++ b/esphome/core/component.h @@ -65,7 +65,7 @@ extern const uint8_t STATUS_LED_ERROR; enum class RetryResult { DONE, RETRY }; -extern const uint32_t WARN_IF_BLOCKING_OVER_MS; +extern const uint16_t WARN_IF_BLOCKING_OVER_MS; class Component { public: @@ -318,7 +318,7 @@ class Component { uint8_t component_state_{0x00}; float setup_priority_override_{NAN}; const char *component_source_{nullptr}; - uint32_t warn_if_blocking_over_{WARN_IF_BLOCKING_OVER_MS}; + uint16_t warn_if_blocking_over_{WARN_IF_BLOCKING_OVER_MS}; ///< Warn if blocked for this many ms (max 65.5s) std::string error_message_{}; }; From c17a3b6fccb5ea0b782ef06cc6c9508702ec471e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 15 Jun 2025 16:34:37 -0500 Subject: [PATCH 08/14] Reduce Component memory usage by 20 bytes per component (#9080) Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com> --- esphome/core/component.cpp | 3 ++- esphome/core/component.h | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/esphome/core/component.cpp b/esphome/core/component.cpp index 03c44599e2..0a4606074a 100644 --- a/esphome/core/component.cpp +++ b/esphome/core/component.cpp @@ -85,7 +85,8 @@ void Component::call_setup() { this->setup(); } void Component::call_dump_config() { this->dump_config(); if (this->is_failed()) { - ESP_LOGE(TAG, " Component %s is marked FAILED: %s", this->get_component_source(), this->error_message_.c_str()); + ESP_LOGE(TAG, " Component %s is marked FAILED: %s", this->get_component_source(), + this->error_message_ ? this->error_message_ : "unspecified"); } } diff --git a/esphome/core/component.h b/esphome/core/component.h index d05a965034..f77d40ae35 100644 --- a/esphome/core/component.h +++ b/esphome/core/component.h @@ -319,7 +319,7 @@ class Component { float setup_priority_override_{NAN}; const char *component_source_{nullptr}; uint16_t warn_if_blocking_over_{WARN_IF_BLOCKING_OVER_MS}; ///< Warn if blocked for this many ms (max 65.5s) - std::string error_message_{}; + const char *error_message_{nullptr}; }; /** This class simplifies creating components that periodically check a state. From 8a06c4380db7cb13faa8230507b66863802747e5 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 15 Jun 2025 18:32:36 -0500 Subject: [PATCH 09/14] partition --- esphome/core/application.cpp | 59 +++++++++++++++++++++++++++++++++--- esphome/core/application.h | 25 +++++++++++++++ esphome/core/component.cpp | 8 ++--- esphome/core/component.h | 6 ---- 4 files changed, 83 insertions(+), 15 deletions(-) diff --git a/esphome/core/application.cpp b/esphome/core/application.cpp index 9dda32f0e6..f9d2cf72c6 100644 --- a/esphome/core/application.cpp +++ b/esphome/core/application.cpp @@ -97,11 +97,12 @@ void Application::loop() { // Feed WDT with time this->feed_wdt(last_op_end_time); - for (Component *component : this->looping_components_) { - // Skip components that are done or failed - if (component->should_skip_loop()) { - continue; - } + // Mark that we're in the loop for safe reentrant modifications + this->in_loop_ = true; + + for (this->current_loop_index_ = 0; this->current_loop_index_ < this->looping_components_active_end_; + this->current_loop_index_++) { + Component *component = this->looping_components_[this->current_loop_index_]; // Update the cached time before each component runs this->loop_component_start_time_ = last_op_end_time; @@ -117,6 +118,8 @@ void Application::loop() { this->app_state_ |= new_app_state; this->feed_wdt(last_op_end_time); } + + this->in_loop_ = false; this->app_state_ = new_app_state; // Use the last component's end time instead of calling millis() again @@ -244,6 +247,52 @@ void Application::calculate_looping_components_() { if (obj->has_overridden_loop()) this->looping_components_.push_back(obj); } + // Initially all components are active + this->looping_components_active_end_ = this->looping_components_.size(); +} + +void Application::disable_component_loop(Component *component) { + // Linear search to find component in active section + // Most configs have 10-30 looping components (30 is on the high end) + // O(n) is acceptable here as we optimize for memory, not complexity + for (uint16_t i = 0; i < this->looping_components_active_end_; i++) { + if (this->looping_components_[i] == component) { + // Move last active component to this position + this->looping_components_active_end_--; + if (i != this->looping_components_active_end_) { + this->looping_components_[i] = this->looping_components_[this->looping_components_active_end_]; + this->looping_components_[this->looping_components_active_end_] = component; + + // If we're currently iterating and just swapped the current position + if (this->in_loop_ && i == this->current_loop_index_) { + // Decrement so we'll process the swapped component next + this->current_loop_index_--; + } + } + return; + } + } +} + +void Application::enable_component_loop(Component *component) { + // Single pass through all components to find and move if needed + // With typical 10-30 components, O(n) is faster than maintaining a map + const uint16_t size = this->looping_components_.size(); + for (uint16_t i = 0; i < size; i++) { + if (this->looping_components_[i] == component) { + if (i < this->looping_components_active_end_) { + return; // Already active + } + // Found in inactive section - move to active + if (i != this->looping_components_active_end_) { + Component *temp = this->looping_components_[this->looping_components_active_end_]; + this->looping_components_[this->looping_components_active_end_] = component; + this->looping_components_[i] = temp; + } + this->looping_components_active_end_++; + return; + } + } } #ifdef USE_SOCKET_SELECT_SUPPORT diff --git a/esphome/core/application.h b/esphome/core/application.h index d9ef4fe036..8b2f78beaa 100644 --- a/esphome/core/application.h +++ b/esphome/core/application.h @@ -572,13 +572,38 @@ class Application { void calculate_looping_components_(); + void disable_component_loop(Component *component); + void enable_component_loop(Component *component); + void feed_wdt_arch_(); /// Perform a delay while also monitoring socket file descriptors for readiness void yield_with_select_(uint32_t delay_ms); std::vector components_{}; + + // Partitioned vector design for looping components + // ================================================= + // Components are partitioned into [active | inactive] sections: + // + // looping_components_: [A, B, C, D | E, F] + // ^ + // looping_components_active_end_ (4) + // + // - Components A,B,C,D are active and will be called in loop() + // - Components E,F are inactive (disabled/failed) and won't be called + // - No flag checking needed during iteration - just loop 0 to active_end_ + // - When a component is disabled, it's swapped with the last active component + // and active_end_ is decremented + // - When a component is enabled, it's swapped with the first inactive component + // and active_end_ is incremented + // - This eliminates branch mispredictions from flag checking in the hot loop std::vector looping_components_{}; + uint16_t looping_components_active_end_{0}; + + // For safe reentrant modifications during iteration + uint16_t current_loop_index_{0}; + bool in_loop_{false}; #ifdef USE_BINARY_SENSOR std::vector binary_sensors_{}; diff --git a/esphome/core/component.cpp b/esphome/core/component.cpp index 14deb9c1df..53e57cea6d 100644 --- a/esphome/core/component.cpp +++ b/esphome/core/component.cpp @@ -136,17 +136,21 @@ void Component::mark_failed() { this->component_state_ &= ~COMPONENT_STATE_MASK; this->component_state_ |= COMPONENT_STATE_FAILED; this->status_set_error(); + // Also remove from loop since failed components shouldn't loop + App.disable_component_loop(this); } void Component::disable_loop() { ESP_LOGD(TAG, "%s loop disabled", this->get_component_source()); this->component_state_ &= ~COMPONENT_STATE_MASK; this->component_state_ |= COMPONENT_STATE_LOOP_DONE; + App.disable_component_loop(this); } void Component::enable_loop() { if ((this->component_state_ & COMPONENT_STATE_MASK) == COMPONENT_STATE_LOOP_DONE) { ESP_LOGD(TAG, "%s loop enabled", this->get_component_source()); this->component_state_ &= ~COMPONENT_STATE_MASK; this->component_state_ |= COMPONENT_STATE_LOOP; + App.enable_component_loop(this); } } void Component::reset_to_construction_state() { @@ -185,10 +189,6 @@ bool Component::is_ready() const { return (this->component_state_ & COMPONENT_STATE_MASK) == COMPONENT_STATE_LOOP || (this->component_state_ & COMPONENT_STATE_MASK) == COMPONENT_STATE_SETUP; } -bool Component::should_skip_loop() const { - uint8_t state = this->component_state_ & COMPONENT_STATE_MASK; - return state == COMPONENT_STATE_FAILED || state == COMPONENT_STATE_LOOP_DONE; -} bool Component::can_proceed() { return true; } bool Component::status_has_warning() const { return this->component_state_ & STATUS_LED_WARNING; } bool Component::status_has_error() const { return this->component_state_ & STATUS_LED_ERROR; } diff --git a/esphome/core/component.h b/esphome/core/component.h index 8ce2e87049..f787520026 100644 --- a/esphome/core/component.h +++ b/esphome/core/component.h @@ -169,12 +169,6 @@ class Component { bool is_ready() const; - /** Check if this component should skip its loop execution. - * - * @return True if the component is in FAILED or LOOP_DONE state - */ - bool should_skip_loop() const; - virtual bool can_proceed(); bool status_has_warning() const; From cee7789ab64a90b978d0ac271f61fd4b9f04648f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 15 Jun 2025 18:37:05 -0500 Subject: [PATCH 10/14] tweak --- esphome/core/application.h | 3 +++ esphome/core/component.h | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/esphome/core/application.h b/esphome/core/application.h index 8b2f78beaa..46330cb2ae 100644 --- a/esphome/core/application.h +++ b/esphome/core/application.h @@ -572,6 +572,9 @@ class Application { void calculate_looping_components_(); + // These methods are called by Component::disable_loop() and Component::enable_loop() + // Components should not call these directly - use this->disable_loop() or this->enable_loop() + // to ensure component state is properly updated along with the loop partition void disable_component_loop(Component *component); void enable_component_loop(Component *component); diff --git a/esphome/core/component.h b/esphome/core/component.h index f787520026..e2adb66c47 100644 --- a/esphome/core/component.h +++ b/esphome/core/component.h @@ -155,6 +155,9 @@ class Component { * * This is useful for components that only need to run for a certain period of time * or when inactive, saving CPU cycles. + * + * @note Components should call this->disable_loop() on themselves, not on other components. + * This ensures the component's state is properly updated along with the loop partition. */ void disable_loop(); @@ -162,6 +165,9 @@ class Component { * * This is useful for components that transition between active and inactive states * and need to re-enable their loop() method when becoming active again. + * + * @note Components should call this->enable_loop() on themselves, not on other components. + * This ensures the component's state is properly updated along with the loop partition. */ void enable_loop(); From f711706b1acf85559ae5723615b9b8a86862e91a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 15 Jun 2025 18:40:08 -0500 Subject: [PATCH 11/14] Fix ESP32 Improv component to re-enable loop when service starts again --- esphome/components/esp32_improv/esp32_improv_component.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/esphome/components/esp32_improv/esp32_improv_component.cpp b/esphome/components/esp32_improv/esp32_improv_component.cpp index ff150a3d69..d41094fda1 100644 --- a/esphome/components/esp32_improv/esp32_improv_component.cpp +++ b/esphome/components/esp32_improv/esp32_improv_component.cpp @@ -256,6 +256,7 @@ void ESP32ImprovComponent::start() { ESP_LOGD(TAG, "Setting Improv to start"); this->should_start_ = true; + this->enable_loop(); } void ESP32ImprovComponent::stop() { From 975520949963e5565cb01272c56d0d74df0ce624 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 15 Jun 2025 18:42:40 -0500 Subject: [PATCH 12/14] comments --- esphome/core/application.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/esphome/core/application.cpp b/esphome/core/application.cpp index f9d2cf72c6..e1432a1eba 100644 --- a/esphome/core/application.cpp +++ b/esphome/core/application.cpp @@ -252,6 +252,7 @@ void Application::calculate_looping_components_() { } void Application::disable_component_loop(Component *component) { + // This method must be reentrant - components can disable themselves during their own loop() call // Linear search to find component in active section // Most configs have 10-30 looping components (30 is on the high end) // O(n) is acceptable here as we optimize for memory, not complexity @@ -275,6 +276,7 @@ void Application::disable_component_loop(Component *component) { } void Application::enable_component_loop(Component *component) { + // This method must be reentrant - components can re-enable themselves during their own loop() call // Single pass through all components to find and move if needed // With typical 10-30 components, O(n) is faster than maintaining a map const uint16_t size = this->looping_components_.size(); From 711b0a291bd65b756384d786c1821bf795b669af Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 15 Jun 2025 18:44:15 -0500 Subject: [PATCH 13/14] comments --- esphome/core/application.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/esphome/core/application.cpp b/esphome/core/application.cpp index e1432a1eba..a47bfdf484 100644 --- a/esphome/core/application.cpp +++ b/esphome/core/application.cpp @@ -261,8 +261,7 @@ void Application::disable_component_loop(Component *component) { // Move last active component to this position this->looping_components_active_end_--; if (i != this->looping_components_active_end_) { - this->looping_components_[i] = this->looping_components_[this->looping_components_active_end_]; - this->looping_components_[this->looping_components_active_end_] = component; + std::swap(this->looping_components_[i], this->looping_components_[this->looping_components_active_end_]); // If we're currently iterating and just swapped the current position if (this->in_loop_ && i == this->current_loop_index_) { @@ -287,9 +286,7 @@ void Application::enable_component_loop(Component *component) { } // Found in inactive section - move to active if (i != this->looping_components_active_end_) { - Component *temp = this->looping_components_[this->looping_components_active_end_]; - this->looping_components_[this->looping_components_active_end_] = component; - this->looping_components_[i] = temp; + std::swap(this->looping_components_[i], this->looping_components_[this->looping_components_active_end_]); } this->looping_components_active_end_++; return; From fd31afe09cfe93110a480fffbee97bdeeb8681a8 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 15 Jun 2025 18:58:32 -0500 Subject: [PATCH 14/14] tidy --- esphome/core/application.cpp | 4 ++-- esphome/core/application.h | 4 ++-- esphome/core/component.cpp | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/esphome/core/application.cpp b/esphome/core/application.cpp index a47bfdf484..74208bbe22 100644 --- a/esphome/core/application.cpp +++ b/esphome/core/application.cpp @@ -251,7 +251,7 @@ void Application::calculate_looping_components_() { this->looping_components_active_end_ = this->looping_components_.size(); } -void Application::disable_component_loop(Component *component) { +void Application::disable_component_loop_(Component *component) { // This method must be reentrant - components can disable themselves during their own loop() call // Linear search to find component in active section // Most configs have 10-30 looping components (30 is on the high end) @@ -274,7 +274,7 @@ void Application::disable_component_loop(Component *component) { } } -void Application::enable_component_loop(Component *component) { +void Application::enable_component_loop_(Component *component) { // This method must be reentrant - components can re-enable themselves during their own loop() call // Single pass through all components to find and move if needed // With typical 10-30 components, O(n) is faster than maintaining a map diff --git a/esphome/core/application.h b/esphome/core/application.h index fc6f53a7c8..ea298638d2 100644 --- a/esphome/core/application.h +++ b/esphome/core/application.h @@ -575,8 +575,8 @@ class Application { // These methods are called by Component::disable_loop() and Component::enable_loop() // Components should not call these directly - use this->disable_loop() or this->enable_loop() // to ensure component state is properly updated along with the loop partition - void disable_component_loop(Component *component); - void enable_component_loop(Component *component); + void disable_component_loop_(Component *component); + void enable_component_loop_(Component *component); void feed_wdt_arch_(); diff --git a/esphome/core/component.cpp b/esphome/core/component.cpp index 2284a53fcd..3117f49ac1 100644 --- a/esphome/core/component.cpp +++ b/esphome/core/component.cpp @@ -145,20 +145,20 @@ void Component::mark_failed() { this->component_state_ |= COMPONENT_STATE_FAILED; this->status_set_error(); // Also remove from loop since failed components shouldn't loop - App.disable_component_loop(this); + App.disable_component_loop_(this); } void Component::disable_loop() { ESP_LOGD(TAG, "%s loop disabled", this->get_component_source()); this->component_state_ &= ~COMPONENT_STATE_MASK; this->component_state_ |= COMPONENT_STATE_LOOP_DONE; - App.disable_component_loop(this); + App.disable_component_loop_(this); } void Component::enable_loop() { if ((this->component_state_ & COMPONENT_STATE_MASK) == COMPONENT_STATE_LOOP_DONE) { ESP_LOGD(TAG, "%s loop enabled", this->get_component_source()); this->component_state_ &= ~COMPONENT_STATE_MASK; this->component_state_ |= COMPONENT_STATE_LOOP; - App.enable_component_loop(this); + App.enable_component_loop_(this); } } void Component::reset_to_construction_state() {