From 97823ddd16abfaae02535ca5b1dcb3f5ccda3901 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 22 Apr 2025 08:09:28 -1000 Subject: [PATCH 001/102] Rewrite BLE scanner to use a state machine (#8601) --- .../esp32_ble_tracker/esp32_ble_tracker.cpp | 292 ++++++++++-------- .../esp32_ble_tracker/esp32_ble_tracker.h | 18 +- 2 files changed, 187 insertions(+), 123 deletions(-) diff --git a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp index 760aac628a..34d4e6727a 100644 --- a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp +++ b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp @@ -57,7 +57,6 @@ void ESP32BLETracker::setup() { global_esp32_ble_tracker = this; this->scan_result_lock_ = xSemaphoreCreateMutex(); - this->scan_end_lock_ = xSemaphoreCreateMutex(); #ifdef USE_OTA ota::get_global_ota_callback()->add_on_state_callback( @@ -117,119 +116,104 @@ void ESP32BLETracker::loop() { } bool promote_to_connecting = discovered && !searching && !connecting; - if (!this->scanner_idle_) { - if (this->scan_result_index_ && // if it looks like we have a scan result we will take the lock - xSemaphoreTake(this->scan_result_lock_, 5L / portTICK_PERIOD_MS)) { - uint32_t index = this->scan_result_index_; - if (index >= ESP32BLETracker::SCAN_RESULT_BUFFER_SIZE) { - ESP_LOGW(TAG, "Too many BLE events to process. Some devices may not show up."); - } + if (this->scanner_state_ == ScannerState::RUNNING && + this->scan_result_index_ && // if it looks like we have a scan result we will take the lock + xSemaphoreTake(this->scan_result_lock_, 5L / portTICK_PERIOD_MS)) { + uint32_t index = this->scan_result_index_; + if (index >= ESP32BLETracker::SCAN_RESULT_BUFFER_SIZE) { + ESP_LOGW(TAG, "Too many BLE events to process. Some devices may not show up."); + } - if (this->raw_advertisements_) { + if (this->raw_advertisements_) { + for (auto *listener : this->listeners_) { + listener->parse_devices(this->scan_result_buffer_, this->scan_result_index_); + } + for (auto *client : this->clients_) { + client->parse_devices(this->scan_result_buffer_, this->scan_result_index_); + } + } + + if (this->parse_advertisements_) { + for (size_t i = 0; i < index; i++) { + ESPBTDevice device; + device.parse_scan_rst(this->scan_result_buffer_[i]); + + bool found = false; for (auto *listener : this->listeners_) { - listener->parse_devices(this->scan_result_buffer_, this->scan_result_index_); + if (listener->parse_device(device)) + found = true; } + for (auto *client : this->clients_) { - client->parse_devices(this->scan_result_buffer_, this->scan_result_index_); - } - } - - if (this->parse_advertisements_) { - for (size_t i = 0; i < index; i++) { - ESPBTDevice device; - device.parse_scan_rst(this->scan_result_buffer_[i]); - - bool found = false; - for (auto *listener : this->listeners_) { - if (listener->parse_device(device)) - found = true; - } - - for (auto *client : this->clients_) { - if (client->parse_device(device)) { - found = true; - if (!connecting && client->state() == ClientState::DISCOVERED) { - promote_to_connecting = true; - } + if (client->parse_device(device)) { + found = true; + if (!connecting && client->state() == ClientState::DISCOVERED) { + promote_to_connecting = true; } } + } - if (!found && !this->scan_continuous_) { - this->print_bt_device_info(device); - } + if (!found && !this->scan_continuous_) { + this->print_bt_device_info(device); } } - this->scan_result_index_ = 0; - xSemaphoreGive(this->scan_result_lock_); } - - /* - - Avoid starting the scanner if: - - we are already scanning - - we are connecting to a device - - we are disconnecting from a device - - Otherwise the scanner could fail to ever start again - and our only way to recover is to reboot. - - https://github.com/espressif/esp-idf/issues/6688 - - */ - if (!connecting && xSemaphoreTake(this->scan_end_lock_, 0L)) { - if (this->scan_continuous_) { - if (!disconnecting && !promote_to_connecting && !this->scan_start_failed_ && !this->scan_set_param_failed_) { - this->start_scan_(false); - } else { - // We didn't start the scan, so we need to release the lock - xSemaphoreGive(this->scan_end_lock_); - } - } else if (!this->scanner_idle_) { - this->end_of_scan_(); - return; - } + this->scan_result_index_ = 0; + xSemaphoreGive(this->scan_result_lock_); + } + if (this->scanner_state_ == ScannerState::STOPPED) { + this->end_of_scan_(); // Change state to IDLE + } + if (this->scanner_state_ == ScannerState::FAILED || + (this->scan_set_param_failed_ && this->scanner_state_ == ScannerState::RUNNING)) { + this->stop_scan_(); + if (this->scan_start_fail_count_ == std::numeric_limits::max()) { + ESP_LOGE(TAG, "ESP-IDF BLE scan could not restart after %d attempts, rebooting to restore BLE stack...", + std::numeric_limits::max()); + App.reboot(); } - - if (this->scan_start_failed_ || this->scan_set_param_failed_) { - if (this->scan_start_fail_count_ == std::numeric_limits::max()) { - ESP_LOGE(TAG, "ESP-IDF BLE scan could not restart after %d attempts, rebooting to restore BLE stack...", - std::numeric_limits::max()); - App.reboot(); - } - if (xSemaphoreTake(this->scan_end_lock_, 0L)) { - xSemaphoreGive(this->scan_end_lock_); - } else { - ESP_LOGD(TAG, "Stopping scan after failure..."); - this->stop_scan_(); - } - if (this->scan_start_failed_) { - ESP_LOGE(TAG, "Scan start failed: %d", this->scan_start_failed_); - this->scan_start_failed_ = ESP_BT_STATUS_SUCCESS; - } - if (this->scan_set_param_failed_) { - ESP_LOGE(TAG, "Scan set param failed: %d", this->scan_set_param_failed_); - this->scan_set_param_failed_ = ESP_BT_STATUS_SUCCESS; - } + if (this->scan_start_failed_) { + ESP_LOGE(TAG, "Scan start failed: %d", this->scan_start_failed_); + this->scan_start_failed_ = ESP_BT_STATUS_SUCCESS; + } + if (this->scan_set_param_failed_) { + ESP_LOGE(TAG, "Scan set param failed: %d", this->scan_set_param_failed_); + this->scan_set_param_failed_ = ESP_BT_STATUS_SUCCESS; } } + /* + Avoid starting the scanner if: + - we are already scanning + - we are connecting to a device + - we are disconnecting from a device + + Otherwise the scanner could fail to ever start again + and our only way to recover is to reboot. + + https://github.com/espressif/esp-idf/issues/6688 + + */ + if (this->scanner_state_ == ScannerState::IDLE && this->scan_continuous_ && !connecting && !disconnecting && + !promote_to_connecting) { + this->start_scan_(false); // first = false + } // If there is a discovered client and no connecting // clients and no clients using the scanner to search for // devices, then stop scanning and promote the discovered // client to ready to connect. - if (promote_to_connecting) { + if (promote_to_connecting && + (this->scanner_state_ == ScannerState::RUNNING || this->scanner_state_ == ScannerState::IDLE)) { for (auto *client : this->clients_) { if (client->state() == ClientState::DISCOVERED) { - if (xSemaphoreTake(this->scan_end_lock_, 0L)) { - // Scanner is not running since we got the - // lock, so we can promote the client. - xSemaphoreGive(this->scan_end_lock_); + if (this->scanner_state_ == ScannerState::RUNNING) { + ESP_LOGD(TAG, "Stopping scan to make connection..."); + this->stop_scan_(); + } else if (this->scanner_state_ == ScannerState::IDLE) { + ESP_LOGD(TAG, "Promoting client to connect..."); // We only want to promote one client at a time. // once the scanner is fully stopped. client->set_state(ClientState::READY_TO_CONNECT); - } else { - ESP_LOGD(TAG, "Pausing scan to make connection..."); - this->stop_scan_(); } break; } @@ -237,13 +221,7 @@ void ESP32BLETracker::loop() { } } -void ESP32BLETracker::start_scan() { - if (xSemaphoreTake(this->scan_end_lock_, 0L)) { - this->start_scan_(true); - } else { - ESP_LOGW(TAG, "Scan requested when a scan is already in progress. Ignoring."); - } -} +void ESP32BLETracker::start_scan() { this->start_scan_(true); } void ESP32BLETracker::stop_scan() { ESP_LOGD(TAG, "Stopping scan."); @@ -251,16 +229,23 @@ void ESP32BLETracker::stop_scan() { this->stop_scan_(); } -void ESP32BLETracker::ble_before_disabled_event_handler() { - this->stop_scan_(); - xSemaphoreGive(this->scan_end_lock_); -} +void ESP32BLETracker::ble_before_disabled_event_handler() { this->stop_scan_(); } void ESP32BLETracker::stop_scan_() { - this->cancel_timeout("scan"); - if (this->scanner_idle_) { + if (this->scanner_state_ != ScannerState::RUNNING && this->scanner_state_ != ScannerState::FAILED) { + if (this->scanner_state_ == ScannerState::IDLE) { + ESP_LOGE(TAG, "Scan is already stopped while trying to stop."); + } else if (this->scanner_state_ == ScannerState::STARTING) { + ESP_LOGE(TAG, "Scan is starting while trying to stop."); + } else if (this->scanner_state_ == ScannerState::STOPPING) { + ESP_LOGE(TAG, "Scan is already stopping while trying to stop."); + } else if (this->scanner_state_ == ScannerState::STOPPED) { + ESP_LOGE(TAG, "Scan is already stopped while trying to stop."); + } return; } + this->cancel_timeout("scan"); + this->scanner_state_ = ScannerState::STOPPING; esp_err_t err = esp_ble_gap_stop_scanning(); if (err != ESP_OK) { ESP_LOGE(TAG, "esp_ble_gap_stop_scanning failed: %d", err); @@ -273,13 +258,22 @@ void ESP32BLETracker::start_scan_(bool first) { ESP_LOGW(TAG, "Cannot start scan while ESP32BLE is disabled."); return; } - // The lock must be held when calling this function. - if (xSemaphoreTake(this->scan_end_lock_, 0L)) { - ESP_LOGE(TAG, "start_scan called without holding scan_end_lock_"); + if (this->scanner_state_ != ScannerState::IDLE) { + if (this->scanner_state_ == ScannerState::STARTING) { + ESP_LOGE(TAG, "Cannot start scan while already starting."); + } else if (this->scanner_state_ == ScannerState::RUNNING) { + ESP_LOGE(TAG, "Cannot start scan while already running."); + } else if (this->scanner_state_ == ScannerState::STOPPING) { + ESP_LOGE(TAG, "Cannot start scan while already stopping."); + } else if (this->scanner_state_ == ScannerState::FAILED) { + ESP_LOGE(TAG, "Cannot start scan while already failed."); + } else if (this->scanner_state_ == ScannerState::STOPPED) { + ESP_LOGE(TAG, "Cannot start scan while already stopped."); + } return; } - - ESP_LOGD(TAG, "Starting scan..."); + this->scanner_state_ = ScannerState::STARTING; + ESP_LOGD(TAG, "Starting scan, set scanner state to STARTING."); if (!first) { for (auto *listener : this->listeners_) listener->on_scan_end(); @@ -307,24 +301,21 @@ void ESP32BLETracker::start_scan_(bool first) { ESP_LOGE(TAG, "esp_ble_gap_start_scanning failed: %d", err); return; } - this->scanner_idle_ = false; } void ESP32BLETracker::end_of_scan_() { // The lock must be held when calling this function. - if (xSemaphoreTake(this->scan_end_lock_, 0L)) { - ESP_LOGE(TAG, "end_of_scan_ called without holding the scan_end_lock_"); + if (this->scanner_state_ != ScannerState::STOPPED) { + ESP_LOGE(TAG, "end_of_scan_ called while scanner is not stopped."); return; } - - ESP_LOGD(TAG, "End of scan."); - this->scanner_idle_ = true; + ESP_LOGD(TAG, "End of scan, set scanner state to IDLE."); this->already_discovered_.clear(); - xSemaphoreGive(this->scan_end_lock_); this->cancel_timeout("scan"); for (auto *listener : this->listeners_) listener->on_scan_end(); + this->scanner_state_ = ScannerState::IDLE; } void ESP32BLETracker::register_client(ESPBTClient *client) { @@ -392,19 +383,46 @@ void ESP32BLETracker::gap_scan_set_param_complete_(const esp_ble_gap_cb_param_t: void ESP32BLETracker::gap_scan_start_complete_(const esp_ble_gap_cb_param_t::ble_scan_start_cmpl_evt_param ¶m) { ESP_LOGV(TAG, "gap_scan_start_complete - status %d", param.status); this->scan_start_failed_ = param.status; + if (this->scanner_state_ != ScannerState::STARTING) { + if (this->scanner_state_ == ScannerState::RUNNING) { + ESP_LOGE(TAG, "Scan was already running when start complete."); + } else if (this->scanner_state_ == ScannerState::STOPPING) { + ESP_LOGE(TAG, "Scan was stopping when start complete."); + } else if (this->scanner_state_ == ScannerState::FAILED) { + ESP_LOGE(TAG, "Scan was in failed state when start complete."); + } else if (this->scanner_state_ == ScannerState::IDLE) { + ESP_LOGE(TAG, "Scan was idle when start complete."); + } else if (this->scanner_state_ == ScannerState::STOPPED) { + ESP_LOGE(TAG, "Scan was stopped when start complete."); + } + } if (param.status == ESP_BT_STATUS_SUCCESS) { this->scan_start_fail_count_ = 0; + this->scanner_state_ = ScannerState::RUNNING; } else { + this->scanner_state_ = ScannerState::FAILED; if (this->scan_start_fail_count_ != std::numeric_limits::max()) { this->scan_start_fail_count_++; } - xSemaphoreGive(this->scan_end_lock_); } } void ESP32BLETracker::gap_scan_stop_complete_(const esp_ble_gap_cb_param_t::ble_scan_stop_cmpl_evt_param ¶m) { ESP_LOGV(TAG, "gap_scan_stop_complete - status %d", param.status); - xSemaphoreGive(this->scan_end_lock_); + if (this->scanner_state_ != ScannerState::STOPPING) { + if (this->scanner_state_ == ScannerState::RUNNING) { + ESP_LOGE(TAG, "Scan was not running when stop complete."); + } else if (this->scanner_state_ == ScannerState::STARTING) { + ESP_LOGE(TAG, "Scan was not started when stop complete."); + } else if (this->scanner_state_ == ScannerState::FAILED) { + ESP_LOGE(TAG, "Scan was in failed state when stop complete."); + } else if (this->scanner_state_ == ScannerState::IDLE) { + ESP_LOGE(TAG, "Scan was idle when stop complete."); + } else if (this->scanner_state_ == ScannerState::STOPPED) { + ESP_LOGE(TAG, "Scan was stopped when stop complete."); + } + } + this->scanner_state_ = ScannerState::STOPPED; } void ESP32BLETracker::gap_scan_result_(const esp_ble_gap_cb_param_t::ble_scan_result_evt_param ¶m) { @@ -417,7 +435,21 @@ void ESP32BLETracker::gap_scan_result_(const esp_ble_gap_cb_param_t::ble_scan_re xSemaphoreGive(this->scan_result_lock_); } } else if (param.search_evt == ESP_GAP_SEARCH_INQ_CMPL_EVT) { - xSemaphoreGive(this->scan_end_lock_); + // Scan finished on its own + if (this->scanner_state_ != ScannerState::RUNNING) { + if (this->scanner_state_ == ScannerState::STOPPING) { + ESP_LOGE(TAG, "Scan was not running when scan completed."); + } else if (this->scanner_state_ == ScannerState::STARTING) { + ESP_LOGE(TAG, "Scan was not started when scan completed."); + } else if (this->scanner_state_ == ScannerState::FAILED) { + ESP_LOGE(TAG, "Scan was in failed state when scan completed."); + } else if (this->scanner_state_ == ScannerState::IDLE) { + ESP_LOGE(TAG, "Scan was idle when scan completed."); + } else if (this->scanner_state_ == ScannerState::STOPPED) { + ESP_LOGE(TAG, "Scan was stopped when scan completed."); + } + } + this->scanner_state_ = ScannerState::STOPPED; } } @@ -680,8 +712,26 @@ void ESP32BLETracker::dump_config() { ESP_LOGCONFIG(TAG, " Scan Window: %.1f ms", this->scan_window_ * 0.625f); ESP_LOGCONFIG(TAG, " Scan Type: %s", this->scan_active_ ? "ACTIVE" : "PASSIVE"); ESP_LOGCONFIG(TAG, " Continuous Scanning: %s", YESNO(this->scan_continuous_)); - ESP_LOGCONFIG(TAG, " Scanner Idle: %s", YESNO(this->scanner_idle_)); - ESP_LOGCONFIG(TAG, " Scan End: %s", YESNO(xSemaphoreGetMutexHolder(this->scan_end_lock_) == nullptr)); + switch (this->scanner_state_) { + case ScannerState::IDLE: + ESP_LOGCONFIG(TAG, " Scanner State: IDLE"); + break; + case ScannerState::STARTING: + ESP_LOGCONFIG(TAG, " Scanner State: STARTING"); + break; + case ScannerState::RUNNING: + ESP_LOGCONFIG(TAG, " Scanner State: RUNNING"); + break; + case ScannerState::STOPPING: + ESP_LOGCONFIG(TAG, " Scanner State: STOPPING"); + break; + case ScannerState::STOPPED: + ESP_LOGCONFIG(TAG, " Scanner State: STOPPED"); + break; + case ScannerState::FAILED: + ESP_LOGCONFIG(TAG, " Scanner State: FAILED"); + break; + } ESP_LOGCONFIG(TAG, " Connecting: %d, discovered: %d, searching: %d, disconnecting: %d", connecting_, discovered_, searching_, disconnecting_); if (this->scan_start_fail_count_) { diff --git a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h index 8b712a01ea..6ca763db07 100644 --- a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h +++ b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h @@ -154,6 +154,21 @@ enum class ClientState { ESTABLISHED, }; +enum class ScannerState { + // Scanner is idle, init state, set from the main loop when processing STOPPED + IDLE, + // Scanner is starting, set from the main loop only + STARTING, + // Scanner is running, set from the ESP callback only + RUNNING, + // Scanner failed to start, set from the ESP callback only + FAILED, + // Scanner is stopping, set from the main loop only + STOPPING, + // Scanner is stopped, set from the ESP callback only + STOPPED, +}; + enum class ConnectionType { // The default connection type, we hold all the services in ram // for the duration of the connection. @@ -257,12 +272,11 @@ class ESP32BLETracker : public Component, uint8_t scan_start_fail_count_{0}; bool scan_continuous_; bool scan_active_; - bool scanner_idle_{true}; + ScannerState scanner_state_{ScannerState::IDLE}; bool ble_was_disabled_{true}; bool raw_advertisements_{false}; bool parse_advertisements_{false}; SemaphoreHandle_t scan_result_lock_; - SemaphoreHandle_t scan_end_lock_; size_t scan_result_index_{0}; #ifdef USE_PSRAM const static u_int8_t SCAN_RESULT_BUFFER_SIZE = 32; From 991f3d3a10e8f70453548224c20e0159305efff9 Mon Sep 17 00:00:00 2001 From: Craig Andrews Date: Wed, 23 Apr 2025 00:30:50 -0400 Subject: [PATCH 002/102] [http_request] Ability to get response headers (#8224) Co-authored-by: guillempages Co-authored-by: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> --- esphome/components/http_request/__init__.py | 19 +++-- .../components/http_request/http_request.cpp | 20 ++++++ .../components/http_request/http_request.h | 72 +++++++++++++++---- .../http_request/http_request_arduino.cpp | 28 ++++++-- .../http_request/http_request_arduino.h | 7 +- .../http_request/http_request_idf.cpp | 38 +++++++++- .../http_request/http_request_idf.h | 13 +++- tests/components/http_request/common.yaml | 11 +-- 8 files changed, 168 insertions(+), 40 deletions(-) diff --git a/esphome/components/http_request/__init__.py b/esphome/components/http_request/__init__.py index 78064fb4b4..2a999532f8 100644 --- a/esphome/components/http_request/__init__.py +++ b/esphome/components/http_request/__init__.py @@ -47,6 +47,8 @@ CONF_BUFFER_SIZE_TX = "buffer_size_tx" CONF_MAX_RESPONSE_BUFFER_SIZE = "max_response_buffer_size" CONF_ON_RESPONSE = "on_response" CONF_HEADERS = "headers" +CONF_REQUEST_HEADERS = "request_headers" +CONF_COLLECT_HEADERS = "collect_headers" CONF_BODY = "body" CONF_JSON = "json" CONF_CAPTURE_RESPONSE = "capture_response" @@ -176,9 +178,13 @@ HTTP_REQUEST_ACTION_SCHEMA = cv.Schema( { cv.GenerateID(): cv.use_id(HttpRequestComponent), cv.Required(CONF_URL): cv.templatable(validate_url), - cv.Optional(CONF_HEADERS): cv.All( + cv.Optional(CONF_HEADERS): cv.invalid( + "The 'headers' options has been renamed to 'request_headers'" + ), + cv.Optional(CONF_REQUEST_HEADERS): cv.All( cv.Schema({cv.string: cv.templatable(cv.string)}) ), + cv.Optional(CONF_COLLECT_HEADERS): cv.ensure_list(cv.string), cv.Optional(CONF_VERIFY_SSL): cv.invalid( f"{CONF_VERIFY_SSL} has moved to the base component configuration." ), @@ -263,11 +269,12 @@ async def http_request_action_to_code(config, action_id, template_arg, args): for key in json_: template_ = await cg.templatable(json_[key], args, cg.std_string) cg.add(var.add_json(key, template_)) - for key in config.get(CONF_HEADERS, []): - template_ = await cg.templatable( - config[CONF_HEADERS][key], args, cg.const_char_ptr - ) - cg.add(var.add_header(key, template_)) + for key in config.get(CONF_REQUEST_HEADERS, []): + template_ = await cg.templatable(key, args, cg.std_string) + cg.add(var.add_request_header(key, template_)) + + for value in config.get(CONF_COLLECT_HEADERS, []): + cg.add(var.add_collect_header(value)) for conf in config.get(CONF_ON_RESPONSE, []): trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID]) diff --git a/esphome/components/http_request/http_request.cpp b/esphome/components/http_request/http_request.cpp index be8bef006e..ca9fd2c2dc 100644 --- a/esphome/components/http_request/http_request.cpp +++ b/esphome/components/http_request/http_request.cpp @@ -20,5 +20,25 @@ void HttpRequestComponent::dump_config() { } } +std::string HttpContainer::get_response_header(const std::string &header_name) { + auto response_headers = this->get_response_headers(); + auto header_name_lower_case = str_lower_case(header_name); + if (response_headers.count(header_name_lower_case) == 0) { + ESP_LOGW(TAG, "No header with name %s found", header_name_lower_case.c_str()); + return ""; + } else { + auto values = response_headers[header_name_lower_case]; + if (values.empty()) { + ESP_LOGE(TAG, "header with name %s returned an empty list, this shouldn't happen", + header_name_lower_case.c_str()); + return ""; + } else { + auto header_value = values.front(); + ESP_LOGD(TAG, "Header with name %s found with value %s", header_name_lower_case.c_str(), header_value.c_str()); + return header_value; + } + } +} + } // namespace http_request } // namespace esphome diff --git a/esphome/components/http_request/http_request.h b/esphome/components/http_request/http_request.h index e98fd1a475..a67b04eadc 100644 --- a/esphome/components/http_request/http_request.h +++ b/esphome/components/http_request/http_request.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -95,9 +96,19 @@ class HttpContainer : public Parented { size_t get_bytes_read() const { return this->bytes_read_; } + /** + * @brief Get response headers. + * + * @return The key is the lower case response header name, the value is the header value. + */ + std::map> get_response_headers() { return this->response_headers_; } + + std::string get_response_header(const std::string &header_name); + protected: size_t bytes_read_{0}; bool secure_{false}; + std::map> response_headers_{}; }; class HttpRequestResponseTrigger : public Trigger, std::string &> { @@ -119,21 +130,46 @@ class HttpRequestComponent : public Component { void set_follow_redirects(bool follow_redirects) { this->follow_redirects_ = follow_redirects; } void set_redirect_limit(uint16_t limit) { this->redirect_limit_ = limit; } - std::shared_ptr get(std::string url) { return this->start(std::move(url), "GET", "", {}); } - std::shared_ptr get(std::string url, std::list
headers) { - return this->start(std::move(url), "GET", "", std::move(headers)); + std::shared_ptr get(const std::string &url) { return this->start(url, "GET", "", {}); } + std::shared_ptr get(const std::string &url, const std::list
&request_headers) { + return this->start(url, "GET", "", request_headers); } - std::shared_ptr post(std::string url, std::string body) { - return this->start(std::move(url), "POST", std::move(body), {}); + std::shared_ptr get(const std::string &url, const std::list
&request_headers, + const std::set &collect_headers) { + return this->start(url, "GET", "", request_headers, collect_headers); } - std::shared_ptr post(std::string url, std::string body, std::list
headers) { - return this->start(std::move(url), "POST", std::move(body), std::move(headers)); + std::shared_ptr post(const std::string &url, const std::string &body) { + return this->start(url, "POST", body, {}); + } + std::shared_ptr post(const std::string &url, const std::string &body, + const std::list
&request_headers) { + return this->start(url, "POST", body, request_headers); + } + std::shared_ptr post(const std::string &url, const std::string &body, + const std::list
&request_headers, + const std::set &collect_headers) { + return this->start(url, "POST", body, request_headers, collect_headers); } - virtual std::shared_ptr start(std::string url, std::string method, std::string body, - std::list
headers) = 0; + std::shared_ptr start(const std::string &url, const std::string &method, const std::string &body, + const std::list
&request_headers) { + return this->start(url, method, body, request_headers, {}); + } + + std::shared_ptr start(const std::string &url, const std::string &method, const std::string &body, + const std::list
&request_headers, + const std::set &collect_headers) { + std::set lower_case_collect_headers; + for (const std::string &collect_header : collect_headers) { + lower_case_collect_headers.insert(str_lower_case(collect_header)); + } + return this->perform(url, method, body, request_headers, lower_case_collect_headers); + } protected: + virtual std::shared_ptr perform(std::string url, std::string method, std::string body, + std::list
request_headers, + std::set collect_headers) = 0; const char *useragent_{nullptr}; bool follow_redirects_{}; uint16_t redirect_limit_{}; @@ -149,7 +185,11 @@ template class HttpRequestSendAction : public Action { TEMPLATABLE_VALUE(std::string, body) TEMPLATABLE_VALUE(bool, capture_response) - void add_header(const char *key, TemplatableValue value) { this->headers_.insert({key, value}); } + void add_request_header(const char *key, TemplatableValue value) { + this->request_headers_.insert({key, value}); + } + + void add_collect_header(const char *value) { this->collect_headers_.insert(value); } void add_json(const char *key, TemplatableValue value) { this->json_.insert({key, value}); } @@ -176,16 +216,17 @@ template class HttpRequestSendAction : public Action { auto f = std::bind(&HttpRequestSendAction::encode_json_func_, this, x..., std::placeholders::_1); body = json::build_json(f); } - std::list
headers; - for (const auto &item : this->headers_) { + std::list
request_headers; + for (const auto &item : this->request_headers_) { auto val = item.second; Header header; header.name = item.first; header.value = val.value(x...); - headers.push_back(header); + request_headers.push_back(header); } - auto container = this->parent_->start(this->url_.value(x...), this->method_.value(x...), body, headers); + auto container = this->parent_->start(this->url_.value(x...), this->method_.value(x...), body, request_headers, + this->collect_headers_); if (container == nullptr) { for (auto *trigger : this->error_triggers_) @@ -238,7 +279,8 @@ template class HttpRequestSendAction : public Action { } void encode_json_func_(Ts... x, JsonObject root) { this->json_func_(x..., root); } HttpRequestComponent *parent_; - std::map> headers_{}; + std::map> request_headers_{}; + std::set collect_headers_{"content-type", "content-length"}; std::map> json_{}; std::function json_func_{nullptr}; std::vector response_triggers_{}; diff --git a/esphome/components/http_request/http_request_arduino.cpp b/esphome/components/http_request/http_request_arduino.cpp index b0067e7839..b4378cdce6 100644 --- a/esphome/components/http_request/http_request_arduino.cpp +++ b/esphome/components/http_request/http_request_arduino.cpp @@ -14,8 +14,9 @@ namespace http_request { static const char *const TAG = "http_request.arduino"; -std::shared_ptr HttpRequestArduino::start(std::string url, std::string method, std::string body, - std::list
headers) { +std::shared_ptr HttpRequestArduino::perform(std::string url, std::string method, std::string body, + std::list
request_headers, + std::set collect_headers) { if (!network::is_connected()) { this->status_momentary_error("failed", 1000); ESP_LOGW(TAG, "HTTP Request failed; Not connected to network"); @@ -95,14 +96,17 @@ std::shared_ptr HttpRequestArduino::start(std::string url, std::s if (this->useragent_ != nullptr) { container->client_.setUserAgent(this->useragent_); } - for (const auto &header : headers) { + for (const auto &header : request_headers) { container->client_.addHeader(header.name.c_str(), header.value.c_str(), false, true); } // returned needed headers must be collected before the requests - static const char *header_keys[] = {"Content-Length", "Content-Type"}; - static const size_t HEADER_COUNT = sizeof(header_keys) / sizeof(header_keys[0]); - container->client_.collectHeaders(header_keys, HEADER_COUNT); + const char *header_keys[collect_headers.size()]; + int index = 0; + for (auto const &header_name : collect_headers) { + header_keys[index++] = header_name.c_str(); + } + container->client_.collectHeaders(header_keys, index); App.feed_wdt(); container->status_code = container->client_.sendRequest(method.c_str(), body.c_str()); @@ -121,6 +125,18 @@ std::shared_ptr HttpRequestArduino::start(std::string url, std::s // Still return the container, so it can be used to get the status code and error message } + container->response_headers_ = {}; + auto header_count = container->client_.headers(); + for (int i = 0; i < header_count; i++) { + const std::string header_name = str_lower_case(container->client_.headerName(i).c_str()); + if (collect_headers.count(header_name) > 0) { + std::string header_value = container->client_.header(i).c_str(); + ESP_LOGD(TAG, "Received response header, name: %s, value: %s", header_name.c_str(), header_value.c_str()); + container->response_headers_[header_name].push_back(header_value); + break; + } + } + int content_length = container->client_.getSize(); ESP_LOGD(TAG, "Content-Length: %d", content_length); container->content_length = (size_t) content_length; diff --git a/esphome/components/http_request/http_request_arduino.h b/esphome/components/http_request/http_request_arduino.h index dfdf4a35e2..ac9ddffbb0 100644 --- a/esphome/components/http_request/http_request_arduino.h +++ b/esphome/components/http_request/http_request_arduino.h @@ -29,9 +29,10 @@ class HttpContainerArduino : public HttpContainer { }; class HttpRequestArduino : public HttpRequestComponent { - public: - std::shared_ptr start(std::string url, std::string method, std::string body, - std::list
headers) override; + protected: + std::shared_ptr perform(std::string url, std::string method, std::string body, + std::list
request_headers, + std::set collect_headers) override; }; } // namespace http_request diff --git a/esphome/components/http_request/http_request_idf.cpp b/esphome/components/http_request/http_request_idf.cpp index 78c37403f5..0923062822 100644 --- a/esphome/components/http_request/http_request_idf.cpp +++ b/esphome/components/http_request/http_request_idf.cpp @@ -19,14 +19,41 @@ namespace http_request { static const char *const TAG = "http_request.idf"; +struct UserData { + const std::set &collect_headers; + std::map> response_headers; +}; + void HttpRequestIDF::dump_config() { HttpRequestComponent::dump_config(); ESP_LOGCONFIG(TAG, " Buffer Size RX: %u", this->buffer_size_rx_); ESP_LOGCONFIG(TAG, " Buffer Size TX: %u", this->buffer_size_tx_); } -std::shared_ptr HttpRequestIDF::start(std::string url, std::string method, std::string body, - std::list
headers) { +esp_err_t HttpRequestIDF::http_event_handler(esp_http_client_event_t *evt) { + UserData *user_data = (UserData *) evt->user_data; + + switch (evt->event_id) { + case HTTP_EVENT_ON_HEADER: { + const std::string header_name = str_lower_case(evt->header_key); + if (user_data->collect_headers.count(header_name)) { + const std::string header_value = evt->header_value; + ESP_LOGD(TAG, "Received response header, name: %s, value: %s", header_name.c_str(), header_value.c_str()); + user_data->response_headers[header_name].push_back(header_value); + break; + } + break; + } + default: { + break; + } + } + return ESP_OK; +} + +std::shared_ptr HttpRequestIDF::perform(std::string url, std::string method, std::string body, + std::list
request_headers, + std::set collect_headers) { if (!network::is_connected()) { this->status_momentary_error("failed", 1000); ESP_LOGE(TAG, "HTTP Request failed; Not connected to network"); @@ -76,6 +103,10 @@ std::shared_ptr HttpRequestIDF::start(std::string url, std::strin const uint32_t start = millis(); watchdog::WatchdogManager wdm(this->get_watchdog_timeout()); + config.event_handler = http_event_handler; + auto user_data = UserData{collect_headers, {}}; + config.user_data = static_cast(&user_data); + esp_http_client_handle_t client = esp_http_client_init(&config); std::shared_ptr container = std::make_shared(client); @@ -83,7 +114,7 @@ std::shared_ptr HttpRequestIDF::start(std::string url, std::strin container->set_secure(secure); - for (const auto &header : headers) { + for (const auto &header : request_headers) { esp_http_client_set_header(client, header.name.c_str(), header.value.c_str()); } @@ -124,6 +155,7 @@ std::shared_ptr HttpRequestIDF::start(std::string url, std::strin container->feed_wdt(); container->status_code = esp_http_client_get_status_code(client); container->feed_wdt(); + container->set_response_headers(user_data.response_headers); if (is_success(container->status_code)) { container->duration_ms = millis() - start; return container; diff --git a/esphome/components/http_request/http_request_idf.h b/esphome/components/http_request/http_request_idf.h index 2ed50698b9..5c5b784853 100644 --- a/esphome/components/http_request/http_request_idf.h +++ b/esphome/components/http_request/http_request_idf.h @@ -21,6 +21,10 @@ class HttpContainerIDF : public HttpContainer { /// @brief Feeds the watchdog timer if the executing task has one attached void feed_wdt(); + void set_response_headers(std::map> &response_headers) { + this->response_headers_ = std::move(response_headers); + } + protected: esp_http_client_handle_t client_; }; @@ -29,16 +33,19 @@ class HttpRequestIDF : public HttpRequestComponent { public: void dump_config() override; - std::shared_ptr start(std::string url, std::string method, std::string body, - std::list
headers) override; - void set_buffer_size_rx(uint16_t buffer_size_rx) { this->buffer_size_rx_ = buffer_size_rx; } void set_buffer_size_tx(uint16_t buffer_size_tx) { this->buffer_size_tx_ = buffer_size_tx; } protected: + std::shared_ptr perform(std::string url, std::string method, std::string body, + std::list
request_headers, + std::set collect_headers) override; // if zero ESP-IDF will use DEFAULT_HTTP_BUF_SIZE uint16_t buffer_size_rx_{}; uint16_t buffer_size_tx_{}; + + /// @brief Monitors the http client events to gather response headers + static esp_err_t http_event_handler(esp_http_client_event_t *evt); }; } // namespace http_request diff --git a/tests/components/http_request/common.yaml b/tests/components/http_request/common.yaml index 8408f27a05..4a9b8a0e62 100644 --- a/tests/components/http_request/common.yaml +++ b/tests/components/http_request/common.yaml @@ -10,27 +10,30 @@ esphome: then: - http_request.get: url: https://esphome.io - headers: + request_headers: Content-Type: application/json + collect_headers: + - age on_error: logger.log: "Request failed" on_response: then: - logger.log: - format: "Response status: %d, Duration: %lu ms" + format: "Response status: %d, Duration: %lu ms, age: %s" args: - response->status_code - (long) response->duration_ms + - response->get_response_header("age").c_str() - http_request.post: url: https://esphome.io - headers: + request_headers: Content-Type: application/json json: key: value - http_request.send: method: PUT url: https://esphome.io - headers: + request_headers: Content-Type: application/json body: "Some data" From 33d79e03d927f2b5abc4b7be3990880985d20a67 Mon Sep 17 00:00:00 2001 From: Djordje Mandic <6750655+DjordjeMandic@users.noreply.github.com> Date: Wed, 23 Apr 2025 10:45:29 +0200 Subject: [PATCH 003/102] [sht4x] Reduce warn spam, added communication check in setup (#8250) Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com> --- esphome/components/sht4x/sht4x.cpp | 65 ++++++++++++++++++++---------- esphome/components/sht4x/sht4x.h | 2 +- 2 files changed, 44 insertions(+), 23 deletions(-) diff --git a/esphome/components/sht4x/sht4x.cpp b/esphome/components/sht4x/sht4x.cpp index dea542ea9e..e4fa16d87a 100644 --- a/esphome/components/sht4x/sht4x.cpp +++ b/esphome/components/sht4x/sht4x.cpp @@ -12,14 +12,22 @@ void SHT4XComponent::start_heater_() { uint8_t cmd[] = {MEASURECOMMANDS[this->heater_command_]}; ESP_LOGD(TAG, "Heater turning on"); - this->write(cmd, 1); + if (this->write(cmd, 1) != i2c::ERROR_OK) { + this->status_set_error("Failed to turn on heater"); + } } void SHT4XComponent::setup() { ESP_LOGCONFIG(TAG, "Setting up sht4x..."); - if (this->duty_cycle_ > 0.0) { - uint32_t heater_interval = (uint32_t) (this->heater_time_ / this->duty_cycle_); + auto err = this->write(nullptr, 0); + if (err != i2c::ERROR_OK) { + this->mark_failed(); + return; + } + + if (std::isfinite(this->duty_cycle_) && this->duty_cycle_ > 0.0f) { + uint32_t heater_interval = static_cast(static_cast(this->heater_time_) / this->duty_cycle_); ESP_LOGD(TAG, "Heater interval: %" PRIu32, heater_interval); if (this->heater_power_ == SHT4X_HEATERPOWER_HIGH) { @@ -47,37 +55,50 @@ void SHT4XComponent::setup() { } } -void SHT4XComponent::dump_config() { LOG_I2C_DEVICE(this); } +void SHT4XComponent::dump_config() { + ESP_LOGCONFIG(TAG, "SHT4x:"); + LOG_I2C_DEVICE(this); + if (this->is_failed()) { + ESP_LOGE(TAG, "Communication with SHT4x failed!"); + } +} void SHT4XComponent::update() { // Send command - this->write_command(MEASURECOMMANDS[this->precision_]); + if (!this->write_command(MEASURECOMMANDS[this->precision_])) { + // Warning will be printed only if warning status is not set yet + this->status_set_warning("Failed to send measurement command"); + return; + } this->set_timeout(10, [this]() { uint16_t buffer[2]; // Read measurement - bool read_status = this->read_data(buffer, 2); + if (!this->read_data(buffer, 2)) { + // Using ESP_LOGW to force the warning to be printed + ESP_LOGW(TAG, "Sensor read failed"); + this->status_set_warning(); + return; + } - if (read_status) { - // Evaluate and publish measurements - if (this->temp_sensor_ != nullptr) { - // Temp is contained in the first result word - float sensor_value_temp = buffer[0]; - float temp = -45 + 175 * sensor_value_temp / 65535; + this->status_clear_warning(); - this->temp_sensor_->publish_state(temp); - } + // Evaluate and publish measurements + if (this->temp_sensor_ != nullptr) { + // Temp is contained in the first result word + float sensor_value_temp = buffer[0]; + float temp = -45 + 175 * sensor_value_temp / 65535; - if (this->humidity_sensor_ != nullptr) { - // Relative humidity is in the second result word - float sensor_value_rh = buffer[1]; - float rh = -6 + 125 * sensor_value_rh / 65535; + this->temp_sensor_->publish_state(temp); + } - this->humidity_sensor_->publish_state(rh); - } - } else { - ESP_LOGD(TAG, "Sensor read failed"); + if (this->humidity_sensor_ != nullptr) { + // Relative humidity is in the second result word + float sensor_value_rh = buffer[1]; + float rh = -6 + 125 * sensor_value_rh / 65535; + + this->humidity_sensor_->publish_state(rh); } }); } diff --git a/esphome/components/sht4x/sht4x.h b/esphome/components/sht4x/sht4x.h index 46037bb0e8..98e0629b50 100644 --- a/esphome/components/sht4x/sht4x.h +++ b/esphome/components/sht4x/sht4x.h @@ -13,7 +13,7 @@ enum SHT4XPRECISION { SHT4X_PRECISION_HIGH = 0, SHT4X_PRECISION_MED, SHT4X_PRECI enum SHT4XHEATERPOWER { SHT4X_HEATERPOWER_HIGH, SHT4X_HEATERPOWER_MED, SHT4X_HEATERPOWER_LOW }; -enum SHT4XHEATERTIME { SHT4X_HEATERTIME_LONG = 1100, SHT4X_HEATERTIME_SHORT = 110 }; +enum SHT4XHEATERTIME : uint16_t { SHT4X_HEATERTIME_LONG = 1100, SHT4X_HEATERTIME_SHORT = 110 }; class SHT4XComponent : public PollingComponent, public sensirion_common::SensirionI2CDevice { public: From 89b1b12993f44b1c29cff0a08606a7f5cd1ce3d0 Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Wed, 23 Apr 2025 18:47:15 +1000 Subject: [PATCH 004/102] [online_image] Fix printf format; comment fixes (#8607) --- esphome/components/online_image/online_image.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/esphome/components/online_image/online_image.cpp b/esphome/components/online_image/online_image.cpp index 3411018901..cb4a3be9e8 100644 --- a/esphome/components/online_image/online_image.cpp +++ b/esphome/components/online_image/online_image.cpp @@ -111,7 +111,7 @@ void OnlineImage::update() { case ImageFormat::BMP: accept_mime_type = "image/bmp"; break; -#endif // ONLINE_IMAGE_BMP_SUPPORT +#endif // USE_ONLINE_IMAGE_BMP_SUPPORT #ifdef USE_ONLINE_IMAGE_JPEG_SUPPORT case ImageFormat::JPEG: accept_mime_type = "image/jpeg"; @@ -121,7 +121,7 @@ void OnlineImage::update() { case ImageFormat::PNG: accept_mime_type = "image/png"; break; -#endif // ONLINE_IMAGE_PNG_SUPPORT +#endif // USE_ONLINE_IMAGE_PNG_SUPPORT default: accept_mime_type = "image/*"; } @@ -159,7 +159,7 @@ void OnlineImage::update() { ESP_LOGD(TAG, "Allocating BMP decoder"); this->decoder_ = make_unique(this); } -#endif // ONLINE_IMAGE_BMP_SUPPORT +#endif // USE_ONLINE_IMAGE_BMP_SUPPORT #ifdef USE_ONLINE_IMAGE_JPEG_SUPPORT if (this->format_ == ImageFormat::JPEG) { ESP_LOGD(TAG, "Allocating JPEG decoder"); @@ -171,7 +171,7 @@ void OnlineImage::update() { ESP_LOGD(TAG, "Allocating PNG decoder"); this->decoder_ = make_unique(this); } -#endif // ONLINE_IMAGE_PNG_SUPPORT +#endif // USE_ONLINE_IMAGE_PNG_SUPPORT if (!this->decoder_) { ESP_LOGE(TAG, "Could not instantiate decoder. Image format unsupported: %d", this->format_); @@ -185,7 +185,7 @@ void OnlineImage::update() { this->download_error_callback_.call(); return; } - ESP_LOGI(TAG, "Downloading image (Size: %d)", total_size); + ESP_LOGI(TAG, "Downloading image (Size: %zu)", total_size); this->start_time_ = ::time(nullptr); } From 911bd547651ff4f9919b9f03edf232f8bb1034d6 Mon Sep 17 00:00:00 2001 From: Jesse Hills <3060199+jesserockz@users.noreply.github.com> Date: Wed, 23 Apr 2025 20:49:33 +1200 Subject: [PATCH 005/102] [watchdog] Fix for variants with single core (#8602) --- esphome/components/watchdog/watchdog.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/esphome/components/watchdog/watchdog.cpp b/esphome/components/watchdog/watchdog.cpp index 3a94a658e8..f6f2992a11 100644 --- a/esphome/components/watchdog/watchdog.cpp +++ b/esphome/components/watchdog/watchdog.cpp @@ -6,6 +6,7 @@ #include #include #ifdef USE_ESP32 +#include #include "esp_idf_version.h" #include "esp_task_wdt.h" #endif @@ -40,7 +41,7 @@ void WatchdogManager::set_timeout_(uint32_t timeout_ms) { #if ESP_IDF_VERSION_MAJOR >= 5 esp_task_wdt_config_t wdt_config = { .timeout_ms = timeout_ms, - .idle_core_mask = 0x03, + .idle_core_mask = (1 << SOC_CPU_CORES_NUM) - 1, .trigger_panic = true, }; esp_task_wdt_reconfigure(&wdt_config); From f29ccb9e75adece3ade409c8195eddd7e2ed62e7 Mon Sep 17 00:00:00 2001 From: Guillermo Ruffino Date: Thu, 24 Apr 2025 00:43:37 -0300 Subject: [PATCH 006/102] Schema gen action (#8593) Co-authored-by: Jonathan Swoboda <154711427+swoboda1337@users.noreply.github.com> Co-authored-by: Keith Burzinski --- .github/workflows/ci.yml | 1 + esphome/schema_extractors.py | 1 - script/build_language_schema.py | 64 ++++++++++++++++++++------------- 3 files changed, 40 insertions(+), 26 deletions(-) mode change 100644 => 100755 script/build_language_schema.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 997b98eefa..9022da68ac 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -165,6 +165,7 @@ jobs: . venv/bin/activate script/ci-custom.py script/build_codeowners.py --check + script/build_language_schema.py --check pytest: name: Run pytest diff --git a/esphome/schema_extractors.py b/esphome/schema_extractors.py index 5491bc88c4..a84e08a8d3 100644 --- a/esphome/schema_extractors.py +++ b/esphome/schema_extractors.py @@ -42,7 +42,6 @@ def schema_extractor_extended(func): def decorate(*args, **kwargs): ret = func(*args, **kwargs) - assert len(args) == 2 extended_schemas[repr(ret)] = args return ret diff --git a/script/build_language_schema.py b/script/build_language_schema.py old mode 100644 new mode 100755 index 7152e23e8f..a0edad150a --- a/script/build_language_schema.py +++ b/script/build_language_schema.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 import argparse import glob import inspect @@ -36,6 +37,7 @@ parser = argparse.ArgumentParser() parser.add_argument( "--output-path", default=".", help="Output path", type=os.path.abspath ) +parser.add_argument("--check", action="store_true", help="Check only for CI") args = parser.parse_args() @@ -66,12 +68,13 @@ def get_component_names(): from esphome.loader import CORE_COMPONENTS_PATH component_names = ["esphome", "sensor", "esp32", "esp8266"] + skip_components = [] for d in os.listdir(CORE_COMPONENTS_PATH): if not d.startswith("__") and os.path.isdir( os.path.join(CORE_COMPONENTS_PATH, d) ): - if d not in component_names: + if d not in component_names and d not in skip_components: component_names.append(d) return component_names @@ -81,16 +84,26 @@ def load_components(): from esphome.config import get_component for domain in get_component_names(): - components[domain] = get_component(domain) + components[domain] = get_component(domain, exception=True) + assert components[domain] is not None -# pylint: disable=wrong-import-position -from esphome.const import CONF_TYPE, KEY_CORE, KEY_TARGET_PLATFORM # noqa: E402 +from esphome.const import ( # noqa: E402 + CONF_TYPE, + KEY_CORE, + KEY_FRAMEWORK_VERSION, + KEY_TARGET_FRAMEWORK, + KEY_TARGET_PLATFORM, +) from esphome.core import CORE # noqa: E402 -# pylint: enable=wrong-import-position +CORE.data[KEY_CORE] = { + KEY_TARGET_PLATFORM: "esp8266", + KEY_TARGET_FRAMEWORK: "arduino", + KEY_FRAMEWORK_VERSION: "0", +} + -CORE.data[KEY_CORE] = {KEY_TARGET_PLATFORM: None} load_components() # Import esphome after loading components (so schema is tracked) @@ -98,7 +111,6 @@ load_components() from esphome import automation, pins # noqa: E402 from esphome.components import remote_base # noqa: E402 import esphome.config_validation as cv # noqa: E402 -import esphome.core as esphome_core # noqa: E402 from esphome.helpers import write_file_if_changed # noqa: E402 from esphome.loader import CORE_COMPONENTS_PATH, get_platform # noqa: E402 from esphome.util import Registry # noqa: E402 @@ -523,11 +535,14 @@ def shrink(): # then are all simple types, integer and strings for x, paths in referenced_schemas.items(): key_s = get_str_path_schema(x) - if key_s and key_s[S_TYPE] in ["enum", "registry", "integer", "string"]: + if key_s and key_s.get(S_TYPE) in ["enum", "registry", "integer", "string"]: if key_s[S_TYPE] == "registry": print("Spreading registry: " + x) for target in paths: target_s = get_arr_path_schema(target) + if S_SCHEMA not in target_s: + print("skipping simple spread for " + ".".join(target)) + continue assert target_s[S_SCHEMA][S_EXTENDS] == [x] target_s.pop(S_SCHEMA) target_s |= key_s @@ -542,14 +557,14 @@ def shrink(): # an empty schema like speaker.SPEAKER_SCHEMA target_s[S_EXTENDS].remove(x) continue - assert target_s[S_SCHEMA][S_EXTENDS] == [x] + assert x in target_s[S_SCHEMA][S_EXTENDS] target_s.pop(S_SCHEMA) target_s.pop(S_TYPE) # undefined target_s["data_type"] = x.split(".")[1] # remove this dangling again pop_str_path_schema(x) - # remove dangling items (unreachable schemas) + # remove unreachable schemas for domain, domain_schemas in output.items(): for schema_name in list(domain_schemas.get(S_SCHEMAS, {}).keys()): s = f"{domain}.{schema_name}" @@ -558,7 +573,6 @@ def shrink(): and s not in referenced_schemas and not is_platform_schema(s) ): - print(f"Removing {s}") domain_schemas[S_SCHEMAS].pop(schema_name) @@ -659,6 +673,9 @@ def build_schema(): # bundle core inside esphome data["esphome"]["core"] = data.pop("core")["core"] + if args.check: # do not gen files + return + for c, s in data.items(): write_file(c, s) delete_extra_files(data.keys()) @@ -721,15 +738,8 @@ def convert(schema, config_var, path): # Extended schemas are tracked when the .extend() is used in a schema if repr_schema in ejs.extended_schemas: extended = ejs.extended_schemas.get(repr_schema) - # The midea actions are extending an empty schema (resulted in the templatize not templatizing anything) - # this causes a recursion in that this extended looks the same in extended schema as the extended[1] - if repr_schema == repr(extended[1]): - assert path.startswith("midea_ac/") - return - - assert len(extended) == 2 - convert(extended[0], config_var, path + "/extL") - convert(extended[1], config_var, path + "/extR") + for idx, ext in enumerate(extended): + convert(ext, config_var, f"{path}/ext{idx}") return if isinstance(schema, cv.All): @@ -881,15 +891,22 @@ def convert(schema, config_var, path): "class": "i2c::I2CBus", "parents": ["Component"], } - elif path == "uart/CONFIG_SCHEMA/val 1/extL/all/id": + elif path == "uart/CONFIG_SCHEMA/val 1/ext0/all/id": config_var["id_type"] = { "class": "uart::UARTComponent", "parents": ["Component"], } + elif path == "http_request/CONFIG_SCHEMA/val 1/ext0/all/id": + config_var["id_type"] = { + "class": "http_request::HttpRequestComponent", + "parents": ["Component"], + } elif path == "pins/esp32/val 1/id": config_var["id_type"] = "pin" else: - raise TypeError("Cannot determine id_type for " + path) + print("Cannot determine id_type for " + path) + + # raise TypeError("Cannot determine id_type for " + path) elif repr_schema in ejs.registry_schemas: solve_registry.append((ejs.registry_schemas[repr_schema], config_var)) @@ -965,9 +982,6 @@ def convert_keys(converted, schema, path): else: converted["key_type"] = str(k) - esphome_core.CORE.data = { - esphome_core.KEY_CORE: {esphome_core.KEY_TARGET_PLATFORM: "esp8266"} - } if hasattr(k, "default") and str(k.default) != "...": default_value = k.default() if default_value is not None: From 6792ff6d58e62ba540cf7e92272d6bb1493ee5d8 Mon Sep 17 00:00:00 2001 From: luar123 <49960470+luar123@users.noreply.github.com> Date: Thu, 24 Apr 2025 22:33:58 +0200 Subject: [PATCH 007/102] [i2s_audio, i2s_audio_microphone, i2s_audio_speaker] Add basic support for new esp-idf 5.x.x i2s driver. (#8181) --- esphome/components/i2s_audio/__init__.py | 135 ++++++++++++--- esphome/components/i2s_audio/i2s_audio.h | 46 ++++- .../i2s_audio/media_player/__init__.py | 9 + .../i2s_audio/microphone/__init__.py | 19 +++ .../microphone/i2s_audio_microphone.cpp | 161 +++++++++++++++++- .../microphone/i2s_audio_microphone.h | 13 +- .../components/i2s_audio/speaker/__init__.py | 34 +++- .../i2s_audio/speaker/i2s_audio_speaker.cpp | 145 +++++++++++++++- .../i2s_audio/speaker/i2s_audio_speaker.h | 21 ++- esphome/core/defines.h | 1 + tests/components/micro_wake_word/common.yaml | 1 + .../components/microphone/test.esp32-idf.yaml | 11 +- 12 files changed, 552 insertions(+), 44 deletions(-) diff --git a/esphome/components/i2s_audio/__init__.py b/esphome/components/i2s_audio/__init__.py index fa515a585f..291ae4ba95 100644 --- a/esphome/components/i2s_audio/__init__.py +++ b/esphome/components/i2s_audio/__init__.py @@ -8,7 +8,15 @@ from esphome.components.esp32.const import ( VARIANT_ESP32S3, ) import esphome.config_validation as cv -from esphome.const import CONF_BITS_PER_SAMPLE, CONF_CHANNEL, CONF_ID, CONF_SAMPLE_RATE +from esphome.const import ( + CONF_BITS_PER_SAMPLE, + CONF_CHANNEL, + CONF_ID, + CONF_SAMPLE_RATE, + KEY_CORE, + KEY_FRAMEWORK_VERSION, +) +from esphome.core import CORE from esphome.cpp_generator import MockObjClass import esphome.final_validate as fv @@ -35,6 +43,9 @@ CONF_MONO = "mono" CONF_LEFT = "left" CONF_RIGHT = "right" CONF_STEREO = "stereo" +CONF_BOTH = "both" + +CONF_USE_LEGACY = "use_legacy" i2s_audio_ns = cg.esphome_ns.namespace("i2s_audio") I2SAudioComponent = i2s_audio_ns.class_("I2SAudioComponent", cg.Component) @@ -50,6 +61,12 @@ I2S_MODE_OPTIONS = { CONF_SECONDARY: i2s_mode_t.I2S_MODE_SLAVE, # NOLINT } +i2s_role_t = cg.global_ns.enum("i2s_role_t") +I2S_ROLE_OPTIONS = { + CONF_PRIMARY: i2s_role_t.I2S_ROLE_MASTER, # NOLINT + CONF_SECONDARY: i2s_role_t.I2S_ROLE_SLAVE, # NOLINT +} + # https://github.com/espressif/esp-idf/blob/master/components/soc/{variant}/include/soc/soc_caps.h I2S_PORTS = { VARIANT_ESP32: 2, @@ -60,10 +77,23 @@ I2S_PORTS = { i2s_channel_fmt_t = cg.global_ns.enum("i2s_channel_fmt_t") I2S_CHANNELS = { - CONF_MONO: i2s_channel_fmt_t.I2S_CHANNEL_FMT_ALL_LEFT, - CONF_LEFT: i2s_channel_fmt_t.I2S_CHANNEL_FMT_ONLY_LEFT, - CONF_RIGHT: i2s_channel_fmt_t.I2S_CHANNEL_FMT_ONLY_RIGHT, - CONF_STEREO: i2s_channel_fmt_t.I2S_CHANNEL_FMT_RIGHT_LEFT, + CONF_MONO: i2s_channel_fmt_t.I2S_CHANNEL_FMT_ALL_LEFT, # left data to both channels + CONF_LEFT: i2s_channel_fmt_t.I2S_CHANNEL_FMT_ONLY_LEFT, # mono data + CONF_RIGHT: i2s_channel_fmt_t.I2S_CHANNEL_FMT_ONLY_RIGHT, # mono data + CONF_STEREO: i2s_channel_fmt_t.I2S_CHANNEL_FMT_RIGHT_LEFT, # stereo data to both channels +} + +i2s_slot_mode_t = cg.global_ns.enum("i2s_slot_mode_t") +I2S_SLOT_MODE = { + CONF_MONO: i2s_slot_mode_t.I2S_SLOT_MODE_MONO, + CONF_STEREO: i2s_slot_mode_t.I2S_SLOT_MODE_STEREO, +} + +i2s_std_slot_mask_t = cg.global_ns.enum("i2s_std_slot_mask_t") +I2S_STD_SLOT_MASK = { + CONF_LEFT: i2s_std_slot_mask_t.I2S_STD_SLOT_LEFT, + CONF_RIGHT: i2s_std_slot_mask_t.I2S_STD_SLOT_RIGHT, + CONF_BOTH: i2s_std_slot_mask_t.I2S_STD_SLOT_BOTH, } i2s_bits_per_sample_t = cg.global_ns.enum("i2s_bits_per_sample_t") @@ -83,8 +113,19 @@ I2S_BITS_PER_CHANNEL = { 32: i2s_bits_per_chan_t.I2S_BITS_PER_CHAN_32BIT, } +i2s_slot_bit_width_t = cg.global_ns.enum("i2s_slot_bit_width_t") +I2S_SLOT_BIT_WIDTH = { + "default": i2s_slot_bit_width_t.I2S_SLOT_BIT_WIDTH_AUTO, + 8: i2s_slot_bit_width_t.I2S_SLOT_BIT_WIDTH_8BIT, + 16: i2s_slot_bit_width_t.I2S_SLOT_BIT_WIDTH_16BIT, + 24: i2s_slot_bit_width_t.I2S_SLOT_BIT_WIDTH_24BIT, + 32: i2s_slot_bit_width_t.I2S_SLOT_BIT_WIDTH_32BIT, +} + _validate_bits = cv.float_with_unit("bits", "bit") +_use_legacy_driver = None + def i2s_audio_component_schema( class_: MockObjClass, @@ -97,20 +138,22 @@ def i2s_audio_component_schema( { cv.GenerateID(): cv.declare_id(class_), cv.GenerateID(CONF_I2S_AUDIO_ID): cv.use_id(I2SAudioComponent), - cv.Optional(CONF_CHANNEL, default=default_channel): cv.enum(I2S_CHANNELS), + cv.Optional(CONF_CHANNEL, default=default_channel): cv.one_of( + *I2S_CHANNELS + ), cv.Optional(CONF_SAMPLE_RATE, default=default_sample_rate): cv.int_range( min=1 ), cv.Optional(CONF_BITS_PER_SAMPLE, default=default_bits_per_sample): cv.All( - _validate_bits, cv.enum(I2S_BITS_PER_SAMPLE) + _validate_bits, cv.one_of(*I2S_BITS_PER_SAMPLE) ), - cv.Optional(CONF_I2S_MODE, default=CONF_PRIMARY): cv.enum( - I2S_MODE_OPTIONS, lower=True + cv.Optional(CONF_I2S_MODE, default=CONF_PRIMARY): cv.one_of( + *I2S_MODE_OPTIONS, lower=True ), cv.Optional(CONF_USE_APLL, default=False): cv.boolean, cv.Optional(CONF_BITS_PER_CHANNEL, default="default"): cv.All( cv.Any(cv.float_with_unit("bits", "bit"), "default"), - cv.enum(I2S_BITS_PER_CHANNEL), + cv.one_of(*I2S_BITS_PER_CHANNEL), ), } ) @@ -118,22 +161,60 @@ def i2s_audio_component_schema( async def register_i2s_audio_component(var, config): await cg.register_parented(var, config[CONF_I2S_AUDIO_ID]) - - cg.add(var.set_i2s_mode(config[CONF_I2S_MODE])) - cg.add(var.set_channel(config[CONF_CHANNEL])) + if use_legacy(): + cg.add(var.set_i2s_mode(I2S_MODE_OPTIONS[config[CONF_I2S_MODE]])) + cg.add(var.set_channel(I2S_CHANNELS[config[CONF_CHANNEL]])) + cg.add( + var.set_bits_per_sample(I2S_BITS_PER_SAMPLE[config[CONF_BITS_PER_SAMPLE]]) + ) + cg.add( + var.set_bits_per_channel( + I2S_BITS_PER_CHANNEL[config[CONF_BITS_PER_CHANNEL]] + ) + ) + else: + cg.add(var.set_i2s_role(I2S_ROLE_OPTIONS[config[CONF_I2S_MODE]])) + slot_mode = config[CONF_CHANNEL] + if slot_mode != CONF_STEREO: + slot_mode = CONF_MONO + slot_mask = config[CONF_CHANNEL] + if slot_mask not in [CONF_LEFT, CONF_RIGHT]: + slot_mask = CONF_BOTH + cg.add(var.set_slot_mode(I2S_SLOT_MODE[slot_mode])) + cg.add(var.set_std_slot_mask(I2S_STD_SLOT_MASK[slot_mask])) + cg.add( + var.set_slot_bit_width(I2S_SLOT_BIT_WIDTH[config[CONF_BITS_PER_CHANNEL]]) + ) cg.add(var.set_sample_rate(config[CONF_SAMPLE_RATE])) - cg.add(var.set_bits_per_sample(config[CONF_BITS_PER_SAMPLE])) - cg.add(var.set_bits_per_channel(config[CONF_BITS_PER_CHANNEL])) cg.add(var.set_use_apll(config[CONF_USE_APLL])) -CONFIG_SCHEMA = cv.Schema( - { - cv.GenerateID(): cv.declare_id(I2SAudioComponent), - cv.Required(CONF_I2S_LRCLK_PIN): pins.internal_gpio_output_pin_number, - cv.Optional(CONF_I2S_BCLK_PIN): pins.internal_gpio_output_pin_number, - cv.Optional(CONF_I2S_MCLK_PIN): pins.internal_gpio_output_pin_number, - } +def validate_use_legacy(value): + global _use_legacy_driver # noqa: PLW0603 + if CONF_USE_LEGACY in value: + if (_use_legacy_driver is not None) and ( + _use_legacy_driver != value[CONF_USE_LEGACY] + ): + raise cv.Invalid( + f"All i2s_audio components must set {CONF_USE_LEGACY} to the same value." + ) + if (not value[CONF_USE_LEGACY]) and (CORE.using_arduino): + raise cv.Invalid("Arduino supports only the legacy i2s driver.") + _use_legacy_driver = value[CONF_USE_LEGACY] + return value + + +CONFIG_SCHEMA = cv.All( + cv.Schema( + { + cv.GenerateID(): cv.declare_id(I2SAudioComponent), + cv.Required(CONF_I2S_LRCLK_PIN): pins.internal_gpio_output_pin_number, + cv.Optional(CONF_I2S_BCLK_PIN): pins.internal_gpio_output_pin_number, + cv.Optional(CONF_I2S_MCLK_PIN): pins.internal_gpio_output_pin_number, + cv.Optional(CONF_USE_LEGACY): cv.boolean, + }, + ), + validate_use_legacy, ) @@ -148,12 +229,22 @@ def _final_validate(_): ) +def use_legacy(): + framework_version = CORE.data[KEY_CORE][KEY_FRAMEWORK_VERSION] + if CORE.using_esp_idf and framework_version >= cv.Version(5, 0, 0): + if not _use_legacy_driver: + return False + return True + + FINAL_VALIDATE_SCHEMA = _final_validate async def to_code(config): var = cg.new_Pvariable(config[CONF_ID]) await cg.register_component(var, config) + if use_legacy(): + cg.add_define("USE_I2S_LEGACY") cg.add(var.set_lrclk_pin(config[CONF_I2S_LRCLK_PIN])) if CONF_I2S_BCLK_PIN in config: diff --git a/esphome/components/i2s_audio/i2s_audio.h b/esphome/components/i2s_audio/i2s_audio.h index 7e2798c33d..d8050665e9 100644 --- a/esphome/components/i2s_audio/i2s_audio.h +++ b/esphome/components/i2s_audio/i2s_audio.h @@ -2,9 +2,14 @@ #ifdef USE_ESP32 -#include #include "esphome/core/component.h" #include "esphome/core/helpers.h" +#include "esphome/core/defines.h" +#ifdef USE_I2S_LEGACY +#include +#else +#include +#endif namespace esphome { namespace i2s_audio { @@ -13,19 +18,33 @@ class I2SAudioComponent; class I2SAudioBase : public Parented { public: +#ifdef USE_I2S_LEGACY void set_i2s_mode(i2s_mode_t mode) { this->i2s_mode_ = mode; } void set_channel(i2s_channel_fmt_t channel) { this->channel_ = channel; } - void set_sample_rate(uint32_t sample_rate) { this->sample_rate_ = sample_rate; } void set_bits_per_sample(i2s_bits_per_sample_t bits_per_sample) { this->bits_per_sample_ = bits_per_sample; } void set_bits_per_channel(i2s_bits_per_chan_t bits_per_channel) { this->bits_per_channel_ = bits_per_channel; } +#else + void set_i2s_role(i2s_role_t role) { this->i2s_role_ = role; } + void set_slot_mode(i2s_slot_mode_t slot_mode) { this->slot_mode_ = slot_mode; } + void set_std_slot_mask(i2s_std_slot_mask_t std_slot_mask) { this->std_slot_mask_ = std_slot_mask; } + void set_slot_bit_width(i2s_slot_bit_width_t slot_bit_width) { this->slot_bit_width_ = slot_bit_width; } +#endif + void set_sample_rate(uint32_t sample_rate) { this->sample_rate_ = sample_rate; } void set_use_apll(uint32_t use_apll) { this->use_apll_ = use_apll; } protected: +#ifdef USE_I2S_LEGACY i2s_mode_t i2s_mode_{}; i2s_channel_fmt_t channel_; - uint32_t sample_rate_; i2s_bits_per_sample_t bits_per_sample_; i2s_bits_per_chan_t bits_per_channel_; +#else + i2s_role_t i2s_role_{}; + i2s_slot_mode_t slot_mode_; + i2s_std_slot_mask_t std_slot_mask_; + i2s_slot_bit_width_t slot_bit_width_; +#endif + uint32_t sample_rate_; bool use_apll_; }; @@ -37,6 +56,7 @@ class I2SAudioComponent : public Component { public: void setup() override; +#ifdef USE_I2S_LEGACY i2s_pin_config_t get_pin_config() const { return { .mck_io_num = this->mclk_pin_, @@ -46,6 +66,20 @@ class I2SAudioComponent : public Component { .data_in_num = I2S_PIN_NO_CHANGE, }; } +#else + i2s_std_gpio_config_t get_pin_config() const { + return {.mclk = (gpio_num_t) this->mclk_pin_, + .bclk = (gpio_num_t) this->bclk_pin_, + .ws = (gpio_num_t) this->lrclk_pin_, + .dout = I2S_GPIO_UNUSED, // add local ports + .din = I2S_GPIO_UNUSED, + .invert_flags = { + .mclk_inv = false, + .bclk_inv = false, + .ws_inv = false, + }}; + } +#endif void set_mclk_pin(int pin) { this->mclk_pin_ = pin; } void set_bclk_pin(int pin) { this->bclk_pin_ = pin; } @@ -62,9 +96,13 @@ class I2SAudioComponent : public Component { I2SAudioIn *audio_in_{nullptr}; I2SAudioOut *audio_out_{nullptr}; - +#ifdef USE_I2S_LEGACY int mclk_pin_{I2S_PIN_NO_CHANGE}; int bclk_pin_{I2S_PIN_NO_CHANGE}; +#else + int mclk_pin_{I2S_GPIO_UNUSED}; + int bclk_pin_{I2S_GPIO_UNUSED}; +#endif int lrclk_pin_; i2s_port_t port_{}; }; diff --git a/esphome/components/i2s_audio/media_player/__init__.py b/esphome/components/i2s_audio/media_player/__init__.py index 2882729b1e..bed25b011f 100644 --- a/esphome/components/i2s_audio/media_player/__init__.py +++ b/esphome/components/i2s_audio/media_player/__init__.py @@ -14,6 +14,7 @@ from .. import ( I2SAudioComponent, I2SAudioOut, i2s_audio_ns, + use_legacy, ) CODEOWNERS = ["@jesserockz"] @@ -87,6 +88,14 @@ CONFIG_SCHEMA = cv.All( ) +def _final_validate(_): + if not use_legacy(): + raise cv.Invalid("I2S media player is only compatible with legacy i2s driver.") + + +FINAL_VALIDATE_SCHEMA = _final_validate + + async def to_code(config): var = cg.new_Pvariable(config[CONF_ID]) await cg.register_component(var, config) diff --git a/esphome/components/i2s_audio/microphone/__init__.py b/esphome/components/i2s_audio/microphone/__init__.py index 161046e962..4950a25751 100644 --- a/esphome/components/i2s_audio/microphone/__init__.py +++ b/esphome/components/i2s_audio/microphone/__init__.py @@ -6,12 +6,15 @@ import esphome.config_validation as cv from esphome.const import CONF_ID, CONF_NUMBER from .. import ( + CONF_CHANNEL, CONF_I2S_DIN_PIN, + CONF_MONO, CONF_RIGHT, I2SAudioIn, i2s_audio_component_schema, i2s_audio_ns, register_i2s_audio_component, + use_legacy, ) CODEOWNERS = ["@jesserockz"] @@ -43,6 +46,12 @@ def validate_esp32_variant(config): raise NotImplementedError +def validate_channel(config): + if config[CONF_CHANNEL] == CONF_MONO: + raise cv.Invalid(f"I2S microphone does not support {CONF_MONO}.") + return config + + BASE_SCHEMA = microphone.MICROPHONE_SCHEMA.extend( i2s_audio_component_schema( I2SAudioMicrophone, @@ -71,9 +80,19 @@ CONFIG_SCHEMA = cv.All( key=CONF_ADC_TYPE, ), validate_esp32_variant, + validate_channel, ) +def _final_validate(config): + if not use_legacy(): + if config[CONF_ADC_TYPE] == "internal": + raise cv.Invalid("Internal ADC is only compatible with legacy i2s driver.") + + +FINAL_VALIDATE_SCHEMA = _final_validate + + async def to_code(config): var = cg.new_Pvariable(config[CONF_ID]) await cg.register_component(var, config) diff --git a/esphome/components/i2s_audio/microphone/i2s_audio_microphone.cpp b/esphome/components/i2s_audio/microphone/i2s_audio_microphone.cpp index 4dbc9dcdac..ef375954cd 100644 --- a/esphome/components/i2s_audio/microphone/i2s_audio_microphone.cpp +++ b/esphome/components/i2s_audio/microphone/i2s_audio_microphone.cpp @@ -2,7 +2,12 @@ #ifdef USE_ESP32 +#ifdef USE_I2S_LEGACY #include +#else +#include +#include +#endif #include "esphome/core/hal.h" #include "esphome/core/log.h" @@ -16,6 +21,7 @@ static const char *const TAG = "i2s_audio.microphone"; void I2SAudioMicrophone::setup() { ESP_LOGCONFIG(TAG, "Setting up I2S Audio Microphone..."); +#ifdef USE_I2S_LEGACY #if SOC_I2S_SUPPORTS_ADC if (this->adc_) { if (this->parent_->get_port() != I2S_NUM_0) { @@ -24,6 +30,7 @@ void I2SAudioMicrophone::setup() { return; } } else +#endif #endif { if (this->pdm_) { @@ -47,6 +54,9 @@ void I2SAudioMicrophone::start_() { if (!this->parent_->try_lock()) { return; // Waiting for another i2s to return lock } + esp_err_t err; + +#ifdef USE_I2S_LEGACY i2s_driver_config_t config = { .mode = (i2s_mode_t) (this->i2s_mode_ | I2S_MODE_RX), .sample_rate = this->sample_rate_, @@ -63,8 +73,6 @@ void I2SAudioMicrophone::start_() { .bits_per_chan = this->bits_per_channel_, }; - esp_err_t err; - #if SOC_I2S_SUPPORTS_ADC if (this->adc_) { config.mode = (i2s_mode_t) (config.mode | I2S_MODE_ADC_BUILT_IN); @@ -111,6 +119,109 @@ void I2SAudioMicrophone::start_() { return; } } +#else + i2s_chan_config_t chan_cfg = { + .id = this->parent_->get_port(), + .role = this->i2s_role_, + .dma_desc_num = 4, + .dma_frame_num = 256, + .auto_clear = false, + }; + /* Allocate a new RX channel and get the handle of this channel */ + err = i2s_new_channel(&chan_cfg, NULL, &this->rx_handle_); + if (err != ESP_OK) { + ESP_LOGW(TAG, "Error creating new I2S channel: %s", esp_err_to_name(err)); + this->status_set_error(); + return; + } + + i2s_clock_src_t clk_src = I2S_CLK_SRC_DEFAULT; +#ifdef I2S_CLK_SRC_APLL + if (this->use_apll_) { + clk_src = I2S_CLK_SRC_APLL; + } +#endif + i2s_std_gpio_config_t pin_config = this->parent_->get_pin_config(); +#if SOC_I2S_SUPPORTS_PDM_RX + if (this->pdm_) { + i2s_pdm_rx_clk_config_t clk_cfg = { + .sample_rate_hz = this->sample_rate_, + .clk_src = clk_src, + .mclk_multiple = I2S_MCLK_MULTIPLE_256, + .dn_sample_mode = I2S_PDM_DSR_8S, + }; + + i2s_pdm_rx_slot_config_t slot_cfg = I2S_PDM_RX_SLOT_DEFAULT_CONFIG(I2S_DATA_BIT_WIDTH_16BIT, this->slot_mode_); + switch (this->std_slot_mask_) { + case I2S_STD_SLOT_LEFT: + slot_cfg.slot_mask = I2S_PDM_SLOT_LEFT; + break; + case I2S_STD_SLOT_RIGHT: + slot_cfg.slot_mask = I2S_PDM_SLOT_RIGHT; + break; + case I2S_STD_SLOT_BOTH: + slot_cfg.slot_mask = I2S_PDM_SLOT_BOTH; + break; + } + + /* Init the channel into PDM RX mode */ + i2s_pdm_rx_config_t pdm_rx_cfg = { + .clk_cfg = clk_cfg, + .slot_cfg = slot_cfg, + .gpio_cfg = + { + .clk = pin_config.ws, + .din = this->din_pin_, + .invert_flags = + { + .clk_inv = pin_config.invert_flags.ws_inv, + }, + }, + }; + err = i2s_channel_init_pdm_rx_mode(this->rx_handle_, &pdm_rx_cfg); + } else +#endif + { + i2s_std_clk_config_t clk_cfg = { + .sample_rate_hz = this->sample_rate_, + .clk_src = clk_src, + .mclk_multiple = I2S_MCLK_MULTIPLE_256, + }; + i2s_data_bit_width_t data_bit_width; + if (this->slot_bit_width_ != I2S_SLOT_BIT_WIDTH_8BIT) { + data_bit_width = I2S_DATA_BIT_WIDTH_16BIT; + } else { + data_bit_width = I2S_DATA_BIT_WIDTH_8BIT; + } + i2s_std_slot_config_t std_slot_cfg = I2S_STD_PHILIPS_SLOT_DEFAULT_CONFIG(data_bit_width, this->slot_mode_); + std_slot_cfg.slot_bit_width = this->slot_bit_width_; + std_slot_cfg.slot_mask = this->std_slot_mask_; + + pin_config.din = this->din_pin_; + + i2s_std_config_t std_cfg = { + .clk_cfg = clk_cfg, + .slot_cfg = std_slot_cfg, + .gpio_cfg = pin_config, + }; + /* Initialize the channel */ + err = i2s_channel_init_std_mode(this->rx_handle_, &std_cfg); + } + if (err != ESP_OK) { + ESP_LOGW(TAG, "Error initializing I2S channel: %s", esp_err_to_name(err)); + this->status_set_error(); + return; + } + + /* Before reading data, start the RX channel first */ + i2s_channel_enable(this->rx_handle_); + if (err != ESP_OK) { + ESP_LOGW(TAG, "Error enabling I2S Microphone: %s", esp_err_to_name(err)); + this->status_set_error(); + return; + } +#endif + this->state_ = microphone::STATE_RUNNING; this->high_freq_.start(); this->status_clear_error(); @@ -128,6 +239,7 @@ void I2SAudioMicrophone::stop() { void I2SAudioMicrophone::stop_() { esp_err_t err; +#ifdef USE_I2S_LEGACY #if SOC_I2S_SUPPORTS_ADC if (this->adc_) { err = i2s_adc_disable(this->parent_->get_port()); @@ -150,6 +262,22 @@ void I2SAudioMicrophone::stop_() { this->status_set_error(); return; } +#else + /* Have to stop the channel before deleting it */ + err = i2s_channel_disable(this->rx_handle_); + if (err != ESP_OK) { + ESP_LOGW(TAG, "Error stopping I2S microphone: %s", esp_err_to_name(err)); + this->status_set_error(); + return; + } + /* If the handle is not needed any more, delete it to release the channel resources */ + err = i2s_del_channel(this->rx_handle_); + if (err != ESP_OK) { + ESP_LOGW(TAG, "Error deleting I2S channel: %s", esp_err_to_name(err)); + this->status_set_error(); + return; + } +#endif this->parent_->unlock(); this->state_ = microphone::STATE_STOPPED; this->high_freq_.stop(); @@ -158,7 +286,11 @@ void I2SAudioMicrophone::stop_() { size_t I2SAudioMicrophone::read(int16_t *buf, size_t len) { size_t bytes_read = 0; +#ifdef USE_I2S_LEGACY esp_err_t err = i2s_read(this->parent_->get_port(), buf, len, &bytes_read, (100 / portTICK_PERIOD_MS)); +#else + esp_err_t err = i2s_channel_read(this->rx_handle_, buf, len, &bytes_read, (100 / portTICK_PERIOD_MS)); +#endif if (err != ESP_OK) { ESP_LOGW(TAG, "Error reading from I2S microphone: %s", esp_err_to_name(err)); this->status_set_warning(); @@ -171,6 +303,7 @@ size_t I2SAudioMicrophone::read(int16_t *buf, size_t len) { this->status_clear_warning(); // ESP-IDF I2S implementation right-extends 8-bit data to 16 bits, // and 24-bit data to 32 bits. +#ifdef USE_I2S_LEGACY switch (this->bits_per_sample_) { case I2S_BITS_PER_SAMPLE_8BIT: case I2S_BITS_PER_SAMPLE_16BIT: @@ -188,6 +321,30 @@ size_t I2SAudioMicrophone::read(int16_t *buf, size_t len) { ESP_LOGE(TAG, "Unsupported bits per sample: %d", this->bits_per_sample_); return 0; } +#else +#ifndef USE_ESP32_VARIANT_ESP32 + // For newer ESP32 variants 8 bit data needs to be extended to 16 bit. + if (this->slot_bit_width_ == I2S_SLOT_BIT_WIDTH_8BIT) { + size_t samples_read = bytes_read / sizeof(int8_t); + for (size_t i = samples_read - 1; i >= 0; i--) { + int16_t temp = static_cast(reinterpret_cast(buf)[i]) << 8; + buf[i] = temp; + } + return samples_read * sizeof(int16_t); + } +#else + // For ESP32 8/16 bit standard mono mode samples need to be switched. + if (this->slot_mode_ == I2S_SLOT_MODE_MONO && this->slot_bit_width_ <= 16 && !this->pdm_) { + size_t samples_read = bytes_read / sizeof(int16_t); + for (int i = 0; i < samples_read; i += 2) { + int16_t tmp = buf[i]; + buf[i] = buf[i + 1]; + buf[i + 1] = tmp; + } + } +#endif + return bytes_read; +#endif } void I2SAudioMicrophone::read_() { diff --git a/esphome/components/i2s_audio/microphone/i2s_audio_microphone.h b/esphome/components/i2s_audio/microphone/i2s_audio_microphone.h index ea3f357624..2ff46fabab 100644 --- a/esphome/components/i2s_audio/microphone/i2s_audio_microphone.h +++ b/esphome/components/i2s_audio/microphone/i2s_audio_microphone.h @@ -17,17 +17,23 @@ class I2SAudioMicrophone : public I2SAudioIn, public microphone::Microphone, pub void stop() override; void loop() override; - +#ifdef USE_I2S_LEGACY void set_din_pin(int8_t pin) { this->din_pin_ = pin; } +#else + void set_din_pin(int8_t pin) { this->din_pin_ = (gpio_num_t) pin; } +#endif + void set_pdm(bool pdm) { this->pdm_ = pdm; } size_t read(int16_t *buf, size_t len) override; +#ifdef USE_I2S_LEGACY #if SOC_I2S_SUPPORTS_ADC void set_adc_channel(adc1_channel_t channel) { this->adc_channel_ = channel; this->adc_ = true; } +#endif #endif protected: @@ -35,10 +41,15 @@ class I2SAudioMicrophone : public I2SAudioIn, public microphone::Microphone, pub void stop_(); void read_(); +#ifdef USE_I2S_LEGACY int8_t din_pin_{I2S_PIN_NO_CHANGE}; #if SOC_I2S_SUPPORTS_ADC adc1_channel_t adc_channel_{ADC1_CHANNEL_MAX}; bool adc_{false}; +#endif +#else + gpio_num_t din_pin_{I2S_GPIO_UNUSED}; + i2s_chan_handle_t rx_handle_; #endif bool pdm_{false}; diff --git a/esphome/components/i2s_audio/speaker/__init__.py b/esphome/components/i2s_audio/speaker/__init__.py index aa3b50d336..7e41cd3991 100644 --- a/esphome/components/i2s_audio/speaker/__init__.py +++ b/esphome/components/i2s_audio/speaker/__init__.py @@ -26,6 +26,7 @@ from .. import ( i2s_audio_component_schema, i2s_audio_ns, register_i2s_audio_component, + use_legacy, ) AUTO_LOAD = ["audio"] @@ -60,7 +61,7 @@ I2C_COMM_FMT_OPTIONS = { "pcm_long": i2s_comm_format_t.I2S_COMM_FORMAT_PCM_LONG, } -NO_INTERNAL_DAC_VARIANTS = [esp32.const.VARIANT_ESP32S2] +INTERNAL_DAC_VARIANTS = [esp32.const.VARIANT_ESP32] def _set_num_channels_from_config(config): @@ -101,7 +102,7 @@ def _validate_esp32_variant(config): if config[CONF_DAC_TYPE] != "internal": return config variant = esp32.get_esp32_variant() - if variant in NO_INTERNAL_DAC_VARIANTS: + if variant not in INTERNAL_DAC_VARIANTS: raise cv.Invalid(f"{variant} does not have an internal DAC") return config @@ -143,8 +144,8 @@ CONFIG_SCHEMA = cv.All( cv.Required( CONF_I2S_DOUT_PIN ): pins.internal_gpio_output_pin_number, - cv.Optional(CONF_I2S_COMM_FMT, default="stand_i2s"): cv.enum( - I2C_COMM_FMT_OPTIONS, lower=True + cv.Optional(CONF_I2S_COMM_FMT, default="stand_i2s"): cv.one_of( + *I2C_COMM_FMT_OPTIONS, lower=True ), } ), @@ -157,6 +158,19 @@ CONFIG_SCHEMA = cv.All( ) +def _final_validate(config): + if not use_legacy(): + if config[CONF_DAC_TYPE] == "internal": + raise cv.Invalid("Internal DAC is only compatible with legacy i2s driver.") + if config[CONF_I2S_COMM_FMT] == "stand_max": + raise cv.Invalid( + "I2S standard max format only implemented with legacy i2s driver." + ) + + +FINAL_VALIDATE_SCHEMA = _final_validate + + async def to_code(config): var = cg.new_Pvariable(config[CONF_ID]) await cg.register_component(var, config) @@ -167,7 +181,17 @@ async def to_code(config): cg.add(var.set_internal_dac_mode(config[CONF_CHANNEL])) else: cg.add(var.set_dout_pin(config[CONF_I2S_DOUT_PIN])) - cg.add(var.set_i2s_comm_fmt(config[CONF_I2S_COMM_FMT])) + if use_legacy(): + cg.add( + var.set_i2s_comm_fmt(I2C_COMM_FMT_OPTIONS[config[CONF_I2S_COMM_FMT]]) + ) + else: + fmt = "std" # equals stand_i2s, stand_pcm_long, i2s_msb, pcm_long + if config[CONF_I2S_COMM_FMT] in ["stand_msb", "i2s_lsb"]: + fmt = "msb" + elif config[CONF_I2S_COMM_FMT] in ["stand_pcm_short", "pcm_short", "pcm"]: + fmt = "pcm" + cg.add(var.set_i2s_comm_fmt(fmt)) if config[CONF_TIMEOUT] != CONF_NEVER: cg.add(var.set_timeout(config[CONF_TIMEOUT])) cg.add(var.set_buffer_duration(config[CONF_BUFFER_DURATION])) diff --git a/esphome/components/i2s_audio/speaker/i2s_audio_speaker.cpp b/esphome/components/i2s_audio/speaker/i2s_audio_speaker.cpp index da25914c87..cb3bbc8cf2 100644 --- a/esphome/components/i2s_audio/speaker/i2s_audio_speaker.cpp +++ b/esphome/components/i2s_audio/speaker/i2s_audio_speaker.cpp @@ -2,7 +2,11 @@ #ifdef USE_ESP32 +#ifdef USE_I2S_LEGACY #include +#else +#include +#endif #include "esphome/components/audio/audio.h" @@ -294,13 +298,21 @@ void I2SAudioSpeaker::speaker_task(void *params) { // Audio stream info changed, stop the speaker task so it will restart with the proper settings. break; } - +#ifdef USE_I2S_LEGACY i2s_event_t i2s_event; while (xQueueReceive(this_speaker->i2s_event_queue_, &i2s_event, 0)) { if (i2s_event.type == I2S_EVENT_TX_Q_OVF) { tx_dma_underflow = true; } } +#else + bool overflow; + while (xQueueReceive(this_speaker->i2s_event_queue_, &overflow, 0)) { + if (overflow) { + tx_dma_underflow = true; + } + } +#endif if (this_speaker->pause_state_) { // Pause state is accessed atomically, so thread safe @@ -319,6 +331,18 @@ void I2SAudioSpeaker::speaker_task(void *params) { bytes_read / sizeof(int16_t), this_speaker->q15_volume_factor_); } +#ifdef USE_ESP32_VARIANT_ESP32 + // For ESP32 8/16 bit mono mode samples need to be switched. + if (audio_stream_info.get_channels() == 1 && audio_stream_info.get_bits_per_sample() <= 16) { + size_t len = bytes_read / sizeof(int16_t); + int16_t *tmp_buf = (int16_t *) this_speaker->data_buffer_; + for (int i = 0; i < len; i += 2) { + int16_t tmp = tmp_buf[i]; + tmp_buf[i] = tmp_buf[i + 1]; + tmp_buf[i + 1] = tmp; + } + } +#endif // Write the audio data to a single DMA buffer at a time to reduce latency for the audio duration played // callback. const uint32_t batches = (bytes_read + single_dma_buffer_input_size - 1) / single_dma_buffer_input_size; @@ -327,6 +351,7 @@ void I2SAudioSpeaker::speaker_task(void *params) { size_t bytes_written = 0; size_t bytes_to_write = std::min(single_dma_buffer_input_size, bytes_read); +#ifdef USE_I2S_LEGACY if (audio_stream_info.get_bits_per_sample() == (uint8_t) this_speaker->bits_per_sample_) { i2s_write(this_speaker->parent_->get_port(), this_speaker->data_buffer_ + i * single_dma_buffer_input_size, bytes_to_write, &bytes_written, pdMS_TO_TICKS(DMA_BUFFER_DURATION_MS * 5)); @@ -336,6 +361,10 @@ void I2SAudioSpeaker::speaker_task(void *params) { audio_stream_info.get_bits_per_sample(), this_speaker->bits_per_sample_, &bytes_written, pdMS_TO_TICKS(DMA_BUFFER_DURATION_MS * 5)); } +#else + i2s_channel_write(this_speaker->tx_handle_, this_speaker->data_buffer_ + i * single_dma_buffer_input_size, + bytes_to_write, &bytes_written, pdMS_TO_TICKS(DMA_BUFFER_DURATION_MS * 5)); +#endif uint32_t write_timestamp = micros(); @@ -369,8 +398,12 @@ void I2SAudioSpeaker::speaker_task(void *params) { } xEventGroupSetBits(this_speaker->event_group_, SpeakerEventGroupBits::STATE_STOPPING); - +#ifdef USE_I2S_LEGACY i2s_driver_uninstall(this_speaker->parent_->get_port()); +#else + i2s_channel_disable(this_speaker->tx_handle_); + i2s_del_channel(this_speaker->tx_handle_); +#endif this_speaker->parent_->unlock(); } @@ -462,12 +495,21 @@ esp_err_t I2SAudioSpeaker::allocate_buffers_(size_t data_buffer_size, size_t rin } esp_err_t I2SAudioSpeaker::start_i2s_driver_(audio::AudioStreamInfo &audio_stream_info) { +#ifdef USE_I2S_LEGACY if ((this->i2s_mode_ & I2S_MODE_SLAVE) && (this->sample_rate_ != audio_stream_info.get_sample_rate())) { // NOLINT +#else + if ((this->i2s_role_ & I2S_ROLE_SLAVE) && (this->sample_rate_ != audio_stream_info.get_sample_rate())) { // NOLINT +#endif // Can't reconfigure I2S bus, so the sample rate must match the configured value return ESP_ERR_NOT_SUPPORTED; } +#ifdef USE_I2S_LEGACY if ((i2s_bits_per_sample_t) audio_stream_info.get_bits_per_sample() > this->bits_per_sample_) { +#else + if (this->slot_bit_width_ != I2S_SLOT_BIT_WIDTH_AUTO && + (i2s_slot_bit_width_t) audio_stream_info.get_bits_per_sample() > this->slot_bit_width_) { +#endif // Currently can't handle the case when the incoming audio has more bits per sample than the configured value return ESP_ERR_NOT_SUPPORTED; } @@ -476,6 +518,9 @@ esp_err_t I2SAudioSpeaker::start_i2s_driver_(audio::AudioStreamInfo &audio_strea return ESP_ERR_INVALID_STATE; } + uint32_t dma_buffer_length = audio_stream_info.ms_to_frames(DMA_BUFFER_DURATION_MS); + +#ifdef USE_I2S_LEGACY i2s_channel_fmt_t channel = this->channel_; if (audio_stream_info.get_channels() == 1) { @@ -488,8 +533,6 @@ esp_err_t I2SAudioSpeaker::start_i2s_driver_(audio::AudioStreamInfo &audio_strea channel = I2S_CHANNEL_FMT_RIGHT_LEFT; } - int dma_buffer_length = audio_stream_info.ms_to_frames(DMA_BUFFER_DURATION_MS); - i2s_driver_config_t config = { .mode = (i2s_mode_t) (this->i2s_mode_ | I2S_MODE_TX), .sample_rate = audio_stream_info.get_sample_rate(), @@ -498,7 +541,7 @@ esp_err_t I2SAudioSpeaker::start_i2s_driver_(audio::AudioStreamInfo &audio_strea .communication_format = this->i2s_comm_fmt_, .intr_alloc_flags = ESP_INTR_FLAG_LEVEL1, .dma_buf_count = DMA_BUFFERS_COUNT, - .dma_buf_len = dma_buffer_length, + .dma_buf_len = (int) dma_buffer_length, .use_apll = this->use_apll_, .tx_desc_auto_clear = true, .fixed_mclk = I2S_PIN_NO_CHANGE, @@ -545,6 +588,89 @@ esp_err_t I2SAudioSpeaker::start_i2s_driver_(audio::AudioStreamInfo &audio_strea i2s_driver_uninstall(this->parent_->get_port()); this->parent_->unlock(); } +#else + i2s_chan_config_t chan_cfg = { + .id = this->parent_->get_port(), + .role = this->i2s_role_, + .dma_desc_num = DMA_BUFFERS_COUNT, + .dma_frame_num = dma_buffer_length, + .auto_clear = true, + }; + /* Allocate a new TX channel and get the handle of this channel */ + esp_err_t err = i2s_new_channel(&chan_cfg, &this->tx_handle_, NULL); + if (err != ESP_OK) { + this->parent_->unlock(); + return err; + } + + i2s_clock_src_t clk_src = I2S_CLK_SRC_DEFAULT; +#ifdef I2S_CLK_SRC_APLL + if (this->use_apll_) { + clk_src = I2S_CLK_SRC_APLL; + } +#endif + i2s_std_gpio_config_t pin_config = this->parent_->get_pin_config(); + + i2s_std_clk_config_t clk_cfg = { + .sample_rate_hz = audio_stream_info.get_sample_rate(), + .clk_src = clk_src, + .mclk_multiple = I2S_MCLK_MULTIPLE_256, + }; + + i2s_slot_mode_t slot_mode = this->slot_mode_; + i2s_std_slot_mask_t slot_mask = this->std_slot_mask_; + if (audio_stream_info.get_channels() == 1) { + slot_mode = I2S_SLOT_MODE_MONO; + } else if (audio_stream_info.get_channels() == 2) { + slot_mode = I2S_SLOT_MODE_STEREO; + slot_mask = I2S_STD_SLOT_BOTH; + } + + i2s_std_slot_config_t std_slot_cfg; + if (this->i2s_comm_fmt_ == "std") { + std_slot_cfg = + I2S_STD_PHILIPS_SLOT_DEFAULT_CONFIG((i2s_data_bit_width_t) audio_stream_info.get_bits_per_sample(), slot_mode); + } else if (this->i2s_comm_fmt_ == "pcm") { + std_slot_cfg = + I2S_STD_PCM_SLOT_DEFAULT_CONFIG((i2s_data_bit_width_t) audio_stream_info.get_bits_per_sample(), slot_mode); + } else { + std_slot_cfg = + I2S_STD_MSB_SLOT_DEFAULT_CONFIG((i2s_data_bit_width_t) audio_stream_info.get_bits_per_sample(), slot_mode); + } + std_slot_cfg.slot_bit_width = this->slot_bit_width_; + std_slot_cfg.slot_mask = slot_mask; + + pin_config.dout = this->dout_pin_; + + i2s_std_config_t std_cfg = { + .clk_cfg = clk_cfg, + .slot_cfg = std_slot_cfg, + .gpio_cfg = pin_config, + }; + /* Initialize the channel */ + err = i2s_channel_init_std_mode(this->tx_handle_, &std_cfg); + + if (err != ESP_OK) { + i2s_del_channel(this->tx_handle_); + this->parent_->unlock(); + return err; + } + if (this->i2s_event_queue_ == nullptr) { + this->i2s_event_queue_ = xQueueCreate(1, sizeof(bool)); + } + const i2s_event_callbacks_t callbacks = { + .on_send_q_ovf = i2s_overflow_cb, + }; + + i2s_channel_register_event_callback(this->tx_handle_, &callbacks, this); + + /* Before reading data, start the TX channel first */ + i2s_channel_enable(this->tx_handle_); + if (err != ESP_OK) { + i2s_del_channel(this->tx_handle_); + this->parent_->unlock(); + } +#endif return err; } @@ -564,6 +690,15 @@ void I2SAudioSpeaker::delete_task_(size_t buffer_size) { vTaskDelete(nullptr); } +#ifndef USE_I2S_LEGACY +bool IRAM_ATTR I2SAudioSpeaker::i2s_overflow_cb(i2s_chan_handle_t handle, i2s_event_data_t *event, void *user_ctx) { + I2SAudioSpeaker *this_speaker = (I2SAudioSpeaker *) user_ctx; + bool overflow = true; + xQueueOverwrite(this_speaker->i2s_event_queue_, &overflow); + return false; +} +#endif + } // namespace i2s_audio } // namespace esphome diff --git a/esphome/components/i2s_audio/speaker/i2s_audio_speaker.h b/esphome/components/i2s_audio/speaker/i2s_audio_speaker.h index 7b14a57aac..b5e4b94bc4 100644 --- a/esphome/components/i2s_audio/speaker/i2s_audio_speaker.h +++ b/esphome/components/i2s_audio/speaker/i2s_audio_speaker.h @@ -4,8 +4,6 @@ #include "../i2s_audio.h" -#include - #include #include #include @@ -30,11 +28,16 @@ class I2SAudioSpeaker : public I2SAudioOut, public speaker::Speaker, public Comp void set_buffer_duration(uint32_t buffer_duration_ms) { this->buffer_duration_ms_ = buffer_duration_ms; } void set_timeout(uint32_t ms) { this->timeout_ = ms; } - void set_dout_pin(uint8_t pin) { this->dout_pin_ = pin; } +#ifdef USE_I2S_LEGACY #if SOC_I2S_SUPPORTS_DAC void set_internal_dac_mode(i2s_dac_mode_t mode) { this->internal_dac_mode_ = mode; } #endif + void set_dout_pin(uint8_t pin) { this->dout_pin_ = pin; } void set_i2s_comm_fmt(i2s_comm_format_t mode) { this->i2s_comm_fmt_ = mode; } +#else + void set_dout_pin(uint8_t pin) { this->dout_pin_ = (gpio_num_t) pin; } + void set_i2s_comm_fmt(std::string mode) { this->i2s_comm_fmt_ = std::move(mode); } +#endif void start() override; void stop() override; @@ -86,6 +89,10 @@ class I2SAudioSpeaker : public I2SAudioOut, public speaker::Speaker, public Comp /// @return True if an ERR_ESP bit is set and false if err == ESP_OK bool send_esp_err_to_event_group_(esp_err_t err); +#ifndef USE_I2S_LEGACY + static bool i2s_overflow_cb(i2s_chan_handle_t handle, i2s_event_data_t *event, void *user_ctx); +#endif + /// @brief Allocates the data buffer and ring buffer /// @param data_buffer_size Number of bytes to allocate for the data buffer. /// @param ring_buffer_size Number of bytes to allocate for the ring buffer. @@ -121,7 +128,6 @@ class I2SAudioSpeaker : public I2SAudioOut, public speaker::Speaker, public Comp uint32_t buffer_duration_ms_; optional timeout_; - uint8_t dout_pin_; bool task_created_{false}; bool pause_state_{false}; @@ -130,10 +136,17 @@ class I2SAudioSpeaker : public I2SAudioOut, public speaker::Speaker, public Comp size_t bytes_written_{0}; +#ifdef USE_I2S_LEGACY #if SOC_I2S_SUPPORTS_DAC i2s_dac_mode_t internal_dac_mode_{I2S_DAC_CHANNEL_DISABLE}; #endif + uint8_t dout_pin_; i2s_comm_format_t i2s_comm_fmt_; +#else + gpio_num_t dout_pin_; + std::string i2s_comm_fmt_; + i2s_chan_handle_t tx_handle_; +#endif uint32_t accumulated_frames_written_{0}; }; diff --git a/esphome/core/defines.h b/esphome/core/defines.h index d6c2bf25e6..81ff6999ba 100644 --- a/esphome/core/defines.h +++ b/esphome/core/defines.h @@ -115,6 +115,7 @@ #ifdef USE_ARDUINO #define USE_PROMETHEUS #define USE_WIFI_WPA2_EAP +#define USE_I2S_LEGACY #endif // IDF-specific feature flags diff --git a/tests/components/micro_wake_word/common.yaml b/tests/components/micro_wake_word/common.yaml index 8bd7345307..c5422baa67 100644 --- a/tests/components/micro_wake_word/common.yaml +++ b/tests/components/micro_wake_word/common.yaml @@ -8,6 +8,7 @@ microphone: i2s_din_pin: GPIO17 adc_type: external pdm: true + bits_per_sample: 16bit micro_wake_word: on_wake_word_detected: diff --git a/tests/components/microphone/test.esp32-idf.yaml b/tests/components/microphone/test.esp32-idf.yaml index 392df582cc..fe9feb9888 100644 --- a/tests/components/microphone/test.esp32-idf.yaml +++ b/tests/components/microphone/test.esp32-idf.yaml @@ -4,9 +4,18 @@ substitutions: i2s_mclk_pin: GPIO17 i2s_din_pin: GPIO33 -<<: !include common.yaml +i2s_audio: + i2s_bclk_pin: ${i2s_bclk_pin} + i2s_lrclk_pin: ${i2s_lrclk_pin} + i2s_mclk_pin: ${i2s_mclk_pin} + use_legacy: true microphone: + - platform: i2s_audio + id: mic_id_external + i2s_din_pin: ${i2s_din_pin} + adc_type: external + pdm: false - platform: i2s_audio id: mic_id_adc adc_pin: 32 From 666d5374ea59f12790a256196146b2d389a77d34 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 24 Apr 2025 13:08:24 -1000 Subject: [PATCH 008/102] Bump actions/download-artifact from 4.2.1 to 4.3.0 (#8617) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7793c574fe..b36c1dd3f3 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -176,7 +176,7 @@ jobs: - uses: actions/checkout@v4.1.7 - name: Download digests - uses: actions/download-artifact@v4.2.1 + uses: actions/download-artifact@v4.3.0 with: pattern: digests-* path: /tmp/digests From 3d24dea455b5813ee23dd822138c042e2ef324cd Mon Sep 17 00:00:00 2001 From: Guillermo Ruffino Date: Thu, 24 Apr 2025 22:30:22 -0300 Subject: [PATCH 009/102] fix schema-gen-ci failures (#8621) --- script/build_language_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/script/build_language_schema.py b/script/build_language_schema.py index a0edad150a..4473ec1b5a 100755 --- a/script/build_language_schema.py +++ b/script/build_language_schema.py @@ -77,7 +77,7 @@ def get_component_names(): if d not in component_names and d not in skip_components: component_names.append(d) - return component_names + return sorted(component_names) def load_components(): From 8f9fbb15b8aed9523a1b5cf0238a64d90b95ea3e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 24 Apr 2025 15:31:50 -1000 Subject: [PATCH 010/102] Bump docker/build-push-action from 6.15.0 to 6.16.0 in /.github/actions/build-image (#8619) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/actions/build-image/action.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/actions/build-image/action.yaml b/.github/actions/build-image/action.yaml index 86a6b3f4d2..c171a0a13c 100644 --- a/.github/actions/build-image/action.yaml +++ b/.github/actions/build-image/action.yaml @@ -46,7 +46,7 @@ runs: - name: Build and push to ghcr by digest id: build-ghcr - uses: docker/build-push-action@v6.15.0 + uses: docker/build-push-action@v6.16.0 env: DOCKER_BUILD_SUMMARY: false DOCKER_BUILD_RECORD_UPLOAD: false @@ -72,7 +72,7 @@ runs: - name: Build and push to dockerhub by digest id: build-dockerhub - uses: docker/build-push-action@v6.15.0 + uses: docker/build-push-action@v6.16.0 env: DOCKER_BUILD_SUMMARY: false DOCKER_BUILD_RECORD_UPLOAD: false From 805a6d85a5380d2299290153a0d650c23c201521 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 24 Apr 2025 16:12:13 -1000 Subject: [PATCH 011/102] Bump ruff from 0.11.6 to 0.11.7 (#8615) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements_test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements_test.txt b/requirements_test.txt index 93352037af..96568f2e3a 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -1,6 +1,6 @@ pylint==3.3.6 flake8==7.2.0 # also change in .pre-commit-config.yaml when updating -ruff==0.11.6 # also change in .pre-commit-config.yaml when updating +ruff==0.11.7 # also change in .pre-commit-config.yaml when updating pyupgrade==3.19.1 # also change in .pre-commit-config.yaml when updating pre-commit From fb97ef33a84c748919b817998109de771eef996c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 24 Apr 2025 16:17:39 -1000 Subject: [PATCH 012/102] Bump setuptools from 78.1.0 to 79.0.1 (#8614) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 77dcaf1fab..daf702d2e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools==78.1.0", "wheel>=0.43,<0.46"] +requires = ["setuptools==79.0.1", "wheel>=0.43,<0.46"] build-backend = "setuptools.build_meta" [project] From 8a3fe9ce4c0ed8db77468818b3cd4b0cc9a7ed73 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 24 Apr 2025 16:18:13 -1000 Subject: [PATCH 013/102] Bump actions/setup-python from 5.5.0 to 5.6.0 (#8618) --- .github/workflows/ci-api-proto.yml | 2 +- .github/workflows/ci-docker.yml | 2 +- .github/workflows/ci.yml | 2 +- .github/workflows/release.yml | 4 ++-- .github/workflows/sync-device-classes.yml | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci-api-proto.yml b/.github/workflows/ci-api-proto.yml index 233fb64693..d6469236d5 100644 --- a/.github/workflows/ci-api-proto.yml +++ b/.github/workflows/ci-api-proto.yml @@ -23,7 +23,7 @@ jobs: - name: Checkout uses: actions/checkout@v4.1.7 - name: Set up Python - uses: actions/setup-python@v5.5.0 + uses: actions/setup-python@v5.6.0 with: python-version: "3.11" diff --git a/.github/workflows/ci-docker.yml b/.github/workflows/ci-docker.yml index 0a08e6ffad..168333f3ff 100644 --- a/.github/workflows/ci-docker.yml +++ b/.github/workflows/ci-docker.yml @@ -42,7 +42,7 @@ jobs: steps: - uses: actions/checkout@v4.1.7 - name: Set up Python - uses: actions/setup-python@v5.5.0 + uses: actions/setup-python@v5.6.0 with: python-version: "3.9" - name: Set up Docker Buildx diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9022da68ac..0b01758323 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -42,7 +42,7 @@ jobs: run: echo key="${{ hashFiles('requirements.txt', 'requirements_optional.txt', 'requirements_test.txt') }}" >> $GITHUB_OUTPUT - name: Set up Python ${{ env.DEFAULT_PYTHON }} id: python - uses: actions/setup-python@v5.5.0 + uses: actions/setup-python@v5.6.0 with: python-version: ${{ env.DEFAULT_PYTHON }} - name: Restore Python virtual environment diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index b36c1dd3f3..417212f40e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -53,7 +53,7 @@ jobs: steps: - uses: actions/checkout@v4.1.7 - name: Set up Python - uses: actions/setup-python@v5.5.0 + uses: actions/setup-python@v5.6.0 with: python-version: "3.x" - name: Set up python environment @@ -84,7 +84,7 @@ jobs: steps: - uses: actions/checkout@v4.1.7 - name: Set up Python - uses: actions/setup-python@v5.5.0 + uses: actions/setup-python@v5.6.0 with: python-version: "3.9" diff --git a/.github/workflows/sync-device-classes.yml b/.github/workflows/sync-device-classes.yml index 0a0c834a71..b262a9f9c1 100644 --- a/.github/workflows/sync-device-classes.yml +++ b/.github/workflows/sync-device-classes.yml @@ -22,7 +22,7 @@ jobs: path: lib/home-assistant - name: Setup Python - uses: actions/setup-python@v5.5.0 + uses: actions/setup-python@v5.6.0 with: python-version: 3.12 From 526db0102cad87818e96d585584b55a9565bfb94 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 24 Apr 2025 16:18:33 -1000 Subject: [PATCH 014/102] Bump actions/setup-python from 5.5.0 to 5.6.0 in /.github/actions/restore-python (#8616) --- .github/actions/restore-python/action.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/actions/restore-python/action.yml b/.github/actions/restore-python/action.yml index 3ac91f8ea2..b9913605da 100644 --- a/.github/actions/restore-python/action.yml +++ b/.github/actions/restore-python/action.yml @@ -17,7 +17,7 @@ runs: steps: - name: Set up Python ${{ inputs.python-version }} id: python - uses: actions/setup-python@v5.5.0 + uses: actions/setup-python@v5.6.0 with: python-version: ${{ inputs.python-version }} - name: Restore Python virtual environment From 4c8f5275f9a82ad47c5687b9bc7393e6cd36e0ed Mon Sep 17 00:00:00 2001 From: Thomas Rupprecht Date: Fri, 25 Apr 2025 21:47:45 +0200 Subject: [PATCH 015/102] replace `http` with `https` (#8628) --- esphome/config_validation.py | 2 +- esphome/platformio_api.py | 2 +- esphome/wizard.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/esphome/config_validation.py b/esphome/config_validation.py index 7bd3f90adc..993fcfac5b 100644 --- a/esphome/config_validation.py +++ b/esphome/config_validation.py @@ -116,7 +116,7 @@ RequiredFieldInvalid = vol.RequiredFieldInvalid ROOT_CONFIG_PATH = object() RESERVED_IDS = [ - # C++ keywords http://en.cppreference.com/w/cpp/keyword + # C++ keywords https://en.cppreference.com/w/cpp/keyword "alarm", "alignas", "alignof", diff --git a/esphome/platformio_api.py b/esphome/platformio_api.py index b81ec4ab37..ed95fa125e 100644 --- a/esphome/platformio_api.py +++ b/esphome/platformio_api.py @@ -53,7 +53,7 @@ FILTER_PLATFORMIO_LINES = [ f"You can ignore this message, if `.*{IGNORE_LIB_WARNINGS}.*` is a built-in library.*", r"Scanning dependencies...", r"Found \d+ compatible libraries", - r"Memory Usage -> http://bit.ly/pio-memory-usage", + r"Memory Usage -> https://bit.ly/pio-memory-usage", r"Found: https://platformio.org/lib/show/.*", r"Using cache: .*", r"Installing dependencies", diff --git a/esphome/wizard.py b/esphome/wizard.py index 7fdf245c76..8c5bd07e1f 100644 --- a/esphome/wizard.py +++ b/esphome/wizard.py @@ -361,11 +361,11 @@ def wizard(path): if platform == "ESP32": board_link = ( - "http://docs.platformio.org/en/latest/platforms/espressif32.html#boards" + "https://docs.platformio.org/en/latest/platforms/espressif32.html#boards" ) elif platform == "ESP8266": board_link = ( - "http://docs.platformio.org/en/latest/platforms/espressif8266.html#boards" + "https://docs.platformio.org/en/latest/platforms/espressif8266.html#boards" ) elif platform == "RP2040": board_link = ( From adcd6517dba3f9c799ca36bd3fbc311bdeaf1ac1 Mon Sep 17 00:00:00 2001 From: Jesse Hills <3060199+jesserockz@users.noreply.github.com> Date: Mon, 28 Apr 2025 11:14:50 +1200 Subject: [PATCH 016/102] [docker] Use new base container image (#8582) --- .dockerignore | 3 +- .github/actions/build-image/action.yaml | 29 ++-- .github/actions/restore-python/action.yml | 4 +- .github/dependabot.yml | 1 - .github/workflows/ci-docker.yml | 7 +- .github/workflows/ci.yml | 4 +- .github/workflows/release.yml | 64 +++---- docker/Dockerfile | 203 ++++------------------ docker/build.py | 15 +- pyproject.toml | 1 - requirements.txt | 1 + requirements_optional.txt | 1 - script/setup | 23 +-- script/setup.bat | 4 +- 14 files changed, 108 insertions(+), 252 deletions(-) delete mode 100644 requirements_optional.txt diff --git a/.dockerignore b/.dockerignore index 7998ff877f..ccd466d8cb 100644 --- a/.dockerignore +++ b/.dockerignore @@ -114,4 +114,5 @@ config/ examples/ Dockerfile .git/ -tests/build/ +tests/ +.* diff --git a/.github/actions/build-image/action.yaml b/.github/actions/build-image/action.yaml index c171a0a13c..3d6de54f42 100644 --- a/.github/actions/build-image/action.yaml +++ b/.github/actions/build-image/action.yaml @@ -1,15 +1,11 @@ name: Build Image inputs: - platform: - description: "Platform to build for" - required: true - example: "linux/amd64" target: description: "Target to build" required: true example: "docker" - baseimg: - description: "Base image type" + build_type: + description: "Build type" required: true example: "docker" suffix: @@ -19,6 +15,11 @@ inputs: description: "Version to build" required: true example: "2023.12.0" + base_os: + description: "Base OS to use" + required: false + default: "debian" + example: "debian" runs: using: "composite" steps: @@ -53,22 +54,22 @@ runs: with: context: . file: ./docker/Dockerfile - platforms: ${{ inputs.platform }} target: ${{ inputs.target }} cache-from: type=gha cache-to: ${{ steps.cache-to.outputs.value }} build-args: | - BASEIMGTYPE=${{ inputs.baseimg }} + BUILD_TYPE=${{ inputs.build_type }} BUILD_VERSION=${{ inputs.version }} + BUILD_OS=${{ inputs.base_os }} outputs: | type=image,name=ghcr.io/${{ steps.tags.outputs.image_name }},push-by-digest=true,name-canonical=true,push=true - name: Export ghcr digests shell: bash run: | - mkdir -p /tmp/digests/${{ inputs.target }}/ghcr + mkdir -p /tmp/digests/${{ inputs.build_type }}/ghcr digest="${{ steps.build-ghcr.outputs.digest }}" - touch "/tmp/digests/${{ inputs.target }}/ghcr/${digest#sha256:}" + touch "/tmp/digests/${{ inputs.build_type }}/ghcr/${digest#sha256:}" - name: Build and push to dockerhub by digest id: build-dockerhub @@ -79,19 +80,19 @@ runs: with: context: . file: ./docker/Dockerfile - platforms: ${{ inputs.platform }} target: ${{ inputs.target }} cache-from: type=gha cache-to: ${{ steps.cache-to.outputs.value }} build-args: | - BASEIMGTYPE=${{ inputs.baseimg }} + BUILD_TYPE=${{ inputs.build_type }} BUILD_VERSION=${{ inputs.version }} + BUILD_OS=${{ inputs.base_os }} outputs: | type=image,name=docker.io/${{ steps.tags.outputs.image_name }},push-by-digest=true,name-canonical=true,push=true - name: Export dockerhub digests shell: bash run: | - mkdir -p /tmp/digests/${{ inputs.target }}/dockerhub + mkdir -p /tmp/digests/${{ inputs.build_type }}/dockerhub digest="${{ steps.build-dockerhub.outputs.digest }}" - touch "/tmp/digests/${{ inputs.target }}/dockerhub/${digest#sha256:}" + touch "/tmp/digests/${{ inputs.build_type }}/dockerhub/${digest#sha256:}" diff --git a/.github/actions/restore-python/action.yml b/.github/actions/restore-python/action.yml index b9913605da..082539adaa 100644 --- a/.github/actions/restore-python/action.yml +++ b/.github/actions/restore-python/action.yml @@ -34,7 +34,7 @@ runs: python -m venv venv source venv/bin/activate python --version - pip install -r requirements.txt -r requirements_optional.txt -r requirements_test.txt + pip install -r requirements.txt -r requirements_test.txt pip install -e . - name: Create Python virtual environment if: steps.cache-venv.outputs.cache-hit != 'true' && runner.os == 'Windows' @@ -43,5 +43,5 @@ runs: python -m venv venv ./venv/Scripts/activate python --version - pip install -r requirements.txt -r requirements_optional.txt -r requirements_test.txt + pip install -r requirements.txt -r requirements_test.txt pip install -e . diff --git a/.github/dependabot.yml b/.github/dependabot.yml index bb35f16048..cf507bbaa6 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -17,7 +17,6 @@ updates: docker-actions: applies-to: version-updates patterns: - - "docker/setup-qemu-action" - "docker/login-action" - "docker/setup-buildx-action" - package-ecosystem: github-actions diff --git a/.github/workflows/ci-docker.yml b/.github/workflows/ci-docker.yml index 168333f3ff..511ec55f3e 100644 --- a/.github/workflows/ci-docker.yml +++ b/.github/workflows/ci-docker.yml @@ -37,8 +37,11 @@ jobs: strategy: fail-fast: false matrix: - os: ["ubuntu-latest", "ubuntu-24.04-arm"] - build_type: ["ha-addon", "docker", "lint"] + os: ["ubuntu-24.04", "ubuntu-24.04-arm"] + build_type: + - "ha-addon" + - "docker" + # - "lint" steps: - uses: actions/checkout@v4.1.7 - name: Set up Python diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0b01758323..77fe79fd1d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -39,7 +39,7 @@ jobs: uses: actions/checkout@v4.1.7 - name: Generate cache-key id: cache-key - run: echo key="${{ hashFiles('requirements.txt', 'requirements_optional.txt', 'requirements_test.txt') }}" >> $GITHUB_OUTPUT + run: echo key="${{ hashFiles('requirements.txt', 'requirements_test.txt') }}" >> $GITHUB_OUTPUT - name: Set up Python ${{ env.DEFAULT_PYTHON }} id: python uses: actions/setup-python@v5.6.0 @@ -58,7 +58,7 @@ jobs: python -m venv venv . venv/bin/activate python --version - pip install -r requirements.txt -r requirements_optional.txt -r requirements_test.txt + pip install -r requirements.txt -r requirements_test.txt pip install -e . ruff: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 417212f40e..359a9bcc53 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -68,19 +68,22 @@ jobs: uses: pypa/gh-action-pypi-publish@v1.12.4 deploy-docker: - name: Build ESPHome ${{ matrix.platform }} + name: Build ESPHome ${{ matrix.platform.arch }} if: github.repository == 'esphome/esphome' permissions: contents: read packages: write - runs-on: ubuntu-latest + runs-on: ${{ matrix.platform.os }} needs: [init] strategy: fail-fast: false matrix: platform: - - linux/amd64 - - linux/arm64 + - arch: amd64 + os: "ubuntu-24.04" + - arch: arm64 + os: "ubuntu-24.04-arm" + steps: - uses: actions/checkout@v4.1.7 - name: Set up Python @@ -90,9 +93,6 @@ jobs: - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3.10.0 - - name: Set up QEMU - if: matrix.platform != 'linux/amd64' - uses: docker/setup-qemu-action@v3.6.0 - name: Log in to docker hub uses: docker/login-action@v3.4.0 @@ -109,45 +109,36 @@ jobs: - name: Build docker uses: ./.github/actions/build-image with: - platform: ${{ matrix.platform }} - target: docker - baseimg: docker + target: final + build_type: docker suffix: "" version: ${{ needs.init.outputs.tag }} - name: Build ha-addon uses: ./.github/actions/build-image with: - platform: ${{ matrix.platform }} - target: hassio - baseimg: hassio + target: final + build_type: ha-addon suffix: "hassio" version: ${{ needs.init.outputs.tag }} - - name: Build lint - uses: ./.github/actions/build-image - with: - platform: ${{ matrix.platform }} - target: lint - baseimg: docker - suffix: lint - version: ${{ needs.init.outputs.tag }} - - - name: Sanitize platform name - id: sanitize - run: | - echo "${{ matrix.platform }}" | sed 's|/|-|g' > /tmp/platform - echo name=$(cat /tmp/platform) >> $GITHUB_OUTPUT + # - name: Build lint + # uses: ./.github/actions/build-image + # with: + # target: lint + # build_type: lint + # suffix: lint + # version: ${{ needs.init.outputs.tag }} - name: Upload digests uses: actions/upload-artifact@v4.6.2 with: - name: digests-${{ steps.sanitize.outputs.name }} + name: digests-${{ matrix.platform.arch }} path: /tmp/digests retention-days: 1 deploy-manifest: - name: Publish ESPHome ${{ matrix.image.title }} to ${{ matrix.registry }} + name: Publish ESPHome ${{ matrix.image.build_type }} to ${{ matrix.registry }} runs-on: ubuntu-latest needs: - init @@ -160,15 +151,12 @@ jobs: fail-fast: false matrix: image: - - title: "ha-addon" - target: "hassio" - suffix: "hassio" - - title: "docker" - target: "docker" + - build_type: "docker" suffix: "" - - title: "lint" - target: "lint" - suffix: "lint" + - build_type: "ha-addon" + suffix: "hassio" + # - build_type: "lint" + # suffix: "lint" registry: - ghcr - dockerhub @@ -212,7 +200,7 @@ jobs: done - name: Create manifest list and push - working-directory: /tmp/digests/${{ matrix.image.target }}/${{ matrix.registry }} + working-directory: /tmp/digests/${{ matrix.image.build_type }}/${{ matrix.registry }} run: | docker buildx imagetools create $(jq -Rcnr 'inputs | . / "," | map("-t " + .) | join(" ")' <<< "${{ steps.tags.outputs.tags}}") \ $(printf '${{ steps.tags.outputs.image }}@sha256:%s ' *) diff --git a/docker/Dockerfile b/docker/Dockerfile index 117ec17ae4..39dc1c7f28 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,131 +1,54 @@ -# Build these with the build.py script -# Example: -# python3 docker/build.py --tag dev --arch amd64 --build-type docker build +ARG BUILD_VERSION=dev +ARG BUILD_OS=alpine +ARG BUILD_BASE_VERSION=2025.04.0 +ARG BUILD_TYPE=docker -# One of "docker", "hassio" -ARG BASEIMGTYPE=docker +FROM ghcr.io/esphome/docker-base:${BUILD_OS}-${BUILD_BASE_VERSION} AS base-source-docker +FROM ghcr.io/esphome/docker-base:${BUILD_OS}-ha-addon-${BUILD_BASE_VERSION} AS base-source-ha-addon +ARG BUILD_TYPE +FROM base-source-${BUILD_TYPE} AS base -# https://github.com/hassio-addons/addon-debian-base/releases -FROM ghcr.io/hassio-addons/debian-base:7.2.0 AS base-hassio -# https://hub.docker.com/_/debian?tab=tags&page=1&name=bookworm -FROM debian:12.2-slim AS base-docker +RUN git config --system --add safe.directory "*" -FROM base-${BASEIMGTYPE} AS base +RUN pip install uv==0.6.14 - -ARG TARGETARCH -ARG TARGETVARIANT - - -# Note that --break-system-packages is used below because -# https://peps.python.org/pep-0668/ added a safety check that prevents -# installing packages with the same name as a system package. This is -# not a problem for us because we are not concerned about overwriting -# system packages because we are running in an isolated container. +COPY requirements.txt / RUN \ - apt-get update \ - # Use pinned versions so that we get updates with build caching - && apt-get install -y --no-install-recommends \ - python3-pip=23.0.1+dfsg-1 \ - python3-setuptools=66.1.1-1+deb12u1 \ - python3-venv=3.11.2-1+b1 \ - python3-wheel=0.38.4-2 \ - iputils-ping=3:20221126-1+deb12u1 \ - git=1:2.39.5-0+deb12u2 \ - curl=7.88.1-10+deb12u12 \ - openssh-client=1:9.2p1-2+deb12u5 \ - python3-cffi=1.15.1-5 \ - libcairo2=1.16.0-7 \ - libmagic1=1:5.44-3 \ - patch=2.7.6-7 \ - && rm -rf \ - /tmp/* \ - /var/{cache,log}/* \ - /var/lib/apt/lists/* - -ENV \ - # Fix click python3 lang warning https://click.palletsprojects.com/en/7.x/python3/ - LANG=C.UTF-8 LC_ALL=C.UTF-8 \ - # Store globally installed pio libs in /piolibs - PLATFORMIO_GLOBALLIB_DIR=/piolibs + uv pip install --no-cache-dir \ + -r /requirements.txt RUN \ - pip3 install \ - --break-system-packages --no-cache-dir \ - # Keep platformio version in sync with requirements.txt - platformio==6.1.18 \ - # Change some platformio settings - && platformio settings set enable_telemetry No \ + platformio settings set enable_telemetry No \ && platformio settings set check_platformio_interval 1000000 \ && mkdir -p /piolibs - -# First install requirements to leverage caching when requirements don't change -# tmpfs is for https://github.com/rust-lang/cargo/issues/8719 - -COPY requirements.txt requirements_optional.txt / -RUN --mount=type=tmpfs,target=/root/.cargo < /etc/apt/sources.list.d/llvm.sources.list \ - && apt-get update \ - # Use pinned versions so that we get updates with build caching - && apt-get install -y --no-install-recommends \ - clang-format-13=1:13.0.1-11+b2 \ - patch=2.7.6-7 \ - software-properties-common=0.99.30-4.1~deb12u1 \ - nano=7.2-1+deb12u1 \ - build-essential=12.9 \ - python3-dev=3.11.2-1+b1 \ - clang-tidy-18=1:18.1.8~++20240731024826+3b5b5c1ec4a3-1~exp1~20240731144843.145 \ - && rm -rf \ - /tmp/* \ - /var/{cache,log}/* \ - /var/lib/apt/lists/* - -COPY requirements_test.txt / -RUN pip3 install --break-system-packages --no-cache-dir -r /requirements_test.txt - -VOLUME ["/esphome"] -WORKDIR /esphome +# Copy esphome and install +COPY . /esphome +RUN uv pip install --no-cache-dir -e /esphome diff --git a/docker/build.py b/docker/build.py index cdc25df340..921adac7ab 100755 --- a/docker/build.py +++ b/docker/build.py @@ -54,7 +54,7 @@ manifest_parser = subparsers.add_parser( class DockerParams: build_to: str manifest_to: str - baseimgtype: str + build_type: str platform: str target: str @@ -66,24 +66,19 @@ class DockerParams: TYPE_LINT: "esphome/esphome-lint", }[build_type] build_to = f"{prefix}-{arch}" - baseimgtype = { - TYPE_DOCKER: "docker", - TYPE_HA_ADDON: "hassio", - TYPE_LINT: "docker", - }[build_type] platform = { ARCH_AMD64: "linux/amd64", ARCH_AARCH64: "linux/arm64", }[arch] target = { - TYPE_DOCKER: "docker", - TYPE_HA_ADDON: "hassio", + TYPE_DOCKER: "final", + TYPE_HA_ADDON: "final", TYPE_LINT: "lint", }[build_type] return cls( build_to=build_to, manifest_to=prefix, - baseimgtype=baseimgtype, + build_type=build_type, platform=platform, target=target, ) @@ -145,7 +140,7 @@ def main(): "buildx", "build", "--build-arg", - f"BASEIMGTYPE={params.baseimgtype}", + f"BUILD_TYPE={params.build_type}", "--build-arg", f"BUILD_VERSION={args.tag}", "--cache-from", diff --git a/pyproject.toml b/pyproject.toml index daf702d2e0..e3b10722c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,6 @@ version = {attr = "esphome.const.__version__"} [tool.setuptools.dynamic.optional-dependencies] dev = { file = ["requirements_dev.txt"] } test = { file = ["requirements_test.txt"] } -displays = { file = ["requirements_optional.txt"] } [tool.setuptools.packages.find] include = ["esphome*"] diff --git a/requirements.txt b/requirements.txt index cb1f1da2f2..f09d7894dd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,6 +19,7 @@ puremagic==1.28 ruamel.yaml==0.18.10 # dashboard_import esphome-glyphsets==0.2.0 pillow==10.4.0 +cairosvg==2.7.1 freetype-py==2.5.1 # esp-idf requires this, but doesn't bundle it by default diff --git a/requirements_optional.txt b/requirements_optional.txt deleted file mode 100644 index 7416753d55..0000000000 --- a/requirements_optional.txt +++ /dev/null @@ -1 +0,0 @@ -cairosvg==2.7.1 diff --git a/script/setup b/script/setup index 3ebf75387f..acc2ec58b4 100755 --- a/script/setup +++ b/script/setup @@ -4,25 +4,28 @@ set -e cd "$(dirname "$0")/.." -location="venv/bin/activate" if [ ! -n "$DEVCONTAINER" ] && [ ! -n "$VIRTUAL_ENV" ] && [ ! "$ESPHOME_NO_VENV" ]; then - python3 -m venv venv - if [ -f venv/Scripts/activate ]; then - location="venv/Scripts/activate" + if [ -x "$(command -v uv)" ]; then + uv venv venv + else + python3 -m venv venv fi - source $location + source venv/bin/activate fi -pip3 install -r requirements.txt -r requirements_optional.txt -r requirements_test.txt -r requirements_dev.txt -pip3 install setuptools wheel -pip3 install -e ".[dev,test,displays]" --config-settings editable_mode=compat +if ! [ -x "$(command -v uv)" ]; then + python3 -m pip install uv +fi + +uv pip install setuptools wheel +uv pip install -e ".[dev,test]" --config-settings editable_mode=compat pre-commit install script/platformio_install_deps.py platformio.ini --libraries --tools --platforms -mkdir .temp +mkdir -p .temp echo echo -echo "Virtual environment created. Run 'source $location' to use it." +echo "Virtual environment created. Run 'source venv/bin/activate' to use it." diff --git a/script/setup.bat b/script/setup.bat index 0b49768139..ea2591bb71 100644 --- a/script/setup.bat +++ b/script/setup.bat @@ -15,9 +15,9 @@ echo Installing required packages... python.exe -m pip install --upgrade pip -pip3 install -r requirements.txt -r requirements_optional.txt -r requirements_test.txt -r requirements_dev.txt +pip3 install -r requirements.txt -r requirements_test.txt -r requirements_dev.txt pip3 install setuptools wheel -pip3 install -e ".[dev,test,displays]" --config-settings editable_mode=compat +pip3 install -e ".[dev,test]" --config-settings editable_mode=compat pre-commit install From e557bca4207f52a9957d8172c04c1eed65d09b38 Mon Sep 17 00:00:00 2001 From: Kevin Ahrendt Date: Sun, 27 Apr 2025 18:19:01 -0500 Subject: [PATCH 017/102] [i2s_audio] Microphone reads in loop for callbacks shouldn't ever delay (#8625) --- .../i2s_audio/microphone/i2s_audio_microphone.cpp | 14 ++++++++------ .../i2s_audio/microphone/i2s_audio_microphone.h | 3 ++- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/esphome/components/i2s_audio/microphone/i2s_audio_microphone.cpp b/esphome/components/i2s_audio/microphone/i2s_audio_microphone.cpp index ef375954cd..3ab3c88142 100644 --- a/esphome/components/i2s_audio/microphone/i2s_audio_microphone.cpp +++ b/esphome/components/i2s_audio/microphone/i2s_audio_microphone.cpp @@ -284,19 +284,21 @@ void I2SAudioMicrophone::stop_() { this->status_clear_error(); } -size_t I2SAudioMicrophone::read(int16_t *buf, size_t len) { +size_t I2SAudioMicrophone::read(int16_t *buf, size_t len, TickType_t ticks_to_wait) { size_t bytes_read = 0; #ifdef USE_I2S_LEGACY - esp_err_t err = i2s_read(this->parent_->get_port(), buf, len, &bytes_read, (100 / portTICK_PERIOD_MS)); + esp_err_t err = i2s_read(this->parent_->get_port(), buf, len, &bytes_read, ticks_to_wait); #else - esp_err_t err = i2s_channel_read(this->rx_handle_, buf, len, &bytes_read, (100 / portTICK_PERIOD_MS)); + // i2s_channel_read expects the timeout value in ms, not ticks + esp_err_t err = i2s_channel_read(this->rx_handle_, buf, len, &bytes_read, pdTICKS_TO_MS(ticks_to_wait)); #endif - if (err != ESP_OK) { + if ((err != ESP_OK) && ((err != ESP_ERR_TIMEOUT) || (ticks_to_wait != 0))) { + // Ignore ESP_ERR_TIMEOUT if ticks_to_wait = 0, as it will read the data on the next call ESP_LOGW(TAG, "Error reading from I2S microphone: %s", esp_err_to_name(err)); this->status_set_warning(); return 0; } - if (bytes_read == 0) { + if ((bytes_read == 0) && (ticks_to_wait > 0)) { this->status_set_warning(); return 0; } @@ -350,7 +352,7 @@ size_t I2SAudioMicrophone::read(int16_t *buf, size_t len) { void I2SAudioMicrophone::read_() { std::vector samples; samples.resize(BUFFER_SIZE); - size_t bytes_read = this->read(samples.data(), BUFFER_SIZE / sizeof(int16_t)); + size_t bytes_read = this->read(samples.data(), BUFFER_SIZE * sizeof(int16_t), 0); samples.resize(bytes_read / sizeof(int16_t)); this->data_callbacks_.call(samples); } diff --git a/esphome/components/i2s_audio/microphone/i2s_audio_microphone.h b/esphome/components/i2s_audio/microphone/i2s_audio_microphone.h index 2ff46fabab..2dbacb447e 100644 --- a/esphome/components/i2s_audio/microphone/i2s_audio_microphone.h +++ b/esphome/components/i2s_audio/microphone/i2s_audio_microphone.h @@ -25,7 +25,8 @@ class I2SAudioMicrophone : public I2SAudioIn, public microphone::Microphone, pub void set_pdm(bool pdm) { this->pdm_ = pdm; } - size_t read(int16_t *buf, size_t len) override; + size_t read(int16_t *buf, size_t len, TickType_t ticks_to_wait); + size_t read(int16_t *buf, size_t len) override { return this->read(buf, len, pdMS_TO_TICKS(100)); } #ifdef USE_I2S_LEGACY #if SOC_I2S_SUPPORTS_ADC From ee646d73247713fe45c539c8e22bf617ee96f4dd Mon Sep 17 00:00:00 2001 From: Kevin Ahrendt Date: Sun, 27 Apr 2025 18:23:25 -0500 Subject: [PATCH 018/102] [micro_wake_word] Use microphone callback and avoid unnecessary allocation attempts (#8626) --- .../micro_wake_word/micro_wake_word.cpp | 89 ++++++++++--------- .../micro_wake_word/micro_wake_word.h | 12 +-- 2 files changed, 49 insertions(+), 52 deletions(-) diff --git a/esphome/components/micro_wake_word/micro_wake_word.cpp b/esphome/components/micro_wake_word/micro_wake_word.cpp index b58c7ec434..533aa9fb75 100644 --- a/esphome/components/micro_wake_word/micro_wake_word.cpp +++ b/esphome/components/micro_wake_word/micro_wake_word.cpp @@ -61,6 +61,29 @@ void MicroWakeWord::dump_config() { void MicroWakeWord::setup() { ESP_LOGCONFIG(TAG, "Setting up microWakeWord..."); + this->microphone_->add_data_callback([this](const std::vector &data) { + if (this->state_ != State::DETECTING_WAKE_WORD) { + return; + } + std::shared_ptr temp_ring_buffer = this->ring_buffer_; + if (this->ring_buffer_.use_count() == 2) { + // mWW still owns the ring buffer and temp_ring_buffer does as well, proceed to copy audio into ring buffer + + size_t bytes_free = temp_ring_buffer->free(); + + if (bytes_free < data.size() * sizeof(int16_t)) { + ESP_LOGW( + TAG, + "Not enough free bytes in ring buffer to store incoming audio data (free bytes=%d, incoming bytes=%d). " + "Resetting the ring buffer. Wake word detection accuracy will be reduced.", + bytes_free, data.size()); + + temp_ring_buffer->reset(); + } + temp_ring_buffer->write((void *) data.data(), data.size() * sizeof(int16_t)); + } + }); + if (!this->register_streaming_ops_(this->streaming_op_resolver_)) { this->mark_failed(); return; @@ -107,7 +130,6 @@ void MicroWakeWord::loop() { ESP_LOGD(TAG, "Starting Microphone"); this->microphone_->start(); this->set_state_(State::STARTING_MICROPHONE); - this->high_freq_.start(); break; case State::STARTING_MICROPHONE: if (this->microphone_->is_running()) { @@ -115,21 +137,19 @@ void MicroWakeWord::loop() { } break; case State::DETECTING_WAKE_WORD: - while (!this->has_enough_samples_()) { - this->read_microphone_(); - } - this->update_model_probabilities_(); - if (this->detect_wake_words_()) { - ESP_LOGD(TAG, "Wake Word '%s' Detected", (this->detected_wake_word_).c_str()); - this->detected_ = true; - this->set_state_(State::STOP_MICROPHONE); + while (this->has_enough_samples_()) { + this->update_model_probabilities_(); + if (this->detect_wake_words_()) { + ESP_LOGD(TAG, "Wake Word '%s' Detected", (this->detected_wake_word_).c_str()); + this->detected_ = true; + this->set_state_(State::STOP_MICROPHONE); + } } break; case State::STOP_MICROPHONE: ESP_LOGD(TAG, "Stopping Microphone"); this->microphone_->stop(); this->set_state_(State::STOPPING_MICROPHONE); - this->high_freq_.stop(); this->unload_models_(); this->deallocate_buffers_(); break; @@ -157,6 +177,11 @@ void MicroWakeWord::start() { return; } + if (this->state_ != State::IDLE) { + ESP_LOGW(TAG, "Wake word is already running"); + return; + } + if (!this->load_models_() || !this->allocate_buffers_()) { ESP_LOGE(TAG, "Failed to load the wake word model(s) or allocate buffers"); this->status_set_error(); @@ -169,11 +194,6 @@ void MicroWakeWord::start() { return; } - if (this->state_ != State::IDLE) { - ESP_LOGW(TAG, "Wake word is already running"); - return; - } - this->reset_states_(); this->set_state_(State::START_MICROPHONE); } @@ -196,26 +216,6 @@ void MicroWakeWord::set_state_(State state) { this->state_ = state; } -size_t MicroWakeWord::read_microphone_() { - size_t bytes_read = this->microphone_->read(this->input_buffer_, INPUT_BUFFER_SIZE * sizeof(int16_t)); - if (bytes_read == 0) { - return 0; - } - - size_t bytes_free = this->ring_buffer_->free(); - - if (bytes_free < bytes_read) { - ESP_LOGW(TAG, - "Not enough free bytes in ring buffer to store incoming audio data (free bytes=%d, incoming bytes=%d). " - "Resetting the ring buffer. Wake word detection accuracy will be reduced.", - bytes_free, bytes_read); - - this->ring_buffer_->reset(); - } - - return this->ring_buffer_->write((void *) this->input_buffer_, bytes_read); -} - bool MicroWakeWord::allocate_buffers_() { ExternalRAMAllocator audio_samples_allocator(ExternalRAMAllocator::ALLOW_FAILURE); @@ -235,9 +235,9 @@ bool MicroWakeWord::allocate_buffers_() { } } - if (this->ring_buffer_ == nullptr) { + if (this->ring_buffer_.use_count() == 0) { this->ring_buffer_ = RingBuffer::create(BUFFER_SIZE * sizeof(int16_t)); - if (this->ring_buffer_ == nullptr) { + if (this->ring_buffer_.use_count() == 0) { ESP_LOGE(TAG, "Could not allocate ring buffer"); return false; } @@ -248,10 +248,17 @@ bool MicroWakeWord::allocate_buffers_() { void MicroWakeWord::deallocate_buffers_() { ExternalRAMAllocator audio_samples_allocator(ExternalRAMAllocator::ALLOW_FAILURE); - audio_samples_allocator.deallocate(this->input_buffer_, INPUT_BUFFER_SIZE * sizeof(int16_t)); - this->input_buffer_ = nullptr; - audio_samples_allocator.deallocate(this->preprocessor_audio_buffer_, this->new_samples_to_get_()); - this->preprocessor_audio_buffer_ = nullptr; + if (this->input_buffer_ != nullptr) { + audio_samples_allocator.deallocate(this->input_buffer_, INPUT_BUFFER_SIZE * sizeof(int16_t)); + this->input_buffer_ = nullptr; + } + + if (this->preprocessor_audio_buffer_ != nullptr) { + audio_samples_allocator.deallocate(this->preprocessor_audio_buffer_, this->new_samples_to_get_()); + this->preprocessor_audio_buffer_ = nullptr; + } + + this->ring_buffer_.reset(); } bool MicroWakeWord::load_models_() { diff --git a/esphome/components/micro_wake_word/micro_wake_word.h b/esphome/components/micro_wake_word/micro_wake_word.h index 0c805b75fc..443911b1e4 100644 --- a/esphome/components/micro_wake_word/micro_wake_word.h +++ b/esphome/components/micro_wake_word/micro_wake_word.h @@ -62,9 +62,8 @@ class MicroWakeWord : public Component { microphone::Microphone *microphone_{nullptr}; Trigger *wake_word_detected_trigger_ = new Trigger(); State state_{State::IDLE}; - HighFrequencyLoopRequester high_freq_; - std::unique_ptr ring_buffer_; + std::shared_ptr ring_buffer_; std::vector wake_word_models_; @@ -98,15 +97,6 @@ class MicroWakeWord : public Component { /// @return True if enough samples, false otherwise. bool has_enough_samples_(); - /** Reads audio from microphone into the ring buffer - * - * Audio data (16000 kHz with int16 samples) is read into the input_buffer_. - * Verifies the ring buffer has enough space for all audio data. If not, it logs - * a warning and resets the ring buffer entirely. - * @return Number of bytes written to the ring buffer - */ - size_t read_microphone_(); - /// @brief Allocates memory for input_buffer_, preprocessor_audio_buffer_, and ring_buffer_ /// @return True if successful, false otherwise bool allocate_buffers_(); From c9d1476ae003f0c7be73cfff958125b3e5ab4793 Mon Sep 17 00:00:00 2001 From: Kevin Ahrendt Date: Sun, 27 Apr 2025 18:30:21 -0500 Subject: [PATCH 019/102] [voice_assisant] support start/continue conversation and deallocate buffers (#8610) --- .../voice_assistant/voice_assistant.cpp | 129 +++++++++++++----- .../voice_assistant/voice_assistant.h | 8 +- 2 files changed, 104 insertions(+), 33 deletions(-) diff --git a/esphome/components/voice_assistant/voice_assistant.cpp b/esphome/components/voice_assistant/voice_assistant.cpp index 4b02867967..a38ae2d12b 100644 --- a/esphome/components/voice_assistant/voice_assistant.cpp +++ b/esphome/components/voice_assistant/voice_assistant.cpp @@ -72,12 +72,8 @@ bool VoiceAssistant::start_udp_socket_() { } bool VoiceAssistant::allocate_buffers_() { - if (this->send_buffer_ != nullptr) { - return true; // Already allocated - } - #ifdef USE_SPEAKER - if (this->speaker_ != nullptr) { + if ((this->speaker_ != nullptr) && (this->speaker_buffer_ == nullptr)) { ExternalRAMAllocator speaker_allocator(ExternalRAMAllocator::ALLOW_FAILURE); this->speaker_buffer_ = speaker_allocator.allocate(SPEAKER_BUFFER_SIZE); if (this->speaker_buffer_ == nullptr) { @@ -87,28 +83,34 @@ bool VoiceAssistant::allocate_buffers_() { } #endif - ExternalRAMAllocator allocator(ExternalRAMAllocator::ALLOW_FAILURE); - this->input_buffer_ = allocator.allocate(INPUT_BUFFER_SIZE); if (this->input_buffer_ == nullptr) { - ESP_LOGW(TAG, "Could not allocate input buffer"); - return false; + ExternalRAMAllocator allocator(ExternalRAMAllocator::ALLOW_FAILURE); + this->input_buffer_ = allocator.allocate(INPUT_BUFFER_SIZE); + if (this->input_buffer_ == nullptr) { + ESP_LOGW(TAG, "Could not allocate input buffer"); + return false; + } } #ifdef USE_ESP_ADF this->vad_instance_ = vad_create(VAD_MODE_4); #endif - this->ring_buffer_ = RingBuffer::create(BUFFER_SIZE * sizeof(int16_t)); - if (this->ring_buffer_ == nullptr) { - ESP_LOGW(TAG, "Could not allocate ring buffer"); - return false; + if (this->ring_buffer_.use_count() == 0) { + this->ring_buffer_ = RingBuffer::create(BUFFER_SIZE * sizeof(int16_t)); + if (this->ring_buffer_.use_count() == 0) { + ESP_LOGE(TAG, "Could not allocate ring buffer"); + return false; + } } - ExternalRAMAllocator send_allocator(ExternalRAMAllocator::ALLOW_FAILURE); - this->send_buffer_ = send_allocator.allocate(SEND_BUFFER_SIZE); - if (send_buffer_ == nullptr) { - ESP_LOGW(TAG, "Could not allocate send buffer"); - return false; + if (this->send_buffer_ == nullptr) { + ExternalRAMAllocator send_allocator(ExternalRAMAllocator::ALLOW_FAILURE); + this->send_buffer_ = send_allocator.allocate(SEND_BUFFER_SIZE); + if (send_buffer_ == nullptr) { + ESP_LOGW(TAG, "Could not allocate send buffer"); + return false; + } } return true; @@ -139,13 +141,14 @@ void VoiceAssistant::clear_buffers_() { } void VoiceAssistant::deallocate_buffers_() { - ExternalRAMAllocator send_deallocator(ExternalRAMAllocator::ALLOW_FAILURE); - send_deallocator.deallocate(this->send_buffer_, SEND_BUFFER_SIZE); - this->send_buffer_ = nullptr; + if (this->send_buffer_ != nullptr) { + ExternalRAMAllocator send_deallocator(ExternalRAMAllocator::ALLOW_FAILURE); + send_deallocator.deallocate(this->send_buffer_, SEND_BUFFER_SIZE); + this->send_buffer_ = nullptr; + } - if (this->ring_buffer_ != nullptr) { + if (this->ring_buffer_.use_count() > 0) { this->ring_buffer_.reset(); - this->ring_buffer_ = nullptr; } #ifdef USE_ESP_ADF @@ -155,9 +158,11 @@ void VoiceAssistant::deallocate_buffers_() { } #endif - ExternalRAMAllocator input_deallocator(ExternalRAMAllocator::ALLOW_FAILURE); - input_deallocator.deallocate(this->input_buffer_, INPUT_BUFFER_SIZE); - this->input_buffer_ = nullptr; + if (this->input_buffer_ != nullptr) { + ExternalRAMAllocator input_deallocator(ExternalRAMAllocator::ALLOW_FAILURE); + input_deallocator.deallocate(this->input_buffer_, INPUT_BUFFER_SIZE); + this->input_buffer_ = nullptr; + } #ifdef USE_SPEAKER if ((this->speaker_ != nullptr) && (this->speaker_buffer_ != nullptr)) { @@ -216,6 +221,7 @@ void VoiceAssistant::loop() { } } else { this->high_freq_.stop(); + this->deallocate_buffers_(); } break; } @@ -276,7 +282,7 @@ void VoiceAssistant::loop() { this->read_microphone_(); ESP_LOGD(TAG, "Requesting start..."); uint32_t flags = 0; - if (this->use_wake_word_) + if (!this->continue_conversation_ && this->use_wake_word_) flags |= api::enums::VOICE_ASSISTANT_REQUEST_USE_WAKE_WORD; if (this->silence_detection_) flags |= api::enums::VOICE_ASSISTANT_REQUEST_USE_VAD; @@ -387,6 +393,25 @@ void VoiceAssistant::loop() { #ifdef USE_MEDIA_PLAYER if (this->media_player_ != nullptr) { playing = (this->media_player_->state == media_player::MediaPlayerState::MEDIA_PLAYER_STATE_ANNOUNCING); + + if (playing && this->media_player_wait_for_announcement_start_) { + // Announcement has started playing, wait for it to finish + this->media_player_wait_for_announcement_start_ = false; + this->media_player_wait_for_announcement_end_ = true; + } + + if (!playing && this->media_player_wait_for_announcement_end_) { + // Announcement has finished playing + this->media_player_wait_for_announcement_end_ = false; + this->cancel_timeout("playing"); + ESP_LOGD(TAG, "Announcement finished playing"); + this->set_state_(State::RESPONSE_FINISHED, State::RESPONSE_FINISHED); + + api::VoiceAssistantAnnounceFinished msg; + msg.success = true; + this->api_client_->send_voice_assistant_announce_finished(msg); + break; + } } #endif if (playing) { @@ -417,7 +442,11 @@ void VoiceAssistant::loop() { this->tts_stream_end_trigger_->trigger(); } #endif - this->set_state_(State::IDLE, State::IDLE); + if (this->continue_conversation_) { + this->set_state_(State::START_MICROPHONE, State::START_PIPELINE); + } else { + this->set_state_(State::IDLE, State::IDLE); + } break; } default: @@ -587,6 +616,7 @@ void VoiceAssistant::request_start(bool continuous, bool silence_detection) { void VoiceAssistant::request_stop() { this->continuous_ = false; + this->continue_conversation_ = false; switch (this->state_) { case State::IDLE: @@ -611,6 +641,16 @@ void VoiceAssistant::request_stop() { this->signal_stop_(); break; case State::STREAMING_RESPONSE: +#ifdef USE_MEDIA_PLAYER + // Stop any ongoing media player announcement + if (this->media_player_ != nullptr) { + this->media_player_->make_call() + .set_command(media_player::MEDIA_PLAYER_COMMAND_STOP) + .set_announcement(true) + .perform(); + } +#endif + break; case State::RESPONSE_FINISHED: break; // Let the incoming audio stream finish then it will go to idle. } @@ -628,9 +668,9 @@ void VoiceAssistant::signal_stop_() { } void VoiceAssistant::start_playback_timeout_() { - this->set_timeout("playing", 100, [this]() { + this->set_timeout("playing", 2000, [this]() { this->cancel_timeout("speaker-timeout"); - this->set_state_(State::IDLE, State::IDLE); + this->set_state_(State::RESPONSE_FINISHED, State::RESPONSE_FINISHED); api::VoiceAssistantAnnounceFinished msg; msg.success = true; @@ -679,6 +719,8 @@ void VoiceAssistant::on_event(const api::VoiceAssistantEventResponse &msg) { for (auto arg : msg.data) { if (arg.name == "conversation_id") { this->conversation_id_ = std::move(arg.value); + } else if (arg.name == "continue_conversation") { + this->continue_conversation_ = (arg.value == "1"); } } this->defer([this]() { this->intent_end_trigger_->trigger(); }); @@ -722,6 +764,9 @@ void VoiceAssistant::on_event(const api::VoiceAssistantEventResponse &msg) { #ifdef USE_MEDIA_PLAYER if (this->media_player_ != nullptr) { this->media_player_->make_call().set_media_url(url).set_announcement(true).perform(); + + this->media_player_wait_for_announcement_start_ = true; + this->media_player_wait_for_announcement_end_ = false; // Start the playback timeout, as the media player state isn't immediately updated this->start_playback_timeout_(); } @@ -888,8 +933,28 @@ void VoiceAssistant::on_announce(const api::VoiceAssistantAnnounceRequest &msg) #ifdef USE_MEDIA_PLAYER if (this->media_player_ != nullptr) { this->tts_start_trigger_->trigger(msg.text); - this->media_player_->make_call().set_media_url(msg.media_id).set_announcement(true).perform(); - this->set_state_(State::STREAMING_RESPONSE, State::STREAMING_RESPONSE); + if (!msg.preannounce_media_id.empty()) { + this->media_player_->make_call().set_media_url(msg.preannounce_media_id).set_announcement(true).perform(); + } + // Enqueueing a URL with an empty playlist will still play the file immediately + this->media_player_->make_call() + .set_command(media_player::MEDIA_PLAYER_COMMAND_ENQUEUE) + .set_media_url(msg.media_id) + .set_announcement(true) + .perform(); + this->continue_conversation_ = msg.start_conversation; + + this->media_player_wait_for_announcement_start_ = true; + this->media_player_wait_for_announcement_end_ = false; + // Start the playback timeout, as the media player state isn't immediately updated + this->start_playback_timeout_(); + + if (this->continuous_) { + this->set_state_(State::STOP_MICROPHONE, State::STREAMING_RESPONSE); + } else { + this->set_state_(State::STREAMING_RESPONSE, State::STREAMING_RESPONSE); + } + this->tts_end_trigger_->trigger(msg.media_id); this->end_trigger_->trigger(); } diff --git a/esphome/components/voice_assistant/voice_assistant.h b/esphome/components/voice_assistant/voice_assistant.h index 12124c1486..66531fcd94 100644 --- a/esphome/components/voice_assistant/voice_assistant.h +++ b/esphome/components/voice_assistant/voice_assistant.h @@ -41,6 +41,7 @@ enum VoiceAssistantFeature : uint32_t { FEATURE_API_AUDIO = 1 << 2, FEATURE_TIMERS = 1 << 3, FEATURE_ANNOUNCE = 1 << 4, + FEATURE_START_CONVERSATION = 1 << 5, }; enum class State { @@ -140,6 +141,7 @@ class VoiceAssistant : public Component { #ifdef USE_MEDIA_PLAYER if (this->media_player_ != nullptr) { flags |= VoiceAssistantFeature::FEATURE_ANNOUNCE; + flags |= VoiceAssistantFeature::FEATURE_START_CONVERSATION; } #endif @@ -267,6 +269,8 @@ class VoiceAssistant : public Component { #endif #ifdef USE_MEDIA_PLAYER media_player::MediaPlayer *media_player_{nullptr}; + bool media_player_wait_for_announcement_start_{false}; + bool media_player_wait_for_announcement_end_{false}; #endif bool local_output_{false}; @@ -282,7 +286,7 @@ class VoiceAssistant : public Component { uint8_t vad_threshold_{5}; uint8_t vad_counter_{0}; #endif - std::unique_ptr ring_buffer_; + std::shared_ptr ring_buffer_; bool use_wake_word_; uint8_t noise_suppression_level_; @@ -296,6 +300,8 @@ class VoiceAssistant : public Component { bool continuous_{false}; bool silence_detection_; + bool continue_conversation_{false}; + State state_{State::IDLE}; State desired_state_{State::IDLE}; From e49252ca3d0e6734adebb4fb8d30e9ea2becf47a Mon Sep 17 00:00:00 2001 From: Kevin Ahrendt Date: Sun, 27 Apr 2025 19:15:28 -0500 Subject: [PATCH 020/102] [voice_assistant] Use mic callback and remove esp_adf code (#8627) Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com> --- .../components/voice_assistant/__init__.py | 4 +- .../voice_assistant/voice_assistant.cpp | 130 +++--------------- .../voice_assistant/voice_assistant.h | 17 +-- 3 files changed, 22 insertions(+), 129 deletions(-) diff --git a/esphome/components/voice_assistant/__init__.py b/esphome/components/voice_assistant/__init__.py index a4fb572208..e8cdca94b8 100644 --- a/esphome/components/voice_assistant/__init__.py +++ b/esphome/components/voice_assistant/__init__.py @@ -94,8 +94,8 @@ CONFIG_SCHEMA = cv.All( media_player.MediaPlayer ), cv.Optional(CONF_USE_WAKE_WORD, default=False): cv.boolean, - cv.Optional(CONF_VAD_THRESHOLD): cv.All( - cv.requires_component("esp_adf"), cv.only_with_esp_idf, cv.uint8_t + cv.Optional(CONF_VAD_THRESHOLD): cv.invalid( + "VAD threshold is no longer supported, as it requires the deprecated esp_adf external component. Use an i2s_audio microphone/speaker instead. Additionally, you may need to configure the audio_adc and audio_dac components depending on your hardware." ), cv.Optional(CONF_NOISE_SUPPRESSION_LEVEL, default=0): cv.int_range(0, 4), cv.Optional(CONF_AUTO_GAIN, default="0dBFS"): cv.All( diff --git a/esphome/components/voice_assistant/voice_assistant.cpp b/esphome/components/voice_assistant/voice_assistant.cpp index a38ae2d12b..c62767d7d5 100644 --- a/esphome/components/voice_assistant/voice_assistant.cpp +++ b/esphome/components/voice_assistant/voice_assistant.cpp @@ -18,14 +18,25 @@ static const char *const TAG = "voice_assistant"; #endif static const size_t SAMPLE_RATE_HZ = 16000; -static const size_t INPUT_BUFFER_SIZE = 32 * SAMPLE_RATE_HZ / 1000; // 32ms * 16kHz / 1000ms -static const size_t BUFFER_SIZE = 512 * SAMPLE_RATE_HZ / 1000; -static const size_t SEND_BUFFER_SIZE = INPUT_BUFFER_SIZE * sizeof(int16_t); + +static const size_t RING_BUFFER_SAMPLES = 512 * SAMPLE_RATE_HZ / 1000; // 512 ms * 16 kHz/ 1000 ms +static const size_t RING_BUFFER_SIZE = RING_BUFFER_SAMPLES * sizeof(int16_t); +static const size_t SEND_BUFFER_SAMPLES = 32 * SAMPLE_RATE_HZ / 1000; // 32ms * 16kHz / 1000ms +static const size_t SEND_BUFFER_SIZE = SEND_BUFFER_SAMPLES * sizeof(int16_t); static const size_t RECEIVE_SIZE = 1024; static const size_t SPEAKER_BUFFER_SIZE = 16 * RECEIVE_SIZE; VoiceAssistant::VoiceAssistant() { global_voice_assistant = this; } +void VoiceAssistant::setup() { + this->mic_->add_data_callback([this](const std::vector &data) { + std::shared_ptr temp_ring_buffer = this->ring_buffer_; + if (this->ring_buffer_.use_count() > 1) { + temp_ring_buffer->write((void *) data.data(), data.size() * sizeof(int16_t)); + } + }); +} + float VoiceAssistant::get_setup_priority() const { return setup_priority::AFTER_CONNECTION; } bool VoiceAssistant::start_udp_socket_() { @@ -83,21 +94,8 @@ bool VoiceAssistant::allocate_buffers_() { } #endif - if (this->input_buffer_ == nullptr) { - ExternalRAMAllocator allocator(ExternalRAMAllocator::ALLOW_FAILURE); - this->input_buffer_ = allocator.allocate(INPUT_BUFFER_SIZE); - if (this->input_buffer_ == nullptr) { - ESP_LOGW(TAG, "Could not allocate input buffer"); - return false; - } - } - -#ifdef USE_ESP_ADF - this->vad_instance_ = vad_create(VAD_MODE_4); -#endif - if (this->ring_buffer_.use_count() == 0) { - this->ring_buffer_ = RingBuffer::create(BUFFER_SIZE * sizeof(int16_t)); + this->ring_buffer_ = RingBuffer::create(RING_BUFFER_SIZE); if (this->ring_buffer_.use_count() == 0) { ESP_LOGE(TAG, "Could not allocate ring buffer"); return false; @@ -121,10 +119,6 @@ void VoiceAssistant::clear_buffers_() { memset(this->send_buffer_, 0, SEND_BUFFER_SIZE); } - if (this->input_buffer_ != nullptr) { - memset(this->input_buffer_, 0, INPUT_BUFFER_SIZE * sizeof(int16_t)); - } - if (this->ring_buffer_ != nullptr) { this->ring_buffer_->reset(); } @@ -151,19 +145,6 @@ void VoiceAssistant::deallocate_buffers_() { this->ring_buffer_.reset(); } -#ifdef USE_ESP_ADF - if (this->vad_instance_ != nullptr) { - vad_destroy(this->vad_instance_); - this->vad_instance_ = nullptr; - } -#endif - - if (this->input_buffer_ != nullptr) { - ExternalRAMAllocator input_deallocator(ExternalRAMAllocator::ALLOW_FAILURE); - input_deallocator.deallocate(this->input_buffer_, INPUT_BUFFER_SIZE); - this->input_buffer_ = nullptr; - } - #ifdef USE_SPEAKER if ((this->speaker_ != nullptr) && (this->speaker_buffer_ != nullptr)) { ExternalRAMAllocator speaker_deallocator(ExternalRAMAllocator::ALLOW_FAILURE); @@ -178,22 +159,6 @@ void VoiceAssistant::reset_conversation_id() { ESP_LOGD(TAG, "reset conversation ID"); } -int VoiceAssistant::read_microphone_() { - size_t bytes_read = 0; - if (this->mic_->is_running()) { // Read audio into input buffer - bytes_read = this->mic_->read(this->input_buffer_, INPUT_BUFFER_SIZE * sizeof(int16_t)); - if (bytes_read == 0) { - memset(this->input_buffer_, 0, INPUT_BUFFER_SIZE * sizeof(int16_t)); - return 0; - } - // Write audio into ring buffer - this->ring_buffer_->write((void *) this->input_buffer_, bytes_read); - } else { - ESP_LOGD(TAG, "microphone not running"); - } - return bytes_read; -} - void VoiceAssistant::loop() { if (this->api_client_ == nullptr && this->state_ != State::IDLE && this->state_ != State::STOP_MICROPHONE && this->state_ != State::STOPPING_MICROPHONE) { @@ -211,16 +176,8 @@ void VoiceAssistant::loop() { case State::IDLE: { if (this->continuous_ && this->desired_state_ == State::IDLE) { this->idle_trigger_->trigger(); -#ifdef USE_ESP_ADF - if (this->use_wake_word_) { - this->set_state_(State::START_MICROPHONE, State::WAIT_FOR_VAD); - } else -#endif - { - this->set_state_(State::START_MICROPHONE, State::START_PIPELINE); - } + this->set_state_(State::START_MICROPHONE, State::START_PIPELINE); } else { - this->high_freq_.stop(); this->deallocate_buffers_(); } break; @@ -237,7 +194,6 @@ void VoiceAssistant::loop() { this->clear_buffers_(); this->mic_->start(); - this->high_freq_.start(); this->set_state_(State::STARTING_MICROPHONE); break; } @@ -247,39 +203,7 @@ void VoiceAssistant::loop() { } break; } -#ifdef USE_ESP_ADF - case State::WAIT_FOR_VAD: { - this->read_microphone_(); - ESP_LOGD(TAG, "Waiting for speech..."); - this->set_state_(State::WAITING_FOR_VAD); - break; - } - case State::WAITING_FOR_VAD: { - size_t bytes_read = this->read_microphone_(); - if (bytes_read > 0) { - vad_state_t vad_state = - vad_process(this->vad_instance_, this->input_buffer_, SAMPLE_RATE_HZ, VAD_FRAME_LENGTH_MS); - if (vad_state == VAD_SPEECH) { - if (this->vad_counter_ < this->vad_threshold_) { - this->vad_counter_++; - } else { - ESP_LOGD(TAG, "VAD detected speech"); - this->set_state_(State::START_PIPELINE, State::STREAMING_MICROPHONE); - - // Reset for next time - this->vad_counter_ = 0; - } - } else { - if (this->vad_counter_ > 0) { - this->vad_counter_--; - } - } - } - break; - } -#endif case State::START_PIPELINE: { - this->read_microphone_(); ESP_LOGD(TAG, "Requesting start..."); uint32_t flags = 0; if (!this->continue_conversation_ && this->use_wake_word_) @@ -312,11 +236,9 @@ void VoiceAssistant::loop() { break; } case State::STARTING_PIPELINE: { - this->read_microphone_(); break; // State changed when udp server port received } case State::STREAMING_MICROPHONE: { - this->read_microphone_(); size_t available = this->ring_buffer_->available(); while (available >= SEND_BUFFER_SIZE) { size_t read_bytes = this->ring_buffer_->read((void *) this->send_buffer_, SEND_BUFFER_SIZE, 0); @@ -603,14 +525,8 @@ void VoiceAssistant::request_start(bool continuous, bool silence_detection) { if (this->state_ == State::IDLE) { this->continuous_ = continuous; this->silence_detection_ = silence_detection; -#ifdef USE_ESP_ADF - if (this->use_wake_word_) { - this->set_state_(State::START_MICROPHONE, State::WAIT_FOR_VAD); - } else -#endif - { - this->set_state_(State::START_MICROPHONE, State::START_PIPELINE); - } + + this->set_state_(State::START_MICROPHONE, State::START_PIPELINE); } } @@ -785,15 +701,7 @@ void VoiceAssistant::on_event(const api::VoiceAssistantEventResponse &msg) { this->set_state_(State::IDLE, State::IDLE); } else if (this->state_ == State::STREAMING_MICROPHONE) { this->ring_buffer_->reset(); -#ifdef USE_ESP_ADF - if (this->use_wake_word_) { - // No need to stop the microphone since we didn't use the speaker - this->set_state_(State::WAIT_FOR_VAD, State::WAITING_FOR_VAD); - } else -#endif - { - this->set_state_(State::IDLE, State::IDLE); - } + this->set_state_(State::IDLE, State::IDLE); } this->defer([this]() { this->end_trigger_->trigger(); }); break; diff --git a/esphome/components/voice_assistant/voice_assistant.h b/esphome/components/voice_assistant/voice_assistant.h index 66531fcd94..cb57a6b05d 100644 --- a/esphome/components/voice_assistant/voice_assistant.h +++ b/esphome/components/voice_assistant/voice_assistant.h @@ -20,10 +20,6 @@ #endif #include "esphome/components/socket/socket.h" -#ifdef USE_ESP_ADF -#include -#endif - #include #include @@ -96,6 +92,7 @@ class VoiceAssistant : public Component { VoiceAssistant(); void loop() override; + void setup() override; float get_setup_priority() const override; void start_streaming(); void start_streaming(struct sockaddr_storage *addr, uint16_t port); @@ -163,9 +160,6 @@ class VoiceAssistant : public Component { bool is_continuous() const { return this->continuous_; } void set_use_wake_word(bool use_wake_word) { this->use_wake_word_ = use_wake_word; } -#ifdef USE_ESP_ADF - void set_vad_threshold(uint8_t vad_threshold) { this->vad_threshold_ = vad_threshold; } -#endif void set_noise_suppression_level(uint8_t noise_suppression_level) { this->noise_suppression_level_ = noise_suppression_level; @@ -214,7 +208,6 @@ class VoiceAssistant : public Component { void clear_buffers_(); void deallocate_buffers_(); - int read_microphone_(); void set_state_(State state); void set_state_(State state, State desired_state); void signal_stop_(); @@ -279,13 +272,6 @@ class VoiceAssistant : public Component { std::string wake_word_{""}; - HighFrequencyLoopRequester high_freq_; - -#ifdef USE_ESP_ADF - vad_handle_t vad_instance_; - uint8_t vad_threshold_{5}; - uint8_t vad_counter_{0}; -#endif std::shared_ptr ring_buffer_; bool use_wake_word_; @@ -295,7 +281,6 @@ class VoiceAssistant : public Component { uint32_t conversation_timeout_; uint8_t *send_buffer_{nullptr}; - int16_t *input_buffer_{nullptr}; bool continuous_{false}; bool silence_detection_; From 2d3f1411403f2ec1d1e6e65f34d2cc753a9ce235 Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Mon, 28 Apr 2025 10:19:50 +1000 Subject: [PATCH 021/102] [core] Fix setting of log level/verbose (#8600) --- esphome/log.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/esphome/log.py b/esphome/log.py index 835cd6b44d..516f27be45 100644 --- a/esphome/log.py +++ b/esphome/log.py @@ -74,13 +74,14 @@ def setup_log( colorama.init() - if log_level == logging.DEBUG: - CORE.verbose = True - elif log_level == logging.CRITICAL: - CORE.quiet = True - + # Setup logging - will map log level from string to constant logging.basicConfig(level=log_level) + if logging.root.level == logging.DEBUG: + CORE.verbose = True + elif logging.root.level == logging.CRITICAL: + CORE.quiet = True + logging.getLogger("urllib3").setLevel(logging.WARNING) logging.getLogger().handlers[0].setFormatter( From 22c0e1079e83e7e28416b2bc612d5d98b96f9f9e Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Mon, 28 Apr 2025 10:23:18 +1000 Subject: [PATCH 022/102] [const] Create component-level const repository (#8385) --- CODEOWNERS | 1 + esphome/components/const/__init__.py | 5 +++ esphome/components/lvgl/__init__.py | 7 +-- esphome/components/lvgl/defines.py | 1 - esphome/components/qspi_dbi/__init__.py | 1 - esphome/components/qspi_dbi/display.py | 3 +- esphome/components/qspi_dbi/models.py | 3 +- tests/components/const/common.yaml | 44 +++++++++++++++++++ tests/components/const/test.esp32-s3-idf.yaml | 1 + 9 files changed, 58 insertions(+), 8 deletions(-) create mode 100644 esphome/components/const/__init__.py create mode 100644 tests/components/const/common.yaml create mode 100644 tests/components/const/test.esp32-s3-idf.yaml diff --git a/CODEOWNERS b/CODEOWNERS index d080563028..73973f420f 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -98,6 +98,7 @@ esphome/components/climate/* @esphome/core esphome/components/climate_ir/* @glmnet esphome/components/color_temperature/* @jesserockz esphome/components/combination/* @Cat-Ion @kahrendt +esphome/components/const/* @esphome/core esphome/components/coolix/* @glmnet esphome/components/copy/* @OttoWinter esphome/components/cover/* @esphome/core diff --git a/esphome/components/const/__init__.py b/esphome/components/const/__init__.py new file mode 100644 index 0000000000..6af357f23b --- /dev/null +++ b/esphome/components/const/__init__.py @@ -0,0 +1,5 @@ +"""Constants used by esphome components.""" + +CODEOWNERS = ["@esphome/core"] + +CONF_DRAW_ROUNDING = "draw_rounding" diff --git a/esphome/components/lvgl/__init__.py b/esphome/components/lvgl/__init__.py index 30fa58c380..69286ada88 100644 --- a/esphome/components/lvgl/__init__.py +++ b/esphome/components/lvgl/__init__.py @@ -2,6 +2,7 @@ import logging from esphome.automation import build_automation, register_action, validate_automation import esphome.codegen as cg +from esphome.components.const import CONF_DRAW_ROUNDING from esphome.components.display import Display import esphome.config_validation as cv from esphome.const import ( @@ -24,7 +25,7 @@ from esphome.helpers import write_file_if_changed from . import defines as df, helpers, lv_validation as lvalid from .automation import disp_update, focused_widgets, update_to_code -from .defines import CONF_DRAW_ROUNDING, add_define +from .defines import add_define from .encoders import ( ENCODERS_CONFIG, encoders_to_code, @@ -323,7 +324,7 @@ async def to_code(configs): displays, frac, config[df.CONF_FULL_REFRESH], - config[df.CONF_DRAW_ROUNDING], + config[CONF_DRAW_ROUNDING], config[df.CONF_RESUME_ON_INPUT], ) await cg.register_component(lv_component, config) @@ -413,7 +414,7 @@ LVGL_SCHEMA = cv.All( df.CONF_DEFAULT_FONT, default="montserrat_14" ): lvalid.lv_font, cv.Optional(df.CONF_FULL_REFRESH, default=False): cv.boolean, - cv.Optional(df.CONF_DRAW_ROUNDING, default=2): cv.positive_int, + cv.Optional(CONF_DRAW_ROUNDING, default=2): cv.positive_int, cv.Optional(CONF_BUFFER_SIZE, default="100%"): cv.percentage, cv.Optional(df.CONF_LOG_LEVEL, default="WARN"): cv.one_of( *df.LV_LOG_LEVELS, upper=True diff --git a/esphome/components/lvgl/defines.py b/esphome/components/lvgl/defines.py index 7dedb55418..7783fb2321 100644 --- a/esphome/components/lvgl/defines.py +++ b/esphome/components/lvgl/defines.py @@ -424,7 +424,6 @@ CONF_DEFAULT_FONT = "default_font" CONF_DEFAULT_GROUP = "default_group" CONF_DIR = "dir" CONF_DISPLAYS = "displays" -CONF_DRAW_ROUNDING = "draw_rounding" CONF_EDITING = "editing" CONF_ENCODERS = "encoders" CONF_END_ANGLE = "end_angle" diff --git a/esphome/components/qspi_dbi/__init__.py b/esphome/components/qspi_dbi/__init__.py index a4b833f6d7..290a864335 100644 --- a/esphome/components/qspi_dbi/__init__.py +++ b/esphome/components/qspi_dbi/__init__.py @@ -1,4 +1,3 @@ CODEOWNERS = ["@clydebarrow"] CONF_DRAW_FROM_ORIGIN = "draw_from_origin" -CONF_DRAW_ROUNDING = "draw_rounding" diff --git a/esphome/components/qspi_dbi/display.py b/esphome/components/qspi_dbi/display.py index 8c29991f37..5b01bcc6ca 100644 --- a/esphome/components/qspi_dbi/display.py +++ b/esphome/components/qspi_dbi/display.py @@ -1,6 +1,7 @@ from esphome import pins import esphome.codegen as cg from esphome.components import display, spi +from esphome.components.const import CONF_DRAW_ROUNDING import esphome.config_validation as cv from esphome.const import ( CONF_BRIGHTNESS, @@ -24,7 +25,7 @@ from esphome.const import ( ) from esphome.core import TimePeriod -from . import CONF_DRAW_FROM_ORIGIN, CONF_DRAW_ROUNDING +from . import CONF_DRAW_FROM_ORIGIN from .models import DriverChip DEPENDENCIES = ["spi"] diff --git a/esphome/components/qspi_dbi/models.py b/esphome/components/qspi_dbi/models.py index 7ae1a10ec0..8ce592e0cf 100644 --- a/esphome/components/qspi_dbi/models.py +++ b/esphome/components/qspi_dbi/models.py @@ -1,8 +1,7 @@ # Commands +from esphome.components.const import CONF_DRAW_ROUNDING from esphome.const import CONF_INVERT_COLORS, CONF_SWAP_XY -from . import CONF_DRAW_ROUNDING - SW_RESET_CMD = 0x01 SLEEP_IN = 0x10 SLEEP_OUT = 0x11 diff --git a/tests/components/const/common.yaml b/tests/components/const/common.yaml new file mode 100644 index 0000000000..655af304af --- /dev/null +++ b/tests/components/const/common.yaml @@ -0,0 +1,44 @@ +spi: + id: quad_spi + clk_pin: 15 + type: quad + data_pins: [14, 10, 16, 12] + +display: + - platform: qspi_dbi + model: RM690B0 + data_rate: 80MHz + spi_mode: mode0 + dimensions: + width: 450 + height: 600 + offset_width: 16 + color_order: rgb + invert_colors: false + brightness: 255 + cs_pin: 11 + reset_pin: 13 + enable_pin: 9 + + - platform: qspi_dbi + model: CUSTOM + id: main_lcd + draw_from_origin: true + dimensions: + height: 240 + width: 536 + transform: + mirror_x: true + swap_xy: true + color_order: rgb + brightness: 255 + cs_pin: 6 + reset_pin: 17 + enable_pin: 38 + init_sequence: + - [0x3A, 0x66] + - [0x11] + - delay 120ms + - [0x29] + - delay 20ms + diff --git a/tests/components/const/test.esp32-s3-idf.yaml b/tests/components/const/test.esp32-s3-idf.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/const/test.esp32-s3-idf.yaml @@ -0,0 +1 @@ +<<: !include common.yaml From 38dae8489e6d9636f31a9ed076923c6acad0bf07 Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Mon, 28 Apr 2025 11:45:28 +1000 Subject: [PATCH 023/102] [http_request] Implement for host platform (#8040) --- esphome/components/http_request/__init__.py | 26 + .../http_request/http_request_host.cpp | 141 + .../http_request/http_request_host.h | 37 + esphome/components/http_request/httplib.h | 9691 +++++++++++++++++ script/ci-custom.py | 2 + tests/components/http_request/common.yaml | 3 +- .../components/http_request/http_request.yaml | 46 + tests/components/http_request/test.host.yaml | 7 + 8 files changed, 9951 insertions(+), 2 deletions(-) create mode 100644 esphome/components/http_request/http_request_host.cpp create mode 100644 esphome/components/http_request/http_request_host.h create mode 100644 esphome/components/http_request/httplib.h create mode 100644 tests/components/http_request/http_request.yaml create mode 100644 tests/components/http_request/test.host.yaml diff --git a/esphome/components/http_request/__init__.py b/esphome/components/http_request/__init__.py index 2a999532f8..4da49ddde1 100644 --- a/esphome/components/http_request/__init__.py +++ b/esphome/components/http_request/__init__.py @@ -10,9 +10,11 @@ from esphome.const import ( CONF_TIMEOUT, CONF_TRIGGER_ID, CONF_URL, + PLATFORM_HOST, __version__, ) from esphome.core import CORE, Lambda +from esphome.helpers import IS_MACOS DEPENDENCIES = ["network"] AUTO_LOAD = ["json", "watchdog"] @@ -21,6 +23,7 @@ http_request_ns = cg.esphome_ns.namespace("http_request") HttpRequestComponent = http_request_ns.class_("HttpRequestComponent", cg.Component) HttpRequestArduino = http_request_ns.class_("HttpRequestArduino", HttpRequestComponent) HttpRequestIDF = http_request_ns.class_("HttpRequestIDF", HttpRequestComponent) +HttpRequestHost = http_request_ns.class_("HttpRequestHost", HttpRequestComponent) HttpContainer = http_request_ns.class_("HttpContainer") @@ -43,6 +46,7 @@ CONF_REDIRECT_LIMIT = "redirect_limit" CONF_WATCHDOG_TIMEOUT = "watchdog_timeout" CONF_BUFFER_SIZE_RX = "buffer_size_rx" CONF_BUFFER_SIZE_TX = "buffer_size_tx" +CONF_CA_CERTIFICATE_PATH = "ca_certificate_path" CONF_MAX_RESPONSE_BUFFER_SIZE = "max_response_buffer_size" CONF_ON_RESPONSE = "on_response" @@ -87,6 +91,8 @@ def validate_ssl_verification(config): def _declare_request_class(value): + if CORE.is_host: + return cv.declare_id(HttpRequestHost)(value) if CORE.using_esp_idf: return cv.declare_id(HttpRequestIDF)(value) if CORE.is_esp8266 or CORE.is_esp32 or CORE.is_rp2040: @@ -121,6 +127,10 @@ CONFIG_SCHEMA = cv.All( cv.SplitDefault(CONF_BUFFER_SIZE_TX, esp32_idf=512): cv.All( cv.uint16_t, cv.only_with_esp_idf ), + cv.Optional(CONF_CA_CERTIFICATE_PATH): cv.All( + cv.file_, + cv.only_on(PLATFORM_HOST), + ), } ).extend(cv.COMPONENT_SCHEMA), cv.require_framework_version( @@ -128,6 +138,7 @@ CONFIG_SCHEMA = cv.All( esp32_arduino=cv.Version(0, 0, 0), esp_idf=cv.Version(0, 0, 0), rp2040_arduino=cv.Version(0, 0, 0), + host=cv.Version(0, 0, 0), ), validate_ssl_verification, ) @@ -170,6 +181,21 @@ async def to_code(config): cg.add_library("ESP8266HTTPClient", None) if CORE.is_rp2040 and CORE.using_arduino: cg.add_library("HTTPClient", None) + if CORE.is_host: + if IS_MACOS: + cg.add_build_flag("-I/opt/homebrew/opt/openssl/include") + cg.add_build_flag("-L/opt/homebrew/opt/openssl/lib") + cg.add_build_flag("-lssl") + cg.add_build_flag("-lcrypto") + cg.add_build_flag("-Wl,-framework,CoreFoundation") + cg.add_build_flag("-Wl,-framework,Security") + cg.add_define("CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN") + cg.add_define("CPPHTTPLIB_OPENSSL_SUPPORT") + elif path := config.get(CONF_CA_CERTIFICATE_PATH): + cg.add_define("CPPHTTPLIB_OPENSSL_SUPPORT") + cg.add(var.set_ca_path(path)) + cg.add_build_flag("-lssl") + cg.add_build_flag("-lcrypto") await cg.register_component(var, config) diff --git a/esphome/components/http_request/http_request_host.cpp b/esphome/components/http_request/http_request_host.cpp new file mode 100644 index 0000000000..192032c1ac --- /dev/null +++ b/esphome/components/http_request/http_request_host.cpp @@ -0,0 +1,141 @@ +#include "http_request_host.h" + +#ifdef USE_HOST + +#include +#include "esphome/components/network/util.h" +#include "esphome/components/watchdog/watchdog.h" + +#include "esphome/core/application.h" +#include "esphome/core/log.h" + +namespace esphome { +namespace http_request { + +static const char *const TAG = "http_request.host"; + +std::shared_ptr HttpRequestHost::perform(std::string url, std::string method, std::string body, + std::list
request_headers, + std::set response_headers) { + if (!network::is_connected()) { + this->status_momentary_error("failed", 1000); + ESP_LOGW(TAG, "HTTP Request failed; Not connected to network"); + return nullptr; + } + + std::regex url_regex(R"(^(([^:\/?#]+):)?(//([^\/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?)", std::regex::extended); + std::smatch url_match_result; + + if (!std::regex_match(url, url_match_result, url_regex) || url_match_result.length() < 7) { + ESP_LOGE(TAG, "HTTP Request failed; Malformed URL: %s", url.c_str()); + return nullptr; + } + auto host = url_match_result[4].str(); + auto scheme_host = url_match_result[1].str() + url_match_result[3].str(); + auto path = url_match_result[5].str() + url_match_result[6].str(); + if (path.empty()) + path = "/"; + + std::shared_ptr container = std::make_shared(); + container->set_parent(this); + + const uint32_t start = millis(); + + watchdog::WatchdogManager wdm(this->get_watchdog_timeout()); + + httplib::Headers h_headers; + h_headers.emplace("Host", host.c_str()); + h_headers.emplace("User-Agent", this->useragent_); + for (const auto &[name, value] : request_headers) { + h_headers.emplace(name, value); + } + httplib::Client client(scheme_host.c_str()); + if (!client.is_valid()) { + ESP_LOGE(TAG, "HTTP Request failed; Invalid URL: %s", url.c_str()); + return nullptr; + } + client.set_follow_location(this->follow_redirects_); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (this->ca_path_ != nullptr) + client.set_ca_cert_path(this->ca_path_); +#endif + + httplib::Result result; + if (method == "GET") { + result = client.Get(path, h_headers, [&](const char *data, size_t data_length) { + ESP_LOGV(TAG, "Got data length: %zu", data_length); + container->response_body_.insert(container->response_body_.end(), (const uint8_t *) data, + (const uint8_t *) data + data_length); + return true; + }); + } else if (method == "HEAD") { + result = client.Head(path, h_headers); + } else if (method == "PUT") { + result = client.Put(path, h_headers, body, ""); + if (result) { + auto data = std::vector(result->body.begin(), result->body.end()); + container->response_body_.insert(container->response_body_.end(), data.begin(), data.end()); + } + } else if (method == "PATCH") { + result = client.Patch(path, h_headers, body, ""); + if (result) { + auto data = std::vector(result->body.begin(), result->body.end()); + container->response_body_.insert(container->response_body_.end(), data.begin(), data.end()); + } + } else if (method == "POST") { + result = client.Post(path, h_headers, body, ""); + if (result) { + auto data = std::vector(result->body.begin(), result->body.end()); + container->response_body_.insert(container->response_body_.end(), data.begin(), data.end()); + } + } else { + ESP_LOGW(TAG, "HTTP Request failed - unsupported method %s; URL: %s", method.c_str(), url.c_str()); + container->end(); + return nullptr; + } + App.feed_wdt(); + if (!result) { + ESP_LOGW(TAG, "HTTP Request failed; URL: %s, error code: %u", url.c_str(), (unsigned) result.error()); + container->end(); + this->status_momentary_error("failed", 1000); + return nullptr; + } + App.feed_wdt(); + auto response = *result; + container->status_code = response.status; + if (!is_success(response.status)) { + ESP_LOGE(TAG, "HTTP Request failed; URL: %s; Code: %d", url.c_str(), response.status); + this->status_momentary_error("failed", 1000); + // Still return the container, so it can be used to get the status code and error message + } + + container->content_length = container->response_body_.size(); + for (auto header : response.headers) { + ESP_LOGD(TAG, "Header: %s: %s", header.first.c_str(), header.second.c_str()); + auto lower_name = str_lower_case(header.first); + if (response_headers.find(lower_name) != response_headers.end()) { + container->response_headers_[lower_name].emplace_back(header.second); + } + } + container->duration_ms = millis() - start; + return container; +} + +int HttpContainerHost::read(uint8_t *buf, size_t max_len) { + auto bytes_remaining = this->response_body_.size() - this->bytes_read_; + auto read_len = std::min(max_len, bytes_remaining); + memcpy(buf, this->response_body_.data() + this->bytes_read_, read_len); + this->bytes_read_ += read_len; + return read_len; +} + +void HttpContainerHost::end() { + watchdog::WatchdogManager wdm(this->parent_->get_watchdog_timeout()); + this->response_body_ = std::vector(); + this->bytes_read_ = 0; +} + +} // namespace http_request +} // namespace esphome + +#endif // USE_HOST diff --git a/esphome/components/http_request/http_request_host.h b/esphome/components/http_request/http_request_host.h new file mode 100644 index 0000000000..49fd3b43fe --- /dev/null +++ b/esphome/components/http_request/http_request_host.h @@ -0,0 +1,37 @@ +#pragma once + +#include "http_request.h" + +#ifdef USE_HOST + +#define CPPHTTPLIB_NO_EXCEPTIONS +#include "httplib.h" +namespace esphome { +namespace http_request { + +class HttpRequestHost; +class HttpContainerHost : public HttpContainer { + public: + int read(uint8_t *buf, size_t max_len) override; + void end() override; + + protected: + friend class HttpRequestHost; + std::vector response_body_{}; +}; + +class HttpRequestHost : public HttpRequestComponent { + public: + std::shared_ptr perform(std::string url, std::string method, std::string body, + std::list
request_headers, + std::set response_headers) override; + void set_ca_path(const char *ca_path) { this->ca_path_ = ca_path; } + + protected: + const char *ca_path_{}; +}; + +} // namespace http_request +} // namespace esphome + +#endif // USE_HOST diff --git a/esphome/components/http_request/httplib.h b/esphome/components/http_request/httplib.h new file mode 100644 index 0000000000..a2f4436ec7 --- /dev/null +++ b/esphome/components/http_request/httplib.h @@ -0,0 +1,9691 @@ +#pragma once + +/** + * NOTE: This is a copy of httplib.h from https://github.com/yhirose/cpp-httplib + * + * It has been modified only to add ifdefs for USE_HOST. While it contains many functions unused in ESPHome, + * it was considered preferable to use it with as few changes as possible, to facilitate future updates. + */ + +#include "esphome/core/defines.h" + +// +// httplib.h +// +// Copyright (c) 2024 Yuji Hirose. All rights reserved. +// MIT License +// + +#ifdef USE_HOST +#ifndef CPPHTTPLIB_HTTPLIB_H +#define CPPHTTPLIB_HTTPLIB_H + +#define CPPHTTPLIB_VERSION "0.18.2" + +/* + * Configuration + */ + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND 10000 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_MAX_COUNT +#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 100 +#endif + +#ifndef CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND +#define CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND 300 +#endif + +#ifndef CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND +#define CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND +#define CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND +#define CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND +#define CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND 300 +#endif + +#ifndef CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND +#define CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND +#define CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND +#define CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_IDLE_INTERVAL_SECOND +#define CPPHTTPLIB_IDLE_INTERVAL_SECOND 0 +#endif + +#ifndef CPPHTTPLIB_IDLE_INTERVAL_USECOND +#ifdef _WIN32 +#define CPPHTTPLIB_IDLE_INTERVAL_USECOND 10000 +#else +#define CPPHTTPLIB_IDLE_INTERVAL_USECOND 0 +#endif +#endif + +#ifndef CPPHTTPLIB_REQUEST_URI_MAX_LENGTH +#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_HEADER_MAX_LENGTH +#define CPPHTTPLIB_HEADER_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_REDIRECT_MAX_COUNT +#define CPPHTTPLIB_REDIRECT_MAX_COUNT 20 +#endif + +#ifndef CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT +#define CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT 1024 +#endif + +#ifndef CPPHTTPLIB_PAYLOAD_MAX_LENGTH +#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH ((std::numeric_limits::max)()) +#endif + +#ifndef CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH +#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_RANGE_MAX_COUNT +#define CPPHTTPLIB_RANGE_MAX_COUNT 1024 +#endif + +#ifndef CPPHTTPLIB_TCP_NODELAY +#define CPPHTTPLIB_TCP_NODELAY false +#endif + +#ifndef CPPHTTPLIB_IPV6_V6ONLY +#define CPPHTTPLIB_IPV6_V6ONLY false +#endif + +#ifndef CPPHTTPLIB_RECV_BUFSIZ +#define CPPHTTPLIB_RECV_BUFSIZ size_t(16384u) +#endif + +#ifndef CPPHTTPLIB_COMPRESSION_BUFSIZ +#define CPPHTTPLIB_COMPRESSION_BUFSIZ size_t(16384u) +#endif + +#ifndef CPPHTTPLIB_THREAD_POOL_COUNT +#define CPPHTTPLIB_THREAD_POOL_COUNT \ + ((std::max)(8u, std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() - 1 : 0)) +#endif + +#ifndef CPPHTTPLIB_RECV_FLAGS +#define CPPHTTPLIB_RECV_FLAGS 0 +#endif + +#ifndef CPPHTTPLIB_SEND_FLAGS +#define CPPHTTPLIB_SEND_FLAGS 0 +#endif + +#ifndef CPPHTTPLIB_LISTEN_BACKLOG +#define CPPHTTPLIB_LISTEN_BACKLOG 5 +#endif + +/* + * Headers + */ + +#ifdef _WIN32 +#ifndef _CRT_SECURE_NO_WARNINGS +#define _CRT_SECURE_NO_WARNINGS +#endif //_CRT_SECURE_NO_WARNINGS + +#ifndef _CRT_NONSTDC_NO_DEPRECATE +#define _CRT_NONSTDC_NO_DEPRECATE +#endif //_CRT_NONSTDC_NO_DEPRECATE + +#if defined(_MSC_VER) +#if _MSC_VER < 1900 +#error Sorry, Visual Studio versions prior to 2015 are not supported +#endif + +#pragma comment(lib, "ws2_32.lib") + +#ifdef _WIN64 +using ssize_t = __int64; +#else +using ssize_t = long; +#endif +#endif // _MSC_VER + +#ifndef S_ISREG +#define S_ISREG(m) (((m) &S_IFREG) == S_IFREG) +#endif // S_ISREG + +#ifndef S_ISDIR +#define S_ISDIR(m) (((m) &S_IFDIR) == S_IFDIR) +#endif // S_ISDIR + +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX + +#include +#include +#include + +#ifndef WSA_FLAG_NO_HANDLE_INHERIT +#define WSA_FLAG_NO_HANDLE_INHERIT 0x80 +#endif + +using socket_t = SOCKET; +#ifdef CPPHTTPLIB_USE_POLL +#define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout) +#endif + +#else // not _WIN32 + +#include +#if !defined(_AIX) && !defined(__MVS__) +#include +#endif +#ifdef __MVS__ +#include +#ifndef NI_MAXHOST +#define NI_MAXHOST 1025 +#endif +#endif +#include +#include +#include +#ifdef __linux__ +#include +#endif +#include +#ifdef CPPHTTPLIB_USE_POLL +#include +#endif +#include +#include +#include +#include +#include +#include +#include + +using socket_t = int; +#ifndef INVALID_SOCKET +#define INVALID_SOCKET (-1) +#endif +#endif //_WIN32 + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef _WIN32 +#include + +// these are defined in wincrypt.h and it breaks compilation if BoringSSL is +// used +#undef X509_NAME +#undef X509_CERT_PAIR +#undef X509_EXTENSIONS +#undef PKCS7_SIGNER_INFO + +#ifdef _MSC_VER +#pragma comment(lib, "crypt32.lib") +#endif +#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && defined(__APPLE__) +#include +#if TARGET_OS_OSX +#include +#include +#endif // TARGET_OS_OSX +#endif // _WIN32 + +#include +#include +#include +#include + +#if defined(_WIN32) && defined(OPENSSL_USE_APPLINK) +#include +#endif + +#include +#include + +#if defined(OPENSSL_IS_BORINGSSL) || defined(LIBRESSL_VERSION_NUMBER) +#if OPENSSL_VERSION_NUMBER < 0x1010107f +#error Please use OpenSSL or a current version of BoringSSL +#endif +#define SSL_get1_peer_certificate SSL_get_peer_certificate +#elif OPENSSL_VERSION_NUMBER < 0x30000000L +#error Sorry, OpenSSL versions prior to 3.0.0 are not supported +#endif + +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +#include +#endif + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT +#include +#include +#endif + +/* + * Declaration + */ +namespace httplib { + +namespace detail { + +/* + * Backport std::make_unique from C++14. + * + * NOTE: This code came up with the following stackoverflow post: + * https://stackoverflow.com/questions/10149840/c-arrays-and-make-unique + * + */ + +template +typename std::enable_if::value, std::unique_ptr>::type make_unique(Args &&...args) { + return std::unique_ptr(new T(std::forward(args)...)); +} + +template +typename std::enable_if::value, std::unique_ptr>::type make_unique(std::size_t n) { + typedef typename std::remove_extent::type RT; + return std::unique_ptr(new RT[n]); +} + +namespace case_ignore { + +inline unsigned char to_lower(int c) { + const static unsigned char table[256] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, + 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 97, + 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, + 120, 121, 122, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, + 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, + 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, + 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, + 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 224, 225, 226, 227, 228, 229, + 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, + 252, 253, 254, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, + 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, + }; + return table[(unsigned char) (char) c]; +} + +inline bool equal(const std::string &a, const std::string &b) { + return a.size() == b.size() && + std::equal(a.begin(), a.end(), b.begin(), [](char ca, char cb) { return to_lower(ca) == to_lower(cb); }); +} + +struct equal_to { + bool operator()(const std::string &a, const std::string &b) const { return equal(a, b); } +}; + +struct hash { + size_t operator()(const std::string &key) const { return hash_core(key.data(), key.size(), 0); } + + size_t hash_core(const char *s, size_t l, size_t h) const { + return (l == 0) ? h + : hash_core(s + 1, l - 1, + // Unsets the 6 high bits of h, therefore no + // overflow happens + (((std::numeric_limits::max)() >> 6) & h * 33) ^ + static_cast(to_lower(*s))); + } +}; + +} // namespace case_ignore + +// This is based on +// "http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2014/n4189". + +struct scope_exit { + explicit scope_exit(std::function &&f) : exit_function(std::move(f)), execute_on_destruction{true} {} + + scope_exit(scope_exit &&rhs) noexcept + : exit_function(std::move(rhs.exit_function)), execute_on_destruction{rhs.execute_on_destruction} { + rhs.release(); + } + + ~scope_exit() { + if (execute_on_destruction) { + this->exit_function(); + } + } + + void release() { this->execute_on_destruction = false; } + + private: + scope_exit(const scope_exit &) = delete; + void operator=(const scope_exit &) = delete; + scope_exit &operator=(scope_exit &&) = delete; + + std::function exit_function; + bool execute_on_destruction; +}; + +} // namespace detail + +enum StatusCode { + // Information responses + Continue_100 = 100, + SwitchingProtocol_101 = 101, + Processing_102 = 102, + EarlyHints_103 = 103, + + // Successful responses + OK_200 = 200, + Created_201 = 201, + Accepted_202 = 202, + NonAuthoritativeInformation_203 = 203, + NoContent_204 = 204, + ResetContent_205 = 205, + PartialContent_206 = 206, + MultiStatus_207 = 207, + AlreadyReported_208 = 208, + IMUsed_226 = 226, + + // Redirection messages + MultipleChoices_300 = 300, + MovedPermanently_301 = 301, + Found_302 = 302, + SeeOther_303 = 303, + NotModified_304 = 304, + UseProxy_305 = 305, + unused_306 = 306, + TemporaryRedirect_307 = 307, + PermanentRedirect_308 = 308, + + // Client error responses + BadRequest_400 = 400, + Unauthorized_401 = 401, + PaymentRequired_402 = 402, + Forbidden_403 = 403, + NotFound_404 = 404, + MethodNotAllowed_405 = 405, + NotAcceptable_406 = 406, + ProxyAuthenticationRequired_407 = 407, + RequestTimeout_408 = 408, + Conflict_409 = 409, + Gone_410 = 410, + LengthRequired_411 = 411, + PreconditionFailed_412 = 412, + PayloadTooLarge_413 = 413, + UriTooLong_414 = 414, + UnsupportedMediaType_415 = 415, + RangeNotSatisfiable_416 = 416, + ExpectationFailed_417 = 417, + ImATeapot_418 = 418, + MisdirectedRequest_421 = 421, + UnprocessableContent_422 = 422, + Locked_423 = 423, + FailedDependency_424 = 424, + TooEarly_425 = 425, + UpgradeRequired_426 = 426, + PreconditionRequired_428 = 428, + TooManyRequests_429 = 429, + RequestHeaderFieldsTooLarge_431 = 431, + UnavailableForLegalReasons_451 = 451, + + // Server error responses + InternalServerError_500 = 500, + NotImplemented_501 = 501, + BadGateway_502 = 502, + ServiceUnavailable_503 = 503, + GatewayTimeout_504 = 504, + HttpVersionNotSupported_505 = 505, + VariantAlsoNegotiates_506 = 506, + InsufficientStorage_507 = 507, + LoopDetected_508 = 508, + NotExtended_510 = 510, + NetworkAuthenticationRequired_511 = 511, +}; + +using Headers = + std::unordered_multimap; + +using Params = std::multimap; +using Match = std::smatch; + +using Progress = std::function; + +struct Response; +using ResponseHandler = std::function; + +struct MultipartFormData { + std::string name; + std::string content; + std::string filename; + std::string content_type; +}; +using MultipartFormDataItems = std::vector; +using MultipartFormDataMap = std::multimap; + +class DataSink { + public: + DataSink() : os(&sb_), sb_(*this) {} + + DataSink(const DataSink &) = delete; + DataSink &operator=(const DataSink &) = delete; + DataSink(DataSink &&) = delete; + DataSink &operator=(DataSink &&) = delete; + + std::function write; + std::function is_writable; + std::function done; + std::function done_with_trailer; + std::ostream os; + + private: + class data_sink_streambuf final : public std::streambuf { + public: + explicit data_sink_streambuf(DataSink &sink) : sink_(sink) {} + + protected: + std::streamsize xsputn(const char *s, std::streamsize n) override { + sink_.write(s, static_cast(n)); + return n; + } + + private: + DataSink &sink_; + }; + + data_sink_streambuf sb_; +}; + +using ContentProvider = std::function; + +using ContentProviderWithoutLength = std::function; + +using ContentProviderResourceReleaser = std::function; + +struct MultipartFormDataProvider { + std::string name; + ContentProviderWithoutLength provider; + std::string filename; + std::string content_type; +}; +using MultipartFormDataProviderItems = std::vector; + +using ContentReceiverWithProgress = + std::function; + +using ContentReceiver = std::function; + +using MultipartContentHeader = std::function; + +class ContentReader { + public: + using Reader = std::function; + using MultipartReader = std::function; + + ContentReader(Reader reader, MultipartReader multipart_reader) + : reader_(std::move(reader)), multipart_reader_(std::move(multipart_reader)) {} + + bool operator()(MultipartContentHeader header, ContentReceiver receiver) const { + return multipart_reader_(std::move(header), std::move(receiver)); + } + + bool operator()(ContentReceiver receiver) const { return reader_(std::move(receiver)); } + + Reader reader_; + MultipartReader multipart_reader_; +}; + +using Range = std::pair; +using Ranges = std::vector; + +struct Request { + std::string method; + std::string path; + Params params; + Headers headers; + std::string body; + + std::string remote_addr; + int remote_port = -1; + std::string local_addr; + int local_port = -1; + + // for server + std::string version; + std::string target; + MultipartFormDataMap files; + Ranges ranges; + Match matches; + std::unordered_map path_params; + + // for client + ResponseHandler response_handler; + ContentReceiverWithProgress content_receiver; + Progress progress; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + const SSL *ssl = nullptr; +#endif + + bool has_header(const std::string &key) const; + std::string get_header_value(const std::string &key, const char *def = "", size_t id = 0) const; + uint64_t get_header_value_u64(const std::string &key, uint64_t def = 0, size_t id = 0) const; + size_t get_header_value_count(const std::string &key) const; + void set_header(const std::string &key, const std::string &val); + + bool has_param(const std::string &key) const; + std::string get_param_value(const std::string &key, size_t id = 0) const; + size_t get_param_value_count(const std::string &key) const; + + bool is_multipart_form_data() const; + + bool has_file(const std::string &key) const; + MultipartFormData get_file_value(const std::string &key) const; + std::vector get_file_values(const std::string &key) const; + + // private members... + size_t redirect_count_ = CPPHTTPLIB_REDIRECT_MAX_COUNT; + size_t content_length_ = 0; + ContentProvider content_provider_; + bool is_chunked_content_provider_ = false; + size_t authorization_count_ = 0; +}; + +struct Response { + std::string version; + int status = -1; + std::string reason; + Headers headers; + std::string body; + std::string location; // Redirect location + + bool has_header(const std::string &key) const; + std::string get_header_value(const std::string &key, const char *def = "", size_t id = 0) const; + uint64_t get_header_value_u64(const std::string &key, uint64_t def = 0, size_t id = 0) const; + size_t get_header_value_count(const std::string &key) const; + void set_header(const std::string &key, const std::string &val); + + void set_redirect(const std::string &url, int status = StatusCode::Found_302); + void set_content(const char *s, size_t n, const std::string &content_type); + void set_content(const std::string &s, const std::string &content_type); + void set_content(std::string &&s, const std::string &content_type); + + void set_content_provider(size_t length, const std::string &content_type, ContentProvider provider, + ContentProviderResourceReleaser resource_releaser = nullptr); + + void set_content_provider(const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser = nullptr); + + void set_chunked_content_provider(const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser = nullptr); + + void set_file_content(const std::string &path, const std::string &content_type); + void set_file_content(const std::string &path); + + Response() = default; + Response(const Response &) = default; + Response &operator=(const Response &) = default; + Response(Response &&) = default; + Response &operator=(Response &&) = default; + ~Response() { + if (content_provider_resource_releaser_) { + content_provider_resource_releaser_(content_provider_success_); + } + } + + // private members... + size_t content_length_ = 0; + ContentProvider content_provider_; + ContentProviderResourceReleaser content_provider_resource_releaser_; + bool is_chunked_content_provider_ = false; + bool content_provider_success_ = false; + std::string file_content_path_; + std::string file_content_content_type_; +}; + +class Stream { + public: + virtual ~Stream() = default; + + virtual bool is_readable() const = 0; + virtual bool is_writable() const = 0; + + virtual ssize_t read(char *ptr, size_t size) = 0; + virtual ssize_t write(const char *ptr, size_t size) = 0; + virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0; + virtual void get_local_ip_and_port(std::string &ip, int &port) const = 0; + virtual socket_t socket() const = 0; + + ssize_t write(const char *ptr); + ssize_t write(const std::string &s); +}; + +class TaskQueue { + public: + TaskQueue() = default; + virtual ~TaskQueue() = default; + + virtual bool enqueue(std::function fn) = 0; + virtual void shutdown() = 0; + + virtual void on_idle() {} +}; + +class ThreadPool final : public TaskQueue { + public: + explicit ThreadPool(size_t n, size_t mqr = 0) : shutdown_(false), max_queued_requests_(mqr) { + while (n) { + threads_.emplace_back(worker(*this)); + n--; + } + } + + ThreadPool(const ThreadPool &) = delete; + ~ThreadPool() override = default; + + bool enqueue(std::function fn) override { + { + std::unique_lock lock(mutex_); + if (max_queued_requests_ > 0 && jobs_.size() >= max_queued_requests_) { + return false; + } + jobs_.push_back(std::move(fn)); + } + + cond_.notify_one(); + return true; + } + + void shutdown() override { + // Stop all worker threads... + { + std::unique_lock lock(mutex_); + shutdown_ = true; + } + + cond_.notify_all(); + + // Join... + for (auto &t : threads_) { + t.join(); + } + } + + private: + struct worker { + explicit worker(ThreadPool &pool) : pool_(pool) {} + + void operator()() { + for (;;) { + std::function fn; + { + std::unique_lock lock(pool_.mutex_); + + pool_.cond_.wait(lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); + + if (pool_.shutdown_ && pool_.jobs_.empty()) { + break; + } + + fn = pool_.jobs_.front(); + pool_.jobs_.pop_front(); + } + + assert(true == static_cast(fn)); + fn(); + } + +#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) && !defined(OPENSSL_IS_BORINGSSL) && !defined(LIBRESSL_VERSION_NUMBER) + OPENSSL_thread_stop(); +#endif + } + + ThreadPool &pool_; + }; + friend struct worker; + + std::vector threads_; + std::list> jobs_; + + bool shutdown_; + size_t max_queued_requests_ = 0; + + std::condition_variable cond_; + std::mutex mutex_; +}; + +using Logger = std::function; + +using SocketOptions = std::function; + +void default_socket_options(socket_t sock); + +const char *status_message(int status); + +std::string get_bearer_token_auth(const Request &req); + +namespace detail { + +class MatcherBase { + public: + virtual ~MatcherBase() = default; + + // Match request path and populate its matches and + virtual bool match(Request &request) const = 0; +}; + +/** + * Captures parameters in request path and stores them in Request::path_params + * + * Capture name is a substring of a pattern from : to /. + * The rest of the pattern is matched agains the request path directly + * Parameters are captured starting from the next character after + * the end of the last matched static pattern fragment until the next /. + * + * Example pattern: + * "/path/fragments/:capture/more/fragments/:second_capture" + * Static fragments: + * "/path/fragments/", "more/fragments/" + * + * Given the following request path: + * "/path/fragments/:1/more/fragments/:2" + * the resulting capture will be + * {{"capture", "1"}, {"second_capture", "2"}} + */ +class PathParamsMatcher final : public MatcherBase { + public: + PathParamsMatcher(const std::string &pattern); + + bool match(Request &request) const override; + + private: + // Treat segment separators as the end of path parameter capture + // Does not need to handle query parameters as they are parsed before path + // matching + static constexpr char separator = '/'; + + // Contains static path fragments to match against, excluding the '/' after + // path params + // Fragments are separated by path params + std::vector static_fragments_; + // Stores the names of the path parameters to be used as keys in the + // Request::path_params map + std::vector param_names_; +}; + +/** + * Performs std::regex_match on request path + * and stores the result in Request::matches + * + * Note that regex match is performed directly on the whole request. + * This means that wildcard patterns may match multiple path segments with /: + * "/begin/(.*)/end" will match both "/begin/middle/end" and "/begin/1/2/end". + */ +class RegexMatcher final : public MatcherBase { + public: + RegexMatcher(const std::string &pattern) : regex_(pattern) {} + + bool match(Request &request) const override; + + private: + std::regex regex_; +}; + +ssize_t write_headers(Stream &strm, const Headers &headers); + +} // namespace detail + +class Server { + public: + using Handler = std::function; + + using ExceptionHandler = std::function; + + enum class HandlerResponse { + Handled, + Unhandled, + }; + using HandlerWithResponse = std::function; + + using HandlerWithContentReader = + std::function; + + using Expect100ContinueHandler = std::function; + + Server(); + + virtual ~Server(); + + virtual bool is_valid() const; + + Server &Get(const std::string &pattern, Handler handler); + Server &Post(const std::string &pattern, Handler handler); + Server &Post(const std::string &pattern, HandlerWithContentReader handler); + Server &Put(const std::string &pattern, Handler handler); + Server &Put(const std::string &pattern, HandlerWithContentReader handler); + Server &Patch(const std::string &pattern, Handler handler); + Server &Patch(const std::string &pattern, HandlerWithContentReader handler); + Server &Delete(const std::string &pattern, Handler handler); + Server &Delete(const std::string &pattern, HandlerWithContentReader handler); + Server &Options(const std::string &pattern, Handler handler); + + bool set_base_dir(const std::string &dir, const std::string &mount_point = std::string()); + bool set_mount_point(const std::string &mount_point, const std::string &dir, Headers headers = Headers()); + bool remove_mount_point(const std::string &mount_point); + Server &set_file_extension_and_mimetype_mapping(const std::string &ext, const std::string &mime); + Server &set_default_file_mimetype(const std::string &mime); + Server &set_file_request_handler(Handler handler); + + template Server &set_error_handler(ErrorHandlerFunc &&handler) { + return set_error_handler_core(std::forward(handler), + std::is_convertible{}); + } + + Server &set_exception_handler(ExceptionHandler handler); + Server &set_pre_routing_handler(HandlerWithResponse handler); + Server &set_post_routing_handler(Handler handler); + + Server &set_expect_100_continue_handler(Expect100ContinueHandler handler); + Server &set_logger(Logger logger); + + Server &set_address_family(int family); + Server &set_tcp_nodelay(bool on); + Server &set_ipv6_v6only(bool on); + Server &set_socket_options(SocketOptions socket_options); + + Server &set_default_headers(Headers headers); + Server &set_header_writer(std::function const &writer); + + Server &set_keep_alive_max_count(size_t count); + Server &set_keep_alive_timeout(time_t sec); + + Server &set_read_timeout(time_t sec, time_t usec = 0); + template Server &set_read_timeout(const std::chrono::duration &duration); + + Server &set_write_timeout(time_t sec, time_t usec = 0); + template Server &set_write_timeout(const std::chrono::duration &duration); + + Server &set_idle_interval(time_t sec, time_t usec = 0); + template Server &set_idle_interval(const std::chrono::duration &duration); + + Server &set_payload_max_length(size_t length); + + bool bind_to_port(const std::string &host, int port, int socket_flags = 0); + int bind_to_any_port(const std::string &host, int socket_flags = 0); + bool listen_after_bind(); + + bool listen(const std::string &host, int port, int socket_flags = 0); + + bool is_running() const; + void wait_until_ready() const; + void stop(); + void decommission(); + + std::function new_task_queue; + + protected: + bool process_request(Stream &strm, const std::string &remote_addr, int remote_port, const std::string &local_addr, + int local_port, bool close_connection, bool &connection_closed, + const std::function &setup_request); + + std::atomic svr_sock_{INVALID_SOCKET}; + size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; + time_t keep_alive_timeout_sec_ = CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND; + time_t idle_interval_sec_ = CPPHTTPLIB_IDLE_INTERVAL_SECOND; + time_t idle_interval_usec_ = CPPHTTPLIB_IDLE_INTERVAL_USECOND; + size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; + + private: + using Handlers = std::vector, Handler>>; + using HandlersForContentReader = + std::vector, HandlerWithContentReader>>; + + static std::unique_ptr make_matcher(const std::string &pattern); + + Server &set_error_handler_core(HandlerWithResponse handler, std::true_type); + Server &set_error_handler_core(Handler handler, std::false_type); + + socket_t create_server_socket(const std::string &host, int port, int socket_flags, + SocketOptions socket_options) const; + int bind_internal(const std::string &host, int port, int socket_flags); + bool listen_internal(); + + bool routing(Request &req, Response &res, Stream &strm); + bool handle_file_request(const Request &req, Response &res, bool head = false); + bool dispatch_request(Request &req, Response &res, const Handlers &handlers) const; + bool dispatch_request_for_content_reader(Request &req, Response &res, ContentReader content_reader, + const HandlersForContentReader &handlers) const; + + bool parse_request_line(const char *s, Request &req) const; + void apply_ranges(const Request &req, Response &res, std::string &content_type, std::string &boundary) const; + bool write_response(Stream &strm, bool close_connection, Request &req, Response &res); + bool write_response_with_content(Stream &strm, bool close_connection, const Request &req, Response &res); + bool write_response_core(Stream &strm, bool close_connection, const Request &req, Response &res, + bool need_apply_ranges); + bool write_content_with_provider(Stream &strm, const Request &req, Response &res, const std::string &boundary, + const std::string &content_type); + bool read_content(Stream &strm, Request &req, Response &res); + bool read_content_with_content_receiver(Stream &strm, Request &req, Response &res, ContentReceiver receiver, + MultipartContentHeader multipart_header, ContentReceiver multipart_receiver); + bool read_content_core(Stream &strm, Request &req, Response &res, ContentReceiver receiver, + MultipartContentHeader multipart_header, ContentReceiver multipart_receiver) const; + + virtual bool process_and_close_socket(socket_t sock); + + std::atomic is_running_{false}; + std::atomic is_decommisioned{false}; + + struct MountPointEntry { + std::string mount_point; + std::string base_dir; + Headers headers; + }; + std::vector base_dirs_; + std::map file_extension_and_mimetype_map_; + std::string default_file_mimetype_ = "application/octet-stream"; + Handler file_request_handler_; + + Handlers get_handlers_; + Handlers post_handlers_; + HandlersForContentReader post_handlers_for_content_reader_; + Handlers put_handlers_; + HandlersForContentReader put_handlers_for_content_reader_; + Handlers patch_handlers_; + HandlersForContentReader patch_handlers_for_content_reader_; + Handlers delete_handlers_; + HandlersForContentReader delete_handlers_for_content_reader_; + Handlers options_handlers_; + + HandlerWithResponse error_handler_; + ExceptionHandler exception_handler_; + HandlerWithResponse pre_routing_handler_; + Handler post_routing_handler_; + Expect100ContinueHandler expect_100_continue_handler_; + + Logger logger_; + + int address_family_ = AF_UNSPEC; + bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + bool ipv6_v6only_ = CPPHTTPLIB_IPV6_V6ONLY; + SocketOptions socket_options_ = default_socket_options; + + Headers default_headers_; + std::function header_writer_ = detail::write_headers; +}; + +enum class Error { + Success = 0, + Unknown, + Connection, + BindIPAddress, + Read, + Write, + ExceedRedirectCount, + Canceled, + SSLConnection, + SSLLoadingCerts, + SSLServerVerification, + SSLServerHostnameVerification, + UnsupportedMultipartBoundaryChars, + Compression, + ConnectionTimeout, + ProxyConnection, + + // For internal use only + SSLPeerCouldBeClosed_, +}; + +std::string to_string(Error error); + +std::ostream &operator<<(std::ostream &os, const Error &obj); + +class Result { + public: + Result() = default; + Result(std::unique_ptr &&res, Error err, Headers &&request_headers = Headers{}) + : res_(std::move(res)), err_(err), request_headers_(std::move(request_headers)) {} + // Response + operator bool() const { return res_ != nullptr; } + bool operator==(std::nullptr_t) const { return res_ == nullptr; } + bool operator!=(std::nullptr_t) const { return res_ != nullptr; } + const Response &value() const { return *res_; } + Response &value() { return *res_; } + const Response &operator*() const { return *res_; } + Response &operator*() { return *res_; } + const Response *operator->() const { return res_.get(); } + Response *operator->() { return res_.get(); } + + // Error + Error error() const { return err_; } + + // Request Headers + bool has_request_header(const std::string &key) const; + std::string get_request_header_value(const std::string &key, const char *def = "", size_t id = 0) const; + uint64_t get_request_header_value_u64(const std::string &key, uint64_t def = 0, size_t id = 0) const; + size_t get_request_header_value_count(const std::string &key) const; + + private: + std::unique_ptr res_; + Error err_ = Error::Unknown; + Headers request_headers_; +}; + +class ClientImpl { + public: + explicit ClientImpl(const std::string &host); + + explicit ClientImpl(const std::string &host, int port); + + explicit ClientImpl(const std::string &host, int port, const std::string &client_cert_path, + const std::string &client_key_path); + + virtual ~ClientImpl(); + + virtual bool is_valid() const; + + Result Get(const std::string &path); + Result Get(const std::string &path, const Headers &headers); + Result Get(const std::string &path, Progress progress); + Result Get(const std::string &path, const Headers &headers, Progress progress); + Result Get(const std::string &path, ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver); + Result Get(const std::string &path, ContentReceiver content_receiver, Progress progress); + Result Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver, Progress progress); + Result Get(const std::string &path, ResponseHandler response_handler, ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const std::string &path, ResponseHandler response_handler, ContentReceiver content_receiver, + Progress progress); + Result Get(const std::string &path, const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress); + + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, Progress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, ContentReceiver content_receiver, + Progress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress = nullptr); + + Result Head(const std::string &path); + Result Head(const std::string &path, const Headers &headers); + + Result Post(const std::string &path); + Result Post(const std::string &path, const Headers &headers); + Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Post(const std::string &path, const std::string &body, const std::string &content_type); + Result Post(const std::string &path, const std::string &body, const std::string &content_type, Progress progress); + Result Post(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, + Progress progress); + Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, const Params ¶ms, Progress progress); + Result Post(const std::string &path, const MultipartFormDataItems &items); + Result Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items); + Result Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const std::string &boundary); + Result Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items); + + Result Put(const std::string &path); + Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Put(const std::string &path, const std::string &body, const std::string &content_type); + Result Put(const std::string &path, const std::string &body, const std::string &content_type, Progress progress); + Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, + Progress progress); + Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, const Params ¶ms, Progress progress); + Result Put(const std::string &path, const MultipartFormDataItems &items); + Result Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items); + Result Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const std::string &boundary); + Result Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items); + + Result Patch(const std::string &path); + Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, + Progress progress); + Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Patch(const std::string &path, const std::string &body, const std::string &content_type); + Result Patch(const std::string &path, const std::string &body, const std::string &content_type, Progress progress); + Result Patch(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type, Progress progress); + Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, + const std::string &content_type); + + Result Delete(const std::string &path); + Result Delete(const std::string &path, const Headers &headers); + Result Delete(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Delete(const std::string &path, const char *body, size_t content_length, const std::string &content_type, + Progress progress); + Result Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Delete(const std::string &path, const std::string &body, const std::string &content_type); + Result Delete(const std::string &path, const std::string &body, const std::string &content_type, Progress progress); + Result Delete(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type, Progress progress); + + Result Options(const std::string &path); + Result Options(const std::string &path, const Headers &headers); + + bool send(Request &req, Response &res, Error &error); + Result send(const Request &req); + + void stop(); + + std::string host() const; + int port() const; + + size_t is_socket_open() const; + socket_t socket() const; + + void set_hostname_addr_map(std::map addr_map); + + void set_default_headers(Headers headers); + + void set_header_writer(std::function const &writer); + + void set_address_family(int family); + void set_tcp_nodelay(bool on); + void set_ipv6_v6only(bool on); + void set_socket_options(SocketOptions socket_options); + + void set_connection_timeout(time_t sec, time_t usec = 0); + template void set_connection_timeout(const std::chrono::duration &duration); + + void set_read_timeout(time_t sec, time_t usec = 0); + template void set_read_timeout(const std::chrono::duration &duration); + + void set_write_timeout(time_t sec, time_t usec = 0); + template void set_write_timeout(const std::chrono::duration &duration); + + void set_basic_auth(const std::string &username, const std::string &password); + void set_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_digest_auth(const std::string &username, const std::string &password); +#endif + + void set_keep_alive(bool on); + void set_follow_location(bool on); + + void set_url_encode(bool on); + + void set_compress(bool on); + + void set_decompress(bool on); + + void set_interface(const std::string &intf); + + void set_proxy(const std::string &host, int port); + void set_proxy_basic_auth(const std::string &username, const std::string &password); + void set_proxy_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_proxy_digest_auth(const std::string &username, const std::string &password); +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_ca_cert_path(const std::string &ca_cert_file_path, const std::string &ca_cert_dir_path = std::string()); + void set_ca_cert_store(X509_STORE *ca_cert_store); + X509_STORE *create_ca_cert_store(const char *ca_cert, std::size_t size) const; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void enable_server_certificate_verification(bool enabled); + void enable_server_hostname_verification(bool enabled); + void set_server_certificate_verifier(std::function verifier); +#endif + + void set_logger(Logger logger); + + protected: + struct Socket { + socket_t sock = INVALID_SOCKET; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSL *ssl = nullptr; +#endif + + bool is_open() const { return sock != INVALID_SOCKET; } + }; + + virtual bool create_and_connect_socket(Socket &socket, Error &error); + + // All of: + // shutdown_ssl + // shutdown_socket + // close_socket + // should ONLY be called when socket_mutex_ is locked. + // Also, shutdown_ssl and close_socket should also NOT be called concurrently + // with a DIFFERENT thread sending requests using that socket. + virtual void shutdown_ssl(Socket &socket, bool shutdown_gracefully); + void shutdown_socket(Socket &socket) const; + void close_socket(Socket &socket); + + bool process_request(Stream &strm, Request &req, Response &res, bool close_connection, Error &error); + + bool write_content_with_provider(Stream &strm, const Request &req, Error &error) const; + + void copy_settings(const ClientImpl &rhs); + + // Socket endpoint information + const std::string host_; + const int port_; + const std::string host_and_port_; + + // Current open socket + Socket socket_; + mutable std::mutex socket_mutex_; + std::recursive_mutex request_mutex_; + + // These are all protected under socket_mutex + size_t socket_requests_in_flight_ = 0; + std::thread::id socket_requests_are_from_thread_ = std::thread::id(); + bool socket_should_be_closed_when_request_is_done_ = false; + + // Hostname-IP map + std::map addr_map_; + + // Default headers + Headers default_headers_; + + // Header writer + std::function header_writer_ = detail::write_headers; + + // Settings + std::string client_cert_path_; + std::string client_key_path_; + + time_t connection_timeout_sec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND; + time_t connection_timeout_usec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND; + + std::string basic_auth_username_; + std::string basic_auth_password_; + std::string bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string digest_auth_username_; + std::string digest_auth_password_; +#endif + + bool keep_alive_ = false; + bool follow_location_ = false; + + bool url_encode_ = true; + + int address_family_ = AF_UNSPEC; + bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + bool ipv6_v6only_ = CPPHTTPLIB_IPV6_V6ONLY; + SocketOptions socket_options_ = nullptr; + + bool compress_ = false; + bool decompress_ = true; + + std::string interface_; + + std::string proxy_host_; + int proxy_port_ = -1; + + std::string proxy_basic_auth_username_; + std::string proxy_basic_auth_password_; + std::string proxy_bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string proxy_digest_auth_username_; + std::string proxy_digest_auth_password_; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string ca_cert_file_path_; + std::string ca_cert_dir_path_; + + X509_STORE *ca_cert_store_ = nullptr; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool server_certificate_verification_ = true; + bool server_hostname_verification_ = true; + std::function server_certificate_verifier_; +#endif + + Logger logger_; + + private: + bool send_(Request &req, Response &res, Error &error); + Result send_(Request &&req); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool is_ssl_peer_could_be_closed(SSL *ssl) const; +#endif + socket_t create_client_socket(Error &error) const; + bool read_response_line(Stream &strm, const Request &req, Response &res) const; + bool write_request(Stream &strm, Request &req, bool close_connection, Error &error); + bool redirect(Request &req, Response &res, Error &error); + bool handle_request(Stream &strm, Request &req, Response &res, bool close_connection, Error &error); + std::unique_ptr send_with_content_provider(Request &req, const char *body, size_t content_length, + ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, Error &error); + Result send_with_content_provider(const std::string &method, const std::string &path, const Headers &headers, + const char *body, size_t content_length, ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, Progress progress); + ContentProviderWithoutLength get_multipart_content_provider( + const std::string &boundary, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) const; + + std::string adjust_host_string(const std::string &host) const; + + virtual bool process_socket(const Socket &socket, std::function callback); + virtual bool is_ssl() const; +}; + +class Client { + public: + // Universal interface + explicit Client(const std::string &scheme_host_port); + + explicit Client(const std::string &scheme_host_port, const std::string &client_cert_path, + const std::string &client_key_path); + + // HTTP only interface + explicit Client(const std::string &host, int port); + + explicit Client(const std::string &host, int port, const std::string &client_cert_path, + const std::string &client_key_path); + + Client(Client &&) = default; + Client &operator=(Client &&) = default; + + ~Client(); + + bool is_valid() const; + + Result Get(const std::string &path); + Result Get(const std::string &path, const Headers &headers); + Result Get(const std::string &path, Progress progress); + Result Get(const std::string &path, const Headers &headers, Progress progress); + Result Get(const std::string &path, ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver); + Result Get(const std::string &path, ContentReceiver content_receiver, Progress progress); + Result Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver, Progress progress); + Result Get(const std::string &path, ResponseHandler response_handler, ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress); + Result Get(const std::string &path, ResponseHandler response_handler, ContentReceiver content_receiver, + Progress progress); + + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, Progress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, ContentReceiver content_receiver, + Progress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress = nullptr); + + Result Head(const std::string &path); + Result Head(const std::string &path, const Headers &headers); + + Result Post(const std::string &path); + Result Post(const std::string &path, const Headers &headers); + Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Post(const std::string &path, const std::string &body, const std::string &content_type); + Result Post(const std::string &path, const std::string &body, const std::string &content_type, Progress progress); + Result Post(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, + Progress progress); + Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, const Params ¶ms, Progress progress); + Result Post(const std::string &path, const MultipartFormDataItems &items); + Result Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items); + Result Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const std::string &boundary); + Result Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items); + + Result Put(const std::string &path); + Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Put(const std::string &path, const std::string &body, const std::string &content_type); + Result Put(const std::string &path, const std::string &body, const std::string &content_type, Progress progress); + Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, + Progress progress); + Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, const Params ¶ms, Progress progress); + Result Put(const std::string &path, const MultipartFormDataItems &items); + Result Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items); + Result Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const std::string &boundary); + Result Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items); + + Result Patch(const std::string &path); + Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, + Progress progress); + Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Patch(const std::string &path, const std::string &body, const std::string &content_type); + Result Patch(const std::string &path, const std::string &body, const std::string &content_type, Progress progress); + Result Patch(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type, Progress progress); + Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, + const std::string &content_type); + + Result Delete(const std::string &path); + Result Delete(const std::string &path, const Headers &headers); + Result Delete(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Delete(const std::string &path, const char *body, size_t content_length, const std::string &content_type, + Progress progress); + Result Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Delete(const std::string &path, const std::string &body, const std::string &content_type); + Result Delete(const std::string &path, const std::string &body, const std::string &content_type, Progress progress); + Result Delete(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type, Progress progress); + + Result Options(const std::string &path); + Result Options(const std::string &path, const Headers &headers); + + bool send(Request &req, Response &res, Error &error); + Result send(const Request &req); + + void stop(); + + std::string host() const; + int port() const; + + size_t is_socket_open() const; + socket_t socket() const; + + void set_hostname_addr_map(std::map addr_map); + + void set_default_headers(Headers headers); + + void set_header_writer(std::function const &writer); + + void set_address_family(int family); + void set_tcp_nodelay(bool on); + void set_socket_options(SocketOptions socket_options); + + void set_connection_timeout(time_t sec, time_t usec = 0); + template void set_connection_timeout(const std::chrono::duration &duration); + + void set_read_timeout(time_t sec, time_t usec = 0); + template void set_read_timeout(const std::chrono::duration &duration); + + void set_write_timeout(time_t sec, time_t usec = 0); + template void set_write_timeout(const std::chrono::duration &duration); + + void set_basic_auth(const std::string &username, const std::string &password); + void set_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_digest_auth(const std::string &username, const std::string &password); +#endif + + void set_keep_alive(bool on); + void set_follow_location(bool on); + + void set_url_encode(bool on); + + void set_compress(bool on); + + void set_decompress(bool on); + + void set_interface(const std::string &intf); + + void set_proxy(const std::string &host, int port); + void set_proxy_basic_auth(const std::string &username, const std::string &password); + void set_proxy_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_proxy_digest_auth(const std::string &username, const std::string &password); +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void enable_server_certificate_verification(bool enabled); + void enable_server_hostname_verification(bool enabled); + void set_server_certificate_verifier(std::function verifier); +#endif + + void set_logger(Logger logger); + + // SSL +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_ca_cert_path(const std::string &ca_cert_file_path, const std::string &ca_cert_dir_path = std::string()); + + void set_ca_cert_store(X509_STORE *ca_cert_store); + void load_ca_cert_store(const char *ca_cert, std::size_t size); + + long get_openssl_verify_result() const; + + SSL_CTX *ssl_context() const; +#endif + + private: + std::unique_ptr cli_; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool is_ssl_ = false; +#endif +}; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +class SSLServer : public Server { + public: + SSLServer(const char *cert_path, const char *private_key_path, const char *client_ca_cert_file_path = nullptr, + const char *client_ca_cert_dir_path = nullptr, const char *private_key_password = nullptr); + + SSLServer(X509 *cert, EVP_PKEY *private_key, X509_STORE *client_ca_cert_store = nullptr); + + SSLServer(const std::function &setup_ssl_ctx_callback); + + ~SSLServer() override; + + bool is_valid() const override; + + SSL_CTX *ssl_context() const; + + void update_certs(X509 *cert, EVP_PKEY *private_key, X509_STORE *client_ca_cert_store = nullptr); + + private: + bool process_and_close_socket(socket_t sock) override; + + SSL_CTX *ctx_; + std::mutex ctx_mutex_; +}; + +class SSLClient final : public ClientImpl { + public: + explicit SSLClient(const std::string &host); + + explicit SSLClient(const std::string &host, int port); + + explicit SSLClient(const std::string &host, int port, const std::string &client_cert_path, + const std::string &client_key_path, const std::string &private_key_password = std::string()); + + explicit SSLClient(const std::string &host, int port, X509 *client_cert, EVP_PKEY *client_key, + const std::string &private_key_password = std::string()); + + ~SSLClient() override; + + bool is_valid() const override; + + void set_ca_cert_store(X509_STORE *ca_cert_store); + void load_ca_cert_store(const char *ca_cert, std::size_t size); + + long get_openssl_verify_result() const; + + SSL_CTX *ssl_context() const; + + private: + bool create_and_connect_socket(Socket &socket, Error &error) override; + void shutdown_ssl(Socket &socket, bool shutdown_gracefully) override; + void shutdown_ssl_impl(Socket &socket, bool shutdown_gracefully); + + bool process_socket(const Socket &socket, std::function callback) override; + bool is_ssl() const override; + + bool connect_with_proxy(Socket &sock, Response &res, bool &success, Error &error); + bool initialize_ssl(Socket &socket, Error &error); + + bool load_certs(); + + bool verify_host(X509 *server_cert) const; + bool verify_host_with_subject_alt_name(X509 *server_cert) const; + bool verify_host_with_common_name(X509 *server_cert) const; + bool check_host_name(const char *pattern, size_t pattern_len) const; + + SSL_CTX *ctx_; + std::mutex ctx_mutex_; + std::once_flag initialize_cert_; + + std::vector host_components_; + + long verify_result_ = 0; + + friend class ClientImpl; +}; +#endif + +/* + * Implementation of template methods. + */ + +namespace detail { + +template inline void duration_to_sec_and_usec(const T &duration, U callback) { + auto sec = std::chrono::duration_cast(duration).count(); + auto usec = std::chrono::duration_cast(duration - std::chrono::seconds(sec)).count(); + callback(static_cast(sec), static_cast(usec)); +} + +inline uint64_t get_header_value_u64(const Headers &headers, const std::string &key, uint64_t def, size_t id) { + auto rng = headers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { + return std::strtoull(it->second.data(), nullptr, 10); + } + return def; +} + +} // namespace detail + +inline uint64_t Request::get_header_value_u64(const std::string &key, uint64_t def, size_t id) const { + return detail::get_header_value_u64(headers, key, def, id); +} + +inline uint64_t Response::get_header_value_u64(const std::string &key, uint64_t def, size_t id) const { + return detail::get_header_value_u64(headers, key, def, id); +} + +inline void default_socket_options(socket_t sock) { + int opt = 1; +#ifdef _WIN32 + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&opt), sizeof(opt)); + setsockopt(sock, SOL_SOCKET, SO_EXCLUSIVEADDRUSE, reinterpret_cast(&opt), sizeof(opt)); +#else +#ifdef SO_REUSEPORT + setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast(&opt), sizeof(opt)); +#else + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&opt), sizeof(opt)); +#endif +#endif +} + +inline const char *status_message(int status) { + switch (status) { + case StatusCode::Continue_100: + return "Continue"; + case StatusCode::SwitchingProtocol_101: + return "Switching Protocol"; + case StatusCode::Processing_102: + return "Processing"; + case StatusCode::EarlyHints_103: + return "Early Hints"; + case StatusCode::OK_200: + return "OK"; + case StatusCode::Created_201: + return "Created"; + case StatusCode::Accepted_202: + return "Accepted"; + case StatusCode::NonAuthoritativeInformation_203: + return "Non-Authoritative Information"; + case StatusCode::NoContent_204: + return "No Content"; + case StatusCode::ResetContent_205: + return "Reset Content"; + case StatusCode::PartialContent_206: + return "Partial Content"; + case StatusCode::MultiStatus_207: + return "Multi-Status"; + case StatusCode::AlreadyReported_208: + return "Already Reported"; + case StatusCode::IMUsed_226: + return "IM Used"; + case StatusCode::MultipleChoices_300: + return "Multiple Choices"; + case StatusCode::MovedPermanently_301: + return "Moved Permanently"; + case StatusCode::Found_302: + return "Found"; + case StatusCode::SeeOther_303: + return "See Other"; + case StatusCode::NotModified_304: + return "Not Modified"; + case StatusCode::UseProxy_305: + return "Use Proxy"; + case StatusCode::unused_306: + return "unused"; + case StatusCode::TemporaryRedirect_307: + return "Temporary Redirect"; + case StatusCode::PermanentRedirect_308: + return "Permanent Redirect"; + case StatusCode::BadRequest_400: + return "Bad Request"; + case StatusCode::Unauthorized_401: + return "Unauthorized"; + case StatusCode::PaymentRequired_402: + return "Payment Required"; + case StatusCode::Forbidden_403: + return "Forbidden"; + case StatusCode::NotFound_404: + return "Not Found"; + case StatusCode::MethodNotAllowed_405: + return "Method Not Allowed"; + case StatusCode::NotAcceptable_406: + return "Not Acceptable"; + case StatusCode::ProxyAuthenticationRequired_407: + return "Proxy Authentication Required"; + case StatusCode::RequestTimeout_408: + return "Request Timeout"; + case StatusCode::Conflict_409: + return "Conflict"; + case StatusCode::Gone_410: + return "Gone"; + case StatusCode::LengthRequired_411: + return "Length Required"; + case StatusCode::PreconditionFailed_412: + return "Precondition Failed"; + case StatusCode::PayloadTooLarge_413: + return "Payload Too Large"; + case StatusCode::UriTooLong_414: + return "URI Too Long"; + case StatusCode::UnsupportedMediaType_415: + return "Unsupported Media Type"; + case StatusCode::RangeNotSatisfiable_416: + return "Range Not Satisfiable"; + case StatusCode::ExpectationFailed_417: + return "Expectation Failed"; + case StatusCode::ImATeapot_418: + return "I'm a teapot"; + case StatusCode::MisdirectedRequest_421: + return "Misdirected Request"; + case StatusCode::UnprocessableContent_422: + return "Unprocessable Content"; + case StatusCode::Locked_423: + return "Locked"; + case StatusCode::FailedDependency_424: + return "Failed Dependency"; + case StatusCode::TooEarly_425: + return "Too Early"; + case StatusCode::UpgradeRequired_426: + return "Upgrade Required"; + case StatusCode::PreconditionRequired_428: + return "Precondition Required"; + case StatusCode::TooManyRequests_429: + return "Too Many Requests"; + case StatusCode::RequestHeaderFieldsTooLarge_431: + return "Request Header Fields Too Large"; + case StatusCode::UnavailableForLegalReasons_451: + return "Unavailable For Legal Reasons"; + case StatusCode::NotImplemented_501: + return "Not Implemented"; + case StatusCode::BadGateway_502: + return "Bad Gateway"; + case StatusCode::ServiceUnavailable_503: + return "Service Unavailable"; + case StatusCode::GatewayTimeout_504: + return "Gateway Timeout"; + case StatusCode::HttpVersionNotSupported_505: + return "HTTP Version Not Supported"; + case StatusCode::VariantAlsoNegotiates_506: + return "Variant Also Negotiates"; + case StatusCode::InsufficientStorage_507: + return "Insufficient Storage"; + case StatusCode::LoopDetected_508: + return "Loop Detected"; + case StatusCode::NotExtended_510: + return "Not Extended"; + case StatusCode::NetworkAuthenticationRequired_511: + return "Network Authentication Required"; + + default: + case StatusCode::InternalServerError_500: + return "Internal Server Error"; + } +} + +inline std::string get_bearer_token_auth(const Request &req) { + if (req.has_header("Authorization")) { + static std::string BearerHeaderPrefix = "Bearer "; + return req.get_header_value("Authorization").substr(BearerHeaderPrefix.length()); + } + return ""; +} + +template +inline Server &Server::set_read_timeout(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t usec) { set_read_timeout(sec, usec); }); + return *this; +} + +template +inline Server &Server::set_write_timeout(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t usec) { set_write_timeout(sec, usec); }); + return *this; +} + +template +inline Server &Server::set_idle_interval(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t usec) { set_idle_interval(sec, usec); }); + return *this; +} + +inline std::string to_string(const Error error) { + switch (error) { + case Error::Success: + return "Success (no error)"; + case Error::Connection: + return "Could not establish connection"; + case Error::BindIPAddress: + return "Failed to bind IP address"; + case Error::Read: + return "Failed to read connection"; + case Error::Write: + return "Failed to write connection"; + case Error::ExceedRedirectCount: + return "Maximum redirect count exceeded"; + case Error::Canceled: + return "Connection handling canceled"; + case Error::SSLConnection: + return "SSL connection failed"; + case Error::SSLLoadingCerts: + return "SSL certificate loading failed"; + case Error::SSLServerVerification: + return "SSL server verification failed"; + case Error::SSLServerHostnameVerification: + return "SSL server hostname verification failed"; + case Error::UnsupportedMultipartBoundaryChars: + return "Unsupported HTTP multipart boundary characters"; + case Error::Compression: + return "Compression failed"; + case Error::ConnectionTimeout: + return "Connection timed out"; + case Error::ProxyConnection: + return "Proxy connection failed"; + case Error::Unknown: + return "Unknown"; + default: + break; + } + + return "Invalid"; +} + +inline std::ostream &operator<<(std::ostream &os, const Error &obj) { + os << to_string(obj); + os << " (" << static_cast::type>(obj) << ')'; + return os; +} + +inline uint64_t Result::get_request_header_value_u64(const std::string &key, uint64_t def, size_t id) const { + return detail::get_header_value_u64(request_headers_, key, def, id); +} + +template +inline void ClientImpl::set_connection_timeout(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t usec) { set_connection_timeout(sec, usec); }); +} + +template +inline void ClientImpl::set_read_timeout(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t usec) { set_read_timeout(sec, usec); }); +} + +template +inline void ClientImpl::set_write_timeout(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t usec) { set_write_timeout(sec, usec); }); +} + +template +inline void Client::set_connection_timeout(const std::chrono::duration &duration) { + cli_->set_connection_timeout(duration); +} + +template +inline void Client::set_read_timeout(const std::chrono::duration &duration) { + cli_->set_read_timeout(duration); +} + +template +inline void Client::set_write_timeout(const std::chrono::duration &duration) { + cli_->set_write_timeout(duration); +} + +/* + * Forward declarations and types that will be part of the .h file if split into + * .h + .cc. + */ + +std::string hosted_at(const std::string &hostname); + +void hosted_at(const std::string &hostname, std::vector &addrs); + +std::string append_query_params(const std::string &path, const Params ¶ms); + +std::pair make_range_header(const Ranges &ranges); + +std::pair make_basic_authentication_header(const std::string &username, + const std::string &password, + bool is_proxy = false); + +namespace detail { + +#if defined(_WIN32) +inline std::wstring u8string_to_wstring(const char *s) { + std::wstring ws; + auto len = static_cast(strlen(s)); + auto wlen = ::MultiByteToWideChar(CP_UTF8, 0, s, len, nullptr, 0); + if (wlen > 0) { + ws.resize(wlen); + wlen = ::MultiByteToWideChar(CP_UTF8, 0, s, len, const_cast(reinterpret_cast(ws.data())), wlen); + if (wlen != static_cast(ws.size())) { + ws.clear(); + } + } + return ws; +} +#endif + +struct FileStat { + FileStat(const std::string &path); + bool is_file() const; + bool is_dir() const; + + private: +#if defined(_WIN32) + struct _stat st_; +#else + struct stat st_; +#endif + int ret_ = -1; +}; + +std::string encode_query_param(const std::string &value); + +std::string decode_url(const std::string &s, bool convert_plus_to_space); + +void read_file(const std::string &path, std::string &out); + +std::string trim_copy(const std::string &s); + +void divide(const char *data, std::size_t size, char d, + std::function fn); + +void divide(const std::string &str, char d, + std::function fn); + +void split(const char *b, const char *e, char d, std::function fn); + +void split(const char *b, const char *e, char d, size_t m, std::function fn); + +bool process_client_socket(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, std::function callback); + +socket_t create_client_socket(const std::string &host, const std::string &ip, int port, int address_family, + bool tcp_nodelay, bool ipv6_v6only, SocketOptions socket_options, + time_t connection_timeout_sec, time_t connection_timeout_usec, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, time_t write_timeout_usec, + const std::string &intf, Error &error); + +const char *get_header_value(const Headers &headers, const std::string &key, const char *def, size_t id); + +std::string params_to_query_str(const Params ¶ms); + +void parse_query_text(const char *data, std::size_t size, Params ¶ms); + +void parse_query_text(const std::string &s, Params ¶ms); + +bool parse_multipart_boundary(const std::string &content_type, std::string &boundary); + +bool parse_range_header(const std::string &s, Ranges &ranges); + +int close_socket(socket_t sock); + +ssize_t send_socket(socket_t sock, const void *ptr, size_t size, int flags); + +ssize_t read_socket(socket_t sock, void *ptr, size_t size, int flags); + +enum class EncodingType { None = 0, Gzip, Brotli }; + +EncodingType encoding_type(const Request &req, const Response &res); + +class BufferStream final : public Stream { + public: + BufferStream() = default; + ~BufferStream() override = default; + + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + + const std::string &get_buffer() const; + + private: + std::string buffer; + size_t position = 0; +}; + +class compressor { + public: + virtual ~compressor() = default; + + typedef std::function Callback; + virtual bool compress(const char *data, size_t data_length, bool last, Callback callback) = 0; +}; + +class decompressor { + public: + virtual ~decompressor() = default; + + virtual bool is_valid() const = 0; + + typedef std::function Callback; + virtual bool decompress(const char *data, size_t data_length, Callback callback) = 0; +}; + +class nocompressor final : public compressor { + public: + ~nocompressor() override = default; + + bool compress(const char *data, size_t data_length, bool /*last*/, Callback callback) override; +}; + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +class gzip_compressor final : public compressor { + public: + gzip_compressor(); + ~gzip_compressor() override; + + bool compress(const char *data, size_t data_length, bool last, Callback callback) override; + + private: + bool is_valid_ = false; + z_stream strm_; +}; + +class gzip_decompressor final : public decompressor { + public: + gzip_decompressor(); + ~gzip_decompressor() override; + + bool is_valid() const override; + + bool decompress(const char *data, size_t data_length, Callback callback) override; + + private: + bool is_valid_ = false; + z_stream strm_; +}; +#endif + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT +class brotli_compressor final : public compressor { + public: + brotli_compressor(); + ~brotli_compressor(); + + bool compress(const char *data, size_t data_length, bool last, Callback callback) override; + + private: + BrotliEncoderState *state_ = nullptr; +}; + +class brotli_decompressor final : public decompressor { + public: + brotli_decompressor(); + ~brotli_decompressor(); + + bool is_valid() const override; + + bool decompress(const char *data, size_t data_length, Callback callback) override; + + private: + BrotliDecoderResult decoder_r; + BrotliDecoderState *decoder_s = nullptr; +}; +#endif + +// NOTE: until the read size reaches `fixed_buffer_size`, use `fixed_buffer` +// to store data. The call can set memory on stack for performance. +class stream_line_reader { + public: + stream_line_reader(Stream &strm, char *fixed_buffer, size_t fixed_buffer_size); + const char *ptr() const; + size_t size() const; + bool end_with_crlf() const; + bool getline(); + + private: + void append(char c); + + Stream &strm_; + char *fixed_buffer_; + const size_t fixed_buffer_size_; + size_t fixed_buffer_used_size_ = 0; + std::string glowable_buffer_; +}; + +class mmap { + public: + mmap(const char *path); + ~mmap(); + + bool open(const char *path); + void close(); + + bool is_open() const; + size_t size() const; + const char *data() const; + + private: +#if defined(_WIN32) + HANDLE hFile_ = NULL; + HANDLE hMapping_ = NULL; +#else + int fd_ = -1; +#endif + size_t size_ = 0; + void *addr_ = nullptr; + bool is_open_empty_file = false; +}; + +} // namespace detail + +// ---------------------------------------------------------------------------- + +/* + * Implementation that will be part of the .cc file if split into .h + .cc. + */ + +namespace detail { + +inline bool is_hex(char c, int &v) { + if (0x20 <= c && isdigit(c)) { + v = c - '0'; + return true; + } else if ('A' <= c && c <= 'F') { + v = c - 'A' + 10; + return true; + } else if ('a' <= c && c <= 'f') { + v = c - 'a' + 10; + return true; + } + return false; +} + +inline bool from_hex_to_i(const std::string &s, size_t i, size_t cnt, int &val) { + if (i >= s.size()) { + return false; + } + + val = 0; + for (; cnt; i++, cnt--) { + if (!s[i]) { + return false; + } + auto v = 0; + if (is_hex(s[i], v)) { + val = val * 16 + v; + } else { + return false; + } + } + return true; +} + +inline std::string from_i_to_hex(size_t n) { + static const auto charset = "0123456789abcdef"; + std::string ret; + do { + ret = charset[n & 15] + ret; + n >>= 4; + } while (n > 0); + return ret; +} + +inline size_t to_utf8(int code, char *buff) { + if (code < 0x0080) { + buff[0] = static_cast(code & 0x7F); + return 1; + } else if (code < 0x0800) { + buff[0] = static_cast(0xC0 | ((code >> 6) & 0x1F)); + buff[1] = static_cast(0x80 | (code & 0x3F)); + return 2; + } else if (code < 0xD800) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0xE000) { // D800 - DFFF is invalid... + return 0; + } else if (code < 0x10000) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0x110000) { + buff[0] = static_cast(0xF0 | ((code >> 18) & 0x7)); + buff[1] = static_cast(0x80 | ((code >> 12) & 0x3F)); + buff[2] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[3] = static_cast(0x80 | (code & 0x3F)); + return 4; + } + + // NOTREACHED + return 0; +} + +// NOTE: This code came up with the following stackoverflow post: +// https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c +inline std::string base64_encode(const std::string &in) { + static const auto lookup = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + std::string out; + out.reserve(in.size()); + + auto val = 0; + auto valb = -6; + + for (auto c : in) { + val = (val << 8) + static_cast(c); + valb += 8; + while (valb >= 0) { + out.push_back(lookup[(val >> valb) & 0x3F]); + valb -= 6; + } + } + + if (valb > -6) { + out.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]); + } + + while (out.size() % 4) { + out.push_back('='); + } + + return out; +} + +inline bool is_valid_path(const std::string &path) { + size_t level = 0; + size_t i = 0; + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + + while (i < path.size()) { + // Read component + auto beg = i; + while (i < path.size() && path[i] != '/') { + if (path[i] == '\0') { + return false; + } else if (path[i] == '\\') { + return false; + } + i++; + } + + auto len = i - beg; + assert(len > 0); + + if (!path.compare(beg, len, ".")) { + ; + } else if (!path.compare(beg, len, "..")) { + if (level == 0) { + return false; + } + level--; + } else { + level++; + } + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + } + + return true; +} + +inline FileStat::FileStat(const std::string &path) { +#if defined(_WIN32) + auto wpath = u8string_to_wstring(path.c_str()); + ret_ = _wstat(wpath.c_str(), &st_); +#else + ret_ = stat(path.c_str(), &st_); +#endif +} +inline bool FileStat::is_file() const { return ret_ >= 0 && S_ISREG(st_.st_mode); } +inline bool FileStat::is_dir() const { return ret_ >= 0 && S_ISDIR(st_.st_mode); } + +inline std::string encode_query_param(const std::string &value) { + std::ostringstream escaped; + escaped.fill('0'); + escaped << std::hex; + + for (auto c : value) { + if (std::isalnum(static_cast(c)) || c == '-' || c == '_' || c == '.' || c == '!' || c == '~' || c == '*' || + c == '\'' || c == '(' || c == ')') { + escaped << c; + } else { + escaped << std::uppercase; + escaped << '%' << std::setw(2) << static_cast(static_cast(c)); + escaped << std::nouppercase; + } + } + + return escaped.str(); +} + +inline std::string encode_url(const std::string &s) { + std::string result; + result.reserve(s.size()); + + for (size_t i = 0; s[i]; i++) { + switch (s[i]) { + case ' ': + result += "%20"; + break; + case '+': + result += "%2B"; + break; + case '\r': + result += "%0D"; + break; + case '\n': + result += "%0A"; + break; + case '\'': + result += "%27"; + break; + case ',': + result += "%2C"; + break; + // case ':': result += "%3A"; break; // ok? probably... + case ';': + result += "%3B"; + break; + default: + auto c = static_cast(s[i]); + if (c >= 0x80) { + result += '%'; + char hex[4]; + auto len = snprintf(hex, sizeof(hex) - 1, "%02X", c); + assert(len == 2); + result.append(hex, static_cast(len)); + } else { + result += s[i]; + } + break; + } + } + + return result; +} + +inline std::string decode_url(const std::string &s, bool convert_plus_to_space) { + std::string result; + + for (size_t i = 0; i < s.size(); i++) { + if (s[i] == '%' && i + 1 < s.size()) { + if (s[i + 1] == 'u') { + auto val = 0; + if (from_hex_to_i(s, i + 2, 4, val)) { + // 4 digits Unicode codes + char buff[4]; + size_t len = to_utf8(val, buff); + if (len > 0) { + result.append(buff, len); + } + i += 5; // 'u0000' + } else { + result += s[i]; + } + } else { + auto val = 0; + if (from_hex_to_i(s, i + 1, 2, val)) { + // 2 digits hex codes + result += static_cast(val); + i += 2; // '00' + } else { + result += s[i]; + } + } + } else if (convert_plus_to_space && s[i] == '+') { + result += ' '; + } else { + result += s[i]; + } + } + + return result; +} + +inline void read_file(const std::string &path, std::string &out) { + std::ifstream fs(path, std::ios_base::binary); + fs.seekg(0, std::ios_base::end); + auto size = fs.tellg(); + fs.seekg(0); + out.resize(static_cast(size)); + fs.read(&out[0], static_cast(size)); +} + +inline std::string file_extension(const std::string &path) { + std::smatch m; + static auto re = std::regex("\\.([a-zA-Z0-9]+)$"); + if (std::regex_search(path, m, re)) { + return m[1].str(); + } + return std::string(); +} + +inline bool is_space_or_tab(char c) { return c == ' ' || c == '\t'; } + +inline std::pair trim(const char *b, const char *e, size_t left, size_t right) { + while (b + left < e && is_space_or_tab(b[left])) { + left++; + } + while (right > 0 && is_space_or_tab(b[right - 1])) { + right--; + } + return std::make_pair(left, right); +} + +inline std::string trim_copy(const std::string &s) { + auto r = trim(s.data(), s.data() + s.size(), 0, s.size()); + return s.substr(r.first, r.second - r.first); +} + +inline std::string trim_double_quotes_copy(const std::string &s) { + if (s.length() >= 2 && s.front() == '"' && s.back() == '"') { + return s.substr(1, s.size() - 2); + } + return s; +} + +inline void divide(const char *data, std::size_t size, char d, + std::function fn) { + const auto it = std::find(data, data + size, d); + const auto found = static_cast(it != data + size); + const auto lhs_data = data; + const auto lhs_size = static_cast(it - data); + const auto rhs_data = it + found; + const auto rhs_size = size - lhs_size - found; + + fn(lhs_data, lhs_size, rhs_data, rhs_size); +} + +inline void divide(const std::string &str, char d, + std::function fn) { + divide(str.data(), str.size(), d, std::move(fn)); +} + +inline void split(const char *b, const char *e, char d, std::function fn) { + return split(b, e, d, (std::numeric_limits::max)(), std::move(fn)); +} + +inline void split(const char *b, const char *e, char d, size_t m, std::function fn) { + size_t i = 0; + size_t beg = 0; + size_t count = 1; + + while (e ? (b + i < e) : (b[i] != '\0')) { + if (b[i] == d && count < m) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { + fn(&b[r.first], &b[r.second]); + } + beg = i + 1; + count++; + } + i++; + } + + if (i) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { + fn(&b[r.first], &b[r.second]); + } + } +} + +inline stream_line_reader::stream_line_reader(Stream &strm, char *fixed_buffer, size_t fixed_buffer_size) + : strm_(strm), fixed_buffer_(fixed_buffer), fixed_buffer_size_(fixed_buffer_size) {} + +inline const char *stream_line_reader::ptr() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_; + } else { + return glowable_buffer_.data(); + } +} + +inline size_t stream_line_reader::size() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_used_size_; + } else { + return glowable_buffer_.size(); + } +} + +inline bool stream_line_reader::end_with_crlf() const { + auto end = ptr() + size(); + return size() >= 2 && end[-2] == '\r' && end[-1] == '\n'; +} + +inline bool stream_line_reader::getline() { + fixed_buffer_used_size_ = 0; + glowable_buffer_.clear(); + +#ifndef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + char prev_byte = 0; +#endif + + for (size_t i = 0;; i++) { + char byte; + auto n = strm_.read(&byte, 1); + + if (n < 0) { + return false; + } else if (n == 0) { + if (i == 0) { + return false; + } else { + break; + } + } + + append(byte); + +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + if (byte == '\n') { + break; + } +#else + if (prev_byte == '\r' && byte == '\n') { + break; + } + prev_byte = byte; +#endif + } + + return true; +} + +inline void stream_line_reader::append(char c) { + if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) { + fixed_buffer_[fixed_buffer_used_size_++] = c; + fixed_buffer_[fixed_buffer_used_size_] = '\0'; + } else { + if (glowable_buffer_.empty()) { + assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); + glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); + } + glowable_buffer_ += c; + } +} + +inline mmap::mmap(const char *path) { open(path); } + +inline mmap::~mmap() { close(); } + +inline bool mmap::open(const char *path) { + close(); + +#if defined(_WIN32) + auto wpath = u8string_to_wstring(path); + if (wpath.empty()) { + return false; + } + +#if _WIN32_WINNT >= _WIN32_WINNT_WIN8 + hFile_ = ::CreateFile2(wpath.c_str(), GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING, NULL); +#else + hFile_ = + ::CreateFileW(wpath.c_str(), GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL); +#endif + + if (hFile_ == INVALID_HANDLE_VALUE) { + return false; + } + + LARGE_INTEGER size{}; + if (!::GetFileSizeEx(hFile_, &size)) { + return false; + } + // If the following line doesn't compile due to QuadPart, update Windows SDK. + // See: + // https://github.com/yhirose/cpp-httplib/issues/1903#issuecomment-2316520721 + if (static_cast(size.QuadPart) > (std::numeric_limits::max)()) { + // `size_t` might be 32-bits, on 32-bits Windows. + return false; + } + size_ = static_cast(size.QuadPart); + +#if _WIN32_WINNT >= _WIN32_WINNT_WIN8 + hMapping_ = ::CreateFileMappingFromApp(hFile_, NULL, PAGE_READONLY, size_, NULL); +#else + hMapping_ = ::CreateFileMappingW(hFile_, NULL, PAGE_READONLY, 0, 0, NULL); +#endif + + // Special treatment for an empty file... + if (hMapping_ == NULL && size_ == 0) { + close(); + is_open_empty_file = true; + return true; + } + + if (hMapping_ == NULL) { + close(); + return false; + } + +#if _WIN32_WINNT >= _WIN32_WINNT_WIN8 + addr_ = ::MapViewOfFileFromApp(hMapping_, FILE_MAP_READ, 0, 0); +#else + addr_ = ::MapViewOfFile(hMapping_, FILE_MAP_READ, 0, 0, 0); +#endif + + if (addr_ == nullptr) { + close(); + return false; + } +#else + fd_ = ::open(path, O_RDONLY); + if (fd_ == -1) { + return false; + } + + struct stat sb; + if (fstat(fd_, &sb) == -1) { + close(); + return false; + } + size_ = static_cast(sb.st_size); + + addr_ = ::mmap(NULL, size_, PROT_READ, MAP_PRIVATE, fd_, 0); + + // Special treatment for an empty file... + if (addr_ == MAP_FAILED && size_ == 0) { + close(); + is_open_empty_file = true; + return false; + } +#endif + + return true; +} + +inline bool mmap::is_open() const { return is_open_empty_file ? true : addr_ != nullptr; } + +inline size_t mmap::size() const { return size_; } + +inline const char *mmap::data() const { return is_open_empty_file ? "" : static_cast(addr_); } + +inline void mmap::close() { +#if defined(_WIN32) + if (addr_) { + ::UnmapViewOfFile(addr_); + addr_ = nullptr; + } + + if (hMapping_) { + ::CloseHandle(hMapping_); + hMapping_ = NULL; + } + + if (hFile_ != INVALID_HANDLE_VALUE) { + ::CloseHandle(hFile_); + hFile_ = INVALID_HANDLE_VALUE; + } + + is_open_empty_file = false; +#else + if (addr_ != nullptr) { + munmap(addr_, size_); + addr_ = nullptr; + } + + if (fd_ != -1) { + ::close(fd_); + fd_ = -1; + } +#endif + size_ = 0; +} +inline int close_socket(socket_t sock) { +#ifdef _WIN32 + return closesocket(sock); +#else + return close(sock); +#endif +} + +template inline ssize_t handle_EINTR(T fn) { + ssize_t res = 0; + while (true) { + res = fn(); + if (res < 0 && errno == EINTR) { + std::this_thread::sleep_for(std::chrono::microseconds{1}); + continue; + } + break; + } + return res; +} + +inline ssize_t read_socket(socket_t sock, void *ptr, size_t size, int flags) { + return handle_EINTR([&]() { + return recv(sock, +#ifdef _WIN32 + static_cast(ptr), static_cast(size), +#else + ptr, size, +#endif + flags); + }); +} + +inline ssize_t send_socket(socket_t sock, const void *ptr, size_t size, int flags) { + return handle_EINTR([&]() { + return send(sock, +#ifdef _WIN32 + static_cast(ptr), static_cast(size), +#else + ptr, size, +#endif + flags); + }); +} + +inline ssize_t select_read(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + return handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); +#else +#ifndef _WIN32 + if (sock >= FD_SETSIZE) { + return -1; + } +#endif + + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + return handle_EINTR([&]() { return select(static_cast(sock + 1), &fds, nullptr, nullptr, &tv); }); +#endif +} + +inline ssize_t select_write(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLOUT; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + return handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); +#else +#ifndef _WIN32 + if (sock >= FD_SETSIZE) { + return -1; + } +#endif + + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + return handle_EINTR([&]() { return select(static_cast(sock + 1), nullptr, &fds, nullptr, &tv); }); +#endif +} + +inline Error wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN | POLLOUT; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + auto poll_res = handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); + + if (poll_res == 0) { + return Error::ConnectionTimeout; + } + + if (poll_res > 0 && pfd_read.revents & (POLLIN | POLLOUT)) { + auto error = 0; + socklen_t len = sizeof(error); + auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len); + auto successful = res >= 0 && !error; + return successful ? Error::Success : Error::Connection; + } + + return Error::Connection; +#else +#ifndef _WIN32 + if (sock >= FD_SETSIZE) { + return Error::Connection; + } +#endif + + fd_set fdsr; + FD_ZERO(&fdsr); + FD_SET(sock, &fdsr); + + auto fdsw = fdsr; + auto fdse = fdsr; + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + auto ret = handle_EINTR([&]() { return select(static_cast(sock + 1), &fdsr, &fdsw, &fdse, &tv); }); + + if (ret == 0) { + return Error::ConnectionTimeout; + } + + if (ret > 0 && (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { + auto error = 0; + socklen_t len = sizeof(error); + auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len); + auto successful = res >= 0 && !error; + return successful ? Error::Success : Error::Connection; + } + return Error::Connection; +#endif +} + +inline bool is_socket_alive(socket_t sock) { + const auto val = detail::select_read(sock, 0, 0); + if (val == 0) { + return true; + } else if (val < 0 && errno == EBADF) { + return false; + } + char buf[1]; + return detail::read_socket(sock, &buf[0], sizeof(buf), MSG_PEEK) > 0; +} + +class SocketStream final : public Stream { + public: + SocketStream(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec); + ~SocketStream() override; + + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + + private: + socket_t sock_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; + + std::vector read_buff_; + size_t read_buff_off_ = 0; + size_t read_buff_content_size_ = 0; + + static const size_t read_buff_size_ = 1024l * 4; +}; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +class SSLSocketStream final : public Stream { + public: + SSLSocketStream(socket_t sock, SSL *ssl, time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec); + ~SSLSocketStream() override; + + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + + private: + socket_t sock_; + SSL *ssl_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; +}; +#endif + +inline bool keep_alive(const std::atomic &svr_sock, socket_t sock, time_t keep_alive_timeout_sec) { + using namespace std::chrono; + + const auto interval_usec = CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND; + + // Avoid expensive `steady_clock::now()` call for the first time + if (select_read(sock, 0, interval_usec) > 0) { + return true; + } + + const auto start = steady_clock::now() - microseconds{interval_usec}; + const auto timeout = seconds{keep_alive_timeout_sec}; + + while (true) { + if (svr_sock == INVALID_SOCKET) { + break; // Server socket is closed + } + + auto val = select_read(sock, 0, interval_usec); + if (val < 0) { + break; // Ssocket error + } else if (val == 0) { + if (steady_clock::now() - start > timeout) { + break; // Timeout + } + } else { + return true; // Ready for read + } + } + + return false; +} + +template +inline bool process_server_socket_core(const std::atomic &svr_sock, socket_t sock, + size_t keep_alive_max_count, time_t keep_alive_timeout_sec, T callback) { + assert(keep_alive_max_count > 0); + auto ret = false; + auto count = keep_alive_max_count; + while (count > 0 && keep_alive(svr_sock, sock, keep_alive_timeout_sec)) { + auto close_connection = count == 1; + auto connection_closed = false; + ret = callback(close_connection, connection_closed); + if (!ret || connection_closed) { + break; + } + count--; + } + return ret; +} + +template +inline bool process_server_socket(const std::atomic &svr_sock, socket_t sock, size_t keep_alive_max_count, + time_t keep_alive_timeout_sec, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, T callback) { + return process_server_socket_core(svr_sock, sock, keep_alive_max_count, keep_alive_timeout_sec, + [&](bool close_connection, bool &connection_closed) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec, write_timeout_sec, + write_timeout_usec); + return callback(strm, close_connection, connection_closed); + }); +} + +inline bool process_client_socket(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + std::function callback) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec, write_timeout_sec, write_timeout_usec); + return callback(strm); +} + +inline int shutdown_socket(socket_t sock) { +#ifdef _WIN32 + return shutdown(sock, SD_BOTH); +#else + return shutdown(sock, SHUT_RDWR); +#endif +} + +inline std::string escape_abstract_namespace_unix_domain(const std::string &s) { + if (s.size() > 1 && s[0] == '\0') { + auto ret = s; + ret[0] = '@'; + return ret; + } + return s; +} + +inline std::string unescape_abstract_namespace_unix_domain(const std::string &s) { + if (s.size() > 1 && s[0] == '@') { + auto ret = s; + ret[0] = '\0'; + return ret; + } + return s; +} + +template +socket_t create_socket(const std::string &host, const std::string &ip, int port, int address_family, int socket_flags, + bool tcp_nodelay, bool ipv6_v6only, SocketOptions socket_options, + BindOrConnect bind_or_connect) { + // Get address info + const char *node = nullptr; + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_IP; + + if (!ip.empty()) { + node = ip.c_str(); + // Ask getaddrinfo to convert IP in c-string to address + hints.ai_family = AF_UNSPEC; + hints.ai_flags = AI_NUMERICHOST; + } else { + if (!host.empty()) { + node = host.c_str(); + } + hints.ai_family = address_family; + hints.ai_flags = socket_flags; + } + +#ifndef _WIN32 + if (hints.ai_family == AF_UNIX) { + const auto addrlen = host.length(); + if (addrlen > sizeof(sockaddr_un::sun_path)) { + return INVALID_SOCKET; + } + +#ifdef SOCK_CLOEXEC + auto sock = socket(hints.ai_family, hints.ai_socktype | SOCK_CLOEXEC, hints.ai_protocol); +#else + auto sock = socket(hints.ai_family, hints.ai_socktype, hints.ai_protocol); +#endif + + if (sock != INVALID_SOCKET) { + sockaddr_un addr{}; + addr.sun_family = AF_UNIX; + + auto unescaped_host = unescape_abstract_namespace_unix_domain(host); + std::copy(unescaped_host.begin(), unescaped_host.end(), addr.sun_path); + + hints.ai_addr = reinterpret_cast(&addr); + hints.ai_addrlen = static_cast(sizeof(addr) - sizeof(addr.sun_path) + addrlen); + +#ifndef SOCK_CLOEXEC + fcntl(sock, F_SETFD, FD_CLOEXEC); +#endif + + if (socket_options) { + socket_options(sock); + } + + bool dummy; + if (!bind_or_connect(sock, hints, dummy)) { + close_socket(sock); + sock = INVALID_SOCKET; + } + } + return sock; + } +#endif + + auto service = std::to_string(port); + + if (getaddrinfo(node, service.c_str(), &hints, &result)) { +#if defined __linux__ && !defined __ANDROID__ + res_init(); +#endif + return INVALID_SOCKET; + } + auto se = detail::scope_exit([&] { freeaddrinfo(result); }); + + for (auto rp = result; rp; rp = rp->ai_next) { + // Create a socket +#ifdef _WIN32 + auto sock = WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol, nullptr, 0, + WSA_FLAG_NO_HANDLE_INHERIT | WSA_FLAG_OVERLAPPED); + /** + * Since the WSA_FLAG_NO_HANDLE_INHERIT is only supported on Windows 7 SP1 + * and above the socket creation fails on older Windows Systems. + * + * Let's try to create a socket the old way in this case. + * + * Reference: + * https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasocketa + * + * WSA_FLAG_NO_HANDLE_INHERIT: + * This flag is supported on Windows 7 with SP1, Windows Server 2008 R2 with + * SP1, and later + * + */ + if (sock == INVALID_SOCKET) { + sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + } +#else + +#ifdef SOCK_CLOEXEC + auto sock = socket(rp->ai_family, rp->ai_socktype | SOCK_CLOEXEC, rp->ai_protocol); +#else + auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); +#endif + +#endif + if (sock == INVALID_SOCKET) { + continue; + } + +#if !defined _WIN32 && !defined SOCK_CLOEXEC + if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { + close_socket(sock); + continue; + } +#endif + + if (tcp_nodelay) { + auto opt = 1; +#ifdef _WIN32 + setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&opt), sizeof(opt)); +#else + setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&opt), sizeof(opt)); +#endif + } + + if (rp->ai_family == AF_INET6) { + auto opt = ipv6_v6only ? 1 : 0; +#ifdef _WIN32 + setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, reinterpret_cast(&opt), sizeof(opt)); +#else + setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, reinterpret_cast(&opt), sizeof(opt)); +#endif + } + + if (socket_options) { + socket_options(sock); + } + + // bind or connect + auto quit = false; + if (bind_or_connect(sock, *rp, quit)) { + return sock; + } + + close_socket(sock); + + if (quit) { + break; + } + } + + return INVALID_SOCKET; +} + +inline void set_nonblocking(socket_t sock, bool nonblocking) { +#ifdef _WIN32 + auto flags = nonblocking ? 1UL : 0UL; + ioctlsocket(sock, FIONBIO, &flags); +#else + auto flags = fcntl(sock, F_GETFL, 0); + fcntl(sock, F_SETFL, nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); +#endif +} + +inline bool is_connection_error() { +#ifdef _WIN32 + return WSAGetLastError() != WSAEWOULDBLOCK; +#else + return errno != EINPROGRESS; +#endif +} + +inline bool bind_ip_address(socket_t sock, const std::string &host) { + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + if (getaddrinfo(host.c_str(), "0", &hints, &result)) { + return false; + } + auto se = detail::scope_exit([&] { freeaddrinfo(result); }); + + auto ret = false; + for (auto rp = result; rp; rp = rp->ai_next) { + const auto &ai = *rp; + if (!::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + ret = true; + break; + } + } + + return ret; +} + +#if !defined _WIN32 && !defined ANDROID && !defined _AIX && !defined __MVS__ +#define USE_IF2IP +#endif + +#ifdef USE_IF2IP +inline std::string if2ip(int address_family, const std::string &ifn) { + struct ifaddrs *ifap; + getifaddrs(&ifap); + auto se = detail::scope_exit([&] { freeifaddrs(ifap); }); + + std::string addr_candidate; + for (auto ifa = ifap; ifa; ifa = ifa->ifa_next) { + if (ifa->ifa_addr && ifn == ifa->ifa_name && + (AF_UNSPEC == address_family || ifa->ifa_addr->sa_family == address_family)) { + if (ifa->ifa_addr->sa_family == AF_INET) { + auto sa = reinterpret_cast(ifa->ifa_addr); + char buf[INET_ADDRSTRLEN]; + if (inet_ntop(AF_INET, &sa->sin_addr, buf, INET_ADDRSTRLEN)) { + return std::string(buf, INET_ADDRSTRLEN); + } + } else if (ifa->ifa_addr->sa_family == AF_INET6) { + auto sa = reinterpret_cast(ifa->ifa_addr); + if (!IN6_IS_ADDR_LINKLOCAL(&sa->sin6_addr)) { + char buf[INET6_ADDRSTRLEN] = {}; + if (inet_ntop(AF_INET6, &sa->sin6_addr, buf, INET6_ADDRSTRLEN)) { + // equivalent to mac's IN6_IS_ADDR_UNIQUE_LOCAL + auto s6_addr_head = sa->sin6_addr.s6_addr[0]; + if (s6_addr_head == 0xfc || s6_addr_head == 0xfd) { + addr_candidate = std::string(buf, INET6_ADDRSTRLEN); + } else { + return std::string(buf, INET6_ADDRSTRLEN); + } + } + } + } + } + } + return addr_candidate; +} +#endif + +inline socket_t create_client_socket(const std::string &host, const std::string &ip, int port, int address_family, + bool tcp_nodelay, bool ipv6_v6only, SocketOptions socket_options, + time_t connection_timeout_sec, time_t connection_timeout_usec, + time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, const std::string &intf, Error &error) { + auto sock = create_socket( + host, ip, port, address_family, 0, tcp_nodelay, ipv6_v6only, std::move(socket_options), + [&](socket_t sock2, struct addrinfo &ai, bool &quit) -> bool { + if (!intf.empty()) { +#ifdef USE_IF2IP + auto ip_from_if = if2ip(address_family, intf); + if (ip_from_if.empty()) { + ip_from_if = intf; + } + if (!bind_ip_address(sock2, ip_from_if)) { + error = Error::BindIPAddress; + return false; + } +#endif + } + + set_nonblocking(sock2, true); + + auto ret = ::connect(sock2, ai.ai_addr, static_cast(ai.ai_addrlen)); + + if (ret < 0) { + if (is_connection_error()) { + error = Error::Connection; + return false; + } + error = wait_until_socket_is_ready(sock2, connection_timeout_sec, connection_timeout_usec); + if (error != Error::Success) { + if (error == Error::ConnectionTimeout) { + quit = true; + } + return false; + } + } + + set_nonblocking(sock2, false); + + { +#ifdef _WIN32 + auto timeout = static_cast(read_timeout_sec * 1000 + read_timeout_usec / 1000); + setsockopt(sock2, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&timeout), sizeof(timeout)); +#else + timeval tv; + tv.tv_sec = static_cast(read_timeout_sec); + tv.tv_usec = static_cast(read_timeout_usec); + setsockopt(sock2, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&tv), sizeof(tv)); +#endif + } + { + +#ifdef _WIN32 + auto timeout = static_cast(write_timeout_sec * 1000 + write_timeout_usec / 1000); + setsockopt(sock2, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast(&timeout), sizeof(timeout)); +#else + timeval tv; + tv.tv_sec = static_cast(write_timeout_sec); + tv.tv_usec = static_cast(write_timeout_usec); + setsockopt(sock2, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast(&tv), sizeof(tv)); +#endif + } + + error = Error::Success; + return true; + }); + + if (sock != INVALID_SOCKET) { + error = Error::Success; + } else { + if (error == Error::Success) { + error = Error::Connection; + } + } + + return sock; +} + +inline bool get_ip_and_port(const struct sockaddr_storage &addr, socklen_t addr_len, std::string &ip, int &port) { + if (addr.ss_family == AF_INET) { + port = ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + port = ntohs(reinterpret_cast(&addr)->sin6_port); + } else { + return false; + } + + std::array ipstr{}; + if (getnameinfo(reinterpret_cast(&addr), addr_len, ipstr.data(), + static_cast(ipstr.size()), nullptr, 0, NI_NUMERICHOST)) { + return false; + } + + ip = ipstr.data(); + return true; +} + +inline void get_local_ip_and_port(socket_t sock, std::string &ip, int &port) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + if (!getsockname(sock, reinterpret_cast(&addr), &addr_len)) { + get_ip_and_port(addr, addr_len, ip, port); + } +} + +inline void get_remote_ip_and_port(socket_t sock, std::string &ip, int &port) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + + if (!getpeername(sock, reinterpret_cast(&addr), &addr_len)) { +#ifndef _WIN32 + if (addr.ss_family == AF_UNIX) { +#if defined(__linux__) + struct ucred ucred; + socklen_t len = sizeof(ucred); + if (getsockopt(sock, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == 0) { + port = ucred.pid; + } +#elif defined(SOL_LOCAL) && defined(SO_PEERPID) // __APPLE__ + pid_t pid; + socklen_t len = sizeof(pid); + if (getsockopt(sock, SOL_LOCAL, SO_PEERPID, &pid, &len) == 0) { + port = pid; + } +#endif + return; + } +#endif + get_ip_and_port(addr, addr_len, ip, port); + } +} + +inline constexpr unsigned int str2tag_core(const char *s, size_t l, unsigned int h) { + return (l == 0) ? h + : str2tag_core( + s + 1, l - 1, + // Unsets the 6 high bits of h, therefore no overflow happens + (((std::numeric_limits::max)() >> 6) & h * 33) ^ static_cast(*s)); +} + +inline unsigned int str2tag(const std::string &s) { return str2tag_core(s.data(), s.size(), 0); } + +namespace udl { + +inline constexpr unsigned int operator""_t(const char *s, size_t l) { return str2tag_core(s, l, 0); } + +} // namespace udl + +inline std::string find_content_type(const std::string &path, const std::map &user_data, + const std::string &default_content_type) { + auto ext = file_extension(path); + + auto it = user_data.find(ext); + if (it != user_data.end()) { + return it->second; + } + + using udl::operator""_t; + + switch (str2tag(ext)) { + default: + return default_content_type; + + case "css"_t: + return "text/css"; + case "csv"_t: + return "text/csv"; + case "htm"_t: + case "html"_t: + return "text/html"; + case "js"_t: + case "mjs"_t: + return "text/javascript"; + case "txt"_t: + return "text/plain"; + case "vtt"_t: + return "text/vtt"; + + case "apng"_t: + return "image/apng"; + case "avif"_t: + return "image/avif"; + case "bmp"_t: + return "image/bmp"; + case "gif"_t: + return "image/gif"; + case "png"_t: + return "image/png"; + case "svg"_t: + return "image/svg+xml"; + case "webp"_t: + return "image/webp"; + case "ico"_t: + return "image/x-icon"; + case "tif"_t: + return "image/tiff"; + case "tiff"_t: + return "image/tiff"; + case "jpg"_t: + case "jpeg"_t: + return "image/jpeg"; + + case "mp4"_t: + return "video/mp4"; + case "mpeg"_t: + return "video/mpeg"; + case "webm"_t: + return "video/webm"; + + case "mp3"_t: + return "audio/mp3"; + case "mpga"_t: + return "audio/mpeg"; + case "weba"_t: + return "audio/webm"; + case "wav"_t: + return "audio/wave"; + + case "otf"_t: + return "font/otf"; + case "ttf"_t: + return "font/ttf"; + case "woff"_t: + return "font/woff"; + case "woff2"_t: + return "font/woff2"; + + case "7z"_t: + return "application/x-7z-compressed"; + case "atom"_t: + return "application/atom+xml"; + case "pdf"_t: + return "application/pdf"; + case "json"_t: + return "application/json"; + case "rss"_t: + return "application/rss+xml"; + case "tar"_t: + return "application/x-tar"; + case "xht"_t: + case "xhtml"_t: + return "application/xhtml+xml"; + case "xslt"_t: + return "application/xslt+xml"; + case "xml"_t: + return "application/xml"; + case "gz"_t: + return "application/gzip"; + case "zip"_t: + return "application/zip"; + case "wasm"_t: + return "application/wasm"; + } +} + +inline bool can_compress_content_type(const std::string &content_type) { + using udl::operator""_t; + + auto tag = str2tag(content_type); + + switch (tag) { + case "image/svg+xml"_t: + case "application/javascript"_t: + case "application/json"_t: + case "application/xml"_t: + case "application/protobuf"_t: + case "application/xhtml+xml"_t: + return true; + + case "text/event-stream"_t: + return false; + + default: + return !content_type.rfind("text/", 0); + } +} + +inline EncodingType encoding_type(const Request &req, const Response &res) { + auto ret = detail::can_compress_content_type(res.get_header_value("Content-Type")); + if (!ret) { + return EncodingType::None; + } + + const auto &s = req.get_header_value("Accept-Encoding"); + (void) (s); + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + // TODO: 'Accept-Encoding' has br, not br;q=0 + ret = s.find("br") != std::string::npos; + if (ret) { + return EncodingType::Brotli; + } +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + // TODO: 'Accept-Encoding' has gzip, not gzip;q=0 + ret = s.find("gzip") != std::string::npos; + if (ret) { + return EncodingType::Gzip; + } +#endif + + return EncodingType::None; +} + +inline bool nocompressor::compress(const char *data, size_t data_length, bool /*last*/, Callback callback) { + if (!data_length) { + return true; + } + return callback(data, data_length); +} + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +inline gzip_compressor::gzip_compressor() { + std::memset(&strm_, 0, sizeof(strm_)); + strm_.zalloc = Z_NULL; + strm_.zfree = Z_NULL; + strm_.opaque = Z_NULL; + + is_valid_ = deflateInit2(&strm_, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, Z_DEFAULT_STRATEGY) == Z_OK; +} + +inline gzip_compressor::~gzip_compressor() { deflateEnd(&strm_); } + +inline bool gzip_compressor::compress(const char *data, size_t data_length, bool last, Callback callback) { + assert(is_valid_); + + do { + constexpr size_t max_avail_in = (std::numeric_limits::max)(); + + strm_.avail_in = static_cast((std::min)(data_length, max_avail_in)); + strm_.next_in = const_cast(reinterpret_cast(data)); + + data_length -= strm_.avail_in; + data += strm_.avail_in; + + auto flush = (last && data_length == 0) ? Z_FINISH : Z_NO_FLUSH; + auto ret = Z_OK; + + std::array buff{}; + do { + strm_.avail_out = static_cast(buff.size()); + strm_.next_out = reinterpret_cast(buff.data()); + + ret = deflate(&strm_, flush); + if (ret == Z_STREAM_ERROR) { + return false; + } + + if (!callback(buff.data(), buff.size() - strm_.avail_out)) { + return false; + } + } while (strm_.avail_out == 0); + + assert((flush == Z_FINISH && ret == Z_STREAM_END) || (flush == Z_NO_FLUSH && ret == Z_OK)); + assert(strm_.avail_in == 0); + } while (data_length > 0); + + return true; +} + +inline gzip_decompressor::gzip_decompressor() { + std::memset(&strm_, 0, sizeof(strm_)); + strm_.zalloc = Z_NULL; + strm_.zfree = Z_NULL; + strm_.opaque = Z_NULL; + + // 15 is the value of wbits, which should be at the maximum possible value + // to ensure that any gzip stream can be decoded. The offset of 32 specifies + // that the stream type should be automatically detected either gzip or + // deflate. + is_valid_ = inflateInit2(&strm_, 32 + 15) == Z_OK; +} + +inline gzip_decompressor::~gzip_decompressor() { inflateEnd(&strm_); } + +inline bool gzip_decompressor::is_valid() const { return is_valid_; } + +inline bool gzip_decompressor::decompress(const char *data, size_t data_length, Callback callback) { + assert(is_valid_); + + auto ret = Z_OK; + + do { + constexpr size_t max_avail_in = (std::numeric_limits::max)(); + + strm_.avail_in = static_cast((std::min)(data_length, max_avail_in)); + strm_.next_in = const_cast(reinterpret_cast(data)); + + data_length -= strm_.avail_in; + data += strm_.avail_in; + + std::array buff{}; + while (strm_.avail_in > 0 && ret == Z_OK) { + strm_.avail_out = static_cast(buff.size()); + strm_.next_out = reinterpret_cast(buff.data()); + + ret = inflate(&strm_, Z_NO_FLUSH); + + assert(ret != Z_STREAM_ERROR); + switch (ret) { + case Z_NEED_DICT: + case Z_DATA_ERROR: + case Z_MEM_ERROR: + inflateEnd(&strm_); + return false; + } + + if (!callback(buff.data(), buff.size() - strm_.avail_out)) { + return false; + } + } + + if (ret != Z_OK && ret != Z_STREAM_END) { + return false; + } + + } while (data_length > 0); + + return true; +} +#endif + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT +inline brotli_compressor::brotli_compressor() { state_ = BrotliEncoderCreateInstance(nullptr, nullptr, nullptr); } + +inline brotli_compressor::~brotli_compressor() { BrotliEncoderDestroyInstance(state_); } + +inline bool brotli_compressor::compress(const char *data, size_t data_length, bool last, Callback callback) { + std::array buff{}; + + auto operation = last ? BROTLI_OPERATION_FINISH : BROTLI_OPERATION_PROCESS; + auto available_in = data_length; + auto next_in = reinterpret_cast(data); + + for (;;) { + if (last) { + if (BrotliEncoderIsFinished(state_)) { + break; + } + } else { + if (!available_in) { + break; + } + } + + auto available_out = buff.size(); + auto next_out = buff.data(); + + if (!BrotliEncoderCompressStream(state_, operation, &available_in, &next_in, &available_out, &next_out, nullptr)) { + return false; + } + + auto output_bytes = buff.size() - available_out; + if (output_bytes) { + callback(reinterpret_cast(buff.data()), output_bytes); + } + } + + return true; +} + +inline brotli_decompressor::brotli_decompressor() { + decoder_s = BrotliDecoderCreateInstance(0, 0, 0); + decoder_r = decoder_s ? BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT : BROTLI_DECODER_RESULT_ERROR; +} + +inline brotli_decompressor::~brotli_decompressor() { + if (decoder_s) { + BrotliDecoderDestroyInstance(decoder_s); + } +} + +inline bool brotli_decompressor::is_valid() const { return decoder_s; } + +inline bool brotli_decompressor::decompress(const char *data, size_t data_length, Callback callback) { + if (decoder_r == BROTLI_DECODER_RESULT_SUCCESS || decoder_r == BROTLI_DECODER_RESULT_ERROR) { + return 0; + } + + auto next_in = reinterpret_cast(data); + size_t avail_in = data_length; + size_t total_out; + + decoder_r = BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT; + + std::array buff{}; + while (decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT) { + char *next_out = buff.data(); + size_t avail_out = buff.size(); + + decoder_r = BrotliDecoderDecompressStream(decoder_s, &avail_in, &next_in, &avail_out, + reinterpret_cast(&next_out), &total_out); + + if (decoder_r == BROTLI_DECODER_RESULT_ERROR) { + return false; + } + + if (!callback(buff.data(), buff.size() - avail_out)) { + return false; + } + } + + return decoder_r == BROTLI_DECODER_RESULT_SUCCESS || decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT; +} +#endif + +inline bool has_header(const Headers &headers, const std::string &key) { return headers.find(key) != headers.end(); } + +inline const char *get_header_value(const Headers &headers, const std::string &key, const char *def, size_t id) { + auto rng = headers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { + return it->second.c_str(); + } + return def; +} + +template inline bool parse_header(const char *beg, const char *end, T fn) { + // Skip trailing spaces and tabs. + while (beg < end && is_space_or_tab(end[-1])) { + end--; + } + + auto p = beg; + while (p < end && *p != ':') { + p++; + } + + if (p == end) { + return false; + } + + auto key_end = p; + + if (*p++ != ':') { + return false; + } + + while (p < end && is_space_or_tab(*p)) { + p++; + } + + if (p <= end) { + auto key_len = key_end - beg; + if (!key_len) { + return false; + } + + auto key = std::string(beg, key_end); + auto val = case_ignore::equal(key, "Location") ? std::string(p, end) : decode_url(std::string(p, end), false); + + // NOTE: From RFC 9110: + // Field values containing CR, LF, or NUL characters are + // invalid and dangerous, due to the varying ways that + // implementations might parse and interpret those + // characters; a recipient of CR, LF, or NUL within a field + // value MUST either reject the message or replace each of + // those characters with SP before further processing or + // forwarding of that message. + static const std::string CR_LF_NUL("\r\n\0", 3); + if (val.find_first_of(CR_LF_NUL) != std::string::npos) { + return false; + } + + fn(key, val); + return true; + } + + return false; +} + +inline bool read_headers(Stream &strm, Headers &headers) { + const auto bufsiz = 2048; + char buf[bufsiz]; + stream_line_reader line_reader(strm, buf, bufsiz); + + for (;;) { + if (!line_reader.getline()) { + return false; + } + + // Check if the line ends with CRLF. + auto line_terminator_len = 2; + if (line_reader.end_with_crlf()) { + // Blank line indicates end of headers. + if (line_reader.size() == 2) { + break; + } + } else { +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + // Blank line indicates end of headers. + if (line_reader.size() == 1) { + break; + } + line_terminator_len = 1; +#else + continue; // Skip invalid line. +#endif + } + + if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { + return false; + } + + // Exclude line terminator + auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; + + if (!parse_header(line_reader.ptr(), end, + [&](const std::string &key, std::string &val) { headers.emplace(key, val); })) { + return false; + } + } + + return true; +} + +inline bool read_content_with_length(Stream &strm, uint64_t len, Progress progress, ContentReceiverWithProgress out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { + return false; + } + + if (!out(buf, static_cast(n), r, len)) { + return false; + } + r += static_cast(n); + + if (progress) { + if (!progress(r, len)) { + return false; + } + } + } + + return true; +} + +inline void skip_content_with_length(Stream &strm, uint64_t len) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { + return; + } + r += static_cast(n); + } +} + +inline bool read_content_without_length(Stream &strm, ContentReceiverWithProgress out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + uint64_t r = 0; + for (;;) { + auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); + if (n <= 0) { + return true; + } + + if (!out(buf, static_cast(n), r, 0)) { + return false; + } + r += static_cast(n); + } + + return true; +} + +template inline bool read_content_chunked(Stream &strm, T &x, ContentReceiverWithProgress out) { + const auto bufsiz = 16; + char buf[bufsiz]; + + stream_line_reader line_reader(strm, buf, bufsiz); + + if (!line_reader.getline()) { + return false; + } + + unsigned long chunk_len; + while (true) { + char *end_ptr; + + chunk_len = std::strtoul(line_reader.ptr(), &end_ptr, 16); + + if (end_ptr == line_reader.ptr()) { + return false; + } + if (chunk_len == ULONG_MAX) { + return false; + } + + if (chunk_len == 0) { + break; + } + + if (!read_content_with_length(strm, chunk_len, nullptr, out)) { + return false; + } + + if (!line_reader.getline()) { + return false; + } + + if (strcmp(line_reader.ptr(), "\r\n") != 0) { + return false; + } + + if (!line_reader.getline()) { + return false; + } + } + + assert(chunk_len == 0); + + // NOTE: In RFC 9112, '7.1 Chunked Transfer Coding' mentiones "The chunked + // transfer coding is complete when a chunk with a chunk-size of zero is + // received, possibly followed by a trailer section, and finally terminated by + // an empty line". https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1 + // + // In '7.1.3. Decoding Chunked', however, the pseudo-code in the section + // does't care for the existence of the final CRLF. In other words, it seems + // to be ok whether the final CRLF exists or not in the chunked data. + // https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1.3 + // + // According to the reference code in RFC 9112, cpp-htpplib now allows + // chuncked transfer coding data without the final CRLF. + if (!line_reader.getline()) { + return true; + } + + while (strcmp(line_reader.ptr(), "\r\n") != 0) { + if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { + return false; + } + + // Exclude line terminator + constexpr auto line_terminator_len = 2; + auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; + + parse_header(line_reader.ptr(), end, + [&](const std::string &key, const std::string &val) { x.headers.emplace(key, val); }); + + if (!line_reader.getline()) { + return false; + } + } + + return true; +} + +inline bool is_chunked_transfer_encoding(const Headers &headers) { + return case_ignore::equal(get_header_value(headers, "Transfer-Encoding", "", 0), "chunked"); +} + +template +bool prepare_content_receiver(T &x, int &status, ContentReceiverWithProgress receiver, bool decompress, U callback) { + if (decompress) { + std::string encoding = x.get_header_value("Content-Encoding"); + std::unique_ptr decompressor; + + if (encoding == "gzip" || encoding == "deflate") { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + decompressor = detail::make_unique(); +#else + status = StatusCode::UnsupportedMediaType_415; + return false; +#endif + } else if (encoding.find("br") != std::string::npos) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + decompressor = detail::make_unique(); +#else + status = StatusCode::UnsupportedMediaType_415; + return false; +#endif + } + + if (decompressor) { + if (decompressor->is_valid()) { + ContentReceiverWithProgress out = [&](const char *buf, size_t n, uint64_t off, uint64_t len) { + return decompressor->decompress(buf, n, + [&](const char *buf2, size_t n2) { return receiver(buf2, n2, off, len); }); + }; + return callback(std::move(out)); + } else { + status = StatusCode::InternalServerError_500; + return false; + } + } + } + + ContentReceiverWithProgress out = [&](const char *buf, size_t n, uint64_t off, uint64_t len) { + return receiver(buf, n, off, len); + }; + return callback(std::move(out)); +} + +template +bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, Progress progress, + ContentReceiverWithProgress receiver, bool decompress) { + return prepare_content_receiver( + x, status, std::move(receiver), decompress, [&](const ContentReceiverWithProgress &out) { + auto ret = true; + auto exceed_payload_max_length = false; + + if (is_chunked_transfer_encoding(x.headers)) { + ret = read_content_chunked(strm, x, out); + } else if (!has_header(x.headers, "Content-Length")) { + ret = read_content_without_length(strm, out); + } else { + auto len = get_header_value_u64(x.headers, "Content-Length", 0, 0); + if (len > payload_max_length) { + exceed_payload_max_length = true; + skip_content_with_length(strm, len); + ret = false; + } else if (len > 0) { + ret = read_content_with_length(strm, len, std::move(progress), out); + } + } + + if (!ret) { + status = exceed_payload_max_length ? StatusCode::PayloadTooLarge_413 : StatusCode::BadRequest_400; + } + return ret; + }); +} + +inline ssize_t write_request_line(Stream &strm, const std::string &method, const std::string &path) { + std::string s = method; + s += " "; + s += path; + s += " HTTP/1.1\r\n"; + return strm.write(s.data(), s.size()); +} + +inline ssize_t write_response_line(Stream &strm, int status) { + std::string s = "HTTP/1.1 "; + s += std::to_string(status); + s += " "; + s += httplib::status_message(status); + s += "\r\n"; + return strm.write(s.data(), s.size()); +} + +inline ssize_t write_headers(Stream &strm, const Headers &headers) { + ssize_t write_len = 0; + for (const auto &x : headers) { + std::string s; + s = x.first; + s += ": "; + s += x.second; + s += "\r\n"; + + auto len = strm.write(s.data(), s.size()); + if (len < 0) { + return len; + } + write_len += len; + } + auto len = strm.write("\r\n"); + if (len < 0) { + return len; + } + write_len += len; + return write_len; +} + +inline bool write_data(Stream &strm, const char *d, size_t l) { + size_t offset = 0; + while (offset < l) { + auto length = strm.write(d + offset, l - offset); + if (length < 0) { + return false; + } + offset += static_cast(length); + } + return true; +} + +template +inline bool write_content(Stream &strm, const ContentProvider &content_provider, size_t offset, size_t length, + T is_shutting_down, Error &error) { + size_t end_offset = offset + length; + auto ok = true; + DataSink data_sink; + + data_sink.write = [&](const char *d, size_t l) -> bool { + if (ok) { + if (strm.is_writable() && write_data(strm, d, l)) { + offset += l; + } else { + ok = false; + } + } + return ok; + }; + + data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + + while (offset < end_offset && !is_shutting_down()) { + if (!strm.is_writable()) { + error = Error::Write; + return false; + } else if (!content_provider(offset, end_offset - offset, data_sink)) { + error = Error::Canceled; + return false; + } else if (!ok) { + error = Error::Write; + return false; + } + } + + error = Error::Success; + return true; +} + +template +inline bool write_content(Stream &strm, const ContentProvider &content_provider, size_t offset, size_t length, + const T &is_shutting_down) { + auto error = Error::Success; + return write_content(strm, content_provider, offset, length, is_shutting_down, error); +} + +template +inline bool write_content_without_length(Stream &strm, const ContentProvider &content_provider, + const T &is_shutting_down) { + size_t offset = 0; + auto data_available = true; + auto ok = true; + DataSink data_sink; + + data_sink.write = [&](const char *d, size_t l) -> bool { + if (ok) { + offset += l; + if (!strm.is_writable() || !write_data(strm, d, l)) { + ok = false; + } + } + return ok; + }; + + data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + + data_sink.done = [&](void) { data_available = false; }; + + while (data_available && !is_shutting_down()) { + if (!strm.is_writable()) { + return false; + } else if (!content_provider(offset, 0, data_sink)) { + return false; + } else if (!ok) { + return false; + } + } + return true; +} + +template +inline bool write_content_chunked(Stream &strm, const ContentProvider &content_provider, const T &is_shutting_down, + U &compressor, Error &error) { + size_t offset = 0; + auto data_available = true; + auto ok = true; + DataSink data_sink; + + data_sink.write = [&](const char *d, size_t l) -> bool { + if (ok) { + data_available = l > 0; + offset += l; + + std::string payload; + if (compressor.compress(d, l, false, [&](const char *data, size_t data_len) { + payload.append(data, data_len); + return true; + })) { + if (!payload.empty()) { + // Emit chunked response header and footer for each chunk + auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; + if (!strm.is_writable() || !write_data(strm, chunk.data(), chunk.size())) { + ok = false; + } + } + } else { + ok = false; + } + } + return ok; + }; + + data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + + auto done_with_trailer = [&](const Headers *trailer) { + if (!ok) { + return; + } + + data_available = false; + + std::string payload; + if (!compressor.compress(nullptr, 0, true, [&](const char *data, size_t data_len) { + payload.append(data, data_len); + return true; + })) { + ok = false; + return; + } + + if (!payload.empty()) { + // Emit chunked response header and footer for each chunk + auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; + if (!strm.is_writable() || !write_data(strm, chunk.data(), chunk.size())) { + ok = false; + return; + } + } + + static const std::string done_marker("0\r\n"); + if (!write_data(strm, done_marker.data(), done_marker.size())) { + ok = false; + } + + // Trailer + if (trailer) { + for (const auto &kv : *trailer) { + std::string field_line = kv.first + ": " + kv.second + "\r\n"; + if (!write_data(strm, field_line.data(), field_line.size())) { + ok = false; + } + } + } + + static const std::string crlf("\r\n"); + if (!write_data(strm, crlf.data(), crlf.size())) { + ok = false; + } + }; + + data_sink.done = [&](void) { done_with_trailer(nullptr); }; + + data_sink.done_with_trailer = [&](const Headers &trailer) { done_with_trailer(&trailer); }; + + while (data_available && !is_shutting_down()) { + if (!strm.is_writable()) { + error = Error::Write; + return false; + } else if (!content_provider(offset, 0, data_sink)) { + error = Error::Canceled; + return false; + } else if (!ok) { + error = Error::Write; + return false; + } + } + + error = Error::Success; + return true; +} + +template +inline bool write_content_chunked(Stream &strm, const ContentProvider &content_provider, const T &is_shutting_down, + U &compressor) { + auto error = Error::Success; + return write_content_chunked(strm, content_provider, is_shutting_down, compressor, error); +} + +template +inline bool redirect(T &cli, Request &req, Response &res, const std::string &path, const std::string &location, + Error &error) { + Request new_req = req; + new_req.path = path; + new_req.redirect_count_ -= 1; + + if (res.status == StatusCode::SeeOther_303 && (req.method != "GET" && req.method != "HEAD")) { + new_req.method = "GET"; + new_req.body.clear(); + new_req.headers.clear(); + } + + Response new_res; + + auto ret = cli.send(new_req, new_res, error); + if (ret) { + req = new_req; + res = new_res; + + if (res.location.empty()) { + res.location = location; + } + } + return ret; +} + +inline std::string params_to_query_str(const Params ¶ms) { + std::string query; + + for (auto it = params.begin(); it != params.end(); ++it) { + if (it != params.begin()) { + query += "&"; + } + query += it->first; + query += "="; + query += encode_query_param(it->second); + } + return query; +} + +inline void parse_query_text(const char *data, std::size_t size, Params ¶ms) { + std::set cache; + split(data, data + size, '&', [&](const char *b, const char *e) { + std::string kv(b, e); + if (cache.find(kv) != cache.end()) { + return; + } + cache.insert(std::move(kv)); + + std::string key; + std::string val; + divide(b, static_cast(e - b), '=', + [&](const char *lhs_data, std::size_t lhs_size, const char *rhs_data, std::size_t rhs_size) { + key.assign(lhs_data, lhs_size); + val.assign(rhs_data, rhs_size); + }); + + if (!key.empty()) { + params.emplace(decode_url(key, true), decode_url(val, true)); + } + }); +} + +inline void parse_query_text(const std::string &s, Params ¶ms) { parse_query_text(s.data(), s.size(), params); } + +inline bool parse_multipart_boundary(const std::string &content_type, std::string &boundary) { + auto boundary_keyword = "boundary="; + auto pos = content_type.find(boundary_keyword); + if (pos == std::string::npos) { + return false; + } + auto end = content_type.find(';', pos); + auto beg = pos + strlen(boundary_keyword); + boundary = trim_double_quotes_copy(content_type.substr(beg, end - beg)); + return !boundary.empty(); +} + +inline void parse_disposition_params(const std::string &s, Params ¶ms) { + std::set cache; + split(s.data(), s.data() + s.size(), ';', [&](const char *b, const char *e) { + std::string kv(b, e); + if (cache.find(kv) != cache.end()) { + return; + } + cache.insert(kv); + + std::string key; + std::string val; + split(b, e, '=', [&](const char *b2, const char *e2) { + if (key.empty()) { + key.assign(b2, e2); + } else { + val.assign(b2, e2); + } + }); + + if (!key.empty()) { + params.emplace(trim_double_quotes_copy((key)), trim_double_quotes_copy((val))); + } + }); +} + +#ifdef CPPHTTPLIB_NO_EXCEPTIONS +inline bool parse_range_header(const std::string &s, Ranges &ranges) { +#else +inline bool parse_range_header(const std::string &s, Ranges &ranges) try { +#endif + auto is_valid = [](const std::string &str) { + return std::all_of(str.cbegin(), str.cend(), [](unsigned char c) { return std::isdigit(c); }); + }; + + if (s.size() > 7 && s.compare(0, 6, "bytes=") == 0) { + const auto pos = static_cast(6); + const auto len = static_cast(s.size() - 6); + auto all_valid_ranges = true; + split(&s[pos], &s[pos + len], ',', [&](const char *b, const char *e) { + if (!all_valid_ranges) { + return; + } + + const auto it = std::find(b, e, '-'); + if (it == e) { + all_valid_ranges = false; + return; + } + + const auto lhs = std::string(b, it); + const auto rhs = std::string(it + 1, e); + if (!is_valid(lhs) || !is_valid(rhs)) { + all_valid_ranges = false; + return; + } + + const auto first = static_cast(lhs.empty() ? -1 : std::stoll(lhs)); + const auto last = static_cast(rhs.empty() ? -1 : std::stoll(rhs)); + if ((first == -1 && last == -1) || (first != -1 && last != -1 && first > last)) { + all_valid_ranges = false; + return; + } + + ranges.emplace_back(first, last); + }); + return all_valid_ranges && !ranges.empty(); + } + return false; +#ifdef CPPHTTPLIB_NO_EXCEPTIONS +} +#else +} catch (...) { + return false; +} +#endif + +class MultipartFormDataParser { + public: + MultipartFormDataParser() = default; + + void set_boundary(std::string &&boundary) { + boundary_ = boundary; + dash_boundary_crlf_ = dash_ + boundary_ + crlf_; + crlf_dash_boundary_ = crlf_ + dash_ + boundary_; + } + + bool is_valid() const { return is_valid_; } + + bool parse(const char *buf, size_t n, const ContentReceiver &content_callback, + const MultipartContentHeader &header_callback) { + buf_append(buf, n); + + while (buf_size() > 0) { + switch (state_) { + case 0: { // Initial boundary + buf_erase(buf_find(dash_boundary_crlf_)); + if (dash_boundary_crlf_.size() > buf_size()) { + return true; + } + if (!buf_start_with(dash_boundary_crlf_)) { + return false; + } + buf_erase(dash_boundary_crlf_.size()); + state_ = 1; + break; + } + case 1: { // New entry + clear_file_info(); + state_ = 2; + break; + } + case 2: { // Headers + auto pos = buf_find(crlf_); + if (pos > CPPHTTPLIB_HEADER_MAX_LENGTH) { + return false; + } + while (pos < buf_size()) { + // Empty line + if (pos == 0) { + if (!header_callback(file_)) { + is_valid_ = false; + return false; + } + buf_erase(crlf_.size()); + state_ = 3; + break; + } + + const auto header = buf_head(pos); + + if (!parse_header(header.data(), header.data() + header.size(), + [&](const std::string &, const std::string &) {})) { + is_valid_ = false; + return false; + } + + static const std::string header_content_type = "Content-Type:"; + + if (start_with_case_ignore(header, header_content_type)) { + file_.content_type = trim_copy(header.substr(header_content_type.size())); + } else { + static const std::regex re_content_disposition(R"~(^Content-Disposition:\s*form-data;\s*(.*)$)~", + std::regex_constants::icase); + + std::smatch m; + if (std::regex_match(header, m, re_content_disposition)) { + Params params; + parse_disposition_params(m[1], params); + + auto it = params.find("name"); + if (it != params.end()) { + file_.name = it->second; + } else { + is_valid_ = false; + return false; + } + + it = params.find("filename"); + if (it != params.end()) { + file_.filename = it->second; + } + + it = params.find("filename*"); + if (it != params.end()) { + // Only allow UTF-8 enconnding... + static const std::regex re_rfc5987_encoding(R"~(^UTF-8''(.+?)$)~", std::regex_constants::icase); + + std::smatch m2; + if (std::regex_match(it->second, m2, re_rfc5987_encoding)) { + file_.filename = decode_url(m2[1], false); // override... + } else { + is_valid_ = false; + return false; + } + } + } + } + buf_erase(pos + crlf_.size()); + pos = buf_find(crlf_); + } + if (state_ != 3) { + return true; + } + break; + } + case 3: { // Body + if (crlf_dash_boundary_.size() > buf_size()) { + return true; + } + auto pos = buf_find(crlf_dash_boundary_); + if (pos < buf_size()) { + if (!content_callback(buf_data(), pos)) { + is_valid_ = false; + return false; + } + buf_erase(pos + crlf_dash_boundary_.size()); + state_ = 4; + } else { + auto len = buf_size() - crlf_dash_boundary_.size(); + if (len > 0) { + if (!content_callback(buf_data(), len)) { + is_valid_ = false; + return false; + } + buf_erase(len); + } + return true; + } + break; + } + case 4: { // Boundary + if (crlf_.size() > buf_size()) { + return true; + } + if (buf_start_with(crlf_)) { + buf_erase(crlf_.size()); + state_ = 1; + } else { + if (dash_.size() > buf_size()) { + return true; + } + if (buf_start_with(dash_)) { + buf_erase(dash_.size()); + is_valid_ = true; + buf_erase(buf_size()); // Remove epilogue + } else { + return true; + } + } + break; + } + } + } + + return true; + } + + private: + void clear_file_info() { + file_.name.clear(); + file_.filename.clear(); + file_.content_type.clear(); + } + + bool start_with_case_ignore(const std::string &a, const std::string &b) const { + if (a.size() < b.size()) { + return false; + } + for (size_t i = 0; i < b.size(); i++) { + if (case_ignore::to_lower(a[i]) != case_ignore::to_lower(b[i])) { + return false; + } + } + return true; + } + + const std::string dash_ = "--"; + const std::string crlf_ = "\r\n"; + std::string boundary_; + std::string dash_boundary_crlf_; + std::string crlf_dash_boundary_; + + size_t state_ = 0; + bool is_valid_ = false; + MultipartFormData file_; + + // Buffer + bool start_with(const std::string &a, size_t spos, size_t epos, const std::string &b) const { + if (epos - spos < b.size()) { + return false; + } + for (size_t i = 0; i < b.size(); i++) { + if (a[i + spos] != b[i]) { + return false; + } + } + return true; + } + + size_t buf_size() const { return buf_epos_ - buf_spos_; } + + const char *buf_data() const { return &buf_[buf_spos_]; } + + std::string buf_head(size_t l) const { return buf_.substr(buf_spos_, l); } + + bool buf_start_with(const std::string &s) const { return start_with(buf_, buf_spos_, buf_epos_, s); } + + size_t buf_find(const std::string &s) const { + auto c = s.front(); + + size_t off = buf_spos_; + while (off < buf_epos_) { + auto pos = off; + while (true) { + if (pos == buf_epos_) { + return buf_size(); + } + if (buf_[pos] == c) { + break; + } + pos++; + } + + auto remaining_size = buf_epos_ - pos; + if (s.size() > remaining_size) { + return buf_size(); + } + + if (start_with(buf_, pos, buf_epos_, s)) { + return pos - buf_spos_; + } + + off = pos + 1; + } + + return buf_size(); + } + + void buf_append(const char *data, size_t n) { + auto remaining_size = buf_size(); + if (remaining_size > 0 && buf_spos_ > 0) { + for (size_t i = 0; i < remaining_size; i++) { + buf_[i] = buf_[buf_spos_ + i]; + } + } + buf_spos_ = 0; + buf_epos_ = remaining_size; + + if (remaining_size + n > buf_.size()) { + buf_.resize(remaining_size + n); + } + + for (size_t i = 0; i < n; i++) { + buf_[buf_epos_ + i] = data[i]; + } + buf_epos_ += n; + } + + void buf_erase(size_t size) { buf_spos_ += size; } + + std::string buf_; + size_t buf_spos_ = 0; + size_t buf_epos_ = 0; +}; + +inline std::string random_string(size_t length) { + static const char data[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + + // std::random_device might actually be deterministic on some + // platforms, but due to lack of support in the c++ standard library, + // doing better requires either some ugly hacks or breaking portability. + static std::random_device seed_gen; + + // Request 128 bits of entropy for initialization + static std::seed_seq seed_sequence{seed_gen(), seed_gen(), seed_gen(), seed_gen()}; + + static std::mt19937 engine(seed_sequence); + + std::string result; + for (size_t i = 0; i < length; i++) { + result += data[engine() % (sizeof(data) - 1)]; + } + return result; +} + +inline std::string make_multipart_data_boundary() { + return "--cpp-httplib-multipart-data-" + detail::random_string(16); +} + +inline bool is_multipart_boundary_chars_valid(const std::string &boundary) { + auto valid = true; + for (size_t i = 0; i < boundary.size(); i++) { + auto c = boundary[i]; + if (!std::isalnum(c) && c != '-' && c != '_') { + valid = false; + break; + } + } + return valid; +} + +template +inline std::string serialize_multipart_formdata_item_begin(const T &item, const std::string &boundary) { + std::string body = "--" + boundary + "\r\n"; + body += "Content-Disposition: form-data; name=\"" + item.name + "\""; + if (!item.filename.empty()) { + body += "; filename=\"" + item.filename + "\""; + } + body += "\r\n"; + if (!item.content_type.empty()) { + body += "Content-Type: " + item.content_type + "\r\n"; + } + body += "\r\n"; + + return body; +} + +inline std::string serialize_multipart_formdata_item_end() { return "\r\n"; } + +inline std::string serialize_multipart_formdata_finish(const std::string &boundary) { + return "--" + boundary + "--\r\n"; +} + +inline std::string serialize_multipart_formdata_get_content_type(const std::string &boundary) { + return "multipart/form-data; boundary=" + boundary; +} + +inline std::string serialize_multipart_formdata(const MultipartFormDataItems &items, const std::string &boundary, + bool finish = true) { + std::string body; + + for (const auto &item : items) { + body += serialize_multipart_formdata_item_begin(item, boundary); + body += item.content + serialize_multipart_formdata_item_end(); + } + + if (finish) { + body += serialize_multipart_formdata_finish(boundary); + } + + return body; +} + +inline bool range_error(Request &req, Response &res) { + if (!req.ranges.empty() && 200 <= res.status && res.status < 300) { + ssize_t contant_len = static_cast(res.content_length_ ? res.content_length_ : res.body.size()); + + ssize_t prev_first_pos = -1; + ssize_t prev_last_pos = -1; + size_t overwrapping_count = 0; + + // NOTE: The following Range check is based on '14.2. Range' in RFC 9110 + // 'HTTP Semantics' to avoid potential denial-of-service attacks. + // https://www.rfc-editor.org/rfc/rfc9110#section-14.2 + + // Too many ranges + if (req.ranges.size() > CPPHTTPLIB_RANGE_MAX_COUNT) { + return true; + } + + for (auto &r : req.ranges) { + auto &first_pos = r.first; + auto &last_pos = r.second; + + if (first_pos == -1 && last_pos == -1) { + first_pos = 0; + last_pos = contant_len; + } + + if (first_pos == -1) { + first_pos = contant_len - last_pos; + last_pos = contant_len - 1; + } + + if (last_pos == -1) { + last_pos = contant_len - 1; + } + + // Range must be within content length + if (!(0 <= first_pos && first_pos <= last_pos && last_pos <= contant_len - 1)) { + return true; + } + + // Ranges must be in ascending order + if (first_pos <= prev_first_pos) { + return true; + } + + // Request must not have more than two overlapping ranges + if (first_pos <= prev_last_pos) { + overwrapping_count++; + if (overwrapping_count > 2) { + return true; + } + } + + prev_first_pos = (std::max)(prev_first_pos, first_pos); + prev_last_pos = (std::max)(prev_last_pos, last_pos); + } + } + + return false; +} + +inline std::pair get_range_offset_and_length(Range r, size_t content_length) { + assert(r.first != -1 && r.second != -1); + assert(0 <= r.first && r.first < static_cast(content_length)); + assert(r.first <= r.second && r.second < static_cast(content_length)); + (void) (content_length); + return std::make_pair(r.first, static_cast(r.second - r.first) + 1); +} + +inline std::string make_content_range_header_field(const std::pair &offset_and_length, + size_t content_length) { + auto st = offset_and_length.first; + auto ed = st + offset_and_length.second - 1; + + std::string field = "bytes "; + field += std::to_string(st); + field += "-"; + field += std::to_string(ed); + field += "/"; + field += std::to_string(content_length); + return field; +} + +template +bool process_multipart_ranges_data(const Request &req, const std::string &boundary, const std::string &content_type, + size_t content_length, SToken stoken, CToken ctoken, Content content) { + for (size_t i = 0; i < req.ranges.size(); i++) { + ctoken("--"); + stoken(boundary); + ctoken("\r\n"); + if (!content_type.empty()) { + ctoken("Content-Type: "); + stoken(content_type); + ctoken("\r\n"); + } + + auto offset_and_length = get_range_offset_and_length(req.ranges[i], content_length); + + ctoken("Content-Range: "); + stoken(make_content_range_header_field(offset_and_length, content_length)); + ctoken("\r\n"); + ctoken("\r\n"); + + if (!content(offset_and_length.first, offset_and_length.second)) { + return false; + } + ctoken("\r\n"); + } + + ctoken("--"); + stoken(boundary); + ctoken("--"); + + return true; +} + +inline void make_multipart_ranges_data(const Request &req, Response &res, const std::string &boundary, + const std::string &content_type, size_t content_length, std::string &data) { + process_multipart_ranges_data( + req, boundary, content_type, content_length, [&](const std::string &token) { data += token; }, + [&](const std::string &token) { data += token; }, + [&](size_t offset, size_t length) { + assert(offset + length <= content_length); + data += res.body.substr(offset, length); + return true; + }); +} + +inline size_t get_multipart_ranges_data_length(const Request &req, const std::string &boundary, + const std::string &content_type, size_t content_length) { + size_t data_length = 0; + + process_multipart_ranges_data( + req, boundary, content_type, content_length, [&](const std::string &token) { data_length += token.size(); }, + [&](const std::string &token) { data_length += token.size(); }, + [&](size_t /*offset*/, size_t length) { + data_length += length; + return true; + }); + + return data_length; +} + +template +inline bool write_multipart_ranges_data(Stream &strm, const Request &req, Response &res, const std::string &boundary, + const std::string &content_type, size_t content_length, + const T &is_shutting_down) { + return process_multipart_ranges_data( + req, boundary, content_type, content_length, [&](const std::string &token) { strm.write(token); }, + [&](const std::string &token) { strm.write(token); }, + [&](size_t offset, size_t length) { + return write_content(strm, res.content_provider_, offset, length, is_shutting_down); + }); +} + +inline bool expect_content(const Request &req) { + if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || req.method == "PRI" || + req.method == "DELETE") { + return true; + } + // TODO: check if Content-Length is set + return false; +} + +inline bool has_crlf(const std::string &s) { + auto p = s.c_str(); + while (*p) { + if (*p == '\r' || *p == '\n') { + return true; + } + p++; + } + return false; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline std::string message_digest(const std::string &s, const EVP_MD *algo) { + auto context = std::unique_ptr(EVP_MD_CTX_new(), EVP_MD_CTX_free); + + unsigned int hash_length = 0; + unsigned char hash[EVP_MAX_MD_SIZE]; + + EVP_DigestInit_ex(context.get(), algo, nullptr); + EVP_DigestUpdate(context.get(), s.c_str(), s.size()); + EVP_DigestFinal_ex(context.get(), hash, &hash_length); + + std::stringstream ss; + for (auto i = 0u; i < hash_length; ++i) { + ss << std::hex << std::setw(2) << std::setfill('0') << static_cast(hash[i]); + } + + return ss.str(); +} + +inline std::string MD5(const std::string &s) { return message_digest(s, EVP_md5()); } + +inline std::string SHA_256(const std::string &s) { return message_digest(s, EVP_sha256()); } + +inline std::string SHA_512(const std::string &s) { return message_digest(s, EVP_sha512()); } +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef _WIN32 +// NOTE: This code came up with the following stackoverflow post: +// https://stackoverflow.com/questions/9507184/can-openssl-on-windows-use-the-system-certificate-store +inline bool load_system_certs_on_windows(X509_STORE *store) { + auto hStore = CertOpenSystemStoreW((HCRYPTPROV_LEGACY) NULL, L"ROOT"); + if (!hStore) { + return false; + } + + auto result = false; + PCCERT_CONTEXT pContext = NULL; + while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) != nullptr) { + auto encoded_cert = static_cast(pContext->pbCertEncoded); + + auto x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded); + if (x509) { + X509_STORE_add_cert(store, x509); + X509_free(x509); + result = true; + } + } + + CertFreeCertificateContext(pContext); + CertCloseStore(hStore, 0); + + return result; +} +#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && defined(__APPLE__) +#if TARGET_OS_OSX +template using CFObjectPtr = std::unique_ptr::type, void (*)(CFTypeRef)>; + +inline void cf_object_ptr_deleter(CFTypeRef obj) { + if (obj) { + CFRelease(obj); + } +} + +inline bool retrieve_certs_from_keychain(CFObjectPtr &certs) { + CFStringRef keys[] = {kSecClass, kSecMatchLimit, kSecReturnRef}; + CFTypeRef values[] = {kSecClassCertificate, kSecMatchLimitAll, kCFBooleanTrue}; + + CFObjectPtr query( + CFDictionaryCreate(nullptr, reinterpret_cast(keys), values, sizeof(keys) / sizeof(keys[0]), + &kCFTypeDictionaryKeyCallBacks, &kCFTypeDictionaryValueCallBacks), + cf_object_ptr_deleter); + + if (!query) { + return false; + } + + CFTypeRef security_items = nullptr; + if (SecItemCopyMatching(query.get(), &security_items) != errSecSuccess || + CFArrayGetTypeID() != CFGetTypeID(security_items)) { + return false; + } + + certs.reset(reinterpret_cast(security_items)); + return true; +} + +inline bool retrieve_root_certs_from_keychain(CFObjectPtr &certs) { + CFArrayRef root_security_items = nullptr; + if (SecTrustCopyAnchorCertificates(&root_security_items) != errSecSuccess) { + return false; + } + + certs.reset(root_security_items); + return true; +} + +inline bool add_certs_to_x509_store(CFArrayRef certs, X509_STORE *store) { + auto result = false; + for (auto i = 0; i < CFArrayGetCount(certs); ++i) { + const auto cert = reinterpret_cast(CFArrayGetValueAtIndex(certs, i)); + + if (SecCertificateGetTypeID() != CFGetTypeID(cert)) { + continue; + } + + CFDataRef cert_data = nullptr; + if (SecItemExport(cert, kSecFormatX509Cert, 0, nullptr, &cert_data) != errSecSuccess) { + continue; + } + + CFObjectPtr cert_data_ptr(cert_data, cf_object_ptr_deleter); + + auto encoded_cert = static_cast(CFDataGetBytePtr(cert_data_ptr.get())); + + auto x509 = d2i_X509(NULL, &encoded_cert, CFDataGetLength(cert_data_ptr.get())); + + if (x509) { + X509_STORE_add_cert(store, x509); + X509_free(x509); + result = true; + } + } + + return result; +} + +inline bool load_system_certs_on_macos(X509_STORE *store) { + auto result = false; + CFObjectPtr certs(nullptr, cf_object_ptr_deleter); + if (retrieve_certs_from_keychain(certs) && certs) { + result = add_certs_to_x509_store(certs.get(), store); + } + + if (retrieve_root_certs_from_keychain(certs) && certs) { + result = add_certs_to_x509_store(certs.get(), store) || result; + } + + return result; +} +#endif // TARGET_OS_OSX +#endif // _WIN32 +#endif // CPPHTTPLIB_OPENSSL_SUPPORT + +#ifdef _WIN32 +class WSInit { + public: + WSInit() { + WSADATA wsaData; + if (WSAStartup(0x0002, &wsaData) == 0) + is_valid_ = true; + } + + ~WSInit() { + if (is_valid_) + WSACleanup(); + } + + bool is_valid_ = false; +}; + +static WSInit wsinit_; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline std::pair make_digest_authentication_header( + const Request &req, const std::map &auth, size_t cnonce_count, const std::string &cnonce, + const std::string &username, const std::string &password, bool is_proxy = false) { + std::string nc; + { + std::stringstream ss; + ss << std::setfill('0') << std::setw(8) << std::hex << cnonce_count; + nc = ss.str(); + } + + std::string qop; + if (auth.find("qop") != auth.end()) { + qop = auth.at("qop"); + if (qop.find("auth-int") != std::string::npos) { + qop = "auth-int"; + } else if (qop.find("auth") != std::string::npos) { + qop = "auth"; + } else { + qop.clear(); + } + } + + std::string algo = "MD5"; + if (auth.find("algorithm") != auth.end()) { + algo = auth.at("algorithm"); + } + + std::string response; + { + auto H = algo == "SHA-256" ? detail::SHA_256 : algo == "SHA-512" ? detail::SHA_512 : detail::MD5; + + auto A1 = username + ":" + auth.at("realm") + ":" + password; + + auto A2 = req.method + ":" + req.path; + if (qop == "auth-int") { + A2 += ":" + H(req.body); + } + + if (qop.empty()) { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + H(A2)); + } else { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + ":" + qop + ":" + H(A2)); + } + } + + auto opaque = (auth.find("opaque") != auth.end()) ? auth.at("opaque") : ""; + + auto field = + "Digest username=\"" + username + "\", realm=\"" + auth.at("realm") + "\", nonce=\"" + auth.at("nonce") + + "\", uri=\"" + req.path + "\", algorithm=" + algo + + (qop.empty() ? ", response=\"" : ", qop=" + qop + ", nc=" + nc + ", cnonce=\"" + cnonce + "\", response=\"") + + response + "\"" + (opaque.empty() ? "" : ", opaque=\"" + opaque + "\""); + + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); +} +#endif + +inline bool parse_www_authenticate(const Response &res, std::map &auth, bool is_proxy) { + auto auth_key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate"; + if (res.has_header(auth_key)) { + static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); + auto s = res.get_header_value(auth_key); + auto pos = s.find(' '); + if (pos != std::string::npos) { + auto type = s.substr(0, pos); + if (type == "Basic") { + return false; + } else if (type == "Digest") { + s = s.substr(pos + 1); + auto beg = std::sregex_iterator(s.begin(), s.end(), re); + for (auto i = beg; i != std::sregex_iterator(); ++i) { + const auto &m = *i; + auto key = s.substr(static_cast(m.position(1)), static_cast(m.length(1))); + auto val = m.length(2) > 0 ? s.substr(static_cast(m.position(2)), static_cast(m.length(2))) + : s.substr(static_cast(m.position(3)), static_cast(m.length(3))); + auth[key] = val; + } + return true; + } + } + } + return false; +} + +class ContentProviderAdapter { + public: + explicit ContentProviderAdapter(ContentProviderWithoutLength &&content_provider) + : content_provider_(content_provider) {} + + bool operator()(size_t offset, size_t, DataSink &sink) { return content_provider_(offset, sink); } + + private: + ContentProviderWithoutLength content_provider_; +}; + +} // namespace detail + +inline std::string hosted_at(const std::string &hostname) { + std::vector addrs; + hosted_at(hostname, addrs); + if (addrs.empty()) { + return std::string(); + } + return addrs[0]; +} + +inline void hosted_at(const std::string &hostname, std::vector &addrs) { + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + if (getaddrinfo(hostname.c_str(), nullptr, &hints, &result)) { +#if defined __linux__ && !defined __ANDROID__ + res_init(); +#endif + return; + } + auto se = detail::scope_exit([&] { freeaddrinfo(result); }); + + for (auto rp = result; rp; rp = rp->ai_next) { + const auto &addr = *reinterpret_cast(rp->ai_addr); + std::string ip; + auto dummy = -1; + if (detail::get_ip_and_port(addr, sizeof(struct sockaddr_storage), ip, dummy)) { + addrs.push_back(ip); + } + } +} + +inline std::string append_query_params(const std::string &path, const Params ¶ms) { + std::string path_with_query = path; + const static std::regex re("[^?]+\\?.*"); + auto delm = std::regex_match(path, re) ? '&' : '?'; + path_with_query += delm + detail::params_to_query_str(params); + return path_with_query; +} + +// Header utilities +inline std::pair make_range_header(const Ranges &ranges) { + std::string field = "bytes="; + auto i = 0; + for (const auto &r : ranges) { + if (i != 0) { + field += ", "; + } + if (r.first != -1) { + field += std::to_string(r.first); + } + field += '-'; + if (r.second != -1) { + field += std::to_string(r.second); + } + i++; + } + return std::make_pair("Range", std::move(field)); +} + +inline std::pair make_basic_authentication_header(const std::string &username, + const std::string &password, + bool is_proxy) { + auto field = "Basic " + detail::base64_encode(username + ":" + password); + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, std::move(field)); +} + +inline std::pair make_bearer_token_authentication_header(const std::string &token, + bool is_proxy = false) { + auto field = "Bearer " + token; + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, std::move(field)); +} + +// Request implementation +inline bool Request::has_header(const std::string &key) const { return detail::has_header(headers, key); } + +inline std::string Request::get_header_value(const std::string &key, const char *def, size_t id) const { + return detail::get_header_value(headers, key, def, id); +} + +inline size_t Request::get_header_value_count(const std::string &key) const { + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline void Request::set_header(const std::string &key, const std::string &val) { + if (!detail::has_crlf(key) && !detail::has_crlf(val)) { + headers.emplace(key, val); + } +} + +inline bool Request::has_param(const std::string &key) const { return params.find(key) != params.end(); } + +inline std::string Request::get_param_value(const std::string &key, size_t id) const { + auto rng = params.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { + return it->second; + } + return std::string(); +} + +inline size_t Request::get_param_value_count(const std::string &key) const { + auto r = params.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline bool Request::is_multipart_form_data() const { + const auto &content_type = get_header_value("Content-Type"); + return !content_type.rfind("multipart/form-data", 0); +} + +inline bool Request::has_file(const std::string &key) const { return files.find(key) != files.end(); } + +inline MultipartFormData Request::get_file_value(const std::string &key) const { + auto it = files.find(key); + if (it != files.end()) { + return it->second; + } + return MultipartFormData(); +} + +inline std::vector Request::get_file_values(const std::string &key) const { + std::vector values; + auto rng = files.equal_range(key); + for (auto it = rng.first; it != rng.second; it++) { + values.push_back(it->second); + } + return values; +} + +// Response implementation +inline bool Response::has_header(const std::string &key) const { return headers.find(key) != headers.end(); } + +inline std::string Response::get_header_value(const std::string &key, const char *def, size_t id) const { + return detail::get_header_value(headers, key, def, id); +} + +inline size_t Response::get_header_value_count(const std::string &key) const { + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline void Response::set_header(const std::string &key, const std::string &val) { + if (!detail::has_crlf(key) && !detail::has_crlf(val)) { + headers.emplace(key, val); + } +} + +inline void Response::set_redirect(const std::string &url, int stat) { + if (!detail::has_crlf(url)) { + set_header("Location", url); + if (300 <= stat && stat < 400) { + this->status = stat; + } else { + this->status = StatusCode::Found_302; + } + } +} + +inline void Response::set_content(const char *s, size_t n, const std::string &content_type) { + body.assign(s, n); + + auto rng = headers.equal_range("Content-Type"); + headers.erase(rng.first, rng.second); + set_header("Content-Type", content_type); +} + +inline void Response::set_content(const std::string &s, const std::string &content_type) { + set_content(s.data(), s.size(), content_type); +} + +inline void Response::set_content(std::string &&s, const std::string &content_type) { + body = std::move(s); + + auto rng = headers.equal_range("Content-Type"); + headers.erase(rng.first, rng.second); + set_header("Content-Type", content_type); +} + +inline void Response::set_content_provider(size_t in_length, const std::string &content_type, ContentProvider provider, + ContentProviderResourceReleaser resource_releaser) { + set_header("Content-Type", content_type); + content_length_ = in_length; + if (in_length > 0) { + content_provider_ = std::move(provider); + } + content_provider_resource_releaser_ = std::move(resource_releaser); + is_chunked_content_provider_ = false; +} + +inline void Response::set_content_provider(const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser) { + set_header("Content-Type", content_type); + content_length_ = 0; + content_provider_ = detail::ContentProviderAdapter(std::move(provider)); + content_provider_resource_releaser_ = std::move(resource_releaser); + is_chunked_content_provider_ = false; +} + +inline void Response::set_chunked_content_provider(const std::string &content_type, + ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser) { + set_header("Content-Type", content_type); + content_length_ = 0; + content_provider_ = detail::ContentProviderAdapter(std::move(provider)); + content_provider_resource_releaser_ = std::move(resource_releaser); + is_chunked_content_provider_ = true; +} + +inline void Response::set_file_content(const std::string &path, const std::string &content_type) { + file_content_path_ = path; + file_content_content_type_ = content_type; +} + +inline void Response::set_file_content(const std::string &path) { file_content_path_ = path; } + +// Result implementation +inline bool Result::has_request_header(const std::string &key) const { + return request_headers_.find(key) != request_headers_.end(); +} + +inline std::string Result::get_request_header_value(const std::string &key, const char *def, size_t id) const { + return detail::get_header_value(request_headers_, key, def, id); +} + +inline size_t Result::get_request_header_value_count(const std::string &key) const { + auto r = request_headers_.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +// Stream implementation +inline ssize_t Stream::write(const char *ptr) { return write(ptr, strlen(ptr)); } + +inline ssize_t Stream::write(const std::string &s) { return write(s.data(), s.size()); } + +namespace detail { + +// Socket stream implementation +inline SocketStream::SocketStream(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec) + : sock_(sock), + read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec), + write_timeout_sec_(write_timeout_sec), + write_timeout_usec_(write_timeout_usec), + read_buff_(read_buff_size_, 0) {} + +inline SocketStream::~SocketStream() = default; + +inline bool SocketStream::is_readable() const { return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; } + +inline bool SocketStream::is_writable() const { + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && is_socket_alive(sock_); +} + +inline ssize_t SocketStream::read(char *ptr, size_t size) { +#ifdef _WIN32 + size = (std::min)(size, static_cast((std::numeric_limits::max)())); +#else + size = (std::min)(size, static_cast((std::numeric_limits::max)())); +#endif + + if (read_buff_off_ < read_buff_content_size_) { + auto remaining_size = read_buff_content_size_ - read_buff_off_; + if (size <= remaining_size) { + memcpy(ptr, read_buff_.data() + read_buff_off_, size); + read_buff_off_ += size; + return static_cast(size); + } else { + memcpy(ptr, read_buff_.data() + read_buff_off_, remaining_size); + read_buff_off_ += remaining_size; + return static_cast(remaining_size); + } + } + + if (!is_readable()) { + return -1; + } + + read_buff_off_ = 0; + read_buff_content_size_ = 0; + + if (size < read_buff_size_) { + auto n = read_socket(sock_, read_buff_.data(), read_buff_size_, CPPHTTPLIB_RECV_FLAGS); + if (n <= 0) { + return n; + } else if (n <= static_cast(size)) { + memcpy(ptr, read_buff_.data(), static_cast(n)); + return n; + } else { + memcpy(ptr, read_buff_.data(), size); + read_buff_off_ = size; + read_buff_content_size_ = static_cast(n); + return static_cast(size); + } + } else { + return read_socket(sock_, ptr, size, CPPHTTPLIB_RECV_FLAGS); + } +} + +inline ssize_t SocketStream::write(const char *ptr, size_t size) { + if (!is_writable()) { + return -1; + } + +#if defined(_WIN32) && !defined(_WIN64) + size = (std::min)(size, static_cast((std::numeric_limits::max)())); +#endif + + return send_socket(sock_, ptr, size, CPPHTTPLIB_SEND_FLAGS); +} + +inline void SocketStream::get_remote_ip_and_port(std::string &ip, int &port) const { + return detail::get_remote_ip_and_port(sock_, ip, port); +} + +inline void SocketStream::get_local_ip_and_port(std::string &ip, int &port) const { + return detail::get_local_ip_and_port(sock_, ip, port); +} + +inline socket_t SocketStream::socket() const { return sock_; } + +// Buffer stream implementation +inline bool BufferStream::is_readable() const { return true; } + +inline bool BufferStream::is_writable() const { return true; } + +inline ssize_t BufferStream::read(char *ptr, size_t size) { +#if defined(_MSC_VER) && _MSC_VER < 1910 + auto len_read = buffer._Copy_s(ptr, size, size, position); +#else + auto len_read = buffer.copy(ptr, size, position); +#endif + position += static_cast(len_read); + return static_cast(len_read); +} + +inline ssize_t BufferStream::write(const char *ptr, size_t size) { + buffer.append(ptr, size); + return static_cast(size); +} + +inline void BufferStream::get_remote_ip_and_port(std::string & /*ip*/, int & /*port*/) const {} + +inline void BufferStream::get_local_ip_and_port(std::string & /*ip*/, int & /*port*/) const {} + +inline socket_t BufferStream::socket() const { return 0; } + +inline const std::string &BufferStream::get_buffer() const { return buffer; } + +inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern) { + static constexpr char marker[] = "/:"; + + // One past the last ending position of a path param substring + std::size_t last_param_end = 0; + +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + // Needed to ensure that parameter names are unique during matcher + // construction + // If exceptions are disabled, only last duplicate path + // parameter will be set + std::unordered_set param_name_set; +#endif + + while (true) { + const auto marker_pos = pattern.find(marker, last_param_end == 0 ? last_param_end : last_param_end - 1); + if (marker_pos == std::string::npos) { + break; + } + + static_fragments_.push_back(pattern.substr(last_param_end, marker_pos - last_param_end + 1)); + + const auto param_name_start = marker_pos + 2; + + auto sep_pos = pattern.find(separator, param_name_start); + if (sep_pos == std::string::npos) { + sep_pos = pattern.length(); + } + + auto param_name = pattern.substr(param_name_start, sep_pos - param_name_start); + +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + if (param_name_set.find(param_name) != param_name_set.cend()) { + std::string msg = + "Encountered path parameter '" + param_name + "' multiple times in route pattern '" + pattern + "'."; + throw std::invalid_argument(msg); + } +#endif + + param_names_.push_back(std::move(param_name)); + + last_param_end = sep_pos + 1; + } + + if (last_param_end < pattern.length()) { + static_fragments_.push_back(pattern.substr(last_param_end)); + } +} + +inline bool PathParamsMatcher::match(Request &request) const { + request.matches = std::smatch(); + request.path_params.clear(); + request.path_params.reserve(param_names_.size()); + + // One past the position at which the path matched the pattern last time + std::size_t starting_pos = 0; + for (size_t i = 0; i < static_fragments_.size(); ++i) { + const auto &fragment = static_fragments_[i]; + + if (starting_pos + fragment.length() > request.path.length()) { + return false; + } + + // Avoid unnecessary allocation by using strncmp instead of substr + + // comparison + if (std::strncmp(request.path.c_str() + starting_pos, fragment.c_str(), fragment.length()) != 0) { + return false; + } + + starting_pos += fragment.length(); + + // Should only happen when we have a static fragment after a param + // Example: '/users/:id/subscriptions' + // The 'subscriptions' fragment here does not have a corresponding param + if (i >= param_names_.size()) { + continue; + } + + auto sep_pos = request.path.find(separator, starting_pos); + if (sep_pos == std::string::npos) { + sep_pos = request.path.length(); + } + + const auto ¶m_name = param_names_[i]; + + request.path_params.emplace(param_name, request.path.substr(starting_pos, sep_pos - starting_pos)); + + // Mark everything up to '/' as matched + starting_pos = sep_pos + 1; + } + // Returns false if the path is longer than the pattern + return starting_pos >= request.path.length(); +} + +inline bool RegexMatcher::match(Request &request) const { + request.path_params.clear(); + return std::regex_match(request.path, request.matches, regex_); +} + +} // namespace detail + +// HTTP server implementation +inline Server::Server() : new_task_queue([] { return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); }) { +#ifndef _WIN32 + signal(SIGPIPE, SIG_IGN); +#endif +} + +inline Server::~Server() = default; + +inline std::unique_ptr Server::make_matcher(const std::string &pattern) { + if (pattern.find("/:") != std::string::npos) { + return detail::make_unique(pattern); + } else { + return detail::make_unique(pattern); + } +} + +inline Server &Server::Get(const std::string &pattern, Handler handler) { + get_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Post(const std::string &pattern, Handler handler) { + post_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Post(const std::string &pattern, HandlerWithContentReader handler) { + post_handlers_for_content_reader_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Put(const std::string &pattern, Handler handler) { + put_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Put(const std::string &pattern, HandlerWithContentReader handler) { + put_handlers_for_content_reader_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Patch(const std::string &pattern, Handler handler) { + patch_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Patch(const std::string &pattern, HandlerWithContentReader handler) { + patch_handlers_for_content_reader_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Delete(const std::string &pattern, Handler handler) { + delete_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Delete(const std::string &pattern, HandlerWithContentReader handler) { + delete_handlers_for_content_reader_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Options(const std::string &pattern, Handler handler) { + options_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline bool Server::set_base_dir(const std::string &dir, const std::string &mount_point) { + return set_mount_point(mount_point, dir); +} + +inline bool Server::set_mount_point(const std::string &mount_point, const std::string &dir, Headers headers) { + detail::FileStat stat(dir); + if (stat.is_dir()) { + std::string mnt = !mount_point.empty() ? mount_point : "/"; + if (!mnt.empty() && mnt[0] == '/') { + base_dirs_.push_back({mnt, dir, std::move(headers)}); + return true; + } + } + return false; +} + +inline bool Server::remove_mount_point(const std::string &mount_point) { + for (auto it = base_dirs_.begin(); it != base_dirs_.end(); ++it) { + if (it->mount_point == mount_point) { + base_dirs_.erase(it); + return true; + } + } + return false; +} + +inline Server &Server::set_file_extension_and_mimetype_mapping(const std::string &ext, const std::string &mime) { + file_extension_and_mimetype_map_[ext] = mime; + return *this; +} + +inline Server &Server::set_default_file_mimetype(const std::string &mime) { + default_file_mimetype_ = mime; + return *this; +} + +inline Server &Server::set_file_request_handler(Handler handler) { + file_request_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_error_handler_core(HandlerWithResponse handler, std::true_type) { + error_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_error_handler_core(Handler handler, std::false_type) { + error_handler_ = [handler](const Request &req, Response &res) { + handler(req, res); + return HandlerResponse::Handled; + }; + return *this; +} + +inline Server &Server::set_exception_handler(ExceptionHandler handler) { + exception_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_pre_routing_handler(HandlerWithResponse handler) { + pre_routing_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_post_routing_handler(Handler handler) { + post_routing_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_logger(Logger logger) { + logger_ = std::move(logger); + return *this; +} + +inline Server &Server::set_expect_100_continue_handler(Expect100ContinueHandler handler) { + expect_100_continue_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_address_family(int family) { + address_family_ = family; + return *this; +} + +inline Server &Server::set_tcp_nodelay(bool on) { + tcp_nodelay_ = on; + return *this; +} + +inline Server &Server::set_ipv6_v6only(bool on) { + ipv6_v6only_ = on; + return *this; +} + +inline Server &Server::set_socket_options(SocketOptions socket_options) { + socket_options_ = std::move(socket_options); + return *this; +} + +inline Server &Server::set_default_headers(Headers headers) { + default_headers_ = std::move(headers); + return *this; +} + +inline Server &Server::set_header_writer(std::function const &writer) { + header_writer_ = writer; + return *this; +} + +inline Server &Server::set_keep_alive_max_count(size_t count) { + keep_alive_max_count_ = count; + return *this; +} + +inline Server &Server::set_keep_alive_timeout(time_t sec) { + keep_alive_timeout_sec_ = sec; + return *this; +} + +inline Server &Server::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; + return *this; +} + +inline Server &Server::set_write_timeout(time_t sec, time_t usec) { + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; + return *this; +} + +inline Server &Server::set_idle_interval(time_t sec, time_t usec) { + idle_interval_sec_ = sec; + idle_interval_usec_ = usec; + return *this; +} + +inline Server &Server::set_payload_max_length(size_t length) { + payload_max_length_ = length; + return *this; +} + +inline bool Server::bind_to_port(const std::string &host, int port, int socket_flags) { + auto ret = bind_internal(host, port, socket_flags); + if (ret == -1) { + is_decommisioned = true; + } + return ret >= 0; +} +inline int Server::bind_to_any_port(const std::string &host, int socket_flags) { + auto ret = bind_internal(host, 0, socket_flags); + if (ret == -1) { + is_decommisioned = true; + } + return ret; +} + +inline bool Server::listen_after_bind() { return listen_internal(); } + +inline bool Server::listen(const std::string &host, int port, int socket_flags) { + return bind_to_port(host, port, socket_flags) && listen_internal(); +} + +inline bool Server::is_running() const { return is_running_; } + +inline void Server::wait_until_ready() const { + while (!is_running_ && !is_decommisioned) { + std::this_thread::sleep_for(std::chrono::milliseconds{1}); + } +} + +inline void Server::stop() { + if (is_running_) { + assert(svr_sock_ != INVALID_SOCKET); + std::atomic sock(svr_sock_.exchange(INVALID_SOCKET)); + detail::shutdown_socket(sock); + detail::close_socket(sock); + } + is_decommisioned = false; +} + +inline void Server::decommission() { is_decommisioned = true; } + +inline bool Server::parse_request_line(const char *s, Request &req) const { + auto len = strlen(s); + if (len < 2 || s[len - 2] != '\r' || s[len - 1] != '\n') { + return false; + } + len -= 2; + + { + size_t count = 0; + + detail::split(s, s + len, ' ', [&](const char *b, const char *e) { + switch (count) { + case 0: + req.method = std::string(b, e); + break; + case 1: + req.target = std::string(b, e); + break; + case 2: + req.version = std::string(b, e); + break; + default: + break; + } + count++; + }); + + if (count != 3) { + return false; + } + } + + static const std::set methods{"GET", "HEAD", "POST", "PUT", "DELETE", + "CONNECT", "OPTIONS", "TRACE", "PATCH", "PRI"}; + + if (methods.find(req.method) == methods.end()) { + return false; + } + + if (req.version != "HTTP/1.1" && req.version != "HTTP/1.0") { + return false; + } + + { + // Skip URL fragment + for (size_t i = 0; i < req.target.size(); i++) { + if (req.target[i] == '#') { + req.target.erase(i); + break; + } + } + + detail::divide(req.target, '?', + [&](const char *lhs_data, std::size_t lhs_size, const char *rhs_data, std::size_t rhs_size) { + req.path = detail::decode_url(std::string(lhs_data, lhs_size), false); + detail::parse_query_text(rhs_data, rhs_size, req.params); + }); + } + + return true; +} + +inline bool Server::write_response(Stream &strm, bool close_connection, Request &req, Response &res) { + // NOTE: `req.ranges` should be empty, otherwise it will be applied + // incorrectly to the error content. + req.ranges.clear(); + return write_response_core(strm, close_connection, req, res, false); +} + +inline bool Server::write_response_with_content(Stream &strm, bool close_connection, const Request &req, + Response &res) { + return write_response_core(strm, close_connection, req, res, true); +} + +inline bool Server::write_response_core(Stream &strm, bool close_connection, const Request &req, Response &res, + bool need_apply_ranges) { + assert(res.status != -1); + + if (400 <= res.status && error_handler_ && error_handler_(req, res) == HandlerResponse::Handled) { + need_apply_ranges = true; + } + + std::string content_type; + std::string boundary; + if (need_apply_ranges) { + apply_ranges(req, res, content_type, boundary); + } + + // Prepare additional headers + if (close_connection || req.get_header_value("Connection") == "close") { + res.set_header("Connection", "close"); + } else { + std::string s = "timeout="; + s += std::to_string(keep_alive_timeout_sec_); + s += ", max="; + s += std::to_string(keep_alive_max_count_); + res.set_header("Keep-Alive", s); + } + + if ((!res.body.empty() || res.content_length_ > 0 || res.content_provider_) && !res.has_header("Content-Type")) { + res.set_header("Content-Type", "text/plain"); + } + + if (res.body.empty() && !res.content_length_ && !res.content_provider_ && !res.has_header("Content-Length")) { + res.set_header("Content-Length", "0"); + } + + if (req.method == "HEAD" && !res.has_header("Accept-Ranges")) { + res.set_header("Accept-Ranges", "bytes"); + } + + if (post_routing_handler_) { + post_routing_handler_(req, res); + } + + // Response line and headers + { + detail::BufferStream bstrm; + if (!detail::write_response_line(bstrm, res.status)) { + return false; + } + if (!header_writer_(bstrm, res.headers)) { + return false; + } + + // Flush buffer + auto &data = bstrm.get_buffer(); + detail::write_data(strm, data.data(), data.size()); + } + + // Body + auto ret = true; + if (req.method != "HEAD") { + if (!res.body.empty()) { + if (!detail::write_data(strm, res.body.data(), res.body.size())) { + ret = false; + } + } else if (res.content_provider_) { + if (write_content_with_provider(strm, req, res, boundary, content_type)) { + res.content_provider_success_ = true; + } else { + ret = false; + } + } + } + + // Log + if (logger_) { + logger_(req, res); + } + + return ret; +} + +inline bool Server::write_content_with_provider(Stream &strm, const Request &req, Response &res, + const std::string &boundary, const std::string &content_type) { + auto is_shutting_down = [this]() { return this->svr_sock_ == INVALID_SOCKET; }; + + if (res.content_length_ > 0) { + if (req.ranges.empty()) { + return detail::write_content(strm, res.content_provider_, 0, res.content_length_, is_shutting_down); + } else if (req.ranges.size() == 1) { + auto offset_and_length = detail::get_range_offset_and_length(req.ranges[0], res.content_length_); + + return detail::write_content(strm, res.content_provider_, offset_and_length.first, offset_and_length.second, + is_shutting_down); + } else { + return detail::write_multipart_ranges_data(strm, req, res, boundary, content_type, res.content_length_, + is_shutting_down); + } + } else { + if (res.is_chunked_content_provider_) { + auto type = detail::encoding_type(req, res); + + std::unique_ptr compressor; + if (type == detail::EncodingType::Gzip) { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + compressor = detail::make_unique(); +#endif + } else if (type == detail::EncodingType::Brotli) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + compressor = detail::make_unique(); +#endif + } else { + compressor = detail::make_unique(); + } + assert(compressor != nullptr); + + return detail::write_content_chunked(strm, res.content_provider_, is_shutting_down, *compressor); + } else { + return detail::write_content_without_length(strm, res.content_provider_, is_shutting_down); + } + } +} + +inline bool Server::read_content(Stream &strm, Request &req, Response &res) { + MultipartFormDataMap::iterator cur; + auto file_count = 0; + if (read_content_core( + strm, req, res, + // Regular + [&](const char *buf, size_t n) { + if (req.body.size() + n > req.body.max_size()) { + return false; + } + req.body.append(buf, n); + return true; + }, + // Multipart + [&](const MultipartFormData &file) { + if (file_count++ == CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT) { + return false; + } + cur = req.files.emplace(file.name, file); + return true; + }, + [&](const char *buf, size_t n) { + auto &content = cur->second.content; + if (content.size() + n > content.max_size()) { + return false; + } + content.append(buf, n); + return true; + })) { + const auto &content_type = req.get_header_value("Content-Type"); + if (!content_type.find("application/x-www-form-urlencoded")) { + if (req.body.size() > CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH) { + res.status = StatusCode::PayloadTooLarge_413; // NOTE: should be 414? + return false; + } + detail::parse_query_text(req.body, req.params); + } + return true; + } + return false; +} + +inline bool Server::read_content_with_content_receiver(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver) { + return read_content_core(strm, req, res, std::move(receiver), std::move(multipart_header), + std::move(multipart_receiver)); +} + +inline bool Server::read_content_core(Stream &strm, Request &req, Response &res, ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver) const { + detail::MultipartFormDataParser multipart_form_data_parser; + ContentReceiverWithProgress out; + + if (req.is_multipart_form_data()) { + const auto &content_type = req.get_header_value("Content-Type"); + std::string boundary; + if (!detail::parse_multipart_boundary(content_type, boundary)) { + res.status = StatusCode::BadRequest_400; + return false; + } + + multipart_form_data_parser.set_boundary(std::move(boundary)); + out = [&](const char *buf, size_t n, uint64_t /*off*/, uint64_t /*len*/) { + /* For debug + size_t pos = 0; + while (pos < n) { + auto read_size = (std::min)(1, n - pos); + auto ret = multipart_form_data_parser.parse( + buf + pos, read_size, multipart_receiver, multipart_header); + if (!ret) { return false; } + pos += read_size; + } + return true; + */ + return multipart_form_data_parser.parse(buf, n, multipart_receiver, multipart_header); + }; + } else { + out = [receiver](const char *buf, size_t n, uint64_t /*off*/, uint64_t /*len*/) { return receiver(buf, n); }; + } + + if (req.method == "DELETE" && !req.has_header("Content-Length")) { + return true; + } + + if (!detail::read_content(strm, req, payload_max_length_, res.status, nullptr, out, true)) { + return false; + } + + if (req.is_multipart_form_data()) { + if (!multipart_form_data_parser.is_valid()) { + res.status = StatusCode::BadRequest_400; + return false; + } + } + + return true; +} + +inline bool Server::handle_file_request(const Request &req, Response &res, bool head) { + for (const auto &entry : base_dirs_) { + // Prefix match + if (!req.path.compare(0, entry.mount_point.size(), entry.mount_point)) { + std::string sub_path = "/" + req.path.substr(entry.mount_point.size()); + if (detail::is_valid_path(sub_path)) { + auto path = entry.base_dir + sub_path; + if (path.back() == '/') { + path += "index.html"; + } + + detail::FileStat stat(path); + + if (stat.is_dir()) { + res.set_redirect(sub_path + "/", StatusCode::MovedPermanently_301); + return true; + } + + if (stat.is_file()) { + for (const auto &kv : entry.headers) { + res.set_header(kv.first, kv.second); + } + + auto mm = std::make_shared(path.c_str()); + if (!mm->is_open()) { + return false; + } + + res.set_content_provider( + mm->size(), detail::find_content_type(path, file_extension_and_mimetype_map_, default_file_mimetype_), + [mm](size_t offset, size_t length, DataSink &sink) -> bool { + sink.write(mm->data() + offset, length); + return true; + }); + + if (!head && file_request_handler_) { + file_request_handler_(req, res); + } + + return true; + } + } + } + } + return false; +} + +inline socket_t Server::create_server_socket(const std::string &host, int port, int socket_flags, + SocketOptions socket_options) const { + return detail::create_socket(host, std::string(), port, address_family_, socket_flags, tcp_nodelay_, ipv6_v6only_, + std::move(socket_options), + [](socket_t sock, struct addrinfo &ai, bool & /*quit*/) -> bool { + if (::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + return false; + } + if (::listen(sock, CPPHTTPLIB_LISTEN_BACKLOG)) { + return false; + } + return true; + }); +} + +inline int Server::bind_internal(const std::string &host, int port, int socket_flags) { + if (is_decommisioned) { + return -1; + } + + if (!is_valid()) { + return -1; + } + + svr_sock_ = create_server_socket(host, port, socket_flags, socket_options_); + if (svr_sock_ == INVALID_SOCKET) { + return -1; + } + + if (port == 0) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + if (getsockname(svr_sock_, reinterpret_cast(&addr), &addr_len) == -1) { + return -1; + } + if (addr.ss_family == AF_INET) { + return ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + return ntohs(reinterpret_cast(&addr)->sin6_port); + } else { + return -1; + } + } else { + return port; + } +} + +inline bool Server::listen_internal() { + if (is_decommisioned) { + return false; + } + + auto ret = true; + is_running_ = true; + auto se = detail::scope_exit([&]() { is_running_ = false; }); + + { + std::unique_ptr task_queue(new_task_queue()); + + while (svr_sock_ != INVALID_SOCKET) { +#ifndef _WIN32 + if (idle_interval_sec_ > 0 || idle_interval_usec_ > 0) { +#endif + auto val = detail::select_read(svr_sock_, idle_interval_sec_, idle_interval_usec_); + if (val == 0) { // Timeout + task_queue->on_idle(); + continue; + } +#ifndef _WIN32 + } +#endif + +#if defined _WIN32 + // sockets conneced via WASAccept inherit flags NO_HANDLE_INHERIT, + // OVERLAPPED + socket_t sock = WSAAccept(svr_sock_, nullptr, nullptr, nullptr, 0); +#elif defined SOCK_CLOEXEC + socket_t sock = accept4(svr_sock_, nullptr, nullptr, SOCK_CLOEXEC); +#else + socket_t sock = accept(svr_sock_, nullptr, nullptr); +#endif + + if (sock == INVALID_SOCKET) { + if (errno == EMFILE) { + // The per-process limit of open file descriptors has been reached. + // Try to accept new connections after a short sleep. + std::this_thread::sleep_for(std::chrono::microseconds{1}); + continue; + } else if (errno == EINTR || errno == EAGAIN) { + continue; + } + if (svr_sock_ != INVALID_SOCKET) { + detail::close_socket(svr_sock_); + ret = false; + } else { + ; // The server socket was closed by user. + } + break; + } + + { +#ifdef _WIN32 + auto timeout = static_cast(read_timeout_sec_ * 1000 + read_timeout_usec_ / 1000); + setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&timeout), sizeof(timeout)); +#else + timeval tv; + tv.tv_sec = static_cast(read_timeout_sec_); + tv.tv_usec = static_cast(read_timeout_usec_); + setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&tv), sizeof(tv)); +#endif + } + { +#ifdef _WIN32 + auto timeout = static_cast(write_timeout_sec_ * 1000 + write_timeout_usec_ / 1000); + setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast(&timeout), sizeof(timeout)); +#else + timeval tv; + tv.tv_sec = static_cast(write_timeout_sec_); + tv.tv_usec = static_cast(write_timeout_usec_); + setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast(&tv), sizeof(tv)); +#endif + } + + if (!task_queue->enqueue([this, sock]() { process_and_close_socket(sock); })) { + detail::shutdown_socket(sock); + detail::close_socket(sock); + } + } + + task_queue->shutdown(); + } + + is_decommisioned = !ret; + return ret; +} + +inline bool Server::routing(Request &req, Response &res, Stream &strm) { + if (pre_routing_handler_ && pre_routing_handler_(req, res) == HandlerResponse::Handled) { + return true; + } + + // File handler + auto is_head_request = req.method == "HEAD"; + if ((req.method == "GET" || is_head_request) && handle_file_request(req, res, is_head_request)) { + return true; + } + + if (detail::expect_content(req)) { + // Content reader handler + { + ContentReader reader( + [&](ContentReceiver receiver) { + return read_content_with_content_receiver(strm, req, res, std::move(receiver), nullptr, nullptr); + }, + [&](MultipartContentHeader header, ContentReceiver receiver) { + return read_content_with_content_receiver(strm, req, res, nullptr, std::move(header), std::move(receiver)); + }); + + if (req.method == "POST") { + if (dispatch_request_for_content_reader(req, res, std::move(reader), post_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PUT") { + if (dispatch_request_for_content_reader(req, res, std::move(reader), put_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PATCH") { + if (dispatch_request_for_content_reader(req, res, std::move(reader), patch_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "DELETE") { + if (dispatch_request_for_content_reader(req, res, std::move(reader), delete_handlers_for_content_reader_)) { + return true; + } + } + } + + // Read content into `req.body` + if (!read_content(strm, req, res)) { + return false; + } + } + + // Regular handler + if (req.method == "GET" || req.method == "HEAD") { + return dispatch_request(req, res, get_handlers_); + } else if (req.method == "POST") { + return dispatch_request(req, res, post_handlers_); + } else if (req.method == "PUT") { + return dispatch_request(req, res, put_handlers_); + } else if (req.method == "DELETE") { + return dispatch_request(req, res, delete_handlers_); + } else if (req.method == "OPTIONS") { + return dispatch_request(req, res, options_handlers_); + } else if (req.method == "PATCH") { + return dispatch_request(req, res, patch_handlers_); + } + + res.status = StatusCode::BadRequest_400; + return false; +} + +inline bool Server::dispatch_request(Request &req, Response &res, const Handlers &handlers) const { + for (const auto &x : handlers) { + const auto &matcher = x.first; + const auto &handler = x.second; + + if (matcher->match(req)) { + handler(req, res); + return true; + } + } + return false; +} + +inline void Server::apply_ranges(const Request &req, Response &res, std::string &content_type, + std::string &boundary) const { + if (req.ranges.size() > 1 && res.status == StatusCode::PartialContent_206) { + auto it = res.headers.find("Content-Type"); + if (it != res.headers.end()) { + content_type = it->second; + res.headers.erase(it); + } + + boundary = detail::make_multipart_data_boundary(); + + res.set_header("Content-Type", "multipart/byteranges; boundary=" + boundary); + } + + auto type = detail::encoding_type(req, res); + + if (res.body.empty()) { + if (res.content_length_ > 0) { + size_t length = 0; + if (req.ranges.empty() || res.status != StatusCode::PartialContent_206) { + length = res.content_length_; + } else if (req.ranges.size() == 1) { + auto offset_and_length = detail::get_range_offset_and_length(req.ranges[0], res.content_length_); + + length = offset_and_length.second; + + auto content_range = detail::make_content_range_header_field(offset_and_length, res.content_length_); + res.set_header("Content-Range", content_range); + } else { + length = detail::get_multipart_ranges_data_length(req, boundary, content_type, res.content_length_); + } + res.set_header("Content-Length", std::to_string(length)); + } else { + if (res.content_provider_) { + if (res.is_chunked_content_provider_) { + res.set_header("Transfer-Encoding", "chunked"); + if (type == detail::EncodingType::Gzip) { + res.set_header("Content-Encoding", "gzip"); + } else if (type == detail::EncodingType::Brotli) { + res.set_header("Content-Encoding", "br"); + } + } + } + } + } else { + if (req.ranges.empty() || res.status != StatusCode::PartialContent_206) { + ; + } else if (req.ranges.size() == 1) { + auto offset_and_length = detail::get_range_offset_and_length(req.ranges[0], res.body.size()); + auto offset = offset_and_length.first; + auto length = offset_and_length.second; + + auto content_range = detail::make_content_range_header_field(offset_and_length, res.body.size()); + res.set_header("Content-Range", content_range); + + assert(offset + length <= res.body.size()); + res.body = res.body.substr(offset, length); + } else { + std::string data; + detail::make_multipart_ranges_data(req, res, boundary, content_type, res.body.size(), data); + res.body.swap(data); + } + + if (type != detail::EncodingType::None) { + std::unique_ptr compressor; + std::string content_encoding; + + if (type == detail::EncodingType::Gzip) { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + compressor = detail::make_unique(); + content_encoding = "gzip"; +#endif + } else if (type == detail::EncodingType::Brotli) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + compressor = detail::make_unique(); + content_encoding = "br"; +#endif + } + + if (compressor) { + std::string compressed; + if (compressor->compress(res.body.data(), res.body.size(), true, [&](const char *data, size_t data_len) { + compressed.append(data, data_len); + return true; + })) { + res.body.swap(compressed); + res.set_header("Content-Encoding", content_encoding); + } + } + } + + auto length = std::to_string(res.body.size()); + res.set_header("Content-Length", length); + } +} + +inline bool Server::dispatch_request_for_content_reader(Request &req, Response &res, ContentReader content_reader, + const HandlersForContentReader &handlers) const { + for (const auto &x : handlers) { + const auto &matcher = x.first; + const auto &handler = x.second; + + if (matcher->match(req)) { + handler(req, res, content_reader); + return true; + } + } + return false; +} + +inline bool Server::process_request(Stream &strm, const std::string &remote_addr, int remote_port, + const std::string &local_addr, int local_port, bool close_connection, + bool &connection_closed, const std::function &setup_request) { + std::array buf{}; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + // Connection has been closed on client + if (!line_reader.getline()) { + return false; + } + + Request req; + + Response res; + res.version = "HTTP/1.1"; + res.headers = default_headers_; + +#ifdef _WIN32 + // TODO: Increase FD_SETSIZE statically (libzmq), dynamically (MySQL). +#else +#ifndef CPPHTTPLIB_USE_POLL + // Socket file descriptor exceeded FD_SETSIZE... + if (strm.socket() >= FD_SETSIZE) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = StatusCode::InternalServerError_500; + return write_response(strm, close_connection, req, res); + } +#endif +#endif + + // Check if the request URI doesn't exceed the limit + if (line_reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = StatusCode::UriTooLong_414; + return write_response(strm, close_connection, req, res); + } + + // Request line and headers + if (!parse_request_line(line_reader.ptr(), req) || !detail::read_headers(strm, req.headers)) { + res.status = StatusCode::BadRequest_400; + return write_response(strm, close_connection, req, res); + } + + if (req.get_header_value("Connection") == "close") { + connection_closed = true; + } + + if (req.version == "HTTP/1.0" && req.get_header_value("Connection") != "Keep-Alive") { + connection_closed = true; + } + + req.remote_addr = remote_addr; + req.remote_port = remote_port; + req.set_header("REMOTE_ADDR", req.remote_addr); + req.set_header("REMOTE_PORT", std::to_string(req.remote_port)); + + req.local_addr = local_addr; + req.local_port = local_port; + req.set_header("LOCAL_ADDR", req.local_addr); + req.set_header("LOCAL_PORT", std::to_string(req.local_port)); + + if (req.has_header("Range")) { + const auto &range_header_value = req.get_header_value("Range"); + if (!detail::parse_range_header(range_header_value, req.ranges)) { + res.status = StatusCode::RangeNotSatisfiable_416; + return write_response(strm, close_connection, req, res); + } + } + + if (setup_request) { + setup_request(req); + } + + if (req.get_header_value("Expect") == "100-continue") { + int status = StatusCode::Continue_100; + if (expect_100_continue_handler_) { + status = expect_100_continue_handler_(req, res); + } + switch (status) { + case StatusCode::Continue_100: + case StatusCode::ExpectationFailed_417: + detail::write_response_line(strm, status); + strm.write("\r\n"); + break; + default: + connection_closed = true; + return write_response(strm, true, req, res); + } + } + + // Routing + auto routed = false; +#ifdef CPPHTTPLIB_NO_EXCEPTIONS + routed = routing(req, res, strm); +#else + try { + routed = routing(req, res, strm); + } catch (std::exception &e) { + if (exception_handler_) { + auto ep = std::current_exception(); + exception_handler_(req, res, ep); + routed = true; + } else { + res.status = StatusCode::InternalServerError_500; + std::string val; + auto s = e.what(); + for (size_t i = 0; s[i]; i++) { + switch (s[i]) { + case '\r': + val += "\\r"; + break; + case '\n': + val += "\\n"; + break; + default: + val += s[i]; + break; + } + } + res.set_header("EXCEPTION_WHAT", val); + } + } catch (...) { + if (exception_handler_) { + auto ep = std::current_exception(); + exception_handler_(req, res, ep); + routed = true; + } else { + res.status = StatusCode::InternalServerError_500; + res.set_header("EXCEPTION_WHAT", "UNKNOWN"); + } + } +#endif + if (routed) { + if (res.status == -1) { + res.status = req.ranges.empty() ? StatusCode::OK_200 : StatusCode::PartialContent_206; + } + + if (detail::range_error(req, res)) { + res.body.clear(); + res.content_length_ = 0; + res.content_provider_ = nullptr; + res.status = StatusCode::RangeNotSatisfiable_416; + return write_response(strm, close_connection, req, res); + } + + // Serve file content by using a content provider + if (!res.file_content_path_.empty()) { + const auto &path = res.file_content_path_; + auto mm = std::make_shared(path.c_str()); + if (!mm->is_open()) { + res.body.clear(); + res.content_length_ = 0; + res.content_provider_ = nullptr; + res.status = StatusCode::NotFound_404; + return write_response(strm, close_connection, req, res); + } + + auto content_type = res.file_content_content_type_; + if (content_type.empty()) { + content_type = detail::find_content_type(path, file_extension_and_mimetype_map_, default_file_mimetype_); + } + + res.set_content_provider(mm->size(), content_type, [mm](size_t offset, size_t length, DataSink &sink) -> bool { + sink.write(mm->data() + offset, length); + return true; + }); + } + + return write_response_with_content(strm, close_connection, req, res); + } else { + if (res.status == -1) { + res.status = StatusCode::NotFound_404; + } + + return write_response(strm, close_connection, req, res); + } +} + +inline bool Server::is_valid() const { return true; } + +inline bool Server::process_and_close_socket(socket_t sock) { + std::string remote_addr; + int remote_port = 0; + detail::get_remote_ip_and_port(sock, remote_addr, remote_port); + + std::string local_addr; + int local_port = 0; + detail::get_local_ip_and_port(sock, local_addr, local_port); + + auto ret = detail::process_server_socket( + svr_sock_, sock, keep_alive_max_count_, keep_alive_timeout_sec_, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, [&](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request(strm, remote_addr, remote_port, local_addr, local_port, close_connection, + connection_closed, nullptr); + }); + + detail::shutdown_socket(sock); + detail::close_socket(sock); + return ret; +} + +// HTTP client implementation +inline ClientImpl::ClientImpl(const std::string &host) : ClientImpl(host, 80, std::string(), std::string()) {} + +inline ClientImpl::ClientImpl(const std::string &host, int port) + : ClientImpl(host, port, std::string(), std::string()) {} + +inline ClientImpl::ClientImpl(const std::string &host, int port, const std::string &client_cert_path, + const std::string &client_key_path) + : host_(detail::escape_abstract_namespace_unix_domain(host)), + port_(port), + host_and_port_(adjust_host_string(host_) + ":" + std::to_string(port)), + client_cert_path_(client_cert_path), + client_key_path_(client_key_path) {} + +inline ClientImpl::~ClientImpl() { + std::lock_guard guard(socket_mutex_); + shutdown_socket(socket_); + close_socket(socket_); +} + +inline bool ClientImpl::is_valid() const { return true; } + +inline void ClientImpl::copy_settings(const ClientImpl &rhs) { + client_cert_path_ = rhs.client_cert_path_; + client_key_path_ = rhs.client_key_path_; + connection_timeout_sec_ = rhs.connection_timeout_sec_; + read_timeout_sec_ = rhs.read_timeout_sec_; + read_timeout_usec_ = rhs.read_timeout_usec_; + write_timeout_sec_ = rhs.write_timeout_sec_; + write_timeout_usec_ = rhs.write_timeout_usec_; + basic_auth_username_ = rhs.basic_auth_username_; + basic_auth_password_ = rhs.basic_auth_password_; + bearer_token_auth_token_ = rhs.bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + digest_auth_username_ = rhs.digest_auth_username_; + digest_auth_password_ = rhs.digest_auth_password_; +#endif + keep_alive_ = rhs.keep_alive_; + follow_location_ = rhs.follow_location_; + url_encode_ = rhs.url_encode_; + address_family_ = rhs.address_family_; + tcp_nodelay_ = rhs.tcp_nodelay_; + ipv6_v6only_ = rhs.ipv6_v6only_; + socket_options_ = rhs.socket_options_; + compress_ = rhs.compress_; + decompress_ = rhs.decompress_; + interface_ = rhs.interface_; + proxy_host_ = rhs.proxy_host_; + proxy_port_ = rhs.proxy_port_; + proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_; + proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_; + proxy_bearer_token_auth_token_ = rhs.proxy_bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; + proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; +#endif +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + ca_cert_file_path_ = rhs.ca_cert_file_path_; + ca_cert_dir_path_ = rhs.ca_cert_dir_path_; + ca_cert_store_ = rhs.ca_cert_store_; +#endif +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + server_certificate_verification_ = rhs.server_certificate_verification_; + server_hostname_verification_ = rhs.server_hostname_verification_; + server_certificate_verifier_ = rhs.server_certificate_verifier_; +#endif + logger_ = rhs.logger_; +} + +inline socket_t ClientImpl::create_client_socket(Error &error) const { + if (!proxy_host_.empty() && proxy_port_ != -1) { + return detail::create_client_socket(proxy_host_, std::string(), proxy_port_, address_family_, tcp_nodelay_, + ipv6_v6only_, socket_options_, connection_timeout_sec_, + connection_timeout_usec_, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, interface_, error); + } + + // Check is custom IP specified for host_ + std::string ip; + auto it = addr_map_.find(host_); + if (it != addr_map_.end()) { + ip = it->second; + } + + return detail::create_client_socket(host_, ip, port_, address_family_, tcp_nodelay_, ipv6_v6only_, socket_options_, + connection_timeout_sec_, connection_timeout_usec_, read_timeout_sec_, + read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, interface_, error); +} + +inline bool ClientImpl::create_and_connect_socket(Socket &socket, Error &error) { + auto sock = create_client_socket(error); + if (sock == INVALID_SOCKET) { + return false; + } + socket.sock = sock; + return true; +} + +inline void ClientImpl::shutdown_ssl(Socket & /*socket*/, bool /*shutdown_gracefully*/) { + // If there are any requests in flight from threads other than us, then it's + // a thread-unsafe race because individual ssl* objects are not thread-safe. + assert(socket_requests_in_flight_ == 0 || socket_requests_are_from_thread_ == std::this_thread::get_id()); +} + +inline void ClientImpl::shutdown_socket(Socket &socket) const { + if (socket.sock == INVALID_SOCKET) { + return; + } + detail::shutdown_socket(socket.sock); +} + +inline void ClientImpl::close_socket(Socket &socket) { + // If there are requests in flight in another thread, usually closing + // the socket will be fine and they will simply receive an error when + // using the closed socket, but it is still a bug since rarely the OS + // may reassign the socket id to be used for a new socket, and then + // suddenly they will be operating on a live socket that is different + // than the one they intended! + assert(socket_requests_in_flight_ == 0 || socket_requests_are_from_thread_ == std::this_thread::get_id()); + + // It is also a bug if this happens while SSL is still active +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + assert(socket.ssl == nullptr); +#endif + if (socket.sock == INVALID_SOCKET) { + return; + } + detail::close_socket(socket.sock); + socket.sock = INVALID_SOCKET; +} + +inline bool ClientImpl::read_response_line(Stream &strm, const Request &req, Response &res) const { + std::array buf{}; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + if (!line_reader.getline()) { + return false; + } + +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + const static std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r?\n"); +#else + const static std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r\n"); +#endif + + std::cmatch m; + if (!std::regex_match(line_reader.ptr(), m, re)) { + return req.method == "CONNECT"; + } + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + res.reason = std::string(m[3]); + + // Ignore '100 Continue' + while (res.status == StatusCode::Continue_100) { + if (!line_reader.getline()) { + return false; + } // CRLF + if (!line_reader.getline()) { + return false; + } // next response line + + if (!std::regex_match(line_reader.ptr(), m, re)) { + return false; + } + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + res.reason = std::string(m[3]); + } + + return true; +} + +inline bool ClientImpl::send(Request &req, Response &res, Error &error) { + std::lock_guard request_mutex_guard(request_mutex_); + auto ret = send_(req, res, error); + if (error == Error::SSLPeerCouldBeClosed_) { + assert(!ret); + ret = send_(req, res, error); + } + return ret; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline bool ClientImpl::is_ssl_peer_could_be_closed(SSL *ssl) const { + char buf[1]; + return !SSL_peek(ssl, buf, 1) && SSL_get_error(ssl, 0) == SSL_ERROR_ZERO_RETURN; +} +#endif + +inline bool ClientImpl::send_(Request &req, Response &res, Error &error) { + { + std::lock_guard guard(socket_mutex_); + + // Set this to false immediately - if it ever gets set to true by the end of + // the request, we know another thread instructed us to close the socket. + socket_should_be_closed_when_request_is_done_ = false; + + auto is_alive = false; + if (socket_.is_open()) { + is_alive = detail::is_socket_alive(socket_.sock); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_alive && is_ssl()) { + if (is_ssl_peer_could_be_closed(socket_.ssl)) { + is_alive = false; + } + } +#endif + + if (!is_alive) { + // Attempt to avoid sigpipe by shutting down nongracefully if it seems + // like the other side has already closed the connection Also, there + // cannot be any requests in flight from other threads since we locked + // request_mutex_, so safe to close everything immediately + const bool shutdown_gracefully = false; + shutdown_ssl(socket_, shutdown_gracefully); + shutdown_socket(socket_); + close_socket(socket_); + } + } + + if (!is_alive) { + if (!create_and_connect_socket(socket_, error)) { + return false; + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + // TODO: refactoring + if (is_ssl()) { + auto &scli = static_cast(*this); + if (!proxy_host_.empty() && proxy_port_ != -1) { + auto success = false; + if (!scli.connect_with_proxy(socket_, res, success, error)) { + return success; + } + } + + if (!scli.initialize_ssl(socket_, error)) { + return false; + } + } +#endif + } + + // Mark the current socket as being in use so that it cannot be closed by + // anyone else while this request is ongoing, even though we will be + // releasing the mutex. + if (socket_requests_in_flight_ > 1) { + assert(socket_requests_are_from_thread_ == std::this_thread::get_id()); + } + socket_requests_in_flight_ += 1; + socket_requests_are_from_thread_ = std::this_thread::get_id(); + } + + for (const auto &header : default_headers_) { + if (req.headers.find(header.first) == req.headers.end()) { + req.headers.insert(header); + } + } + + auto ret = false; + auto close_connection = !keep_alive_; + + auto se = detail::scope_exit([&]() { + // Briefly lock mutex in order to mark that a request is no longer ongoing + std::lock_guard guard(socket_mutex_); + socket_requests_in_flight_ -= 1; + if (socket_requests_in_flight_ <= 0) { + assert(socket_requests_in_flight_ == 0); + socket_requests_are_from_thread_ = std::thread::id(); + } + + if (socket_should_be_closed_when_request_is_done_ || close_connection || !ret) { + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); + } + }); + + ret = process_socket(socket_, [&](Stream &strm) { return handle_request(strm, req, res, close_connection, error); }); + + if (!ret) { + if (error == Error::Success) { + error = Error::Unknown; + } + } + + return ret; +} + +inline Result ClientImpl::send(const Request &req) { + auto req2 = req; + return send_(std::move(req2)); +} + +inline Result ClientImpl::send_(Request &&req) { + auto res = detail::make_unique(); + auto error = Error::Success; + auto ret = send(req, *res, error); + return Result{ret ? std::move(res) : nullptr, error, std::move(req.headers)}; +} + +inline bool ClientImpl::handle_request(Stream &strm, Request &req, Response &res, bool close_connection, Error &error) { + if (req.path.empty()) { + error = Error::Connection; + return false; + } + + auto req_save = req; + + bool ret; + + if (!is_ssl() && !proxy_host_.empty() && proxy_port_ != -1) { + auto req2 = req; + req2.path = "http://" + host_and_port_ + req.path; + ret = process_request(strm, req2, res, close_connection, error); + req = req2; + req.path = req_save.path; + } else { + ret = process_request(strm, req, res, close_connection, error); + } + + if (!ret) { + return false; + } + + if (res.get_header_value("Connection") == "close" || + (res.version == "HTTP/1.0" && res.reason != "Connection established")) { + // TODO this requires a not-entirely-obvious chain of calls to be correct + // for this to be safe. + + // This is safe to call because handle_request is only called by send_ + // which locks the request mutex during the process. It would be a bug + // to call it from a different thread since it's a thread-safety issue + // to do these things to the socket if another thread is using the socket. + std::lock_guard guard(socket_mutex_); + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); + } + + if (300 < res.status && res.status < 400 && follow_location_) { + req = req_save; + ret = redirect(req, res, error); + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if ((res.status == StatusCode::Unauthorized_401 || res.status == StatusCode::ProxyAuthenticationRequired_407) && + req.authorization_count_ < 5) { + auto is_proxy = res.status == StatusCode::ProxyAuthenticationRequired_407; + const auto &username = is_proxy ? proxy_digest_auth_username_ : digest_auth_username_; + const auto &password = is_proxy ? proxy_digest_auth_password_ : digest_auth_password_; + + if (!username.empty() && !password.empty()) { + std::map auth; + if (detail::parse_www_authenticate(res, auth, is_proxy)) { + Request new_req = req; + new_req.authorization_count_ += 1; + new_req.headers.erase(is_proxy ? "Proxy-Authorization" : "Authorization"); + new_req.headers.insert(detail::make_digest_authentication_header( + req, auth, new_req.authorization_count_, detail::random_string(10), username, password, is_proxy)); + + Response new_res; + + ret = send(new_req, new_res, error); + if (ret) { + res = new_res; + } + } + } + } +#endif + + return ret; +} + +inline bool ClientImpl::redirect(Request &req, Response &res, Error &error) { + if (req.redirect_count_ == 0) { + error = Error::ExceedRedirectCount; + return false; + } + + auto location = res.get_header_value("location"); + if (location.empty()) { + return false; + } + + const static std::regex re( + R"((?:(https?):)?(?://(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)?([^?#]*)(\?[^#]*)?(?:#.*)?)"); + + std::smatch m; + if (!std::regex_match(location, m, re)) { + return false; + } + + auto scheme = is_ssl() ? "https" : "http"; + + auto next_scheme = m[1].str(); + auto next_host = m[2].str(); + if (next_host.empty()) { + next_host = m[3].str(); + } + auto port_str = m[4].str(); + auto next_path = m[5].str(); + auto next_query = m[6].str(); + + auto next_port = port_; + if (!port_str.empty()) { + next_port = std::stoi(port_str); + } else if (!next_scheme.empty()) { + next_port = next_scheme == "https" ? 443 : 80; + } + + if (next_scheme.empty()) { + next_scheme = scheme; + } + if (next_host.empty()) { + next_host = host_; + } + if (next_path.empty()) { + next_path = "/"; + } + + auto path = detail::decode_url(next_path, true) + next_query; + + if (next_scheme == scheme && next_host == host_ && next_port == port_) { + return detail::redirect(*this, req, res, path, location, error); + } else { + if (next_scheme == "https") { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSLClient cli(next_host, next_port); + cli.copy_settings(*this); + if (ca_cert_store_) { + cli.set_ca_cert_store(ca_cert_store_); + } + return detail::redirect(cli, req, res, path, location, error); +#else + return false; +#endif + } else { + ClientImpl cli(next_host, next_port); + cli.copy_settings(*this); + return detail::redirect(cli, req, res, path, location, error); + } + } +} + +inline bool ClientImpl::write_content_with_provider(Stream &strm, const Request &req, Error &error) const { + auto is_shutting_down = []() { return false; }; + + if (req.is_chunked_content_provider_) { + // TODO: Brotli support + std::unique_ptr compressor; +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_) { + compressor = detail::make_unique(); + } else +#endif + { + compressor = detail::make_unique(); + } + + return detail::write_content_chunked(strm, req.content_provider_, is_shutting_down, *compressor, error); + } else { + return detail::write_content(strm, req.content_provider_, 0, req.content_length_, is_shutting_down, error); + } +} + +inline bool ClientImpl::write_request(Stream &strm, Request &req, bool close_connection, Error &error) { + // Prepare additional headers + if (close_connection) { + if (!req.has_header("Connection")) { + req.set_header("Connection", "close"); + } + } + + if (!req.has_header("Host")) { + if (is_ssl()) { + if (port_ == 443) { + req.set_header("Host", host_); + } else { + req.set_header("Host", host_and_port_); + } + } else { + if (port_ == 80) { + req.set_header("Host", host_); + } else { + req.set_header("Host", host_and_port_); + } + } + } + + if (!req.has_header("Accept")) { + req.set_header("Accept", "*/*"); + } + + if (!req.content_receiver) { + if (!req.has_header("Accept-Encoding")) { + std::string accept_encoding; +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + accept_encoding = "br"; +#endif +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (!accept_encoding.empty()) { + accept_encoding += ", "; + } + accept_encoding += "gzip, deflate"; +#endif + req.set_header("Accept-Encoding", accept_encoding); + } + +#ifndef CPPHTTPLIB_NO_DEFAULT_USER_AGENT + if (!req.has_header("User-Agent")) { + auto agent = std::string("cpp-httplib/") + CPPHTTPLIB_VERSION; + req.set_header("User-Agent", agent); + } +#endif + }; + + if (req.body.empty()) { + if (req.content_provider_) { + if (!req.is_chunked_content_provider_) { + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.content_length_); + req.set_header("Content-Length", length); + } + } + } else { + if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH") { + req.set_header("Content-Length", "0"); + } + } + } else { + if (!req.has_header("Content-Type")) { + req.set_header("Content-Type", "text/plain"); + } + + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.body.size()); + req.set_header("Content-Length", length); + } + } + + if (!basic_auth_password_.empty() || !basic_auth_username_.empty()) { + if (!req.has_header("Authorization")) { + req.headers.insert(make_basic_authentication_header(basic_auth_username_, basic_auth_password_, false)); + } + } + + if (!proxy_basic_auth_username_.empty() && !proxy_basic_auth_password_.empty()) { + if (!req.has_header("Proxy-Authorization")) { + req.headers.insert( + make_basic_authentication_header(proxy_basic_auth_username_, proxy_basic_auth_password_, true)); + } + } + + if (!bearer_token_auth_token_.empty()) { + if (!req.has_header("Authorization")) { + req.headers.insert(make_bearer_token_authentication_header(bearer_token_auth_token_, false)); + } + } + + if (!proxy_bearer_token_auth_token_.empty()) { + if (!req.has_header("Proxy-Authorization")) { + req.headers.insert(make_bearer_token_authentication_header(proxy_bearer_token_auth_token_, true)); + } + } + + // Request line and headers + { + detail::BufferStream bstrm; + + const auto &path_with_query = req.params.empty() ? req.path : append_query_params(req.path, req.params); + + const auto &path = url_encode_ ? detail::encode_url(path_with_query) : path_with_query; + + detail::write_request_line(bstrm, req.method, path); + + header_writer_(bstrm, req.headers); + + // Flush buffer + auto &data = bstrm.get_buffer(); + if (!detail::write_data(strm, data.data(), data.size())) { + error = Error::Write; + return false; + } + } + + // Body + if (req.body.empty()) { + return write_content_with_provider(strm, req, error); + } + + if (!detail::write_data(strm, req.body.data(), req.body.size())) { + error = Error::Write; + return false; + } + + return true; +} + +inline std::unique_ptr ClientImpl::send_with_content_provider( + Request &req, const char *body, size_t content_length, ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, const std::string &content_type, Error &error) { + if (!content_type.empty()) { + req.set_header("Content-Type", content_type); + } + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_) { + req.set_header("Content-Encoding", "gzip"); + } +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_ && !content_provider_without_length) { + // TODO: Brotli support + detail::gzip_compressor compressor; + + if (content_provider) { + auto ok = true; + size_t offset = 0; + DataSink data_sink; + + data_sink.write = [&](const char *data, size_t data_len) -> bool { + if (ok) { + auto last = offset + data_len == content_length; + + auto ret = + compressor.compress(data, data_len, last, [&](const char *compressed_data, size_t compressed_data_len) { + req.body.append(compressed_data, compressed_data_len); + return true; + }); + + if (ret) { + offset += data_len; + } else { + ok = false; + } + } + return ok; + }; + + while (ok && offset < content_length) { + if (!content_provider(offset, content_length - offset, data_sink)) { + error = Error::Canceled; + return nullptr; + } + } + } else { + if (!compressor.compress(body, content_length, true, [&](const char *data, size_t data_len) { + req.body.append(data, data_len); + return true; + })) { + error = Error::Compression; + return nullptr; + } + } + } else +#endif + { + if (content_provider) { + req.content_length_ = content_length; + req.content_provider_ = std::move(content_provider); + req.is_chunked_content_provider_ = false; + } else if (content_provider_without_length) { + req.content_length_ = 0; + req.content_provider_ = detail::ContentProviderAdapter(std::move(content_provider_without_length)); + req.is_chunked_content_provider_ = true; + req.set_header("Transfer-Encoding", "chunked"); + } else { + req.body.assign(body, content_length); + } + } + + auto res = detail::make_unique(); + return send(req, *res, error) ? std::move(res) : nullptr; +} + +inline Result ClientImpl::send_with_content_provider(const std::string &method, const std::string &path, + const Headers &headers, const char *body, size_t content_length, + ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, Progress progress) { + Request req; + req.method = method; + req.headers = headers; + req.path = path; + req.progress = progress; + + auto error = Error::Success; + + auto res = send_with_content_provider(req, body, content_length, std::move(content_provider), + std::move(content_provider_without_length), content_type, error); + + return Result{std::move(res), error, std::move(req.headers)}; +} + +inline std::string ClientImpl::adjust_host_string(const std::string &host) const { + if (host.find(':') != std::string::npos) { + return "[" + host + "]"; + } + return host; +} + +inline bool ClientImpl::process_request(Stream &strm, Request &req, Response &res, bool close_connection, + Error &error) { + // Send request + if (!write_request(strm, req, close_connection, error)) { + return false; + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_ssl()) { + auto is_proxy_enabled = !proxy_host_.empty() && proxy_port_ != -1; + if (!is_proxy_enabled) { + if (is_ssl_peer_could_be_closed(socket_.ssl)) { + error = Error::SSLPeerCouldBeClosed_; + return false; + } + } + } +#endif + + // Receive response and headers + if (!read_response_line(strm, req, res) || !detail::read_headers(strm, res.headers)) { + error = Error::Read; + return false; + } + + // Body + if ((res.status != StatusCode::NoContent_204) && req.method != "HEAD" && req.method != "CONNECT") { + auto redirect = 300 < res.status && res.status < 400 && follow_location_; + + if (req.response_handler && !redirect) { + if (!req.response_handler(res)) { + error = Error::Canceled; + return false; + } + } + + auto out = + req.content_receiver + ? static_cast([&](const char *buf, size_t n, uint64_t off, uint64_t len) { + if (redirect) { + return true; + } + auto ret = req.content_receiver(buf, n, off, len); + if (!ret) { + error = Error::Canceled; + } + return ret; + }) + : static_cast( + [&](const char *buf, size_t n, uint64_t /*off*/, uint64_t /*len*/) { + assert(res.body.size() + n <= res.body.max_size()); + res.body.append(buf, n); + return true; + }); + + auto progress = [&](uint64_t current, uint64_t total) { + if (!req.progress || redirect) { + return true; + } + auto ret = req.progress(current, total); + if (!ret) { + error = Error::Canceled; + } + return ret; + }; + + if (res.has_header("Content-Length")) { + if (!req.content_receiver) { + auto len = res.get_header_value_u64("Content-Length"); + if (len > res.body.max_size()) { + error = Error::Read; + return false; + } + res.body.reserve(len); + } + } + + int dummy_status; + if (!detail::read_content(strm, res, (std::numeric_limits::max)(), dummy_status, std::move(progress), + std::move(out), decompress_)) { + if (error != Error::Canceled) { + error = Error::Read; + } + return false; + } + } + + // Log + if (logger_) { + logger_(req, res); + } + + return true; +} + +inline ContentProviderWithoutLength ClientImpl::get_multipart_content_provider( + const std::string &boundary, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) const { + size_t cur_item = 0; + size_t cur_start = 0; + // cur_item and cur_start are copied to within the std::function and maintain + // state between successive calls + return [&, cur_item, cur_start](size_t offset, DataSink &sink) mutable -> bool { + if (!offset && !items.empty()) { + sink.os << detail::serialize_multipart_formdata(items, boundary, false); + return true; + } else if (cur_item < provider_items.size()) { + if (!cur_start) { + const auto &begin = detail::serialize_multipart_formdata_item_begin(provider_items[cur_item], boundary); + offset += begin.size(); + cur_start = offset; + sink.os << begin; + } + + DataSink cur_sink; + auto has_data = true; + cur_sink.write = sink.write; + cur_sink.done = [&]() { has_data = false; }; + + if (!provider_items[cur_item].provider(offset - cur_start, cur_sink)) { + return false; + } + + if (!has_data) { + sink.os << detail::serialize_multipart_formdata_item_end(); + cur_item++; + cur_start = 0; + } + return true; + } else { + sink.os << detail::serialize_multipart_formdata_finish(boundary); + sink.done(); + return true; + } + }; +} + +inline bool ClientImpl::process_socket(const Socket &socket, std::function callback) { + return detail::process_client_socket(socket.sock, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, std::move(callback)); +} + +inline bool ClientImpl::is_ssl() const { return false; } + +inline Result ClientImpl::Get(const std::string &path) { return Get(path, Headers(), Progress()); } + +inline Result ClientImpl::Get(const std::string &path, Progress progress) { + return Get(path, Headers(), std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers) { + return Get(path, headers, Progress()); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.progress = std::move(progress); + + return send_(std::move(req)); +} + +inline Result ClientImpl::Get(const std::string &path, ContentReceiver content_receiver) { + return Get(path, Headers(), nullptr, std::move(content_receiver), nullptr); +} + +inline Result ClientImpl::Get(const std::string &path, ContentReceiver content_receiver, Progress progress) { + return Get(path, Headers(), nullptr, std::move(content_receiver), std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver) { + return Get(path, headers, nullptr, std::move(content_receiver), nullptr); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver, + Progress progress) { + return Get(path, headers, nullptr, std::move(content_receiver), std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, ResponseHandler response_handler, + ContentReceiver content_receiver) { + return Get(path, Headers(), std::move(response_handler), std::move(content_receiver), nullptr); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver) { + return Get(path, headers, std::move(response_handler), std::move(content_receiver), nullptr); +} + +inline Result ClientImpl::Get(const std::string &path, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress) { + return Get(path, Headers(), std::move(response_handler), std::move(content_receiver), std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.response_handler = std::move(response_handler); + req.content_receiver = [content_receiver](const char *data, size_t data_length, uint64_t /*offset*/, + uint64_t /*total_length*/) { return content_receiver(data, data_length); }; + req.progress = std::move(progress); + + return send_(std::move(req)); +} + +inline Result ClientImpl::Get(const std::string &path, const Params ¶ms, const Headers &headers, + Progress progress) { + if (params.empty()) { + return Get(path, headers); + } + + std::string path_with_query = append_query_params(path, params); + return Get(path_with_query, headers, std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Params ¶ms, const Headers &headers, + ContentReceiver content_receiver, Progress progress) { + return Get(path, params, headers, nullptr, std::move(content_receiver), std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Params ¶ms, const Headers &headers, + ResponseHandler response_handler, ContentReceiver content_receiver, Progress progress) { + if (params.empty()) { + return Get(path, headers, std::move(response_handler), std::move(content_receiver), std::move(progress)); + } + + std::string path_with_query = append_query_params(path, params); + return Get(path_with_query, headers, std::move(response_handler), std::move(content_receiver), std::move(progress)); +} + +inline Result ClientImpl::Head(const std::string &path) { return Head(path, Headers()); } + +inline Result ClientImpl::Head(const std::string &path, const Headers &headers) { + Request req; + req.method = "HEAD"; + req.headers = headers; + req.path = path; + + return send_(std::move(req)); +} + +inline Result ClientImpl::Post(const std::string &path) { return Post(path, std::string(), std::string()); } + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers) { + return Post(path, headers, nullptr, 0, std::string()); +} + +inline Result ClientImpl::Post(const std::string &path, const char *body, size_t content_length, + const std::string &content_type) { + return Post(path, Headers(), body, content_length, content_type, nullptr); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type) { + return send_with_content_provider("POST", path, headers, body, content_length, nullptr, nullptr, content_type, + nullptr); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return send_with_content_provider("POST", path, headers, body, content_length, nullptr, nullptr, content_type, + progress); +} + +inline Result ClientImpl::Post(const std::string &path, const std::string &body, const std::string &content_type) { + return Post(path, Headers(), body, content_type); +} + +inline Result ClientImpl::Post(const std::string &path, const std::string &body, const std::string &content_type, + Progress progress) { + return Post(path, Headers(), body, content_type, progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type) { + return send_with_content_provider("POST", path, headers, body.data(), body.size(), nullptr, nullptr, content_type, + nullptr); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type, Progress progress) { + return send_with_content_provider("POST", path, headers, body.data(), body.size(), nullptr, nullptr, content_type, + progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Params ¶ms) { return Post(path, Headers(), params); } + +inline Result ClientImpl::Post(const std::string &path, size_t content_length, ContentProvider content_provider, + const std::string &content_type) { + return Post(path, Headers(), content_length, std::move(content_provider), content_type); +} + +inline Result ClientImpl::Post(const std::string &path, ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return Post(path, Headers(), std::move(content_provider), content_type); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const std::string &content_type) { + return send_with_content_provider("POST", path, headers, nullptr, content_length, std::move(content_provider), + nullptr, content_type, nullptr); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, const std::string &content_type) { + return send_with_content_provider("POST", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, nullptr); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const Params ¶ms) { + auto query = detail::params_to_query_str(params); + return Post(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const Params ¶ms, + Progress progress) { + auto query = detail::params_to_query_str(params); + return Post(path, headers, query, "application/x-www-form-urlencoded", progress); +} + +inline Result ClientImpl::Post(const std::string &path, const MultipartFormDataItems &items) { + return Post(path, Headers(), items); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Post(path, headers, body, content_type); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const std::string &boundary) { + if (!detail::is_multipart_boundary_chars_valid(boundary)) { + return Result{nullptr, Error::UnsupportedMultipartBoundaryChars}; + } + + const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Post(path, headers, body, content_type); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); + return send_with_content_provider("POST", path, headers, nullptr, 0, nullptr, + get_multipart_content_provider(boundary, items, provider_items), content_type, + nullptr); +} + +inline Result ClientImpl::Put(const std::string &path) { return Put(path, std::string(), std::string()); } + +inline Result ClientImpl::Put(const std::string &path, const char *body, size_t content_length, + const std::string &content_type) { + return Put(path, Headers(), body, content_length, content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type) { + return send_with_content_provider("PUT", path, headers, body, content_length, nullptr, nullptr, content_type, + nullptr); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return send_with_content_provider("PUT", path, headers, body, content_length, nullptr, nullptr, content_type, + progress); +} + +inline Result ClientImpl::Put(const std::string &path, const std::string &body, const std::string &content_type) { + return Put(path, Headers(), body, content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const std::string &body, const std::string &content_type, + Progress progress) { + return Put(path, Headers(), body, content_type, progress); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type) { + return send_with_content_provider("PUT", path, headers, body.data(), body.size(), nullptr, nullptr, content_type, + nullptr); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type, Progress progress) { + return send_with_content_provider("PUT", path, headers, body.data(), body.size(), nullptr, nullptr, content_type, + progress); +} + +inline Result ClientImpl::Put(const std::string &path, size_t content_length, ContentProvider content_provider, + const std::string &content_type) { + return Put(path, Headers(), content_length, std::move(content_provider), content_type); +} + +inline Result ClientImpl::Put(const std::string &path, ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return Put(path, Headers(), std::move(content_provider), content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const std::string &content_type) { + return send_with_content_provider("PUT", path, headers, nullptr, content_length, std::move(content_provider), nullptr, + content_type, nullptr); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, const std::string &content_type) { + return send_with_content_provider("PUT", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, nullptr); +} + +inline Result ClientImpl::Put(const std::string &path, const Params ¶ms) { return Put(path, Headers(), params); } + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const Params ¶ms) { + auto query = detail::params_to_query_str(params); + return Put(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const Params ¶ms, + Progress progress) { + auto query = detail::params_to_query_str(params); + return Put(path, headers, query, "application/x-www-form-urlencoded", progress); +} + +inline Result ClientImpl::Put(const std::string &path, const MultipartFormDataItems &items) { + return Put(path, Headers(), items); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Put(path, headers, body, content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const std::string &boundary) { + if (!detail::is_multipart_boundary_chars_valid(boundary)) { + return Result{nullptr, Error::UnsupportedMultipartBoundaryChars}; + } + + const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Put(path, headers, body, content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); + return send_with_content_provider("PUT", path, headers, nullptr, 0, nullptr, + get_multipart_content_provider(boundary, items, provider_items), content_type, + nullptr); +} +inline Result ClientImpl::Patch(const std::string &path) { return Patch(path, std::string(), std::string()); } + +inline Result ClientImpl::Patch(const std::string &path, const char *body, size_t content_length, + const std::string &content_type) { + return Patch(path, Headers(), body, content_length, content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return Patch(path, Headers(), body, content_length, content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type) { + return Patch(path, headers, body, content_length, content_type, nullptr); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type, Progress progress) { + return send_with_content_provider("PATCH", path, headers, body, content_length, nullptr, nullptr, content_type, + progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const std::string &body, const std::string &content_type) { + return Patch(path, Headers(), body, content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, const std::string &body, const std::string &content_type, + Progress progress) { + return Patch(path, Headers(), body, content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type) { + return Patch(path, headers, body, content_type, nullptr); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type, Progress progress) { + return send_with_content_provider("PATCH", path, headers, body.data(), body.size(), nullptr, nullptr, content_type, + progress); +} + +inline Result ClientImpl::Patch(const std::string &path, size_t content_length, ContentProvider content_provider, + const std::string &content_type) { + return Patch(path, Headers(), content_length, std::move(content_provider), content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return Patch(path, Headers(), std::move(content_provider), content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const std::string &content_type) { + return send_with_content_provider("PATCH", path, headers, nullptr, content_length, std::move(content_provider), + nullptr, content_type, nullptr); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, const std::string &content_type) { + return send_with_content_provider("PATCH", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, nullptr); +} + +inline Result ClientImpl::Delete(const std::string &path) { + return Delete(path, Headers(), std::string(), std::string()); +} + +inline Result ClientImpl::Delete(const std::string &path, const Headers &headers) { + return Delete(path, headers, std::string(), std::string()); +} + +inline Result ClientImpl::Delete(const std::string &path, const char *body, size_t content_length, + const std::string &content_type) { + return Delete(path, Headers(), body, content_length, content_type); +} + +inline Result ClientImpl::Delete(const std::string &path, const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return Delete(path, Headers(), body, content_length, content_type, progress); +} + +inline Result ClientImpl::Delete(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type) { + return Delete(path, headers, body, content_length, content_type, nullptr); +} + +inline Result ClientImpl::Delete(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type, Progress progress) { + Request req; + req.method = "DELETE"; + req.headers = headers; + req.path = path; + req.progress = progress; + + if (!content_type.empty()) { + req.set_header("Content-Type", content_type); + } + req.body.assign(body, content_length); + + return send_(std::move(req)); +} + +inline Result ClientImpl::Delete(const std::string &path, const std::string &body, const std::string &content_type) { + return Delete(path, Headers(), body.data(), body.size(), content_type); +} + +inline Result ClientImpl::Delete(const std::string &path, const std::string &body, const std::string &content_type, + Progress progress) { + return Delete(path, Headers(), body.data(), body.size(), content_type, progress); +} + +inline Result ClientImpl::Delete(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type) { + return Delete(path, headers, body.data(), body.size(), content_type); +} + +inline Result ClientImpl::Delete(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type, Progress progress) { + return Delete(path, headers, body.data(), body.size(), content_type, progress); +} + +inline Result ClientImpl::Options(const std::string &path) { return Options(path, Headers()); } + +inline Result ClientImpl::Options(const std::string &path, const Headers &headers) { + Request req; + req.method = "OPTIONS"; + req.headers = headers; + req.path = path; + + return send_(std::move(req)); +} + +inline void ClientImpl::stop() { + std::lock_guard guard(socket_mutex_); + + // If there is anything ongoing right now, the ONLY thread-safe thing we can + // do is to shutdown_socket, so that threads using this socket suddenly + // discover they can't read/write any more and error out. Everything else + // (closing the socket, shutting ssl down) is unsafe because these actions are + // not thread-safe. + if (socket_requests_in_flight_ > 0) { + shutdown_socket(socket_); + + // Aside from that, we set a flag for the socket to be closed when we're + // done. + socket_should_be_closed_when_request_is_done_ = true; + return; + } + + // Otherwise, still holding the mutex, we can shut everything down ourselves + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); +} + +inline std::string ClientImpl::host() const { return host_; } + +inline int ClientImpl::port() const { return port_; } + +inline size_t ClientImpl::is_socket_open() const { + std::lock_guard guard(socket_mutex_); + return socket_.is_open(); +} + +inline socket_t ClientImpl::socket() const { return socket_.sock; } + +inline void ClientImpl::set_connection_timeout(time_t sec, time_t usec) { + connection_timeout_sec_ = sec; + connection_timeout_usec_ = usec; +} + +inline void ClientImpl::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; +} + +inline void ClientImpl::set_write_timeout(time_t sec, time_t usec) { + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; +} + +inline void ClientImpl::set_basic_auth(const std::string &username, const std::string &password) { + basic_auth_username_ = username; + basic_auth_password_ = password; +} + +inline void ClientImpl::set_bearer_token_auth(const std::string &token) { bearer_token_auth_token_ = token; } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void ClientImpl::set_digest_auth(const std::string &username, const std::string &password) { + digest_auth_username_ = username; + digest_auth_password_ = password; +} +#endif + +inline void ClientImpl::set_keep_alive(bool on) { keep_alive_ = on; } + +inline void ClientImpl::set_follow_location(bool on) { follow_location_ = on; } + +inline void ClientImpl::set_url_encode(bool on) { url_encode_ = on; } + +inline void ClientImpl::set_hostname_addr_map(std::map addr_map) { + addr_map_ = std::move(addr_map); +} + +inline void ClientImpl::set_default_headers(Headers headers) { default_headers_ = std::move(headers); } + +inline void ClientImpl::set_header_writer(std::function const &writer) { + header_writer_ = writer; +} + +inline void ClientImpl::set_address_family(int family) { address_family_ = family; } + +inline void ClientImpl::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; } + +inline void ClientImpl::set_ipv6_v6only(bool on) { ipv6_v6only_ = on; } + +inline void ClientImpl::set_socket_options(SocketOptions socket_options) { + socket_options_ = std::move(socket_options); +} + +inline void ClientImpl::set_compress(bool on) { compress_ = on; } + +inline void ClientImpl::set_decompress(bool on) { decompress_ = on; } + +inline void ClientImpl::set_interface(const std::string &intf) { interface_ = intf; } + +inline void ClientImpl::set_proxy(const std::string &host, int port) { + proxy_host_ = host; + proxy_port_ = port; +} + +inline void ClientImpl::set_proxy_basic_auth(const std::string &username, const std::string &password) { + proxy_basic_auth_username_ = username; + proxy_basic_auth_password_ = password; +} + +inline void ClientImpl::set_proxy_bearer_token_auth(const std::string &token) { + proxy_bearer_token_auth_token_ = token; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void ClientImpl::set_proxy_digest_auth(const std::string &username, const std::string &password) { + proxy_digest_auth_username_ = username; + proxy_digest_auth_password_ = password; +} + +inline void ClientImpl::set_ca_cert_path(const std::string &ca_cert_file_path, const std::string &ca_cert_dir_path) { + ca_cert_file_path_ = ca_cert_file_path; + ca_cert_dir_path_ = ca_cert_dir_path; +} + +inline void ClientImpl::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (ca_cert_store && ca_cert_store != ca_cert_store_) { + ca_cert_store_ = ca_cert_store; + } +} + +inline X509_STORE *ClientImpl::create_ca_cert_store(const char *ca_cert, std::size_t size) const { + auto mem = BIO_new_mem_buf(ca_cert, static_cast(size)); + auto se = detail::scope_exit([&] { BIO_free_all(mem); }); + if (!mem) { + return nullptr; + } + + auto inf = PEM_X509_INFO_read_bio(mem, nullptr, nullptr, nullptr); + if (!inf) { + return nullptr; + } + + auto cts = X509_STORE_new(); + if (cts) { + for (auto i = 0; i < static_cast(sk_X509_INFO_num(inf)); i++) { + auto itmp = sk_X509_INFO_value(inf, i); + if (!itmp) { + continue; + } + + if (itmp->x509) { + X509_STORE_add_cert(cts, itmp->x509); + } + if (itmp->crl) { + X509_STORE_add_crl(cts, itmp->crl); + } + } + } + + sk_X509_INFO_pop_free(inf, X509_INFO_free); + return cts; +} + +inline void ClientImpl::enable_server_certificate_verification(bool enabled) { + server_certificate_verification_ = enabled; +} + +inline void ClientImpl::enable_server_hostname_verification(bool enabled) { server_hostname_verification_ = enabled; } + +inline void ClientImpl::set_server_certificate_verifier(std::function verifier) { + server_certificate_verifier_ = verifier; +} +#endif + +inline void ClientImpl::set_logger(Logger logger) { logger_ = std::move(logger); } + +/* + * SSL Implementation + */ +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +namespace detail { + +template +inline SSL *ssl_new(socket_t sock, SSL_CTX *ctx, std::mutex &ctx_mutex, U SSL_connect_or_accept, V setup) { + SSL *ssl = nullptr; + { + std::lock_guard guard(ctx_mutex); + ssl = SSL_new(ctx); + } + + if (ssl) { + set_nonblocking(sock, true); + auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); + BIO_set_nbio(bio, 1); + SSL_set_bio(ssl, bio, bio); + + if (!setup(ssl) || SSL_connect_or_accept(ssl) != 1) { + SSL_shutdown(ssl); + { + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); + } + set_nonblocking(sock, false); + return nullptr; + } + BIO_set_nbio(bio, 0); + set_nonblocking(sock, false); + } + + return ssl; +} + +inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, socket_t sock, bool shutdown_gracefully) { + // sometimes we may want to skip this to try to avoid SIGPIPE if we know + // the remote has closed the network connection + // Note that it is not always possible to avoid SIGPIPE, this is merely a + // best-efforts. + if (shutdown_gracefully) { +#ifdef _WIN32 + SSL_shutdown(ssl); +#else + timeval tv; + tv.tv_sec = 1; + tv.tv_usec = 0; + setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&tv), sizeof(tv)); + + auto ret = SSL_shutdown(ssl); + while (ret == 0) { + std::this_thread::sleep_for(std::chrono::milliseconds{100}); + ret = SSL_shutdown(ssl); + } +#endif + } + + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); +} + +template +bool ssl_connect_or_accept_nonblocking(socket_t sock, SSL *ssl, U ssl_connect_or_accept, time_t timeout_sec, + time_t timeout_usec) { + auto res = 0; + while ((res = ssl_connect_or_accept(ssl)) != 1) { + auto err = SSL_get_error(ssl, res); + switch (err) { + case SSL_ERROR_WANT_READ: + if (select_read(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + break; + case SSL_ERROR_WANT_WRITE: + if (select_write(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + break; + default: + break; + } + return false; + } + return true; +} + +template +inline bool process_server_socket_ssl(const std::atomic &svr_sock, SSL *ssl, socket_t sock, + size_t keep_alive_max_count, time_t keep_alive_timeout_sec, + time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, T callback) { + return process_server_socket_core(svr_sock, sock, keep_alive_max_count, keep_alive_timeout_sec, + [&](bool close_connection, bool &connection_closed) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm, close_connection, connection_closed); + }); +} + +template +inline bool process_client_socket_ssl(SSL *ssl, socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, T callback) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, write_timeout_sec, write_timeout_usec); + return callback(strm); +} + +class SSLInit { + public: + SSLInit() { OPENSSL_init_ssl(OPENSSL_INIT_LOAD_SSL_STRINGS | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL); } +}; + +// SSL socket stream implementation +inline SSLSocketStream::SSLSocketStream(socket_t sock, SSL *ssl, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec) + : sock_(sock), + ssl_(ssl), + read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec), + write_timeout_sec_(write_timeout_sec), + write_timeout_usec_(write_timeout_usec) { + SSL_clear_mode(ssl, SSL_MODE_AUTO_RETRY); +} + +inline SSLSocketStream::~SSLSocketStream() = default; + +inline bool SSLSocketStream::is_readable() const { + return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; +} + +inline bool SSLSocketStream::is_writable() const { + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && is_socket_alive(sock_); +} + +inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { + if (SSL_pending(ssl_) > 0) { + return SSL_read(ssl_, ptr, static_cast(size)); + } else if (is_readable()) { + auto ret = SSL_read(ssl_, ptr, static_cast(size)); + if (ret < 0) { + auto err = SSL_get_error(ssl_, ret); + auto n = 1000; +#ifdef _WIN32 + while (--n >= 0 && + (err == SSL_ERROR_WANT_READ || (err == SSL_ERROR_SYSCALL && WSAGetLastError() == WSAETIMEDOUT))) { +#else + while (--n >= 0 && err == SSL_ERROR_WANT_READ) { +#endif + if (SSL_pending(ssl_) > 0) { + return SSL_read(ssl_, ptr, static_cast(size)); + } else if (is_readable()) { + std::this_thread::sleep_for(std::chrono::microseconds{10}); + ret = SSL_read(ssl_, ptr, static_cast(size)); + if (ret >= 0) { + return ret; + } + err = SSL_get_error(ssl_, ret); + } else { + return -1; + } + } + } + return ret; + } + return -1; +} + +inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { + if (is_writable()) { + auto handle_size = static_cast(std::min(size, (std::numeric_limits::max)())); + + auto ret = SSL_write(ssl_, ptr, static_cast(handle_size)); + if (ret < 0) { + auto err = SSL_get_error(ssl_, ret); + auto n = 1000; +#ifdef _WIN32 + while (--n >= 0 && + (err == SSL_ERROR_WANT_WRITE || (err == SSL_ERROR_SYSCALL && WSAGetLastError() == WSAETIMEDOUT))) { +#else + while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) { +#endif + if (is_writable()) { + std::this_thread::sleep_for(std::chrono::microseconds{10}); + ret = SSL_write(ssl_, ptr, static_cast(handle_size)); + if (ret >= 0) { + return ret; + } + err = SSL_get_error(ssl_, ret); + } else { + return -1; + } + } + } + return ret; + } + return -1; +} + +inline void SSLSocketStream::get_remote_ip_and_port(std::string &ip, int &port) const { + detail::get_remote_ip_and_port(sock_, ip, port); +} + +inline void SSLSocketStream::get_local_ip_and_port(std::string &ip, int &port) const { + detail::get_local_ip_and_port(sock_, ip, port); +} + +inline socket_t SSLSocketStream::socket() const { return sock_; } + +static SSLInit sslinit_; + +} // namespace detail + +// SSL HTTP server implementation +inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path, const char *client_ca_cert_file_path, + const char *client_ca_cert_dir_path, const char *private_key_password) { + ctx_ = SSL_CTX_new(TLS_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION | SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); + + if (private_key_password != nullptr && (private_key_password[0] != '\0')) { + SSL_CTX_set_default_passwd_cb_userdata(ctx_, reinterpret_cast(const_cast(private_key_password))); + } + + if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != 1 || + SSL_CTX_check_private_key(ctx_) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { + SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, client_ca_cert_dir_path); + + SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); + } + } +} + +inline SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key, X509_STORE *client_ca_cert_store) { + ctx_ = SSL_CTX_new(TLS_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION | SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); + + if (SSL_CTX_use_certificate(ctx_, cert) != 1 || SSL_CTX_use_PrivateKey(ctx_, private_key) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_store) { + SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); + + SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); + } + } +} + +inline SSLServer::SSLServer(const std::function &setup_ssl_ctx_callback) { + ctx_ = SSL_CTX_new(TLS_method()); + if (ctx_) { + if (!setup_ssl_ctx_callback(*ctx_)) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLServer::~SSLServer() { + if (ctx_) { + SSL_CTX_free(ctx_); + } +} + +inline bool SSLServer::is_valid() const { return ctx_; } + +inline SSL_CTX *SSLServer::ssl_context() const { return ctx_; } + +inline void SSLServer::update_certs(X509 *cert, EVP_PKEY *private_key, X509_STORE *client_ca_cert_store) { + std::lock_guard guard(ctx_mutex_); + + SSL_CTX_use_certificate(ctx_, cert); + SSL_CTX_use_PrivateKey(ctx_, private_key); + + if (client_ca_cert_store != nullptr) { + SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); + } +} + +inline bool SSLServer::process_and_close_socket(socket_t sock) { + auto ssl = detail::ssl_new( + sock, ctx_, ctx_mutex_, + [&](SSL *ssl2) { + return detail::ssl_connect_or_accept_nonblocking(sock, ssl2, SSL_accept, read_timeout_sec_, read_timeout_usec_); + }, + [](SSL * /*ssl2*/) { return true; }); + + auto ret = false; + if (ssl) { + std::string remote_addr; + int remote_port = 0; + detail::get_remote_ip_and_port(sock, remote_addr, remote_port); + + std::string local_addr; + int local_port = 0; + detail::get_local_ip_and_port(sock, local_addr, local_port); + + ret = detail::process_server_socket_ssl( + svr_sock_, ssl, sock, keep_alive_max_count_, keep_alive_timeout_sec_, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, [&](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request(strm, remote_addr, remote_port, local_addr, local_port, close_connection, + connection_closed, [&](Request &req) { req.ssl = ssl; }); + }); + + // Shutdown gracefully if the result seemed successful, non-gracefully if + // the connection appeared to be closed. + const bool shutdown_gracefully = ret; + detail::ssl_delete(ctx_mutex_, ssl, sock, shutdown_gracefully); + } + + detail::shutdown_socket(sock); + detail::close_socket(sock); + return ret; +} + +// SSL HTTP client implementation +inline SSLClient::SSLClient(const std::string &host) : SSLClient(host, 443, std::string(), std::string()) {} + +inline SSLClient::SSLClient(const std::string &host, int port) : SSLClient(host, port, std::string(), std::string()) {} + +inline SSLClient::SSLClient(const std::string &host, int port, const std::string &client_cert_path, + const std::string &client_key_path, const std::string &private_key_password) + : ClientImpl(host, port, client_cert_path, client_key_path) { + ctx_ = SSL_CTX_new(TLS_client_method()); + + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { host_components_.emplace_back(b, e); }); + + if (!client_cert_path.empty() && !client_key_path.empty()) { + if (!private_key_password.empty()) { + SSL_CTX_set_default_passwd_cb_userdata( + ctx_, reinterpret_cast(const_cast(private_key_password.c_str()))); + } + + if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), SSL_FILETYPE_PEM) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(), SSL_FILETYPE_PEM) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLClient::SSLClient(const std::string &host, int port, X509 *client_cert, EVP_PKEY *client_key, + const std::string &private_key_password) + : ClientImpl(host, port) { + ctx_ = SSL_CTX_new(TLS_client_method()); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { host_components_.emplace_back(b, e); }); + + if (client_cert != nullptr && client_key != nullptr) { + if (!private_key_password.empty()) { + SSL_CTX_set_default_passwd_cb_userdata( + ctx_, reinterpret_cast(const_cast(private_key_password.c_str()))); + } + + if (SSL_CTX_use_certificate(ctx_, client_cert) != 1 || SSL_CTX_use_PrivateKey(ctx_, client_key) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLClient::~SSLClient() { + if (ctx_) { + SSL_CTX_free(ctx_); + } + // Make sure to shut down SSL since shutdown_ssl will resolve to the + // base function rather than the derived function once we get to the + // base class destructor, and won't free the SSL (causing a leak). + shutdown_ssl_impl(socket_, true); +} + +inline bool SSLClient::is_valid() const { return ctx_; } + +inline void SSLClient::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (ca_cert_store) { + if (ctx_) { + if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store) { + // Free memory allocated for old cert and use new store `ca_cert_store` + SSL_CTX_set_cert_store(ctx_, ca_cert_store); + } + } else { + X509_STORE_free(ca_cert_store); + } + } +} + +inline void SSLClient::load_ca_cert_store(const char *ca_cert, std::size_t size) { + set_ca_cert_store(ClientImpl::create_ca_cert_store(ca_cert, size)); +} + +inline long SSLClient::get_openssl_verify_result() const { return verify_result_; } + +inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; } + +inline bool SSLClient::create_and_connect_socket(Socket &socket, Error &error) { + return is_valid() && ClientImpl::create_and_connect_socket(socket, error); +} + +// Assumes that socket_mutex_ is locked and that there are no requests in flight +inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res, bool &success, Error &error) { + success = true; + Response proxy_res; + if (!detail::process_client_socket(socket.sock, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, [&](Stream &strm) { + Request req2; + req2.method = "CONNECT"; + req2.path = host_and_port_; + return process_request(strm, req2, proxy_res, false, error); + })) { + // Thread-safe to close everything because we are assuming there are no + // requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + success = false; + return false; + } + + if (proxy_res.status == StatusCode::ProxyAuthenticationRequired_407) { + if (!proxy_digest_auth_username_.empty() && !proxy_digest_auth_password_.empty()) { + std::map auth; + if (detail::parse_www_authenticate(proxy_res, auth, true)) { + proxy_res = Response(); + if (!detail::process_client_socket(socket.sock, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, [&](Stream &strm) { + Request req3; + req3.method = "CONNECT"; + req3.path = host_and_port_; + req3.headers.insert(detail::make_digest_authentication_header( + req3, auth, 1, detail::random_string(10), proxy_digest_auth_username_, + proxy_digest_auth_password_, true)); + return process_request(strm, req3, proxy_res, false, error); + })) { + // Thread-safe to close everything because we are assuming there are + // no requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + success = false; + return false; + } + } + } + } + + // If status code is not 200, proxy request is failed. + // Set error to ProxyConnection and return proxy response + // as the response of the request + if (proxy_res.status != StatusCode::OK_200) { + error = Error::ProxyConnection; + res = std::move(proxy_res); + // Thread-safe to close everything because we are assuming there are + // no requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + return false; + } + + return true; +} + +inline bool SSLClient::load_certs() { + auto ret = true; + + std::call_once(initialize_cert_, [&]() { + std::lock_guard guard(ctx_mutex_); + if (!ca_cert_file_path_.empty()) { + if (!SSL_CTX_load_verify_locations(ctx_, ca_cert_file_path_.c_str(), nullptr)) { + ret = false; + } + } else if (!ca_cert_dir_path_.empty()) { + if (!SSL_CTX_load_verify_locations(ctx_, nullptr, ca_cert_dir_path_.c_str())) { + ret = false; + } + } else { + auto loaded = false; +#ifdef _WIN32 + loaded = detail::load_system_certs_on_windows(SSL_CTX_get_cert_store(ctx_)); +#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && defined(__APPLE__) +#if TARGET_OS_OSX + loaded = detail::load_system_certs_on_macos(SSL_CTX_get_cert_store(ctx_)); +#endif // TARGET_OS_OSX +#endif // _WIN32 + if (!loaded) { + SSL_CTX_set_default_verify_paths(ctx_); + } + } + }); + + return ret; +} + +inline bool SSLClient::initialize_ssl(Socket &socket, Error &error) { + auto ssl = detail::ssl_new( + socket.sock, ctx_, ctx_mutex_, + [&](SSL *ssl2) { + if (server_certificate_verification_) { + if (!load_certs()) { + error = Error::SSLLoadingCerts; + return false; + } + SSL_set_verify(ssl2, SSL_VERIFY_NONE, nullptr); + } + + if (!detail::ssl_connect_or_accept_nonblocking(socket.sock, ssl2, SSL_connect, connection_timeout_sec_, + connection_timeout_usec_)) { + error = Error::SSLConnection; + return false; + } + + if (server_certificate_verification_) { + if (server_certificate_verifier_) { + if (!server_certificate_verifier_(ssl2)) { + error = Error::SSLServerVerification; + return false; + } + } else { + verify_result_ = SSL_get_verify_result(ssl2); + + if (verify_result_ != X509_V_OK) { + error = Error::SSLServerVerification; + return false; + } + + auto server_cert = SSL_get1_peer_certificate(ssl2); + auto se = detail::scope_exit([&] { X509_free(server_cert); }); + + if (server_cert == nullptr) { + error = Error::SSLServerVerification; + return false; + } + + if (server_hostname_verification_) { + if (!verify_host(server_cert)) { + error = Error::SSLServerHostnameVerification; + return false; + } + } + } + } + + return true; + }, + [&](SSL *ssl2) { +#if defined(OPENSSL_IS_BORINGSSL) + SSL_set_tlsext_host_name(ssl2, host_.c_str()); +#else + // NOTE: Direct call instead of using the OpenSSL macro to suppress + // -Wold-style-cast warning + SSL_ctrl(ssl2, SSL_CTRL_SET_TLSEXT_HOSTNAME, TLSEXT_NAMETYPE_host_name, + static_cast(const_cast(host_.c_str()))); +#endif + return true; + }); + + if (ssl) { + socket.ssl = ssl; + return true; + } + + shutdown_socket(socket); + close_socket(socket); + return false; +} + +inline void SSLClient::shutdown_ssl(Socket &socket, bool shutdown_gracefully) { + shutdown_ssl_impl(socket, shutdown_gracefully); +} + +inline void SSLClient::shutdown_ssl_impl(Socket &socket, bool shutdown_gracefully) { + if (socket.sock == INVALID_SOCKET) { + assert(socket.ssl == nullptr); + return; + } + if (socket.ssl) { + detail::ssl_delete(ctx_mutex_, socket.ssl, socket.sock, shutdown_gracefully); + socket.ssl = nullptr; + } + assert(socket.ssl == nullptr); +} + +inline bool SSLClient::process_socket(const Socket &socket, std::function callback) { + assert(socket.ssl); + return detail::process_client_socket_ssl(socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, std::move(callback)); +} + +inline bool SSLClient::is_ssl() const { return true; } + +inline bool SSLClient::verify_host(X509 *server_cert) const { + /* Quote from RFC2818 section 3.1 "Server Identity" + + If a subjectAltName extension of type dNSName is present, that MUST + be used as the identity. Otherwise, the (most specific) Common Name + field in the Subject field of the certificate MUST be used. Although + the use of the Common Name is existing practice, it is deprecated and + Certification Authorities are encouraged to use the dNSName instead. + + Matching is performed using the matching rules specified by + [RFC2459]. If more than one identity of a given type is present in + the certificate (e.g., more than one dNSName name, a match in any one + of the set is considered acceptable.) Names may contain the wildcard + character * which is considered to match any single domain name + component or component fragment. E.g., *.a.com matches foo.a.com but + not bar.foo.a.com. f*.com matches foo.com but not bar.com. + + In some cases, the URI is specified as an IP address rather than a + hostname. In this case, the iPAddress subjectAltName must be present + in the certificate and must exactly match the IP in the URI. + + */ + return verify_host_with_subject_alt_name(server_cert) || verify_host_with_common_name(server_cert); +} + +inline bool SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { + auto ret = false; + + auto type = GEN_DNS; + + struct in6_addr addr6 {}; + struct in_addr addr {}; + size_t addr_len = 0; + +#ifndef __MINGW32__ + if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { + type = GEN_IPADD; + addr_len = sizeof(struct in6_addr); + } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { + type = GEN_IPADD; + addr_len = sizeof(struct in_addr); + } +#endif + + auto alt_names = static_cast( + X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); + + if (alt_names) { + auto dsn_matched = false; + auto ip_matched = false; + + auto count = sk_GENERAL_NAME_num(alt_names); + + for (decltype(count) i = 0; i < count && !dsn_matched; i++) { + auto val = sk_GENERAL_NAME_value(alt_names, i); + if (val->type == type) { + auto name = reinterpret_cast(ASN1_STRING_get0_data(val->d.ia5)); + auto name_len = static_cast(ASN1_STRING_length(val->d.ia5)); + + switch (type) { + case GEN_DNS: + dsn_matched = check_host_name(name, name_len); + break; + + case GEN_IPADD: + if (!memcmp(&addr6, name, addr_len) || !memcmp(&addr, name, addr_len)) { + ip_matched = true; + } + break; + } + } + } + + if (dsn_matched || ip_matched) { + ret = true; + } + } + + GENERAL_NAMES_free(const_cast(reinterpret_cast(alt_names))); + return ret; +} + +inline bool SSLClient::verify_host_with_common_name(X509 *server_cert) const { + const auto subject_name = X509_get_subject_name(server_cert); + + if (subject_name != nullptr) { + char name[BUFSIZ]; + auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, name, sizeof(name)); + + if (name_len != -1) { + return check_host_name(name, static_cast(name_len)); + } + } + + return false; +} + +inline bool SSLClient::check_host_name(const char *pattern, size_t pattern_len) const { + if (host_.size() == pattern_len && host_ == pattern) { + return true; + } + + // Wildcard match + // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 + std::vector pattern_components; + detail::split(&pattern[0], &pattern[pattern_len], '.', + [&](const char *b, const char *e) { pattern_components.emplace_back(b, e); }); + + if (host_components_.size() != pattern_components.size()) { + return false; + } + + auto itr = pattern_components.begin(); + for (const auto &h : host_components_) { + auto &p = *itr; + if (p != h && p != "*") { + auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' && !p.compare(0, p.size() - 1, h)); + if (!partial_match) { + return false; + } + } + ++itr; + } + + return true; +} +#endif + +// Universal client implementation +inline Client::Client(const std::string &scheme_host_port) : Client(scheme_host_port, std::string(), std::string()) {} + +inline Client::Client(const std::string &scheme_host_port, const std::string &client_cert_path, + const std::string &client_key_path) { + const static std::regex re(R"((?:([a-z]+):\/\/)?(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)"); + + std::smatch m; + if (std::regex_match(scheme_host_port, m, re)) { + auto scheme = m[1].str(); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (!scheme.empty() && (scheme != "http" && scheme != "https")) { +#else + if (!scheme.empty() && scheme != "http") { +#endif +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + std::string msg = "'" + scheme + "' scheme is not supported."; + throw std::invalid_argument(msg); +#endif + return; + } + + auto is_ssl = scheme == "https"; + + auto host = m[2].str(); + if (host.empty()) { + host = m[3].str(); + } + + auto port_str = m[4].str(); + auto port = !port_str.empty() ? std::stoi(port_str) : (is_ssl ? 443 : 80); + + if (is_ssl) { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + cli_ = detail::make_unique(host, port, client_cert_path, client_key_path); + is_ssl_ = is_ssl; +#endif + } else { + cli_ = detail::make_unique(host, port, client_cert_path, client_key_path); + } + } else { + // NOTE: Update TEST(UniversalClientImplTest, Ipv6LiteralAddress) + // if port param below changes. + cli_ = detail::make_unique(scheme_host_port, 80, client_cert_path, client_key_path); + } +} // namespace detail + +inline Client::Client(const std::string &host, int port) : cli_(detail::make_unique(host, port)) {} + +inline Client::Client(const std::string &host, int port, const std::string &client_cert_path, + const std::string &client_key_path) + : cli_(detail::make_unique(host, port, client_cert_path, client_key_path)) {} + +inline Client::~Client() = default; + +inline bool Client::is_valid() const { return cli_ != nullptr && cli_->is_valid(); } + +inline Result Client::Get(const std::string &path) { return cli_->Get(path); } +inline Result Client::Get(const std::string &path, const Headers &headers) { return cli_->Get(path, headers); } +inline Result Client::Get(const std::string &path, Progress progress) { return cli_->Get(path, std::move(progress)); } +inline Result Client::Get(const std::string &path, const Headers &headers, Progress progress) { + return cli_->Get(path, headers, std::move(progress)); +} +inline Result Client::Get(const std::string &path, ContentReceiver content_receiver) { + return cli_->Get(path, std::move(content_receiver)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver) { + return cli_->Get(path, headers, std::move(content_receiver)); +} +inline Result Client::Get(const std::string &path, ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, std::move(content_receiver), std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver, + Progress progress) { + return cli_->Get(path, headers, std::move(content_receiver), std::move(progress)); +} +inline Result Client::Get(const std::string &path, ResponseHandler response_handler, ContentReceiver content_receiver) { + return cli_->Get(path, std::move(response_handler), std::move(content_receiver)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver) { + return cli_->Get(path, headers, std::move(response_handler), std::move(content_receiver)); +} +inline Result Client::Get(const std::string &path, ResponseHandler response_handler, ContentReceiver content_receiver, + Progress progress) { + return cli_->Get(path, std::move(response_handler), std::move(content_receiver), std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, headers, std::move(response_handler), std::move(content_receiver), std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Params ¶ms, const Headers &headers, Progress progress) { + return cli_->Get(path, params, headers, std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Params ¶ms, const Headers &headers, + ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, params, headers, std::move(content_receiver), std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Params ¶ms, const Headers &headers, + ResponseHandler response_handler, ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, params, headers, std::move(response_handler), std::move(content_receiver), + std::move(progress)); +} + +inline Result Client::Head(const std::string &path) { return cli_->Head(path); } +inline Result Client::Head(const std::string &path, const Headers &headers) { return cli_->Head(path, headers); } + +inline Result Client::Post(const std::string &path) { return cli_->Post(path); } +inline Result Client::Post(const std::string &path, const Headers &headers) { return cli_->Post(path, headers); } +inline Result Client::Post(const std::string &path, const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Post(path, body, content_length, content_type); +} +inline Result Client::Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Post(path, headers, body, content_length, content_type); +} +inline Result Client::Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return cli_->Post(path, headers, body, content_length, content_type, progress); +} +inline Result Client::Post(const std::string &path, const std::string &body, const std::string &content_type) { + return cli_->Post(path, body, content_type); +} +inline Result Client::Post(const std::string &path, const std::string &body, const std::string &content_type, + Progress progress) { + return cli_->Post(path, body, content_type, progress); +} +inline Result Client::Post(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type) { + return cli_->Post(path, headers, body, content_type); +} +inline Result Client::Post(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type, Progress progress) { + return cli_->Post(path, headers, body, content_type, progress); +} +inline Result Client::Post(const std::string &path, size_t content_length, ContentProvider content_provider, + const std::string &content_type) { + return cli_->Post(path, content_length, std::move(content_provider), content_type); +} +inline Result Client::Post(const std::string &path, ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Post(path, std::move(content_provider), content_type); +} +inline Result Client::Post(const std::string &path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const std::string &content_type) { + return cli_->Post(path, headers, content_length, std::move(content_provider), content_type); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, const std::string &content_type) { + return cli_->Post(path, headers, std::move(content_provider), content_type); +} +inline Result Client::Post(const std::string &path, const Params ¶ms) { return cli_->Post(path, params); } +inline Result Client::Post(const std::string &path, const Headers &headers, const Params ¶ms) { + return cli_->Post(path, headers, params); +} +inline Result Client::Post(const std::string &path, const Headers &headers, const Params ¶ms, Progress progress) { + return cli_->Post(path, headers, params, progress); +} +inline Result Client::Post(const std::string &path, const MultipartFormDataItems &items) { + return cli_->Post(path, items); +} +inline Result Client::Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items) { + return cli_->Post(path, headers, items); +} +inline Result Client::Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const std::string &boundary) { + return cli_->Post(path, headers, items, boundary); +} +inline Result Client::Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) { + return cli_->Post(path, headers, items, provider_items); +} +inline Result Client::Put(const std::string &path) { return cli_->Put(path); } +inline Result Client::Put(const std::string &path, const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Put(path, body, content_length, content_type); +} +inline Result Client::Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Put(path, headers, body, content_length, content_type); +} +inline Result Client::Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return cli_->Put(path, headers, body, content_length, content_type, progress); +} +inline Result Client::Put(const std::string &path, const std::string &body, const std::string &content_type) { + return cli_->Put(path, body, content_type); +} +inline Result Client::Put(const std::string &path, const std::string &body, const std::string &content_type, + Progress progress) { + return cli_->Put(path, body, content_type, progress); +} +inline Result Client::Put(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type) { + return cli_->Put(path, headers, body, content_type); +} +inline Result Client::Put(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type, Progress progress) { + return cli_->Put(path, headers, body, content_type, progress); +} +inline Result Client::Put(const std::string &path, size_t content_length, ContentProvider content_provider, + const std::string &content_type) { + return cli_->Put(path, content_length, std::move(content_provider), content_type); +} +inline Result Client::Put(const std::string &path, ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Put(path, std::move(content_provider), content_type); +} +inline Result Client::Put(const std::string &path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const std::string &content_type) { + return cli_->Put(path, headers, content_length, std::move(content_provider), content_type); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, const std::string &content_type) { + return cli_->Put(path, headers, std::move(content_provider), content_type); +} +inline Result Client::Put(const std::string &path, const Params ¶ms) { return cli_->Put(path, params); } +inline Result Client::Put(const std::string &path, const Headers &headers, const Params ¶ms) { + return cli_->Put(path, headers, params); +} +inline Result Client::Put(const std::string &path, const Headers &headers, const Params ¶ms, Progress progress) { + return cli_->Put(path, headers, params, progress); +} +inline Result Client::Put(const std::string &path, const MultipartFormDataItems &items) { + return cli_->Put(path, items); +} +inline Result Client::Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items) { + return cli_->Put(path, headers, items); +} +inline Result Client::Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const std::string &boundary) { + return cli_->Put(path, headers, items, boundary); +} +inline Result Client::Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) { + return cli_->Put(path, headers, items, provider_items); +} +inline Result Client::Patch(const std::string &path) { return cli_->Patch(path); } +inline Result Client::Patch(const std::string &path, const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Patch(path, body, content_length, content_type); +} +inline Result Client::Patch(const std::string &path, const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return cli_->Patch(path, body, content_length, content_type, progress); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Patch(path, headers, body, content_length, content_type); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return cli_->Patch(path, headers, body, content_length, content_type, progress); +} +inline Result Client::Patch(const std::string &path, const std::string &body, const std::string &content_type) { + return cli_->Patch(path, body, content_type); +} +inline Result Client::Patch(const std::string &path, const std::string &body, const std::string &content_type, + Progress progress) { + return cli_->Patch(path, body, content_type, progress); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type) { + return cli_->Patch(path, headers, body, content_type); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type, Progress progress) { + return cli_->Patch(path, headers, body, content_type, progress); +} +inline Result Client::Patch(const std::string &path, size_t content_length, ContentProvider content_provider, + const std::string &content_type) { + return cli_->Patch(path, content_length, std::move(content_provider), content_type); +} +inline Result Client::Patch(const std::string &path, ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Patch(path, std::move(content_provider), content_type); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const std::string &content_type) { + return cli_->Patch(path, headers, content_length, std::move(content_provider), content_type); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, const std::string &content_type) { + return cli_->Patch(path, headers, std::move(content_provider), content_type); +} +inline Result Client::Delete(const std::string &path) { return cli_->Delete(path); } +inline Result Client::Delete(const std::string &path, const Headers &headers) { return cli_->Delete(path, headers); } +inline Result Client::Delete(const std::string &path, const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Delete(path, body, content_length, content_type); +} +inline Result Client::Delete(const std::string &path, const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return cli_->Delete(path, body, content_length, content_type, progress); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Delete(path, headers, body, content_length, content_type); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return cli_->Delete(path, headers, body, content_length, content_type, progress); +} +inline Result Client::Delete(const std::string &path, const std::string &body, const std::string &content_type) { + return cli_->Delete(path, body, content_type); +} +inline Result Client::Delete(const std::string &path, const std::string &body, const std::string &content_type, + Progress progress) { + return cli_->Delete(path, body, content_type, progress); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type) { + return cli_->Delete(path, headers, body, content_type); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type, Progress progress) { + return cli_->Delete(path, headers, body, content_type, progress); +} +inline Result Client::Options(const std::string &path) { return cli_->Options(path); } +inline Result Client::Options(const std::string &path, const Headers &headers) { return cli_->Options(path, headers); } + +inline bool Client::send(Request &req, Response &res, Error &error) { return cli_->send(req, res, error); } + +inline Result Client::send(const Request &req) { return cli_->send(req); } + +inline void Client::stop() { cli_->stop(); } + +inline std::string Client::host() const { return cli_->host(); } + +inline int Client::port() const { return cli_->port(); } + +inline size_t Client::is_socket_open() const { return cli_->is_socket_open(); } + +inline socket_t Client::socket() const { return cli_->socket(); } + +inline void Client::set_hostname_addr_map(std::map addr_map) { + cli_->set_hostname_addr_map(std::move(addr_map)); +} + +inline void Client::set_default_headers(Headers headers) { cli_->set_default_headers(std::move(headers)); } + +inline void Client::set_header_writer(std::function const &writer) { + cli_->set_header_writer(writer); +} + +inline void Client::set_address_family(int family) { cli_->set_address_family(family); } + +inline void Client::set_tcp_nodelay(bool on) { cli_->set_tcp_nodelay(on); } + +inline void Client::set_socket_options(SocketOptions socket_options) { + cli_->set_socket_options(std::move(socket_options)); +} + +inline void Client::set_connection_timeout(time_t sec, time_t usec) { cli_->set_connection_timeout(sec, usec); } + +inline void Client::set_read_timeout(time_t sec, time_t usec) { cli_->set_read_timeout(sec, usec); } + +inline void Client::set_write_timeout(time_t sec, time_t usec) { cli_->set_write_timeout(sec, usec); } + +inline void Client::set_basic_auth(const std::string &username, const std::string &password) { + cli_->set_basic_auth(username, password); +} +inline void Client::set_bearer_token_auth(const std::string &token) { cli_->set_bearer_token_auth(token); } +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_digest_auth(const std::string &username, const std::string &password) { + cli_->set_digest_auth(username, password); +} +#endif + +inline void Client::set_keep_alive(bool on) { cli_->set_keep_alive(on); } +inline void Client::set_follow_location(bool on) { cli_->set_follow_location(on); } + +inline void Client::set_url_encode(bool on) { cli_->set_url_encode(on); } + +inline void Client::set_compress(bool on) { cli_->set_compress(on); } + +inline void Client::set_decompress(bool on) { cli_->set_decompress(on); } + +inline void Client::set_interface(const std::string &intf) { cli_->set_interface(intf); } + +inline void Client::set_proxy(const std::string &host, int port) { cli_->set_proxy(host, port); } +inline void Client::set_proxy_basic_auth(const std::string &username, const std::string &password) { + cli_->set_proxy_basic_auth(username, password); +} +inline void Client::set_proxy_bearer_token_auth(const std::string &token) { cli_->set_proxy_bearer_token_auth(token); } +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_proxy_digest_auth(const std::string &username, const std::string &password) { + cli_->set_proxy_digest_auth(username, password); +} +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::enable_server_certificate_verification(bool enabled) { + cli_->enable_server_certificate_verification(enabled); +} + +inline void Client::enable_server_hostname_verification(bool enabled) { + cli_->enable_server_hostname_verification(enabled); +} + +inline void Client::set_server_certificate_verifier(std::function verifier) { + cli_->set_server_certificate_verifier(verifier); +} +#endif + +inline void Client::set_logger(Logger logger) { cli_->set_logger(std::move(logger)); } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_ca_cert_path(const std::string &ca_cert_file_path, const std::string &ca_cert_dir_path) { + cli_->set_ca_cert_path(ca_cert_file_path, ca_cert_dir_path); +} + +inline void Client::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (is_ssl_) { + static_cast(*cli_).set_ca_cert_store(ca_cert_store); + } else { + cli_->set_ca_cert_store(ca_cert_store); + } +} + +inline void Client::load_ca_cert_store(const char *ca_cert, std::size_t size) { + set_ca_cert_store(cli_->create_ca_cert_store(ca_cert, size)); +} + +inline long Client::get_openssl_verify_result() const { + if (is_ssl_) { + return static_cast(*cli_).get_openssl_verify_result(); + } + return -1; // NOTE: -1 doesn't match any of X509_V_ERR_??? +} + +inline SSL_CTX *Client::ssl_context() const { + if (is_ssl_) { + return static_cast(*cli_).ssl_context(); + } + return nullptr; +} +#endif + +// ---------------------------------------------------------------------------- + +} // namespace httplib + +#if defined(_WIN32) && defined(CPPHTTPLIB_USE_POLL) +#undef poll +#endif + +#endif // CPPHTTPLIB_HTTPLIB_H + +#endif diff --git a/script/ci-custom.py b/script/ci-custom.py index d5d3ab88c8..6a5fb32180 100755 --- a/script/ci-custom.py +++ b/script/ci-custom.py @@ -292,6 +292,7 @@ def highlight(s): "esphome/core/log.h", "esphome/components/socket/headers.h", "esphome/core/defines.h", + "esphome/components/http_request/httplib.h", ], ) def lint_no_defines(fname, match): @@ -552,6 +553,7 @@ def lint_relative_py_import(fname): "esphome/components/rp2040/core.cpp", "esphome/components/libretiny/core.cpp", "esphome/components/host/core.cpp", + "esphome/components/http_request/httplib.h", ], ) def lint_namespace(fname, content): diff --git a/tests/components/http_request/common.yaml b/tests/components/http_request/common.yaml index 4a9b8a0e62..af4852901f 100644 --- a/tests/components/http_request/common.yaml +++ b/tests/components/http_request/common.yaml @@ -1,5 +1,4 @@ -substitutions: - verify_ssl: "true" +<<: !include http_request.yaml wifi: ssid: MySSID diff --git a/tests/components/http_request/http_request.yaml b/tests/components/http_request/http_request.yaml new file mode 100644 index 0000000000..ea7f6bf5a7 --- /dev/null +++ b/tests/components/http_request/http_request.yaml @@ -0,0 +1,46 @@ +substitutions: + verify_ssl: "true" + +network: + +esphome: + on_boot: + then: + - http_request.get: + url: https://esphome.io + request_headers: + Content-Type: application/json + on_error: + logger.log: "Request failed" + on_response: + then: + - logger.log: + format: "Response status: %d, Duration: %lu ms" + args: + - response->status_code + - (long) response->duration_ms + - http_request.post: + url: https://esphome.io + request_headers: + Content-Type: application/json + json: + key: value + - http_request.send: + method: PUT + url: https://esphome.io + request_headers: + Content-Type: application/json + body: "Some data" + +http_request: + useragent: esphome/tagreader + timeout: 10s + verify_ssl: ${verify_ssl} + +script: + - id: does_not_compile + parameters: + api_url: string + then: + - http_request.get: + url: "http://google.com" diff --git a/tests/components/http_request/test.host.yaml b/tests/components/http_request/test.host.yaml new file mode 100644 index 0000000000..e91445fb2d --- /dev/null +++ b/tests/components/http_request/test.host.yaml @@ -0,0 +1,7 @@ +substitutions: + verify_ssl: "true" +http_request: + # Just a file we can be sure exists + ca_certificate_path: /etc/passwd + +<<: !include http_request.yaml From 1da0dff8b1ce02db7fcb5f9396eb9699527ea01d Mon Sep 17 00:00:00 2001 From: Lucas Hartmann Date: Sun, 27 Apr 2025 23:18:47 -0300 Subject: [PATCH 024/102] Take advantage of clipping to speed image drawing. (#8630) --- esphome/components/image/image.cpp | 33 ++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/esphome/components/image/image.cpp b/esphome/components/image/image.cpp index f05f4af711..82e46e3460 100644 --- a/esphome/components/image/image.cpp +++ b/esphome/components/image/image.cpp @@ -6,10 +6,27 @@ namespace esphome { namespace image { void Image::draw(int x, int y, display::Display *display, Color color_on, Color color_off) { + int img_x0 = 0; + int img_y0 = 0; + int w = width_; + int h = height_; + + auto clipping = display->get_clipping(); + if (clipping.is_set()) { + if (clipping.x > x) + img_x0 += clipping.x - x; + if (clipping.y > y) + img_y0 += clipping.y - y; + if (w > clipping.x2() - x) + w = clipping.x2() - x; + if (h > clipping.y2() - y) + h = clipping.y2() - y; + } + switch (type_) { case IMAGE_TYPE_BINARY: { - for (int img_x = 0; img_x < width_; img_x++) { - for (int img_y = 0; img_y < height_; img_y++) { + for (int img_x = img_x0; img_x < w; img_x++) { + for (int img_y = img_y0; img_y < h; img_y++) { if (this->get_binary_pixel_(img_x, img_y)) { display->draw_pixel_at(x + img_x, y + img_y, color_on); } else if (!this->transparency_) { @@ -20,8 +37,8 @@ void Image::draw(int x, int y, display::Display *display, Color color_on, Color break; } case IMAGE_TYPE_GRAYSCALE: - for (int img_x = 0; img_x < width_; img_x++) { - for (int img_y = 0; img_y < height_; img_y++) { + for (int img_x = img_x0; img_x < w; img_x++) { + for (int img_y = img_y0; img_y < h; img_y++) { const uint32_t pos = (img_x + img_y * this->width_); const uint8_t gray = progmem_read_byte(this->data_start_ + pos); Color color = Color(gray, gray, gray, 0xFF); @@ -47,8 +64,8 @@ void Image::draw(int x, int y, display::Display *display, Color color_on, Color } break; case IMAGE_TYPE_RGB565: - for (int img_x = 0; img_x < width_; img_x++) { - for (int img_y = 0; img_y < height_; img_y++) { + for (int img_x = img_x0; img_x < w; img_x++) { + for (int img_y = img_y0; img_y < h; img_y++) { auto color = this->get_rgb565_pixel_(img_x, img_y); if (color.w >= 0x80) { display->draw_pixel_at(x + img_x, y + img_y, color); @@ -57,8 +74,8 @@ void Image::draw(int x, int y, display::Display *display, Color color_on, Color } break; case IMAGE_TYPE_RGB: - for (int img_x = 0; img_x < width_; img_x++) { - for (int img_y = 0; img_y < height_; img_y++) { + for (int img_x = img_x0; img_x < w; img_x++) { + for (int img_y = img_y0; img_y < h; img_y++) { auto color = this->get_rgb_pixel_(img_x, img_y); if (color.w >= 0x80) { display->draw_pixel_at(x + img_x, y + img_y, color); From fdc4ec8a5784ebc39c7b9b029b28456aef1da8b2 Mon Sep 17 00:00:00 2001 From: Ben Winslow Date: Sun, 27 Apr 2025 22:29:47 -0400 Subject: [PATCH 025/102] [touchscreen] Clear interrupt flag before reading touch data. (#8632) --- esphome/components/touchscreen/touchscreen.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/esphome/components/touchscreen/touchscreen.cpp b/esphome/components/touchscreen/touchscreen.cpp index 11207908fa..dcf3209752 100644 --- a/esphome/components/touchscreen/touchscreen.cpp +++ b/esphome/components/touchscreen/touchscreen.cpp @@ -50,13 +50,15 @@ void Touchscreen::loop() { tp.second.x_prev = tp.second.x; tp.second.y_prev = tp.second.y; } + // The interrupt flag must be reset BEFORE calling update_touches, otherwise we might miss an interrupt that was + // triggered while we were reading touch data. + this->store_.touched = false; this->update_touches(); if (this->skip_update_) { for (auto &tp : this->touches_) { tp.second.state &= ~STATE_RELEASING; } } else { - this->store_.touched = false; this->defer([this]() { this->send_touches_(); }); if (this->touch_timeout_ > 0) { // Simulate a touch after touch_timeout_> ms. This will reset any existing timeout operation. From 253e3ec6f63b16ca3ee4030e77a0236557360efe Mon Sep 17 00:00:00 2001 From: Nate Clark Date: Mon, 28 Apr 2025 00:27:39 -0400 Subject: [PATCH 026/102] [mdns] Support templatable config options for MDNS extra services (#8606) --- esphome/components/mdns/__init__.py | 16 +++++++++++----- esphome/components/mdns/mdns_component.cpp | 6 ++++-- esphome/components/mdns/mdns_component.h | 5 +++-- esphome/components/mdns/mdns_esp32.cpp | 7 ++++--- esphome/components/mdns/mdns_esp8266.cpp | 6 ++++-- esphome/components/mdns/mdns_libretiny.cpp | 6 ++++-- esphome/components/mdns/mdns_rp2040.cpp | 6 ++++-- esphome/core/automation.h | 7 ++++--- tests/components/mdns/common-enabled.yaml | 7 +++++++ 9 files changed, 45 insertions(+), 21 deletions(-) diff --git a/esphome/components/mdns/__init__.py b/esphome/components/mdns/__init__.py index e8902d5222..4b5e40dfea 100644 --- a/esphome/components/mdns/__init__.py +++ b/esphome/components/mdns/__init__.py @@ -35,8 +35,8 @@ SERVICE_SCHEMA = cv.Schema( { cv.Required(CONF_SERVICE): cv.string, cv.Required(CONF_PROTOCOL): cv.string, - cv.Optional(CONF_PORT, default=0): cv.Any(0, cv.port), - cv.Optional(CONF_TXT, default={}): {cv.string: cv.string}, + cv.Optional(CONF_PORT, default=0): cv.templatable(cv.Any(0, cv.port)), + cv.Optional(CONF_TXT, default={}): {cv.string: cv.templatable(cv.string)}, } ) @@ -102,12 +102,18 @@ async def to_code(config): for service in config[CONF_SERVICES]: txt = [ - mdns_txt_record(txt_key, txt_value) + cg.StructInitializer( + MDNSTXTRecord, + ("key", txt_key), + ("value", await cg.templatable(txt_value, [], cg.std_string)), + ) for txt_key, txt_value in service[CONF_TXT].items() ] - exp = mdns_service( - service[CONF_SERVICE], service[CONF_PROTOCOL], service[CONF_PORT], txt + service[CONF_SERVICE], + service[CONF_PROTOCOL], + await cg.templatable(service[CONF_PORT], [], cg.uint16), + txt, ) cg.add(var.add_extra_service(exp)) diff --git a/esphome/components/mdns/mdns_component.cpp b/esphome/components/mdns/mdns_component.cpp index 7f4b749456..ffc668e218 100644 --- a/esphome/components/mdns/mdns_component.cpp +++ b/esphome/components/mdns/mdns_component.cpp @@ -121,9 +121,11 @@ void MDNSComponent::dump_config() { ESP_LOGCONFIG(TAG, " Hostname: %s", this->hostname_.c_str()); ESP_LOGV(TAG, " Services:"); for (const auto &service : this->services_) { - ESP_LOGV(TAG, " - %s, %s, %d", service.service_type.c_str(), service.proto.c_str(), service.port); + ESP_LOGV(TAG, " - %s, %s, %d", service.service_type.c_str(), service.proto.c_str(), + const_cast &>(service.port).value()); for (const auto &record : service.txt_records) { - ESP_LOGV(TAG, " TXT: %s = %s", record.key.c_str(), record.value.c_str()); + ESP_LOGV(TAG, " TXT: %s = %s", record.key.c_str(), + const_cast &>(record.value).value().c_str()); } } } diff --git a/esphome/components/mdns/mdns_component.h b/esphome/components/mdns/mdns_component.h index dfb5b72292..9eb2ba11d0 100644 --- a/esphome/components/mdns/mdns_component.h +++ b/esphome/components/mdns/mdns_component.h @@ -3,6 +3,7 @@ #ifdef USE_MDNS #include #include +#include "esphome/core/automation.h" #include "esphome/core/component.h" namespace esphome { @@ -10,7 +11,7 @@ namespace mdns { struct MDNSTXTRecord { std::string key; - std::string value; + TemplatableValue value; }; struct MDNSService { @@ -20,7 +21,7 @@ struct MDNSService { // second label indicating protocol _including_ underscore character prefix // as defined in RFC6763 Section 7, like "_tcp" or "_udp" std::string proto; - uint16_t port; + TemplatableValue port; std::vector txt_records; }; diff --git a/esphome/components/mdns/mdns_esp32.cpp b/esphome/components/mdns/mdns_esp32.cpp index 8006eb27f1..fed18d3630 100644 --- a/esphome/components/mdns/mdns_esp32.cpp +++ b/esphome/components/mdns/mdns_esp32.cpp @@ -31,11 +31,12 @@ void MDNSComponent::setup() { mdns_txt_item_t it{}; // dup strings to ensure the pointer is valid even after the record loop it.key = strdup(record.key.c_str()); - it.value = strdup(record.value.c_str()); + it.value = strdup(const_cast &>(record.value).value().c_str()); txt_records.push_back(it); } - err = mdns_service_add(nullptr, service.service_type.c_str(), service.proto.c_str(), service.port, - txt_records.data(), txt_records.size()); + uint16_t port = const_cast &>(service.port).value(); + err = mdns_service_add(nullptr, service.service_type.c_str(), service.proto.c_str(), port, txt_records.data(), + txt_records.size()); // free records for (const auto &it : txt_records) { diff --git a/esphome/components/mdns/mdns_esp8266.cpp b/esphome/components/mdns/mdns_esp8266.cpp index 7b6e7ec448..2c90d57021 100644 --- a/esphome/components/mdns/mdns_esp8266.cpp +++ b/esphome/components/mdns/mdns_esp8266.cpp @@ -29,9 +29,11 @@ void MDNSComponent::setup() { while (*service_type == '_') { service_type++; } - MDNS.addService(service_type, proto, service.port); + uint16_t port = const_cast &>(service.port).value(); + MDNS.addService(service_type, proto, port); for (const auto &record : service.txt_records) { - MDNS.addServiceTxt(service_type, proto, record.key.c_str(), record.value.c_str()); + MDNS.addServiceTxt(service_type, proto, record.key.c_str(), + const_cast &>(record.value).value().c_str()); } } } diff --git a/esphome/components/mdns/mdns_libretiny.cpp b/esphome/components/mdns/mdns_libretiny.cpp index c9a9a289dd..7a41ec9dce 100644 --- a/esphome/components/mdns/mdns_libretiny.cpp +++ b/esphome/components/mdns/mdns_libretiny.cpp @@ -29,9 +29,11 @@ void MDNSComponent::setup() { while (*service_type == '_') { service_type++; } - MDNS.addService(service_type, proto, service.port); + uint16_t port_ = const_cast &>(service.port).value(); + MDNS.addService(service_type, proto, port_); for (const auto &record : service.txt_records) { - MDNS.addServiceTxt(service_type, proto, record.key.c_str(), record.value.c_str()); + MDNS.addServiceTxt(service_type, proto, record.key.c_str(), + const_cast &>(record.value).value().c_str()); } } } diff --git a/esphome/components/mdns/mdns_rp2040.cpp b/esphome/components/mdns/mdns_rp2040.cpp index 89e668ee59..95894323f4 100644 --- a/esphome/components/mdns/mdns_rp2040.cpp +++ b/esphome/components/mdns/mdns_rp2040.cpp @@ -29,9 +29,11 @@ void MDNSComponent::setup() { while (*service_type == '_') { service_type++; } - MDNS.addService(service_type, proto, service.port); + uint16_t port = const_cast &>(service.port).value(); + MDNS.addService(service_type, proto, port); for (const auto &record : service.txt_records) { - MDNS.addServiceTxt(service_type, proto, record.key.c_str(), record.value.c_str()); + MDNS.addServiceTxt(service_type, proto, record.key.c_str(), + const_cast &>(record.value).value().c_str()); } } } diff --git a/esphome/core/automation.h b/esphome/core/automation.h index e77e453431..02c9d44f16 100644 --- a/esphome/core/automation.h +++ b/esphome/core/automation.h @@ -1,10 +1,11 @@ #pragma once -#include #include "esphome/core/component.h" -#include "esphome/core/helpers.h" #include "esphome/core/defines.h" +#include "esphome/core/helpers.h" #include "esphome/core/preferences.h" +#include +#include namespace esphome { @@ -27,7 +28,7 @@ template class TemplatableValue { TemplatableValue() : type_(NONE) {} template::value, int> = 0> - TemplatableValue(F value) : type_(VALUE), value_(value) {} + TemplatableValue(F value) : type_(VALUE), value_(std::move(value)) {} template::value, int> = 0> TemplatableValue(F f) : type_(LAMBDA), f_(f) {} diff --git a/tests/components/mdns/common-enabled.yaml b/tests/components/mdns/common-enabled.yaml index bc31e32783..8b3d81cf69 100644 --- a/tests/components/mdns/common-enabled.yaml +++ b/tests/components/mdns/common-enabled.yaml @@ -4,3 +4,10 @@ wifi: mdns: disabled: false + services: + - service: _test_service + protocol: _tcp + port: 8888 + txt: + static_string: Anything + templated_string: !lambda return "Something else"; From d2ee2d3b23e364326827fc21d2015985a7d2da38 Mon Sep 17 00:00:00 2001 From: baal86 <43853528+baal86@users.noreply.github.com> Date: Mon, 28 Apr 2025 07:21:24 +0200 Subject: [PATCH 027/102] Fix support for ESP32-H2 in deep_sleep (#8290) --- esphome/components/deep_sleep/deep_sleep_esp32.cpp | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/esphome/components/deep_sleep/deep_sleep_esp32.cpp b/esphome/components/deep_sleep/deep_sleep_esp32.cpp index d647140865..4582d695f6 100644 --- a/esphome/components/deep_sleep/deep_sleep_esp32.cpp +++ b/esphome/components/deep_sleep/deep_sleep_esp32.cpp @@ -31,9 +31,12 @@ void DeepSleepComponent::set_wakeup_pin_mode(WakeupPinMode wakeup_pin_mode) { #if !defined(USE_ESP32_VARIANT_ESP32C3) && !defined(USE_ESP32_VARIANT_ESP32C6) void DeepSleepComponent::set_ext1_wakeup(Ext1Wakeup ext1_wakeup) { this->ext1_wakeup_ = ext1_wakeup; } +#if !defined(USE_ESP32_VARIANT_ESP32H2) void DeepSleepComponent::set_touch_wakeup(bool touch_wakeup) { this->touch_wakeup_ = touch_wakeup; } #endif +#endif + void DeepSleepComponent::set_run_duration(WakeupCauseToRunDuration wakeup_cause_to_run_duration) { wakeup_cause_to_run_duration_ = wakeup_cause_to_run_duration; } @@ -65,7 +68,7 @@ bool DeepSleepComponent::prepare_to_sleep_() { } void DeepSleepComponent::deep_sleep_() { -#if !defined(USE_ESP32_VARIANT_ESP32C3) && !defined(USE_ESP32_VARIANT_ESP32C6) +#if !defined(USE_ESP32_VARIANT_ESP32C3) && !defined(USE_ESP32_VARIANT_ESP32C6) && !defined(USE_ESP32_VARIANT_ESP32H2) if (this->sleep_duration_.has_value()) esp_sleep_enable_timer_wakeup(*this->sleep_duration_); if (this->wakeup_pin_ != nullptr) { @@ -84,6 +87,15 @@ void DeepSleepComponent::deep_sleep_() { esp_sleep_pd_config(ESP_PD_DOMAIN_RTC_PERIPH, ESP_PD_OPTION_ON); } #endif + +#if defined(USE_ESP32_VARIANT_ESP32H2) + if (this->sleep_duration_.has_value()) + esp_sleep_enable_timer_wakeup(*this->sleep_duration_); + if (this->ext1_wakeup_.has_value()) { + esp_sleep_enable_ext1_wakeup(this->ext1_wakeup_->mask, this->ext1_wakeup_->wakeup_mode); + } +#endif + #if defined(USE_ESP32_VARIANT_ESP32C3) || defined(USE_ESP32_VARIANT_ESP32C6) if (this->sleep_duration_.has_value()) esp_sleep_enable_timer_wakeup(*this->sleep_duration_); From 3291a11824d03863cb4bc398e08da0ddf1204ebc Mon Sep 17 00:00:00 2001 From: Steffen Banhardt Date: Mon, 28 Apr 2025 21:18:46 +0200 Subject: [PATCH 028/102] =?UTF-8?q?Update=20ens160=5Fbase.cpp=20=E2=80=93?= =?UTF-8?q?=20fix=20wrong=20double=20negative=20(#8639)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- esphome/components/ens160_base/ens160_base.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/esphome/components/ens160_base/ens160_base.cpp b/esphome/components/ens160_base/ens160_base.cpp index 71082c58c2..852328d4bb 100644 --- a/esphome/components/ens160_base/ens160_base.cpp +++ b/esphome/components/ens160_base/ens160_base.cpp @@ -187,7 +187,7 @@ void ENS160Component::update() { } return; case INVALID_OUTPUT: - ESP_LOGE(TAG, "ENS160 Invalid Status - No Invalid Output"); + ESP_LOGE(TAG, "ENS160 Invalid Status - No valid output"); this->status_set_warning(); return; } From 629481a5261906a3905c6da746efd7d80500188b Mon Sep 17 00:00:00 2001 From: Jesse Hills <3060199+jesserockz@users.noreply.github.com> Date: Tue, 29 Apr 2025 10:46:39 +1200 Subject: [PATCH 029/102] [esp32_ble] Remove explicit and now incorrect ble override for esp32-c6 (#8643) --- esphome/components/esp32_ble/ble.cpp | 8 --- esphome/components/esp32_ble/const_esp32c6.h | 74 -------------------- 2 files changed, 82 deletions(-) delete mode 100644 esphome/components/esp32_ble/const_esp32c6.h diff --git a/esphome/components/esp32_ble/ble.cpp b/esphome/components/esp32_ble/ble.cpp index b10e454c21..ab2647b738 100644 --- a/esphome/components/esp32_ble/ble.cpp +++ b/esphome/components/esp32_ble/ble.cpp @@ -2,10 +2,6 @@ #include "ble.h" -#ifdef USE_ESP32_VARIANT_ESP32C6 -#include "const_esp32c6.h" -#endif // USE_ESP32_VARIANT_ESP32C6 - #include "esphome/core/application.h" #include "esphome/core/log.h" @@ -127,11 +123,7 @@ bool ESP32BLE::ble_setup_() { if (esp_bt_controller_get_status() != ESP_BT_CONTROLLER_STATUS_ENABLED) { // start bt controller if (esp_bt_controller_get_status() == ESP_BT_CONTROLLER_STATUS_IDLE) { -#ifdef USE_ESP32_VARIANT_ESP32C6 - esp_bt_controller_config_t cfg = BT_CONTROLLER_CONFIG; -#else esp_bt_controller_config_t cfg = BT_CONTROLLER_INIT_CONFIG_DEFAULT(); -#endif err = esp_bt_controller_init(&cfg); if (err != ESP_OK) { ESP_LOGE(TAG, "esp_bt_controller_init failed: %s", esp_err_to_name(err)); diff --git a/esphome/components/esp32_ble/const_esp32c6.h b/esphome/components/esp32_ble/const_esp32c6.h deleted file mode 100644 index 89179d8dd9..0000000000 --- a/esphome/components/esp32_ble/const_esp32c6.h +++ /dev/null @@ -1,74 +0,0 @@ -#pragma once - -#ifdef USE_ESP32_VARIANT_ESP32C6 - -#include - -namespace esphome { -namespace esp32_ble { - -static const esp_bt_controller_config_t BT_CONTROLLER_CONFIG = { - .config_version = CONFIG_VERSION, - .ble_ll_resolv_list_size = CONFIG_BT_LE_LL_RESOLV_LIST_SIZE, - .ble_hci_evt_hi_buf_count = DEFAULT_BT_LE_HCI_EVT_HI_BUF_COUNT, - .ble_hci_evt_lo_buf_count = DEFAULT_BT_LE_HCI_EVT_LO_BUF_COUNT, - .ble_ll_sync_list_cnt = DEFAULT_BT_LE_MAX_PERIODIC_ADVERTISER_LIST, - .ble_ll_sync_cnt = DEFAULT_BT_LE_MAX_PERIODIC_SYNCS, - .ble_ll_rsp_dup_list_count = CONFIG_BT_LE_LL_DUP_SCAN_LIST_COUNT, - .ble_ll_adv_dup_list_count = CONFIG_BT_LE_LL_DUP_SCAN_LIST_COUNT, - .ble_ll_tx_pwr_dbm = BLE_LL_TX_PWR_DBM_N, - .rtc_freq = RTC_FREQ_N, - .ble_ll_sca = CONFIG_BT_LE_LL_SCA, - .ble_ll_scan_phy_number = BLE_LL_SCAN_PHY_NUMBER_N, - .ble_ll_conn_def_auth_pyld_tmo = BLE_LL_CONN_DEF_AUTH_PYLD_TMO_N, - .ble_ll_jitter_usecs = BLE_LL_JITTER_USECS_N, - .ble_ll_sched_max_adv_pdu_usecs = BLE_LL_SCHED_MAX_ADV_PDU_USECS_N, - .ble_ll_sched_direct_adv_max_usecs = BLE_LL_SCHED_DIRECT_ADV_MAX_USECS_N, - .ble_ll_sched_adv_max_usecs = BLE_LL_SCHED_ADV_MAX_USECS_N, - .ble_scan_rsp_data_max_len = DEFAULT_BT_LE_SCAN_RSP_DATA_MAX_LEN_N, - .ble_ll_cfg_num_hci_cmd_pkts = BLE_LL_CFG_NUM_HCI_CMD_PKTS_N, - .ble_ll_ctrl_proc_timeout_ms = BLE_LL_CTRL_PROC_TIMEOUT_MS_N, - .nimble_max_connections = DEFAULT_BT_LE_MAX_CONNECTIONS, - .ble_whitelist_size = DEFAULT_BT_NIMBLE_WHITELIST_SIZE, // NOLINT - .ble_acl_buf_size = DEFAULT_BT_LE_ACL_BUF_SIZE, - .ble_acl_buf_count = DEFAULT_BT_LE_ACL_BUF_COUNT, - .ble_hci_evt_buf_size = DEFAULT_BT_LE_HCI_EVT_BUF_SIZE, - .ble_multi_adv_instances = DEFAULT_BT_LE_MAX_EXT_ADV_INSTANCES, - .ble_ext_adv_max_size = DEFAULT_BT_LE_EXT_ADV_MAX_SIZE, - .controller_task_stack_size = NIMBLE_LL_STACK_SIZE, - .controller_task_prio = ESP_TASK_BT_CONTROLLER_PRIO, - .controller_run_cpu = 0, - .enable_qa_test = RUN_QA_TEST, - .enable_bqb_test = RUN_BQB_TEST, -#if ESP_IDF_VERSION < ESP_IDF_VERSION_VAL(5, 3, 1) - // The following fields have been removed since ESP IDF version 5.3.1, see commit: - // https://github.com/espressif/esp-idf/commit/e761c1de8f9c0777829d597b4d5a33bb070a30a8 - .enable_uart_hci = HCI_UART_EN, - .ble_hci_uart_port = DEFAULT_BT_LE_HCI_UART_PORT, - .ble_hci_uart_baud = DEFAULT_BT_LE_HCI_UART_BAUD, - .ble_hci_uart_data_bits = DEFAULT_BT_LE_HCI_UART_DATA_BITS, - .ble_hci_uart_stop_bits = DEFAULT_BT_LE_HCI_UART_STOP_BITS, - .ble_hci_uart_flow_ctrl = DEFAULT_BT_LE_HCI_UART_FLOW_CTRL, - .ble_hci_uart_uart_parity = DEFAULT_BT_LE_HCI_UART_PARITY, -#endif - .enable_tx_cca = DEFAULT_BT_LE_TX_CCA_ENABLED, - .cca_rssi_thresh = 256 - DEFAULT_BT_LE_CCA_RSSI_THRESH, - .sleep_en = NIMBLE_SLEEP_ENABLE, - .coex_phy_coded_tx_rx_time_limit = DEFAULT_BT_LE_COEX_PHY_CODED_TX_RX_TLIM_EFF, - .dis_scan_backoff = NIMBLE_DISABLE_SCAN_BACKOFF, - .ble_scan_classify_filter_enable = 1, - .main_xtal_freq = CONFIG_XTAL_FREQ, - .version_num = (uint8_t) efuse_hal_chip_revision(), - .cpu_freq_mhz = CONFIG_ESP_DEFAULT_CPU_FREQ_MHZ, - .ignore_wl_for_direct_adv = 0, - .enable_pcl = DEFAULT_BT_LE_POWER_CONTROL_ENABLED, -#if ESP_IDF_VERSION >= ESP_IDF_VERSION_VAL(5, 1, 3) - .csa2_select = DEFAULT_BT_LE_50_FEATURE_SUPPORT, -#endif - .config_magic = CONFIG_MAGIC, -}; - -} // namespace esp32_ble -} // namespace esphome - -#endif // USE_ESP32_VARIANT_ESP32C6 From a31a5e74bdfa3ece19c502a91f81b0d93414bd67 Mon Sep 17 00:00:00 2001 From: Jesse Hills <3060199+jesserockz@users.noreply.github.com> Date: Tue, 29 Apr 2025 11:35:38 +1200 Subject: [PATCH 030/102] [const] Move CONF_GAIN_FACTOR to const.py (#8646) --- esphome/components/sen5x/sensor.py | 2 +- esphome/components/sgp4x/sensor.py | 2 +- esphome/const.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/esphome/components/sen5x/sensor.py b/esphome/components/sen5x/sensor.py index a8a796853e..f52de5fe85 100644 --- a/esphome/components/sen5x/sensor.py +++ b/esphome/components/sen5x/sensor.py @@ -4,6 +4,7 @@ import esphome.codegen as cg from esphome.components import i2c, sensirion_common, sensor import esphome.config_validation as cv from esphome.const import ( + CONF_GAIN_FACTOR, CONF_HUMIDITY, CONF_ID, CONF_OFFSET, @@ -43,7 +44,6 @@ RhtAccelerationMode = sen5x_ns.enum("RhtAccelerationMode") CONF_ACCELERATION_MODE = "acceleration_mode" CONF_ALGORITHM_TUNING = "algorithm_tuning" CONF_AUTO_CLEANING_INTERVAL = "auto_cleaning_interval" -CONF_GAIN_FACTOR = "gain_factor" CONF_GATING_MAX_DURATION_MINUTES = "gating_max_duration_minutes" CONF_INDEX_OFFSET = "index_offset" CONF_LEARNING_TIME_GAIN_HOURS = "learning_time_gain_hours" diff --git a/esphome/components/sgp4x/sensor.py b/esphome/components/sgp4x/sensor.py index 9317187df3..4f29248881 100644 --- a/esphome/components/sgp4x/sensor.py +++ b/esphome/components/sgp4x/sensor.py @@ -3,6 +3,7 @@ from esphome.components import i2c, sensirion_common, sensor import esphome.config_validation as cv from esphome.const import ( CONF_COMPENSATION, + CONF_GAIN_FACTOR, CONF_ID, CONF_STORE_BASELINE, CONF_TEMPERATURE_SOURCE, @@ -24,7 +25,6 @@ SGP4xComponent = sgp4x_ns.class_( ) CONF_ALGORITHM_TUNING = "algorithm_tuning" -CONF_GAIN_FACTOR = "gain_factor" CONF_GATING_MAX_DURATION_MINUTES = "gating_max_duration_minutes" CONF_HUMIDITY_SOURCE = "humidity_source" CONF_INDEX_OFFSET = "index_offset" diff --git a/esphome/const.py b/esphome/const.py index b2437eca7e..ffa5de2de3 100644 --- a/esphome/const.py +++ b/esphome/const.py @@ -333,6 +333,7 @@ CONF_FULL_SPECTRUM = "full_spectrum" CONF_FULL_SPECTRUM_COUNTS = "full_spectrum_counts" CONF_FULL_UPDATE_EVERY = "full_update_every" CONF_GAIN = "gain" +CONF_GAIN_FACTOR = "gain_factor" CONF_GAMMA_CORRECT = "gamma_correct" CONF_GAS_RESISTANCE = "gas_resistance" CONF_GATEWAY = "gateway" From b5bdfb30890a1d8dac9d97e77ed30da00c7a7242 Mon Sep 17 00:00:00 2001 From: Jesse Hills <3060199+jesserockz@users.noreply.github.com> Date: Tue, 29 Apr 2025 11:45:41 +1200 Subject: [PATCH 031/102] [http_request] Fix request headers (#8644) Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- esphome/components/http_request/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/esphome/components/http_request/__init__.py b/esphome/components/http_request/__init__.py index 4da49ddde1..9aa0c42fa2 100644 --- a/esphome/components/http_request/__init__.py +++ b/esphome/components/http_request/__init__.py @@ -295,8 +295,8 @@ async def http_request_action_to_code(config, action_id, template_arg, args): for key in json_: template_ = await cg.templatable(json_[key], args, cg.std_string) cg.add(var.add_json(key, template_)) - for key in config.get(CONF_REQUEST_HEADERS, []): - template_ = await cg.templatable(key, args, cg.std_string) + for key, value in config.get(CONF_REQUEST_HEADERS, {}).items(): + template_ = await cg.templatable(value, args, cg.const_char_ptr) cg.add(var.add_request_header(key, template_)) for value in config.get(CONF_COLLECT_HEADERS, []): From 59b4a1f5541ac053a5968efabddb6bd7402a6368 Mon Sep 17 00:00:00 2001 From: Jesse Hills <3060199+jesserockz@users.noreply.github.com> Date: Thu, 17 Apr 2025 17:04:43 +1200 Subject: [PATCH 032/102] Fix psram below idf 5 (#8584) --- esphome/components/psram/psram.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/esphome/components/psram/psram.cpp b/esphome/components/psram/psram.cpp index f592ada246..162543545e 100644 --- a/esphome/components/psram/psram.cpp +++ b/esphome/components/psram/psram.cpp @@ -1,7 +1,8 @@ #ifdef USE_ESP32 #include "psram.h" -#ifdef USE_ESP_IDF +#include +#if defined(USE_ESP_IDF) && ESP_IDF_VERSION_MAJOR >= 5 #include #endif // USE_ESP_IDF @@ -15,7 +16,7 @@ static const char *const TAG = "psram"; void PsramComponent::dump_config() { ESP_LOGCONFIG(TAG, "PSRAM:"); -#ifdef USE_ESP_IDF +#if defined(USE_ESP_IDF) && ESP_IDF_VERSION_MAJOR >= 5 bool available = esp_psram_is_initialized(); ESP_LOGCONFIG(TAG, " Available: %s", YESNO(available)); From 86033b661208cfa22f1fd8bbc38241aa8aa4b523 Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Tue, 22 Apr 2025 06:51:52 +1000 Subject: [PATCH 033/102] [lvgl] Ensure pages are created on the correct display (#8596) --- esphome/components/lvgl/lvgl_esphome.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/esphome/components/lvgl/lvgl_esphome.cpp b/esphome/components/lvgl/lvgl_esphome.cpp index 2560cd2168..2e5ba25851 100644 --- a/esphome/components/lvgl/lvgl_esphome.cpp +++ b/esphome/components/lvgl/lvgl_esphome.cpp @@ -120,6 +120,7 @@ void LvglComponent::add_event_cb(lv_obj_t *obj, event_callback_t callback, lv_ev void LvglComponent::add_page(LvPageType *page) { this->pages_.push_back(page); page->set_parent(this); + lv_disp_set_default(this->disp_); page->setup(this->pages_.size() - 1); } void LvglComponent::show_page(size_t index, lv_scr_load_anim_t anim, uint32_t time) { From aa6e172e14e042f27c4de2ed11a905789b19165f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 21 Apr 2025 23:17:09 -1000 Subject: [PATCH 034/102] Fix BLE connection loop caused by timeout and pending disconnect race (#8597) --- esphome/components/bluetooth_proxy/bluetooth_proxy.cpp | 6 ++++++ esphome/components/esp32_ble_tracker/esp32_ble_tracker.h | 2 ++ 2 files changed, 8 insertions(+) diff --git a/esphome/components/bluetooth_proxy/bluetooth_proxy.cpp b/esphome/components/bluetooth_proxy/bluetooth_proxy.cpp index a263aca456..03213432cd 100644 --- a/esphome/components/bluetooth_proxy/bluetooth_proxy.cpp +++ b/esphome/components/bluetooth_proxy/bluetooth_proxy.cpp @@ -265,6 +265,12 @@ void BluetoothProxy::bluetooth_device_request(const api::BluetoothDeviceRequest connection->get_connection_index(), connection->address_str().c_str()); return; } else if (connection->state() == espbt::ClientState::CONNECTING) { + if (connection->disconnect_pending()) { + ESP_LOGW(TAG, "[%d] [%s] Connection request while pending disconnect, cancelling pending disconnect", + connection->get_connection_index(), connection->address_str().c_str()); + connection->cancel_pending_disconnect(); + return; + } ESP_LOGW(TAG, "[%d] [%s] Connection request ignored, already connecting", connection->get_connection_index(), connection->address_str().c_str()); return; diff --git a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h index 99126f9173..8b712a01ea 100644 --- a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h +++ b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h @@ -173,6 +173,8 @@ class ESPBTClient : public ESPBTDeviceListener { virtual void gap_event_handler(esp_gap_ble_cb_event_t event, esp_ble_gap_cb_param_t *param) = 0; virtual void connect() = 0; virtual void disconnect() = 0; + bool disconnect_pending() const { return this->want_disconnect_; } + void cancel_pending_disconnect() { this->want_disconnect_ = false; } virtual void set_state(ClientState st) { this->state_ = st; if (st == ClientState::IDLE) { From b940db65492b937d6abdea98ae45e735ee4b3d22 Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Wed, 23 Apr 2025 18:47:15 +1000 Subject: [PATCH 035/102] [online_image] Fix printf format; comment fixes (#8607) --- esphome/components/online_image/online_image.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/esphome/components/online_image/online_image.cpp b/esphome/components/online_image/online_image.cpp index 3411018901..cb4a3be9e8 100644 --- a/esphome/components/online_image/online_image.cpp +++ b/esphome/components/online_image/online_image.cpp @@ -111,7 +111,7 @@ void OnlineImage::update() { case ImageFormat::BMP: accept_mime_type = "image/bmp"; break; -#endif // ONLINE_IMAGE_BMP_SUPPORT +#endif // USE_ONLINE_IMAGE_BMP_SUPPORT #ifdef USE_ONLINE_IMAGE_JPEG_SUPPORT case ImageFormat::JPEG: accept_mime_type = "image/jpeg"; @@ -121,7 +121,7 @@ void OnlineImage::update() { case ImageFormat::PNG: accept_mime_type = "image/png"; break; -#endif // ONLINE_IMAGE_PNG_SUPPORT +#endif // USE_ONLINE_IMAGE_PNG_SUPPORT default: accept_mime_type = "image/*"; } @@ -159,7 +159,7 @@ void OnlineImage::update() { ESP_LOGD(TAG, "Allocating BMP decoder"); this->decoder_ = make_unique(this); } -#endif // ONLINE_IMAGE_BMP_SUPPORT +#endif // USE_ONLINE_IMAGE_BMP_SUPPORT #ifdef USE_ONLINE_IMAGE_JPEG_SUPPORT if (this->format_ == ImageFormat::JPEG) { ESP_LOGD(TAG, "Allocating JPEG decoder"); @@ -171,7 +171,7 @@ void OnlineImage::update() { ESP_LOGD(TAG, "Allocating PNG decoder"); this->decoder_ = make_unique(this); } -#endif // ONLINE_IMAGE_PNG_SUPPORT +#endif // USE_ONLINE_IMAGE_PNG_SUPPORT if (!this->decoder_) { ESP_LOGE(TAG, "Could not instantiate decoder. Image format unsupported: %d", this->format_); @@ -185,7 +185,7 @@ void OnlineImage::update() { this->download_error_callback_.call(); return; } - ESP_LOGI(TAG, "Downloading image (Size: %d)", total_size); + ESP_LOGI(TAG, "Downloading image (Size: %zu)", total_size); this->start_time_ = ::time(nullptr); } From 1c60038111b6357972fb6531f8e07e499f6f033a Mon Sep 17 00:00:00 2001 From: Jesse Hills <3060199+jesserockz@users.noreply.github.com> Date: Wed, 23 Apr 2025 20:49:33 +1200 Subject: [PATCH 036/102] [watchdog] Fix for variants with single core (#8602) --- esphome/components/watchdog/watchdog.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/esphome/components/watchdog/watchdog.cpp b/esphome/components/watchdog/watchdog.cpp index 3a94a658e8..f6f2992a11 100644 --- a/esphome/components/watchdog/watchdog.cpp +++ b/esphome/components/watchdog/watchdog.cpp @@ -6,6 +6,7 @@ #include #include #ifdef USE_ESP32 +#include #include "esp_idf_version.h" #include "esp_task_wdt.h" #endif @@ -40,7 +41,7 @@ void WatchdogManager::set_timeout_(uint32_t timeout_ms) { #if ESP_IDF_VERSION_MAJOR >= 5 esp_task_wdt_config_t wdt_config = { .timeout_ms = timeout_ms, - .idle_core_mask = 0x03, + .idle_core_mask = (1 << SOC_CPU_CORES_NUM) - 1, .trigger_panic = true, }; esp_task_wdt_reconfigure(&wdt_config); From 5bfb5ccc348b68deddffed389e8f771f76ed791d Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Mon, 28 Apr 2025 10:19:50 +1000 Subject: [PATCH 037/102] [core] Fix setting of log level/verbose (#8600) --- esphome/log.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/esphome/log.py b/esphome/log.py index 835cd6b44d..516f27be45 100644 --- a/esphome/log.py +++ b/esphome/log.py @@ -74,13 +74,14 @@ def setup_log( colorama.init() - if log_level == logging.DEBUG: - CORE.verbose = True - elif log_level == logging.CRITICAL: - CORE.quiet = True - + # Setup logging - will map log level from string to constant logging.basicConfig(level=log_level) + if logging.root.level == logging.DEBUG: + CORE.verbose = True + elif logging.root.level == logging.CRITICAL: + CORE.quiet = True + logging.getLogger("urllib3").setLevel(logging.WARNING) logging.getLogger().handlers[0].setFormatter( From f096567ac750e3b7f943721c1b5ef7769f808238 Mon Sep 17 00:00:00 2001 From: Steffen Banhardt Date: Mon, 28 Apr 2025 21:18:46 +0200 Subject: [PATCH 038/102] =?UTF-8?q?Update=20ens160=5Fbase.cpp=20=E2=80=93?= =?UTF-8?q?=20fix=20wrong=20double=20negative=20(#8639)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- esphome/components/ens160_base/ens160_base.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/esphome/components/ens160_base/ens160_base.cpp b/esphome/components/ens160_base/ens160_base.cpp index 71082c58c2..852328d4bb 100644 --- a/esphome/components/ens160_base/ens160_base.cpp +++ b/esphome/components/ens160_base/ens160_base.cpp @@ -187,7 +187,7 @@ void ENS160Component::update() { } return; case INVALID_OUTPUT: - ESP_LOGE(TAG, "ENS160 Invalid Status - No Invalid Output"); + ESP_LOGE(TAG, "ENS160 Invalid Status - No valid output"); this->status_set_warning(); return; } From 7900660bb879a214744f869480bc21128013a877 Mon Sep 17 00:00:00 2001 From: Jesse Hills <3060199+jesserockz@users.noreply.github.com> Date: Tue, 29 Apr 2025 11:46:20 +1200 Subject: [PATCH 039/102] Bump version to 2025.4.1 --- esphome/const.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/esphome/const.py b/esphome/const.py index 6fe79eb2c9..6d1ff157bf 100644 --- a/esphome/const.py +++ b/esphome/const.py @@ -1,6 +1,6 @@ """Constants used by esphome.""" -__version__ = "2025.4.0" +__version__ = "2025.4.1" ALLOWED_NAME_CHARS = "abcdefghijklmnopqrstuvwxyz0123456789-_" VALID_SUBSTITUTIONS_CHARACTERS = ( From c9f7ab69482bd27d2ed991115668700d562a460e Mon Sep 17 00:00:00 2001 From: aanban Date: Tue, 29 Apr 2025 01:50:40 +0200 Subject: [PATCH 040/102] add beo4_protocol to remote_base component (#8307) --- esphome/components/remote_base/__init__.py | 48 ++++++ .../components/remote_base/beo4_protocol.cpp | 151 ++++++++++++++++++ .../components/remote_base/beo4_protocol.h | 43 +++++ .../remote_receiver/common-actions.yaml | 5 + .../remote_transmitter/common-buttons.yaml | 7 + 5 files changed, 254 insertions(+) create mode 100644 esphome/components/remote_base/beo4_protocol.cpp create mode 100644 esphome/components/remote_base/beo4_protocol.h diff --git a/esphome/components/remote_base/__init__.py b/esphome/components/remote_base/__init__.py index daea4e5c11..adacb83a30 100644 --- a/esphome/components/remote_base/__init__.py +++ b/esphome/components/remote_base/__init__.py @@ -28,6 +28,7 @@ from esphome.const import ( CONF_RC_CODE_2, CONF_REPEAT, CONF_SECOND, + CONF_SOURCE, CONF_STATE, CONF_SYNC, CONF_TIMES, @@ -265,6 +266,53 @@ async def build_dumpers(config): return dumpers +# Beo4 +Beo4Data, Beo4BinarySensor, Beo4Trigger, Beo4Action, Beo4Dumper = declare_protocol( + "Beo4" +) +BEO4_SCHEMA = cv.Schema( + { + cv.Required(CONF_SOURCE): cv.hex_uint8_t, + cv.Required(CONF_COMMAND): cv.hex_uint8_t, + cv.Optional(CONF_COMMAND_REPEATS, default=1): cv.uint8_t, + } +) + + +@register_binary_sensor("beo4", Beo4BinarySensor, BEO4_SCHEMA) +def beo4_binary_sensor(var, config): + cg.add( + var.set_data( + cg.StructInitializer( + Beo4Data, + ("source", config[CONF_SOURCE]), + ("command", config[CONF_COMMAND]), + ("repeats", config[CONF_COMMAND_REPEATS]), + ) + ) + ) + + +@register_trigger("beo4", Beo4Trigger, Beo4Data) +def beo4_trigger(var, config): + pass + + +@register_dumper("beo4", Beo4Dumper) +def beo4_dumper(var, config): + pass + + +@register_action("beo4", Beo4Action, BEO4_SCHEMA) +async def beo4_action(var, config, args): + template_ = await cg.templatable(config[CONF_SOURCE], args, cg.uint8) + cg.add(var.set_source(template_)) + template_ = await cg.templatable(config[CONF_COMMAND], args, cg.uint8) + cg.add(var.set_command(template_)) + template_ = await cg.templatable(config[CONF_COMMAND_REPEATS], args, cg.uint8) + cg.add(var.set_repeats(template_)) + + # ByronSX ( ByronSXData, diff --git a/esphome/components/remote_base/beo4_protocol.cpp b/esphome/components/remote_base/beo4_protocol.cpp new file mode 100644 index 0000000000..9f8d5e72c9 --- /dev/null +++ b/esphome/components/remote_base/beo4_protocol.cpp @@ -0,0 +1,151 @@ +#include "beo4_protocol.h" +#include "esphome/core/log.h" + +namespace esphome { +namespace remote_base { + +static const char *const TAG = "remote.beo4"; + +// beo4 pulse width, high=carrier_pulse low=data_pulse +constexpr uint16_t PW_CARR_US = 200; // carrier pulse length +constexpr uint16_t PW_ZERO_US = 2925; // + 200 = 3125 µs +constexpr uint16_t PW_SAME_US = 6050; // + 200 = 6250 µs +constexpr uint16_t PW_ONE_US = 9175; // + 200 = 9375 µs +constexpr uint16_t PW_STOP_US = 12300; // + 200 = 12500 µs +constexpr uint16_t PW_START_US = 15425; // + 200 = 15625 µs + +// beo4 pulse codes +constexpr uint8_t PC_ZERO = (PW_CARR_US + PW_ZERO_US) / 3125; // =1 +constexpr uint8_t PC_SAME = (PW_CARR_US + PW_SAME_US) / 3125; // =2 +constexpr uint8_t PC_ONE = (PW_CARR_US + PW_ONE_US) / 3125; // =3 +constexpr uint8_t PC_STOP = (PW_CARR_US + PW_STOP_US) / 3125; // =4 +constexpr uint8_t PC_START = (PW_CARR_US + PW_START_US) / 3125; // =5 + +// beo4 number of data bits = beoLink+beoSrc+beoCmd = 1+8+8 = 17 +constexpr uint32_t N_BITS = 1 + 8 + 8; + +// required symbols = 2*(start_sequence + n_bits + stop) = 2*(3+17+1) = 42 +constexpr uint32_t N_SYM = 2 + ((3 + 17 + 1) * 2u); // + 2 = 44 + +// states finite-state-machine decoder +enum class RxSt { RX_IDLE, RX_DATA, RX_STOP }; + +void Beo4Protocol::encode(RemoteTransmitData *dst, const Beo4Data &data) { + uint32_t beo_code = ((uint32_t) data.source << 8) + (uint32_t) data.command; + uint32_t jc = 0, ic = 0; + uint32_t cur_bit = 0; + uint32_t pre_bit = 0; + dst->set_carrier_frequency(455000); + dst->reserve(N_SYM); + + // start sequence=zero,zero,start + dst->item(PW_CARR_US, PW_ZERO_US); + dst->item(PW_CARR_US, PW_ZERO_US); + dst->item(PW_CARR_US, PW_START_US); + + // the data-bit BeoLink is always 0 + dst->item(PW_CARR_US, PW_ZERO_US); + + // The B&O trick to avoid extra long and extra short + // code-frames by extracting the data-bits from left + // to right, then comparing current with previous bit + // and set pulse to "same" "one" or "zero" + for (jc = 15, ic = 0; ic < 16; ic++, jc--) { + cur_bit = ((beo_code) >> jc) & 1; + if (cur_bit == pre_bit) { + dst->item(PW_CARR_US, PW_SAME_US); + } else if (1 == cur_bit) { + dst->item(PW_CARR_US, PW_ONE_US); + } else { + dst->item(PW_CARR_US, PW_ZERO_US); + } + pre_bit = cur_bit; + } + // complete the frame with stop-symbol and final carrier pulse + dst->item(PW_CARR_US, PW_STOP_US); + dst->mark(PW_CARR_US); +} + +optional Beo4Protocol::decode(RemoteReceiveData src) { + int32_t n_sym = src.size(); + Beo4Data data{ + .source = 0, + .command = 0, + .repeats = 0, + }; + // suppress dummy codes (TSO7000 hiccups) + if (n_sym > 42) { + static uint32_t beo_code = 0; + RxSt fsm = RxSt::RX_IDLE; + int32_t ic = 0; + int32_t jc = 0; + uint32_t pre_bit = 0; + uint32_t cnt_bit = 0; + ESP_LOGD(TAG, "Beo4: n_sym=%d ", n_sym); + for (jc = 0, ic = 0; ic < (n_sym - 1); ic += 2, jc++) { + int32_t pulse_width = src[ic] - src[ic + 1]; + // suppress TSOP7000 (dummy pulses) + if (pulse_width > 1500) { + int32_t pulse_code = (pulse_width + 1560) / 3125; + switch (fsm) { + case RxSt::RX_IDLE: { + beo_code = 0; + cnt_bit = 0; + pre_bit = 0; + if (PC_START == pulse_code) { + fsm = RxSt::RX_DATA; + } + break; + } + case RxSt::RX_DATA: { + uint32_t cur_bit = 0; + switch (pulse_code) { + case PC_ZERO: { + cur_bit = pre_bit = 0; + break; + } + case PC_SAME: { + cur_bit = pre_bit; + break; + } + case PC_ONE: { + cur_bit = pre_bit = 1; + break; + } + default: { + fsm = RxSt::RX_IDLE; + break; + } + } + beo_code = (beo_code << 1) + cur_bit; + if (++cnt_bit == N_BITS) { + fsm = RxSt::RX_STOP; + } + break; + } + case RxSt::RX_STOP: { + if (PC_STOP == pulse_code) { + data.source = (uint8_t) ((beo_code >> 8) & 0xff); + data.command = (uint8_t) ((beo_code) &0xff); + data.repeats++; + } + if ((n_sym - ic) < 42) { + return data; + } else { + fsm = RxSt::RX_IDLE; + } + break; + } + } + } + } + } + return {}; // decoding failed +} + +void Beo4Protocol::dump(const Beo4Data &data) { + ESP_LOGI(TAG, "Beo4: source=0x%02x command=0x%02x repeats=%d ", data.source, data.command, data.repeats); +} + +} // namespace remote_base +} // namespace esphome diff --git a/esphome/components/remote_base/beo4_protocol.h b/esphome/components/remote_base/beo4_protocol.h new file mode 100644 index 0000000000..445e792cbc --- /dev/null +++ b/esphome/components/remote_base/beo4_protocol.h @@ -0,0 +1,43 @@ +#pragma once + +#include "remote_base.h" + +#include + +namespace esphome { +namespace remote_base { + +struct Beo4Data { + uint8_t source; // beoSource, e.g. video, audio, light... + uint8_t command; // beoCommend, e.g. volume+, mute,... + uint8_t repeats; // beoRepeat for repeat commands, e.g. up, down... + + bool operator==(const Beo4Data &rhs) const { return source == rhs.source && command == rhs.command; } +}; + +class Beo4Protocol : public RemoteProtocol { + public: + void encode(RemoteTransmitData *dst, const Beo4Data &data) override; + optional decode(RemoteReceiveData src) override; + void dump(const Beo4Data &data) override; +}; + +DECLARE_REMOTE_PROTOCOL(Beo4) + +template class Beo4Action : public RemoteTransmitterActionBase { + public: + TEMPLATABLE_VALUE(uint8_t, source) + TEMPLATABLE_VALUE(uint8_t, command) + TEMPLATABLE_VALUE(uint8_t, repeats) + + void encode(RemoteTransmitData *dst, Ts... x) override { + Beo4Data data{}; + data.source = this->source_.value(x...); + data.command = this->command_.value(x...); + data.repeats = this->repeats_.value(x...); + Beo4Protocol().encode(dst, data); + } +}; + +} // namespace remote_base +} // namespace esphome diff --git a/tests/components/remote_receiver/common-actions.yaml b/tests/components/remote_receiver/common-actions.yaml index c1f576d20e..08b1091116 100644 --- a/tests/components/remote_receiver/common-actions.yaml +++ b/tests/components/remote_receiver/common-actions.yaml @@ -3,6 +3,11 @@ on_abbwelcome: - logger.log: format: "on_abbwelcome: %u" args: ["x.data()[0]"] +on_beo4: + then: + - logger.log: + format: "on_beo4: %u %u" + args: ["x.source", "x.command"] on_aeha: then: - logger.log: diff --git a/tests/components/remote_transmitter/common-buttons.yaml b/tests/components/remote_transmitter/common-buttons.yaml index b037c50e12..1fb7ef6dbe 100644 --- a/tests/components/remote_transmitter/common-buttons.yaml +++ b/tests/components/remote_transmitter/common-buttons.yaml @@ -1,4 +1,11 @@ button: + - platform: template + name: Beo4 audio mute + id: beo4_audio_mute + on_press: + remote_transmitter.transmit_beo4: + source: 0x01 + command: 0x0C - platform: template name: JVC Off id: living_room_lights_on From 43580739accaab3882ffcff0b7a0ef50fd50eac3 Mon Sep 17 00:00:00 2001 From: Jesse Hills <3060199+jesserockz@users.noreply.github.com> Date: Tue, 29 Apr 2025 11:58:13 +1200 Subject: [PATCH 041/102] Ensure new const file stays in order (#8642) --- script/ci-custom.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/script/ci-custom.py b/script/ci-custom.py index 6a5fb32180..dda5410778 100755 --- a/script/ci-custom.py +++ b/script/ci-custom.py @@ -318,7 +318,12 @@ def lint_no_long_delays(fname, match): ) -@lint_content_check(include=["esphome/const.py"]) +@lint_content_check( + include=[ + "esphome/const.py", + "esphome/components/const/__init__.py", + ] +) def lint_const_ordered(fname, content): """Lint that value in const.py are ordered. From 844569e96bd8e0431ea0082ebb19f633bd6da3c3 Mon Sep 17 00:00:00 2001 From: Kevin Ahrendt Date: Mon, 28 Apr 2025 19:05:07 -0500 Subject: [PATCH 042/102] [audio, microphone] Add MicrophoneSource helper class (#8641) Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com> --- CODEOWNERS | 2 +- esphome/components/audio/__init__.py | 112 +++++++++++----- esphome/components/audio/audio_resampler.cpp | 2 + esphome/components/audio/audio_resampler.h | 1 + esphome/components/microphone/__init__.py | 122 +++++++++++++++++- esphome/components/microphone/microphone.h | 6 + .../microphone/microphone_source.cpp | 96 ++++++++++++++ .../components/microphone/microphone_source.h | 63 +++++++++ 8 files changed, 365 insertions(+), 39 deletions(-) create mode 100644 esphome/components/microphone/microphone_source.cpp create mode 100644 esphome/components/microphone/microphone_source.h diff --git a/CODEOWNERS b/CODEOWNERS index 73973f420f..6f5eae1a9c 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -278,7 +278,7 @@ esphome/components/mdns/* @esphome/core esphome/components/media_player/* @jesserockz esphome/components/micro_wake_word/* @jesserockz @kahrendt esphome/components/micronova/* @jorre05 -esphome/components/microphone/* @jesserockz +esphome/components/microphone/* @jesserockz @kahrendt esphome/components/mics_4514/* @jesserockz esphome/components/midea/* @dudanov esphome/components/midea_ir/* @dudanov diff --git a/esphome/components/audio/__init__.py b/esphome/components/audio/__init__.py index f8ec8cbd85..9f08c81e77 100644 --- a/esphome/components/audio/__init__.py +++ b/esphome/components/audio/__init__.py @@ -48,6 +48,12 @@ def set_stream_limits( min_sample_rate: int = _UNDEF, max_sample_rate: int = _UNDEF, ): + """Sets the limits for the audio stream that audio component can handle + + When the component sinks audio (e.g., a speaker), these indicate the limits to the audio it can receive. + When the component sources audio (e.g., a microphone), these indicate the limits to the audio it can send. + """ + def set_limits_in_config(config): if min_bits_per_sample is not _UNDEF: config[CONF_MIN_BITS_PER_SAMPLE] = min_bits_per_sample @@ -69,43 +75,87 @@ def final_validate_audio_schema( name: str, *, audio_device: str, - bits_per_sample: int, - channels: int, - sample_rate: int, + bits_per_sample: int = _UNDEF, + channels: int = _UNDEF, + sample_rate: int = _UNDEF, + enabled_channels: list[int] = _UNDEF, + audio_device_issue: bool = False, ): + """Validates audio compatibility when passed between different components. + + The component derived from ``AUDIO_COMPONENT_SCHEMA`` should call ``set_stream_limits`` in a validator to specify its compatible settings + + - If audio_device_issue is True, then the error message indicates the user should adjust the AUDIO_COMPONENT_SCHEMA derived component's configuration to match the values passed to this function + - If audio_device_issue is False, then the error message indicates the user should adjust the configuration of the component calling this function, as it falls out of the valid stream limits + + Args: + name (str): Friendly name of the component calling this function with an audio component to validate + audio_device (str): The configuration parameter name that contains the ID of an AUDIO_COMPONENT_SCHEMA derived component to validate against + bits_per_sample (int, optional): The desired bits per sample + channels (int, optional): The desired number of channels + sample_rate (int, optional): The desired sample rate + enabled_channels (list[int], optional): The desired enabled channels + audio_device_issue (bool, optional): Format the error message to indicate the problem is in the configuration for the ``audio_device`` component. Defaults to False. + """ + def validate_audio_compatiblity(audio_config): audio_schema = {} - try: - cv.int_range( - min=audio_config.get(CONF_MIN_BITS_PER_SAMPLE), - max=audio_config.get(CONF_MAX_BITS_PER_SAMPLE), - )(bits_per_sample) - except cv.Invalid as exc: - raise cv.Invalid( - f"Invalid configuration for the {name} component. The {CONF_BITS_PER_SAMPLE} {str(exc)}" - ) from exc + if bits_per_sample is not _UNDEF: + try: + cv.int_range( + min=audio_config.get(CONF_MIN_BITS_PER_SAMPLE), + max=audio_config.get(CONF_MAX_BITS_PER_SAMPLE), + )(bits_per_sample) + except cv.Invalid as exc: + if audio_device_issue: + error_string = f"Invalid configuration for the specified {audio_device}. The {name} component requires {bits_per_sample} bits per sample." + else: + error_string = f"Invalid configuration for the {name} component. The {CONF_BITS_PER_SAMPLE} {str(exc)}" + raise cv.Invalid(error_string) from exc - try: - cv.int_range( - min=audio_config.get(CONF_MIN_CHANNELS), - max=audio_config.get(CONF_MAX_CHANNELS), - )(channels) - except cv.Invalid as exc: - raise cv.Invalid( - f"Invalid configuration for the {name} component. The {CONF_NUM_CHANNELS} {str(exc)}" - ) from exc + if channels is not _UNDEF: + try: + cv.int_range( + min=audio_config.get(CONF_MIN_CHANNELS), + max=audio_config.get(CONF_MAX_CHANNELS), + )(channels) + except cv.Invalid as exc: + if audio_device_issue: + error_string = f"Invalid configuration for the specified {audio_device}. The {name} component requires {channels} channels." + else: + error_string = f"Invalid configuration for the {name} component. The {CONF_NUM_CHANNELS} {str(exc)}" + raise cv.Invalid(error_string) from exc - try: - cv.int_range( - min=audio_config.get(CONF_MIN_SAMPLE_RATE), - max=audio_config.get(CONF_MAX_SAMPLE_RATE), - )(sample_rate) - return cv.Schema(audio_schema, extra=cv.ALLOW_EXTRA)(audio_config) - except cv.Invalid as exc: - raise cv.Invalid( - f"Invalid configuration for the {name} component. The {CONF_SAMPLE_RATE} {str(exc)}" - ) from exc + if sample_rate is not _UNDEF: + try: + cv.int_range( + min=audio_config.get(CONF_MIN_SAMPLE_RATE), + max=audio_config.get(CONF_MAX_SAMPLE_RATE), + )(sample_rate) + except cv.Invalid as exc: + if audio_device_issue: + error_string = f"Invalid configuration for the specified {audio_device}. The {name} component requires a {sample_rate} sample rate." + else: + error_string = f"Invalid configuration for the {name} component. The {CONF_SAMPLE_RATE} {str(exc)}" + raise cv.Invalid(error_string) from exc + + if enabled_channels is not _UNDEF: + for channel in enabled_channels: + try: + # Channels are 0-indexed + cv.int_range( + min=0, + max=audio_config.get(CONF_MAX_CHANNELS) - 1, + )(channel) + except cv.Invalid as exc: + if audio_device_issue: + error_string = f"Invalid configuration for the specified {audio_device}. The {name} component requires channel {channel}." + else: + error_string = f"Invalid configuration for the {name} component. Enabled channel {channel} {str(exc)}" + raise cv.Invalid(error_string) from exc + + return cv.Schema(audio_schema, extra=cv.ALLOW_EXTRA)(audio_config) return cv.Schema( { diff --git a/esphome/components/audio/audio_resampler.cpp b/esphome/components/audio/audio_resampler.cpp index a7621225a1..20d246f1e0 100644 --- a/esphome/components/audio/audio_resampler.cpp +++ b/esphome/components/audio/audio_resampler.cpp @@ -4,6 +4,8 @@ #include "esphome/core/hal.h" +#include + namespace esphome { namespace audio { diff --git a/esphome/components/audio/audio_resampler.h b/esphome/components/audio/audio_resampler.h index 7f4e987b4c..082ade3371 100644 --- a/esphome/components/audio/audio_resampler.h +++ b/esphome/components/audio/audio_resampler.h @@ -6,6 +6,7 @@ #include "audio_transfer_buffer.h" #include "esphome/core/defines.h" +#include "esphome/core/helpers.h" #include "esphome/core/ring_buffer.h" #ifdef USE_SPEAKER diff --git a/esphome/components/microphone/__init__.py b/esphome/components/microphone/__init__.py index 4e5471b117..b9d24bc4a7 100644 --- a/esphome/components/microphone/__init__.py +++ b/esphome/components/microphone/__init__.py @@ -1,12 +1,21 @@ from esphome import automation from esphome.automation import maybe_simple_id import esphome.codegen as cg +from esphome.components import audio import esphome.config_validation as cv -from esphome.const import CONF_ID, CONF_TRIGGER_ID +from esphome.const import ( + CONF_BITS_PER_SAMPLE, + CONF_CHANNELS, + CONF_GAIN_FACTOR, + CONF_ID, + CONF_MICROPHONE, + CONF_TRIGGER_ID, +) from esphome.core import CORE from esphome.coroutine import coroutine_with_priority -CODEOWNERS = ["@jesserockz"] +AUTO_LOAD = ["audio"] +CODEOWNERS = ["@jesserockz", "@kahrendt"] IS_PLATFORM_COMPONENT = True @@ -15,6 +24,7 @@ CONF_ON_DATA = "on_data" microphone_ns = cg.esphome_ns.namespace("microphone") Microphone = microphone_ns.class_("Microphone") +MicrophoneSource = microphone_ns.class_("MicrophoneSource") CaptureAction = microphone_ns.class_( "CaptureAction", automation.Action, cg.Parented.template(Microphone) @@ -37,6 +47,7 @@ IsCapturingCondition = microphone_ns.class_( async def setup_microphone_core_(var, config): for conf in config.get(CONF_ON_DATA, []): trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) + # Future PR will change the vector type to uint8 await automation.build_automation( trigger, [(cg.std_vector.template(cg.int16).operator("ref").operator("const"), "x")], @@ -50,7 +61,7 @@ async def register_microphone(var, config): await setup_microphone_core_(var, config) -MICROPHONE_SCHEMA = cv.Schema( +MICROPHONE_SCHEMA = cv.Schema.extend(audio.AUDIO_COMPONENT_SCHEMA).extend( { cv.Optional(CONF_ON_DATA): automation.validate_automation( { @@ -64,7 +75,104 @@ MICROPHONE_SCHEMA = cv.Schema( MICROPHONE_ACTION_SCHEMA = maybe_simple_id({cv.GenerateID(): cv.use_id(Microphone)}) -async def media_player_action(config, action_id, template_arg, args): +def microphone_source_schema( + min_bits_per_sample: int = 16, + max_bits_per_sample: int = 16, + min_channels: int = 1, + max_channels: int = 1, +): + """Schema for a microphone source + + Components requesting microphone data should use this schema instead of accessing a microphone directly. + + Args: + min_bits_per_sample (int, optional): Minimum number of bits per sample the requesting component supports. Defaults to 16. + max_bits_per_sample (int, optional): Maximum number of bits per sample the requesting component supports. Defaults to 16. + min_channels (int, optional): Minimum number of channels the requesting component supports. Defaults to 1. + max_channels (int, optional): Maximum number of channels the requesting component supports. Defaults to 1. + """ + + def _validate_unique_channels(config): + if len(config) != len(set(config)): + raise cv.Invalid("Channels must be unique") + return config + + return cv.All( + cv.maybe_simple_value( + { + cv.GenerateID(CONF_ID): cv.declare_id(MicrophoneSource), + cv.Required(CONF_MICROPHONE): cv.use_id(Microphone), + cv.Optional(CONF_BITS_PER_SAMPLE, default=16): cv.int_range( + min_bits_per_sample, max_bits_per_sample + ), + cv.Optional(CONF_CHANNELS, default="0"): cv.All( + cv.ensure_list(cv.int_range(0, 7)), + cv.Length(min=min_channels, max=max_channels), + _validate_unique_channels, + ), + cv.Optional(CONF_GAIN_FACTOR, default="1"): cv.int_range(1, 64), + }, + key=CONF_MICROPHONE, + ), + ) + + +_UNDEF = object() + + +def final_validate_microphone_source_schema( + component_name: str, sample_rate: int = _UNDEF +): + """Validates that the microphone source can provide audio in the correct format. In particular it validates the sample rate and the enabled channels. + + Note that: + - MicrophoneSource class automatically handles converting bits per sample, so no need to validate + - microphone_source_schema already validates that channels are unique and specifies the max number of channels the component supports + + Args: + component_name (str): The name of the component requesting mic audio + sample_rate (int, optional): The sample rate the component requesting mic audio requires + """ + + def _validate_audio_compatability(config): + if sample_rate is not _UNDEF: + # Issues require changing the microphone configuration + # - Verifies sample rates match + audio.final_validate_audio_schema( + component_name, + audio_device=CONF_MICROPHONE, + sample_rate=sample_rate, + audio_device_issue=True, + )(config) + + # Issues require changing the MicrophoneSource configuration + # - Verifies that each of the enabled channels are available + audio.final_validate_audio_schema( + component_name, + audio_device=CONF_MICROPHONE, + enabled_channels=config[CONF_CHANNELS], + audio_device_issue=False, + )(config) + + return config + + return _validate_audio_compatability + + +async def microphone_source_to_code(config): + mic = await cg.get_variable(config[CONF_MICROPHONE]) + mic_source = cg.new_Pvariable( + config[CONF_ID], + mic, + config[CONF_BITS_PER_SAMPLE], + config[CONF_GAIN_FACTOR], + ) + for channel in config[CONF_CHANNELS]: + cg.add(mic_source.add_channel(channel)) + return mic_source + + +async def microphone_action(config, action_id, template_arg, args): var = cg.new_Pvariable(action_id, template_arg) await cg.register_parented(var, config[CONF_ID]) return var @@ -72,15 +180,15 @@ async def media_player_action(config, action_id, template_arg, args): automation.register_action( "microphone.capture", CaptureAction, MICROPHONE_ACTION_SCHEMA -)(media_player_action) +)(microphone_action) automation.register_action( "microphone.stop_capture", StopCaptureAction, MICROPHONE_ACTION_SCHEMA -)(media_player_action) +)(microphone_action) automation.register_condition( "microphone.is_capturing", IsCapturingCondition, MICROPHONE_ACTION_SCHEMA -)(media_player_action) +)(microphone_action) @coroutine_with_priority(100.0) diff --git a/esphome/components/microphone/microphone.h b/esphome/components/microphone/microphone.h index 914ad80bea..58552aa34a 100644 --- a/esphome/components/microphone/microphone.h +++ b/esphome/components/microphone/microphone.h @@ -1,5 +1,7 @@ #pragma once +#include "esphome/components/audio/audio.h" + #include #include #include @@ -28,9 +30,13 @@ class Microphone { bool is_running() const { return this->state_ == STATE_RUNNING; } bool is_stopped() const { return this->state_ == STATE_STOPPED; } + audio::AudioStreamInfo get_audio_stream_info() { return this->audio_stream_info_; } + protected: State state_{STATE_STOPPED}; + audio::AudioStreamInfo audio_stream_info_; + CallbackManager &)> data_callbacks_{}; }; diff --git a/esphome/components/microphone/microphone_source.cpp b/esphome/components/microphone/microphone_source.cpp new file mode 100644 index 0000000000..7e397348b9 --- /dev/null +++ b/esphome/components/microphone/microphone_source.cpp @@ -0,0 +1,96 @@ +#include "microphone_source.h" + +namespace esphome { +namespace microphone { + +void MicrophoneSource::add_data_callback(std::function &)> &&data_callback) { + std::function &)> filtered_callback = + [this, data_callback](const std::vector &data) { + if (this->enabled_) { + data_callback(this->process_audio_(data)); + } + }; + // Future PR will uncomment this! It requires changing the callback vector to an uint8_t in every component using a + // mic callback. + // this->mic_->add_data_callback(std::move(filtered_callback)); +} + +void MicrophoneSource::start() { + this->enabled_ = true; + this->mic_->start(); +} +void MicrophoneSource::stop() { + this->enabled_ = false; + this->mic_->stop(); +} + +std::vector MicrophoneSource::process_audio_(const std::vector &data) { + // Bit depth conversions are obtained by truncating bits or padding with zeros - no dithering is applied. + + const size_t source_bytes_per_sample = this->mic_->get_audio_stream_info().samples_to_bytes(1); + const size_t source_channels = this->mic_->get_audio_stream_info().get_channels(); + + const size_t source_bytes_per_frame = this->mic_->get_audio_stream_info().frames_to_bytes(1); + + const uint32_t total_frames = this->mic_->get_audio_stream_info().bytes_to_frames(data.size()); + const size_t target_bytes_per_sample = (this->bits_per_sample_ + 7) / 8; + const size_t target_bytes_per_frame = target_bytes_per_sample * this->channels_.count(); + + std::vector filtered_data; + filtered_data.reserve(target_bytes_per_frame * total_frames); + + const int32_t target_min_value = -(1 << (8 * target_bytes_per_sample - 1)); + const int32_t target_max_value = (1 << (8 * target_bytes_per_sample - 1)) - 1; + + for (size_t frame_index = 0; frame_index < total_frames; ++frame_index) { + for (size_t channel_index = 0; channel_index < source_channels; ++channel_index) { + if (this->channels_.test(channel_index)) { + // Channel's current sample is included in the target mask. Convert bits per sample, if necessary. + + size_t sample_index = frame_index * source_bytes_per_frame + channel_index * source_bytes_per_sample; + + int32_t sample = 0; + + // Copy the data into the most significant bits of the sample variable to ensure the sign bit is correct + uint8_t bit_offset = (4 - source_bytes_per_sample) * 8; + for (int i = 0; i < source_bytes_per_sample; ++i) { + sample |= data[sample_index + i] << bit_offset; + bit_offset += 8; + } + + // Shift data back to the least significant bits + if (source_bytes_per_sample >= target_bytes_per_sample) { + // Keep source bytes per sample of data so that the gain multiplication uses all significant bits instead of + // shifting to the target bytes per sample immediately, potentially losing information. + sample >>= (4 - source_bytes_per_sample) * 8; // ``source_bytes_per_sample`` bytes of valid data + } else { + // Keep padded zeros to match the target bytes per sample + sample >>= (4 - target_bytes_per_sample) * 8; // ``target_bytes_per_sample`` bytes of valid data + } + + // Apply gain using multiplication + sample *= this->gain_factor_; + + // Match target output bytes by shifting out the least significant bits + if (source_bytes_per_sample > target_bytes_per_sample) { + sample >>= 8 * (source_bytes_per_sample - + target_bytes_per_sample); // ``target_bytes_per_sample`` bytes of valid data + } + + // Clamp ``sample`` to the target bytes per sample range in case gain multiplication overflows + sample = clamp(sample, target_min_value, target_max_value); + + // Copy ``target_bytes_per_sample`` bytes to the output buffer. + for (int i = 0; i < target_bytes_per_sample; ++i) { + filtered_data.push_back(static_cast(sample)); + sample >>= 8; + } + } + } + } + + return filtered_data; +} + +} // namespace microphone +} // namespace esphome diff --git a/esphome/components/microphone/microphone_source.h b/esphome/components/microphone/microphone_source.h new file mode 100644 index 0000000000..028920f101 --- /dev/null +++ b/esphome/components/microphone/microphone_source.h @@ -0,0 +1,63 @@ +#pragma once + +#include +#include +#include +#include +#include +#include "microphone.h" + +namespace esphome { +namespace microphone { + +class MicrophoneSource { + /* + * @brief Helper class that handles converting raw microphone data to a requested format. + * Components requesting microphone audio should register a callback through this class instead of registering a + * callback directly with the microphone if a particular format is required. + * + * Raw microphone data may have a different number of bits per sample and number of channels than the requesting + * component needs. This class handles the conversion by: + * - Internally adds a callback to receive the raw microphone data + * - The ``process_audio_`` handles the raw data + * - Only the channels set in the ``channels_`` bitset are passed through + * - Passed through samples have the bits per sample converted + * - A gain factor is optionally applied to increase the volume - audio may clip! + * - The processed audio is passed to the callback of the component requesting microphone data + * - It tracks an internal enabled state, so it ignores raw microphone data when the component requesting + * microphone data is not actively requesting audio. + * + * Note that this class cannot convert sample rates! + */ + public: + MicrophoneSource(Microphone *mic, uint8_t bits_per_sample, int32_t gain_factor) + : mic_(mic), bits_per_sample_(bits_per_sample), gain_factor_(gain_factor) {} + + /// @brief Enables a channel to be processed through the callback. + /// + /// If the microphone component only has reads from one channel, it is always in channel number 0, regardless if it + /// represents left or right. If the microphone reads from both left and right, channel number 0 and 1 represent the + /// left and right channels respectively. + /// + /// @param channel 0-indexed channel number to enable + void add_channel(uint8_t channel) { this->channels_.set(channel); } + + void add_data_callback(std::function &)> &&data_callback); + + void start(); + void stop(); + bool is_running() const { return (this->mic_->is_running() && this->enabled_); } + bool is_stopped() const { return !this->enabled_; } + + protected: + std::vector process_audio_(const std::vector &data); + + Microphone *mic_; + uint8_t bits_per_sample_; + std::bitset<8> channels_; + int32_t gain_factor_; + bool enabled_{false}; +}; + +} // namespace microphone +} // namespace esphome From b8ba26787e80f6b05107401c21b63e4b4e28b32b Mon Sep 17 00:00:00 2001 From: Thomas Rupprecht Date: Tue, 29 Apr 2025 02:24:48 +0200 Subject: [PATCH 043/102] [pmsx003] Refactor Imports, Extract Constants, Improve Data Handling & Logging (#8344) --- CODEOWNERS | 1 + esphome/components/pmsx003/pmsx003.cpp | 402 +++++++++++-------------- esphome/components/pmsx003/pmsx003.h | 74 +++-- esphome/components/pmsx003/sensor.py | 32 +- 4 files changed, 251 insertions(+), 258 deletions(-) diff --git a/CODEOWNERS b/CODEOWNERS index 6f5eae1a9c..06d3601858 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -328,6 +328,7 @@ esphome/components/pipsolar/* @andreashergert1984 esphome/components/pm1006/* @habbie esphome/components/pm2005/* @andrewjswan esphome/components/pmsa003i/* @sjtrny +esphome/components/pmsx003/* @ximex esphome/components/pmwcs3/* @SeByDocKy esphome/components/pn532/* @OttoWinter @jesserockz esphome/components/pn532_i2c/* @OttoWinter @jesserockz diff --git a/esphome/components/pmsx003/pmsx003.cpp b/esphome/components/pmsx003/pmsx003.cpp index de2b23b8eb..11626768d8 100644 --- a/esphome/components/pmsx003/pmsx003.cpp +++ b/esphome/components/pmsx003/pmsx003.cpp @@ -6,45 +6,39 @@ namespace pmsx003 { static const char *const TAG = "pmsx003"; -void PMSX003Component::set_pm_1_0_std_sensor(sensor::Sensor *pm_1_0_std_sensor) { - pm_1_0_std_sensor_ = pm_1_0_std_sensor; -} -void PMSX003Component::set_pm_2_5_std_sensor(sensor::Sensor *pm_2_5_std_sensor) { - pm_2_5_std_sensor_ = pm_2_5_std_sensor; -} -void PMSX003Component::set_pm_10_0_std_sensor(sensor::Sensor *pm_10_0_std_sensor) { - pm_10_0_std_sensor_ = pm_10_0_std_sensor; -} +static const uint8_t START_CHARACTER_1 = 0x42; +static const uint8_t START_CHARACTER_2 = 0x4D; -void PMSX003Component::set_pm_1_0_sensor(sensor::Sensor *pm_1_0_sensor) { pm_1_0_sensor_ = pm_1_0_sensor; } -void PMSX003Component::set_pm_2_5_sensor(sensor::Sensor *pm_2_5_sensor) { pm_2_5_sensor_ = pm_2_5_sensor; } -void PMSX003Component::set_pm_10_0_sensor(sensor::Sensor *pm_10_0_sensor) { pm_10_0_sensor_ = pm_10_0_sensor; } +static const uint16_t PMS_STABILISING_MS = 30000; // time taken for the sensor to become stable after power on in ms -void PMSX003Component::set_pm_particles_03um_sensor(sensor::Sensor *pm_particles_03um_sensor) { - pm_particles_03um_sensor_ = pm_particles_03um_sensor; -} -void PMSX003Component::set_pm_particles_05um_sensor(sensor::Sensor *pm_particles_05um_sensor) { - pm_particles_05um_sensor_ = pm_particles_05um_sensor; -} -void PMSX003Component::set_pm_particles_10um_sensor(sensor::Sensor *pm_particles_10um_sensor) { - pm_particles_10um_sensor_ = pm_particles_10um_sensor; -} -void PMSX003Component::set_pm_particles_25um_sensor(sensor::Sensor *pm_particles_25um_sensor) { - pm_particles_25um_sensor_ = pm_particles_25um_sensor; -} -void PMSX003Component::set_pm_particles_50um_sensor(sensor::Sensor *pm_particles_50um_sensor) { - pm_particles_50um_sensor_ = pm_particles_50um_sensor; -} -void PMSX003Component::set_pm_particles_100um_sensor(sensor::Sensor *pm_particles_100um_sensor) { - pm_particles_100um_sensor_ = pm_particles_100um_sensor; -} +static const uint16_t PMS_CMD_MEASUREMENT_MODE_PASSIVE = + 0x0000; // use `PMS_CMD_MANUAL_MEASUREMENT` to trigger a measurement +static const uint16_t PMS_CMD_MEASUREMENT_MODE_ACTIVE = 0x0001; // automatically perform measurements +static const uint16_t PMS_CMD_SLEEP_MODE_SLEEP = 0x0000; // go to sleep mode +static const uint16_t PMS_CMD_SLEEP_MODE_WAKEUP = 0x0001; // wake up from sleep mode -void PMSX003Component::set_temperature_sensor(sensor::Sensor *temperature_sensor) { - temperature_sensor_ = temperature_sensor; -} -void PMSX003Component::set_humidity_sensor(sensor::Sensor *humidity_sensor) { humidity_sensor_ = humidity_sensor; } -void PMSX003Component::set_formaldehyde_sensor(sensor::Sensor *formaldehyde_sensor) { - formaldehyde_sensor_ = formaldehyde_sensor; +void PMSX003Component::dump_config() { + ESP_LOGCONFIG(TAG, "PMSX003:"); + LOG_SENSOR(" ", "PM1.0STD", this->pm_1_0_std_sensor_); + LOG_SENSOR(" ", "PM2.5STD", this->pm_2_5_std_sensor_); + LOG_SENSOR(" ", "PM10.0STD", this->pm_10_0_std_sensor_); + + LOG_SENSOR(" ", "PM1.0", this->pm_1_0_sensor_); + LOG_SENSOR(" ", "PM2.5", this->pm_2_5_sensor_); + LOG_SENSOR(" ", "PM10.0", this->pm_10_0_sensor_); + + LOG_SENSOR(" ", "PM0.3um", this->pm_particles_03um_sensor_); + LOG_SENSOR(" ", "PM0.5um", this->pm_particles_05um_sensor_); + LOG_SENSOR(" ", "PM1.0um", this->pm_particles_10um_sensor_); + LOG_SENSOR(" ", "PM2.5um", this->pm_particles_25um_sensor_); + LOG_SENSOR(" ", "PM5.0um", this->pm_particles_50um_sensor_); + LOG_SENSOR(" ", "PM10.0um", this->pm_particles_100um_sensor_); + + LOG_SENSOR(" ", "Formaldehyde", this->formaldehyde_sensor_); + + LOG_SENSOR(" ", "Temperature", this->temperature_sensor_); + LOG_SENSOR(" ", "Humidity", this->humidity_sensor_); + this->check_uart_settings(9600); } void PMSX003Component::loop() { @@ -55,8 +49,8 @@ void PMSX003Component::loop() { // need to keep track of what state we're in. if (this->update_interval_ > PMS_STABILISING_MS) { if (this->initialised_ == 0) { - this->send_command_(PMS_CMD_AUTO_MANUAL, 0); - this->send_command_(PMS_CMD_ON_STANDBY, 1); + this->send_command_(PMS_CMD_MEASUREMENT_MODE, PMS_CMD_MEASUREMENT_MODE_PASSIVE); + this->send_command_(PMS_CMD_SLEEP_MODE, PMS_CMD_SLEEP_MODE_WAKEUP); this->initialised_ = 1; } switch (this->state_) { @@ -66,7 +60,7 @@ void PMSX003Component::loop() { return; this->state_ = PMSX003_STATE_STABILISING; - this->send_command_(PMS_CMD_ON_STANDBY, 1); + this->send_command_(PMS_CMD_SLEEP_MODE, PMS_CMD_SLEEP_MODE_WAKEUP); this->fan_on_time_ = now; return; case PMSX003_STATE_STABILISING: @@ -77,7 +71,7 @@ void PMSX003Component::loop() { while (this->available()) this->read_byte(&this->data_[0]); // Trigger a new read - this->send_command_(PMS_CMD_TRIG_MANUAL, 0); + this->send_command_(PMS_CMD_MANUAL_MEASUREMENT, 0); this->state_ = PMSX003_STATE_WAITING; break; case PMSX003_STATE_WAITING: @@ -116,242 +110,212 @@ void PMSX003Component::loop() { } } } -float PMSX003Component::get_setup_priority() const { return setup_priority::DATA; } + optional PMSX003Component::check_byte_() { - uint8_t index = this->data_index_; - uint8_t byte = this->data_[index]; + const uint8_t index = this->data_index_; + const uint8_t byte = this->data_[index]; - if (index == 0) - return byte == 0x42; - - if (index == 1) - return byte == 0x4D; - - if (index == 2) - return true; - - uint16_t payload_length = this->get_16_bit_uint_(2); - if (index == 3) { - bool length_matches = false; - switch (this->type_) { - case PMSX003_TYPE_X003: - length_matches = payload_length == 28 || payload_length == 20; - break; - case PMSX003_TYPE_5003T: - case PMSX003_TYPE_5003S: - length_matches = payload_length == 28; - break; - case PMSX003_TYPE_5003ST: - length_matches = payload_length == 36; - break; + if (index == 0 || index == 1) { + const uint8_t start_char = index == 0 ? START_CHARACTER_1 : START_CHARACTER_2; + if (byte == start_char) { + return true; } - if (!length_matches) { - ESP_LOGW(TAG, "PMSX003 length %u doesn't match. Are you using the correct PMSX003 type?", payload_length); - return false; - } + ESP_LOGW(TAG, "Start character %u mismatch: 0x%02X != 0x%02X", index + 1, byte, START_CHARACTER_1); + return false; + } + + if (index == 2) { return true; } - // start (16bit) + length (16bit) + DATA (payload_length-2 bytes) + checksum (16bit) - uint8_t total_size = 4 + payload_length; + const uint16_t payload_length = this->get_16_bit_uint_(2); + if (index == 3) { + if (this->check_payload_length_(payload_length)) { + return true; + } else { + ESP_LOGW(TAG, "Payload length %u doesn't match. Are you using the correct PMSX003 type?", payload_length); + return false; + } + } - if (index < total_size - 1) + // start (16bit) + length (16bit) + DATA (payload_length - 16bit) + checksum (16bit) + const uint16_t total_size = 4 + payload_length; + + if (index < total_size - 1) { return true; + } // checksum is without checksum bytes uint16_t checksum = 0; - for (uint8_t i = 0; i < total_size - 2; i++) + for (uint16_t i = 0; i < total_size - 2; i++) { checksum += this->data_[i]; + } - uint16_t check = this->get_16_bit_uint_(total_size - 2); + const uint16_t check = this->get_16_bit_uint_(total_size - 2); if (checksum != check) { - ESP_LOGW(TAG, "PMSX003 checksum mismatch! 0x%02X!=0x%02X", checksum, check); + ESP_LOGW(TAG, "PMSX003 checksum mismatch! 0x%02X != 0x%02X", checksum, check); return false; } return {}; } -void PMSX003Component::send_command_(uint8_t cmd, uint16_t data) { - this->data_index_ = 0; - this->data_[data_index_++] = 0x42; - this->data_[data_index_++] = 0x4D; - this->data_[data_index_++] = cmd; - this->data_[data_index_++] = (data >> 8) & 0xFF; - this->data_[data_index_++] = (data >> 0) & 0xFF; - int sum = 0; - for (int i = 0; i < data_index_; i++) { - sum += this->data_[i]; +bool PMSX003Component::check_payload_length_(uint16_t payload_length) { + switch (this->type_) { + case PMSX003_TYPE_X003: + // The expected payload length is typically 28 bytes. + // However, a 20-byte payload check was already present in the code. + // No official documentation was found confirming this. + // Retaining this check to avoid breaking existing behavior. + return payload_length == 28 || payload_length == 20; // 2*13+2 + case PMSX003_TYPE_5003T: + case PMSX003_TYPE_5003S: + return payload_length == 28; // 2*13+2 (Data 13 not set/reserved) + case PMSX003_TYPE_5003ST: + return payload_length == 36; // 2*17+2 (Data 16 not set/reserved) } - this->data_[data_index_++] = (sum >> 8) & 0xFF; - this->data_[data_index_++] = (sum >> 0) & 0xFF; - for (int i = 0; i < data_index_; i++) { - this->write_byte(this->data_[i]); + return false; +} + +void PMSX003Component::send_command_(PMSX0003Command cmd, uint16_t data) { + uint8_t send_data[7] = { + START_CHARACTER_1, // Start Byte 1 + START_CHARACTER_2, // Start Byte 2 + cmd, // Command + uint8_t((data >> 8) & 0xFF), // Data 1 + uint8_t((data >> 0) & 0xFF), // Data 2 + 0, // Verify Byte 1 + 0, // Verify Byte 2 + }; + + // Calculate checksum + uint16_t checksum = 0; + for (uint8_t i = 0; i < 5; i++) { + checksum += send_data[i]; + } + send_data[5] = (checksum >> 8) & 0xFF; // Verify Byte 1 + send_data[6] = (checksum >> 0) & 0xFF; // Verify Byte 2 + + for (auto send_byte : send_data) { + this->write_byte(send_byte); } - this->data_index_ = 0; } void PMSX003Component::parse_data_() { - switch (this->type_) { - case PMSX003_TYPE_5003ST: { - float temperature = (int16_t) this->get_16_bit_uint_(30) / 10.0f; - float humidity = this->get_16_bit_uint_(32) / 10.0f; + // Particle Matter + const uint16_t pm_1_0_std_concentration = this->get_16_bit_uint_(4); + const uint16_t pm_2_5_std_concentration = this->get_16_bit_uint_(6); + const uint16_t pm_10_0_std_concentration = this->get_16_bit_uint_(8); - ESP_LOGD(TAG, "Got Temperature: %.1f°C, Humidity: %.1f%%", temperature, humidity); + const uint16_t pm_1_0_concentration = this->get_16_bit_uint_(10); + const uint16_t pm_2_5_concentration = this->get_16_bit_uint_(12); + const uint16_t pm_10_0_concentration = this->get_16_bit_uint_(14); - if (this->temperature_sensor_ != nullptr) - this->temperature_sensor_->publish_state(temperature); - if (this->humidity_sensor_ != nullptr) - this->humidity_sensor_->publish_state(humidity); - // The rest of the PMS5003ST matches the PMS5003S, continue on - } - case PMSX003_TYPE_5003S: { - uint16_t formaldehyde = this->get_16_bit_uint_(28); + const uint16_t pm_particles_03um = this->get_16_bit_uint_(16); + const uint16_t pm_particles_05um = this->get_16_bit_uint_(18); + const uint16_t pm_particles_10um = this->get_16_bit_uint_(20); + const uint16_t pm_particles_25um = this->get_16_bit_uint_(22); - ESP_LOGD(TAG, "Got Formaldehyde: %u µg/m^3", formaldehyde); + ESP_LOGD(TAG, + "Got PM1.0 Standard Concentration: %u µg/m³, PM2.5 Standard Concentration %u µg/m³, PM10.0 Standard " + "Concentration: %u µg/m³, PM1.0 Concentration: %u µg/m³, PM2.5 Concentration %u µg/m³, PM10.0 " + "Concentration: %u µg/m³", + pm_1_0_std_concentration, pm_2_5_std_concentration, pm_10_0_std_concentration, pm_1_0_concentration, + pm_2_5_concentration, pm_10_0_concentration); - if (this->formaldehyde_sensor_ != nullptr) - this->formaldehyde_sensor_->publish_state(formaldehyde); - // The rest of the PMS5003S matches the PMS5003, continue on - } - case PMSX003_TYPE_X003: { - uint16_t pm_1_0_std_concentration = this->get_16_bit_uint_(4); - uint16_t pm_2_5_std_concentration = this->get_16_bit_uint_(6); - uint16_t pm_10_0_std_concentration = this->get_16_bit_uint_(8); + if (this->pm_1_0_std_sensor_ != nullptr) + this->pm_1_0_std_sensor_->publish_state(pm_1_0_std_concentration); + if (this->pm_2_5_std_sensor_ != nullptr) + this->pm_2_5_std_sensor_->publish_state(pm_2_5_std_concentration); + if (this->pm_10_0_std_sensor_ != nullptr) + this->pm_10_0_std_sensor_->publish_state(pm_10_0_std_concentration); - uint16_t pm_1_0_concentration = this->get_16_bit_uint_(10); - uint16_t pm_2_5_concentration = this->get_16_bit_uint_(12); - uint16_t pm_10_0_concentration = this->get_16_bit_uint_(14); + if (this->pm_1_0_sensor_ != nullptr) + this->pm_1_0_sensor_->publish_state(pm_1_0_concentration); + if (this->pm_2_5_sensor_ != nullptr) + this->pm_2_5_sensor_->publish_state(pm_2_5_concentration); + if (this->pm_10_0_sensor_ != nullptr) + this->pm_10_0_sensor_->publish_state(pm_10_0_concentration); - uint16_t pm_particles_03um = this->get_16_bit_uint_(16); - uint16_t pm_particles_05um = this->get_16_bit_uint_(18); - uint16_t pm_particles_10um = this->get_16_bit_uint_(20); - uint16_t pm_particles_25um = this->get_16_bit_uint_(22); - uint16_t pm_particles_50um = this->get_16_bit_uint_(24); - uint16_t pm_particles_100um = this->get_16_bit_uint_(26); + if (this->pm_particles_03um_sensor_ != nullptr) + this->pm_particles_03um_sensor_->publish_state(pm_particles_03um); + if (this->pm_particles_05um_sensor_ != nullptr) + this->pm_particles_05um_sensor_->publish_state(pm_particles_05um); + if (this->pm_particles_10um_sensor_ != nullptr) + this->pm_particles_10um_sensor_->publish_state(pm_particles_10um); + if (this->pm_particles_25um_sensor_ != nullptr) + this->pm_particles_25um_sensor_->publish_state(pm_particles_25um); - ESP_LOGD(TAG, - "Got PM1.0 Concentration: %u µg/m^3, PM2.5 Concentration %u µg/m^3, PM10.0 Concentration: %u µg/m^3", - pm_1_0_concentration, pm_2_5_concentration, pm_10_0_concentration); + if (this->type_ == PMSX003_TYPE_5003T) { + ESP_LOGD(TAG, + "Got PM0.3 Particles: %u Count/0.1L, PM0.5 Particles: %u Count/0.1L, PM1.0 Particles: %u Count/0.1L, " + "PM2.5 Particles %u Count/0.1L", + pm_particles_03um, pm_particles_05um, pm_particles_10um, pm_particles_25um); + } else { + // Note the pm particles 50um & 100um are not returned, + // as PMS5003T uses those data values for temperature and humidity. + const uint16_t pm_particles_50um = this->get_16_bit_uint_(24); + const uint16_t pm_particles_100um = this->get_16_bit_uint_(26); - if (this->pm_1_0_std_sensor_ != nullptr) - this->pm_1_0_std_sensor_->publish_state(pm_1_0_std_concentration); - if (this->pm_2_5_std_sensor_ != nullptr) - this->pm_2_5_std_sensor_->publish_state(pm_2_5_std_concentration); - if (this->pm_10_0_std_sensor_ != nullptr) - this->pm_10_0_std_sensor_->publish_state(pm_10_0_std_concentration); + ESP_LOGD(TAG, + "Got PM0.3 Particles: %u Count/0.1L, PM0.5 Particles: %u Count/0.1L, PM1.0 Particles: %u Count/0.1L, " + "PM2.5 Particles %u Count/0.1L, PM5.0 Particles: %u Count/0.1L, PM10.0 Particles %u Count/0.1L", + pm_particles_03um, pm_particles_05um, pm_particles_10um, pm_particles_25um, pm_particles_50um, + pm_particles_100um); - if (this->pm_1_0_sensor_ != nullptr) - this->pm_1_0_sensor_->publish_state(pm_1_0_concentration); - if (this->pm_2_5_sensor_ != nullptr) - this->pm_2_5_sensor_->publish_state(pm_2_5_concentration); - if (this->pm_10_0_sensor_ != nullptr) - this->pm_10_0_sensor_->publish_state(pm_10_0_concentration); + if (this->pm_particles_50um_sensor_ != nullptr) + this->pm_particles_50um_sensor_->publish_state(pm_particles_50um); + if (this->pm_particles_100um_sensor_ != nullptr) + this->pm_particles_100um_sensor_->publish_state(pm_particles_100um); + } - if (this->pm_particles_03um_sensor_ != nullptr) - this->pm_particles_03um_sensor_->publish_state(pm_particles_03um); - if (this->pm_particles_05um_sensor_ != nullptr) - this->pm_particles_05um_sensor_->publish_state(pm_particles_05um); - if (this->pm_particles_10um_sensor_ != nullptr) - this->pm_particles_10um_sensor_->publish_state(pm_particles_10um); - if (this->pm_particles_25um_sensor_ != nullptr) - this->pm_particles_25um_sensor_->publish_state(pm_particles_25um); - if (this->pm_particles_50um_sensor_ != nullptr) - this->pm_particles_50um_sensor_->publish_state(pm_particles_50um); - if (this->pm_particles_100um_sensor_ != nullptr) - this->pm_particles_100um_sensor_->publish_state(pm_particles_100um); - break; - } - case PMSX003_TYPE_5003T: { - uint16_t pm_1_0_std_concentration = this->get_16_bit_uint_(4); - uint16_t pm_2_5_std_concentration = this->get_16_bit_uint_(6); - uint16_t pm_10_0_std_concentration = this->get_16_bit_uint_(8); + // Formaldehyde + if (this->type_ == PMSX003_TYPE_5003ST || this->type_ == PMSX003_TYPE_5003S) { + const uint16_t formaldehyde = this->get_16_bit_uint_(28); - uint16_t pm_1_0_concentration = this->get_16_bit_uint_(10); - uint16_t pm_2_5_concentration = this->get_16_bit_uint_(12); - uint16_t pm_10_0_concentration = this->get_16_bit_uint_(14); + ESP_LOGD(TAG, "Got Formaldehyde: %u µg/m^3", formaldehyde); - uint16_t pm_particles_03um = this->get_16_bit_uint_(16); - uint16_t pm_particles_05um = this->get_16_bit_uint_(18); - uint16_t pm_particles_10um = this->get_16_bit_uint_(20); - uint16_t pm_particles_25um = this->get_16_bit_uint_(22); - // Note the pm particles 50um & 100um are not returned, - // as PMS5003T uses those data values for temperature and humidity. + if (this->formaldehyde_sensor_ != nullptr) + this->formaldehyde_sensor_->publish_state(formaldehyde); + } - float temperature = (int16_t) this->get_16_bit_uint_(24) / 10.0f; - float humidity = this->get_16_bit_uint_(26) / 10.0f; + // Temperature and Humidity + if (this->type_ == PMSX003_TYPE_5003ST || this->type_ == PMSX003_TYPE_5003T) { + const uint8_t temperature_offset = (this->type_ == PMSX003_TYPE_5003T) ? 24 : 30; - ESP_LOGD(TAG, - "Got PM1.0 Concentration: %u µg/m^3, PM2.5 Concentration %u µg/m^3, PM10.0 Concentration: %u µg/m^3, " - "Temperature: %.1f°C, Humidity: %.1f%%", - pm_1_0_concentration, pm_2_5_concentration, pm_10_0_concentration, temperature, humidity); + const float temperature = static_cast(this->get_16_bit_uint_(temperature_offset)) / 10.0f; + const float humidity = this->get_16_bit_uint_(temperature_offset + 2) / 10.0f; - if (this->pm_1_0_std_sensor_ != nullptr) - this->pm_1_0_std_sensor_->publish_state(pm_1_0_std_concentration); - if (this->pm_2_5_std_sensor_ != nullptr) - this->pm_2_5_std_sensor_->publish_state(pm_2_5_std_concentration); - if (this->pm_10_0_std_sensor_ != nullptr) - this->pm_10_0_std_sensor_->publish_state(pm_10_0_std_concentration); + ESP_LOGD(TAG, "Got Temperature: %.1f°C, Humidity: %.1f%%", temperature, humidity); - if (this->pm_1_0_sensor_ != nullptr) - this->pm_1_0_sensor_->publish_state(pm_1_0_concentration); - if (this->pm_2_5_sensor_ != nullptr) - this->pm_2_5_sensor_->publish_state(pm_2_5_concentration); - if (this->pm_10_0_sensor_ != nullptr) - this->pm_10_0_sensor_->publish_state(pm_10_0_concentration); + if (this->temperature_sensor_ != nullptr) + this->temperature_sensor_->publish_state(temperature); + if (this->humidity_sensor_ != nullptr) + this->humidity_sensor_->publish_state(humidity); + } - if (this->pm_particles_03um_sensor_ != nullptr) - this->pm_particles_03um_sensor_->publish_state(pm_particles_03um); - if (this->pm_particles_05um_sensor_ != nullptr) - this->pm_particles_05um_sensor_->publish_state(pm_particles_05um); - if (this->pm_particles_10um_sensor_ != nullptr) - this->pm_particles_10um_sensor_->publish_state(pm_particles_10um); - if (this->pm_particles_25um_sensor_ != nullptr) - this->pm_particles_25um_sensor_->publish_state(pm_particles_25um); + // Firmware Version and Error Code + if (this->type_ == PMSX003_TYPE_5003ST) { + const uint8_t firmware_version = this->data_[36]; + const uint8_t error_code = this->data_[37]; - if (this->temperature_sensor_ != nullptr) - this->temperature_sensor_->publish_state(temperature); - if (this->humidity_sensor_ != nullptr) - this->humidity_sensor_->publish_state(humidity); - break; - } + ESP_LOGD(TAG, "Got Firmware Version: 0x%02X, Error Code: 0x%02X", firmware_version, error_code); } // Spin down the sensor again if we aren't going to need it until more time has // passed than it takes to stabilise if (this->update_interval_ > PMS_STABILISING_MS) { - this->send_command_(PMS_CMD_ON_STANDBY, 0); + this->send_command_(PMS_CMD_SLEEP_MODE, PMS_CMD_SLEEP_MODE_SLEEP); this->state_ = PMSX003_STATE_IDLE; } this->status_clear_warning(); } + uint16_t PMSX003Component::get_16_bit_uint_(uint8_t start_index) { return (uint16_t(this->data_[start_index]) << 8) | uint16_t(this->data_[start_index + 1]); } -void PMSX003Component::dump_config() { - ESP_LOGCONFIG(TAG, "PMSX003:"); - LOG_SENSOR(" ", "PM1.0STD", this->pm_1_0_std_sensor_); - LOG_SENSOR(" ", "PM2.5STD", this->pm_2_5_std_sensor_); - LOG_SENSOR(" ", "PM10.0STD", this->pm_10_0_std_sensor_); - - LOG_SENSOR(" ", "PM1.0", this->pm_1_0_sensor_); - LOG_SENSOR(" ", "PM2.5", this->pm_2_5_sensor_); - LOG_SENSOR(" ", "PM10.0", this->pm_10_0_sensor_); - - LOG_SENSOR(" ", "PM0.3um", this->pm_particles_03um_sensor_); - LOG_SENSOR(" ", "PM0.5um", this->pm_particles_05um_sensor_); - LOG_SENSOR(" ", "PM1.0um", this->pm_particles_10um_sensor_); - LOG_SENSOR(" ", "PM2.5um", this->pm_particles_25um_sensor_); - LOG_SENSOR(" ", "PM5.0um", this->pm_particles_50um_sensor_); - LOG_SENSOR(" ", "PM10.0um", this->pm_particles_100um_sensor_); - - LOG_SENSOR(" ", "Temperature", this->temperature_sensor_); - LOG_SENSOR(" ", "Humidity", this->humidity_sensor_); - LOG_SENSOR(" ", "Formaldehyde", this->formaldehyde_sensor_); - this->check_uart_settings(9600); -} } // namespace pmsx003 } // namespace esphome diff --git a/esphome/components/pmsx003/pmsx003.h b/esphome/components/pmsx003/pmsx003.h index cb5c16aecf..85bb1ff9f3 100644 --- a/esphome/components/pmsx003/pmsx003.h +++ b/esphome/components/pmsx003/pmsx003.h @@ -7,13 +7,12 @@ namespace esphome { namespace pmsx003 { -// known command bytes -static const uint8_t PMS_CMD_AUTO_MANUAL = - 0xE1; // data=0: perform measurement manually, data=1: perform measurement automatically -static const uint8_t PMS_CMD_TRIG_MANUAL = 0xE2; // trigger a manual measurement -static const uint8_t PMS_CMD_ON_STANDBY = 0xE4; // data=0: go to standby mode, data=1: go to normal mode - -static const uint16_t PMS_STABILISING_MS = 30000; // time taken for the sensor to become stable after power on +enum PMSX0003Command : uint8_t { + PMS_CMD_MEASUREMENT_MODE = + 0xE1, // Data Options: `PMS_CMD_MEASUREMENT_MODE_PASSIVE`, `PMS_CMD_MEASUREMENT_MODE_ACTIVE` + PMS_CMD_MANUAL_MEASUREMENT = 0xE2, + PMS_CMD_SLEEP_MODE = 0xE4, // Data Options: `PMS_CMD_SLEEP_MODE_SLEEP`, `PMS_CMD_SLEEP_MODE_WAKEUP` +}; enum PMSX003Type { PMSX003_TYPE_X003 = 0, @@ -31,37 +30,53 @@ enum PMSX003State { class PMSX003Component : public uart::UARTDevice, public Component { public: PMSX003Component() = default; - void loop() override; - float get_setup_priority() const override; + float get_setup_priority() const override { return setup_priority::DATA; } void dump_config() override; + void loop() override; - void set_type(PMSX003Type type) { type_ = type; } + void set_update_interval(uint32_t update_interval) { this->update_interval_ = update_interval; } - void set_update_interval(uint32_t val) { update_interval_ = val; }; + void set_type(PMSX003Type type) { this->type_ = type; } - void set_pm_1_0_std_sensor(sensor::Sensor *pm_1_0_std_sensor); - void set_pm_2_5_std_sensor(sensor::Sensor *pm_2_5_std_sensor); - void set_pm_10_0_std_sensor(sensor::Sensor *pm_10_0_std_sensor); + void set_pm_1_0_std_sensor(sensor::Sensor *pm_1_0_std_sensor) { this->pm_1_0_std_sensor_ = pm_1_0_std_sensor; } + void set_pm_2_5_std_sensor(sensor::Sensor *pm_2_5_std_sensor) { this->pm_2_5_std_sensor_ = pm_2_5_std_sensor; } + void set_pm_10_0_std_sensor(sensor::Sensor *pm_10_0_std_sensor) { this->pm_10_0_std_sensor_ = pm_10_0_std_sensor; } - void set_pm_1_0_sensor(sensor::Sensor *pm_1_0_sensor); - void set_pm_2_5_sensor(sensor::Sensor *pm_2_5_sensor); - void set_pm_10_0_sensor(sensor::Sensor *pm_10_0_sensor); + void set_pm_1_0_sensor(sensor::Sensor *pm_1_0_sensor) { this->pm_1_0_sensor_ = pm_1_0_sensor; } + void set_pm_2_5_sensor(sensor::Sensor *pm_2_5_sensor) { this->pm_2_5_sensor_ = pm_2_5_sensor; } + void set_pm_10_0_sensor(sensor::Sensor *pm_10_0_sensor) { this->pm_10_0_sensor_ = pm_10_0_sensor; } - void set_pm_particles_03um_sensor(sensor::Sensor *pm_particles_03um_sensor); - void set_pm_particles_05um_sensor(sensor::Sensor *pm_particles_05um_sensor); - void set_pm_particles_10um_sensor(sensor::Sensor *pm_particles_10um_sensor); - void set_pm_particles_25um_sensor(sensor::Sensor *pm_particles_25um_sensor); - void set_pm_particles_50um_sensor(sensor::Sensor *pm_particles_50um_sensor); - void set_pm_particles_100um_sensor(sensor::Sensor *pm_particles_100um_sensor); + void set_pm_particles_03um_sensor(sensor::Sensor *pm_particles_03um_sensor) { + this->pm_particles_03um_sensor_ = pm_particles_03um_sensor; + } + void set_pm_particles_05um_sensor(sensor::Sensor *pm_particles_05um_sensor) { + this->pm_particles_05um_sensor_ = pm_particles_05um_sensor; + } + void set_pm_particles_10um_sensor(sensor::Sensor *pm_particles_10um_sensor) { + this->pm_particles_10um_sensor_ = pm_particles_10um_sensor; + } + void set_pm_particles_25um_sensor(sensor::Sensor *pm_particles_25um_sensor) { + this->pm_particles_25um_sensor_ = pm_particles_25um_sensor; + } + void set_pm_particles_50um_sensor(sensor::Sensor *pm_particles_50um_sensor) { + this->pm_particles_50um_sensor_ = pm_particles_50um_sensor; + } + void set_pm_particles_100um_sensor(sensor::Sensor *pm_particles_100um_sensor) { + this->pm_particles_100um_sensor_ = pm_particles_100um_sensor; + } - void set_temperature_sensor(sensor::Sensor *temperature_sensor); - void set_humidity_sensor(sensor::Sensor *humidity_sensor); - void set_formaldehyde_sensor(sensor::Sensor *formaldehyde_sensor); + void set_formaldehyde_sensor(sensor::Sensor *formaldehyde_sensor) { + this->formaldehyde_sensor_ = formaldehyde_sensor; + } + + void set_temperature_sensor(sensor::Sensor *temperature_sensor) { this->temperature_sensor_ = temperature_sensor; } + void set_humidity_sensor(sensor::Sensor *humidity_sensor) { this->humidity_sensor_ = humidity_sensor; } protected: optional check_byte_(); void parse_data_(); - void send_command_(uint8_t cmd, uint16_t data); + bool check_payload_length_(uint16_t payload_length); + void send_command_(PMSX0003Command cmd, uint16_t data); uint16_t get_16_bit_uint_(uint8_t start_index); uint8_t data_[64]; @@ -92,9 +107,12 @@ class PMSX003Component : public uart::UARTDevice, public Component { sensor::Sensor *pm_particles_50um_sensor_{nullptr}; sensor::Sensor *pm_particles_100um_sensor_{nullptr}; + // Formaldehyde + sensor::Sensor *formaldehyde_sensor_{nullptr}; + + // Temperature and Humidity sensor::Sensor *temperature_sensor_{nullptr}; sensor::Sensor *humidity_sensor_{nullptr}; - sensor::Sensor *formaldehyde_sensor_{nullptr}; }; } // namespace pmsx003 diff --git a/esphome/components/pmsx003/sensor.py b/esphome/components/pmsx003/sensor.py index 1556b3c983..bebd3a01ee 100644 --- a/esphome/components/pmsx003/sensor.py +++ b/esphome/components/pmsx003/sensor.py @@ -33,6 +33,7 @@ from esphome.const import ( UNIT_PERCENT, ) +CODEOWNERS = ["@ximex"] DEPENDENCIES = ["uart"] pmsx003_ns = cg.esphome_ns.namespace("pmsx003") @@ -57,9 +58,18 @@ SENSORS_TO_TYPE = { CONF_PM_1_0: [TYPE_PMSX003, TYPE_PMS5003T, TYPE_PMS5003ST, TYPE_PMS5003S], CONF_PM_2_5: [TYPE_PMSX003, TYPE_PMS5003T, TYPE_PMS5003ST, TYPE_PMS5003S], CONF_PM_10_0: [TYPE_PMSX003, TYPE_PMS5003T, TYPE_PMS5003ST, TYPE_PMS5003S], + CONF_PM_1_0_STD: [TYPE_PMSX003, TYPE_PMS5003T, TYPE_PMS5003ST, TYPE_PMS5003S], + CONF_PM_2_5_STD: [TYPE_PMSX003, TYPE_PMS5003T, TYPE_PMS5003ST, TYPE_PMS5003S], + CONF_PM_10_0_STD: [TYPE_PMSX003, TYPE_PMS5003T, TYPE_PMS5003ST, TYPE_PMS5003S], + CONF_PM_0_3UM: [TYPE_PMSX003, TYPE_PMS5003T, TYPE_PMS5003ST, TYPE_PMS5003S], + CONF_PM_0_5UM: [TYPE_PMSX003, TYPE_PMS5003T, TYPE_PMS5003ST, TYPE_PMS5003S], + CONF_PM_1_0UM: [TYPE_PMSX003, TYPE_PMS5003T, TYPE_PMS5003ST, TYPE_PMS5003S], + CONF_PM_2_5UM: [TYPE_PMSX003, TYPE_PMS5003T, TYPE_PMS5003ST, TYPE_PMS5003S], + CONF_PM_5_0UM: [TYPE_PMSX003, TYPE_PMS5003ST, TYPE_PMS5003S], + CONF_PM_10_0UM: [TYPE_PMSX003, TYPE_PMS5003ST, TYPE_PMS5003S], + CONF_FORMALDEHYDE: [TYPE_PMS5003ST, TYPE_PMS5003S], CONF_TEMPERATURE: [TYPE_PMS5003T, TYPE_PMS5003ST], CONF_HUMIDITY: [TYPE_PMS5003T, TYPE_PMS5003ST], - CONF_FORMALDEHYDE: [TYPE_PMS5003ST, TYPE_PMS5003S], } @@ -164,6 +174,12 @@ CONFIG_SCHEMA = ( accuracy_decimals=0, state_class=STATE_CLASS_MEASUREMENT, ), + cv.Optional(CONF_FORMALDEHYDE): sensor.sensor_schema( + unit_of_measurement=UNIT_MICROGRAMS_PER_CUBIC_METER, + icon=ICON_CHEMICAL_WEAPON, + accuracy_decimals=0, + state_class=STATE_CLASS_MEASUREMENT, + ), cv.Optional(CONF_TEMPERATURE): sensor.sensor_schema( unit_of_measurement=UNIT_CELSIUS, accuracy_decimals=1, @@ -176,12 +192,6 @@ CONFIG_SCHEMA = ( device_class=DEVICE_CLASS_HUMIDITY, state_class=STATE_CLASS_MEASUREMENT, ), - cv.Optional(CONF_FORMALDEHYDE): sensor.sensor_schema( - unit_of_measurement=UNIT_MICROGRAMS_PER_CUBIC_METER, - icon=ICON_CHEMICAL_WEAPON, - accuracy_decimals=0, - state_class=STATE_CLASS_MEASUREMENT, - ), cv.Optional(CONF_UPDATE_INTERVAL, default="0s"): validate_update_interval, } ) @@ -256,6 +266,10 @@ async def to_code(config): sens = await sensor.new_sensor(config[CONF_PM_10_0UM]) cg.add(var.set_pm_particles_100um_sensor(sens)) + if CONF_FORMALDEHYDE in config: + sens = await sensor.new_sensor(config[CONF_FORMALDEHYDE]) + cg.add(var.set_formaldehyde_sensor(sens)) + if CONF_TEMPERATURE in config: sens = await sensor.new_sensor(config[CONF_TEMPERATURE]) cg.add(var.set_temperature_sensor(sens)) @@ -264,8 +278,4 @@ async def to_code(config): sens = await sensor.new_sensor(config[CONF_HUMIDITY]) cg.add(var.set_humidity_sensor(sens)) - if CONF_FORMALDEHYDE in config: - sens = await sensor.new_sensor(config[CONF_FORMALDEHYDE]) - cg.add(var.set_formaldehyde_sensor(sens)) - cg.add(var.set_update_interval(config[CONF_UPDATE_INTERVAL])) From 5f9a509bdcb80bd74fa5dcd8f7ec88cb766675d8 Mon Sep 17 00:00:00 2001 From: cvwillegen Date: Tue, 29 Apr 2025 10:21:05 +0200 Subject: [PATCH 044/102] Add code to send/receive GoBox infrared control messages. (#7554) Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com> --- esphome/components/remote_base/__init__.py | 43 ++++++ .../components/remote_base/gobox_protocol.cpp | 131 ++++++++++++++++++ .../components/remote_base/gobox_protocol.h | 54 ++++++++ .../remote_receiver/common-actions.yaml | 5 + 4 files changed, 233 insertions(+) create mode 100644 esphome/components/remote_base/gobox_protocol.cpp create mode 100644 esphome/components/remote_base/gobox_protocol.h diff --git a/esphome/components/remote_base/__init__.py b/esphome/components/remote_base/__init__.py index adacb83a30..836b98104b 100644 --- a/esphome/components/remote_base/__init__.py +++ b/esphome/components/remote_base/__init__.py @@ -929,6 +929,49 @@ async def pronto_action(var, config, args): cg.add(var.set_data(template_)) +# Gobox +( + GoboxData, + GoboxBinarySensor, + GoboxTrigger, + GoboxAction, + GoboxDumper, +) = declare_protocol("Gobox") +GOBOX_SCHEMA = cv.Schema( + { + cv.Required(CONF_CODE): cv.int_, + } +) + + +@register_binary_sensor("gobox", GoboxBinarySensor, GOBOX_SCHEMA) +def gobox_binary_sensor(var, config): + cg.add( + var.set_data( + cg.StructInitializer( + GoboxData, + ("code", config[CONF_CODE]), + ) + ) + ) + + +@register_trigger("gobox", GoboxTrigger, GoboxData) +def gobox_trigger(var, config): + pass + + +@register_dumper("gobox", GoboxDumper) +def gobox_dumper(var, config): + pass + + +@register_action("gobox", GoboxAction, GOBOX_SCHEMA) +async def gobox_action(var, config, args): + template_ = await cg.templatable(config[CONF_CODE], args, cg.int_) + cg.add(var.set_code(template_)) + + # Roomba ( RoombaData, diff --git a/esphome/components/remote_base/gobox_protocol.cpp b/esphome/components/remote_base/gobox_protocol.cpp new file mode 100644 index 0000000000..54e0dff663 --- /dev/null +++ b/esphome/components/remote_base/gobox_protocol.cpp @@ -0,0 +1,131 @@ +#include "gobox_protocol.h" +#include "esphome/core/log.h" + +namespace esphome { +namespace remote_base { + +static const char *const TAG = "remote.gobox"; + +constexpr uint32_t BIT_MARK_US = 580; // 70us seems like a safe time delta for the receiver... +constexpr uint32_t BIT_ONE_SPACE_US = 1640; +constexpr uint32_t BIT_ZERO_SPACE_US = 545; +constexpr uint64_t HEADER = 0b011001001100010uL; // 15 bits +constexpr uint64_t HEADER_SIZE = 15; +constexpr uint64_t CODE_SIZE = 17; + +void GoboxProtocol::dump_timings_(const RawTimings &timings) const { + ESP_LOGD(TAG, "Gobox: size=%u", timings.size()); + for (int32_t timing : timings) { + ESP_LOGD(TAG, "Gobox: timing=%ld", (long) timing); + } +} + +void GoboxProtocol::encode(RemoteTransmitData *dst, const GoboxData &data) { + ESP_LOGI(TAG, "Send Gobox: code=0x%x", data.code); + dst->set_carrier_frequency(38000); + dst->reserve((HEADER_SIZE + CODE_SIZE + 1) * 2); + uint64_t code = (HEADER << CODE_SIZE) | (data.code & ((1UL << CODE_SIZE) - 1)); + ESP_LOGI(TAG, "Send Gobox: code=0x%Lx", code); + for (int16_t i = (HEADER_SIZE + CODE_SIZE - 1); i >= 0; i--) { + if (code & ((uint64_t) 1 << i)) { + dst->item(BIT_MARK_US, BIT_ONE_SPACE_US); + } else { + dst->item(BIT_MARK_US, BIT_ZERO_SPACE_US); + } + } + dst->item(BIT_MARK_US, 2000); + + dump_timings_(dst->get_data()); +} + +optional GoboxProtocol::decode(RemoteReceiveData src) { + if (src.size() < ((HEADER_SIZE + CODE_SIZE) * 2 + 1)) { + return {}; + } + + // First check for the header + uint64_t code = HEADER; + for (int16_t i = HEADER_SIZE - 1; i >= 0; i--) { + if (code & ((uint64_t) 1 << i)) { + if (!src.expect_item(BIT_MARK_US, BIT_ONE_SPACE_US)) { + return {}; + } + } else { + if (!src.expect_item(BIT_MARK_US, BIT_ZERO_SPACE_US)) { + return {}; + } + } + } + + // Next, build up the code + code = 0UL; + for (int16_t i = CODE_SIZE - 1; i >= 0; i--) { + if (!src.expect_mark(BIT_MARK_US)) { + return {}; + } + if (src.expect_space(BIT_ONE_SPACE_US)) { + code |= (1UL << i); + } else if (!src.expect_space(BIT_ZERO_SPACE_US)) { + return {}; + } + } + + if (!src.expect_mark(BIT_MARK_US)) { + return {}; + } + + dump_timings_(src.get_raw_data()); + + GoboxData out; + out.code = code; + + return out; +} + +void GoboxProtocol::dump(const GoboxData &data) { + ESP_LOGI(TAG, "Received Gobox: code=0x%x", data.code); + switch (data.code) { + case GOBOX_MENU: + ESP_LOGI(TAG, "Received Gobox: key=MENU"); + break; + case GOBOX_RETURN: + ESP_LOGI(TAG, "Received Gobox: key=RETURN"); + break; + case GOBOX_UP: + ESP_LOGI(TAG, "Received Gobox: key=UP"); + break; + case GOBOX_LEFT: + ESP_LOGI(TAG, "Received Gobox: key=LEFT"); + break; + case GOBOX_RIGHT: + ESP_LOGI(TAG, "Received Gobox: key=RIGHT"); + break; + case GOBOX_DOWN: + ESP_LOGI(TAG, "Received Gobox: key=DOWN"); + break; + case GOBOX_OK: + ESP_LOGI(TAG, "Received Gobox: key=OK"); + break; + case GOBOX_TOGGLE: + ESP_LOGI(TAG, "Received Gobox: key=TOGGLE"); + break; + case GOBOX_PROFILE: + ESP_LOGI(TAG, "Received Gobox: key=PROFILE"); + break; + case GOBOX_FASTER: + ESP_LOGI(TAG, "Received Gobox: key=FASTER"); + break; + case GOBOX_SLOWER: + ESP_LOGI(TAG, "Received Gobox: key=SLOWER"); + break; + case GOBOX_LOUDER: + ESP_LOGI(TAG, "Received Gobox: key=LOUDER"); + break; + case GOBOX_SOFTER: + ESP_LOGI(TAG, "Received Gobox: key=SOFTER"); + break; + } +} + +} // namespace remote_base +} // namespace esphome diff --git a/esphome/components/remote_base/gobox_protocol.h b/esphome/components/remote_base/gobox_protocol.h new file mode 100644 index 0000000000..7e18b61458 --- /dev/null +++ b/esphome/components/remote_base/gobox_protocol.h @@ -0,0 +1,54 @@ +#pragma once + +#include "esphome/core/component.h" +#include "remote_base.h" + +namespace esphome { +namespace remote_base { + +struct GoboxData { + int code; + bool operator==(const GoboxData &rhs) const { return code == rhs.code; } +}; + +enum { + GOBOX_MENU = 0xaa55, + GOBOX_RETURN = 0x22dd, + GOBOX_UP = 0x0af5, + GOBOX_LEFT = 0x8a75, + GOBOX_RIGHT = 0x48b7, + GOBOX_DOWN = 0xa25d, + GOBOX_OK = 0xc837, + GOBOX_TOGGLE = 0xb847, + GOBOX_PROFILE = 0xfa05, + GOBOX_FASTER = 0xf00f, + GOBOX_SLOWER = 0xd02f, + GOBOX_LOUDER = 0xb04f, + GOBOX_SOFTER = 0xf807, +}; + +class GoboxProtocol : public RemoteProtocol { + private: + void dump_timings_(const RawTimings &timings) const; + + public: + void encode(RemoteTransmitData *dst, const GoboxData &data) override; + optional decode(RemoteReceiveData src) override; + void dump(const GoboxData &data) override; +}; + +DECLARE_REMOTE_PROTOCOL(Gobox) + +template class GoboxAction : public RemoteTransmitterActionBase { + public: + TEMPLATABLE_VALUE(uint64_t, code); + + void encode(RemoteTransmitData *dst, Ts... x) override { + GoboxData data{}; + data.code = this->code_.value(x...); + GoboxProtocol().encode(dst, data); + } +}; + +} // namespace remote_base +} // namespace esphome diff --git a/tests/components/remote_receiver/common-actions.yaml b/tests/components/remote_receiver/common-actions.yaml index 08b1091116..ca7713f58a 100644 --- a/tests/components/remote_receiver/common-actions.yaml +++ b/tests/components/remote_receiver/common-actions.yaml @@ -48,6 +48,11 @@ on_drayton: - logger.log: format: "on_drayton: %u %u %u" args: ["x.address", "x.channel", "x.command"] +on_gobox: + then: + - logger.log: + format: "on_gobox: %d" + args: ["x.code"] on_jvc: then: - logger.log: From ecb91b0101874154b5273f672594e616362a3cae Mon Sep 17 00:00:00 2001 From: Jesse Hills <3060199+jesserockz@users.noreply.github.com> Date: Wed, 30 Apr 2025 00:43:55 +1200 Subject: [PATCH 045/102] [bluetooth_proxy] Allow changing active/passive via api (#8649) --- esphome/components/api/api.proto | 32 +++++++ esphome/components/api/api_connection.cpp | 5 ++ esphome/components/api/api_connection.h | 1 + esphome/components/api/api_pb2.cpp | 87 +++++++++++++++++++ esphome/components/api/api_pb2.h | 35 ++++++++ esphome/components/api/api_pb2_service.cpp | 34 ++++++++ esphome/components/api/api_pb2_service.h | 12 +++ .../bluetooth_proxy/bluetooth_proxy.cpp | 29 +++++++ .../bluetooth_proxy/bluetooth_proxy.h | 6 ++ .../esp32_ble_tracker/esp32_ble_tracker.cpp | 19 ++-- .../esp32_ble_tracker/esp32_ble_tracker.h | 9 ++ 11 files changed, 262 insertions(+), 7 deletions(-) diff --git a/esphome/components/api/api.proto b/esphome/components/api/api.proto index a7e6af427f..55dc3984b0 100644 --- a/esphome/components/api/api.proto +++ b/esphome/components/api/api.proto @@ -61,6 +61,7 @@ service APIConnection { rpc bluetooth_gatt_notify(BluetoothGATTNotifyRequest) returns (void) {} rpc subscribe_bluetooth_connections_free(SubscribeBluetoothConnectionsFreeRequest) returns (BluetoothConnectionsFreeResponse) {} rpc unsubscribe_bluetooth_le_advertisements(UnsubscribeBluetoothLEAdvertisementsRequest) returns (void) {} + rpc bluetooth_scanner_set_mode(BluetoothScannerSetModeRequest) returns (void) {} rpc subscribe_voice_assistant(SubscribeVoiceAssistantRequest) returns (void) {} rpc voice_assistant_get_configuration(VoiceAssistantConfigurationRequest) returns (VoiceAssistantConfigurationResponse) {} @@ -1472,6 +1473,37 @@ message BluetoothDeviceClearCacheResponse { int32 error = 3; } +enum BluetoothScannerState { + BLUETOOTH_SCANNER_STATE_IDLE = 0; + BLUETOOTH_SCANNER_STATE_STARTING = 1; + BLUETOOTH_SCANNER_STATE_RUNNING = 2; + BLUETOOTH_SCANNER_STATE_FAILED = 3; + BLUETOOTH_SCANNER_STATE_STOPPING = 4; + BLUETOOTH_SCANNER_STATE_STOPPED = 5; +} + +enum BluetoothScannerMode { + BLUETOOTH_SCANNER_MODE_PASSIVE = 0; + BLUETOOTH_SCANNER_MODE_ACTIVE = 1; +} + +message BluetoothScannerStateResponse { + option(id) = 126; + option(source) = SOURCE_SERVER; + option(ifdef) = "USE_BLUETOOTH_PROXY"; + + BluetoothScannerState state = 1; + BluetoothScannerMode mode = 2; +} + +message BluetoothScannerSetModeRequest { + option(id) = 127; + option(source) = SOURCE_CLIENT; + option(ifdef) = "USE_BLUETOOTH_PROXY"; + + BluetoothScannerMode mode = 1; +} + // ==================== PUSH TO TALK ==================== enum VoiceAssistantSubscribeFlag { VOICE_ASSISTANT_SUBSCRIBE_NONE = 0; diff --git a/esphome/components/api/api_connection.cpp b/esphome/components/api/api_connection.cpp index 27db953329..4670aeca63 100644 --- a/esphome/components/api/api_connection.cpp +++ b/esphome/components/api/api_connection.cpp @@ -1475,6 +1475,11 @@ BluetoothConnectionsFreeResponse APIConnection::subscribe_bluetooth_connections_ resp.limit = bluetooth_proxy::global_bluetooth_proxy->get_bluetooth_connections_limit(); return resp; } + +void APIConnection::bluetooth_scanner_set_mode(const BluetoothScannerSetModeRequest &msg) { + bluetooth_proxy::global_bluetooth_proxy->bluetooth_scanner_set_mode( + msg.mode == enums::BluetoothScannerMode::BLUETOOTH_SCANNER_MODE_ACTIVE); +} #endif #ifdef USE_VOICE_ASSISTANT diff --git a/esphome/components/api/api_connection.h b/esphome/components/api/api_connection.h index 09534af8dc..3fefe71cbb 100644 --- a/esphome/components/api/api_connection.h +++ b/esphome/components/api/api_connection.h @@ -221,6 +221,7 @@ class APIConnection : public APIServerConnection { void bluetooth_gatt_notify(const BluetoothGATTNotifyRequest &msg) override; BluetoothConnectionsFreeResponse subscribe_bluetooth_connections_free( const SubscribeBluetoothConnectionsFreeRequest &msg) override; + void bluetooth_scanner_set_mode(const BluetoothScannerSetModeRequest &msg) override; #endif #ifdef USE_HOMEASSISTANT_TIME diff --git a/esphome/components/api/api_pb2.cpp b/esphome/components/api/api_pb2.cpp index 45d620715a..90e5bcb548 100644 --- a/esphome/components/api/api_pb2.cpp +++ b/esphome/components/api/api_pb2.cpp @@ -422,6 +422,38 @@ const char *proto_enum_to_string(enums::Bluet } #endif #ifdef HAS_PROTO_MESSAGE_DUMP +template<> const char *proto_enum_to_string(enums::BluetoothScannerState value) { + switch (value) { + case enums::BLUETOOTH_SCANNER_STATE_IDLE: + return "BLUETOOTH_SCANNER_STATE_IDLE"; + case enums::BLUETOOTH_SCANNER_STATE_STARTING: + return "BLUETOOTH_SCANNER_STATE_STARTING"; + case enums::BLUETOOTH_SCANNER_STATE_RUNNING: + return "BLUETOOTH_SCANNER_STATE_RUNNING"; + case enums::BLUETOOTH_SCANNER_STATE_FAILED: + return "BLUETOOTH_SCANNER_STATE_FAILED"; + case enums::BLUETOOTH_SCANNER_STATE_STOPPING: + return "BLUETOOTH_SCANNER_STATE_STOPPING"; + case enums::BLUETOOTH_SCANNER_STATE_STOPPED: + return "BLUETOOTH_SCANNER_STATE_STOPPED"; + default: + return "UNKNOWN"; + } +} +#endif +#ifdef HAS_PROTO_MESSAGE_DUMP +template<> const char *proto_enum_to_string(enums::BluetoothScannerMode value) { + switch (value) { + case enums::BLUETOOTH_SCANNER_MODE_PASSIVE: + return "BLUETOOTH_SCANNER_MODE_PASSIVE"; + case enums::BLUETOOTH_SCANNER_MODE_ACTIVE: + return "BLUETOOTH_SCANNER_MODE_ACTIVE"; + default: + return "UNKNOWN"; + } +} +#endif +#ifdef HAS_PROTO_MESSAGE_DUMP template<> const char *proto_enum_to_string(enums::VoiceAssistantSubscribeFlag value) { switch (value) { @@ -6775,6 +6807,61 @@ void BluetoothDeviceClearCacheResponse::dump_to(std::string &out) const { out.append("}"); } #endif +bool BluetoothScannerStateResponse::decode_varint(uint32_t field_id, ProtoVarInt value) { + switch (field_id) { + case 1: { + this->state = value.as_enum(); + return true; + } + case 2: { + this->mode = value.as_enum(); + return true; + } + default: + return false; + } +} +void BluetoothScannerStateResponse::encode(ProtoWriteBuffer buffer) const { + buffer.encode_enum(1, this->state); + buffer.encode_enum(2, this->mode); +} +#ifdef HAS_PROTO_MESSAGE_DUMP +void BluetoothScannerStateResponse::dump_to(std::string &out) const { + __attribute__((unused)) char buffer[64]; + out.append("BluetoothScannerStateResponse {\n"); + out.append(" state: "); + out.append(proto_enum_to_string(this->state)); + out.append("\n"); + + out.append(" mode: "); + out.append(proto_enum_to_string(this->mode)); + out.append("\n"); + out.append("}"); +} +#endif +bool BluetoothScannerSetModeRequest::decode_varint(uint32_t field_id, ProtoVarInt value) { + switch (field_id) { + case 1: { + this->mode = value.as_enum(); + return true; + } + default: + return false; + } +} +void BluetoothScannerSetModeRequest::encode(ProtoWriteBuffer buffer) const { + buffer.encode_enum(1, this->mode); +} +#ifdef HAS_PROTO_MESSAGE_DUMP +void BluetoothScannerSetModeRequest::dump_to(std::string &out) const { + __attribute__((unused)) char buffer[64]; + out.append("BluetoothScannerSetModeRequest {\n"); + out.append(" mode: "); + out.append(proto_enum_to_string(this->mode)); + out.append("\n"); + out.append("}"); +} +#endif bool SubscribeVoiceAssistantRequest::decode_varint(uint32_t field_id, ProtoVarInt value) { switch (field_id) { case 1: { diff --git a/esphome/components/api/api_pb2.h b/esphome/components/api/api_pb2.h index 383d566a16..18e4002107 100644 --- a/esphome/components/api/api_pb2.h +++ b/esphome/components/api/api_pb2.h @@ -169,6 +169,18 @@ enum BluetoothDeviceRequestType : uint32_t { BLUETOOTH_DEVICE_REQUEST_TYPE_CONNECT_V3_WITHOUT_CACHE = 5, BLUETOOTH_DEVICE_REQUEST_TYPE_CLEAR_CACHE = 6, }; +enum BluetoothScannerState : uint32_t { + BLUETOOTH_SCANNER_STATE_IDLE = 0, + BLUETOOTH_SCANNER_STATE_STARTING = 1, + BLUETOOTH_SCANNER_STATE_RUNNING = 2, + BLUETOOTH_SCANNER_STATE_FAILED = 3, + BLUETOOTH_SCANNER_STATE_STOPPING = 4, + BLUETOOTH_SCANNER_STATE_STOPPED = 5, +}; +enum BluetoothScannerMode : uint32_t { + BLUETOOTH_SCANNER_MODE_PASSIVE = 0, + BLUETOOTH_SCANNER_MODE_ACTIVE = 1, +}; enum VoiceAssistantSubscribeFlag : uint32_t { VOICE_ASSISTANT_SUBSCRIBE_NONE = 0, VOICE_ASSISTANT_SUBSCRIBE_API_AUDIO = 1, @@ -1742,6 +1754,29 @@ class BluetoothDeviceClearCacheResponse : public ProtoMessage { protected: bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; +class BluetoothScannerStateResponse : public ProtoMessage { + public: + enums::BluetoothScannerState state{}; + enums::BluetoothScannerMode mode{}; + void encode(ProtoWriteBuffer buffer) const override; +#ifdef HAS_PROTO_MESSAGE_DUMP + void dump_to(std::string &out) const override; +#endif + + protected: + bool decode_varint(uint32_t field_id, ProtoVarInt value) override; +}; +class BluetoothScannerSetModeRequest : public ProtoMessage { + public: + enums::BluetoothScannerMode mode{}; + void encode(ProtoWriteBuffer buffer) const override; +#ifdef HAS_PROTO_MESSAGE_DUMP + void dump_to(std::string &out) const override; +#endif + + protected: + bool decode_varint(uint32_t field_id, ProtoVarInt value) override; +}; class SubscribeVoiceAssistantRequest : public ProtoMessage { public: bool subscribe{false}; diff --git a/esphome/components/api/api_pb2_service.cpp b/esphome/components/api/api_pb2_service.cpp index 8238bcf96d..dd86c9538a 100644 --- a/esphome/components/api/api_pb2_service.cpp +++ b/esphome/components/api/api_pb2_service.cpp @@ -472,6 +472,16 @@ bool APIServerConnectionBase::send_bluetooth_device_clear_cache_response(const B return this->send_message_(msg, 88); } #endif +#ifdef USE_BLUETOOTH_PROXY +bool APIServerConnectionBase::send_bluetooth_scanner_state_response(const BluetoothScannerStateResponse &msg) { +#ifdef HAS_PROTO_MESSAGE_DUMP + ESP_LOGVV(TAG, "send_bluetooth_scanner_state_response: %s", msg.dump().c_str()); +#endif + return this->send_message_(msg, 126); +} +#endif +#ifdef USE_BLUETOOTH_PROXY +#endif #ifdef USE_VOICE_ASSISTANT #endif #ifdef USE_VOICE_ASSISTANT @@ -1212,6 +1222,17 @@ bool APIServerConnectionBase::read_message(uint32_t msg_size, uint32_t msg_type, ESP_LOGVV(TAG, "on_noise_encryption_set_key_request: %s", msg.dump().c_str()); #endif this->on_noise_encryption_set_key_request(msg); +#endif + break; + } + case 127: { +#ifdef USE_BLUETOOTH_PROXY + BluetoothScannerSetModeRequest msg; + msg.decode(msg_data, msg_size); +#ifdef HAS_PROTO_MESSAGE_DUMP + ESP_LOGVV(TAG, "on_bluetooth_scanner_set_mode_request: %s", msg.dump().c_str()); +#endif + this->on_bluetooth_scanner_set_mode_request(msg); #endif break; } @@ -1705,6 +1726,19 @@ void APIServerConnection::on_unsubscribe_bluetooth_le_advertisements_request( this->unsubscribe_bluetooth_le_advertisements(msg); } #endif +#ifdef USE_BLUETOOTH_PROXY +void APIServerConnection::on_bluetooth_scanner_set_mode_request(const BluetoothScannerSetModeRequest &msg) { + if (!this->is_connection_setup()) { + this->on_no_setup_connection(); + return; + } + if (!this->is_authenticated()) { + this->on_unauthenticated_access(); + return; + } + this->bluetooth_scanner_set_mode(msg); +} +#endif #ifdef USE_VOICE_ASSISTANT void APIServerConnection::on_subscribe_voice_assistant_request(const SubscribeVoiceAssistantRequest &msg) { if (!this->is_connection_setup()) { diff --git a/esphome/components/api/api_pb2_service.h b/esphome/components/api/api_pb2_service.h index 4a3a1da8f0..1012d8a65b 100644 --- a/esphome/components/api/api_pb2_service.h +++ b/esphome/components/api/api_pb2_service.h @@ -234,6 +234,12 @@ class APIServerConnectionBase : public ProtoService { #ifdef USE_BLUETOOTH_PROXY bool send_bluetooth_device_clear_cache_response(const BluetoothDeviceClearCacheResponse &msg); #endif +#ifdef USE_BLUETOOTH_PROXY + bool send_bluetooth_scanner_state_response(const BluetoothScannerStateResponse &msg); +#endif +#ifdef USE_BLUETOOTH_PROXY + virtual void on_bluetooth_scanner_set_mode_request(const BluetoothScannerSetModeRequest &value){}; +#endif #ifdef USE_VOICE_ASSISTANT virtual void on_subscribe_voice_assistant_request(const SubscribeVoiceAssistantRequest &value){}; #endif @@ -440,6 +446,9 @@ class APIServerConnection : public APIServerConnectionBase { #ifdef USE_BLUETOOTH_PROXY virtual void unsubscribe_bluetooth_le_advertisements(const UnsubscribeBluetoothLEAdvertisementsRequest &msg) = 0; #endif +#ifdef USE_BLUETOOTH_PROXY + virtual void bluetooth_scanner_set_mode(const BluetoothScannerSetModeRequest &msg) = 0; +#endif #ifdef USE_VOICE_ASSISTANT virtual void subscribe_voice_assistant(const SubscribeVoiceAssistantRequest &msg) = 0; #endif @@ -551,6 +560,9 @@ class APIServerConnection : public APIServerConnectionBase { void on_unsubscribe_bluetooth_le_advertisements_request( const UnsubscribeBluetoothLEAdvertisementsRequest &msg) override; #endif +#ifdef USE_BLUETOOTH_PROXY + void on_bluetooth_scanner_set_mode_request(const BluetoothScannerSetModeRequest &msg) override; +#endif #ifdef USE_VOICE_ASSISTANT void on_subscribe_voice_assistant_request(const SubscribeVoiceAssistantRequest &msg) override; #endif diff --git a/esphome/components/bluetooth_proxy/bluetooth_proxy.cpp b/esphome/components/bluetooth_proxy/bluetooth_proxy.cpp index 03213432cd..e40f4e5dcc 100644 --- a/esphome/components/bluetooth_proxy/bluetooth_proxy.cpp +++ b/esphome/components/bluetooth_proxy/bluetooth_proxy.cpp @@ -25,6 +25,22 @@ std::vector get_128bit_uuid_vec(esp_bt_uuid_t uuid_source) { BluetoothProxy::BluetoothProxy() { global_bluetooth_proxy = this; } +void BluetoothProxy::setup() { + this->parent_->add_scanner_state_callback([this](esp32_ble_tracker::ScannerState state) { + if (this->api_connection_ != nullptr) { + this->send_bluetooth_scanner_state_(state); + } + }); +} + +void BluetoothProxy::send_bluetooth_scanner_state_(esp32_ble_tracker::ScannerState state) { + api::BluetoothScannerStateResponse resp; + resp.state = static_cast(state); + resp.mode = this->parent_->get_scan_active() ? api::enums::BluetoothScannerMode::BLUETOOTH_SCANNER_MODE_ACTIVE + : api::enums::BluetoothScannerMode::BLUETOOTH_SCANNER_MODE_PASSIVE; + this->api_connection_->send_bluetooth_scanner_state_response(resp); +} + bool BluetoothProxy::parse_device(const esp32_ble_tracker::ESPBTDevice &device) { if (!api::global_api_server->is_connected() || this->api_connection_ == nullptr || this->raw_advertisements_) return false; @@ -453,6 +469,8 @@ void BluetoothProxy::subscribe_api_connection(api::APIConnection *api_connection this->api_connection_ = api_connection; this->raw_advertisements_ = flags & BluetoothProxySubscriptionFlag::SUBSCRIPTION_RAW_ADVERTISEMENTS; this->parent_->recalculate_advertisement_parser_types(); + + this->send_bluetooth_scanner_state_(this->parent_->get_scanner_state()); } void BluetoothProxy::unsubscribe_api_connection(api::APIConnection *api_connection) { @@ -525,6 +543,17 @@ void BluetoothProxy::send_device_unpairing(uint64_t address, bool success, esp_e this->api_connection_->send_bluetooth_device_unpairing_response(call); } +void BluetoothProxy::bluetooth_scanner_set_mode(bool active) { + if (this->parent_->get_scan_active() == active) { + return; + } + ESP_LOGD(TAG, "Setting scanner mode to %s", active ? "active" : "passive"); + this->parent_->set_scan_active(active); + this->parent_->stop_scan(); + this->parent_->set_scan_continuous( + true); // Set this to true to automatically start scanning again when it has cleaned up. +} + BluetoothProxy *global_bluetooth_proxy = nullptr; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) } // namespace bluetooth_proxy diff --git a/esphome/components/bluetooth_proxy/bluetooth_proxy.h b/esphome/components/bluetooth_proxy/bluetooth_proxy.h index e0345ff248..de24165fe8 100644 --- a/esphome/components/bluetooth_proxy/bluetooth_proxy.h +++ b/esphome/components/bluetooth_proxy/bluetooth_proxy.h @@ -41,6 +41,7 @@ enum BluetoothProxyFeature : uint32_t { FEATURE_PAIRING = 1 << 3, FEATURE_CACHE_CLEARING = 1 << 4, FEATURE_RAW_ADVERTISEMENTS = 1 << 5, + FEATURE_STATE_AND_MODE = 1 << 6, }; enum BluetoothProxySubscriptionFlag : uint32_t { @@ -53,6 +54,7 @@ class BluetoothProxy : public esp32_ble_tracker::ESPBTDeviceListener, public Com bool parse_device(const esp32_ble_tracker::ESPBTDevice &device) override; bool parse_devices(esp_ble_gap_cb_param_t::ble_scan_result_evt_param *advertisements, size_t count) override; void dump_config() override; + void setup() override; void loop() override; esp32_ble_tracker::AdvertisementParserType get_advertisement_parser_type() override; @@ -84,6 +86,8 @@ class BluetoothProxy : public esp32_ble_tracker::ESPBTDeviceListener, public Com void send_device_unpairing(uint64_t address, bool success, esp_err_t error = ESP_OK); void send_device_clear_cache(uint64_t address, bool success, esp_err_t error = ESP_OK); + void bluetooth_scanner_set_mode(bool active); + static void uint64_to_bd_addr(uint64_t address, esp_bd_addr_t bd_addr) { bd_addr[0] = (address >> 40) & 0xff; bd_addr[1] = (address >> 32) & 0xff; @@ -107,6 +111,7 @@ class BluetoothProxy : public esp32_ble_tracker::ESPBTDeviceListener, public Com uint32_t flags = 0; flags |= BluetoothProxyFeature::FEATURE_PASSIVE_SCAN; flags |= BluetoothProxyFeature::FEATURE_RAW_ADVERTISEMENTS; + flags |= BluetoothProxyFeature::FEATURE_STATE_AND_MODE; if (this->active_) { flags |= BluetoothProxyFeature::FEATURE_ACTIVE_CONNECTIONS; flags |= BluetoothProxyFeature::FEATURE_REMOTE_CACHING; @@ -124,6 +129,7 @@ class BluetoothProxy : public esp32_ble_tracker::ESPBTDeviceListener, public Com protected: void send_api_packet_(const esp32_ble_tracker::ESPBTDevice &device); + void send_bluetooth_scanner_state_(esp32_ble_tracker::ScannerState state); BluetoothConnection *get_connection_(uint64_t address, bool reserve); diff --git a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp index 34d4e6727a..0dc0f58fa2 100644 --- a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp +++ b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp @@ -245,7 +245,7 @@ void ESP32BLETracker::stop_scan_() { return; } this->cancel_timeout("scan"); - this->scanner_state_ = ScannerState::STOPPING; + this->set_scanner_state_(ScannerState::STOPPING); esp_err_t err = esp_ble_gap_stop_scanning(); if (err != ESP_OK) { ESP_LOGE(TAG, "esp_ble_gap_stop_scanning failed: %d", err); @@ -272,7 +272,7 @@ void ESP32BLETracker::start_scan_(bool first) { } return; } - this->scanner_state_ = ScannerState::STARTING; + this->set_scanner_state_(ScannerState::STARTING); ESP_LOGD(TAG, "Starting scan, set scanner state to STARTING."); if (!first) { for (auto *listener : this->listeners_) @@ -315,7 +315,7 @@ void ESP32BLETracker::end_of_scan_() { for (auto *listener : this->listeners_) listener->on_scan_end(); - this->scanner_state_ = ScannerState::IDLE; + this->set_scanner_state_(ScannerState::IDLE); } void ESP32BLETracker::register_client(ESPBTClient *client) { @@ -398,9 +398,9 @@ void ESP32BLETracker::gap_scan_start_complete_(const esp_ble_gap_cb_param_t::ble } if (param.status == ESP_BT_STATUS_SUCCESS) { this->scan_start_fail_count_ = 0; - this->scanner_state_ = ScannerState::RUNNING; + this->set_scanner_state_(ScannerState::RUNNING); } else { - this->scanner_state_ = ScannerState::FAILED; + this->set_scanner_state_(ScannerState::FAILED); if (this->scan_start_fail_count_ != std::numeric_limits::max()) { this->scan_start_fail_count_++; } @@ -422,7 +422,7 @@ void ESP32BLETracker::gap_scan_stop_complete_(const esp_ble_gap_cb_param_t::ble_ ESP_LOGE(TAG, "Scan was stopped when stop complete."); } } - this->scanner_state_ = ScannerState::STOPPED; + this->set_scanner_state_(ScannerState::STOPPED); } void ESP32BLETracker::gap_scan_result_(const esp_ble_gap_cb_param_t::ble_scan_result_evt_param ¶m) { @@ -449,7 +449,7 @@ void ESP32BLETracker::gap_scan_result_(const esp_ble_gap_cb_param_t::ble_scan_re ESP_LOGE(TAG, "Scan was stopped when scan completed."); } } - this->scanner_state_ = ScannerState::STOPPED; + this->set_scanner_state_(ScannerState::STOPPED); } } @@ -460,6 +460,11 @@ void ESP32BLETracker::gattc_event_handler(esp_gattc_cb_event_t event, esp_gatt_i } } +void ESP32BLETracker::set_scanner_state_(ScannerState state) { + this->scanner_state_ = state; + this->scanner_state_callbacks_.call(state); +} + ESPBLEiBeacon::ESPBLEiBeacon(const uint8_t *data) { memcpy(&this->beacon_data_, data, sizeof(beacon_data_)); } optional ESPBLEiBeacon::from_manufacturer_data(const ServiceData &data) { if (!data.uuid.contains(0x4C, 0x00)) diff --git a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h index 6ca763db07..ca2e53c343 100644 --- a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h +++ b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h @@ -218,6 +218,7 @@ class ESP32BLETracker : public Component, void set_scan_interval(uint32_t scan_interval) { scan_interval_ = scan_interval; } void set_scan_window(uint32_t scan_window) { scan_window_ = scan_window; } void set_scan_active(bool scan_active) { scan_active_ = scan_active; } + bool get_scan_active() const { return scan_active_; } void set_scan_continuous(bool scan_continuous) { scan_continuous_ = scan_continuous; } /// Setup the FreeRTOS task and the Bluetooth stack. @@ -241,6 +242,11 @@ class ESP32BLETracker : public Component, void gap_event_handler(esp_gap_ble_cb_event_t event, esp_ble_gap_cb_param_t *param) override; void ble_before_disabled_event_handler() override; + void add_scanner_state_callback(std::function &&callback) { + this->scanner_state_callbacks_.add(std::move(callback)); + } + ScannerState get_scanner_state() const { return this->scanner_state_; } + protected: void stop_scan_(); /// Start a single scan by setting up the parameters and doing some esp-idf calls. @@ -255,6 +261,8 @@ class ESP32BLETracker : public Component, void gap_scan_start_complete_(const esp_ble_gap_cb_param_t::ble_scan_start_cmpl_evt_param ¶m); /// Called when a `ESP_GAP_BLE_SCAN_STOP_COMPLETE_EVT` event is received. void gap_scan_stop_complete_(const esp_ble_gap_cb_param_t::ble_scan_stop_cmpl_evt_param ¶m); + /// Called to set the scanner state. Will also call callbacks to let listeners know when state is changed. + void set_scanner_state_(ScannerState state); int app_id_{0}; @@ -273,6 +281,7 @@ class ESP32BLETracker : public Component, bool scan_continuous_; bool scan_active_; ScannerState scanner_state_{ScannerState::IDLE}; + CallbackManager scanner_state_callbacks_; bool ble_was_disabled_{true}; bool raw_advertisements_{false}; bool parse_advertisements_{false}; From c756bb3b3e294f2ba1d967c05358dbcf7a920bd1 Mon Sep 17 00:00:00 2001 From: Thomas Rupprecht Date: Tue, 29 Apr 2025 21:29:04 +0200 Subject: [PATCH 046/102] [pmsa003i] code improvements (#8485) --- esphome/components/pmsa003i/pmsa003i.cpp | 59 +++++++++++++++++------- esphome/components/pmsa003i/pmsa003i.h | 50 ++++++++++---------- 2 files changed, 68 insertions(+), 41 deletions(-) diff --git a/esphome/components/pmsa003i/pmsa003i.cpp b/esphome/components/pmsa003i/pmsa003i.cpp index a9665c6a5a..36f9c9a132 100644 --- a/esphome/components/pmsa003i/pmsa003i.cpp +++ b/esphome/components/pmsa003i/pmsa003i.cpp @@ -1,5 +1,6 @@ #include "pmsa003i.h" #include "esphome/core/log.h" +#include "esphome/core/helpers.h" #include namespace esphome { @@ -7,6 +8,16 @@ namespace pmsa003i { static const char *const TAG = "pmsa003i"; +static const uint8_t COUNT_PAYLOAD_BYTES = 28; +static const uint8_t COUNT_PAYLOAD_LENGTH_BYTES = 2; +static const uint8_t COUNT_START_CHARACTER_BYTES = 2; +static const uint8_t COUNT_DATA_BYTES = COUNT_START_CHARACTER_BYTES + COUNT_PAYLOAD_LENGTH_BYTES + COUNT_PAYLOAD_BYTES; +static const uint8_t CHECKSUM_START_INDEX = COUNT_DATA_BYTES - 2; +static const uint8_t COUNT_16_BIT_VALUES = (COUNT_PAYLOAD_LENGTH_BYTES + COUNT_PAYLOAD_BYTES) / 2; +static const uint8_t START_CHARACTER_1 = 0x42; +static const uint8_t START_CHARACTER_2 = 0x4D; +static const uint8_t READ_DATA_RETRY_COUNT = 3; + void PMSA003IComponent::setup() { ESP_LOGCONFIG(TAG, "Setting up pmsa003i..."); @@ -14,7 +25,7 @@ void PMSA003IComponent::setup() { bool successful_read = this->read_data_(&data); if (!successful_read) { - for (int i = 0; i < 3; i++) { + for (uint8_t i = 0; i < READ_DATA_RETRY_COUNT; i++) { successful_read = this->read_data_(&data); if (successful_read) { break; @@ -28,7 +39,10 @@ void PMSA003IComponent::setup() { } } -void PMSA003IComponent::dump_config() { LOG_I2C_DEVICE(this); } +void PMSA003IComponent::dump_config() { + ESP_LOGCONFIG(TAG, "PMSA003I:"); + LOG_I2C_DEVICE(this); +} void PMSA003IComponent::update() { PM25AQIData data; @@ -75,35 +89,48 @@ void PMSA003IComponent::update() { } bool PMSA003IComponent::read_data_(PM25AQIData *data) { - const uint8_t num_bytes = 32; - uint8_t buffer[num_bytes]; + uint8_t buffer[COUNT_DATA_BYTES]; - this->read_bytes_raw(buffer, num_bytes); + this->read_bytes_raw(buffer, COUNT_DATA_BYTES); // https://github.com/adafruit/Adafruit_PM25AQI // Check that start byte is correct! - if (buffer[0] != 0x42) { + if (buffer[0] != START_CHARACTER_1 || buffer[1] != START_CHARACTER_2) { + ESP_LOGW(TAG, "Start character mismatch: %02X %02X != %02X %02X", buffer[0], buffer[1], START_CHARACTER_1, + START_CHARACTER_2); return false; } - // get checksum ready - int16_t sum = 0; - for (uint8_t i = 0; i < 30; i++) { - sum += buffer[i]; + const uint16_t payload_length = encode_uint16(buffer[2], buffer[3]); + if (payload_length != COUNT_PAYLOAD_BYTES) { + ESP_LOGW(TAG, "Payload length mismatch: %u != %u", payload_length, COUNT_PAYLOAD_BYTES); + return false; + } + + // Calculate checksum + uint16_t checksum = 0; + for (uint8_t i = 0; i < CHECKSUM_START_INDEX; i++) { + checksum += buffer[i]; + } + + const uint16_t check = encode_uint16(buffer[CHECKSUM_START_INDEX], buffer[CHECKSUM_START_INDEX + 1]); + if (checksum != check) { + ESP_LOGW(TAG, "Checksum mismatch: %u != %u", checksum, check); + return false; } // The data comes in endian'd, this solves it so it works on all platforms - uint16_t buffer_u16[15]; - for (uint8_t i = 0; i < 15; i++) { - buffer_u16[i] = buffer[2 + i * 2 + 1]; - buffer_u16[i] += (buffer[2 + i * 2] << 8); + uint16_t buffer_u16[COUNT_16_BIT_VALUES]; + for (uint8_t i = 0; i < COUNT_16_BIT_VALUES; i++) { + const uint8_t buffer_index = COUNT_START_CHARACTER_BYTES + i * 2; + buffer_u16[i] = encode_uint16(buffer[buffer_index], buffer[buffer_index + 1]); } // put it into a nice struct :) - memcpy((void *) data, (void *) buffer_u16, 30); + memcpy((void *) data, (void *) buffer_u16, COUNT_16_BIT_VALUES * 2); - return (sum == data->checksum); + return true; } } // namespace pmsa003i diff --git a/esphome/components/pmsa003i/pmsa003i.h b/esphome/components/pmsa003i/pmsa003i.h index 1fe4139951..59f39a7314 100644 --- a/esphome/components/pmsa003i/pmsa003i.h +++ b/esphome/components/pmsa003i/pmsa003i.h @@ -10,21 +10,21 @@ namespace pmsa003i { /**! Structure holding Plantower's standard packet **/ // From https://github.com/adafruit/Adafruit_PM25AQI struct PM25AQIData { - uint16_t framelen; ///< How long this data chunk is - uint16_t pm10_standard, ///< Standard PM1.0 - pm25_standard, ///< Standard PM2.5 - pm100_standard; ///< Standard PM10.0 - uint16_t pm10_env, ///< Environmental PM1.0 - pm25_env, ///< Environmental PM2.5 - pm100_env; ///< Environmental PM10.0 - uint16_t particles_03um, ///> 0.3um Particle Count - particles_05um, ///> 0.5um Particle Count - particles_10um, ///> 1.0um Particle Count - particles_25um, ///> 2.5um Particle Count - particles_50um, ///> 5.0um Particle Count - particles_100um; ///> 10.0um Particle Count - uint16_t unused; ///< Unused - uint16_t checksum; ///< Packet checksum + uint16_t framelen; ///< How long this data chunk is + uint16_t pm10_standard; ///< Standard PM1.0 + uint16_t pm25_standard; ///< Standard PM2.5 + uint16_t pm100_standard; ///< Standard PM10.0 + uint16_t pm10_env; ///< Environmental PM1.0 + uint16_t pm25_env; ///< Environmental PM2.5 + uint16_t pm100_env; ///< Environmental PM10.0 + uint16_t particles_03um; ///< 0.3um Particle Count + uint16_t particles_05um; ///< 0.5um Particle Count + uint16_t particles_10um; ///< 1.0um Particle Count + uint16_t particles_25um; ///< 2.5um Particle Count + uint16_t particles_50um; ///< 5.0um Particle Count + uint16_t particles_100um; ///< 10.0um Particle Count + uint16_t unused; ///< Unused + uint16_t checksum; ///< Packet checksum }; class PMSA003IComponent : public PollingComponent, public i2c::I2CDevice { @@ -34,18 +34,18 @@ class PMSA003IComponent : public PollingComponent, public i2c::I2CDevice { void update() override; float get_setup_priority() const override { return setup_priority::DATA; } - void set_standard_units(bool standard_units) { standard_units_ = standard_units; } + void set_standard_units(bool standard_units) { this->standard_units_ = standard_units; } - void set_pm_1_0_sensor(sensor::Sensor *pm_1_0) { pm_1_0_sensor_ = pm_1_0; } - void set_pm_2_5_sensor(sensor::Sensor *pm_2_5) { pm_2_5_sensor_ = pm_2_5; } - void set_pm_10_0_sensor(sensor::Sensor *pm_10_0) { pm_10_0_sensor_ = pm_10_0; } + void set_pm_1_0_sensor(sensor::Sensor *pm_1_0) { this->pm_1_0_sensor_ = pm_1_0; } + void set_pm_2_5_sensor(sensor::Sensor *pm_2_5) { this->pm_2_5_sensor_ = pm_2_5; } + void set_pm_10_0_sensor(sensor::Sensor *pm_10_0) { this->pm_10_0_sensor_ = pm_10_0; } - void set_pmc_0_3_sensor(sensor::Sensor *pmc_0_3) { pmc_0_3_sensor_ = pmc_0_3; } - void set_pmc_0_5_sensor(sensor::Sensor *pmc_0_5) { pmc_0_5_sensor_ = pmc_0_5; } - void set_pmc_1_0_sensor(sensor::Sensor *pmc_1_0) { pmc_1_0_sensor_ = pmc_1_0; } - void set_pmc_2_5_sensor(sensor::Sensor *pmc_2_5) { pmc_2_5_sensor_ = pmc_2_5; } - void set_pmc_5_0_sensor(sensor::Sensor *pmc_5_0) { pmc_5_0_sensor_ = pmc_5_0; } - void set_pmc_10_0_sensor(sensor::Sensor *pmc_10_0) { pmc_10_0_sensor_ = pmc_10_0; } + void set_pmc_0_3_sensor(sensor::Sensor *pmc_0_3) { this->pmc_0_3_sensor_ = pmc_0_3; } + void set_pmc_0_5_sensor(sensor::Sensor *pmc_0_5) { this->pmc_0_5_sensor_ = pmc_0_5; } + void set_pmc_1_0_sensor(sensor::Sensor *pmc_1_0) { this->pmc_1_0_sensor_ = pmc_1_0; } + void set_pmc_2_5_sensor(sensor::Sensor *pmc_2_5) { this->pmc_2_5_sensor_ = pmc_2_5; } + void set_pmc_5_0_sensor(sensor::Sensor *pmc_5_0) { this->pmc_5_0_sensor_ = pmc_5_0; } + void set_pmc_10_0_sensor(sensor::Sensor *pmc_10_0) { this->pmc_10_0_sensor_ = pmc_10_0; } protected: bool read_data_(PM25AQIData *data); From 0fe6c65ba37d2919b24bb8bc7191ddaa4defe219 Mon Sep 17 00:00:00 2001 From: Thomas Rupprecht Date: Tue, 29 Apr 2025 22:08:08 +0200 Subject: [PATCH 047/102] [adc] sort variants and add links to reference implementations (#8327) --- esphome/components/adc/__init__.py | 93 ++++++++++++++++++------------ 1 file changed, 55 insertions(+), 38 deletions(-) diff --git a/esphome/components/adc/__init__.py b/esphome/components/adc/__init__.py index be420475fb..5f94c61a08 100644 --- a/esphome/components/adc/__init__.py +++ b/esphome/components/adc/__init__.py @@ -47,9 +47,10 @@ SAMPLING_MODES = { adc1_channel_t = cg.global_ns.enum("adc1_channel_t") adc2_channel_t = cg.global_ns.enum("adc2_channel_t") -# From https://github.com/espressif/esp-idf/blob/master/components/driver/include/driver/adc_common.h # pin to adc1 channel mapping +# https://github.com/espressif/esp-idf/blob/v4.4.8/components/driver/include/driver/adc.h ESP32_VARIANT_ADC1_PIN_TO_CHANNEL = { + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32/include/soc/adc_channel.h VARIANT_ESP32: { 36: adc1_channel_t.ADC1_CHANNEL_0, 37: adc1_channel_t.ADC1_CHANNEL_1, @@ -60,6 +61,41 @@ ESP32_VARIANT_ADC1_PIN_TO_CHANNEL = { 34: adc1_channel_t.ADC1_CHANNEL_6, 35: adc1_channel_t.ADC1_CHANNEL_7, }, + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32c2/include/soc/adc_channel.h + VARIANT_ESP32C2: { + 0: adc1_channel_t.ADC1_CHANNEL_0, + 1: adc1_channel_t.ADC1_CHANNEL_1, + 2: adc1_channel_t.ADC1_CHANNEL_2, + 3: adc1_channel_t.ADC1_CHANNEL_3, + 4: adc1_channel_t.ADC1_CHANNEL_4, + }, + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32c3/include/soc/adc_channel.h + VARIANT_ESP32C3: { + 0: adc1_channel_t.ADC1_CHANNEL_0, + 1: adc1_channel_t.ADC1_CHANNEL_1, + 2: adc1_channel_t.ADC1_CHANNEL_2, + 3: adc1_channel_t.ADC1_CHANNEL_3, + 4: adc1_channel_t.ADC1_CHANNEL_4, + }, + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32c6/include/soc/adc_channel.h + VARIANT_ESP32C6: { + 0: adc1_channel_t.ADC1_CHANNEL_0, + 1: adc1_channel_t.ADC1_CHANNEL_1, + 2: adc1_channel_t.ADC1_CHANNEL_2, + 3: adc1_channel_t.ADC1_CHANNEL_3, + 4: adc1_channel_t.ADC1_CHANNEL_4, + 5: adc1_channel_t.ADC1_CHANNEL_5, + 6: adc1_channel_t.ADC1_CHANNEL_6, + }, + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32h2/include/soc/adc_channel.h + VARIANT_ESP32H2: { + 1: adc1_channel_t.ADC1_CHANNEL_0, + 2: adc1_channel_t.ADC1_CHANNEL_1, + 3: adc1_channel_t.ADC1_CHANNEL_2, + 4: adc1_channel_t.ADC1_CHANNEL_3, + 5: adc1_channel_t.ADC1_CHANNEL_4, + }, + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32s2/include/soc/adc_channel.h VARIANT_ESP32S2: { 1: adc1_channel_t.ADC1_CHANNEL_0, 2: adc1_channel_t.ADC1_CHANNEL_1, @@ -72,6 +108,7 @@ ESP32_VARIANT_ADC1_PIN_TO_CHANNEL = { 9: adc1_channel_t.ADC1_CHANNEL_8, 10: adc1_channel_t.ADC1_CHANNEL_9, }, + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32s3/include/soc/adc_channel.h VARIANT_ESP32S3: { 1: adc1_channel_t.ADC1_CHANNEL_0, 2: adc1_channel_t.ADC1_CHANNEL_1, @@ -84,40 +121,12 @@ ESP32_VARIANT_ADC1_PIN_TO_CHANNEL = { 9: adc1_channel_t.ADC1_CHANNEL_8, 10: adc1_channel_t.ADC1_CHANNEL_9, }, - VARIANT_ESP32C3: { - 0: adc1_channel_t.ADC1_CHANNEL_0, - 1: adc1_channel_t.ADC1_CHANNEL_1, - 2: adc1_channel_t.ADC1_CHANNEL_2, - 3: adc1_channel_t.ADC1_CHANNEL_3, - 4: adc1_channel_t.ADC1_CHANNEL_4, - }, - VARIANT_ESP32C2: { - 0: adc1_channel_t.ADC1_CHANNEL_0, - 1: adc1_channel_t.ADC1_CHANNEL_1, - 2: adc1_channel_t.ADC1_CHANNEL_2, - 3: adc1_channel_t.ADC1_CHANNEL_3, - 4: adc1_channel_t.ADC1_CHANNEL_4, - }, - VARIANT_ESP32C6: { - 0: adc1_channel_t.ADC1_CHANNEL_0, - 1: adc1_channel_t.ADC1_CHANNEL_1, - 2: adc1_channel_t.ADC1_CHANNEL_2, - 3: adc1_channel_t.ADC1_CHANNEL_3, - 4: adc1_channel_t.ADC1_CHANNEL_4, - 5: adc1_channel_t.ADC1_CHANNEL_5, - 6: adc1_channel_t.ADC1_CHANNEL_6, - }, - VARIANT_ESP32H2: { - 1: adc1_channel_t.ADC1_CHANNEL_0, - 2: adc1_channel_t.ADC1_CHANNEL_1, - 3: adc1_channel_t.ADC1_CHANNEL_2, - 4: adc1_channel_t.ADC1_CHANNEL_3, - 5: adc1_channel_t.ADC1_CHANNEL_4, - }, } +# pin to adc2 channel mapping +# https://github.com/espressif/esp-idf/blob/v4.4.8/components/driver/include/driver/adc.h ESP32_VARIANT_ADC2_PIN_TO_CHANNEL = { - # TODO: add other variants + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32/include/soc/adc_channel.h VARIANT_ESP32: { 4: adc2_channel_t.ADC2_CHANNEL_0, 0: adc2_channel_t.ADC2_CHANNEL_1, @@ -130,6 +139,19 @@ ESP32_VARIANT_ADC2_PIN_TO_CHANNEL = { 25: adc2_channel_t.ADC2_CHANNEL_8, 26: adc2_channel_t.ADC2_CHANNEL_9, }, + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32c2/include/soc/adc_channel.h + VARIANT_ESP32C2: { + 5: adc2_channel_t.ADC2_CHANNEL_0, + }, + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32c3/include/soc/adc_channel.h + VARIANT_ESP32C3: { + 5: adc2_channel_t.ADC2_CHANNEL_0, + }, + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32c6/include/soc/adc_channel.h + VARIANT_ESP32C6: {}, # no ADC2 + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32h2/include/soc/adc_channel.h + VARIANT_ESP32H2: {}, # no ADC2 + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32s2/include/soc/adc_channel.h VARIANT_ESP32S2: { 11: adc2_channel_t.ADC2_CHANNEL_0, 12: adc2_channel_t.ADC2_CHANNEL_1, @@ -142,6 +164,7 @@ ESP32_VARIANT_ADC2_PIN_TO_CHANNEL = { 19: adc2_channel_t.ADC2_CHANNEL_8, 20: adc2_channel_t.ADC2_CHANNEL_9, }, + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32s3/include/soc/adc_channel.h VARIANT_ESP32S3: { 11: adc2_channel_t.ADC2_CHANNEL_0, 12: adc2_channel_t.ADC2_CHANNEL_1, @@ -154,12 +177,6 @@ ESP32_VARIANT_ADC2_PIN_TO_CHANNEL = { 19: adc2_channel_t.ADC2_CHANNEL_8, 20: adc2_channel_t.ADC2_CHANNEL_9, }, - VARIANT_ESP32C3: { - 5: adc2_channel_t.ADC2_CHANNEL_0, - }, - VARIANT_ESP32C2: {}, - VARIANT_ESP32C6: {}, - VARIANT_ESP32H2: {}, } From 9f629dcaa245053d313f9db26c778ca33c27541c Mon Sep 17 00:00:00 2001 From: Kevin Ahrendt Date: Tue, 29 Apr 2025 17:27:03 -0500 Subject: [PATCH 048/102] [i2s_audio, microphone, micro_wake_word, voice_assistant] Use microphone source to process incoming audio (#8645) Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com> --- .../i2s_audio/microphone/__init__.py | 45 ++++++++-- .../microphone/i2s_audio_microphone.cpp | 87 +++++++++---------- .../microphone/i2s_audio_microphone.h | 4 +- .../components/micro_wake_word/__init__.py | 27 +++++- .../micro_wake_word/micro_wake_word.cpp | 14 +-- .../micro_wake_word/micro_wake_word.h | 8 +- esphome/components/microphone/__init__.py | 8 +- esphome/components/microphone/automation.h | 4 +- esphome/components/microphone/microphone.h | 5 +- .../microphone/microphone_source.cpp | 4 +- .../components/voice_assistant/__init__.py | 26 +++++- .../voice_assistant/voice_assistant.cpp | 20 ++--- .../voice_assistant/voice_assistant.h | 6 +- tests/components/micro_wake_word/common.yaml | 1 + tests/components/voice_assistant/common.yaml | 5 +- 15 files changed, 166 insertions(+), 98 deletions(-) diff --git a/esphome/components/i2s_audio/microphone/__init__.py b/esphome/components/i2s_audio/microphone/__init__.py index 4950a25751..06eb29986d 100644 --- a/esphome/components/i2s_audio/microphone/__init__.py +++ b/esphome/components/i2s_audio/microphone/__init__.py @@ -1,13 +1,20 @@ from esphome import pins import esphome.codegen as cg -from esphome.components import esp32, microphone +from esphome.components import audio, esp32, microphone from esphome.components.adc import ESP32_VARIANT_ADC1_PIN_TO_CHANNEL, validate_adc_pin import esphome.config_validation as cv -from esphome.const import CONF_ID, CONF_NUMBER +from esphome.const import ( + CONF_BITS_PER_SAMPLE, + CONF_CHANNEL, + CONF_ID, + CONF_NUM_CHANNELS, + CONF_NUMBER, + CONF_SAMPLE_RATE, +) from .. import ( - CONF_CHANNEL, CONF_I2S_DIN_PIN, + CONF_LEFT, CONF_MONO, CONF_RIGHT, I2SAudioIn, @@ -32,7 +39,7 @@ INTERNAL_ADC_VARIANTS = [esp32.const.VARIANT_ESP32] PDM_VARIANTS = [esp32.const.VARIANT_ESP32, esp32.const.VARIANT_ESP32S3] -def validate_esp32_variant(config): +def _validate_esp32_variant(config): variant = esp32.get_esp32_variant() if config[CONF_ADC_TYPE] == "external": if config[CONF_PDM]: @@ -46,12 +53,34 @@ def validate_esp32_variant(config): raise NotImplementedError -def validate_channel(config): +def _validate_channel(config): if config[CONF_CHANNEL] == CONF_MONO: raise cv.Invalid(f"I2S microphone does not support {CONF_MONO}.") return config +def _set_num_channels_from_config(config): + if config[CONF_CHANNEL] in (CONF_LEFT, CONF_RIGHT): + config[CONF_NUM_CHANNELS] = 1 + else: + config[CONF_NUM_CHANNELS] = 2 + + return config + + +def _set_stream_limits(config): + audio.set_stream_limits( + min_bits_per_sample=config.get(CONF_BITS_PER_SAMPLE), + max_bits_per_sample=config.get(CONF_BITS_PER_SAMPLE), + min_channels=config.get(CONF_NUM_CHANNELS), + max_channels=config.get(CONF_NUM_CHANNELS), + min_sample_rate=config.get(CONF_SAMPLE_RATE), + max_sample_rate=config.get(CONF_SAMPLE_RATE), + )(config) + + return config + + BASE_SCHEMA = microphone.MICROPHONE_SCHEMA.extend( i2s_audio_component_schema( I2SAudioMicrophone, @@ -79,8 +108,10 @@ CONFIG_SCHEMA = cv.All( }, key=CONF_ADC_TYPE, ), - validate_esp32_variant, - validate_channel, + _validate_esp32_variant, + _validate_channel, + _set_num_channels_from_config, + _set_stream_limits, ) diff --git a/esphome/components/i2s_audio/microphone/i2s_audio_microphone.cpp b/esphome/components/i2s_audio/microphone/i2s_audio_microphone.cpp index 3ab3c88142..78a7f92c2f 100644 --- a/esphome/components/i2s_audio/microphone/i2s_audio_microphone.cpp +++ b/esphome/components/i2s_audio/microphone/i2s_audio_microphone.cpp @@ -56,6 +56,35 @@ void I2SAudioMicrophone::start_() { } esp_err_t err; + uint8_t channel_count = 1; +#ifdef USE_I2S_LEGACY + uint8_t bits_per_sample = this->bits_per_sample_; + + if (this->channel_ == I2S_CHANNEL_FMT_RIGHT_LEFT) { + channel_count = 2; + } +#else + if (this->slot_bit_width_ == I2S_SLOT_BIT_WIDTH_AUTO) { + this->slot_bit_width_ = I2S_SLOT_BIT_WIDTH_16BIT; + } + uint8_t bits_per_sample = this->slot_bit_width_; + + if (this->slot_mode_ == I2S_SLOT_MODE_STEREO) { + channel_count = 2; + } +#endif + +#ifdef USE_ESP32_VARIANT_ESP32 + // ESP32 reads audio aligned to a multiple of 2 bytes. For example, if configured for 24 bits per sample, then it will + // produce 32 bits per sample, where the actual data is in the most significant bits. Other ESP32 variants produce 24 + // bits per sample in this situation. + if (bits_per_sample < 16) { + bits_per_sample = 16; + } else if ((bits_per_sample > 16) && (bits_per_sample <= 32)) { + bits_per_sample = 32; + } +#endif + #ifdef USE_I2S_LEGACY i2s_driver_config_t config = { .mode = (i2s_mode_t) (this->i2s_mode_ | I2S_MODE_RX), @@ -144,6 +173,8 @@ void I2SAudioMicrophone::start_() { i2s_std_gpio_config_t pin_config = this->parent_->get_pin_config(); #if SOC_I2S_SUPPORTS_PDM_RX if (this->pdm_) { + bits_per_sample = 16; // PDM mics are always 16 bits per sample with the IDF 5 driver + i2s_pdm_rx_clk_config_t clk_cfg = { .sample_rate_hz = this->sample_rate_, .clk_src = clk_src, @@ -187,13 +218,8 @@ void I2SAudioMicrophone::start_() { .clk_src = clk_src, .mclk_multiple = I2S_MCLK_MULTIPLE_256, }; - i2s_data_bit_width_t data_bit_width; - if (this->slot_bit_width_ != I2S_SLOT_BIT_WIDTH_8BIT) { - data_bit_width = I2S_DATA_BIT_WIDTH_16BIT; - } else { - data_bit_width = I2S_DATA_BIT_WIDTH_8BIT; - } - i2s_std_slot_config_t std_slot_cfg = I2S_STD_PHILIPS_SLOT_DEFAULT_CONFIG(data_bit_width, this->slot_mode_); + i2s_std_slot_config_t std_slot_cfg = + I2S_STD_PHILIPS_SLOT_DEFAULT_CONFIG((i2s_data_bit_width_t) this->slot_bit_width_, this->slot_mode_); std_slot_cfg.slot_bit_width = this->slot_bit_width_; std_slot_cfg.slot_mask = this->std_slot_mask_; @@ -222,6 +248,8 @@ void I2SAudioMicrophone::start_() { } #endif + this->audio_stream_info_ = audio::AudioStreamInfo(bits_per_sample, channel_count, this->sample_rate_); + this->state_ = microphone::STATE_RUNNING; this->high_freq_.start(); this->status_clear_error(); @@ -284,7 +312,7 @@ void I2SAudioMicrophone::stop_() { this->status_clear_error(); } -size_t I2SAudioMicrophone::read(int16_t *buf, size_t len, TickType_t ticks_to_wait) { +size_t I2SAudioMicrophone::read_(uint8_t *buf, size_t len, TickType_t ticks_to_wait) { size_t bytes_read = 0; #ifdef USE_I2S_LEGACY esp_err_t err = i2s_read(this->parent_->get_port(), buf, len, &bytes_read, ticks_to_wait); @@ -303,38 +331,7 @@ size_t I2SAudioMicrophone::read(int16_t *buf, size_t len, TickType_t ticks_to_wa return 0; } this->status_clear_warning(); - // ESP-IDF I2S implementation right-extends 8-bit data to 16 bits, - // and 24-bit data to 32 bits. -#ifdef USE_I2S_LEGACY - switch (this->bits_per_sample_) { - case I2S_BITS_PER_SAMPLE_8BIT: - case I2S_BITS_PER_SAMPLE_16BIT: - return bytes_read; - case I2S_BITS_PER_SAMPLE_24BIT: - case I2S_BITS_PER_SAMPLE_32BIT: { - size_t samples_read = bytes_read / sizeof(int32_t); - for (size_t i = 0; i < samples_read; i++) { - int32_t temp = reinterpret_cast(buf)[i] >> 14; - buf[i] = clamp(temp, INT16_MIN, INT16_MAX); - } - return samples_read * sizeof(int16_t); - } - default: - ESP_LOGE(TAG, "Unsupported bits per sample: %d", this->bits_per_sample_); - return 0; - } -#else -#ifndef USE_ESP32_VARIANT_ESP32 - // For newer ESP32 variants 8 bit data needs to be extended to 16 bit. - if (this->slot_bit_width_ == I2S_SLOT_BIT_WIDTH_8BIT) { - size_t samples_read = bytes_read / sizeof(int8_t); - for (size_t i = samples_read - 1; i >= 0; i--) { - int16_t temp = static_cast(reinterpret_cast(buf)[i]) << 8; - buf[i] = temp; - } - return samples_read * sizeof(int16_t); - } -#else +#if defined(USE_ESP32_VARIANT_ESP32) and not defined(USE_I2S_LEGACY) // For ESP32 8/16 bit standard mono mode samples need to be switched. if (this->slot_mode_ == I2S_SLOT_MODE_MONO && this->slot_bit_width_ <= 16 && !this->pdm_) { size_t samples_read = bytes_read / sizeof(int16_t); @@ -346,14 +343,14 @@ size_t I2SAudioMicrophone::read(int16_t *buf, size_t len, TickType_t ticks_to_wa } #endif return bytes_read; -#endif } void I2SAudioMicrophone::read_() { - std::vector samples; - samples.resize(BUFFER_SIZE); - size_t bytes_read = this->read(samples.data(), BUFFER_SIZE * sizeof(int16_t), 0); - samples.resize(bytes_read / sizeof(int16_t)); + std::vector samples; + const size_t bytes_to_read = this->audio_stream_info_.ms_to_bytes(32); + samples.resize(bytes_to_read); + size_t bytes_read = this->read_(samples.data(), bytes_to_read, 0); + samples.resize(bytes_read); this->data_callbacks_.call(samples); } diff --git a/esphome/components/i2s_audio/microphone/i2s_audio_microphone.h b/esphome/components/i2s_audio/microphone/i2s_audio_microphone.h index 2dbacb447e..072d312e0f 100644 --- a/esphome/components/i2s_audio/microphone/i2s_audio_microphone.h +++ b/esphome/components/i2s_audio/microphone/i2s_audio_microphone.h @@ -25,9 +25,6 @@ class I2SAudioMicrophone : public I2SAudioIn, public microphone::Microphone, pub void set_pdm(bool pdm) { this->pdm_ = pdm; } - size_t read(int16_t *buf, size_t len, TickType_t ticks_to_wait); - size_t read(int16_t *buf, size_t len) override { return this->read(buf, len, pdMS_TO_TICKS(100)); } - #ifdef USE_I2S_LEGACY #if SOC_I2S_SUPPORTS_ADC void set_adc_channel(adc1_channel_t channel) { @@ -41,6 +38,7 @@ class I2SAudioMicrophone : public I2SAudioIn, public microphone::Microphone, pub void start_(); void stop_(); void read_(); + size_t read_(uint8_t *buf, size_t len, TickType_t ticks_to_wait); #ifdef USE_I2S_LEGACY int8_t din_pin_{I2S_PIN_NO_CHANGE}; diff --git a/esphome/components/micro_wake_word/__init__.py b/esphome/components/micro_wake_word/__init__.py index 0862406e46..9d5caca937 100644 --- a/esphome/components/micro_wake_word/__init__.py +++ b/esphome/components/micro_wake_word/__init__.py @@ -328,7 +328,14 @@ CONFIG_SCHEMA = cv.All( cv.Schema( { cv.GenerateID(): cv.declare_id(MicroWakeWord), - cv.GenerateID(CONF_MICROPHONE): cv.use_id(microphone.Microphone), + cv.Optional( + CONF_MICROPHONE, default={} + ): microphone.microphone_source_schema( + min_bits_per_sample=16, + max_bits_per_sample=16, + min_channels=1, + max_channels=1, + ), cv.Required(CONF_MODELS): cv.ensure_list( cv.maybe_simple_value(MODEL_SCHEMA, key=CONF_MODEL) ), @@ -404,15 +411,27 @@ def _feature_step_size_validate(config): raise cv.Invalid("Cannot load models with different features step sizes.") -FINAL_VALIDATE_SCHEMA = _feature_step_size_validate +FINAL_VALIDATE_SCHEMA = cv.All( + cv.Schema( + { + cv.Required( + CONF_MICROPHONE + ): microphone.final_validate_microphone_source_schema( + "micro_wake_word", sample_rate=16000 + ), + }, + extra=cv.ALLOW_EXTRA, + ), + _feature_step_size_validate, +) async def to_code(config): var = cg.new_Pvariable(config[CONF_ID]) await cg.register_component(var, config) - mic = await cg.get_variable(config[CONF_MICROPHONE]) - cg.add(var.set_microphone(mic)) + mic_source = await microphone.microphone_source_to_code(config[CONF_MICROPHONE]) + cg.add(var.set_microphone_source(mic_source)) esp32.add_idf_component( name="esp-tflite-micro", diff --git a/esphome/components/micro_wake_word/micro_wake_word.cpp b/esphome/components/micro_wake_word/micro_wake_word.cpp index 533aa9fb75..dd1a8be378 100644 --- a/esphome/components/micro_wake_word/micro_wake_word.cpp +++ b/esphome/components/micro_wake_word/micro_wake_word.cpp @@ -61,7 +61,7 @@ void MicroWakeWord::dump_config() { void MicroWakeWord::setup() { ESP_LOGCONFIG(TAG, "Setting up microWakeWord..."); - this->microphone_->add_data_callback([this](const std::vector &data) { + this->microphone_source_->add_data_callback([this](const std::vector &data) { if (this->state_ != State::DETECTING_WAKE_WORD) { return; } @@ -71,7 +71,7 @@ void MicroWakeWord::setup() { size_t bytes_free = temp_ring_buffer->free(); - if (bytes_free < data.size() * sizeof(int16_t)) { + if (bytes_free < data.size()) { ESP_LOGW( TAG, "Not enough free bytes in ring buffer to store incoming audio data (free bytes=%d, incoming bytes=%d). " @@ -80,7 +80,7 @@ void MicroWakeWord::setup() { temp_ring_buffer->reset(); } - temp_ring_buffer->write((void *) data.data(), data.size() * sizeof(int16_t)); + temp_ring_buffer->write((void *) data.data(), data.size()); } }); @@ -128,11 +128,11 @@ void MicroWakeWord::loop() { break; case State::START_MICROPHONE: ESP_LOGD(TAG, "Starting Microphone"); - this->microphone_->start(); + this->microphone_source_->start(); this->set_state_(State::STARTING_MICROPHONE); break; case State::STARTING_MICROPHONE: - if (this->microphone_->is_running()) { + if (this->microphone_source_->is_running()) { this->set_state_(State::DETECTING_WAKE_WORD); } break; @@ -148,13 +148,13 @@ void MicroWakeWord::loop() { break; case State::STOP_MICROPHONE: ESP_LOGD(TAG, "Stopping Microphone"); - this->microphone_->stop(); + this->microphone_source_->stop(); this->set_state_(State::STOPPING_MICROPHONE); this->unload_models_(); this->deallocate_buffers_(); break; case State::STOPPING_MICROPHONE: - if (this->microphone_->is_stopped()) { + if (this->microphone_source_->is_stopped()) { this->set_state_(State::IDLE); if (this->detected_) { this->wake_word_detected_trigger_->trigger(this->detected_wake_word_); diff --git a/esphome/components/micro_wake_word/micro_wake_word.h b/esphome/components/micro_wake_word/micro_wake_word.h index 443911b1e4..b06d35ca1f 100644 --- a/esphome/components/micro_wake_word/micro_wake_word.h +++ b/esphome/components/micro_wake_word/micro_wake_word.h @@ -9,7 +9,7 @@ #include "esphome/core/component.h" #include "esphome/core/ring_buffer.h" -#include "esphome/components/microphone/microphone.h" +#include "esphome/components/microphone/microphone_source.h" #include @@ -46,7 +46,9 @@ class MicroWakeWord : public Component { void set_features_step_size(uint8_t step_size) { this->features_step_size_ = step_size; } - void set_microphone(microphone::Microphone *microphone) { this->microphone_ = microphone; } + void set_microphone_source(microphone::MicrophoneSource *microphone_source) { + this->microphone_source_ = microphone_source; + } Trigger *get_wake_word_detected_trigger() const { return this->wake_word_detected_trigger_; } @@ -59,7 +61,7 @@ class MicroWakeWord : public Component { #endif protected: - microphone::Microphone *microphone_{nullptr}; + microphone::MicrophoneSource *microphone_source_{nullptr}; Trigger *wake_word_detected_trigger_ = new Trigger(); State state_{State::IDLE}; diff --git a/esphome/components/microphone/__init__.py b/esphome/components/microphone/__init__.py index b9d24bc4a7..dcae513578 100644 --- a/esphome/components/microphone/__init__.py +++ b/esphome/components/microphone/__init__.py @@ -36,7 +36,7 @@ StopCaptureAction = microphone_ns.class_( DataTrigger = microphone_ns.class_( "DataTrigger", - automation.Trigger.template(cg.std_vector.template(cg.int16).operator("ref")), + automation.Trigger.template(cg.std_vector.template(cg.uint8).operator("ref")), ) IsCapturingCondition = microphone_ns.class_( @@ -98,10 +98,11 @@ def microphone_source_schema( return config return cv.All( - cv.maybe_simple_value( + automation.maybe_conf( + CONF_MICROPHONE, { cv.GenerateID(CONF_ID): cv.declare_id(MicrophoneSource), - cv.Required(CONF_MICROPHONE): cv.use_id(Microphone), + cv.GenerateID(CONF_MICROPHONE): cv.use_id(Microphone), cv.Optional(CONF_BITS_PER_SAMPLE, default=16): cv.int_range( min_bits_per_sample, max_bits_per_sample ), @@ -112,7 +113,6 @@ def microphone_source_schema( ), cv.Optional(CONF_GAIN_FACTOR, default="1"): cv.int_range(1, 64), }, - key=CONF_MICROPHONE, ), ) diff --git a/esphome/components/microphone/automation.h b/esphome/components/microphone/automation.h index 29c0ec5df2..324699c0af 100644 --- a/esphome/components/microphone/automation.h +++ b/esphome/components/microphone/automation.h @@ -16,10 +16,10 @@ template class StopCaptureAction : public Action, public void play(Ts... x) override { this->parent_->stop(); } }; -class DataTrigger : public Trigger &> { +class DataTrigger : public Trigger &> { public: explicit DataTrigger(Microphone *mic) { - mic->add_data_callback([this](const std::vector &data) { this->trigger(data); }); + mic->add_data_callback([this](const std::vector &data) { this->trigger(data); }); } }; diff --git a/esphome/components/microphone/microphone.h b/esphome/components/microphone/microphone.h index 58552aa34a..cef8d0f4c3 100644 --- a/esphome/components/microphone/microphone.h +++ b/esphome/components/microphone/microphone.h @@ -22,10 +22,9 @@ class Microphone { public: virtual void start() = 0; virtual void stop() = 0; - void add_data_callback(std::function &)> &&data_callback) { + void add_data_callback(std::function &)> &&data_callback) { this->data_callbacks_.add(std::move(data_callback)); } - virtual size_t read(int16_t *buf, size_t len) = 0; bool is_running() const { return this->state_ == STATE_RUNNING; } bool is_stopped() const { return this->state_ == STATE_STOPPED; } @@ -37,7 +36,7 @@ class Microphone { audio::AudioStreamInfo audio_stream_info_; - CallbackManager &)> data_callbacks_{}; + CallbackManager &)> data_callbacks_{}; }; } // namespace microphone diff --git a/esphome/components/microphone/microphone_source.cpp b/esphome/components/microphone/microphone_source.cpp index 7e397348b9..dcd3b31622 100644 --- a/esphome/components/microphone/microphone_source.cpp +++ b/esphome/components/microphone/microphone_source.cpp @@ -10,9 +10,7 @@ void MicrophoneSource::add_data_callback(std::functionprocess_audio_(data)); } }; - // Future PR will uncomment this! It requires changing the callback vector to an uint8_t in every component using a - // mic callback. - // this->mic_->add_data_callback(std::move(filtered_callback)); + this->mic_->add_data_callback(std::move(filtered_callback)); } void MicrophoneSource::start() { diff --git a/esphome/components/voice_assistant/__init__.py b/esphome/components/voice_assistant/__init__.py index e8cdca94b8..ca0b6da742 100644 --- a/esphome/components/voice_assistant/__init__.py +++ b/esphome/components/voice_assistant/__init__.py @@ -88,7 +88,14 @@ CONFIG_SCHEMA = cv.All( cv.Schema( { cv.GenerateID(): cv.declare_id(VoiceAssistant), - cv.GenerateID(CONF_MICROPHONE): cv.use_id(microphone.Microphone), + cv.Optional( + CONF_MICROPHONE, default={} + ): microphone.microphone_source_schema( + min_bits_per_sample=16, + max_bits_per_sample=16, + min_channels=1, + max_channels=1, + ), cv.Exclusive(CONF_SPEAKER, "output"): cv.use_id(speaker.Speaker), cv.Exclusive(CONF_MEDIA_PLAYER, "output"): cv.use_id( media_player.MediaPlayer @@ -163,13 +170,26 @@ CONFIG_SCHEMA = cv.All( tts_stream_validate, ) +FINAL_VALIDATE_SCHEMA = cv.All( + cv.Schema( + { + cv.Optional( + CONF_MICROPHONE + ): microphone.final_validate_microphone_source_schema( + "voice_assistant", sample_rate=16000 + ), + }, + extra=cv.ALLOW_EXTRA, + ), +) + async def to_code(config): var = cg.new_Pvariable(config[CONF_ID]) await cg.register_component(var, config) - mic = await cg.get_variable(config[CONF_MICROPHONE]) - cg.add(var.set_microphone(mic)) + mic_source = await microphone.microphone_source_to_code(config[CONF_MICROPHONE]) + cg.add(var.set_microphone_source(mic_source)) if CONF_SPEAKER in config: spkr = await cg.get_variable(config[CONF_SPEAKER]) diff --git a/esphome/components/voice_assistant/voice_assistant.cpp b/esphome/components/voice_assistant/voice_assistant.cpp index c62767d7d5..37b97239c8 100644 --- a/esphome/components/voice_assistant/voice_assistant.cpp +++ b/esphome/components/voice_assistant/voice_assistant.cpp @@ -29,10 +29,10 @@ static const size_t SPEAKER_BUFFER_SIZE = 16 * RECEIVE_SIZE; VoiceAssistant::VoiceAssistant() { global_voice_assistant = this; } void VoiceAssistant::setup() { - this->mic_->add_data_callback([this](const std::vector &data) { + this->mic_source_->add_data_callback([this](const std::vector &data) { std::shared_ptr temp_ring_buffer = this->ring_buffer_; if (this->ring_buffer_.use_count() > 1) { - temp_ring_buffer->write((void *) data.data(), data.size() * sizeof(int16_t)); + temp_ring_buffer->write((void *) data.data(), data.size()); } }); } @@ -162,7 +162,7 @@ void VoiceAssistant::reset_conversation_id() { void VoiceAssistant::loop() { if (this->api_client_ == nullptr && this->state_ != State::IDLE && this->state_ != State::STOP_MICROPHONE && this->state_ != State::STOPPING_MICROPHONE) { - if (this->mic_->is_running() || this->state_ == State::STARTING_MICROPHONE) { + if (this->mic_source_->is_running() || this->state_ == State::STARTING_MICROPHONE) { this->set_state_(State::STOP_MICROPHONE, State::IDLE); } else { this->set_state_(State::IDLE, State::IDLE); @@ -193,12 +193,12 @@ void VoiceAssistant::loop() { } this->clear_buffers_(); - this->mic_->start(); + this->mic_source_->start(); this->set_state_(State::STARTING_MICROPHONE); break; } case State::STARTING_MICROPHONE: { - if (this->mic_->is_running()) { + if (this->mic_source_->is_running()) { this->set_state_(this->desired_state_); } break; @@ -262,8 +262,8 @@ void VoiceAssistant::loop() { break; } case State::STOP_MICROPHONE: { - if (this->mic_->is_running()) { - this->mic_->stop(); + if (this->mic_source_->is_running()) { + this->mic_source_->stop(); this->set_state_(State::STOPPING_MICROPHONE); } else { this->set_state_(this->desired_state_); @@ -271,7 +271,7 @@ void VoiceAssistant::loop() { break; } case State::STOPPING_MICROPHONE: { - if (this->mic_->is_stopped()) { + if (this->mic_source_->is_stopped()) { this->set_state_(this->desired_state_); } break; @@ -478,7 +478,7 @@ void VoiceAssistant::start_streaming() { ESP_LOGD(TAG, "Client started, streaming microphone"); this->audio_mode_ = AUDIO_MODE_API; - if (this->mic_->is_running()) { + if (this->mic_source_->is_running()) { this->set_state_(State::STREAMING_MICROPHONE, State::STREAMING_MICROPHONE); } else { this->set_state_(State::START_MICROPHONE, State::STREAMING_MICROPHONE); @@ -508,7 +508,7 @@ void VoiceAssistant::start_streaming(struct sockaddr_storage *addr, uint16_t por return; } - if (this->mic_->is_running()) { + if (this->mic_source_->is_running()) { this->set_state_(State::STREAMING_MICROPHONE, State::STREAMING_MICROPHONE); } else { this->set_state_(State::START_MICROPHONE, State::STREAMING_MICROPHONE); diff --git a/esphome/components/voice_assistant/voice_assistant.h b/esphome/components/voice_assistant/voice_assistant.h index cb57a6b05d..7122d69527 100644 --- a/esphome/components/voice_assistant/voice_assistant.h +++ b/esphome/components/voice_assistant/voice_assistant.h @@ -11,7 +11,7 @@ #include "esphome/components/api/api_connection.h" #include "esphome/components/api/api_pb2.h" -#include "esphome/components/microphone/microphone.h" +#include "esphome/components/microphone/microphone_source.h" #ifdef USE_SPEAKER #include "esphome/components/speaker/speaker.h" #endif @@ -98,7 +98,7 @@ class VoiceAssistant : public Component { void start_streaming(struct sockaddr_storage *addr, uint16_t port); void failed_to_start(); - void set_microphone(microphone::Microphone *mic) { this->mic_ = mic; } + void set_microphone_source(microphone::MicrophoneSource *mic_source) { this->mic_source_ = mic_source; } #ifdef USE_SPEAKER void set_speaker(speaker::Speaker *speaker) { this->speaker_ = speaker; @@ -249,7 +249,7 @@ class VoiceAssistant : public Component { bool has_timers_{false}; bool timer_tick_running_{false}; - microphone::Microphone *mic_{nullptr}; + microphone::MicrophoneSource *mic_source_{nullptr}; #ifdef USE_SPEAKER void write_speaker_(); speaker::Speaker *speaker_{nullptr}; diff --git a/tests/components/micro_wake_word/common.yaml b/tests/components/micro_wake_word/common.yaml index c5422baa67..b5507397f8 100644 --- a/tests/components/micro_wake_word/common.yaml +++ b/tests/components/micro_wake_word/common.yaml @@ -11,6 +11,7 @@ microphone: bits_per_sample: 16bit micro_wake_word: + microphone: echo_microphone on_wake_word_detected: - logger.log: "Wake word detected" models: diff --git a/tests/components/voice_assistant/common.yaml b/tests/components/voice_assistant/common.yaml index e7374941f7..f248154b7e 100644 --- a/tests/components/voice_assistant/common.yaml +++ b/tests/components/voice_assistant/common.yaml @@ -30,7 +30,10 @@ speaker: i2s_dout_pin: ${i2s_dout_pin} voice_assistant: - microphone: mic_id_external + microphone: + microphone: mic_id_external + gain_factor: 4 + channels: 0 speaker: speaker_id conversation_timeout: 60s on_listening: From c0be2c14f30296a80272e51d4d621dcaa87821d3 Mon Sep 17 00:00:00 2001 From: StriboYar Date: Wed, 30 Apr 2025 09:15:56 +0300 Subject: [PATCH 049/102] [debug] Fix compile errors when using the ESP32-C2 (#7474) Co-authored-by: Keith Burzinski --- esphome/components/debug/debug_esp32.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/esphome/components/debug/debug_esp32.cpp b/esphome/components/debug/debug_esp32.cpp index caa9f8d743..7367f54807 100644 --- a/esphome/components/debug/debug_esp32.cpp +++ b/esphome/components/debug/debug_esp32.cpp @@ -9,6 +9,8 @@ #if defined(USE_ESP32_VARIANT_ESP32) #include +#elif defined(USE_ESP32_VARIANT_ESP32C2) +#include #elif defined(USE_ESP32_VARIANT_ESP32C3) #include #elif defined(USE_ESP32_VARIANT_ESP32C6) @@ -123,9 +125,11 @@ std::string DebugComponent::get_reset_reason_() { case TG0WDT_SYS_RESET: reset_reason = "Timer Group 0 Watch Dog Reset Digital Core"; break; +#if !defined(USE_ESP32_VARIANT_ESP32C2) case TG1WDT_SYS_RESET: reset_reason = "Timer Group 1 Watch Dog Reset Digital Core"; break; +#endif case RTCWDT_SYS_RESET: reset_reason = "RTC Watch Dog Reset Digital Core"; break; @@ -245,6 +249,8 @@ void DebugComponent::get_device_info_(std::string &device_info) { const char *model; #if defined(USE_ESP32_VARIANT_ESP32) model = "ESP32"; +#elif defined(USE_ESP32_VARIANT_ESP32C2) + model = "ESP32-C2"; #elif defined(USE_ESP32_VARIANT_ESP32C3) model = "ESP32-C3"; #elif defined(USE_ESP32_VARIANT_ESP32C6) @@ -344,9 +350,11 @@ void DebugComponent::get_device_info_(std::string &device_info) { case UART1_TRIG: wakeup_reason = "UART1"; break; +#if !defined(USE_ESP32_VARIANT_ESP32C2) case TOUCH_TRIG: wakeup_reason = "Touch"; break; +#endif case SAR_TRIG: wakeup_reason = "SAR"; break; From caa255f5d18e6897cde85202ecae95fedea727ac Mon Sep 17 00:00:00 2001 From: Jesse Hills <3060199+jesserockz@users.noreply.github.com> Date: Wed, 30 Apr 2025 20:08:46 +1200 Subject: [PATCH 050/102] [media_player] Fix actions with id as value (#8654) --- esphome/components/media_player/__init__.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/esphome/components/media_player/__init__.py b/esphome/components/media_player/__init__.py index b2543ac05f..14fe1fdb6a 100644 --- a/esphome/components/media_player/__init__.py +++ b/esphome/components/media_player/__init__.py @@ -134,11 +134,13 @@ MEDIA_PLAYER_SCHEMA = cv.ENTITY_BASE_SCHEMA.extend( ) -MEDIA_PLAYER_ACTION_SCHEMA = cv.Schema( - { - cv.GenerateID(): cv.use_id(MediaPlayer), - cv.Optional(CONF_ANNOUNCEMENT, default=False): cv.templatable(cv.boolean), - } +MEDIA_PLAYER_ACTION_SCHEMA = automation.maybe_simple_id( + cv.Schema( + { + cv.GenerateID(): cv.use_id(MediaPlayer), + cv.Optional(CONF_ANNOUNCEMENT, default=False): cv.templatable(cv.boolean), + } + ) ) MEDIA_PLAYER_CONDITION_SCHEMA = automation.maybe_simple_id( From 07ba9fdf8f319b5969abd608001f4b4b34ade020 Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Wed, 30 Apr 2025 19:10:54 +1000 Subject: [PATCH 051/102] [canbus] Add callback for use by other components (#8578) Co-authored-by: clydeps --- esphome/components/canbus/canbus.cpp | 3 +++ esphome/components/canbus/canbus.h | 16 ++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/esphome/components/canbus/canbus.cpp b/esphome/components/canbus/canbus.cpp index 696cfff2b7..3b86f209cd 100644 --- a/esphome/components/canbus/canbus.cpp +++ b/esphome/components/canbus/canbus.cpp @@ -86,6 +86,9 @@ void Canbus::loop() { data.push_back(can_message.data[i]); } + this->callback_manager_(can_message.can_id, can_message.use_extended_id, can_message.remote_transmission_request, + data); + // fire all triggers for (auto *trigger : this->triggers_) { if ((trigger->can_id_ == (can_message.can_id & trigger->can_id_mask_)) && diff --git a/esphome/components/canbus/canbus.h b/esphome/components/canbus/canbus.h index 1e5214fef4..7319bfb4ad 100644 --- a/esphome/components/canbus/canbus.h +++ b/esphome/components/canbus/canbus.h @@ -81,6 +81,20 @@ class Canbus : public Component { void set_bitrate(CanSpeed bit_rate) { this->bit_rate_ = bit_rate; } void add_trigger(CanbusTrigger *trigger); + /** + * Add a callback to be called when a CAN message is received. All received messages + * are passed to the callback without filtering. + * + * The callback function receives: + * - can_id of the received data + * - extended_id True if the can_id is an extended id + * - rtr If this is a remote transmission request + * - data The message data + */ + void add_callback( + std::function &data)> callback) { + this->callback_manager_.add(std::move(callback)); + } protected: template friend class CanbusSendAction; @@ -88,6 +102,8 @@ class Canbus : public Component { uint32_t can_id_; bool use_extended_id_; CanSpeed bit_rate_; + CallbackManager &data)> + callback_manager_{}; virtual bool setup_internal(); virtual Error send_message(struct CanFrame *frame); From 20062576a3f088a4f6ab1f9f7bb1a52c78a545fe Mon Sep 17 00:00:00 2001 From: Kevin Ahrendt Date: Wed, 30 Apr 2025 04:50:56 -0500 Subject: [PATCH 052/102] [i2s_audio] Move microphone reads into a task (#8651) Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com> --- esphome/components/i2s_audio/__init__.py | 24 ++- esphome/components/i2s_audio/i2s_audio.h | 2 + .../i2s_audio/microphone/__init__.py | 2 + .../microphone/i2s_audio_microphone.cpp | 187 +++++++++++++----- .../microphone/i2s_audio_microphone.h | 18 +- .../components/i2s_audio/speaker/__init__.py | 2 + .../i2s_audio/speaker/i2s_audio_speaker.cpp | 4 +- tests/components/microphone/common.yaml | 1 + 8 files changed, 185 insertions(+), 55 deletions(-) diff --git a/esphome/components/i2s_audio/__init__.py b/esphome/components/i2s_audio/__init__.py index 291ae4ba95..0d413adb8a 100644 --- a/esphome/components/i2s_audio/__init__.py +++ b/esphome/components/i2s_audio/__init__.py @@ -39,6 +39,7 @@ CONF_SECONDARY = "secondary" CONF_USE_APLL = "use_apll" CONF_BITS_PER_CHANNEL = "bits_per_channel" +CONF_MCLK_MULTIPLE = "mclk_multiple" CONF_MONO = "mono" CONF_LEFT = "left" CONF_RIGHT = "right" @@ -122,8 +123,25 @@ I2S_SLOT_BIT_WIDTH = { 32: i2s_slot_bit_width_t.I2S_SLOT_BIT_WIDTH_32BIT, } +i2s_mclk_multiple_t = cg.global_ns.enum("i2s_mclk_multiple_t") +I2S_MCLK_MULTIPLE = { + 128: i2s_mclk_multiple_t.I2S_MCLK_MULTIPLE_128, + 256: i2s_mclk_multiple_t.I2S_MCLK_MULTIPLE_256, + 384: i2s_mclk_multiple_t.I2S_MCLK_MULTIPLE_384, + 512: i2s_mclk_multiple_t.I2S_MCLK_MULTIPLE_512, +} + _validate_bits = cv.float_with_unit("bits", "bit") + +def validate_mclk_divisible_by_3(config): + if config[CONF_BITS_PER_SAMPLE] == 24 and config[CONF_MCLK_MULTIPLE] % 3 != 0: + raise cv.Invalid( + f"{CONF_MCLK_MULTIPLE} must be divisible by 3 when bits per sample is 24" + ) + return config + + _use_legacy_driver = None @@ -155,6 +173,7 @@ def i2s_audio_component_schema( cv.Any(cv.float_with_unit("bits", "bit"), "default"), cv.one_of(*I2S_BITS_PER_CHANNEL), ), + cv.Optional(CONF_MCLK_MULTIPLE, default=256): cv.one_of(*I2S_MCLK_MULTIPLE), } ) @@ -182,11 +201,10 @@ async def register_i2s_audio_component(var, config): slot_mask = CONF_BOTH cg.add(var.set_slot_mode(I2S_SLOT_MODE[slot_mode])) cg.add(var.set_std_slot_mask(I2S_STD_SLOT_MASK[slot_mask])) - cg.add( - var.set_slot_bit_width(I2S_SLOT_BIT_WIDTH[config[CONF_BITS_PER_CHANNEL]]) - ) + cg.add(var.set_slot_bit_width(I2S_SLOT_BIT_WIDTH[config[CONF_BITS_PER_SAMPLE]])) cg.add(var.set_sample_rate(config[CONF_SAMPLE_RATE])) cg.add(var.set_use_apll(config[CONF_USE_APLL])) + cg.add(var.set_mclk_multiple(I2S_MCLK_MULTIPLE[config[CONF_MCLK_MULTIPLE]])) def validate_use_legacy(value): diff --git a/esphome/components/i2s_audio/i2s_audio.h b/esphome/components/i2s_audio/i2s_audio.h index d8050665e9..e839bcd891 100644 --- a/esphome/components/i2s_audio/i2s_audio.h +++ b/esphome/components/i2s_audio/i2s_audio.h @@ -31,6 +31,7 @@ class I2SAudioBase : public Parented { #endif void set_sample_rate(uint32_t sample_rate) { this->sample_rate_ = sample_rate; } void set_use_apll(uint32_t use_apll) { this->use_apll_ = use_apll; } + void set_mclk_multiple(i2s_mclk_multiple_t mclk_multiple) { this->mclk_multiple_ = mclk_multiple; } protected: #ifdef USE_I2S_LEGACY @@ -46,6 +47,7 @@ class I2SAudioBase : public Parented { #endif uint32_t sample_rate_; bool use_apll_; + i2s_mclk_multiple_t mclk_multiple_; }; class I2SAudioIn : public I2SAudioBase {}; diff --git a/esphome/components/i2s_audio/microphone/__init__.py b/esphome/components/i2s_audio/microphone/__init__.py index 06eb29986d..1fb4e9df99 100644 --- a/esphome/components/i2s_audio/microphone/__init__.py +++ b/esphome/components/i2s_audio/microphone/__init__.py @@ -22,6 +22,7 @@ from .. import ( i2s_audio_ns, register_i2s_audio_component, use_legacy, + validate_mclk_divisible_by_3, ) CODEOWNERS = ["@jesserockz"] @@ -112,6 +113,7 @@ CONFIG_SCHEMA = cv.All( _validate_channel, _set_num_channels_from_config, _set_stream_limits, + validate_mclk_divisible_by_3, ) diff --git a/esphome/components/i2s_audio/microphone/i2s_audio_microphone.cpp b/esphome/components/i2s_audio/microphone/i2s_audio_microphone.cpp index 78a7f92c2f..72d1e4476c 100644 --- a/esphome/components/i2s_audio/microphone/i2s_audio_microphone.cpp +++ b/esphome/components/i2s_audio/microphone/i2s_audio_microphone.cpp @@ -15,10 +15,25 @@ namespace esphome { namespace i2s_audio { -static const size_t BUFFER_SIZE = 512; +static const UBaseType_t MAX_LISTENERS = 16; + +static const uint32_t READ_DURATION_MS = 16; + +static const size_t TASK_STACK_SIZE = 4096; +static const ssize_t TASK_PRIORITY = 23; static const char *const TAG = "i2s_audio.microphone"; +enum MicrophoneEventGroupBits : uint32_t { + COMMAND_STOP = (1 << 0), // stops the microphone task + TASK_STARTING = (1 << 10), + TASK_RUNNING = (1 << 11), + TASK_STOPPING = (1 << 12), + TASK_STOPPED = (1 << 13), + + ALL_BITS = 0x00FFFFFF, // All valid FreeRTOS event group bits +}; + void I2SAudioMicrophone::setup() { ESP_LOGCONFIG(TAG, "Setting up I2S Audio Microphone..."); #ifdef USE_I2S_LEGACY @@ -41,18 +56,32 @@ void I2SAudioMicrophone::setup() { } } } + + this->active_listeners_semaphore_ = xSemaphoreCreateCounting(MAX_LISTENERS, MAX_LISTENERS); + if (this->active_listeners_semaphore_ == nullptr) { + ESP_LOGE(TAG, "Failed to create semaphore"); + this->mark_failed(); + return; + } + + this->event_group_ = xEventGroupCreate(); + if (this->event_group_ == nullptr) { + ESP_LOGE(TAG, "Failed to create event group"); + this->mark_failed(); + return; + } } void I2SAudioMicrophone::start() { if (this->is_failed()) return; - if (this->state_ == microphone::STATE_RUNNING) - return; // Already running - this->state_ = microphone::STATE_STARTING; + + xSemaphoreTake(this->active_listeners_semaphore_, 0); } -void I2SAudioMicrophone::start_() { + +bool I2SAudioMicrophone::start_driver_() { if (!this->parent_->try_lock()) { - return; // Waiting for another i2s to return lock + return false; // Waiting for another i2s to return lock } esp_err_t err; @@ -94,11 +123,11 @@ void I2SAudioMicrophone::start_() { .communication_format = I2S_COMM_FORMAT_STAND_I2S, .intr_alloc_flags = ESP_INTR_FLAG_LEVEL1, .dma_buf_count = 4, - .dma_buf_len = 256, + .dma_buf_len = 240, // Must be divisible by 3 to support 24 bits per sample on old driver and newer variants .use_apll = this->use_apll_, .tx_desc_auto_clear = false, .fixed_mclk = 0, - .mclk_multiple = I2S_MCLK_MULTIPLE_256, + .mclk_multiple = this->mclk_multiple_, .bits_per_chan = this->bits_per_channel_, }; @@ -109,20 +138,20 @@ void I2SAudioMicrophone::start_() { if (err != ESP_OK) { ESP_LOGW(TAG, "Error installing I2S driver: %s", esp_err_to_name(err)); this->status_set_error(); - return; + return false; } err = i2s_set_adc_mode(ADC_UNIT_1, this->adc_channel_); if (err != ESP_OK) { ESP_LOGW(TAG, "Error setting ADC mode: %s", esp_err_to_name(err)); this->status_set_error(); - return; + return false; } err = i2s_adc_enable(this->parent_->get_port()); if (err != ESP_OK) { ESP_LOGW(TAG, "Error enabling ADC: %s", esp_err_to_name(err)); this->status_set_error(); - return; + return false; } } else @@ -135,7 +164,7 @@ void I2SAudioMicrophone::start_() { if (err != ESP_OK) { ESP_LOGW(TAG, "Error installing I2S driver: %s", esp_err_to_name(err)); this->status_set_error(); - return; + return false; } i2s_pin_config_t pin_config = this->parent_->get_pin_config(); @@ -145,7 +174,7 @@ void I2SAudioMicrophone::start_() { if (err != ESP_OK) { ESP_LOGW(TAG, "Error setting I2S pin: %s", esp_err_to_name(err)); this->status_set_error(); - return; + return false; } } #else @@ -161,7 +190,7 @@ void I2SAudioMicrophone::start_() { if (err != ESP_OK) { ESP_LOGW(TAG, "Error creating new I2S channel: %s", esp_err_to_name(err)); this->status_set_error(); - return; + return false; } i2s_clock_src_t clk_src = I2S_CLK_SRC_DEFAULT; @@ -178,7 +207,7 @@ void I2SAudioMicrophone::start_() { i2s_pdm_rx_clk_config_t clk_cfg = { .sample_rate_hz = this->sample_rate_, .clk_src = clk_src, - .mclk_multiple = I2S_MCLK_MULTIPLE_256, + .mclk_multiple = this->mclk_multiple_, .dn_sample_mode = I2S_PDM_DSR_8S, }; @@ -216,7 +245,7 @@ void I2SAudioMicrophone::start_() { i2s_std_clk_config_t clk_cfg = { .sample_rate_hz = this->sample_rate_, .clk_src = clk_src, - .mclk_multiple = I2S_MCLK_MULTIPLE_256, + .mclk_multiple = this->mclk_multiple_, }; i2s_std_slot_config_t std_slot_cfg = I2S_STD_PHILIPS_SLOT_DEFAULT_CONFIG((i2s_data_bit_width_t) this->slot_bit_width_, this->slot_mode_); @@ -236,7 +265,7 @@ void I2SAudioMicrophone::start_() { if (err != ESP_OK) { ESP_LOGW(TAG, "Error initializing I2S channel: %s", esp_err_to_name(err)); this->status_set_error(); - return; + return false; } /* Before reading data, start the RX channel first */ @@ -244,28 +273,25 @@ void I2SAudioMicrophone::start_() { if (err != ESP_OK) { ESP_LOGW(TAG, "Error enabling I2S Microphone: %s", esp_err_to_name(err)); this->status_set_error(); - return; + return false; } #endif this->audio_stream_info_ = audio::AudioStreamInfo(bits_per_sample, channel_count, this->sample_rate_); - this->state_ = microphone::STATE_RUNNING; - this->high_freq_.start(); this->status_clear_error(); + + return true; } void I2SAudioMicrophone::stop() { if (this->state_ == microphone::STATE_STOPPED || this->is_failed()) return; - if (this->state_ == microphone::STATE_STARTING) { - this->state_ = microphone::STATE_STOPPED; - return; - } - this->state_ = microphone::STATE_STOPPING; + + xSemaphoreGive(this->active_listeners_semaphore_); } -void I2SAudioMicrophone::stop_() { +void I2SAudioMicrophone::stop_driver_() { esp_err_t err; #ifdef USE_I2S_LEGACY #if SOC_I2S_SUPPORTS_ADC @@ -307,11 +333,51 @@ void I2SAudioMicrophone::stop_() { } #endif this->parent_->unlock(); - this->state_ = microphone::STATE_STOPPED; - this->high_freq_.stop(); this->status_clear_error(); } +void I2SAudioMicrophone::mic_task(void *params) { + I2SAudioMicrophone *this_microphone = (I2SAudioMicrophone *) params; + + xEventGroupSetBits(this_microphone->event_group_, MicrophoneEventGroupBits::TASK_STARTING); + + uint8_t start_counter = 0; + bool started = this_microphone->start_driver_(); + while (!started && start_counter < 10) { + // Attempt to load the driver again in 100 ms. Doesn't slow down main loop since its in a task. + vTaskDelay(pdMS_TO_TICKS(100)); + ++start_counter; + started = this_microphone->start_driver_(); + } + + if (started) { + xEventGroupSetBits(this_microphone->event_group_, MicrophoneEventGroupBits::TASK_RUNNING); + const size_t bytes_to_read = this_microphone->audio_stream_info_.ms_to_bytes(READ_DURATION_MS); + std::vector samples; + samples.reserve(bytes_to_read); + + while (!(xEventGroupGetBits(this_microphone->event_group_) & COMMAND_STOP)) { + if (this_microphone->data_callbacks_.size() > 0) { + samples.resize(bytes_to_read); + size_t bytes_read = this_microphone->read_(samples.data(), bytes_to_read, 2 * pdMS_TO_TICKS(READ_DURATION_MS)); + samples.resize(bytes_read); + this_microphone->data_callbacks_.call(samples); + } else { + delay(READ_DURATION_MS); + } + } + } + + xEventGroupSetBits(this_microphone->event_group_, MicrophoneEventGroupBits::TASK_STOPPING); + this_microphone->stop_driver_(); + + xEventGroupSetBits(this_microphone->event_group_, MicrophoneEventGroupBits::TASK_STOPPED); + while (true) { + // Continuously delay until the loop method delete the task + delay(10); + } +} + size_t I2SAudioMicrophone::read_(uint8_t *buf, size_t len, TickType_t ticks_to_wait) { size_t bytes_read = 0; #ifdef USE_I2S_LEGACY @@ -345,29 +411,60 @@ size_t I2SAudioMicrophone::read_(uint8_t *buf, size_t len, TickType_t ticks_to_w return bytes_read; } -void I2SAudioMicrophone::read_() { - std::vector samples; - const size_t bytes_to_read = this->audio_stream_info_.ms_to_bytes(32); - samples.resize(bytes_to_read); - size_t bytes_read = this->read_(samples.data(), bytes_to_read, 0); - samples.resize(bytes_read); - this->data_callbacks_.call(samples); -} - void I2SAudioMicrophone::loop() { + uint32_t event_group_bits = xEventGroupGetBits(this->event_group_); + + if (event_group_bits & MicrophoneEventGroupBits::TASK_STARTING) { + ESP_LOGD(TAG, "Task has started, attempting to setup I2S audio driver"); + xEventGroupClearBits(this->event_group_, MicrophoneEventGroupBits::TASK_STARTING); + } + + if (event_group_bits & MicrophoneEventGroupBits::TASK_RUNNING) { + ESP_LOGD(TAG, "Task is running and reading data"); + + xEventGroupClearBits(this->event_group_, MicrophoneEventGroupBits::TASK_RUNNING); + this->state_ = microphone::STATE_RUNNING; + } + + if (event_group_bits & MicrophoneEventGroupBits::TASK_STOPPING) { + ESP_LOGD(TAG, "Task is stopping, attempting to unload the I2S audio driver"); + xEventGroupClearBits(this->event_group_, MicrophoneEventGroupBits::TASK_STOPPING); + } + + if ((event_group_bits & MicrophoneEventGroupBits::TASK_STOPPED)) { + ESP_LOGD(TAG, "Task is finished, freeing resources"); + vTaskDelete(this->task_handle_); + this->task_handle_ = nullptr; + xEventGroupClearBits(this->event_group_, ALL_BITS); + this->state_ = microphone::STATE_STOPPED; + } + + if ((uxSemaphoreGetCount(this->active_listeners_semaphore_) < MAX_LISTENERS) && + (this->state_ == microphone::STATE_STOPPED)) { + this->state_ = microphone::STATE_STARTING; + } + if ((uxSemaphoreGetCount(this->active_listeners_semaphore_) == MAX_LISTENERS) && + (this->state_ == microphone::STATE_RUNNING)) { + this->state_ = microphone::STATE_STOPPING; + } + switch (this->state_) { - case microphone::STATE_STOPPED: - break; case microphone::STATE_STARTING: - this->start_(); - break; - case microphone::STATE_RUNNING: - if (this->data_callbacks_.size() > 0) { - this->read_(); + if ((this->task_handle_ == nullptr) && !this->status_has_error()) { + xTaskCreate(I2SAudioMicrophone::mic_task, "mic_task", TASK_STACK_SIZE, (void *) this, TASK_PRIORITY, + &this->task_handle_); + + if (this->task_handle_ == nullptr) { + this->status_momentary_error("Task failed to start, attempting again in 1 second", 1000); + } } break; + case microphone::STATE_RUNNING: + break; case microphone::STATE_STOPPING: - this->stop_(); + xEventGroupSetBits(this->event_group_, MicrophoneEventGroupBits::COMMAND_STOP); + break; + case microphone::STATE_STOPPED: break; } } diff --git a/esphome/components/i2s_audio/microphone/i2s_audio_microphone.h b/esphome/components/i2s_audio/microphone/i2s_audio_microphone.h index 072d312e0f..8e6d83cad3 100644 --- a/esphome/components/i2s_audio/microphone/i2s_audio_microphone.h +++ b/esphome/components/i2s_audio/microphone/i2s_audio_microphone.h @@ -7,6 +7,9 @@ #include "esphome/components/microphone/microphone.h" #include "esphome/core/component.h" +#include +#include + namespace esphome { namespace i2s_audio { @@ -35,11 +38,18 @@ class I2SAudioMicrophone : public I2SAudioIn, public microphone::Microphone, pub #endif protected: - void start_(); - void stop_(); - void read_(); + bool start_driver_(); + void stop_driver_(); + size_t read_(uint8_t *buf, size_t len, TickType_t ticks_to_wait); + static void mic_task(void *params); + + SemaphoreHandle_t active_listeners_semaphore_{nullptr}; + EventGroupHandle_t event_group_{nullptr}; + + TaskHandle_t task_handle_{nullptr}; + #ifdef USE_I2S_LEGACY int8_t din_pin_{I2S_PIN_NO_CHANGE}; #if SOC_I2S_SUPPORTS_ADC @@ -51,8 +61,6 @@ class I2SAudioMicrophone : public I2SAudioIn, public microphone::Microphone, pub i2s_chan_handle_t rx_handle_; #endif bool pdm_{false}; - - HighFrequencyLoopRequester high_freq_; }; } // namespace i2s_audio diff --git a/esphome/components/i2s_audio/speaker/__init__.py b/esphome/components/i2s_audio/speaker/__init__.py index 7e41cd3991..bb9f24bf0b 100644 --- a/esphome/components/i2s_audio/speaker/__init__.py +++ b/esphome/components/i2s_audio/speaker/__init__.py @@ -27,6 +27,7 @@ from .. import ( i2s_audio_ns, register_i2s_audio_component, use_legacy, + validate_mclk_divisible_by_3, ) AUTO_LOAD = ["audio"] @@ -155,6 +156,7 @@ CONFIG_SCHEMA = cv.All( _validate_esp32_variant, _set_num_channels_from_config, _set_stream_limits, + validate_mclk_divisible_by_3, ) diff --git a/esphome/components/i2s_audio/speaker/i2s_audio_speaker.cpp b/esphome/components/i2s_audio/speaker/i2s_audio_speaker.cpp index cb3bbc8cf2..7d247003f7 100644 --- a/esphome/components/i2s_audio/speaker/i2s_audio_speaker.cpp +++ b/esphome/components/i2s_audio/speaker/i2s_audio_speaker.cpp @@ -545,7 +545,7 @@ esp_err_t I2SAudioSpeaker::start_i2s_driver_(audio::AudioStreamInfo &audio_strea .use_apll = this->use_apll_, .tx_desc_auto_clear = true, .fixed_mclk = I2S_PIN_NO_CHANGE, - .mclk_multiple = I2S_MCLK_MULTIPLE_256, + .mclk_multiple = this->mclk_multiple_, .bits_per_chan = this->bits_per_channel_, #if SOC_I2S_SUPPORTS_TDM .chan_mask = (i2s_channel_t) (I2S_TDM_ACTIVE_CH0 | I2S_TDM_ACTIVE_CH1), @@ -614,7 +614,7 @@ esp_err_t I2SAudioSpeaker::start_i2s_driver_(audio::AudioStreamInfo &audio_strea i2s_std_clk_config_t clk_cfg = { .sample_rate_hz = audio_stream_info.get_sample_rate(), .clk_src = clk_src, - .mclk_multiple = I2S_MCLK_MULTIPLE_256, + .mclk_multiple = this->mclk_multiple_, }; i2s_slot_mode_t slot_mode = this->slot_mode_; diff --git a/tests/components/microphone/common.yaml b/tests/components/microphone/common.yaml index ea79266281..ccadc7aee5 100644 --- a/tests/components/microphone/common.yaml +++ b/tests/components/microphone/common.yaml @@ -9,3 +9,4 @@ microphone: i2s_din_pin: ${i2s_din_pin} adc_type: external pdm: false + mclk_multiple: 384 From 6de6a0c82c07b543fb53af9a7ad5188654ff543b Mon Sep 17 00:00:00 2001 From: Stanislav Meduna Date: Thu, 1 May 2025 01:57:01 +0200 Subject: [PATCH 053/102] Only warn if the component blocked for a longer time than the last time (#8064) --- esphome/core/component.cpp | 22 +++++++++++++++++++--- esphome/core/component.h | 5 +++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/esphome/core/component.cpp b/esphome/core/component.cpp index b20964b872..a7e451b93d 100644 --- a/esphome/core/component.cpp +++ b/esphome/core/component.cpp @@ -39,6 +39,9 @@ const uint32_t STATUS_LED_OK = 0x0000; const uint32_t STATUS_LED_WARNING = 0x0100; const uint32_t STATUS_LED_ERROR = 0x0200; +const uint32_t WARN_IF_BLOCKING_OVER_MS = 50U; ///< Initial blocking time allowed without warning +const uint32_t WARN_IF_BLOCKING_INCREMENT_MS = 10U; ///< How long the blocking time must be larger to warn again + uint32_t global_state = 0; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) float Component::get_loop_priority() const { return 0.0f; } @@ -115,6 +118,13 @@ const char *Component::get_component_source() const { return ""; return this->component_source_; } +bool Component::should_warn_of_blocking(uint32_t blocking_time) { + if (blocking_time > this->warn_if_blocking_over_) { + this->warn_if_blocking_over_ = blocking_time + WARN_IF_BLOCKING_INCREMENT_MS; + return true; + } + return false; +} void Component::mark_failed() { ESP_LOGE(TAG, "Component %s was marked as failed.", this->get_component_source()); this->component_state_ &= ~COMPONENT_STATE_MASK; @@ -233,10 +243,16 @@ void PollingComponent::set_update_interval(uint32_t update_interval) { this->upd WarnIfComponentBlockingGuard::WarnIfComponentBlockingGuard(Component *component) : started_(millis()), component_(component) {} WarnIfComponentBlockingGuard::~WarnIfComponentBlockingGuard() { - uint32_t now = millis(); - if (now - started_ > 50) { + uint32_t blocking_time = millis() - this->started_; + bool should_warn; + if (this->component_ != nullptr) { + should_warn = this->component_->should_warn_of_blocking(blocking_time); + } else { + should_warn = blocking_time > WARN_IF_BLOCKING_OVER_MS; + } + if (should_warn) { const char *src = component_ == nullptr ? "" : component_->get_component_source(); - ESP_LOGW(TAG, "Component %s took a long time for an operation (%" PRIu32 " ms).", src, (now - started_)); + ESP_LOGW(TAG, "Component %s took a long time for an operation (%" PRIu32 " ms).", src, blocking_time); ESP_LOGW(TAG, "Components should block for at most 30 ms."); ; } diff --git a/esphome/core/component.h b/esphome/core/component.h index f5c56459b1..412074282d 100644 --- a/esphome/core/component.h +++ b/esphome/core/component.h @@ -65,6 +65,8 @@ extern const uint32_t STATUS_LED_ERROR; enum class RetryResult { DONE, RETRY }; +extern const uint32_t WARN_IF_BLOCKING_OVER_MS; + class Component { public: /** Where the component's initialization should happen. @@ -158,6 +160,8 @@ class Component { */ const char *get_component_source() const; + bool should_warn_of_blocking(uint32_t blocking_time); + protected: friend class Application; @@ -284,6 +288,7 @@ class Component { uint32_t component_state_{0x0000}; ///< State of this component. float setup_priority_override_{NAN}; const char *component_source_{nullptr}; + uint32_t warn_if_blocking_over_{WARN_IF_BLOCKING_OVER_MS}; std::string error_message_{}; }; From cdc77506de6ad01ecd14c244927bb1ea5495e581 Mon Sep 17 00:00:00 2001 From: Kevin Ahrendt Date: Wed, 30 Apr 2025 19:22:48 -0500 Subject: [PATCH 054/102] [micro_wake_word] add new VPE features (#8655) --- .../components/micro_wake_word/__init__.py | 103 ++- .../components/micro_wake_word/automation.h | 54 ++ .../micro_wake_word/micro_wake_word.cpp | 665 +++++++++--------- .../micro_wake_word/micro_wake_word.h | 145 ++-- .../micro_wake_word/preprocessor_settings.h | 19 + .../micro_wake_word/streaming_model.cpp | 193 ++++- .../micro_wake_word/streaming_model.h | 95 ++- esphome/core/defines.h | 1 + tests/components/micro_wake_word/common.yaml | 16 + 9 files changed, 788 insertions(+), 503 deletions(-) create mode 100644 esphome/components/micro_wake_word/automation.h diff --git a/esphome/components/micro_wake_word/__init__.py b/esphome/components/micro_wake_word/__init__.py index 9d5caca937..0efe2ac288 100644 --- a/esphome/components/micro_wake_word/__init__.py +++ b/esphome/components/micro_wake_word/__init__.py @@ -12,6 +12,7 @@ import esphome.config_validation as cv from esphome.const import ( CONF_FILE, CONF_ID, + CONF_INTERNAL, CONF_MICROPHONE, CONF_MODEL, CONF_PASSWORD, @@ -40,6 +41,7 @@ CONF_ON_WAKE_WORD_DETECTED = "on_wake_word_detected" CONF_PROBABILITY_CUTOFF = "probability_cutoff" CONF_SLIDING_WINDOW_AVERAGE_SIZE = "sliding_window_average_size" CONF_SLIDING_WINDOW_SIZE = "sliding_window_size" +CONF_STOP_AFTER_DETECTION = "stop_after_detection" CONF_TENSOR_ARENA_SIZE = "tensor_arena_size" CONF_VAD = "vad" @@ -49,13 +51,20 @@ micro_wake_word_ns = cg.esphome_ns.namespace("micro_wake_word") MicroWakeWord = micro_wake_word_ns.class_("MicroWakeWord", cg.Component) +DisableModelAction = micro_wake_word_ns.class_("DisableModelAction", automation.Action) +EnableModelAction = micro_wake_word_ns.class_("EnableModelAction", automation.Action) StartAction = micro_wake_word_ns.class_("StartAction", automation.Action) StopAction = micro_wake_word_ns.class_("StopAction", automation.Action) +ModelIsEnabledCondition = micro_wake_word_ns.class_( + "ModelIsEnabledCondition", automation.Condition +) IsRunningCondition = micro_wake_word_ns.class_( "IsRunningCondition", automation.Condition ) +WakeWordModel = micro_wake_word_ns.class_("WakeWordModel") + def _validate_json_filename(value): value = cv.string(value) @@ -169,9 +178,10 @@ def _convert_manifest_v1_to_v2(v1_manifest): # Original Inception-based V1 manifest models require a minimum of 45672 bytes v2_manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE] = 45672 - # Original Inception-based V1 manifest models use a 20 ms feature step size v2_manifest[KEY_MICRO][CONF_FEATURE_STEP_SIZE] = 20 + # Original Inception-based V1 manifest models were trained only on TTS English samples + v2_manifest[KEY_TRAINED_LANGUAGES] = ["en"] return v2_manifest @@ -296,14 +306,16 @@ MODEL_SOURCE_SCHEMA = cv.Any( MODEL_SCHEMA = cv.Schema( { + cv.GenerateID(CONF_ID): cv.declare_id(WakeWordModel), cv.Optional(CONF_MODEL): MODEL_SOURCE_SCHEMA, cv.Optional(CONF_PROBABILITY_CUTOFF): cv.percentage, cv.Optional(CONF_SLIDING_WINDOW_SIZE): cv.positive_int, + cv.Optional(CONF_INTERNAL, default=False): cv.boolean, cv.GenerateID(CONF_RAW_DATA_ID): cv.declare_id(cg.uint8), } ) -# Provide a default VAD model that could be overridden +# Provides a default VAD model that could be overridden VAD_MODEL_SCHEMA = MODEL_SCHEMA.extend( cv.Schema( { @@ -343,6 +355,7 @@ CONFIG_SCHEMA = cv.All( single=True ), cv.Optional(CONF_VAD): _maybe_empty_vad_schema, + cv.Optional(CONF_STOP_AFTER_DETECTION, default=True): cv.boolean, cv.Optional(CONF_MODEL): cv.invalid( f"The {CONF_MODEL} parameter has moved to be a list element under the {CONF_MODELS} parameter." ), @@ -433,29 +446,20 @@ async def to_code(config): mic_source = await microphone.microphone_source_to_code(config[CONF_MICROPHONE]) cg.add(var.set_microphone_source(mic_source)) + cg.add_define("USE_MICRO_WAKE_WORD") + cg.add_define("USE_OTA_STATE_CALLBACK") + esp32.add_idf_component( name="esp-tflite-micro", repo="https://github.com/espressif/esp-tflite-micro", - ref="v1.3.1", - ) - # add esp-nn dependency for tflite-micro to work around https://github.com/espressif/esp-nn/issues/17 - # ...remove after switching to IDF 5.1.4+ - esp32.add_idf_component( - name="esp-nn", - repo="https://github.com/espressif/esp-nn", - ref="v1.1.0", + ref="v1.3.3.1", ) cg.add_build_flag("-DTF_LITE_STATIC_MEMORY") cg.add_build_flag("-DTF_LITE_DISABLE_X86_NEON") cg.add_build_flag("-DESP_NN") - if on_wake_word_detection_config := config.get(CONF_ON_WAKE_WORD_DETECTED): - await automation.build_automation( - var.get_wake_word_detected_trigger(), - [(cg.std_string, "wake_word")], - on_wake_word_detection_config, - ) + cg.add_library("kahrendt/ESPMicroSpeechFeatures", "1.1.0") if vad_model := config.get(CONF_VAD): cg.add_define("USE_MICRO_WAKE_WORD_VAD") @@ -463,7 +467,7 @@ async def to_code(config): # Use the general model loading code for the VAD codegen config[CONF_MODELS].append(vad_model) - for model_parameters in config[CONF_MODELS]: + for i, model_parameters in enumerate(config[CONF_MODELS]): model_config = model_parameters.get(CONF_MODEL) data = [] manifest, data = _model_config_to_manifest_data(model_config) @@ -474,6 +478,8 @@ async def to_code(config): probability_cutoff = model_parameters.get( CONF_PROBABILITY_CUTOFF, manifest[KEY_MICRO][CONF_PROBABILITY_CUTOFF] ) + quantized_probability_cutoff = int(probability_cutoff * 255) + sliding_window_size = model_parameters.get( CONF_SLIDING_WINDOW_SIZE, manifest[KEY_MICRO][CONF_SLIDING_WINDOW_SIZE], @@ -483,24 +489,40 @@ async def to_code(config): cg.add( var.add_vad_model( prog_arr, - probability_cutoff, + quantized_probability_cutoff, sliding_window_size, manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE], ) ) else: - cg.add( - var.add_wake_word_model( - prog_arr, - probability_cutoff, - sliding_window_size, - manifest[KEY_WAKE_WORD], - manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE], - ) + # Only enable the first wake word by default. After first boot, the enable state is saved/loaded to the flash + default_enabled = i == 0 + wake_word_model = cg.new_Pvariable( + model_parameters[CONF_ID], + str(model_parameters[CONF_ID]), + prog_arr, + quantized_probability_cutoff, + sliding_window_size, + manifest[KEY_WAKE_WORD], + manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE], + default_enabled, + model_parameters[CONF_INTERNAL], ) + for lang in manifest[KEY_TRAINED_LANGUAGES]: + cg.add(wake_word_model.add_trained_language(lang)) + + cg.add(var.add_wake_word_model(wake_word_model)) + cg.add(var.set_features_step_size(manifest[KEY_MICRO][CONF_FEATURE_STEP_SIZE])) - cg.add_library("kahrendt/ESPMicroSpeechFeatures", "1.1.0") + cg.add(var.set_stop_after_detection(config[CONF_STOP_AFTER_DETECTION])) + + if on_wake_word_detection_config := config.get(CONF_ON_WAKE_WORD_DETECTED): + await automation.build_automation( + var.get_wake_word_detected_trigger(), + [(cg.std_string, "wake_word")], + on_wake_word_detection_config, + ) MICRO_WAKE_WORD_ACTION_SCHEMA = cv.Schema({cv.GenerateID(): cv.use_id(MicroWakeWord)}) @@ -515,3 +537,30 @@ async def micro_wake_word_action_to_code(config, action_id, template_arg, args): var = cg.new_Pvariable(action_id, template_arg) await cg.register_parented(var, config[CONF_ID]) return var + + +MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA = automation.maybe_simple_id( + { + cv.Required(CONF_ID): cv.use_id(WakeWordModel), + } +) + + +@register_action( + "micro_wake_word.enable_model", + EnableModelAction, + MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA, +) +@register_action( + "micro_wake_word.disable_model", + DisableModelAction, + MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA, +) +@register_condition( + "micro_wake_word.model_is_enabled", + ModelIsEnabledCondition, + MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA, +) +async def model_action(config, action_id, template_arg, args): + parent = await cg.get_variable(config[CONF_ID]) + return cg.new_Pvariable(action_id, template_arg, parent) diff --git a/esphome/components/micro_wake_word/automation.h b/esphome/components/micro_wake_word/automation.h new file mode 100644 index 0000000000..f10a4ed347 --- /dev/null +++ b/esphome/components/micro_wake_word/automation.h @@ -0,0 +1,54 @@ +#pragma once + +#include "micro_wake_word.h" +#include "streaming_model.h" + +#ifdef USE_ESP_IDF +namespace esphome { +namespace micro_wake_word { + +template class StartAction : public Action, public Parented { + public: + void play(Ts... x) override { this->parent_->start(); } +}; + +template class StopAction : public Action, public Parented { + public: + void play(Ts... x) override { this->parent_->stop(); } +}; + +template class IsRunningCondition : public Condition, public Parented { + public: + bool check(Ts... x) override { return this->parent_->is_running(); } +}; + +template class EnableModelAction : public Action { + public: + explicit EnableModelAction(WakeWordModel *wake_word_model) : wake_word_model_(wake_word_model) {} + void play(Ts... x) override { this->wake_word_model_->enable(); } + + protected: + WakeWordModel *wake_word_model_; +}; + +template class DisableModelAction : public Action { + public: + explicit DisableModelAction(WakeWordModel *wake_word_model) : wake_word_model_(wake_word_model) {} + void play(Ts... x) override { this->wake_word_model_->disable(); } + + protected: + WakeWordModel *wake_word_model_; +}; + +template class ModelIsEnabledCondition : public Condition { + public: + explicit ModelIsEnabledCondition(WakeWordModel *wake_word_model) : wake_word_model_(wake_word_model) {} + bool check(Ts... x) override { return this->wake_word_model_->is_enabled(); } + + protected: + WakeWordModel *wake_word_model_; +}; + +} // namespace micro_wake_word +} // namespace esphome +#endif diff --git a/esphome/components/micro_wake_word/micro_wake_word.cpp b/esphome/components/micro_wake_word/micro_wake_word.cpp index dd1a8be378..f768b661c0 100644 --- a/esphome/components/micro_wake_word/micro_wake_word.cpp +++ b/esphome/components/micro_wake_word/micro_wake_word.cpp @@ -1,5 +1,4 @@ #include "micro_wake_word.h" -#include "streaming_model.h" #ifdef USE_ESP_IDF @@ -7,41 +6,57 @@ #include "esphome/core/helpers.h" #include "esphome/core/log.h" -#include -#include +#include "esphome/components/audio/audio_transfer_buffer.h" -#include -#include -#include - -#include +#ifdef USE_OTA +#include "esphome/components/ota/ota_backend.h" +#endif namespace esphome { namespace micro_wake_word { static const char *const TAG = "micro_wake_word"; -static const size_t SAMPLE_RATE_HZ = 16000; // 16 kHz -static const size_t BUFFER_LENGTH = 64; // 0.064 seconds -static const size_t BUFFER_SIZE = SAMPLE_RATE_HZ / 1000 * BUFFER_LENGTH; -static const size_t INPUT_BUFFER_SIZE = 16 * SAMPLE_RATE_HZ / 1000; // 16ms * 16kHz / 1000ms +static const ssize_t DETECTION_QUEUE_LENGTH = 5; + +static const size_t DATA_TIMEOUT_MS = 50; + +static const uint32_t RING_BUFFER_DURATION_MS = 120; +static const uint32_t RING_BUFFER_SAMPLES = RING_BUFFER_DURATION_MS * (AUDIO_SAMPLE_FREQUENCY / 1000); +static const size_t RING_BUFFER_SIZE = RING_BUFFER_SAMPLES * sizeof(int16_t); + +static const uint32_t INFERENCE_TASK_STACK_SIZE = 3072; +static const UBaseType_t INFERENCE_TASK_PRIORITY = 3; + +enum EventGroupBits : uint32_t { + COMMAND_STOP = (1 << 0), // Signals the inference task should stop + + TASK_STARTING = (1 << 3), + TASK_RUNNING = (1 << 4), + TASK_STOPPING = (1 << 5), + TASK_STOPPED = (1 << 6), + + ERROR_MEMORY = (1 << 9), + ERROR_INFERENCE = (1 << 10), + + WARNING_FULL_RING_BUFFER = (1 << 13), + + ERROR_BITS = ERROR_MEMORY | ERROR_INFERENCE, + ALL_BITS = 0xfffff, // 24 total bits available in an event group +}; float MicroWakeWord::get_setup_priority() const { return setup_priority::AFTER_CONNECTION; } static const LogString *micro_wake_word_state_to_string(State state) { switch (state) { - case State::IDLE: - return LOG_STR("IDLE"); - case State::START_MICROPHONE: - return LOG_STR("START_MICROPHONE"); - case State::STARTING_MICROPHONE: - return LOG_STR("STARTING_MICROPHONE"); + case State::STARTING: + return LOG_STR("STARTING"); case State::DETECTING_WAKE_WORD: return LOG_STR("DETECTING_WAKE_WORD"); - case State::STOP_MICROPHONE: - return LOG_STR("STOP_MICROPHONE"); - case State::STOPPING_MICROPHONE: - return LOG_STR("STOPPING_MICROPHONE"); + case State::STOPPING: + return LOG_STR("STOPPING"); + case State::STOPPED: + return LOG_STR("STOPPED"); default: return LOG_STR("UNKNOWN"); } @@ -51,7 +66,7 @@ void MicroWakeWord::dump_config() { ESP_LOGCONFIG(TAG, "microWakeWord:"); ESP_LOGCONFIG(TAG, " models:"); for (auto &model : this->wake_word_models_) { - model.log_model_config(); + model->log_model_config(); } #ifdef USE_MICRO_WAKE_WORD_VAD this->vad_model_->log_model_config(); @@ -61,108 +76,266 @@ void MicroWakeWord::dump_config() { void MicroWakeWord::setup() { ESP_LOGCONFIG(TAG, "Setting up microWakeWord..."); + this->frontend_config_.window.size_ms = FEATURE_DURATION_MS; + this->frontend_config_.window.step_size_ms = this->features_step_size_; + this->frontend_config_.filterbank.num_channels = PREPROCESSOR_FEATURE_SIZE; + this->frontend_config_.filterbank.lower_band_limit = FILTERBANK_LOWER_BAND_LIMIT; + this->frontend_config_.filterbank.upper_band_limit = FILTERBANK_UPPER_BAND_LIMIT; + this->frontend_config_.noise_reduction.smoothing_bits = NOISE_REDUCTION_SMOOTHING_BITS; + this->frontend_config_.noise_reduction.even_smoothing = NOISE_REDUCTION_EVEN_SMOOTHING; + this->frontend_config_.noise_reduction.odd_smoothing = NOISE_REDUCTION_ODD_SMOOTHING; + this->frontend_config_.noise_reduction.min_signal_remaining = NOISE_REDUCTION_MIN_SIGNAL_REMAINING; + this->frontend_config_.pcan_gain_control.enable_pcan = PCAN_GAIN_CONTROL_ENABLE_PCAN; + this->frontend_config_.pcan_gain_control.strength = PCAN_GAIN_CONTROL_STRENGTH; + this->frontend_config_.pcan_gain_control.offset = PCAN_GAIN_CONTROL_OFFSET; + this->frontend_config_.pcan_gain_control.gain_bits = PCAN_GAIN_CONTROL_GAIN_BITS; + this->frontend_config_.log_scale.enable_log = LOG_SCALE_ENABLE_LOG; + this->frontend_config_.log_scale.scale_shift = LOG_SCALE_SCALE_SHIFT; + + this->event_group_ = xEventGroupCreate(); + if (this->event_group_ == nullptr) { + ESP_LOGE(TAG, "Failed to create event group"); + this->mark_failed(); + return; + } + + this->detection_queue_ = xQueueCreate(DETECTION_QUEUE_LENGTH, sizeof(DetectionEvent)); + if (this->detection_queue_ == nullptr) { + ESP_LOGE(TAG, "Failed to create detection event queue"); + this->mark_failed(); + return; + } + this->microphone_source_->add_data_callback([this](const std::vector &data) { - if (this->state_ != State::DETECTING_WAKE_WORD) { + if (this->state_ == State::STOPPED) { return; } - std::shared_ptr temp_ring_buffer = this->ring_buffer_; - if (this->ring_buffer_.use_count() == 2) { - // mWW still owns the ring buffer and temp_ring_buffer does as well, proceed to copy audio into ring buffer - + std::shared_ptr temp_ring_buffer = this->ring_buffer_.lock(); + if (this->ring_buffer_.use_count() > 1) { size_t bytes_free = temp_ring_buffer->free(); if (bytes_free < data.size()) { - ESP_LOGW( - TAG, - "Not enough free bytes in ring buffer to store incoming audio data (free bytes=%d, incoming bytes=%d). " - "Resetting the ring buffer. Wake word detection accuracy will be reduced.", - bytes_free, data.size()); - + xEventGroupSetBits(this->event_group_, EventGroupBits::WARNING_FULL_RING_BUFFER); temp_ring_buffer->reset(); } temp_ring_buffer->write((void *) data.data(), data.size()); } }); - if (!this->register_streaming_ops_(this->streaming_op_resolver_)) { - this->mark_failed(); - return; +#ifdef USE_OTA + ota::get_global_ota_callback()->add_on_state_callback( + [this](ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) { + if (state == ota::OTA_STARTED) { + this->suspend_task_(); + } else if (state == ota::OTA_ERROR) { + this->resume_task_(); + } + }); +#endif + ESP_LOGCONFIG(TAG, "Micro Wake Word initialized"); +} + +void MicroWakeWord::inference_task(void *params) { + MicroWakeWord *this_mww = (MicroWakeWord *) params; + + xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STARTING); + + { // Ensures any C++ objects fall out of scope to deallocate before deleting the task + const size_t new_samples_to_read = this_mww->features_step_size_ * (AUDIO_SAMPLE_FREQUENCY / 1000); + std::unique_ptr audio_buffer; + int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE]; + + if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) { + // Allocate audio transfer buffer + audio_buffer = audio::AudioSourceTransferBuffer::create(new_samples_to_read * sizeof(int16_t)); + + if (audio_buffer == nullptr) { + xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_MEMORY); + } + } + + if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) { + // Allocate ring buffer + std::shared_ptr temp_ring_buffer = RingBuffer::create(RING_BUFFER_SIZE); + if (temp_ring_buffer.use_count() == 0) { + xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_MEMORY); + } + audio_buffer->set_source(temp_ring_buffer); + this_mww->ring_buffer_ = temp_ring_buffer; + } + + if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) { + this_mww->microphone_source_->start(); + xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_RUNNING); + + while (!(xEventGroupGetBits(this_mww->event_group_) & COMMAND_STOP)) { + audio_buffer->transfer_data_from_source(pdMS_TO_TICKS(DATA_TIMEOUT_MS)); + + if (audio_buffer->available() < new_samples_to_read * sizeof(int16_t)) { + // Insufficient data to generate new spectrogram features, read more next iteration + continue; + } + + // Generate new spectrogram features + size_t processed_samples = this_mww->generate_features_( + (int16_t *) audio_buffer->get_buffer_start(), audio_buffer->available() / sizeof(int16_t), features_buffer); + audio_buffer->decrease_buffer_length(processed_samples * sizeof(int16_t)); + + // Run inference using the new spectorgram features + if (!this_mww->update_model_probabilities_(features_buffer)) { + xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_INFERENCE); + break; + } + + // Process each model's probabilities and possibly send a Detection Event to the queue + this_mww->process_probabilities_(); + } + } } - ESP_LOGCONFIG(TAG, "Micro Wake Word initialized"); + xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STOPPING); - this->frontend_config_.window.size_ms = FEATURE_DURATION_MS; - this->frontend_config_.window.step_size_ms = this->features_step_size_; - this->frontend_config_.filterbank.num_channels = PREPROCESSOR_FEATURE_SIZE; - this->frontend_config_.filterbank.lower_band_limit = 125.0; - this->frontend_config_.filterbank.upper_band_limit = 7500.0; - this->frontend_config_.noise_reduction.smoothing_bits = 10; - this->frontend_config_.noise_reduction.even_smoothing = 0.025; - this->frontend_config_.noise_reduction.odd_smoothing = 0.06; - this->frontend_config_.noise_reduction.min_signal_remaining = 0.05; - this->frontend_config_.pcan_gain_control.enable_pcan = 1; - this->frontend_config_.pcan_gain_control.strength = 0.95; - this->frontend_config_.pcan_gain_control.offset = 80.0; - this->frontend_config_.pcan_gain_control.gain_bits = 21; - this->frontend_config_.log_scale.enable_log = 1; - this->frontend_config_.log_scale.scale_shift = 6; + this_mww->unload_models_(); + this_mww->microphone_source_->stop(); + FrontendFreeStateContents(&this_mww->frontend_state_); + + xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STOPPED); + while (true) { + // Continuously delay until the main loop deletes the task + delay(10); + } } -void MicroWakeWord::add_wake_word_model(const uint8_t *model_start, float probability_cutoff, - size_t sliding_window_average_size, const std::string &wake_word, - size_t tensor_arena_size) { - this->wake_word_models_.emplace_back(model_start, probability_cutoff, sliding_window_average_size, wake_word, - tensor_arena_size); +std::vector MicroWakeWord::get_wake_words() { + std::vector external_wake_word_models; + for (auto *model : this->wake_word_models_) { + if (!model->get_internal_only()) { + external_wake_word_models.push_back(model); + } + } + return external_wake_word_models; } +void MicroWakeWord::add_wake_word_model(WakeWordModel *model) { this->wake_word_models_.push_back(model); } + #ifdef USE_MICRO_WAKE_WORD_VAD -void MicroWakeWord::add_vad_model(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size, +void MicroWakeWord::add_vad_model(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size) { this->vad_model_ = make_unique(model_start, probability_cutoff, sliding_window_size, tensor_arena_size); } #endif +void MicroWakeWord::suspend_task_() { + if (this->inference_task_handle_ != nullptr) { + vTaskSuspend(this->inference_task_handle_); + } +} + +void MicroWakeWord::resume_task_() { + if (this->inference_task_handle_ != nullptr) { + vTaskResume(this->inference_task_handle_); + } +} + void MicroWakeWord::loop() { + uint32_t event_group_bits = xEventGroupGetBits(this->event_group_); + + if (event_group_bits & EventGroupBits::ERROR_MEMORY) { + xEventGroupClearBits(this->event_group_, EventGroupBits::ERROR_MEMORY); + ESP_LOGE(TAG, "Encountered an error allocating buffers"); + } + + if (event_group_bits & EventGroupBits::ERROR_INFERENCE) { + xEventGroupClearBits(this->event_group_, EventGroupBits::ERROR_INFERENCE); + ESP_LOGE(TAG, "Encountered an error while performing an inference"); + } + + if (event_group_bits & EventGroupBits::WARNING_FULL_RING_BUFFER) { + xEventGroupClearBits(this->event_group_, EventGroupBits::WARNING_FULL_RING_BUFFER); + ESP_LOGW(TAG, "Not enough free bytes in ring buffer to store incoming audio data. Resetting the ring buffer. Wake " + "word detection accuracy will temporarily be reduced."); + } + + if (event_group_bits & EventGroupBits::TASK_STARTING) { + ESP_LOGD(TAG, "Inference task has started, attempting to allocate memory for buffers"); + xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_STARTING); + } + + if (event_group_bits & EventGroupBits::TASK_RUNNING) { + ESP_LOGD(TAG, "Inference task is running"); + + xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_RUNNING); + this->set_state_(State::DETECTING_WAKE_WORD); + } + + if (event_group_bits & EventGroupBits::TASK_STOPPING) { + ESP_LOGD(TAG, "Inference task is stopping, deallocating buffers"); + xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_STOPPING); + } + + if ((event_group_bits & EventGroupBits::TASK_STOPPED)) { + ESP_LOGD(TAG, "Inference task is finished, freeing task resources"); + vTaskDelete(this->inference_task_handle_); + this->inference_task_handle_ = nullptr; + xEventGroupClearBits(this->event_group_, ALL_BITS); + xQueueReset(this->detection_queue_); + this->set_state_(State::STOPPED); + } + + if ((this->pending_start_) && (this->state_ == State::STOPPED)) { + this->set_state_(State::STARTING); + this->pending_start_ = false; + } + + if ((this->pending_stop_) && (this->state_ == State::DETECTING_WAKE_WORD)) { + this->set_state_(State::STOPPING); + this->pending_stop_ = false; + } + switch (this->state_) { - case State::IDLE: - break; - case State::START_MICROPHONE: - ESP_LOGD(TAG, "Starting Microphone"); - this->microphone_source_->start(); - this->set_state_(State::STARTING_MICROPHONE); - break; - case State::STARTING_MICROPHONE: - if (this->microphone_source_->is_running()) { - this->set_state_(State::DETECTING_WAKE_WORD); - } - break; - case State::DETECTING_WAKE_WORD: - while (this->has_enough_samples_()) { - this->update_model_probabilities_(); - if (this->detect_wake_words_()) { - ESP_LOGD(TAG, "Wake Word '%s' Detected", (this->detected_wake_word_).c_str()); - this->detected_ = true; - this->set_state_(State::STOP_MICROPHONE); + case State::STARTING: + if ((this->inference_task_handle_ == nullptr) && !this->status_has_error()) { + // Setup preprocesor feature generator. If done in the task, it would lock the task to its initial core, as it + // uses floating point operations. + if (!FrontendPopulateState(&this->frontend_config_, &this->frontend_state_, AUDIO_SAMPLE_FREQUENCY)) { + this->status_momentary_error( + "Failed to allocate buffers for spectrogram feature processor, attempting again in 1 second", 1000); + return; + } + + xTaskCreate(MicroWakeWord::inference_task, "mww", INFERENCE_TASK_STACK_SIZE, (void *) this, + INFERENCE_TASK_PRIORITY, &this->inference_task_handle_); + + if (this->inference_task_handle_ == nullptr) { + FrontendFreeStateContents(&this->frontend_state_); // Deallocate frontend state + this->status_momentary_error("Task failed to start, attempting again in 1 second", 1000); } } break; - case State::STOP_MICROPHONE: - ESP_LOGD(TAG, "Stopping Microphone"); - this->microphone_source_->stop(); - this->set_state_(State::STOPPING_MICROPHONE); - this->unload_models_(); - this->deallocate_buffers_(); - break; - case State::STOPPING_MICROPHONE: - if (this->microphone_source_->is_stopped()) { - this->set_state_(State::IDLE); - if (this->detected_) { - this->wake_word_detected_trigger_->trigger(this->detected_wake_word_); - this->detected_ = false; - this->detected_wake_word_ = ""; + case State::DETECTING_WAKE_WORD: { + DetectionEvent detection_event; + while (xQueueReceive(this->detection_queue_, &detection_event, 0)) { + if (detection_event.blocked_by_vad) { + ESP_LOGD(TAG, "Wake word model predicts '%s', but VAD model doesn't.", detection_event.wake_word->c_str()); + } else { + constexpr float uint8_to_float_divisor = + 255.0f; // Converting a quantized uint8 probability to floating point + ESP_LOGD(TAG, "Detected '%s' with sliding average probability is %.2f and max probability is %.2f", + detection_event.wake_word->c_str(), (detection_event.average_probability / uint8_to_float_divisor), + (detection_event.max_probability / uint8_to_float_divisor)); + this->wake_word_detected_trigger_->trigger(*detection_event.wake_word); + if (this->stop_after_detection_) { + this->stop(); + } } } break; + } + case State::STOPPING: + xEventGroupSetBits(this->event_group_, EventGroupBits::COMMAND_STOP); + break; + case State::STOPPED: + break; } } @@ -177,199 +350,40 @@ void MicroWakeWord::start() { return; } - if (this->state_ != State::IDLE) { - ESP_LOGW(TAG, "Wake word is already running"); + if (this->is_running()) { + ESP_LOGW(TAG, "Wake word detection is already running"); return; } - if (!this->load_models_() || !this->allocate_buffers_()) { - ESP_LOGE(TAG, "Failed to load the wake word model(s) or allocate buffers"); - this->status_set_error(); - } else { - this->status_clear_error(); - } + ESP_LOGD(TAG, "Starting wake word detection"); - if (this->status_has_error()) { - ESP_LOGW(TAG, "Wake word component has an error. Please check logs"); - return; - } - - this->reset_states_(); - this->set_state_(State::START_MICROPHONE); + this->pending_start_ = true; + this->pending_stop_ = false; } void MicroWakeWord::stop() { - if (this->state_ == State::IDLE) { - ESP_LOGW(TAG, "Wake word is already stopped"); + if (this->state_ == STOPPED) return; - } - if (this->state_ == State::STOPPING_MICROPHONE) { - ESP_LOGW(TAG, "Wake word is already stopping"); - return; - } - this->set_state_(State::STOP_MICROPHONE); + + ESP_LOGD(TAG, "Stopping wake word detection"); + + this->pending_start_ = false; + this->pending_stop_ = true; } void MicroWakeWord::set_state_(State state) { - ESP_LOGD(TAG, "State changed from %s to %s", LOG_STR_ARG(micro_wake_word_state_to_string(this->state_)), - LOG_STR_ARG(micro_wake_word_state_to_string(state))); - this->state_ = state; + if (this->state_ != state) { + ESP_LOGD(TAG, "State changed from %s to %s", LOG_STR_ARG(micro_wake_word_state_to_string(this->state_)), + LOG_STR_ARG(micro_wake_word_state_to_string(state))); + this->state_ = state; + } } -bool MicroWakeWord::allocate_buffers_() { - ExternalRAMAllocator audio_samples_allocator(ExternalRAMAllocator::ALLOW_FAILURE); - - if (this->input_buffer_ == nullptr) { - this->input_buffer_ = audio_samples_allocator.allocate(INPUT_BUFFER_SIZE * sizeof(int16_t)); - if (this->input_buffer_ == nullptr) { - ESP_LOGE(TAG, "Could not allocate input buffer"); - return false; - } - } - - if (this->preprocessor_audio_buffer_ == nullptr) { - this->preprocessor_audio_buffer_ = audio_samples_allocator.allocate(this->new_samples_to_get_()); - if (this->preprocessor_audio_buffer_ == nullptr) { - ESP_LOGE(TAG, "Could not allocate the audio preprocessor's buffer."); - return false; - } - } - - if (this->ring_buffer_.use_count() == 0) { - this->ring_buffer_ = RingBuffer::create(BUFFER_SIZE * sizeof(int16_t)); - if (this->ring_buffer_.use_count() == 0) { - ESP_LOGE(TAG, "Could not allocate ring buffer"); - return false; - } - } - - return true; -} - -void MicroWakeWord::deallocate_buffers_() { - ExternalRAMAllocator audio_samples_allocator(ExternalRAMAllocator::ALLOW_FAILURE); - if (this->input_buffer_ != nullptr) { - audio_samples_allocator.deallocate(this->input_buffer_, INPUT_BUFFER_SIZE * sizeof(int16_t)); - this->input_buffer_ = nullptr; - } - - if (this->preprocessor_audio_buffer_ != nullptr) { - audio_samples_allocator.deallocate(this->preprocessor_audio_buffer_, this->new_samples_to_get_()); - this->preprocessor_audio_buffer_ = nullptr; - } - - this->ring_buffer_.reset(); -} - -bool MicroWakeWord::load_models_() { - // Setup preprocesor feature generator - if (!FrontendPopulateState(&this->frontend_config_, &this->frontend_state_, AUDIO_SAMPLE_FREQUENCY)) { - ESP_LOGD(TAG, "Failed to populate frontend state"); - FrontendFreeStateContents(&this->frontend_state_); - return false; - } - - // Setup streaming models - for (auto &model : this->wake_word_models_) { - if (!model.load_model(this->streaming_op_resolver_)) { - ESP_LOGE(TAG, "Failed to initialize a wake word model."); - return false; - } - } -#ifdef USE_MICRO_WAKE_WORD_VAD - if (!this->vad_model_->load_model(this->streaming_op_resolver_)) { - ESP_LOGE(TAG, "Failed to initialize VAD model."); - return false; - } -#endif - - return true; -} - -void MicroWakeWord::unload_models_() { - FrontendFreeStateContents(&this->frontend_state_); - - for (auto &model : this->wake_word_models_) { - model.unload_model(); - } -#ifdef USE_MICRO_WAKE_WORD_VAD - this->vad_model_->unload_model(); -#endif -} - -void MicroWakeWord::update_model_probabilities_() { - int8_t audio_features[PREPROCESSOR_FEATURE_SIZE]; - - if (!this->generate_features_for_window_(audio_features)) { - return; - } - - // Increase the counter since the last positive detection - this->ignore_windows_ = std::min(this->ignore_windows_ + 1, 0); - - for (auto &model : this->wake_word_models_) { - // Perform inference - model.perform_streaming_inference(audio_features); - } -#ifdef USE_MICRO_WAKE_WORD_VAD - this->vad_model_->perform_streaming_inference(audio_features); -#endif -} - -bool MicroWakeWord::detect_wake_words_() { - // Verify we have processed samples since the last positive detection - if (this->ignore_windows_ < 0) { - return false; - } - -#ifdef USE_MICRO_WAKE_WORD_VAD - bool vad_state = this->vad_model_->determine_detected(); -#endif - - for (auto &model : this->wake_word_models_) { - if (model.determine_detected()) { -#ifdef USE_MICRO_WAKE_WORD_VAD - if (vad_state) { -#endif - this->detected_wake_word_ = model.get_wake_word(); - return true; -#ifdef USE_MICRO_WAKE_WORD_VAD - } else { - ESP_LOGD(TAG, "Wake word model predicts %s, but VAD model doesn't.", model.get_wake_word().c_str()); - } -#endif - } - } - - return false; -} - -bool MicroWakeWord::has_enough_samples_() { - return this->ring_buffer_->available() >= - (this->features_step_size_ * (AUDIO_SAMPLE_FREQUENCY / 1000)) * sizeof(int16_t); -} - -bool MicroWakeWord::generate_features_for_window_(int8_t features[PREPROCESSOR_FEATURE_SIZE]) { - // Ensure we have enough new audio samples in the ring buffer for a full window - if (!this->has_enough_samples_()) { - return false; - } - - size_t bytes_read = this->ring_buffer_->read((void *) (this->preprocessor_audio_buffer_), - this->new_samples_to_get_() * sizeof(int16_t), pdMS_TO_TICKS(200)); - - if (bytes_read == 0) { - ESP_LOGE(TAG, "Could not read data from Ring Buffer"); - } else if (bytes_read < this->new_samples_to_get_() * sizeof(int16_t)) { - ESP_LOGD(TAG, "Partial Read of Data by Model"); - ESP_LOGD(TAG, "Could only read %d bytes when required %d bytes ", bytes_read, - (int) (this->new_samples_to_get_() * sizeof(int16_t))); - return false; - } - - size_t num_samples_read; - struct FrontendOutput frontend_output = FrontendProcessSamples( - &this->frontend_state_, this->preprocessor_audio_buffer_, this->new_samples_to_get_(), &num_samples_read); +size_t MicroWakeWord::generate_features_(int16_t *audio_buffer, size_t samples_available, + int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE]) { + size_t processed_samples = 0; + struct FrontendOutput frontend_output = + FrontendProcessSamples(&this->frontend_state_, audio_buffer, samples_available, &processed_samples); for (size_t i = 0; i < frontend_output.size; ++i) { // These scaling values are set to match the TFLite audio frontend int8 output. @@ -379,8 +393,8 @@ bool MicroWakeWord::generate_features_for_window_(int8_t features[PREPROCESSOR_F // for historical reasons, to match up with the output of other feature // generators. // The process is then further complicated when we quantize the model. This - // means we have to scale the 0.0 to 26.0 real values to the -128 to 127 - // signed integer numbers. + // means we have to scale the 0.0 to 26.0 real values to the -128 (INT8_MIN) + // to 127 (INT8_MAX) signed integer numbers. // All this means that to get matching values from our integer feature // output into the tensor input, we have to perform: // input = (((feature / 25.6) / 26.0) * 256) - 128 @@ -389,74 +403,63 @@ bool MicroWakeWord::generate_features_for_window_(int8_t features[PREPROCESSOR_F constexpr int32_t value_scale = 256; constexpr int32_t value_div = 666; // 666 = 25.6 * 26.0 after rounding int32_t value = ((frontend_output.values[i] * value_scale) + (value_div / 2)) / value_div; - value -= 128; - if (value < -128) { - value = -128; - } - if (value > 127) { - value = 127; - } - features[i] = value; + + value -= INT8_MIN; + features_buffer[i] = clamp(value, INT8_MIN, INT8_MAX); } - return true; + return processed_samples; } -void MicroWakeWord::reset_states_() { - ESP_LOGD(TAG, "Resetting buffers and probabilities"); - this->ring_buffer_->reset(); - this->ignore_windows_ = -MIN_SLICES_BEFORE_DETECTION; +void MicroWakeWord::process_probabilities_() { +#ifdef USE_MICRO_WAKE_WORD_VAD + DetectionEvent vad_state = this->vad_model_->determine_detected(); + + this->vad_state_ = vad_state.detected; // atomic write, so thread safe +#endif + for (auto &model : this->wake_word_models_) { - model.reset_probabilities(); + if (model->get_unprocessed_probability_status()) { + // Only detect wake words if there is a new probability since the last check + DetectionEvent wake_word_state = model->determine_detected(); + if (wake_word_state.detected) { +#ifdef USE_MICRO_WAKE_WORD_VAD + if (vad_state.detected) { +#endif + xQueueSend(this->detection_queue_, &wake_word_state, portMAX_DELAY); + model->reset_probabilities(); +#ifdef USE_MICRO_WAKE_WORD_VAD + } else { + wake_word_state.blocked_by_vad = true; + xQueueSend(this->detection_queue_, &wake_word_state, portMAX_DELAY); + } +#endif + } + } + } +} + +void MicroWakeWord::unload_models_() { + for (auto &model : this->wake_word_models_) { + model->unload_model(); } #ifdef USE_MICRO_WAKE_WORD_VAD - this->vad_model_->reset_probabilities(); + this->vad_model_->unload_model(); #endif } -bool MicroWakeWord::register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver) { - if (op_resolver.AddCallOnce() != kTfLiteOk) - return false; - if (op_resolver.AddVarHandle() != kTfLiteOk) - return false; - if (op_resolver.AddReshape() != kTfLiteOk) - return false; - if (op_resolver.AddReadVariable() != kTfLiteOk) - return false; - if (op_resolver.AddStridedSlice() != kTfLiteOk) - return false; - if (op_resolver.AddConcatenation() != kTfLiteOk) - return false; - if (op_resolver.AddAssignVariable() != kTfLiteOk) - return false; - if (op_resolver.AddConv2D() != kTfLiteOk) - return false; - if (op_resolver.AddMul() != kTfLiteOk) - return false; - if (op_resolver.AddAdd() != kTfLiteOk) - return false; - if (op_resolver.AddMean() != kTfLiteOk) - return false; - if (op_resolver.AddFullyConnected() != kTfLiteOk) - return false; - if (op_resolver.AddLogistic() != kTfLiteOk) - return false; - if (op_resolver.AddQuantize() != kTfLiteOk) - return false; - if (op_resolver.AddDepthwiseConv2D() != kTfLiteOk) - return false; - if (op_resolver.AddAveragePool2D() != kTfLiteOk) - return false; - if (op_resolver.AddMaxPool2D() != kTfLiteOk) - return false; - if (op_resolver.AddPad() != kTfLiteOk) - return false; - if (op_resolver.AddPack() != kTfLiteOk) - return false; - if (op_resolver.AddSplitV() != kTfLiteOk) - return false; +bool MicroWakeWord::update_model_probabilities_(const int8_t audio_features[PREPROCESSOR_FEATURE_SIZE]) { + bool success = true; - return true; + for (auto &model : this->wake_word_models_) { + // Perform inference + success = success & model->perform_streaming_inference(audio_features); + } +#ifdef USE_MICRO_WAKE_WORD_VAD + success = success & this->vad_model_->perform_streaming_inference(audio_features); +#endif + + return success; } } // namespace micro_wake_word diff --git a/esphome/components/micro_wake_word/micro_wake_word.h b/esphome/components/micro_wake_word/micro_wake_word.h index b06d35ca1f..626b8bffb8 100644 --- a/esphome/components/micro_wake_word/micro_wake_word.h +++ b/esphome/components/micro_wake_word/micro_wake_word.h @@ -5,33 +5,27 @@ #include "preprocessor_settings.h" #include "streaming_model.h" +#include "esphome/components/microphone/microphone_source.h" + #include "esphome/core/automation.h" #include "esphome/core/component.h" #include "esphome/core/ring_buffer.h" -#include "esphome/components/microphone/microphone_source.h" +#include +#include #include -#include -#include -#include - namespace esphome { namespace micro_wake_word { enum State { - IDLE, - START_MICROPHONE, - STARTING_MICROPHONE, + STARTING, DETECTING_WAKE_WORD, - STOP_MICROPHONE, - STOPPING_MICROPHONE, + STOPPING, + STOPPED, }; -// The number of audio slices to process before accepting a positive detection -static const uint8_t MIN_SLICES_BEFORE_DETECTION = 74; - class MicroWakeWord : public Component { public: void setup() override; @@ -42,7 +36,7 @@ class MicroWakeWord : public Component { void start(); void stop(); - bool is_running() const { return this->state_ != State::IDLE; } + bool is_running() const { return this->state_ != State::STOPPED; } void set_features_step_size(uint8_t step_size) { this->features_step_size_ = step_size; } @@ -50,118 +44,87 @@ class MicroWakeWord : public Component { this->microphone_source_ = microphone_source; } + void set_stop_after_detection(bool stop_after_detection) { this->stop_after_detection_ = stop_after_detection; } + Trigger *get_wake_word_detected_trigger() const { return this->wake_word_detected_trigger_; } - void add_wake_word_model(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_average_size, - const std::string &wake_word, size_t tensor_arena_size); + void add_wake_word_model(WakeWordModel *model); #ifdef USE_MICRO_WAKE_WORD_VAD - void add_vad_model(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size, + void add_vad_model(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size); + + // Intended for the voice assistant component to fetch VAD status + bool get_vad_state() { return this->vad_state_; } #endif + // Intended for the voice assistant component to access which wake words are available + // Since these are pointers to the WakeWordModel objects, the voice assistant component can enable or disable them + std::vector get_wake_words(); + protected: microphone::MicrophoneSource *microphone_source_{nullptr}; Trigger *wake_word_detected_trigger_ = new Trigger(); - State state_{State::IDLE}; + State state_{State::STOPPED}; - std::shared_ptr ring_buffer_; - - std::vector wake_word_models_; + std::weak_ptr ring_buffer_; + std::vector wake_word_models_; #ifdef USE_MICRO_WAKE_WORD_VAD std::unique_ptr vad_model_; + bool vad_state_{false}; #endif - tflite::MicroMutableOpResolver<20> streaming_op_resolver_; + bool pending_start_{false}; + bool pending_stop_{false}; + + bool stop_after_detection_; + + uint8_t features_step_size_; // Audio frontend handles generating spectrogram features struct FrontendConfig frontend_config_; struct FrontendState frontend_state_; - // When the wake word detection first starts, we ignore this many audio - // feature slices before accepting a positive detection - int16_t ignore_windows_{-MIN_SLICES_BEFORE_DETECTION}; + // Handles managing the stop/state of the inference task + EventGroupHandle_t event_group_; - uint8_t features_step_size_; + // Used to send messages about the models' states to the main loop + QueueHandle_t detection_queue_; - // Stores audio read from the microphone before being added to the ring buffer. - int16_t *input_buffer_{nullptr}; - // Stores audio to be fed into the audio frontend for generating features. - int16_t *preprocessor_audio_buffer_{nullptr}; + static void inference_task(void *params); + TaskHandle_t inference_task_handle_{nullptr}; - bool detected_{false}; - std::string detected_wake_word_{""}; + /// @brief Suspends the inference task + void suspend_task_(); + /// @brief Resumes the inference task + void resume_task_(); void set_state_(State state); - /// @brief Tests if there are enough samples in the ring buffer to generate new features. - /// @return True if enough samples, false otherwise. - bool has_enough_samples_(); + /// @brief Generates spectrogram features from an input buffer of audio samples + /// @param audio_buffer (int16_t *) Buffer containing input audio samples + /// @param samples_available (size_t) Number of samples avaiable in the input buffer + /// @param features_buffer (int8_t *) Buffer to store generated features + /// @return (size_t) Number of samples processed from the input buffer + size_t generate_features_(int16_t *audio_buffer, size_t samples_available, + int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE]); - /// @brief Allocates memory for input_buffer_, preprocessor_audio_buffer_, and ring_buffer_ - /// @return True if successful, false otherwise - bool allocate_buffers_(); + /// @brief Processes any new probabilities for each model. If any wake word is detected, it will send a DetectionEvent + /// to the detection_queue_. + void process_probabilities_(); - /// @brief Frees memory allocated for input_buffer_ and preprocessor_audio_buffer_ - void deallocate_buffers_(); - - /// @brief Loads streaming models and prepares the feature generation frontend - /// @return True if successful, false otherwise - bool load_models_(); - - /// @brief Deletes each model's TFLite interpreters and frees tensor arena memory. Frees memory used by the feature - /// generation frontend. + /// @brief Deletes each model's TFLite interpreters and frees tensor arena memory. void unload_models_(); - /** Performs inference with each configured model - * - * If enough audio samples are available, it will generate one slice of new features. - * It then loops through and performs inference with each of the loaded models. - */ - void update_model_probabilities_(); - - /** Checks every model's recent probabilities to determine if the wake word has been predicted - * - * Verifies the models have processed enough new samples for accurate predictions. - * Sets detected_wake_word_ to the wake word, if one is detected. - * @return True if a wake word is predicted, false otherwise - */ - bool detect_wake_words_(); - - /** Generates features for a window of audio samples - * - * Reads samples from the ring buffer and feeds them into the preprocessor frontend. - * Adapted from TFLite microspeech frontend. - * @param features int8_t array to store the audio features - * @return True if successful, false otherwise. - */ - bool generate_features_for_window_(int8_t features[PREPROCESSOR_FEATURE_SIZE]); - - /// @brief Resets the ring buffer, ignore_windows_, and sliding window probabilities - void reset_states_(); - - /// @brief Returns true if successfully registered the streaming model's TensorFlow operations - bool register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver); + /// @brief Runs an inference with each model using the new spectrogram features + /// @param audio_features (int8_t *) Buffer containing new spectrogram features + /// @return True if successful, false if any errors were encountered + bool update_model_probabilities_(const int8_t audio_features[PREPROCESSOR_FEATURE_SIZE]); inline uint16_t new_samples_to_get_() { return (this->features_step_size_ * (AUDIO_SAMPLE_FREQUENCY / 1000)); } }; -template class StartAction : public Action, public Parented { - public: - void play(Ts... x) override { this->parent_->start(); } -}; - -template class StopAction : public Action, public Parented { - public: - void play(Ts... x) override { this->parent_->stop(); } -}; - -template class IsRunningCondition : public Condition, public Parented { - public: - bool check(Ts... x) override { return this->parent_->is_running(); } -}; - } // namespace micro_wake_word } // namespace esphome diff --git a/esphome/components/micro_wake_word/preprocessor_settings.h b/esphome/components/micro_wake_word/preprocessor_settings.h index 03f4fb5230..025e21c5f7 100644 --- a/esphome/components/micro_wake_word/preprocessor_settings.h +++ b/esphome/components/micro_wake_word/preprocessor_settings.h @@ -7,6 +7,10 @@ namespace esphome { namespace micro_wake_word { +// Settings for controlling the spectrogram feature generation by the preprocessor. +// These must match the settings used when training a particular model. +// All microWakeWord models have been trained with these specific paramters. + // The number of features the audio preprocessor generates per slice static const uint8_t PREPROCESSOR_FEATURE_SIZE = 40; // Duration of each slice used as input into the preprocessor @@ -14,6 +18,21 @@ static const uint8_t FEATURE_DURATION_MS = 30; // Audio sample frequency in hertz static const uint16_t AUDIO_SAMPLE_FREQUENCY = 16000; +static const float FILTERBANK_LOWER_BAND_LIMIT = 125.0; +static const float FILTERBANK_UPPER_BAND_LIMIT = 7500.0; + +static const uint8_t NOISE_REDUCTION_SMOOTHING_BITS = 10; +static const float NOISE_REDUCTION_EVEN_SMOOTHING = 0.025; +static const float NOISE_REDUCTION_ODD_SMOOTHING = 0.06; +static const float NOISE_REDUCTION_MIN_SIGNAL_REMAINING = 0.05; + +static const bool PCAN_GAIN_CONTROL_ENABLE_PCAN = true; +static const float PCAN_GAIN_CONTROL_STRENGTH = 0.95; +static const float PCAN_GAIN_CONTROL_OFFSET = 80.0; +static const uint8_t PCAN_GAIN_CONTROL_GAIN_BITS = 21; + +static const bool LOG_SCALE_ENABLE_LOG = true; +static const uint8_t LOG_SCALE_SCALE_SHIFT = 6; } // namespace micro_wake_word } // namespace esphome diff --git a/esphome/components/micro_wake_word/streaming_model.cpp b/esphome/components/micro_wake_word/streaming_model.cpp index d0d2e2df05..6512c0f569 100644 --- a/esphome/components/micro_wake_word/streaming_model.cpp +++ b/esphome/components/micro_wake_word/streaming_model.cpp @@ -1,8 +1,7 @@ -#ifdef USE_ESP_IDF - #include "streaming_model.h" -#include "esphome/core/hal.h" +#ifdef USE_ESP_IDF + #include "esphome/core/helpers.h" #include "esphome/core/log.h" @@ -13,18 +12,18 @@ namespace micro_wake_word { void WakeWordModel::log_model_config() { ESP_LOGCONFIG(TAG, " - Wake Word: %s", this->wake_word_.c_str()); - ESP_LOGCONFIG(TAG, " Probability cutoff: %.3f", this->probability_cutoff_); + ESP_LOGCONFIG(TAG, " Probability cutoff: %.2f", this->probability_cutoff_ / 255.0f); ESP_LOGCONFIG(TAG, " Sliding window size: %d", this->sliding_window_size_); } void VADModel::log_model_config() { ESP_LOGCONFIG(TAG, " - VAD Model"); - ESP_LOGCONFIG(TAG, " Probability cutoff: %.3f", this->probability_cutoff_); + ESP_LOGCONFIG(TAG, " Probability cutoff: %.2f", this->probability_cutoff_ / 255.0f); ESP_LOGCONFIG(TAG, " Sliding window size: %d", this->sliding_window_size_); } -bool StreamingModel::load_model(tflite::MicroMutableOpResolver<20> &op_resolver) { - ExternalRAMAllocator arena_allocator(ExternalRAMAllocator::ALLOW_FAILURE); +bool StreamingModel::load_model_() { + RAMAllocator arena_allocator(RAMAllocator::ALLOW_FAILURE); if (this->tensor_arena_ == nullptr) { this->tensor_arena_ = arena_allocator.allocate(this->tensor_arena_size_); @@ -51,8 +50,9 @@ bool StreamingModel::load_model(tflite::MicroMutableOpResolver<20> &op_resolver) } if (this->interpreter_ == nullptr) { - this->interpreter_ = make_unique( - tflite::GetModel(this->model_start_), op_resolver, this->tensor_arena_, this->tensor_arena_size_, this->mrv_); + this->interpreter_ = + make_unique(tflite::GetModel(this->model_start_), this->streaming_op_resolver_, + this->tensor_arena_, this->tensor_arena_size_, this->mrv_); if (this->interpreter_->AllocateTensors() != kTfLiteOk) { ESP_LOGE(TAG, "Failed to allocate tensors for the streaming model"); return false; @@ -84,34 +84,55 @@ bool StreamingModel::load_model(tflite::MicroMutableOpResolver<20> &op_resolver) } } + this->loaded_ = true; + this->reset_probabilities(); return true; } void StreamingModel::unload_model() { this->interpreter_.reset(); - ExternalRAMAllocator arena_allocator(ExternalRAMAllocator::ALLOW_FAILURE); + RAMAllocator arena_allocator(RAMAllocator::ALLOW_FAILURE); - arena_allocator.deallocate(this->tensor_arena_, this->tensor_arena_size_); - this->tensor_arena_ = nullptr; - arena_allocator.deallocate(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE); - this->var_arena_ = nullptr; + if (this->tensor_arena_ != nullptr) { + arena_allocator.deallocate(this->tensor_arena_, this->tensor_arena_size_); + this->tensor_arena_ = nullptr; + } + + if (this->var_arena_ != nullptr) { + arena_allocator.deallocate(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE); + this->var_arena_ = nullptr; + } + + this->loaded_ = false; } bool StreamingModel::perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE]) { - if (this->interpreter_ != nullptr) { + if (this->enabled_ && !this->loaded_) { + // Model is enabled but isn't loaded + if (!this->load_model_()) { + return false; + } + } + + if (!this->enabled_ && this->loaded_) { + // Model is disabled but still loaded + this->unload_model(); + return true; + } + + if (this->loaded_) { TfLiteTensor *input = this->interpreter_->input(0); + uint8_t stride = this->interpreter_->input(0)->dims->data[1]; + this->current_stride_step_ = this->current_stride_step_ % stride; + std::memmove( (int8_t *) (tflite::GetTensorData(input)) + PREPROCESSOR_FEATURE_SIZE * this->current_stride_step_, features, PREPROCESSOR_FEATURE_SIZE); ++this->current_stride_step_; - uint8_t stride = this->interpreter_->input(0)->dims->data[1]; - if (this->current_stride_step_ >= stride) { - this->current_stride_step_ = 0; - TfLiteStatus invoke_status = this->interpreter_->Invoke(); if (invoke_status != kTfLiteOk) { ESP_LOGW(TAG, "Streaming interpreter invoke failed"); @@ -124,65 +145,159 @@ bool StreamingModel::perform_streaming_inference(const int8_t features[PREPROCES if (this->last_n_index_ == this->sliding_window_size_) this->last_n_index_ = 0; this->recent_streaming_probabilities_[this->last_n_index_] = output->data.uint8[0]; // probability; + this->unprocessed_probability_status_ = true; } - return true; + this->ignore_windows_ = std::min(this->ignore_windows_ + 1, 0); } - ESP_LOGE(TAG, "Streaming interpreter is not initialized."); - return false; + return true; } void StreamingModel::reset_probabilities() { for (auto &prob : this->recent_streaming_probabilities_) { prob = 0; } + this->ignore_windows_ = -MIN_SLICES_BEFORE_DETECTION; } -WakeWordModel::WakeWordModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_average_size, - const std::string &wake_word, size_t tensor_arena_size) { +WakeWordModel::WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t probability_cutoff, + size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size, + bool default_enabled, bool internal_only) { + this->id_ = id; this->model_start_ = model_start; this->probability_cutoff_ = probability_cutoff; this->sliding_window_size_ = sliding_window_average_size; this->recent_streaming_probabilities_.resize(sliding_window_average_size, 0); this->wake_word_ = wake_word; this->tensor_arena_size_ = tensor_arena_size; + this->register_streaming_ops_(this->streaming_op_resolver_); + this->current_stride_step_ = 0; + this->internal_only_ = internal_only; + + this->pref_ = global_preferences->make_preference(fnv1_hash(id)); + bool enabled; + if (this->pref_.load(&enabled)) { + // Use the enabled state loaded from flash + this->enabled_ = enabled; + } else { + // If no state saved, then use the default + this->enabled_ = default_enabled; + } }; -bool WakeWordModel::determine_detected() { +void WakeWordModel::enable() { + this->enabled_ = true; + if (!this->internal_only_) { + this->pref_.save(&this->enabled_); + } +} + +void WakeWordModel::disable() { + this->enabled_ = false; + if (!this->internal_only_) { + this->pref_.save(&this->enabled_); + } +} + +DetectionEvent WakeWordModel::determine_detected() { + DetectionEvent detection_event; + detection_event.wake_word = &this->wake_word_; + detection_event.max_probability = 0; + detection_event.average_probability = 0; + + if ((this->ignore_windows_ < 0) || !this->enabled_) { + detection_event.detected = false; + return detection_event; + } + uint32_t sum = 0; for (auto &prob : this->recent_streaming_probabilities_) { + detection_event.max_probability = std::max(detection_event.max_probability, prob); sum += prob; } - float sliding_window_average = static_cast(sum) / static_cast(255 * this->sliding_window_size_); + detection_event.average_probability = sum / this->sliding_window_size_; + detection_event.detected = sum > this->probability_cutoff_ * this->sliding_window_size_; - // Detect the wake word if the sliding window average is above the cutoff - if (sliding_window_average > this->probability_cutoff_) { - ESP_LOGD(TAG, "The '%s' model sliding average probability is %.3f and most recent probability is %.3f", - this->wake_word_.c_str(), sliding_window_average, - this->recent_streaming_probabilities_[this->last_n_index_] / (255.0)); - return true; - } - return false; + this->unprocessed_probability_status_ = false; + return detection_event; } -VADModel::VADModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size, +VADModel::VADModel(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size) { this->model_start_ = model_start; this->probability_cutoff_ = probability_cutoff; this->sliding_window_size_ = sliding_window_size; this->recent_streaming_probabilities_.resize(sliding_window_size, 0); this->tensor_arena_size_ = tensor_arena_size; -}; + this->register_streaming_ops_(this->streaming_op_resolver_); +} + +DetectionEvent VADModel::determine_detected() { + DetectionEvent detection_event; + detection_event.max_probability = 0; + detection_event.average_probability = 0; + + if (!this->enabled_) { + // We disabled the VAD model for some reason... so we shouldn't block wake words from being detected + detection_event.detected = true; + return detection_event; + } -bool VADModel::determine_detected() { uint32_t sum = 0; for (auto &prob : this->recent_streaming_probabilities_) { + detection_event.max_probability = std::max(detection_event.max_probability, prob); sum += prob; } - float sliding_window_average = static_cast(sum) / static_cast(255 * this->sliding_window_size_); + detection_event.average_probability = sum / this->sliding_window_size_; + detection_event.detected = sum > (this->probability_cutoff_ * this->sliding_window_size_); - return sliding_window_average > this->probability_cutoff_; + return detection_event; +} + +bool StreamingModel::register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver) { + if (op_resolver.AddCallOnce() != kTfLiteOk) + return false; + if (op_resolver.AddVarHandle() != kTfLiteOk) + return false; + if (op_resolver.AddReshape() != kTfLiteOk) + return false; + if (op_resolver.AddReadVariable() != kTfLiteOk) + return false; + if (op_resolver.AddStridedSlice() != kTfLiteOk) + return false; + if (op_resolver.AddConcatenation() != kTfLiteOk) + return false; + if (op_resolver.AddAssignVariable() != kTfLiteOk) + return false; + if (op_resolver.AddConv2D() != kTfLiteOk) + return false; + if (op_resolver.AddMul() != kTfLiteOk) + return false; + if (op_resolver.AddAdd() != kTfLiteOk) + return false; + if (op_resolver.AddMean() != kTfLiteOk) + return false; + if (op_resolver.AddFullyConnected() != kTfLiteOk) + return false; + if (op_resolver.AddLogistic() != kTfLiteOk) + return false; + if (op_resolver.AddQuantize() != kTfLiteOk) + return false; + if (op_resolver.AddDepthwiseConv2D() != kTfLiteOk) + return false; + if (op_resolver.AddAveragePool2D() != kTfLiteOk) + return false; + if (op_resolver.AddMaxPool2D() != kTfLiteOk) + return false; + if (op_resolver.AddPad() != kTfLiteOk) + return false; + if (op_resolver.AddPack() != kTfLiteOk) + return false; + if (op_resolver.AddSplitV() != kTfLiteOk) + return false; + + return true; } } // namespace micro_wake_word diff --git a/esphome/components/micro_wake_word/streaming_model.h b/esphome/components/micro_wake_word/streaming_model.h index 0d85579f35..5bd1cf356a 100644 --- a/esphome/components/micro_wake_word/streaming_model.h +++ b/esphome/components/micro_wake_word/streaming_model.h @@ -4,6 +4,8 @@ #include "preprocessor_settings.h" +#include "esphome/core/preferences.h" + #include #include #include @@ -11,30 +13,63 @@ namespace esphome { namespace micro_wake_word { +static const uint8_t MIN_SLICES_BEFORE_DETECTION = 100; static const uint32_t STREAMING_MODEL_VARIABLE_ARENA_SIZE = 1024; +struct DetectionEvent { + std::string *wake_word; + bool detected; + bool partially_detection; // Set if the most recent probability exceed the threshold, but the sliding window average + // hasn't yet + uint8_t max_probability; + uint8_t average_probability; + bool blocked_by_vad = false; +}; + class StreamingModel { public: virtual void log_model_config() = 0; - virtual bool determine_detected() = 0; + virtual DetectionEvent determine_detected() = 0; + // Performs inference on the given features. + // - If the model is enabled but not loaded, it will load it + // - If the model is disabled but loaded, it will unload it + // Returns true if sucessful or false if there is an error bool perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE]); - /// @brief Sets all recent_streaming_probabilities to 0 + /// @brief Sets all recent_streaming_probabilities to 0 and resets the ignore window count void reset_probabilities(); - /// @brief Allocates tensor and variable arenas and sets up the model interpreter - /// @param op_resolver MicroMutableOpResolver object that must exist until the model is unloaded - /// @return True if successful, false otherwise - bool load_model(tflite::MicroMutableOpResolver<20> &op_resolver); - /// @brief Destroys the TFLite interpreter and frees the tensor and variable arenas' memory void unload_model(); - protected: - uint8_t current_stride_step_{0}; + /// @brief Enable the model. The next performing_streaming_inference call will load it. + virtual void enable() { this->enabled_ = true; } - float probability_cutoff_; + /// @brief Disable the model. The next performing_streaming_inference call will unload it. + virtual void disable() { this->enabled_ = false; } + + /// @brief Return true if the model is enabled. + bool is_enabled() { return this->enabled_; } + + bool get_unprocessed_probability_status() { return this->unprocessed_probability_status_; } + + protected: + /// @brief Allocates tensor and variable arenas and sets up the model interpreter + /// @return True if successful, false otherwise + bool load_model_(); + /// @brief Returns true if successfully registered the streaming model's TensorFlow operations + bool register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver); + + tflite::MicroMutableOpResolver<20> streaming_op_resolver_; + + bool loaded_{false}; + bool enabled_{true}; + bool unprocessed_probability_status_{false}; + uint8_t current_stride_step_{0}; + int16_t ignore_windows_{-MIN_SLICES_BEFORE_DETECTION}; + + uint8_t probability_cutoff_; // Quantized probability cutoff mapping 0.0 - 1.0 to 0 - 255 size_t sliding_window_size_; size_t last_n_index_{0}; size_t tensor_arena_size_; @@ -50,32 +85,62 @@ class StreamingModel { class WakeWordModel final : public StreamingModel { public: - WakeWordModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_average_size, - const std::string &wake_word, size_t tensor_arena_size); + /// @brief Constructs a wake word model object + /// @param id (std::string) identifier for this model + /// @param model_start (const uint8_t *) pointer to the start of the model's TFLite FlatBuffer + /// @param probability_cutoff (uint8_t) probability cutoff for acceping the wake word has been said + /// @param sliding_window_average_size (size_t) the length of the sliding window computing the mean rolling + /// probability + /// @param wake_word (std::string) Friendly name of the wake word + /// @param tensor_arena_size (size_t) Size in bytes for allocating the tensor arena + /// @param default_enabled (bool) If true, it will be enabled by default on first boot + /// @param internal_only (bool) If true, the model will not be exposed to HomeAssistant as an available model + WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t probability_cutoff, + size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size, + bool default_enabled, bool internal_only); void log_model_config() override; /// @brief Checks for the wake word by comparing the mean probability in the sliding window with the probability /// cutoff /// @return True if wake word is detected, false otherwise - bool determine_detected() override; + DetectionEvent determine_detected() override; + const std::string &get_id() const { return this->id_; } const std::string &get_wake_word() const { return this->wake_word_; } + void add_trained_language(const std::string &language) { this->trained_languages_.push_back(language); } + const std::vector &get_trained_languages() const { return this->trained_languages_; } + + /// @brief Enable the model and save to flash. The next performing_streaming_inference call will load it. + void enable() override; + + /// @brief Disable the model and save to flash. The next performing_streaming_inference call will unload it. + void disable() override; + + bool get_internal_only() { return this->internal_only_; } + protected: + std::string id_; std::string wake_word_; + std::vector trained_languages_; + + bool internal_only_; + + ESPPreferenceObject pref_; }; class VADModel final : public StreamingModel { public: - VADModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size); + VADModel(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size, + size_t tensor_arena_size); void log_model_config() override; /// @brief Checks for voice activity by comparing the max probability in the sliding window with the probability /// cutoff /// @return True if voice activity is detected, false otherwise - bool determine_detected() override; + DetectionEvent determine_detected() override; }; } // namespace micro_wake_word diff --git a/esphome/core/defines.h b/esphome/core/defines.h index 81ff6999ba..de963313db 100644 --- a/esphome/core/defines.h +++ b/esphome/core/defines.h @@ -79,6 +79,7 @@ #define USE_LVGL_TEXTAREA #define USE_LVGL_TILEVIEW #define USE_LVGL_TOUCHSCREEN +#define USE_MICRO_WAKE_WORD #define USE_MD5 #define USE_MDNS #define USE_MEDIA_PLAYER diff --git a/tests/components/micro_wake_word/common.yaml b/tests/components/micro_wake_word/common.yaml index b5507397f8..c051c8dd57 100644 --- a/tests/components/micro_wake_word/common.yaml +++ b/tests/components/micro_wake_word/common.yaml @@ -14,8 +14,24 @@ micro_wake_word: microphone: echo_microphone on_wake_word_detected: - logger.log: "Wake word detected" + - micro_wake_word.stop: + - if: + condition: + - micro_wake_word.model_is_enabled: hey_jarvis_model + then: + - micro_wake_word.disable_model: hey_jarvis_model + else: + - micro_wake_word.enable_model: hey_jarvis_model + - if: + condition: + - not: + - micro_wake_word.is_running: + then: + micro_wake_word.start: + stop_after_detection: false models: - model: hey_jarvis probability_cutoff: 0.7 + id: hey_jarvis_model - model: okay_nabu sliding_window_size: 5 From bf527b033147abe29184ee6a0e78e22f4f8c005d Mon Sep 17 00:00:00 2001 From: Kevin Ahrendt Date: Wed, 30 Apr 2025 19:45:33 -0500 Subject: [PATCH 055/102] [microphone] Bugfix: protect against starting mic if already started (#8656) --- esphome/components/microphone/microphone_source.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/esphome/components/microphone/microphone_source.cpp b/esphome/components/microphone/microphone_source.cpp index dcd3b31622..35e8d5dd4d 100644 --- a/esphome/components/microphone/microphone_source.cpp +++ b/esphome/components/microphone/microphone_source.cpp @@ -14,12 +14,16 @@ void MicrophoneSource::add_data_callback(std::functionenabled_ = true; - this->mic_->start(); + if (!this->enabled_) { + this->enabled_ = true; + this->mic_->start(); + } } void MicrophoneSource::stop() { - this->enabled_ = false; - this->mic_->stop(); + if (this->enabled_) { + this->enabled_ = false; + this->mic_->stop(); + } } std::vector MicrophoneSource::process_audio_(const std::vector &data) { From d2b4dba51f8eb4e75bf9459892d6ddb5e0d74a9b Mon Sep 17 00:00:00 2001 From: Ben Winslow Date: Wed, 30 Apr 2025 20:55:36 -0400 Subject: [PATCH 056/102] Fix typo preventing tt21100 from autosetting the touchscreen res. (#8662) --- esphome/components/tt21100/touchscreen/tt21100.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/esphome/components/tt21100/touchscreen/tt21100.cpp b/esphome/components/tt21100/touchscreen/tt21100.cpp index 2bea72a59e..d18ce835c1 100644 --- a/esphome/components/tt21100/touchscreen/tt21100.cpp +++ b/esphome/components/tt21100/touchscreen/tt21100.cpp @@ -68,7 +68,7 @@ void TT21100Touchscreen::setup() { this->x_raw_max_ = this->display_->get_native_width(); } if (this->y_raw_max_ == this->y_raw_min_) { - this->x_raw_max_ = this->display_->get_native_height(); + this->y_raw_max_ = this->display_->get_native_height(); } } From e8a3de26424e2431747fea387ff992009e76b962 Mon Sep 17 00:00:00 2001 From: "Andrew J.Swan" Date: Thu, 1 May 2025 04:07:55 +0300 Subject: [PATCH 057/102] Bump FastLed version to 3.9.16 (#8402) --- esphome/components/fastled_base/__init__.py | 5 +---- esphome/components/fastled_base/fastled_light.cpp | 2 +- platformio.ini | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/esphome/components/fastled_base/__init__.py b/esphome/components/fastled_base/__init__.py index 1e70e14f10..11e8423258 100644 --- a/esphome/components/fastled_base/__init__.py +++ b/esphome/components/fastled_base/__init__.py @@ -40,9 +40,6 @@ async def new_fastled_light(config): if CONF_MAX_REFRESH_RATE in config: cg.add(var.set_max_refresh_rate(config[CONF_MAX_REFRESH_RATE])) + cg.add_library("fastled/FastLED", "3.9.16") await light.register_light(var, config) - # https://github.com/FastLED/FastLED/blob/master/library.json - # 3.3.3 has an issue on ESP32 with RMT and fastled_clockless: - # https://github.com/esphome/issues/issues/1375 - cg.add_library("fastled/FastLED", "3.3.2") return var diff --git a/esphome/components/fastled_base/fastled_light.cpp b/esphome/components/fastled_base/fastled_light.cpp index 486364d0c0..3ecdee61b1 100644 --- a/esphome/components/fastled_base/fastled_light.cpp +++ b/esphome/components/fastled_base/fastled_light.cpp @@ -34,7 +34,7 @@ void FastLEDLightOutput::write_state(light::LightState *state) { this->mark_shown_(); ESP_LOGVV(TAG, "Writing RGB values to bus..."); - this->controller_->showLeds(); + this->controller_->showLeds(this->state_parent_->current_values.get_brightness() * 255); } } // namespace fastled_base diff --git a/platformio.ini b/platformio.ini index 656202e372..a2d5d27faf 100644 --- a/platformio.ini +++ b/platformio.ini @@ -63,7 +63,7 @@ lib_deps = Wire ; i2c (Arduino built-int) heman/AsyncMqttClient-esphome@1.0.0 ; mqtt esphome/ESPAsyncWebServer-esphome@3.3.0 ; web_server_base - fastled/FastLED@3.3.2 ; fastled_base + fastled/FastLED@3.9.16 ; fastled_base mikalhart/TinyGPSPlus@1.0.2 ; gps freekode/TM1651@1.0.1 ; tm1651 glmnet/Dsmr@0.7 ; dsmr From 9dcf295df81256b17f67e11f41cd33662a536253 Mon Sep 17 00:00:00 2001 From: Simon <965089+sarthurdev@users.noreply.github.com> Date: Thu, 1 May 2025 03:12:17 +0200 Subject: [PATCH 058/102] [gree] Add support for YAG remotes (#7418) --- esphome/components/gree/climate.py | 1 + esphome/components/gree/gree.cpp | 18 ++++++++++++++++-- esphome/components/gree/gree.h | 4 ++-- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/esphome/components/gree/climate.py b/esphome/components/gree/climate.py index 75436f2cf5..389c9fb3c7 100644 --- a/esphome/components/gree/climate.py +++ b/esphome/components/gree/climate.py @@ -18,6 +18,7 @@ MODELS = { "yac": Model.GREE_YAC, "yac1fb9": Model.GREE_YAC1FB9, "yx1ff": Model.GREE_YX1FF, + "yag": Model.GREE_YAG, } CONFIG_SCHEMA = climate_ir.CLIMATE_IR_WITH_RECEIVER_SCHEMA.extend( diff --git a/esphome/components/gree/gree.cpp b/esphome/components/gree/gree.cpp index 6d179a947b..e0cacb4f1e 100644 --- a/esphome/components/gree/gree.cpp +++ b/esphome/components/gree/gree.cpp @@ -22,13 +22,21 @@ void GreeClimate::transmit_state() { remote_state[0] = this->fan_speed_() | this->operation_mode_(); remote_state[1] = this->temperature_(); - if (this->model_ == GREE_YAN || this->model_ == GREE_YX1FF) { + if (this->model_ == GREE_YAN || this->model_ == GREE_YX1FF || this->model_ == GREE_YAG) { remote_state[2] = 0x60; remote_state[3] = 0x50; remote_state[4] = this->vertical_swing_(); } - if (this->model_ == GREE_YAC) { + if (this->model_ == GREE_YAG) { + remote_state[5] = 0x40; + + if (this->vertical_swing_() == GREE_VDIR_SWING || this->horizontal_swing_() == GREE_HDIR_SWING) { + remote_state[0] |= (1 << 6); + } + } + + if (this->model_ == GREE_YAC || this->model_ == GREE_YAG) { remote_state[4] |= (this->horizontal_swing_() << 4); } @@ -57,6 +65,12 @@ void GreeClimate::transmit_state() { // Calculate the checksum if (this->model_ == GREE_YAN || this->model_ == GREE_YX1FF) { remote_state[7] = ((remote_state[0] << 4) + (remote_state[1] << 4) + 0xC0); + } else if (this->model_ == GREE_YAG) { + remote_state[7] = + ((((remote_state[0] & 0x0F) + (remote_state[1] & 0x0F) + (remote_state[2] & 0x0F) + (remote_state[3] & 0x0F) + + ((remote_state[4] & 0xF0) >> 4) + ((remote_state[5] & 0xF0) >> 4) + ((remote_state[6] & 0xF0) >> 4) + 0x0A) & + 0x0F) + << 4); } else { remote_state[7] = ((((remote_state[0] & 0x0F) + (remote_state[1] & 0x0F) + (remote_state[2] & 0x0F) + (remote_state[3] & 0x0F) + diff --git a/esphome/components/gree/gree.h b/esphome/components/gree/gree.h index 6762b41eb0..f91d78cabd 100644 --- a/esphome/components/gree/gree.h +++ b/esphome/components/gree/gree.h @@ -58,7 +58,7 @@ const uint8_t GREE_VDIR_MIDDLE = 0x04; const uint8_t GREE_VDIR_MDOWN = 0x05; const uint8_t GREE_VDIR_DOWN = 0x06; -// Only available on YAC +// Only available on YAC/YAG // Horizontal air directions. Note that these cannot be set on all heat pumps const uint8_t GREE_HDIR_AUTO = 0x00; const uint8_t GREE_HDIR_MANUAL = 0x00; @@ -78,7 +78,7 @@ const uint8_t GREE_PRESET_SLEEP = 0x01; const uint8_t GREE_PRESET_SLEEP_BIT = 0x80; // Model codes -enum Model { GREE_GENERIC, GREE_YAN, GREE_YAA, GREE_YAC, GREE_YAC1FB9, GREE_YX1FF }; +enum Model { GREE_GENERIC, GREE_YAN, GREE_YAA, GREE_YAC, GREE_YAC1FB9, GREE_YX1FF, GREE_YAG }; class GreeClimate : public climate_ir::ClimateIR { public: From 9a9b91b180374eda0082d76ac72ba1ceb914c439 Mon Sep 17 00:00:00 2001 From: Jannik <33796278+SuperPlusUser@users.noreply.github.com> Date: Thu, 1 May 2025 03:12:51 +0200 Subject: [PATCH 059/102] Fix HLW8012 sensor not returning values if change_mode_every is set to never (#8456) --- esphome/components/hlw8012/hlw8012.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/esphome/components/hlw8012/hlw8012.cpp b/esphome/components/hlw8012/hlw8012.cpp index 14e83f60e1..1efc57ab66 100644 --- a/esphome/components/hlw8012/hlw8012.cpp +++ b/esphome/components/hlw8012/hlw8012.cpp @@ -69,7 +69,7 @@ void HLW8012Component::update() { float power = cf_hz * this->power_multiplier_; - if (this->change_mode_at_ != 0) { + if (this->change_mode_at_ != 0 || this->change_mode_every_ == 0) { // Only read cf1 after one cycle. Apparently it's quite unstable after being changed. if (this->current_mode_) { float current = cf1_hz * this->current_multiplier_; From b5975651650677e4b3478b57317d5288dd94b065 Mon Sep 17 00:00:00 2001 From: Pat Satyshur Date: Wed, 30 Apr 2025 20:14:29 -0500 Subject: [PATCH 060/102] Add a function to return the I2C address from an I2CDevice object (#8454) Co-authored-by: Djordje Mandic <6750655+DjordjeMandic@users.noreply.github.com> --- esphome/components/i2c/i2c.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/esphome/components/i2c/i2c.h b/esphome/components/i2c/i2c.h index 8d8e139c61..15f786245b 100644 --- a/esphome/components/i2c/i2c.h +++ b/esphome/components/i2c/i2c.h @@ -139,6 +139,10 @@ class I2CDevice { /// @param address of the device void set_i2c_address(uint8_t address) { address_ = address; } + /// @brief Returns the I2C address of the object. + /// @return the I2C address + uint8_t get_i2c_address() const { return this->address_; } + /// @brief we store the pointer to the I2CBus to use /// @param bus pointer to the I2CBus object void set_i2c_bus(I2CBus *bus) { bus_ = bus; } From 807925fd38f97fc898e9b27c1927d727e4a89f70 Mon Sep 17 00:00:00 2001 From: Anton Sergunov Date: Thu, 1 May 2025 08:03:35 +0600 Subject: [PATCH 061/102] Fix second scrolling run ussue (#8347) --- .../components/max7219digit/max7219digit.cpp | 54 +++++++++---------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/esphome/components/max7219digit/max7219digit.cpp b/esphome/components/max7219digit/max7219digit.cpp index ec9970d1a0..13b75ca734 100644 --- a/esphome/components/max7219digit/max7219digit.cpp +++ b/esphome/components/max7219digit/max7219digit.cpp @@ -4,6 +4,8 @@ #include "esphome/core/hal.h" #include "max7219font.h" +#include + namespace esphome { namespace max7219digit { @@ -61,45 +63,42 @@ void MAX7219Component::dump_config() { } void MAX7219Component::loop() { - uint32_t now = millis(); - + const uint32_t now = millis(); + const uint32_t millis_since_last_scroll = now - this->last_scroll_; + const size_t first_line_size = this->max_displaybuffer_[0].size(); // check if the buffer has shrunk past the current position since last update - if ((this->max_displaybuffer_[0].size() >= this->old_buffer_size_ + 3) || - (this->max_displaybuffer_[0].size() <= this->old_buffer_size_ - 3)) { + if ((first_line_size >= this->old_buffer_size_ + 3) || (first_line_size <= this->old_buffer_size_ - 3)) { + ESP_LOGV(TAG, "Buffer size changed %d to %d", this->old_buffer_size_, first_line_size); this->stepsleft_ = 0; this->display(); - this->old_buffer_size_ = this->max_displaybuffer_[0].size(); + this->old_buffer_size_ = first_line_size; } - // Reset the counter back to 0 when full string has been displayed. - if (this->stepsleft_ > this->max_displaybuffer_[0].size()) - this->stepsleft_ = 0; - - // Return if there is no need to scroll or scroll is off - if (!this->scroll_ || (this->max_displaybuffer_[0].size() <= (size_t) get_width_internal())) { + if (!this->scroll_ || (first_line_size <= (size_t) get_width_internal())) { + ESP_LOGVV(TAG, "Return if there is no need to scroll or scroll is off."); this->display(); return; } - if ((this->stepsleft_ == 0) && (now - this->last_scroll_ < this->scroll_delay_)) { + if ((this->stepsleft_ == 0) && (millis_since_last_scroll < this->scroll_delay_)) { + ESP_LOGVV(TAG, "At first step. Waiting for scroll delay"); this->display(); return; } - // Dwell time at end of string in case of stop at end if (this->scroll_mode_ == ScrollMode::STOP) { - if (this->stepsleft_ >= this->max_displaybuffer_[0].size() - (size_t) get_width_internal() + 1) { - if (now - this->last_scroll_ >= this->scroll_dwell_) { - this->stepsleft_ = 0; - this->last_scroll_ = now; - this->display(); + if (this->stepsleft_ + get_width_internal() == first_line_size + 1) { + if (millis_since_last_scroll < this->scroll_dwell_) { + ESP_LOGVV(TAG, "Dwell time at end of string in case of stop at end. Step %d, since last scroll %d, dwell %d.", + this->stepsleft_, millis_since_last_scroll, this->scroll_dwell_); + return; } - return; + ESP_LOGV(TAG, "Dwell time passed. Continue scrolling."); } } - // Actual call to scroll left action - if (now - this->last_scroll_ >= this->scroll_speed_) { + if (millis_since_last_scroll >= this->scroll_speed_) { + ESP_LOGVV(TAG, "Call to scroll left action"); this->last_scroll_ = now; this->scroll_left(); this->display(); @@ -227,19 +226,20 @@ void MAX7219Component::scroll(bool on_off) { this->set_scroll(on_off); } void MAX7219Component::scroll_left() { for (int chip_line = 0; chip_line < this->num_chip_lines_; chip_line++) { + auto scroll = [&](std::vector &line, uint16_t steps) { + std::rotate(line.begin(), std::next(line.begin(), steps), line.end()); + }; if (this->update_) { this->max_displaybuffer_[chip_line].push_back(this->bckgrnd_); - for (uint16_t i = 0; i < this->stepsleft_; i++) { - this->max_displaybuffer_[chip_line].push_back(this->max_displaybuffer_[chip_line].front()); - this->max_displaybuffer_[chip_line].erase(this->max_displaybuffer_[chip_line].begin()); - } + scroll(this->max_displaybuffer_[chip_line], + (this->stepsleft_ + 1) % (this->max_displaybuffer_[chip_line].size())); } else { - this->max_displaybuffer_[chip_line].push_back(this->max_displaybuffer_[chip_line].front()); - this->max_displaybuffer_[chip_line].erase(this->max_displaybuffer_[chip_line].begin()); + scroll(this->max_displaybuffer_[chip_line], 1); } } this->update_ = false; this->stepsleft_++; + this->stepsleft_ %= this->max_displaybuffer_[0].size(); } void MAX7219Component::send_char(uint8_t chip, uint8_t data) { From 4ec8414050129693d3bb00146330ef7acccbf9a4 Mon Sep 17 00:00:00 2001 From: nworbneb Date: Thu, 1 May 2025 03:27:14 +0100 Subject: [PATCH 062/102] [alarm_control_panel] Allow sensor to trigger when alarm disarmed (#7746) --- .../template/alarm_control_panel/__init__.py | 1 + .../template_alarm_control_panel.cpp | 38 ++++++++++--------- .../template_alarm_control_panel.h | 3 +- 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/esphome/components/template/alarm_control_panel/__init__.py b/esphome/components/template/alarm_control_panel/__init__.py index 0f213857dc..8b13bcd29f 100644 --- a/esphome/components/template/alarm_control_panel/__init__.py +++ b/esphome/components/template/alarm_control_panel/__init__.py @@ -51,6 +51,7 @@ ALARM_SENSOR_TYPES = { "DELAYED": AlarmSensorType.ALARM_SENSOR_TYPE_DELAYED, "INSTANT": AlarmSensorType.ALARM_SENSOR_TYPE_INSTANT, "DELAYED_FOLLOWER": AlarmSensorType.ALARM_SENSOR_TYPE_DELAYED_FOLLOWER, + "INSTANT_ALWAYS": AlarmSensorType.ALARM_SENSOR_TYPE_INSTANT_ALWAYS, } diff --git a/esphome/components/template/alarm_control_panel/template_alarm_control_panel.cpp b/esphome/components/template/alarm_control_panel/template_alarm_control_panel.cpp index 99843417fa..bf1338edbe 100644 --- a/esphome/components/template/alarm_control_panel/template_alarm_control_panel.cpp +++ b/esphome/components/template/alarm_control_panel/template_alarm_control_panel.cpp @@ -58,6 +58,9 @@ void TemplateAlarmControlPanel::dump_config() { case ALARM_SENSOR_TYPE_DELAYED_FOLLOWER: sensor_type = "delayed_follower"; break; + case ALARM_SENSOR_TYPE_INSTANT_ALWAYS: + sensor_type = "instant_always"; + break; case ALARM_SENSOR_TYPE_DELAYED: default: sensor_type = "delayed"; @@ -145,24 +148,25 @@ void TemplateAlarmControlPanel::loop() { continue; } - // If sensor type is of type instant - if (sensor_info.second.type == ALARM_SENSOR_TYPE_INSTANT) { - instant_sensor_not_ready = true; - break; - } - // If sensor type is of type interior follower - if (sensor_info.second.type == ALARM_SENSOR_TYPE_DELAYED_FOLLOWER) { - // Look to see if we are in the pending state - if (this->current_state_ == ACP_STATE_PENDING) { - delayed_sensor_not_ready = true; - } else { + switch (sensor_info.second.type) { + case ALARM_SENSOR_TYPE_INSTANT: instant_sensor_not_ready = true; - } - } - // If sensor type is of type delayed - if (sensor_info.second.type == ALARM_SENSOR_TYPE_DELAYED) { - delayed_sensor_not_ready = true; - break; + break; + case ALARM_SENSOR_TYPE_INSTANT_ALWAYS: + instant_sensor_not_ready = true; + future_state = ACP_STATE_TRIGGERED; + break; + case ALARM_SENSOR_TYPE_DELAYED_FOLLOWER: + // Look to see if we are in the pending state + if (this->current_state_ == ACP_STATE_PENDING) { + delayed_sensor_not_ready = true; + } else { + instant_sensor_not_ready = true; + } + break; + case ALARM_SENSOR_TYPE_DELAYED: + default: + delayed_sensor_not_ready = true; } } } diff --git a/esphome/components/template/alarm_control_panel/template_alarm_control_panel.h b/esphome/components/template/alarm_control_panel/template_alarm_control_panel.h index 9ae69a0422..b29a48dfd7 100644 --- a/esphome/components/template/alarm_control_panel/template_alarm_control_panel.h +++ b/esphome/components/template/alarm_control_panel/template_alarm_control_panel.h @@ -27,7 +27,8 @@ enum BinarySensorFlags : uint16_t { enum AlarmSensorType : uint16_t { ALARM_SENSOR_TYPE_DELAYED = 0, ALARM_SENSOR_TYPE_INSTANT, - ALARM_SENSOR_TYPE_DELAYED_FOLLOWER + ALARM_SENSOR_TYPE_DELAYED_FOLLOWER, + ALARM_SENSOR_TYPE_INSTANT_ALWAYS, }; #endif From 71f81d2f18eaef7232433c367083b0d32cfaeb94 Mon Sep 17 00:00:00 2001 From: uae007 <74835465+uae007@users.noreply.github.com> Date: Thu, 1 May 2025 02:27:59 +0000 Subject: [PATCH 063/102] Component pca9685 - phase_begin always set to zero (#8379) Co-authored-by: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> --- esphome/components/pca9685/pca9685_output.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/esphome/components/pca9685/pca9685_output.cpp b/esphome/components/pca9685/pca9685_output.cpp index d92312355a..1998f8d12f 100644 --- a/esphome/components/pca9685/pca9685_output.cpp +++ b/esphome/components/pca9685/pca9685_output.cpp @@ -101,8 +101,9 @@ void PCA9685Output::loop() { return; const uint16_t num_channels = this->max_channel_ - this->min_channel_ + 1; + const uint16_t phase_delta_begin = 4096 / num_channels; for (uint8_t channel = this->min_channel_; channel <= this->max_channel_; channel++) { - uint16_t phase_begin = uint16_t(channel - this->min_channel_) / num_channels * 4096; + uint16_t phase_begin = (channel - this->min_channel_) * phase_delta_begin; uint16_t phase_end; uint16_t amount = this->pwm_amounts_[channel]; if (amount == 0) { From 62646f5f321821f3a5d5bd578a6e82c2f0811be8 Mon Sep 17 00:00:00 2001 From: Keith Burzinski Date: Wed, 30 Apr 2025 21:30:36 -0500 Subject: [PATCH 064/102] [remote_base] Fix compile error on IDF (#8664) --- esphome/components/remote_base/beo4_protocol.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/esphome/components/remote_base/beo4_protocol.cpp b/esphome/components/remote_base/beo4_protocol.cpp index 9f8d5e72c9..8f5a642401 100644 --- a/esphome/components/remote_base/beo4_protocol.cpp +++ b/esphome/components/remote_base/beo4_protocol.cpp @@ -1,6 +1,8 @@ #include "beo4_protocol.h" #include "esphome/core/log.h" +#include + namespace esphome { namespace remote_base { @@ -81,7 +83,7 @@ optional Beo4Protocol::decode(RemoteReceiveData src) { int32_t jc = 0; uint32_t pre_bit = 0; uint32_t cnt_bit = 0; - ESP_LOGD(TAG, "Beo4: n_sym=%d ", n_sym); + ESP_LOGD(TAG, "Beo4: n_sym=%" PRId32, n_sym); for (jc = 0, ic = 0; ic < (n_sym - 1); ic += 2, jc++) { int32_t pulse_width = src[ic] - src[ic + 1]; // suppress TSOP7000 (dummy pulses) From 0f8a0af2447c56f28a0d1b471b5e20f04c91a6c1 Mon Sep 17 00:00:00 2001 From: Jesse Hills <3060199+jesserockz@users.noreply.github.com> Date: Thu, 1 May 2025 14:32:23 +1200 Subject: [PATCH 065/102] [defines] Fix USE_MICRO_WAKE_WORD position (#8663) --- esphome/core/defines.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/esphome/core/defines.h b/esphome/core/defines.h index de963313db..9f4099e67f 100644 --- a/esphome/core/defines.h +++ b/esphome/core/defines.h @@ -79,7 +79,6 @@ #define USE_LVGL_TEXTAREA #define USE_LVGL_TILEVIEW #define USE_LVGL_TOUCHSCREEN -#define USE_MICRO_WAKE_WORD #define USE_MD5 #define USE_MDNS #define USE_MEDIA_PLAYER @@ -133,7 +132,6 @@ #define USE_ESP32_BLE_SERVER #define USE_ESP32_CAMERA #define USE_IMPROV -#define USE_MICRO_WAKE_WORD_VAD #define USE_MICROPHONE #define USE_PSRAM #define USE_SOCKET_IMPL_BSD_SOCKETS @@ -151,6 +149,8 @@ #ifdef USE_ESP_IDF #define USE_ESP_IDF_VERSION_CODE VERSION_CODE(5, 1, 6) +#define USE_MICRO_WAKE_WORD +#define USE_MICRO_WAKE_WORD_VAD #endif #if defined(USE_ESP32_VARIANT_ESP32S2) From f03b42ced5470a1f8bf54a5851922af3caa15b8a Mon Sep 17 00:00:00 2001 From: lastradanet <101437425+lastradanet@users.noreply.github.com> Date: Wed, 30 Apr 2025 23:17:27 -0400 Subject: [PATCH 066/102] Adding timing budget support for vl53l0x (#7991) Co-authored-by: Brian Davis Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com> --- esphome/components/vl53l0x/sensor.py | 11 +++++++++++ esphome/components/vl53l0x/vl53l0x_sensor.cpp | 6 +++++- esphome/components/vl53l0x/vl53l0x_sensor.h | 3 ++- tests/components/vl53l0x/common.yaml | 1 + 4 files changed, 19 insertions(+), 2 deletions(-) diff --git a/esphome/components/vl53l0x/sensor.py b/esphome/components/vl53l0x/sensor.py index 8055d5ff77..583d6ccca9 100644 --- a/esphome/components/vl53l0x/sensor.py +++ b/esphome/components/vl53l0x/sensor.py @@ -20,6 +20,7 @@ VL53L0XSensor = vl53l0x_ns.class_( CONF_SIGNAL_RATE_LIMIT = "signal_rate_limit" CONF_LONG_RANGE = "long_range" +CONF_TIMING_BUDGET = "timing_budget" def check_keys(obj): @@ -54,6 +55,13 @@ CONFIG_SCHEMA = cv.All( cv.Optional(CONF_LONG_RANGE, default=False): cv.boolean, cv.Optional(CONF_TIMEOUT, default="10ms"): check_timeout, cv.Optional(CONF_ENABLE_PIN): pins.gpio_output_pin_schema, + cv.Optional(CONF_TIMING_BUDGET): cv.All( + cv.positive_time_period_microseconds, + cv.Range( + min=cv.TimePeriod(microseconds=20000), + max=cv.TimePeriod(microseconds=4294967295), + ), + ), } ) .extend(cv.polling_component_schema("60s")) @@ -73,4 +81,7 @@ async def to_code(config): enable = await cg.gpio_pin_expression(config[CONF_ENABLE_PIN]) cg.add(var.set_enable_pin(enable)) + if timing_budget := config.get(CONF_TIMING_BUDGET): + cg.add(var.set_timing_budget(timing_budget)) + await i2c.register_i2c_device(var, config) diff --git a/esphome/components/vl53l0x/vl53l0x_sensor.cpp b/esphome/components/vl53l0x/vl53l0x_sensor.cpp index b07779a653..d0b7116eb8 100644 --- a/esphome/components/vl53l0x/vl53l0x_sensor.cpp +++ b/esphome/components/vl53l0x/vl53l0x_sensor.cpp @@ -28,6 +28,7 @@ void VL53L0XSensor::dump_config() { LOG_PIN(" Enable Pin: ", this->enable_pin_); } ESP_LOGCONFIG(TAG, " Timeout: %u%s", this->timeout_us_, this->timeout_us_ > 0 ? "us" : " (no timeout)"); + ESP_LOGCONFIG(TAG, " Timing Budget %uus ", this->measurement_timing_budget_us_); } void VL53L0XSensor::setup() { @@ -230,7 +231,10 @@ void VL53L0XSensor::setup() { reg(0x84) &= ~0x10; reg(0x0B) = 0x01; - measurement_timing_budget_us_ = get_measurement_timing_budget_(); + if (this->measurement_timing_budget_us_ == 0) { + this->measurement_timing_budget_us_ = this->get_measurement_timing_budget_(); + } + reg(0x01) = 0xE8; set_measurement_timing_budget_(measurement_timing_budget_us_); reg(0x01) = 0x01; diff --git a/esphome/components/vl53l0x/vl53l0x_sensor.h b/esphome/components/vl53l0x/vl53l0x_sensor.h index 971fb458bb..dd76e8e0ab 100644 --- a/esphome/components/vl53l0x/vl53l0x_sensor.h +++ b/esphome/components/vl53l0x/vl53l0x_sensor.h @@ -39,6 +39,7 @@ class VL53L0XSensor : public sensor::Sensor, public PollingComponent, public i2c void set_long_range(bool long_range) { long_range_ = long_range; } void set_timeout_us(uint32_t timeout_us) { this->timeout_us_ = timeout_us; } void set_enable_pin(GPIOPin *enable) { this->enable_pin_ = enable; } + void set_timing_budget(uint32_t timing_budget) { this->measurement_timing_budget_us_ = timing_budget; } protected: uint32_t get_measurement_timing_budget_(); @@ -59,7 +60,7 @@ class VL53L0XSensor : public sensor::Sensor, public PollingComponent, public i2c float signal_rate_limit_; bool long_range_; GPIOPin *enable_pin_{nullptr}; - uint32_t measurement_timing_budget_us_; + uint32_t measurement_timing_budget_us_{0}; bool initiated_read_{false}; bool waiting_for_interrupt_{false}; uint8_t stop_variable_; diff --git a/tests/components/vl53l0x/common.yaml b/tests/components/vl53l0x/common.yaml index 973e481b1a..8346eae854 100644 --- a/tests/components/vl53l0x/common.yaml +++ b/tests/components/vl53l0x/common.yaml @@ -10,3 +10,4 @@ sensor: enable_pin: 3 timeout: 200us update_interval: 60s + timing_budget: 30000us From 2dca2d5f859ee9d2793df56962e9614a69887307 Mon Sep 17 00:00:00 2001 From: Benjamin Pearce Date: Wed, 30 Apr 2025 23:52:51 -0400 Subject: [PATCH 067/102] Daikin IR Climate Remote Target Temperature and Fan Modes (#7946) Co-authored-by: Benjamin Pearce --- esphome/components/daikin/daikin.cpp | 24 +++++++++++++++--------- esphome/components/daikin/daikin.h | 10 +++++----- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/esphome/components/daikin/daikin.cpp b/esphome/components/daikin/daikin.cpp index bb8587fbeb..359c63aeca 100644 --- a/esphome/components/daikin/daikin.cpp +++ b/esphome/components/daikin/daikin.cpp @@ -65,7 +65,7 @@ void DaikinClimate::transmit_state() { transmit.perform(); } -uint8_t DaikinClimate::operation_mode_() { +uint8_t DaikinClimate::operation_mode_() const { uint8_t operating_mode = DAIKIN_MODE_ON; switch (this->mode) { case climate::CLIMATE_MODE_COOL: @@ -92,9 +92,12 @@ uint8_t DaikinClimate::operation_mode_() { return operating_mode; } -uint16_t DaikinClimate::fan_speed_() { +uint16_t DaikinClimate::fan_speed_() const { uint16_t fan_speed; switch (this->fan_mode.value()) { + case climate::CLIMATE_FAN_QUIET: + fan_speed = DAIKIN_FAN_SILENT << 8; + break; case climate::CLIMATE_FAN_LOW: fan_speed = DAIKIN_FAN_1 << 8; break; @@ -126,12 +129,11 @@ uint16_t DaikinClimate::fan_speed_() { return fan_speed; } -uint8_t DaikinClimate::temperature_() { +uint8_t DaikinClimate::temperature_() const { // Force special temperatures depending on the mode switch (this->mode) { case climate::CLIMATE_MODE_FAN_ONLY: return 0x32; - case climate::CLIMATE_MODE_HEAT_COOL: case climate::CLIMATE_MODE_DRY: return 0xc0; default: @@ -148,19 +150,25 @@ bool DaikinClimate::parse_state_frame_(const uint8_t frame[]) { if (frame[DAIKIN_STATE_FRAME_SIZE - 1] != checksum) return false; uint8_t mode = frame[5]; + // Temperature is given in degrees celcius * 2 + // only update for states that use the temperature + uint8_t temperature = frame[6]; if (mode & DAIKIN_MODE_ON) { switch (mode & 0xF0) { case DAIKIN_MODE_COOL: this->mode = climate::CLIMATE_MODE_COOL; + this->target_temperature = static_cast(temperature * 0.5f); break; case DAIKIN_MODE_DRY: this->mode = climate::CLIMATE_MODE_DRY; break; case DAIKIN_MODE_HEAT: this->mode = climate::CLIMATE_MODE_HEAT; + this->target_temperature = static_cast(temperature * 0.5f); break; case DAIKIN_MODE_AUTO: this->mode = climate::CLIMATE_MODE_HEAT_COOL; + this->target_temperature = static_cast(temperature * 0.5f); break; case DAIKIN_MODE_FAN: this->mode = climate::CLIMATE_MODE_FAN_ONLY; @@ -169,10 +177,6 @@ bool DaikinClimate::parse_state_frame_(const uint8_t frame[]) { } else { this->mode = climate::CLIMATE_MODE_OFF; } - uint8_t temperature = frame[6]; - if (!(temperature & 0xC0)) { - this->target_temperature = temperature >> 1; - } uint8_t fan_mode = frame[8]; uint8_t swing_mode = frame[9]; if (fan_mode & 0xF && swing_mode & 0xF) { @@ -187,7 +191,6 @@ bool DaikinClimate::parse_state_frame_(const uint8_t frame[]) { switch (fan_mode & 0xF0) { case DAIKIN_FAN_1: case DAIKIN_FAN_2: - case DAIKIN_FAN_SILENT: this->fan_mode = climate::CLIMATE_FAN_LOW; break; case DAIKIN_FAN_3: @@ -200,6 +203,9 @@ bool DaikinClimate::parse_state_frame_(const uint8_t frame[]) { case DAIKIN_FAN_AUTO: this->fan_mode = climate::CLIMATE_FAN_AUTO; break; + case DAIKIN_FAN_SILENT: + this->fan_mode = climate::CLIMATE_FAN_QUIET; + break; } this->publish_state(); return true; diff --git a/esphome/components/daikin/daikin.h b/esphome/components/daikin/daikin.h index b4ac309de9..159292cb55 100644 --- a/esphome/components/daikin/daikin.h +++ b/esphome/components/daikin/daikin.h @@ -44,17 +44,17 @@ class DaikinClimate : public climate_ir::ClimateIR { public: DaikinClimate() : climate_ir::ClimateIR(DAIKIN_TEMP_MIN, DAIKIN_TEMP_MAX, 1.0f, true, true, - {climate::CLIMATE_FAN_AUTO, climate::CLIMATE_FAN_LOW, climate::CLIMATE_FAN_MEDIUM, - climate::CLIMATE_FAN_HIGH}, + {climate::CLIMATE_FAN_QUIET, climate::CLIMATE_FAN_AUTO, climate::CLIMATE_FAN_LOW, + climate::CLIMATE_FAN_MEDIUM, climate::CLIMATE_FAN_HIGH}, {climate::CLIMATE_SWING_OFF, climate::CLIMATE_SWING_VERTICAL, climate::CLIMATE_SWING_HORIZONTAL, climate::CLIMATE_SWING_BOTH}) {} protected: // Transmit via IR the state of this climate controller. void transmit_state() override; - uint8_t operation_mode_(); - uint16_t fan_speed_(); - uint8_t temperature_(); + uint8_t operation_mode_() const; + uint16_t fan_speed_() const; + uint8_t temperature_() const; // Handle received IR Buffer bool on_receive(remote_base::RemoteReceiveData data) override; bool parse_state_frame_(const uint8_t frame[]); From 1aa2b79311a6d9000a68cd7e95dce6639eb4351e Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Thu, 1 May 2025 13:54:56 +1000 Subject: [PATCH 068/102] [i2c] Allow buffers in PSRAM (#8640) --- esphome/components/i2c/i2c_bus_esp_idf.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/esphome/components/i2c/i2c_bus_esp_idf.cpp b/esphome/components/i2c/i2c_bus_esp_idf.cpp index c5d6dd8b2a..c14300f725 100644 --- a/esphome/components/i2c/i2c_bus_esp_idf.cpp +++ b/esphome/components/i2c/i2c_bus_esp_idf.cpp @@ -67,7 +67,7 @@ void IDFI2CBus::setup() { ESP_LOGV(TAG, "i2c_timeout set to %" PRIu32 " ticks (%" PRIu32 " us)", timeout_ * 80, timeout_); } } - err = i2c_driver_install(port_, I2C_MODE_MASTER, 0, 0, ESP_INTR_FLAG_IRAM); + err = i2c_driver_install(port_, I2C_MODE_MASTER, 0, 0, 0); if (err != ESP_OK) { ESP_LOGW(TAG, "i2c_driver_install failed: %s", esp_err_to_name(err)); this->mark_failed(); From f5241ff777038b19aa273c87c21fb8bc816bfa77 Mon Sep 17 00:00:00 2001 From: rwrozelle Date: Wed, 30 Apr 2025 23:55:30 -0400 Subject: [PATCH 069/102] Fix CONFIG_LWIP_TCP_RCV_SCALE and CONFIG_TCP_WND_DEFAULT (#8425) --- esphome/components/speaker/media_player/__init__.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/esphome/components/speaker/media_player/__init__.py b/esphome/components/speaker/media_player/__init__.py index 14b72cacc0..35d763b1f8 100644 --- a/esphome/components/speaker/media_player/__init__.py +++ b/esphome/components/speaker/media_player/__init__.py @@ -332,14 +332,12 @@ async def to_code(config): esp32.add_idf_sdkconfig_option("CONFIG_TCP_MSS", 1436) esp32.add_idf_sdkconfig_option("CONFIG_TCP_MSL", 60000) esp32.add_idf_sdkconfig_option("CONFIG_TCP_SND_BUF_DEFAULT", 65535) - esp32.add_idf_sdkconfig_option( - "CONFIG_TCP_WND_DEFAULT", 65535 - ) # Adjusted from referenced settings to avoid compilation error + esp32.add_idf_sdkconfig_option("CONFIG_TCP_WND_DEFAULT", 512000) esp32.add_idf_sdkconfig_option("CONFIG_TCP_RECVMBOX_SIZE", 512) esp32.add_idf_sdkconfig_option("CONFIG_TCP_QUEUE_OOSEQ", True) esp32.add_idf_sdkconfig_option("CONFIG_TCP_OVERSIZE_MSS", True) esp32.add_idf_sdkconfig_option("CONFIG_LWIP_WND_SCALE", True) - esp32.add_idf_sdkconfig_option("CONFIG_TCP_RCV_SCALE", 3) + esp32.add_idf_sdkconfig_option("CONFIG_LWIP_TCP_RCV_SCALE", 3) esp32.add_idf_sdkconfig_option("CONFIG_LWIP_TCPIP_RECVMBOX_SIZE", 512) # Allocate wifi buffers in PSRAM From 8cd62c0308e3052321f4e89d09172df14b2812d9 Mon Sep 17 00:00:00 2001 From: scaiper Date: Thu, 1 May 2025 06:57:52 +0300 Subject: [PATCH 070/102] support self-signed cert in mqtt (#8650) --- esphome/components/mqtt/__init__.py | 2 +- esphome/const.py | 1 + esphome/mqtt.py | 29 +++++++++++++++++++++-------- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/esphome/components/mqtt/__init__.py b/esphome/components/mqtt/__init__.py index 99f8ad76d8..63d8da5788 100644 --- a/esphome/components/mqtt/__init__.py +++ b/esphome/components/mqtt/__init__.py @@ -41,6 +41,7 @@ from esphome.const import ( CONF_REBOOT_TIMEOUT, CONF_RETAIN, CONF_SHUTDOWN_MESSAGE, + CONF_SKIP_CERT_CN_CHECK, CONF_SSL_FINGERPRINTS, CONF_STATE_TOPIC, CONF_SUBSCRIBE_QOS, @@ -67,7 +68,6 @@ def AUTO_LOAD(): CONF_DISCOVER_IP = "discover_ip" CONF_IDF_SEND_ASYNC = "idf_send_async" -CONF_SKIP_CERT_CN_CHECK = "skip_cert_cn_check" def validate_message_just_topic(value): diff --git a/esphome/const.py b/esphome/const.py index ffa5de2de3..21cf7367de 100644 --- a/esphome/const.py +++ b/esphome/const.py @@ -800,6 +800,7 @@ CONF_SHUTDOWN_MESSAGE = "shutdown_message" CONF_SIGNAL_STRENGTH = "signal_strength" CONF_SINGLE_LIGHT_ID = "single_light_id" CONF_SIZE = "size" +CONF_SKIP_CERT_CN_CHECK = "skip_cert_cn_check" CONF_SLEEP_DURATION = "sleep_duration" CONF_SLEEP_PIN = "sleep_pin" CONF_SLEEP_WHEN_DONE = "sleep_when_done" diff --git a/esphome/mqtt.py b/esphome/mqtt.py index 2f90c49025..2403a4a1d9 100644 --- a/esphome/mqtt.py +++ b/esphome/mqtt.py @@ -3,6 +3,7 @@ import hashlib import json import logging import ssl +import tempfile import time import paho.mqtt.client as mqtt @@ -10,6 +11,8 @@ import paho.mqtt.client as mqtt from esphome.const import ( CONF_BROKER, CONF_CERTIFICATE_AUTHORITY, + CONF_CLIENT_CERTIFICATE, + CONF_CLIENT_CERTIFICATE_KEY, CONF_DISCOVERY_PREFIX, CONF_ESPHOME, CONF_LOG_TOPIC, @@ -17,6 +20,7 @@ from esphome.const import ( CONF_NAME, CONF_PASSWORD, CONF_PORT, + CONF_SKIP_CERT_CN_CHECK, CONF_SSL_FINGERPRINTS, CONF_TOPIC, CONF_TOPIC_PREFIX, @@ -102,15 +106,24 @@ def prepare( if config[CONF_MQTT].get(CONF_SSL_FINGERPRINTS) or config[CONF_MQTT].get( CONF_CERTIFICATE_AUTHORITY ): - tls_version = ssl.PROTOCOL_TLS # pylint: disable=no-member - client.tls_set( - ca_certs=None, - certfile=None, - keyfile=None, - cert_reqs=ssl.CERT_REQUIRED, - tls_version=tls_version, - ciphers=None, + context = ssl.create_default_context( + cadata=config[CONF_MQTT].get(CONF_CERTIFICATE_AUTHORITY) ) + if config[CONF_MQTT].get(CONF_SKIP_CERT_CN_CHECK): + context.check_hostname = False + if config[CONF_MQTT].get(CONF_CLIENT_CERTIFICATE) and config[CONF_MQTT].get( + CONF_CLIENT_CERTIFICATE_KEY + ): + with ( + tempfile.NamedTemporaryFile(mode="w+") as cert_file, + tempfile.NamedTemporaryFile(mode="w+") as key_file, + ): + cert_file.write(config[CONF_MQTT].get(CONF_CLIENT_CERTIFICATE)) + cert_file.flush() + key_file.write(config[CONF_MQTT].get(CONF_CLIENT_CERTIFICATE_KEY)) + key_file.flush() + context.load_cert_chain(cert_file, key_file) + client.tls_set_context(context) try: host = str(config[CONF_MQTT][CONF_BROKER]) From 087ff865a787a1cf3595890ead4deb36ce053f62 Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Thu, 1 May 2025 13:58:35 +1000 Subject: [PATCH 071/102] [binary_sensor] initial state refactor (#8648) Co-authored-by: Zsombor Welker --- .../binary_sensor/binary_sensor.cpp | 16 +++---- .../components/binary_sensor/binary_sensor.h | 2 +- esphome/components/binary_sensor/filter.cpp | 42 +++++++++---------- esphome/components/binary_sensor/filter.h | 20 ++++----- 4 files changed, 38 insertions(+), 42 deletions(-) diff --git a/esphome/components/binary_sensor/binary_sensor.cpp b/esphome/components/binary_sensor/binary_sensor.cpp index 20604a0b7e..30fbe4f0b4 100644 --- a/esphome/components/binary_sensor/binary_sensor.cpp +++ b/esphome/components/binary_sensor/binary_sensor.cpp @@ -15,21 +15,17 @@ void BinarySensor::publish_state(bool state) { if (!this->publish_dedup_.next(state)) return; if (this->filter_list_ == nullptr) { - this->send_state_internal(state, false); + this->send_state_internal(state); } else { - this->filter_list_->input(state, false); + this->filter_list_->input(state); } } void BinarySensor::publish_initial_state(bool state) { - if (!this->publish_dedup_.next(state)) - return; - if (this->filter_list_ == nullptr) { - this->send_state_internal(state, true); - } else { - this->filter_list_->input(state, true); - } + this->has_state_ = false; + this->publish_state(state); } -void BinarySensor::send_state_internal(bool state, bool is_initial) { +void BinarySensor::send_state_internal(bool state) { + bool is_initial = !this->has_state_; if (is_initial) { ESP_LOGD(TAG, "'%s': Sending initial state %s", this->get_name().c_str(), ONOFF(state)); } else { diff --git a/esphome/components/binary_sensor/binary_sensor.h b/esphome/components/binary_sensor/binary_sensor.h index 57cae9e2f5..9ba7aeeeff 100644 --- a/esphome/components/binary_sensor/binary_sensor.h +++ b/esphome/components/binary_sensor/binary_sensor.h @@ -67,7 +67,7 @@ class BinarySensor : public EntityBase, public EntityBase_DeviceClass { // ========== INTERNAL METHODS ========== // (In most use cases you won't need these) - void send_state_internal(bool state, bool is_initial); + void send_state_internal(bool state); /// Return whether this binary sensor has outputted a state. virtual bool has_state() const; diff --git a/esphome/components/binary_sensor/filter.cpp b/esphome/components/binary_sensor/filter.cpp index 8f94b108ac..fd6cc31008 100644 --- a/esphome/components/binary_sensor/filter.cpp +++ b/esphome/components/binary_sensor/filter.cpp @@ -9,37 +9,37 @@ namespace binary_sensor { static const char *const TAG = "sensor.filter"; -void Filter::output(bool value, bool is_initial) { +void Filter::output(bool value) { if (!this->dedup_.next(value)) return; if (this->next_ == nullptr) { - this->parent_->send_state_internal(value, is_initial); + this->parent_->send_state_internal(value); } else { - this->next_->input(value, is_initial); + this->next_->input(value); } } -void Filter::input(bool value, bool is_initial) { - auto b = this->new_value(value, is_initial); +void Filter::input(bool value) { + auto b = this->new_value(value); if (b.has_value()) { - this->output(*b, is_initial); + this->output(*b); } } -optional DelayedOnOffFilter::new_value(bool value, bool is_initial) { +optional DelayedOnOffFilter::new_value(bool value) { if (value) { - this->set_timeout("ON_OFF", this->on_delay_.value(), [this, is_initial]() { this->output(true, is_initial); }); + this->set_timeout("ON_OFF", this->on_delay_.value(), [this]() { this->output(true); }); } else { - this->set_timeout("ON_OFF", this->off_delay_.value(), [this, is_initial]() { this->output(false, is_initial); }); + this->set_timeout("ON_OFF", this->off_delay_.value(), [this]() { this->output(false); }); } return {}; } float DelayedOnOffFilter::get_setup_priority() const { return setup_priority::HARDWARE; } -optional DelayedOnFilter::new_value(bool value, bool is_initial) { +optional DelayedOnFilter::new_value(bool value) { if (value) { - this->set_timeout("ON", this->delay_.value(), [this, is_initial]() { this->output(true, is_initial); }); + this->set_timeout("ON", this->delay_.value(), [this]() { this->output(true); }); return {}; } else { this->cancel_timeout("ON"); @@ -49,9 +49,9 @@ optional DelayedOnFilter::new_value(bool value, bool is_initial) { float DelayedOnFilter::get_setup_priority() const { return setup_priority::HARDWARE; } -optional DelayedOffFilter::new_value(bool value, bool is_initial) { +optional DelayedOffFilter::new_value(bool value) { if (!value) { - this->set_timeout("OFF", this->delay_.value(), [this, is_initial]() { this->output(false, is_initial); }); + this->set_timeout("OFF", this->delay_.value(), [this]() { this->output(false); }); return {}; } else { this->cancel_timeout("OFF"); @@ -61,11 +61,11 @@ optional DelayedOffFilter::new_value(bool value, bool is_initial) { float DelayedOffFilter::get_setup_priority() const { return setup_priority::HARDWARE; } -optional InvertFilter::new_value(bool value, bool is_initial) { return !value; } +optional InvertFilter::new_value(bool value) { return !value; } AutorepeatFilter::AutorepeatFilter(std::vector timings) : timings_(std::move(timings)) {} -optional AutorepeatFilter::new_value(bool value, bool is_initial) { +optional AutorepeatFilter::new_value(bool value) { if (value) { // Ignore if already running if (this->active_timing_ != 0) @@ -101,7 +101,7 @@ void AutorepeatFilter::next_timing_() { void AutorepeatFilter::next_value_(bool val) { const AutorepeatFilterTiming &timing = this->timings_[this->active_timing_ - 2]; - this->output(val, false); // This is at least the second one so not initial + this->output(val); this->set_timeout("ON_OFF", val ? timing.time_on : timing.time_off, [this, val]() { this->next_value_(!val); }); } @@ -109,18 +109,18 @@ float AutorepeatFilter::get_setup_priority() const { return setup_priority::HARD LambdaFilter::LambdaFilter(std::function(bool)> f) : f_(std::move(f)) {} -optional LambdaFilter::new_value(bool value, bool is_initial) { return this->f_(value); } +optional LambdaFilter::new_value(bool value) { return this->f_(value); } -optional SettleFilter::new_value(bool value, bool is_initial) { +optional SettleFilter::new_value(bool value) { if (!this->steady_) { - this->set_timeout("SETTLE", this->delay_.value(), [this, value, is_initial]() { + this->set_timeout("SETTLE", this->delay_.value(), [this, value]() { this->steady_ = true; - this->output(value, is_initial); + this->output(value); }); return {}; } else { this->steady_ = false; - this->output(value, is_initial); + this->output(value); this->set_timeout("SETTLE", this->delay_.value(), [this]() { this->steady_ = true; }); return value; } diff --git a/esphome/components/binary_sensor/filter.h b/esphome/components/binary_sensor/filter.h index f7342db2fb..65838da49d 100644 --- a/esphome/components/binary_sensor/filter.h +++ b/esphome/components/binary_sensor/filter.h @@ -14,11 +14,11 @@ class BinarySensor; class Filter { public: - virtual optional new_value(bool value, bool is_initial) = 0; + virtual optional new_value(bool value) = 0; - void input(bool value, bool is_initial); + void input(bool value); - void output(bool value, bool is_initial); + void output(bool value); protected: friend BinarySensor; @@ -30,7 +30,7 @@ class Filter { class DelayedOnOffFilter : public Filter, public Component { public: - optional new_value(bool value, bool is_initial) override; + optional new_value(bool value) override; float get_setup_priority() const override; @@ -44,7 +44,7 @@ class DelayedOnOffFilter : public Filter, public Component { class DelayedOnFilter : public Filter, public Component { public: - optional new_value(bool value, bool is_initial) override; + optional new_value(bool value) override; float get_setup_priority() const override; @@ -56,7 +56,7 @@ class DelayedOnFilter : public Filter, public Component { class DelayedOffFilter : public Filter, public Component { public: - optional new_value(bool value, bool is_initial) override; + optional new_value(bool value) override; float get_setup_priority() const override; @@ -68,7 +68,7 @@ class DelayedOffFilter : public Filter, public Component { class InvertFilter : public Filter { public: - optional new_value(bool value, bool is_initial) override; + optional new_value(bool value) override; }; struct AutorepeatFilterTiming { @@ -86,7 +86,7 @@ class AutorepeatFilter : public Filter, public Component { public: explicit AutorepeatFilter(std::vector timings); - optional new_value(bool value, bool is_initial) override; + optional new_value(bool value) override; float get_setup_priority() const override; @@ -102,7 +102,7 @@ class LambdaFilter : public Filter { public: explicit LambdaFilter(std::function(bool)> f); - optional new_value(bool value, bool is_initial) override; + optional new_value(bool value) override; protected: std::function(bool)> f_; @@ -110,7 +110,7 @@ class LambdaFilter : public Filter { class SettleFilter : public Filter, public Component { public: - optional new_value(bool value, bool is_initial) override; + optional new_value(bool value) override; float get_setup_priority() const override; From da9c755f6730f38f173fafceac6848d813cca9df Mon Sep 17 00:00:00 2001 From: Ralf Habacker Date: Thu, 1 May 2025 09:53:12 +0200 Subject: [PATCH 072/102] Add to_ntc_resistance|temperature sensor filter (esphome/feature-requests#2967) (#7898) Co-authored-by: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> --- esphome/components/sensor/__init__.py | 139 ++++++++++++++++++++++++++ esphome/components/sensor/filter.cpp | 23 +++++ esphome/components/sensor/filter.h | 22 ++++ esphome/const.py | 2 + tests/components/template/common.yaml | 10 ++ 5 files changed, 196 insertions(+) diff --git a/esphome/components/sensor/__init__.py b/esphome/components/sensor/__init__.py index 9dbad27102..5f990466c8 100644 --- a/esphome/components/sensor/__init__.py +++ b/esphome/components/sensor/__init__.py @@ -1,3 +1,4 @@ +import logging import math from esphome import automation @@ -9,6 +10,7 @@ from esphome.const import ( CONF_ACCURACY_DECIMALS, CONF_ALPHA, CONF_BELOW, + CONF_CALIBRATION, CONF_DEVICE_CLASS, CONF_ENTITY_CATEGORY, CONF_EXPIRE_AFTER, @@ -30,6 +32,7 @@ from esphome.const import ( CONF_SEND_EVERY, CONF_SEND_FIRST_AT, CONF_STATE_CLASS, + CONF_TEMPERATURE, CONF_TIMEOUT, CONF_TO, CONF_TRIGGER_ID, @@ -153,6 +156,8 @@ DEVICE_CLASSES = [ DEVICE_CLASS_WIND_SPEED, ] +_LOGGER = logging.getLogger(__name__) + sensor_ns = cg.esphome_ns.namespace("sensor") StateClasses = sensor_ns.enum("StateClass") STATE_CLASSES = { @@ -246,6 +251,8 @@ HeartbeatFilter = sensor_ns.class_("HeartbeatFilter", Filter, cg.Component) DeltaFilter = sensor_ns.class_("DeltaFilter", Filter) OrFilter = sensor_ns.class_("OrFilter", Filter) CalibrateLinearFilter = sensor_ns.class_("CalibrateLinearFilter", Filter) +ToNTCResistanceFilter = sensor_ns.class_("ToNTCResistanceFilter", Filter) +ToNTCTemperatureFilter = sensor_ns.class_("ToNTCTemperatureFilter", Filter) CalibratePolynomialFilter = sensor_ns.class_("CalibratePolynomialFilter", Filter) SensorInRangeCondition = sensor_ns.class_("SensorInRangeCondition", Filter) ClampFilter = sensor_ns.class_("ClampFilter", Filter) @@ -852,6 +859,138 @@ async def sensor_in_range_to_code(config, condition_id, template_arg, args): return var +def validate_ntc_calibration_parameter(value): + if isinstance(value, dict): + return cv.Schema( + { + cv.Required(CONF_TEMPERATURE): cv.temperature, + cv.Required(CONF_VALUE): cv.resistance, + } + )(value) + + value = cv.string(value) + parts = value.split("->") + if len(parts) != 2: + raise cv.Invalid("Calibration parameter must be of form 3000 -> 23°C") + resistance = cv.resistance(parts[0].strip()) + temperature = cv.temperature(parts[1].strip()) + return validate_ntc_calibration_parameter( + { + CONF_TEMPERATURE: temperature, + CONF_VALUE: resistance, + } + ) + + +CONF_A = "a" +CONF_B = "b" +CONF_C = "c" +ZERO_POINT = 273.15 + + +def ntc_calc_steinhart_hart(value): + r1 = value[0][CONF_VALUE] + r2 = value[1][CONF_VALUE] + r3 = value[2][CONF_VALUE] + t1 = value[0][CONF_TEMPERATURE] + ZERO_POINT + t2 = value[1][CONF_TEMPERATURE] + ZERO_POINT + t3 = value[2][CONF_TEMPERATURE] + ZERO_POINT + + l1 = math.log(r1) + l2 = math.log(r2) + l3 = math.log(r3) + + y1 = 1 / t1 + y2 = 1 / t2 + y3 = 1 / t3 + + g2 = (y2 - y1) / (l2 - l1) + g3 = (y3 - y1) / (l3 - l1) + + c = (g3 - g2) / (l3 - l2) * 1 / (l1 + l2 + l3) + b = g2 - c * (l1 * l1 + l1 * l2 + l2 * l2) + a = y1 - (b + l1 * l1 * c) * l1 + return a, b, c + + +def ntc_get_abc(value): + a = value[CONF_A] + b = value[CONF_B] + c = value[CONF_C] + return a, b, c + + +def ntc_process_calibration(value): + if isinstance(value, dict): + value = cv.Schema( + { + cv.Required(CONF_A): cv.float_, + cv.Required(CONF_B): cv.float_, + cv.Required(CONF_C): cv.float_, + } + )(value) + a, b, c = ntc_get_abc(value) + elif isinstance(value, list): + if len(value) != 3: + raise cv.Invalid( + "Steinhart–Hart Calibration must consist of exactly three values" + ) + value = cv.Schema([validate_ntc_calibration_parameter])(value) + a, b, c = ntc_calc_steinhart_hart(value) + else: + raise cv.Invalid( + f"Calibration parameter accepts either a list for steinhart-hart calibration, or mapping for b-constant calibration, not {type(value)}" + ) + _LOGGER.info("Coefficient: a:%s, b:%s, c:%s", a, b, c) + return { + CONF_A: a, + CONF_B: b, + CONF_C: c, + } + + +@FILTER_REGISTRY.register( + "to_ntc_resistance", + ToNTCResistanceFilter, + cv.All( + cv.Schema( + { + cv.Required(CONF_CALIBRATION): ntc_process_calibration, + } + ), + ), +) +async def calibrate_ntc_resistance_filter_to_code(config, filter_id): + calib = config[CONF_CALIBRATION] + return cg.new_Pvariable( + filter_id, + calib[CONF_A], + calib[CONF_B], + calib[CONF_C], + ) + + +@FILTER_REGISTRY.register( + "to_ntc_temperature", + ToNTCTemperatureFilter, + cv.All( + cv.Schema( + { + cv.Required(CONF_CALIBRATION): ntc_process_calibration, + } + ), + ), +) +async def calibrate_ntc_temperature_filter_to_code(config, filter_id): + calib = config[CONF_CALIBRATION] + return cg.new_Pvariable( + filter_id, + calib[CONF_A], + calib[CONF_B], + calib[CONF_C], + ) + + def _mean(xs): return sum(xs) / len(xs) diff --git a/esphome/components/sensor/filter.cpp b/esphome/components/sensor/filter.cpp index 0a8740dd5b..ce23c1f800 100644 --- a/esphome/components/sensor/filter.cpp +++ b/esphome/components/sensor/filter.cpp @@ -481,5 +481,28 @@ optional RoundMultipleFilter::new_value(float value) { return value; } +optional ToNTCResistanceFilter::new_value(float value) { + if (!std::isfinite(value)) { + return NAN; + } + double k = 273.15; + // https://de.wikipedia.org/wiki/Steinhart-Hart-Gleichung#cite_note-stein2_s4-3 + double t = value + k; + double y = (this->a_ - 1 / (t)) / (2 * this->c_); + double x = sqrt(pow(this->b_ / (3 * this->c_), 3) + y * y); + double resistance = exp(pow(x - y, 1 / 3.0) - pow(x + y, 1 / 3.0)); + return resistance; +} + +optional ToNTCTemperatureFilter::new_value(float value) { + if (!std::isfinite(value)) { + return NAN; + } + double lr = log(double(value)); + double v = this->a_ + this->b_ * lr + this->c_ * lr * lr * lr; + double temp = float(1.0 / v - 273.15); + return temp; +} + } // namespace sensor } // namespace esphome diff --git a/esphome/components/sensor/filter.h b/esphome/components/sensor/filter.h index 86586b458d..3cfaebb708 100644 --- a/esphome/components/sensor/filter.h +++ b/esphome/components/sensor/filter.h @@ -439,5 +439,27 @@ class RoundMultipleFilter : public Filter { float multiple_; }; +class ToNTCResistanceFilter : public Filter { + public: + ToNTCResistanceFilter(double a, double b, double c) : a_(a), b_(b), c_(c) {} + optional new_value(float value) override; + + protected: + double a_; + double b_; + double c_; +}; + +class ToNTCTemperatureFilter : public Filter { + public: + ToNTCTemperatureFilter(double a, double b, double c) : a_(a), b_(b), c_(c) {} + optional new_value(float value) override; + + protected: + double a_; + double b_; + double c_; +}; + } // namespace sensor } // namespace esphome diff --git a/esphome/const.py b/esphome/const.py index 21cf7367de..f78312a5b0 100644 --- a/esphome/const.py +++ b/esphome/const.py @@ -897,6 +897,8 @@ CONF_TIMES = "times" CONF_TIMEZONE = "timezone" CONF_TIMING = "timing" CONF_TO = "to" +CONF_TO_NTC_RESISTANCE = "to_ntc_resistance" +CONF_TO_NTC_TEMPERATURE = "to_ntc_temperature" CONF_TOLERANCE = "tolerance" CONF_TOPIC = "topic" CONF_TOPIC_PREFIX = "topic_prefix" diff --git a/tests/components/template/common.yaml b/tests/components/template/common.yaml index 79201fbe07..987849a80c 100644 --- a/tests/components/template/common.yaml +++ b/tests/components/template/common.yaml @@ -28,6 +28,16 @@ sensor: value: 20.0 - timeout: timeout: 1d + - to_ntc_resistance: + calibration: + - 10.0kOhm -> 25°C + - 27.219kOhm -> 0°C + - 14.674kOhm -> 15°C + - to_ntc_temperature: + calibration: + - 10.0kOhm -> 25°C + - 27.219kOhm -> 0°C + - 14.674kOhm -> 15°C esphome: on_boot: From e215fafebe082f09b7867fa6261d1d8ab4430b97 Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Thu, 1 May 2025 18:28:07 +1000 Subject: [PATCH 073/102] [esp32, debug] Add ``cpu_frequency`` config option and debug sensor (#8542) --- esphome/components/debug/debug_component.cpp | 4 + esphome/components/debug/debug_component.h | 5 + esphome/components/debug/debug_esp32.cpp | 345 +++++------------- esphome/components/debug/sensor.py | 14 + esphome/components/esp32/__init__.py | 54 +++ esphome/components/esp32/core.cpp | 16 +- tests/components/debug/common.yaml | 17 + tests/components/debug/test.esp32-ard.yaml | 3 + tests/components/debug/test.esp32-c3-ard.yaml | 3 + tests/components/debug/test.esp32-idf.yaml | 12 + 10 files changed, 206 insertions(+), 267 deletions(-) diff --git a/esphome/components/debug/debug_component.cpp b/esphome/components/debug/debug_component.cpp index 7d25bf5472..fcded02ba5 100644 --- a/esphome/components/debug/debug_component.cpp +++ b/esphome/components/debug/debug_component.cpp @@ -25,6 +25,7 @@ void DebugComponent::dump_config() { #ifdef USE_SENSOR LOG_SENSOR(" ", "Free space on heap", this->free_sensor_); LOG_SENSOR(" ", "Largest free heap block", this->block_sensor_); + LOG_SENSOR(" ", "CPU frequency", this->cpu_frequency_sensor_); #if defined(USE_ESP8266) && USE_ARDUINO_VERSION_CODE >= VERSION_CODE(2, 5, 2) LOG_SENSOR(" ", "Heap fragmentation", this->fragmentation_sensor_); #endif // defined(USE_ESP8266) && USE_ARDUINO_VERSION_CODE >= VERSION_CODE(2, 5, 2) @@ -86,6 +87,9 @@ void DebugComponent::update() { this->loop_time_sensor_->publish_state(this->max_loop_time_); this->max_loop_time_ = 0; } + if (this->cpu_frequency_sensor_ != nullptr) { + this->cpu_frequency_sensor_->publish_state(arch_get_cpu_freq_hz()); + } #endif // USE_SENSOR update_platform_(); diff --git a/esphome/components/debug/debug_component.h b/esphome/components/debug/debug_component.h index 608addb4a3..f887d52864 100644 --- a/esphome/components/debug/debug_component.h +++ b/esphome/components/debug/debug_component.h @@ -36,6 +36,9 @@ class DebugComponent : public PollingComponent { #ifdef USE_ESP32 void set_psram_sensor(sensor::Sensor *psram_sensor) { this->psram_sensor_ = psram_sensor; } #endif // USE_ESP32 + void set_cpu_frequency_sensor(sensor::Sensor *cpu_frequency_sensor) { + this->cpu_frequency_sensor_ = cpu_frequency_sensor; + } #endif // USE_SENSOR protected: uint32_t free_heap_{}; @@ -53,6 +56,7 @@ class DebugComponent : public PollingComponent { #ifdef USE_ESP32 sensor::Sensor *psram_sensor_{nullptr}; #endif // USE_ESP32 + sensor::Sensor *cpu_frequency_sensor_{nullptr}; #endif // USE_SENSOR #ifdef USE_ESP32 @@ -75,6 +79,7 @@ class DebugComponent : public PollingComponent { #endif // USE_TEXT_SENSOR std::string get_reset_reason_(); + std::string get_wakeup_cause_(); uint32_t get_free_heap_(); void get_device_info_(std::string &device_info); void update_platform_(); diff --git a/esphome/components/debug/debug_esp32.cpp b/esphome/components/debug/debug_esp32.cpp index 7367f54807..bc772a1d58 100644 --- a/esphome/components/debug/debug_esp32.cpp +++ b/esphome/components/debug/debug_esp32.cpp @@ -1,27 +1,15 @@ #include "debug_component.h" + #ifdef USE_ESP32 #include "esphome/core/log.h" +#include "esphome/core/hal.h" +#include #include #include #include #include -#if defined(USE_ESP32_VARIANT_ESP32) -#include -#elif defined(USE_ESP32_VARIANT_ESP32C2) -#include -#elif defined(USE_ESP32_VARIANT_ESP32C3) -#include -#elif defined(USE_ESP32_VARIANT_ESP32C6) -#include -#elif defined(USE_ESP32_VARIANT_ESP32S2) -#include -#elif defined(USE_ESP32_VARIANT_ESP32S3) -#include -#elif defined(USE_ESP32_VARIANT_ESP32H2) -#include -#endif #ifdef USE_ARDUINO #include #endif @@ -31,6 +19,67 @@ namespace debug { static const char *const TAG = "debug"; +// index by values returned by esp_reset_reason + +static const char *const RESET_REASONS[] = { + "unknown source", + "power-on event", + "external pin", + "software via esp_restart", + "exception/panic", + "interrupt watchdog", + "task watchdog", + "other watchdogs", + "exiting deep sleep mode", + "brownout", + "SDIO", + "USB peripheral", + "JTAG", + "efuse error", + "power glitch detected", + "CPU lock up", +}; + +std::string DebugComponent::get_reset_reason_() { + std::string reset_reason; + unsigned reason = esp_reset_reason(); + if (reason < sizeof(RESET_REASONS) / sizeof(RESET_REASONS[0])) { + reset_reason = RESET_REASONS[reason]; + } else { + reset_reason = "unknown source"; + } + ESP_LOGD(TAG, "Reset Reason: %s", reset_reason.c_str()); + return "Reset by " + reset_reason; +} + +static const char *const WAKEUP_CAUSES[] = { + "undefined", + "undefined", + "external signal using RTC_IO", + "external signal using RTC_CNTL", + "timer", + "touchpad", + "ULP program", + "GPIO", + "UART", + "WIFI", + "COCPU int", + "COCPU crash", + "BT", +}; + +std::string DebugComponent::get_wakeup_cause_() { + const char *wake_reason; + unsigned reason = esp_sleep_get_wakeup_cause(); + if (reason < sizeof(WAKEUP_CAUSES) / sizeof(WAKEUP_CAUSES[0])) { + wake_reason = WAKEUP_CAUSES[reason]; + } else { + wake_reason = "unknown source"; + } + ESP_LOGD(TAG, "Wakeup Reason: %s", wake_reason); + return wake_reason; +} + void DebugComponent::log_partition_info_() { ESP_LOGCONFIG(TAG, "Partition table:"); ESP_LOGCONFIG(TAG, " %-12s %-4s %-8s %-10s %-10s", "Name", "Type", "Subtype", "Address", "Size"); @@ -44,173 +93,16 @@ void DebugComponent::log_partition_info_() { esp_partition_iterator_release(it); } -std::string DebugComponent::get_reset_reason_() { - std::string reset_reason; - switch (esp_reset_reason()) { - case ESP_RST_POWERON: - reset_reason = "Reset due to power-on event"; - break; - case ESP_RST_EXT: - reset_reason = "Reset by external pin"; - break; - case ESP_RST_SW: - reset_reason = "Software reset via esp_restart"; - break; - case ESP_RST_PANIC: - reset_reason = "Software reset due to exception/panic"; - break; - case ESP_RST_INT_WDT: - reset_reason = "Reset (software or hardware) due to interrupt watchdog"; - break; - case ESP_RST_TASK_WDT: - reset_reason = "Reset due to task watchdog"; - break; - case ESP_RST_WDT: - reset_reason = "Reset due to other watchdogs"; - break; - case ESP_RST_DEEPSLEEP: - reset_reason = "Reset after exiting deep sleep mode"; - break; - case ESP_RST_BROWNOUT: - reset_reason = "Brownout reset (software or hardware)"; - break; - case ESP_RST_SDIO: - reset_reason = "Reset over SDIO"; - break; -#ifdef USE_ESP32_VARIANT_ESP32 -#if (ESP_IDF_VERSION >= ESP_IDF_VERSION_VAL(5, 1, 4)) - case ESP_RST_USB: - reset_reason = "Reset by USB peripheral"; - break; - case ESP_RST_JTAG: - reset_reason = "Reset by JTAG"; - break; - case ESP_RST_EFUSE: - reset_reason = "Reset due to efuse error"; - break; - case ESP_RST_PWR_GLITCH: - reset_reason = "Reset due to power glitch detected"; - break; - case ESP_RST_CPU_LOCKUP: - reset_reason = "Reset due to CPU lock up (double exception)"; - break; -#endif // ESP_IDF_VERSION >= ESP_IDF_VERSION_VAL(5, 1, 4) -#endif // USE_ESP32_VARIANT_ESP32 - default: // Includes ESP_RST_UNKNOWN - switch (rtc_get_reset_reason(0)) { - case POWERON_RESET: - reset_reason = "Power On Reset"; - break; -#if defined(USE_ESP32_VARIANT_ESP32) - case SW_RESET: -#elif defined(USE_ESP32_VARIANT_ESP32C3) || defined(USE_ESP32_VARIANT_ESP32S2) || \ - defined(USE_ESP32_VARIANT_ESP32S3) || defined(USE_ESP32_VARIANT_ESP32C6) - case RTC_SW_SYS_RESET: -#endif - reset_reason = "Software Reset Digital Core"; - break; -#if defined(USE_ESP32_VARIANT_ESP32) - case OWDT_RESET: - reset_reason = "Watch Dog Reset Digital Core"; - break; -#endif - case DEEPSLEEP_RESET: - reset_reason = "Deep Sleep Reset Digital Core"; - break; -#if defined(USE_ESP32_VARIANT_ESP32) - case SDIO_RESET: - reset_reason = "SLC Module Reset Digital Core"; - break; -#endif - case TG0WDT_SYS_RESET: - reset_reason = "Timer Group 0 Watch Dog Reset Digital Core"; - break; -#if !defined(USE_ESP32_VARIANT_ESP32C2) - case TG1WDT_SYS_RESET: - reset_reason = "Timer Group 1 Watch Dog Reset Digital Core"; - break; -#endif - case RTCWDT_SYS_RESET: - reset_reason = "RTC Watch Dog Reset Digital Core"; - break; -#if !defined(USE_ESP32_VARIANT_ESP32C6) && !defined(USE_ESP32_VARIANT_ESP32H2) - case INTRUSION_RESET: - reset_reason = "Intrusion Reset CPU"; - break; -#endif -#if defined(USE_ESP32_VARIANT_ESP32) - case TGWDT_CPU_RESET: - reset_reason = "Timer Group Reset CPU"; - break; -#elif defined(USE_ESP32_VARIANT_ESP32C3) || defined(USE_ESP32_VARIANT_ESP32S2) || \ - defined(USE_ESP32_VARIANT_ESP32S3) || defined(USE_ESP32_VARIANT_ESP32C6) - case TG0WDT_CPU_RESET: - reset_reason = "Timer Group 0 Reset CPU"; - break; -#endif -#if defined(USE_ESP32_VARIANT_ESP32) - case SW_CPU_RESET: -#elif defined(USE_ESP32_VARIANT_ESP32C3) || defined(USE_ESP32_VARIANT_ESP32S2) || \ - defined(USE_ESP32_VARIANT_ESP32S3) || defined(USE_ESP32_VARIANT_ESP32C6) - case RTC_SW_CPU_RESET: -#endif - reset_reason = "Software Reset CPU"; - break; - case RTCWDT_CPU_RESET: - reset_reason = "RTC Watch Dog Reset CPU"; - break; -#if defined(USE_ESP32_VARIANT_ESP32) - case EXT_CPU_RESET: - reset_reason = "External CPU Reset"; - break; -#endif - case RTCWDT_BROWN_OUT_RESET: - reset_reason = "Voltage Unstable Reset"; - break; - case RTCWDT_RTC_RESET: - reset_reason = "RTC Watch Dog Reset Digital Core And RTC Module"; - break; -#if defined(USE_ESP32_VARIANT_ESP32C3) || defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) || \ - defined(USE_ESP32_VARIANT_ESP32C6) - case TG1WDT_CPU_RESET: - reset_reason = "Timer Group 1 Reset CPU"; - break; - case SUPER_WDT_RESET: - reset_reason = "Super Watchdog Reset Digital Core And RTC Module"; - break; - case EFUSE_RESET: - reset_reason = "eFuse Reset Digital Core"; - break; -#endif -#if defined(USE_ESP32_VARIANT_ESP32C3) || defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) - case GLITCH_RTC_RESET: - reset_reason = "Glitch Reset Digital Core And RTC Module"; - break; -#endif -#if defined(USE_ESP32_VARIANT_ESP32C3) || defined(USE_ESP32_VARIANT_ESP32S3) || defined(USE_ESP32_VARIANT_ESP32C6) - case USB_UART_CHIP_RESET: - reset_reason = "USB UART Reset Digital Core"; - break; - case USB_JTAG_CHIP_RESET: - reset_reason = "USB JTAG Reset Digital Core"; - break; -#endif -#if defined(USE_ESP32_VARIANT_ESP32C3) || defined(USE_ESP32_VARIANT_ESP32S3) - case POWER_GLITCH_RESET: - reset_reason = "Power Glitch Reset Digital Core And RTC Module"; - break; -#endif - default: - reset_reason = "Unknown Reset Reason"; - } - break; - } - ESP_LOGD(TAG, "Reset Reason: %s", reset_reason.c_str()); - return reset_reason; -} - uint32_t DebugComponent::get_free_heap_() { return heap_caps_get_free_size(MALLOC_CAP_INTERNAL); } +static const std::map CHIP_FEATURES = { + {CHIP_FEATURE_BLE, "BLE"}, + {CHIP_FEATURE_BT, "BT"}, + {CHIP_FEATURE_EMB_FLASH, "EMB Flash"}, + {CHIP_FEATURE_EMB_PSRAM, "EMB PSRAM"}, + {CHIP_FEATURE_WIFI_BGN, "2.4GHz WiFi"}, +}; + void DebugComponent::get_device_info_(std::string &device_info) { #if defined(USE_ARDUINO) const char *flash_mode; @@ -246,46 +138,16 @@ void DebugComponent::get_device_info_(std::string &device_info) { esp_chip_info_t info; esp_chip_info(&info); - const char *model; -#if defined(USE_ESP32_VARIANT_ESP32) - model = "ESP32"; -#elif defined(USE_ESP32_VARIANT_ESP32C2) - model = "ESP32-C2"; -#elif defined(USE_ESP32_VARIANT_ESP32C3) - model = "ESP32-C3"; -#elif defined(USE_ESP32_VARIANT_ESP32C6) - model = "ESP32-C6"; -#elif defined(USE_ESP32_VARIANT_ESP32S2) - model = "ESP32-S2"; -#elif defined(USE_ESP32_VARIANT_ESP32S3) - model = "ESP32-S3"; -#elif defined(USE_ESP32_VARIANT_ESP32H2) - model = "ESP32-H2"; -#else - model = "UNKNOWN"; -#endif + const char *model = ESPHOME_VARIANT; std::string features; - if (info.features & CHIP_FEATURE_EMB_FLASH) { - features += "EMB_FLASH,"; - info.features &= ~CHIP_FEATURE_EMB_FLASH; + for (auto feature : CHIP_FEATURES) { + if (info.features & feature.first) { + features += feature.second; + features += ", "; + info.features &= ~feature.first; + } } - if (info.features & CHIP_FEATURE_WIFI_BGN) { - features += "WIFI_BGN,"; - info.features &= ~CHIP_FEATURE_WIFI_BGN; - } - if (info.features & CHIP_FEATURE_BLE) { - features += "BLE,"; - info.features &= ~CHIP_FEATURE_BLE; - } - if (info.features & CHIP_FEATURE_BT) { - features += "BT,"; - info.features &= ~CHIP_FEATURE_BT; - } - if (info.features & CHIP_FEATURE_EMB_PSRAM) { - features += "EMB_PSRAM,"; - info.features &= ~CHIP_FEATURE_EMB_PSRAM; - } - if (info.features) + if (info.features != 0) features += "Other:" + format_hex(info.features); ESP_LOGD(TAG, "Chip: Model=%s, Features=%s Cores=%u, Revision=%u", model, features.c_str(), info.cores, info.revision); @@ -295,6 +157,8 @@ void DebugComponent::get_device_info_(std::string &device_info) { device_info += features; device_info += " Cores:" + to_string(info.cores); device_info += " Revision:" + to_string(info.revision); + device_info += str_sprintf("|CPU Frequency: %" PRIu32 " MHz", arch_get_cpu_freq_hz() / 1000000); + ESP_LOGD(TAG, "CPU Frequency: %" PRIu32 " MHz", arch_get_cpu_freq_hz() / 1000000); // Framework detection device_info += "|Framework: "; @@ -321,50 +185,7 @@ void DebugComponent::get_device_info_(std::string &device_info) { device_info += "|Reset: "; device_info += get_reset_reason_(); - const char *wakeup_reason; - switch (rtc_get_wakeup_cause()) { - case NO_SLEEP: - wakeup_reason = "No Sleep"; - break; - case EXT_EVENT0_TRIG: - wakeup_reason = "External Event 0"; - break; - case EXT_EVENT1_TRIG: - wakeup_reason = "External Event 1"; - break; - case GPIO_TRIG: - wakeup_reason = "GPIO"; - break; - case TIMER_EXPIRE: - wakeup_reason = "Wakeup Timer"; - break; - case SDIO_TRIG: - wakeup_reason = "SDIO"; - break; - case MAC_TRIG: - wakeup_reason = "MAC"; - break; - case UART0_TRIG: - wakeup_reason = "UART0"; - break; - case UART1_TRIG: - wakeup_reason = "UART1"; - break; -#if !defined(USE_ESP32_VARIANT_ESP32C2) - case TOUCH_TRIG: - wakeup_reason = "Touch"; - break; -#endif - case SAR_TRIG: - wakeup_reason = "SAR"; - break; - case BT_TRIG: - wakeup_reason = "BT"; - break; - default: - wakeup_reason = "Unknown"; - } - ESP_LOGD(TAG, "Wakeup Reason: %s", wakeup_reason); + std::string wakeup_reason = this->get_wakeup_cause_(); device_info += "|Wakeup: "; device_info += wakeup_reason; } diff --git a/esphome/components/debug/sensor.py b/esphome/components/debug/sensor.py index 0a23658907..4669095d5d 100644 --- a/esphome/components/debug/sensor.py +++ b/esphome/components/debug/sensor.py @@ -1,5 +1,6 @@ import esphome.codegen as cg from esphome.components import sensor +from esphome.components.esp32 import CONF_CPU_FREQUENCY import esphome.config_validation as cv from esphome.const import ( CONF_BLOCK, @@ -10,6 +11,7 @@ from esphome.const import ( ICON_COUNTER, ICON_TIMER, UNIT_BYTES, + UNIT_HERTZ, UNIT_MILLISECOND, UNIT_PERCENT, ) @@ -60,6 +62,14 @@ CONFIG_SCHEMA = { entity_category=ENTITY_CATEGORY_DIAGNOSTIC, ), ), + cv.Optional(CONF_CPU_FREQUENCY): cv.All( + sensor.sensor_schema( + unit_of_measurement=UNIT_HERTZ, + icon="mdi:speedometer", + accuracy_decimals=0, + entity_category=ENTITY_CATEGORY_DIAGNOSTIC, + ), + ), } @@ -85,3 +95,7 @@ async def to_code(config): if psram_conf := config.get(CONF_PSRAM): sens = await sensor.new_sensor(psram_conf) cg.add(debug_component.set_psram_sensor(sens)) + + if cpu_freq_conf := config.get(CONF_CPU_FREQUENCY): + sens = await sensor.new_sensor(cpu_freq_conf) + cg.add(debug_component.set_cpu_frequency_sensor(sens)) diff --git a/esphome/components/esp32/__init__.py b/esphome/components/esp32/__init__.py index 307766ff94..12d0f9fcd5 100644 --- a/esphome/components/esp32/__init__.py +++ b/esphome/components/esp32/__init__.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +import itertools import logging import os from pathlib import Path @@ -37,6 +38,7 @@ from esphome.const import ( __version__, ) from esphome.core import CORE, HexInt, TimePeriod +from esphome.cpp_generator import RawExpression import esphome.final_validate as fv from esphome.helpers import copy_file_if_changed, mkdir_p, write_file_if_changed @@ -54,6 +56,12 @@ from .const import ( # noqa KEY_SUBMODULES, KEY_VARIANT, VARIANT_ESP32, + VARIANT_ESP32C2, + VARIANT_ESP32C3, + VARIANT_ESP32C6, + VARIANT_ESP32H2, + VARIANT_ESP32S2, + VARIANT_ESP32S3, VARIANT_FRIENDLY, VARIANTS, ) @@ -70,7 +78,43 @@ CONF_RELEASE = "release" CONF_ENABLE_IDF_EXPERIMENTAL_FEATURES = "enable_idf_experimental_features" +def get_cpu_frequencies(*frequencies): + return [str(x) + "MHZ" for x in frequencies] + + +CPU_FREQUENCIES = { + VARIANT_ESP32: get_cpu_frequencies(80, 160, 240), + VARIANT_ESP32S2: get_cpu_frequencies(80, 160, 240), + VARIANT_ESP32S3: get_cpu_frequencies(80, 160, 240), + VARIANT_ESP32C2: get_cpu_frequencies(80, 120), + VARIANT_ESP32C3: get_cpu_frequencies(80, 160), + VARIANT_ESP32C6: get_cpu_frequencies(80, 120, 160), + VARIANT_ESP32H2: get_cpu_frequencies(16, 32, 48, 64, 96), +} + +# Make sure not missed here if a new variant added. +assert all(v in CPU_FREQUENCIES for v in VARIANTS) + +FULL_CPU_FREQUENCIES = set(itertools.chain.from_iterable(CPU_FREQUENCIES.values())) + + def set_core_data(config): + cpu_frequency = config.get(CONF_CPU_FREQUENCY, None) + variant = config[CONF_VARIANT] + # if not specified in config, set to 160MHz if supported, the fastest otherwise + if cpu_frequency is None: + choices = CPU_FREQUENCIES[variant] + if "160MHZ" in choices: + cpu_frequency = "160MHZ" + else: + cpu_frequency = choices[-1] + config[CONF_CPU_FREQUENCY] = cpu_frequency + elif cpu_frequency not in CPU_FREQUENCIES[variant]: + raise cv.Invalid( + f"Invalid CPU frequency '{cpu_frequency}' for {config[CONF_VARIANT]}", + path=[CONF_CPU_FREQUENCY], + ) + CORE.data[KEY_ESP32] = {} CORE.data[KEY_CORE][KEY_TARGET_PLATFORM] = PLATFORM_ESP32 conf = config[CONF_FRAMEWORK] @@ -83,6 +127,7 @@ def set_core_data(config): CORE.data[KEY_CORE][KEY_FRAMEWORK_VERSION] = cv.Version.parse( config[CONF_FRAMEWORK][CONF_VERSION] ) + CORE.data[KEY_ESP32][KEY_BOARD] = config[CONF_BOARD] CORE.data[KEY_ESP32][KEY_VARIANT] = config[CONF_VARIANT] CORE.data[KEY_ESP32][KEY_EXTRA_BUILD_FILES] = {} @@ -553,11 +598,15 @@ FLASH_SIZES = [ ] CONF_FLASH_SIZE = "flash_size" +CONF_CPU_FREQUENCY = "cpu_frequency" CONF_PARTITIONS = "partitions" CONFIG_SCHEMA = cv.All( cv.Schema( { cv.Required(CONF_BOARD): cv.string_strict, + cv.Optional(CONF_CPU_FREQUENCY): cv.one_of( + *FULL_CPU_FREQUENCIES, upper=True + ), cv.Optional(CONF_FLASH_SIZE, default="4MB"): cv.one_of( *FLASH_SIZES, upper=True ), @@ -598,6 +647,7 @@ async def to_code(config): os.path.join(os.path.dirname(__file__), "post_build.py.script"), ) + freq = config[CONF_CPU_FREQUENCY][:-3] if conf[CONF_TYPE] == FRAMEWORK_ESP_IDF: cg.add_platformio_option("framework", "espidf") cg.add_build_flag("-DUSE_ESP_IDF") @@ -631,6 +681,9 @@ async def to_code(config): add_idf_sdkconfig_option("CONFIG_ESP_TASK_WDT_CHECK_IDLE_TASK_CPU0", False) add_idf_sdkconfig_option("CONFIG_ESP_TASK_WDT_CHECK_IDLE_TASK_CPU1", False) + # Set default CPU frequency + add_idf_sdkconfig_option(f"CONFIG_ESP_DEFAULT_CPU_FREQ_MHZ_{freq}", True) + cg.add_platformio_option("board_build.partitions", "partitions.csv") if CONF_PARTITIONS in config: add_extra_build_file( @@ -696,6 +749,7 @@ async def to_code(config): f"VERSION_CODE({framework_ver.major}, {framework_ver.minor}, {framework_ver.patch})" ), ) + cg.add(RawExpression(f"setCpuFrequencyMhz({freq})")) APP_PARTITION_SIZES = { diff --git a/esphome/components/esp32/core.cpp b/esphome/components/esp32/core.cpp index ff8e663ec1..c90d68d00e 100644 --- a/esphome/components/esp32/core.cpp +++ b/esphome/components/esp32/core.cpp @@ -13,11 +13,13 @@ #include #ifdef USE_ARDUINO -#include -#endif +#include +#else +#include void setup(); void loop(); +#endif namespace esphome { @@ -59,9 +61,13 @@ uint32_t arch_get_cpu_cycle_count() { return esp_cpu_get_cycle_count(); } uint32_t arch_get_cpu_cycle_count() { return cpu_hal_get_cycle_count(); } #endif uint32_t arch_get_cpu_freq_hz() { - rtc_cpu_freq_config_t config; - rtc_clk_cpu_freq_get_config(&config); - return config.freq_mhz * 1000000U; + uint32_t freq = 0; +#ifdef USE_ESP_IDF + esp_clk_tree_src_get_freq_hz(SOC_MOD_CLK_CPU, ESP_CLK_TREE_SRC_FREQ_PRECISION_CACHED, &freq); +#elif defined(USE_ARDUINO) + freq = ESP.getCpuFreqMHz() * 1000000; +#endif + return freq; } #ifdef USE_ESP_IDF diff --git a/tests/components/debug/common.yaml b/tests/components/debug/common.yaml index 5845beaa80..a9d74e6865 100644 --- a/tests/components/debug/common.yaml +++ b/tests/components/debug/common.yaml @@ -1 +1,18 @@ debug: + +text_sensor: + - platform: debug + device: + name: "Device Info" + reset_reason: + name: "Reset Reason" + +sensor: + - platform: debug + free: + name: "Heap Free" + loop_time: + name: "Loop Time" + cpu_frequency: + name: "CPU Frequency" + diff --git a/tests/components/debug/test.esp32-ard.yaml b/tests/components/debug/test.esp32-ard.yaml index dade44d145..8e19a4d627 100644 --- a/tests/components/debug/test.esp32-ard.yaml +++ b/tests/components/debug/test.esp32-ard.yaml @@ -1 +1,4 @@ <<: !include common.yaml + +esp32: + cpu_frequency: 240MHz diff --git a/tests/components/debug/test.esp32-c3-ard.yaml b/tests/components/debug/test.esp32-c3-ard.yaml index dade44d145..7d43491862 100644 --- a/tests/components/debug/test.esp32-c3-ard.yaml +++ b/tests/components/debug/test.esp32-c3-ard.yaml @@ -1 +1,4 @@ <<: !include common.yaml + +esp32: + cpu_frequency: 80MHz diff --git a/tests/components/debug/test.esp32-idf.yaml b/tests/components/debug/test.esp32-idf.yaml index dade44d145..f7483a54b3 100644 --- a/tests/components/debug/test.esp32-idf.yaml +++ b/tests/components/debug/test.esp32-idf.yaml @@ -1 +1,13 @@ <<: !include common.yaml + +esp32: + cpu_frequency: 240MHz + +sensor: + - platform: debug + free: + name: "Heap Free" + psram: + name: "Free PSRAM" + +psram: From c7f597bc753deef9c49d0edb95f7e4c79e5f170d Mon Sep 17 00:00:00 2001 From: Kevin Ahrendt Date: Thu, 1 May 2025 06:11:09 -0500 Subject: [PATCH 074/102] [voice_assistant] voice assistant can configure enabled wake words (#8657) --- .../components/voice_assistant/__init__.py | 16 +++-- .../voice_assistant/voice_assistant.cpp | 53 ++++++++++++++ .../voice_assistant/voice_assistant.h | 20 ++++-- .../voice_assistant/common-idf.yaml | 69 +++++++++++++++++++ .../voice_assistant/test.esp32-c3-idf.yaml | 2 +- .../voice_assistant/test.esp32-idf.yaml | 2 +- 6 files changed, 150 insertions(+), 12 deletions(-) create mode 100644 tests/components/voice_assistant/common-idf.yaml diff --git a/esphome/components/voice_assistant/__init__.py b/esphome/components/voice_assistant/__init__.py index ca0b6da742..b9309ab422 100644 --- a/esphome/components/voice_assistant/__init__.py +++ b/esphome/components/voice_assistant/__init__.py @@ -1,7 +1,7 @@ from esphome import automation from esphome.automation import register_action, register_condition import esphome.codegen as cg -from esphome.components import media_player, microphone, speaker +from esphome.components import media_player, micro_wake_word, microphone, speaker import esphome.config_validation as cv from esphome.const import ( CONF_ID, @@ -41,6 +41,7 @@ CONF_AUTO_GAIN = "auto_gain" CONF_NOISE_SUPPRESSION_LEVEL = "noise_suppression_level" CONF_VOLUME_MULTIPLIER = "volume_multiplier" +CONF_MICRO_WAKE_WORD = "micro_wake_word" CONF_WAKE_WORD = "wake_word" CONF_CONVERSATION_TIMEOUT = "conversation_timeout" @@ -96,11 +97,12 @@ CONFIG_SCHEMA = cv.All( min_channels=1, max_channels=1, ), - cv.Exclusive(CONF_SPEAKER, "output"): cv.use_id(speaker.Speaker), cv.Exclusive(CONF_MEDIA_PLAYER, "output"): cv.use_id( media_player.MediaPlayer ), + cv.Exclusive(CONF_SPEAKER, "output"): cv.use_id(speaker.Speaker), cv.Optional(CONF_USE_WAKE_WORD, default=False): cv.boolean, + cv.Optional(CONF_MICRO_WAKE_WORD): cv.use_id(micro_wake_word.MicroWakeWord), cv.Optional(CONF_VAD_THRESHOLD): cv.invalid( "VAD threshold is no longer supported, as it requires the deprecated esp_adf external component. Use an i2s_audio microphone/speaker instead. Additionally, you may need to configure the audio_adc and audio_dac components depending on your hardware." ), @@ -191,14 +193,18 @@ async def to_code(config): mic_source = await microphone.microphone_source_to_code(config[CONF_MICROPHONE]) cg.add(var.set_microphone_source(mic_source)) - if CONF_SPEAKER in config: - spkr = await cg.get_variable(config[CONF_SPEAKER]) - cg.add(var.set_speaker(spkr)) + if CONF_MICRO_WAKE_WORD in config: + mww = await cg.get_variable(config[CONF_MICRO_WAKE_WORD]) + cg.add(var.set_micro_wake_word(mww)) if CONF_MEDIA_PLAYER in config: mp = await cg.get_variable(config[CONF_MEDIA_PLAYER]) cg.add(var.set_media_player(mp)) + if CONF_SPEAKER in config: + spkr = await cg.get_variable(config[CONF_SPEAKER]) + cg.add(var.set_speaker(spkr)) + cg.add(var.set_use_wake_word(config[CONF_USE_WAKE_WORD])) if (vad_threshold := config.get(CONF_VAD_THRESHOLD)) is not None: diff --git a/esphome/components/voice_assistant/voice_assistant.cpp b/esphome/components/voice_assistant/voice_assistant.cpp index 37b97239c8..d35717ef91 100644 --- a/esphome/components/voice_assistant/voice_assistant.cpp +++ b/esphome/components/voice_assistant/voice_assistant.cpp @@ -869,6 +869,59 @@ void VoiceAssistant::on_announce(const api::VoiceAssistantAnnounceRequest &msg) #endif } +void VoiceAssistant::on_set_configuration(const std::vector &active_wake_words) { +#ifdef USE_MICRO_WAKE_WORD + if (this->micro_wake_word_) { + // Disable all wake words first + for (auto &model : this->micro_wake_word_->get_wake_words()) { + model->disable(); + } + + // Enable only active wake words + for (auto ww_id : active_wake_words) { + for (auto &model : this->micro_wake_word_->get_wake_words()) { + if (model->get_id() == ww_id) { + model->enable(); + ESP_LOGD(TAG, "Enabled wake word: %s (id=%s)", model->get_wake_word().c_str(), model->get_id().c_str()); + } + } + } + } +#endif +}; + +const Configuration &VoiceAssistant::get_configuration() { + this->config_.available_wake_words.clear(); + this->config_.active_wake_words.clear(); + +#ifdef USE_MICRO_WAKE_WORD + if (this->micro_wake_word_) { + this->config_.max_active_wake_words = 1; + + for (auto &model : this->micro_wake_word_->get_wake_words()) { + if (model->is_enabled()) { + this->config_.active_wake_words.push_back(model->get_id()); + } + + WakeWord wake_word; + wake_word.id = model->get_id(); + wake_word.wake_word = model->get_wake_word(); + for (const auto &lang : model->get_trained_languages()) { + wake_word.trained_languages.push_back(lang); + } + this->config_.available_wake_words.push_back(std::move(wake_word)); + } + } else { +#endif + // No microWakeWord + this->config_.max_active_wake_words = 0; +#ifdef USE_MICRO_WAKE_WORD + } +#endif + + return this->config_; +}; + VoiceAssistant *global_voice_assistant = nullptr; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) } // namespace voice_assistant diff --git a/esphome/components/voice_assistant/voice_assistant.h b/esphome/components/voice_assistant/voice_assistant.h index 7122d69527..865731522f 100644 --- a/esphome/components/voice_assistant/voice_assistant.h +++ b/esphome/components/voice_assistant/voice_assistant.h @@ -12,12 +12,15 @@ #include "esphome/components/api/api_connection.h" #include "esphome/components/api/api_pb2.h" #include "esphome/components/microphone/microphone_source.h" -#ifdef USE_SPEAKER -#include "esphome/components/speaker/speaker.h" -#endif #ifdef USE_MEDIA_PLAYER #include "esphome/components/media_player/media_player.h" #endif +#ifdef USE_MICRO_WAKE_WORD +#include "esphome/components/micro_wake_word/micro_wake_word.h" +#endif +#ifdef USE_SPEAKER +#include "esphome/components/speaker/speaker.h" +#endif #include "esphome/components/socket/socket.h" #include @@ -99,6 +102,9 @@ class VoiceAssistant : public Component { void failed_to_start(); void set_microphone_source(microphone::MicrophoneSource *mic_source) { this->mic_source_ = mic_source; } +#ifdef USE_MICRO_WAKE_WORD + void set_micro_wake_word(micro_wake_word::MicroWakeWord *mww) { this->micro_wake_word_ = mww; } +#endif #ifdef USE_SPEAKER void set_speaker(speaker::Speaker *speaker) { this->speaker_ = speaker; @@ -152,8 +158,8 @@ class VoiceAssistant : public Component { void on_audio(const api::VoiceAssistantAudio &msg); void on_timer_event(const api::VoiceAssistantTimerEventResponse &msg); void on_announce(const api::VoiceAssistantAnnounceRequest &msg); - void on_set_configuration(const std::vector &active_wake_words){}; - const Configuration &get_configuration() { return this->config_; }; + void on_set_configuration(const std::vector &active_wake_words); + const Configuration &get_configuration(); bool is_running() const { return this->state_ != State::IDLE; } void set_continuous(bool continuous) { this->continuous_ = continuous; } @@ -295,6 +301,10 @@ class VoiceAssistant : public Component { bool start_udp_socket_(); Configuration config_{}; + +#ifdef USE_MICRO_WAKE_WORD + micro_wake_word::MicroWakeWord *micro_wake_word_{nullptr}; +#endif }; template class StartAction : public Action, public Parented { diff --git a/tests/components/voice_assistant/common-idf.yaml b/tests/components/voice_assistant/common-idf.yaml new file mode 100644 index 0000000000..b1d249d5b4 --- /dev/null +++ b/tests/components/voice_assistant/common-idf.yaml @@ -0,0 +1,69 @@ +esphome: + on_boot: + then: + - voice_assistant.start + - voice_assistant.start_continuous + - voice_assistant.stop + +wifi: + ssid: MySSID + password: password1 + +api: + +i2s_audio: + i2s_lrclk_pin: ${i2s_lrclk_pin} + i2s_bclk_pin: ${i2s_bclk_pin} + i2s_mclk_pin: ${i2s_mclk_pin} + +micro_wake_word: + id: mww_id + on_wake_word_detected: + - voice_assistant.start: + wake_word: !lambda return wake_word; + models: + - model: okay_nabu + +microphone: + - platform: i2s_audio + id: mic_id_external + i2s_din_pin: ${i2s_din_pin} + adc_type: external + pdm: false + +speaker: + - platform: i2s_audio + id: speaker_id + dac_type: external + i2s_dout_pin: ${i2s_dout_pin} + +voice_assistant: + microphone: + microphone: mic_id_external + gain_factor: 4 + channels: 0 + speaker: speaker_id + micro_wake_word: mww_id + conversation_timeout: 60s + on_listening: + - logger.log: "Voice assistant microphone listening" + on_start: + - logger.log: "Voice assistant started" + on_stt_end: + - logger.log: + format: "Voice assistant STT ended with result %s" + args: [x.c_str()] + on_tts_start: + - logger.log: + format: "Voice assistant TTS started with text %s" + args: [x.c_str()] + on_tts_end: + - logger.log: + format: "Voice assistant TTS ended with url %s" + args: [x.c_str()] + on_end: + - logger.log: "Voice assistant ended" + on_error: + - logger.log: + format: "Voice assistant error - code %s, message: %s" + args: [code.c_str(), message.c_str()] diff --git a/tests/components/voice_assistant/test.esp32-c3-idf.yaml b/tests/components/voice_assistant/test.esp32-c3-idf.yaml index f596d927cb..46745e4308 100644 --- a/tests/components/voice_assistant/test.esp32-c3-idf.yaml +++ b/tests/components/voice_assistant/test.esp32-c3-idf.yaml @@ -5,4 +5,4 @@ substitutions: i2s_din_pin: GPIO3 i2s_dout_pin: GPIO2 -<<: !include common.yaml +<<: !include common-idf.yaml diff --git a/tests/components/voice_assistant/test.esp32-idf.yaml b/tests/components/voice_assistant/test.esp32-idf.yaml index f6e553f9dc..0fe5d347be 100644 --- a/tests/components/voice_assistant/test.esp32-idf.yaml +++ b/tests/components/voice_assistant/test.esp32-idf.yaml @@ -5,4 +5,4 @@ substitutions: i2s_din_pin: GPIO13 i2s_dout_pin: GPIO12 -<<: !include common.yaml +<<: !include common-idf.yaml From 836e5ffa4371f91f1cc1cc804b3f5f15124c3cb0 Mon Sep 17 00:00:00 2001 From: functionpointer Date: Thu, 1 May 2025 14:01:02 +0200 Subject: [PATCH 075/102] [mlx90393] Add verification for register contents (#8279) Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com> --- esphome/components/mlx90393/sensor.py | 5 + .../components/mlx90393/sensor_mlx90393.cpp | 208 ++++++++++++++++-- esphome/components/mlx90393/sensor_mlx90393.h | 25 ++- tests/components/mlx90393/common.yaml | 3 +- 4 files changed, 222 insertions(+), 19 deletions(-) diff --git a/esphome/components/mlx90393/sensor.py b/esphome/components/mlx90393/sensor.py index cb9cb84aae..372bb05bda 100644 --- a/esphome/components/mlx90393/sensor.py +++ b/esphome/components/mlx90393/sensor.py @@ -63,6 +63,11 @@ def _validate(config): raise cv.Invalid( f"{axis}: {CONF_RESOLUTION} cannot be {res} with {CONF_TEMPERATURE_COMPENSATION} enabled" ) + if config[CONF_HALLCONF] == 0xC: + if (config[CONF_OVERSAMPLING], config[CONF_FILTER]) in [(0, 0), (1, 0), (0, 1)]: + raise cv.Invalid( + f"{CONF_OVERSAMPLING}=={config[CONF_OVERSAMPLING]} and {CONF_FILTER}=={config[CONF_FILTER]} not allowed with {CONF_HALLCONF}=={config[CONF_HALLCONF]:#02x}" + ) return config diff --git a/esphome/components/mlx90393/sensor_mlx90393.cpp b/esphome/components/mlx90393/sensor_mlx90393.cpp index e86080fe9c..46fe68fab0 100644 --- a/esphome/components/mlx90393/sensor_mlx90393.cpp +++ b/esphome/components/mlx90393/sensor_mlx90393.cpp @@ -6,13 +6,41 @@ namespace mlx90393 { static const char *const TAG = "mlx90393"; +const LogString *settings_to_string(MLX90393Setting setting) { + switch (setting) { + case MLX90393_GAIN_SEL: + return LOG_STR("gain"); + case MLX90393_RESOLUTION: + return LOG_STR("resolution"); + case MLX90393_OVER_SAMPLING: + return LOG_STR("oversampling"); + case MLX90393_DIGITAL_FILTERING: + return LOG_STR("digital filtering"); + case MLX90393_TEMPERATURE_OVER_SAMPLING: + return LOG_STR("temperature oversampling"); + case MLX90393_TEMPERATURE_COMPENSATION: + return LOG_STR("temperature compensation"); + case MLX90393_HALLCONF: + return LOG_STR("hallconf"); + case MLX90393_LAST: + return LOG_STR("error"); + default: + return LOG_STR("unknown"); + } +}; + bool MLX90393Cls::transceive(const uint8_t *request, size_t request_size, uint8_t *response, size_t response_size) { i2c::ErrorCode e = this->write(request, request_size); if (e != i2c::ErrorCode::ERROR_OK) { + ESP_LOGV(TAG, "i2c failed to write %u", e); return false; } e = this->read(response, response_size); - return e == i2c::ErrorCode::ERROR_OK; + if (e != i2c::ErrorCode::ERROR_OK) { + ESP_LOGV(TAG, "i2c failed to read %u", e); + return false; + } + return true; } bool MLX90393Cls::has_drdy_pin() { return this->drdy_pin_ != nullptr; } @@ -27,6 +55,53 @@ bool MLX90393Cls::read_drdy_pin() { void MLX90393Cls::sleep_millis(uint32_t millis) { delay(millis); } void MLX90393Cls::sleep_micros(uint32_t micros) { delayMicroseconds(micros); } +uint8_t MLX90393Cls::apply_setting_(MLX90393Setting which) { + uint8_t ret = -1; + switch (which) { + case MLX90393_GAIN_SEL: + ret = this->mlx_.setGainSel(this->gain_); + break; + case MLX90393_RESOLUTION: + ret = this->mlx_.setResolution(this->resolutions_[0], this->resolutions_[1], this->resolutions_[2]); + break; + case MLX90393_OVER_SAMPLING: + ret = this->mlx_.setOverSampling(this->oversampling_); + break; + case MLX90393_DIGITAL_FILTERING: + ret = this->mlx_.setDigitalFiltering(this->filter_); + break; + case MLX90393_TEMPERATURE_OVER_SAMPLING: + ret = this->mlx_.setTemperatureOverSampling(this->temperature_oversampling_); + break; + case MLX90393_TEMPERATURE_COMPENSATION: + ret = this->mlx_.setTemperatureCompensation(this->temperature_compensation_); + break; + case MLX90393_HALLCONF: + ret = this->mlx_.setHallConf(this->hallconf_); + break; + default: + break; + } + if (ret != MLX90393::STATUS_OK) { + ESP_LOGE(TAG, "failed to apply %s", LOG_STR_ARG(settings_to_string(which))); + } + return ret; +} + +bool MLX90393Cls::apply_all_settings_() { + // perform dummy read after reset + // first one always gets NAK even tough everything is fine + uint8_t ignore = 0; + this->mlx_.getGainSel(ignore); + + uint8_t result = MLX90393::STATUS_OK; + for (int i = MLX90393_GAIN_SEL; i != MLX90393_LAST; i++) { + MLX90393Setting stage = static_cast(i); + result |= this->apply_setting_(stage); + } + return result == MLX90393::STATUS_OK; +} + void MLX90393Cls::setup() { ESP_LOGCONFIG(TAG, "Setting up MLX90393..."); // note the two arguments A0 and A1 which are used to construct an i2c address @@ -34,19 +109,12 @@ void MLX90393Cls::setup() { // see the transceive function above, which uses the address from I2CComponent this->mlx_.begin_with_hal(this, 0, 0); - this->mlx_.setGainSel(this->gain_); + if (!this->apply_all_settings_()) { + this->mark_failed(); + } - this->mlx_.setResolution(this->resolutions_[0], this->resolutions_[1], this->resolutions_[2]); - - this->mlx_.setOverSampling(this->oversampling_); - - this->mlx_.setDigitalFiltering(this->filter_); - - this->mlx_.setTemperatureOverSampling(this->temperature_oversampling_); - - this->mlx_.setTemperatureCompensation(this->temperature_compensation_); - - this->mlx_.setHallConf(this->hallconf_); + // start verify settings process + this->set_timeout("verify settings", 3000, [this]() { this->verify_settings_timeout_(MLX90393_GAIN_SEL); }); } void MLX90393Cls::dump_config() { @@ -91,5 +159,119 @@ void MLX90393Cls::update() { } } +bool MLX90393Cls::verify_setting_(MLX90393Setting which) { + uint8_t read_value = 0xFF; + uint8_t expected_value = 0xFF; + uint8_t read_status = -1; + char read_back_str[25] = {0}; + + switch (which) { + case MLX90393_GAIN_SEL: { + read_status = this->mlx_.getGainSel(read_value); + expected_value = this->gain_; + break; + } + + case MLX90393_RESOLUTION: { + uint8_t read_resolutions[3] = {0xFF}; + read_status = this->mlx_.getResolution(read_resolutions[0], read_resolutions[1], read_resolutions[2]); + snprintf(read_back_str, sizeof(read_back_str), "%u %u %u expected %u %u %u", read_resolutions[0], + read_resolutions[1], read_resolutions[2], this->resolutions_[0], this->resolutions_[1], + this->resolutions_[2]); + bool is_correct = true; + for (int i = 0; i < 3; i++) { + is_correct &= read_resolutions[i] == this->resolutions_[i]; + } + if (is_correct) { + // set read_value and expected_value to same number, so the code blow recognizes it is correct + read_value = 0; + expected_value = 0; + } else { + // set to different numbers, to show incorrect + read_value = 1; + expected_value = 0; + } + break; + } + case MLX90393_OVER_SAMPLING: { + read_status = this->mlx_.getOverSampling(read_value); + expected_value = this->oversampling_; + break; + } + case MLX90393_DIGITAL_FILTERING: { + read_status = this->mlx_.getDigitalFiltering(read_value); + expected_value = this->filter_; + break; + } + case MLX90393_TEMPERATURE_OVER_SAMPLING: { + read_status = this->mlx_.getTemperatureOverSampling(read_value); + expected_value = this->temperature_oversampling_; + break; + } + case MLX90393_TEMPERATURE_COMPENSATION: { + read_status = this->mlx_.getTemperatureCompensation(read_value); + expected_value = (bool) this->temperature_compensation_; + break; + } + case MLX90393_HALLCONF: { + read_status = this->mlx_.getHallConf(read_value); + expected_value = this->hallconf_; + break; + } + default: { + return false; + } + } + if (read_status != MLX90393::STATUS_OK) { + ESP_LOGE(TAG, "verify error: failed to read %s", LOG_STR_ARG(settings_to_string(which))); + return false; + } + if (read_back_str[0] == 0x0) { + snprintf(read_back_str, sizeof(read_back_str), "%u expected %u", read_value, expected_value); + } + bool is_correct = read_value == expected_value; + if (!is_correct) { + ESP_LOGW(TAG, "verify failed: read back wrong %s: got %s", LOG_STR_ARG(settings_to_string(which)), read_back_str); + return false; + } + ESP_LOGD(TAG, "verify succeeded for %s. got %s", LOG_STR_ARG(settings_to_string(which)), read_back_str); + return true; +} + +/** + * Regularly checks that our settings are still applied. + * Used to catch spurious chip resets. + * + * returns true if everything is fine. + * false if not + */ +void MLX90393Cls::verify_settings_timeout_(MLX90393Setting stage) { + bool is_setting_ok = this->verify_setting_(stage); + + if (!is_setting_ok) { + if (this->mlx_.checkStatus(this->mlx_.reset()) != MLX90393::STATUS_OK) { + ESP_LOGE(TAG, "failed to reset device"); + this->status_set_error(); + this->mark_failed(); + return; + } + + if (!this->apply_all_settings_()) { + ESP_LOGE(TAG, "failed to re-apply settings"); + this->status_set_error(); + this->mark_failed(); + } else { + ESP_LOGI(TAG, "reset and re-apply settings completed"); + } + } + + MLX90393Setting next_stage = static_cast(static_cast(stage) + 1); + if (next_stage == MLX90393_LAST) { + next_stage = static_cast(0); + } + + this->set_timeout("verify settings", 3000, [this, next_stage]() { this->verify_settings_timeout_(next_stage); }); +} + } // namespace mlx90393 } // namespace esphome diff --git a/esphome/components/mlx90393/sensor_mlx90393.h b/esphome/components/mlx90393/sensor_mlx90393.h index 479891a76c..8a6f3321f9 100644 --- a/esphome/components/mlx90393/sensor_mlx90393.h +++ b/esphome/components/mlx90393/sensor_mlx90393.h @@ -1,15 +1,26 @@ #pragma once -#include "esphome/core/component.h" -#include "esphome/components/sensor/sensor.h" -#include "esphome/components/i2c/i2c.h" -#include "esphome/core/hal.h" #include #include +#include "esphome/components/i2c/i2c.h" +#include "esphome/components/sensor/sensor.h" +#include "esphome/core/component.h" +#include "esphome/core/hal.h" namespace esphome { namespace mlx90393 { +enum MLX90393Setting { + MLX90393_GAIN_SEL = 0, + MLX90393_RESOLUTION, + MLX90393_OVER_SAMPLING, + MLX90393_DIGITAL_FILTERING, + MLX90393_TEMPERATURE_OVER_SAMPLING, + MLX90393_TEMPERATURE_COMPENSATION, + MLX90393_HALLCONF, + MLX90393_LAST, +}; + class MLX90393Cls : public PollingComponent, public i2c::I2CDevice, public MLX90393Hal { public: void setup() override; @@ -58,6 +69,12 @@ class MLX90393Cls : public PollingComponent, public i2c::I2CDevice, public MLX90 bool temperature_compensation_{false}; uint8_t hallconf_{0xC}; GPIOPin *drdy_pin_{nullptr}; + + bool apply_all_settings_(); + uint8_t apply_setting_(MLX90393Setting which); + + bool verify_setting_(MLX90393Setting which); + void verify_settings_timeout_(MLX90393Setting stage); }; } // namespace mlx90393 diff --git a/tests/components/mlx90393/common.yaml b/tests/components/mlx90393/common.yaml index 0b074f9be3..58f3b6ecf5 100644 --- a/tests/components/mlx90393/common.yaml +++ b/tests/components/mlx90393/common.yaml @@ -5,8 +5,7 @@ i2c: sensor: - platform: mlx90393 - oversampling: 1 - filter: 0 + oversampling: 3 gain: 1X temperature_compensation: true x_axis: From d6699fa3c0f2e1b7fe0551441505587863d998f9 Mon Sep 17 00:00:00 2001 From: Trent Houliston Date: Thu, 1 May 2025 22:29:12 +1000 Subject: [PATCH 076/102] Check for missed pulse_meter ISRs in the main loop (#6126) --- .../pulse_meter/pulse_meter_sensor.cpp | 44 ++++++++++++++----- .../pulse_meter/pulse_meter_sensor.h | 10 ++--- 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/esphome/components/pulse_meter/pulse_meter_sensor.cpp b/esphome/components/pulse_meter/pulse_meter_sensor.cpp index 836a84b391..b82cb7a15c 100644 --- a/esphome/components/pulse_meter/pulse_meter_sensor.cpp +++ b/esphome/components/pulse_meter/pulse_meter_sensor.cpp @@ -18,6 +18,9 @@ void PulseMeterSensor::setup() { this->pin_->setup(); this->isr_pin_ = pin_->to_isr(); + // Set the pin value to the current value to avoid a false edge + this->last_pin_val_ = this->pin_->digital_read(); + // Set the last processed edge to now for the first timeout this->last_processed_edge_us_ = micros(); @@ -25,23 +28,37 @@ void PulseMeterSensor::setup() { this->pin_->attach_interrupt(PulseMeterSensor::edge_intr, this, gpio::INTERRUPT_RISING_EDGE); } else if (this->filter_mode_ == FILTER_PULSE) { // Set the pin value to the current value to avoid a false edge - this->pulse_state_.last_pin_val_ = this->isr_pin_.digital_read(); - this->pulse_state_.latched_ = this->pulse_state_.last_pin_val_; + this->pulse_state_.latched_ = this->last_pin_val_; this->pin_->attach_interrupt(PulseMeterSensor::pulse_intr, this, gpio::INTERRUPT_ANY_EDGE); } } void PulseMeterSensor::loop() { - const uint32_t now = micros(); - // Reset the count in get before we pass it back to the ISR as set this->get_->count_ = 0; - // Swap out set and get to get the latest state from the ISR - // The ISR could interrupt on any of these lines and the results would be consistent - auto *temp = this->set_; - this->set_ = this->get_; - this->get_ = temp; + { + // Lock the interrupt so the interrupt code doesn't interfere with itself + InterruptLock lock; + + // Sometimes ESP devices miss interrupts if the edge rises or falls too slowly. + // See https://github.com/espressif/arduino-esp32/issues/4172 + // If the edges are rising too slowly it also implies that the pulse rate is slow. + // Therefore the update rate of the loop is likely fast enough to detect the edges. + // When the main loop detects an edge that the ISR didn't it will run the ISR functions directly. + bool current = this->pin_->digital_read(); + if (this->filter_mode_ == FILTER_EDGE && current && !this->last_pin_val_) { + PulseMeterSensor::edge_intr(this); + } else if (this->filter_mode_ == FILTER_PULSE && current != this->last_pin_val_) { + PulseMeterSensor::pulse_intr(this); + } + this->last_pin_val_ = current; + + // Swap out set and get to get the latest state from the ISR + std::swap(this->set_, this->get_); + } + + const uint32_t now = micros(); // If an edge was peeked, repay the debt if (this->peeked_edge_ && this->get_->count_ > 0) { @@ -131,6 +148,9 @@ void IRAM_ATTR PulseMeterSensor::edge_intr(PulseMeterSensor *sensor) { set.last_rising_edge_us_ = now; set.count_++; } + + // This ISR is bound to rising edges, so the pin is high + sensor->last_pin_val_ = true; } void IRAM_ATTR PulseMeterSensor::pulse_intr(PulseMeterSensor *sensor) { @@ -144,9 +164,9 @@ void IRAM_ATTR PulseMeterSensor::pulse_intr(PulseMeterSensor *sensor) { // Filter length has passed since the last interrupt const bool length = now - state.last_intr_ >= sensor->filter_us_; - if (length && state.latched_ && !state.last_pin_val_) { // Long enough low edge + if (length && state.latched_ && !sensor->last_pin_val_) { // Long enough low edge state.latched_ = false; - } else if (length && !state.latched_ && state.last_pin_val_) { // Long enough high edge + } else if (length && !state.latched_ && sensor->last_pin_val_) { // Long enough high edge state.latched_ = true; set.last_detected_edge_us_ = state.last_intr_; set.count_++; @@ -158,7 +178,7 @@ void IRAM_ATTR PulseMeterSensor::pulse_intr(PulseMeterSensor *sensor) { set.last_rising_edge_us_ = !state.latched_ && pin_val ? now : set.last_detected_edge_us_; state.last_intr_ = now; - state.last_pin_val_ = pin_val; + sensor->last_pin_val_ = pin_val; } } // namespace pulse_meter diff --git a/esphome/components/pulse_meter/pulse_meter_sensor.h b/esphome/components/pulse_meter/pulse_meter_sensor.h index 76c4a35f03..748bab29ac 100644 --- a/esphome/components/pulse_meter/pulse_meter_sensor.h +++ b/esphome/components/pulse_meter/pulse_meter_sensor.h @@ -49,9 +49,7 @@ class PulseMeterSensor : public sensor::Sensor, public Component { // This struct (and the two pointers) are used to pass data between the ISR and loop. // These two pointers are exchanged each loop. - // Therefore you can't use data in the pointer to loop receives to set values in the pointer to loop sends. - // As a result it's easiest if you only use these pointers to send data from the ISR to the loop. - // (except for resetting the values) + // Use these to send data from the ISR to the loop not the other way around (except for resetting the values). struct State { uint32_t last_detected_edge_us_ = 0; uint32_t last_rising_edge_us_ = 0; @@ -61,9 +59,12 @@ class PulseMeterSensor : public sensor::Sensor, public Component { volatile State *set_ = state_; volatile State *get_ = state_ + 1; - // Only use these variables in the ISR + // Only use the following variables in the ISR or while guarded by an InterruptLock ISRInternalGPIOPin isr_pin_; + /// The last pin value seen + bool last_pin_val_ = false; + /// Filter state for edge mode struct EdgeState { uint32_t last_sent_edge_us_ = 0; @@ -74,7 +75,6 @@ class PulseMeterSensor : public sensor::Sensor, public Component { struct PulseState { uint32_t last_intr_ = 0; bool latched_ = false; - bool last_pin_val_ = false; }; PulseState pulse_state_{}; }; From ced7ae1d7a7411e93c584c4f00c20bc85a56baa0 Mon Sep 17 00:00:00 2001 From: Kevin Ahrendt Date: Thu, 1 May 2025 07:50:32 -0500 Subject: [PATCH 077/102] [debug] add missing header (#8666) --- esphome/components/debug/debug_esp32.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/esphome/components/debug/debug_esp32.cpp b/esphome/components/debug/debug_esp32.cpp index bc772a1d58..662e60501d 100644 --- a/esphome/components/debug/debug_esp32.cpp +++ b/esphome/components/debug/debug_esp32.cpp @@ -14,6 +14,8 @@ #include #endif +#include + namespace esphome { namespace debug { From db97440b0471508ed1be07c2abb09c9c38292694 Mon Sep 17 00:00:00 2001 From: Kevin Ahrendt Date: Thu, 1 May 2025 14:02:33 -0500 Subject: [PATCH 078/102] [microphone] Add software mute and fix wrong type for automations (#8667) --- esphome/components/microphone/__init__.py | 20 +++++++++++++++++-- esphome/components/microphone/automation.h | 12 +++++++++++ esphome/components/microphone/microphone.cpp | 21 ++++++++++++++++++++ esphome/components/microphone/microphone.h | 10 +++++++--- tests/components/microphone/common.yaml | 8 ++++++++ 5 files changed, 66 insertions(+), 5 deletions(-) create mode 100644 esphome/components/microphone/microphone.cpp diff --git a/esphome/components/microphone/__init__.py b/esphome/components/microphone/__init__.py index dcae513578..f85f0b76f3 100644 --- a/esphome/components/microphone/__init__.py +++ b/esphome/components/microphone/__init__.py @@ -32,6 +32,12 @@ CaptureAction = microphone_ns.class_( StopCaptureAction = microphone_ns.class_( "StopCaptureAction", automation.Action, cg.Parented.template(Microphone) ) +MuteAction = microphone_ns.class_( + "MuteAction", automation.Action, cg.Parented.template(Microphone) +) +UnmuteAction = microphone_ns.class_( + "UnmuteAction", automation.Action, cg.Parented.template(Microphone) +) DataTrigger = microphone_ns.class_( @@ -42,15 +48,15 @@ DataTrigger = microphone_ns.class_( IsCapturingCondition = microphone_ns.class_( "IsCapturingCondition", automation.Condition ) +IsMutedCondition = microphone_ns.class_("IsMutedCondition", automation.Condition) async def setup_microphone_core_(var, config): for conf in config.get(CONF_ON_DATA, []): trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) - # Future PR will change the vector type to uint8 await automation.build_automation( trigger, - [(cg.std_vector.template(cg.int16).operator("ref").operator("const"), "x")], + [(cg.std_vector.template(cg.uint8).operator("ref").operator("const"), "x")], conf, ) @@ -186,9 +192,19 @@ automation.register_action( "microphone.stop_capture", StopCaptureAction, MICROPHONE_ACTION_SCHEMA )(microphone_action) +automation.register_action("microphone.mute", MuteAction, MICROPHONE_ACTION_SCHEMA)( + microphone_action +) +automation.register_action("microphone.unmute", UnmuteAction, MICROPHONE_ACTION_SCHEMA)( + microphone_action +) + automation.register_condition( "microphone.is_capturing", IsCapturingCondition, MICROPHONE_ACTION_SCHEMA )(microphone_action) +automation.register_condition( + "microphone.is_muted", IsMutedCondition, MICROPHONE_ACTION_SCHEMA +)(microphone_action) @coroutine_with_priority(100.0) diff --git a/esphome/components/microphone/automation.h b/esphome/components/microphone/automation.h index 324699c0af..5745909c46 100644 --- a/esphome/components/microphone/automation.h +++ b/esphome/components/microphone/automation.h @@ -16,6 +16,13 @@ template class StopCaptureAction : public Action, public void play(Ts... x) override { this->parent_->stop(); } }; +template class MuteAction : public Action, public Parented { + void play(Ts... x) override { this->parent_->set_mute_state(true); } +}; +template class UnmuteAction : public Action, public Parented { + void play(Ts... x) override { this->parent_->set_mute_state(false); } +}; + class DataTrigger : public Trigger &> { public: explicit DataTrigger(Microphone *mic) { @@ -28,5 +35,10 @@ template class IsCapturingCondition : public Condition, p bool check(Ts... x) override { return this->parent_->is_running(); } }; +template class IsMutedCondition : public Condition, public Parented { + public: + bool check(Ts... x) override { return this->parent_->get_mute_state(); } +}; + } // namespace microphone } // namespace esphome diff --git a/esphome/components/microphone/microphone.cpp b/esphome/components/microphone/microphone.cpp new file mode 100644 index 0000000000..b1289f3791 --- /dev/null +++ b/esphome/components/microphone/microphone.cpp @@ -0,0 +1,21 @@ +#include "microphone.h" + +namespace esphome { +namespace microphone { + +void Microphone::add_data_callback(std::function &)> &&data_callback) { + std::function &)> mute_handled_callback = + [this, data_callback](const std::vector &data) { data_callback(this->silence_audio_(data)); }; + this->data_callbacks_.add(std::move(mute_handled_callback)); +} + +std::vector Microphone::silence_audio_(std::vector data) { + if (this->mute_state_) { + std::memset((void *) data.data(), 0, data.size()); + } + + return data; +} + +} // namespace microphone +} // namespace esphome diff --git a/esphome/components/microphone/microphone.h b/esphome/components/microphone/microphone.h index cef8d0f4c3..ea4e979e20 100644 --- a/esphome/components/microphone/microphone.h +++ b/esphome/components/microphone/microphone.h @@ -22,17 +22,21 @@ class Microphone { public: virtual void start() = 0; virtual void stop() = 0; - void add_data_callback(std::function &)> &&data_callback) { - this->data_callbacks_.add(std::move(data_callback)); - } + void add_data_callback(std::function &)> &&data_callback); bool is_running() const { return this->state_ == STATE_RUNNING; } bool is_stopped() const { return this->state_ == STATE_STOPPED; } + void set_mute_state(bool is_muted) { this->mute_state_ = is_muted; } + bool get_mute_state() { return this->mute_state_; } + audio::AudioStreamInfo get_audio_stream_info() { return this->audio_stream_info_; } protected: + std::vector silence_audio_(std::vector data); + State state_{STATE_STOPPED}; + bool mute_state_{false}; audio::AudioStreamInfo audio_stream_info_; diff --git a/tests/components/microphone/common.yaml b/tests/components/microphone/common.yaml index ccadc7aee5..d8e4abd12a 100644 --- a/tests/components/microphone/common.yaml +++ b/tests/components/microphone/common.yaml @@ -10,3 +10,11 @@ microphone: adc_type: external pdm: false mclk_multiple: 384 + on_data: + - if: + condition: + - microphone.is_muted: + then: + - microphone.unmute: + else: + - microphone.mute: From 2eb9582d0ff9a4b66c5863b01fba4d9c51be8cbc Mon Sep 17 00:00:00 2001 From: Kevin Ahrendt Date: Thu, 1 May 2025 14:04:23 -0500 Subject: [PATCH 079/102] [micro_wake_word] Clarify spectrogram features calculation (#8669) --- esphome/components/micro_wake_word/micro_wake_word.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/esphome/components/micro_wake_word/micro_wake_word.cpp b/esphome/components/micro_wake_word/micro_wake_word.cpp index f768b661c0..46ca328730 100644 --- a/esphome/components/micro_wake_word/micro_wake_word.cpp +++ b/esphome/components/micro_wake_word/micro_wake_word.cpp @@ -404,8 +404,8 @@ size_t MicroWakeWord::generate_features_(int16_t *audio_buffer, size_t samples_a constexpr int32_t value_div = 666; // 666 = 25.6 * 26.0 after rounding int32_t value = ((frontend_output.values[i] * value_scale) + (value_div / 2)) / value_div; - value -= INT8_MIN; - features_buffer[i] = clamp(value, INT8_MIN, INT8_MAX); + value += INT8_MIN; // Adds a -128; i.e., subtracts 128 + features_buffer[i] = static_cast(clamp(value, INT8_MIN, INT8_MAX)); } return processed_samples; From f4b5f32cb43db5a5b6ba0fe3cb541a2bf8ac8cef Mon Sep 17 00:00:00 2001 From: DJTerentjev Date: Fri, 2 May 2025 04:43:58 +0300 Subject: [PATCH 080/102] Update const.py (#8665) --- esphome/const.py | 1 + 1 file changed, 1 insertion(+) diff --git a/esphome/const.py b/esphome/const.py index f78312a5b0..262f5e0033 100644 --- a/esphome/const.py +++ b/esphome/const.py @@ -1102,6 +1102,7 @@ UNIT_MILLIGRAMS_PER_CUBIC_METER = "mg/m³" UNIT_MILLIMETER = "mm" UNIT_MILLISECOND = "ms" UNIT_MILLISIEMENS_PER_CENTIMETER = "mS/cm" +UNIT_MILLIVOLT = "mV" UNIT_MINUTE = "min" UNIT_OHM = "Ω" UNIT_PARTS_PER_BILLION = "ppb" From 8d33c6de364e3f231cdb2ae9b1bedb588381eb3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Obrembski?= Date: Sat, 3 May 2025 00:54:27 +0200 Subject: [PATCH 081/102] Added Banking support to tca9555, fixed input bug (#8003) Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com> --- .../components/gpio_expander/cached_gpio.h | 29 ++++++++++++++----- esphome/components/tca9555/tca9555.cpp | 17 +++++++---- 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/esphome/components/gpio_expander/cached_gpio.h b/esphome/components/gpio_expander/cached_gpio.h index 784c5f0f4a..78c675cdb2 100644 --- a/esphome/components/gpio_expander/cached_gpio.h +++ b/esphome/components/gpio_expander/cached_gpio.h @@ -8,30 +8,45 @@ namespace esphome { namespace gpio_expander { /// @brief A class to cache the read state of a GPIO expander. +/// This class caches reads between GPIO Pins which are on the same bank. +/// This means that for reading whole Port (ex. 8 pins) component needs only one +/// I2C/SPI read per main loop call. It assumes, that one bit in byte identifies one GPIO pin +/// Template parameters: +/// T - Type which represents internal register. Could be uint8_t or uint16_t. Adjust to +/// match size of your internal GPIO bank register. +/// N - Number of pins template class CachedGpioExpander { public: bool digital_read(T pin) { - if (!this->read_cache_invalidated_[pin]) { - this->read_cache_invalidated_[pin] = true; - return this->digital_read_cache(pin); + uint8_t bank = pin / (sizeof(T) * BITS_PER_BYTE); + if (this->read_cache_invalidated_[bank]) { + this->read_cache_invalidated_[bank] = false; + if (!this->digital_read_hw(pin)) + return false; } - return this->digital_read_hw(pin); + return this->digital_read_cache(pin); } void digital_write(T pin, bool value) { this->digital_write_hw(pin, value); } protected: + /// @brief Call component low level function to read GPIO state from device virtual bool digital_read_hw(T pin) = 0; + /// @brief Call component read function from internal cache. virtual bool digital_read_cache(T pin) = 0; + /// @brief Call component low level function to write GPIO state to device virtual void digital_write_hw(T pin, bool value) = 0; + const uint8_t cache_byte_size_ = N / (sizeof(T) * BITS_PER_BYTE); + /// @brief Invalidate cache. This function should be called in component loop(). void reset_pin_cache_() { - for (T i = 0; i < N; i++) { - this->read_cache_invalidated_[i] = false; + for (T i = 0; i < this->cache_byte_size_; i++) { + this->read_cache_invalidated_[i] = true; } } - std::array read_cache_invalidated_{}; + static const uint8_t BITS_PER_BYTE = 8; + std::array read_cache_invalidated_{}; }; } // namespace gpio_expander diff --git a/esphome/components/tca9555/tca9555.cpp b/esphome/components/tca9555/tca9555.cpp index cf0894427f..e065398c46 100644 --- a/esphome/components/tca9555/tca9555.cpp +++ b/esphome/components/tca9555/tca9555.cpp @@ -76,15 +76,20 @@ bool TCA9555Component::read_gpio_modes_() { bool TCA9555Component::digital_read_hw(uint8_t pin) { if (this->is_failed()) return false; - bool success; - uint8_t data[2]; - success = this->read_bytes(TCA9555_INPUT_PORT_REGISTER_0, data, 2); - this->input_mask_ = (uint16_t(data[1]) << 8) | (uint16_t(data[0]) << 0); - - if (!success) { + uint8_t data; + uint8_t bank_number = pin < 8 ? 0 : 1; + uint8_t register_to_read = bank_number ? TCA9555_INPUT_PORT_REGISTER_1 : TCA9555_INPUT_PORT_REGISTER_0; + if (!this->read_bytes(register_to_read, &data, 1)) { this->status_set_warning("Failed to read input register"); return false; } + uint8_t second_half = this->input_mask_ >> 8; + uint8_t first_half = this->input_mask_; + if (bank_number) { + this->input_mask_ = (data << 8) | (uint16_t(first_half) << 0); + } else { + this->input_mask_ = (uint16_t(second_half) << 8) | (data << 0); + } this->status_clear_warning(); return true; From 8aff6d2fdd09417472c9b80d3fb9e93ccc9b73df Mon Sep 17 00:00:00 2001 From: Jani <43068636+myllyja@users.noreply.github.com> Date: Sat, 3 May 2025 06:02:35 +0300 Subject: [PATCH 082/102] Add GDEY0583T81 support (#8668) --- .../components/waveshare_epaper/display.py | 2 + .../waveshare_epaper/waveshare_epaper.cpp | 217 ++++++++++++++++++ .../waveshare_epaper/waveshare_epaper.h | 34 +++ 3 files changed, 253 insertions(+) diff --git a/esphome/components/waveshare_epaper/display.py b/esphome/components/waveshare_epaper/display.py index afce83d553..cea0b2be5e 100644 --- a/esphome/components/waveshare_epaper/display.py +++ b/esphome/components/waveshare_epaper/display.py @@ -79,6 +79,7 @@ WaveshareEPaper5P8In = waveshare_epaper_ns.class_( WaveshareEPaper5P8InV2 = waveshare_epaper_ns.class_( "WaveshareEPaper5P8InV2", WaveshareEPaper ) +GDEY0583T81 = waveshare_epaper_ns.class_("GDEY0583T81", WaveshareEPaper) WaveshareEPaper7P3InF = waveshare_epaper_ns.class_( "WaveshareEPaper7P3InF", WaveshareEPaper7C ) @@ -156,6 +157,7 @@ MODELS = { "5.65in-f": ("b", WaveshareEPaper5P65InF), "5.83in": ("b", WaveshareEPaper5P8In), "5.83inv2": ("b", WaveshareEPaper5P8InV2), + "gdey0583t81": ("c", GDEY0583T81), "7.30in-f": ("b", WaveshareEPaper7P3InF), "7.50in": ("b", WaveshareEPaper7P5In), "7.50in-bv2": ("b", WaveshareEPaper7P5InBV2), diff --git a/esphome/components/waveshare_epaper/waveshare_epaper.cpp b/esphome/components/waveshare_epaper/waveshare_epaper.cpp index 8e30fc4c32..5031446c95 100644 --- a/esphome/components/waveshare_epaper/waveshare_epaper.cpp +++ b/esphome/components/waveshare_epaper/waveshare_epaper.cpp @@ -2938,6 +2938,223 @@ void WaveshareEPaper5P8InV2::dump_config() { LOG_UPDATE_INTERVAL(this); } +// ======================================================== +// Good Display 5.83in black/white GDEY0583T81 +// Product page: +// - https://www.good-display.com/product/440.html +// - https://www.seeedstudio.com/5-83-Monochrome-ePaper-Display-with-648x480-Pixels-p-5785.html +// Datasheet: +// - +// https://www.good-display.com/public/html/pdfjs/viewer/viewernew.html?file=https://v4.cecdn.yun300.cn/100001_1909185148/GDEY0583T81-new.pdf +// - https://v4.cecdn.yun300.cn/100001_1909185148/GDEY0583T81-new.pdf +// Reference code from GoodDisplay: +// - https://www.good-display.com/companyfile/903.html +// ======================================================== + +void GDEY0583T81::initialize() { + // Allocate buffer for old data for partial updates + RAMAllocator allocator{}; + this->old_buffer_ = allocator.allocate(this->get_buffer_length_()); + if (this->old_buffer_ == nullptr) { + ESP_LOGE(TAG, "Could not allocate old buffer for display!"); + return; + } + memset(this->old_buffer_, 0xFF, this->get_buffer_length_()); + + this->init_full_(); + + this->wait_until_idle_(); + + this->deep_sleep(); +} + +void GDEY0583T81::power_on_() { + if (!this->power_is_on_) { + this->command(0x04); + this->wait_until_idle_(); + } + this->power_is_on_ = true; + this->is_deep_sleep_ = false; +} + +void GDEY0583T81::power_off_() { + this->command(0x02); + this->wait_until_idle_(); + this->power_is_on_ = false; +} + +void GDEY0583T81::deep_sleep() { + if (this->is_deep_sleep_) { + return; + } + + // VCOM and data interval setting (CDI) + this->command(0x50); + this->data(0xf7); + + this->power_off_(); + delay(10); + + // Deep sleep (DSLP) + this->command(0x07); + this->data(0xA5); + this->is_deep_sleep_ = true; +} + +void GDEY0583T81::reset_() { + if (this->reset_pin_ != nullptr) { + this->reset_pin_->digital_write(false); + delay(10); + this->reset_pin_->digital_write(true); + delay(10); + } +} + +// Initialize for full screen update in fast mode +void GDEY0583T81::init_full_() { + this->init_display_(); + + // Based on the GD sample code + // VCOM and data interval setting (CDI) + this->command(0x50); + this->data(0x29); + this->data(0x07); + + // Cascade Setting (CCSET) + this->command(0xE0); + this->data(0x02); + + // Force Temperature (TSSET) + this->command(0xE5); + this->data(0x5A); +} + +// Initialize for a partial update of the full screen +void GDEY0583T81::init_partial_() { + this->init_display_(); + + // Cascade Setting (CCSET) + this->command(0xE0); + this->data(0x02); + + // Force Temperature (TSSET) + this->command(0xE5); + this->data(0x6E); +} + +void GDEY0583T81::init_display_() { + this->reset_(); + + // Panel Setting (PSR) + this->command(0x00); + // Sets: REG=0, LUT from OTP (set by CDI) + // KW/R=1, Sets KW mode (Black/White) + // as opposed to the default KWR mode (Black/White/Red) + // UD=1, Gate Scan Direction, 1 = up (default) + // SHL=1, Source Shift Direction, 1 = right (default) + // SHD_N=1, Booster Switch, 1 = ON (default) + // RST_N=1, Soft reset, 1 = No effect (default) + this->data(0x1F); + + // Resolution setting (TRES) + this->command(0x61); + + // Horizontal display resolution (HRES) + this->data(get_width_internal() / 256); + this->data(get_width_internal() % 256); + + // Vertical display resolution (VRES) + this->data(get_height_internal() / 256); + this->data(get_height_internal() % 256); + + this->power_on_(); +} + +void HOT GDEY0583T81::display() { + bool full_update = this->at_update_ == 0; + if (full_update) { + this->init_full_(); + } else { + this->init_partial_(); + + // VCOM and data interval setting (CDI) + this->command(0x50); + this->data(0xA9); + this->data(0x07); + + // Partial In (PTIN), makes the display enter partial mode + this->command(0x91); + + // Partial Window (PTL) + // We use the full screen as the window + this->command(0x90); + + // Horizontal start/end channel bank (HRST/HRED) + this->data(0); + this->data(0); + this->data((get_width_internal() - 1) / 256); + this->data((get_width_internal() - 1) % 256); + + // Vertical start/end line (VRST/VRED) + this->data(0); + this->data(0); + this->data((get_height_internal() - 1) / 256); + this->data((get_height_internal() - 1) % 256); + + this->data(0x01); + + // Display Start Transmission 1 (DTM1) + // in KW mode this writes "OLD" data to SRAM + this->command(0x10); + this->start_data_(); + this->write_array(this->old_buffer_, this->get_buffer_length_()); + this->end_data_(); + } + + // Display Start Transmission 2 (DTM2) + // in KW mode this writes "NEW" data to SRAM + this->command(0x13); + this->start_data_(); + this->write_array(this->buffer_, this->get_buffer_length_()); + this->end_data_(); + + for (size_t i = 0; i < this->get_buffer_length_(); i++) { + this->old_buffer_[i] = this->buffer_[i]; + } + + // Display Refresh (DRF) + this->command(0x12); + delay(10); + this->wait_until_idle_(); + + if (full_update) { + ESP_LOGD(TAG, "Full update done"); + } else { + // Partial out (PTOUT), makes the display exit partial mode + this->command(0x92); + ESP_LOGD(TAG, "Partial update done, next full update after %d cycles", + this->full_update_every_ - this->at_update_ - 1); + } + + this->at_update_ = (this->at_update_ + 1) % this->full_update_every_; + + this->deep_sleep(); +} + +void GDEY0583T81::set_full_update_every(uint32_t full_update_every) { this->full_update_every_ = full_update_every; } +int GDEY0583T81::get_width_internal() { return 648; } +int GDEY0583T81::get_height_internal() { return 480; } +uint32_t GDEY0583T81::idle_timeout_() { return 5000; } +void GDEY0583T81::dump_config() { + LOG_DISPLAY("", "GoodDisplay E-Paper", this); + ESP_LOGCONFIG(TAG, " Model: 5.83in B/W GDEY0583T81"); + ESP_LOGCONFIG(TAG, " Full Update Every: %" PRIu32, this->full_update_every_); + LOG_PIN(" Reset Pin: ", this->reset_pin_); + LOG_PIN(" DC Pin: ", this->dc_pin_); + LOG_PIN(" Busy Pin: ", this->busy_pin_); + LOG_UPDATE_INTERVAL(this); +} + void WaveshareEPaper7P5InBV2::initialize() { // COMMAND POWER SETTING this->command(0x01); diff --git a/esphome/components/waveshare_epaper/waveshare_epaper.h b/esphome/components/waveshare_epaper/waveshare_epaper.h index 9fff1ea6b5..74bb153519 100644 --- a/esphome/components/waveshare_epaper/waveshare_epaper.h +++ b/esphome/components/waveshare_epaper/waveshare_epaper.h @@ -686,6 +686,40 @@ class WaveshareEPaper5P8InV2 : public WaveshareEPaper { int get_height_internal() override; }; +class GDEY0583T81 : public WaveshareEPaper { + public: + void initialize() override; + + void display() override; + + void dump_config() override; + + void deep_sleep() override; + + void set_full_update_every(uint32_t full_update_every); + + protected: + int get_width_internal() override; + int get_height_internal() override; + uint32_t idle_timeout_() override; + + private: + void power_on_(); + void power_off_(); + void reset_(); + void update_full_(); + void update_part_(); + void init_full_(); + void init_partial_(); + void init_display_(); + + uint32_t full_update_every_{30}; + uint32_t at_update_{0}; + bool power_is_on_{false}; + bool is_deep_sleep_{false}; + uint8_t *old_buffer_{nullptr}; +}; + class WaveshareEPaper5P65InF : public WaveshareEPaper7C { public: void initialize() override; From e869a3aec32bc6764328b024aff9f6488b1b9038 Mon Sep 17 00:00:00 2001 From: Thomas Rupprecht Date: Sun, 4 May 2025 05:41:52 +0200 Subject: [PATCH 083/102] [climate] Fix typo and use ``this->`` (#8678) --- esphome/components/climate/climate_mode.h | 2 +- esphome/components/climate/climate_traits.h | 120 ++++++++++---------- 2 files changed, 64 insertions(+), 58 deletions(-) diff --git a/esphome/components/climate/climate_mode.h b/esphome/components/climate/climate_mode.h index c5245812c7..80efb4c048 100644 --- a/esphome/components/climate/climate_mode.h +++ b/esphome/components/climate/climate_mode.h @@ -20,7 +20,7 @@ enum ClimateMode : uint8_t { CLIMATE_MODE_FAN_ONLY = 4, /// The climate device is set to dry/humidity mode CLIMATE_MODE_DRY = 5, - /** The climate device is adjusting the temperatre dynamically. + /** The climate device is adjusting the temperature dynamically. * For example, the target temperature can be adjusted based on a schedule, or learned behavior. * The target temperature can't be adjusted when in this mode. */ diff --git a/esphome/components/climate/climate_traits.h b/esphome/components/climate/climate_traits.h index 58d7b586d7..c3a0dfca8f 100644 --- a/esphome/components/climate/climate_traits.h +++ b/esphome/components/climate/climate_traits.h @@ -40,24 +40,24 @@ namespace climate { */ class ClimateTraits { public: - bool get_supports_current_temperature() const { return supports_current_temperature_; } + bool get_supports_current_temperature() const { return this->supports_current_temperature_; } void set_supports_current_temperature(bool supports_current_temperature) { - supports_current_temperature_ = supports_current_temperature; + this->supports_current_temperature_ = supports_current_temperature; } - bool get_supports_current_humidity() const { return supports_current_humidity_; } + bool get_supports_current_humidity() const { return this->supports_current_humidity_; } void set_supports_current_humidity(bool supports_current_humidity) { - supports_current_humidity_ = supports_current_humidity; + this->supports_current_humidity_ = supports_current_humidity; } - bool get_supports_two_point_target_temperature() const { return supports_two_point_target_temperature_; } + bool get_supports_two_point_target_temperature() const { return this->supports_two_point_target_temperature_; } void set_supports_two_point_target_temperature(bool supports_two_point_target_temperature) { - supports_two_point_target_temperature_ = supports_two_point_target_temperature; + this->supports_two_point_target_temperature_ = supports_two_point_target_temperature; } - bool get_supports_target_humidity() const { return supports_target_humidity_; } + bool get_supports_target_humidity() const { return this->supports_target_humidity_; } void set_supports_target_humidity(bool supports_target_humidity) { - supports_target_humidity_ = supports_target_humidity; + this->supports_target_humidity_ = supports_target_humidity; } - void set_supported_modes(std::set modes) { supported_modes_ = std::move(modes); } - void add_supported_mode(ClimateMode mode) { supported_modes_.insert(mode); } + void set_supported_modes(std::set modes) { this->supported_modes_ = std::move(modes); } + void add_supported_mode(ClimateMode mode) { this->supported_modes_.insert(mode); } ESPDEPRECATED("This method is deprecated, use set_supported_modes() instead", "v1.20") void set_supports_auto_mode(bool supports_auto_mode) { set_mode_support_(CLIMATE_MODE_AUTO, supports_auto_mode); } ESPDEPRECATED("This method is deprecated, use set_supported_modes() instead", "v1.20") @@ -72,15 +72,15 @@ class ClimateTraits { } ESPDEPRECATED("This method is deprecated, use set_supported_modes() instead", "v1.20") void set_supports_dry_mode(bool supports_dry_mode) { set_mode_support_(CLIMATE_MODE_DRY, supports_dry_mode); } - bool supports_mode(ClimateMode mode) const { return supported_modes_.count(mode); } - const std::set &get_supported_modes() const { return supported_modes_; } + bool supports_mode(ClimateMode mode) const { return this->supported_modes_.count(mode); } + const std::set &get_supported_modes() const { return this->supported_modes_; } - void set_supports_action(bool supports_action) { supports_action_ = supports_action; } - bool get_supports_action() const { return supports_action_; } + void set_supports_action(bool supports_action) { this->supports_action_ = supports_action; } + bool get_supports_action() const { return this->supports_action_; } - void set_supported_fan_modes(std::set modes) { supported_fan_modes_ = std::move(modes); } - void add_supported_fan_mode(ClimateFanMode mode) { supported_fan_modes_.insert(mode); } - void add_supported_custom_fan_mode(const std::string &mode) { supported_custom_fan_modes_.insert(mode); } + void set_supported_fan_modes(std::set modes) { this->supported_fan_modes_ = std::move(modes); } + void add_supported_fan_mode(ClimateFanMode mode) { this->supported_fan_modes_.insert(mode); } + void add_supported_custom_fan_mode(const std::string &mode) { this->supported_custom_fan_modes_.insert(mode); } ESPDEPRECATED("This method is deprecated, use set_supported_fan_modes() instead", "v1.20") void set_supports_fan_mode_on(bool supported) { set_fan_mode_support_(CLIMATE_FAN_ON, supported); } ESPDEPRECATED("This method is deprecated, use set_supported_fan_modes() instead", "v1.20") @@ -99,35 +99,37 @@ class ClimateTraits { void set_supports_fan_mode_focus(bool supported) { set_fan_mode_support_(CLIMATE_FAN_FOCUS, supported); } ESPDEPRECATED("This method is deprecated, use set_supported_fan_modes() instead", "v1.20") void set_supports_fan_mode_diffuse(bool supported) { set_fan_mode_support_(CLIMATE_FAN_DIFFUSE, supported); } - bool supports_fan_mode(ClimateFanMode fan_mode) const { return supported_fan_modes_.count(fan_mode); } - bool get_supports_fan_modes() const { return !supported_fan_modes_.empty() || !supported_custom_fan_modes_.empty(); } - const std::set &get_supported_fan_modes() const { return supported_fan_modes_; } + bool supports_fan_mode(ClimateFanMode fan_mode) const { return this->supported_fan_modes_.count(fan_mode); } + bool get_supports_fan_modes() const { + return !this->supported_fan_modes_.empty() || !this->supported_custom_fan_modes_.empty(); + } + const std::set &get_supported_fan_modes() const { return this->supported_fan_modes_; } void set_supported_custom_fan_modes(std::set supported_custom_fan_modes) { - supported_custom_fan_modes_ = std::move(supported_custom_fan_modes); + this->supported_custom_fan_modes_ = std::move(supported_custom_fan_modes); } - const std::set &get_supported_custom_fan_modes() const { return supported_custom_fan_modes_; } + const std::set &get_supported_custom_fan_modes() const { return this->supported_custom_fan_modes_; } bool supports_custom_fan_mode(const std::string &custom_fan_mode) const { - return supported_custom_fan_modes_.count(custom_fan_mode); + return this->supported_custom_fan_modes_.count(custom_fan_mode); } - void set_supported_presets(std::set presets) { supported_presets_ = std::move(presets); } - void add_supported_preset(ClimatePreset preset) { supported_presets_.insert(preset); } - void add_supported_custom_preset(const std::string &preset) { supported_custom_presets_.insert(preset); } - bool supports_preset(ClimatePreset preset) const { return supported_presets_.count(preset); } - bool get_supports_presets() const { return !supported_presets_.empty(); } - const std::set &get_supported_presets() const { return supported_presets_; } + void set_supported_presets(std::set presets) { this->supported_presets_ = std::move(presets); } + void add_supported_preset(ClimatePreset preset) { this->supported_presets_.insert(preset); } + void add_supported_custom_preset(const std::string &preset) { this->supported_custom_presets_.insert(preset); } + bool supports_preset(ClimatePreset preset) const { return this->supported_presets_.count(preset); } + bool get_supports_presets() const { return !this->supported_presets_.empty(); } + const std::set &get_supported_presets() const { return this->supported_presets_; } void set_supported_custom_presets(std::set supported_custom_presets) { - supported_custom_presets_ = std::move(supported_custom_presets); + this->supported_custom_presets_ = std::move(supported_custom_presets); } - const std::set &get_supported_custom_presets() const { return supported_custom_presets_; } + const std::set &get_supported_custom_presets() const { return this->supported_custom_presets_; } bool supports_custom_preset(const std::string &custom_preset) const { - return supported_custom_presets_.count(custom_preset); + return this->supported_custom_presets_.count(custom_preset); } - void set_supported_swing_modes(std::set modes) { supported_swing_modes_ = std::move(modes); } - void add_supported_swing_mode(ClimateSwingMode mode) { supported_swing_modes_.insert(mode); } + void set_supported_swing_modes(std::set modes) { this->supported_swing_modes_ = std::move(modes); } + void add_supported_swing_mode(ClimateSwingMode mode) { this->supported_swing_modes_.insert(mode); } ESPDEPRECATED("This method is deprecated, use set_supported_swing_modes() instead", "v1.20") void set_supports_swing_mode_off(bool supported) { set_swing_mode_support_(CLIMATE_SWING_OFF, supported); } ESPDEPRECATED("This method is deprecated, use set_supported_swing_modes() instead", "v1.20") @@ -138,54 +140,58 @@ class ClimateTraits { void set_supports_swing_mode_horizontal(bool supported) { set_swing_mode_support_(CLIMATE_SWING_HORIZONTAL, supported); } - bool supports_swing_mode(ClimateSwingMode swing_mode) const { return supported_swing_modes_.count(swing_mode); } - bool get_supports_swing_modes() const { return !supported_swing_modes_.empty(); } - const std::set &get_supported_swing_modes() const { return supported_swing_modes_; } + bool supports_swing_mode(ClimateSwingMode swing_mode) const { return this->supported_swing_modes_.count(swing_mode); } + bool get_supports_swing_modes() const { return !this->supported_swing_modes_.empty(); } + const std::set &get_supported_swing_modes() const { return this->supported_swing_modes_; } - float get_visual_min_temperature() const { return visual_min_temperature_; } - void set_visual_min_temperature(float visual_min_temperature) { visual_min_temperature_ = visual_min_temperature; } - float get_visual_max_temperature() const { return visual_max_temperature_; } - void set_visual_max_temperature(float visual_max_temperature) { visual_max_temperature_ = visual_max_temperature; } - float get_visual_target_temperature_step() const { return visual_target_temperature_step_; } - float get_visual_current_temperature_step() const { return visual_current_temperature_step_; } + float get_visual_min_temperature() const { return this->visual_min_temperature_; } + void set_visual_min_temperature(float visual_min_temperature) { + this->visual_min_temperature_ = visual_min_temperature; + } + float get_visual_max_temperature() const { return this->visual_max_temperature_; } + void set_visual_max_temperature(float visual_max_temperature) { + this->visual_max_temperature_ = visual_max_temperature; + } + float get_visual_target_temperature_step() const { return this->visual_target_temperature_step_; } + float get_visual_current_temperature_step() const { return this->visual_current_temperature_step_; } void set_visual_target_temperature_step(float temperature_step) { - visual_target_temperature_step_ = temperature_step; + this->visual_target_temperature_step_ = temperature_step; } void set_visual_current_temperature_step(float temperature_step) { - visual_current_temperature_step_ = temperature_step; + this->visual_current_temperature_step_ = temperature_step; } void set_visual_temperature_step(float temperature_step) { - visual_target_temperature_step_ = temperature_step; - visual_current_temperature_step_ = temperature_step; + this->visual_target_temperature_step_ = temperature_step; + this->visual_current_temperature_step_ = temperature_step; } int8_t get_target_temperature_accuracy_decimals() const; int8_t get_current_temperature_accuracy_decimals() const; - float get_visual_min_humidity() const { return visual_min_humidity_; } - void set_visual_min_humidity(float visual_min_humidity) { visual_min_humidity_ = visual_min_humidity; } - float get_visual_max_humidity() const { return visual_max_humidity_; } - void set_visual_max_humidity(float visual_max_humidity) { visual_max_humidity_ = visual_max_humidity; } + float get_visual_min_humidity() const { return this->visual_min_humidity_; } + void set_visual_min_humidity(float visual_min_humidity) { this->visual_min_humidity_ = visual_min_humidity; } + float get_visual_max_humidity() const { return this->visual_max_humidity_; } + void set_visual_max_humidity(float visual_max_humidity) { this->visual_max_humidity_ = visual_max_humidity; } protected: void set_mode_support_(climate::ClimateMode mode, bool supported) { if (supported) { - supported_modes_.insert(mode); + this->supported_modes_.insert(mode); } else { - supported_modes_.erase(mode); + this->supported_modes_.erase(mode); } } void set_fan_mode_support_(climate::ClimateFanMode mode, bool supported) { if (supported) { - supported_fan_modes_.insert(mode); + this->supported_fan_modes_.insert(mode); } else { - supported_fan_modes_.erase(mode); + this->supported_fan_modes_.erase(mode); } } void set_swing_mode_support_(climate::ClimateSwingMode mode, bool supported) { if (supported) { - supported_swing_modes_.insert(mode); + this->supported_swing_modes_.insert(mode); } else { - supported_swing_modes_.erase(mode); + this->supported_swing_modes_.erase(mode); } } From bc6ee202705aff64f9930d2a4f4413cce3188f4e Mon Sep 17 00:00:00 2001 From: Pat Satyshur Date: Sat, 3 May 2025 22:44:54 -0500 Subject: [PATCH 084/102] Add CONF_CONTINUOUS to const.py (#8682) --- esphome/components/esp32_ble_tracker/__init__.py | 2 +- esphome/components/graph/__init__.py | 3 +-- esphome/const.py | 1 + 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/esphome/components/esp32_ble_tracker/__init__.py b/esphome/components/esp32_ble_tracker/__init__.py index 68be2cbbe9..2c877367f8 100644 --- a/esphome/components/esp32_ble_tracker/__init__.py +++ b/esphome/components/esp32_ble_tracker/__init__.py @@ -17,6 +17,7 @@ from esphome.components.esp32_ble import ( import esphome.config_validation as cv from esphome.const import ( CONF_ACTIVE, + CONF_CONTINUOUS, CONF_DURATION, CONF_ID, CONF_INTERVAL, @@ -42,7 +43,6 @@ CONF_MAX_CONNECTIONS = "max_connections" CONF_ESP32_BLE_ID = "esp32_ble_id" CONF_SCAN_PARAMETERS = "scan_parameters" CONF_WINDOW = "window" -CONF_CONTINUOUS = "continuous" CONF_ON_SCAN_END = "on_scan_end" DEFAULT_MAX_CONNECTIONS = 3 diff --git a/esphome/components/graph/__init__.py b/esphome/components/graph/__init__.py index 254294619e..6e8ba44bec 100644 --- a/esphome/components/graph/__init__.py +++ b/esphome/components/graph/__init__.py @@ -5,6 +5,7 @@ import esphome.config_validation as cv from esphome.const import ( CONF_BORDER, CONF_COLOR, + CONF_CONTINUOUS, CONF_DIRECTION, CONF_DURATION, CONF_HEIGHT, @@ -61,8 +62,6 @@ VALUE_POSITION_TYPE = { "BELOW": ValuePositionType.VALUE_POSITION_TYPE_BELOW, } -CONF_CONTINUOUS = "continuous" - GRAPH_TRACE_SCHEMA = cv.Schema( { cv.GenerateID(): cv.declare_id(GraphTrace), diff --git a/esphome/const.py b/esphome/const.py index 262f5e0033..3b84055789 100644 --- a/esphome/const.py +++ b/esphome/const.py @@ -160,6 +160,7 @@ CONF_CONDITION = "condition" CONF_CONDITION_ID = "condition_id" CONF_CONDUCTIVITY = "conductivity" CONF_CONSTANT_BRIGHTNESS = "constant_brightness" +CONF_CONTINUOUS = "continuous" CONF_CONTRAST = "contrast" CONF_COOL_ACTION = "cool_action" CONF_COOL_DEADBAND = "cool_deadband" From 670ad7192c02869adb8e4385243c851f4398ed78 Mon Sep 17 00:00:00 2001 From: Thomas Rupprecht Date: Sun, 4 May 2025 22:47:57 +0200 Subject: [PATCH 085/102] unify lowercase `x` in hexadecimal values (#8686) --- esphome/components/as7341/as7341.h | 2 +- esphome/components/bl0906/constants.h | 32 +++++++++---------- .../climate_ir_lg/climate_ir_lg.cpp | 2 +- esphome/components/hm3301/hm3301.h | 2 +- esphome/components/ili9xxx/ili9xxx_init.h | 2 +- esphome/components/ld2410/ld2410.h | 2 +- esphome/components/ld2450/ld2450.h | 2 +- esphome/components/mitsubishi/mitsubishi.cpp | 18 +++++------ esphome/components/pn7150/pn7150.h | 4 +-- esphome/components/pn7160/pn7160.h | 4 +-- .../waveshare_epaper/waveshare_epaper.cpp | 12 +++---- .../xpt2046/touchscreen/xpt2046.cpp | 2 +- 12 files changed, 42 insertions(+), 42 deletions(-) diff --git a/esphome/components/as7341/as7341.h b/esphome/components/as7341/as7341.h index e517e1d2bf..aed7996cef 100644 --- a/esphome/components/as7341/as7341.h +++ b/esphome/components/as7341/as7341.h @@ -7,7 +7,7 @@ namespace esphome { namespace as7341 { -static const uint8_t AS7341_CHIP_ID = 0X09; +static const uint8_t AS7341_CHIP_ID = 0x09; static const uint8_t AS7341_CONFIG = 0x70; static const uint8_t AS7341_LED = 0x74; diff --git a/esphome/components/bl0906/constants.h b/esphome/components/bl0906/constants.h index 546916aa3c..a174e54bb2 100644 --- a/esphome/components/bl0906/constants.h +++ b/esphome/components/bl0906/constants.h @@ -45,7 +45,7 @@ static const uint8_t BL0906_WRITE_COMMAND = 0xCA; static const uint8_t BL0906_V_RMS = 0x16; // Total power -static const uint8_t BL0906_WATT_SUM = 0X2C; +static const uint8_t BL0906_WATT_SUM = 0x2C; // Current1~6 static const uint8_t BL0906_I_1_RMS = 0x0D; // current_1 @@ -56,29 +56,29 @@ static const uint8_t BL0906_I_5_RMS = 0x13; static const uint8_t BL0906_I_6_RMS = 0x14; // current_6 // Power1~6 -static const uint8_t BL0906_WATT_1 = 0X23; // power_1 -static const uint8_t BL0906_WATT_2 = 0X24; -static const uint8_t BL0906_WATT_3 = 0X25; -static const uint8_t BL0906_WATT_4 = 0X26; -static const uint8_t BL0906_WATT_5 = 0X29; -static const uint8_t BL0906_WATT_6 = 0X2A; // power_6 +static const uint8_t BL0906_WATT_1 = 0x23; // power_1 +static const uint8_t BL0906_WATT_2 = 0x24; +static const uint8_t BL0906_WATT_3 = 0x25; +static const uint8_t BL0906_WATT_4 = 0x26; +static const uint8_t BL0906_WATT_5 = 0x29; +static const uint8_t BL0906_WATT_6 = 0x2A; // power_6 // Active pulse count, unsigned -static const uint8_t BL0906_CF_1_CNT = 0X30; // Channel_1 -static const uint8_t BL0906_CF_2_CNT = 0X31; -static const uint8_t BL0906_CF_3_CNT = 0X32; -static const uint8_t BL0906_CF_4_CNT = 0X33; -static const uint8_t BL0906_CF_5_CNT = 0X36; -static const uint8_t BL0906_CF_6_CNT = 0X37; // Channel_6 +static const uint8_t BL0906_CF_1_CNT = 0x30; // Channel_1 +static const uint8_t BL0906_CF_2_CNT = 0x31; +static const uint8_t BL0906_CF_3_CNT = 0x32; +static const uint8_t BL0906_CF_4_CNT = 0x33; +static const uint8_t BL0906_CF_5_CNT = 0x36; +static const uint8_t BL0906_CF_6_CNT = 0x37; // Channel_6 // Total active pulse count, unsigned -static const uint8_t BL0906_CF_SUM_CNT = 0X39; +static const uint8_t BL0906_CF_SUM_CNT = 0x39; // Voltage frequency cycle -static const uint8_t BL0906_FREQUENCY = 0X4E; +static const uint8_t BL0906_FREQUENCY = 0x4E; // Internal temperature -static const uint8_t BL0906_TEMPERATURE = 0X5E; +static const uint8_t BL0906_TEMPERATURE = 0x5E; // Calibration register // RMS gain adjustment register diff --git a/esphome/components/climate_ir_lg/climate_ir_lg.cpp b/esphome/components/climate_ir_lg/climate_ir_lg.cpp index c65f24ebc0..7e37639a39 100644 --- a/esphome/components/climate_ir_lg/climate_ir_lg.cpp +++ b/esphome/components/climate_ir_lg/climate_ir_lg.cpp @@ -32,7 +32,7 @@ const uint32_t FAN_MAX = 0x40; // Temperature const uint8_t TEMP_RANGE = TEMP_MAX - TEMP_MIN + 1; -const uint32_t TEMP_MASK = 0XF00; +const uint32_t TEMP_MASK = 0xF00; const uint32_t TEMP_SHIFT = 8; const uint16_t BITS = 28; diff --git a/esphome/components/hm3301/hm3301.h b/esphome/components/hm3301/hm3301.h index bccdd1d35b..6779b4e195 100644 --- a/esphome/components/hm3301/hm3301.h +++ b/esphome/components/hm3301/hm3301.h @@ -8,7 +8,7 @@ namespace esphome { namespace hm3301 { -static const uint8_t SELECT_COMM_CMD = 0X88; +static const uint8_t SELECT_COMM_CMD = 0x88; class HM3301Component : public PollingComponent, public i2c::I2CDevice { public: diff --git a/esphome/components/ili9xxx/ili9xxx_init.h b/esphome/components/ili9xxx/ili9xxx_init.h index f05b884be6..7b176ed57a 100644 --- a/esphome/components/ili9xxx/ili9xxx_init.h +++ b/esphome/components/ili9xxx/ili9xxx_init.h @@ -388,7 +388,7 @@ static const uint8_t PROGMEM INITCMD_GC9D01N[] = { 0x8D, 1, 0xFF, 0x8E, 1, 0xFF, 0x8F, 1, 0xFF, - 0X3A, 1, 0x05, // COLMOD: Pixel Format Set (3Ah) MCU interface, 16 bits / pixel + 0x3A, 1, 0x05, // COLMOD: Pixel Format Set (3Ah) MCU interface, 16 bits / pixel 0xEC, 1, 0x01, // Inversion (ECh) DINV=1+2H1V column for Dual Gate (BFh=0) // According to datasheet Inversion (ECh) value 0x01 isn't valid, but Lilygo uses it everywhere 0x74, 7, 0x02, 0x0E, 0x00, 0x00, 0x00, 0x00, 0x00, diff --git a/esphome/components/ld2410/ld2410.h b/esphome/components/ld2410/ld2410.h index 8084d4c33e..1bbaa8987a 100644 --- a/esphome/components/ld2410/ld2410.h +++ b/esphome/components/ld2410/ld2410.h @@ -129,7 +129,7 @@ enum PeriodicDataStructure : uint8_t { LIGHT_SENSOR = 37, OUT_PIN_SENSOR = 38, }; -enum PeriodicDataValue : uint8_t { HEAD = 0XAA, END = 0x55, CHECK = 0x00 }; +enum PeriodicDataValue : uint8_t { HEAD = 0xAA, END = 0x55, CHECK = 0x00 }; enum AckDataStructure : uint8_t { COMMAND = 6, COMMAND_STATUS = 7 }; diff --git a/esphome/components/ld2450/ld2450.h b/esphome/components/ld2450/ld2450.h index 32e4bc02e4..e0927e5d7d 100644 --- a/esphome/components/ld2450/ld2450.h +++ b/esphome/components/ld2450/ld2450.h @@ -105,7 +105,7 @@ enum PeriodicDataStructure : uint8_t { TARGET_RESOLUTION = 10, }; -enum PeriodicDataValue : uint8_t { HEAD = 0XAA, END = 0x55, CHECK = 0x00 }; +enum PeriodicDataValue : uint8_t { HEAD = 0xAA, END = 0x55, CHECK = 0x00 }; enum AckDataStructure : uint8_t { COMMAND = 6, COMMAND_STATUS = 7 }; diff --git a/esphome/components/mitsubishi/mitsubishi.cpp b/esphome/components/mitsubishi/mitsubishi.cpp index 449c8fc712..3d9207dd96 100644 --- a/esphome/components/mitsubishi/mitsubishi.cpp +++ b/esphome/components/mitsubishi/mitsubishi.cpp @@ -25,8 +25,8 @@ const uint8_t MITSUBISHI_FAN_AUTO = 0x00; const uint8_t MITSUBISHI_VERTICAL_VANE_SWING = 0x38; -// const uint8_t MITSUBISHI_AUTO = 0X80; -const uint8_t MITSUBISHI_OTHERWISE = 0X40; +// const uint8_t MITSUBISHI_AUTO = 0x80; +const uint8_t MITSUBISHI_OTHERWISE = 0x40; const uint8_t MITSUBISHI_POWERFUL = 0x08; // Optional presets used to enable some model features @@ -42,13 +42,13 @@ const uint16_t MITSUBISHI_HEADER_SPACE = 1700; const uint16_t MITSUBISHI_MIN_GAP = 17500; // Marker bytes -const uint8_t MITSUBISHI_BYTE00 = 0X23; -const uint8_t MITSUBISHI_BYTE01 = 0XCB; -const uint8_t MITSUBISHI_BYTE02 = 0X26; -const uint8_t MITSUBISHI_BYTE03 = 0X01; -const uint8_t MITSUBISHI_BYTE04 = 0X00; -const uint8_t MITSUBISHI_BYTE13 = 0X00; -const uint8_t MITSUBISHI_BYTE16 = 0X00; +const uint8_t MITSUBISHI_BYTE00 = 0x23; +const uint8_t MITSUBISHI_BYTE01 = 0xCB; +const uint8_t MITSUBISHI_BYTE02 = 0x26; +const uint8_t MITSUBISHI_BYTE03 = 0x01; +const uint8_t MITSUBISHI_BYTE04 = 0x00; +const uint8_t MITSUBISHI_BYTE13 = 0x00; +const uint8_t MITSUBISHI_BYTE16 = 0x00; climate::ClimateTraits MitsubishiClimate::traits() { auto traits = climate::ClimateTraits(); diff --git a/esphome/components/pn7150/pn7150.h b/esphome/components/pn7150/pn7150.h index 54038f5085..87af7d629b 100644 --- a/esphome/components/pn7150/pn7150.h +++ b/esphome/components/pn7150/pn7150.h @@ -123,8 +123,8 @@ enum class NCIState : uint8_t { RFST_POLL_ACTIVE, EP_DEACTIVATING, EP_SELECTING, - TEST = 0XFE, - FAILED = 0XFF, + TEST = 0xFE, + FAILED = 0xFF, }; enum class TestMode : uint8_t { diff --git a/esphome/components/pn7160/pn7160.h b/esphome/components/pn7160/pn7160.h index f2e05ea1d0..ff8a492b7b 100644 --- a/esphome/components/pn7160/pn7160.h +++ b/esphome/components/pn7160/pn7160.h @@ -138,8 +138,8 @@ enum class NCIState : uint8_t { RFST_POLL_ACTIVE, EP_DEACTIVATING, EP_SELECTING, - TEST = 0XFE, - FAILED = 0XFF, + TEST = 0xFE, + FAILED = 0xFF, }; enum class TestMode : uint8_t { diff --git a/esphome/components/waveshare_epaper/waveshare_epaper.cpp b/esphome/components/waveshare_epaper/waveshare_epaper.cpp index 5031446c95..79aae70e41 100644 --- a/esphome/components/waveshare_epaper/waveshare_epaper.cpp +++ b/esphome/components/waveshare_epaper/waveshare_epaper.cpp @@ -1004,7 +1004,7 @@ void WaveshareEPaper1P54InBV2::initialize() { this->command(0x4E); // set RAM x address count to 0; this->data(0x00); - this->command(0x4F); // set RAM y address count to 0X199; + this->command(0x4F); // set RAM y address count to 0x199; this->data(0xC7); this->data(0x00); @@ -1878,7 +1878,7 @@ void GDEY029T94::initialize() { this->command(0x4E); // set RAM x address count to 0; this->data(0x00); - this->command(0x4F); // set RAM y address count to 0X199; + this->command(0x4F); // set RAM y address count to 0x199; this->command(0x00); this->command(0x00); this->wait_until_idle_(); @@ -2070,7 +2070,7 @@ void GDEW029T5::init_full_() { this->init_display_(); this->command(0x82); // vcom_DC setting this->data(0x08); - this->command(0X50); // VCOM AND DATA INTERVAL SETTING + this->command(0x50); // VCOM AND DATA INTERVAL SETTING this->data(0x97); // WBmode:VBDF 17|D7 VBDW 97 VBDB 57 WBRmode:VBDF F7 VBDW 77 VBDB 37 VBDR B7 this->command(0x20); this->write_lut_(LUT_20_VCOMDC_29_5, sizeof(LUT_20_VCOMDC_29_5)); @@ -2090,7 +2090,7 @@ void GDEW029T5::init_partial_() { this->init_display_(); this->command(0x82); // vcom_DC setting this->data(0x08); - this->command(0X50); // VCOM AND DATA INTERVAL SETTING + this->command(0x50); // VCOM AND DATA INTERVAL SETTING this->data(0x17); // WBmode:VBDF 17|D7 VBDW 97 VBDB 57 WBRmode:VBDF F7 VBDW 77 VBDB 37 VBDR B7 this->command(0x20); this->write_lut_(LUT_20_VCOMDC_PARTIAL_29_5, sizeof(LUT_20_VCOMDC_PARTIAL_29_5)); @@ -4481,10 +4481,10 @@ void WaveshareEPaper7P5InHDB::initialize() { this->data(0x01); // LUT1, for white this->command(0x18); - this->data(0X80); + this->data(0x80); this->command(0x22); - this->data(0XB1); // Load Temperature and waveform setting. + this->data(0xB1); // Load Temperature and waveform setting. this->command(0x20); diff --git a/esphome/components/xpt2046/touchscreen/xpt2046.cpp b/esphome/components/xpt2046/touchscreen/xpt2046.cpp index a4e2b84656..aa11ed4b77 100644 --- a/esphome/components/xpt2046/touchscreen/xpt2046.cpp +++ b/esphome/components/xpt2046/touchscreen/xpt2046.cpp @@ -32,7 +32,7 @@ void XPT2046Component::update_touches() { int16_t touch_pressure_1 = this->read_adc_(0xB1 /* touch_pressure_1 */); int16_t touch_pressure_2 = this->read_adc_(0xC1 /* touch_pressure_2 */); - z_raw = touch_pressure_1 + 0Xfff - touch_pressure_2; + z_raw = touch_pressure_1 + 0xfff - touch_pressure_2; ESP_LOGVV(TAG, "Touchscreen Update z = %d", z_raw); touch = (z_raw >= this->threshold_); if (touch) { From 84ebbf07629bb8cba808658f32945e5469912267 Mon Sep 17 00:00:00 2001 From: Thomas Rupprecht Date: Sun, 4 May 2025 23:21:57 +0200 Subject: [PATCH 086/102] [climate_ir_lg] use `this->` (#8687) --- esphome/components/climate_ir_lg/climate_ir_lg.cpp | 12 ++++++------ esphome/components/climate_ir_lg/climate_ir_lg.h | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/esphome/components/climate_ir_lg/climate_ir_lg.cpp b/esphome/components/climate_ir_lg/climate_ir_lg.cpp index 7e37639a39..7fe0646230 100644 --- a/esphome/components/climate_ir_lg/climate_ir_lg.cpp +++ b/esphome/components/climate_ir_lg/climate_ir_lg.cpp @@ -43,11 +43,11 @@ void LgIrClimate::transmit_state() { // ESP_LOGD(TAG, "climate_lg_ir mode_before_ code: 0x%02X", modeBefore_); // Set command - if (send_swing_cmd_) { - send_swing_cmd_ = false; + if (this->send_swing_cmd_) { + this->send_swing_cmd_ = false; remote_state |= COMMAND_SWING; } else { - bool climate_is_off = (mode_before_ == climate::CLIMATE_MODE_OFF); + bool climate_is_off = (this->mode_before_ == climate::CLIMATE_MODE_OFF); switch (this->mode) { case climate::CLIMATE_MODE_COOL: remote_state |= climate_is_off ? COMMAND_ON_COOL : COMMAND_COOL; @@ -71,7 +71,7 @@ void LgIrClimate::transmit_state() { } } - mode_before_ = this->mode; + this->mode_before_ = this->mode; ESP_LOGD(TAG, "climate_lg_ir mode code: 0x%02X", this->mode); @@ -102,7 +102,7 @@ void LgIrClimate::transmit_state() { remote_state |= ((temp - 15) << TEMP_SHIFT); } - transmit_(remote_state); + this->transmit_(remote_state); this->publish_state(); } @@ -187,7 +187,7 @@ bool LgIrClimate::on_receive(remote_base::RemoteReceiveData data) { } void LgIrClimate::transmit_(uint32_t value) { - calc_checksum_(value); + this->calc_checksum_(value); ESP_LOGD(TAG, "Sending climate_lg_ir code: 0x%02" PRIX32, value); auto transmit = this->transmitter_->transmit(); diff --git a/esphome/components/climate_ir_lg/climate_ir_lg.h b/esphome/components/climate_ir_lg/climate_ir_lg.h index 7ee041b86f..00fc99ae73 100644 --- a/esphome/components/climate_ir_lg/climate_ir_lg.h +++ b/esphome/components/climate_ir_lg/climate_ir_lg.h @@ -21,7 +21,7 @@ class LgIrClimate : public climate_ir::ClimateIR { /// Override control to change settings of the climate device. void control(const climate::ClimateCall &call) override { - send_swing_cmd_ = call.get_swing_mode().has_value(); + this->send_swing_cmd_ = call.get_swing_mode().has_value(); // swing resets after unit powered off if (call.get_mode().has_value() && *call.get_mode() == climate::CLIMATE_MODE_OFF) this->swing_mode = climate::CLIMATE_SWING_OFF; From 524cd4b4e357477f55a124709f00a47cf3a8bd87 Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Mon, 5 May 2025 07:29:17 +1000 Subject: [PATCH 087/102] [packet_transport] Extract packet encoding functionality (#8187) --- CODEOWNERS | 1 + .../components/packet_transport/__init__.py | 201 +++++++ .../packet_transport/binary_sensor.py | 19 + .../packet_transport/packet_transport.cpp | 534 ++++++++++++++++++ .../packet_transport/packet_transport.h | 155 +++++ esphome/components/packet_transport/sensor.py | 19 + esphome/components/udp/__init__.py | 262 +++++---- esphome/components/udp/automation.h | 38 ++ esphome/components/udp/binary_sensor.py | 28 +- .../udp/packet_transport/__init__.py | 29 + .../udp/packet_transport/udp_transport.cpp | 36 ++ .../udp/packet_transport/udp_transport.h | 26 + esphome/components/udp/sensor.py | 28 +- esphome/components/udp/udp_component.cpp | 496 +--------------- esphome/components/udp/udp_component.h | 149 +---- tests/components/packet_transport/common.yaml | 40 ++ .../packet_transport/test.bk72xx-ard.yaml | 1 + .../packet_transport/test.esp32-ard.yaml | 1 + .../packet_transport/test.esp32-c3-ard.yaml | 1 + .../packet_transport/test.esp32-c3-idf.yaml | 1 + .../packet_transport/test.esp32-idf.yaml | 1 + .../packet_transport/test.esp8266-ard.yaml | 1 + .../packet_transport/test.host.yaml | 4 + .../packet_transport/test.rp2040-ard.yaml | 1 + tests/components/udp/common.yaml | 42 +- 25 files changed, 1305 insertions(+), 809 deletions(-) create mode 100644 esphome/components/packet_transport/__init__.py create mode 100644 esphome/components/packet_transport/binary_sensor.py create mode 100644 esphome/components/packet_transport/packet_transport.cpp create mode 100644 esphome/components/packet_transport/packet_transport.h create mode 100644 esphome/components/packet_transport/sensor.py create mode 100644 esphome/components/udp/automation.h create mode 100644 esphome/components/udp/packet_transport/__init__.py create mode 100644 esphome/components/udp/packet_transport/udp_transport.cpp create mode 100644 esphome/components/udp/packet_transport/udp_transport.h create mode 100644 tests/components/packet_transport/common.yaml create mode 100644 tests/components/packet_transport/test.bk72xx-ard.yaml create mode 100644 tests/components/packet_transport/test.esp32-ard.yaml create mode 100644 tests/components/packet_transport/test.esp32-c3-ard.yaml create mode 100644 tests/components/packet_transport/test.esp32-c3-idf.yaml create mode 100644 tests/components/packet_transport/test.esp32-idf.yaml create mode 100644 tests/components/packet_transport/test.esp8266-ard.yaml create mode 100644 tests/components/packet_transport/test.host.yaml create mode 100644 tests/components/packet_transport/test.rp2040-ard.yaml diff --git a/CODEOWNERS b/CODEOWNERS index 06d3601858..46e0e6c579 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -319,6 +319,7 @@ esphome/components/online_image/* @clydebarrow @guillempages esphome/components/opentherm/* @olegtarasov esphome/components/ota/* @esphome/core esphome/components/output/* @esphome/core +esphome/components/packet_transport/* @clydebarrow esphome/components/pca6416a/* @Mat931 esphome/components/pca9554/* @clydebarrow @hwstar esphome/components/pcf85063/* @brogon diff --git a/esphome/components/packet_transport/__init__.py b/esphome/components/packet_transport/__init__.py new file mode 100644 index 0000000000..99c1d824ca --- /dev/null +++ b/esphome/components/packet_transport/__init__.py @@ -0,0 +1,201 @@ +"""ESPHome packet transport component.""" + +import hashlib +import logging + +import esphome.codegen as cg +from esphome.components.api import CONF_ENCRYPTION +from esphome.components.binary_sensor import BinarySensor +from esphome.components.sensor import Sensor +import esphome.config_validation as cv +from esphome.const import ( + CONF_BINARY_SENSORS, + CONF_ID, + CONF_INTERNAL, + CONF_KEY, + CONF_NAME, + CONF_PLATFORM, + CONF_SENSORS, +) +from esphome.core import CORE +from esphome.cpp_generator import MockObjClass + +CODEOWNERS = ["@clydebarrow"] +AUTO_LOAD = ["xxtea"] + +packet_transport_ns = cg.esphome_ns.namespace("packet_transport") +PacketTransport = packet_transport_ns.class_("PacketTransport", cg.PollingComponent) + +IS_PLATFORM_COMPONENT = True + +DOMAIN = "packet_transport" +CONF_BROADCAST = "broadcast" +CONF_BROADCAST_ID = "broadcast_id" +CONF_PROVIDER = "provider" +CONF_PROVIDERS = "providers" +CONF_REMOTE_ID = "remote_id" +CONF_PING_PONG_ENABLE = "ping_pong_enable" +CONF_PING_PONG_RECYCLE_TIME = "ping_pong_recycle_time" +CONF_ROLLING_CODE_ENABLE = "rolling_code_enable" +CONF_TRANSPORT_ID = "transport_id" + + +_LOGGER = logging.getLogger(__name__) + + +def sensor_validation(cls: MockObjClass): + return cv.maybe_simple_value( + cv.Schema( + { + cv.Required(CONF_ID): cv.use_id(cls), + cv.Optional(CONF_BROADCAST_ID): cv.validate_id_name, + } + ), + key=CONF_ID, + ) + + +def provider_name_validate(value): + value = cv.valid_name(value) + if "_" in value: + _LOGGER.warning( + "Device names typically do not contain underscores - did you mean to use a hyphen in '%s'?", + value, + ) + return value + + +ENCRYPTION_SCHEMA = { + cv.Optional(CONF_ENCRYPTION): cv.maybe_simple_value( + cv.Schema( + { + cv.Required(CONF_KEY): cv.string, + } + ), + key=CONF_KEY, + ) +} + +PROVIDER_SCHEMA = cv.Schema( + { + cv.Required(CONF_NAME): provider_name_validate, + } +).extend(ENCRYPTION_SCHEMA) + + +def validate_(config): + if CONF_ENCRYPTION in config: + if CONF_SENSORS not in config and CONF_BINARY_SENSORS not in config: + raise cv.Invalid("No sensors or binary sensors to encrypt") + elif config[CONF_ROLLING_CODE_ENABLE]: + raise cv.Invalid("Rolling code requires an encryption key") + if config[CONF_PING_PONG_ENABLE]: + if not any(CONF_ENCRYPTION in p for p in config.get(CONF_PROVIDERS) or ()): + raise cv.Invalid("Ping-pong requires at least one encrypted provider") + return config + + +TRANSPORT_SCHEMA = ( + cv.polling_component_schema("15s") + .extend( + { + cv.Optional(CONF_ROLLING_CODE_ENABLE, default=False): cv.boolean, + cv.Optional(CONF_PING_PONG_ENABLE, default=False): cv.boolean, + cv.Optional( + CONF_PING_PONG_RECYCLE_TIME, default="600s" + ): cv.positive_time_period_seconds, + cv.Optional(CONF_SENSORS): cv.ensure_list(sensor_validation(Sensor)), + cv.Optional(CONF_BINARY_SENSORS): cv.ensure_list( + sensor_validation(BinarySensor) + ), + cv.Optional(CONF_PROVIDERS, default=[]): cv.ensure_list(PROVIDER_SCHEMA), + }, + ) + .extend(ENCRYPTION_SCHEMA) + .add_extra(validate_) +) + + +def transport_schema(cls): + return TRANSPORT_SCHEMA.extend({cv.GenerateID(): cv.declare_id(cls)}) + + +# Build a list of sensors for this platform +CORE.data[DOMAIN] = {CONF_SENSORS: []} + + +def get_sensors(transport_id): + """Return the list of sensors for this platform.""" + return ( + sensor + for sensor in CORE.data[DOMAIN][CONF_SENSORS] + if sensor[CONF_TRANSPORT_ID] == transport_id + ) + + +def validate_packet_transport_sensor(config): + if CONF_NAME in config and CONF_INTERNAL not in config: + raise cv.Invalid("Must provide internal: config when using name:") + CORE.data[DOMAIN][CONF_SENSORS].append(config) + return config + + +def packet_transport_sensor_schema(base_schema): + return cv.All( + base_schema.extend( + { + cv.GenerateID(CONF_TRANSPORT_ID): cv.use_id(PacketTransport), + cv.Optional(CONF_REMOTE_ID): cv.string_strict, + cv.Required(CONF_PROVIDER): provider_name_validate, + } + ), + cv.has_at_least_one_key(CONF_ID, CONF_REMOTE_ID), + validate_packet_transport_sensor, + ) + + +def hash_encryption_key(config: dict): + return list(hashlib.sha256(config[CONF_KEY].encode()).digest()) + + +async def register_packet_transport(var, config): + var = await cg.register_component(var, config) + cg.add(var.set_rolling_code_enable(config[CONF_ROLLING_CODE_ENABLE])) + cg.add(var.set_ping_pong_enable(config[CONF_PING_PONG_ENABLE])) + cg.add( + var.set_ping_pong_recycle_time( + config[CONF_PING_PONG_RECYCLE_TIME].total_seconds + ) + ) + # Get directly configured providers, plus those from sensors and binary sensors + providers = { + sensor[CONF_PROVIDER] for sensor in get_sensors(config[CONF_ID]) + }.union(x[CONF_NAME] for x in config[CONF_PROVIDERS]) + for provider in providers: + cg.add(var.add_provider(provider)) + for provider in config[CONF_PROVIDERS]: + name = provider[CONF_NAME] + if encryption := provider.get(CONF_ENCRYPTION): + cg.add(var.set_provider_encryption(name, hash_encryption_key(encryption))) + + for sens_conf in config.get(CONF_SENSORS, ()): + sens_id = sens_conf[CONF_ID] + sensor = await cg.get_variable(sens_id) + bcst_id = sens_conf.get(CONF_BROADCAST_ID, sens_id.id) + cg.add(var.add_sensor(bcst_id, sensor)) + for sens_conf in config.get(CONF_BINARY_SENSORS, ()): + sens_id = sens_conf[CONF_ID] + sensor = await cg.get_variable(sens_id) + bcst_id = sens_conf.get(CONF_BROADCAST_ID, sens_id.id) + cg.add(var.add_binary_sensor(bcst_id, sensor)) + + if encryption := config.get(CONF_ENCRYPTION): + cg.add(var.set_encryption_key(hash_encryption_key(encryption))) + return providers + + +async def new_packet_transport(config): + var = cg.new_Pvariable(config[CONF_ID]) + cg.add(var.set_platform_name(config[CONF_PLATFORM])) + providers = await register_packet_transport(var, config) + return var, providers diff --git a/esphome/components/packet_transport/binary_sensor.py b/esphome/components/packet_transport/binary_sensor.py new file mode 100644 index 0000000000..076e37e6bb --- /dev/null +++ b/esphome/components/packet_transport/binary_sensor.py @@ -0,0 +1,19 @@ +import esphome.codegen as cg +from esphome.components import binary_sensor +from esphome.const import CONF_ID + +from . import ( + CONF_PROVIDER, + CONF_REMOTE_ID, + CONF_TRANSPORT_ID, + packet_transport_sensor_schema, +) + +CONFIG_SCHEMA = packet_transport_sensor_schema(binary_sensor.binary_sensor_schema()) + + +async def to_code(config): + var = await binary_sensor.new_binary_sensor(config) + comp = await cg.get_variable(config[CONF_TRANSPORT_ID]) + remote_id = str(config.get(CONF_REMOTE_ID) or config.get(CONF_ID)) + cg.add(comp.add_remote_binary_sensor(config[CONF_PROVIDER], remote_id, var)) diff --git a/esphome/components/packet_transport/packet_transport.cpp b/esphome/components/packet_transport/packet_transport.cpp new file mode 100644 index 0000000000..4514584408 --- /dev/null +++ b/esphome/components/packet_transport/packet_transport.cpp @@ -0,0 +1,534 @@ +#include "esphome/core/log.h" +#include "esphome/core/application.h" +#include "packet_transport.h" + +#include "esphome/components/xxtea/xxtea.h" + +namespace esphome { +namespace packet_transport { +/** + * Structure of a data packet; everything is little-endian + * + * --- In clear text --- + * MAGIC_NUMBER: 16 bits + * host name length: 1 byte + * host name: (length) bytes + * padding: 0 or more null bytes to a 4 byte boundary + * + * --- Encrypted (if key set) ---- + * DATA_KEY: 1 byte: OR ROLLING_CODE_KEY: + * Rolling code (if enabled): 8 bytes + * Ping keys: if any + * repeat: + * PING_KEY: 1 byte + * ping code: 4 bytes + * Sensors: + * repeat: + * SENSOR_KEY: 1 byte + * float value: 4 bytes + * name length: 1 byte + * name + * Binary Sensors: + * repeat: + * BINARY_SENSOR_KEY: 1 byte + * bool value: 1 bytes + * name length: 1 byte + * name + * + * Padded to a 4 byte boundary with nulls + * + * Structure of a ping request packet: + * --- In clear text --- + * MAGIC_PING: 16 bits + * host name length: 1 byte + * host name: (length) bytes + * Ping key (4 bytes) + * + */ +static const char *const TAG = "packet_transport"; + +static size_t round4(size_t value) { return (value + 3) & ~3; } + +union FuData { + uint32_t u32; + float f32; +}; + +static const uint16_t MAGIC_NUMBER = 0x4553; +static const uint16_t MAGIC_PING = 0x5048; +static const uint32_t PREF_HASH = 0x45535043; +enum DataKey { + ZERO_FILL_KEY, + DATA_KEY, + SENSOR_KEY, + BINARY_SENSOR_KEY, + PING_KEY, + ROLLING_CODE_KEY, +}; + +enum DecodeResult { + DECODE_OK, + DECODE_UNMATCHED, + DECODE_ERROR, + DECODE_EMPTY, +}; + +static const size_t MAX_PING_KEYS = 4; + +static inline void add(std::vector &vec, uint32_t data) { + vec.push_back(data & 0xFF); + vec.push_back((data >> 8) & 0xFF); + vec.push_back((data >> 16) & 0xFF); + vec.push_back((data >> 24) & 0xFF); +} + +class PacketDecoder { + public: + PacketDecoder(const uint8_t *buffer, size_t len) : buffer_(buffer), len_(len) {} + + DecodeResult decode_string(char *data, size_t maxlen) { + if (this->position_ == this->len_) + return DECODE_EMPTY; + auto len = this->buffer_[this->position_]; + if (len == 0 || this->position_ + 1 + len > this->len_ || len >= maxlen) + return DECODE_ERROR; + this->position_++; + memcpy(data, this->buffer_ + this->position_, len); + data[len] = 0; + this->position_ += len; + return DECODE_OK; + } + + template DecodeResult get(T &data) { + if (this->position_ + sizeof(T) > this->len_) + return DECODE_ERROR; + T value = 0; + for (size_t i = 0; i != sizeof(T); ++i) { + value += this->buffer_[this->position_++] << (i * 8); + } + data = value; + return DECODE_OK; + } + + template DecodeResult decode(uint8_t key, T &data) { + if (this->position_ == this->len_) + return DECODE_EMPTY; + if (this->buffer_[this->position_] != key) + return DECODE_UNMATCHED; + if (this->position_ + 1 + sizeof(T) > this->len_) + return DECODE_ERROR; + this->position_++; + T value = 0; + for (size_t i = 0; i != sizeof(T); ++i) { + value += this->buffer_[this->position_++] << (i * 8); + } + data = value; + return DECODE_OK; + } + + template DecodeResult decode(uint8_t key, char *buf, size_t buflen, T &data) { + if (this->position_ == this->len_) + return DECODE_EMPTY; + if (this->buffer_[this->position_] != key) + return DECODE_UNMATCHED; + this->position_++; + T value = 0; + for (size_t i = 0; i != sizeof(T); ++i) { + value += this->buffer_[this->position_++] << (i * 8); + } + data = value; + return this->decode_string(buf, buflen); + } + + DecodeResult decode(uint8_t key) { + if (this->position_ == this->len_) + return DECODE_EMPTY; + if (this->buffer_[this->position_] != key) + return DECODE_UNMATCHED; + this->position_++; + return DECODE_OK; + } + + size_t get_remaining_size() const { return this->len_ - this->position_; } + + // align the pointer to the given byte boundary + bool bump_to(size_t boundary) { + auto newpos = this->position_; + auto offset = this->position_ % boundary; + if (offset != 0) { + newpos += boundary - offset; + } + if (newpos >= this->len_) + return false; + this->position_ = newpos; + return true; + } + + bool decrypt(const uint32_t *key) { + if (this->get_remaining_size() % 4 != 0) { + return false; + } + xxtea::decrypt((uint32_t *) (this->buffer_ + this->position_), this->get_remaining_size() / 4, key); + return true; + } + + protected: + const uint8_t *buffer_; + size_t len_; + size_t position_{}; +}; + +static inline void add(std::vector &vec, uint8_t data) { vec.push_back(data); } +static inline void add(std::vector &vec, uint16_t data) { + vec.push_back((uint8_t) data); + vec.push_back((uint8_t) (data >> 8)); +} +static inline void add(std::vector &vec, DataKey data) { vec.push_back(data); } +static void add(std::vector &vec, const char *str) { + auto len = strlen(str); + vec.push_back(len); + for (size_t i = 0; i != len; i++) { + vec.push_back(*str++); + } +} + +void PacketTransport::setup() { + this->name_ = App.get_name().c_str(); + if (strlen(this->name_) > 255) { + this->mark_failed(); + this->status_set_error("Device name exceeds 255 chars"); + return; + } + this->resend_ping_key_ = this->ping_pong_enable_; + this->pref_ = global_preferences->make_preference(PREF_HASH, true); + if (this->rolling_code_enable_) { + // restore the upper 32 bits of the rolling code, increment and save. + this->pref_.load(&this->rolling_code_[1]); + this->rolling_code_[1]++; + this->pref_.save(&this->rolling_code_[1]); + // must make sure it's saved immediately + global_preferences->sync(); + this->ping_key_ = random_uint32(); + ESP_LOGV(TAG, "Rolling code incremented, upper part now %u", (unsigned) this->rolling_code_[1]); + } +#ifdef USE_SENSOR + for (auto &sensor : this->sensors_) { + sensor.sensor->add_on_state_callback([this, &sensor](float x) { + this->updated_ = true; + sensor.updated = true; + }); + } +#endif +#ifdef USE_BINARY_SENSOR + for (auto &sensor : this->binary_sensors_) { + sensor.sensor->add_on_state_callback([this, &sensor](bool value) { + this->updated_ = true; + sensor.updated = true; + }); + } +#endif + // initialise the header. This is invariant. + add(this->header_, MAGIC_NUMBER); + add(this->header_, this->name_); + // pad to a multiple of 4 bytes + while (this->header_.size() & 0x3) + this->header_.push_back(0); +} + +void PacketTransport::init_data_() { + this->data_.clear(); + if (this->rolling_code_enable_) { + add(this->data_, ROLLING_CODE_KEY); + add(this->data_, this->rolling_code_[0]); + add(this->data_, this->rolling_code_[1]); + this->increment_code_(); + } else { + add(this->data_, DATA_KEY); + } + for (auto pkey : this->ping_keys_) { + add(this->data_, PING_KEY); + add(this->data_, pkey.second); + } +} + +void PacketTransport::flush_() { + if (!this->should_send() || this->data_.empty()) + return; + auto header_len = round4(this->header_.size()); + auto len = round4(data_.size()); + auto encode_buffer = std::vector(round4(header_len + len)); + memcpy(encode_buffer.data(), this->header_.data(), this->header_.size()); + memcpy(encode_buffer.data() + header_len, this->data_.data(), this->data_.size()); + if (this->is_encrypted_()) { + xxtea::encrypt((uint32_t *) (encode_buffer.data() + header_len), len / 4, + (uint32_t *) this->encryption_key_.data()); + } + this->send_packet(encode_buffer); +} + +void PacketTransport::add_binary_data_(uint8_t key, const char *id, bool data) { + auto len = 1 + 1 + 1 + strlen(id); + if (len + this->header_.size() + this->data_.size() > this->get_max_packet_size()) { + this->flush_(); + } + add(this->data_, key); + add(this->data_, (uint8_t) data); + add(this->data_, id); +} +void PacketTransport::add_data_(uint8_t key, const char *id, float data) { + FuData udata{.f32 = data}; + this->add_data_(key, id, udata.u32); +} + +void PacketTransport::add_data_(uint8_t key, const char *id, uint32_t data) { + auto len = 4 + 1 + 1 + strlen(id); + if (len + this->header_.size() + this->data_.size() > this->get_max_packet_size()) { + this->flush_(); + } + add(this->data_, key); + add(this->data_, data); + add(this->data_, id); +} +void PacketTransport::send_data_(bool all) { + if (!this->should_send()) + return; + this->init_data_(); +#ifdef USE_SENSOR + for (auto &sensor : this->sensors_) { + if (all || sensor.updated) { + sensor.updated = false; + this->add_data_(SENSOR_KEY, sensor.id, sensor.sensor->get_state()); + } + } +#endif +#ifdef USE_BINARY_SENSOR + for (auto &sensor : this->binary_sensors_) { + if (all || sensor.updated) { + sensor.updated = false; + this->add_binary_data_(BINARY_SENSOR_KEY, sensor.id, sensor.sensor->state); + } + } +#endif + this->flush_(); + this->updated_ = false; +} + +void PacketTransport::update() { + auto now = millis() / 1000; + if (this->last_key_time_ + this->ping_pong_recyle_time_ < now) { + this->resend_ping_key_ = this->ping_pong_enable_; + this->last_key_time_ = now; + } +} + +void PacketTransport::add_key_(const char *name, uint32_t key) { + if (!this->is_encrypted_()) + return; + if (this->ping_keys_.count(name) == 0 && this->ping_keys_.size() == MAX_PING_KEYS) { + ESP_LOGW(TAG, "Ping key from %s discarded", name); + return; + } + this->ping_keys_[name] = key; + this->updated_ = true; + ESP_LOGV(TAG, "Ping key from %s now %X", name, (unsigned) key); +} + +static bool process_rolling_code(Provider &provider, PacketDecoder &decoder) { + uint32_t code0, code1; + if (decoder.get(code0) != DECODE_OK || decoder.get(code1) != DECODE_OK) { + ESP_LOGW(TAG, "Rolling code requires 8 bytes"); + return false; + } + if (code1 < provider.last_code[1] || (code1 == provider.last_code[1] && code0 <= provider.last_code[0])) { + ESP_LOGW(TAG, "Rolling code for %s %08lX:%08lX is old", provider.name, (unsigned long) code1, + (unsigned long) code0); + return false; + } + provider.last_code[0] = code0; + provider.last_code[1] = code1; + ESP_LOGV(TAG, "Saw new rolling code for %s %08lX:%08lX", provider.name, (unsigned long) code1, (unsigned long) code0); + return true; +} + +/** + * Process a received packet + */ +void PacketTransport::process_(std::vector &data) { + auto ping_key_seen = !this->ping_pong_enable_; + PacketDecoder decoder((data.data()), data.size()); + char namebuf[256]{}; + uint8_t byte; + FuData rdata{}; + uint16_t magic; + if (decoder.get(magic) != DECODE_OK) { + ESP_LOGD(TAG, "Short buffer"); + return; + } + if (magic != MAGIC_NUMBER && magic != MAGIC_PING) { + ESP_LOGV(TAG, "Bad magic %X", magic); + return; + } + + if (decoder.decode_string(namebuf, sizeof namebuf) != DECODE_OK) { + ESP_LOGV(TAG, "Bad hostname length"); + return; + } + if (strcmp(this->name_, namebuf) == 0) { + ESP_LOGVV(TAG, "Ignoring our own data"); + return; + } + if (magic == MAGIC_PING) { + uint32_t key; + if (decoder.get(key) != DECODE_OK) { + ESP_LOGW(TAG, "Bad ping request"); + return; + } + this->add_key_(namebuf, key); + ESP_LOGV(TAG, "Updated ping key for %s to %08X", namebuf, (unsigned) key); + return; + } + + if (this->providers_.count(namebuf) == 0) { + ESP_LOGVV(TAG, "Unknown hostname %s", namebuf); + return; + } + ESP_LOGV(TAG, "Found hostname %s", namebuf); + +#ifdef USE_SENSOR + auto &sensors = this->remote_sensors_[namebuf]; +#endif +#ifdef USE_BINARY_SENSOR + auto &binary_sensors = this->remote_binary_sensors_[namebuf]; +#endif + + if (!decoder.bump_to(4)) { + ESP_LOGW(TAG, "Bad packet length %zu", data.size()); + } + auto len = decoder.get_remaining_size(); + if (round4(len) != len) { + ESP_LOGW(TAG, "Bad payload length %zu", len); + return; + } + + auto &provider = this->providers_[namebuf]; + // if encryption not used with this host, ping check is pointless since it would be easily spoofed. + if (provider.encryption_key.empty()) + ping_key_seen = true; + + if (!provider.encryption_key.empty()) { + decoder.decrypt((const uint32_t *) provider.encryption_key.data()); + } + if (decoder.get(byte) != DECODE_OK) { + ESP_LOGV(TAG, "No key byte"); + return; + } + + if (byte == ROLLING_CODE_KEY) { + if (!process_rolling_code(provider, decoder)) + return; + } else if (byte != DATA_KEY) { + ESP_LOGV(TAG, "Expected rolling_key or data_key, got %X", byte); + return; + } + uint32_t key; + while (decoder.get_remaining_size() != 0) { + if (decoder.decode(ZERO_FILL_KEY) == DECODE_OK) + continue; + if (decoder.decode(PING_KEY, key) == DECODE_OK) { + if (key == this->ping_key_) { + ping_key_seen = true; + ESP_LOGV(TAG, "Found good ping key %X", (unsigned) key); + } else { + ESP_LOGV(TAG, "Unknown ping key %X", (unsigned) key); + } + continue; + } + if (!ping_key_seen) { + ESP_LOGW(TAG, "Ping key not seen"); + this->resend_ping_key_ = true; + break; + } + if (decoder.decode(BINARY_SENSOR_KEY, namebuf, sizeof(namebuf), byte) == DECODE_OK) { + ESP_LOGV(TAG, "Got binary sensor %s %d", namebuf, byte); +#ifdef USE_BINARY_SENSOR + if (binary_sensors.count(namebuf) != 0) + binary_sensors[namebuf]->publish_state(byte != 0); +#endif + continue; + } + if (decoder.decode(SENSOR_KEY, namebuf, sizeof(namebuf), rdata.u32) == DECODE_OK) { + ESP_LOGV(TAG, "Got sensor %s %f", namebuf, rdata.f32); +#ifdef USE_SENSOR + if (sensors.count(namebuf) != 0) + sensors[namebuf]->publish_state(rdata.f32); +#endif + continue; + } + if (decoder.get(byte) == DECODE_OK) { + ESP_LOGW(TAG, "Unknown key %X", byte); + ESP_LOGD(TAG, "Buffer pos: %zu contents: %s", data.size() - decoder.get_remaining_size(), + format_hex_pretty(data).c_str()); + } + break; + } +} + +void PacketTransport::dump_config() { + ESP_LOGCONFIG(TAG, "Packet Transport:"); + ESP_LOGCONFIG(TAG, " Platform: %s", this->platform_name_); + ESP_LOGCONFIG(TAG, " Encrypted: %s", YESNO(this->is_encrypted_())); + ESP_LOGCONFIG(TAG, " Ping-pong: %s", YESNO(this->ping_pong_enable_)); +#ifdef USE_SENSOR + for (auto sensor : this->sensors_) + ESP_LOGCONFIG(TAG, " Sensor: %s", sensor.id); +#endif +#ifdef USE_BINARY_SENSOR + for (auto sensor : this->binary_sensors_) + ESP_LOGCONFIG(TAG, " Binary Sensor: %s", sensor.id); +#endif + for (const auto &host : this->providers_) { + ESP_LOGCONFIG(TAG, " Remote host: %s", host.first.c_str()); + ESP_LOGCONFIG(TAG, " Encrypted: %s", YESNO(!host.second.encryption_key.empty())); +#ifdef USE_SENSOR + for (const auto &sensor : this->remote_sensors_[host.first.c_str()]) + ESP_LOGCONFIG(TAG, " Sensor: %s", sensor.first.c_str()); +#endif +#ifdef USE_BINARY_SENSOR + for (const auto &sensor : this->remote_binary_sensors_[host.first.c_str()]) + ESP_LOGCONFIG(TAG, " Binary Sensor: %s", sensor.first.c_str()); +#endif + } +} +void PacketTransport::increment_code_() { + if (this->rolling_code_enable_) { + if (++this->rolling_code_[0] == 0) { + this->rolling_code_[1]++; + this->pref_.save(&this->rolling_code_[1]); + // must make sure it's saved immediately + global_preferences->sync(); + } + } +} + +void PacketTransport::loop() { + if (this->resend_ping_key_) + this->send_ping_pong_request_(); + if (this->updated_) { + this->send_data_(this->resend_data_); + } +} + +void PacketTransport::send_ping_pong_request_() { + if (!this->ping_pong_enable_ || !this->should_send()) + return; + this->ping_key_ = random_uint32(); + this->ping_header_.clear(); + add(this->ping_header_, MAGIC_PING); + add(this->ping_header_, this->name_); + add(this->ping_header_, this->ping_key_); + this->send_packet(this->ping_header_); + this->resend_ping_key_ = false; + ESP_LOGV(TAG, "Sent new ping request %08X", (unsigned) this->ping_key_); +} +} // namespace packet_transport +} // namespace esphome diff --git a/esphome/components/packet_transport/packet_transport.h b/esphome/components/packet_transport/packet_transport.h new file mode 100644 index 0000000000..6799cb6ea1 --- /dev/null +++ b/esphome/components/packet_transport/packet_transport.h @@ -0,0 +1,155 @@ +#pragma once + +#include "esphome/core/component.h" +#include "esphome/core/preferences.h" +#ifdef USE_SENSOR +#include "esphome/components/sensor/sensor.h" +#endif +#ifdef USE_BINARY_SENSOR +#include "esphome/components/binary_sensor/binary_sensor.h" +#endif +# +#include +#include + +/** + * Providing packet encoding functions for exchanging data with a remote host. + * + * A transport is required to send the data; this is provided by a child class. + * The child class should implement the virtual functions send_packet_ and get_max_packet_size_. + * On receipt of a data packet, it should call `this->process_()` with the data. + */ + +namespace esphome { +namespace packet_transport { + +struct Provider { + std::vector encryption_key; + const char *name; + uint32_t last_code[2]; +}; + +#ifdef USE_SENSOR +struct Sensor { + sensor::Sensor *sensor; + const char *id; + bool updated; +}; +#endif +#ifdef USE_BINARY_SENSOR +struct BinarySensor { + binary_sensor::BinarySensor *sensor; + const char *id; + bool updated; +}; +#endif + +class PacketTransport : public PollingComponent { + public: + void setup() override; + void loop() override; + void update() override; + void dump_config() override; + +#ifdef USE_SENSOR + void add_sensor(const char *id, sensor::Sensor *sensor) { + Sensor st{sensor, id, true}; + this->sensors_.push_back(st); + } + void add_remote_sensor(const char *hostname, const char *remote_id, sensor::Sensor *sensor) { + this->add_provider(hostname); + this->remote_sensors_[hostname][remote_id] = sensor; + } +#endif +#ifdef USE_BINARY_SENSOR + void add_binary_sensor(const char *id, binary_sensor::BinarySensor *sensor) { + BinarySensor st{sensor, id, true}; + this->binary_sensors_.push_back(st); + } + + void add_remote_binary_sensor(const char *hostname, const char *remote_id, binary_sensor::BinarySensor *sensor) { + this->add_provider(hostname); + this->remote_binary_sensors_[hostname][remote_id] = sensor; + } +#endif + + void add_provider(const char *hostname) { + if (this->providers_.count(hostname) == 0) { + Provider provider; + provider.encryption_key = std::vector{}; + provider.last_code[0] = 0; + provider.last_code[1] = 0; + provider.name = hostname; + this->providers_[hostname] = provider; +#ifdef USE_SENSOR + this->remote_sensors_[hostname] = std::map(); +#endif +#ifdef USE_BINARY_SENSOR + this->remote_binary_sensors_[hostname] = std::map(); +#endif + } + } + + void set_encryption_key(std::vector key) { this->encryption_key_ = std::move(key); } + void set_rolling_code_enable(bool enable) { this->rolling_code_enable_ = enable; } + void set_ping_pong_enable(bool enable) { this->ping_pong_enable_ = enable; } + void set_ping_pong_recycle_time(uint32_t recycle_time) { this->ping_pong_recyle_time_ = recycle_time; } + void set_provider_encryption(const char *name, std::vector key) { + this->providers_[name].encryption_key = std::move(key); + } + void set_platform_name(const char *name) { this->platform_name_ = name; } + + protected: + // child classes must implement this + virtual void send_packet(std::vector &buf) const = 0; + virtual size_t get_max_packet_size() = 0; + virtual bool should_send() { return true; } + + // to be called by child classes when a data packet is received. + void process_(std::vector &data); + void send_data_(bool all); + void flush_(); + void add_data_(uint8_t key, const char *id, float data); + void add_data_(uint8_t key, const char *id, uint32_t data); + void increment_code_(); + void add_binary_data_(uint8_t key, const char *id, bool data); + void init_data_(); + + bool updated_{}; + uint32_t ping_key_{}; + uint32_t rolling_code_[2]{}; + bool rolling_code_enable_{}; + bool ping_pong_enable_{}; + uint32_t ping_pong_recyle_time_{}; + uint32_t last_key_time_{}; + bool resend_ping_key_{}; + bool resend_data_{}; + const char *name_{}; + ESPPreferenceObject pref_{}; + + std::vector encryption_key_{}; + +#ifdef USE_SENSOR + std::vector sensors_{}; + std::map> remote_sensors_{}; +#endif +#ifdef USE_BINARY_SENSOR + std::vector binary_sensors_{}; + std::map> remote_binary_sensors_{}; +#endif + + std::map providers_{}; + std::vector ping_header_{}; + std::vector header_{}; + std::vector data_{}; + std::map ping_keys_{}; + const char *platform_name_{""}; + void add_key_(const char *name, uint32_t key); + void send_ping_pong_request_(); + void process_ping_request_(const char *name, uint8_t *ptr, size_t len); + + inline bool is_encrypted_() { return !this->encryption_key_.empty(); } +}; + +} // namespace packet_transport +} // namespace esphome diff --git a/esphome/components/packet_transport/sensor.py b/esphome/components/packet_transport/sensor.py new file mode 100644 index 0000000000..15c0e33b30 --- /dev/null +++ b/esphome/components/packet_transport/sensor.py @@ -0,0 +1,19 @@ +import esphome.codegen as cg +from esphome.components.sensor import new_sensor, sensor_schema +from esphome.const import CONF_ID + +from . import ( + CONF_PROVIDER, + CONF_REMOTE_ID, + CONF_TRANSPORT_ID, + packet_transport_sensor_schema, +) + +CONFIG_SCHEMA = packet_transport_sensor_schema(sensor_schema()) + + +async def to_code(config): + var = await new_sensor(config) + comp = await cg.get_variable(config[CONF_TRANSPORT_ID]) + remote_id = str(config.get(CONF_REMOTE_ID) or config.get(CONF_ID)) + cg.add(comp.add_remote_sensor(config[CONF_PROVIDER], remote_id, var)) diff --git a/esphome/components/udp/__init__.py b/esphome/components/udp/__init__.py index 140d1e4236..ed405d7c22 100644 --- a/esphome/components/udp/__init__.py +++ b/esphome/components/udp/__init__.py @@ -1,164 +1,162 @@ -import hashlib - +from esphome import automation +from esphome.automation import Trigger import esphome.codegen as cg -from esphome.components.api import CONF_ENCRYPTION -from esphome.components.binary_sensor import BinarySensor -from esphome.components.sensor import Sensor -import esphome.config_validation as cv -from esphome.const import ( +from esphome.components.packet_transport import ( CONF_BINARY_SENSORS, - CONF_ID, - CONF_INTERNAL, - CONF_KEY, - CONF_NAME, - CONF_PORT, + CONF_ENCRYPTION, + CONF_PING_PONG_ENABLE, + CONF_PROVIDERS, + CONF_ROLLING_CODE_ENABLE, CONF_SENSORS, ) -from esphome.cpp_generator import MockObjClass +import esphome.config_validation as cv +from esphome.const import CONF_DATA, CONF_ID, CONF_PORT, CONF_TRIGGER_ID +from esphome.core import Lambda +from esphome.cpp_generator import ExpressionStatement, MockObj CODEOWNERS = ["@clydebarrow"] DEPENDENCIES = ["network"] -AUTO_LOAD = ["socket", "xxtea"] +AUTO_LOAD = ["socket"] + MULTI_CONF = True - udp_ns = cg.esphome_ns.namespace("udp") -UDPComponent = udp_ns.class_("UDPComponent", cg.PollingComponent) +UDPComponent = udp_ns.class_("UDPComponent", cg.Component) +UDPWriteAction = udp_ns.class_("UDPWriteAction", automation.Action) +trigger_args = cg.std_vector.template(cg.uint8) -CONF_BROADCAST = "broadcast" -CONF_BROADCAST_ID = "broadcast_id" CONF_ADDRESSES = "addresses" CONF_LISTEN_ADDRESS = "listen_address" -CONF_PROVIDER = "provider" -CONF_PROVIDERS = "providers" -CONF_REMOTE_ID = "remote_id" CONF_UDP_ID = "udp_id" -CONF_PING_PONG_ENABLE = "ping_pong_enable" -CONF_PING_PONG_RECYCLE_TIME = "ping_pong_recycle_time" -CONF_ROLLING_CODE_ENABLE = "rolling_code_enable" +CONF_ON_RECEIVE = "on_receive" +CONF_LISTEN_PORT = "listen_port" +CONF_BROADCAST_PORT = "broadcast_port" - -def sensor_validation(cls: MockObjClass): - return cv.maybe_simple_value( - cv.Schema( - { - cv.Required(CONF_ID): cv.use_id(cls), - cv.Optional(CONF_BROADCAST_ID): cv.validate_id_name, - } - ), - key=CONF_ID, - ) - - -ENCRYPTION_SCHEMA = { - cv.Optional(CONF_ENCRYPTION): cv.maybe_simple_value( - cv.Schema( - { - cv.Required(CONF_KEY): cv.string, - } - ), - key=CONF_KEY, - ) -} - -PROVIDER_SCHEMA = cv.Schema( +UDP_SCHEMA = cv.Schema( { - cv.Required(CONF_NAME): cv.valid_name, - } -).extend(ENCRYPTION_SCHEMA) - - -def validate_(config): - if CONF_ENCRYPTION in config: - if CONF_SENSORS not in config and CONF_BINARY_SENSORS not in config: - raise cv.Invalid("No sensors or binary sensors to encrypt") - elif config[CONF_ROLLING_CODE_ENABLE]: - raise cv.Invalid("Rolling code requires an encryption key") - if config[CONF_PING_PONG_ENABLE]: - if not any(CONF_ENCRYPTION in p for p in config.get(CONF_PROVIDERS) or ()): - raise cv.Invalid("Ping-pong requires at least one encrypted provider") - return config - - -CONFIG_SCHEMA = cv.All( - cv.polling_component_schema("15s") - .extend( - { - cv.GenerateID(): cv.declare_id(UDPComponent), - cv.Optional(CONF_PORT, default=18511): cv.port, - cv.Optional( - CONF_LISTEN_ADDRESS, default="255.255.255.255" - ): cv.ipv4address_multi_broadcast, - cv.Optional(CONF_ADDRESSES, default=["255.255.255.255"]): cv.ensure_list( - cv.ipv4address, - ), - cv.Optional(CONF_ROLLING_CODE_ENABLE, default=False): cv.boolean, - cv.Optional(CONF_PING_PONG_ENABLE, default=False): cv.boolean, - cv.Optional( - CONF_PING_PONG_RECYCLE_TIME, default="600s" - ): cv.positive_time_period_seconds, - cv.Optional(CONF_SENSORS): cv.ensure_list(sensor_validation(Sensor)), - cv.Optional(CONF_BINARY_SENSORS): cv.ensure_list( - sensor_validation(BinarySensor) - ), - cv.Optional(CONF_PROVIDERS): cv.ensure_list(PROVIDER_SCHEMA), - }, - ) - .extend(ENCRYPTION_SCHEMA), - validate_, -) - -SENSOR_SCHEMA = cv.Schema( - { - cv.Optional(CONF_REMOTE_ID): cv.string_strict, - cv.Required(CONF_PROVIDER): cv.valid_name, cv.GenerateID(CONF_UDP_ID): cv.use_id(UDPComponent), } ) -def require_internal_with_name(config): - if CONF_NAME in config and CONF_INTERNAL not in config: - raise cv.Invalid("Must provide internal: config when using name:") - return config +def is_relocated(option): + def validator(value): + raise cv.Invalid( + f"The '{option}' option should now be configured in the 'packet_transport' component" + ) + + return validator -def hash_encryption_key(config: dict): - return list(hashlib.sha256(config[CONF_KEY].encode()).digest()) +RELOCATED = { + cv.Optional(x): is_relocated(x) + for x in ( + CONF_PROVIDERS, + CONF_ENCRYPTION, + CONF_PING_PONG_ENABLE, + CONF_ROLLING_CODE_ENABLE, + CONF_SENSORS, + CONF_BINARY_SENSORS, + ) +} + +CONFIG_SCHEMA = cv.COMPONENT_SCHEMA.extend( + { + cv.GenerateID(): cv.declare_id(UDPComponent), + cv.Optional(CONF_PORT, default=18511): cv.Any( + cv.port, + cv.Schema( + { + cv.Required(CONF_LISTEN_PORT): cv.port, + cv.Required(CONF_BROADCAST_PORT): cv.port, + } + ), + ), + cv.Optional( + CONF_LISTEN_ADDRESS, default="255.255.255.255" + ): cv.ipv4address_multi_broadcast, + cv.Optional(CONF_ADDRESSES, default=["255.255.255.255"]): cv.ensure_list( + cv.ipv4address, + ), + cv.Optional(CONF_ON_RECEIVE): automation.validate_automation( + { + cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id( + Trigger.template(trigger_args) + ), + } + ), + } +).extend(RELOCATED) + + +async def register_udp_client(var, config): + udp_var = await cg.get_variable(config[CONF_UDP_ID]) + cg.add(var.set_parent(udp_var)) + return udp_var async def to_code(config): cg.add_define("USE_UDP") cg.add_global(udp_ns.using) var = cg.new_Pvariable(config[CONF_ID]) - await cg.register_component(var, config) - cg.add(var.set_port(config[CONF_PORT])) - cg.add(var.set_rolling_code_enable(config[CONF_ROLLING_CODE_ENABLE])) - cg.add(var.set_ping_pong_enable(config[CONF_PING_PONG_ENABLE])) - cg.add( - var.set_ping_pong_recycle_time( - config[CONF_PING_PONG_RECYCLE_TIME].total_seconds - ) - ) - for sens_conf in config.get(CONF_SENSORS, ()): - sens_id = sens_conf[CONF_ID] - sensor = await cg.get_variable(sens_id) - bcst_id = sens_conf.get(CONF_BROADCAST_ID, sens_id.id) - cg.add(var.add_sensor(bcst_id, sensor)) - for sens_conf in config.get(CONF_BINARY_SENSORS, ()): - sens_id = sens_conf[CONF_ID] - sensor = await cg.get_variable(sens_id) - bcst_id = sens_conf.get(CONF_BROADCAST_ID, sens_id.id) - cg.add(var.add_binary_sensor(bcst_id, sensor)) + var = await cg.register_component(var, config) + conf_port = config[CONF_PORT] + if isinstance(conf_port, int): + cg.add(var.set_listen_port(conf_port)) + cg.add(var.set_broadcast_port(conf_port)) + else: + cg.add(var.set_listen_port(conf_port[CONF_LISTEN_PORT])) + cg.add(var.set_broadcast_port(conf_port[CONF_BROADCAST_PORT])) + if (listen_address := str(config[CONF_LISTEN_ADDRESS])) != "255.255.255.255": + cg.add(var.set_listen_address(listen_address)) for address in config[CONF_ADDRESSES]: cg.add(var.add_address(str(address))) + if on_receive := config.get(CONF_ON_RECEIVE): + on_receive = on_receive[0] + trigger = cg.new_Pvariable(on_receive[CONF_TRIGGER_ID]) + trigger = await automation.build_automation( + trigger, [(trigger_args, "data")], on_receive + ) + trigger = Lambda(str(ExpressionStatement(trigger.trigger(MockObj("data"))))) + trigger = await cg.process_lambda(trigger, [(trigger_args, "data")]) + cg.add(var.add_listener(trigger)) + cg.add(var.set_should_listen()) - if encryption := config.get(CONF_ENCRYPTION): - cg.add(var.set_encryption_key(hash_encryption_key(encryption))) - for provider in config.get(CONF_PROVIDERS, ()): - name = provider[CONF_NAME] - cg.add(var.add_provider(name)) - if (listen_address := str(config[CONF_LISTEN_ADDRESS])) != "255.255.255.255": - cg.add(var.set_listen_address(listen_address)) - if encryption := provider.get(CONF_ENCRYPTION): - cg.add(var.set_provider_encryption(name, hash_encryption_key(encryption))) +def validate_raw_data(value): + if isinstance(value, str): + return value.encode("utf-8") + if isinstance(value, str): + return value + if isinstance(value, list): + return cv.Schema([cv.hex_uint8_t])(value) + raise cv.Invalid( + "data must either be a string wrapped in quotes or a list of bytes" + ) + + +@automation.register_action( + "udp.write", + UDPWriteAction, + cv.maybe_simple_value( + { + cv.GenerateID(): cv.use_id(UDPComponent), + cv.Required(CONF_DATA): cv.templatable(validate_raw_data), + }, + key=CONF_DATA, + ), +) +async def udp_write_to_code(config, action_id, template_arg, args): + var = cg.new_Pvariable(action_id, template_arg) + udp_var = await cg.get_variable(config[CONF_ID]) + await cg.register_parented(var, udp_var) + cg.add(udp_var.set_should_broadcast()) + data = config[CONF_DATA] + if isinstance(data, bytes): + data = list(data) + + if cg.is_template(data): + templ = await cg.templatable(data, args, cg.std_vector.template(cg.uint8)) + cg.add(var.set_data_template(templ)) + else: + cg.add(var.set_data_static(data)) + return var diff --git a/esphome/components/udp/automation.h b/esphome/components/udp/automation.h new file mode 100644 index 0000000000..663daa1c15 --- /dev/null +++ b/esphome/components/udp/automation.h @@ -0,0 +1,38 @@ +#pragma once + +#include "udp_component.h" +#include "esphome/core/automation.h" + +#include + +namespace esphome { +namespace udp { + +template class UDPWriteAction : public Action, public Parented { + public: + void set_data_template(std::function(Ts...)> func) { + this->data_func_ = func; + this->static_ = false; + } + void set_data_static(const std::vector &data) { + this->data_static_ = data; + this->static_ = true; + } + + void play(Ts... x) override { + if (this->static_) { + this->parent_->send_packet(this->data_static_); + } else { + auto val = this->data_func_(x...); + this->parent_->send_packet(val); + } + } + + protected: + bool static_{false}; + std::function(Ts...)> data_func_{}; + std::vector data_static_{}; +}; + +} // namespace udp +} // namespace esphome diff --git a/esphome/components/udp/binary_sensor.py b/esphome/components/udp/binary_sensor.py index d90e495527..7d449efbfd 100644 --- a/esphome/components/udp/binary_sensor.py +++ b/esphome/components/udp/binary_sensor.py @@ -1,27 +1,5 @@ -import esphome.codegen as cg -from esphome.components import binary_sensor -from esphome.config_validation import All, has_at_least_one_key -from esphome.const import CONF_ID +import esphome.config_validation as cv -from . import ( - CONF_PROVIDER, - CONF_REMOTE_ID, - CONF_UDP_ID, - SENSOR_SCHEMA, - require_internal_with_name, +CONFIG_SCHEMA = cv.invalid( + "The 'udp.binary_sensor' component has been migrated to the 'packet_transport.binary_sensor' component." ) - -DEPENDENCIES = ["udp"] - -CONFIG_SCHEMA = All( - binary_sensor.binary_sensor_schema().extend(SENSOR_SCHEMA), - has_at_least_one_key(CONF_ID, CONF_REMOTE_ID), - require_internal_with_name, -) - - -async def to_code(config): - var = await binary_sensor.new_binary_sensor(config) - comp = await cg.get_variable(config[CONF_UDP_ID]) - remote_id = str(config.get(CONF_REMOTE_ID) or config.get(CONF_ID)) - cg.add(comp.add_remote_binary_sensor(config[CONF_PROVIDER], remote_id, var)) diff --git a/esphome/components/udp/packet_transport/__init__.py b/esphome/components/udp/packet_transport/__init__.py new file mode 100644 index 0000000000..b6957a372b --- /dev/null +++ b/esphome/components/udp/packet_transport/__init__.py @@ -0,0 +1,29 @@ +import esphome.codegen as cg +from esphome.components.api import CONF_ENCRYPTION +from esphome.components.packet_transport import ( + CONF_PING_PONG_ENABLE, + PacketTransport, + new_packet_transport, + transport_schema, +) +from esphome.const import CONF_BINARY_SENSORS, CONF_SENSORS +from esphome.cpp_types import PollingComponent + +from .. import UDP_SCHEMA, register_udp_client, udp_ns + +UDPTransport = udp_ns.class_("UDPTransport", PacketTransport, PollingComponent) + +CONFIG_SCHEMA = transport_schema(UDPTransport).extend(UDP_SCHEMA) + + +async def to_code(config): + var, providers = await new_packet_transport(config) + udp_var = await register_udp_client(var, config) + if CONF_ENCRYPTION in config or providers: + cg.add(udp_var.set_should_listen()) + if ( + config[CONF_PING_PONG_ENABLE] + or config.get(CONF_SENSORS, ()) + or config.get(CONF_BINARY_SENSORS, ()) + ): + cg.add(udp_var.set_should_broadcast()) diff --git a/esphome/components/udp/packet_transport/udp_transport.cpp b/esphome/components/udp/packet_transport/udp_transport.cpp new file mode 100644 index 0000000000..3918760627 --- /dev/null +++ b/esphome/components/udp/packet_transport/udp_transport.cpp @@ -0,0 +1,36 @@ +#include "esphome/core/log.h" +#include "esphome/core/application.h" +#include "esphome/components/network/util.h" +#include "udp_transport.h" + +namespace esphome { +namespace udp { + +static const char *const TAG = "udp_transport"; + +bool UDPTransport::should_send() { return this->should_broadcast_ && network::is_connected(); } +void UDPTransport::setup() { + PacketTransport::setup(); + this->should_broadcast_ = this->ping_pong_enable_; +#ifdef USE_SENSOR + this->should_broadcast_ |= !this->sensors_.empty(); +#endif +#ifdef USE_BINARY_SENSOR + this->should_broadcast_ |= !this->binary_sensors_.empty(); +#endif + if (this->should_broadcast_) + this->parent_->set_should_broadcast(); + if (!this->providers_.empty() || this->is_encrypted_()) { + this->parent_->add_listener([this](std::vector &buf) { this->process_(buf); }); + } +} + +void UDPTransport::update() { + PacketTransport::update(); + this->updated_ = true; + this->resend_data_ = this->should_broadcast_; +} + +void UDPTransport::send_packet(std::vector &buf) const { this->parent_->send_packet(buf); } +} // namespace udp +} // namespace esphome diff --git a/esphome/components/udp/packet_transport/udp_transport.h b/esphome/components/udp/packet_transport/udp_transport.h new file mode 100644 index 0000000000..5a27bc32c7 --- /dev/null +++ b/esphome/components/udp/packet_transport/udp_transport.h @@ -0,0 +1,26 @@ +#pragma once + +#include "../udp_component.h" +#include "esphome/core/component.h" +#include "esphome/components/packet_transport/packet_transport.h" +#include + +namespace esphome { +namespace udp { + +class UDPTransport : public packet_transport::PacketTransport, public Parented { + public: + void setup() override; + void update() override; + + float get_setup_priority() const override { return setup_priority::AFTER_WIFI; } + + protected: + void send_packet(std::vector &buf) const override; + bool should_send() override; + bool should_broadcast_{false}; + size_t get_max_packet_size() override { return MAX_PACKET_SIZE; } +}; + +} // namespace udp +} // namespace esphome diff --git a/esphome/components/udp/sensor.py b/esphome/components/udp/sensor.py index 860c277c44..9ce05e7ffb 100644 --- a/esphome/components/udp/sensor.py +++ b/esphome/components/udp/sensor.py @@ -1,27 +1,5 @@ -import esphome.codegen as cg -from esphome.components.sensor import new_sensor, sensor_schema -from esphome.config_validation import All, has_at_least_one_key -from esphome.const import CONF_ID +import esphome.config_validation as cv -from . import ( - CONF_PROVIDER, - CONF_REMOTE_ID, - CONF_UDP_ID, - SENSOR_SCHEMA, - require_internal_with_name, +CONFIG_SCHEMA = cv.invalid( + "The 'udp.sensor' component has been migrated to the 'packet_transport.sensor' component." ) - -DEPENDENCIES = ["udp"] - -CONFIG_SCHEMA = All( - sensor_schema().extend(SENSOR_SCHEMA), - has_at_least_one_key(CONF_ID, CONF_REMOTE_ID), - require_internal_with_name, -) - - -async def to_code(config): - var = await new_sensor(config) - comp = await cg.get_variable(config[CONF_UDP_ID]) - remote_id = str(config.get(CONF_REMOTE_ID) or config.get(CONF_ID)) - cg.add(comp.add_remote_sensor(config[CONF_PROVIDER], remote_id, var)) diff --git a/esphome/components/udp/udp_component.cpp b/esphome/components/udp/udp_component.cpp index 59cba8c7fe..222c73f82e 100644 --- a/esphome/components/udp/udp_component.cpp +++ b/esphome/components/udp/udp_component.cpp @@ -1,164 +1,24 @@ +#include "esphome/core/defines.h" +#ifdef USE_NETWORK #include "esphome/core/log.h" #include "esphome/core/application.h" #include "esphome/components/network/util.h" #include "udp_component.h" -#include "esphome/components/xxtea/xxtea.h" - namespace esphome { namespace udp { -/** - * Structure of a data packet; everything is little-endian - * - * --- In clear text --- - * MAGIC_NUMBER: 16 bits - * host name length: 1 byte - * host name: (length) bytes - * padding: 0 or more null bytes to a 4 byte boundary - * - * --- Encrypted (if key set) ---- - * DATA_KEY: 1 byte: OR ROLLING_CODE_KEY: - * Rolling code (if enabled): 8 bytes - * Ping keys: if any - * repeat: - * PING_KEY: 1 byte - * ping code: 4 bytes - * Sensors: - * repeat: - * SENSOR_KEY: 1 byte - * float value: 4 bytes - * name length: 1 byte - * name - * Binary Sensors: - * repeat: - * BINARY_SENSOR_KEY: 1 byte - * bool value: 1 bytes - * name length: 1 byte - * name - * - * Padded to a 4 byte boundary with nulls - * - * Structure of a ping request packet: - * --- In clear text --- - * MAGIC_PING: 16 bits - * host name length: 1 byte - * host name: (length) bytes - * Ping key (4 bytes) - * - */ static const char *const TAG = "udp"; -static size_t round4(size_t value) { return (value + 3) & ~3; } - -union FuData { - uint32_t u32; - float f32; -}; - -static const size_t MAX_PACKET_SIZE = 508; -static const uint16_t MAGIC_NUMBER = 0x4553; -static const uint16_t MAGIC_PING = 0x5048; -static const uint32_t PREF_HASH = 0x45535043; -enum DataKey { - ZERO_FILL_KEY, - DATA_KEY, - SENSOR_KEY, - BINARY_SENSOR_KEY, - PING_KEY, - ROLLING_CODE_KEY, -}; - -static const size_t MAX_PING_KEYS = 4; - -static inline void add(std::vector &vec, uint32_t data) { - vec.push_back(data & 0xFF); - vec.push_back((data >> 8) & 0xFF); - vec.push_back((data >> 16) & 0xFF); - vec.push_back((data >> 24) & 0xFF); -} - -static inline uint32_t get_uint32(uint8_t *&buf) { - uint32_t data = *buf++; - data += *buf++ << 8; - data += *buf++ << 16; - data += *buf++ << 24; - return data; -} - -static inline uint16_t get_uint16(uint8_t *&buf) { - uint16_t data = *buf++; - data += *buf++ << 8; - return data; -} - -static inline void add(std::vector &vec, uint8_t data) { vec.push_back(data); } -static inline void add(std::vector &vec, uint16_t data) { - vec.push_back((uint8_t) data); - vec.push_back((uint8_t) (data >> 8)); -} -static inline void add(std::vector &vec, DataKey data) { vec.push_back(data); } -static void add(std::vector &vec, const char *str) { - auto len = strlen(str); - vec.push_back(len); - for (size_t i = 0; i != len; i++) { - vec.push_back(*str++); - } -} - void UDPComponent::setup() { - this->name_ = App.get_name().c_str(); - if (strlen(this->name_) > 255) { - this->mark_failed(); - this->status_set_error("Device name exceeds 255 chars"); - return; - } - this->resend_ping_key_ = this->ping_pong_enable_; - // restore the upper 32 bits of the rolling code, increment and save. - this->pref_ = global_preferences->make_preference(PREF_HASH, true); - this->pref_.load(&this->rolling_code_[1]); - this->rolling_code_[1]++; - this->pref_.save(&this->rolling_code_[1]); - this->ping_key_ = random_uint32(); - ESP_LOGV(TAG, "Rolling code incremented, upper part now %u", (unsigned) this->rolling_code_[1]); -#ifdef USE_SENSOR - for (auto &sensor : this->sensors_) { - sensor.sensor->add_on_state_callback([this, &sensor](float x) { - this->updated_ = true; - sensor.updated = true; - }); - } -#endif -#ifdef USE_BINARY_SENSOR - for (auto &sensor : this->binary_sensors_) { - sensor.sensor->add_on_state_callback([this, &sensor](bool value) { - this->updated_ = true; - sensor.updated = true; - }); - } -#endif - this->should_send_ = this->ping_pong_enable_; -#ifdef USE_SENSOR - this->should_send_ |= !this->sensors_.empty(); -#endif -#ifdef USE_BINARY_SENSOR - this->should_send_ |= !this->binary_sensors_.empty(); -#endif - this->should_listen_ = !this->providers_.empty() || this->is_encrypted_(); - // initialise the header. This is invariant. - add(this->header_, MAGIC_NUMBER); - add(this->header_, this->name_); - // pad to a multiple of 4 bytes - while (this->header_.size() & 0x3) - this->header_.push_back(0); #if defined(USE_SOCKET_IMPL_BSD_SOCKETS) || defined(USE_SOCKET_IMPL_LWIP_SOCKETS) for (const auto &address : this->addresses_) { struct sockaddr saddr {}; - socket::set_sockaddr(&saddr, sizeof(saddr), address, this->port_); + socket::set_sockaddr(&saddr, sizeof(saddr), address, this->broadcast_port_); this->sockaddrs_.push_back(saddr); } // set up broadcast socket - if (this->should_send_) { + if (this->should_broadcast_) { this->broadcast_socket_ = socket::socket(AF_INET, SOCK_DGRAM, IPPROTO_IP); if (this->broadcast_socket_ == nullptr) { this->mark_failed(); @@ -202,14 +62,14 @@ void UDPComponent::setup() { server.sin_family = AF_INET; server.sin_addr.s_addr = ESPHOME_INADDR_ANY; - server.sin_port = htons(this->port_); + server.sin_port = htons(this->listen_port_); if (this->listen_address_.has_value()) { struct ip_mreq imreq = {}; imreq.imr_interface.s_addr = ESPHOME_INADDR_ANY; inet_aton(this->listen_address_.value().str().c_str(), &imreq.imr_multiaddr); server.sin_addr.s_addr = imreq.imr_multiaddr.s_addr; - ESP_LOGV(TAG, "Join multicast %s", this->listen_address_.value().str().c_str()); + ESP_LOGD(TAG, "Join multicast %s", this->listen_address_.value().str().c_str()); err = this->listen_socket_->setsockopt(IPPROTO_IP, IP_ADD_MEMBERSHIP, &imreq, sizeof(imreq)); if (err < 0) { ESP_LOGE(TAG, "Failed to set IP_ADD_MEMBERSHIP. Error %d", errno); @@ -236,341 +96,48 @@ void UDPComponent::setup() { this->ipaddrs_.push_back(ipaddr); } if (this->should_listen_) - this->udp_client_.begin(this->port_); + this->udp_client_.begin(this->listen_port_); #endif } -void UDPComponent::init_data_() { - this->data_.clear(); - if (this->rolling_code_enable_) { - add(this->data_, ROLLING_CODE_KEY); - add(this->data_, this->rolling_code_[0]); - add(this->data_, this->rolling_code_[1]); - this->increment_code_(); - } else { - add(this->data_, DATA_KEY); - } - for (auto pkey : this->ping_keys_) { - add(this->data_, PING_KEY); - add(this->data_, pkey.second); - } -} - -void UDPComponent::flush_() { - if (!network::is_connected() || this->data_.empty()) - return; - uint32_t buffer[MAX_PACKET_SIZE / 4]; - memset(buffer, 0, sizeof buffer); - // len must be a multiple of 4 - auto header_len = round4(this->header_.size()) / 4; - auto len = round4(data_.size()) / 4; - memcpy(buffer, this->header_.data(), this->header_.size()); - memcpy(buffer + header_len, this->data_.data(), this->data_.size()); - if (this->is_encrypted_()) { - xxtea::encrypt(buffer + header_len, len, (uint32_t *) this->encryption_key_.data()); - } - auto total_len = (header_len + len) * 4; - this->send_packet_(buffer, total_len); -} - -void UDPComponent::add_binary_data_(uint8_t key, const char *id, bool data) { - auto len = 1 + 1 + 1 + strlen(id); - if (len + this->header_.size() + this->data_.size() > MAX_PACKET_SIZE) { - this->flush_(); - } - add(this->data_, key); - add(this->data_, (uint8_t) data); - add(this->data_, id); -} -void UDPComponent::add_data_(uint8_t key, const char *id, float data) { - FuData udata{.f32 = data}; - this->add_data_(key, id, udata.u32); -} - -void UDPComponent::add_data_(uint8_t key, const char *id, uint32_t data) { - auto len = 4 + 1 + 1 + strlen(id); - if (len + this->header_.size() + this->data_.size() > MAX_PACKET_SIZE) { - this->flush_(); - } - add(this->data_, key); - add(this->data_, data); - add(this->data_, id); -} -void UDPComponent::send_data_(bool all) { - if (!this->should_send_ || !network::is_connected()) - return; - this->init_data_(); -#ifdef USE_SENSOR - for (auto &sensor : this->sensors_) { - if (all || sensor.updated) { - sensor.updated = false; - this->add_data_(SENSOR_KEY, sensor.id, sensor.sensor->get_state()); - } - } -#endif -#ifdef USE_BINARY_SENSOR - for (auto &sensor : this->binary_sensors_) { - if (all || sensor.updated) { - sensor.updated = false; - this->add_binary_data_(BINARY_SENSOR_KEY, sensor.id, sensor.sensor->state); - } - } -#endif - this->flush_(); - this->updated_ = false; - this->resend_data_ = false; -} - -void UDPComponent::update() { - this->updated_ = true; - this->resend_data_ = this->should_send_; - auto now = millis() / 1000; - if (this->last_key_time_ + this->ping_pong_recyle_time_ < now) { - this->resend_ping_key_ = this->ping_pong_enable_; - this->last_key_time_ = now; - } -} - void UDPComponent::loop() { - uint8_t buf[MAX_PACKET_SIZE]; + auto buf = std::vector(MAX_PACKET_SIZE); if (this->should_listen_) { for (;;) { #if defined(USE_SOCKET_IMPL_BSD_SOCKETS) || defined(USE_SOCKET_IMPL_LWIP_SOCKETS) - auto len = this->listen_socket_->read(buf, sizeof(buf)); + auto len = this->listen_socket_->read(buf.data(), buf.size()); #endif #ifdef USE_SOCKET_IMPL_LWIP_TCP auto len = this->udp_client_.parsePacket(); if (len > 0) - len = this->udp_client_.read(buf, sizeof(buf)); + len = this->udp_client_.read(buf.data(), buf.size()); #endif - if (len > 0) { - this->process_(buf, len); - continue; - } - break; + if (len <= 0) + break; + buf.resize(len); + ESP_LOGV(TAG, "Received packet of length %zu", len); + this->packet_listeners_.call(buf); } } - if (this->resend_ping_key_) - this->send_ping_pong_request_(); - if (this->updated_) { - this->send_data_(this->resend_data_); - } -} - -void UDPComponent::add_key_(const char *name, uint32_t key) { - if (!this->is_encrypted_()) - return; - if (this->ping_keys_.count(name) == 0 && this->ping_keys_.size() == MAX_PING_KEYS) { - ESP_LOGW(TAG, "Ping key from %s discarded", name); - return; - } - this->ping_keys_[name] = key; - this->resend_data_ = true; - ESP_LOGV(TAG, "Ping key from %s now %X", name, (unsigned) key); -} - -void UDPComponent::process_ping_request_(const char *name, uint8_t *ptr, size_t len) { - if (len != 4) { - ESP_LOGW(TAG, "Bad ping request"); - return; - } - auto key = get_uint32(ptr); - this->add_key_(name, key); - ESP_LOGV(TAG, "Updated ping key for %s to %08X", name, (unsigned) key); -} - -static bool process_rolling_code(Provider &provider, uint8_t *&buf, const uint8_t *end) { - if (end - buf < 8) - return false; - auto code0 = get_uint32(buf); - auto code1 = get_uint32(buf); - if (code1 < provider.last_code[1] || (code1 == provider.last_code[1] && code0 <= provider.last_code[0])) { - ESP_LOGW(TAG, "Rolling code for %s %08lX:%08lX is old", provider.name, (unsigned long) code1, - (unsigned long) code0); - return false; - } - provider.last_code[0] = code0; - provider.last_code[1] = code1; - return true; -} - -/** - * Process a received packet - */ -void UDPComponent::process_(uint8_t *buf, const size_t len) { - auto ping_key_seen = !this->ping_pong_enable_; - if (len < 8) { - ESP_LOGV(TAG, "Bad length %zu", len); - return; - } - char namebuf[256]{}; - uint8_t byte; - uint8_t *start_ptr = buf; - const uint8_t *end = buf + len; - FuData rdata{}; - auto magic = get_uint16(buf); - if (magic != MAGIC_NUMBER && magic != MAGIC_PING) { - ESP_LOGV(TAG, "Bad magic %X", magic); - return; - } - - auto hlen = *buf++; - if (hlen > len - 3) { - ESP_LOGV(TAG, "Bad hostname length %u > %zu", hlen, len - 3); - return; - } - memcpy(namebuf, buf, hlen); - if (strcmp(this->name_, namebuf) == 0) { - ESP_LOGV(TAG, "Ignoring our own data"); - return; - } - buf += hlen; - if (magic == MAGIC_PING) { - this->process_ping_request_(namebuf, buf, end - buf); - return; - } - if (round4(len) != len) { - ESP_LOGW(TAG, "Bad length %zu", len); - return; - } - hlen = round4(hlen + 3); - buf = start_ptr + hlen; - if (buf == end) { - ESP_LOGV(TAG, "No data after header"); - return; - } - - if (this->providers_.count(namebuf) == 0) { - ESP_LOGVV(TAG, "Unknown hostname %s", namebuf); - return; - } - auto &provider = this->providers_[namebuf]; - // if encryption not used with this host, ping check is pointless since it would be easily spoofed. - if (provider.encryption_key.empty()) - ping_key_seen = true; - - ESP_LOGV(TAG, "Found hostname %s", namebuf); -#ifdef USE_SENSOR - auto &sensors = this->remote_sensors_[namebuf]; -#endif -#ifdef USE_BINARY_SENSOR - auto &binary_sensors = this->remote_binary_sensors_[namebuf]; -#endif - - if (!provider.encryption_key.empty()) { - xxtea::decrypt((uint32_t *) buf, (end - buf) / 4, (uint32_t *) provider.encryption_key.data()); - } - byte = *buf++; - if (byte == ROLLING_CODE_KEY) { - if (!process_rolling_code(provider, buf, end)) - return; - } else if (byte != DATA_KEY) { - ESP_LOGV(TAG, "Expected rolling_key or data_key, got %X", byte); - return; - } - while (buf < end) { - byte = *buf++; - if (byte == ZERO_FILL_KEY) - continue; - if (byte == PING_KEY) { - if (end - buf < 4) { - ESP_LOGV(TAG, "PING_KEY requires 4 more bytes"); - return; - } - auto key = get_uint32(buf); - if (key == this->ping_key_) { - ping_key_seen = true; - ESP_LOGV(TAG, "Found good ping key %X", (unsigned) key); - } else { - ESP_LOGV(TAG, "Unknown ping key %X", (unsigned) key); - } - continue; - } - if (!ping_key_seen) { - ESP_LOGW(TAG, "Ping key not seen"); - this->resend_ping_key_ = true; - break; - } - if (byte == BINARY_SENSOR_KEY) { - if (end - buf < 3) { - ESP_LOGV(TAG, "Binary sensor key requires at least 3 more bytes"); - return; - } - rdata.u32 = *buf++; - } else if (byte == SENSOR_KEY) { - if (end - buf < 6) { - ESP_LOGV(TAG, "Sensor key requires at least 6 more bytes"); - return; - } - rdata.u32 = get_uint32(buf); - } else { - ESP_LOGW(TAG, "Unknown key byte %X", byte); - return; - } - - hlen = *buf++; - if (end - buf < hlen) { - ESP_LOGV(TAG, "Name length of %u not available", hlen); - return; - } - memset(namebuf, 0, sizeof namebuf); - memcpy(namebuf, buf, hlen); - ESP_LOGV(TAG, "Found sensor key %d, id %s, data %lX", byte, namebuf, (unsigned long) rdata.u32); - buf += hlen; -#ifdef USE_SENSOR - if (byte == SENSOR_KEY && sensors.count(namebuf) != 0) - sensors[namebuf]->publish_state(rdata.f32); -#endif -#ifdef USE_BINARY_SENSOR - if (byte == BINARY_SENSOR_KEY && binary_sensors.count(namebuf) != 0) - binary_sensors[namebuf]->publish_state(rdata.u32 != 0); -#endif - } } void UDPComponent::dump_config() { ESP_LOGCONFIG(TAG, "UDP:"); - ESP_LOGCONFIG(TAG, " Port: %u", this->port_); - ESP_LOGCONFIG(TAG, " Encrypted: %s", YESNO(this->is_encrypted_())); - ESP_LOGCONFIG(TAG, " Ping-pong: %s", YESNO(this->ping_pong_enable_)); + ESP_LOGCONFIG(TAG, " Listen Port: %u", this->listen_port_); + ESP_LOGCONFIG(TAG, " Broadcast Port: %u", this->broadcast_port_); for (const auto &address : this->addresses_) ESP_LOGCONFIG(TAG, " Address: %s", address.c_str()); if (this->listen_address_.has_value()) { ESP_LOGCONFIG(TAG, " Listen address: %s", this->listen_address_.value().str().c_str()); } -#ifdef USE_SENSOR - for (auto sensor : this->sensors_) - ESP_LOGCONFIG(TAG, " Sensor: %s", sensor.id); -#endif -#ifdef USE_BINARY_SENSOR - for (auto sensor : this->binary_sensors_) - ESP_LOGCONFIG(TAG, " Binary Sensor: %s", sensor.id); -#endif - for (const auto &host : this->providers_) { - ESP_LOGCONFIG(TAG, " Remote host: %s", host.first.c_str()); - ESP_LOGCONFIG(TAG, " Encrypted: %s", YESNO(!host.second.encryption_key.empty())); -#ifdef USE_SENSOR - for (const auto &sensor : this->remote_sensors_[host.first.c_str()]) - ESP_LOGCONFIG(TAG, " Sensor: %s", sensor.first.c_str()); -#endif -#ifdef USE_BINARY_SENSOR - for (const auto &sensor : this->remote_binary_sensors_[host.first.c_str()]) - ESP_LOGCONFIG(TAG, " Binary Sensor: %s", sensor.first.c_str()); -#endif - } + ESP_LOGCONFIG(TAG, " Broadcasting: %s", YESNO(this->should_broadcast_)); + ESP_LOGCONFIG(TAG, " Listening: %s", YESNO(this->should_listen_)); } -void UDPComponent::increment_code_() { - if (this->rolling_code_enable_) { - if (++this->rolling_code_[0] == 0) { - this->rolling_code_[1]++; - this->pref_.save(&this->rolling_code_[1]); - } - } -} -void UDPComponent::send_packet_(void *data, size_t len) { + +void UDPComponent::send_packet(const uint8_t *data, size_t size) { #if defined(USE_SOCKET_IMPL_BSD_SOCKETS) || defined(USE_SOCKET_IMPL_LWIP_SOCKETS) for (const auto &saddr : this->sockaddrs_) { - auto result = this->broadcast_socket_->sendto(data, len, 0, &saddr, sizeof(saddr)); + auto result = this->broadcast_socket_->sendto(data, size, 0, &saddr, sizeof(saddr)); if (result < 0) ESP_LOGW(TAG, "sendto() error %d", errno); } @@ -578,8 +145,8 @@ void UDPComponent::send_packet_(void *data, size_t len) { #ifdef USE_SOCKET_IMPL_LWIP_TCP auto iface = IPAddress(0, 0, 0, 0); for (const auto &saddr : this->ipaddrs_) { - if (this->udp_client_.beginPacketMulticast(saddr, this->port_, iface, 128) != 0) { - this->udp_client_.write((const uint8_t *) data, len); + if (this->udp_client_.beginPacketMulticast(saddr, this->broadcast_port_, iface, 128) != 0) { + this->udp_client_.write(data, size); auto result = this->udp_client_.endPacket(); if (result == 0) ESP_LOGW(TAG, "udp.write() error"); @@ -587,18 +154,7 @@ void UDPComponent::send_packet_(void *data, size_t len) { } #endif } - -void UDPComponent::send_ping_pong_request_() { - if (!this->ping_pong_enable_ || !network::is_connected()) - return; - this->ping_key_ = random_uint32(); - this->ping_header_.clear(); - add(this->ping_header_, MAGIC_PING); - add(this->ping_header_, this->name_); - add(this->ping_header_, this->ping_key_); - this->send_packet_(this->ping_header_.data(), this->ping_header_.size()); - this->resend_ping_key_ = false; - ESP_LOGV(TAG, "Sent new ping request %08X", (unsigned) this->ping_key_); -} } // namespace udp } // namespace esphome + +#endif diff --git a/esphome/components/udp/udp_component.h b/esphome/components/udp/udp_component.h index 02f998ded7..25909eba1d 100644 --- a/esphome/components/udp/udp_component.h +++ b/esphome/components/udp/udp_component.h @@ -1,13 +1,8 @@ #pragma once -#include "esphome/core/component.h" +#include "esphome/core/defines.h" +#ifdef USE_NETWORK #include "esphome/components/network/ip_address.h" -#ifdef USE_SENSOR -#include "esphome/components/sensor/sensor.h" -#endif -#ifdef USE_BINARY_SENSOR -#include "esphome/components/binary_sensor/binary_sensor.h" -#endif #if defined(USE_SOCKET_IMPL_BSD_SOCKETS) || defined(USE_SOCKET_IMPL_LWIP_SOCKETS) #include "esphome/components/socket/socket.h" #endif @@ -15,116 +10,35 @@ #include #endif #include -#include namespace esphome { namespace udp { -struct Provider { - std::vector encryption_key; - const char *name; - uint32_t last_code[2]; -}; - -#ifdef USE_SENSOR -struct Sensor { - sensor::Sensor *sensor; - const char *id; - bool updated; -}; -#endif -#ifdef USE_BINARY_SENSOR -struct BinarySensor { - binary_sensor::BinarySensor *sensor; - const char *id; - bool updated; -}; -#endif - -class UDPComponent : public PollingComponent { +static const size_t MAX_PACKET_SIZE = 508; +class UDPComponent : public Component { public: + void add_address(const char *addr) { this->addresses_.emplace_back(addr); } + void set_listen_address(const char *listen_addr) { this->listen_address_ = network::IPAddress(listen_addr); } + void set_listen_port(uint16_t port) { this->listen_port_ = port; } + void set_broadcast_port(uint16_t port) { this->broadcast_port_ = port; } + void set_should_broadcast() { this->should_broadcast_ = true; } + void set_should_listen() { this->should_listen_ = true; } + void add_listener(std::function &)> &&listener) { + this->packet_listeners_.add(std::move(listener)); + } void setup() override; void loop() override; - void update() override; void dump_config() override; - -#ifdef USE_SENSOR - void add_sensor(const char *id, sensor::Sensor *sensor) { - Sensor st{sensor, id, true}; - this->sensors_.push_back(st); - } - void add_remote_sensor(const char *hostname, const char *remote_id, sensor::Sensor *sensor) { - this->add_provider(hostname); - this->remote_sensors_[hostname][remote_id] = sensor; - } -#endif -#ifdef USE_BINARY_SENSOR - void add_binary_sensor(const char *id, binary_sensor::BinarySensor *sensor) { - BinarySensor st{sensor, id, true}; - this->binary_sensors_.push_back(st); - } - - void add_remote_binary_sensor(const char *hostname, const char *remote_id, binary_sensor::BinarySensor *sensor) { - this->add_provider(hostname); - this->remote_binary_sensors_[hostname][remote_id] = sensor; - } -#endif - void add_address(const char *addr) { this->addresses_.emplace_back(addr); } -#ifdef USE_NETWORK - void set_listen_address(const char *listen_addr) { this->listen_address_ = network::IPAddress(listen_addr); } -#endif - void set_port(uint16_t port) { this->port_ = port; } - float get_setup_priority() const override { return setup_priority::AFTER_WIFI; } - - void add_provider(const char *hostname) { - if (this->providers_.count(hostname) == 0) { - Provider provider; - provider.encryption_key = std::vector{}; - provider.last_code[0] = 0; - provider.last_code[1] = 0; - provider.name = hostname; - this->providers_[hostname] = provider; -#ifdef USE_SENSOR - this->remote_sensors_[hostname] = std::map(); -#endif -#ifdef USE_BINARY_SENSOR - this->remote_binary_sensors_[hostname] = std::map(); -#endif - } - } - - void set_encryption_key(std::vector key) { this->encryption_key_ = std::move(key); } - void set_rolling_code_enable(bool enable) { this->rolling_code_enable_ = enable; } - void set_ping_pong_enable(bool enable) { this->ping_pong_enable_ = enable; } - void set_ping_pong_recycle_time(uint32_t recycle_time) { this->ping_pong_recyle_time_ = recycle_time; } - void set_provider_encryption(const char *name, std::vector key) { - this->providers_[name].encryption_key = std::move(key); - } + void send_packet(const uint8_t *data, size_t size); + void send_packet(std::vector &buf) { this->send_packet(buf.data(), buf.size()); } + float get_setup_priority() const override { return setup_priority::AFTER_WIFI; }; protected: - void send_data_(bool all); - void process_(uint8_t *buf, size_t len); - void flush_(); - void add_data_(uint8_t key, const char *id, float data); - void add_data_(uint8_t key, const char *id, uint32_t data); - void increment_code_(); - void add_binary_data_(uint8_t key, const char *id, bool data); - void init_data_(); - - bool updated_{}; - uint16_t port_{18511}; - uint32_t ping_key_{}; - uint32_t rolling_code_[2]{}; - bool rolling_code_enable_{}; - bool ping_pong_enable_{}; - uint32_t ping_pong_recyle_time_{}; - uint32_t last_key_time_{}; - bool resend_ping_key_{}; - bool resend_data_{}; - bool should_send_{}; - const char *name_{}; + uint16_t listen_port_{}; + uint16_t broadcast_port_{}; + bool should_broadcast_{}; bool should_listen_{}; - ESPPreferenceObject pref_; + CallbackManager &)> packet_listeners_{}; #if defined(USE_SOCKET_IMPL_BSD_SOCKETS) || defined(USE_SOCKET_IMPL_LWIP_SOCKETS) std::unique_ptr broadcast_socket_ = nullptr; @@ -135,32 +49,11 @@ class UDPComponent : public PollingComponent { std::vector ipaddrs_{}; WiFiUDP udp_client_{}; #endif - std::vector encryption_key_{}; std::vector addresses_{}; -#ifdef USE_SENSOR - std::vector sensors_{}; - std::map> remote_sensors_{}; -#endif -#ifdef USE_BINARY_SENSOR - std::vector binary_sensors_{}; - std::map> remote_binary_sensors_{}; -#endif -#ifdef USE_NETWORK optional listen_address_{}; -#endif - std::map providers_{}; - std::vector ping_header_{}; - std::vector header_{}; - std::vector data_{}; - std::map ping_keys_{}; - void add_key_(const char *name, uint32_t key); - void send_ping_pong_request_(); - void send_packet_(void *data, size_t len); - void process_ping_request_(const char *name, uint8_t *ptr, size_t len); - - inline bool is_encrypted_() { return !this->encryption_key_.empty(); } }; } // namespace udp } // namespace esphome +#endif diff --git a/tests/components/packet_transport/common.yaml b/tests/components/packet_transport/common.yaml new file mode 100644 index 0000000000..cbb34c4572 --- /dev/null +++ b/tests/components/packet_transport/common.yaml @@ -0,0 +1,40 @@ +wifi: + ssid: MySSID + password: password1 + +udp: + listen_address: 239.0.60.53 + addresses: ["239.0.60.53"] + +packet_transport: + platform: udp + update_interval: 5s + encryption: "our key goes here" + rolling_code_enable: true + ping_pong_enable: true + binary_sensors: + - binary_sensor_id1 + - id: binary_sensor_id1 + broadcast_id: other_id + sensors: + - sensor_id1 + - id: sensor_id1 + broadcast_id: other_id + providers: + - name: some-device-name + encryption: "their key goes here" + +sensor: + - platform: template + id: sensor_id1 + - platform: packet_transport + provider: some-device-name + id: our_id + remote_id: some_sensor_id + +binary_sensor: + - platform: packet_transport + provider: unencrypted-device + id: other_binary_sensor_id + - platform: template + id: binary_sensor_id1 diff --git a/tests/components/packet_transport/test.bk72xx-ard.yaml b/tests/components/packet_transport/test.bk72xx-ard.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/packet_transport/test.bk72xx-ard.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/packet_transport/test.esp32-ard.yaml b/tests/components/packet_transport/test.esp32-ard.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/packet_transport/test.esp32-ard.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/packet_transport/test.esp32-c3-ard.yaml b/tests/components/packet_transport/test.esp32-c3-ard.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/packet_transport/test.esp32-c3-ard.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/packet_transport/test.esp32-c3-idf.yaml b/tests/components/packet_transport/test.esp32-c3-idf.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/packet_transport/test.esp32-c3-idf.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/packet_transport/test.esp32-idf.yaml b/tests/components/packet_transport/test.esp32-idf.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/packet_transport/test.esp32-idf.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/packet_transport/test.esp8266-ard.yaml b/tests/components/packet_transport/test.esp8266-ard.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/packet_transport/test.esp8266-ard.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/packet_transport/test.host.yaml b/tests/components/packet_transport/test.host.yaml new file mode 100644 index 0000000000..e735c37e4d --- /dev/null +++ b/tests/components/packet_transport/test.host.yaml @@ -0,0 +1,4 @@ +packages: + common: !include common.yaml + +wifi: !remove diff --git a/tests/components/packet_transport/test.rp2040-ard.yaml b/tests/components/packet_transport/test.rp2040-ard.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/packet_transport/test.rp2040-ard.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/udp/common.yaml b/tests/components/udp/common.yaml index e533cb965e..79da02a692 100644 --- a/tests/components/udp/common.yaml +++ b/tests/components/udp/common.yaml @@ -3,34 +3,18 @@ wifi: password: password1 udp: - update_interval: 5s - encryption: "our key goes here" - rolling_code_enable: true - ping_pong_enable: true + id: my_udp listen_address: 239.0.60.53 - binary_sensors: - - binary_sensor_id1 - - id: binary_sensor_id1 - broadcast_id: other_id - sensors: - - sensor_id1 - - id: sensor_id1 - broadcast_id: other_id - providers: - - name: some-device-name - encryption: "their key goes here" + addresses: ["239.0.60.53"] + on_receive: + - logger.log: + format: "Received %d bytes" + args: [data.size()] + - udp.write: + id: my_udp + data: "hello world" + - udp.write: + id: my_udp + data: !lambda |- + return std::vector{1,3,4,5,6}; -sensor: - - platform: template - id: sensor_id1 - - platform: udp - provider: some-device-name - id: our_id - remote_id: some_sensor_id - -binary_sensor: - - platform: udp - provider: unencrypted-device - id: other_binary_sensor_id - - platform: template - id: binary_sensor_id1 From 4dc6cbe2d7d84d5d393879c4987832608646d1c3 Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Mon, 5 May 2025 10:02:33 +1000 Subject: [PATCH 088/102] [esp32_ble_server] Add appearance advertising field (#8672) --- esphome/components/esp32_ble/ble.cpp | 1 + esphome/components/esp32_ble/ble.h | 8 +++++--- esphome/components/esp32_ble/ble_advertising.h | 1 + esphome/components/esp32_ble_server/__init__.py | 3 +++ tests/components/esp32_ble_server/common.yaml | 1 + 5 files changed, 11 insertions(+), 3 deletions(-) diff --git a/esphome/components/esp32_ble/ble.cpp b/esphome/components/esp32_ble/ble.cpp index ab2647b738..fc1303673f 100644 --- a/esphome/components/esp32_ble/ble.cpp +++ b/esphome/components/esp32_ble/ble.cpp @@ -110,6 +110,7 @@ void ESP32BLE::advertising_init_() { this->advertising_->set_scan_response(true); this->advertising_->set_min_preferred_interval(0x06); + this->advertising_->set_appearance(this->appearance_); } bool ESP32BLE::ble_setup_() { diff --git a/esphome/components/esp32_ble/ble.h b/esphome/components/esp32_ble/ble.h index ed7575f128..13ec3b6dd9 100644 --- a/esphome/components/esp32_ble/ble.h +++ b/esphome/components/esp32_ble/ble.h @@ -95,6 +95,7 @@ class ESP32BLE : public Component { void advertising_start(); void advertising_set_service_data(const std::vector &data); void advertising_set_manufacturer_data(const std::vector &data); + void advertising_set_appearance(uint16_t appearance) { this->appearance_ = appearance; } void advertising_add_service_uuid(ESPBTUUID uuid); void advertising_remove_service_uuid(ESPBTUUID uuid); void advertising_register_raw_advertisement_callback(std::function &&callback); @@ -128,11 +129,12 @@ class ESP32BLE : public Component { BLEComponentState state_{BLE_COMPONENT_STATE_OFF}; Queue ble_events_; - BLEAdvertising *advertising_; + BLEAdvertising *advertising_{}; esp_ble_io_cap_t io_cap_{ESP_IO_CAP_NONE}; - uint32_t advertising_cycle_time_; - bool enable_on_boot_; + uint32_t advertising_cycle_time_{}; + bool enable_on_boot_{}; optional name_; + uint16_t appearance_{0}; }; // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) diff --git a/esphome/components/esp32_ble/ble_advertising.h b/esphome/components/esp32_ble/ble_advertising.h index 946e414c1d..0b2142115d 100644 --- a/esphome/components/esp32_ble/ble_advertising.h +++ b/esphome/components/esp32_ble/ble_advertising.h @@ -32,6 +32,7 @@ class BLEAdvertising { void set_scan_response(bool scan_response) { this->scan_response_ = scan_response; } void set_min_preferred_interval(uint16_t interval) { this->advertising_data_.min_interval = interval; } void set_manufacturer_data(const std::vector &data); + void set_appearance(uint16_t appearance) { this->advertising_data_.appearance = appearance; } void set_service_data(const std::vector &data); void register_raw_advertisement_callback(std::function &&callback); diff --git a/esphome/components/esp32_ble_server/__init__.py b/esphome/components/esp32_ble_server/__init__.py index ab8e27ec43..0fcb5c9822 100644 --- a/esphome/components/esp32_ble_server/__init__.py +++ b/esphome/components/esp32_ble_server/__init__.py @@ -32,6 +32,7 @@ DEPENDENCIES = ["esp32"] DOMAIN = "esp32_ble_server" CONF_ADVERTISE = "advertise" +CONF_APPEARANCE = "appearance" CONF_BROADCAST = "broadcast" CONF_CHARACTERISTICS = "characteristics" CONF_DESCRIPTION = "description" @@ -421,6 +422,7 @@ CONFIG_SCHEMA = cv.Schema( cv.GenerateID(): cv.declare_id(BLEServer), cv.GenerateID(esp32_ble.CONF_BLE_ID): cv.use_id(esp32_ble.ESP32BLE), cv.Optional(CONF_MANUFACTURER): value_schema("string", templatable=False), + cv.Optional(CONF_APPEARANCE, default=0): cv.uint16_t, cv.Optional(CONF_MODEL): value_schema("string", templatable=False), cv.Optional(CONF_FIRMWARE_VERSION): value_schema("string", templatable=False), cv.Optional(CONF_MANUFACTURER_DATA): cv.Schema([cv.uint8_t]), @@ -531,6 +533,7 @@ async def to_code(config): cg.add(parent.register_gatts_event_handler(var)) cg.add(parent.register_ble_status_event_handler(var)) cg.add(var.set_parent(parent)) + cg.add(parent.advertising_set_appearance(config[CONF_APPEARANCE])) if CONF_MANUFACTURER_DATA in config: cg.add(var.set_manufacturer_data(config[CONF_MANUFACTURER_DATA])) for service_config in config[CONF_SERVICES]: diff --git a/tests/components/esp32_ble_server/common.yaml b/tests/components/esp32_ble_server/common.yaml index 696f4ea8fe..e9576a8262 100644 --- a/tests/components/esp32_ble_server/common.yaml +++ b/tests/components/esp32_ble_server/common.yaml @@ -2,6 +2,7 @@ esp32_ble_server: id: ble_server manufacturer_data: [0x72, 0x4, 0x00, 0x23] manufacturer: ESPHome + appearance: 0x1 model: Test on_connect: - lambda: |- From 3ed03edfec82740ebecda53a6100470fbe6570ee Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Mon, 5 May 2025 10:04:33 +1000 Subject: [PATCH 089/102] [display] Fix Rect::inside (#8679) --- esphome/components/display/rect.cpp | 13 ++++--------- esphome/components/display/rect.h | 2 +- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/esphome/components/display/rect.cpp b/esphome/components/display/rect.cpp index 49bb7d025f..2c41127860 100644 --- a/esphome/components/display/rect.cpp +++ b/esphome/components/display/rect.cpp @@ -69,21 +69,16 @@ bool Rect::inside(int16_t test_x, int16_t test_y, bool absolute) const { // NOL return true; } if (absolute) { - return ((test_x >= this->x) && (test_x <= this->x2()) && (test_y >= this->y) && (test_y <= this->y2())); - } else { - return ((test_x >= 0) && (test_x <= this->w) && (test_y >= 0) && (test_y <= this->h)); + return test_x >= this->x && test_x < this->x2() && test_y >= this->y && test_y < this->y2(); } + return test_x >= 0 && test_x < this->w && test_y >= 0 && test_y < this->h; } -bool Rect::inside(Rect rect, bool absolute) const { +bool Rect::inside(Rect rect) const { if (!this->is_set() || !rect.is_set()) { return true; } - if (absolute) { - return ((rect.x <= this->x2()) && (rect.x2() >= this->x) && (rect.y <= this->y2()) && (rect.y2() >= this->y)); - } else { - return ((rect.x <= this->w) && (rect.w >= 0) && (rect.y <= this->h) && (rect.h >= 0)); - } + return this->x2() >= rect.x && this->x <= rect.x2() && this->y2() >= rect.y && this->y <= rect.y2(); } void Rect::info(const std::string &prefix) { diff --git a/esphome/components/display/rect.h b/esphome/components/display/rect.h index f55c2fe201..5f11d94681 100644 --- a/esphome/components/display/rect.h +++ b/esphome/components/display/rect.h @@ -26,7 +26,7 @@ class Rect { void extend(Rect rect); void shrink(Rect rect); - bool inside(Rect rect, bool absolute = true) const; + bool inside(Rect rect) const; bool inside(int16_t test_x, int16_t test_y, bool absolute = true) const; bool equal(Rect rect) const; void info(const std::string &prefix = "rect info:"); From a31d8ec309f86214c457e6d8c73737baeb41af6b Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Mon, 5 May 2025 10:26:59 +1000 Subject: [PATCH 090/102] [packages] Allow list instead of dict for packages (#8688) --- esphome/components/packages/__init__.py | 64 ++++++++++--------- .../component_tests/packages/test_packages.py | 5 +- tests/components/packages/package.yaml | 3 + tests/components/packages/test.esp32-ard.yaml | 11 ++++ tests/components/packages/test.esp32-idf.yaml | 13 ++++ 5 files changed, 63 insertions(+), 33 deletions(-) create mode 100644 tests/components/packages/package.yaml create mode 100644 tests/components/packages/test.esp32-ard.yaml create mode 100644 tests/components/packages/test.esp32-idf.yaml diff --git a/esphome/components/packages/__init__.py b/esphome/components/packages/__init__.py index f4d11e7bd0..08ae798282 100644 --- a/esphome/components/packages/__init__.py +++ b/esphome/components/packages/__init__.py @@ -24,22 +24,13 @@ DOMAIN = CONF_PACKAGES def validate_git_package(config: dict): + if CONF_URL not in config: + return config + config = BASE_SCHEMA(config) new_config = config - for key, conf in config.items(): - if CONF_URL in conf: - try: - conf = BASE_SCHEMA(conf) - if CONF_FILE in conf: - new_config[key][CONF_FILES] = [conf[CONF_FILE]] - del new_config[key][CONF_FILE] - except cv.MultipleInvalid as e: - with cv.prepend_path([key]): - raise e - except cv.Invalid as e: - raise cv.Invalid( - "Extra keys not allowed in git based package", - path=[key] + e.path, - ) from e + if CONF_FILE in config: + new_config[CONF_FILES] = [config[CONF_FILE]] + del new_config[CONF_FILE] return new_config @@ -74,8 +65,8 @@ BASE_SCHEMA = cv.All( cv.Required(CONF_URL): cv.url, cv.Optional(CONF_USERNAME): cv.string, cv.Optional(CONF_PASSWORD): cv.string, - cv.Exclusive(CONF_FILE, "files"): validate_yaml_filename, - cv.Exclusive(CONF_FILES, "files"): cv.All( + cv.Exclusive(CONF_FILE, CONF_FILES): validate_yaml_filename, + cv.Exclusive(CONF_FILES, CONF_FILES): cv.All( cv.ensure_list( cv.Any( validate_yaml_filename, @@ -100,14 +91,17 @@ BASE_SCHEMA = cv.All( cv.has_at_least_one_key(CONF_FILE, CONF_FILES), ) +PACKAGE_SCHEMA = cv.All( + cv.Any(validate_source_shorthand, BASE_SCHEMA, dict), validate_git_package +) -CONFIG_SCHEMA = cv.All( +CONFIG_SCHEMA = cv.Any( cv.Schema( { - str: cv.Any(validate_source_shorthand, BASE_SCHEMA, dict), + str: PACKAGE_SCHEMA, } ), - validate_git_package, + cv.ensure_list(PACKAGE_SCHEMA), ) @@ -183,25 +177,33 @@ def _process_base_package(config: dict) -> dict: return {"packages": packages} +def _process_package(package_config, config): + recursive_package = package_config + if CONF_URL in package_config: + package_config = _process_base_package(package_config) + if isinstance(package_config, dict): + recursive_package = do_packages_pass(package_config) + config = merge_config(recursive_package, config) + return config + + def do_packages_pass(config: dict): if CONF_PACKAGES not in config: return config packages = config[CONF_PACKAGES] with cv.prepend_path(CONF_PACKAGES): packages = CONFIG_SCHEMA(packages) - if not isinstance(packages, dict): + if isinstance(packages, dict): + for package_name, package_config in reversed(packages.items()): + with cv.prepend_path(package_name): + config = _process_package(package_config, config) + elif isinstance(packages, list): + for package_config in reversed(packages): + config = _process_package(package_config, config) + else: raise cv.Invalid( - f"Packages must be a key to value mapping, got {type(packages)} instead" + f"Packages must be a key to value mapping or list, got {type(packages)} instead" ) - for package_name, package_config in reversed(packages.items()): - with cv.prepend_path(package_name): - recursive_package = package_config - if CONF_URL in package_config: - package_config = _process_base_package(package_config) - if isinstance(package_config, dict): - recursive_package = do_packages_pass(package_config) - config = merge_config(recursive_package, config) - del config[CONF_PACKAGES] return config diff --git a/tests/component_tests/packages/test_packages.py b/tests/component_tests/packages/test_packages.py index 3fbbf49afd..4712daad0d 100644 --- a/tests/component_tests/packages/test_packages.py +++ b/tests/component_tests/packages/test_packages.py @@ -76,10 +76,11 @@ def test_package_unused(basic_esphome, basic_wifi): def test_package_invalid_dict(basic_esphome, basic_wifi): """ - Ensures an error is raised if packages is not valid. + If a url: key is present, it's expected to be well-formed remote package spec. Ensure an error is raised if not. + Any other simple dict passed as a package will be merged as usual but may fail later validation. """ - config = {CONF_ESPHOME: basic_esphome, CONF_PACKAGES: basic_wifi} + config = {CONF_ESPHOME: basic_esphome, CONF_PACKAGES: basic_wifi | {CONF_URL: ""}} with pytest.raises(cv.Invalid): do_packages_pass(config) diff --git a/tests/components/packages/package.yaml b/tests/components/packages/package.yaml new file mode 100644 index 0000000000..672d66151e --- /dev/null +++ b/tests/components/packages/package.yaml @@ -0,0 +1,3 @@ +sensor: + - platform: template + id: package_sensor diff --git a/tests/components/packages/test.esp32-ard.yaml b/tests/components/packages/test.esp32-ard.yaml new file mode 100644 index 0000000000..d35c27d997 --- /dev/null +++ b/tests/components/packages/test.esp32-ard.yaml @@ -0,0 +1,11 @@ +packages: + - sensor: + - platform: template + id: inline_sensor + - !include package.yaml + - github://esphome/esphome/tests/components/template/common.yaml@dev + - url: https://github.com/esphome/esphome + file: tests/components/binary_sensor_map/common.yaml + ref: dev + refresh: 1d + diff --git a/tests/components/packages/test.esp32-idf.yaml b/tests/components/packages/test.esp32-idf.yaml new file mode 100644 index 0000000000..9f1484d1fd --- /dev/null +++ b/tests/components/packages/test.esp32-idf.yaml @@ -0,0 +1,13 @@ +packages: + sensor: + sensor: + - platform: template + id: inline_sensor + local: !include package.yaml + shorthand: github://esphome/esphome/tests/components/template/common.yaml@dev + github: + url: https://github.com/esphome/esphome + file: tests/components/binary_sensor_map/common.yaml + ref: dev + refresh: 1d + From 125aff79ec92c87bae2ce6b99599cceef46ffb5c Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Mon, 5 May 2025 10:28:00 +1000 Subject: [PATCH 091/102] [as3935_i2c] Remove redundant includes (#8677) --- esphome/components/as3935_i2c/as3935_i2c.h | 3 --- tests/components/as3935_i2c/test.esp32-ard.yaml | 7 ++++++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/esphome/components/as3935_i2c/as3935_i2c.h b/esphome/components/as3935_i2c/as3935_i2c.h index 1d16397bdf..a2a3d213ef 100644 --- a/esphome/components/as3935_i2c/as3935_i2c.h +++ b/esphome/components/as3935_i2c/as3935_i2c.h @@ -1,10 +1,7 @@ #pragma once -#include "esphome/core/component.h" #include "esphome/components/as3935/as3935.h" #include "esphome/components/i2c/i2c.h" -#include "esphome/components/sensor/sensor.h" -#include "esphome/components/binary_sensor/binary_sensor.h" namespace esphome { namespace as3935_i2c { diff --git a/tests/components/as3935_i2c/test.esp32-ard.yaml b/tests/components/as3935_i2c/test.esp32-ard.yaml index 2c57d412f6..52d5a045cb 100644 --- a/tests/components/as3935_i2c/test.esp32-ard.yaml +++ b/tests/components/as3935_i2c/test.esp32-ard.yaml @@ -3,4 +3,9 @@ substitutions: sda_pin: GPIO17 irq_pin: GPIO15 -<<: !include common.yaml +packages: + as3935: !include common.yaml + +# Trigger issue: https://github.com/esphome/issues/issues/6990 +# Compile with no binary sensor results in error +binary_sensor: !remove From 2a6827e1d21735ea0000f262839f412c01ad38c9 Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Mon, 5 May 2025 10:30:11 +1000 Subject: [PATCH 092/102] [lvgl] Allow padding to be negative (#8671) --- esphome/components/lvgl/lv_validation.py | 11 +++++++- esphome/components/lvgl/schemas.py | 28 +++++++++---------- .../components/lvgl/widgets/buttonmatrix.py | 6 ++-- esphome/components/lvgl/widgets/checkbox.py | 4 +-- tests/components/lvgl/lvgl-package.yaml | 2 ++ 5 files changed, 31 insertions(+), 20 deletions(-) diff --git a/esphome/components/lvgl/lv_validation.py b/esphome/components/lvgl/lv_validation.py index a3b7cc8ed3..3755d35d27 100644 --- a/esphome/components/lvgl/lv_validation.py +++ b/esphome/components/lvgl/lv_validation.py @@ -16,7 +16,7 @@ from esphome.const import ( ) from esphome.core import CORE, ID, Lambda from esphome.cpp_generator import MockObj -from esphome.cpp_types import ESPTime, uint32 +from esphome.cpp_types import ESPTime, int32, uint32 from esphome.helpers import cpp_string_escape from esphome.schema_extractors import SCHEMA_EXTRACT, schema_extractor @@ -263,6 +263,15 @@ def pixels_validator(value): pixels = LValidator(pixels_validator, uint32, retmapper=literal) +def padding_validator(value): + if isinstance(value, str) and value.lower().endswith("px"): + value = value[:-2] + return cv.int_(value) + + +padding = LValidator(padding_validator, int32, retmapper=literal) + + def zoom_validator(value): value = cv.float_range(0.1, 10.0)(value) return value diff --git a/esphome/components/lvgl/schemas.py b/esphome/components/lvgl/schemas.py index 051dbe5e0e..d0dde01421 100644 --- a/esphome/components/lvgl/schemas.py +++ b/esphome/components/lvgl/schemas.py @@ -156,13 +156,13 @@ STYLE_PROPS = { "opa_layered": lvalid.opacity, "outline_color": lvalid.lv_color, "outline_opa": lvalid.opacity, - "outline_pad": lvalid.pixels, + "outline_pad": lvalid.padding, "outline_width": lvalid.pixels, - "pad_all": lvalid.pixels, - "pad_bottom": lvalid.pixels, - "pad_left": lvalid.pixels, - "pad_right": lvalid.pixels, - "pad_top": lvalid.pixels, + "pad_all": lvalid.padding, + "pad_bottom": lvalid.padding, + "pad_left": lvalid.padding, + "pad_right": lvalid.padding, + "pad_top": lvalid.padding, "shadow_color": lvalid.lv_color, "shadow_ofs_x": lvalid.lv_int, "shadow_ofs_y": lvalid.lv_int, @@ -226,8 +226,8 @@ FULL_STYLE_SCHEMA = STYLE_SCHEMA.extend( { cv.Optional(df.CONF_GRID_CELL_X_ALIGN): grid_alignments, cv.Optional(df.CONF_GRID_CELL_Y_ALIGN): grid_alignments, - cv.Optional(df.CONF_PAD_ROW): lvalid.pixels, - cv.Optional(df.CONF_PAD_COLUMN): lvalid.pixels, + cv.Optional(df.CONF_PAD_ROW): lvalid.padding, + cv.Optional(df.CONF_PAD_COLUMN): lvalid.padding, } ) @@ -370,8 +370,8 @@ LAYOUT_SCHEMA = { cv.Required(df.CONF_GRID_COLUMNS): [grid_spec], cv.Optional(df.CONF_GRID_COLUMN_ALIGN): grid_alignments, cv.Optional(df.CONF_GRID_ROW_ALIGN): grid_alignments, - cv.Optional(df.CONF_PAD_ROW): lvalid.pixels, - cv.Optional(df.CONF_PAD_COLUMN): lvalid.pixels, + cv.Optional(df.CONF_PAD_ROW): lvalid.padding, + cv.Optional(df.CONF_PAD_COLUMN): lvalid.padding, }, df.TYPE_FLEX: { cv.Optional( @@ -380,8 +380,8 @@ LAYOUT_SCHEMA = { cv.Optional(df.CONF_FLEX_ALIGN_MAIN, default="start"): flex_alignments, cv.Optional(df.CONF_FLEX_ALIGN_CROSS, default="start"): flex_alignments, cv.Optional(df.CONF_FLEX_ALIGN_TRACK, default="start"): flex_alignments, - cv.Optional(df.CONF_PAD_ROW): lvalid.pixels, - cv.Optional(df.CONF_PAD_COLUMN): lvalid.pixels, + cv.Optional(df.CONF_PAD_ROW): lvalid.padding, + cv.Optional(df.CONF_PAD_COLUMN): lvalid.padding, }, }, lower=True, @@ -427,8 +427,8 @@ ALL_STYLES = { **STYLE_PROPS, **GRID_CELL_SCHEMA, **FLEX_OBJ_SCHEMA, - cv.Optional(df.CONF_PAD_ROW): lvalid.pixels, - cv.Optional(df.CONF_PAD_COLUMN): lvalid.pixels, + cv.Optional(df.CONF_PAD_ROW): lvalid.padding, + cv.Optional(df.CONF_PAD_COLUMN): lvalid.padding, } diff --git a/esphome/components/lvgl/widgets/buttonmatrix.py b/esphome/components/lvgl/widgets/buttonmatrix.py index 0ba1fe4ae1..aa33be722c 100644 --- a/esphome/components/lvgl/widgets/buttonmatrix.py +++ b/esphome/components/lvgl/widgets/buttonmatrix.py @@ -19,7 +19,7 @@ from ..defines import ( CONF_SELECTED, ) from ..helpers import lvgl_components_required -from ..lv_validation import key_code, lv_bool, pixels +from ..lv_validation import key_code, lv_bool, padding from ..lvcode import lv, lv_add, lv_expr from ..schemas import automation_schema from ..types import ( @@ -59,8 +59,8 @@ BUTTONMATRIX_BUTTON_SCHEMA = cv.Schema( BUTTONMATRIX_SCHEMA = cv.Schema( { cv.Optional(CONF_ONE_CHECKED, default=False): lv_bool, - cv.Optional(CONF_PAD_ROW): pixels, - cv.Optional(CONF_PAD_COLUMN): pixels, + cv.Optional(CONF_PAD_ROW): padding, + cv.Optional(CONF_PAD_COLUMN): padding, cv.GenerateID(CONF_BUTTON_TEXT_LIST_ID): cv.declare_id(char_ptr), cv.Required(CONF_ROWS): cv.ensure_list( cv.Schema( diff --git a/esphome/components/lvgl/widgets/checkbox.py b/esphome/components/lvgl/widgets/checkbox.py index 75f4142eb1..c344fbfe75 100644 --- a/esphome/components/lvgl/widgets/checkbox.py +++ b/esphome/components/lvgl/widgets/checkbox.py @@ -2,7 +2,7 @@ from esphome.config_validation import Optional from esphome.const import CONF_TEXT from ..defines import CONF_INDICATOR, CONF_MAIN, CONF_PAD_COLUMN -from ..lv_validation import lv_text, pixels +from ..lv_validation import lv_text, padding from ..lvcode import lv from ..schemas import TEXT_SCHEMA from ..types import LvBoolean @@ -19,7 +19,7 @@ class CheckboxType(WidgetType): (CONF_MAIN, CONF_INDICATOR), TEXT_SCHEMA.extend( { - Optional(CONF_PAD_COLUMN): pixels, + Optional(CONF_PAD_COLUMN): padding, } ), ) diff --git a/tests/components/lvgl/lvgl-package.yaml b/tests/components/lvgl/lvgl-package.yaml index a0b7dd096f..d0e281e583 100644 --- a/tests/components/lvgl/lvgl-package.yaml +++ b/tests/components/lvgl/lvgl-package.yaml @@ -641,6 +641,8 @@ lvgl: knob: radius: 1 width: "4" + pad_left: -5 + pad_top: 5 height: 10% bg_color: 0x000000 width: 100% From c7523ace7881b01eaa1c86ad58c69abbe74c4070 Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Mon, 5 May 2025 10:31:22 +1000 Subject: [PATCH 093/102] [lvgl] Fix image property processing (#8691) --- esphome/components/lvgl/widgets/img.py | 18 ++++++++++-------- tests/components/lvgl/lvgl-package.yaml | 11 +++++++++++ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/esphome/components/lvgl/widgets/img.py b/esphome/components/lvgl/widgets/img.py index c3e0781489..8ec18e3033 100644 --- a/esphome/components/lvgl/widgets/img.py +++ b/esphome/components/lvgl/widgets/img.py @@ -10,7 +10,7 @@ from ..defines import ( CONF_ZOOM, LvConstant, ) -from ..lv_validation import angle, lv_bool, lv_image, size, zoom +from ..lv_validation import lv_angle, lv_bool, lv_image, size, zoom from ..lvcode import lv from ..types import lv_img_t from . import Widget, WidgetType @@ -22,7 +22,7 @@ BASE_IMG_SCHEMA = cv.Schema( { cv.Optional(CONF_PIVOT_X): size, cv.Optional(CONF_PIVOT_Y): size, - cv.Optional(CONF_ANGLE): angle, + cv.Optional(CONF_ANGLE): lv_angle, cv.Optional(CONF_ZOOM): zoom, cv.Optional(CONF_OFFSET_X): size, cv.Optional(CONF_OFFSET_Y): size, @@ -66,17 +66,19 @@ class ImgType(WidgetType): if (pivot_x := config.get(CONF_PIVOT_X)) and ( pivot_y := config.get(CONF_PIVOT_Y) ): - lv.img_set_pivot(w.obj, pivot_x, pivot_y) + lv.img_set_pivot( + w.obj, await size.process(pivot_x), await size.process(pivot_y) + ) if (cf_angle := config.get(CONF_ANGLE)) is not None: - lv.img_set_angle(w.obj, cf_angle) + lv.img_set_angle(w.obj, await lv_angle.process(cf_angle)) if (img_zoom := config.get(CONF_ZOOM)) is not None: - lv.img_set_zoom(w.obj, img_zoom) + lv.img_set_zoom(w.obj, await zoom.process(img_zoom)) if (offset := config.get(CONF_OFFSET_X)) is not None: - lv.img_set_offset_x(w.obj, offset) + lv.img_set_offset_x(w.obj, await size.process(offset)) if (offset := config.get(CONF_OFFSET_Y)) is not None: - lv.img_set_offset_y(w.obj, offset) + lv.img_set_offset_y(w.obj, await size.process(offset)) if CONF_ANTIALIAS in config: - lv.img_set_antialias(w.obj, config[CONF_ANTIALIAS]) + lv.img_set_antialias(w.obj, await lv_bool.process(config[CONF_ANTIALIAS])) if mode := config.get(CONF_MODE): await w.set_property("size_mode", mode) diff --git a/tests/components/lvgl/lvgl-package.yaml b/tests/components/lvgl/lvgl-package.yaml index d0e281e583..6fd0b5e3c4 100644 --- a/tests/components/lvgl/lvgl-package.yaml +++ b/tests/components/lvgl/lvgl-package.yaml @@ -134,6 +134,15 @@ lvgl: id: style_test bg_color: blue bg_opa: !lambda return 0.5; + - lvgl.image.update: + id: lv_image + zoom: !lambda return 512; + angle: !lambda return 100; + pivot_x: !lambda return 20; + pivot_y: !lambda return 20; + offset_x: !lambda return 20; + offset_y: !lambda return 20; + antialias: !lambda return true; - id: simple_msgbox title: Simple @@ -486,6 +495,8 @@ lvgl: align: top_left y: "50" mode: real + zoom: 2.0 + angle: 45 - tileview: id: tileview_id scrollbar_mode: active From 0b032e5c19796c30fbdcddc622062518b9fa3ff0 Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Mon, 5 May 2025 13:26:16 +1000 Subject: [PATCH 094/102] [lvgl] Add refresh action to re-evaluate initial widget properties (#8675) --- esphome/components/lvgl/__init__.py | 11 +++++- esphome/components/lvgl/automation.py | 51 ++++++++++++++++++++++++- tests/components/lvgl/lvgl-package.yaml | 3 +- 3 files changed, 61 insertions(+), 4 deletions(-) diff --git a/esphome/components/lvgl/__init__.py b/esphome/components/lvgl/__init__.py index 69286ada88..f60d60d9a4 100644 --- a/esphome/components/lvgl/__init__.py +++ b/esphome/components/lvgl/__init__.py @@ -18,13 +18,13 @@ from esphome.const import ( CONF_TRIGGER_ID, CONF_TYPE, ) -from esphome.core import CORE, ID +from esphome.core import CORE, ID, Lambda from esphome.cpp_generator import MockObj from esphome.final_validate import full_config from esphome.helpers import write_file_if_changed from . import defines as df, helpers, lv_validation as lvalid -from .automation import disp_update, focused_widgets, update_to_code +from .automation import disp_update, focused_widgets, refreshed_widgets, update_to_code from .defines import add_define from .encoders import ( ENCODERS_CONFIG, @@ -240,6 +240,13 @@ def final_validation(configs): "A non adjustable arc may not be focused", path, ) + for w in refreshed_widgets: + path = global_config.get_path_for_id(w) + widget_conf = global_config.get_config_for_path(path[:-1]) + if not any(isinstance(v, Lambda) for v in widget_conf.values()): + raise cv.Invalid( + f"Widget '{w}' does not have any templated properties to refresh", + ) async def to_code(configs): diff --git a/esphome/components/lvgl/automation.py b/esphome/components/lvgl/automation.py index 4a71872022..5fea9bfdb1 100644 --- a/esphome/components/lvgl/automation.py +++ b/esphome/components/lvgl/automation.py @@ -35,7 +35,13 @@ from .lvcode import ( lv_obj, lvgl_comp, ) -from .schemas import DISP_BG_SCHEMA, LIST_ACTION_SCHEMA, LVGL_SCHEMA, base_update_schema +from .schemas import ( + ALL_STYLES, + DISP_BG_SCHEMA, + LIST_ACTION_SCHEMA, + LVGL_SCHEMA, + base_update_schema, +) from .types import ( LV_STATE, LvglAction, @@ -57,6 +63,7 @@ from .widgets import ( # Record widgets that are used in a focused action here focused_widgets = set() +refreshed_widgets = set() async def action_to_code( @@ -361,3 +368,45 @@ async def obj_update_to_code(config, action_id, template_arg, args): return await action_to_code( widgets, do_update, action_id, template_arg, args, config ) + + +def validate_refresh_config(config): + for w in config: + refreshed_widgets.add(w[CONF_ID]) + return config + + +@automation.register_action( + "lvgl.widget.refresh", + ObjUpdateAction, + cv.All( + cv.ensure_list( + cv.maybe_simple_value( + { + cv.Required(CONF_ID): cv.use_id(lv_obj_t), + }, + key=CONF_ID, + ) + ), + validate_refresh_config, + ), +) +async def obj_refresh_to_code(config, action_id, template_arg, args): + widget = await get_widgets(config) + + async def do_refresh(widget: Widget): + # only update style properties that might have changed, i.e. are templated + config = {k: v for k, v in widget.config.items() if isinstance(v, Lambda)} + await set_obj_properties(widget, config) + # must pass all widget-specific options here, even if not templated, but only do so if at least one is + # templated. First filter out common style properties. + config = {k: v for k, v in widget.config.items() if k not in ALL_STYLES} + if any(isinstance(v, Lambda) for v in config.values()): + await widget.type.to_code(widget, config) + if ( + widget.type.w_type.value_property is not None + and widget.type.w_type.value_property in config + ): + lv.event_send(widget.obj, UPDATE_EVENT, nullptr) + + return await action_to_code(widget, do_refresh, action_id, template_arg, args) diff --git a/tests/components/lvgl/lvgl-package.yaml b/tests/components/lvgl/lvgl-package.yaml index 6fd0b5e3c4..db55da9225 100644 --- a/tests/components/lvgl/lvgl-package.yaml +++ b/tests/components/lvgl/lvgl-package.yaml @@ -212,7 +212,7 @@ lvgl: - animimg: height: 60 id: anim_img - src: [cat_image, dog_image] + src: !lambda "return {dog_image, cat_image};" repeat_count: 10 duration: 1s auto_start: true @@ -224,6 +224,7 @@ lvgl: id: anim_img src: !lambda "return {dog_image, cat_image};" duration: 2s + - lvgl.widget.refresh: anim_img - label: on_boot: lvgl.label.update: From ad99d7fb4535692a0d68706b81aa941a67229f9b Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Mon, 5 May 2025 13:31:16 +1000 Subject: [PATCH 095/102] [image] Support the other Pictogrammers icon sets `memory:` and `mdil:` (#8676) --- esphome/components/image/__init__.py | 52 +++++++++++++++++++--------- tests/components/image/common.yaml | 15 ++++++++ 2 files changed, 50 insertions(+), 17 deletions(-) diff --git a/esphome/components/image/__init__.py b/esphome/components/image/__init__.py index fbf61c105c..5d593ac3d4 100644 --- a/esphome/components/image/__init__.py +++ b/esphome/components/image/__init__.py @@ -286,9 +286,18 @@ CONF_TRANSPARENCY = "transparency" IMAGE_DOWNLOAD_TIMEOUT = 30 # seconds SOURCE_LOCAL = "local" -SOURCE_MDI = "mdi" SOURCE_WEB = "web" +SOURCE_MDI = "mdi" +SOURCE_MDIL = "mdil" +SOURCE_MEMORY = "memory" + +MDI_SOURCES = { + SOURCE_MDI: "https://raw.githubusercontent.com/Templarian/MaterialDesign/master/svg/", + SOURCE_MDIL: "https://raw.githubusercontent.com/Pictogrammers/MaterialDesignLight/refs/heads/master/svg/", + SOURCE_MEMORY: "https://raw.githubusercontent.com/Pictogrammers/Memory/refs/heads/main/src/svg/", +} + Image_ = image_ns.class_("Image") INSTANCE_TYPE = Image_ @@ -313,12 +322,12 @@ def download_file(url, path): return str(path) -def download_mdi(value): +def download_gh_svg(value, source): mdi_id = value[CONF_ICON] if isinstance(value, dict) else value - base_dir = external_files.compute_local_file_dir(DOMAIN) / "mdi" + base_dir = external_files.compute_local_file_dir(DOMAIN) / source path = base_dir / f"{mdi_id}.svg" - url = f"https://raw.githubusercontent.com/Templarian/MaterialDesign/master/svg/{mdi_id}.svg" + url = MDI_SOURCES[source] + mdi_id + ".svg" return download_file(url, path) @@ -353,12 +362,12 @@ def validate_cairosvg_installed(): def validate_file_shorthand(value): value = cv.string_strict(value) - if value.startswith("mdi:"): - match = re.search(r"mdi:([a-zA-Z0-9\-]+)", value) + parts = value.strip().split(":") + if len(parts) == 2 and parts[0] in MDI_SOURCES: + match = re.match(r"[a-zA-Z0-9\-]+", parts[1]) if match is None: - raise cv.Invalid("Could not parse mdi icon name.") - icon = match.group(1) - return download_mdi(icon) + raise cv.Invalid(f"Could not parse mdi icon name from '{value}'.") + return download_gh_svg(parts[1], parts[0]) if value.startswith("http://") or value.startswith("https://"): return download_image(value) @@ -374,12 +383,20 @@ LOCAL_SCHEMA = cv.All( local_path, ) -MDI_SCHEMA = cv.All( - { - cv.Required(CONF_ICON): cv.string, - }, - download_mdi, -) + +def mdi_schema(source): + def validate_mdi(value): + return download_gh_svg(value, source) + + return cv.All( + cv.Schema( + { + cv.Required(CONF_ICON): cv.string, + } + ), + validate_mdi, + ) + WEB_SCHEMA = cv.All( { @@ -388,12 +405,13 @@ WEB_SCHEMA = cv.All( download_image, ) + TYPED_FILE_SCHEMA = cv.typed_schema( { SOURCE_LOCAL: LOCAL_SCHEMA, - SOURCE_MDI: MDI_SCHEMA, SOURCE_WEB: WEB_SCHEMA, - }, + } + | {source: mdi_schema(source) for source in MDI_SOURCES}, key=CONF_SOURCE, ) diff --git a/tests/components/image/common.yaml b/tests/components/image/common.yaml index 4c9b9ed670..864ca41c44 100644 --- a/tests/components/image/common.yaml +++ b/tests/components/image/common.yaml @@ -69,3 +69,18 @@ image: - id: another_alert_icon file: mdi:alert-outline type: BINARY + - file: mdil:arrange-bring-to-front + id: mdil_id + resize: 50x50 + type: binary + transparency: chroma_key + - file: mdi:beer + id: mdi_id + resize: 50x50 + type: binary + transparency: chroma_key + - file: memory:alert-octagon + id: memory_id + resize: 50x50 + type: binary + transparency: chroma_key From e7a2b395fd0c822ec0be6c5ddd8cd5a4d4db5c70 Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Mon, 5 May 2025 14:15:46 +1000 Subject: [PATCH 096/102] [uart] Add packet_transport platform (#8214) Co-authored-by: Faidon Liambotis Co-authored-by: clydeps Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com> --- CODEOWNERS | 1 + .../uart/packet_transport/__init__.py | 20 +++++ .../uart/packet_transport/uart_transport.cpp | 88 +++++++++++++++++++ .../uart/packet_transport/uart_transport.h | 41 +++++++++ tests/components/uart/test.esp32-idf.yaml | 3 + 5 files changed, 153 insertions(+) create mode 100644 esphome/components/uart/packet_transport/__init__.py create mode 100644 esphome/components/uart/packet_transport/uart_transport.cpp create mode 100644 esphome/components/uart/packet_transport/uart_transport.h diff --git a/CODEOWNERS b/CODEOWNERS index 46e0e6c579..d6381f9799 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -468,6 +468,7 @@ esphome/components/tuya/switch/* @jesserockz esphome/components/tuya/text_sensor/* @dentra esphome/components/uart/* @esphome/core esphome/components/uart/button/* @ssieb +esphome/components/uart/packet_transport/* @clydebarrow esphome/components/udp/* @clydebarrow esphome/components/ufire_ec/* @pvizeli esphome/components/ufire_ise/* @pvizeli diff --git a/esphome/components/uart/packet_transport/__init__.py b/esphome/components/uart/packet_transport/__init__.py new file mode 100644 index 0000000000..58c6296e2f --- /dev/null +++ b/esphome/components/uart/packet_transport/__init__.py @@ -0,0 +1,20 @@ +from esphome.components.packet_transport import ( + PacketTransport, + new_packet_transport, + transport_schema, +) +from esphome.cpp_types import PollingComponent + +from .. import UART_DEVICE_SCHEMA, register_uart_device, uart_ns + +CODEOWNERS = ["@clydebarrow"] +DEPENDENCIES = ["uart"] + +UARTTransport = uart_ns.class_("UARTTransport", PacketTransport, PollingComponent) + +CONFIG_SCHEMA = transport_schema(UARTTransport).extend(UART_DEVICE_SCHEMA) + + +async def to_code(config): + var, _ = await new_packet_transport(config) + await register_uart_device(var, config) diff --git a/esphome/components/uart/packet_transport/uart_transport.cpp b/esphome/components/uart/packet_transport/uart_transport.cpp new file mode 100644 index 0000000000..aa11ae0772 --- /dev/null +++ b/esphome/components/uart/packet_transport/uart_transport.cpp @@ -0,0 +1,88 @@ +#include "esphome/core/log.h" +#include "esphome/core/application.h" +#include "uart_transport.h" + +namespace esphome { +namespace uart { + +static const char *const TAG = "uart_transport"; + +void UARTTransport::loop() { + PacketTransport::loop(); + + while (this->parent_->available()) { + uint8_t byte; + if (!this->parent_->read_byte(&byte)) { + ESP_LOGW(TAG, "Failed to read byte from UART"); + return; + } + if (byte == FLAG_BYTE) { + if (this->rx_started_ && this->receive_buffer_.size() > 6) { + auto len = this->receive_buffer_.size(); + auto crc = crc16(this->receive_buffer_.data(), len - 2); + if (crc != (this->receive_buffer_[len - 2] | (this->receive_buffer_[len - 1] << 8))) { + ESP_LOGD(TAG, "CRC mismatch, discarding packet"); + this->rx_started_ = false; + this->receive_buffer_.clear(); + continue; + } + this->receive_buffer_.resize(len - 2); + this->process_(this->receive_buffer_); + this->rx_started_ = false; + } else { + this->rx_started_ = true; + } + this->receive_buffer_.clear(); + this->rx_control_ = false; + continue; + } + if (!this->rx_started_) + continue; + if (byte == CONTROL_BYTE) { + this->rx_control_ = true; + continue; + } + if (this->rx_control_) { + byte ^= 0x20; + this->rx_control_ = false; + } + if (this->receive_buffer_.size() == MAX_PACKET_SIZE) { + ESP_LOGD(TAG, "Packet too large, discarding"); + this->rx_started_ = false; + this->receive_buffer_.clear(); + continue; + } + this->receive_buffer_.push_back(byte); + } +} + +void UARTTransport::update() { + this->updated_ = true; + this->resend_data_ = true; + PacketTransport::update(); +} + +/** + * Write a byte to the UART bus. If the byte is a flag or control byte, it will be escaped. + * @param byte The byte to write. + */ +void UARTTransport::write_byte_(uint8_t byte) const { + if (byte == FLAG_BYTE || byte == CONTROL_BYTE) { + this->parent_->write_byte(CONTROL_BYTE); + byte ^= 0x20; + } + this->parent_->write_byte(byte); +} + +void UARTTransport::send_packet(std::vector &buf) const { + this->parent_->write_byte(FLAG_BYTE); + for (uint8_t byte : buf) { + this->write_byte_(byte); + } + auto crc = crc16(buf.data(), buf.size()); + this->write_byte_(crc & 0xFF); + this->write_byte_(crc >> 8); + this->parent_->write_byte(FLAG_BYTE); +} +} // namespace uart +} // namespace esphome diff --git a/esphome/components/uart/packet_transport/uart_transport.h b/esphome/components/uart/packet_transport/uart_transport.h new file mode 100644 index 0000000000..db32859452 --- /dev/null +++ b/esphome/components/uart/packet_transport/uart_transport.h @@ -0,0 +1,41 @@ +#pragma once + +#include "esphome/core/component.h" +#include "esphome/components/packet_transport/packet_transport.h" +#include +#include "../uart.h" + +namespace esphome { +namespace uart { + +/** + * A transport protocol for sending and receiving packets over a UART connection. + * The protocol is based on Asynchronous HDLC framing. (https://en.wikipedia.org/wiki/High-Level_Data_Link_Control) + * There are two special bytes: FLAG_BYTE and CONTROL_BYTE. + * A 16-bit CRC is appended to the packet, then + * the protocol wraps the resulting data between FLAG_BYTEs. + * Any occurrence of FLAG_BYTE or CONTROL_BYTE in the data is escaped by emitting CONTROL_BYTE followed by the byte + * XORed with 0x20. + */ +static const uint16_t MAX_PACKET_SIZE = 508; +static const uint8_t FLAG_BYTE = 0x7E; +static const uint8_t CONTROL_BYTE = 0x7D; + +class UARTTransport : public packet_transport::PacketTransport, public UARTDevice { + public: + void loop() override; + void update() override; + float get_setup_priority() const override { return setup_priority::PROCESSOR; } + + protected: + void write_byte_(uint8_t byte) const; + void send_packet(std::vector &buf) const override; + bool should_send() override { return true; }; + size_t get_max_packet_size() override { return MAX_PACKET_SIZE; } + std::vector receive_buffer_{}; + bool rx_started_{}; + bool rx_control_{}; +}; + +} // namespace uart +} // namespace esphome diff --git a/tests/components/uart/test.esp32-idf.yaml b/tests/components/uart/test.esp32-idf.yaml index bef5b460ab..5a0ed7eba7 100644 --- a/tests/components/uart/test.esp32-idf.yaml +++ b/tests/components/uart/test.esp32-idf.yaml @@ -13,3 +13,6 @@ uart: rx_buffer_size: 512 parity: EVEN stop_bits: 2 + +packet_transport: + - platform: uart From b8d83d07651418e6f4fe7755078d729f328fc998 Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Mon, 5 May 2025 14:31:37 +1000 Subject: [PATCH 097/102] [debug] Show source of last software reboot (#8595) --- esphome/components/debug/debug_component.cpp | 1 + esphome/components/debug/debug_component.h | 1 + esphome/components/debug/debug_esp32.cpp | 30 ++++++++++++++++++-- esphome/core/application.cpp | 1 + esphome/core/application.h | 4 +++ esphome/core/scheduler.cpp | 3 ++ 6 files changed, 37 insertions(+), 3 deletions(-) diff --git a/esphome/components/debug/debug_component.cpp b/esphome/components/debug/debug_component.cpp index fcded02ba5..5bcc676247 100644 --- a/esphome/components/debug/debug_component.cpp +++ b/esphome/components/debug/debug_component.cpp @@ -1,6 +1,7 @@ #include "debug_component.h" #include +#include "esphome/core/application.h" #include "esphome/core/log.h" #include "esphome/core/hal.h" #include "esphome/core/helpers.h" diff --git a/esphome/components/debug/debug_component.h b/esphome/components/debug/debug_component.h index f887d52864..a55cc7bf44 100644 --- a/esphome/components/debug/debug_component.h +++ b/esphome/components/debug/debug_component.h @@ -34,6 +34,7 @@ class DebugComponent : public PollingComponent { #endif void set_loop_time_sensor(sensor::Sensor *loop_time_sensor) { loop_time_sensor_ = loop_time_sensor; } #ifdef USE_ESP32 + void on_shutdown() override; void set_psram_sensor(sensor::Sensor *psram_sensor) { this->psram_sensor_ = psram_sensor; } #endif // USE_ESP32 void set_cpu_frequency_sensor(sensor::Sensor *cpu_frequency_sensor) { diff --git a/esphome/components/debug/debug_esp32.cpp b/esphome/components/debug/debug_esp32.cpp index 662e60501d..999cb927b3 100644 --- a/esphome/components/debug/debug_esp32.cpp +++ b/esphome/components/debug/debug_esp32.cpp @@ -1,6 +1,7 @@ #include "debug_component.h" #ifdef USE_ESP32 +#include "esphome/core/application.h" #include "esphome/core/log.h" #include "esphome/core/hal.h" #include @@ -10,12 +11,12 @@ #include #include +#include + #ifdef USE_ARDUINO #include #endif -#include - namespace esphome { namespace debug { @@ -42,16 +43,39 @@ static const char *const RESET_REASONS[] = { "CPU lock up", }; +static const char *const REBOOT_KEY = "reboot_source"; +static const size_t REBOOT_MAX_LEN = 24; + +// on shutdown, store the source of the reboot request +void DebugComponent::on_shutdown() { + auto *component = App.get_current_component(); + char buffer[REBOOT_MAX_LEN]{}; + auto pref = global_preferences->make_preference(REBOOT_MAX_LEN, fnv1_hash(REBOOT_KEY + App.get_name())); + if (component != nullptr) { + strncpy(buffer, component->get_component_source(), REBOOT_MAX_LEN - 1); + } + ESP_LOGD(TAG, "Storing reboot source: %s", buffer); + pref.save(&buffer); + global_preferences->sync(); +} + std::string DebugComponent::get_reset_reason_() { std::string reset_reason; unsigned reason = esp_reset_reason(); if (reason < sizeof(RESET_REASONS) / sizeof(RESET_REASONS[0])) { reset_reason = RESET_REASONS[reason]; + if (reason == ESP_RST_SW) { + auto pref = global_preferences->make_preference(REBOOT_MAX_LEN, fnv1_hash(REBOOT_KEY + App.get_name())); + char buffer[REBOOT_MAX_LEN]{}; + if (pref.load(&buffer)) { + reset_reason = "Reboot request from " + std::string(buffer); + } + } } else { reset_reason = "unknown source"; } ESP_LOGD(TAG, "Reset Reason: %s", reset_reason.c_str()); - return "Reset by " + reset_reason; + return reset_reason; } static const char *const WAKEUP_CAUSES[] = { diff --git a/esphome/core/application.cpp b/esphome/core/application.cpp index a4550bcd9e..3f5a283fd8 100644 --- a/esphome/core/application.cpp +++ b/esphome/core/application.cpp @@ -70,6 +70,7 @@ void Application::loop() { this->feed_wdt(); for (Component *component : this->looping_components_) { { + this->set_current_component(component); WarnIfComponentBlockingGuard guard{component}; component->call(); } diff --git a/esphome/core/application.h b/esphome/core/application.h index 462beb1f25..e64e2b7655 100644 --- a/esphome/core/application.h +++ b/esphome/core/application.h @@ -97,6 +97,9 @@ class Application { this->compilation_time_ = compilation_time; } + void set_current_component(Component *component) { this->current_component_ = component; } + Component *get_current_component() { return this->current_component_; } + #ifdef USE_BINARY_SENSOR void register_binary_sensor(binary_sensor::BinarySensor *binary_sensor) { this->binary_sensors_.push_back(binary_sensor); @@ -547,6 +550,7 @@ class Application { uint32_t loop_interval_{16}; size_t dump_config_at_{SIZE_MAX}; uint32_t app_state_{0}; + Component *current_component_{nullptr}; }; /// Global storage of Application pointer - only one Application can exist. diff --git a/esphome/core/scheduler.cpp b/esphome/core/scheduler.cpp index 7e83b3b705..b4f617d405 100644 --- a/esphome/core/scheduler.cpp +++ b/esphome/core/scheduler.cpp @@ -1,4 +1,6 @@ #include "scheduler.h" + +#include "application.h" #include "esphome/core/defines.h" #include "esphome/core/log.h" #include "esphome/core/helpers.h" @@ -215,6 +217,7 @@ void HOT Scheduler::call() { this->pop_raw_(); continue; } + App.set_current_component(item->component); #ifdef ESPHOME_DEBUG_SCHEDULER ESP_LOGV(TAG, "Running %s '%s/%s' with interval=%" PRIu32 " next_execution=%" PRIu64 " (now=%" PRIu64 ")", From 3b8a5db97c67abad90c6028b05c7d1a45b3b2a9b Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Mon, 5 May 2025 14:48:13 +1000 Subject: [PATCH 098/102] [syslog] Implement logging via syslog (#8637) Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com> --- CODEOWNERS | 1 + esphome/components/syslog/__init__.py | 41 ++++++++++++++++ esphome/components/syslog/esphome_syslog.cpp | 49 +++++++++++++++++++ esphome/components/syslog/esphome_syslog.h | 25 ++++++++++ tests/components/syslog/common.yaml | 15 ++++++ tests/components/syslog/test.bk72xx-ard.yaml | 1 + tests/components/syslog/test.esp32-ard.yaml | 1 + .../components/syslog/test.esp32-c3-ard.yaml | 1 + .../components/syslog/test.esp32-c3-idf.yaml | 1 + tests/components/syslog/test.esp32-idf.yaml | 1 + tests/components/syslog/test.esp8266-ard.yaml | 1 + tests/components/syslog/test.host.yaml | 4 ++ tests/components/syslog/test.rp2040-ard.yaml | 1 + 13 files changed, 142 insertions(+) create mode 100644 esphome/components/syslog/__init__.py create mode 100644 esphome/components/syslog/esphome_syslog.cpp create mode 100644 esphome/components/syslog/esphome_syslog.h create mode 100644 tests/components/syslog/common.yaml create mode 100644 tests/components/syslog/test.bk72xx-ard.yaml create mode 100644 tests/components/syslog/test.esp32-ard.yaml create mode 100644 tests/components/syslog/test.esp32-c3-ard.yaml create mode 100644 tests/components/syslog/test.esp32-c3-idf.yaml create mode 100644 tests/components/syslog/test.esp32-idf.yaml create mode 100644 tests/components/syslog/test.esp8266-ard.yaml create mode 100644 tests/components/syslog/test.host.yaml create mode 100644 tests/components/syslog/test.rp2040-ard.yaml diff --git a/CODEOWNERS b/CODEOWNERS index d6381f9799..29919b6d70 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -429,6 +429,7 @@ esphome/components/sun/* @OttoWinter esphome/components/sun_gtil2/* @Mat931 esphome/components/switch/* @esphome/core esphome/components/switch/binary_sensor/* @ssieb +esphome/components/syslog/* @clydebarrow esphome/components/t6615/* @tylermenezes esphome/components/tc74/* @sethgirvan esphome/components/tca9548a/* @andreashergert1984 diff --git a/esphome/components/syslog/__init__.py b/esphome/components/syslog/__init__.py new file mode 100644 index 0000000000..80b79d2040 --- /dev/null +++ b/esphome/components/syslog/__init__.py @@ -0,0 +1,41 @@ +import esphome.codegen as cg +from esphome.components import udp +from esphome.components.logger import LOG_LEVELS, is_log_level +from esphome.components.time import RealTimeClock +from esphome.components.udp import CONF_UDP_ID +import esphome.config_validation as cv +from esphome.const import CONF_ID, CONF_LEVEL, CONF_PORT, CONF_TIME_ID +from esphome.cpp_types import Component, Parented + +CODEOWNERS = ["@clydebarrow"] + +DEPENDENCIES = ["udp", "logger", "time"] + +syslog_ns = cg.esphome_ns.namespace("syslog") +Syslog = syslog_ns.class_("Syslog", Component, Parented.template(udp.UDPComponent)) + +CONF_STRIP = "strip" +CONF_FACILITY = "facility" +CONFIG_SCHEMA = udp.UDP_SCHEMA.extend( + { + cv.GenerateID(): cv.declare_id(Syslog), + cv.GenerateID(CONF_TIME_ID): cv.use_id(RealTimeClock), + cv.Optional(CONF_PORT, default=514): cv.port, + cv.Optional(CONF_LEVEL, default="DEBUG"): is_log_level, + cv.Optional(CONF_STRIP, default=True): cv.boolean, + cv.Optional(CONF_FACILITY, default=16): cv.int_range(0, 23), + } +) + + +async def to_code(config): + parent = await cg.get_variable(config[CONF_UDP_ID]) + time = await cg.get_variable(config[CONF_TIME_ID]) + cg.add(parent.set_broadcast_port(config[CONF_PORT])) + cg.add(parent.set_should_broadcast()) + level = LOG_LEVELS[config[CONF_LEVEL]] + var = cg.new_Pvariable(config[CONF_ID], level, time) + await cg.register_component(var, config) + await cg.register_parented(var, parent) + cg.add(var.set_strip(config[CONF_STRIP])) + cg.add(var.set_facility(config[CONF_FACILITY])) diff --git a/esphome/components/syslog/esphome_syslog.cpp b/esphome/components/syslog/esphome_syslog.cpp new file mode 100644 index 0000000000..9d2cda549b --- /dev/null +++ b/esphome/components/syslog/esphome_syslog.cpp @@ -0,0 +1,49 @@ +#include "esphome_syslog.h" + +#include "esphome/components/logger/logger.h" +#include "esphome/core/application.h" +#include "esphome/core/time.h" + +namespace esphome { +namespace syslog { + +// Map log levels to syslog severity using an array, indexed by ESPHome log level (1-7) +constexpr int LOG_LEVEL_TO_SYSLOG_SEVERITY[] = { + 3, // NONE + 3, // ERROR + 4, // WARN + 5, // INFO + 6, // CONFIG + 7, // DEBUG + 7, // VERBOSE + 7 // VERY_VERBOSE +}; + +void Syslog::setup() { + logger::global_logger->add_on_log_callback( + [this](int level, const char *tag, const char *message) { this->log_(level, tag, message); }); +} + +void Syslog::log_(const int level, const char *tag, const char *message) const { + if (level > this->log_level_) + return; + // Syslog PRI calculation: facility * 8 + severity + int severity = 7; + if ((unsigned) level <= 7) { + severity = LOG_LEVEL_TO_SYSLOG_SEVERITY[level]; + } + int pri = this->facility_ * 8 + severity; + auto timestamp = this->time_->now().strftime("%b %d %H:%M:%S"); + unsigned len = strlen(message); + // remove color formatting + if (this->strip_ && message[0] == 0x1B && len > 11) { + message += 7; + len -= 11; + } + + auto data = str_sprintf("<%d>%s %s %s: %.*s", pri, timestamp.c_str(), App.get_name().c_str(), tag, len, message); + this->parent_->send_packet((const uint8_t *) data.data(), data.size()); +} + +} // namespace syslog +} // namespace esphome diff --git a/esphome/components/syslog/esphome_syslog.h b/esphome/components/syslog/esphome_syslog.h new file mode 100644 index 0000000000..3fa077b466 --- /dev/null +++ b/esphome/components/syslog/esphome_syslog.h @@ -0,0 +1,25 @@ +#pragma once +#include "esphome/core/component.h" +#include "esphome/core/helpers.h" +#include "esphome/core/log.h" +#include "esphome/components/udp/udp_component.h" +#include "esphome/components/time/real_time_clock.h" + +namespace esphome { +namespace syslog { +class Syslog : public Component, public Parented { + public: + Syslog(int level, time::RealTimeClock *time) : log_level_(level), time_(time) {} + void setup() override; + void set_strip(bool strip) { this->strip_ = strip; } + void set_facility(int facility) { this->facility_ = facility; } + + protected: + int log_level_; + void log_(int level, const char *tag, const char *message) const; + time::RealTimeClock *time_; + bool strip_{true}; + int facility_{16}; +}; +} // namespace syslog +} // namespace esphome diff --git a/tests/components/syslog/common.yaml b/tests/components/syslog/common.yaml new file mode 100644 index 0000000000..cd6e63c9ec --- /dev/null +++ b/tests/components/syslog/common.yaml @@ -0,0 +1,15 @@ +wifi: + ssid: MySSID + password: password1 + +udp: + addresses: ["239.0.60.53"] + +time: + platform: host + +syslog: + port: 514 + strip: true + level: info + facility: 16 diff --git a/tests/components/syslog/test.bk72xx-ard.yaml b/tests/components/syslog/test.bk72xx-ard.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/syslog/test.bk72xx-ard.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/syslog/test.esp32-ard.yaml b/tests/components/syslog/test.esp32-ard.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/syslog/test.esp32-ard.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/syslog/test.esp32-c3-ard.yaml b/tests/components/syslog/test.esp32-c3-ard.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/syslog/test.esp32-c3-ard.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/syslog/test.esp32-c3-idf.yaml b/tests/components/syslog/test.esp32-c3-idf.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/syslog/test.esp32-c3-idf.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/syslog/test.esp32-idf.yaml b/tests/components/syslog/test.esp32-idf.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/syslog/test.esp32-idf.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/syslog/test.esp8266-ard.yaml b/tests/components/syslog/test.esp8266-ard.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/syslog/test.esp8266-ard.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/syslog/test.host.yaml b/tests/components/syslog/test.host.yaml new file mode 100644 index 0000000000..e735c37e4d --- /dev/null +++ b/tests/components/syslog/test.host.yaml @@ -0,0 +1,4 @@ +packages: + common: !include common.yaml + +wifi: !remove diff --git a/tests/components/syslog/test.rp2040-ard.yaml b/tests/components/syslog/test.rp2040-ard.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/syslog/test.rp2040-ard.yaml @@ -0,0 +1 @@ +<<: !include common.yaml From 6f35d0ac88a2c8e0f44645be3b99819c6e5d8138 Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Mon, 5 May 2025 17:56:30 +1000 Subject: [PATCH 099/102] [cst226] Add support for cst226 binary sensor (#8381) --- .../cst226/binary_sensor/__init__.py | 28 ++++++++++++++++ .../cst226/binary_sensor/cs226_button.h | 22 +++++++++++++ .../cst226/binary_sensor/cstt6_button.cpp | 19 +++++++++++ .../cst226/touchscreen/cst226_touchscreen.cpp | 32 ++++++++++++++++--- .../cst226/touchscreen/cst226_touchscreen.h | 20 ++++++------ tests/components/cst226/common.yaml | 6 ++++ 6 files changed, 112 insertions(+), 15 deletions(-) create mode 100644 esphome/components/cst226/binary_sensor/__init__.py create mode 100644 esphome/components/cst226/binary_sensor/cs226_button.h create mode 100644 esphome/components/cst226/binary_sensor/cstt6_button.cpp diff --git a/esphome/components/cst226/binary_sensor/__init__.py b/esphome/components/cst226/binary_sensor/__init__.py new file mode 100644 index 0000000000..d95f0d2b4d --- /dev/null +++ b/esphome/components/cst226/binary_sensor/__init__.py @@ -0,0 +1,28 @@ +import esphome.codegen as cg +from esphome.components import binary_sensor +import esphome.config_validation as cv + +from .. import cst226_ns +from ..touchscreen import CST226ButtonListener, CST226Touchscreen + +CONF_CST226_ID = "cst226_id" + +CST226Button = cst226_ns.class_( + "CST226Button", + binary_sensor.BinarySensor, + cg.Component, + CST226ButtonListener, + cg.Parented.template(CST226Touchscreen), +) + +CONFIG_SCHEMA = binary_sensor.binary_sensor_schema(CST226Button).extend( + { + cv.GenerateID(CONF_CST226_ID): cv.use_id(CST226Touchscreen), + } +) + + +async def to_code(config): + var = await binary_sensor.new_binary_sensor(config) + await cg.register_component(var, config) + await cg.register_parented(var, config[CONF_CST226_ID]) diff --git a/esphome/components/cst226/binary_sensor/cs226_button.h b/esphome/components/cst226/binary_sensor/cs226_button.h new file mode 100644 index 0000000000..6d409df04f --- /dev/null +++ b/esphome/components/cst226/binary_sensor/cs226_button.h @@ -0,0 +1,22 @@ +#pragma once + +#include "esphome/components/binary_sensor/binary_sensor.h" +#include "../touchscreen/cst226_touchscreen.h" +#include "esphome/core/helpers.h" + +namespace esphome { +namespace cst226 { + +class CST226Button : public binary_sensor::BinarySensor, + public Component, + public CST226ButtonListener, + public Parented { + public: + void setup() override; + void dump_config() override; + + void update_button(bool state) override; +}; + +} // namespace cst226 +} // namespace esphome diff --git a/esphome/components/cst226/binary_sensor/cstt6_button.cpp b/esphome/components/cst226/binary_sensor/cstt6_button.cpp new file mode 100644 index 0000000000..c481ce5d57 --- /dev/null +++ b/esphome/components/cst226/binary_sensor/cstt6_button.cpp @@ -0,0 +1,19 @@ +#include "cs226_button.h" +#include "esphome/core/log.h" + +namespace esphome { +namespace cst226 { + +static const char *const TAG = "CST226.binary_sensor"; + +void CST226Button::setup() { + this->parent_->register_button_listener(this); + this->publish_initial_state(false); +} + +void CST226Button::dump_config() { LOG_BINARY_SENSOR("", "CST226 Button", this); } + +void CST226Button::update_button(bool state) { this->publish_state(state); } + +} // namespace cst226 +} // namespace esphome diff --git a/esphome/components/cst226/touchscreen/cst226_touchscreen.cpp b/esphome/components/cst226/touchscreen/cst226_touchscreen.cpp index a25859fe17..fa8cd9b057 100644 --- a/esphome/components/cst226/touchscreen/cst226_touchscreen.cpp +++ b/esphome/components/cst226/touchscreen/cst226_touchscreen.cpp @@ -3,8 +3,10 @@ namespace esphome { namespace cst226 { +static const char *const TAG = "cst226.touchscreen"; + void CST226Touchscreen::setup() { - esph_log_config(TAG, "Setting up CST226 Touchscreen..."); + ESP_LOGCONFIG(TAG, "Setting up CST226 Touchscreen..."); if (this->reset_pin_ != nullptr) { this->reset_pin_->setup(); this->reset_pin_->digital_write(true); @@ -26,6 +28,11 @@ void CST226Touchscreen::update_touches() { return; } this->status_clear_warning(); + if (data[0] == 0x83 && data[1] == 0x17 && data[5] == 0x80) { + this->update_button_state_(true); + return; + } + this->update_button_state_(false); if (data[6] != 0xAB || data[0] == 0xAB || data[5] == 0x80) { this->skip_update_ = true; return; @@ -43,13 +50,21 @@ void CST226Touchscreen::update_touches() { int16_t y = (data[index + 2] << 4) | (data[index + 3] & 0x0F); int16_t z = data[index + 4]; this->add_raw_touch_position_(id, x, y, z); - esph_log_v(TAG, "Read touch %d: %d/%d", id, x, y); + ESP_LOGV(TAG, "Read touch %d: %d/%d", id, x, y); index += 5; if (i == 0) index += 2; } } +bool CST226Touchscreen::read16_(uint16_t addr, uint8_t *data, size_t len) { + if (this->read_register16(addr, data, len) != i2c::ERROR_OK) { + ESP_LOGE(TAG, "Read data from 0x%04X failed", addr); + this->mark_failed(); + return false; + } + return true; +} void CST226Touchscreen::continue_setup_() { uint8_t buffer[8]; if (this->interrupt_pin_ != nullptr) { @@ -58,7 +73,7 @@ void CST226Touchscreen::continue_setup_() { } buffer[0] = 0xD1; if (this->write_register16(0xD1, buffer, 1) != i2c::ERROR_OK) { - esph_log_e(TAG, "Write byte to 0xD1 failed"); + ESP_LOGE(TAG, "Write byte to 0xD1 failed"); this->mark_failed(); return; } @@ -66,7 +81,7 @@ void CST226Touchscreen::continue_setup_() { if (this->read16_(0xD204, buffer, 4)) { uint16_t chip_id = buffer[2] + (buffer[3] << 8); uint16_t project_id = buffer[0] + (buffer[1] << 8); - esph_log_config(TAG, "Chip ID %X, project ID %x", chip_id, project_id); + ESP_LOGCONFIG(TAG, "Chip ID %X, project ID %x", chip_id, project_id); } if (this->x_raw_max_ == 0 || this->y_raw_max_ == 0) { if (this->read16_(0xD1F8, buffer, 4)) { @@ -80,7 +95,14 @@ void CST226Touchscreen::continue_setup_() { } } this->setup_complete_ = true; - esph_log_config(TAG, "CST226 Touchscreen setup complete"); + ESP_LOGCONFIG(TAG, "CST226 Touchscreen setup complete"); +} +void CST226Touchscreen::update_button_state_(bool state) { + if (this->button_touched_ == state) + return; + this->button_touched_ = state; + for (auto *listener : this->button_listeners_) + listener->update_button(state); } void CST226Touchscreen::dump_config() { diff --git a/esphome/components/cst226/touchscreen/cst226_touchscreen.h b/esphome/components/cst226/touchscreen/cst226_touchscreen.h index 9f518e5068..c744e51fec 100644 --- a/esphome/components/cst226/touchscreen/cst226_touchscreen.h +++ b/esphome/components/cst226/touchscreen/cst226_touchscreen.h @@ -9,10 +9,13 @@ namespace esphome { namespace cst226 { -static const char *const TAG = "cst226.touchscreen"; - static const uint8_t CST226_REG_STATUS = 0x00; +class CST226ButtonListener { + public: + virtual void update_button(bool state) = 0; +}; + class CST226Touchscreen : public touchscreen::Touchscreen, public i2c::I2CDevice { public: void setup() override; @@ -22,22 +25,19 @@ class CST226Touchscreen : public touchscreen::Touchscreen, public i2c::I2CDevice void set_interrupt_pin(InternalGPIOPin *pin) { this->interrupt_pin_ = pin; } void set_reset_pin(GPIOPin *pin) { this->reset_pin_ = pin; } bool can_proceed() override { return this->setup_complete_ || this->is_failed(); } + void register_button_listener(CST226ButtonListener *listener) { this->button_listeners_.push_back(listener); } protected: - bool read16_(uint16_t addr, uint8_t *data, size_t len) { - if (this->read_register16(addr, data, len) != i2c::ERROR_OK) { - esph_log_e(TAG, "Read data from 0x%04X failed", addr); - this->mark_failed(); - return false; - } - return true; - } + bool read16_(uint16_t addr, uint8_t *data, size_t len); void continue_setup_(); + void update_button_state_(bool state); InternalGPIOPin *interrupt_pin_{}; GPIOPin *reset_pin_{}; uint8_t chip_id_{}; bool setup_complete_{}; + std::vector button_listeners_; + bool button_touched_{}; }; } // namespace cst226 diff --git a/tests/components/cst226/common.yaml b/tests/components/cst226/common.yaml index c12d8d872c..d0b8ea3a86 100644 --- a/tests/components/cst226/common.yaml +++ b/tests/components/cst226/common.yaml @@ -23,3 +23,9 @@ touchscreen: interrupt_pin: ${interrupt_pin} reset_pin: ${reset_pin} +binary_sensor: + - id: cst226_touch + platform: cst226 + on_press: + then: + - component.update: ts_cst226 From 8bbc509b0b45d0925d5b64b80d586d2839343f1e Mon Sep 17 00:00:00 2001 From: Edward Firmo <94725493+edwardtfn@users.noreply.github.com> Date: Mon, 5 May 2025 10:08:16 +0200 Subject: [PATCH 100/102] [nextion] Adds a command pacer with `command_spacing` attribute (#7948) Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com> --- esphome/components/nextion/base_component.py | 37 ++++++++------- esphome/components/nextion/display.py | 13 ++++- esphome/components/nextion/nextion.cpp | 40 +++++++++++++++- esphome/components/nextion/nextion.h | 50 ++++++++++++++++++++ tests/components/nextion/common.yaml | 1 + 5 files changed, 120 insertions(+), 21 deletions(-) diff --git a/esphome/components/nextion/base_component.py b/esphome/components/nextion/base_component.py index 9708379861..0058d957dc 100644 --- a/esphome/components/nextion/base_component.py +++ b/esphome/components/nextion/base_component.py @@ -7,28 +7,29 @@ from esphome.const import CONF_BACKGROUND_COLOR, CONF_FOREGROUND_COLOR, CONF_VIS from . import CONF_NEXTION_ID, Nextion -CONF_VARIABLE_NAME = "variable_name" +CONF_AUTO_WAKE_ON_TOUCH = "auto_wake_on_touch" +CONF_BACKGROUND_PRESSED_COLOR = "background_pressed_color" +CONF_COMMAND_SPACING = "command_spacing" CONF_COMPONENT_NAME = "component_name" -CONF_WAVE_CHANNEL_ID = "wave_channel_id" -CONF_WAVE_MAX_VALUE = "wave_max_value" -CONF_PRECISION = "precision" -CONF_WAVEFORM_SEND_LAST_VALUE = "waveform_send_last_value" -CONF_TFT_URL = "tft_url" +CONF_EXIT_REPARSE_ON_START = "exit_reparse_on_start" +CONF_FONT_ID = "font_id" +CONF_FOREGROUND_PRESSED_COLOR = "foreground_pressed_color" +CONF_ON_BUFFER_OVERFLOW = "on_buffer_overflow" +CONF_ON_PAGE = "on_page" +CONF_ON_SETUP = "on_setup" CONF_ON_SLEEP = "on_sleep" CONF_ON_WAKE = "on_wake" -CONF_ON_SETUP = "on_setup" -CONF_ON_PAGE = "on_page" -CONF_ON_BUFFER_OVERFLOW = "on_buffer_overflow" -CONF_TOUCH_SLEEP_TIMEOUT = "touch_sleep_timeout" -CONF_WAKE_UP_PAGE = "wake_up_page" -CONF_START_UP_PAGE = "start_up_page" -CONF_AUTO_WAKE_ON_TOUCH = "auto_wake_on_touch" -CONF_WAVE_MAX_LENGTH = "wave_max_length" -CONF_BACKGROUND_PRESSED_COLOR = "background_pressed_color" -CONF_FOREGROUND_PRESSED_COLOR = "foreground_pressed_color" -CONF_FONT_ID = "font_id" -CONF_EXIT_REPARSE_ON_START = "exit_reparse_on_start" +CONF_PRECISION = "precision" CONF_SKIP_CONNECTION_HANDSHAKE = "skip_connection_handshake" +CONF_START_UP_PAGE = "start_up_page" +CONF_TFT_URL = "tft_url" +CONF_TOUCH_SLEEP_TIMEOUT = "touch_sleep_timeout" +CONF_VARIABLE_NAME = "variable_name" +CONF_WAKE_UP_PAGE = "wake_up_page" +CONF_WAVE_CHANNEL_ID = "wave_channel_id" +CONF_WAVE_MAX_LENGTH = "wave_max_length" +CONF_WAVE_MAX_VALUE = "wave_max_value" +CONF_WAVEFORM_SEND_LAST_VALUE = "waveform_send_last_value" def NextionName(value): diff --git a/esphome/components/nextion/display.py b/esphome/components/nextion/display.py index 60f26e5234..2e7c1c2825 100644 --- a/esphome/components/nextion/display.py +++ b/esphome/components/nextion/display.py @@ -9,16 +9,17 @@ from esphome.const import ( CONF_ON_TOUCH, CONF_TRIGGER_ID, ) -from esphome.core import CORE +from esphome.core import CORE, TimePeriod from . import Nextion, nextion_ns, nextion_ref from .base_component import ( CONF_AUTO_WAKE_ON_TOUCH, + CONF_COMMAND_SPACING, CONF_EXIT_REPARSE_ON_START, CONF_ON_BUFFER_OVERFLOW, - CONF_ON_PAGE, CONF_ON_SETUP, CONF_ON_SLEEP, + CONF_ON_PAGE, CONF_ON_WAKE, CONF_SKIP_CONNECTION_HANDSHAKE, CONF_START_UP_PAGE, @@ -88,6 +89,10 @@ CONFIG_SCHEMA = ( cv.Optional(CONF_AUTO_WAKE_ON_TOUCH, default=True): cv.boolean, cv.Optional(CONF_EXIT_REPARSE_ON_START, default=False): cv.boolean, cv.Optional(CONF_SKIP_CONNECTION_HANDSHAKE, default=False): cv.boolean, + cv.Optional(CONF_COMMAND_SPACING): cv.All( + cv.positive_time_period_milliseconds, + cv.Range(max=TimePeriod(milliseconds=255)), + ), } ) .extend(cv.polling_component_schema("5s")) @@ -120,6 +125,10 @@ async def to_code(config): var = cg.new_Pvariable(config[CONF_ID]) await uart.register_uart_device(var, config) + if command_spacing := config.get(CONF_COMMAND_SPACING): + cg.add_define("USE_NEXTION_COMMAND_SPACING") + cg.add(var.set_command_spacing(command_spacing.total_milliseconds)) + if CONF_BRIGHTNESS in config: cg.add(var.set_brightness(config[CONF_BRIGHTNESS])) diff --git a/esphome/components/nextion/nextion.cpp b/esphome/components/nextion/nextion.cpp index 67f08f68f8..38e37300af 100644 --- a/esphome/components/nextion/nextion.cpp +++ b/esphome/components/nextion/nextion.cpp @@ -31,11 +31,22 @@ bool Nextion::send_command_(const std::string &command) { return false; } +#ifdef USE_NEXTION_COMMAND_SPACING + if (!this->ignore_is_setup_ && !this->command_pacer_.can_send()) { + return false; + } +#endif // USE_NEXTION_COMMAND_SPACING + ESP_LOGN(TAG, "send_command %s", command.c_str()); this->write_str(command.c_str()); const uint8_t to_send[3] = {0xFF, 0xFF, 0xFF}; this->write_array(to_send, sizeof(to_send)); + +#ifdef USE_NEXTION_COMMAND_SPACING + this->command_pacer_.mark_sent(); +#endif // USE_NEXTION_COMMAND_SPACING + return true; } @@ -158,6 +169,10 @@ void Nextion::dump_config() { if (this->start_up_page_ != -1) { ESP_LOGCONFIG(TAG, " Start Up Page: %" PRId16, this->start_up_page_); } + +#ifdef USE_NEXTION_COMMAND_SPACING + ESP_LOGCONFIG(TAG, " Command spacing: %" PRIu8 "ms", this->command_pacer_.get_spacing()); +#endif // USE_NEXTION_COMMAND_SPACING } float Nextion::get_setup_priority() const { return setup_priority::DATA; } @@ -312,6 +327,11 @@ bool Nextion::remove_from_q_(bool report_empty) { } NextionQueue *nb = this->nextion_queue_.front(); + if (!nb || !nb->component) { + ESP_LOGE(TAG, "Invalid queue entry!"); + this->nextion_queue_.pop_front(); + return false; + } NextionComponentBase *component = nb->component; ESP_LOGN(TAG, "Removing %s from the queue", component->get_variable_name().c_str()); @@ -341,6 +361,12 @@ void Nextion::process_nextion_commands_() { return; } +#ifdef USE_NEXTION_COMMAND_SPACING + if (!this->command_pacer_.can_send()) { + return; // Will try again in next loop iteration + } +#endif + size_t to_process_length = 0; std::string to_process; @@ -380,7 +406,9 @@ void Nextion::process_nextion_commands_() { this->setup_callback_.call(); } } - +#ifdef USE_NEXTION_COMMAND_SPACING + this->command_pacer_.mark_sent(); // Here is where we should mark the command as sent +#endif break; case 0x02: // invalid Component ID or name was used ESP_LOGW(TAG, "Nextion reported component ID or name invalid!"); @@ -524,6 +552,11 @@ void Nextion::process_nextion_commands_() { } NextionQueue *nb = this->nextion_queue_.front(); + if (!nb || !nb->component) { + ESP_LOGE(TAG, "Invalid queue entry!"); + this->nextion_queue_.pop_front(); + return; + } NextionComponentBase *component = nb->component; if (component->get_queue_type() != NextionQueueType::TEXT_SENSOR) { @@ -564,6 +597,11 @@ void Nextion::process_nextion_commands_() { } NextionQueue *nb = this->nextion_queue_.front(); + if (!nb || !nb->component) { + ESP_LOGE(TAG, "Invalid queue entry!"); + this->nextion_queue_.pop_front(); + return; + } NextionComponentBase *component = nb->component; if (component->get_queue_type() != NextionQueueType::SENSOR && diff --git a/esphome/components/nextion/nextion.h b/esphome/components/nextion/nextion.h index b2404e1f0d..4bc5305923 100644 --- a/esphome/components/nextion/nextion.h +++ b/esphome/components/nextion/nextion.h @@ -35,8 +35,54 @@ using nextion_writer_t = std::function; static const std::string COMMAND_DELIMITER{static_cast(255), static_cast(255), static_cast(255)}; +#ifdef USE_NEXTION_COMMAND_SPACING +class NextionCommandPacer { + public: + /** + * @brief Creates command pacer with initial spacing + * @param initial_spacing Initial time between commands in milliseconds + */ + explicit NextionCommandPacer(uint8_t initial_spacing = 0) : spacing_ms_(initial_spacing) {} + + /** + * @brief Set the minimum time between commands + * @param spacing_ms Spacing in milliseconds + */ + void set_spacing(uint8_t spacing_ms) { spacing_ms_ = spacing_ms; } + + /** + * @brief Get current command spacing + * @return Current spacing in milliseconds + */ + uint8_t get_spacing() const { return spacing_ms_; } + + /** + * @brief Check if enough time has passed to send next command + * @return true if enough time has passed since last command + */ + bool can_send() const { return (millis() - last_command_time_) >= spacing_ms_; } + + /** + * @brief Mark a command as sent, updating the timing + */ + void mark_sent() { last_command_time_ = millis(); } + + private: + uint8_t spacing_ms_; + uint32_t last_command_time_{0}; +}; +#endif // USE_NEXTION_COMMAND_SPACING + class Nextion : public NextionBase, public PollingComponent, public uart::UARTDevice { public: +#ifdef USE_NEXTION_COMMAND_SPACING + /** + * @brief Set the command spacing for the display + * @param spacing_ms Time in milliseconds between commands + */ + void set_command_spacing(uint32_t spacing_ms) { this->command_pacer_.set_spacing(spacing_ms); } +#endif // USE_NEXTION_COMMAND_SPACING + /** * Set the text of a component to a static string. * @param component The component name. @@ -1227,6 +1273,9 @@ class Nextion : public NextionBase, public PollingComponent, public uart::UARTDe bool is_connected() { return this->is_connected_; } protected: +#ifdef USE_NEXTION_COMMAND_SPACING + NextionCommandPacer command_pacer_{0}; +#endif // USE_NEXTION_COMMAND_SPACING std::deque nextion_queue_; std::deque waveform_queue_; uint16_t recv_ret_string_(std::string &response, uint32_t timeout, bool recv_flag); @@ -1360,5 +1409,6 @@ class Nextion : public NextionBase, public PollingComponent, public uart::UARTDe uint32_t started_ms_ = 0; bool sent_setup_commands_ = false; }; + } // namespace nextion } // namespace esphome diff --git a/tests/components/nextion/common.yaml b/tests/components/nextion/common.yaml index 589afcfefb..44d6cdfbc9 100644 --- a/tests/components/nextion/common.yaml +++ b/tests/components/nextion/common.yaml @@ -280,6 +280,7 @@ display: - platform: nextion id: main_lcd update_interval: 5s + command_spacing: 5ms on_sleep: then: lambda: 'ESP_LOGD("display","Display went to sleep");' From 1ac56b06c5955ecf1067ac67039e1d72ed9fb3d5 Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Mon, 5 May 2025 18:25:24 +1000 Subject: [PATCH 101/102] [arduino] Always include Arduino.h for Arduino (#8693) --- esphome/core/macros.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/esphome/core/macros.h b/esphome/core/macros.h index ee53d20ad1..8b2383321b 100644 --- a/esphome/core/macros.h +++ b/esphome/core/macros.h @@ -2,3 +2,7 @@ // Helper macro to define a version code, whose value can be compared against other version codes. #define VERSION_CODE(major, minor, patch) ((major) << 16 | (minor) << 8 | (patch)) + +#ifdef USE_ARDUINO +#include +#endif From 88be14aaa37398acaf870f6040641b82a53cbe5e Mon Sep 17 00:00:00 2001 From: Kevin Ahrendt Date: Mon, 5 May 2025 16:23:50 -0500 Subject: [PATCH 102/102] [audio, microphone] Quantization Improvements (#8695) --- esphome/components/audio/audio.h | 25 +++++ .../microphone/microphone_source.cpp | 97 +++++++++++-------- .../components/microphone/microphone_source.h | 19 +++- 3 files changed, 97 insertions(+), 44 deletions(-) diff --git a/esphome/components/audio/audio.h b/esphome/components/audio/audio.h index 6f0f1aaa46..2c556c68e2 100644 --- a/esphome/components/audio/audio.h +++ b/esphome/components/audio/audio.h @@ -135,5 +135,30 @@ const char *audio_file_type_to_string(AudioFileType file_type); void scale_audio_samples(const int16_t *audio_samples, int16_t *output_buffer, int16_t scale_factor, size_t samples_to_scale); +/// @brief Unpacks a quantized audio sample into a Q31 fixed point number. +/// @param data Pointer to uint8_t array containing the audio sample +/// @param bytes_per_sample The number of bytes per sample +/// @return Q31 sample +inline int32_t unpack_audio_sample_to_q31(const uint8_t *data, size_t bytes_per_sample) { + int32_t sample = 0; + if (bytes_per_sample == 1) { + sample |= data[0] << 24; + } else if (bytes_per_sample == 2) { + sample |= data[0] << 16; + sample |= data[1] << 24; + } else if (bytes_per_sample == 3) { + sample |= data[0] << 8; + sample |= data[1] << 16; + sample |= data[2] << 24; + } else if (bytes_per_sample == 4) { + sample |= data[0]; + sample |= data[1] << 8; + sample |= data[2] << 16; + sample |= data[3] << 24; + } + + return sample; +} + } // namespace audio } // namespace esphome diff --git a/esphome/components/microphone/microphone_source.cpp b/esphome/components/microphone/microphone_source.cpp index 35e8d5dd4d..1ea0deb22b 100644 --- a/esphome/components/microphone/microphone_source.cpp +++ b/esphome/components/microphone/microphone_source.cpp @@ -3,16 +3,34 @@ namespace esphome { namespace microphone { +static const int32_t Q25_MAX_VALUE = (1 << 25) - 1; +static const int32_t Q25_MIN_VALUE = ~Q25_MAX_VALUE; + +static const uint32_t HISTORY_VALUES = 32; + void MicrophoneSource::add_data_callback(std::function &)> &&data_callback) { std::function &)> filtered_callback = [this, data_callback](const std::vector &data) { if (this->enabled_) { - data_callback(this->process_audio_(data)); + if (this->processed_samples_.use_count() == 0) { + // Create vector if its unused + this->processed_samples_ = std::make_shared>(); + } + + // Take temporary ownership of samples vector to avoid deallaction before the callback finishes + std::shared_ptr> output_samples = this->processed_samples_; + this->process_audio_(data, *output_samples); + data_callback(*output_samples); } }; this->mic_->add_data_callback(std::move(filtered_callback)); } +audio::AudioStreamInfo MicrophoneSource::get_audio_stream_info() { + return audio::AudioStreamInfo(this->bits_per_sample_, this->channels_.count(), + this->mic_->get_audio_stream_info().get_sample_rate()); +} + void MicrophoneSource::start() { if (!this->enabled_) { this->enabled_ = true; @@ -23,14 +41,21 @@ void MicrophoneSource::stop() { if (this->enabled_) { this->enabled_ = false; this->mic_->stop(); + this->processed_samples_.reset(); } } -std::vector MicrophoneSource::process_audio_(const std::vector &data) { - // Bit depth conversions are obtained by truncating bits or padding with zeros - no dithering is applied. +void MicrophoneSource::process_audio_(const std::vector &data, std::vector &filtered_data) { + // - Bit depth conversions are obtained by truncating bits or padding with zeros - no dithering is applied. + // - In the comments, Qxx refers to a fixed point number with xx bits of precision for representing fractional values. + // For example, audio with a bit depth of 16 can store a sample in a int16, which can be considered a Q15 number. + // - All samples are converted to Q25 before applying the gain factor - this results in a small precision loss for + // data with 32 bits per sample. Since the maximum gain factor is 64 = (1<<6), this ensures that applying the gain + // will never overflow a 32 bit signed integer. This still retains more bit depth than what is audibly noticeable. + // - Loops for reading/writing data buffers are unrolled, assuming little endian, for a small performance increase. const size_t source_bytes_per_sample = this->mic_->get_audio_stream_info().samples_to_bytes(1); - const size_t source_channels = this->mic_->get_audio_stream_info().get_channels(); + const uint32_t source_channels = this->mic_->get_audio_stream_info().get_channels(); const size_t source_bytes_per_frame = this->mic_->get_audio_stream_info().frames_to_bytes(1); @@ -38,60 +63,48 @@ std::vector MicrophoneSource::process_audio_(const std::vector const size_t target_bytes_per_sample = (this->bits_per_sample_ + 7) / 8; const size_t target_bytes_per_frame = target_bytes_per_sample * this->channels_.count(); - std::vector filtered_data; filtered_data.reserve(target_bytes_per_frame * total_frames); + filtered_data.resize(0); - const int32_t target_min_value = -(1 << (8 * target_bytes_per_sample - 1)); - const int32_t target_max_value = (1 << (8 * target_bytes_per_sample - 1)) - 1; - - for (size_t frame_index = 0; frame_index < total_frames; ++frame_index) { - for (size_t channel_index = 0; channel_index < source_channels; ++channel_index) { + for (uint32_t frame_index = 0; frame_index < total_frames; ++frame_index) { + for (uint32_t channel_index = 0; channel_index < source_channels; ++channel_index) { if (this->channels_.test(channel_index)) { // Channel's current sample is included in the target mask. Convert bits per sample, if necessary. - size_t sample_index = frame_index * source_bytes_per_frame + channel_index * source_bytes_per_sample; + const uint32_t sample_index = frame_index * source_bytes_per_frame + channel_index * source_bytes_per_sample; - int32_t sample = 0; - - // Copy the data into the most significant bits of the sample variable to ensure the sign bit is correct - uint8_t bit_offset = (4 - source_bytes_per_sample) * 8; - for (int i = 0; i < source_bytes_per_sample; ++i) { - sample |= data[sample_index + i] << bit_offset; - bit_offset += 8; - } - - // Shift data back to the least significant bits - if (source_bytes_per_sample >= target_bytes_per_sample) { - // Keep source bytes per sample of data so that the gain multiplication uses all significant bits instead of - // shifting to the target bytes per sample immediately, potentially losing information. - sample >>= (4 - source_bytes_per_sample) * 8; // ``source_bytes_per_sample`` bytes of valid data - } else { - // Keep padded zeros to match the target bytes per sample - sample >>= (4 - target_bytes_per_sample) * 8; // ``target_bytes_per_sample`` bytes of valid data - } + int32_t sample = audio::unpack_audio_sample_to_q31(&data[sample_index], source_bytes_per_sample); // Q31 + sample >>= 6; // Q31 -> Q25 // Apply gain using multiplication - sample *= this->gain_factor_; + sample *= this->gain_factor_; // Q25 - // Match target output bytes by shifting out the least significant bits - if (source_bytes_per_sample > target_bytes_per_sample) { - sample >>= 8 * (source_bytes_per_sample - - target_bytes_per_sample); // ``target_bytes_per_sample`` bytes of valid data - } - - // Clamp ``sample`` to the target bytes per sample range in case gain multiplication overflows - sample = clamp(sample, target_min_value, target_max_value); + // Clamp ``sample`` in case gain multiplication overflows 25 bits + sample = clamp(sample, Q25_MIN_VALUE, Q25_MAX_VALUE); // Q25 // Copy ``target_bytes_per_sample`` bytes to the output buffer. - for (int i = 0; i < target_bytes_per_sample; ++i) { + if (target_bytes_per_sample == 1) { + sample >>= 18; // Q25 -> Q7 filtered_data.push_back(static_cast(sample)); - sample >>= 8; + } else if (target_bytes_per_sample == 2) { + sample >>= 10; // Q25 -> Q15 + filtered_data.push_back(static_cast(sample)); + filtered_data.push_back(static_cast(sample >> 8)); + } else if (target_bytes_per_sample == 3) { + sample >>= 2; // Q25 -> Q23 + filtered_data.push_back(static_cast(sample)); + filtered_data.push_back(static_cast(sample >> 8)); + filtered_data.push_back(static_cast(sample >> 16)); + } else { + sample *= (1 << 6); // Q25 -> Q31 + filtered_data.push_back(static_cast(sample)); + filtered_data.push_back(static_cast(sample >> 8)); + filtered_data.push_back(static_cast(sample >> 16)); + filtered_data.push_back(static_cast(sample >> 24)); } } } } - - return filtered_data; } } // namespace microphone diff --git a/esphome/components/microphone/microphone_source.h b/esphome/components/microphone/microphone_source.h index 028920f101..7f8a37b360 100644 --- a/esphome/components/microphone/microphone_source.h +++ b/esphome/components/microphone/microphone_source.h @@ -1,15 +1,20 @@ #pragma once +#include "microphone.h" + +#include "esphome/components/audio/audio.h" + #include #include #include #include #include -#include "microphone.h" namespace esphome { namespace microphone { +static const int32_t MAX_GAIN_FACTOR = 64; + class MicrophoneSource { /* * @brief Helper class that handles converting raw microphone data to a requested format. @@ -44,13 +49,23 @@ class MicrophoneSource { void add_data_callback(std::function &)> &&data_callback); + void set_gain_factor(int32_t gain_factor) { this->gain_factor_ = clamp(gain_factor, 1, MAX_GAIN_FACTOR); } + int32_t get_gain_factor() { return this->gain_factor_; } + + /// @brief Gets the AudioStreamInfo of the data after processing + /// @return audio::AudioStreamInfo with the configured bits per sample, configured channel count, and source + /// microphone's sample rate + audio::AudioStreamInfo get_audio_stream_info(); + void start(); void stop(); bool is_running() const { return (this->mic_->is_running() && this->enabled_); } bool is_stopped() const { return !this->enabled_; } protected: - std::vector process_audio_(const std::vector &data); + void process_audio_(const std::vector &data, std::vector &filtered_data); + + std::shared_ptr> processed_samples_; Microphone *mic_; uint8_t bits_per_sample_;