From 86a1b4cf694026ea77f2c37d406b55e84203b122 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 18 Jan 2026 19:51:11 -1000 Subject: [PATCH] [select][fan] Use StringRef for on_value/on_preset_set triggers to avoid heap allocation (#13324) --- esphome/codegen.py | 1 + esphome/components/fan/__init__.py | 4 +- esphome/components/fan/automation.h | 4 +- esphome/components/select/__init__.py | 4 +- esphome/components/select/automation.h | 4 +- esphome/core/string_ref.h | 64 ++++++++ esphome/cpp_types.py | 1 + .../fixtures/select_stringref_trigger.yaml | 85 +++++++++++ .../test_select_stringref_trigger.py | 143 ++++++++++++++++++ 9 files changed, 302 insertions(+), 8 deletions(-) create mode 100644 tests/integration/fixtures/select_stringref_trigger.yaml create mode 100644 tests/integration/test_select_stringref_trigger.py diff --git a/esphome/codegen.py b/esphome/codegen.py index 6d55c6023d..4a2a5975c6 100644 --- a/esphome/codegen.py +++ b/esphome/codegen.py @@ -69,6 +69,7 @@ from esphome.cpp_types import ( # noqa: F401 JsonObjectConst, Parented, PollingComponent, + StringRef, arduino_json_ns, bool_, const_char_ptr, diff --git a/esphome/components/fan/__init__.py b/esphome/components/fan/__init__.py index 35a351e8f1..6010aa8ed4 100644 --- a/esphome/components/fan/__init__.py +++ b/esphome/components/fan/__init__.py @@ -77,7 +77,7 @@ FanSpeedSetTrigger = fan_ns.class_( "FanSpeedSetTrigger", automation.Trigger.template(cg.int_) ) FanPresetSetTrigger = fan_ns.class_( - "FanPresetSetTrigger", automation.Trigger.template(cg.std_string) + "FanPresetSetTrigger", automation.Trigger.template(cg.StringRef) ) FanIsOnCondition = fan_ns.class_("FanIsOnCondition", automation.Condition.template()) @@ -287,7 +287,7 @@ async def setup_fan_core_(var, config): await automation.build_automation(trigger, [(cg.int_, "x")], conf) for conf in config.get(CONF_ON_PRESET_SET, []): trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) - await automation.build_automation(trigger, [(cg.std_string, "x")], conf) + await automation.build_automation(trigger, [(cg.StringRef, "x")], conf) async def register_fan(var, config): diff --git a/esphome/components/fan/automation.h b/esphome/components/fan/automation.h index 77abc2f13f..3c3b0ce519 100644 --- a/esphome/components/fan/automation.h +++ b/esphome/components/fan/automation.h @@ -208,7 +208,7 @@ class FanSpeedSetTrigger : public Trigger { int last_speed_; }; -class FanPresetSetTrigger : public Trigger { +class FanPresetSetTrigger : public Trigger { public: FanPresetSetTrigger(Fan *state) { state->add_on_state_callback([this, state]() { @@ -216,7 +216,7 @@ class FanPresetSetTrigger : public Trigger { auto should_trigger = preset_mode != this->last_preset_mode_; this->last_preset_mode_ = preset_mode; if (should_trigger) { - this->trigger(std::string(preset_mode)); + this->trigger(preset_mode); } }); this->last_preset_mode_ = state->get_preset_mode(); diff --git a/esphome/components/select/__init__.py b/esphome/components/select/__init__.py index c51131a292..84ad591ba1 100644 --- a/esphome/components/select/__init__.py +++ b/esphome/components/select/__init__.py @@ -33,7 +33,7 @@ SelectPtr = Select.operator("ptr") # Triggers SelectStateTrigger = select_ns.class_( "SelectStateTrigger", - automation.Trigger.template(cg.std_string, cg.size_t), + automation.Trigger.template(cg.StringRef, cg.size_t), ) # Actions @@ -100,7 +100,7 @@ async def setup_select_core_(var, config, *, options: list[str]): for conf in config.get(CONF_ON_VALUE, []): trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) await automation.build_automation( - trigger, [(cg.std_string, "x"), (cg.size_t, "i")], conf + trigger, [(cg.StringRef, "x"), (cg.size_t, "i")], conf ) if (mqtt_id := config.get(CONF_MQTT_ID)) is not None: diff --git a/esphome/components/select/automation.h b/esphome/components/select/automation.h index 81e8a3561d..ffdabd5f7c 100644 --- a/esphome/components/select/automation.h +++ b/esphome/components/select/automation.h @@ -6,11 +6,11 @@ namespace esphome::select { -class SelectStateTrigger : public Trigger { +class SelectStateTrigger : public Trigger { public: explicit SelectStateTrigger(Select *parent) : parent_(parent) { parent->add_on_state_callback( - [this](size_t index) { this->trigger(std::string(this->parent_->option_at(index)), index); }); + [this](size_t index) { this->trigger(StringRef(this->parent_->option_at(index)), index); }); } protected: diff --git a/esphome/core/string_ref.h b/esphome/core/string_ref.h index 44ca79c81b..d502c4d27f 100644 --- a/esphome/core/string_ref.h +++ b/esphome/core/string_ref.h @@ -72,6 +72,7 @@ class StringRef { constexpr const char *c_str() const { return base_; } constexpr size_type size() const { return len_; } + constexpr size_type length() const { return len_; } constexpr bool empty() const { return len_ == 0; } constexpr const_reference operator[](size_type pos) const { return *(base_ + pos); } @@ -80,6 +81,32 @@ class StringRef { operator std::string() const { return str(); } + /// Find first occurrence of substring, returns std::string::npos if not found. + /// Note: Requires the underlying string to be null-terminated. + size_type find(const char *s, size_type pos = 0) const { + if (pos >= len_) + return std::string::npos; + const char *result = std::strstr(base_ + pos, s); + // Verify entire match is within bounds (strstr searches to null terminator) + if (result && result + std::strlen(s) <= base_ + len_) + return static_cast(result - base_); + return std::string::npos; + } + size_type find(char c, size_type pos = 0) const { + if (pos >= len_) + return std::string::npos; + const void *result = std::memchr(base_ + pos, static_cast(c), len_ - pos); + return result ? static_cast(static_cast(result) - base_) : std::string::npos; + } + + /// Return substring as std::string + std::string substr(size_type pos = 0, size_type count = std::string::npos) const { + if (pos >= len_) + return std::string(); + size_type actual_count = (count == std::string::npos || pos + count > len_) ? len_ - pos : count; + return std::string(base_ + pos, actual_count); + } + private: const char *base_; size_type len_; @@ -160,6 +187,43 @@ inline std::string operator+(const std::string &lhs, const StringRef &rhs) { str.append(rhs.c_str(), rhs.size()); return str; } +// String conversion functions for ADL compatibility (allows stoi(x) where x is StringRef) +// Must be in esphome namespace for ADL to find them. Uses strtol/strtod directly to avoid heap allocation. +namespace internal { +// NOLINTBEGIN(google-runtime-int) +template inline R parse_number(const StringRef &str, size_t *pos, F conv) { + char *end; + R result = conv(str.c_str(), &end); + // Set pos to 0 on conversion failure (when no characters consumed), otherwise index after number + if (pos) + *pos = (end == str.c_str()) ? 0 : static_cast(end - str.c_str()); + return result; +} +template inline R parse_number(const StringRef &str, size_t *pos, int base, F conv) { + char *end; + R result = conv(str.c_str(), &end, base); + // Set pos to 0 on conversion failure (when no characters consumed), otherwise index after number + if (pos) + *pos = (end == str.c_str()) ? 0 : static_cast(end - str.c_str()); + return result; +} +// NOLINTEND(google-runtime-int) +} // namespace internal +// NOLINTBEGIN(readability-identifier-naming,google-runtime-int) +inline int stoi(const StringRef &str, size_t *pos = nullptr, int base = 10) { + return static_cast(internal::parse_number(str, pos, base, std::strtol)); +} +inline long stol(const StringRef &str, size_t *pos = nullptr, int base = 10) { + return internal::parse_number(str, pos, base, std::strtol); +} +inline float stof(const StringRef &str, size_t *pos = nullptr) { + return internal::parse_number(str, pos, std::strtof); +} +inline double stod(const StringRef &str, size_t *pos = nullptr) { + return internal::parse_number(str, pos, std::strtod); +} +// NOLINTEND(readability-identifier-naming,google-runtime-int) + #ifdef USE_JSON // NOLINTNEXTLINE(readability-identifier-naming) inline void convertToJson(const StringRef &src, JsonVariant dst) { dst.set(src.c_str()); } diff --git a/esphome/cpp_types.py b/esphome/cpp_types.py index 0d1813f63b..7001c38857 100644 --- a/esphome/cpp_types.py +++ b/esphome/cpp_types.py @@ -44,3 +44,4 @@ gpio_Flags = gpio_ns.enum("Flags", is_class=True) EntityCategory = esphome_ns.enum("EntityCategory") Parented = esphome_ns.class_("Parented") ESPTime = esphome_ns.struct("ESPTime") +StringRef = esphome_ns.class_("StringRef") diff --git a/tests/integration/fixtures/select_stringref_trigger.yaml b/tests/integration/fixtures/select_stringref_trigger.yaml new file mode 100644 index 0000000000..bb1e1fd843 --- /dev/null +++ b/tests/integration/fixtures/select_stringref_trigger.yaml @@ -0,0 +1,85 @@ +esphome: + name: select-stringref-test + friendly_name: Select StringRef Test + +host: + +logger: + level: DEBUG + +api: + +select: + - platform: template + name: "Test Select" + id: test_select + optimistic: true + options: + - "Option A" + - "Option B" + - "Option C" + initial_option: "Option A" + on_value: + then: + # Test 1: Log the value directly (StringRef -> const char* via c_str()) + - logger.log: + format: "Select value: %s" + args: ['x.c_str()'] + # Test 2: String concatenation (StringRef + const char* -> std::string) + - lambda: |- + std::string with_suffix = x + " selected"; + ESP_LOGI("test", "Concatenated: %s", with_suffix.c_str()); + # Test 3: Comparison (StringRef == const char*) + - lambda: |- + if (x == "Option B") { + ESP_LOGI("test", "Option B was selected"); + } + # Test 4: Use index parameter (variable name is 'i') + - lambda: |- + ESP_LOGI("test", "Select index: %d", (int)i); + # Test 5: StringRef.length() method + - lambda: |- + ESP_LOGI("test", "Length: %d", (int)x.length()); + # Test 6: StringRef.find() method with substring + - lambda: |- + if (x.find("Option") != std::string::npos) { + ESP_LOGI("test", "Found 'Option' in value"); + } + # Test 7: StringRef.find() method with character + - lambda: |- + size_t space_pos = x.find(' '); + if (space_pos != std::string::npos) { + ESP_LOGI("test", "Space at position: %d", (int)space_pos); + } + # Test 8: StringRef.substr() method + - lambda: |- + std::string prefix = x.substr(0, 6); + ESP_LOGI("test", "Substr prefix: %s", prefix.c_str()); + + # Second select with numeric options to test ADL functions + - platform: template + name: "Baud Rate" + id: baud_select + optimistic: true + options: + - "9600" + - "115200" + initial_option: "9600" + on_value: + then: + # Test 9: stoi via ADL + - lambda: |- + int baud = stoi(x); + ESP_LOGI("test", "stoi result: %d", baud); + # Test 10: stol via ADL + - lambda: |- + long baud_long = stol(x); + ESP_LOGI("test", "stol result: %ld", baud_long); + # Test 11: stof via ADL + - lambda: |- + float baud_float = stof(x); + ESP_LOGI("test", "stof result: %.0f", baud_float); + # Test 12: stod via ADL + - lambda: |- + double baud_double = stod(x); + ESP_LOGI("test", "stod result: %.0f", baud_double); diff --git a/tests/integration/test_select_stringref_trigger.py b/tests/integration/test_select_stringref_trigger.py new file mode 100644 index 0000000000..7fc72a2290 --- /dev/null +++ b/tests/integration/test_select_stringref_trigger.py @@ -0,0 +1,143 @@ +"""Integration test for select on_value trigger with StringRef parameter.""" + +from __future__ import annotations + +import asyncio +import re + +import pytest + +from .types import APIClientConnectedFactory, RunCompiledFunction + + +@pytest.mark.asyncio +async def test_select_stringref_trigger( + yaml_config: str, + run_compiled: RunCompiledFunction, + api_client_connected: APIClientConnectedFactory, +) -> None: + """Test select on_value trigger passes StringRef that works with string operations.""" + loop = asyncio.get_running_loop() + + # Track log messages to verify StringRef operations work + value_logged_future = loop.create_future() + concatenated_future = loop.create_future() + comparison_future = loop.create_future() + index_logged_future = loop.create_future() + length_future = loop.create_future() + find_substr_future = loop.create_future() + find_char_future = loop.create_future() + substr_future = loop.create_future() + # ADL functions + stoi_future = loop.create_future() + stol_future = loop.create_future() + stof_future = loop.create_future() + stod_future = loop.create_future() + + # Patterns to match in logs + value_pattern = re.compile(r"Select value: Option B") + concatenated_pattern = re.compile(r"Concatenated: Option B selected") + comparison_pattern = re.compile(r"Option B was selected") + index_pattern = re.compile(r"Select index: 1") + length_pattern = re.compile(r"Length: 8") # "Option B" is 8 chars + find_substr_pattern = re.compile(r"Found 'Option' in value") + find_char_pattern = re.compile(r"Space at position: 6") # space at index 6 + substr_pattern = re.compile(r"Substr prefix: Option") + # ADL function patterns (115200 from baud rate select) + stoi_pattern = re.compile(r"stoi result: 115200") + stol_pattern = re.compile(r"stol result: 115200") + stof_pattern = re.compile(r"stof result: 115200") + stod_pattern = re.compile(r"stod result: 115200") + + def check_output(line: str) -> None: + """Check log output for expected messages.""" + if not value_logged_future.done() and value_pattern.search(line): + value_logged_future.set_result(True) + if not concatenated_future.done() and concatenated_pattern.search(line): + concatenated_future.set_result(True) + if not comparison_future.done() and comparison_pattern.search(line): + comparison_future.set_result(True) + if not index_logged_future.done() and index_pattern.search(line): + index_logged_future.set_result(True) + if not length_future.done() and length_pattern.search(line): + length_future.set_result(True) + if not find_substr_future.done() and find_substr_pattern.search(line): + find_substr_future.set_result(True) + if not find_char_future.done() and find_char_pattern.search(line): + find_char_future.set_result(True) + if not substr_future.done() and substr_pattern.search(line): + substr_future.set_result(True) + # ADL functions + if not stoi_future.done() and stoi_pattern.search(line): + stoi_future.set_result(True) + if not stol_future.done() and stol_pattern.search(line): + stol_future.set_result(True) + if not stof_future.done() and stof_pattern.search(line): + stof_future.set_result(True) + if not stod_future.done() and stod_pattern.search(line): + stod_future.set_result(True) + + async with ( + run_compiled(yaml_config, line_callback=check_output), + api_client_connected() as client, + ): + # Verify device info + device_info = await client.device_info() + assert device_info is not None + assert device_info.name == "select-stringref-test" + + # List entities to find our select + entities, _ = await client.list_entities_services() + + select_entity = next( + (e for e in entities if hasattr(e, "options") and e.name == "Test Select"), + None, + ) + assert select_entity is not None, "Test Select entity not found" + + baud_entity = next( + (e for e in entities if hasattr(e, "options") and e.name == "Baud Rate"), + None, + ) + assert baud_entity is not None, "Baud Rate entity not found" + + # Change select to Option B - this should trigger on_value with StringRef + client.select_command(select_entity.key, "Option B") + # Change baud to 115200 - this tests ADL functions (stoi, stol, stof, stod) + client.select_command(baud_entity.key, "115200") + + # Wait for all log messages confirming StringRef operations work + try: + await asyncio.wait_for( + asyncio.gather( + value_logged_future, + concatenated_future, + comparison_future, + index_logged_future, + length_future, + find_substr_future, + find_char_future, + substr_future, + stoi_future, + stol_future, + stof_future, + stod_future, + ), + timeout=5.0, + ) + except TimeoutError: + results = { + "value_logged": value_logged_future.done(), + "concatenated": concatenated_future.done(), + "comparison": comparison_future.done(), + "index_logged": index_logged_future.done(), + "length": length_future.done(), + "find_substr": find_substr_future.done(), + "find_char": find_char_future.done(), + "substr": substr_future.done(), + "stoi": stoi_future.done(), + "stol": stol_future.done(), + "stof": stof_future.done(), + "stod": stod_future.done(), + } + pytest.fail(f"StringRef operations failed - received: {results}")