diff --git a/esphome/components/api/custom_api_device.h b/esphome/components/api/custom_api_device.h index d34ccfa0ce..43ea644f0c 100644 --- a/esphome/components/api/custom_api_device.h +++ b/esphome/components/api/custom_api_device.h @@ -9,11 +9,11 @@ namespace esphome::api { #ifdef USE_API_SERVICES -template class CustomAPIDeviceService : public UserServiceBase { +template class CustomAPIDeviceService : public UserServiceDynamic { public: CustomAPIDeviceService(const std::string &name, const std::array &arg_names, T *obj, void (T::*callback)(Ts...)) - : UserServiceBase(name, arg_names), obj_(obj), callback_(callback) {} + : UserServiceDynamic(name, arg_names), obj_(obj), callback_(callback) {} protected: void execute(Ts... x) override { (this->obj_->*this->callback_)(x...); } // NOLINT diff --git a/esphome/components/api/user_services.h b/esphome/components/api/user_services.h index 9ca5e1093e..2a887fc52d 100644 --- a/esphome/components/api/user_services.h +++ b/esphome/components/api/user_services.h @@ -23,11 +23,13 @@ template T get_execute_arg_value(const ExecuteServiceArgument &arg); template enums::ServiceArgType to_service_arg_type(); +// Base class for YAML-defined services (most common case) +// Stores only pointers to string literals in flash - no heap allocation template class UserServiceBase : public UserServiceDescriptor { public: - UserServiceBase(std::string name, const std::array &arg_names) - : name_(std::move(name)), arg_names_(arg_names) { - this->key_ = fnv1_hash(this->name_); + UserServiceBase(const char *name, const std::array &arg_names) + : name_(name), arg_names_(arg_names) { + this->key_ = fnv1_hash(name); } ListEntitiesServicesResponse encode_list_service_response() override { @@ -47,7 +49,7 @@ template class UserServiceBase : public UserServiceDescriptor { bool execute_service(const ExecuteServiceRequest &req) override { if (req.key != this->key_) return false; - if (req.args.size() != this->arg_names_.size()) + if (req.args.size() != sizeof...(Ts)) return false; this->execute_(req.args, typename gens::type()); return true; @@ -59,14 +61,60 @@ template class UserServiceBase : public UserServiceDescriptor { this->execute((get_execute_arg_value(args[S]))...); } - std::string name_; + // Pointers to string literals in flash - no heap allocation + const char *name_; + std::array arg_names_; uint32_t key_{0}; +}; + +// Separate class for custom_api_device services (rare case) +// Stores copies of runtime-generated names +template class UserServiceDynamic : public UserServiceDescriptor { + public: + UserServiceDynamic(std::string name, const std::array &arg_names) + : name_(std::move(name)), arg_names_(arg_names) { + this->key_ = fnv1_hash(this->name_.c_str()); + } + + ListEntitiesServicesResponse encode_list_service_response() override { + ListEntitiesServicesResponse msg; + msg.set_name(StringRef(this->name_)); + msg.key = this->key_; + std::array arg_types = {to_service_arg_type()...}; + msg.args.init(sizeof...(Ts)); + for (size_t i = 0; i < sizeof...(Ts); i++) { + auto &arg = msg.args.emplace_back(); + arg.type = arg_types[i]; + arg.set_name(StringRef(this->arg_names_[i])); + } + return msg; + } + + bool execute_service(const ExecuteServiceRequest &req) override { + if (req.key != this->key_) + return false; + if (req.args.size() != sizeof...(Ts)) + return false; + this->execute_(req.args, typename gens::type()); + return true; + } + + protected: + virtual void execute(Ts... x) = 0; + template void execute_(const ArgsContainer &args, seq type) { + this->execute((get_execute_arg_value(args[S]))...); + } + + // Heap-allocated strings for runtime-generated names + std::string name_; std::array arg_names_; + uint32_t key_{0}; }; template class UserServiceTrigger : public UserServiceBase, public Trigger { public: - UserServiceTrigger(const std::string &name, const std::array &arg_names) + // Constructor for static names (YAML-defined services - used by code generator) + UserServiceTrigger(const char *name, const std::array &arg_names) : UserServiceBase(name, arg_names) {} protected: diff --git a/tests/integration/fixtures/api_custom_services.yaml b/tests/integration/fixtures/api_custom_services.yaml index 41efc95b85..a597c74126 100644 --- a/tests/integration/fixtures/api_custom_services.yaml +++ b/tests/integration/fixtures/api_custom_services.yaml @@ -11,6 +11,28 @@ api: then: - logger.log: "YAML service called" + # Test YAML service with arguments (tests UserServiceBase with const char* array) + - action: test_yaml_service_with_args + variables: + my_int: int + my_string: string + then: + - logger.log: + format: "YAML service with args: %d, %s" + args: [my_int, my_string.c_str()] + + # Test YAML service with multiple arguments + - action: test_yaml_service_many_args + variables: + arg1: int + arg2: float + arg3: bool + arg4: string + then: + - logger.log: + format: "YAML service many args: %d, %.2f, %d, %s" + args: [arg1, arg2, arg3, arg4.c_str()] + logger: level: DEBUG diff --git a/tests/integration/test_api_custom_services.py b/tests/integration/test_api_custom_services.py index 9ae4cdcb5d..967c504112 100644 --- a/tests/integration/test_api_custom_services.py +++ b/tests/integration/test_api_custom_services.py @@ -33,12 +33,16 @@ async def test_api_custom_services( # Track log messages yaml_service_future = loop.create_future() + yaml_args_future = loop.create_future() + yaml_many_args_future = loop.create_future() custom_service_future = loop.create_future() custom_args_future = loop.create_future() custom_arrays_future = loop.create_future() # Patterns to match in logs yaml_service_pattern = re.compile(r"YAML service called") + yaml_args_pattern = re.compile(r"YAML service with args: 123, test_value") + yaml_many_args_pattern = re.compile(r"YAML service many args: 42, 3\.14, 1, hello") custom_service_pattern = re.compile(r"Custom test service called!") custom_args_pattern = re.compile( r"Custom service called with: test_string, 456, 1, 78\.90" @@ -51,6 +55,10 @@ async def test_api_custom_services( """Check log output for expected messages.""" if not yaml_service_future.done() and yaml_service_pattern.search(line): yaml_service_future.set_result(True) + elif not yaml_args_future.done() and yaml_args_pattern.search(line): + yaml_args_future.set_result(True) + elif not yaml_many_args_future.done() and yaml_many_args_pattern.search(line): + yaml_many_args_future.set_result(True) elif not custom_service_future.done() and custom_service_pattern.search(line): custom_service_future.set_result(True) elif not custom_args_future.done() and custom_args_pattern.search(line): @@ -71,11 +79,13 @@ async def test_api_custom_services( # List services _, services = await client.list_entities_services() - # Should have 4 services: 1 YAML + 3 CustomAPIDevice - assert len(services) == 4, f"Expected 4 services, found {len(services)}" + # Should have 6 services: 3 YAML + 3 CustomAPIDevice + assert len(services) == 6, f"Expected 6 services, found {len(services)}" # Find our services yaml_service: UserService | None = None + yaml_args_service: UserService | None = None + yaml_many_args_service: UserService | None = None custom_service: UserService | None = None custom_args_service: UserService | None = None custom_arrays_service: UserService | None = None @@ -83,6 +93,10 @@ async def test_api_custom_services( for service in services: if service.name == "test_yaml_service": yaml_service = service + elif service.name == "test_yaml_service_with_args": + yaml_args_service = service + elif service.name == "test_yaml_service_many_args": + yaml_many_args_service = service elif service.name == "custom_test_service": custom_service = service elif service.name == "custom_service_with_args": @@ -91,6 +105,10 @@ async def test_api_custom_services( custom_arrays_service = service assert yaml_service is not None, "test_yaml_service not found" + assert yaml_args_service is not None, "test_yaml_service_with_args not found" + assert yaml_many_args_service is not None, ( + "test_yaml_service_many_args not found" + ) assert custom_service is not None, "custom_test_service not found" assert custom_args_service is not None, "custom_service_with_args not found" assert custom_arrays_service is not None, "custom_service_with_arrays not found" @@ -99,6 +117,44 @@ async def test_api_custom_services( client.execute_service(yaml_service, {}) await asyncio.wait_for(yaml_service_future, timeout=5.0) + # Verify YAML service with args arguments + assert len(yaml_args_service.args) == 2 + yaml_args_types = {arg.name: arg.type for arg in yaml_args_service.args} + assert yaml_args_types["my_int"] == UserServiceArgType.INT + assert yaml_args_types["my_string"] == UserServiceArgType.STRING + + # Test YAML service with arguments + client.execute_service( + yaml_args_service, + { + "my_int": 123, + "my_string": "test_value", + }, + ) + await asyncio.wait_for(yaml_args_future, timeout=5.0) + + # Verify YAML service with many args arguments + assert len(yaml_many_args_service.args) == 4 + yaml_many_args_types = { + arg.name: arg.type for arg in yaml_many_args_service.args + } + assert yaml_many_args_types["arg1"] == UserServiceArgType.INT + assert yaml_many_args_types["arg2"] == UserServiceArgType.FLOAT + assert yaml_many_args_types["arg3"] == UserServiceArgType.BOOL + assert yaml_many_args_types["arg4"] == UserServiceArgType.STRING + + # Test YAML service with many arguments + client.execute_service( + yaml_many_args_service, + { + "arg1": 42, + "arg2": 3.14, + "arg3": True, + "arg4": "hello", + }, + ) + await asyncio.wait_for(yaml_many_args_future, timeout=5.0) + # Test simple CustomAPIDevice service client.execute_service(custom_service, {}) await asyncio.wait_for(custom_service_future, timeout=5.0)