1
0
mirror of https://github.com/esphome/esphome.git synced 2025-09-29 16:42:19 +01:00

Merge branch 'make_captive_portal_captive' into integration

This commit is contained in:
J. Nick Koston
2025-09-27 13:04:14 -05:00
5 changed files with 293 additions and 53 deletions

View File

@@ -1,6 +1,7 @@
import esphome.codegen as cg
from esphome.components import web_server_base
from esphome.components.web_server_base import CONF_WEB_SERVER_BASE_ID
from esphome.config_helpers import filter_source_files_from_platform
import esphome.config_validation as cv
from esphome.const import (
CONF_ID,
@@ -9,6 +10,7 @@ from esphome.const import (
PLATFORM_ESP8266,
PLATFORM_LN882X,
PLATFORM_RTL87XX,
PlatformFramework,
)
from esphome.core import CORE, coroutine_with_priority
from esphome.coroutine import CoroPriority
@@ -58,3 +60,11 @@ async def to_code(config):
cg.add_library("DNSServer", None)
if CORE.is_libretiny:
cg.add_library("DNSServer", None)
# Only compile the ESP-IDF DNS server when using ESP-IDF framework
FILTER_SOURCE_FILES = filter_source_files_from_platform(
{
"dns_server_esp32_idf.cpp": {PlatformFramework.ESP32_IDF},
}
)

View File

@@ -56,10 +56,8 @@ void CaptivePortal::handle_wifisave(AsyncWebServerRequest *request) {
}
void CaptivePortal::setup() {
#ifndef USE_ARDUINO
// No DNS server needed for non-Arduino frameworks
// Disable loop by default - will be enabled when captive portal starts
this->disable_loop();
#endif
}
void CaptivePortal::start() {
this->base_->init();
@@ -67,51 +65,48 @@ void CaptivePortal::start() {
this->base_->add_handler(this);
}
network::IPAddress ip = wifi::global_wifi_component->wifi_soft_ap_ip();
ESP_LOGI(TAG, "Starting captive portal on IP: %s", ip.str().c_str());
#ifdef USE_ESP_IDF
// Create DNS server instance for ESP-IDF
this->dns_server_ = make_unique<DNSServer>();
this->dns_server_->start(ip);
#endif
#ifdef USE_ARDUINO
this->dns_server_ = make_unique<DNSServer>();
this->dns_server_->setErrorReplyCode(DNSReplyCode::NoError);
network::IPAddress ip = wifi::global_wifi_component->wifi_soft_ap_ip();
this->dns_server_->start(53, F("*"), ip);
// Re-enable loop() when DNS server is started
this->enable_loop();
#endif
this->base_->get_server()->onNotFound([this](AsyncWebServerRequest *req) {
if (!this->active_ || req->host().c_str() == wifi::global_wifi_component->wifi_soft_ap_ip().str()) {
req->send(404, F("text/html"), F("File not found"));
return;
}
#ifdef USE_ESP8266
String url = F("http://");
url += wifi::global_wifi_component->wifi_soft_ap_ip().str().c_str();
#else
auto url = "http://" + wifi::global_wifi_component->wifi_soft_ap_ip().str();
#endif
req->redirect(url.c_str());
});
this->initialized_ = true;
this->active_ = true;
// Enable loop() now that captive portal is active
this->enable_loop();
ESP_LOGV(TAG, "Captive portal started");
}
void CaptivePortal::handleRequest(AsyncWebServerRequest *req) {
if (req->url() == F("/")) {
#ifndef USE_ESP8266
auto *response = req->beginResponse(200, F("text/html"), INDEX_GZ, sizeof(INDEX_GZ));
#else
auto *response = req->beginResponse_P(200, F("text/html"), INDEX_GZ, sizeof(INDEX_GZ));
#endif
response->addHeader(F("Content-Encoding"), F("gzip"));
req->send(response);
return;
} else if (req->url() == F("/config.json")) {
if (req->url() == F("/config.json")) {
this->handle_config(req);
return;
} else if (req->url() == F("/wifisave")) {
this->handle_wifisave(req);
return;
}
// All other requests get the captive portal page
// This includes OS captive portal detection endpoints which will trigger
// the captive portal when they don't receive their expected responses
#ifndef USE_ESP8266
auto *response = req->beginResponse(200, F("text/html"), INDEX_GZ, sizeof(INDEX_GZ));
#else
auto *response = req->beginResponse_P(200, F("text/html"), INDEX_GZ, sizeof(INDEX_GZ));
#endif
response->addHeader(F("Content-Encoding"), F("gzip"));
req->send(response);
}
CaptivePortal::CaptivePortal(web_server_base::WebServerBase *base) : base_(base) { global_captive_portal = this; }

View File

@@ -5,6 +5,9 @@
#ifdef USE_ARDUINO
#include <DNSServer.h>
#endif
#ifdef USE_ESP_IDF
#include "dns_server_esp32_idf.h"
#endif
#include "esphome/core/component.h"
#include "esphome/core/helpers.h"
#include "esphome/core/preferences.h"
@@ -19,41 +22,36 @@ class CaptivePortal : public AsyncWebHandler, public Component {
CaptivePortal(web_server_base::WebServerBase *base);
void setup() override;
void dump_config() override;
#ifdef USE_ARDUINO
void loop() override {
#ifdef USE_ARDUINO
if (this->dns_server_ != nullptr) {
this->dns_server_->processNextRequest();
} else {
this->disable_loop();
}
}
#endif
#ifdef USE_ESP_IDF
if (this->dns_server_ != nullptr) {
this->dns_server_->process_next_request();
}
#endif
}
float get_setup_priority() const override;
void start();
bool is_active() const { return this->active_; }
void end() {
this->active_ = false;
this->disable_loop(); // Stop processing DNS requests
this->base_->deinit();
#ifdef USE_ARDUINO
this->dns_server_->stop();
this->dns_server_ = nullptr;
#endif
if (this->dns_server_ != nullptr) {
this->dns_server_->stop();
this->dns_server_ = nullptr;
}
}
bool canHandle(AsyncWebServerRequest *request) const override {
if (!this->active_)
return false;
if (request->method() == HTTP_GET) {
if (request->url() == F("/"))
return true;
if (request->url() == F("/config.json"))
return true;
if (request->url() == F("/wifisave"))
return true;
}
return false;
// Handle all GET requests when captive portal is active
// This allows us to respond with the portal page for any URL,
// triggering OS captive portal detection
return this->active_ && request->method() == HTTP_GET;
}
void handle_config(AsyncWebServerRequest *request);
@@ -66,7 +64,7 @@ class CaptivePortal : public AsyncWebHandler, public Component {
web_server_base::WebServerBase *base_;
bool initialized_{false};
bool active_{false};
#ifdef USE_ARDUINO
#if defined(USE_ARDUINO) || defined(USE_ESP_IDF)
std::unique_ptr<DNSServer> dns_server_{nullptr};
#endif
};

View File

@@ -0,0 +1,208 @@
#include "dns_server_esp32_idf.h"
#ifdef USE_ESP_IDF
#include "esphome/core/log.h"
#include "esphome/core/hal.h"
#include "esphome/components/socket/socket.h"
#include <lwip/sockets.h>
#include <lwip/inet.h>
namespace esphome::captive_portal {
static const char *const TAG = "captive_portal.dns";
// DNS constants
static constexpr uint16_t DNS_PORT = 53;
static constexpr uint16_t DNS_QR_FLAG = 1 << 15;
static constexpr uint16_t DNS_OPCODE_MASK = 0x7800;
static constexpr uint16_t DNS_QTYPE_A = 0x0001;
static constexpr uint16_t DNS_QCLASS_IN = 0x0001;
static constexpr uint16_t DNS_ANSWER_TTL = 300;
// DNS Header structure
struct DNSHeader {
uint16_t id;
uint16_t flags;
uint16_t qd_count;
uint16_t an_count;
uint16_t ns_count;
uint16_t ar_count;
} __attribute__((packed));
// DNS Question structure
struct DNSQuestion {
uint16_t type;
uint16_t dns_class;
} __attribute__((packed));
// DNS Answer structure
struct DNSAnswer {
uint16_t ptr_offset;
uint16_t type;
uint16_t dns_class;
uint32_t ttl;
uint16_t addr_len;
uint32_t ip_addr;
} __attribute__((packed));
void DNSServer::start(const network::IPAddress &ip) {
this->server_ip_ = ip;
ESP_LOGI(TAG, "Starting DNS server on %s", ip.str().c_str());
// Create loop-monitored UDP socket
this->socket_ = socket::socket_ip_loop_monitored(SOCK_DGRAM, IPPROTO_UDP);
if (this->socket_ == nullptr) {
ESP_LOGE(TAG, "Socket create failed");
return;
}
// Set socket options
int enable = 1;
this->socket_->setsockopt(SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(enable));
// Bind to port 53
struct sockaddr_storage server_addr = {};
socklen_t addr_len = socket::set_sockaddr_any((struct sockaddr *) &server_addr, sizeof(server_addr), DNS_PORT);
int err = this->socket_->bind((struct sockaddr *) &server_addr, addr_len);
if (err != 0) {
ESP_LOGE(TAG, "Bind failed: %d", errno);
this->socket_ = nullptr;
return;
}
ESP_LOGD(TAG, "Bound to port %d", DNS_PORT);
}
void DNSServer::stop() {
if (this->socket_ != nullptr) {
this->socket_->close();
this->socket_ = nullptr;
}
ESP_LOGV(TAG, "Stopped");
}
void DNSServer::process_next_request() {
// Process one request if socket is valid and data is available
if (this->socket_ != nullptr && this->socket_->ready()) {
this->process_dns_request();
}
}
void DNSServer::process_dns_request() {
struct sockaddr_in client_addr;
socklen_t client_addr_len = sizeof(client_addr);
// Receive DNS request using raw fd for recvfrom
int fd = this->socket_->get_fd();
if (fd < 0) {
return;
}
ssize_t len = recvfrom(fd, this->buffer_, sizeof(this->buffer_), MSG_DONTWAIT, (struct sockaddr *) &client_addr,
&client_addr_len);
if (len < 0) {
if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
ESP_LOGE(TAG, "recvfrom failed: %d", errno);
}
return;
}
ESP_LOGVV(TAG, "Received %d bytes from %s:%d", len, inet_ntoa(client_addr.sin_addr), ntohs(client_addr.sin_port));
if (len < sizeof(DNSHeader) + 1) {
ESP_LOGW(TAG, "Request too short: %d", len);
return;
}
// Parse DNS header
DNSHeader *header = (DNSHeader *) this->buffer_;
uint16_t flags = ntohs(header->flags);
uint16_t qd_count = ntohs(header->qd_count);
// Check if it's a standard query
if ((flags & DNS_QR_FLAG) || (flags & DNS_OPCODE_MASK) || qd_count != 1) {
ESP_LOGV(TAG, "Not a standard query: flags=0x%04X, qd_count=%d", flags, qd_count);
return; // Not a standard query
}
// Parse domain name (we don't actually care about it - redirect everything)
uint8_t *ptr = this->buffer_ + sizeof(DNSHeader);
uint8_t *end = this->buffer_ + len;
while (ptr < end && *ptr != 0) {
uint8_t label_len = *ptr;
if (label_len > 63) { // Check for invalid label length
return;
}
// Check if we have room for this label plus the length byte
if (ptr + label_len + 1 > end) {
return; // Would overflow
}
ptr += label_len + 1;
}
// Check if we reached a proper null terminator
if (ptr >= end || *ptr != 0) {
return; // Name not terminated or truncated
}
ptr++; // Skip the null terminator
// Check we have room for the question
if (ptr + sizeof(DNSQuestion) > end) {
return; // Request truncated
}
// Parse DNS question
DNSQuestion *question = (DNSQuestion *) ptr;
uint16_t qtype = ntohs(question->type);
uint16_t qclass = ntohs(question->dns_class);
// We only handle A queries
if (qtype != DNS_QTYPE_A || qclass != DNS_QCLASS_IN) {
ESP_LOGV(TAG, "Not an A query: type=0x%04X, class=0x%04X", qtype, qclass);
return; // Not an A query
}
// Build DNS response by modifying the request in-place
header->flags = htons(DNS_QR_FLAG | 0x8000); // Response + Authoritative
header->an_count = htons(1); // One answer
// Add answer section after the question
size_t question_len = (ptr + sizeof(DNSQuestion)) - this->buffer_ - sizeof(DNSHeader);
size_t answer_offset = sizeof(DNSHeader) + question_len;
// Check if we have room for the answer
if (answer_offset + sizeof(DNSAnswer) > sizeof(this->buffer_)) {
ESP_LOGW(TAG, "Response too large");
return;
}
DNSAnswer *answer = (DNSAnswer *) (this->buffer_ + answer_offset);
// Pointer to name in question (offset from start of packet)
answer->ptr_offset = htons(0xC000 | sizeof(DNSHeader));
answer->type = htons(DNS_QTYPE_A);
answer->dns_class = htons(DNS_QCLASS_IN);
answer->ttl = htonl(DNS_ANSWER_TTL);
answer->addr_len = htons(4);
// Get the raw IP address
ip4_addr_t addr = this->server_ip_;
answer->ip_addr = addr.addr;
size_t response_len = answer_offset + sizeof(DNSAnswer);
// Send response
ssize_t sent =
this->socket_->sendto(this->buffer_, response_len, 0, (struct sockaddr *) &client_addr, client_addr_len);
if (sent < 0) {
ESP_LOGV(TAG, "Send failed: %d", errno);
} else {
ESP_LOGV(TAG, "Sent %d bytes", sent);
}
}
} // namespace esphome::captive_portal
#endif // USE_ESP_IDF

View File

@@ -0,0 +1,29 @@
#pragma once
#ifdef USE_ESP_IDF
#include <memory>
#include "esphome/core/helpers.h"
#include "esphome/components/network/ip_address.h"
#include "esphome/components/socket/socket.h"
namespace esphome::captive_portal {
class DNSServer {
public:
void start(const network::IPAddress &ip);
void stop();
void process_next_request();
protected:
static constexpr size_t DNS_BUFFER_SIZE = 256;
void process_dns_request();
std::unique_ptr<socket::Socket> socket_{nullptr};
network::IPAddress server_ip_;
uint8_t buffer_[DNS_BUFFER_SIZE];
};
} // namespace esphome::captive_portal
#endif // USE_ESP_IDF