mirror of
				https://github.com/esphome/esphome.git
				synced 2025-10-21 19:23:45 +01:00 
			
		
		
		
	Add on_client_connected and disconnected to voice assistant (#5629)
This commit is contained in:
		| @@ -18,6 +18,8 @@ from esphome.const import ( | ||||
|     CONF_TRIGGER_ID, | ||||
|     CONF_EVENT, | ||||
|     CONF_TAG, | ||||
|     CONF_ON_CLIENT_CONNECTED, | ||||
|     CONF_ON_CLIENT_DISCONNECTED, | ||||
| ) | ||||
| from esphome.core import coroutine_with_priority | ||||
|  | ||||
| @@ -45,8 +47,6 @@ SERVICE_ARG_NATIVE_TYPES = { | ||||
|     "string[]": cg.std_vector.template(cg.std_string), | ||||
| } | ||||
| CONF_ENCRYPTION = "encryption" | ||||
| CONF_ON_CLIENT_CONNECTED = "on_client_connected" | ||||
| CONF_ON_CLIENT_DISCONNECTED = "on_client_disconnected" | ||||
|  | ||||
|  | ||||
| def validate_encryption_key(value): | ||||
|   | ||||
| @@ -60,6 +60,11 @@ APIConnection::~APIConnection() { | ||||
|     bluetooth_proxy::global_bluetooth_proxy->unsubscribe_api_connection(this); | ||||
|   } | ||||
| #endif | ||||
| #ifdef USE_VOICE_ASSISTANT | ||||
|   if (voice_assistant::global_voice_assistant->get_api_connection() == this) { | ||||
|     voice_assistant::global_voice_assistant->client_subscription(this, false); | ||||
|   } | ||||
| #endif | ||||
| } | ||||
|  | ||||
| void APIConnection::loop() { | ||||
| @@ -950,14 +955,17 @@ BluetoothConnectionsFreeResponse APIConnection::subscribe_bluetooth_connections_ | ||||
| #endif | ||||
|  | ||||
| #ifdef USE_VOICE_ASSISTANT | ||||
| bool APIConnection::request_voice_assistant(const VoiceAssistantRequest &msg) { | ||||
|   if (!this->voice_assistant_subscription_) | ||||
|     return false; | ||||
|  | ||||
|   return this->send_voice_assistant_request(msg); | ||||
| void APIConnection::subscribe_voice_assistant(const SubscribeVoiceAssistantRequest &msg) { | ||||
|   if (voice_assistant::global_voice_assistant != nullptr) { | ||||
|     voice_assistant::global_voice_assistant->client_subscription(this, msg.subscribe); | ||||
|   } | ||||
| } | ||||
| void APIConnection::on_voice_assistant_response(const VoiceAssistantResponse &msg) { | ||||
|   if (voice_assistant::global_voice_assistant != nullptr) { | ||||
|     if (voice_assistant::global_voice_assistant->get_api_connection() != this) { | ||||
|       return; | ||||
|     } | ||||
|  | ||||
|     if (msg.error) { | ||||
|       voice_assistant::global_voice_assistant->failed_to_start(); | ||||
|       return; | ||||
| @@ -970,6 +978,10 @@ void APIConnection::on_voice_assistant_response(const VoiceAssistantResponse &ms | ||||
| }; | ||||
| void APIConnection::on_voice_assistant_event_response(const VoiceAssistantEventResponse &msg) { | ||||
|   if (voice_assistant::global_voice_assistant != nullptr) { | ||||
|     if (voice_assistant::global_voice_assistant->get_api_connection() != this) { | ||||
|       return; | ||||
|     } | ||||
|  | ||||
|     voice_assistant::global_voice_assistant->on_event(msg); | ||||
|   } | ||||
| } | ||||
|   | ||||
| @@ -126,10 +126,7 @@ class APIConnection : public APIServerConnection { | ||||
| #endif | ||||
|  | ||||
| #ifdef USE_VOICE_ASSISTANT | ||||
|   void subscribe_voice_assistant(const SubscribeVoiceAssistantRequest &msg) override { | ||||
|     this->voice_assistant_subscription_ = msg.subscribe; | ||||
|   } | ||||
|   bool request_voice_assistant(const VoiceAssistantRequest &msg); | ||||
|   void subscribe_voice_assistant(const SubscribeVoiceAssistantRequest &msg) override; | ||||
|   void on_voice_assistant_response(const VoiceAssistantResponse &msg) override; | ||||
|   void on_voice_assistant_event_response(const VoiceAssistantEventResponse &msg) override; | ||||
| #endif | ||||
| @@ -188,6 +185,8 @@ class APIConnection : public APIServerConnection { | ||||
|   } | ||||
|   bool send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) override; | ||||
|  | ||||
|   std::string get_client_combined_info() const { return this->client_combined_info_; } | ||||
|  | ||||
|  protected: | ||||
|   friend APIServer; | ||||
|  | ||||
| @@ -220,9 +219,6 @@ class APIConnection : public APIServerConnection { | ||||
|   uint32_t last_traffic_; | ||||
|   bool sent_ping_{false}; | ||||
|   bool service_call_subscription_{false}; | ||||
| #ifdef USE_VOICE_ASSISTANT | ||||
|   bool voice_assistant_subscription_{false}; | ||||
| #endif | ||||
|   bool next_close_ = false; | ||||
|   APIServer *parent_; | ||||
|   InitialStateIterator initial_state_iterator_; | ||||
|   | ||||
| @@ -332,30 +332,6 @@ void APIServer::on_shutdown() { | ||||
|   delay(10); | ||||
| } | ||||
|  | ||||
| #ifdef USE_VOICE_ASSISTANT | ||||
| bool APIServer::start_voice_assistant(const std::string &conversation_id, uint32_t flags, | ||||
|                                       const api::VoiceAssistantAudioSettings &audio_settings) { | ||||
|   VoiceAssistantRequest msg; | ||||
|   msg.start = true; | ||||
|   msg.conversation_id = conversation_id; | ||||
|   msg.flags = flags; | ||||
|   msg.audio_settings = audio_settings; | ||||
|   for (auto &c : this->clients_) { | ||||
|     if (c->request_voice_assistant(msg)) | ||||
|       return true; | ||||
|   } | ||||
|   return false; | ||||
| } | ||||
| void APIServer::stop_voice_assistant() { | ||||
|   VoiceAssistantRequest msg; | ||||
|   msg.start = false; | ||||
|   for (auto &c : this->clients_) { | ||||
|     if (c->request_voice_assistant(msg)) | ||||
|       return; | ||||
|   } | ||||
| } | ||||
| #endif | ||||
|  | ||||
| #ifdef USE_ALARM_CONTROL_PANEL | ||||
| void APIServer::on_alarm_control_panel_update(alarm_control_panel::AlarmControlPanel *obj) { | ||||
|   if (obj->is_internal()) | ||||
|   | ||||
| @@ -84,12 +84,6 @@ class APIServer : public Component, public Controller { | ||||
|   void request_time(); | ||||
| #endif | ||||
|  | ||||
| #ifdef USE_VOICE_ASSISTANT | ||||
|   bool start_voice_assistant(const std::string &conversation_id, uint32_t flags, | ||||
|                              const api::VoiceAssistantAudioSettings &audio_settings); | ||||
|   void stop_voice_assistant(); | ||||
| #endif | ||||
|  | ||||
| #ifdef USE_ALARM_CONTROL_PANEL | ||||
|   void on_alarm_control_panel_update(alarm_control_panel::AlarmControlPanel *obj) override; | ||||
| #endif | ||||
|   | ||||
| @@ -6,6 +6,8 @@ from esphome.const import ( | ||||
|     CONF_MICROPHONE, | ||||
|     CONF_SPEAKER, | ||||
|     CONF_MEDIA_PLAYER, | ||||
|     CONF_ON_CLIENT_CONNECTED, | ||||
|     CONF_ON_CLIENT_DISCONNECTED, | ||||
| ) | ||||
| from esphome import automation | ||||
| from esphome.automation import register_action, register_condition | ||||
| @@ -80,6 +82,12 @@ CONFIG_SCHEMA = cv.All( | ||||
|             cv.Optional(CONF_ON_TTS_END): automation.validate_automation(single=True), | ||||
|             cv.Optional(CONF_ON_END): automation.validate_automation(single=True), | ||||
|             cv.Optional(CONF_ON_ERROR): automation.validate_automation(single=True), | ||||
|             cv.Optional(CONF_ON_CLIENT_CONNECTED): automation.validate_automation( | ||||
|                 single=True | ||||
|             ), | ||||
|             cv.Optional(CONF_ON_CLIENT_DISCONNECTED): automation.validate_automation( | ||||
|                 single=True | ||||
|             ), | ||||
|         } | ||||
|     ).extend(cv.COMPONENT_SCHEMA), | ||||
| ) | ||||
| @@ -155,6 +163,20 @@ async def to_code(config): | ||||
|             config[CONF_ON_ERROR], | ||||
|         ) | ||||
|  | ||||
|     if CONF_ON_CLIENT_CONNECTED in config: | ||||
|         await automation.build_automation( | ||||
|             var.get_client_connected_trigger(), | ||||
|             [], | ||||
|             config[CONF_ON_CLIENT_CONNECTED], | ||||
|         ) | ||||
|  | ||||
|     if CONF_ON_CLIENT_DISCONNECTED in config: | ||||
|         await automation.build_automation( | ||||
|             var.get_client_disconnected_trigger(), | ||||
|             [], | ||||
|             config[CONF_ON_CLIENT_DISCONNECTED], | ||||
|         ) | ||||
|  | ||||
|     cg.add_define("USE_VOICE_ASSISTANT") | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -127,8 +127,8 @@ int VoiceAssistant::read_microphone_() { | ||||
| } | ||||
|  | ||||
| void VoiceAssistant::loop() { | ||||
|   if (this->state_ != State::IDLE && this->state_ != State::STOP_MICROPHONE && | ||||
|       this->state_ != State::STOPPING_MICROPHONE && !api::global_api_server->is_connected()) { | ||||
|   if (this->api_client_ == nullptr && this->state_ != State::IDLE && this->state_ != State::STOP_MICROPHONE && | ||||
|       this->state_ != State::STOPPING_MICROPHONE) { | ||||
|     if (this->mic_->is_running() || this->state_ == State::STARTING_MICROPHONE) { | ||||
|       this->set_state_(State::STOP_MICROPHONE, State::IDLE); | ||||
|     } else { | ||||
| @@ -213,7 +213,14 @@ void VoiceAssistant::loop() { | ||||
|       audio_settings.noise_suppression_level = this->noise_suppression_level_; | ||||
|       audio_settings.auto_gain = this->auto_gain_; | ||||
|       audio_settings.volume_multiplier = this->volume_multiplier_; | ||||
|       if (!api::global_api_server->start_voice_assistant(this->conversation_id_, flags, audio_settings)) { | ||||
|  | ||||
|       api::VoiceAssistantRequest msg; | ||||
|       msg.start = true; | ||||
|       msg.conversation_id = this->conversation_id_; | ||||
|       msg.flags = flags; | ||||
|       msg.audio_settings = audio_settings; | ||||
|  | ||||
|       if (this->api_client_ == nullptr || !this->api_client_->send_voice_assistant_request(msg)) { | ||||
|         ESP_LOGW(TAG, "Could not request start."); | ||||
|         this->error_trigger_->trigger("not-connected", "Could not request start."); | ||||
|         this->continuous_ = false; | ||||
| @@ -326,6 +333,28 @@ void VoiceAssistant::loop() { | ||||
|   } | ||||
| } | ||||
|  | ||||
| void VoiceAssistant::client_subscription(api::APIConnection *client, bool subscribe) { | ||||
|   if (!subscribe) { | ||||
|     if (this->api_client_ == nullptr || client != this->api_client_) { | ||||
|       ESP_LOGE(TAG, "Client attempting to unsubscribe that is not the current API Client"); | ||||
|       return; | ||||
|     } | ||||
|     this->api_client_ = nullptr; | ||||
|     this->client_disconnected_trigger_->trigger(); | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   if (this->api_client_ != nullptr) { | ||||
|     ESP_LOGE(TAG, "Multiple API Clients attempting to connect to Voice Assistant"); | ||||
|     ESP_LOGE(TAG, "Current client: %s", this->api_client_->get_client_combined_info().c_str()); | ||||
|     ESP_LOGE(TAG, "New client: %s", client->get_client_combined_info().c_str()); | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   this->api_client_ = client; | ||||
|   this->client_connected_trigger_->trigger(); | ||||
| } | ||||
|  | ||||
| static const LogString *voice_assistant_state_to_string(State state) { | ||||
|   switch (state) { | ||||
|     case State::IDLE: | ||||
| @@ -408,7 +437,7 @@ void VoiceAssistant::start_streaming(struct sockaddr_storage *addr, uint16_t por | ||||
| } | ||||
|  | ||||
| void VoiceAssistant::request_start(bool continuous, bool silence_detection) { | ||||
|   if (!api::global_api_server->is_connected()) { | ||||
|   if (this->api_client_ == nullptr) { | ||||
|     ESP_LOGE(TAG, "No API client connected"); | ||||
|     this->set_state_(State::IDLE, State::IDLE); | ||||
|     this->continuous_ = false; | ||||
| @@ -459,9 +488,14 @@ void VoiceAssistant::request_stop() { | ||||
| } | ||||
|  | ||||
| void VoiceAssistant::signal_stop_() { | ||||
|   ESP_LOGD(TAG, "Signaling stop..."); | ||||
|   api::global_api_server->stop_voice_assistant(); | ||||
|   memset(&this->dest_addr_, 0, sizeof(this->dest_addr_)); | ||||
|   if (this->api_client_ == nullptr) { | ||||
|     return; | ||||
|   } | ||||
|   ESP_LOGD(TAG, "Signaling stop..."); | ||||
|   api::VoiceAssistantRequest msg; | ||||
|   msg.start = false; | ||||
|   this->api_client_->send_voice_assistant_request(msg); | ||||
| } | ||||
|  | ||||
| void VoiceAssistant::on_event(const api::VoiceAssistantEventResponse &msg) { | ||||
|   | ||||
| @@ -8,8 +8,8 @@ | ||||
| #include "esphome/core/component.h" | ||||
| #include "esphome/core/helpers.h" | ||||
|  | ||||
| #include "esphome/components/api/api_connection.h" | ||||
| #include "esphome/components/api/api_pb2.h" | ||||
| #include "esphome/components/api/api_server.h" | ||||
| #include "esphome/components/microphone/microphone.h" | ||||
| #ifdef USE_SPEAKER | ||||
| #include "esphome/components/speaker/speaker.h" | ||||
| @@ -109,6 +109,12 @@ class VoiceAssistant : public Component { | ||||
|   Trigger<> *get_end_trigger() const { return this->end_trigger_; } | ||||
|   Trigger<std::string, std::string> *get_error_trigger() const { return this->error_trigger_; } | ||||
|  | ||||
|   Trigger<> *get_client_connected_trigger() const { return this->client_connected_trigger_; } | ||||
|   Trigger<> *get_client_disconnected_trigger() const { return this->client_disconnected_trigger_; } | ||||
|  | ||||
|   void client_subscription(api::APIConnection *client, bool subscribe); | ||||
|   api::APIConnection *get_api_connection() const { return this->api_client_; } | ||||
|  | ||||
|  protected: | ||||
|   int read_microphone_(); | ||||
|   void set_state_(State state); | ||||
| @@ -127,6 +133,11 @@ class VoiceAssistant : public Component { | ||||
|   Trigger<> *end_trigger_ = new Trigger<>(); | ||||
|   Trigger<std::string, std::string> *error_trigger_ = new Trigger<std::string, std::string>(); | ||||
|  | ||||
|   Trigger<> *client_connected_trigger_ = new Trigger<>(); | ||||
|   Trigger<> *client_disconnected_trigger_ = new Trigger<>(); | ||||
|  | ||||
|   api::APIConnection *api_client_{nullptr}; | ||||
|  | ||||
|   microphone::Microphone *mic_{nullptr}; | ||||
| #ifdef USE_SPEAKER | ||||
|   speaker::Speaker *speaker_{nullptr}; | ||||
|   | ||||
| @@ -485,6 +485,8 @@ CONF_ON_BLE_MANUFACTURER_DATA_ADVERTISE = "on_ble_manufacturer_data_advertise" | ||||
| CONF_ON_BLE_SERVICE_DATA_ADVERTISE = "on_ble_service_data_advertise" | ||||
| CONF_ON_BOOT = "on_boot" | ||||
| CONF_ON_CLICK = "on_click" | ||||
| CONF_ON_CLIENT_CONNECTED = "on_client_connected" | ||||
| CONF_ON_CLIENT_DISCONNECTED = "on_client_disconnected" | ||||
| CONF_ON_CONNECT = "on_connect" | ||||
| CONF_ON_CONTROL = "on_control" | ||||
| CONF_ON_DISCONNECT = "on_disconnect" | ||||
|   | ||||
		Reference in New Issue
	
	Block a user