From a39b2c4ac708e439b26781a0f3f11c9f24ad10f8 Mon Sep 17 00:00:00 2001 From: Otto winter Date: Wed, 2 Feb 2022 22:38:27 +0100 Subject: [PATCH] Add support for socket client mode and getaddrinfo --- esphome/components/socket/__init__.py | 4 + .../components/socket/bsd_sockets_impl.cpp | 37 ++++ esphome/components/socket/getaddrinfo.h | 33 +++ esphome/components/socket/headers.h | 30 +++ .../socket/lwip_getaddrinfo_impl.cpp | 189 ++++++++++++++++++ .../components/socket/lwip_raw_tcp_impl.cpp | 104 +++++++++- esphome/components/socket/socket.cpp | 4 +- esphome/components/socket/socket.h | 19 +- .../socket/thread_getaddrinfo_impl.cpp | 99 +++++++++ esphome/core/defines.h | 2 + 10 files changed, 511 insertions(+), 10 deletions(-) create mode 100644 esphome/components/socket/getaddrinfo.h create mode 100644 esphome/components/socket/lwip_getaddrinfo_impl.cpp create mode 100644 esphome/components/socket/thread_getaddrinfo_impl.cpp diff --git a/esphome/components/socket/__init__.py b/esphome/components/socket/__init__.py index 8e9502be6d..7c11738455 100644 --- a/esphome/components/socket/__init__.py +++ b/esphome/components/socket/__init__.py @@ -1,5 +1,6 @@ import esphome.config_validation as cv import esphome.codegen as cg +from esphome.core import CORE CODEOWNERS = ["@esphome/core"] @@ -26,3 +27,6 @@ async def to_code(config): cg.add_define("USE_SOCKET_IMPL_LWIP_TCP") elif impl == IMPLEMENTATION_BSD_SOCKETS: cg.add_define("USE_SOCKET_IMPL_BSD_SOCKETS") + + if CORE.target_platform in ["esp8266", "esp32"]: + cg.add_define("USE_SOCKET_HAS_LWIP") diff --git a/esphome/components/socket/bsd_sockets_impl.cpp b/esphome/components/socket/bsd_sockets_impl.cpp index 6636bcb3eb..01703e633d 100644 --- a/esphome/components/socket/bsd_sockets_impl.cpp +++ b/esphome/components/socket/bsd_sockets_impl.cpp @@ -48,6 +48,43 @@ class BSDSocketImpl : public Socket { return make_unique(fd); } int bind(const struct sockaddr *addr, socklen_t addrlen) override { return ::bind(fd_, addr, addrlen); } + int connect(const struct sockaddr *addr, socklen_t addrlen) override { return ::connect(fd_, addr, addrlen); } + int connect_finished() override { + fd_set wfds; + struct timeval tv; + FD_ZERO(&wfds); + FD_SET(fd_, &wfds); + tv.tv_sec = 0; + tv.tv_usec = 0; + int retval = ::select(fd_ + 1, nullptr, &wfds, nullptr, &tv); + if (retval == -1) { + // reuse errno + return -1; + } + if (retval == 0) { + // timeout, not writable yet + errno = EINPROGRESS; + return -1; + } + if (!FD_ISSET(fd_, &wfds)) { + errno = ECONNREFUSED; + return -1; + } + + int so_error; + socklen_t len = sizeof(so_error); + int ret = this->getsockopt(SOL_SOCKET, SO_ERROR, &so_error, &len); + if (ret == -1) { + // reuse errno + return -1; + } + if (so_error == 0) { + return 0; + } + errno = ECONNREFUSED; + return -1; + } + int close() override { int ret = ::close(fd_); closed_ = true; diff --git a/esphome/components/socket/getaddrinfo.h b/esphome/components/socket/getaddrinfo.h new file mode 100644 index 0000000000..dc18991cdc --- /dev/null +++ b/esphome/components/socket/getaddrinfo.h @@ -0,0 +1,33 @@ +#pragma once +#include +#include "headers.h" + +namespace esphome { +namespace socket { + +struct GetaddrinfoFuture { + public: + virtual ~GetaddrinfoFuture() = default; + // returns true when the request has completed (successfully or with an error) + virtual bool completed() = 0; + /** + * @brief Fetch the completed result into res. + * + * Should only be called after completed() returned true. + * Make sure to call freeaddrinfo() to free the addrinfo storage + * when it's no longer needed. + * + * @return See posix getaddrinfo() return values. + */ + virtual int fetch_result(struct addrinfo **res) = 0; +}; + +std::unique_ptr getaddrinfo_async(const char *node, const char *service, + const struct addrinfo *hints); + +} // namespace socket +} // namespace esphome + +#ifdef USE_ESP8266 +void freeaddrinfo(struct addrinfo *ai); +#endif diff --git a/esphome/components/socket/headers.h b/esphome/components/socket/headers.h index a383c0071d..42f4ef6e1c 100644 --- a/esphome/components/socket/headers.h +++ b/esphome/components/socket/headers.h @@ -8,6 +8,7 @@ #define LWIP_INTERNAL #include "lwip/inet.h" +#include "lwip/netdb.h" #include #include #include @@ -107,6 +108,34 @@ struct iovec { #define ESPHOME_INADDR_NONE INADDR_NONE #endif +#ifndef EAI_FAIL +#define EAI_BADFLAGS (-1) +#define EAI_NONAME (-2) +#define EAI_AGAIN (-3) +#define EAI_FAIL (-4) +#define EAI_FAMILY (-6) +#define EAI_SOCKTYPE (-7) +#define EAI_SERVICE (-8) +#define EAI_MEMORY (-10) +#define EAI_SYSTEM (-11) +#define EAI_OVERFLOW (-12) +#endif // !EAI_FAIL + +#ifndef IPPROTO_UDP +#define IPPROTO_UDP 17 +#endif + +struct addrinfo { // NOLINT(readability-identifier-naming) + int ai_flags; + int ai_family; + int ai_socktype; + int ai_protocol; + socklen_t ai_addrlen; + struct sockaddr *ai_addr; + char *ai_canonname; + struct addrinfo *ai_next; +}; + #endif // USE_SOCKET_IMPL_LWIP_TCP #ifdef USE_SOCKET_IMPL_BSD_SOCKETS @@ -118,6 +147,7 @@ struct iovec { #include #include #include +#include #ifdef USE_ARDUINO // arduino-esp32 declares a global var called INADDR_NONE which is replaced diff --git a/esphome/components/socket/lwip_getaddrinfo_impl.cpp b/esphome/components/socket/lwip_getaddrinfo_impl.cpp new file mode 100644 index 0000000000..d4fb1fdc7d --- /dev/null +++ b/esphome/components/socket/lwip_getaddrinfo_impl.cpp @@ -0,0 +1,189 @@ +#include "getaddrinfo.h" +#include "esphome/core/defines.h" + +#ifdef USE_SOCKET_HAS_LWIP + +#include +#include "lwip/dns.h" +#include "lwip/ip_addr.h" +#include "lwip/netdb.h" + +#include "esphome/core/helpers.h" +#include "esphome/core/log.h" + +namespace esphome { +namespace socket { + +static const char *const TAG = "socket.lwipgetaddrinfo"; + +struct LwipDNSResult { + bool completed; + bool error; + ip_addr_t ipaddr; +}; + +struct LwipDNSCallbackArg { + std::weak_ptr res; +}; + +void lwip_dns_callback(const char *name, const ip_addr_t *ipaddr, void *callback_arg) { + LwipDNSCallbackArg *arg = reinterpret_cast(callback_arg); + { + std::shared_ptr result = arg->res.lock(); + if (result) { + if (ipaddr == nullptr) { + result->error = true; + } else { + result->error = false; + ip_addr_copy(result->ipaddr, *ipaddr); + } + result->completed = true; + } + } + delete arg; // NOLINT(cppcoreguidelines-owning-memory) +} + +class LwipGetaddrinfoFuture : public GetaddrinfoFuture { + public: + LwipGetaddrinfoFuture(std::shared_ptr result, int hint_ai_socktype, int hint_ai_protocol, + uint16_t portno) + : result_(std::move(result)), + hint_ai_socktype_(hint_ai_socktype), + hint_ai_protocol_(hint_ai_protocol), + portno_(portno) {} + ~LwipGetaddrinfoFuture() override = default; + + bool completed() override { return result_->completed; } + int fetch_result(struct addrinfo **res) override { + if (res == nullptr) + return EAI_FAIL; + *res = nullptr; + if (!result_->completed) + return EAI_FAIL; + if (result_->error) + return EAI_FAIL; + + size_t alloc_size = sizeof(struct addrinfo) + sizeof(struct sockaddr_storage); + // NOLINTNEXTLINE(cppcoreguidelines-owning-memory,cppcoreguidelines-no-malloc) + void *storage = malloc(alloc_size); + memset(storage, 0, alloc_size); + struct addrinfo *ai = reinterpret_cast(storage); + struct sockaddr_storage *sa = reinterpret_cast(ai + 1); + +#if LWIP_IPV4 && LWIP_IPV6 + bool isipv6 = result_->ipaddr.type == IPADDR_TYPE_V6; +#elif LWIP_IPV4 + bool isipv6 = false; +#elif LWIP_IPV6 + bool isipv6 = true; +#endif + + bool istcp = true; + if ((hint_ai_socktype_ != 0 && hint_ai_socktype_ == SOCK_DGRAM) || + (hint_ai_protocol_ != 0 && hint_ai_protocol_ == IPPROTO_UDP)) { + istcp = false; + } + + ai->ai_family = isipv6 ? AF_INET6 : AF_INET; + ai->ai_socktype = istcp ? SOCK_STREAM : SOCK_DGRAM; + ai->ai_protocol = istcp ? IPPROTO_TCP : IPPROTO_UDP; + + if (isipv6) { +#if LWIP_IPV6 + struct sockaddr_in6 *sa6 = reinterpret_cast(sa); + inet6_addr_from_ip6addr(&sa6->sin6_addr, ip_2_ip6(&result_->ipaddr)) sa6->sin6_family = AF_INET6; + sa6->sin6_len = sizeof(struct sockaddr_in6); + sa6->sin6_port = htons(portno_); +#endif // LWIP_IPV6 + } else { + struct sockaddr_in *sa4 = reinterpret_cast(sa); + inet_addr_from_ip4addr(&sa4->sin_addr, ip_2_ip4(&result_->ipaddr)); + sa4->sin_family = AF_INET; + sa4->sin_len = sizeof(struct sockaddr_in); + sa4->sin_port = htons(portno_); + } + + ai->ai_addrlen = sizeof(struct sockaddr_storage); + ai->ai_addr = reinterpret_cast(sa); + *res = ai; + return 0; + } + + protected: + std::shared_ptr result_; + int hint_ai_socktype_; + int hint_ai_protocol_; + uint16_t portno_; +}; + +std::unique_ptr getaddrinfo_async(const char *node, const char *service, + const struct addrinfo *hints) { + std::shared_ptr result = std::make_shared(); + result->completed = false; + + uint16_t portno = 0; + if (service != nullptr) { + optional i = parse_number(service); + if (!i.has_value()) { + result->completed = true; + result->error = true; + return std::unique_ptr{new LwipGetaddrinfoFuture(result, 0, 0, 0)}; + } + portno = *i; + } + + int hint_ai_socktype = 0, hint_ai_protocol = 0; + uint8_t dns_addrtype = LWIP_DNS_ADDRTYPE_DEFAULT; + if (hints != nullptr) { + hint_ai_socktype = hints->ai_socktype; + hint_ai_protocol = hints->ai_protocol; + if (hints->ai_family == AF_INET) { + dns_addrtype = LWIP_DNS_ADDRTYPE_IPV4; + } else if (hints->ai_family == AF_INET6) { + dns_addrtype = LWIP_DNS_ADDRTYPE_IPV6; + } + } + + // NOLINTNEXTLINE(cppcoreguidelines-owning-memory) + LwipDNSCallbackArg *callback_arg = new LwipDNSCallbackArg; + callback_arg->res = result; + + ip_addr_t immediate_result; + err_t err = dns_gethostbyname_addrtype(node, &immediate_result, lwip_dns_callback, callback_arg, dns_addrtype); + if (err == ERR_OK) { + // immediate result + result->completed = true; + result->error = false; + ip_addr_copy(result->ipaddr, immediate_result); + + // callback won't be called + delete callback_arg; // NOLINT(cppcoreguidelines-owning-memory) + } else if (err == ERR_INPROGRESS) { + // result notified via callback + } else { + // error + result->completed = true; + result->error = true; + + // callback won't be called + delete callback_arg; // NOLINT(cppcoreguidelines-owning-memory) + } + + return std::unique_ptr{ + new LwipGetaddrinfoFuture(result, hint_ai_socktype, hint_ai_protocol, portno)}; +} + +} // namespace socket +} // namespace esphome + +#ifdef USE_ESP8266 +void freeaddrinfo(struct addrinfo *ai) { + while (ai != NULL) { + struct addrinfo *next = ai->ai_next; + delete ai; + ai = next; + } +} +#endif + +#endif // USE_SOCKET_HAS_LWIP diff --git a/esphome/components/socket/lwip_raw_tcp_impl.cpp b/esphome/components/socket/lwip_raw_tcp_impl.cpp index f5bb57bb93..6376be8761 100644 --- a/esphome/components/socket/lwip_raw_tcp_impl.cpp +++ b/esphome/components/socket/lwip_raw_tcp_impl.cpp @@ -69,7 +69,7 @@ class LWIPRawImpl : public Socket { } if (name == nullptr) { errno = EINVAL; - return 0; + return -1; } ip_addr_t ip; in_port_t port; @@ -126,6 +126,76 @@ class LWIPRawImpl : public Socket { } return 0; } + int connect(const struct sockaddr *addr, socklen_t addrlen) override { + if (pcb_ == nullptr) { + errno = EBADF; + return -1; + } + if (addr == nullptr) { + errno = EINVAL; + return -1; + } + if (connecting_) { + errno = EALREADY; + return -1; + } + + ip_addr_t ipaddr; + uint16_t port; + + if (addr->sa_family == AF_INET) { + const struct sockaddr_in *sa4 = reinterpret_cast(addr); + inet_addr_to_ip4addr(ip_2_ip4(&ipaddr), &sa4->sin_addr); +#if LWIP_IPV4 && LWIP_IPV6 + ipaddr.type = IPADDR_TYPE_V4; +#endif + port = ntohs(sa4->sin_port); +#if LWIP_IPV6 + } else if (addr->sa_family == AF_INET6) { + const struct sockaddr_in6 *sa6 = reinterpret_cast(addr); + inet6_addr_to_ip6addr(ip_2_ip6(&ipaddr), &sa6->sin_addr); + ipaddr.type = IPADDR_TYPE_V6; + port = ntohs(sa6->sin_port); +#endif // LWIP_IPV6 + } else { + errno = EAFNOSUPPORT; + return -1; + } + + connecting_ = true; + connected_ = false; + connect_error_ = false; + LWIP_LOG("tcp_connect(%u)", port); + err_t err = tcp_connect(pcb_, &ipaddr, port, LWIPRawImpl::s_connected_fn); + if (err == ERR_VAL) { + errno = EINVAL; + return -1; + } + if (err != ERR_OK) { + errno = EIO; + return -1; + } + + errno = EINPROGRESS; + return -1; + } + int connect_finished() override { + if (connected_) { + return 0; + } + if (connect_error_) { + errno = ECONNREFUSED; + return -1; + } + if (connecting_) { + errno = EINPROGRESS; + return -1; + } + // no connect started + errno = EALREADY; + return -1; + } + int close() override { if (pcb_ == nullptr) { errno = ECONNRESET; @@ -369,9 +439,10 @@ class LWIPRawImpl : public Socket { for (int i = 0; i < iovcnt; i++) { ssize_t err = read(reinterpret_cast(iov[i].iov_base), iov[i].iov_len); if (err == -1) { - if (ret != 0) + if (ret != 0) { // if we already read some don't return an error break; + } return err; } ret += err; @@ -433,9 +504,10 @@ class LWIPRawImpl : public Socket { ssize_t written = internal_write(buf, len); if (written == -1) return -1; - if (written == 0) + if (written == 0) { // no need to output if nothing written return 0; + } if (nodelay_) { int err = internal_output(); if (err == -1) @@ -448,18 +520,20 @@ class LWIPRawImpl : public Socket { for (int i = 0; i < iovcnt; i++) { ssize_t err = internal_write(reinterpret_cast(iov[i].iov_base), iov[i].iov_len); if (err == -1) { - if (written != 0) + if (written != 0) { // if we already read some don't return an error break; + } return err; } written += err; if ((size_t) err != iov[i].iov_len) break; } - if (written == 0) + if (written == 0) { // no need to output if nothing written return 0; + } if (nodelay_) { int err = internal_output(); if (err == -1) @@ -524,6 +598,18 @@ class LWIPRawImpl : public Socket { } return ERR_OK; } + err_t connected_fn(err_t err) { + LWIP_LOG("connected(err=%d)", err); + if (err != ERR_OK) { + connected_ = false; + connect_error_ = false; + } else { + connected_ = true; + connect_error_ = true; + } + connecting_ = false; + return ERR_OK; + } static err_t s_accept_fn(void *arg, struct tcp_pcb *newpcb, err_t err) { LWIPRawImpl *arg_this = reinterpret_cast(arg); @@ -540,6 +626,11 @@ class LWIPRawImpl : public Socket { return arg_this->recv_fn(pb, err); } + static err_t s_connected_fn(void *arg, struct tcp_pcb *pcb, err_t err) { + LWIPRawImpl *arg_this = reinterpret_cast(arg); + return arg_this->connected_fn(err); + } + protected: int ip2sockaddr_(ip_addr_t *ip, uint16_t port, struct sockaddr *name, socklen_t *addrlen) { if (family_ == AF_INET) { @@ -590,6 +681,9 @@ class LWIPRawImpl : public Socket { // instead use it for determining whether to call lwip_output bool nodelay_ = false; sa_family_t family_ = 0; + bool connecting_ = false; + bool connected_ = false; + bool connect_error_ = false; }; std::unique_ptr socket(int domain, int type, int protocol) { diff --git a/esphome/components/socket/socket.cpp b/esphome/components/socket/socket.cpp index 22a4c11df8..9f1fc8b032 100644 --- a/esphome/components/socket/socket.cpp +++ b/esphome/components/socket/socket.cpp @@ -7,7 +7,7 @@ namespace esphome { namespace socket { std::unique_ptr socket_ip(int type, int protocol) { -#if LWIP_IPV6 +#ifdef USE_SOCKET_IPV6 return socket(AF_INET6, type, protocol); #else return socket(AF_INET, type, protocol); @@ -15,7 +15,7 @@ std::unique_ptr socket_ip(int type, int protocol) { } socklen_t set_sockaddr_any(struct sockaddr *addr, socklen_t addrlen, uint16_t port) { -#if LWIP_IPV6 +#if USE_SOCKET_IPV6 if (addrlen < sizeof(sockaddr_in6)) { errno = EINVAL; return 0; diff --git a/esphome/components/socket/socket.h b/esphome/components/socket/socket.h index ecf117deeb..847f33224f 100644 --- a/esphome/components/socket/socket.h +++ b/esphome/components/socket/socket.h @@ -5,6 +5,12 @@ #include "headers.h" #include "esphome/core/optional.h" +#ifdef USE_SOCKET_IMPL_LWIP_TCP +#if LWIP_IPV6 +#define USE_SOCKET_IPV6 +#endif +#endif + namespace esphome { namespace socket { @@ -17,10 +23,17 @@ class Socket { virtual std::unique_ptr accept(struct sockaddr *addr, socklen_t *addrlen) = 0; virtual int bind(const struct sockaddr *addr, socklen_t addrlen) = 0; + virtual int connect(const struct sockaddr *addr, socklen_t addrlen) = 0; + /** + * @brief Helper to check if a socket connect() that was EINPROGRESS is now finished. + * + * If the connect finnished successfully, returns 0. + * If it's still in progress, returns -1 and sets errno to EINPROGRESS. + * Other errors result in return code -1 and errno like in blocking connect(). + */ + virtual int connect_finished() = 0; + virtual int close() = 0; - // not supported yet: - // virtual int connect(const std::string &address) = 0; - // virtual int connect(const struct sockaddr *addr, socklen_t addrlen) = 0; virtual int shutdown(int how) = 0; virtual int getpeername(struct sockaddr *addr, socklen_t *addrlen) = 0; diff --git a/esphome/components/socket/thread_getaddrinfo_impl.cpp b/esphome/components/socket/thread_getaddrinfo_impl.cpp new file mode 100644 index 0000000000..92df347566 --- /dev/null +++ b/esphome/components/socket/thread_getaddrinfo_impl.cpp @@ -0,0 +1,99 @@ +#include "getaddrinfo.h" +#include "esphome/core/defines.h" + +#ifndef USE_SOCKET_HAS_LWIP + +#include +#include +#include +#include + +#include "esphome/core/helpers.h" +#include "esphome/core/log.h" + +namespace esphome { +namespace socket { + +static const char *const TAG = "socket.threadgetaddrinfo"; + +struct ThreadGetaddrinfoResult { + bool completed; + int return_code; + struct addrinfo *res; +}; + +class ThreadGetaddrinfoFuture : public GetaddrinfoFuture { + public: + ThreadGetaddrinfoFuture(std::shared_ptr result) : result_(result) {} + ~ThreadGetaddrinfoFuture() override = default; + + bool completed() override { return result_->completed; } + int fetch_result(struct addrinfo **res) { + if (res == nullptr) + return EAI_FAIL; + *res = nullptr; + if (!result_->completed) + return EAI_FAIL; + if (result_->return_code != 0) + return result_->return_code; + + *res = result_->res; + return 0; + } + + protected: + std::shared_ptr result_; +}; + +void worker(std::shared_ptr result, const char *node, const char *service, + const struct addrinfo *hints) { + result->return_code = getaddrinfo(node, service, hints, &result->res); + result->completed = true; + if (hints != nullptr) { + delete hints->ai_addr; + delete hints->ai_canonname; + delete hints; + } + delete node; + delete service; +} + +std::unique_ptr getaddrinfo_async(const char *node, const char *service, + const struct addrinfo *hints) { + std::shared_ptr result = std::make_shared(); + result->completed = false; + + struct addrinfo *hints_copy = nullptr; + if (hints != nullptr) { + hints_copy = new struct addrinfo; + hints_copy->ai_flags = hints->ai_flags; + hints_copy->ai_family = hints->ai_family; + hints_copy->ai_socktype = hints->ai_socktype; + hints_copy->ai_protocol = hints->ai_protocol; + hints_copy->ai_addrlen = hints->ai_addrlen; + if (ai->ai_addr != nullptr) { + hints_copy->ai_addr = malloc(hints->ai_addrlen); + memcpy(hints_copy->ai_addr, hints->ai_addr, hints->ai_addrlen); + } + if (ai->ai_canonname != nullptr) { + hints_copy->ai_canonname = strdup(hints->ai_canonname); + } + hints_copy->ai_next = nullptr; + } + + const char *node_copy = nullptr, *service_copy = nullptr; + if (node != nullptr) + node_copy = strdup(node); + if (service != nullptr) + service_copy = strdup(service); + + std::thread thread(worker, result, node_copy, service_copy, hints_copy); + thread.detach(); + + return std::unique_ptr{new ThreadGetaddrinfoFuture(result)}; +} + +} // namespace socket +} // namespace esphome + +#endif // !USE_SOCKET_HAS_LWIP diff --git a/esphome/core/defines.h b/esphome/core/defines.h index acdc5df815..4fad2d88e1 100644 --- a/esphome/core/defines.h +++ b/esphome/core/defines.h @@ -70,6 +70,7 @@ #ifdef USE_ESP_IDF #define USE_ARDUINO_VERSION_CODE VERSION_CODE(4, 3, 0) #endif +#define USE_SOCKET_HAS_LWIP #endif // ESP8266-specific feature flags @@ -79,6 +80,7 @@ #define USE_ESP8266_PREFERENCES_FLASH #define USE_HTTP_REQUEST_ESP8266_HTTPS #define USE_SOCKET_IMPL_LWIP_TCP +#define USE_SOCKET_HAS_LWIP #endif // Disabled feature flags