mirror of
				https://github.com/esphome/esphome.git
				synced 2025-10-30 22:53:59 +00:00 
			
		
		
		
	Optimize API frame helper buffer management (#8805)
Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com>
This commit is contained in:
		| @@ -7,20 +7,13 @@ | ||||
| #include "proto.h" | ||||
| #include "api_pb2_size.h" | ||||
| #include <cstring> | ||||
| #include <cinttypes> | ||||
|  | ||||
| namespace esphome { | ||||
| namespace api { | ||||
|  | ||||
| static const char *const TAG = "api.socket"; | ||||
|  | ||||
| /// Is the given return value (from write syscalls) a wouldblock error? | ||||
| bool is_would_block(ssize_t ret) { | ||||
|   if (ret == -1) { | ||||
|     return errno == EWOULDBLOCK || errno == EAGAIN; | ||||
|   } | ||||
|   return ret == 0; | ||||
| } | ||||
|  | ||||
| const char *api_error_to_str(APIError err) { | ||||
|   // not using switch to ensure compiler doesn't try to build a big table out of it | ||||
|   if (err == APIError::OK) { | ||||
| @@ -73,92 +66,154 @@ const char *api_error_to_str(APIError err) { | ||||
|   return "UNKNOWN"; | ||||
| } | ||||
|  | ||||
| // Common implementation for writing raw data to socket | ||||
| template<typename StateEnum> | ||||
| APIError APIFrameHelper::write_raw_(const struct iovec *iov, int iovcnt, socket::Socket *socket, | ||||
|                                     std::vector<uint8_t> &tx_buf, const std::string &info, StateEnum &state, | ||||
|                                     StateEnum failed_state) { | ||||
|   // This method writes data to socket or buffers it | ||||
| // Helper method to buffer data from IOVs | ||||
| void APIFrameHelper::buffer_data_from_iov_(const struct iovec *iov, int iovcnt, uint16_t total_write_len) { | ||||
|   SendBuffer buffer; | ||||
|   buffer.data.reserve(total_write_len); | ||||
|   for (int i = 0; i < iovcnt; i++) { | ||||
|     const uint8_t *data = reinterpret_cast<uint8_t *>(iov[i].iov_base); | ||||
|     buffer.data.insert(buffer.data.end(), data, data + iov[i].iov_len); | ||||
|   } | ||||
|   this->tx_buf_.push_back(std::move(buffer)); | ||||
| } | ||||
|  | ||||
| // This method writes data to socket or buffers it | ||||
| APIError APIFrameHelper::write_raw_(const struct iovec *iov, int iovcnt) { | ||||
|   // Returns APIError::OK if successful (or would block, but data has been buffered) | ||||
|   // Returns APIError::SOCKET_WRITE_FAILED if socket write failed, and sets state to failed_state | ||||
|   // Returns APIError::SOCKET_WRITE_FAILED if socket write failed, and sets state to FAILED | ||||
|  | ||||
|   if (iovcnt == 0) | ||||
|     return APIError::OK;  // Nothing to do, success | ||||
|  | ||||
|   size_t total_write_len = 0; | ||||
|   uint16_t total_write_len = 0; | ||||
|   for (int i = 0; i < iovcnt; i++) { | ||||
| #ifdef HELPER_LOG_PACKETS | ||||
|     ESP_LOGVV(TAG, "Sending raw: %s", | ||||
|               format_hex_pretty(reinterpret_cast<uint8_t *>(iov[i].iov_base), iov[i].iov_len).c_str()); | ||||
| #endif | ||||
|     total_write_len += iov[i].iov_len; | ||||
|     total_write_len += static_cast<uint16_t>(iov[i].iov_len); | ||||
|   } | ||||
|  | ||||
|   if (!tx_buf.empty()) { | ||||
|     // try to empty tx_buf first | ||||
|     while (!tx_buf.empty()) { | ||||
|       ssize_t sent = socket->write(tx_buf.data(), tx_buf.size()); | ||||
|       if (is_would_block(sent)) { | ||||
|         break; | ||||
|       } else if (sent == -1) { | ||||
|         ESP_LOGVV(TAG, "%s: Socket write failed with errno %d", info.c_str(), errno); | ||||
|         state = failed_state; | ||||
|         return APIError::SOCKET_WRITE_FAILED;  // Socket write failed | ||||
|       } | ||||
|       // TODO: inefficient if multiple packets in txbuf | ||||
|       // replace with deque of buffers | ||||
|       tx_buf.erase(tx_buf.begin(), tx_buf.begin() + sent); | ||||
|   // Try to send any existing buffered data first if there is any | ||||
|   if (!this->tx_buf_.empty()) { | ||||
|     APIError send_result = try_send_tx_buf_(); | ||||
|     // If real error occurred (not just WOULD_BLOCK), return it | ||||
|     if (send_result != APIError::OK && send_result != APIError::WOULD_BLOCK) { | ||||
|       return send_result; | ||||
|     } | ||||
|  | ||||
|     // If there is still data in the buffer, we can't send, buffer | ||||
|     // the new data and return | ||||
|     if (!this->tx_buf_.empty()) { | ||||
|       this->buffer_data_from_iov_(iov, iovcnt, total_write_len); | ||||
|       return APIError::OK;  // Success, data buffered | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   if (!tx_buf.empty()) { | ||||
|     // tx buf not empty, can't write now because then stream would be inconsistent | ||||
|     // Reserve space upfront to avoid multiple reallocations | ||||
|     tx_buf.reserve(tx_buf.size() + total_write_len); | ||||
|     for (int i = 0; i < iovcnt; i++) { | ||||
|       tx_buf.insert(tx_buf.end(), reinterpret_cast<uint8_t *>(iov[i].iov_base), | ||||
|                     reinterpret_cast<uint8_t *>(iov[i].iov_base) + iov[i].iov_len); | ||||
|     } | ||||
|     return APIError::OK;  // Success, data buffered | ||||
|   } | ||||
|   // Try to send directly if no buffered data | ||||
|   ssize_t sent = this->socket_->writev(iov, iovcnt); | ||||
|  | ||||
|   ssize_t sent = socket->writev(iov, iovcnt); | ||||
|   if (is_would_block(sent)) { | ||||
|     // operation would block, add buffer to tx_buf | ||||
|     // Reserve space upfront to avoid multiple reallocations | ||||
|     tx_buf.reserve(tx_buf.size() + total_write_len); | ||||
|     for (int i = 0; i < iovcnt; i++) { | ||||
|       tx_buf.insert(tx_buf.end(), reinterpret_cast<uint8_t *>(iov[i].iov_base), | ||||
|                     reinterpret_cast<uint8_t *>(iov[i].iov_base) + iov[i].iov_len); | ||||
|   if (sent == -1) { | ||||
|     if (errno == EWOULDBLOCK || errno == EAGAIN) { | ||||
|       // Socket would block, buffer the data | ||||
|       this->buffer_data_from_iov_(iov, iovcnt, total_write_len); | ||||
|       return APIError::OK;  // Success, data buffered | ||||
|     } | ||||
|     return APIError::OK;  // Success, data buffered | ||||
|   } else if (sent == -1) { | ||||
|     // an error occurred | ||||
|     ESP_LOGVV(TAG, "%s: Socket write failed with errno %d", info.c_str(), errno); | ||||
|     state = failed_state; | ||||
|     // Socket error | ||||
|     ESP_LOGVV(TAG, "%s: Socket write failed with errno %d", this->info_.c_str(), errno); | ||||
|     this->state_ = State::FAILED; | ||||
|     return APIError::SOCKET_WRITE_FAILED;  // Socket write failed | ||||
|   } else if ((size_t) sent != total_write_len) { | ||||
|     // partially sent, add end to tx_buf | ||||
|     size_t remaining = total_write_len - sent; | ||||
|     // Reserve space upfront to avoid multiple reallocations | ||||
|     tx_buf.reserve(tx_buf.size() + remaining); | ||||
|   } else if (static_cast<uint16_t>(sent) < total_write_len) { | ||||
|     // Partially sent, buffer the remaining data | ||||
|     SendBuffer buffer; | ||||
|     uint16_t to_consume = static_cast<uint16_t>(sent); | ||||
|     uint16_t remaining = total_write_len - static_cast<uint16_t>(sent); | ||||
|  | ||||
|     buffer.data.reserve(remaining); | ||||
|  | ||||
|     size_t to_consume = sent; | ||||
|     for (int i = 0; i < iovcnt; i++) { | ||||
|       if (to_consume >= iov[i].iov_len) { | ||||
|         to_consume -= iov[i].iov_len; | ||||
|         // This segment was fully sent | ||||
|         to_consume -= static_cast<uint16_t>(iov[i].iov_len); | ||||
|       } else { | ||||
|         tx_buf.insert(tx_buf.end(), reinterpret_cast<uint8_t *>(iov[i].iov_base) + to_consume, | ||||
|                       reinterpret_cast<uint8_t *>(iov[i].iov_base) + iov[i].iov_len); | ||||
|         // This segment was partially sent or not sent at all | ||||
|         const uint8_t *data = reinterpret_cast<uint8_t *>(iov[i].iov_base) + to_consume; | ||||
|         uint16_t len = static_cast<uint16_t>(iov[i].iov_len) - to_consume; | ||||
|         buffer.data.insert(buffer.data.end(), data, data + len); | ||||
|         to_consume = 0; | ||||
|       } | ||||
|     } | ||||
|     return APIError::OK;  // Success, data buffered | ||||
|  | ||||
|     this->tx_buf_.push_back(std::move(buffer)); | ||||
|   } | ||||
|   return APIError::OK;  // Success, all data sent | ||||
|  | ||||
|   return APIError::OK;  // Success, all data sent or buffered | ||||
| } | ||||
|  | ||||
| #define HELPER_LOG(msg, ...) ESP_LOGVV(TAG, "%s: " msg, info_.c_str(), ##__VA_ARGS__) | ||||
| // Common implementation for trying to send buffered data | ||||
| // IMPORTANT: Caller MUST ensure tx_buf_ is not empty before calling this method | ||||
| APIError APIFrameHelper::try_send_tx_buf_() { | ||||
|   // Try to send from tx_buf - we assume it's not empty as it's the caller's responsibility to check | ||||
|   bool tx_buf_empty = false; | ||||
|   while (!tx_buf_empty) { | ||||
|     // Get the first buffer in the queue | ||||
|     SendBuffer &front_buffer = this->tx_buf_.front(); | ||||
|  | ||||
|     // Try to send the remaining data in this buffer | ||||
|     ssize_t sent = this->socket_->write(front_buffer.current_data(), front_buffer.remaining()); | ||||
|  | ||||
|     if (sent == -1) { | ||||
|       if (errno != EWOULDBLOCK && errno != EAGAIN) { | ||||
|         // Real socket error (not just would block) | ||||
|         ESP_LOGVV(TAG, "%s: Socket write failed with errno %d", this->info_.c_str(), errno); | ||||
|         this->state_ = State::FAILED; | ||||
|         return APIError::SOCKET_WRITE_FAILED;  // Socket write failed | ||||
|       } | ||||
|       // Socket would block, we'll try again later | ||||
|       return APIError::WOULD_BLOCK; | ||||
|     } else if (sent == 0) { | ||||
|       // Nothing sent but not an error | ||||
|       return APIError::WOULD_BLOCK; | ||||
|     } else if (static_cast<uint16_t>(sent) < front_buffer.remaining()) { | ||||
|       // Partially sent, update offset | ||||
|       // Cast to ensure no overflow issues with uint16_t | ||||
|       front_buffer.offset += static_cast<uint16_t>(sent); | ||||
|       return APIError::WOULD_BLOCK;  // Stop processing more buffers if we couldn't send a complete buffer | ||||
|     } else { | ||||
|       // Buffer completely sent, remove it from the queue | ||||
|       this->tx_buf_.pop_front(); | ||||
|       // Update empty status for the loop condition | ||||
|       tx_buf_empty = this->tx_buf_.empty(); | ||||
|       // Continue loop to try sending the next buffer | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   return APIError::OK;  // All buffers sent successfully | ||||
| } | ||||
|  | ||||
| APIError APIFrameHelper::init_common_() { | ||||
|   if (state_ != State::INITIALIZE || this->socket_ == nullptr) { | ||||
|     ESP_LOGVV(TAG, "%s: Bad state for init %d", this->info_.c_str(), (int) state_); | ||||
|     return APIError::BAD_STATE; | ||||
|   } | ||||
|   int err = this->socket_->setblocking(false); | ||||
|   if (err != 0) { | ||||
|     state_ = State::FAILED; | ||||
|     ESP_LOGVV(TAG, "%s: Setting nonblocking failed with errno %d", this->info_.c_str(), errno); | ||||
|     return APIError::TCP_NONBLOCKING_FAILED; | ||||
|   } | ||||
|  | ||||
|   int enable = 1; | ||||
|   err = this->socket_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int)); | ||||
|   if (err != 0) { | ||||
|     state_ = State::FAILED; | ||||
|     ESP_LOGVV(TAG, "%s: Setting nodelay failed with errno %d", this->info_.c_str(), errno); | ||||
|     return APIError::TCP_NODELAY_FAILED; | ||||
|   } | ||||
|   return APIError::OK; | ||||
| } | ||||
|  | ||||
| #define HELPER_LOG(msg, ...) ESP_LOGVV(TAG, "%s: " msg, this->info_.c_str(), ##__VA_ARGS__) | ||||
| // uncomment to log raw packets | ||||
| //#define HELPER_LOG_PACKETS | ||||
|  | ||||
| @@ -206,23 +261,9 @@ std::string noise_err_to_str(int err) { | ||||
|  | ||||
| /// Initialize the frame helper, returns OK if successful. | ||||
| APIError APINoiseFrameHelper::init() { | ||||
|   if (state_ != State::INITIALIZE || socket_ == nullptr) { | ||||
|     HELPER_LOG("Bad state for init %d", (int) state_); | ||||
|     return APIError::BAD_STATE; | ||||
|   } | ||||
|   int err = socket_->setblocking(false); | ||||
|   if (err != 0) { | ||||
|     state_ = State::FAILED; | ||||
|     HELPER_LOG("Setting nonblocking failed with errno %d", errno); | ||||
|     return APIError::TCP_NONBLOCKING_FAILED; | ||||
|   } | ||||
|  | ||||
|   int enable = 1; | ||||
|   err = socket_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int)); | ||||
|   if (err != 0) { | ||||
|     state_ = State::FAILED; | ||||
|     HELPER_LOG("Setting nodelay failed with errno %d", errno); | ||||
|     return APIError::TCP_NODELAY_FAILED; | ||||
|   APIError err = init_common_(); | ||||
|   if (err != APIError::OK) { | ||||
|     return err; | ||||
|   } | ||||
|  | ||||
|   // init prologue | ||||
| @@ -234,17 +275,16 @@ APIError APINoiseFrameHelper::init() { | ||||
| /// Run through handshake messages (if in that phase) | ||||
| APIError APINoiseFrameHelper::loop() { | ||||
|   APIError err = state_action_(); | ||||
|   if (err == APIError::WOULD_BLOCK) | ||||
|     return APIError::OK; | ||||
|   if (err != APIError::OK) | ||||
|   if (err != APIError::OK && err != APIError::WOULD_BLOCK) { | ||||
|     return err; | ||||
|   if (!tx_buf_.empty()) { | ||||
|   } | ||||
|   if (!this->tx_buf_.empty()) { | ||||
|     err = try_send_tx_buf_(); | ||||
|     if (err != APIError::OK) { | ||||
|     if (err != APIError::OK && err != APIError::WOULD_BLOCK) { | ||||
|       return err; | ||||
|     } | ||||
|   } | ||||
|   return APIError::OK; | ||||
|   return APIError::OK;  // Convert WOULD_BLOCK to OK to avoid connection termination | ||||
| } | ||||
|  | ||||
| /** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter | ||||
| @@ -270,8 +310,8 @@ APIError APINoiseFrameHelper::try_read_frame_(ParsedFrame *frame) { | ||||
|   // read header | ||||
|   if (rx_header_buf_len_ < 3) { | ||||
|     // no header information yet | ||||
|     size_t to_read = 3 - rx_header_buf_len_; | ||||
|     ssize_t received = socket_->read(&rx_header_buf_[rx_header_buf_len_], to_read); | ||||
|     uint8_t to_read = 3 - rx_header_buf_len_; | ||||
|     ssize_t received = this->socket_->read(&rx_header_buf_[rx_header_buf_len_], to_read); | ||||
|     if (received == -1) { | ||||
|       if (errno == EWOULDBLOCK || errno == EAGAIN) { | ||||
|         return APIError::WOULD_BLOCK; | ||||
| @@ -284,8 +324,8 @@ APIError APINoiseFrameHelper::try_read_frame_(ParsedFrame *frame) { | ||||
|       HELPER_LOG("Connection closed"); | ||||
|       return APIError::CONNECTION_CLOSED; | ||||
|     } | ||||
|     rx_header_buf_len_ += received; | ||||
|     if ((size_t) received != to_read) { | ||||
|     rx_header_buf_len_ += static_cast<uint8_t>(received); | ||||
|     if (static_cast<uint8_t>(received) != to_read) { | ||||
|       // not a full read | ||||
|       return APIError::WOULD_BLOCK; | ||||
|     } | ||||
| @@ -317,8 +357,8 @@ APIError APINoiseFrameHelper::try_read_frame_(ParsedFrame *frame) { | ||||
|  | ||||
|   if (rx_buf_len_ < msg_size) { | ||||
|     // more data to read | ||||
|     size_t to_read = msg_size - rx_buf_len_; | ||||
|     ssize_t received = socket_->read(&rx_buf_[rx_buf_len_], to_read); | ||||
|     uint16_t to_read = msg_size - rx_buf_len_; | ||||
|     ssize_t received = this->socket_->read(&rx_buf_[rx_buf_len_], to_read); | ||||
|     if (received == -1) { | ||||
|       if (errno == EWOULDBLOCK || errno == EAGAIN) { | ||||
|         return APIError::WOULD_BLOCK; | ||||
| @@ -331,8 +371,8 @@ APIError APINoiseFrameHelper::try_read_frame_(ParsedFrame *frame) { | ||||
|       HELPER_LOG("Connection closed"); | ||||
|       return APIError::CONNECTION_CLOSED; | ||||
|     } | ||||
|     rx_buf_len_ += received; | ||||
|     if ((size_t) received != to_read) { | ||||
|     rx_buf_len_ += static_cast<uint16_t>(received); | ||||
|     if (static_cast<uint16_t>(received) != to_read) { | ||||
|       // not all read | ||||
|       return APIError::WOULD_BLOCK; | ||||
|     } | ||||
| @@ -381,6 +421,8 @@ APIError APINoiseFrameHelper::state_action_() { | ||||
|     if (aerr != APIError::OK) | ||||
|       return aerr; | ||||
|     // ignore contents, may be used in future for flags | ||||
|     // Reserve space for: existing prologue + 2 size bytes + frame data | ||||
|     prologue_.reserve(prologue_.size() + 2 + frame.msg.size()); | ||||
|     prologue_.push_back((uint8_t) (frame.msg.size() >> 8)); | ||||
|     prologue_.push_back((uint8_t) frame.msg.size()); | ||||
|     prologue_.insert(prologue_.end(), frame.msg.begin(), frame.msg.end()); | ||||
| @@ -389,16 +431,20 @@ APIError APINoiseFrameHelper::state_action_() { | ||||
|   } | ||||
|   if (state_ == State::SERVER_HELLO) { | ||||
|     // send server hello | ||||
|     const std::string &name = App.get_name(); | ||||
|     const std::string &mac = get_mac_address(); | ||||
|  | ||||
|     std::vector<uint8_t> msg; | ||||
|     // Reserve space for: 1 byte proto + name + null + mac + null | ||||
|     msg.reserve(1 + name.size() + 1 + mac.size() + 1); | ||||
|  | ||||
|     // chosen proto | ||||
|     msg.push_back(0x01); | ||||
|  | ||||
|     // node name, terminated by null byte | ||||
|     const std::string &name = App.get_name(); | ||||
|     const uint8_t *name_ptr = reinterpret_cast<const uint8_t *>(name.c_str()); | ||||
|     msg.insert(msg.end(), name_ptr, name_ptr + name.size() + 1); | ||||
|     // node mac, terminated by null byte | ||||
|     const std::string &mac = get_mac_address(); | ||||
|     const uint8_t *mac_ptr = reinterpret_cast<const uint8_t *>(mac.c_str()); | ||||
|     msg.insert(msg.end(), mac_ptr, mac_ptr + mac.size() + 1); | ||||
|  | ||||
| @@ -505,7 +551,6 @@ void APINoiseFrameHelper::send_explicit_handshake_reject_(const std::string &rea | ||||
|   write_frame_(data.data(), data.size()); | ||||
|   state_ = orig_state; | ||||
| } | ||||
|  | ||||
| APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) { | ||||
|   int err; | ||||
|   APIError aerr; | ||||
| @@ -533,7 +578,7 @@ APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) { | ||||
|     return APIError::CIPHERSTATE_DECRYPT_FAILED; | ||||
|   } | ||||
|  | ||||
|   size_t msg_size = mbuf.size; | ||||
|   uint16_t msg_size = mbuf.size; | ||||
|   uint8_t *msg_data = frame.msg.data(); | ||||
|   if (msg_size < 4) { | ||||
|     state_ = State::FAILED; | ||||
| @@ -559,7 +604,6 @@ APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) { | ||||
|   buffer->type = type; | ||||
|   return APIError::OK; | ||||
| } | ||||
| bool APINoiseFrameHelper::can_write_without_blocking() { return state_ == State::DATA && tx_buf_.empty(); } | ||||
| APIError APINoiseFrameHelper::write_protobuf_packet(uint16_t type, ProtoWriteBuffer buffer) { | ||||
|   int err; | ||||
|   APIError aerr; | ||||
| @@ -574,9 +618,9 @@ APIError APINoiseFrameHelper::write_protobuf_packet(uint16_t type, ProtoWriteBuf | ||||
|  | ||||
|   std::vector<uint8_t> *raw_buffer = buffer.get_buffer(); | ||||
|   // Message data starts after padding | ||||
|   size_t payload_len = raw_buffer->size() - frame_header_padding_; | ||||
|   size_t padding = 0; | ||||
|   size_t msg_len = 4 + payload_len + padding; | ||||
|   uint16_t payload_len = raw_buffer->size() - frame_header_padding_; | ||||
|   uint16_t padding = 0; | ||||
|   uint16_t msg_len = 4 + payload_len + padding; | ||||
|  | ||||
|   // We need to resize to include MAC space, but we already reserved it in create_buffer | ||||
|   raw_buffer->resize(raw_buffer->size() + frame_footer_size_); | ||||
| @@ -609,7 +653,7 @@ APIError APINoiseFrameHelper::write_protobuf_packet(uint16_t type, ProtoWriteBuf | ||||
|     return APIError::CIPHERSTATE_ENCRYPT_FAILED; | ||||
|   } | ||||
|  | ||||
|   size_t total_len = 3 + mbuf.size; | ||||
|   uint16_t total_len = 3 + mbuf.size; | ||||
|   buf_start[1] = (uint8_t) (mbuf.size >> 8); | ||||
|   buf_start[2] = (uint8_t) mbuf.size; | ||||
|  | ||||
| @@ -620,29 +664,9 @@ APIError APINoiseFrameHelper::write_protobuf_packet(uint16_t type, ProtoWriteBuf | ||||
|   iov.iov_len = total_len; | ||||
|  | ||||
|   // write raw to not have two packets sent if NAGLE disabled | ||||
|   return write_raw_(&iov, 1); | ||||
|   return this->write_raw_(&iov, 1); | ||||
| } | ||||
| APIError APINoiseFrameHelper::try_send_tx_buf_() { | ||||
|   // try send from tx_buf | ||||
|   while (state_ != State::CLOSED && !tx_buf_.empty()) { | ||||
|     ssize_t sent = socket_->write(tx_buf_.data(), tx_buf_.size()); | ||||
|     if (sent == -1) { | ||||
|       if (errno == EWOULDBLOCK || errno == EAGAIN) | ||||
|         break; | ||||
|       state_ = State::FAILED; | ||||
|       HELPER_LOG("Socket write failed with errno %d", errno); | ||||
|       return APIError::SOCKET_WRITE_FAILED; | ||||
|     } else if (sent == 0) { | ||||
|       break; | ||||
|     } | ||||
|     // TODO: inefficient if multiple packets in txbuf | ||||
|     // replace with deque of buffers | ||||
|     tx_buf_.erase(tx_buf_.begin(), tx_buf_.begin() + sent); | ||||
|   } | ||||
|  | ||||
|   return APIError::OK; | ||||
| } | ||||
| APIError APINoiseFrameHelper::write_frame_(const uint8_t *data, size_t len) { | ||||
| APIError APINoiseFrameHelper::write_frame_(const uint8_t *data, uint16_t len) { | ||||
|   uint8_t header[3]; | ||||
|   header[0] = 0x01;  // indicator | ||||
|   header[1] = (uint8_t) (len >> 8); | ||||
| @@ -652,12 +676,12 @@ APIError APINoiseFrameHelper::write_frame_(const uint8_t *data, size_t len) { | ||||
|   iov[0].iov_base = header; | ||||
|   iov[0].iov_len = 3; | ||||
|   if (len == 0) { | ||||
|     return write_raw_(iov, 1); | ||||
|     return this->write_raw_(iov, 1); | ||||
|   } | ||||
|   iov[1].iov_base = const_cast<uint8_t *>(data); | ||||
|   iov[1].iov_len = len; | ||||
|  | ||||
|   return write_raw_(iov, 2); | ||||
|   return this->write_raw_(iov, 2); | ||||
| } | ||||
|  | ||||
| /** Initiate the data structures for the handshake. | ||||
| @@ -752,22 +776,6 @@ APINoiseFrameHelper::~APINoiseFrameHelper() { | ||||
|   } | ||||
| } | ||||
|  | ||||
| APIError APINoiseFrameHelper::close() { | ||||
|   state_ = State::CLOSED; | ||||
|   int err = socket_->close(); | ||||
|   if (err == -1) | ||||
|     return APIError::CLOSE_FAILED; | ||||
|   return APIError::OK; | ||||
| } | ||||
| APIError APINoiseFrameHelper::shutdown(int how) { | ||||
|   int err = socket_->shutdown(how); | ||||
|   if (err == -1) | ||||
|     return APIError::SHUTDOWN_FAILED; | ||||
|   if (how == SHUT_RDWR) { | ||||
|     state_ = State::CLOSED; | ||||
|   } | ||||
|   return APIError::OK; | ||||
| } | ||||
| extern "C" { | ||||
| // declare how noise generates random bytes (here with a good HWRNG based on the RF system) | ||||
| void noise_rand_bytes(void *output, size_t len) { | ||||
| @@ -778,32 +786,15 @@ void noise_rand_bytes(void *output, size_t len) { | ||||
| } | ||||
| } | ||||
|  | ||||
| // Explicit template instantiation for Noise | ||||
| template APIError APIFrameHelper::write_raw_<APINoiseFrameHelper::State>( | ||||
|     const struct iovec *iov, int iovcnt, socket::Socket *socket, std::vector<uint8_t> &tx_buf_, const std::string &info, | ||||
|     APINoiseFrameHelper::State &state, APINoiseFrameHelper::State failed_state); | ||||
| #endif  // USE_API_NOISE | ||||
|  | ||||
| #ifdef USE_API_PLAINTEXT | ||||
|  | ||||
| /// Initialize the frame helper, returns OK if successful. | ||||
| APIError APIPlaintextFrameHelper::init() { | ||||
|   if (state_ != State::INITIALIZE || socket_ == nullptr) { | ||||
|     HELPER_LOG("Bad state for init %d", (int) state_); | ||||
|     return APIError::BAD_STATE; | ||||
|   } | ||||
|   int err = socket_->setblocking(false); | ||||
|   if (err != 0) { | ||||
|     state_ = State::FAILED; | ||||
|     HELPER_LOG("Setting nonblocking failed with errno %d", errno); | ||||
|     return APIError::TCP_NONBLOCKING_FAILED; | ||||
|   } | ||||
|   int enable = 1; | ||||
|   err = socket_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int)); | ||||
|   if (err != 0) { | ||||
|     state_ = State::FAILED; | ||||
|     HELPER_LOG("Setting nodelay failed with errno %d", errno); | ||||
|     return APIError::TCP_NODELAY_FAILED; | ||||
|   APIError err = init_common_(); | ||||
|   if (err != APIError::OK) { | ||||
|     return err; | ||||
|   } | ||||
|  | ||||
|   state_ = State::DATA; | ||||
| @@ -814,14 +805,13 @@ APIError APIPlaintextFrameHelper::loop() { | ||||
|   if (state_ != State::DATA) { | ||||
|     return APIError::BAD_STATE; | ||||
|   } | ||||
|   // try send pending TX data | ||||
|   if (!tx_buf_.empty()) { | ||||
|   if (!this->tx_buf_.empty()) { | ||||
|     APIError err = try_send_tx_buf_(); | ||||
|     if (err != APIError::OK) { | ||||
|     if (err != APIError::OK && err != APIError::WOULD_BLOCK) { | ||||
|       return err; | ||||
|     } | ||||
|   } | ||||
|   return APIError::OK; | ||||
|   return APIError::OK;  // Convert WOULD_BLOCK to OK to avoid connection termination | ||||
| } | ||||
|  | ||||
| /** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter | ||||
| @@ -846,7 +836,7 @@ APIError APIPlaintextFrameHelper::try_read_frame_(ParsedFrame *frame) { | ||||
|     // there is no data on the wire (which is the common case). | ||||
|     // This results in faster failure detection compared to | ||||
|     // attempting to read multiple bytes at once. | ||||
|     ssize_t received = socket_->read(&data, 1); | ||||
|     ssize_t received = this->socket_->read(&data, 1); | ||||
|     if (received == -1) { | ||||
|       if (errno == EWOULDBLOCK || errno == EAGAIN) { | ||||
|         return APIError::WOULD_BLOCK; | ||||
| @@ -910,14 +900,24 @@ APIError APIPlaintextFrameHelper::try_read_frame_(ParsedFrame *frame) { | ||||
|       continue; | ||||
|     } | ||||
|  | ||||
|     rx_header_parsed_len_ = msg_size_varint->as_uint32(); | ||||
|     if (msg_size_varint->as_uint32() > 65535) { | ||||
|       state_ = State::FAILED; | ||||
|       HELPER_LOG("Bad packet: message size %" PRIu32 " exceeds maximum 65535", msg_size_varint->as_uint32()); | ||||
|       return APIError::BAD_DATA_PACKET; | ||||
|     } | ||||
|     rx_header_parsed_len_ = msg_size_varint->as_uint16(); | ||||
|  | ||||
|     auto msg_type_varint = ProtoVarInt::parse(&rx_header_buf_[consumed], rx_header_buf_pos_ - 1 - consumed, &consumed); | ||||
|     if (!msg_type_varint.has_value()) { | ||||
|       // not enough data there yet | ||||
|       continue; | ||||
|     } | ||||
|     rx_header_parsed_type_ = msg_type_varint->as_uint32(); | ||||
|     if (msg_type_varint->as_uint32() > 65535) { | ||||
|       state_ = State::FAILED; | ||||
|       HELPER_LOG("Bad packet: message type %" PRIu32 " exceeds maximum 65535", msg_type_varint->as_uint32()); | ||||
|       return APIError::BAD_DATA_PACKET; | ||||
|     } | ||||
|     rx_header_parsed_type_ = msg_type_varint->as_uint16(); | ||||
|     rx_header_parsed_ = true; | ||||
|   } | ||||
|   // header reading done | ||||
| @@ -929,8 +929,8 @@ APIError APIPlaintextFrameHelper::try_read_frame_(ParsedFrame *frame) { | ||||
|  | ||||
|   if (rx_buf_len_ < rx_header_parsed_len_) { | ||||
|     // more data to read | ||||
|     size_t to_read = rx_header_parsed_len_ - rx_buf_len_; | ||||
|     ssize_t received = socket_->read(&rx_buf_[rx_buf_len_], to_read); | ||||
|     uint16_t to_read = rx_header_parsed_len_ - rx_buf_len_; | ||||
|     ssize_t received = this->socket_->read(&rx_buf_[rx_buf_len_], to_read); | ||||
|     if (received == -1) { | ||||
|       if (errno == EWOULDBLOCK || errno == EAGAIN) { | ||||
|         return APIError::WOULD_BLOCK; | ||||
| @@ -943,8 +943,8 @@ APIError APIPlaintextFrameHelper::try_read_frame_(ParsedFrame *frame) { | ||||
|       HELPER_LOG("Connection closed"); | ||||
|       return APIError::CONNECTION_CLOSED; | ||||
|     } | ||||
|     rx_buf_len_ += received; | ||||
|     if ((size_t) received != to_read) { | ||||
|     rx_buf_len_ += static_cast<uint16_t>(received); | ||||
|     if (static_cast<uint16_t>(received) != to_read) { | ||||
|       // not all read | ||||
|       return APIError::WOULD_BLOCK; | ||||
|     } | ||||
| @@ -962,7 +962,6 @@ APIError APIPlaintextFrameHelper::try_read_frame_(ParsedFrame *frame) { | ||||
|   rx_header_parsed_ = false; | ||||
|   return APIError::OK; | ||||
| } | ||||
|  | ||||
| APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) { | ||||
|   APIError aerr; | ||||
|  | ||||
| @@ -990,7 +989,7 @@ APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) { | ||||
|                          "Bad indicator byte"; | ||||
|       iov[0].iov_base = (void *) msg; | ||||
|       iov[0].iov_len = 19; | ||||
|       write_raw_(iov, 1); | ||||
|       this->write_raw_(iov, 1); | ||||
|     } | ||||
|     return aerr; | ||||
|   } | ||||
| @@ -1001,7 +1000,6 @@ APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) { | ||||
|   buffer->type = rx_header_parsed_type_; | ||||
|   return APIError::OK; | ||||
| } | ||||
| bool APIPlaintextFrameHelper::can_write_without_blocking() { return state_ == State::DATA && tx_buf_.empty(); } | ||||
| APIError APIPlaintextFrameHelper::write_protobuf_packet(uint16_t type, ProtoWriteBuffer buffer) { | ||||
|   if (state_ != State::DATA) { | ||||
|     return APIError::BAD_STATE; | ||||
| @@ -1009,12 +1007,12 @@ APIError APIPlaintextFrameHelper::write_protobuf_packet(uint16_t type, ProtoWrit | ||||
|  | ||||
|   std::vector<uint8_t> *raw_buffer = buffer.get_buffer(); | ||||
|   // Message data starts after padding (frame_header_padding_ = 6) | ||||
|   size_t payload_len = raw_buffer->size() - frame_header_padding_; | ||||
|   uint16_t payload_len = static_cast<uint16_t>(raw_buffer->size() - frame_header_padding_); | ||||
|  | ||||
|   // Calculate varint sizes for header components | ||||
|   size_t size_varint_len = api::ProtoSize::varint(static_cast<uint32_t>(payload_len)); | ||||
|   size_t type_varint_len = api::ProtoSize::varint(static_cast<uint32_t>(type)); | ||||
|   size_t total_header_len = 1 + size_varint_len + type_varint_len; | ||||
|   uint8_t size_varint_len = api::ProtoSize::varint(static_cast<uint32_t>(payload_len)); | ||||
|   uint8_t type_varint_len = api::ProtoSize::varint(static_cast<uint32_t>(type)); | ||||
|   uint8_t total_header_len = 1 + size_varint_len + type_varint_len; | ||||
|  | ||||
|   if (total_header_len > frame_header_padding_) { | ||||
|     // Header is too large to fit in the padding | ||||
| @@ -1044,7 +1042,7 @@ APIError APIPlaintextFrameHelper::write_protobuf_packet(uint16_t type, ProtoWrit | ||||
|   // [4-5]  - Message type varint (2 bytes, for types 128-32767) | ||||
|   // [6...] - Actual payload data | ||||
|   uint8_t *buf_start = raw_buffer->data(); | ||||
|   size_t header_offset = frame_header_padding_ - total_header_len; | ||||
|   uint8_t header_offset = frame_header_padding_ - total_header_len; | ||||
|  | ||||
|   // Write the plaintext header | ||||
|   buf_start[header_offset] = 0x00;  // indicator | ||||
| @@ -1063,46 +1061,7 @@ APIError APIPlaintextFrameHelper::write_protobuf_packet(uint16_t type, ProtoWrit | ||||
|  | ||||
|   return write_raw_(&iov, 1); | ||||
| } | ||||
| APIError APIPlaintextFrameHelper::try_send_tx_buf_() { | ||||
|   // try send from tx_buf | ||||
|   while (state_ != State::CLOSED && !tx_buf_.empty()) { | ||||
|     ssize_t sent = socket_->write(tx_buf_.data(), tx_buf_.size()); | ||||
|     if (is_would_block(sent)) { | ||||
|       break; | ||||
|     } else if (sent == -1) { | ||||
|       state_ = State::FAILED; | ||||
|       HELPER_LOG("Socket write failed with errno %d", errno); | ||||
|       return APIError::SOCKET_WRITE_FAILED; | ||||
|     } | ||||
|     // TODO: inefficient if multiple packets in txbuf | ||||
|     // replace with deque of buffers | ||||
|     tx_buf_.erase(tx_buf_.begin(), tx_buf_.begin() + sent); | ||||
|   } | ||||
|  | ||||
|   return APIError::OK; | ||||
| } | ||||
|  | ||||
| APIError APIPlaintextFrameHelper::close() { | ||||
|   state_ = State::CLOSED; | ||||
|   int err = socket_->close(); | ||||
|   if (err == -1) | ||||
|     return APIError::CLOSE_FAILED; | ||||
|   return APIError::OK; | ||||
| } | ||||
| APIError APIPlaintextFrameHelper::shutdown(int how) { | ||||
|   int err = socket_->shutdown(how); | ||||
|   if (err == -1) | ||||
|     return APIError::SHUTDOWN_FAILED; | ||||
|   if (how == SHUT_RDWR) { | ||||
|     state_ = State::CLOSED; | ||||
|   } | ||||
|   return APIError::OK; | ||||
| } | ||||
|  | ||||
| // Explicit template instantiation for Plaintext | ||||
| template APIError APIFrameHelper::write_raw_<APIPlaintextFrameHelper::State>( | ||||
|     const struct iovec *iov, int iovcnt, socket::Socket *socket, std::vector<uint8_t> &tx_buf_, const std::string &info, | ||||
|     APIPlaintextFrameHelper::State &state, APIPlaintextFrameHelper::State failed_state); | ||||
| #endif  // USE_API_PLAINTEXT | ||||
|  | ||||
| }  // namespace api | ||||
|   | ||||
| @@ -21,15 +21,8 @@ class ProtoWriteBuffer; | ||||
| struct ReadPacketBuffer { | ||||
|   std::vector<uint8_t> container; | ||||
|   uint16_t type; | ||||
|   size_t data_offset; | ||||
|   size_t data_len; | ||||
| }; | ||||
|  | ||||
| struct PacketBuffer { | ||||
|   const std::vector<uint8_t> container; | ||||
|   uint16_t type; | ||||
|   uint8_t data_offset; | ||||
|   uint8_t data_len; | ||||
|   uint16_t data_offset; | ||||
|   uint16_t data_len; | ||||
| }; | ||||
|  | ||||
| enum class APIError : int { | ||||
| @@ -62,38 +55,117 @@ const char *api_error_to_str(APIError err); | ||||
|  | ||||
| class APIFrameHelper { | ||||
|  public: | ||||
|   APIFrameHelper() = default; | ||||
|   explicit APIFrameHelper(std::unique_ptr<socket::Socket> socket) : socket_owned_(std::move(socket)) { | ||||
|     socket_ = socket_owned_.get(); | ||||
|   } | ||||
|   virtual ~APIFrameHelper() = default; | ||||
|   virtual APIError init() = 0; | ||||
|   virtual APIError loop() = 0; | ||||
|   virtual APIError read_packet(ReadPacketBuffer *buffer) = 0; | ||||
|   virtual bool can_write_without_blocking() = 0; | ||||
|   virtual APIError write_protobuf_packet(uint16_t type, ProtoWriteBuffer buffer) = 0; | ||||
|   virtual std::string getpeername() = 0; | ||||
|   virtual int getpeername(struct sockaddr *addr, socklen_t *addrlen) = 0; | ||||
|   virtual APIError close() = 0; | ||||
|   virtual APIError shutdown(int how) = 0; | ||||
|   bool can_write_without_blocking() { return state_ == State::DATA && tx_buf_.empty(); } | ||||
|   std::string getpeername() { return socket_->getpeername(); } | ||||
|   int getpeername(struct sockaddr *addr, socklen_t *addrlen) { return socket_->getpeername(addr, addrlen); } | ||||
|   APIError close() { | ||||
|     state_ = State::CLOSED; | ||||
|     int err = this->socket_->close(); | ||||
|     if (err == -1) | ||||
|       return APIError::CLOSE_FAILED; | ||||
|     return APIError::OK; | ||||
|   } | ||||
|   APIError shutdown(int how) { | ||||
|     int err = this->socket_->shutdown(how); | ||||
|     if (err == -1) | ||||
|       return APIError::SHUTDOWN_FAILED; | ||||
|     if (how == SHUT_RDWR) { | ||||
|       state_ = State::CLOSED; | ||||
|     } | ||||
|     return APIError::OK; | ||||
|   } | ||||
|   // Give this helper a name for logging | ||||
|   virtual void set_log_info(std::string info) = 0; | ||||
|   void set_log_info(std::string info) { info_ = std::move(info); } | ||||
|   virtual APIError write_protobuf_packet(uint16_t type, ProtoWriteBuffer buffer) = 0; | ||||
|   // Get the frame header padding required by this protocol | ||||
|   virtual uint8_t frame_header_padding() = 0; | ||||
|   // Get the frame footer size required by this protocol | ||||
|   virtual uint8_t frame_footer_size() = 0; | ||||
|  | ||||
|  protected: | ||||
|   // Struct for holding parsed frame data | ||||
|   struct ParsedFrame { | ||||
|     std::vector<uint8_t> msg; | ||||
|   }; | ||||
|  | ||||
|   // Buffer containing data to be sent | ||||
|   struct SendBuffer { | ||||
|     std::vector<uint8_t> data; | ||||
|     uint16_t offset{0};  // Current offset within the buffer (uint16_t to reduce memory usage) | ||||
|  | ||||
|     // Using uint16_t reduces memory usage since ESPHome API messages are limited to 64KB max | ||||
|     uint16_t remaining() const { return static_cast<uint16_t>(data.size()) - offset; } | ||||
|     const uint8_t *current_data() const { return data.data() + offset; } | ||||
|   }; | ||||
|  | ||||
|   // Queue of data buffers to be sent | ||||
|   std::deque<SendBuffer> tx_buf_; | ||||
|  | ||||
|   // Common state enum for all frame helpers | ||||
|   // Note: Not all states are used by all implementations | ||||
|   // - INITIALIZE: Used by both Noise and Plaintext | ||||
|   // - CLIENT_HELLO, SERVER_HELLO, HANDSHAKE: Only used by Noise protocol | ||||
|   // - DATA: Used by both Noise and Plaintext | ||||
|   // - CLOSED: Used by both Noise and Plaintext | ||||
|   // - FAILED: Used by both Noise and Plaintext | ||||
|   // - EXPLICIT_REJECT: Only used by Noise protocol | ||||
|   enum class State { | ||||
|     INITIALIZE = 1, | ||||
|     CLIENT_HELLO = 2,  // Noise only | ||||
|     SERVER_HELLO = 3,  // Noise only | ||||
|     HANDSHAKE = 4,     // Noise only | ||||
|     DATA = 5, | ||||
|     CLOSED = 6, | ||||
|     FAILED = 7, | ||||
|     EXPLICIT_REJECT = 8,  // Noise only | ||||
|   }; | ||||
|  | ||||
|   // Current state of the frame helper | ||||
|   State state_{State::INITIALIZE}; | ||||
|  | ||||
|   // Helper name for logging | ||||
|   std::string info_; | ||||
|  | ||||
|   // Socket for communication | ||||
|   socket::Socket *socket_{nullptr}; | ||||
|   std::unique_ptr<socket::Socket> socket_owned_; | ||||
|  | ||||
|   // Common implementation for writing raw data to socket | ||||
|   APIError write_raw_(const struct iovec *iov, int iovcnt); | ||||
|  | ||||
|   // Try to send data from the tx buffer | ||||
|   APIError try_send_tx_buf_(); | ||||
|  | ||||
|   // Helper method to buffer data from IOVs | ||||
|   void buffer_data_from_iov_(const struct iovec *iov, int iovcnt, uint16_t total_write_len); | ||||
|   template<typename StateEnum> | ||||
|   APIError write_raw_(const struct iovec *iov, int iovcnt, socket::Socket *socket, std::vector<uint8_t> &tx_buf, | ||||
|                       const std::string &info, StateEnum &state, StateEnum failed_state); | ||||
|  | ||||
|   uint8_t frame_header_padding_{0}; | ||||
|   uint8_t frame_footer_size_{0}; | ||||
|  | ||||
|   // Receive buffer for reading frame data | ||||
|   std::vector<uint8_t> rx_buf_; | ||||
|   uint16_t rx_buf_len_ = 0; | ||||
|  | ||||
|   // Common initialization for both plaintext and noise protocols | ||||
|   APIError init_common_(); | ||||
| }; | ||||
|  | ||||
| #ifdef USE_API_NOISE | ||||
| class APINoiseFrameHelper : public APIFrameHelper { | ||||
|  public: | ||||
|   APINoiseFrameHelper(std::unique_ptr<socket::Socket> socket, std::shared_ptr<APINoiseContext> ctx) | ||||
|       : socket_(std::move(socket)), ctx_(std::move(ctx)) { | ||||
|       : APIFrameHelper(std::move(socket)), ctx_(std::move(ctx)) { | ||||
|     // Noise header structure: | ||||
|     // Pos 0: indicator (0x01) | ||||
|     // Pos 1-2: encrypted payload size (16-bit big-endian) | ||||
| @@ -105,49 +177,25 @@ class APINoiseFrameHelper : public APIFrameHelper { | ||||
|   APIError init() override; | ||||
|   APIError loop() override; | ||||
|   APIError read_packet(ReadPacketBuffer *buffer) override; | ||||
|   bool can_write_without_blocking() override; | ||||
|   APIError write_protobuf_packet(uint16_t type, ProtoWriteBuffer buffer) override; | ||||
|   std::string getpeername() override { return this->socket_->getpeername(); } | ||||
|   int getpeername(struct sockaddr *addr, socklen_t *addrlen) override { | ||||
|     return this->socket_->getpeername(addr, addrlen); | ||||
|   } | ||||
|   APIError close() override; | ||||
|   APIError shutdown(int how) override; | ||||
|   // Give this helper a name for logging | ||||
|   void set_log_info(std::string info) override { info_ = std::move(info); } | ||||
|   // Get the frame header padding required by this protocol | ||||
|   uint8_t frame_header_padding() override { return frame_header_padding_; } | ||||
|   // Get the frame footer size required by this protocol | ||||
|   uint8_t frame_footer_size() override { return frame_footer_size_; } | ||||
|  | ||||
|  protected: | ||||
|   struct ParsedFrame { | ||||
|     std::vector<uint8_t> msg; | ||||
|   }; | ||||
|  | ||||
|   APIError state_action_(); | ||||
|   APIError try_read_frame_(ParsedFrame *frame); | ||||
|   APIError try_send_tx_buf_(); | ||||
|   APIError write_frame_(const uint8_t *data, size_t len); | ||||
|   inline APIError write_raw_(const struct iovec *iov, int iovcnt) { | ||||
|     return APIFrameHelper::write_raw_(iov, iovcnt, socket_.get(), tx_buf_, info_, state_, State::FAILED); | ||||
|   } | ||||
|   APIError write_frame_(const uint8_t *data, uint16_t len); | ||||
|   APIError init_handshake_(); | ||||
|   APIError check_handshake_finished_(); | ||||
|   void send_explicit_handshake_reject_(const std::string &reason); | ||||
|  | ||||
|   std::unique_ptr<socket::Socket> socket_; | ||||
|  | ||||
|   std::string info_; | ||||
|   // Fixed-size header buffer for noise protocol: | ||||
|   // 1 byte for indicator + 2 bytes for message size (16-bit value, not varint) | ||||
|   // Note: Maximum message size is 65535, with a limit of 128 bytes during handshake phase | ||||
|   uint8_t rx_header_buf_[3]; | ||||
|   size_t rx_header_buf_len_ = 0; | ||||
|   std::vector<uint8_t> rx_buf_; | ||||
|   size_t rx_buf_len_ = 0; | ||||
|   uint8_t rx_header_buf_len_ = 0; | ||||
|  | ||||
|   std::vector<uint8_t> tx_buf_; | ||||
|   std::vector<uint8_t> prologue_; | ||||
|  | ||||
|   std::shared_ptr<APINoiseContext> ctx_; | ||||
| @@ -155,24 +203,13 @@ class APINoiseFrameHelper : public APIFrameHelper { | ||||
|   NoiseCipherState *send_cipher_{nullptr}; | ||||
|   NoiseCipherState *recv_cipher_{nullptr}; | ||||
|   NoiseProtocolId nid_; | ||||
|  | ||||
|   enum class State { | ||||
|     INITIALIZE = 1, | ||||
|     CLIENT_HELLO = 2, | ||||
|     SERVER_HELLO = 3, | ||||
|     HANDSHAKE = 4, | ||||
|     DATA = 5, | ||||
|     CLOSED = 6, | ||||
|     FAILED = 7, | ||||
|     EXPLICIT_REJECT = 8, | ||||
|   } state_ = State::INITIALIZE; | ||||
| }; | ||||
| #endif  // USE_API_NOISE | ||||
|  | ||||
| #ifdef USE_API_PLAINTEXT | ||||
| class APIPlaintextFrameHelper : public APIFrameHelper { | ||||
|  public: | ||||
|   APIPlaintextFrameHelper(std::unique_ptr<socket::Socket> socket) : socket_(std::move(socket)) { | ||||
|   APIPlaintextFrameHelper(std::unique_ptr<socket::Socket> socket) : APIFrameHelper(std::move(socket)) { | ||||
|     // Plaintext header structure (worst case): | ||||
|     // Pos 0: indicator (0x00) | ||||
|     // Pos 1-3: payload size varint (up to 3 bytes) | ||||
| @@ -184,35 +221,13 @@ class APIPlaintextFrameHelper : public APIFrameHelper { | ||||
|   APIError init() override; | ||||
|   APIError loop() override; | ||||
|   APIError read_packet(ReadPacketBuffer *buffer) override; | ||||
|   bool can_write_without_blocking() override; | ||||
|   APIError write_protobuf_packet(uint16_t type, ProtoWriteBuffer buffer) override; | ||||
|   std::string getpeername() override { return this->socket_->getpeername(); } | ||||
|   int getpeername(struct sockaddr *addr, socklen_t *addrlen) override { | ||||
|     return this->socket_->getpeername(addr, addrlen); | ||||
|   } | ||||
|   APIError close() override; | ||||
|   APIError shutdown(int how) override; | ||||
|   // Give this helper a name for logging | ||||
|   void set_log_info(std::string info) override { info_ = std::move(info); } | ||||
|   // Get the frame header padding required by this protocol | ||||
|   uint8_t frame_header_padding() override { return frame_header_padding_; } | ||||
|   // Get the frame footer size required by this protocol | ||||
|   uint8_t frame_footer_size() override { return frame_footer_size_; } | ||||
|  | ||||
|  protected: | ||||
|   struct ParsedFrame { | ||||
|     std::vector<uint8_t> msg; | ||||
|   }; | ||||
|  | ||||
|   APIError try_read_frame_(ParsedFrame *frame); | ||||
|   APIError try_send_tx_buf_(); | ||||
|   inline APIError write_raw_(const struct iovec *iov, int iovcnt) { | ||||
|     return APIFrameHelper::write_raw_(iov, iovcnt, socket_.get(), tx_buf_, info_, state_, State::FAILED); | ||||
|   } | ||||
|  | ||||
|   std::unique_ptr<socket::Socket> socket_; | ||||
|  | ||||
|   std::string info_; | ||||
|   // Fixed-size header buffer for plaintext protocol: | ||||
|   // We only need space for the two varints since we validate the indicator byte separately. | ||||
|   // To match noise protocol's maximum message size (65535), we need: | ||||
| @@ -224,20 +239,8 @@ class APIPlaintextFrameHelper : public APIFrameHelper { | ||||
|   uint8_t rx_header_buf_[5];  // 5 bytes for varints (3 for size + 2 for type) | ||||
|   uint8_t rx_header_buf_pos_ = 0; | ||||
|   bool rx_header_parsed_ = false; | ||||
|   uint32_t rx_header_parsed_type_ = 0; | ||||
|   uint32_t rx_header_parsed_len_ = 0; | ||||
|  | ||||
|   std::vector<uint8_t> rx_buf_; | ||||
|   size_t rx_buf_len_ = 0; | ||||
|  | ||||
|   std::vector<uint8_t> tx_buf_; | ||||
|  | ||||
|   enum class State { | ||||
|     INITIALIZE = 1, | ||||
|     DATA = 2, | ||||
|     CLOSED = 3, | ||||
|     FAILED = 4, | ||||
|   } state_ = State::INITIALIZE; | ||||
|   uint16_t rx_header_parsed_type_ = 0; | ||||
|   uint16_t rx_header_parsed_len_ = 0; | ||||
| }; | ||||
| #endif | ||||
|  | ||||
|   | ||||
| @@ -55,6 +55,7 @@ class ProtoVarInt { | ||||
|     return {};  // Incomplete or invalid varint | ||||
|   } | ||||
|  | ||||
|   uint16_t as_uint16() const { return this->value_; } | ||||
|   uint32_t as_uint32() const { return this->value_; } | ||||
|   uint64_t as_uint64() const { return this->value_; } | ||||
|   bool as_bool() const { return this->value_; } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user