From 9579423b24545c2921e83897a96ca2370dff44d7 Mon Sep 17 00:00:00 2001
From: Jesse Hills <3060199+jesserockz@users.noreply.github.com>
Date: Thu, 19 Oct 2023 11:42:52 +1300
Subject: [PATCH] esp32_improv add timeout (#5556)

---
 esphome/components/esp32_improv/__init__.py        |  5 +++++
 .../esp32_improv/esp32_improv_component.h          |  5 +++++
 esphome/components/wifi/wifi_component.cpp         | 14 +++++++-------
 3 files changed, 17 insertions(+), 7 deletions(-)

diff --git a/esphome/components/esp32_improv/__init__.py b/esphome/components/esp32_improv/__init__.py
index fba2e55ae8..49d95d89e5 100644
--- a/esphome/components/esp32_improv/__init__.py
+++ b/esphome/components/esp32_improv/__init__.py
@@ -36,6 +36,9 @@ CONFIG_SCHEMA = cv.Schema(
         cv.Optional(
             CONF_AUTHORIZED_DURATION, default="1min"
         ): cv.positive_time_period_milliseconds,
+        cv.Optional(
+            CONF_WIFI_TIMEOUT, default="1min"
+        ): cv.positive_time_period_milliseconds,
     }
 ).extend(cv.COMPONENT_SCHEMA)
 
@@ -53,6 +56,8 @@ async def to_code(config):
     cg.add(var.set_identify_duration(config[CONF_IDENTIFY_DURATION]))
     cg.add(var.set_authorized_duration(config[CONF_AUTHORIZED_DURATION]))
 
+    cg.add(var.set_wifi_timeout(config[CONF_WIFI_TIMEOUT]))
+
     if CONF_AUTHORIZER in config and config[CONF_AUTHORIZER] is not None:
         activator = await cg.get_variable(config[CONF_AUTHORIZER])
         cg.add(var.set_authorizer(activator))
diff --git a/esphome/components/esp32_improv/esp32_improv_component.h b/esphome/components/esp32_improv/esp32_improv_component.h
index ba9892d6a5..00c6cf885a 100644
--- a/esphome/components/esp32_improv/esp32_improv_component.h
+++ b/esphome/components/esp32_improv/esp32_improv_component.h
@@ -51,6 +51,9 @@ class ESP32ImprovComponent : public Component, public BLEServiceComponent {
   void set_identify_duration(uint32_t identify_duration) { this->identify_duration_ = identify_duration; }
   void set_authorized_duration(uint32_t authorized_duration) { this->authorized_duration_ = authorized_duration; }
 
+  void set_wifi_timeout(uint32_t wifi_timeout) { this->wifi_timeout_ = wifi_timeout; }
+  uint32_t get_wifi_timeout() const { return this->wifi_timeout_; }
+
  protected:
   bool should_start_{false};
   bool setup_complete_{false};
@@ -60,6 +63,8 @@ class ESP32ImprovComponent : public Component, public BLEServiceComponent {
   uint32_t authorized_start_{0};
   uint32_t authorized_duration_;
 
+  uint32_t wifi_timeout_{};
+
   std::vector<uint8_t> incoming_data_;
   wifi::WiFiAP connecting_sta_;
 
diff --git a/esphome/components/wifi/wifi_component.cpp b/esphome/components/wifi/wifi_component.cpp
index 2cb36fe8ea..b08f20de21 100644
--- a/esphome/components/wifi/wifi_component.cpp
+++ b/esphome/components/wifi/wifi_component.cpp
@@ -8,16 +8,16 @@
 #include <user_interface.h>
 #endif
 
-#include <utility>
 #include <algorithm>
-#include "lwip/err.h"
+#include <utility>
 #include "lwip/dns.h"
+#include "lwip/err.h"
 
+#include "esphome/core/application.h"
+#include "esphome/core/hal.h"
 #include "esphome/core/helpers.h"
 #include "esphome/core/log.h"
-#include "esphome/core/hal.h"
 #include "esphome/core/util.h"
-#include "esphome/core/application.h"
 
 #ifdef USE_CAPTIVE_PORTAL
 #include "esphome/components/captive_portal/captive_portal.h"
@@ -96,7 +96,7 @@ void WiFiComponent::start() {
 #endif
   }
 #ifdef USE_IMPROV
-  if (esp32_improv::global_improv_component != nullptr) {
+  if (!this->has_sta() && esp32_improv::global_improv_component != nullptr) {
     if (this->wifi_mode_(true, {}))
       esp32_improv::global_improv_component->start();
   }
@@ -163,8 +163,8 @@ void WiFiComponent::loop() {
     }
 
 #ifdef USE_IMPROV
-    if (esp32_improv::global_improv_component != nullptr) {
-      if (!this->is_connected()) {
+    if (esp32_improv::global_improv_component != nullptr && !esp32_improv::global_improv_component->is_active()) {
+      if (now - this->last_connected_ > esp32_improv::global_improv_component->get_wifi_timeout()) {
         if (this->wifi_mode_(true, {}))
           esp32_improv::global_improv_component->start();
       }