diff --git a/esphome/components/script/script.h b/esphome/components/script/script.h index 58fb67a3ea..870a623f32 100644 --- a/esphome/components/script/script.h +++ b/esphome/components/script/script.h @@ -2,6 +2,7 @@ #include #include +#include #include "esphome/core/automation.h" #include "esphome/core/component.h" #include "esphome/core/helpers.h" @@ -264,10 +265,22 @@ template class IsRunningCondition : public Condition class ScriptWaitAction : public Action, public Component { public: ScriptWaitAction(C *script) : script_(script) {} + void setup() override { + // Start with loop disabled - only enable when there's work to do + this->disable_loop(); + } + void play_complex(Ts... x) override { this->num_running_++; // Check if we can continue immediately. @@ -275,7 +288,11 @@ template class ScriptWaitAction : public Action, this->play_next_(x...); return; } - this->var_ = std::make_tuple(x...); + + // Store parameters for later execution + this->param_queue_.emplace_front(x...); + // Enable loop now that we have work to do + this->enable_loop(); this->loop(); } @@ -286,15 +303,30 @@ template class ScriptWaitAction : public Action, if (this->script_->is_running()) return; - this->play_next_tuple_(this->var_); + while (!this->param_queue_.empty()) { + auto ¶ms = this->param_queue_.front(); + this->play_next_tuple_(params, typename gens::type()); + this->param_queue_.pop_front(); + } + // Queue is now empty - disable loop until next play_complex + this->disable_loop(); } void play(Ts... x) override { /* ignore - see play_complex */ } + void stop() override { + this->param_queue_.clear(); + this->disable_loop(); + } + protected: + template void play_next_tuple_(const std::tuple &tuple, seq /*unused*/) { + this->play_next_(std::get(tuple)...); + } + C *script_; - std::tuple var_{}; + std::forward_list> param_queue_; }; } // namespace script diff --git a/esphome/core/base_automation.h b/esphome/core/base_automation.h index 1c60dd1c7a..e668a1782a 100644 --- a/esphome/core/base_automation.h +++ b/esphome/core/base_automation.h @@ -10,6 +10,7 @@ #include "esphome/core/helpers.h" #include +#include namespace esphome { @@ -268,32 +269,28 @@ template class WhileAction : public Action { void add_then(const std::initializer_list *> &actions) { this->then_.add_actions(actions); this->then_.add_action(new LambdaAction([this](Ts... x) { - if (this->num_running_ > 0 && this->condition_->check_tuple(this->var_)) { + if (this->num_running_ > 0 && this->condition_->check(x...)) { // play again - if (this->num_running_ > 0) { - this->then_.play_tuple(this->var_); - } + this->then_.play(x...); } else { // condition false, play next - this->play_next_tuple_(this->var_); + this->play_next_(x...); } })); } void play_complex(Ts... x) override { this->num_running_++; - // Store loop parameters - this->var_ = std::make_tuple(x...); // Initial condition check - if (!this->condition_->check_tuple(this->var_)) { + if (!this->condition_->check(x...)) { // If new condition check failed, stop loop if running this->then_.stop(); - this->play_next_tuple_(this->var_); + this->play_next_(x...); return; } if (this->num_running_ > 0) { - this->then_.play_tuple(this->var_); + this->then_.play(x...); } } @@ -305,7 +302,6 @@ template class WhileAction : public Action { protected: Condition *condition_; ActionList then_; - std::tuple var_{}; }; template class RepeatAction : public Action { @@ -317,7 +313,7 @@ template class RepeatAction : public Action { this->then_.add_action(new LambdaAction([this](uint32_t iteration, Ts... x) { iteration++; if (iteration >= this->count_.value(x...)) { - this->play_next_tuple_(this->var_); + this->play_next_(x...); } else { this->then_.play(iteration, x...); } @@ -326,11 +322,10 @@ template class RepeatAction : public Action { void play_complex(Ts... x) override { this->num_running_++; - this->var_ = std::make_tuple(x...); if (this->count_.value(x...) > 0) { this->then_.play(0, x...); } else { - this->play_next_tuple_(this->var_); + this->play_next_(x...); } } @@ -341,15 +336,26 @@ template class RepeatAction : public Action { protected: ActionList then_; - std::tuple var_; }; +/** Wait until a condition is true to continue execution. + * + * Uses queue-based storage to safely handle concurrent executions. + * While concurrent execution from the same trigger is uncommon, it's possible + * (e.g., rapid button presses, high-frequency sensor updates), so we use + * queue-based storage for correctness. + */ template class WaitUntilAction : public Action, public Component { public: WaitUntilAction(Condition *condition) : condition_(condition) {} TEMPLATABLE_VALUE(uint32_t, timeout_value) + void setup() override { + // Start with loop disabled - only enable when there's work to do + this->disable_loop(); + } + void play_complex(Ts... x) override { this->num_running_++; // Check if we can continue immediately. @@ -359,13 +365,14 @@ template class WaitUntilAction : public Action, public Co } return; } - this->var_ = std::make_tuple(x...); - if (this->timeout_value_.has_value()) { - auto f = std::bind(&WaitUntilAction::play_next_, this, x...); - this->set_timeout("timeout", this->timeout_value_.value(x...), f); - } + // Store for later processing + auto now = millis(); + auto timeout = this->timeout_value_.optional_value(x...); + this->var_queue_.emplace_front(now, timeout, std::make_tuple(x...)); + // Enable loop now that we have work to do + this->enable_loop(); this->loop(); } @@ -373,13 +380,32 @@ template class WaitUntilAction : public Action, public Co if (this->num_running_ == 0) return; - if (!this->condition_->check_tuple(this->var_)) { - return; + auto now = millis(); + + this->var_queue_.remove_if([&](auto &queued) { + auto start = std::get(queued); + auto timeout = std::get>(queued); + auto &var = std::get>(queued); + + auto expired = timeout && (now - start) >= *timeout; + + if (!expired && !this->condition_->check_tuple(var)) { + return false; + } + + this->play_next_tuple_(var); + return true; + }); + + // If queue is now empty, disable loop until next play_complex + if (this->var_queue_.empty()) { + this->disable_loop(); } + } - this->cancel_timeout("timeout"); - - this->play_next_tuple_(this->var_); + void stop() override { + this->var_queue_.clear(); + this->disable_loop(); } float get_setup_priority() const override { return setup_priority::DATA; } @@ -387,11 +413,9 @@ template class WaitUntilAction : public Action, public Co void play(Ts... x) override { /* ignore - see play_complex */ } - void stop() override { this->cancel_timeout("timeout"); } - protected: Condition *condition_; - std::tuple var_{}; + std::forward_list, std::tuple>> var_queue_{}; }; template class UpdateComponentAction : public Action { diff --git a/tests/integration/fixtures/automation_wait_actions.yaml b/tests/integration/fixtures/automation_wait_actions.yaml new file mode 100644 index 0000000000..65a61be14f --- /dev/null +++ b/tests/integration/fixtures/automation_wait_actions.yaml @@ -0,0 +1,130 @@ +esphome: + name: test-automation-wait-actions + +host: + +api: + actions: + # Test 1: Trigger wait_until automation 5 times rapidly + - action: test_wait_until + then: + - logger.log: "=== TEST 1: Triggering wait_until automation 5 times ===" + # Publish 5 different values to trigger the on_value automation 5 times + - sensor.template.publish: + id: wait_until_sensor + state: 1 + - sensor.template.publish: + id: wait_until_sensor + state: 2 + - sensor.template.publish: + id: wait_until_sensor + state: 3 + - sensor.template.publish: + id: wait_until_sensor + state: 4 + - sensor.template.publish: + id: wait_until_sensor + state: 5 + # Wait then satisfy the condition so all 5 waiting actions complete + - delay: 100ms + - globals.set: + id: test_flag + value: 'true' + + # Test 2: Trigger script.wait automation 5 times rapidly + - action: test_script_wait + then: + - logger.log: "=== TEST 2: Triggering script.wait automation 5 times ===" + # Start a long-running script + - script.execute: blocking_script + # Publish 5 different values to trigger the on_value automation 5 times + - sensor.template.publish: + id: script_wait_sensor + state: 1 + - sensor.template.publish: + id: script_wait_sensor + state: 2 + - sensor.template.publish: + id: script_wait_sensor + state: 3 + - sensor.template.publish: + id: script_wait_sensor + state: 4 + - sensor.template.publish: + id: script_wait_sensor + state: 5 + + # Test 3: Trigger wait_until timeout automation 5 times rapidly + - action: test_wait_timeout + then: + - logger.log: "=== TEST 3: Triggering timeout automation 5 times ===" + # Publish 5 different values (condition will never be true, all will timeout) + - sensor.template.publish: + id: timeout_sensor + state: 1 + - sensor.template.publish: + id: timeout_sensor + state: 2 + - sensor.template.publish: + id: timeout_sensor + state: 3 + - sensor.template.publish: + id: timeout_sensor + state: 4 + - sensor.template.publish: + id: timeout_sensor + state: 5 + +logger: + level: DEBUG + +globals: + - id: test_flag + type: bool + restore_value: false + initial_value: 'false' + + - id: timeout_flag + type: bool + restore_value: false + initial_value: 'false' + +# Sensors with wait_until/script.wait in their on_value automations +sensor: + # Test 1: on_value automation with wait_until + - platform: template + id: wait_until_sensor + on_value: + # This wait_until will be hit 5 times before any complete + - wait_until: + condition: + lambda: return id(test_flag); + - logger.log: "wait_until automation completed" + + # Test 2: on_value automation with script.wait + - platform: template + id: script_wait_sensor + on_value: + # This script.wait will be hit 5 times before any complete + - script.wait: blocking_script + - logger.log: "script.wait automation completed" + + # Test 3: on_value automation with wait_until timeout + - platform: template + id: timeout_sensor + on_value: + # This wait_until will be hit 5 times, all will timeout + - wait_until: + condition: + lambda: return id(timeout_flag); + timeout: 200ms + - logger.log: "timeout automation completed" + +script: + # Blocking script for script.wait test + - id: blocking_script + mode: single + then: + - logger.log: "Blocking script: START" + - delay: 200ms + - logger.log: "Blocking script: END" diff --git a/tests/integration/test_action_concurrent_reentry.py b/tests/integration/test_action_concurrent_reentry.py index ba67e4c798..aa5801ca2b 100644 --- a/tests/integration/test_action_concurrent_reentry.py +++ b/tests/integration/test_action_concurrent_reentry.py @@ -11,7 +11,6 @@ import pytest from .types import APIClientConnectedFactory, RunCompiledFunction -@pytest.mark.xfail(reason="https://github.com/esphome/issues/issues/6534") @pytest.mark.asyncio async def test_action_concurrent_reentry( yaml_config: str, diff --git a/tests/integration/test_automation_wait_actions.py b/tests/integration/test_automation_wait_actions.py new file mode 100644 index 0000000000..adcb8ba487 --- /dev/null +++ b/tests/integration/test_automation_wait_actions.py @@ -0,0 +1,104 @@ +"""Test concurrent execution of wait_until and script.wait in direct automation actions.""" + +from __future__ import annotations + +import asyncio +import re + +import pytest + +from .types import APIClientConnectedFactory, RunCompiledFunction + + +@pytest.mark.asyncio +async def test_automation_wait_actions( + yaml_config: str, + run_compiled: RunCompiledFunction, + api_client_connected: APIClientConnectedFactory, +) -> None: + """ + Test that wait_until and script.wait correctly handle concurrent executions + when automation actions (not scripts) are triggered multiple times rapidly. + + This tests sensor.on_value automations being triggered 5 times before any complete. + """ + loop = asyncio.get_running_loop() + + # Track completion counts + test_results = { + "wait_until": 0, + "script_wait": 0, + "wait_until_timeout": 0, + } + + # Patterns for log messages + wait_until_complete = re.compile(r"wait_until automation completed") + script_wait_complete = re.compile(r"script\.wait automation completed") + timeout_complete = re.compile(r"timeout automation completed") + + # Test completion futures + test1_complete = loop.create_future() + test2_complete = loop.create_future() + test3_complete = loop.create_future() + + def check_output(line: str) -> None: + """Check log output for completion messages.""" + # Test 1: wait_until concurrent execution + if wait_until_complete.search(line): + test_results["wait_until"] += 1 + if test_results["wait_until"] == 5 and not test1_complete.done(): + test1_complete.set_result(True) + + # Test 2: script.wait concurrent execution + if script_wait_complete.search(line): + test_results["script_wait"] += 1 + if test_results["script_wait"] == 5 and not test2_complete.done(): + test2_complete.set_result(True) + + # Test 3: wait_until with timeout + if timeout_complete.search(line): + test_results["wait_until_timeout"] += 1 + if test_results["wait_until_timeout"] == 5 and not test3_complete.done(): + test3_complete.set_result(True) + + async with ( + run_compiled(yaml_config, line_callback=check_output), + api_client_connected() as client, + ): + # Get services + _, services = await client.list_entities_services() + + # Test 1: wait_until in automation - trigger 5 times rapidly + test_service = next((s for s in services if s.name == "test_wait_until"), None) + assert test_service is not None, "test_wait_until service not found" + client.execute_service(test_service, {}) + await asyncio.wait_for(test1_complete, timeout=3.0) + + # Verify Test 1: All 5 triggers should complete + assert test_results["wait_until"] == 5, ( + f"Test 1: Expected 5 wait_until completions, got {test_results['wait_until']}" + ) + + # Test 2: script.wait in automation - trigger 5 times rapidly + test_service = next((s for s in services if s.name == "test_script_wait"), None) + assert test_service is not None, "test_script_wait service not found" + client.execute_service(test_service, {}) + await asyncio.wait_for(test2_complete, timeout=3.0) + + # Verify Test 2: All 5 triggers should complete + assert test_results["script_wait"] == 5, ( + f"Test 2: Expected 5 script.wait completions, got {test_results['script_wait']}" + ) + + # Test 3: wait_until with timeout in automation - trigger 5 times rapidly + test_service = next( + (s for s in services if s.name == "test_wait_timeout"), None + ) + assert test_service is not None, "test_wait_timeout service not found" + client.execute_service(test_service, {}) + await asyncio.wait_for(test3_complete, timeout=3.0) + + # Verify Test 3: All 5 triggers should timeout and complete + assert test_results["wait_until_timeout"] == 5, ( + f"Test 3: Expected 5 timeout completions, got {test_results['wait_until_timeout']}" + )