diff --git a/esphome/components/improv_serial/improv_serial_component.cpp b/esphome/components/improv_serial/improv_serial_component.cpp index ce82504d3c..9d080ea98e 100644 --- a/esphome/components/improv_serial/improv_serial_component.cpp +++ b/esphome/components/improv_serial/improv_serial_component.cpp @@ -28,6 +28,38 @@ void ImprovSerialComponent::setup() { } } +void ImprovSerialComponent::loop() { + if (this->last_read_byte_ && (millis() - this->last_read_byte_ > IMPROV_SERIAL_TIMEOUT)) { + this->last_read_byte_ = 0; + this->rx_buffer_.clear(); + ESP_LOGV(TAG, "Timeout"); + } + + auto byte = this->read_byte_(); + while (byte.has_value()) { + if (this->parse_improv_serial_byte_(byte.value())) { + this->last_read_byte_ = millis(); + } else { + this->last_read_byte_ = 0; + this->rx_buffer_.clear(); + } + byte = this->read_byte_(); + } + + if (this->state_ == improv::STATE_PROVISIONING) { + if (wifi::global_wifi_component->is_connected()) { + wifi::global_wifi_component->save_wifi_sta(this->connecting_sta_.get_ssid(), + this->connecting_sta_.get_password()); + this->connecting_sta_ = {}; + this->cancel_timeout("wifi-connect-timeout"); + this->set_state_(improv::STATE_PROVISIONED); + + std::vector url = this->build_rpc_settings_response_(improv::WIFI_SETTINGS); + this->send_response_(url); + } + } +} + void ImprovSerialComponent::dump_config() { ESP_LOGCONFIG(TAG, "Improv Serial:"); } optional ImprovSerialComponent::read_byte_() { @@ -78,8 +110,28 @@ optional ImprovSerialComponent::read_byte_() { return byte; } -void ImprovSerialComponent::write_data_(std::vector &data) { - data.push_back('\n'); +void ImprovSerialComponent::write_data_(const uint8_t *data, const size_t size) { + // First, set length field + this->tx_header_[TX_LENGTH_IDX] = this->tx_header_[TX_TYPE_IDX] == TYPE_RPC_RESPONSE ? size : 1; + + const bool there_is_data = data != nullptr && size > 0; + // If there_is_data, checksum must not include our optional data byte + const uint8_t header_checksum_len = there_is_data ? TX_BUFFER_SIZE - 3 : TX_BUFFER_SIZE - 2; + // Only transmit the full buffer length if there is no data (only state/error byte is provided in this case) + const uint8_t header_tx_len = there_is_data ? TX_BUFFER_SIZE - 3 : TX_BUFFER_SIZE; + // Calculate checksum for message + uint8_t checksum = 0; + for (uint8_t i = 0; i < header_checksum_len; i++) { + checksum += this->tx_header_[i]; + } + if (there_is_data) { + // Include data in checksum + for (size_t i = 0; i < size; i++) { + checksum += data[i]; + } + } + this->tx_header_[TX_CHECKSUM_IDX] = checksum; + #ifdef USE_ESP32 switch (logger::global_logger->get_uart()) { case logger::UART_SELECTION_UART0: @@ -87,63 +139,45 @@ void ImprovSerialComponent::write_data_(std::vector &data) { #if !defined(USE_ESP32_VARIANT_ESP32C3) && !defined(USE_ESP32_VARIANT_ESP32C6) && \ !defined(USE_ESP32_VARIANT_ESP32S2) && !defined(USE_ESP32_VARIANT_ESP32S3) case logger::UART_SELECTION_UART2: -#endif // !USE_ESP32_VARIANT_ESP32C3 && !USE_ESP32_VARIANT_ESP32S2 && !USE_ESP32_VARIANT_ESP32S3 - uart_write_bytes(this->uart_num_, data.data(), data.size()); +#endif + uart_write_bytes(this->uart_num_, this->tx_header_, header_tx_len); + if (there_is_data) { + uart_write_bytes(this->uart_num_, data, size); + uart_write_bytes(this->uart_num_, &this->tx_header_[TX_CHECKSUM_IDX], 2); // Footer: checksum and newline + } break; #if defined(USE_LOGGER_USB_CDC) && defined(CONFIG_ESP_CONSOLE_USB_CDC) - case logger::UART_SELECTION_USB_CDC: { - const char *msg = (char *) data.data(); - esp_usb_console_write_buf(msg, data.size()); + case logger::UART_SELECTION_USB_CDC: + esp_usb_console_write_buf((const char *) this->tx_header_, header_tx_len); + if (there_is_data) { + esp_usb_console_write_buf((const char *) data, size); + esp_usb_console_write_buf((const char *) &this->tx_header_[TX_CHECKSUM_IDX], + 2); // Footer: checksum and newline + } break; - } -#endif // USE_LOGGER_USB_CDC +#endif #ifdef USE_LOGGER_USB_SERIAL_JTAG case logger::UART_SELECTION_USB_SERIAL_JTAG: - usb_serial_jtag_write_bytes((char *) data.data(), data.size(), 20 / portTICK_PERIOD_MS); - delay(10); - usb_serial_jtag_ll_txfifo_flush(); // fixes for issue in IDF 4.4.7 + usb_serial_jtag_write_bytes((const char *) this->tx_header_, header_tx_len, 20 / portTICK_PERIOD_MS); + if (there_is_data) { + usb_serial_jtag_write_bytes((const char *) data, size, 20 / portTICK_PERIOD_MS); + usb_serial_jtag_write_bytes((const char *) &this->tx_header_[TX_CHECKSUM_IDX], 2, + 20 / portTICK_PERIOD_MS); // Footer: checksum and newline + } break; -#endif // USE_LOGGER_USB_SERIAL_JTAG +#endif default: break; } #elif defined(USE_ARDUINO) - this->hw_serial_->write(data.data(), data.size()); + this->hw_serial_->write(this->tx_header_, header_tx_len); + if (there_is_data) { + this->hw_serial_->write(data, size); + this->hw_serial_->write(&this->tx_header_[TX_CHECKSUM_IDX], 2); // Footer: checksum and newline + } #endif } -void ImprovSerialComponent::loop() { - if (this->last_read_byte_ && (millis() - this->last_read_byte_ > IMPROV_SERIAL_TIMEOUT)) { - this->last_read_byte_ = 0; - this->rx_buffer_.clear(); - ESP_LOGV(TAG, "Improv Serial timeout"); - } - - auto byte = this->read_byte_(); - while (byte.has_value()) { - if (this->parse_improv_serial_byte_(byte.value())) { - this->last_read_byte_ = millis(); - } else { - this->last_read_byte_ = 0; - this->rx_buffer_.clear(); - } - byte = this->read_byte_(); - } - - if (this->state_ == improv::STATE_PROVISIONING) { - if (wifi::global_wifi_component->is_connected()) { - wifi::global_wifi_component->save_wifi_sta(this->connecting_sta_.get_ssid(), - this->connecting_sta_.get_password()); - this->connecting_sta_ = {}; - this->cancel_timeout("wifi-connect-timeout"); - this->set_state_(improv::STATE_PROVISIONED); - - std::vector url = this->build_rpc_settings_response_(improv::WIFI_SETTINGS); - this->send_response_(url); - } - } -} - std::vector ImprovSerialComponent::build_rpc_settings_response_(improv::Command command) { std::vector urls; #ifdef USE_IMPROV_SERIAL_NEXT_URL @@ -177,13 +211,13 @@ std::vector ImprovSerialComponent::build_version_info_() { bool ImprovSerialComponent::parse_improv_serial_byte_(uint8_t byte) { size_t at = this->rx_buffer_.size(); this->rx_buffer_.push_back(byte); - ESP_LOGV(TAG, "Improv Serial byte: 0x%02X", byte); + ESP_LOGV(TAG, "Byte: 0x%02X", byte); const uint8_t *raw = &this->rx_buffer_[0]; return improv::parse_improv_serial_byte( at, byte, raw, [this](improv::ImprovCommand command) -> bool { return this->parse_improv_payload_(command); }, [this](improv::Error error) -> void { - ESP_LOGW(TAG, "Error decoding Improv payload"); + ESP_LOGW(TAG, "Error decoding payload"); this->set_error_(error); }); } @@ -199,7 +233,7 @@ bool ImprovSerialComponent::parse_improv_payload_(improv::ImprovCommand &command wifi::global_wifi_component->set_sta(sta); wifi::global_wifi_component->start_connecting(sta, false); this->set_state_(improv::STATE_PROVISIONING); - ESP_LOGD(TAG, "Received Improv wifi settings ssid=%s, password=" LOG_SECRET("%s"), command.ssid.c_str(), + ESP_LOGD(TAG, "Received settings: SSID=%s, password=" LOG_SECRET("%s"), command.ssid.c_str(), command.password.c_str()); auto f = std::bind(&ImprovSerialComponent::on_wifi_connect_timeout_, this); @@ -240,7 +274,7 @@ bool ImprovSerialComponent::parse_improv_payload_(improv::ImprovCommand &command return true; } default: { - ESP_LOGW(TAG, "Unknown Improv payload"); + ESP_LOGW(TAG, "Unknown payload"); this->set_error_(improv::ERROR_UNKNOWN_RPC); return false; } @@ -249,57 +283,26 @@ bool ImprovSerialComponent::parse_improv_payload_(improv::ImprovCommand &command void ImprovSerialComponent::set_state_(improv::State state) { this->state_ = state; - - std::vector data = {'I', 'M', 'P', 'R', 'O', 'V'}; - data.resize(11); - data[6] = IMPROV_SERIAL_VERSION; - data[7] = TYPE_CURRENT_STATE; - data[8] = 1; - data[9] = state; - - uint8_t checksum = 0x00; - for (uint8_t d : data) - checksum += d; - data[10] = checksum; - - this->write_data_(data); + this->tx_header_[TX_TYPE_IDX] = TYPE_CURRENT_STATE; + this->tx_header_[TX_DATA_IDX] = state; + this->write_data_(); } void ImprovSerialComponent::set_error_(improv::Error error) { - std::vector data = {'I', 'M', 'P', 'R', 'O', 'V'}; - data.resize(11); - data[6] = IMPROV_SERIAL_VERSION; - data[7] = TYPE_ERROR_STATE; - data[8] = 1; - data[9] = error; - - uint8_t checksum = 0x00; - for (uint8_t d : data) - checksum += d; - data[10] = checksum; - this->write_data_(data); + this->tx_header_[TX_TYPE_IDX] = TYPE_ERROR_STATE; + this->tx_header_[TX_DATA_IDX] = error; + this->write_data_(); } void ImprovSerialComponent::send_response_(std::vector &response) { - std::vector data = {'I', 'M', 'P', 'R', 'O', 'V'}; - data.resize(9); - data[6] = IMPROV_SERIAL_VERSION; - data[7] = TYPE_RPC_RESPONSE; - data[8] = response.size(); - data.insert(data.end(), response.begin(), response.end()); - - uint8_t checksum = 0x00; - for (uint8_t d : data) - checksum += d; - data.push_back(checksum); - - this->write_data_(data); + this->tx_header_[TX_TYPE_IDX] = TYPE_RPC_RESPONSE; + this->write_data_(response.data(), response.size()); } void ImprovSerialComponent::on_wifi_connect_timeout_() { this->set_error_(improv::ERROR_UNABLE_TO_CONNECT); this->set_state_(improv::STATE_AUTHORIZED); - ESP_LOGW(TAG, "Timed out trying to connect to given WiFi network"); + ESP_LOGW(TAG, "Timed out while connecting to Wi-Fi network"); wifi::global_wifi_component->clear_sta(); } diff --git a/esphome/components/improv_serial/improv_serial_component.h b/esphome/components/improv_serial/improv_serial_component.h index c3c9aee24e..057247f376 100644 --- a/esphome/components/improv_serial/improv_serial_component.h +++ b/esphome/components/improv_serial/improv_serial_component.h @@ -26,6 +26,16 @@ namespace esphome { namespace improv_serial { +// TX buffer layout constants +static constexpr uint8_t TX_HEADER_SIZE = 6; // Bytes 0-5 = "IMPROV" +static constexpr uint8_t TX_VERSION_IDX = 6; +static constexpr uint8_t TX_TYPE_IDX = 7; +static constexpr uint8_t TX_LENGTH_IDX = 8; +static constexpr uint8_t TX_DATA_IDX = 9; // For state/error messages only +static constexpr uint8_t TX_CHECKSUM_IDX = 10; +static constexpr uint8_t TX_NEWLINE_IDX = 11; +static constexpr uint8_t TX_BUFFER_SIZE = 12; + enum ImprovSerialType : uint8_t { TYPE_CURRENT_STATE = 0x01, TYPE_ERROR_STATE = 0x02, @@ -57,7 +67,22 @@ class ImprovSerialComponent : public Component, public improv_base::ImprovBase { std::vector build_version_info_(); optional read_byte_(); - void write_data_(std::vector &data); + void write_data_(const uint8_t *data = nullptr, size_t size = 0); + + uint8_t tx_header_[TX_BUFFER_SIZE] = { + 'I', // 0: Header + 'M', // 1: Header + 'P', // 2: Header + 'R', // 3: Header + 'O', // 4: Header + 'V', // 5: Header + IMPROV_SERIAL_VERSION, // 6: Version + 0, // 7: ImprovSerialType + 0, // 8: Length + 0, // 9...X: Data (here, one byte reserved for state/error) + 0, // X + 10: Checksum + '\n', + }; #ifdef USE_ESP32 uart_port_t uart_num_; diff --git a/tests/components/improv_base/common-uart0.yaml b/tests/components/improv_base/common-uart0.yaml new file mode 100644 index 0000000000..7b7730fd46 --- /dev/null +++ b/tests/components/improv_base/common-uart0.yaml @@ -0,0 +1,8 @@ +wifi: + ssid: MySSID + password: password1 + +logger: + hardware_uart: UART0 + +improv_serial: diff --git a/tests/components/improv_base/test-uart0.esp8266-ard.yaml b/tests/components/improv_base/test-uart0.esp8266-ard.yaml new file mode 100644 index 0000000000..ef8c799241 --- /dev/null +++ b/tests/components/improv_base/test-uart0.esp8266-ard.yaml @@ -0,0 +1 @@ +<<: !include common-uart0.yaml