diff --git a/esphome/components/captive_portal/captive_portal.cpp b/esphome/components/captive_portal/captive_portal.cpp index 6873f8e93c..d5c820930a 100644 --- a/esphome/components/captive_portal/captive_portal.cpp +++ b/esphome/components/captive_portal/captive_portal.cpp @@ -85,7 +85,7 @@ void CaptivePortal::start() { this->initialized_ = true; this->active_ = true; - ESP_LOGI(TAG, "Captive portal started"); + ESP_LOGV(TAG, "Captive portal started"); } void CaptivePortal::handleRequest(AsyncWebServerRequest *req) { diff --git a/esphome/components/captive_portal/captive_portal.h b/esphome/components/captive_portal/captive_portal.h index 705af8ab45..f3d40ecae8 100644 --- a/esphome/components/captive_portal/captive_portal.h +++ b/esphome/components/captive_portal/captive_portal.h @@ -44,6 +44,9 @@ class CaptivePortal : public AsyncWebHandler, public Component { } bool canHandle(AsyncWebServerRequest *request) const override { + // Handle all GET requests when captive portal is active + // This allows us to respond with the portal page for any URL, + // triggering OS captive portal detection return this->active_ && request->method() == HTTP_GET; } diff --git a/esphome/components/captive_portal/dns_server_esp32_idf.cpp b/esphome/components/captive_portal/dns_server_esp32_idf.cpp index a32e268f20..503990637a 100644 --- a/esphome/components/captive_portal/dns_server_esp32_idf.cpp +++ b/esphome/components/captive_portal/dns_server_esp32_idf.cpp @@ -47,8 +47,6 @@ struct DNSAnswer { uint32_t ip_addr; } __attribute__((packed)); -DNSServer::~DNSServer() { this->stop(); } - void DNSServer::start(const network::IPAddress &ip) { this->server_ip_ = ip; ESP_LOGI(TAG, "Starting DNS server on %s", ip.str().c_str()); @@ -103,7 +101,7 @@ void DNSServer::stop() { this->dns_socket_ = -1; } - ESP_LOGI(TAG, "Stopped"); + ESP_LOGV(TAG, "Stopped"); } void DNSServer::dns_server_task(void *pvParameters) { @@ -126,11 +124,10 @@ void DNSServer::dns_server_task(void *pvParameters) { void DNSServer::process_dns_request(int sock) { struct sockaddr_in client_addr; socklen_t client_addr_len = sizeof(client_addr); - uint8_t rx_buffer[DNS_MAX_LEN]; - uint8_t tx_buffer[DNS_MAX_LEN]; + uint8_t buffer[DNS_MAX_LEN]; // Receive DNS request - int len = recvfrom(sock, rx_buffer, sizeof(rx_buffer), 0, (struct sockaddr *) &client_addr, &client_addr_len); + int len = recvfrom(sock, buffer, sizeof(buffer), 0, (struct sockaddr *) &client_addr, &client_addr_len); if (len < 0) { if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) { @@ -147,7 +144,7 @@ void DNSServer::process_dns_request(int sock) { } // Parse DNS header - DNSHeader *header = (DNSHeader *) rx_buffer; + DNSHeader *header = (DNSHeader *) buffer; uint16_t flags = ntohs(header->flags); uint16_t qd_count = ntohs(header->qd_count); @@ -158,22 +155,29 @@ void DNSServer::process_dns_request(int sock) { } // Parse domain name (we don't actually care about it - redirect everything) - uint8_t *ptr = rx_buffer + sizeof(DNSHeader); - uint8_t *name_start = ptr; - while (*ptr != 0 && ptr < (rx_buffer + len)) { - if (*ptr > 63) { // Check for invalid label length + uint8_t *ptr = buffer + sizeof(DNSHeader); + uint8_t *end = buffer + len; + + while (ptr < end && *ptr != 0) { + uint8_t label_len = *ptr; + if (label_len > 63) { // Check for invalid label length return; } - ptr += *ptr + 1; + // Check if we have room for this label plus the length byte + if (ptr + label_len + 1 > end) { + return; // Would overflow + } + ptr += label_len + 1; } - if (*ptr != 0) { - return; // Name not terminated + // Check if we reached a proper null terminator + if (ptr >= end || *ptr != 0) { + return; // Name not terminated or truncated } ptr++; // Skip the null terminator // Check we have room for the question - if (ptr + sizeof(DNSQuestion) > rx_buffer + len) { + if (ptr + sizeof(DNSQuestion) > end) { return; // Request truncated } @@ -188,22 +192,21 @@ void DNSServer::process_dns_request(int sock) { return; // Not an A query } - // Build DNS response - memset(tx_buffer, 0, sizeof(tx_buffer)); + // Build DNS response by modifying the request in-place + header->flags = htons(DNS_QR_FLAG | 0x8000); // Response + Authoritative + header->an_count = htons(1); // One answer - // Copy request header and modify flags - memcpy(tx_buffer, rx_buffer, sizeof(DNSHeader)); - DNSHeader *response_header = (DNSHeader *) tx_buffer; - response_header->flags = htons(DNS_QR_FLAG | 0x8000); // Response + Authoritative - response_header->an_count = htons(1); // One answer - - // Copy the question section - size_t question_len = (ptr + sizeof(DNSQuestion)) - rx_buffer - sizeof(DNSHeader); - memcpy(tx_buffer + sizeof(DNSHeader), rx_buffer + sizeof(DNSHeader), question_len); - - // Add answer section + // Add answer section after the question + size_t question_len = (ptr + sizeof(DNSQuestion)) - buffer - sizeof(DNSHeader); size_t answer_offset = sizeof(DNSHeader) + question_len; - DNSAnswer *answer = (DNSAnswer *) (tx_buffer + answer_offset); + + // Check if we have room for the answer + if (answer_offset + sizeof(DNSAnswer) > sizeof(buffer)) { + ESP_LOGW(TAG, "Response too large"); + return; + } + + DNSAnswer *answer = (DNSAnswer *) (buffer + answer_offset); // Pointer to name in question (offset from start of packet) answer->ptr_offset = htons(0xC000 | sizeof(DNSHeader)); @@ -219,7 +222,7 @@ void DNSServer::process_dns_request(int sock) { size_t response_len = answer_offset + sizeof(DNSAnswer); // Send response - int sent = sendto(sock, tx_buffer, response_len, 0, (struct sockaddr *) &client_addr, client_addr_len); + int sent = sendto(sock, buffer, response_len, 0, (struct sockaddr *) &client_addr, client_addr_len); if (sent < 0) { ESP_LOGV(TAG, "Send failed: %d", errno); } else { diff --git a/esphome/components/captive_portal/dns_server_esp32_idf.h b/esphome/components/captive_portal/dns_server_esp32_idf.h index 7cbf4490ed..87b5c76e44 100644 --- a/esphome/components/captive_portal/dns_server_esp32_idf.h +++ b/esphome/components/captive_portal/dns_server_esp32_idf.h @@ -10,8 +10,6 @@ namespace esphome::captive_portal { class DNSServer { public: - ~DNSServer(); - void start(const network::IPAddress &ip); void stop();