mirror of
				https://github.com/esphome/esphome.git
				synced 2025-10-31 07:03:55 +00:00 
			
		
		
		
	[micro_wake_word] add new VPE features (#8655)
This commit is contained in:
		| @@ -12,6 +12,7 @@ import esphome.config_validation as cv | ||||
| from esphome.const import ( | ||||
|     CONF_FILE, | ||||
|     CONF_ID, | ||||
|     CONF_INTERNAL, | ||||
|     CONF_MICROPHONE, | ||||
|     CONF_MODEL, | ||||
|     CONF_PASSWORD, | ||||
| @@ -40,6 +41,7 @@ CONF_ON_WAKE_WORD_DETECTED = "on_wake_word_detected" | ||||
| CONF_PROBABILITY_CUTOFF = "probability_cutoff" | ||||
| CONF_SLIDING_WINDOW_AVERAGE_SIZE = "sliding_window_average_size" | ||||
| CONF_SLIDING_WINDOW_SIZE = "sliding_window_size" | ||||
| CONF_STOP_AFTER_DETECTION = "stop_after_detection" | ||||
| CONF_TENSOR_ARENA_SIZE = "tensor_arena_size" | ||||
| CONF_VAD = "vad" | ||||
|  | ||||
| @@ -49,13 +51,20 @@ micro_wake_word_ns = cg.esphome_ns.namespace("micro_wake_word") | ||||
|  | ||||
| MicroWakeWord = micro_wake_word_ns.class_("MicroWakeWord", cg.Component) | ||||
|  | ||||
| DisableModelAction = micro_wake_word_ns.class_("DisableModelAction", automation.Action) | ||||
| EnableModelAction = micro_wake_word_ns.class_("EnableModelAction", automation.Action) | ||||
| StartAction = micro_wake_word_ns.class_("StartAction", automation.Action) | ||||
| StopAction = micro_wake_word_ns.class_("StopAction", automation.Action) | ||||
|  | ||||
| ModelIsEnabledCondition = micro_wake_word_ns.class_( | ||||
|     "ModelIsEnabledCondition", automation.Condition | ||||
| ) | ||||
| IsRunningCondition = micro_wake_word_ns.class_( | ||||
|     "IsRunningCondition", automation.Condition | ||||
| ) | ||||
|  | ||||
| WakeWordModel = micro_wake_word_ns.class_("WakeWordModel") | ||||
|  | ||||
|  | ||||
| def _validate_json_filename(value): | ||||
|     value = cv.string(value) | ||||
| @@ -169,9 +178,10 @@ def _convert_manifest_v1_to_v2(v1_manifest): | ||||
|  | ||||
|     # Original Inception-based V1 manifest models require a minimum of 45672 bytes | ||||
|     v2_manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE] = 45672 | ||||
|  | ||||
|     # Original Inception-based V1 manifest models use a 20 ms feature step size | ||||
|     v2_manifest[KEY_MICRO][CONF_FEATURE_STEP_SIZE] = 20 | ||||
|     # Original Inception-based V1 manifest models were trained only on TTS English samples | ||||
|     v2_manifest[KEY_TRAINED_LANGUAGES] = ["en"] | ||||
|  | ||||
|     return v2_manifest | ||||
|  | ||||
| @@ -296,14 +306,16 @@ MODEL_SOURCE_SCHEMA = cv.Any( | ||||
|  | ||||
| MODEL_SCHEMA = cv.Schema( | ||||
|     { | ||||
|         cv.GenerateID(CONF_ID): cv.declare_id(WakeWordModel), | ||||
|         cv.Optional(CONF_MODEL): MODEL_SOURCE_SCHEMA, | ||||
|         cv.Optional(CONF_PROBABILITY_CUTOFF): cv.percentage, | ||||
|         cv.Optional(CONF_SLIDING_WINDOW_SIZE): cv.positive_int, | ||||
|         cv.Optional(CONF_INTERNAL, default=False): cv.boolean, | ||||
|         cv.GenerateID(CONF_RAW_DATA_ID): cv.declare_id(cg.uint8), | ||||
|     } | ||||
| ) | ||||
|  | ||||
| # Provide a default VAD model that could be overridden | ||||
| # Provides a default VAD model that could be overridden | ||||
| VAD_MODEL_SCHEMA = MODEL_SCHEMA.extend( | ||||
|     cv.Schema( | ||||
|         { | ||||
| @@ -343,6 +355,7 @@ CONFIG_SCHEMA = cv.All( | ||||
|                 single=True | ||||
|             ), | ||||
|             cv.Optional(CONF_VAD): _maybe_empty_vad_schema, | ||||
|             cv.Optional(CONF_STOP_AFTER_DETECTION, default=True): cv.boolean, | ||||
|             cv.Optional(CONF_MODEL): cv.invalid( | ||||
|                 f"The {CONF_MODEL} parameter has moved to be a list element under the {CONF_MODELS} parameter." | ||||
|             ), | ||||
| @@ -433,29 +446,20 @@ async def to_code(config): | ||||
|     mic_source = await microphone.microphone_source_to_code(config[CONF_MICROPHONE]) | ||||
|     cg.add(var.set_microphone_source(mic_source)) | ||||
|  | ||||
|     cg.add_define("USE_MICRO_WAKE_WORD") | ||||
|     cg.add_define("USE_OTA_STATE_CALLBACK") | ||||
|  | ||||
|     esp32.add_idf_component( | ||||
|         name="esp-tflite-micro", | ||||
|         repo="https://github.com/espressif/esp-tflite-micro", | ||||
|         ref="v1.3.1", | ||||
|     ) | ||||
|     # add esp-nn dependency for tflite-micro to work around https://github.com/espressif/esp-nn/issues/17 | ||||
|     # ...remove after switching to IDF 5.1.4+ | ||||
|     esp32.add_idf_component( | ||||
|         name="esp-nn", | ||||
|         repo="https://github.com/espressif/esp-nn", | ||||
|         ref="v1.1.0", | ||||
|         ref="v1.3.3.1", | ||||
|     ) | ||||
|  | ||||
|     cg.add_build_flag("-DTF_LITE_STATIC_MEMORY") | ||||
|     cg.add_build_flag("-DTF_LITE_DISABLE_X86_NEON") | ||||
|     cg.add_build_flag("-DESP_NN") | ||||
|  | ||||
|     if on_wake_word_detection_config := config.get(CONF_ON_WAKE_WORD_DETECTED): | ||||
|         await automation.build_automation( | ||||
|             var.get_wake_word_detected_trigger(), | ||||
|             [(cg.std_string, "wake_word")], | ||||
|             on_wake_word_detection_config, | ||||
|         ) | ||||
|     cg.add_library("kahrendt/ESPMicroSpeechFeatures", "1.1.0") | ||||
|  | ||||
|     if vad_model := config.get(CONF_VAD): | ||||
|         cg.add_define("USE_MICRO_WAKE_WORD_VAD") | ||||
| @@ -463,7 +467,7 @@ async def to_code(config): | ||||
|         # Use the general model loading code for the VAD codegen | ||||
|         config[CONF_MODELS].append(vad_model) | ||||
|  | ||||
|     for model_parameters in config[CONF_MODELS]: | ||||
|     for i, model_parameters in enumerate(config[CONF_MODELS]): | ||||
|         model_config = model_parameters.get(CONF_MODEL) | ||||
|         data = [] | ||||
|         manifest, data = _model_config_to_manifest_data(model_config) | ||||
| @@ -474,6 +478,8 @@ async def to_code(config): | ||||
|         probability_cutoff = model_parameters.get( | ||||
|             CONF_PROBABILITY_CUTOFF, manifest[KEY_MICRO][CONF_PROBABILITY_CUTOFF] | ||||
|         ) | ||||
|         quantized_probability_cutoff = int(probability_cutoff * 255) | ||||
|  | ||||
|         sliding_window_size = model_parameters.get( | ||||
|             CONF_SLIDING_WINDOW_SIZE, | ||||
|             manifest[KEY_MICRO][CONF_SLIDING_WINDOW_SIZE], | ||||
| @@ -483,24 +489,40 @@ async def to_code(config): | ||||
|             cg.add( | ||||
|                 var.add_vad_model( | ||||
|                     prog_arr, | ||||
|                     probability_cutoff, | ||||
|                     quantized_probability_cutoff, | ||||
|                     sliding_window_size, | ||||
|                     manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE], | ||||
|                 ) | ||||
|             ) | ||||
|         else: | ||||
|             cg.add( | ||||
|                 var.add_wake_word_model( | ||||
|             # Only enable the first wake word by default. After first boot, the enable state is saved/loaded to the flash | ||||
|             default_enabled = i == 0 | ||||
|             wake_word_model = cg.new_Pvariable( | ||||
|                 model_parameters[CONF_ID], | ||||
|                 str(model_parameters[CONF_ID]), | ||||
|                 prog_arr, | ||||
|                     probability_cutoff, | ||||
|                 quantized_probability_cutoff, | ||||
|                 sliding_window_size, | ||||
|                 manifest[KEY_WAKE_WORD], | ||||
|                 manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE], | ||||
|                 ) | ||||
|                 default_enabled, | ||||
|                 model_parameters[CONF_INTERNAL], | ||||
|             ) | ||||
|  | ||||
|             for lang in manifest[KEY_TRAINED_LANGUAGES]: | ||||
|                 cg.add(wake_word_model.add_trained_language(lang)) | ||||
|  | ||||
|             cg.add(var.add_wake_word_model(wake_word_model)) | ||||
|  | ||||
|     cg.add(var.set_features_step_size(manifest[KEY_MICRO][CONF_FEATURE_STEP_SIZE])) | ||||
|     cg.add_library("kahrendt/ESPMicroSpeechFeatures", "1.1.0") | ||||
|     cg.add(var.set_stop_after_detection(config[CONF_STOP_AFTER_DETECTION])) | ||||
|  | ||||
|     if on_wake_word_detection_config := config.get(CONF_ON_WAKE_WORD_DETECTED): | ||||
|         await automation.build_automation( | ||||
|             var.get_wake_word_detected_trigger(), | ||||
|             [(cg.std_string, "wake_word")], | ||||
|             on_wake_word_detection_config, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| MICRO_WAKE_WORD_ACTION_SCHEMA = cv.Schema({cv.GenerateID(): cv.use_id(MicroWakeWord)}) | ||||
| @@ -515,3 +537,30 @@ async def micro_wake_word_action_to_code(config, action_id, template_arg, args): | ||||
|     var = cg.new_Pvariable(action_id, template_arg) | ||||
|     await cg.register_parented(var, config[CONF_ID]) | ||||
|     return var | ||||
|  | ||||
|  | ||||
| MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA = automation.maybe_simple_id( | ||||
|     { | ||||
|         cv.Required(CONF_ID): cv.use_id(WakeWordModel), | ||||
|     } | ||||
| ) | ||||
|  | ||||
|  | ||||
| @register_action( | ||||
|     "micro_wake_word.enable_model", | ||||
|     EnableModelAction, | ||||
|     MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA, | ||||
| ) | ||||
| @register_action( | ||||
|     "micro_wake_word.disable_model", | ||||
|     DisableModelAction, | ||||
|     MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA, | ||||
| ) | ||||
| @register_condition( | ||||
|     "micro_wake_word.model_is_enabled", | ||||
|     ModelIsEnabledCondition, | ||||
|     MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA, | ||||
| ) | ||||
| async def model_action(config, action_id, template_arg, args): | ||||
|     parent = await cg.get_variable(config[CONF_ID]) | ||||
|     return cg.new_Pvariable(action_id, template_arg, parent) | ||||
|   | ||||
							
								
								
									
										54
									
								
								esphome/components/micro_wake_word/automation.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								esphome/components/micro_wake_word/automation.h
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,54 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include "micro_wake_word.h" | ||||
| #include "streaming_model.h" | ||||
|  | ||||
| #ifdef USE_ESP_IDF | ||||
| namespace esphome { | ||||
| namespace micro_wake_word { | ||||
|  | ||||
| template<typename... Ts> class StartAction : public Action<Ts...>, public Parented<MicroWakeWord> { | ||||
|  public: | ||||
|   void play(Ts... x) override { this->parent_->start(); } | ||||
| }; | ||||
|  | ||||
| template<typename... Ts> class StopAction : public Action<Ts...>, public Parented<MicroWakeWord> { | ||||
|  public: | ||||
|   void play(Ts... x) override { this->parent_->stop(); } | ||||
| }; | ||||
|  | ||||
| template<typename... Ts> class IsRunningCondition : public Condition<Ts...>, public Parented<MicroWakeWord> { | ||||
|  public: | ||||
|   bool check(Ts... x) override { return this->parent_->is_running(); } | ||||
| }; | ||||
|  | ||||
| template<typename... Ts> class EnableModelAction : public Action<Ts...> { | ||||
|  public: | ||||
|   explicit EnableModelAction(WakeWordModel *wake_word_model) : wake_word_model_(wake_word_model) {} | ||||
|   void play(Ts... x) override { this->wake_word_model_->enable(); } | ||||
|  | ||||
|  protected: | ||||
|   WakeWordModel *wake_word_model_; | ||||
| }; | ||||
|  | ||||
| template<typename... Ts> class DisableModelAction : public Action<Ts...> { | ||||
|  public: | ||||
|   explicit DisableModelAction(WakeWordModel *wake_word_model) : wake_word_model_(wake_word_model) {} | ||||
|   void play(Ts... x) override { this->wake_word_model_->disable(); } | ||||
|  | ||||
|  protected: | ||||
|   WakeWordModel *wake_word_model_; | ||||
| }; | ||||
|  | ||||
| template<typename... Ts> class ModelIsEnabledCondition : public Condition<Ts...> { | ||||
|  public: | ||||
|   explicit ModelIsEnabledCondition(WakeWordModel *wake_word_model) : wake_word_model_(wake_word_model) {} | ||||
|   bool check(Ts... x) override { return this->wake_word_model_->is_enabled(); } | ||||
|  | ||||
|  protected: | ||||
|   WakeWordModel *wake_word_model_; | ||||
| }; | ||||
|  | ||||
| }  // namespace micro_wake_word | ||||
| }  // namespace esphome | ||||
| #endif | ||||
| @@ -1,5 +1,4 @@ | ||||
| #include "micro_wake_word.h" | ||||
| #include "streaming_model.h" | ||||
|  | ||||
| #ifdef USE_ESP_IDF | ||||
|  | ||||
| @@ -7,41 +6,57 @@ | ||||
| #include "esphome/core/helpers.h" | ||||
| #include "esphome/core/log.h" | ||||
|  | ||||
| #include <frontend.h> | ||||
| #include <frontend_util.h> | ||||
| #include "esphome/components/audio/audio_transfer_buffer.h" | ||||
|  | ||||
| #include <tensorflow/lite/core/c/common.h> | ||||
| #include <tensorflow/lite/micro/micro_interpreter.h> | ||||
| #include <tensorflow/lite/micro/micro_mutable_op_resolver.h> | ||||
|  | ||||
| #include <cmath> | ||||
| #ifdef USE_OTA | ||||
| #include "esphome/components/ota/ota_backend.h" | ||||
| #endif | ||||
|  | ||||
| namespace esphome { | ||||
| namespace micro_wake_word { | ||||
|  | ||||
| static const char *const TAG = "micro_wake_word"; | ||||
|  | ||||
| static const size_t SAMPLE_RATE_HZ = 16000;  // 16 kHz | ||||
| static const size_t BUFFER_LENGTH = 64;      // 0.064 seconds | ||||
| static const size_t BUFFER_SIZE = SAMPLE_RATE_HZ / 1000 * BUFFER_LENGTH; | ||||
| static const size_t INPUT_BUFFER_SIZE = 16 * SAMPLE_RATE_HZ / 1000;  // 16ms * 16kHz / 1000ms | ||||
| static const ssize_t DETECTION_QUEUE_LENGTH = 5; | ||||
|  | ||||
| static const size_t DATA_TIMEOUT_MS = 50; | ||||
|  | ||||
| static const uint32_t RING_BUFFER_DURATION_MS = 120; | ||||
| static const uint32_t RING_BUFFER_SAMPLES = RING_BUFFER_DURATION_MS * (AUDIO_SAMPLE_FREQUENCY / 1000); | ||||
| static const size_t RING_BUFFER_SIZE = RING_BUFFER_SAMPLES * sizeof(int16_t); | ||||
|  | ||||
| static const uint32_t INFERENCE_TASK_STACK_SIZE = 3072; | ||||
| static const UBaseType_t INFERENCE_TASK_PRIORITY = 3; | ||||
|  | ||||
| enum EventGroupBits : uint32_t { | ||||
|   COMMAND_STOP = (1 << 0),  // Signals the inference task should stop | ||||
|  | ||||
|   TASK_STARTING = (1 << 3), | ||||
|   TASK_RUNNING = (1 << 4), | ||||
|   TASK_STOPPING = (1 << 5), | ||||
|   TASK_STOPPED = (1 << 6), | ||||
|  | ||||
|   ERROR_MEMORY = (1 << 9), | ||||
|   ERROR_INFERENCE = (1 << 10), | ||||
|  | ||||
|   WARNING_FULL_RING_BUFFER = (1 << 13), | ||||
|  | ||||
|   ERROR_BITS = ERROR_MEMORY | ERROR_INFERENCE, | ||||
|   ALL_BITS = 0xfffff,  // 24 total bits available in an event group | ||||
| }; | ||||
|  | ||||
| float MicroWakeWord::get_setup_priority() const { return setup_priority::AFTER_CONNECTION; } | ||||
|  | ||||
| static const LogString *micro_wake_word_state_to_string(State state) { | ||||
|   switch (state) { | ||||
|     case State::IDLE: | ||||
|       return LOG_STR("IDLE"); | ||||
|     case State::START_MICROPHONE: | ||||
|       return LOG_STR("START_MICROPHONE"); | ||||
|     case State::STARTING_MICROPHONE: | ||||
|       return LOG_STR("STARTING_MICROPHONE"); | ||||
|     case State::STARTING: | ||||
|       return LOG_STR("STARTING"); | ||||
|     case State::DETECTING_WAKE_WORD: | ||||
|       return LOG_STR("DETECTING_WAKE_WORD"); | ||||
|     case State::STOP_MICROPHONE: | ||||
|       return LOG_STR("STOP_MICROPHONE"); | ||||
|     case State::STOPPING_MICROPHONE: | ||||
|       return LOG_STR("STOPPING_MICROPHONE"); | ||||
|     case State::STOPPING: | ||||
|       return LOG_STR("STOPPING"); | ||||
|     case State::STOPPED: | ||||
|       return LOG_STR("STOPPED"); | ||||
|     default: | ||||
|       return LOG_STR("UNKNOWN"); | ||||
|   } | ||||
| @@ -51,7 +66,7 @@ void MicroWakeWord::dump_config() { | ||||
|   ESP_LOGCONFIG(TAG, "microWakeWord:"); | ||||
|   ESP_LOGCONFIG(TAG, "  models:"); | ||||
|   for (auto &model : this->wake_word_models_) { | ||||
|     model.log_model_config(); | ||||
|     model->log_model_config(); | ||||
|   } | ||||
| #ifdef USE_MICRO_WAKE_WORD_VAD | ||||
|   this->vad_model_->log_model_config(); | ||||
| @@ -61,107 +76,265 @@ void MicroWakeWord::dump_config() { | ||||
| void MicroWakeWord::setup() { | ||||
|   ESP_LOGCONFIG(TAG, "Setting up microWakeWord..."); | ||||
|  | ||||
|   this->microphone_source_->add_data_callback([this](const std::vector<uint8_t> &data) { | ||||
|     if (this->state_ != State::DETECTING_WAKE_WORD) { | ||||
|   this->frontend_config_.window.size_ms = FEATURE_DURATION_MS; | ||||
|   this->frontend_config_.window.step_size_ms = this->features_step_size_; | ||||
|   this->frontend_config_.filterbank.num_channels = PREPROCESSOR_FEATURE_SIZE; | ||||
|   this->frontend_config_.filterbank.lower_band_limit = FILTERBANK_LOWER_BAND_LIMIT; | ||||
|   this->frontend_config_.filterbank.upper_band_limit = FILTERBANK_UPPER_BAND_LIMIT; | ||||
|   this->frontend_config_.noise_reduction.smoothing_bits = NOISE_REDUCTION_SMOOTHING_BITS; | ||||
|   this->frontend_config_.noise_reduction.even_smoothing = NOISE_REDUCTION_EVEN_SMOOTHING; | ||||
|   this->frontend_config_.noise_reduction.odd_smoothing = NOISE_REDUCTION_ODD_SMOOTHING; | ||||
|   this->frontend_config_.noise_reduction.min_signal_remaining = NOISE_REDUCTION_MIN_SIGNAL_REMAINING; | ||||
|   this->frontend_config_.pcan_gain_control.enable_pcan = PCAN_GAIN_CONTROL_ENABLE_PCAN; | ||||
|   this->frontend_config_.pcan_gain_control.strength = PCAN_GAIN_CONTROL_STRENGTH; | ||||
|   this->frontend_config_.pcan_gain_control.offset = PCAN_GAIN_CONTROL_OFFSET; | ||||
|   this->frontend_config_.pcan_gain_control.gain_bits = PCAN_GAIN_CONTROL_GAIN_BITS; | ||||
|   this->frontend_config_.log_scale.enable_log = LOG_SCALE_ENABLE_LOG; | ||||
|   this->frontend_config_.log_scale.scale_shift = LOG_SCALE_SCALE_SHIFT; | ||||
|  | ||||
|   this->event_group_ = xEventGroupCreate(); | ||||
|   if (this->event_group_ == nullptr) { | ||||
|     ESP_LOGE(TAG, "Failed to create event group"); | ||||
|     this->mark_failed(); | ||||
|     return; | ||||
|   } | ||||
|     std::shared_ptr<RingBuffer> temp_ring_buffer = this->ring_buffer_; | ||||
|     if (this->ring_buffer_.use_count() == 2) { | ||||
|       // mWW still owns the ring buffer and temp_ring_buffer does as well, proceed to copy audio into ring buffer | ||||
|  | ||||
|   this->detection_queue_ = xQueueCreate(DETECTION_QUEUE_LENGTH, sizeof(DetectionEvent)); | ||||
|   if (this->detection_queue_ == nullptr) { | ||||
|     ESP_LOGE(TAG, "Failed to create detection event queue"); | ||||
|     this->mark_failed(); | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   this->microphone_source_->add_data_callback([this](const std::vector<uint8_t> &data) { | ||||
|     if (this->state_ == State::STOPPED) { | ||||
|       return; | ||||
|     } | ||||
|     std::shared_ptr<RingBuffer> temp_ring_buffer = this->ring_buffer_.lock(); | ||||
|     if (this->ring_buffer_.use_count() > 1) { | ||||
|       size_t bytes_free = temp_ring_buffer->free(); | ||||
|  | ||||
|       if (bytes_free < data.size()) { | ||||
|         ESP_LOGW( | ||||
|             TAG, | ||||
|             "Not enough free bytes in ring buffer to store incoming audio data (free bytes=%d, incoming bytes=%d). " | ||||
|             "Resetting the ring buffer. Wake word detection accuracy will be reduced.", | ||||
|             bytes_free, data.size()); | ||||
|  | ||||
|         xEventGroupSetBits(this->event_group_, EventGroupBits::WARNING_FULL_RING_BUFFER); | ||||
|         temp_ring_buffer->reset(); | ||||
|       } | ||||
|       temp_ring_buffer->write((void *) data.data(), data.size()); | ||||
|     } | ||||
|   }); | ||||
|  | ||||
|   if (!this->register_streaming_ops_(this->streaming_op_resolver_)) { | ||||
|     this->mark_failed(); | ||||
|     return; | ||||
| #ifdef USE_OTA | ||||
|   ota::get_global_ota_callback()->add_on_state_callback( | ||||
|       [this](ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) { | ||||
|         if (state == ota::OTA_STARTED) { | ||||
|           this->suspend_task_(); | ||||
|         } else if (state == ota::OTA_ERROR) { | ||||
|           this->resume_task_(); | ||||
|         } | ||||
|  | ||||
|       }); | ||||
| #endif | ||||
|   ESP_LOGCONFIG(TAG, "Micro Wake Word initialized"); | ||||
|  | ||||
|   this->frontend_config_.window.size_ms = FEATURE_DURATION_MS; | ||||
|   this->frontend_config_.window.step_size_ms = this->features_step_size_; | ||||
|   this->frontend_config_.filterbank.num_channels = PREPROCESSOR_FEATURE_SIZE; | ||||
|   this->frontend_config_.filterbank.lower_band_limit = 125.0; | ||||
|   this->frontend_config_.filterbank.upper_band_limit = 7500.0; | ||||
|   this->frontend_config_.noise_reduction.smoothing_bits = 10; | ||||
|   this->frontend_config_.noise_reduction.even_smoothing = 0.025; | ||||
|   this->frontend_config_.noise_reduction.odd_smoothing = 0.06; | ||||
|   this->frontend_config_.noise_reduction.min_signal_remaining = 0.05; | ||||
|   this->frontend_config_.pcan_gain_control.enable_pcan = 1; | ||||
|   this->frontend_config_.pcan_gain_control.strength = 0.95; | ||||
|   this->frontend_config_.pcan_gain_control.offset = 80.0; | ||||
|   this->frontend_config_.pcan_gain_control.gain_bits = 21; | ||||
|   this->frontend_config_.log_scale.enable_log = 1; | ||||
|   this->frontend_config_.log_scale.scale_shift = 6; | ||||
| } | ||||
|  | ||||
| void MicroWakeWord::add_wake_word_model(const uint8_t *model_start, float probability_cutoff, | ||||
|                                         size_t sliding_window_average_size, const std::string &wake_word, | ||||
|                                         size_t tensor_arena_size) { | ||||
|   this->wake_word_models_.emplace_back(model_start, probability_cutoff, sliding_window_average_size, wake_word, | ||||
|                                        tensor_arena_size); | ||||
| void MicroWakeWord::inference_task(void *params) { | ||||
|   MicroWakeWord *this_mww = (MicroWakeWord *) params; | ||||
|  | ||||
|   xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STARTING); | ||||
|  | ||||
|   {  // Ensures any C++ objects fall out of scope to deallocate before deleting the task | ||||
|     const size_t new_samples_to_read = this_mww->features_step_size_ * (AUDIO_SAMPLE_FREQUENCY / 1000); | ||||
|     std::unique_ptr<audio::AudioSourceTransferBuffer> audio_buffer; | ||||
|     int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE]; | ||||
|  | ||||
|     if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) { | ||||
|       // Allocate audio transfer buffer | ||||
|       audio_buffer = audio::AudioSourceTransferBuffer::create(new_samples_to_read * sizeof(int16_t)); | ||||
|  | ||||
|       if (audio_buffer == nullptr) { | ||||
|         xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_MEMORY); | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) { | ||||
|       // Allocate ring buffer | ||||
|       std::shared_ptr<RingBuffer> temp_ring_buffer = RingBuffer::create(RING_BUFFER_SIZE); | ||||
|       if (temp_ring_buffer.use_count() == 0) { | ||||
|         xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_MEMORY); | ||||
|       } | ||||
|       audio_buffer->set_source(temp_ring_buffer); | ||||
|       this_mww->ring_buffer_ = temp_ring_buffer; | ||||
|     } | ||||
|  | ||||
|     if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) { | ||||
|       this_mww->microphone_source_->start(); | ||||
|       xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_RUNNING); | ||||
|  | ||||
|       while (!(xEventGroupGetBits(this_mww->event_group_) & COMMAND_STOP)) { | ||||
|         audio_buffer->transfer_data_from_source(pdMS_TO_TICKS(DATA_TIMEOUT_MS)); | ||||
|  | ||||
|         if (audio_buffer->available() < new_samples_to_read * sizeof(int16_t)) { | ||||
|           // Insufficient data to generate new spectrogram features, read more next iteration | ||||
|           continue; | ||||
|         } | ||||
|  | ||||
|         // Generate new spectrogram features | ||||
|         size_t processed_samples = this_mww->generate_features_( | ||||
|             (int16_t *) audio_buffer->get_buffer_start(), audio_buffer->available() / sizeof(int16_t), features_buffer); | ||||
|         audio_buffer->decrease_buffer_length(processed_samples * sizeof(int16_t)); | ||||
|  | ||||
|         // Run inference using the new spectorgram features | ||||
|         if (!this_mww->update_model_probabilities_(features_buffer)) { | ||||
|           xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_INFERENCE); | ||||
|           break; | ||||
|         } | ||||
|  | ||||
|         // Process each model's probabilities and possibly send a Detection Event to the queue | ||||
|         this_mww->process_probabilities_(); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STOPPING); | ||||
|  | ||||
|   this_mww->unload_models_(); | ||||
|   this_mww->microphone_source_->stop(); | ||||
|   FrontendFreeStateContents(&this_mww->frontend_state_); | ||||
|  | ||||
|   xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STOPPED); | ||||
|   while (true) { | ||||
|     // Continuously delay until the main loop deletes the task | ||||
|     delay(10); | ||||
|   } | ||||
| } | ||||
|  | ||||
| std::vector<WakeWordModel *> MicroWakeWord::get_wake_words() { | ||||
|   std::vector<WakeWordModel *> external_wake_word_models; | ||||
|   for (auto *model : this->wake_word_models_) { | ||||
|     if (!model->get_internal_only()) { | ||||
|       external_wake_word_models.push_back(model); | ||||
|     } | ||||
|   } | ||||
|   return external_wake_word_models; | ||||
| } | ||||
|  | ||||
| void MicroWakeWord::add_wake_word_model(WakeWordModel *model) { this->wake_word_models_.push_back(model); } | ||||
|  | ||||
| #ifdef USE_MICRO_WAKE_WORD_VAD | ||||
| void MicroWakeWord::add_vad_model(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size, | ||||
| void MicroWakeWord::add_vad_model(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size, | ||||
|                                   size_t tensor_arena_size) { | ||||
|   this->vad_model_ = make_unique<VADModel>(model_start, probability_cutoff, sliding_window_size, tensor_arena_size); | ||||
| } | ||||
| #endif | ||||
|  | ||||
| void MicroWakeWord::suspend_task_() { | ||||
|   if (this->inference_task_handle_ != nullptr) { | ||||
|     vTaskSuspend(this->inference_task_handle_); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void MicroWakeWord::resume_task_() { | ||||
|   if (this->inference_task_handle_ != nullptr) { | ||||
|     vTaskResume(this->inference_task_handle_); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void MicroWakeWord::loop() { | ||||
|   switch (this->state_) { | ||||
|     case State::IDLE: | ||||
|       break; | ||||
|     case State::START_MICROPHONE: | ||||
|       ESP_LOGD(TAG, "Starting Microphone"); | ||||
|       this->microphone_source_->start(); | ||||
|       this->set_state_(State::STARTING_MICROPHONE); | ||||
|       break; | ||||
|     case State::STARTING_MICROPHONE: | ||||
|       if (this->microphone_source_->is_running()) { | ||||
|   uint32_t event_group_bits = xEventGroupGetBits(this->event_group_); | ||||
|  | ||||
|   if (event_group_bits & EventGroupBits::ERROR_MEMORY) { | ||||
|     xEventGroupClearBits(this->event_group_, EventGroupBits::ERROR_MEMORY); | ||||
|     ESP_LOGE(TAG, "Encountered an error allocating buffers"); | ||||
|   } | ||||
|  | ||||
|   if (event_group_bits & EventGroupBits::ERROR_INFERENCE) { | ||||
|     xEventGroupClearBits(this->event_group_, EventGroupBits::ERROR_INFERENCE); | ||||
|     ESP_LOGE(TAG, "Encountered an error while performing an inference"); | ||||
|   } | ||||
|  | ||||
|   if (event_group_bits & EventGroupBits::WARNING_FULL_RING_BUFFER) { | ||||
|     xEventGroupClearBits(this->event_group_, EventGroupBits::WARNING_FULL_RING_BUFFER); | ||||
|     ESP_LOGW(TAG, "Not enough free bytes in ring buffer to store incoming audio data. Resetting the ring buffer. Wake " | ||||
|                   "word detection accuracy will temporarily be reduced."); | ||||
|   } | ||||
|  | ||||
|   if (event_group_bits & EventGroupBits::TASK_STARTING) { | ||||
|     ESP_LOGD(TAG, "Inference task has started, attempting to allocate memory for buffers"); | ||||
|     xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_STARTING); | ||||
|   } | ||||
|  | ||||
|   if (event_group_bits & EventGroupBits::TASK_RUNNING) { | ||||
|     ESP_LOGD(TAG, "Inference task is running"); | ||||
|  | ||||
|     xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_RUNNING); | ||||
|     this->set_state_(State::DETECTING_WAKE_WORD); | ||||
|   } | ||||
|       break; | ||||
|     case State::DETECTING_WAKE_WORD: | ||||
|       while (this->has_enough_samples_()) { | ||||
|         this->update_model_probabilities_(); | ||||
|         if (this->detect_wake_words_()) { | ||||
|           ESP_LOGD(TAG, "Wake Word '%s' Detected", (this->detected_wake_word_).c_str()); | ||||
|           this->detected_ = true; | ||||
|           this->set_state_(State::STOP_MICROPHONE); | ||||
|  | ||||
|   if (event_group_bits & EventGroupBits::TASK_STOPPING) { | ||||
|     ESP_LOGD(TAG, "Inference task is stopping, deallocating buffers"); | ||||
|     xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_STOPPING); | ||||
|   } | ||||
|  | ||||
|   if ((event_group_bits & EventGroupBits::TASK_STOPPED)) { | ||||
|     ESP_LOGD(TAG, "Inference task is finished, freeing task resources"); | ||||
|     vTaskDelete(this->inference_task_handle_); | ||||
|     this->inference_task_handle_ = nullptr; | ||||
|     xEventGroupClearBits(this->event_group_, ALL_BITS); | ||||
|     xQueueReset(this->detection_queue_); | ||||
|     this->set_state_(State::STOPPED); | ||||
|   } | ||||
|  | ||||
|   if ((this->pending_start_) && (this->state_ == State::STOPPED)) { | ||||
|     this->set_state_(State::STARTING); | ||||
|     this->pending_start_ = false; | ||||
|   } | ||||
|  | ||||
|   if ((this->pending_stop_) && (this->state_ == State::DETECTING_WAKE_WORD)) { | ||||
|     this->set_state_(State::STOPPING); | ||||
|     this->pending_stop_ = false; | ||||
|   } | ||||
|  | ||||
|   switch (this->state_) { | ||||
|     case State::STARTING: | ||||
|       if ((this->inference_task_handle_ == nullptr) && !this->status_has_error()) { | ||||
|         // Setup preprocesor feature generator. If done in the task, it would lock the task to its initial core, as it | ||||
|         // uses floating point operations. | ||||
|         if (!FrontendPopulateState(&this->frontend_config_, &this->frontend_state_, AUDIO_SAMPLE_FREQUENCY)) { | ||||
|           this->status_momentary_error( | ||||
|               "Failed to allocate buffers for spectrogram feature processor, attempting again in 1 second", 1000); | ||||
|           return; | ||||
|         } | ||||
|  | ||||
|         xTaskCreate(MicroWakeWord::inference_task, "mww", INFERENCE_TASK_STACK_SIZE, (void *) this, | ||||
|                     INFERENCE_TASK_PRIORITY, &this->inference_task_handle_); | ||||
|  | ||||
|         if (this->inference_task_handle_ == nullptr) { | ||||
|           FrontendFreeStateContents(&this->frontend_state_);  // Deallocate frontend state | ||||
|           this->status_momentary_error("Task failed to start, attempting again in 1 second", 1000); | ||||
|         } | ||||
|       } | ||||
|       break; | ||||
|     case State::STOP_MICROPHONE: | ||||
|       ESP_LOGD(TAG, "Stopping Microphone"); | ||||
|       this->microphone_source_->stop(); | ||||
|       this->set_state_(State::STOPPING_MICROPHONE); | ||||
|       this->unload_models_(); | ||||
|       this->deallocate_buffers_(); | ||||
|     case State::DETECTING_WAKE_WORD: { | ||||
|       DetectionEvent detection_event; | ||||
|       while (xQueueReceive(this->detection_queue_, &detection_event, 0)) { | ||||
|         if (detection_event.blocked_by_vad) { | ||||
|           ESP_LOGD(TAG, "Wake word model predicts '%s', but VAD model doesn't.", detection_event.wake_word->c_str()); | ||||
|         } else { | ||||
|           constexpr float uint8_to_float_divisor = | ||||
|               255.0f;  // Converting a quantized uint8 probability to floating point | ||||
|           ESP_LOGD(TAG, "Detected '%s' with sliding average probability is %.2f and max probability is %.2f", | ||||
|                    detection_event.wake_word->c_str(), (detection_event.average_probability / uint8_to_float_divisor), | ||||
|                    (detection_event.max_probability / uint8_to_float_divisor)); | ||||
|           this->wake_word_detected_trigger_->trigger(*detection_event.wake_word); | ||||
|           if (this->stop_after_detection_) { | ||||
|             this->stop(); | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|       break; | ||||
|     case State::STOPPING_MICROPHONE: | ||||
|       if (this->microphone_source_->is_stopped()) { | ||||
|         this->set_state_(State::IDLE); | ||||
|         if (this->detected_) { | ||||
|           this->wake_word_detected_trigger_->trigger(this->detected_wake_word_); | ||||
|           this->detected_ = false; | ||||
|           this->detected_wake_word_ = ""; | ||||
|         } | ||||
|     } | ||||
|     case State::STOPPING: | ||||
|       xEventGroupSetBits(this->event_group_, EventGroupBits::COMMAND_STOP); | ||||
|       break; | ||||
|     case State::STOPPED: | ||||
|       break; | ||||
|   } | ||||
| } | ||||
| @@ -177,199 +350,40 @@ void MicroWakeWord::start() { | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   if (this->state_ != State::IDLE) { | ||||
|     ESP_LOGW(TAG, "Wake word is already running"); | ||||
|   if (this->is_running()) { | ||||
|     ESP_LOGW(TAG, "Wake word detection is already running"); | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   if (!this->load_models_() || !this->allocate_buffers_()) { | ||||
|     ESP_LOGE(TAG, "Failed to load the wake word model(s) or allocate buffers"); | ||||
|     this->status_set_error(); | ||||
|   } else { | ||||
|     this->status_clear_error(); | ||||
|   } | ||||
|   ESP_LOGD(TAG, "Starting wake word detection"); | ||||
|  | ||||
|   if (this->status_has_error()) { | ||||
|     ESP_LOGW(TAG, "Wake word component has an error. Please check logs"); | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   this->reset_states_(); | ||||
|   this->set_state_(State::START_MICROPHONE); | ||||
|   this->pending_start_ = true; | ||||
|   this->pending_stop_ = false; | ||||
| } | ||||
|  | ||||
| void MicroWakeWord::stop() { | ||||
|   if (this->state_ == State::IDLE) { | ||||
|     ESP_LOGW(TAG, "Wake word is already stopped"); | ||||
|   if (this->state_ == STOPPED) | ||||
|     return; | ||||
|   } | ||||
|   if (this->state_ == State::STOPPING_MICROPHONE) { | ||||
|     ESP_LOGW(TAG, "Wake word is already stopping"); | ||||
|     return; | ||||
|   } | ||||
|   this->set_state_(State::STOP_MICROPHONE); | ||||
|  | ||||
|   ESP_LOGD(TAG, "Stopping wake word detection"); | ||||
|  | ||||
|   this->pending_start_ = false; | ||||
|   this->pending_stop_ = true; | ||||
| } | ||||
|  | ||||
| void MicroWakeWord::set_state_(State state) { | ||||
|   if (this->state_ != state) { | ||||
|     ESP_LOGD(TAG, "State changed from %s to %s", LOG_STR_ARG(micro_wake_word_state_to_string(this->state_)), | ||||
|              LOG_STR_ARG(micro_wake_word_state_to_string(state))); | ||||
|     this->state_ = state; | ||||
|   } | ||||
|  | ||||
| bool MicroWakeWord::allocate_buffers_() { | ||||
|   ExternalRAMAllocator<int16_t> audio_samples_allocator(ExternalRAMAllocator<int16_t>::ALLOW_FAILURE); | ||||
|  | ||||
|   if (this->input_buffer_ == nullptr) { | ||||
|     this->input_buffer_ = audio_samples_allocator.allocate(INPUT_BUFFER_SIZE * sizeof(int16_t)); | ||||
|     if (this->input_buffer_ == nullptr) { | ||||
|       ESP_LOGE(TAG, "Could not allocate input buffer"); | ||||
|       return false; | ||||
|     } | ||||
| } | ||||
|  | ||||
|   if (this->preprocessor_audio_buffer_ == nullptr) { | ||||
|     this->preprocessor_audio_buffer_ = audio_samples_allocator.allocate(this->new_samples_to_get_()); | ||||
|     if (this->preprocessor_audio_buffer_ == nullptr) { | ||||
|       ESP_LOGE(TAG, "Could not allocate the audio preprocessor's buffer."); | ||||
|       return false; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   if (this->ring_buffer_.use_count() == 0) { | ||||
|     this->ring_buffer_ = RingBuffer::create(BUFFER_SIZE * sizeof(int16_t)); | ||||
|     if (this->ring_buffer_.use_count() == 0) { | ||||
|       ESP_LOGE(TAG, "Could not allocate ring buffer"); | ||||
|       return false; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| void MicroWakeWord::deallocate_buffers_() { | ||||
|   ExternalRAMAllocator<int16_t> audio_samples_allocator(ExternalRAMAllocator<int16_t>::ALLOW_FAILURE); | ||||
|   if (this->input_buffer_ != nullptr) { | ||||
|     audio_samples_allocator.deallocate(this->input_buffer_, INPUT_BUFFER_SIZE * sizeof(int16_t)); | ||||
|     this->input_buffer_ = nullptr; | ||||
|   } | ||||
|  | ||||
|   if (this->preprocessor_audio_buffer_ != nullptr) { | ||||
|     audio_samples_allocator.deallocate(this->preprocessor_audio_buffer_, this->new_samples_to_get_()); | ||||
|     this->preprocessor_audio_buffer_ = nullptr; | ||||
|   } | ||||
|  | ||||
|   this->ring_buffer_.reset(); | ||||
| } | ||||
|  | ||||
| bool MicroWakeWord::load_models_() { | ||||
|   // Setup preprocesor feature generator | ||||
|   if (!FrontendPopulateState(&this->frontend_config_, &this->frontend_state_, AUDIO_SAMPLE_FREQUENCY)) { | ||||
|     ESP_LOGD(TAG, "Failed to populate frontend state"); | ||||
|     FrontendFreeStateContents(&this->frontend_state_); | ||||
|     return false; | ||||
|   } | ||||
|  | ||||
|   // Setup streaming models | ||||
|   for (auto &model : this->wake_word_models_) { | ||||
|     if (!model.load_model(this->streaming_op_resolver_)) { | ||||
|       ESP_LOGE(TAG, "Failed to initialize a wake word model."); | ||||
|       return false; | ||||
|     } | ||||
|   } | ||||
| #ifdef USE_MICRO_WAKE_WORD_VAD | ||||
|   if (!this->vad_model_->load_model(this->streaming_op_resolver_)) { | ||||
|     ESP_LOGE(TAG, "Failed to initialize VAD model."); | ||||
|     return false; | ||||
|   } | ||||
| #endif | ||||
|  | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| void MicroWakeWord::unload_models_() { | ||||
|   FrontendFreeStateContents(&this->frontend_state_); | ||||
|  | ||||
|   for (auto &model : this->wake_word_models_) { | ||||
|     model.unload_model(); | ||||
|   } | ||||
| #ifdef USE_MICRO_WAKE_WORD_VAD | ||||
|   this->vad_model_->unload_model(); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| void MicroWakeWord::update_model_probabilities_() { | ||||
|   int8_t audio_features[PREPROCESSOR_FEATURE_SIZE]; | ||||
|  | ||||
|   if (!this->generate_features_for_window_(audio_features)) { | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   // Increase the counter since the last positive detection | ||||
|   this->ignore_windows_ = std::min(this->ignore_windows_ + 1, 0); | ||||
|  | ||||
|   for (auto &model : this->wake_word_models_) { | ||||
|     // Perform inference | ||||
|     model.perform_streaming_inference(audio_features); | ||||
|   } | ||||
| #ifdef USE_MICRO_WAKE_WORD_VAD | ||||
|   this->vad_model_->perform_streaming_inference(audio_features); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| bool MicroWakeWord::detect_wake_words_() { | ||||
|   // Verify we have processed samples since the last positive detection | ||||
|   if (this->ignore_windows_ < 0) { | ||||
|     return false; | ||||
|   } | ||||
|  | ||||
| #ifdef USE_MICRO_WAKE_WORD_VAD | ||||
|   bool vad_state = this->vad_model_->determine_detected(); | ||||
| #endif | ||||
|  | ||||
|   for (auto &model : this->wake_word_models_) { | ||||
|     if (model.determine_detected()) { | ||||
| #ifdef USE_MICRO_WAKE_WORD_VAD | ||||
|       if (vad_state) { | ||||
| #endif | ||||
|         this->detected_wake_word_ = model.get_wake_word(); | ||||
|         return true; | ||||
| #ifdef USE_MICRO_WAKE_WORD_VAD | ||||
|       } else { | ||||
|         ESP_LOGD(TAG, "Wake word model predicts %s, but VAD model doesn't.", model.get_wake_word().c_str()); | ||||
|       } | ||||
| #endif | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   return false; | ||||
| } | ||||
|  | ||||
| bool MicroWakeWord::has_enough_samples_() { | ||||
|   return this->ring_buffer_->available() >= | ||||
|          (this->features_step_size_ * (AUDIO_SAMPLE_FREQUENCY / 1000)) * sizeof(int16_t); | ||||
| } | ||||
|  | ||||
| bool MicroWakeWord::generate_features_for_window_(int8_t features[PREPROCESSOR_FEATURE_SIZE]) { | ||||
|   // Ensure we have enough new audio samples in the ring buffer for a full window | ||||
|   if (!this->has_enough_samples_()) { | ||||
|     return false; | ||||
|   } | ||||
|  | ||||
|   size_t bytes_read = this->ring_buffer_->read((void *) (this->preprocessor_audio_buffer_), | ||||
|                                                this->new_samples_to_get_() * sizeof(int16_t), pdMS_TO_TICKS(200)); | ||||
|  | ||||
|   if (bytes_read == 0) { | ||||
|     ESP_LOGE(TAG, "Could not read data from Ring Buffer"); | ||||
|   } else if (bytes_read < this->new_samples_to_get_() * sizeof(int16_t)) { | ||||
|     ESP_LOGD(TAG, "Partial Read of Data by Model"); | ||||
|     ESP_LOGD(TAG, "Could only read %d bytes when required %d bytes ", bytes_read, | ||||
|              (int) (this->new_samples_to_get_() * sizeof(int16_t))); | ||||
|     return false; | ||||
|   } | ||||
|  | ||||
|   size_t num_samples_read; | ||||
|   struct FrontendOutput frontend_output = FrontendProcessSamples( | ||||
|       &this->frontend_state_, this->preprocessor_audio_buffer_, this->new_samples_to_get_(), &num_samples_read); | ||||
| size_t MicroWakeWord::generate_features_(int16_t *audio_buffer, size_t samples_available, | ||||
|                                          int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE]) { | ||||
|   size_t processed_samples = 0; | ||||
|   struct FrontendOutput frontend_output = | ||||
|       FrontendProcessSamples(&this->frontend_state_, audio_buffer, samples_available, &processed_samples); | ||||
|  | ||||
|   for (size_t i = 0; i < frontend_output.size; ++i) { | ||||
|     // These scaling values are set to match the TFLite audio frontend int8 output. | ||||
| @@ -379,8 +393,8 @@ bool MicroWakeWord::generate_features_for_window_(int8_t features[PREPROCESSOR_F | ||||
|     // for historical reasons, to match up with the output of other feature | ||||
|     // generators. | ||||
|     // The process is then further complicated when we quantize the model. This | ||||
|     // means we have to scale the 0.0 to 26.0 real values to the -128 to 127 | ||||
|     // signed integer numbers. | ||||
|     // means we have to scale the 0.0 to 26.0 real values to the -128 (INT8_MIN) | ||||
|     // to 127 (INT8_MAX) signed integer numbers. | ||||
|     // All this means that to get matching values from our integer feature | ||||
|     // output into the tensor input, we have to perform: | ||||
|     // input = (((feature / 25.6) / 26.0) * 256) - 128 | ||||
| @@ -389,74 +403,63 @@ bool MicroWakeWord::generate_features_for_window_(int8_t features[PREPROCESSOR_F | ||||
|     constexpr int32_t value_scale = 256; | ||||
|     constexpr int32_t value_div = 666;  // 666 = 25.6 * 26.0 after rounding | ||||
|     int32_t value = ((frontend_output.values[i] * value_scale) + (value_div / 2)) / value_div; | ||||
|     value -= 128; | ||||
|     if (value < -128) { | ||||
|       value = -128; | ||||
|     } | ||||
|     if (value > 127) { | ||||
|       value = 127; | ||||
|     } | ||||
|     features[i] = value; | ||||
|  | ||||
|     value -= INT8_MIN; | ||||
|     features_buffer[i] = clamp<int8_t>(value, INT8_MIN, INT8_MAX); | ||||
|   } | ||||
|  | ||||
|   return true; | ||||
|   return processed_samples; | ||||
| } | ||||
|  | ||||
| void MicroWakeWord::reset_states_() { | ||||
|   ESP_LOGD(TAG, "Resetting buffers and probabilities"); | ||||
|   this->ring_buffer_->reset(); | ||||
|   this->ignore_windows_ = -MIN_SLICES_BEFORE_DETECTION; | ||||
| void MicroWakeWord::process_probabilities_() { | ||||
| #ifdef USE_MICRO_WAKE_WORD_VAD | ||||
|   DetectionEvent vad_state = this->vad_model_->determine_detected(); | ||||
|  | ||||
|   this->vad_state_ = vad_state.detected;  // atomic write, so thread safe | ||||
| #endif | ||||
|  | ||||
|   for (auto &model : this->wake_word_models_) { | ||||
|     model.reset_probabilities(); | ||||
|     if (model->get_unprocessed_probability_status()) { | ||||
|       // Only detect wake words if there is a new probability since the last check | ||||
|       DetectionEvent wake_word_state = model->determine_detected(); | ||||
|       if (wake_word_state.detected) { | ||||
| #ifdef USE_MICRO_WAKE_WORD_VAD | ||||
|         if (vad_state.detected) { | ||||
| #endif | ||||
|           xQueueSend(this->detection_queue_, &wake_word_state, portMAX_DELAY); | ||||
|           model->reset_probabilities(); | ||||
| #ifdef USE_MICRO_WAKE_WORD_VAD | ||||
|         } else { | ||||
|           wake_word_state.blocked_by_vad = true; | ||||
|           xQueueSend(this->detection_queue_, &wake_word_state, portMAX_DELAY); | ||||
|         } | ||||
| #endif | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| void MicroWakeWord::unload_models_() { | ||||
|   for (auto &model : this->wake_word_models_) { | ||||
|     model->unload_model(); | ||||
|   } | ||||
| #ifdef USE_MICRO_WAKE_WORD_VAD | ||||
|   this->vad_model_->reset_probabilities(); | ||||
|   this->vad_model_->unload_model(); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| bool MicroWakeWord::register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver) { | ||||
|   if (op_resolver.AddCallOnce() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddVarHandle() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddReshape() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddReadVariable() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddStridedSlice() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddConcatenation() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddAssignVariable() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddConv2D() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddMul() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddAdd() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddMean() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddFullyConnected() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddLogistic() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddQuantize() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddDepthwiseConv2D() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddAveragePool2D() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddMaxPool2D() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddPad() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddPack() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddSplitV() != kTfLiteOk) | ||||
|     return false; | ||||
| bool MicroWakeWord::update_model_probabilities_(const int8_t audio_features[PREPROCESSOR_FEATURE_SIZE]) { | ||||
|   bool success = true; | ||||
|  | ||||
|   return true; | ||||
|   for (auto &model : this->wake_word_models_) { | ||||
|     // Perform inference | ||||
|     success = success & model->perform_streaming_inference(audio_features); | ||||
|   } | ||||
| #ifdef USE_MICRO_WAKE_WORD_VAD | ||||
|   success = success & this->vad_model_->perform_streaming_inference(audio_features); | ||||
| #endif | ||||
|  | ||||
|   return success; | ||||
| } | ||||
|  | ||||
| }  // namespace micro_wake_word | ||||
|   | ||||
| @@ -5,33 +5,27 @@ | ||||
| #include "preprocessor_settings.h" | ||||
| #include "streaming_model.h" | ||||
|  | ||||
| #include "esphome/components/microphone/microphone_source.h" | ||||
|  | ||||
| #include "esphome/core/automation.h" | ||||
| #include "esphome/core/component.h" | ||||
| #include "esphome/core/ring_buffer.h" | ||||
|  | ||||
| #include "esphome/components/microphone/microphone_source.h" | ||||
| #include <freertos/event_groups.h> | ||||
|  | ||||
| #include <frontend.h> | ||||
| #include <frontend_util.h> | ||||
|  | ||||
| #include <tensorflow/lite/core/c/common.h> | ||||
| #include <tensorflow/lite/micro/micro_interpreter.h> | ||||
| #include <tensorflow/lite/micro/micro_mutable_op_resolver.h> | ||||
|  | ||||
| namespace esphome { | ||||
| namespace micro_wake_word { | ||||
|  | ||||
| enum State { | ||||
|   IDLE, | ||||
|   START_MICROPHONE, | ||||
|   STARTING_MICROPHONE, | ||||
|   STARTING, | ||||
|   DETECTING_WAKE_WORD, | ||||
|   STOP_MICROPHONE, | ||||
|   STOPPING_MICROPHONE, | ||||
|   STOPPING, | ||||
|   STOPPED, | ||||
| }; | ||||
|  | ||||
| // The number of audio slices to process before accepting a positive detection | ||||
| static const uint8_t MIN_SLICES_BEFORE_DETECTION = 74; | ||||
|  | ||||
| class MicroWakeWord : public Component { | ||||
|  public: | ||||
|   void setup() override; | ||||
| @@ -42,7 +36,7 @@ class MicroWakeWord : public Component { | ||||
|   void start(); | ||||
|   void stop(); | ||||
|  | ||||
|   bool is_running() const { return this->state_ != State::IDLE; } | ||||
|   bool is_running() const { return this->state_ != State::STOPPED; } | ||||
|  | ||||
|   void set_features_step_size(uint8_t step_size) { this->features_step_size_ = step_size; } | ||||
|  | ||||
| @@ -50,118 +44,87 @@ class MicroWakeWord : public Component { | ||||
|     this->microphone_source_ = microphone_source; | ||||
|   } | ||||
|  | ||||
|   void set_stop_after_detection(bool stop_after_detection) { this->stop_after_detection_ = stop_after_detection; } | ||||
|  | ||||
|   Trigger<std::string> *get_wake_word_detected_trigger() const { return this->wake_word_detected_trigger_; } | ||||
|  | ||||
|   void add_wake_word_model(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_average_size, | ||||
|                            const std::string &wake_word, size_t tensor_arena_size); | ||||
|   void add_wake_word_model(WakeWordModel *model); | ||||
|  | ||||
| #ifdef USE_MICRO_WAKE_WORD_VAD | ||||
|   void add_vad_model(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size, | ||||
|   void add_vad_model(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size, | ||||
|                      size_t tensor_arena_size); | ||||
|  | ||||
|   // Intended for the voice assistant component to fetch VAD status | ||||
|   bool get_vad_state() { return this->vad_state_; } | ||||
| #endif | ||||
|  | ||||
|   // Intended for the voice assistant component to access which wake words are available | ||||
|   // Since these are pointers to the WakeWordModel objects, the voice assistant component can enable or disable them | ||||
|   std::vector<WakeWordModel *> get_wake_words(); | ||||
|  | ||||
|  protected: | ||||
|   microphone::MicrophoneSource *microphone_source_{nullptr}; | ||||
|   Trigger<std::string> *wake_word_detected_trigger_ = new Trigger<std::string>(); | ||||
|   State state_{State::IDLE}; | ||||
|   State state_{State::STOPPED}; | ||||
|  | ||||
|   std::shared_ptr<RingBuffer> ring_buffer_; | ||||
|  | ||||
|   std::vector<WakeWordModel> wake_word_models_; | ||||
|   std::weak_ptr<RingBuffer> ring_buffer_; | ||||
|   std::vector<WakeWordModel *> wake_word_models_; | ||||
|  | ||||
| #ifdef USE_MICRO_WAKE_WORD_VAD | ||||
|   std::unique_ptr<VADModel> vad_model_; | ||||
|   bool vad_state_{false}; | ||||
| #endif | ||||
|  | ||||
|   tflite::MicroMutableOpResolver<20> streaming_op_resolver_; | ||||
|   bool pending_start_{false}; | ||||
|   bool pending_stop_{false}; | ||||
|  | ||||
|   bool stop_after_detection_; | ||||
|  | ||||
|   uint8_t features_step_size_; | ||||
|  | ||||
|   // Audio frontend handles generating spectrogram features | ||||
|   struct FrontendConfig frontend_config_; | ||||
|   struct FrontendState frontend_state_; | ||||
|  | ||||
|   // When the wake word detection first starts, we ignore this many audio | ||||
|   // feature slices before accepting a positive detection | ||||
|   int16_t ignore_windows_{-MIN_SLICES_BEFORE_DETECTION}; | ||||
|   // Handles managing the stop/state of the inference task | ||||
|   EventGroupHandle_t event_group_; | ||||
|  | ||||
|   uint8_t features_step_size_; | ||||
|   // Used to send messages about the models' states to the main loop | ||||
|   QueueHandle_t detection_queue_; | ||||
|  | ||||
|   // Stores audio read from the microphone before being added to the ring buffer. | ||||
|   int16_t *input_buffer_{nullptr}; | ||||
|   // Stores audio to be fed into the audio frontend for generating features. | ||||
|   int16_t *preprocessor_audio_buffer_{nullptr}; | ||||
|   static void inference_task(void *params); | ||||
|   TaskHandle_t inference_task_handle_{nullptr}; | ||||
|  | ||||
|   bool detected_{false}; | ||||
|   std::string detected_wake_word_{""}; | ||||
|   /// @brief Suspends the inference task | ||||
|   void suspend_task_(); | ||||
|   /// @brief Resumes the inference task | ||||
|   void resume_task_(); | ||||
|  | ||||
|   void set_state_(State state); | ||||
|  | ||||
|   /// @brief Tests if there are enough samples in the ring buffer to generate new features. | ||||
|   /// @return True if enough samples, false otherwise. | ||||
|   bool has_enough_samples_(); | ||||
|   /// @brief Generates spectrogram features from an input buffer of audio samples | ||||
|   /// @param audio_buffer (int16_t *) Buffer containing input audio samples | ||||
|   /// @param samples_available (size_t) Number of samples avaiable in the input buffer | ||||
|   /// @param features_buffer (int8_t *) Buffer to store generated features | ||||
|   /// @return (size_t) Number of samples processed from the input buffer | ||||
|   size_t generate_features_(int16_t *audio_buffer, size_t samples_available, | ||||
|                             int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE]); | ||||
|  | ||||
|   /// @brief Allocates memory for input_buffer_, preprocessor_audio_buffer_, and ring_buffer_ | ||||
|   /// @return True if successful, false otherwise | ||||
|   bool allocate_buffers_(); | ||||
|   /// @brief Processes any new probabilities for each model. If any wake word is detected, it will send a DetectionEvent | ||||
|   /// to the detection_queue_. | ||||
|   void process_probabilities_(); | ||||
|  | ||||
|   /// @brief Frees memory allocated for input_buffer_ and preprocessor_audio_buffer_ | ||||
|   void deallocate_buffers_(); | ||||
|  | ||||
|   /// @brief Loads streaming models and prepares the feature generation frontend | ||||
|   /// @return True if successful, false otherwise | ||||
|   bool load_models_(); | ||||
|  | ||||
|   /// @brief Deletes each model's TFLite interpreters and frees tensor arena memory. Frees memory used by the feature | ||||
|   /// generation frontend. | ||||
|   /// @brief Deletes each model's TFLite interpreters and frees tensor arena memory. | ||||
|   void unload_models_(); | ||||
|  | ||||
|   /** Performs inference with each configured model | ||||
|    * | ||||
|    * If enough audio samples are available, it will generate one slice of new features. | ||||
|    * It then loops through and performs inference with each of the loaded models. | ||||
|    */ | ||||
|   void update_model_probabilities_(); | ||||
|  | ||||
|   /** Checks every model's recent probabilities to determine if the wake word has been predicted | ||||
|    * | ||||
|    * Verifies the models have processed enough new samples for accurate predictions. | ||||
|    * Sets detected_wake_word_ to the wake word, if one is detected. | ||||
|    * @return True if a wake word is predicted, false otherwise | ||||
|    */ | ||||
|   bool detect_wake_words_(); | ||||
|  | ||||
|   /** Generates features for a window of audio samples | ||||
|    * | ||||
|    * Reads samples from the ring buffer and feeds them into the preprocessor frontend. | ||||
|    * Adapted from TFLite microspeech frontend. | ||||
|    * @param features int8_t array to store the audio features | ||||
|    * @return True if successful, false otherwise. | ||||
|    */ | ||||
|   bool generate_features_for_window_(int8_t features[PREPROCESSOR_FEATURE_SIZE]); | ||||
|  | ||||
|   /// @brief Resets the ring buffer, ignore_windows_, and sliding window probabilities | ||||
|   void reset_states_(); | ||||
|  | ||||
|   /// @brief Returns true if successfully registered the streaming model's TensorFlow operations | ||||
|   bool register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver); | ||||
|   /// @brief Runs an inference with each model using the new spectrogram features | ||||
|   /// @param audio_features (int8_t *) Buffer containing new spectrogram features | ||||
|   /// @return True if successful, false if any errors were encountered | ||||
|   bool update_model_probabilities_(const int8_t audio_features[PREPROCESSOR_FEATURE_SIZE]); | ||||
|  | ||||
|   inline uint16_t new_samples_to_get_() { return (this->features_step_size_ * (AUDIO_SAMPLE_FREQUENCY / 1000)); } | ||||
| }; | ||||
|  | ||||
| template<typename... Ts> class StartAction : public Action<Ts...>, public Parented<MicroWakeWord> { | ||||
|  public: | ||||
|   void play(Ts... x) override { this->parent_->start(); } | ||||
| }; | ||||
|  | ||||
| template<typename... Ts> class StopAction : public Action<Ts...>, public Parented<MicroWakeWord> { | ||||
|  public: | ||||
|   void play(Ts... x) override { this->parent_->stop(); } | ||||
| }; | ||||
|  | ||||
| template<typename... Ts> class IsRunningCondition : public Condition<Ts...>, public Parented<MicroWakeWord> { | ||||
|  public: | ||||
|   bool check(Ts... x) override { return this->parent_->is_running(); } | ||||
| }; | ||||
|  | ||||
| }  // namespace micro_wake_word | ||||
| }  // namespace esphome | ||||
|  | ||||
|   | ||||
| @@ -7,6 +7,10 @@ | ||||
| namespace esphome { | ||||
| namespace micro_wake_word { | ||||
|  | ||||
| // Settings for controlling the spectrogram feature generation by the preprocessor. | ||||
| // These must match the settings used when training a particular model. | ||||
| // All microWakeWord models have been trained with these specific paramters. | ||||
|  | ||||
| // The number of features the audio preprocessor generates per slice | ||||
| static const uint8_t PREPROCESSOR_FEATURE_SIZE = 40; | ||||
| // Duration of each slice used as input into the preprocessor | ||||
| @@ -14,6 +18,21 @@ static const uint8_t FEATURE_DURATION_MS = 30; | ||||
| // Audio sample frequency in hertz | ||||
| static const uint16_t AUDIO_SAMPLE_FREQUENCY = 16000; | ||||
|  | ||||
| static const float FILTERBANK_LOWER_BAND_LIMIT = 125.0; | ||||
| static const float FILTERBANK_UPPER_BAND_LIMIT = 7500.0; | ||||
|  | ||||
| static const uint8_t NOISE_REDUCTION_SMOOTHING_BITS = 10; | ||||
| static const float NOISE_REDUCTION_EVEN_SMOOTHING = 0.025; | ||||
| static const float NOISE_REDUCTION_ODD_SMOOTHING = 0.06; | ||||
| static const float NOISE_REDUCTION_MIN_SIGNAL_REMAINING = 0.05; | ||||
|  | ||||
| static const bool PCAN_GAIN_CONTROL_ENABLE_PCAN = true; | ||||
| static const float PCAN_GAIN_CONTROL_STRENGTH = 0.95; | ||||
| static const float PCAN_GAIN_CONTROL_OFFSET = 80.0; | ||||
| static const uint8_t PCAN_GAIN_CONTROL_GAIN_BITS = 21; | ||||
|  | ||||
| static const bool LOG_SCALE_ENABLE_LOG = true; | ||||
| static const uint8_t LOG_SCALE_SCALE_SHIFT = 6; | ||||
| }  // namespace micro_wake_word | ||||
| }  // namespace esphome | ||||
|  | ||||
|   | ||||
| @@ -1,8 +1,7 @@ | ||||
| #ifdef USE_ESP_IDF | ||||
|  | ||||
| #include "streaming_model.h" | ||||
|  | ||||
| #include "esphome/core/hal.h" | ||||
| #ifdef USE_ESP_IDF | ||||
|  | ||||
| #include "esphome/core/helpers.h" | ||||
| #include "esphome/core/log.h" | ||||
|  | ||||
| @@ -13,18 +12,18 @@ namespace micro_wake_word { | ||||
|  | ||||
| void WakeWordModel::log_model_config() { | ||||
|   ESP_LOGCONFIG(TAG, "    - Wake Word: %s", this->wake_word_.c_str()); | ||||
|   ESP_LOGCONFIG(TAG, "      Probability cutoff: %.3f", this->probability_cutoff_); | ||||
|   ESP_LOGCONFIG(TAG, "      Probability cutoff: %.2f", this->probability_cutoff_ / 255.0f); | ||||
|   ESP_LOGCONFIG(TAG, "      Sliding window size: %d", this->sliding_window_size_); | ||||
| } | ||||
|  | ||||
| void VADModel::log_model_config() { | ||||
|   ESP_LOGCONFIG(TAG, "    - VAD Model"); | ||||
|   ESP_LOGCONFIG(TAG, "      Probability cutoff: %.3f", this->probability_cutoff_); | ||||
|   ESP_LOGCONFIG(TAG, "      Probability cutoff: %.2f", this->probability_cutoff_ / 255.0f); | ||||
|   ESP_LOGCONFIG(TAG, "      Sliding window size: %d", this->sliding_window_size_); | ||||
| } | ||||
|  | ||||
| bool StreamingModel::load_model(tflite::MicroMutableOpResolver<20> &op_resolver) { | ||||
|   ExternalRAMAllocator<uint8_t> arena_allocator(ExternalRAMAllocator<uint8_t>::ALLOW_FAILURE); | ||||
| bool StreamingModel::load_model_() { | ||||
|   RAMAllocator<uint8_t> arena_allocator(RAMAllocator<uint8_t>::ALLOW_FAILURE); | ||||
|  | ||||
|   if (this->tensor_arena_ == nullptr) { | ||||
|     this->tensor_arena_ = arena_allocator.allocate(this->tensor_arena_size_); | ||||
| @@ -51,8 +50,9 @@ bool StreamingModel::load_model(tflite::MicroMutableOpResolver<20> &op_resolver) | ||||
|   } | ||||
|  | ||||
|   if (this->interpreter_ == nullptr) { | ||||
|     this->interpreter_ = make_unique<tflite::MicroInterpreter>( | ||||
|         tflite::GetModel(this->model_start_), op_resolver, this->tensor_arena_, this->tensor_arena_size_, this->mrv_); | ||||
|     this->interpreter_ = | ||||
|         make_unique<tflite::MicroInterpreter>(tflite::GetModel(this->model_start_), this->streaming_op_resolver_, | ||||
|                                               this->tensor_arena_, this->tensor_arena_size_, this->mrv_); | ||||
|     if (this->interpreter_->AllocateTensors() != kTfLiteOk) { | ||||
|       ESP_LOGE(TAG, "Failed to allocate tensors for the streaming model"); | ||||
|       return false; | ||||
| @@ -84,34 +84,55 @@ bool StreamingModel::load_model(tflite::MicroMutableOpResolver<20> &op_resolver) | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   this->loaded_ = true; | ||||
|   this->reset_probabilities(); | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| void StreamingModel::unload_model() { | ||||
|   this->interpreter_.reset(); | ||||
|  | ||||
|   ExternalRAMAllocator<uint8_t> arena_allocator(ExternalRAMAllocator<uint8_t>::ALLOW_FAILURE); | ||||
|   RAMAllocator<uint8_t> arena_allocator(RAMAllocator<uint8_t>::ALLOW_FAILURE); | ||||
|  | ||||
|   if (this->tensor_arena_ != nullptr) { | ||||
|     arena_allocator.deallocate(this->tensor_arena_, this->tensor_arena_size_); | ||||
|     this->tensor_arena_ = nullptr; | ||||
|   } | ||||
|  | ||||
|   if (this->var_arena_ != nullptr) { | ||||
|     arena_allocator.deallocate(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE); | ||||
|     this->var_arena_ = nullptr; | ||||
|   } | ||||
|  | ||||
|   this->loaded_ = false; | ||||
| } | ||||
|  | ||||
| bool StreamingModel::perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE]) { | ||||
|   if (this->interpreter_ != nullptr) { | ||||
|   if (this->enabled_ && !this->loaded_) { | ||||
|     // Model is enabled but isn't loaded | ||||
|     if (!this->load_model_()) { | ||||
|       return false; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   if (!this->enabled_ && this->loaded_) { | ||||
|     // Model is disabled but still loaded | ||||
|     this->unload_model(); | ||||
|     return true; | ||||
|   } | ||||
|  | ||||
|   if (this->loaded_) { | ||||
|     TfLiteTensor *input = this->interpreter_->input(0); | ||||
|  | ||||
|     uint8_t stride = this->interpreter_->input(0)->dims->data[1]; | ||||
|     this->current_stride_step_ = this->current_stride_step_ % stride; | ||||
|  | ||||
|     std::memmove( | ||||
|         (int8_t *) (tflite::GetTensorData<int8_t>(input)) + PREPROCESSOR_FEATURE_SIZE * this->current_stride_step_, | ||||
|         features, PREPROCESSOR_FEATURE_SIZE); | ||||
|     ++this->current_stride_step_; | ||||
|  | ||||
|     uint8_t stride = this->interpreter_->input(0)->dims->data[1]; | ||||
|  | ||||
|     if (this->current_stride_step_ >= stride) { | ||||
|       this->current_stride_step_ = 0; | ||||
|  | ||||
|       TfLiteStatus invoke_status = this->interpreter_->Invoke(); | ||||
|       if (invoke_status != kTfLiteOk) { | ||||
|         ESP_LOGW(TAG, "Streaming interpreter invoke failed"); | ||||
| @@ -124,65 +145,159 @@ bool StreamingModel::perform_streaming_inference(const int8_t features[PREPROCES | ||||
|       if (this->last_n_index_ == this->sliding_window_size_) | ||||
|         this->last_n_index_ = 0; | ||||
|       this->recent_streaming_probabilities_[this->last_n_index_] = output->data.uint8[0];  // probability; | ||||
|       this->unprocessed_probability_status_ = true; | ||||
|     } | ||||
|     this->ignore_windows_ = std::min(this->ignore_windows_ + 1, 0); | ||||
|   } | ||||
|   return true; | ||||
| } | ||||
|   ESP_LOGE(TAG, "Streaming interpreter is not initialized."); | ||||
|   return false; | ||||
| } | ||||
|  | ||||
| void StreamingModel::reset_probabilities() { | ||||
|   for (auto &prob : this->recent_streaming_probabilities_) { | ||||
|     prob = 0; | ||||
|   } | ||||
|   this->ignore_windows_ = -MIN_SLICES_BEFORE_DETECTION; | ||||
| } | ||||
|  | ||||
| WakeWordModel::WakeWordModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_average_size, | ||||
|                              const std::string &wake_word, size_t tensor_arena_size) { | ||||
| WakeWordModel::WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t probability_cutoff, | ||||
|                              size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size, | ||||
|                              bool default_enabled, bool internal_only) { | ||||
|   this->id_ = id; | ||||
|   this->model_start_ = model_start; | ||||
|   this->probability_cutoff_ = probability_cutoff; | ||||
|   this->sliding_window_size_ = sliding_window_average_size; | ||||
|   this->recent_streaming_probabilities_.resize(sliding_window_average_size, 0); | ||||
|   this->wake_word_ = wake_word; | ||||
|   this->tensor_arena_size_ = tensor_arena_size; | ||||
|   this->register_streaming_ops_(this->streaming_op_resolver_); | ||||
|   this->current_stride_step_ = 0; | ||||
|   this->internal_only_ = internal_only; | ||||
|  | ||||
|   this->pref_ = global_preferences->make_preference<bool>(fnv1_hash(id)); | ||||
|   bool enabled; | ||||
|   if (this->pref_.load(&enabled)) { | ||||
|     // Use the enabled state loaded from flash | ||||
|     this->enabled_ = enabled; | ||||
|   } else { | ||||
|     // If no state saved, then use the default | ||||
|     this->enabled_ = default_enabled; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| bool WakeWordModel::determine_detected() { | ||||
| void WakeWordModel::enable() { | ||||
|   this->enabled_ = true; | ||||
|   if (!this->internal_only_) { | ||||
|     this->pref_.save(&this->enabled_); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void WakeWordModel::disable() { | ||||
|   this->enabled_ = false; | ||||
|   if (!this->internal_only_) { | ||||
|     this->pref_.save(&this->enabled_); | ||||
|   } | ||||
| } | ||||
|  | ||||
| DetectionEvent WakeWordModel::determine_detected() { | ||||
|   DetectionEvent detection_event; | ||||
|   detection_event.wake_word = &this->wake_word_; | ||||
|   detection_event.max_probability = 0; | ||||
|   detection_event.average_probability = 0; | ||||
|  | ||||
|   if ((this->ignore_windows_ < 0) || !this->enabled_) { | ||||
|     detection_event.detected = false; | ||||
|     return detection_event; | ||||
|   } | ||||
|  | ||||
|   uint32_t sum = 0; | ||||
|   for (auto &prob : this->recent_streaming_probabilities_) { | ||||
|     detection_event.max_probability = std::max(detection_event.max_probability, prob); | ||||
|     sum += prob; | ||||
|   } | ||||
|  | ||||
|   float sliding_window_average = static_cast<float>(sum) / static_cast<float>(255 * this->sliding_window_size_); | ||||
|   detection_event.average_probability = sum / this->sliding_window_size_; | ||||
|   detection_event.detected = sum > this->probability_cutoff_ * this->sliding_window_size_; | ||||
|  | ||||
|   // Detect the wake word if the sliding window average is above the cutoff | ||||
|   if (sliding_window_average > this->probability_cutoff_) { | ||||
|     ESP_LOGD(TAG, "The '%s' model sliding average probability is %.3f and most recent probability is %.3f", | ||||
|              this->wake_word_.c_str(), sliding_window_average, | ||||
|              this->recent_streaming_probabilities_[this->last_n_index_] / (255.0)); | ||||
|     return true; | ||||
|   } | ||||
|   return false; | ||||
|   this->unprocessed_probability_status_ = false; | ||||
|   return detection_event; | ||||
| } | ||||
|  | ||||
| VADModel::VADModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size, | ||||
| VADModel::VADModel(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size, | ||||
|                    size_t tensor_arena_size) { | ||||
|   this->model_start_ = model_start; | ||||
|   this->probability_cutoff_ = probability_cutoff; | ||||
|   this->sliding_window_size_ = sliding_window_size; | ||||
|   this->recent_streaming_probabilities_.resize(sliding_window_size, 0); | ||||
|   this->tensor_arena_size_ = tensor_arena_size; | ||||
| }; | ||||
|   this->register_streaming_ops_(this->streaming_op_resolver_); | ||||
| } | ||||
|  | ||||
| DetectionEvent VADModel::determine_detected() { | ||||
|   DetectionEvent detection_event; | ||||
|   detection_event.max_probability = 0; | ||||
|   detection_event.average_probability = 0; | ||||
|  | ||||
|   if (!this->enabled_) { | ||||
|     // We disabled the VAD model for some reason... so we shouldn't block wake words from being detected | ||||
|     detection_event.detected = true; | ||||
|     return detection_event; | ||||
|   } | ||||
|  | ||||
| bool VADModel::determine_detected() { | ||||
|   uint32_t sum = 0; | ||||
|   for (auto &prob : this->recent_streaming_probabilities_) { | ||||
|     detection_event.max_probability = std::max(detection_event.max_probability, prob); | ||||
|     sum += prob; | ||||
|   } | ||||
|  | ||||
|   float sliding_window_average = static_cast<float>(sum) / static_cast<float>(255 * this->sliding_window_size_); | ||||
|   detection_event.average_probability = sum / this->sliding_window_size_; | ||||
|   detection_event.detected = sum > (this->probability_cutoff_ * this->sliding_window_size_); | ||||
|  | ||||
|   return sliding_window_average > this->probability_cutoff_; | ||||
|   return detection_event; | ||||
| } | ||||
|  | ||||
| bool StreamingModel::register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver) { | ||||
|   if (op_resolver.AddCallOnce() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddVarHandle() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddReshape() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddReadVariable() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddStridedSlice() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddConcatenation() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddAssignVariable() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddConv2D() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddMul() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddAdd() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddMean() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddFullyConnected() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddLogistic() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddQuantize() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddDepthwiseConv2D() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddAveragePool2D() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddMaxPool2D() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddPad() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddPack() != kTfLiteOk) | ||||
|     return false; | ||||
|   if (op_resolver.AddSplitV() != kTfLiteOk) | ||||
|     return false; | ||||
|  | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| }  // namespace micro_wake_word | ||||
|   | ||||
| @@ -4,6 +4,8 @@ | ||||
|  | ||||
| #include "preprocessor_settings.h" | ||||
|  | ||||
| #include "esphome/core/preferences.h" | ||||
|  | ||||
| #include <tensorflow/lite/core/c/common.h> | ||||
| #include <tensorflow/lite/micro/micro_interpreter.h> | ||||
| #include <tensorflow/lite/micro/micro_mutable_op_resolver.h> | ||||
| @@ -11,30 +13,63 @@ | ||||
| namespace esphome { | ||||
| namespace micro_wake_word { | ||||
|  | ||||
| static const uint8_t MIN_SLICES_BEFORE_DETECTION = 100; | ||||
| static const uint32_t STREAMING_MODEL_VARIABLE_ARENA_SIZE = 1024; | ||||
|  | ||||
| struct DetectionEvent { | ||||
|   std::string *wake_word; | ||||
|   bool detected; | ||||
|   bool partially_detection;  // Set if the most recent probability exceed the threshold, but the sliding window average | ||||
|                              // hasn't yet | ||||
|   uint8_t max_probability; | ||||
|   uint8_t average_probability; | ||||
|   bool blocked_by_vad = false; | ||||
| }; | ||||
|  | ||||
| class StreamingModel { | ||||
|  public: | ||||
|   virtual void log_model_config() = 0; | ||||
|   virtual bool determine_detected() = 0; | ||||
|   virtual DetectionEvent determine_detected() = 0; | ||||
|  | ||||
|   // Performs inference on the given features. | ||||
|   //  - If the model is enabled but not loaded, it will load it | ||||
|   //  - If the model is disabled but loaded, it will unload it | ||||
|   // Returns true if sucessful or false if there is an error | ||||
|   bool perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE]); | ||||
|  | ||||
|   /// @brief Sets all recent_streaming_probabilities to 0 | ||||
|   /// @brief Sets all recent_streaming_probabilities to 0 and resets the ignore window count | ||||
|   void reset_probabilities(); | ||||
|  | ||||
|   /// @brief Allocates tensor and variable arenas and sets up the model interpreter | ||||
|   /// @param op_resolver MicroMutableOpResolver object that must exist until the model is unloaded | ||||
|   /// @return True if successful, false otherwise | ||||
|   bool load_model(tflite::MicroMutableOpResolver<20> &op_resolver); | ||||
|  | ||||
|   /// @brief Destroys the TFLite interpreter and frees the tensor and variable arenas' memory | ||||
|   void unload_model(); | ||||
|  | ||||
|  protected: | ||||
|   uint8_t current_stride_step_{0}; | ||||
|   /// @brief Enable the model. The next performing_streaming_inference call will load it. | ||||
|   virtual void enable() { this->enabled_ = true; } | ||||
|  | ||||
|   float probability_cutoff_; | ||||
|   /// @brief Disable the model. The next performing_streaming_inference call will unload it. | ||||
|   virtual void disable() { this->enabled_ = false; } | ||||
|  | ||||
|   /// @brief Return true if the model is enabled. | ||||
|   bool is_enabled() { return this->enabled_; } | ||||
|  | ||||
|   bool get_unprocessed_probability_status() { return this->unprocessed_probability_status_; } | ||||
|  | ||||
|  protected: | ||||
|   /// @brief Allocates tensor and variable arenas and sets up the model interpreter | ||||
|   /// @return True if successful, false otherwise | ||||
|   bool load_model_(); | ||||
|   /// @brief Returns true if successfully registered the streaming model's TensorFlow operations | ||||
|   bool register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver); | ||||
|  | ||||
|   tflite::MicroMutableOpResolver<20> streaming_op_resolver_; | ||||
|  | ||||
|   bool loaded_{false}; | ||||
|   bool enabled_{true}; | ||||
|   bool unprocessed_probability_status_{false}; | ||||
|   uint8_t current_stride_step_{0}; | ||||
|   int16_t ignore_windows_{-MIN_SLICES_BEFORE_DETECTION}; | ||||
|  | ||||
|   uint8_t probability_cutoff_;  // Quantized probability cutoff mapping 0.0 - 1.0 to 0 - 255 | ||||
|   size_t sliding_window_size_; | ||||
|   size_t last_n_index_{0}; | ||||
|   size_t tensor_arena_size_; | ||||
| @@ -50,32 +85,62 @@ class StreamingModel { | ||||
|  | ||||
| class WakeWordModel final : public StreamingModel { | ||||
|  public: | ||||
|   WakeWordModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_average_size, | ||||
|                 const std::string &wake_word, size_t tensor_arena_size); | ||||
|   /// @brief Constructs a wake word model object | ||||
|   /// @param id (std::string) identifier for this model | ||||
|   /// @param model_start (const uint8_t *) pointer to the start of the model's TFLite FlatBuffer | ||||
|   /// @param probability_cutoff (uint8_t) probability cutoff for acceping the wake word has been said | ||||
|   /// @param sliding_window_average_size (size_t) the length of the sliding window computing the mean rolling | ||||
|   ///                                    probability | ||||
|   /// @param wake_word (std::string) Friendly name of the wake word | ||||
|   /// @param tensor_arena_size (size_t) Size in bytes for allocating the tensor arena | ||||
|   /// @param default_enabled (bool) If true, it will be enabled by default on first boot | ||||
|   /// @param internal_only (bool) If true, the model will not be exposed to HomeAssistant as an available model | ||||
|   WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t probability_cutoff, | ||||
|                 size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size, | ||||
|                 bool default_enabled, bool internal_only); | ||||
|  | ||||
|   void log_model_config() override; | ||||
|  | ||||
|   /// @brief Checks for the wake word by comparing the mean probability in the sliding window with the probability | ||||
|   /// cutoff | ||||
|   /// @return True if wake word is detected, false otherwise | ||||
|   bool determine_detected() override; | ||||
|   DetectionEvent determine_detected() override; | ||||
|  | ||||
|   const std::string &get_id() const { return this->id_; } | ||||
|   const std::string &get_wake_word() const { return this->wake_word_; } | ||||
|  | ||||
|   void add_trained_language(const std::string &language) { this->trained_languages_.push_back(language); } | ||||
|   const std::vector<std::string> &get_trained_languages() const { return this->trained_languages_; } | ||||
|  | ||||
|   /// @brief Enable the model and save to flash. The next performing_streaming_inference call will load it. | ||||
|   void enable() override; | ||||
|  | ||||
|   /// @brief Disable the model and save to flash. The next performing_streaming_inference call will unload it. | ||||
|   void disable() override; | ||||
|  | ||||
|   bool get_internal_only() { return this->internal_only_; } | ||||
|  | ||||
|  protected: | ||||
|   std::string id_; | ||||
|   std::string wake_word_; | ||||
|   std::vector<std::string> trained_languages_; | ||||
|  | ||||
|   bool internal_only_; | ||||
|  | ||||
|   ESPPreferenceObject pref_; | ||||
| }; | ||||
|  | ||||
| class VADModel final : public StreamingModel { | ||||
|  public: | ||||
|   VADModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size); | ||||
|   VADModel(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size, | ||||
|            size_t tensor_arena_size); | ||||
|  | ||||
|   void log_model_config() override; | ||||
|  | ||||
|   /// @brief Checks for voice activity by comparing the max probability in the sliding window with the probability | ||||
|   /// cutoff | ||||
|   /// @return True if voice activity is detected, false otherwise | ||||
|   bool determine_detected() override; | ||||
|   DetectionEvent determine_detected() override; | ||||
| }; | ||||
|  | ||||
| }  // namespace micro_wake_word | ||||
|   | ||||
| @@ -79,6 +79,7 @@ | ||||
| #define USE_LVGL_TEXTAREA | ||||
| #define USE_LVGL_TILEVIEW | ||||
| #define USE_LVGL_TOUCHSCREEN | ||||
| #define USE_MICRO_WAKE_WORD | ||||
| #define USE_MD5 | ||||
| #define USE_MDNS | ||||
| #define USE_MEDIA_PLAYER | ||||
|   | ||||
| @@ -14,8 +14,24 @@ micro_wake_word: | ||||
|   microphone: echo_microphone | ||||
|   on_wake_word_detected: | ||||
|     - logger.log: "Wake word detected" | ||||
|     - micro_wake_word.stop: | ||||
|     - if: | ||||
|         condition: | ||||
|           - micro_wake_word.model_is_enabled: hey_jarvis_model | ||||
|         then: | ||||
|           - micro_wake_word.disable_model: hey_jarvis_model | ||||
|         else: | ||||
|           - micro_wake_word.enable_model: hey_jarvis_model | ||||
|     - if: | ||||
|         condition: | ||||
|           - not: | ||||
|               - micro_wake_word.is_running: | ||||
|         then: | ||||
|           micro_wake_word.start: | ||||
|   stop_after_detection: false | ||||
|   models: | ||||
|     - model: hey_jarvis | ||||
|       probability_cutoff: 0.7 | ||||
|       id: hey_jarvis_model | ||||
|     - model: okay_nabu | ||||
|       sliding_window_size: 5 | ||||
|   | ||||
		Reference in New Issue
	
	Block a user