mirror of
				https://github.com/esphome/esphome.git
				synced 2025-10-31 15:12:06 +00:00 
			
		
		
		
	Consolidate write_raw_ implementation to reduce code duplication (#8717)
This commit is contained in:
		| @@ -73,6 +73,91 @@ const char *api_error_to_str(APIError err) { | ||||
|   return "UNKNOWN"; | ||||
| } | ||||
|  | ||||
| // Common implementation for writing raw data to socket | ||||
| template<typename StateEnum> | ||||
| APIError APIFrameHelper::write_raw_(const struct iovec *iov, int iovcnt, socket::Socket *socket, | ||||
|                                     std::vector<uint8_t> &tx_buf, const std::string &info, StateEnum &state, | ||||
|                                     StateEnum failed_state) { | ||||
|   // This method writes data to socket or buffers it | ||||
|   // Returns APIError::OK if successful (or would block, but data has been buffered) | ||||
|   // Returns APIError::SOCKET_WRITE_FAILED if socket write failed, and sets state to failed_state | ||||
|  | ||||
|   if (iovcnt == 0) | ||||
|     return APIError::OK;  // Nothing to do, success | ||||
|  | ||||
|   size_t total_write_len = 0; | ||||
|   for (int i = 0; i < iovcnt; i++) { | ||||
| #ifdef HELPER_LOG_PACKETS | ||||
|     ESP_LOGVV(TAG, "Sending raw: %s", | ||||
|               format_hex_pretty(reinterpret_cast<uint8_t *>(iov[i].iov_base), iov[i].iov_len).c_str()); | ||||
| #endif | ||||
|     total_write_len += iov[i].iov_len; | ||||
|   } | ||||
|  | ||||
|   if (!tx_buf.empty()) { | ||||
|     // try to empty tx_buf first | ||||
|     while (!tx_buf.empty()) { | ||||
|       ssize_t sent = socket->write(tx_buf.data(), tx_buf.size()); | ||||
|       if (is_would_block(sent)) { | ||||
|         break; | ||||
|       } else if (sent == -1) { | ||||
|         ESP_LOGVV(TAG, "%s: Socket write failed with errno %d", info.c_str(), errno); | ||||
|         state = failed_state; | ||||
|         return APIError::SOCKET_WRITE_FAILED;  // Socket write failed | ||||
|       } | ||||
|       // TODO: inefficient if multiple packets in txbuf | ||||
|       // replace with deque of buffers | ||||
|       tx_buf.erase(tx_buf.begin(), tx_buf.begin() + sent); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   if (!tx_buf.empty()) { | ||||
|     // tx buf not empty, can't write now because then stream would be inconsistent | ||||
|     // Reserve space upfront to avoid multiple reallocations | ||||
|     tx_buf.reserve(tx_buf.size() + total_write_len); | ||||
|     for (int i = 0; i < iovcnt; i++) { | ||||
|       tx_buf.insert(tx_buf.end(), reinterpret_cast<uint8_t *>(iov[i].iov_base), | ||||
|                     reinterpret_cast<uint8_t *>(iov[i].iov_base) + iov[i].iov_len); | ||||
|     } | ||||
|     return APIError::OK;  // Success, data buffered | ||||
|   } | ||||
|  | ||||
|   ssize_t sent = socket->writev(iov, iovcnt); | ||||
|   if (is_would_block(sent)) { | ||||
|     // operation would block, add buffer to tx_buf | ||||
|     // Reserve space upfront to avoid multiple reallocations | ||||
|     tx_buf.reserve(tx_buf.size() + total_write_len); | ||||
|     for (int i = 0; i < iovcnt; i++) { | ||||
|       tx_buf.insert(tx_buf.end(), reinterpret_cast<uint8_t *>(iov[i].iov_base), | ||||
|                     reinterpret_cast<uint8_t *>(iov[i].iov_base) + iov[i].iov_len); | ||||
|     } | ||||
|     return APIError::OK;  // Success, data buffered | ||||
|   } else if (sent == -1) { | ||||
|     // an error occurred | ||||
|     ESP_LOGVV(TAG, "%s: Socket write failed with errno %d", info.c_str(), errno); | ||||
|     state = failed_state; | ||||
|     return APIError::SOCKET_WRITE_FAILED;  // Socket write failed | ||||
|   } else if ((size_t) sent != total_write_len) { | ||||
|     // partially sent, add end to tx_buf | ||||
|     size_t remaining = total_write_len - sent; | ||||
|     // Reserve space upfront to avoid multiple reallocations | ||||
|     tx_buf.reserve(tx_buf.size() + remaining); | ||||
|  | ||||
|     size_t to_consume = sent; | ||||
|     for (int i = 0; i < iovcnt; i++) { | ||||
|       if (to_consume >= iov[i].iov_len) { | ||||
|         to_consume -= iov[i].iov_len; | ||||
|       } else { | ||||
|         tx_buf.insert(tx_buf.end(), reinterpret_cast<uint8_t *>(iov[i].iov_base) + to_consume, | ||||
|                       reinterpret_cast<uint8_t *>(iov[i].iov_base) + iov[i].iov_len); | ||||
|         to_consume = 0; | ||||
|       } | ||||
|     } | ||||
|     return APIError::OK;  // Success, data buffered | ||||
|   } | ||||
|   return APIError::OK;  // Success, all data sent | ||||
| } | ||||
|  | ||||
| #define HELPER_LOG(msg, ...) ESP_LOGVV(TAG, "%s: " msg, info_.c_str(), ##__VA_ARGS__) | ||||
| // uncomment to log raw packets | ||||
| //#define HELPER_LOG_PACKETS | ||||
| @@ -547,79 +632,6 @@ APIError APINoiseFrameHelper::try_send_tx_buf_() { | ||||
|  | ||||
|   return APIError::OK; | ||||
| } | ||||
| /** Write the data to the socket, or buffer it a write would block | ||||
|  * | ||||
|  * @param data The data to write | ||||
|  * @param len The length of data | ||||
|  */ | ||||
| APIError APINoiseFrameHelper::write_raw_(const struct iovec *iov, int iovcnt) { | ||||
|   if (iovcnt == 0) | ||||
|     return APIError::OK; | ||||
|   APIError aerr; | ||||
|  | ||||
|   size_t total_write_len = 0; | ||||
|   for (int i = 0; i < iovcnt; i++) { | ||||
| #ifdef HELPER_LOG_PACKETS | ||||
|     ESP_LOGVV(TAG, "Sending raw: %s", | ||||
|               format_hex_pretty(reinterpret_cast<uint8_t *>(iov[i].iov_base), iov[i].iov_len).c_str()); | ||||
| #endif | ||||
|     total_write_len += iov[i].iov_len; | ||||
|   } | ||||
|  | ||||
|   if (!tx_buf_.empty()) { | ||||
|     // try to empty tx_buf_ first | ||||
|     aerr = try_send_tx_buf_(); | ||||
|     if (aerr != APIError::OK && aerr != APIError::WOULD_BLOCK) | ||||
|       return aerr; | ||||
|   } | ||||
|  | ||||
|   if (!tx_buf_.empty()) { | ||||
|     // tx buf not empty, can't write now because then stream would be inconsistent | ||||
|     // Reserve space upfront to avoid multiple reallocations | ||||
|     tx_buf_.reserve(tx_buf_.size() + total_write_len); | ||||
|     for (int i = 0; i < iovcnt; i++) { | ||||
|       tx_buf_.insert(tx_buf_.end(), reinterpret_cast<uint8_t *>(iov[i].iov_base), | ||||
|                      reinterpret_cast<uint8_t *>(iov[i].iov_base) + iov[i].iov_len); | ||||
|     } | ||||
|     return APIError::OK; | ||||
|   } | ||||
|  | ||||
|   ssize_t sent = socket_->writev(iov, iovcnt); | ||||
|   if (is_would_block(sent)) { | ||||
|     // operation would block, add buffer to tx_buf | ||||
|     // Reserve space upfront to avoid multiple reallocations | ||||
|     tx_buf_.reserve(tx_buf_.size() + total_write_len); | ||||
|     for (int i = 0; i < iovcnt; i++) { | ||||
|       tx_buf_.insert(tx_buf_.end(), reinterpret_cast<uint8_t *>(iov[i].iov_base), | ||||
|                      reinterpret_cast<uint8_t *>(iov[i].iov_base) + iov[i].iov_len); | ||||
|     } | ||||
|     return APIError::OK; | ||||
|   } else if (sent == -1) { | ||||
|     // an error occurred | ||||
|     state_ = State::FAILED; | ||||
|     HELPER_LOG("Socket write failed with errno %d", errno); | ||||
|     return APIError::SOCKET_WRITE_FAILED; | ||||
|   } else if ((size_t) sent != total_write_len) { | ||||
|     // partially sent, add end to tx_buf | ||||
|     size_t remaining = total_write_len - sent; | ||||
|     // Reserve space upfront to avoid multiple reallocations | ||||
|     tx_buf_.reserve(tx_buf_.size() + remaining); | ||||
|  | ||||
|     size_t to_consume = sent; | ||||
|     for (int i = 0; i < iovcnt; i++) { | ||||
|       if (to_consume >= iov[i].iov_len) { | ||||
|         to_consume -= iov[i].iov_len; | ||||
|       } else { | ||||
|         tx_buf_.insert(tx_buf_.end(), reinterpret_cast<uint8_t *>(iov[i].iov_base) + to_consume, | ||||
|                        reinterpret_cast<uint8_t *>(iov[i].iov_base) + iov[i].iov_len); | ||||
|         to_consume = 0; | ||||
|       } | ||||
|     } | ||||
|     return APIError::OK; | ||||
|   } | ||||
|   // fully sent | ||||
|   return APIError::OK; | ||||
| } | ||||
| APIError APINoiseFrameHelper::write_frame_(const uint8_t *data, size_t len) { | ||||
|   uint8_t header[3]; | ||||
|   header[0] = 0x01;  // indicator | ||||
| @@ -753,6 +765,11 @@ void noise_rand_bytes(void *output, size_t len) { | ||||
|   } | ||||
| } | ||||
| } | ||||
|  | ||||
| // Explicit template instantiation for Noise | ||||
| template APIError APIFrameHelper::write_raw_<APINoiseFrameHelper::State>( | ||||
|     const struct iovec *iov, int iovcnt, socket::Socket *socket, std::vector<uint8_t> &tx_buf_, const std::string &info, | ||||
|     APINoiseFrameHelper::State &state, APINoiseFrameHelper::State failed_state); | ||||
| #endif  // USE_API_NOISE | ||||
|  | ||||
| #ifdef USE_API_PLAINTEXT | ||||
| @@ -977,79 +994,6 @@ APIError APIPlaintextFrameHelper::try_send_tx_buf_() { | ||||
|  | ||||
|   return APIError::OK; | ||||
| } | ||||
| /** Write the data to the socket, or buffer it a write would block | ||||
|  * | ||||
|  * @param data The data to write | ||||
|  * @param len The length of data | ||||
|  */ | ||||
| APIError APIPlaintextFrameHelper::write_raw_(const struct iovec *iov, int iovcnt) { | ||||
|   if (iovcnt == 0) | ||||
|     return APIError::OK; | ||||
|   APIError aerr; | ||||
|  | ||||
|   size_t total_write_len = 0; | ||||
|   for (int i = 0; i < iovcnt; i++) { | ||||
| #ifdef HELPER_LOG_PACKETS | ||||
|     ESP_LOGVV(TAG, "Sending raw: %s", | ||||
|               format_hex_pretty(reinterpret_cast<uint8_t *>(iov[i].iov_base), iov[i].iov_len).c_str()); | ||||
| #endif | ||||
|     total_write_len += iov[i].iov_len; | ||||
|   } | ||||
|  | ||||
|   if (!tx_buf_.empty()) { | ||||
|     // try to empty tx_buf_ first | ||||
|     aerr = try_send_tx_buf_(); | ||||
|     if (aerr != APIError::OK && aerr != APIError::WOULD_BLOCK) | ||||
|       return aerr; | ||||
|   } | ||||
|  | ||||
|   if (!tx_buf_.empty()) { | ||||
|     // tx buf not empty, can't write now because then stream would be inconsistent | ||||
|     // Reserve space upfront to avoid multiple reallocations | ||||
|     tx_buf_.reserve(tx_buf_.size() + total_write_len); | ||||
|     for (int i = 0; i < iovcnt; i++) { | ||||
|       tx_buf_.insert(tx_buf_.end(), reinterpret_cast<uint8_t *>(iov[i].iov_base), | ||||
|                      reinterpret_cast<uint8_t *>(iov[i].iov_base) + iov[i].iov_len); | ||||
|     } | ||||
|     return APIError::OK; | ||||
|   } | ||||
|  | ||||
|   ssize_t sent = socket_->writev(iov, iovcnt); | ||||
|   if (is_would_block(sent)) { | ||||
|     // operation would block, add buffer to tx_buf | ||||
|     // Reserve space upfront to avoid multiple reallocations | ||||
|     tx_buf_.reserve(tx_buf_.size() + total_write_len); | ||||
|     for (int i = 0; i < iovcnt; i++) { | ||||
|       tx_buf_.insert(tx_buf_.end(), reinterpret_cast<uint8_t *>(iov[i].iov_base), | ||||
|                      reinterpret_cast<uint8_t *>(iov[i].iov_base) + iov[i].iov_len); | ||||
|     } | ||||
|     return APIError::OK; | ||||
|   } else if (sent == -1) { | ||||
|     // an error occurred | ||||
|     state_ = State::FAILED; | ||||
|     HELPER_LOG("Socket write failed with errno %d", errno); | ||||
|     return APIError::SOCKET_WRITE_FAILED; | ||||
|   } else if ((size_t) sent != total_write_len) { | ||||
|     // partially sent, add end to tx_buf | ||||
|     size_t remaining = total_write_len - sent; | ||||
|     // Reserve space upfront to avoid multiple reallocations | ||||
|     tx_buf_.reserve(tx_buf_.size() + remaining); | ||||
|  | ||||
|     size_t to_consume = sent; | ||||
|     for (int i = 0; i < iovcnt; i++) { | ||||
|       if (to_consume >= iov[i].iov_len) { | ||||
|         to_consume -= iov[i].iov_len; | ||||
|       } else { | ||||
|         tx_buf_.insert(tx_buf_.end(), reinterpret_cast<uint8_t *>(iov[i].iov_base) + to_consume, | ||||
|                        reinterpret_cast<uint8_t *>(iov[i].iov_base) + iov[i].iov_len); | ||||
|         to_consume = 0; | ||||
|       } | ||||
|     } | ||||
|     return APIError::OK; | ||||
|   } | ||||
|   // fully sent | ||||
|   return APIError::OK; | ||||
| } | ||||
|  | ||||
| APIError APIPlaintextFrameHelper::close() { | ||||
|   state_ = State::CLOSED; | ||||
| @@ -1067,6 +1011,11 @@ APIError APIPlaintextFrameHelper::shutdown(int how) { | ||||
|   } | ||||
|   return APIError::OK; | ||||
| } | ||||
|  | ||||
| // Explicit template instantiation for Plaintext | ||||
| template APIError APIFrameHelper::write_raw_<APIPlaintextFrameHelper::State>( | ||||
|     const struct iovec *iov, int iovcnt, socket::Socket *socket, std::vector<uint8_t> &tx_buf_, const std::string &info, | ||||
|     APIPlaintextFrameHelper::State &state, APIPlaintextFrameHelper::State failed_state); | ||||
| #endif  // USE_API_PLAINTEXT | ||||
|  | ||||
| }  // namespace api | ||||
|   | ||||
| @@ -72,6 +72,12 @@ class APIFrameHelper { | ||||
|   virtual APIError shutdown(int how) = 0; | ||||
|   // Give this helper a name for logging | ||||
|   virtual void set_log_info(std::string info) = 0; | ||||
|  | ||||
|  protected: | ||||
|   // Common implementation for writing raw data to socket | ||||
|   template<typename StateEnum> | ||||
|   APIError write_raw_(const struct iovec *iov, int iovcnt, socket::Socket *socket, std::vector<uint8_t> &tx_buf, | ||||
|                       const std::string &info, StateEnum &state, StateEnum failed_state); | ||||
| }; | ||||
|  | ||||
| #ifdef USE_API_NOISE | ||||
| @@ -103,7 +109,9 @@ class APINoiseFrameHelper : public APIFrameHelper { | ||||
|   APIError try_read_frame_(ParsedFrame *frame); | ||||
|   APIError try_send_tx_buf_(); | ||||
|   APIError write_frame_(const uint8_t *data, size_t len); | ||||
|   APIError write_raw_(const struct iovec *iov, int iovcnt); | ||||
|   inline APIError write_raw_(const struct iovec *iov, int iovcnt) { | ||||
|     return APIFrameHelper::write_raw_(iov, iovcnt, socket_.get(), tx_buf_, info_, state_, State::FAILED); | ||||
|   } | ||||
|   APIError init_handshake_(); | ||||
|   APIError check_handshake_finished_(); | ||||
|   void send_explicit_handshake_reject_(const std::string &reason); | ||||
| @@ -164,7 +172,9 @@ class APIPlaintextFrameHelper : public APIFrameHelper { | ||||
|  | ||||
|   APIError try_read_frame_(ParsedFrame *frame); | ||||
|   APIError try_send_tx_buf_(); | ||||
|   APIError write_raw_(const struct iovec *iov, int iovcnt); | ||||
|   inline APIError write_raw_(const struct iovec *iov, int iovcnt) { | ||||
|     return APIFrameHelper::write_raw_(iov, iovcnt, socket_.get(), tx_buf_, info_, state_, State::FAILED); | ||||
|   } | ||||
|  | ||||
|   std::unique_ptr<socket::Socket> socket_; | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user