diff --git a/esphome/components/api/api_connection.cpp b/esphome/components/api/api_connection.cpp index 2b34b2eb69..5a90ae86e9 100644 --- a/esphome/components/api/api_connection.cpp +++ b/esphome/components/api/api_connection.cpp @@ -332,12 +332,8 @@ uint16_t APIConnection::encode_message_to_buffer(ProtoMessage &msg, uint8_t mess std::vector &shared_buf = conn->parent_->get_shared_buffer_ref(); if (conn->flags_.batch_first_message) { - // First message - clear flag + // First message - buffer already prepared by caller, just clear flag conn->flags_.batch_first_message = false; - // If buffer not prepped by caller (batch pre-reserves with size == header_padding), prep now - if (shared_buf.size() != header_padding) { - conn->prepare_first_message_buffer(shared_buf, header_padding, total_calculated_size); - } } else { // Batch message second or later // Add padding for previous message footer + this message header @@ -1040,7 +1036,7 @@ void APIConnection::try_send_camera_image_() { msg.device_id = camera::Camera::instance()->get_device_id(); #endif - if (!this->send_message_(msg, CameraImageResponse::MESSAGE_TYPE)) { + if (!this->send_message_impl(msg, CameraImageResponse::MESSAGE_TYPE)) { return; // Send failed, try again later } this->image_reader_->consume_data(to_send); @@ -1448,7 +1444,7 @@ bool APIConnection::try_send_log_message(int level, const char *tag, const char SubscribeLogsResponse msg; msg.level = static_cast(level); msg.set_message(reinterpret_cast(line), message_len); - return this->send_message_(msg, SubscribeLogsResponse::MESSAGE_TYPE); + return this->send_message_impl(msg, SubscribeLogsResponse::MESSAGE_TYPE); } void APIConnection::complete_authentication_() { @@ -1777,6 +1773,14 @@ bool APIConnection::try_to_clear_buffer(bool log_out_of_space) { } return false; } +bool APIConnection::send_message_impl(const ProtoMessage &msg, uint8_t message_type) { + ProtoSize size; + msg.calculate_size(size); + std::vector &shared_buf = this->parent_->get_shared_buffer_ref(); + this->prepare_first_message_buffer(shared_buf, size.get_size()); + msg.encode({&shared_buf}); + return this->send_buffer({&shared_buf}, message_type); +} bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint8_t message_type) { const bool is_log_message = (message_type == SubscribeLogsResponse::MESSAGE_TYPE); @@ -1838,6 +1842,23 @@ void APIConnection::DeferredBatch::add_item_front(EntityBase *entity, uint8_t me } } +bool APIConnection::send_message_smart_(EntityBase *entity, uint8_t message_type, uint8_t estimated_size, + uint8_t aux_data_index) { + if (this->should_send_immediately_(message_type) && this->helper_->can_write_without_blocking()) { + auto &shared_buf = this->parent_->get_shared_buffer_ref(); + this->prepare_first_message_buffer(shared_buf, estimated_size); + DeferredBatch::BatchItem item{entity, message_type, estimated_size, aux_data_index}; + if (this->dispatch_message_(item, MAX_BATCH_PACKET_SIZE, true) && + this->send_buffer(ProtoWriteBuffer{&shared_buf}, message_type)) { +#ifdef HAS_PROTO_MESSAGE_DUMP + this->log_batch_item_(item); +#endif + return true; + } + } + return this->schedule_message_(entity, message_type, estimated_size, aux_data_index); +} + bool APIConnection::schedule_batch_() { if (!this->flags_.batch_scheduled) { this->flags_.batch_scheduled = true; @@ -1866,10 +1887,21 @@ void APIConnection::process_batch_() { auto &shared_buf = this->parent_->get_shared_buffer_ref(); size_t num_items = this->deferred_batch_.size(); - // Fast path for single message - allocate exact size needed + // Cache these values to avoid repeated virtual calls + const uint8_t header_padding = this->helper_->frame_header_padding(); + const uint8_t footer_size = this->helper_->frame_footer_size(); + + // Pre-calculate exact buffer size needed based on message types + uint32_t total_estimated_size = num_items * (header_padding + footer_size); + for (size_t i = 0; i < num_items; i++) { + total_estimated_size += this->deferred_batch_[i].estimated_size; + } + + this->prepare_first_message_buffer(shared_buf, header_padding, total_estimated_size); + + // Fast path for single message - buffer already allocated above if (num_items == 1) { const auto &item = this->deferred_batch_[0]; - // Let dispatch_message_ calculate size and encode if it fits uint16_t payload_size = this->dispatch_message_(item, std::numeric_limits::max(), true); @@ -1892,29 +1924,8 @@ void APIConnection::process_batch_() { // Stack-allocated array for message info alignas(MessageInfo) char message_info_storage[MAX_MESSAGES_PER_BATCH * sizeof(MessageInfo)]; MessageInfo *message_info = reinterpret_cast(message_info_storage); - size_t message_count = 0; - - // Cache these values to avoid repeated virtual calls - const uint8_t header_padding = this->helper_->frame_header_padding(); - const uint8_t footer_size = this->helper_->frame_footer_size(); - - // Initialize buffer and tracking variables - shared_buf.clear(); - - // Pre-calculate exact buffer size needed based on message types - uint32_t total_estimated_size = num_items * (header_padding + footer_size); - for (size_t i = 0; i < this->deferred_batch_.size(); i++) { - const auto &item = this->deferred_batch_[i]; - total_estimated_size += item.estimated_size; - } - - // Calculate total overhead for all messages - // Reserve based on estimated size (much more accurate than 24-byte worst-case) - shared_buf.reserve(total_estimated_size); - size_t items_processed = 0; uint16_t remaining_size = std::numeric_limits::max(); - // Track where each message's header padding begins in the buffer // For plaintext: this is where the 6-byte header padding starts // For noise: this is where the 7-byte header padding starts @@ -1940,10 +1951,7 @@ void APIConnection::process_batch_() { // This avoids default-constructing all MAX_MESSAGES_PER_BATCH elements // Explicit destruction is not needed because MessageInfo is trivially destructible, // as ensured by the static_assert in its definition. - new (&message_info[message_count++]) MessageInfo(item.message_type, current_offset, proto_payload_size); - - // Update tracking variables - items_processed++; + new (&message_info[items_processed++]) MessageInfo(item.message_type, current_offset, proto_payload_size); // After first message, set remaining size to MAX_BATCH_PACKET_SIZE to avoid fragmentation if (items_processed == 1) { remaining_size = MAX_BATCH_PACKET_SIZE; @@ -1966,7 +1974,7 @@ void APIConnection::process_batch_() { // Send all collected messages APIError err = this->helper_->write_protobuf_messages(ProtoWriteBuffer{&shared_buf}, - std::span(message_info, message_count)); + std::span(message_info, items_processed)); if (err != APIError::OK && err != APIError::WOULD_BLOCK) { this->fatal_error_with_log_(LOG_STR("Batch write failed"), err); } diff --git a/esphome/components/api/api_connection.h b/esphome/components/api/api_connection.h index 9e2ad0e946..c564814c95 100644 --- a/esphome/components/api/api_connection.h +++ b/esphome/components/api/api_connection.h @@ -255,17 +255,7 @@ class APIConnection final : public APIServerConnection { void on_fatal_error() override; void on_no_setup_connection() override; - ProtoWriteBuffer create_buffer(uint32_t reserve_size) override { - // FIXME: ensure no recursive writes can happen - - // Get header padding size - used for both reserve and insert - uint8_t header_padding = this->helper_->frame_header_padding(); - // Get shared buffer from parent server - std::vector &shared_buf = this->parent_->get_shared_buffer_ref(); - this->prepare_first_message_buffer(shared_buf, header_padding, - reserve_size + header_padding + this->helper_->frame_footer_size()); - return {&shared_buf}; - } + bool send_message_impl(const ProtoMessage &msg, uint8_t message_type) override; void prepare_first_message_buffer(std::vector &shared_buf, size_t header_padding, size_t total_size) { shared_buf.clear(); @@ -277,6 +267,13 @@ class APIConnection final : public APIServerConnection { shared_buf.resize(header_padding); } + // Convenience overload - computes frame overhead internally + void prepare_first_message_buffer(std::vector &shared_buf, size_t payload_size) { + const uint8_t header_padding = this->helper_->frame_header_padding(); + const uint8_t footer_size = this->helper_->frame_footer_size(); + this->prepare_first_message_buffer(shared_buf, header_padding, payload_size + header_padding + footer_size); + } + bool try_to_clear_buffer(bool log_out_of_space); bool send_buffer(ProtoWriteBuffer buffer, uint8_t message_type) override; @@ -653,19 +650,7 @@ class APIConnection final : public APIServerConnection { // Tries immediate send if should_send_immediately_() returns true and buffer has space // Falls back to batching if immediate send fails or isn't applicable bool send_message_smart_(EntityBase *entity, uint8_t message_type, uint8_t estimated_size, - uint8_t aux_data_index = DeferredBatch::AUX_DATA_UNUSED) { - if (this->should_send_immediately_(message_type) && this->helper_->can_write_without_blocking()) { - DeferredBatch::BatchItem item{entity, message_type, estimated_size, aux_data_index}; - if (this->dispatch_message_(item, MAX_BATCH_PACKET_SIZE, true) && - this->send_buffer(ProtoWriteBuffer{&this->parent_->get_shared_buffer_ref()}, message_type)) { -#ifdef HAS_PROTO_MESSAGE_DUMP - this->log_batch_item_(item); -#endif - return true; - } - } - return this->schedule_message_(entity, message_type, estimated_size, aux_data_index); - } + uint8_t aux_data_index = DeferredBatch::AUX_DATA_UNUSED); // Helper function to schedule a deferred message with known message type bool schedule_message_(EntityBase *entity, uint8_t message_type, uint8_t estimated_size, diff --git a/esphome/components/api/api_pb2_service.h b/esphome/components/api/api_pb2_service.h index 80a61c1041..e2bc1609ed 100644 --- a/esphome/components/api/api_pb2_service.h +++ b/esphome/components/api/api_pb2_service.h @@ -23,7 +23,7 @@ class APIServerConnectionBase : public ProtoService { DumpBuffer dump_buf; this->log_send_message_(msg.message_name(), msg.dump_to(dump_buf)); #endif - return this->send_message_(msg, message_type); + return this->send_message_impl(msg, message_type); } virtual void on_hello_request(const HelloRequest &value){}; diff --git a/esphome/components/api/proto.h b/esphome/components/api/proto.h index 552b4a4625..92978f765f 100644 --- a/esphome/components/api/proto.h +++ b/esphome/components/api/proto.h @@ -957,32 +957,16 @@ class ProtoService { virtual bool is_connection_setup() = 0; virtual void on_fatal_error() = 0; virtual void on_no_setup_connection() = 0; - /** - * Create a buffer with a reserved size. - * @param reserve_size The number of bytes to pre-allocate in the buffer. This is a hint - * to optimize memory usage and avoid reallocations during encoding. - * Implementations should aim to allocate at least this size. - * @return A ProtoWriteBuffer object with the reserved size. - */ - virtual ProtoWriteBuffer create_buffer(uint32_t reserve_size) = 0; virtual bool send_buffer(ProtoWriteBuffer buffer, uint8_t message_type) = 0; virtual void read_message(uint32_t msg_size, uint32_t msg_type, const uint8_t *msg_data) = 0; - - // Optimized method that pre-allocates buffer based on message size - bool send_message_(const ProtoMessage &msg, uint8_t message_type) { - ProtoSize size; - msg.calculate_size(size); - uint32_t msg_size = size.get_size(); - - // Create a pre-sized buffer - auto buffer = this->create_buffer(msg_size); - - // Encode message into the buffer - msg.encode(buffer); - - // Send the buffer - return this->send_buffer(buffer, message_type); - } + /** + * Send a protobuf message by calculating its size, allocating a buffer, encoding, and sending. + * This is the implementation method - callers should use send_message() which adds logging. + * @param msg The protobuf message to send. + * @param message_type The message type identifier. + * @return True if the message was sent successfully, false otherwise. + */ + virtual bool send_message_impl(const ProtoMessage &msg, uint8_t message_type) = 0; // Authentication helper methods inline bool check_connection_setup_() { diff --git a/script/api_protobuf/api_protobuf.py b/script/api_protobuf/api_protobuf.py index 8baf6acf11..4021a062ca 100755 --- a/script/api_protobuf/api_protobuf.py +++ b/script/api_protobuf/api_protobuf.py @@ -2848,7 +2848,7 @@ static const char *const TAG = "api.service"; hpp += " DumpBuffer dump_buf;\n" hpp += " this->log_send_message_(msg.message_name(), msg.dump_to(dump_buf));\n" hpp += "#endif\n" - hpp += " return this->send_message_(msg, message_type);\n" + hpp += " return this->send_message_impl(msg, message_type);\n" hpp += " }\n\n" # Add logging helper method implementations to cpp