From aed6fa14f0fc7ab85ff8e484461ff7004cca5b50 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 27 Sep 2025 12:23:45 -0500 Subject: [PATCH 1/5] make_captive_portal_captive --- esphome/components/captive_portal/__init__.py | 10 + .../captive_portal/captive_portal.cpp | 49 ++-- .../captive_portal/captive_portal.h | 27 +- .../captive_portal/dns_server_esp32_idf.cpp | 232 ++++++++++++++++++ .../captive_portal/dns_server_esp32_idf.h | 29 +++ 5 files changed, 302 insertions(+), 45 deletions(-) create mode 100644 esphome/components/captive_portal/dns_server_esp32_idf.cpp create mode 100644 esphome/components/captive_portal/dns_server_esp32_idf.h diff --git a/esphome/components/captive_portal/__init__.py b/esphome/components/captive_portal/__init__.py index 9f2af0a230..4e0c0d6093 100644 --- a/esphome/components/captive_portal/__init__.py +++ b/esphome/components/captive_portal/__init__.py @@ -1,6 +1,7 @@ import esphome.codegen as cg from esphome.components import web_server_base from esphome.components.web_server_base import CONF_WEB_SERVER_BASE_ID +from esphome.config_helpers import filter_source_files_from_platform import esphome.config_validation as cv from esphome.const import ( CONF_ID, @@ -9,6 +10,7 @@ from esphome.const import ( PLATFORM_ESP8266, PLATFORM_LN882X, PLATFORM_RTL87XX, + PlatformFramework, ) from esphome.core import CORE, coroutine_with_priority from esphome.coroutine import CoroPriority @@ -58,3 +60,11 @@ async def to_code(config): cg.add_library("DNSServer", None) if CORE.is_libretiny: cg.add_library("DNSServer", None) + + +# Only compile the ESP-IDF DNS server when using ESP-IDF framework +FILTER_SOURCE_FILES = filter_source_files_from_platform( + { + "dns_server_esp32_idf.cpp": {PlatformFramework.ESP32_IDF}, + } +) diff --git a/esphome/components/captive_portal/captive_portal.cpp b/esphome/components/captive_portal/captive_portal.cpp index 7eb0ffa99e..6873f8e93c 100644 --- a/esphome/components/captive_portal/captive_portal.cpp +++ b/esphome/components/captive_portal/captive_portal.cpp @@ -57,7 +57,7 @@ void CaptivePortal::handle_wifisave(AsyncWebServerRequest *request) { void CaptivePortal::setup() { #ifndef USE_ARDUINO - // No DNS server needed for non-Arduino frameworks + // Disable loop for non-Arduino frameworks (DNS runs in its own task on ESP-IDF) this->disable_loop(); #endif } @@ -67,51 +67,46 @@ void CaptivePortal::start() { this->base_->add_handler(this); } + network::IPAddress ip = wifi::global_wifi_component->wifi_soft_ap_ip(); + ESP_LOGI(TAG, "Starting captive portal on IP: %s", ip.str().c_str()); + +#ifdef USE_ESP_IDF + // Create DNS server instance for ESP-IDF + this->dns_server_ = make_unique(); + this->dns_server_->start(ip); +#endif #ifdef USE_ARDUINO this->dns_server_ = make_unique(); this->dns_server_->setErrorReplyCode(DNSReplyCode::NoError); - network::IPAddress ip = wifi::global_wifi_component->wifi_soft_ap_ip(); this->dns_server_->start(53, F("*"), ip); // Re-enable loop() when DNS server is started this->enable_loop(); #endif - this->base_->get_server()->onNotFound([this](AsyncWebServerRequest *req) { - if (!this->active_ || req->host().c_str() == wifi::global_wifi_component->wifi_soft_ap_ip().str()) { - req->send(404, F("text/html"), F("File not found")); - return; - } - -#ifdef USE_ESP8266 - String url = F("http://"); - url += wifi::global_wifi_component->wifi_soft_ap_ip().str().c_str(); -#else - auto url = "http://" + wifi::global_wifi_component->wifi_soft_ap_ip().str(); -#endif - req->redirect(url.c_str()); - }); - this->initialized_ = true; this->active_ = true; + ESP_LOGI(TAG, "Captive portal started"); } void CaptivePortal::handleRequest(AsyncWebServerRequest *req) { - if (req->url() == F("/")) { -#ifndef USE_ESP8266 - auto *response = req->beginResponse(200, F("text/html"), INDEX_GZ, sizeof(INDEX_GZ)); -#else - auto *response = req->beginResponse_P(200, F("text/html"), INDEX_GZ, sizeof(INDEX_GZ)); -#endif - response->addHeader(F("Content-Encoding"), F("gzip")); - req->send(response); - return; - } else if (req->url() == F("/config.json")) { + if (req->url() == F("/config.json")) { this->handle_config(req); return; } else if (req->url() == F("/wifisave")) { this->handle_wifisave(req); return; } + + // All other requests get the captive portal page + // This includes OS captive portal detection endpoints which will trigger + // the captive portal when they don't receive their expected responses +#ifndef USE_ESP8266 + auto *response = req->beginResponse(200, F("text/html"), INDEX_GZ, sizeof(INDEX_GZ)); +#else + auto *response = req->beginResponse_P(200, F("text/html"), INDEX_GZ, sizeof(INDEX_GZ)); +#endif + response->addHeader(F("Content-Encoding"), F("gzip")); + req->send(response); } CaptivePortal::CaptivePortal(web_server_base::WebServerBase *base) : base_(base) { global_captive_portal = this; } diff --git a/esphome/components/captive_portal/captive_portal.h b/esphome/components/captive_portal/captive_portal.h index 382afe92f0..705af8ab45 100644 --- a/esphome/components/captive_portal/captive_portal.h +++ b/esphome/components/captive_portal/captive_portal.h @@ -5,6 +5,9 @@ #ifdef USE_ARDUINO #include #endif +#ifdef USE_ESP_IDF +#include "dns_server_esp32_idf.h" +#endif #include "esphome/core/component.h" #include "esphome/core/helpers.h" #include "esphome/core/preferences.h" @@ -34,26 +37,14 @@ class CaptivePortal : public AsyncWebHandler, public Component { void end() { this->active_ = false; this->base_->deinit(); -#ifdef USE_ARDUINO - this->dns_server_->stop(); - this->dns_server_ = nullptr; -#endif + if (this->dns_server_ != nullptr) { + this->dns_server_->stop(); + this->dns_server_ = nullptr; + } } bool canHandle(AsyncWebServerRequest *request) const override { - if (!this->active_) - return false; - - if (request->method() == HTTP_GET) { - if (request->url() == F("/")) - return true; - if (request->url() == F("/config.json")) - return true; - if (request->url() == F("/wifisave")) - return true; - } - - return false; + return this->active_ && request->method() == HTTP_GET; } void handle_config(AsyncWebServerRequest *request); @@ -66,7 +57,7 @@ class CaptivePortal : public AsyncWebHandler, public Component { web_server_base::WebServerBase *base_; bool initialized_{false}; bool active_{false}; -#ifdef USE_ARDUINO +#if defined(USE_ARDUINO) || defined(USE_ESP_IDF) std::unique_ptr dns_server_{nullptr}; #endif }; diff --git a/esphome/components/captive_portal/dns_server_esp32_idf.cpp b/esphome/components/captive_portal/dns_server_esp32_idf.cpp new file mode 100644 index 0000000000..a32e268f20 --- /dev/null +++ b/esphome/components/captive_portal/dns_server_esp32_idf.cpp @@ -0,0 +1,232 @@ +#include "dns_server_esp32_idf.h" +#ifdef USE_ESP_IDF + +#include "esphome/core/log.h" +#include "esphome/core/hal.h" +#include +#include +#include + +namespace esphome::captive_portal { + +static const char *const TAG = "captive_portal.dns"; + +// DNS constants +static constexpr uint16_t DNS_PORT = 53; +static constexpr uint16_t DNS_MAX_LEN = 256; +static constexpr uint16_t DNS_QR_FLAG = 1 << 15; +static constexpr uint16_t DNS_OPCODE_MASK = 0x7800; +static constexpr uint16_t DNS_QTYPE_A = 0x0001; +static constexpr uint16_t DNS_QCLASS_IN = 0x0001; +static constexpr uint16_t DNS_ANSWER_TTL = 300; +static constexpr size_t DNS_TASK_STACK_SIZE = 3072; + +// DNS Header structure +struct DNSHeader { + uint16_t id; + uint16_t flags; + uint16_t qd_count; + uint16_t an_count; + uint16_t ns_count; + uint16_t ar_count; +} __attribute__((packed)); + +// DNS Question structure +struct DNSQuestion { + uint16_t type; + uint16_t dns_class; +} __attribute__((packed)); + +// DNS Answer structure +struct DNSAnswer { + uint16_t ptr_offset; + uint16_t type; + uint16_t dns_class; + uint32_t ttl; + uint16_t addr_len; + uint32_t ip_addr; +} __attribute__((packed)); + +DNSServer::~DNSServer() { this->stop(); } + +void DNSServer::start(const network::IPAddress &ip) { + this->server_ip_ = ip; + ESP_LOGI(TAG, "Starting DNS server on %s", ip.str().c_str()); + + // Create socket + this->dns_socket_ = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); + if (this->dns_socket_ < 0) { + ESP_LOGE(TAG, "Socket create failed: %d", errno); + return; + } + ESP_LOGD(TAG, "Socket created: %d", this->dns_socket_); + + // Set socket options + int enable = 1; + if (setsockopt(this->dns_socket_, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(enable)) < 0) { + ESP_LOGW(TAG, "SO_REUSEADDR failed: %d", errno); + } + + // Bind to port 53 + struct sockaddr_in server_addr = {}; + server_addr.sin_family = AF_INET; + server_addr.sin_addr.s_addr = htonl(INADDR_ANY); + server_addr.sin_port = htons(DNS_PORT); + + if (bind(this->dns_socket_, (struct sockaddr *) &server_addr, sizeof(server_addr)) < 0) { + ESP_LOGE(TAG, "Bind failed: %d", errno); + close(this->dns_socket_); + this->dns_socket_ = -1; + return; + } + ESP_LOGD(TAG, "Bound to port %d", DNS_PORT); + + // Create task + BaseType_t task_result = + xTaskCreate(&DNSServer::dns_server_task, "dns_server", DNS_TASK_STACK_SIZE, this, 1, &this->dns_task_handle_); + if (task_result != pdPASS) { + ESP_LOGE(TAG, "Task create failed"); + close(this->dns_socket_); + this->dns_socket_ = -1; + return; + } +} + +void DNSServer::stop() { + if (this->dns_task_handle_) { + vTaskDelete(this->dns_task_handle_); + this->dns_task_handle_ = nullptr; + } + + if (this->dns_socket_ >= 0) { + close(this->dns_socket_); + this->dns_socket_ = -1; + } + + ESP_LOGI(TAG, "Stopped"); +} + +void DNSServer::dns_server_task(void *pvParameters) { + DNSServer *server = static_cast(pvParameters); + ESP_LOGV(TAG, "Task started, socket: %d", server->dns_socket_); + + // Set socket timeout to prevent blocking forever + struct timeval timeout; + timeout.tv_sec = 1; + timeout.tv_usec = 0; + if (setsockopt(server->dns_socket_, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0) { + ESP_LOGW(TAG, "SO_RCVTIMEO failed: %d", errno); + } + + while (true) { + server->process_dns_request(server->dns_socket_); + } +} + +void DNSServer::process_dns_request(int sock) { + struct sockaddr_in client_addr; + socklen_t client_addr_len = sizeof(client_addr); + uint8_t rx_buffer[DNS_MAX_LEN]; + uint8_t tx_buffer[DNS_MAX_LEN]; + + // Receive DNS request + int len = recvfrom(sock, rx_buffer, sizeof(rx_buffer), 0, (struct sockaddr *) &client_addr, &client_addr_len); + + if (len < 0) { + if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) { + ESP_LOGE(TAG, "recvfrom failed: %d", errno); + } + return; + } + + ESP_LOGVV(TAG, "Received %d bytes from %s:%d", len, inet_ntoa(client_addr.sin_addr), ntohs(client_addr.sin_port)); + + if (len < sizeof(DNSHeader) + 1) { + ESP_LOGW(TAG, "Request too short: %d", len); + return; + } + + // Parse DNS header + DNSHeader *header = (DNSHeader *) rx_buffer; + uint16_t flags = ntohs(header->flags); + uint16_t qd_count = ntohs(header->qd_count); + + // Check if it's a standard query + if ((flags & DNS_QR_FLAG) || (flags & DNS_OPCODE_MASK) || qd_count != 1) { + ESP_LOGV(TAG, "Not a standard query: flags=0x%04X, qd_count=%d", flags, qd_count); + return; // Not a standard query + } + + // Parse domain name (we don't actually care about it - redirect everything) + uint8_t *ptr = rx_buffer + sizeof(DNSHeader); + uint8_t *name_start = ptr; + while (*ptr != 0 && ptr < (rx_buffer + len)) { + if (*ptr > 63) { // Check for invalid label length + return; + } + ptr += *ptr + 1; + } + + if (*ptr != 0) { + return; // Name not terminated + } + ptr++; // Skip the null terminator + + // Check we have room for the question + if (ptr + sizeof(DNSQuestion) > rx_buffer + len) { + return; // Request truncated + } + + // Parse DNS question + DNSQuestion *question = (DNSQuestion *) ptr; + uint16_t qtype = ntohs(question->type); + uint16_t qclass = ntohs(question->dns_class); + + // We only handle A queries + if (qtype != DNS_QTYPE_A || qclass != DNS_QCLASS_IN) { + ESP_LOGV(TAG, "Not an A query: type=0x%04X, class=0x%04X", qtype, qclass); + return; // Not an A query + } + + // Build DNS response + memset(tx_buffer, 0, sizeof(tx_buffer)); + + // Copy request header and modify flags + memcpy(tx_buffer, rx_buffer, sizeof(DNSHeader)); + DNSHeader *response_header = (DNSHeader *) tx_buffer; + response_header->flags = htons(DNS_QR_FLAG | 0x8000); // Response + Authoritative + response_header->an_count = htons(1); // One answer + + // Copy the question section + size_t question_len = (ptr + sizeof(DNSQuestion)) - rx_buffer - sizeof(DNSHeader); + memcpy(tx_buffer + sizeof(DNSHeader), rx_buffer + sizeof(DNSHeader), question_len); + + // Add answer section + size_t answer_offset = sizeof(DNSHeader) + question_len; + DNSAnswer *answer = (DNSAnswer *) (tx_buffer + answer_offset); + + // Pointer to name in question (offset from start of packet) + answer->ptr_offset = htons(0xC000 | sizeof(DNSHeader)); + answer->type = htons(DNS_QTYPE_A); + answer->dns_class = htons(DNS_QCLASS_IN); + answer->ttl = htonl(DNS_ANSWER_TTL); + answer->addr_len = htons(4); + + // Get the raw IP address + ip4_addr_t addr = this->server_ip_; + answer->ip_addr = addr.addr; + + size_t response_len = answer_offset + sizeof(DNSAnswer); + + // Send response + int sent = sendto(sock, tx_buffer, response_len, 0, (struct sockaddr *) &client_addr, client_addr_len); + if (sent < 0) { + ESP_LOGV(TAG, "Send failed: %d", errno); + } else { + ESP_LOGV(TAG, "Sent %d bytes", sent); + } +} + +} // namespace esphome::captive_portal + +#endif // USE_ESP_IDF diff --git a/esphome/components/captive_portal/dns_server_esp32_idf.h b/esphome/components/captive_portal/dns_server_esp32_idf.h new file mode 100644 index 0000000000..7cbf4490ed --- /dev/null +++ b/esphome/components/captive_portal/dns_server_esp32_idf.h @@ -0,0 +1,29 @@ +#pragma once +#ifdef USE_ESP_IDF + +#include "esphome/core/helpers.h" +#include "esphome/components/network/ip_address.h" +#include +#include + +namespace esphome::captive_portal { + +class DNSServer { + public: + ~DNSServer(); + + void start(const network::IPAddress &ip); + void stop(); + + protected: + static void dns_server_task(void *pvParameters); + void process_dns_request(int sock); + + TaskHandle_t dns_task_handle_{nullptr}; + int dns_socket_{-1}; + network::IPAddress server_ip_; +}; + +} // namespace esphome::captive_portal + +#endif // USE_ESP_IDF From 6b72736d5e3d74cfe8d4b1c258c0c1ba6736c31f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 27 Sep 2025 12:32:24 -0500 Subject: [PATCH 2/5] wip --- .../captive_portal/captive_portal.cpp | 2 +- .../captive_portal/captive_portal.h | 3 + .../captive_portal/dns_server_esp32_idf.cpp | 63 ++++++++++--------- .../captive_portal/dns_server_esp32_idf.h | 2 - 4 files changed, 37 insertions(+), 33 deletions(-) diff --git a/esphome/components/captive_portal/captive_portal.cpp b/esphome/components/captive_portal/captive_portal.cpp index 6873f8e93c..d5c820930a 100644 --- a/esphome/components/captive_portal/captive_portal.cpp +++ b/esphome/components/captive_portal/captive_portal.cpp @@ -85,7 +85,7 @@ void CaptivePortal::start() { this->initialized_ = true; this->active_ = true; - ESP_LOGI(TAG, "Captive portal started"); + ESP_LOGV(TAG, "Captive portal started"); } void CaptivePortal::handleRequest(AsyncWebServerRequest *req) { diff --git a/esphome/components/captive_portal/captive_portal.h b/esphome/components/captive_portal/captive_portal.h index 705af8ab45..f3d40ecae8 100644 --- a/esphome/components/captive_portal/captive_portal.h +++ b/esphome/components/captive_portal/captive_portal.h @@ -44,6 +44,9 @@ class CaptivePortal : public AsyncWebHandler, public Component { } bool canHandle(AsyncWebServerRequest *request) const override { + // Handle all GET requests when captive portal is active + // This allows us to respond with the portal page for any URL, + // triggering OS captive portal detection return this->active_ && request->method() == HTTP_GET; } diff --git a/esphome/components/captive_portal/dns_server_esp32_idf.cpp b/esphome/components/captive_portal/dns_server_esp32_idf.cpp index a32e268f20..503990637a 100644 --- a/esphome/components/captive_portal/dns_server_esp32_idf.cpp +++ b/esphome/components/captive_portal/dns_server_esp32_idf.cpp @@ -47,8 +47,6 @@ struct DNSAnswer { uint32_t ip_addr; } __attribute__((packed)); -DNSServer::~DNSServer() { this->stop(); } - void DNSServer::start(const network::IPAddress &ip) { this->server_ip_ = ip; ESP_LOGI(TAG, "Starting DNS server on %s", ip.str().c_str()); @@ -103,7 +101,7 @@ void DNSServer::stop() { this->dns_socket_ = -1; } - ESP_LOGI(TAG, "Stopped"); + ESP_LOGV(TAG, "Stopped"); } void DNSServer::dns_server_task(void *pvParameters) { @@ -126,11 +124,10 @@ void DNSServer::dns_server_task(void *pvParameters) { void DNSServer::process_dns_request(int sock) { struct sockaddr_in client_addr; socklen_t client_addr_len = sizeof(client_addr); - uint8_t rx_buffer[DNS_MAX_LEN]; - uint8_t tx_buffer[DNS_MAX_LEN]; + uint8_t buffer[DNS_MAX_LEN]; // Receive DNS request - int len = recvfrom(sock, rx_buffer, sizeof(rx_buffer), 0, (struct sockaddr *) &client_addr, &client_addr_len); + int len = recvfrom(sock, buffer, sizeof(buffer), 0, (struct sockaddr *) &client_addr, &client_addr_len); if (len < 0) { if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) { @@ -147,7 +144,7 @@ void DNSServer::process_dns_request(int sock) { } // Parse DNS header - DNSHeader *header = (DNSHeader *) rx_buffer; + DNSHeader *header = (DNSHeader *) buffer; uint16_t flags = ntohs(header->flags); uint16_t qd_count = ntohs(header->qd_count); @@ -158,22 +155,29 @@ void DNSServer::process_dns_request(int sock) { } // Parse domain name (we don't actually care about it - redirect everything) - uint8_t *ptr = rx_buffer + sizeof(DNSHeader); - uint8_t *name_start = ptr; - while (*ptr != 0 && ptr < (rx_buffer + len)) { - if (*ptr > 63) { // Check for invalid label length + uint8_t *ptr = buffer + sizeof(DNSHeader); + uint8_t *end = buffer + len; + + while (ptr < end && *ptr != 0) { + uint8_t label_len = *ptr; + if (label_len > 63) { // Check for invalid label length return; } - ptr += *ptr + 1; + // Check if we have room for this label plus the length byte + if (ptr + label_len + 1 > end) { + return; // Would overflow + } + ptr += label_len + 1; } - if (*ptr != 0) { - return; // Name not terminated + // Check if we reached a proper null terminator + if (ptr >= end || *ptr != 0) { + return; // Name not terminated or truncated } ptr++; // Skip the null terminator // Check we have room for the question - if (ptr + sizeof(DNSQuestion) > rx_buffer + len) { + if (ptr + sizeof(DNSQuestion) > end) { return; // Request truncated } @@ -188,22 +192,21 @@ void DNSServer::process_dns_request(int sock) { return; // Not an A query } - // Build DNS response - memset(tx_buffer, 0, sizeof(tx_buffer)); + // Build DNS response by modifying the request in-place + header->flags = htons(DNS_QR_FLAG | 0x8000); // Response + Authoritative + header->an_count = htons(1); // One answer - // Copy request header and modify flags - memcpy(tx_buffer, rx_buffer, sizeof(DNSHeader)); - DNSHeader *response_header = (DNSHeader *) tx_buffer; - response_header->flags = htons(DNS_QR_FLAG | 0x8000); // Response + Authoritative - response_header->an_count = htons(1); // One answer - - // Copy the question section - size_t question_len = (ptr + sizeof(DNSQuestion)) - rx_buffer - sizeof(DNSHeader); - memcpy(tx_buffer + sizeof(DNSHeader), rx_buffer + sizeof(DNSHeader), question_len); - - // Add answer section + // Add answer section after the question + size_t question_len = (ptr + sizeof(DNSQuestion)) - buffer - sizeof(DNSHeader); size_t answer_offset = sizeof(DNSHeader) + question_len; - DNSAnswer *answer = (DNSAnswer *) (tx_buffer + answer_offset); + + // Check if we have room for the answer + if (answer_offset + sizeof(DNSAnswer) > sizeof(buffer)) { + ESP_LOGW(TAG, "Response too large"); + return; + } + + DNSAnswer *answer = (DNSAnswer *) (buffer + answer_offset); // Pointer to name in question (offset from start of packet) answer->ptr_offset = htons(0xC000 | sizeof(DNSHeader)); @@ -219,7 +222,7 @@ void DNSServer::process_dns_request(int sock) { size_t response_len = answer_offset + sizeof(DNSAnswer); // Send response - int sent = sendto(sock, tx_buffer, response_len, 0, (struct sockaddr *) &client_addr, client_addr_len); + int sent = sendto(sock, buffer, response_len, 0, (struct sockaddr *) &client_addr, client_addr_len); if (sent < 0) { ESP_LOGV(TAG, "Send failed: %d", errno); } else { diff --git a/esphome/components/captive_portal/dns_server_esp32_idf.h b/esphome/components/captive_portal/dns_server_esp32_idf.h index 7cbf4490ed..87b5c76e44 100644 --- a/esphome/components/captive_portal/dns_server_esp32_idf.h +++ b/esphome/components/captive_portal/dns_server_esp32_idf.h @@ -10,8 +10,6 @@ namespace esphome::captive_portal { class DNSServer { public: - ~DNSServer(); - void start(const network::IPAddress &ip); void stop(); From 0356081961e53dbd4ab9455f165d6693305278ef Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 27 Sep 2025 12:47:00 -0500 Subject: [PATCH 3/5] make it captive --- .../captive_portal/captive_portal.cpp | 10 +- .../captive_portal/captive_portal.h | 12 ++- .../captive_portal/dns_server_esp32_idf.cpp | 95 +++++++------------ .../captive_portal/dns_server_esp32_idf.h | 12 +-- 4 files changed, 54 insertions(+), 75 deletions(-) diff --git a/esphome/components/captive_portal/captive_portal.cpp b/esphome/components/captive_portal/captive_portal.cpp index d5c820930a..daf66b6e12 100644 --- a/esphome/components/captive_portal/captive_portal.cpp +++ b/esphome/components/captive_portal/captive_portal.cpp @@ -56,10 +56,8 @@ void CaptivePortal::handle_wifisave(AsyncWebServerRequest *request) { } void CaptivePortal::setup() { -#ifndef USE_ARDUINO - // Disable loop for non-Arduino frameworks (DNS runs in its own task on ESP-IDF) + // Disable loop by default - will be enabled when captive portal starts this->disable_loop(); -#endif } void CaptivePortal::start() { this->base_->init(); @@ -79,12 +77,14 @@ void CaptivePortal::start() { this->dns_server_ = make_unique(); this->dns_server_->setErrorReplyCode(DNSReplyCode::NoError); this->dns_server_->start(53, F("*"), ip); - // Re-enable loop() when DNS server is started - this->enable_loop(); #endif this->initialized_ = true; this->active_ = true; + + // Enable loop() now that captive portal is active + this->enable_loop(); + ESP_LOGV(TAG, "Captive portal started"); } diff --git a/esphome/components/captive_portal/captive_portal.h b/esphome/components/captive_portal/captive_portal.h index f3d40ecae8..f48c286f0c 100644 --- a/esphome/components/captive_portal/captive_portal.h +++ b/esphome/components/captive_portal/captive_portal.h @@ -22,20 +22,24 @@ class CaptivePortal : public AsyncWebHandler, public Component { CaptivePortal(web_server_base::WebServerBase *base); void setup() override; void dump_config() override; -#ifdef USE_ARDUINO void loop() override { +#ifdef USE_ARDUINO if (this->dns_server_ != nullptr) { this->dns_server_->processNextRequest(); - } else { - this->disable_loop(); } - } #endif +#ifdef USE_ESP_IDF + if (this->dns_server_ != nullptr) { + this->dns_server_->process_next_request(); + } +#endif + } float get_setup_priority() const override; void start(); bool is_active() const { return this->active_; } void end() { this->active_ = false; + this->disable_loop(); // Stop processing DNS requests this->base_->deinit(); if (this->dns_server_ != nullptr) { this->dns_server_->stop(); diff --git a/esphome/components/captive_portal/dns_server_esp32_idf.cpp b/esphome/components/captive_portal/dns_server_esp32_idf.cpp index 503990637a..e60cbc851e 100644 --- a/esphome/components/captive_portal/dns_server_esp32_idf.cpp +++ b/esphome/components/captive_portal/dns_server_esp32_idf.cpp @@ -3,8 +3,8 @@ #include "esphome/core/log.h" #include "esphome/core/hal.h" +#include "esphome/components/socket/socket.h" #include -#include #include namespace esphome::captive_portal { @@ -51,83 +51,57 @@ void DNSServer::start(const network::IPAddress &ip) { this->server_ip_ = ip; ESP_LOGI(TAG, "Starting DNS server on %s", ip.str().c_str()); - // Create socket - this->dns_socket_ = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); - if (this->dns_socket_ < 0) { - ESP_LOGE(TAG, "Socket create failed: %d", errno); + // Create loop-monitored UDP socket + this->socket_ = socket::socket_ip_loop_monitored(SOCK_DGRAM, IPPROTO_UDP); + if (this->socket_ == nullptr) { + ESP_LOGE(TAG, "Socket create failed"); return; } - ESP_LOGD(TAG, "Socket created: %d", this->dns_socket_); // Set socket options int enable = 1; - if (setsockopt(this->dns_socket_, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(enable)) < 0) { - ESP_LOGW(TAG, "SO_REUSEADDR failed: %d", errno); - } + this->socket_->setsockopt(SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(enable)); // Bind to port 53 - struct sockaddr_in server_addr = {}; - server_addr.sin_family = AF_INET; - server_addr.sin_addr.s_addr = htonl(INADDR_ANY); - server_addr.sin_port = htons(DNS_PORT); + struct sockaddr_storage server_addr = {}; + socklen_t addr_len = socket::set_sockaddr_any((struct sockaddr *) &server_addr, sizeof(server_addr), DNS_PORT); - if (bind(this->dns_socket_, (struct sockaddr *) &server_addr, sizeof(server_addr)) < 0) { + int err = this->socket_->bind((struct sockaddr *) &server_addr, addr_len); + if (err != 0) { ESP_LOGE(TAG, "Bind failed: %d", errno); - close(this->dns_socket_); - this->dns_socket_ = -1; + this->socket_ = nullptr; return; } ESP_LOGD(TAG, "Bound to port %d", DNS_PORT); - - // Create task - BaseType_t task_result = - xTaskCreate(&DNSServer::dns_server_task, "dns_server", DNS_TASK_STACK_SIZE, this, 1, &this->dns_task_handle_); - if (task_result != pdPASS) { - ESP_LOGE(TAG, "Task create failed"); - close(this->dns_socket_); - this->dns_socket_ = -1; - return; - } } void DNSServer::stop() { - if (this->dns_task_handle_) { - vTaskDelete(this->dns_task_handle_); - this->dns_task_handle_ = nullptr; + if (this->socket_ != nullptr) { + this->socket_->close(); + this->socket_ = nullptr; } - - if (this->dns_socket_ >= 0) { - close(this->dns_socket_); - this->dns_socket_ = -1; - } - ESP_LOGV(TAG, "Stopped"); } -void DNSServer::dns_server_task(void *pvParameters) { - DNSServer *server = static_cast(pvParameters); - ESP_LOGV(TAG, "Task started, socket: %d", server->dns_socket_); - - // Set socket timeout to prevent blocking forever - struct timeval timeout; - timeout.tv_sec = 1; - timeout.tv_usec = 0; - if (setsockopt(server->dns_socket_, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0) { - ESP_LOGW(TAG, "SO_RCVTIMEO failed: %d", errno); - } - - while (true) { - server->process_dns_request(server->dns_socket_); +void DNSServer::process_next_request() { + // Process one request if socket is valid and data is available + if (this->socket_ != nullptr && this->socket_->ready()) { + this->process_dns_request(); } } -void DNSServer::process_dns_request(int sock) { +void DNSServer::process_dns_request() { struct sockaddr_in client_addr; socklen_t client_addr_len = sizeof(client_addr); - uint8_t buffer[DNS_MAX_LEN]; - // Receive DNS request - int len = recvfrom(sock, buffer, sizeof(buffer), 0, (struct sockaddr *) &client_addr, &client_addr_len); + // Receive DNS request using raw fd for recvfrom + int fd = this->socket_->get_fd(); + if (fd < 0) { + return; + } + + ssize_t len = recvfrom(fd, this->buffer_, sizeof(this->buffer_), MSG_DONTWAIT, (struct sockaddr *) &client_addr, + &client_addr_len); if (len < 0) { if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) { @@ -144,7 +118,7 @@ void DNSServer::process_dns_request(int sock) { } // Parse DNS header - DNSHeader *header = (DNSHeader *) buffer; + DNSHeader *header = (DNSHeader *) this->buffer_; uint16_t flags = ntohs(header->flags); uint16_t qd_count = ntohs(header->qd_count); @@ -155,8 +129,8 @@ void DNSServer::process_dns_request(int sock) { } // Parse domain name (we don't actually care about it - redirect everything) - uint8_t *ptr = buffer + sizeof(DNSHeader); - uint8_t *end = buffer + len; + uint8_t *ptr = this->buffer_ + sizeof(DNSHeader); + uint8_t *end = this->buffer_ + len; while (ptr < end && *ptr != 0) { uint8_t label_len = *ptr; @@ -197,16 +171,16 @@ void DNSServer::process_dns_request(int sock) { header->an_count = htons(1); // One answer // Add answer section after the question - size_t question_len = (ptr + sizeof(DNSQuestion)) - buffer - sizeof(DNSHeader); + size_t question_len = (ptr + sizeof(DNSQuestion)) - this->buffer_ - sizeof(DNSHeader); size_t answer_offset = sizeof(DNSHeader) + question_len; // Check if we have room for the answer - if (answer_offset + sizeof(DNSAnswer) > sizeof(buffer)) { + if (answer_offset + sizeof(DNSAnswer) > sizeof(this->buffer_)) { ESP_LOGW(TAG, "Response too large"); return; } - DNSAnswer *answer = (DNSAnswer *) (buffer + answer_offset); + DNSAnswer *answer = (DNSAnswer *) (this->buffer_ + answer_offset); // Pointer to name in question (offset from start of packet) answer->ptr_offset = htons(0xC000 | sizeof(DNSHeader)); @@ -222,7 +196,8 @@ void DNSServer::process_dns_request(int sock) { size_t response_len = answer_offset + sizeof(DNSAnswer); // Send response - int sent = sendto(sock, buffer, response_len, 0, (struct sockaddr *) &client_addr, client_addr_len); + ssize_t sent = + this->socket_->sendto(this->buffer_, response_len, 0, (struct sockaddr *) &client_addr, client_addr_len); if (sent < 0) { ESP_LOGV(TAG, "Send failed: %d", errno); } else { diff --git a/esphome/components/captive_portal/dns_server_esp32_idf.h b/esphome/components/captive_portal/dns_server_esp32_idf.h index 87b5c76e44..cec039b332 100644 --- a/esphome/components/captive_portal/dns_server_esp32_idf.h +++ b/esphome/components/captive_portal/dns_server_esp32_idf.h @@ -1,10 +1,10 @@ #pragma once #ifdef USE_ESP_IDF +#include #include "esphome/core/helpers.h" #include "esphome/components/network/ip_address.h" -#include -#include +#include "esphome/components/socket/socket.h" namespace esphome::captive_portal { @@ -12,14 +12,14 @@ class DNSServer { public: void start(const network::IPAddress &ip); void stop(); + void process_next_request(); protected: - static void dns_server_task(void *pvParameters); - void process_dns_request(int sock); + void process_dns_request(); - TaskHandle_t dns_task_handle_{nullptr}; - int dns_socket_{-1}; + std::unique_ptr socket_{nullptr}; network::IPAddress server_ip_; + uint8_t buffer_[256]; // DNS_MAX_LEN }; } // namespace esphome::captive_portal From 29943bfef195fa6a2d033c6f8bd674677430a3ae Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 27 Sep 2025 12:48:09 -0500 Subject: [PATCH 4/5] preen --- esphome/components/captive_portal/dns_server_esp32_idf.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/esphome/components/captive_portal/dns_server_esp32_idf.h b/esphome/components/captive_portal/dns_server_esp32_idf.h index cec039b332..9837eca1ae 100644 --- a/esphome/components/captive_portal/dns_server_esp32_idf.h +++ b/esphome/components/captive_portal/dns_server_esp32_idf.h @@ -15,11 +15,13 @@ class DNSServer { void process_next_request(); protected: + static constexpr size_t DNS_BUFFER_SIZE = 256; + void process_dns_request(); std::unique_ptr socket_{nullptr}; network::IPAddress server_ip_; - uint8_t buffer_[256]; // DNS_MAX_LEN + uint8_t buffer_[DNS_BUFFER_SIZE]; }; } // namespace esphome::captive_portal From 72c1830b9b4382f31c4f7d38e794efafd00b26ab Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 27 Sep 2025 12:49:08 -0500 Subject: [PATCH 5/5] preen --- esphome/components/captive_portal/dns_server_esp32_idf.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/esphome/components/captive_portal/dns_server_esp32_idf.cpp b/esphome/components/captive_portal/dns_server_esp32_idf.cpp index e60cbc851e..1e93c0abd0 100644 --- a/esphome/components/captive_portal/dns_server_esp32_idf.cpp +++ b/esphome/components/captive_portal/dns_server_esp32_idf.cpp @@ -13,13 +13,11 @@ static const char *const TAG = "captive_portal.dns"; // DNS constants static constexpr uint16_t DNS_PORT = 53; -static constexpr uint16_t DNS_MAX_LEN = 256; static constexpr uint16_t DNS_QR_FLAG = 1 << 15; static constexpr uint16_t DNS_OPCODE_MASK = 0x7800; static constexpr uint16_t DNS_QTYPE_A = 0x0001; static constexpr uint16_t DNS_QCLASS_IN = 0x0001; static constexpr uint16_t DNS_ANSWER_TTL = 300; -static constexpr size_t DNS_TASK_STACK_SIZE = 3072; // DNS Header structure struct DNSHeader {