mirror of
				https://github.com/esphome/esphome.git
				synced 2025-10-29 22:24:26 +00:00 
			
		
		
		
	Add native API user-defined services (#453)
This commit is contained in:
		| @@ -1,13 +1,15 @@ | ||||
| import voluptuous as vol | ||||
|  | ||||
| from esphome import automation | ||||
| from esphome.automation import ACTION_REGISTRY | ||||
| import esphome.config_validation as cv | ||||
| from esphome.const import CONF_DATA, CONF_DATA_TEMPLATE, CONF_ID, CONF_PASSWORD, CONF_PORT, \ | ||||
|     CONF_REBOOT_TIMEOUT, CONF_SERVICE, CONF_VARIABLES | ||||
|     CONF_REBOOT_TIMEOUT, CONF_SERVICE, CONF_VARIABLES, CONF_SERVICES, CONF_TRIGGER_ID | ||||
| from esphome.core import CORE | ||||
| from esphome.cpp_generator import Pvariable, add, get_variable, process_lambda | ||||
| from esphome.cpp_helpers import setup_component | ||||
| from esphome.cpp_types import Action, App, Component, StoringController, esphome_ns | ||||
| from esphome.cpp_types import Action, App, Component, StoringController, esphome_ns, Trigger, bool_, \ | ||||
|     int32, float_, std_string | ||||
|  | ||||
| api_ns = esphome_ns.namespace('api') | ||||
| APIServer = api_ns.class_('APIServer', Component, StoringController) | ||||
| @@ -15,11 +17,35 @@ HomeAssistantServiceCallAction = api_ns.class_('HomeAssistantServiceCallAction', | ||||
| KeyValuePair = api_ns.class_('KeyValuePair') | ||||
| TemplatableKeyValuePair = api_ns.class_('TemplatableKeyValuePair') | ||||
|  | ||||
| UserService = api_ns.class_('UserService', Trigger) | ||||
| ServiceTypeArgument = api_ns.class_('ServiceTypeArgument') | ||||
| ServiceArgType = api_ns.enum('ServiceArgType') | ||||
| SERVICE_ARG_TYPES = { | ||||
|     'bool': ServiceArgType.SERVICE_ARG_TYPE_BOOL, | ||||
|     'int': ServiceArgType.SERVICE_ARG_TYPE_INT, | ||||
|     'float': ServiceArgType.SERVICE_ARG_TYPE_FLOAT, | ||||
|     'string': ServiceArgType.SERVICE_ARG_TYPE_STRING, | ||||
| } | ||||
| SERVICE_ARG_NATIVE_TYPES = { | ||||
|     'bool': bool_, | ||||
|     'int': int32, | ||||
|     'float': float_, | ||||
|     'string': std_string, | ||||
| } | ||||
|  | ||||
|  | ||||
| CONFIG_SCHEMA = cv.Schema({ | ||||
|     cv.GenerateID(): cv.declare_variable_id(APIServer), | ||||
|     vol.Optional(CONF_PORT, default=6053): cv.port, | ||||
|     vol.Optional(CONF_PASSWORD, default=''): cv.string_strict, | ||||
|     vol.Optional(CONF_REBOOT_TIMEOUT): cv.positive_time_period_milliseconds, | ||||
|     vol.Optional(CONF_SERVICES): automation.validate_automation({ | ||||
|         cv.GenerateID(CONF_TRIGGER_ID): cv.declare_variable_id(UserService), | ||||
|         vol.Required(CONF_SERVICE): cv.valid_name, | ||||
|         vol.Optional(CONF_VARIABLES, default={}): cv.Schema({ | ||||
|             cv.validate_id_name: cv.one_of(*SERVICE_ARG_TYPES, lower=True), | ||||
|         }), | ||||
|     }), | ||||
| }).extend(cv.COMPONENT_SCHEMA.schema) | ||||
|  | ||||
|  | ||||
| @@ -34,6 +60,21 @@ def to_code(config): | ||||
|     if CONF_REBOOT_TIMEOUT in config: | ||||
|         add(api.set_reboot_timeout(config[CONF_REBOOT_TIMEOUT])) | ||||
|  | ||||
|     for conf in config.get(CONF_SERVICES, []): | ||||
|         template_args = [] | ||||
|         func_args = [] | ||||
|         service_type_args = [] | ||||
|         for name, var_ in conf[CONF_VARIABLES].items(): | ||||
|             native = SERVICE_ARG_NATIVE_TYPES[var_] | ||||
|             template_args.append(native) | ||||
|             func_args.append((native, name)) | ||||
|             service_type_args.append(ServiceTypeArgument(name, SERVICE_ARG_TYPES[var_])) | ||||
|         func = api.make_user_service_trigger.template(*template_args) | ||||
|         rhs = func(conf[CONF_SERVICE], service_type_args) | ||||
|         type_ = UserService.template(*template_args) | ||||
|         trigger = Pvariable(conf[CONF_TRIGGER_ID], rhs, type=type_) | ||||
|         automation.build_automations(trigger, func_args, conf) | ||||
|  | ||||
|     setup_component(api, config) | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -162,7 +162,7 @@ def int_(value): | ||||
| hex_int = vol.Coerce(hex_int_) | ||||
|  | ||||
|  | ||||
| def variable_id_str_(value): | ||||
| def validate_id_name(value): | ||||
|     value = string(value) | ||||
|     if not value: | ||||
|         raise vol.Invalid("ID must not be empty") | ||||
| @@ -185,7 +185,7 @@ def use_variable_id(type): | ||||
|         if value is None: | ||||
|             return core.ID(None, is_declaration=False, type=type) | ||||
|  | ||||
|         return core.ID(variable_id_str_(value), is_declaration=False, type=type) | ||||
|         return core.ID(validate_id_name(value), is_declaration=False, type=type) | ||||
|  | ||||
|     return validator | ||||
|  | ||||
| @@ -195,7 +195,7 @@ def declare_variable_id(type): | ||||
|         if value is None: | ||||
|             return core.ID(None, is_declaration=True, type=type) | ||||
|  | ||||
|         return core.ID(variable_id_str_(value), is_declaration=True, type=type) | ||||
|         return core.ID(validate_id_name(value), is_declaration=True, type=type) | ||||
|  | ||||
|     return validator | ||||
|  | ||||
|   | ||||
| @@ -23,6 +23,7 @@ CONF_ARDUINO_VERSION = 'arduino_version' | ||||
| CONF_LOCAL = 'local' | ||||
| CONF_REPOSITORY = 'repository' | ||||
| CONF_COMMIT = 'commit' | ||||
| CONF_SERVICES = 'services' | ||||
| CONF_TAG = 'tag' | ||||
| CONF_BRANCH = 'branch' | ||||
| CONF_LOGGER = 'logger' | ||||
|   | ||||
		Reference in New Issue
	
	Block a user