mirror of
				https://github.com/esphome/esphome.git
				synced 2025-10-30 22:53:59 +00:00 
			
		
		
		
	[micro_wake_word] add new VPE features (#8655)
This commit is contained in:
		| @@ -12,6 +12,7 @@ import esphome.config_validation as cv | |||||||
| from esphome.const import ( | from esphome.const import ( | ||||||
|     CONF_FILE, |     CONF_FILE, | ||||||
|     CONF_ID, |     CONF_ID, | ||||||
|  |     CONF_INTERNAL, | ||||||
|     CONF_MICROPHONE, |     CONF_MICROPHONE, | ||||||
|     CONF_MODEL, |     CONF_MODEL, | ||||||
|     CONF_PASSWORD, |     CONF_PASSWORD, | ||||||
| @@ -40,6 +41,7 @@ CONF_ON_WAKE_WORD_DETECTED = "on_wake_word_detected" | |||||||
| CONF_PROBABILITY_CUTOFF = "probability_cutoff" | CONF_PROBABILITY_CUTOFF = "probability_cutoff" | ||||||
| CONF_SLIDING_WINDOW_AVERAGE_SIZE = "sliding_window_average_size" | CONF_SLIDING_WINDOW_AVERAGE_SIZE = "sliding_window_average_size" | ||||||
| CONF_SLIDING_WINDOW_SIZE = "sliding_window_size" | CONF_SLIDING_WINDOW_SIZE = "sliding_window_size" | ||||||
|  | CONF_STOP_AFTER_DETECTION = "stop_after_detection" | ||||||
| CONF_TENSOR_ARENA_SIZE = "tensor_arena_size" | CONF_TENSOR_ARENA_SIZE = "tensor_arena_size" | ||||||
| CONF_VAD = "vad" | CONF_VAD = "vad" | ||||||
|  |  | ||||||
| @@ -49,13 +51,20 @@ micro_wake_word_ns = cg.esphome_ns.namespace("micro_wake_word") | |||||||
|  |  | ||||||
| MicroWakeWord = micro_wake_word_ns.class_("MicroWakeWord", cg.Component) | MicroWakeWord = micro_wake_word_ns.class_("MicroWakeWord", cg.Component) | ||||||
|  |  | ||||||
|  | DisableModelAction = micro_wake_word_ns.class_("DisableModelAction", automation.Action) | ||||||
|  | EnableModelAction = micro_wake_word_ns.class_("EnableModelAction", automation.Action) | ||||||
| StartAction = micro_wake_word_ns.class_("StartAction", automation.Action) | StartAction = micro_wake_word_ns.class_("StartAction", automation.Action) | ||||||
| StopAction = micro_wake_word_ns.class_("StopAction", automation.Action) | StopAction = micro_wake_word_ns.class_("StopAction", automation.Action) | ||||||
|  |  | ||||||
|  | ModelIsEnabledCondition = micro_wake_word_ns.class_( | ||||||
|  |     "ModelIsEnabledCondition", automation.Condition | ||||||
|  | ) | ||||||
| IsRunningCondition = micro_wake_word_ns.class_( | IsRunningCondition = micro_wake_word_ns.class_( | ||||||
|     "IsRunningCondition", automation.Condition |     "IsRunningCondition", automation.Condition | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | WakeWordModel = micro_wake_word_ns.class_("WakeWordModel") | ||||||
|  |  | ||||||
|  |  | ||||||
| def _validate_json_filename(value): | def _validate_json_filename(value): | ||||||
|     value = cv.string(value) |     value = cv.string(value) | ||||||
| @@ -169,9 +178,10 @@ def _convert_manifest_v1_to_v2(v1_manifest): | |||||||
|  |  | ||||||
|     # Original Inception-based V1 manifest models require a minimum of 45672 bytes |     # Original Inception-based V1 manifest models require a minimum of 45672 bytes | ||||||
|     v2_manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE] = 45672 |     v2_manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE] = 45672 | ||||||
|  |  | ||||||
|     # Original Inception-based V1 manifest models use a 20 ms feature step size |     # Original Inception-based V1 manifest models use a 20 ms feature step size | ||||||
|     v2_manifest[KEY_MICRO][CONF_FEATURE_STEP_SIZE] = 20 |     v2_manifest[KEY_MICRO][CONF_FEATURE_STEP_SIZE] = 20 | ||||||
|  |     # Original Inception-based V1 manifest models were trained only on TTS English samples | ||||||
|  |     v2_manifest[KEY_TRAINED_LANGUAGES] = ["en"] | ||||||
|  |  | ||||||
|     return v2_manifest |     return v2_manifest | ||||||
|  |  | ||||||
| @@ -296,14 +306,16 @@ MODEL_SOURCE_SCHEMA = cv.Any( | |||||||
|  |  | ||||||
| MODEL_SCHEMA = cv.Schema( | MODEL_SCHEMA = cv.Schema( | ||||||
|     { |     { | ||||||
|  |         cv.GenerateID(CONF_ID): cv.declare_id(WakeWordModel), | ||||||
|         cv.Optional(CONF_MODEL): MODEL_SOURCE_SCHEMA, |         cv.Optional(CONF_MODEL): MODEL_SOURCE_SCHEMA, | ||||||
|         cv.Optional(CONF_PROBABILITY_CUTOFF): cv.percentage, |         cv.Optional(CONF_PROBABILITY_CUTOFF): cv.percentage, | ||||||
|         cv.Optional(CONF_SLIDING_WINDOW_SIZE): cv.positive_int, |         cv.Optional(CONF_SLIDING_WINDOW_SIZE): cv.positive_int, | ||||||
|  |         cv.Optional(CONF_INTERNAL, default=False): cv.boolean, | ||||||
|         cv.GenerateID(CONF_RAW_DATA_ID): cv.declare_id(cg.uint8), |         cv.GenerateID(CONF_RAW_DATA_ID): cv.declare_id(cg.uint8), | ||||||
|     } |     } | ||||||
| ) | ) | ||||||
|  |  | ||||||
| # Provide a default VAD model that could be overridden | # Provides a default VAD model that could be overridden | ||||||
| VAD_MODEL_SCHEMA = MODEL_SCHEMA.extend( | VAD_MODEL_SCHEMA = MODEL_SCHEMA.extend( | ||||||
|     cv.Schema( |     cv.Schema( | ||||||
|         { |         { | ||||||
| @@ -343,6 +355,7 @@ CONFIG_SCHEMA = cv.All( | |||||||
|                 single=True |                 single=True | ||||||
|             ), |             ), | ||||||
|             cv.Optional(CONF_VAD): _maybe_empty_vad_schema, |             cv.Optional(CONF_VAD): _maybe_empty_vad_schema, | ||||||
|  |             cv.Optional(CONF_STOP_AFTER_DETECTION, default=True): cv.boolean, | ||||||
|             cv.Optional(CONF_MODEL): cv.invalid( |             cv.Optional(CONF_MODEL): cv.invalid( | ||||||
|                 f"The {CONF_MODEL} parameter has moved to be a list element under the {CONF_MODELS} parameter." |                 f"The {CONF_MODEL} parameter has moved to be a list element under the {CONF_MODELS} parameter." | ||||||
|             ), |             ), | ||||||
| @@ -433,29 +446,20 @@ async def to_code(config): | |||||||
|     mic_source = await microphone.microphone_source_to_code(config[CONF_MICROPHONE]) |     mic_source = await microphone.microphone_source_to_code(config[CONF_MICROPHONE]) | ||||||
|     cg.add(var.set_microphone_source(mic_source)) |     cg.add(var.set_microphone_source(mic_source)) | ||||||
|  |  | ||||||
|  |     cg.add_define("USE_MICRO_WAKE_WORD") | ||||||
|  |     cg.add_define("USE_OTA_STATE_CALLBACK") | ||||||
|  |  | ||||||
|     esp32.add_idf_component( |     esp32.add_idf_component( | ||||||
|         name="esp-tflite-micro", |         name="esp-tflite-micro", | ||||||
|         repo="https://github.com/espressif/esp-tflite-micro", |         repo="https://github.com/espressif/esp-tflite-micro", | ||||||
|         ref="v1.3.1", |         ref="v1.3.3.1", | ||||||
|     ) |  | ||||||
|     # add esp-nn dependency for tflite-micro to work around https://github.com/espressif/esp-nn/issues/17 |  | ||||||
|     # ...remove after switching to IDF 5.1.4+ |  | ||||||
|     esp32.add_idf_component( |  | ||||||
|         name="esp-nn", |  | ||||||
|         repo="https://github.com/espressif/esp-nn", |  | ||||||
|         ref="v1.1.0", |  | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     cg.add_build_flag("-DTF_LITE_STATIC_MEMORY") |     cg.add_build_flag("-DTF_LITE_STATIC_MEMORY") | ||||||
|     cg.add_build_flag("-DTF_LITE_DISABLE_X86_NEON") |     cg.add_build_flag("-DTF_LITE_DISABLE_X86_NEON") | ||||||
|     cg.add_build_flag("-DESP_NN") |     cg.add_build_flag("-DESP_NN") | ||||||
|  |  | ||||||
|     if on_wake_word_detection_config := config.get(CONF_ON_WAKE_WORD_DETECTED): |     cg.add_library("kahrendt/ESPMicroSpeechFeatures", "1.1.0") | ||||||
|         await automation.build_automation( |  | ||||||
|             var.get_wake_word_detected_trigger(), |  | ||||||
|             [(cg.std_string, "wake_word")], |  | ||||||
|             on_wake_word_detection_config, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     if vad_model := config.get(CONF_VAD): |     if vad_model := config.get(CONF_VAD): | ||||||
|         cg.add_define("USE_MICRO_WAKE_WORD_VAD") |         cg.add_define("USE_MICRO_WAKE_WORD_VAD") | ||||||
| @@ -463,7 +467,7 @@ async def to_code(config): | |||||||
|         # Use the general model loading code for the VAD codegen |         # Use the general model loading code for the VAD codegen | ||||||
|         config[CONF_MODELS].append(vad_model) |         config[CONF_MODELS].append(vad_model) | ||||||
|  |  | ||||||
|     for model_parameters in config[CONF_MODELS]: |     for i, model_parameters in enumerate(config[CONF_MODELS]): | ||||||
|         model_config = model_parameters.get(CONF_MODEL) |         model_config = model_parameters.get(CONF_MODEL) | ||||||
|         data = [] |         data = [] | ||||||
|         manifest, data = _model_config_to_manifest_data(model_config) |         manifest, data = _model_config_to_manifest_data(model_config) | ||||||
| @@ -474,6 +478,8 @@ async def to_code(config): | |||||||
|         probability_cutoff = model_parameters.get( |         probability_cutoff = model_parameters.get( | ||||||
|             CONF_PROBABILITY_CUTOFF, manifest[KEY_MICRO][CONF_PROBABILITY_CUTOFF] |             CONF_PROBABILITY_CUTOFF, manifest[KEY_MICRO][CONF_PROBABILITY_CUTOFF] | ||||||
|         ) |         ) | ||||||
|  |         quantized_probability_cutoff = int(probability_cutoff * 255) | ||||||
|  |  | ||||||
|         sliding_window_size = model_parameters.get( |         sliding_window_size = model_parameters.get( | ||||||
|             CONF_SLIDING_WINDOW_SIZE, |             CONF_SLIDING_WINDOW_SIZE, | ||||||
|             manifest[KEY_MICRO][CONF_SLIDING_WINDOW_SIZE], |             manifest[KEY_MICRO][CONF_SLIDING_WINDOW_SIZE], | ||||||
| @@ -483,24 +489,40 @@ async def to_code(config): | |||||||
|             cg.add( |             cg.add( | ||||||
|                 var.add_vad_model( |                 var.add_vad_model( | ||||||
|                     prog_arr, |                     prog_arr, | ||||||
|                     probability_cutoff, |                     quantized_probability_cutoff, | ||||||
|                     sliding_window_size, |                     sliding_window_size, | ||||||
|                     manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE], |                     manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE], | ||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|         else: |         else: | ||||||
|             cg.add( |             # Only enable the first wake word by default. After first boot, the enable state is saved/loaded to the flash | ||||||
|                 var.add_wake_word_model( |             default_enabled = i == 0 | ||||||
|                     prog_arr, |             wake_word_model = cg.new_Pvariable( | ||||||
|                     probability_cutoff, |                 model_parameters[CONF_ID], | ||||||
|                     sliding_window_size, |                 str(model_parameters[CONF_ID]), | ||||||
|                     manifest[KEY_WAKE_WORD], |                 prog_arr, | ||||||
|                     manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE], |                 quantized_probability_cutoff, | ||||||
|                 ) |                 sliding_window_size, | ||||||
|  |                 manifest[KEY_WAKE_WORD], | ||||||
|  |                 manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE], | ||||||
|  |                 default_enabled, | ||||||
|  |                 model_parameters[CONF_INTERNAL], | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|  |             for lang in manifest[KEY_TRAINED_LANGUAGES]: | ||||||
|  |                 cg.add(wake_word_model.add_trained_language(lang)) | ||||||
|  |  | ||||||
|  |             cg.add(var.add_wake_word_model(wake_word_model)) | ||||||
|  |  | ||||||
|     cg.add(var.set_features_step_size(manifest[KEY_MICRO][CONF_FEATURE_STEP_SIZE])) |     cg.add(var.set_features_step_size(manifest[KEY_MICRO][CONF_FEATURE_STEP_SIZE])) | ||||||
|     cg.add_library("kahrendt/ESPMicroSpeechFeatures", "1.1.0") |     cg.add(var.set_stop_after_detection(config[CONF_STOP_AFTER_DETECTION])) | ||||||
|  |  | ||||||
|  |     if on_wake_word_detection_config := config.get(CONF_ON_WAKE_WORD_DETECTED): | ||||||
|  |         await automation.build_automation( | ||||||
|  |             var.get_wake_word_detected_trigger(), | ||||||
|  |             [(cg.std_string, "wake_word")], | ||||||
|  |             on_wake_word_detection_config, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| MICRO_WAKE_WORD_ACTION_SCHEMA = cv.Schema({cv.GenerateID(): cv.use_id(MicroWakeWord)}) | MICRO_WAKE_WORD_ACTION_SCHEMA = cv.Schema({cv.GenerateID(): cv.use_id(MicroWakeWord)}) | ||||||
| @@ -515,3 +537,30 @@ async def micro_wake_word_action_to_code(config, action_id, template_arg, args): | |||||||
|     var = cg.new_Pvariable(action_id, template_arg) |     var = cg.new_Pvariable(action_id, template_arg) | ||||||
|     await cg.register_parented(var, config[CONF_ID]) |     await cg.register_parented(var, config[CONF_ID]) | ||||||
|     return var |     return var | ||||||
|  |  | ||||||
|  |  | ||||||
|  | MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA = automation.maybe_simple_id( | ||||||
|  |     { | ||||||
|  |         cv.Required(CONF_ID): cv.use_id(WakeWordModel), | ||||||
|  |     } | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @register_action( | ||||||
|  |     "micro_wake_word.enable_model", | ||||||
|  |     EnableModelAction, | ||||||
|  |     MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA, | ||||||
|  | ) | ||||||
|  | @register_action( | ||||||
|  |     "micro_wake_word.disable_model", | ||||||
|  |     DisableModelAction, | ||||||
|  |     MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA, | ||||||
|  | ) | ||||||
|  | @register_condition( | ||||||
|  |     "micro_wake_word.model_is_enabled", | ||||||
|  |     ModelIsEnabledCondition, | ||||||
|  |     MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA, | ||||||
|  | ) | ||||||
|  | async def model_action(config, action_id, template_arg, args): | ||||||
|  |     parent = await cg.get_variable(config[CONF_ID]) | ||||||
|  |     return cg.new_Pvariable(action_id, template_arg, parent) | ||||||
|   | |||||||
							
								
								
									
										54
									
								
								esphome/components/micro_wake_word/automation.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								esphome/components/micro_wake_word/automation.h
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,54 @@ | |||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include "micro_wake_word.h" | ||||||
|  | #include "streaming_model.h" | ||||||
|  |  | ||||||
|  | #ifdef USE_ESP_IDF | ||||||
|  | namespace esphome { | ||||||
|  | namespace micro_wake_word { | ||||||
|  |  | ||||||
|  | template<typename... Ts> class StartAction : public Action<Ts...>, public Parented<MicroWakeWord> { | ||||||
|  |  public: | ||||||
|  |   void play(Ts... x) override { this->parent_->start(); } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | template<typename... Ts> class StopAction : public Action<Ts...>, public Parented<MicroWakeWord> { | ||||||
|  |  public: | ||||||
|  |   void play(Ts... x) override { this->parent_->stop(); } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | template<typename... Ts> class IsRunningCondition : public Condition<Ts...>, public Parented<MicroWakeWord> { | ||||||
|  |  public: | ||||||
|  |   bool check(Ts... x) override { return this->parent_->is_running(); } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | template<typename... Ts> class EnableModelAction : public Action<Ts...> { | ||||||
|  |  public: | ||||||
|  |   explicit EnableModelAction(WakeWordModel *wake_word_model) : wake_word_model_(wake_word_model) {} | ||||||
|  |   void play(Ts... x) override { this->wake_word_model_->enable(); } | ||||||
|  |  | ||||||
|  |  protected: | ||||||
|  |   WakeWordModel *wake_word_model_; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | template<typename... Ts> class DisableModelAction : public Action<Ts...> { | ||||||
|  |  public: | ||||||
|  |   explicit DisableModelAction(WakeWordModel *wake_word_model) : wake_word_model_(wake_word_model) {} | ||||||
|  |   void play(Ts... x) override { this->wake_word_model_->disable(); } | ||||||
|  |  | ||||||
|  |  protected: | ||||||
|  |   WakeWordModel *wake_word_model_; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | template<typename... Ts> class ModelIsEnabledCondition : public Condition<Ts...> { | ||||||
|  |  public: | ||||||
|  |   explicit ModelIsEnabledCondition(WakeWordModel *wake_word_model) : wake_word_model_(wake_word_model) {} | ||||||
|  |   bool check(Ts... x) override { return this->wake_word_model_->is_enabled(); } | ||||||
|  |  | ||||||
|  |  protected: | ||||||
|  |   WakeWordModel *wake_word_model_; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | }  // namespace micro_wake_word | ||||||
|  | }  // namespace esphome | ||||||
|  | #endif | ||||||
| @@ -1,5 +1,4 @@ | |||||||
| #include "micro_wake_word.h" | #include "micro_wake_word.h" | ||||||
| #include "streaming_model.h" |  | ||||||
|  |  | ||||||
| #ifdef USE_ESP_IDF | #ifdef USE_ESP_IDF | ||||||
|  |  | ||||||
| @@ -7,41 +6,57 @@ | |||||||
| #include "esphome/core/helpers.h" | #include "esphome/core/helpers.h" | ||||||
| #include "esphome/core/log.h" | #include "esphome/core/log.h" | ||||||
|  |  | ||||||
| #include <frontend.h> | #include "esphome/components/audio/audio_transfer_buffer.h" | ||||||
| #include <frontend_util.h> |  | ||||||
|  |  | ||||||
| #include <tensorflow/lite/core/c/common.h> | #ifdef USE_OTA | ||||||
| #include <tensorflow/lite/micro/micro_interpreter.h> | #include "esphome/components/ota/ota_backend.h" | ||||||
| #include <tensorflow/lite/micro/micro_mutable_op_resolver.h> | #endif | ||||||
|  |  | ||||||
| #include <cmath> |  | ||||||
|  |  | ||||||
| namespace esphome { | namespace esphome { | ||||||
| namespace micro_wake_word { | namespace micro_wake_word { | ||||||
|  |  | ||||||
| static const char *const TAG = "micro_wake_word"; | static const char *const TAG = "micro_wake_word"; | ||||||
|  |  | ||||||
| static const size_t SAMPLE_RATE_HZ = 16000;  // 16 kHz | static const ssize_t DETECTION_QUEUE_LENGTH = 5; | ||||||
| static const size_t BUFFER_LENGTH = 64;      // 0.064 seconds |  | ||||||
| static const size_t BUFFER_SIZE = SAMPLE_RATE_HZ / 1000 * BUFFER_LENGTH; | static const size_t DATA_TIMEOUT_MS = 50; | ||||||
| static const size_t INPUT_BUFFER_SIZE = 16 * SAMPLE_RATE_HZ / 1000;  // 16ms * 16kHz / 1000ms |  | ||||||
|  | static const uint32_t RING_BUFFER_DURATION_MS = 120; | ||||||
|  | static const uint32_t RING_BUFFER_SAMPLES = RING_BUFFER_DURATION_MS * (AUDIO_SAMPLE_FREQUENCY / 1000); | ||||||
|  | static const size_t RING_BUFFER_SIZE = RING_BUFFER_SAMPLES * sizeof(int16_t); | ||||||
|  |  | ||||||
|  | static const uint32_t INFERENCE_TASK_STACK_SIZE = 3072; | ||||||
|  | static const UBaseType_t INFERENCE_TASK_PRIORITY = 3; | ||||||
|  |  | ||||||
|  | enum EventGroupBits : uint32_t { | ||||||
|  |   COMMAND_STOP = (1 << 0),  // Signals the inference task should stop | ||||||
|  |  | ||||||
|  |   TASK_STARTING = (1 << 3), | ||||||
|  |   TASK_RUNNING = (1 << 4), | ||||||
|  |   TASK_STOPPING = (1 << 5), | ||||||
|  |   TASK_STOPPED = (1 << 6), | ||||||
|  |  | ||||||
|  |   ERROR_MEMORY = (1 << 9), | ||||||
|  |   ERROR_INFERENCE = (1 << 10), | ||||||
|  |  | ||||||
|  |   WARNING_FULL_RING_BUFFER = (1 << 13), | ||||||
|  |  | ||||||
|  |   ERROR_BITS = ERROR_MEMORY | ERROR_INFERENCE, | ||||||
|  |   ALL_BITS = 0xfffff,  // 24 total bits available in an event group | ||||||
|  | }; | ||||||
|  |  | ||||||
| float MicroWakeWord::get_setup_priority() const { return setup_priority::AFTER_CONNECTION; } | float MicroWakeWord::get_setup_priority() const { return setup_priority::AFTER_CONNECTION; } | ||||||
|  |  | ||||||
| static const LogString *micro_wake_word_state_to_string(State state) { | static const LogString *micro_wake_word_state_to_string(State state) { | ||||||
|   switch (state) { |   switch (state) { | ||||||
|     case State::IDLE: |     case State::STARTING: | ||||||
|       return LOG_STR("IDLE"); |       return LOG_STR("STARTING"); | ||||||
|     case State::START_MICROPHONE: |  | ||||||
|       return LOG_STR("START_MICROPHONE"); |  | ||||||
|     case State::STARTING_MICROPHONE: |  | ||||||
|       return LOG_STR("STARTING_MICROPHONE"); |  | ||||||
|     case State::DETECTING_WAKE_WORD: |     case State::DETECTING_WAKE_WORD: | ||||||
|       return LOG_STR("DETECTING_WAKE_WORD"); |       return LOG_STR("DETECTING_WAKE_WORD"); | ||||||
|     case State::STOP_MICROPHONE: |     case State::STOPPING: | ||||||
|       return LOG_STR("STOP_MICROPHONE"); |       return LOG_STR("STOPPING"); | ||||||
|     case State::STOPPING_MICROPHONE: |     case State::STOPPED: | ||||||
|       return LOG_STR("STOPPING_MICROPHONE"); |       return LOG_STR("STOPPED"); | ||||||
|     default: |     default: | ||||||
|       return LOG_STR("UNKNOWN"); |       return LOG_STR("UNKNOWN"); | ||||||
|   } |   } | ||||||
| @@ -51,7 +66,7 @@ void MicroWakeWord::dump_config() { | |||||||
|   ESP_LOGCONFIG(TAG, "microWakeWord:"); |   ESP_LOGCONFIG(TAG, "microWakeWord:"); | ||||||
|   ESP_LOGCONFIG(TAG, "  models:"); |   ESP_LOGCONFIG(TAG, "  models:"); | ||||||
|   for (auto &model : this->wake_word_models_) { |   for (auto &model : this->wake_word_models_) { | ||||||
|     model.log_model_config(); |     model->log_model_config(); | ||||||
|   } |   } | ||||||
| #ifdef USE_MICRO_WAKE_WORD_VAD | #ifdef USE_MICRO_WAKE_WORD_VAD | ||||||
|   this->vad_model_->log_model_config(); |   this->vad_model_->log_model_config(); | ||||||
| @@ -61,108 +76,266 @@ void MicroWakeWord::dump_config() { | |||||||
| void MicroWakeWord::setup() { | void MicroWakeWord::setup() { | ||||||
|   ESP_LOGCONFIG(TAG, "Setting up microWakeWord..."); |   ESP_LOGCONFIG(TAG, "Setting up microWakeWord..."); | ||||||
|  |  | ||||||
|  |   this->frontend_config_.window.size_ms = FEATURE_DURATION_MS; | ||||||
|  |   this->frontend_config_.window.step_size_ms = this->features_step_size_; | ||||||
|  |   this->frontend_config_.filterbank.num_channels = PREPROCESSOR_FEATURE_SIZE; | ||||||
|  |   this->frontend_config_.filterbank.lower_band_limit = FILTERBANK_LOWER_BAND_LIMIT; | ||||||
|  |   this->frontend_config_.filterbank.upper_band_limit = FILTERBANK_UPPER_BAND_LIMIT; | ||||||
|  |   this->frontend_config_.noise_reduction.smoothing_bits = NOISE_REDUCTION_SMOOTHING_BITS; | ||||||
|  |   this->frontend_config_.noise_reduction.even_smoothing = NOISE_REDUCTION_EVEN_SMOOTHING; | ||||||
|  |   this->frontend_config_.noise_reduction.odd_smoothing = NOISE_REDUCTION_ODD_SMOOTHING; | ||||||
|  |   this->frontend_config_.noise_reduction.min_signal_remaining = NOISE_REDUCTION_MIN_SIGNAL_REMAINING; | ||||||
|  |   this->frontend_config_.pcan_gain_control.enable_pcan = PCAN_GAIN_CONTROL_ENABLE_PCAN; | ||||||
|  |   this->frontend_config_.pcan_gain_control.strength = PCAN_GAIN_CONTROL_STRENGTH; | ||||||
|  |   this->frontend_config_.pcan_gain_control.offset = PCAN_GAIN_CONTROL_OFFSET; | ||||||
|  |   this->frontend_config_.pcan_gain_control.gain_bits = PCAN_GAIN_CONTROL_GAIN_BITS; | ||||||
|  |   this->frontend_config_.log_scale.enable_log = LOG_SCALE_ENABLE_LOG; | ||||||
|  |   this->frontend_config_.log_scale.scale_shift = LOG_SCALE_SCALE_SHIFT; | ||||||
|  |  | ||||||
|  |   this->event_group_ = xEventGroupCreate(); | ||||||
|  |   if (this->event_group_ == nullptr) { | ||||||
|  |     ESP_LOGE(TAG, "Failed to create event group"); | ||||||
|  |     this->mark_failed(); | ||||||
|  |     return; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   this->detection_queue_ = xQueueCreate(DETECTION_QUEUE_LENGTH, sizeof(DetectionEvent)); | ||||||
|  |   if (this->detection_queue_ == nullptr) { | ||||||
|  |     ESP_LOGE(TAG, "Failed to create detection event queue"); | ||||||
|  |     this->mark_failed(); | ||||||
|  |     return; | ||||||
|  |   } | ||||||
|  |  | ||||||
|   this->microphone_source_->add_data_callback([this](const std::vector<uint8_t> &data) { |   this->microphone_source_->add_data_callback([this](const std::vector<uint8_t> &data) { | ||||||
|     if (this->state_ != State::DETECTING_WAKE_WORD) { |     if (this->state_ == State::STOPPED) { | ||||||
|       return; |       return; | ||||||
|     } |     } | ||||||
|     std::shared_ptr<RingBuffer> temp_ring_buffer = this->ring_buffer_; |     std::shared_ptr<RingBuffer> temp_ring_buffer = this->ring_buffer_.lock(); | ||||||
|     if (this->ring_buffer_.use_count() == 2) { |     if (this->ring_buffer_.use_count() > 1) { | ||||||
|       // mWW still owns the ring buffer and temp_ring_buffer does as well, proceed to copy audio into ring buffer |  | ||||||
|  |  | ||||||
|       size_t bytes_free = temp_ring_buffer->free(); |       size_t bytes_free = temp_ring_buffer->free(); | ||||||
|  |  | ||||||
|       if (bytes_free < data.size()) { |       if (bytes_free < data.size()) { | ||||||
|         ESP_LOGW( |         xEventGroupSetBits(this->event_group_, EventGroupBits::WARNING_FULL_RING_BUFFER); | ||||||
|             TAG, |  | ||||||
|             "Not enough free bytes in ring buffer to store incoming audio data (free bytes=%d, incoming bytes=%d). " |  | ||||||
|             "Resetting the ring buffer. Wake word detection accuracy will be reduced.", |  | ||||||
|             bytes_free, data.size()); |  | ||||||
|  |  | ||||||
|         temp_ring_buffer->reset(); |         temp_ring_buffer->reset(); | ||||||
|       } |       } | ||||||
|       temp_ring_buffer->write((void *) data.data(), data.size()); |       temp_ring_buffer->write((void *) data.data(), data.size()); | ||||||
|     } |     } | ||||||
|   }); |   }); | ||||||
|  |  | ||||||
|   if (!this->register_streaming_ops_(this->streaming_op_resolver_)) { | #ifdef USE_OTA | ||||||
|     this->mark_failed(); |   ota::get_global_ota_callback()->add_on_state_callback( | ||||||
|     return; |       [this](ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) { | ||||||
|  |         if (state == ota::OTA_STARTED) { | ||||||
|  |           this->suspend_task_(); | ||||||
|  |         } else if (state == ota::OTA_ERROR) { | ||||||
|  |           this->resume_task_(); | ||||||
|  |         } | ||||||
|  |       }); | ||||||
|  | #endif | ||||||
|  |   ESP_LOGCONFIG(TAG, "Micro Wake Word initialized"); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void MicroWakeWord::inference_task(void *params) { | ||||||
|  |   MicroWakeWord *this_mww = (MicroWakeWord *) params; | ||||||
|  |  | ||||||
|  |   xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STARTING); | ||||||
|  |  | ||||||
|  |   {  // Ensures any C++ objects fall out of scope to deallocate before deleting the task | ||||||
|  |     const size_t new_samples_to_read = this_mww->features_step_size_ * (AUDIO_SAMPLE_FREQUENCY / 1000); | ||||||
|  |     std::unique_ptr<audio::AudioSourceTransferBuffer> audio_buffer; | ||||||
|  |     int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE]; | ||||||
|  |  | ||||||
|  |     if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) { | ||||||
|  |       // Allocate audio transfer buffer | ||||||
|  |       audio_buffer = audio::AudioSourceTransferBuffer::create(new_samples_to_read * sizeof(int16_t)); | ||||||
|  |  | ||||||
|  |       if (audio_buffer == nullptr) { | ||||||
|  |         xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_MEMORY); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) { | ||||||
|  |       // Allocate ring buffer | ||||||
|  |       std::shared_ptr<RingBuffer> temp_ring_buffer = RingBuffer::create(RING_BUFFER_SIZE); | ||||||
|  |       if (temp_ring_buffer.use_count() == 0) { | ||||||
|  |         xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_MEMORY); | ||||||
|  |       } | ||||||
|  |       audio_buffer->set_source(temp_ring_buffer); | ||||||
|  |       this_mww->ring_buffer_ = temp_ring_buffer; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) { | ||||||
|  |       this_mww->microphone_source_->start(); | ||||||
|  |       xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_RUNNING); | ||||||
|  |  | ||||||
|  |       while (!(xEventGroupGetBits(this_mww->event_group_) & COMMAND_STOP)) { | ||||||
|  |         audio_buffer->transfer_data_from_source(pdMS_TO_TICKS(DATA_TIMEOUT_MS)); | ||||||
|  |  | ||||||
|  |         if (audio_buffer->available() < new_samples_to_read * sizeof(int16_t)) { | ||||||
|  |           // Insufficient data to generate new spectrogram features, read more next iteration | ||||||
|  |           continue; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         // Generate new spectrogram features | ||||||
|  |         size_t processed_samples = this_mww->generate_features_( | ||||||
|  |             (int16_t *) audio_buffer->get_buffer_start(), audio_buffer->available() / sizeof(int16_t), features_buffer); | ||||||
|  |         audio_buffer->decrease_buffer_length(processed_samples * sizeof(int16_t)); | ||||||
|  |  | ||||||
|  |         // Run inference using the new spectorgram features | ||||||
|  |         if (!this_mww->update_model_probabilities_(features_buffer)) { | ||||||
|  |           xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_INFERENCE); | ||||||
|  |           break; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         // Process each model's probabilities and possibly send a Detection Event to the queue | ||||||
|  |         this_mww->process_probabilities_(); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   ESP_LOGCONFIG(TAG, "Micro Wake Word initialized"); |   xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STOPPING); | ||||||
|  |  | ||||||
|   this->frontend_config_.window.size_ms = FEATURE_DURATION_MS; |   this_mww->unload_models_(); | ||||||
|   this->frontend_config_.window.step_size_ms = this->features_step_size_; |   this_mww->microphone_source_->stop(); | ||||||
|   this->frontend_config_.filterbank.num_channels = PREPROCESSOR_FEATURE_SIZE; |   FrontendFreeStateContents(&this_mww->frontend_state_); | ||||||
|   this->frontend_config_.filterbank.lower_band_limit = 125.0; |  | ||||||
|   this->frontend_config_.filterbank.upper_band_limit = 7500.0; |   xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STOPPED); | ||||||
|   this->frontend_config_.noise_reduction.smoothing_bits = 10; |   while (true) { | ||||||
|   this->frontend_config_.noise_reduction.even_smoothing = 0.025; |     // Continuously delay until the main loop deletes the task | ||||||
|   this->frontend_config_.noise_reduction.odd_smoothing = 0.06; |     delay(10); | ||||||
|   this->frontend_config_.noise_reduction.min_signal_remaining = 0.05; |   } | ||||||
|   this->frontend_config_.pcan_gain_control.enable_pcan = 1; |  | ||||||
|   this->frontend_config_.pcan_gain_control.strength = 0.95; |  | ||||||
|   this->frontend_config_.pcan_gain_control.offset = 80.0; |  | ||||||
|   this->frontend_config_.pcan_gain_control.gain_bits = 21; |  | ||||||
|   this->frontend_config_.log_scale.enable_log = 1; |  | ||||||
|   this->frontend_config_.log_scale.scale_shift = 6; |  | ||||||
| } | } | ||||||
|  |  | ||||||
| void MicroWakeWord::add_wake_word_model(const uint8_t *model_start, float probability_cutoff, | std::vector<WakeWordModel *> MicroWakeWord::get_wake_words() { | ||||||
|                                         size_t sliding_window_average_size, const std::string &wake_word, |   std::vector<WakeWordModel *> external_wake_word_models; | ||||||
|                                         size_t tensor_arena_size) { |   for (auto *model : this->wake_word_models_) { | ||||||
|   this->wake_word_models_.emplace_back(model_start, probability_cutoff, sliding_window_average_size, wake_word, |     if (!model->get_internal_only()) { | ||||||
|                                        tensor_arena_size); |       external_wake_word_models.push_back(model); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   return external_wake_word_models; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | void MicroWakeWord::add_wake_word_model(WakeWordModel *model) { this->wake_word_models_.push_back(model); } | ||||||
|  |  | ||||||
| #ifdef USE_MICRO_WAKE_WORD_VAD | #ifdef USE_MICRO_WAKE_WORD_VAD | ||||||
| void MicroWakeWord::add_vad_model(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size, | void MicroWakeWord::add_vad_model(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size, | ||||||
|                                   size_t tensor_arena_size) { |                                   size_t tensor_arena_size) { | ||||||
|   this->vad_model_ = make_unique<VADModel>(model_start, probability_cutoff, sliding_window_size, tensor_arena_size); |   this->vad_model_ = make_unique<VADModel>(model_start, probability_cutoff, sliding_window_size, tensor_arena_size); | ||||||
| } | } | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|  | void MicroWakeWord::suspend_task_() { | ||||||
|  |   if (this->inference_task_handle_ != nullptr) { | ||||||
|  |     vTaskSuspend(this->inference_task_handle_); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void MicroWakeWord::resume_task_() { | ||||||
|  |   if (this->inference_task_handle_ != nullptr) { | ||||||
|  |     vTaskResume(this->inference_task_handle_); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
| void MicroWakeWord::loop() { | void MicroWakeWord::loop() { | ||||||
|  |   uint32_t event_group_bits = xEventGroupGetBits(this->event_group_); | ||||||
|  |  | ||||||
|  |   if (event_group_bits & EventGroupBits::ERROR_MEMORY) { | ||||||
|  |     xEventGroupClearBits(this->event_group_, EventGroupBits::ERROR_MEMORY); | ||||||
|  |     ESP_LOGE(TAG, "Encountered an error allocating buffers"); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   if (event_group_bits & EventGroupBits::ERROR_INFERENCE) { | ||||||
|  |     xEventGroupClearBits(this->event_group_, EventGroupBits::ERROR_INFERENCE); | ||||||
|  |     ESP_LOGE(TAG, "Encountered an error while performing an inference"); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   if (event_group_bits & EventGroupBits::WARNING_FULL_RING_BUFFER) { | ||||||
|  |     xEventGroupClearBits(this->event_group_, EventGroupBits::WARNING_FULL_RING_BUFFER); | ||||||
|  |     ESP_LOGW(TAG, "Not enough free bytes in ring buffer to store incoming audio data. Resetting the ring buffer. Wake " | ||||||
|  |                   "word detection accuracy will temporarily be reduced."); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   if (event_group_bits & EventGroupBits::TASK_STARTING) { | ||||||
|  |     ESP_LOGD(TAG, "Inference task has started, attempting to allocate memory for buffers"); | ||||||
|  |     xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_STARTING); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   if (event_group_bits & EventGroupBits::TASK_RUNNING) { | ||||||
|  |     ESP_LOGD(TAG, "Inference task is running"); | ||||||
|  |  | ||||||
|  |     xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_RUNNING); | ||||||
|  |     this->set_state_(State::DETECTING_WAKE_WORD); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   if (event_group_bits & EventGroupBits::TASK_STOPPING) { | ||||||
|  |     ESP_LOGD(TAG, "Inference task is stopping, deallocating buffers"); | ||||||
|  |     xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_STOPPING); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   if ((event_group_bits & EventGroupBits::TASK_STOPPED)) { | ||||||
|  |     ESP_LOGD(TAG, "Inference task is finished, freeing task resources"); | ||||||
|  |     vTaskDelete(this->inference_task_handle_); | ||||||
|  |     this->inference_task_handle_ = nullptr; | ||||||
|  |     xEventGroupClearBits(this->event_group_, ALL_BITS); | ||||||
|  |     xQueueReset(this->detection_queue_); | ||||||
|  |     this->set_state_(State::STOPPED); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   if ((this->pending_start_) && (this->state_ == State::STOPPED)) { | ||||||
|  |     this->set_state_(State::STARTING); | ||||||
|  |     this->pending_start_ = false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   if ((this->pending_stop_) && (this->state_ == State::DETECTING_WAKE_WORD)) { | ||||||
|  |     this->set_state_(State::STOPPING); | ||||||
|  |     this->pending_stop_ = false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|   switch (this->state_) { |   switch (this->state_) { | ||||||
|     case State::IDLE: |     case State::STARTING: | ||||||
|       break; |       if ((this->inference_task_handle_ == nullptr) && !this->status_has_error()) { | ||||||
|     case State::START_MICROPHONE: |         // Setup preprocesor feature generator. If done in the task, it would lock the task to its initial core, as it | ||||||
|       ESP_LOGD(TAG, "Starting Microphone"); |         // uses floating point operations. | ||||||
|       this->microphone_source_->start(); |         if (!FrontendPopulateState(&this->frontend_config_, &this->frontend_state_, AUDIO_SAMPLE_FREQUENCY)) { | ||||||
|       this->set_state_(State::STARTING_MICROPHONE); |           this->status_momentary_error( | ||||||
|       break; |               "Failed to allocate buffers for spectrogram feature processor, attempting again in 1 second", 1000); | ||||||
|     case State::STARTING_MICROPHONE: |           return; | ||||||
|       if (this->microphone_source_->is_running()) { |         } | ||||||
|         this->set_state_(State::DETECTING_WAKE_WORD); |  | ||||||
|       } |         xTaskCreate(MicroWakeWord::inference_task, "mww", INFERENCE_TASK_STACK_SIZE, (void *) this, | ||||||
|       break; |                     INFERENCE_TASK_PRIORITY, &this->inference_task_handle_); | ||||||
|     case State::DETECTING_WAKE_WORD: |  | ||||||
|       while (this->has_enough_samples_()) { |         if (this->inference_task_handle_ == nullptr) { | ||||||
|         this->update_model_probabilities_(); |           FrontendFreeStateContents(&this->frontend_state_);  // Deallocate frontend state | ||||||
|         if (this->detect_wake_words_()) { |           this->status_momentary_error("Task failed to start, attempting again in 1 second", 1000); | ||||||
|           ESP_LOGD(TAG, "Wake Word '%s' Detected", (this->detected_wake_word_).c_str()); |  | ||||||
|           this->detected_ = true; |  | ||||||
|           this->set_state_(State::STOP_MICROPHONE); |  | ||||||
|         } |         } | ||||||
|       } |       } | ||||||
|       break; |       break; | ||||||
|     case State::STOP_MICROPHONE: |     case State::DETECTING_WAKE_WORD: { | ||||||
|       ESP_LOGD(TAG, "Stopping Microphone"); |       DetectionEvent detection_event; | ||||||
|       this->microphone_source_->stop(); |       while (xQueueReceive(this->detection_queue_, &detection_event, 0)) { | ||||||
|       this->set_state_(State::STOPPING_MICROPHONE); |         if (detection_event.blocked_by_vad) { | ||||||
|       this->unload_models_(); |           ESP_LOGD(TAG, "Wake word model predicts '%s', but VAD model doesn't.", detection_event.wake_word->c_str()); | ||||||
|       this->deallocate_buffers_(); |         } else { | ||||||
|       break; |           constexpr float uint8_to_float_divisor = | ||||||
|     case State::STOPPING_MICROPHONE: |               255.0f;  // Converting a quantized uint8 probability to floating point | ||||||
|       if (this->microphone_source_->is_stopped()) { |           ESP_LOGD(TAG, "Detected '%s' with sliding average probability is %.2f and max probability is %.2f", | ||||||
|         this->set_state_(State::IDLE); |                    detection_event.wake_word->c_str(), (detection_event.average_probability / uint8_to_float_divisor), | ||||||
|         if (this->detected_) { |                    (detection_event.max_probability / uint8_to_float_divisor)); | ||||||
|           this->wake_word_detected_trigger_->trigger(this->detected_wake_word_); |           this->wake_word_detected_trigger_->trigger(*detection_event.wake_word); | ||||||
|           this->detected_ = false; |           if (this->stop_after_detection_) { | ||||||
|           this->detected_wake_word_ = ""; |             this->stop(); | ||||||
|  |           } | ||||||
|         } |         } | ||||||
|       } |       } | ||||||
|       break; |       break; | ||||||
|  |     } | ||||||
|  |     case State::STOPPING: | ||||||
|  |       xEventGroupSetBits(this->event_group_, EventGroupBits::COMMAND_STOP); | ||||||
|  |       break; | ||||||
|  |     case State::STOPPED: | ||||||
|  |       break; | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -177,199 +350,40 @@ void MicroWakeWord::start() { | |||||||
|     return; |     return; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   if (this->state_ != State::IDLE) { |   if (this->is_running()) { | ||||||
|     ESP_LOGW(TAG, "Wake word is already running"); |     ESP_LOGW(TAG, "Wake word detection is already running"); | ||||||
|     return; |     return; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   if (!this->load_models_() || !this->allocate_buffers_()) { |   ESP_LOGD(TAG, "Starting wake word detection"); | ||||||
|     ESP_LOGE(TAG, "Failed to load the wake word model(s) or allocate buffers"); |  | ||||||
|     this->status_set_error(); |  | ||||||
|   } else { |  | ||||||
|     this->status_clear_error(); |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   if (this->status_has_error()) { |   this->pending_start_ = true; | ||||||
|     ESP_LOGW(TAG, "Wake word component has an error. Please check logs"); |   this->pending_stop_ = false; | ||||||
|     return; |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   this->reset_states_(); |  | ||||||
|   this->set_state_(State::START_MICROPHONE); |  | ||||||
| } | } | ||||||
|  |  | ||||||
| void MicroWakeWord::stop() { | void MicroWakeWord::stop() { | ||||||
|   if (this->state_ == State::IDLE) { |   if (this->state_ == STOPPED) | ||||||
|     ESP_LOGW(TAG, "Wake word is already stopped"); |  | ||||||
|     return; |     return; | ||||||
|   } |  | ||||||
|   if (this->state_ == State::STOPPING_MICROPHONE) { |   ESP_LOGD(TAG, "Stopping wake word detection"); | ||||||
|     ESP_LOGW(TAG, "Wake word is already stopping"); |  | ||||||
|     return; |   this->pending_start_ = false; | ||||||
|   } |   this->pending_stop_ = true; | ||||||
|   this->set_state_(State::STOP_MICROPHONE); |  | ||||||
| } | } | ||||||
|  |  | ||||||
| void MicroWakeWord::set_state_(State state) { | void MicroWakeWord::set_state_(State state) { | ||||||
|   ESP_LOGD(TAG, "State changed from %s to %s", LOG_STR_ARG(micro_wake_word_state_to_string(this->state_)), |   if (this->state_ != state) { | ||||||
|            LOG_STR_ARG(micro_wake_word_state_to_string(state))); |     ESP_LOGD(TAG, "State changed from %s to %s", LOG_STR_ARG(micro_wake_word_state_to_string(this->state_)), | ||||||
|   this->state_ = state; |              LOG_STR_ARG(micro_wake_word_state_to_string(state))); | ||||||
|  |     this->state_ = state; | ||||||
|  |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| bool MicroWakeWord::allocate_buffers_() { | size_t MicroWakeWord::generate_features_(int16_t *audio_buffer, size_t samples_available, | ||||||
|   ExternalRAMAllocator<int16_t> audio_samples_allocator(ExternalRAMAllocator<int16_t>::ALLOW_FAILURE); |                                          int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE]) { | ||||||
|  |   size_t processed_samples = 0; | ||||||
|   if (this->input_buffer_ == nullptr) { |   struct FrontendOutput frontend_output = | ||||||
|     this->input_buffer_ = audio_samples_allocator.allocate(INPUT_BUFFER_SIZE * sizeof(int16_t)); |       FrontendProcessSamples(&this->frontend_state_, audio_buffer, samples_available, &processed_samples); | ||||||
|     if (this->input_buffer_ == nullptr) { |  | ||||||
|       ESP_LOGE(TAG, "Could not allocate input buffer"); |  | ||||||
|       return false; |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   if (this->preprocessor_audio_buffer_ == nullptr) { |  | ||||||
|     this->preprocessor_audio_buffer_ = audio_samples_allocator.allocate(this->new_samples_to_get_()); |  | ||||||
|     if (this->preprocessor_audio_buffer_ == nullptr) { |  | ||||||
|       ESP_LOGE(TAG, "Could not allocate the audio preprocessor's buffer."); |  | ||||||
|       return false; |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   if (this->ring_buffer_.use_count() == 0) { |  | ||||||
|     this->ring_buffer_ = RingBuffer::create(BUFFER_SIZE * sizeof(int16_t)); |  | ||||||
|     if (this->ring_buffer_.use_count() == 0) { |  | ||||||
|       ESP_LOGE(TAG, "Could not allocate ring buffer"); |  | ||||||
|       return false; |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   return true; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| void MicroWakeWord::deallocate_buffers_() { |  | ||||||
|   ExternalRAMAllocator<int16_t> audio_samples_allocator(ExternalRAMAllocator<int16_t>::ALLOW_FAILURE); |  | ||||||
|   if (this->input_buffer_ != nullptr) { |  | ||||||
|     audio_samples_allocator.deallocate(this->input_buffer_, INPUT_BUFFER_SIZE * sizeof(int16_t)); |  | ||||||
|     this->input_buffer_ = nullptr; |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   if (this->preprocessor_audio_buffer_ != nullptr) { |  | ||||||
|     audio_samples_allocator.deallocate(this->preprocessor_audio_buffer_, this->new_samples_to_get_()); |  | ||||||
|     this->preprocessor_audio_buffer_ = nullptr; |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   this->ring_buffer_.reset(); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| bool MicroWakeWord::load_models_() { |  | ||||||
|   // Setup preprocesor feature generator |  | ||||||
|   if (!FrontendPopulateState(&this->frontend_config_, &this->frontend_state_, AUDIO_SAMPLE_FREQUENCY)) { |  | ||||||
|     ESP_LOGD(TAG, "Failed to populate frontend state"); |  | ||||||
|     FrontendFreeStateContents(&this->frontend_state_); |  | ||||||
|     return false; |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   // Setup streaming models |  | ||||||
|   for (auto &model : this->wake_word_models_) { |  | ||||||
|     if (!model.load_model(this->streaming_op_resolver_)) { |  | ||||||
|       ESP_LOGE(TAG, "Failed to initialize a wake word model."); |  | ||||||
|       return false; |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| #ifdef USE_MICRO_WAKE_WORD_VAD |  | ||||||
|   if (!this->vad_model_->load_model(this->streaming_op_resolver_)) { |  | ||||||
|     ESP_LOGE(TAG, "Failed to initialize VAD model."); |  | ||||||
|     return false; |  | ||||||
|   } |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
|   return true; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| void MicroWakeWord::unload_models_() { |  | ||||||
|   FrontendFreeStateContents(&this->frontend_state_); |  | ||||||
|  |  | ||||||
|   for (auto &model : this->wake_word_models_) { |  | ||||||
|     model.unload_model(); |  | ||||||
|   } |  | ||||||
| #ifdef USE_MICRO_WAKE_WORD_VAD |  | ||||||
|   this->vad_model_->unload_model(); |  | ||||||
| #endif |  | ||||||
| } |  | ||||||
|  |  | ||||||
| void MicroWakeWord::update_model_probabilities_() { |  | ||||||
|   int8_t audio_features[PREPROCESSOR_FEATURE_SIZE]; |  | ||||||
|  |  | ||||||
|   if (!this->generate_features_for_window_(audio_features)) { |  | ||||||
|     return; |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   // Increase the counter since the last positive detection |  | ||||||
|   this->ignore_windows_ = std::min(this->ignore_windows_ + 1, 0); |  | ||||||
|  |  | ||||||
|   for (auto &model : this->wake_word_models_) { |  | ||||||
|     // Perform inference |  | ||||||
|     model.perform_streaming_inference(audio_features); |  | ||||||
|   } |  | ||||||
| #ifdef USE_MICRO_WAKE_WORD_VAD |  | ||||||
|   this->vad_model_->perform_streaming_inference(audio_features); |  | ||||||
| #endif |  | ||||||
| } |  | ||||||
|  |  | ||||||
| bool MicroWakeWord::detect_wake_words_() { |  | ||||||
|   // Verify we have processed samples since the last positive detection |  | ||||||
|   if (this->ignore_windows_ < 0) { |  | ||||||
|     return false; |  | ||||||
|   } |  | ||||||
|  |  | ||||||
| #ifdef USE_MICRO_WAKE_WORD_VAD |  | ||||||
|   bool vad_state = this->vad_model_->determine_detected(); |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
|   for (auto &model : this->wake_word_models_) { |  | ||||||
|     if (model.determine_detected()) { |  | ||||||
| #ifdef USE_MICRO_WAKE_WORD_VAD |  | ||||||
|       if (vad_state) { |  | ||||||
| #endif |  | ||||||
|         this->detected_wake_word_ = model.get_wake_word(); |  | ||||||
|         return true; |  | ||||||
| #ifdef USE_MICRO_WAKE_WORD_VAD |  | ||||||
|       } else { |  | ||||||
|         ESP_LOGD(TAG, "Wake word model predicts %s, but VAD model doesn't.", model.get_wake_word().c_str()); |  | ||||||
|       } |  | ||||||
| #endif |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   return false; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| bool MicroWakeWord::has_enough_samples_() { |  | ||||||
|   return this->ring_buffer_->available() >= |  | ||||||
|          (this->features_step_size_ * (AUDIO_SAMPLE_FREQUENCY / 1000)) * sizeof(int16_t); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| bool MicroWakeWord::generate_features_for_window_(int8_t features[PREPROCESSOR_FEATURE_SIZE]) { |  | ||||||
|   // Ensure we have enough new audio samples in the ring buffer for a full window |  | ||||||
|   if (!this->has_enough_samples_()) { |  | ||||||
|     return false; |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   size_t bytes_read = this->ring_buffer_->read((void *) (this->preprocessor_audio_buffer_), |  | ||||||
|                                                this->new_samples_to_get_() * sizeof(int16_t), pdMS_TO_TICKS(200)); |  | ||||||
|  |  | ||||||
|   if (bytes_read == 0) { |  | ||||||
|     ESP_LOGE(TAG, "Could not read data from Ring Buffer"); |  | ||||||
|   } else if (bytes_read < this->new_samples_to_get_() * sizeof(int16_t)) { |  | ||||||
|     ESP_LOGD(TAG, "Partial Read of Data by Model"); |  | ||||||
|     ESP_LOGD(TAG, "Could only read %d bytes when required %d bytes ", bytes_read, |  | ||||||
|              (int) (this->new_samples_to_get_() * sizeof(int16_t))); |  | ||||||
|     return false; |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   size_t num_samples_read; |  | ||||||
|   struct FrontendOutput frontend_output = FrontendProcessSamples( |  | ||||||
|       &this->frontend_state_, this->preprocessor_audio_buffer_, this->new_samples_to_get_(), &num_samples_read); |  | ||||||
|  |  | ||||||
|   for (size_t i = 0; i < frontend_output.size; ++i) { |   for (size_t i = 0; i < frontend_output.size; ++i) { | ||||||
|     // These scaling values are set to match the TFLite audio frontend int8 output. |     // These scaling values are set to match the TFLite audio frontend int8 output. | ||||||
| @@ -379,8 +393,8 @@ bool MicroWakeWord::generate_features_for_window_(int8_t features[PREPROCESSOR_F | |||||||
|     // for historical reasons, to match up with the output of other feature |     // for historical reasons, to match up with the output of other feature | ||||||
|     // generators. |     // generators. | ||||||
|     // The process is then further complicated when we quantize the model. This |     // The process is then further complicated when we quantize the model. This | ||||||
|     // means we have to scale the 0.0 to 26.0 real values to the -128 to 127 |     // means we have to scale the 0.0 to 26.0 real values to the -128 (INT8_MIN) | ||||||
|     // signed integer numbers. |     // to 127 (INT8_MAX) signed integer numbers. | ||||||
|     // All this means that to get matching values from our integer feature |     // All this means that to get matching values from our integer feature | ||||||
|     // output into the tensor input, we have to perform: |     // output into the tensor input, we have to perform: | ||||||
|     // input = (((feature / 25.6) / 26.0) * 256) - 128 |     // input = (((feature / 25.6) / 26.0) * 256) - 128 | ||||||
| @@ -389,74 +403,63 @@ bool MicroWakeWord::generate_features_for_window_(int8_t features[PREPROCESSOR_F | |||||||
|     constexpr int32_t value_scale = 256; |     constexpr int32_t value_scale = 256; | ||||||
|     constexpr int32_t value_div = 666;  // 666 = 25.6 * 26.0 after rounding |     constexpr int32_t value_div = 666;  // 666 = 25.6 * 26.0 after rounding | ||||||
|     int32_t value = ((frontend_output.values[i] * value_scale) + (value_div / 2)) / value_div; |     int32_t value = ((frontend_output.values[i] * value_scale) + (value_div / 2)) / value_div; | ||||||
|     value -= 128; |  | ||||||
|     if (value < -128) { |     value -= INT8_MIN; | ||||||
|       value = -128; |     features_buffer[i] = clamp<int8_t>(value, INT8_MIN, INT8_MAX); | ||||||
|     } |  | ||||||
|     if (value > 127) { |  | ||||||
|       value = 127; |  | ||||||
|     } |  | ||||||
|     features[i] = value; |  | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   return true; |   return processed_samples; | ||||||
| } | } | ||||||
|  |  | ||||||
| void MicroWakeWord::reset_states_() { | void MicroWakeWord::process_probabilities_() { | ||||||
|   ESP_LOGD(TAG, "Resetting buffers and probabilities"); | #ifdef USE_MICRO_WAKE_WORD_VAD | ||||||
|   this->ring_buffer_->reset(); |   DetectionEvent vad_state = this->vad_model_->determine_detected(); | ||||||
|   this->ignore_windows_ = -MIN_SLICES_BEFORE_DETECTION; |  | ||||||
|  |   this->vad_state_ = vad_state.detected;  // atomic write, so thread safe | ||||||
|  | #endif | ||||||
|  |  | ||||||
|   for (auto &model : this->wake_word_models_) { |   for (auto &model : this->wake_word_models_) { | ||||||
|     model.reset_probabilities(); |     if (model->get_unprocessed_probability_status()) { | ||||||
|  |       // Only detect wake words if there is a new probability since the last check | ||||||
|  |       DetectionEvent wake_word_state = model->determine_detected(); | ||||||
|  |       if (wake_word_state.detected) { | ||||||
|  | #ifdef USE_MICRO_WAKE_WORD_VAD | ||||||
|  |         if (vad_state.detected) { | ||||||
|  | #endif | ||||||
|  |           xQueueSend(this->detection_queue_, &wake_word_state, portMAX_DELAY); | ||||||
|  |           model->reset_probabilities(); | ||||||
|  | #ifdef USE_MICRO_WAKE_WORD_VAD | ||||||
|  |         } else { | ||||||
|  |           wake_word_state.blocked_by_vad = true; | ||||||
|  |           xQueueSend(this->detection_queue_, &wake_word_state, portMAX_DELAY); | ||||||
|  |         } | ||||||
|  | #endif | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void MicroWakeWord::unload_models_() { | ||||||
|  |   for (auto &model : this->wake_word_models_) { | ||||||
|  |     model->unload_model(); | ||||||
|   } |   } | ||||||
| #ifdef USE_MICRO_WAKE_WORD_VAD | #ifdef USE_MICRO_WAKE_WORD_VAD | ||||||
|   this->vad_model_->reset_probabilities(); |   this->vad_model_->unload_model(); | ||||||
| #endif | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| bool MicroWakeWord::register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver) { | bool MicroWakeWord::update_model_probabilities_(const int8_t audio_features[PREPROCESSOR_FEATURE_SIZE]) { | ||||||
|   if (op_resolver.AddCallOnce() != kTfLiteOk) |   bool success = true; | ||||||
|     return false; |  | ||||||
|   if (op_resolver.AddVarHandle() != kTfLiteOk) |  | ||||||
|     return false; |  | ||||||
|   if (op_resolver.AddReshape() != kTfLiteOk) |  | ||||||
|     return false; |  | ||||||
|   if (op_resolver.AddReadVariable() != kTfLiteOk) |  | ||||||
|     return false; |  | ||||||
|   if (op_resolver.AddStridedSlice() != kTfLiteOk) |  | ||||||
|     return false; |  | ||||||
|   if (op_resolver.AddConcatenation() != kTfLiteOk) |  | ||||||
|     return false; |  | ||||||
|   if (op_resolver.AddAssignVariable() != kTfLiteOk) |  | ||||||
|     return false; |  | ||||||
|   if (op_resolver.AddConv2D() != kTfLiteOk) |  | ||||||
|     return false; |  | ||||||
|   if (op_resolver.AddMul() != kTfLiteOk) |  | ||||||
|     return false; |  | ||||||
|   if (op_resolver.AddAdd() != kTfLiteOk) |  | ||||||
|     return false; |  | ||||||
|   if (op_resolver.AddMean() != kTfLiteOk) |  | ||||||
|     return false; |  | ||||||
|   if (op_resolver.AddFullyConnected() != kTfLiteOk) |  | ||||||
|     return false; |  | ||||||
|   if (op_resolver.AddLogistic() != kTfLiteOk) |  | ||||||
|     return false; |  | ||||||
|   if (op_resolver.AddQuantize() != kTfLiteOk) |  | ||||||
|     return false; |  | ||||||
|   if (op_resolver.AddDepthwiseConv2D() != kTfLiteOk) |  | ||||||
|     return false; |  | ||||||
|   if (op_resolver.AddAveragePool2D() != kTfLiteOk) |  | ||||||
|     return false; |  | ||||||
|   if (op_resolver.AddMaxPool2D() != kTfLiteOk) |  | ||||||
|     return false; |  | ||||||
|   if (op_resolver.AddPad() != kTfLiteOk) |  | ||||||
|     return false; |  | ||||||
|   if (op_resolver.AddPack() != kTfLiteOk) |  | ||||||
|     return false; |  | ||||||
|   if (op_resolver.AddSplitV() != kTfLiteOk) |  | ||||||
|     return false; |  | ||||||
|  |  | ||||||
|   return true; |   for (auto &model : this->wake_word_models_) { | ||||||
|  |     // Perform inference | ||||||
|  |     success = success & model->perform_streaming_inference(audio_features); | ||||||
|  |   } | ||||||
|  | #ifdef USE_MICRO_WAKE_WORD_VAD | ||||||
|  |   success = success & this->vad_model_->perform_streaming_inference(audio_features); | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  |   return success; | ||||||
| } | } | ||||||
|  |  | ||||||
| }  // namespace micro_wake_word | }  // namespace micro_wake_word | ||||||
|   | |||||||
| @@ -5,33 +5,27 @@ | |||||||
| #include "preprocessor_settings.h" | #include "preprocessor_settings.h" | ||||||
| #include "streaming_model.h" | #include "streaming_model.h" | ||||||
|  |  | ||||||
|  | #include "esphome/components/microphone/microphone_source.h" | ||||||
|  |  | ||||||
| #include "esphome/core/automation.h" | #include "esphome/core/automation.h" | ||||||
| #include "esphome/core/component.h" | #include "esphome/core/component.h" | ||||||
| #include "esphome/core/ring_buffer.h" | #include "esphome/core/ring_buffer.h" | ||||||
|  |  | ||||||
| #include "esphome/components/microphone/microphone_source.h" | #include <freertos/event_groups.h> | ||||||
|  |  | ||||||
|  | #include <frontend.h> | ||||||
| #include <frontend_util.h> | #include <frontend_util.h> | ||||||
|  |  | ||||||
| #include <tensorflow/lite/core/c/common.h> |  | ||||||
| #include <tensorflow/lite/micro/micro_interpreter.h> |  | ||||||
| #include <tensorflow/lite/micro/micro_mutable_op_resolver.h> |  | ||||||
|  |  | ||||||
| namespace esphome { | namespace esphome { | ||||||
| namespace micro_wake_word { | namespace micro_wake_word { | ||||||
|  |  | ||||||
| enum State { | enum State { | ||||||
|   IDLE, |   STARTING, | ||||||
|   START_MICROPHONE, |  | ||||||
|   STARTING_MICROPHONE, |  | ||||||
|   DETECTING_WAKE_WORD, |   DETECTING_WAKE_WORD, | ||||||
|   STOP_MICROPHONE, |   STOPPING, | ||||||
|   STOPPING_MICROPHONE, |   STOPPED, | ||||||
| }; | }; | ||||||
|  |  | ||||||
| // The number of audio slices to process before accepting a positive detection |  | ||||||
| static const uint8_t MIN_SLICES_BEFORE_DETECTION = 74; |  | ||||||
|  |  | ||||||
| class MicroWakeWord : public Component { | class MicroWakeWord : public Component { | ||||||
|  public: |  public: | ||||||
|   void setup() override; |   void setup() override; | ||||||
| @@ -42,7 +36,7 @@ class MicroWakeWord : public Component { | |||||||
|   void start(); |   void start(); | ||||||
|   void stop(); |   void stop(); | ||||||
|  |  | ||||||
|   bool is_running() const { return this->state_ != State::IDLE; } |   bool is_running() const { return this->state_ != State::STOPPED; } | ||||||
|  |  | ||||||
|   void set_features_step_size(uint8_t step_size) { this->features_step_size_ = step_size; } |   void set_features_step_size(uint8_t step_size) { this->features_step_size_ = step_size; } | ||||||
|  |  | ||||||
| @@ -50,118 +44,87 @@ class MicroWakeWord : public Component { | |||||||
|     this->microphone_source_ = microphone_source; |     this->microphone_source_ = microphone_source; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   void set_stop_after_detection(bool stop_after_detection) { this->stop_after_detection_ = stop_after_detection; } | ||||||
|  |  | ||||||
|   Trigger<std::string> *get_wake_word_detected_trigger() const { return this->wake_word_detected_trigger_; } |   Trigger<std::string> *get_wake_word_detected_trigger() const { return this->wake_word_detected_trigger_; } | ||||||
|  |  | ||||||
|   void add_wake_word_model(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_average_size, |   void add_wake_word_model(WakeWordModel *model); | ||||||
|                            const std::string &wake_word, size_t tensor_arena_size); |  | ||||||
|  |  | ||||||
| #ifdef USE_MICRO_WAKE_WORD_VAD | #ifdef USE_MICRO_WAKE_WORD_VAD | ||||||
|   void add_vad_model(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size, |   void add_vad_model(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size, | ||||||
|                      size_t tensor_arena_size); |                      size_t tensor_arena_size); | ||||||
|  |  | ||||||
|  |   // Intended for the voice assistant component to fetch VAD status | ||||||
|  |   bool get_vad_state() { return this->vad_state_; } | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|  |   // Intended for the voice assistant component to access which wake words are available | ||||||
|  |   // Since these are pointers to the WakeWordModel objects, the voice assistant component can enable or disable them | ||||||
|  |   std::vector<WakeWordModel *> get_wake_words(); | ||||||
|  |  | ||||||
|  protected: |  protected: | ||||||
|   microphone::MicrophoneSource *microphone_source_{nullptr}; |   microphone::MicrophoneSource *microphone_source_{nullptr}; | ||||||
|   Trigger<std::string> *wake_word_detected_trigger_ = new Trigger<std::string>(); |   Trigger<std::string> *wake_word_detected_trigger_ = new Trigger<std::string>(); | ||||||
|   State state_{State::IDLE}; |   State state_{State::STOPPED}; | ||||||
|  |  | ||||||
|   std::shared_ptr<RingBuffer> ring_buffer_; |   std::weak_ptr<RingBuffer> ring_buffer_; | ||||||
|  |   std::vector<WakeWordModel *> wake_word_models_; | ||||||
|   std::vector<WakeWordModel> wake_word_models_; |  | ||||||
|  |  | ||||||
| #ifdef USE_MICRO_WAKE_WORD_VAD | #ifdef USE_MICRO_WAKE_WORD_VAD | ||||||
|   std::unique_ptr<VADModel> vad_model_; |   std::unique_ptr<VADModel> vad_model_; | ||||||
|  |   bool vad_state_{false}; | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|   tflite::MicroMutableOpResolver<20> streaming_op_resolver_; |   bool pending_start_{false}; | ||||||
|  |   bool pending_stop_{false}; | ||||||
|  |  | ||||||
|  |   bool stop_after_detection_; | ||||||
|  |  | ||||||
|  |   uint8_t features_step_size_; | ||||||
|  |  | ||||||
|   // Audio frontend handles generating spectrogram features |   // Audio frontend handles generating spectrogram features | ||||||
|   struct FrontendConfig frontend_config_; |   struct FrontendConfig frontend_config_; | ||||||
|   struct FrontendState frontend_state_; |   struct FrontendState frontend_state_; | ||||||
|  |  | ||||||
|   // When the wake word detection first starts, we ignore this many audio |   // Handles managing the stop/state of the inference task | ||||||
|   // feature slices before accepting a positive detection |   EventGroupHandle_t event_group_; | ||||||
|   int16_t ignore_windows_{-MIN_SLICES_BEFORE_DETECTION}; |  | ||||||
|  |  | ||||||
|   uint8_t features_step_size_; |   // Used to send messages about the models' states to the main loop | ||||||
|  |   QueueHandle_t detection_queue_; | ||||||
|  |  | ||||||
|   // Stores audio read from the microphone before being added to the ring buffer. |   static void inference_task(void *params); | ||||||
|   int16_t *input_buffer_{nullptr}; |   TaskHandle_t inference_task_handle_{nullptr}; | ||||||
|   // Stores audio to be fed into the audio frontend for generating features. |  | ||||||
|   int16_t *preprocessor_audio_buffer_{nullptr}; |  | ||||||
|  |  | ||||||
|   bool detected_{false}; |   /// @brief Suspends the inference task | ||||||
|   std::string detected_wake_word_{""}; |   void suspend_task_(); | ||||||
|  |   /// @brief Resumes the inference task | ||||||
|  |   void resume_task_(); | ||||||
|  |  | ||||||
|   void set_state_(State state); |   void set_state_(State state); | ||||||
|  |  | ||||||
|   /// @brief Tests if there are enough samples in the ring buffer to generate new features. |   /// @brief Generates spectrogram features from an input buffer of audio samples | ||||||
|   /// @return True if enough samples, false otherwise. |   /// @param audio_buffer (int16_t *) Buffer containing input audio samples | ||||||
|   bool has_enough_samples_(); |   /// @param samples_available (size_t) Number of samples avaiable in the input buffer | ||||||
|  |   /// @param features_buffer (int8_t *) Buffer to store generated features | ||||||
|  |   /// @return (size_t) Number of samples processed from the input buffer | ||||||
|  |   size_t generate_features_(int16_t *audio_buffer, size_t samples_available, | ||||||
|  |                             int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE]); | ||||||
|  |  | ||||||
|   /// @brief Allocates memory for input_buffer_, preprocessor_audio_buffer_, and ring_buffer_ |   /// @brief Processes any new probabilities for each model. If any wake word is detected, it will send a DetectionEvent | ||||||
|   /// @return True if successful, false otherwise |   /// to the detection_queue_. | ||||||
|   bool allocate_buffers_(); |   void process_probabilities_(); | ||||||
|  |  | ||||||
|   /// @brief Frees memory allocated for input_buffer_ and preprocessor_audio_buffer_ |   /// @brief Deletes each model's TFLite interpreters and frees tensor arena memory. | ||||||
|   void deallocate_buffers_(); |  | ||||||
|  |  | ||||||
|   /// @brief Loads streaming models and prepares the feature generation frontend |  | ||||||
|   /// @return True if successful, false otherwise |  | ||||||
|   bool load_models_(); |  | ||||||
|  |  | ||||||
|   /// @brief Deletes each model's TFLite interpreters and frees tensor arena memory. Frees memory used by the feature |  | ||||||
|   /// generation frontend. |  | ||||||
|   void unload_models_(); |   void unload_models_(); | ||||||
|  |  | ||||||
|   /** Performs inference with each configured model |   /// @brief Runs an inference with each model using the new spectrogram features | ||||||
|    * |   /// @param audio_features (int8_t *) Buffer containing new spectrogram features | ||||||
|    * If enough audio samples are available, it will generate one slice of new features. |   /// @return True if successful, false if any errors were encountered | ||||||
|    * It then loops through and performs inference with each of the loaded models. |   bool update_model_probabilities_(const int8_t audio_features[PREPROCESSOR_FEATURE_SIZE]); | ||||||
|    */ |  | ||||||
|   void update_model_probabilities_(); |  | ||||||
|  |  | ||||||
|   /** Checks every model's recent probabilities to determine if the wake word has been predicted |  | ||||||
|    * |  | ||||||
|    * Verifies the models have processed enough new samples for accurate predictions. |  | ||||||
|    * Sets detected_wake_word_ to the wake word, if one is detected. |  | ||||||
|    * @return True if a wake word is predicted, false otherwise |  | ||||||
|    */ |  | ||||||
|   bool detect_wake_words_(); |  | ||||||
|  |  | ||||||
|   /** Generates features for a window of audio samples |  | ||||||
|    * |  | ||||||
|    * Reads samples from the ring buffer and feeds them into the preprocessor frontend. |  | ||||||
|    * Adapted from TFLite microspeech frontend. |  | ||||||
|    * @param features int8_t array to store the audio features |  | ||||||
|    * @return True if successful, false otherwise. |  | ||||||
|    */ |  | ||||||
|   bool generate_features_for_window_(int8_t features[PREPROCESSOR_FEATURE_SIZE]); |  | ||||||
|  |  | ||||||
|   /// @brief Resets the ring buffer, ignore_windows_, and sliding window probabilities |  | ||||||
|   void reset_states_(); |  | ||||||
|  |  | ||||||
|   /// @brief Returns true if successfully registered the streaming model's TensorFlow operations |  | ||||||
|   bool register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver); |  | ||||||
|  |  | ||||||
|   inline uint16_t new_samples_to_get_() { return (this->features_step_size_ * (AUDIO_SAMPLE_FREQUENCY / 1000)); } |   inline uint16_t new_samples_to_get_() { return (this->features_step_size_ * (AUDIO_SAMPLE_FREQUENCY / 1000)); } | ||||||
| }; | }; | ||||||
|  |  | ||||||
| template<typename... Ts> class StartAction : public Action<Ts...>, public Parented<MicroWakeWord> { |  | ||||||
|  public: |  | ||||||
|   void play(Ts... x) override { this->parent_->start(); } |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| template<typename... Ts> class StopAction : public Action<Ts...>, public Parented<MicroWakeWord> { |  | ||||||
|  public: |  | ||||||
|   void play(Ts... x) override { this->parent_->stop(); } |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| template<typename... Ts> class IsRunningCondition : public Condition<Ts...>, public Parented<MicroWakeWord> { |  | ||||||
|  public: |  | ||||||
|   bool check(Ts... x) override { return this->parent_->is_running(); } |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| }  // namespace micro_wake_word | }  // namespace micro_wake_word | ||||||
| }  // namespace esphome | }  // namespace esphome | ||||||
|  |  | ||||||
|   | |||||||
| @@ -7,6 +7,10 @@ | |||||||
| namespace esphome { | namespace esphome { | ||||||
| namespace micro_wake_word { | namespace micro_wake_word { | ||||||
|  |  | ||||||
|  | // Settings for controlling the spectrogram feature generation by the preprocessor. | ||||||
|  | // These must match the settings used when training a particular model. | ||||||
|  | // All microWakeWord models have been trained with these specific paramters. | ||||||
|  |  | ||||||
| // The number of features the audio preprocessor generates per slice | // The number of features the audio preprocessor generates per slice | ||||||
| static const uint8_t PREPROCESSOR_FEATURE_SIZE = 40; | static const uint8_t PREPROCESSOR_FEATURE_SIZE = 40; | ||||||
| // Duration of each slice used as input into the preprocessor | // Duration of each slice used as input into the preprocessor | ||||||
| @@ -14,6 +18,21 @@ static const uint8_t FEATURE_DURATION_MS = 30; | |||||||
| // Audio sample frequency in hertz | // Audio sample frequency in hertz | ||||||
| static const uint16_t AUDIO_SAMPLE_FREQUENCY = 16000; | static const uint16_t AUDIO_SAMPLE_FREQUENCY = 16000; | ||||||
|  |  | ||||||
|  | static const float FILTERBANK_LOWER_BAND_LIMIT = 125.0; | ||||||
|  | static const float FILTERBANK_UPPER_BAND_LIMIT = 7500.0; | ||||||
|  |  | ||||||
|  | static const uint8_t NOISE_REDUCTION_SMOOTHING_BITS = 10; | ||||||
|  | static const float NOISE_REDUCTION_EVEN_SMOOTHING = 0.025; | ||||||
|  | static const float NOISE_REDUCTION_ODD_SMOOTHING = 0.06; | ||||||
|  | static const float NOISE_REDUCTION_MIN_SIGNAL_REMAINING = 0.05; | ||||||
|  |  | ||||||
|  | static const bool PCAN_GAIN_CONTROL_ENABLE_PCAN = true; | ||||||
|  | static const float PCAN_GAIN_CONTROL_STRENGTH = 0.95; | ||||||
|  | static const float PCAN_GAIN_CONTROL_OFFSET = 80.0; | ||||||
|  | static const uint8_t PCAN_GAIN_CONTROL_GAIN_BITS = 21; | ||||||
|  |  | ||||||
|  | static const bool LOG_SCALE_ENABLE_LOG = true; | ||||||
|  | static const uint8_t LOG_SCALE_SCALE_SHIFT = 6; | ||||||
| }  // namespace micro_wake_word | }  // namespace micro_wake_word | ||||||
| }  // namespace esphome | }  // namespace esphome | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,8 +1,7 @@ | |||||||
| #ifdef USE_ESP_IDF |  | ||||||
|  |  | ||||||
| #include "streaming_model.h" | #include "streaming_model.h" | ||||||
|  |  | ||||||
| #include "esphome/core/hal.h" | #ifdef USE_ESP_IDF | ||||||
|  |  | ||||||
| #include "esphome/core/helpers.h" | #include "esphome/core/helpers.h" | ||||||
| #include "esphome/core/log.h" | #include "esphome/core/log.h" | ||||||
|  |  | ||||||
| @@ -13,18 +12,18 @@ namespace micro_wake_word { | |||||||
|  |  | ||||||
| void WakeWordModel::log_model_config() { | void WakeWordModel::log_model_config() { | ||||||
|   ESP_LOGCONFIG(TAG, "    - Wake Word: %s", this->wake_word_.c_str()); |   ESP_LOGCONFIG(TAG, "    - Wake Word: %s", this->wake_word_.c_str()); | ||||||
|   ESP_LOGCONFIG(TAG, "      Probability cutoff: %.3f", this->probability_cutoff_); |   ESP_LOGCONFIG(TAG, "      Probability cutoff: %.2f", this->probability_cutoff_ / 255.0f); | ||||||
|   ESP_LOGCONFIG(TAG, "      Sliding window size: %d", this->sliding_window_size_); |   ESP_LOGCONFIG(TAG, "      Sliding window size: %d", this->sliding_window_size_); | ||||||
| } | } | ||||||
|  |  | ||||||
| void VADModel::log_model_config() { | void VADModel::log_model_config() { | ||||||
|   ESP_LOGCONFIG(TAG, "    - VAD Model"); |   ESP_LOGCONFIG(TAG, "    - VAD Model"); | ||||||
|   ESP_LOGCONFIG(TAG, "      Probability cutoff: %.3f", this->probability_cutoff_); |   ESP_LOGCONFIG(TAG, "      Probability cutoff: %.2f", this->probability_cutoff_ / 255.0f); | ||||||
|   ESP_LOGCONFIG(TAG, "      Sliding window size: %d", this->sliding_window_size_); |   ESP_LOGCONFIG(TAG, "      Sliding window size: %d", this->sliding_window_size_); | ||||||
| } | } | ||||||
|  |  | ||||||
| bool StreamingModel::load_model(tflite::MicroMutableOpResolver<20> &op_resolver) { | bool StreamingModel::load_model_() { | ||||||
|   ExternalRAMAllocator<uint8_t> arena_allocator(ExternalRAMAllocator<uint8_t>::ALLOW_FAILURE); |   RAMAllocator<uint8_t> arena_allocator(RAMAllocator<uint8_t>::ALLOW_FAILURE); | ||||||
|  |  | ||||||
|   if (this->tensor_arena_ == nullptr) { |   if (this->tensor_arena_ == nullptr) { | ||||||
|     this->tensor_arena_ = arena_allocator.allocate(this->tensor_arena_size_); |     this->tensor_arena_ = arena_allocator.allocate(this->tensor_arena_size_); | ||||||
| @@ -51,8 +50,9 @@ bool StreamingModel::load_model(tflite::MicroMutableOpResolver<20> &op_resolver) | |||||||
|   } |   } | ||||||
|  |  | ||||||
|   if (this->interpreter_ == nullptr) { |   if (this->interpreter_ == nullptr) { | ||||||
|     this->interpreter_ = make_unique<tflite::MicroInterpreter>( |     this->interpreter_ = | ||||||
|         tflite::GetModel(this->model_start_), op_resolver, this->tensor_arena_, this->tensor_arena_size_, this->mrv_); |         make_unique<tflite::MicroInterpreter>(tflite::GetModel(this->model_start_), this->streaming_op_resolver_, | ||||||
|  |                                               this->tensor_arena_, this->tensor_arena_size_, this->mrv_); | ||||||
|     if (this->interpreter_->AllocateTensors() != kTfLiteOk) { |     if (this->interpreter_->AllocateTensors() != kTfLiteOk) { | ||||||
|       ESP_LOGE(TAG, "Failed to allocate tensors for the streaming model"); |       ESP_LOGE(TAG, "Failed to allocate tensors for the streaming model"); | ||||||
|       return false; |       return false; | ||||||
| @@ -84,34 +84,55 @@ bool StreamingModel::load_model(tflite::MicroMutableOpResolver<20> &op_resolver) | |||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   this->loaded_ = true; | ||||||
|  |   this->reset_probabilities(); | ||||||
|   return true; |   return true; | ||||||
| } | } | ||||||
|  |  | ||||||
| void StreamingModel::unload_model() { | void StreamingModel::unload_model() { | ||||||
|   this->interpreter_.reset(); |   this->interpreter_.reset(); | ||||||
|  |  | ||||||
|   ExternalRAMAllocator<uint8_t> arena_allocator(ExternalRAMAllocator<uint8_t>::ALLOW_FAILURE); |   RAMAllocator<uint8_t> arena_allocator(RAMAllocator<uint8_t>::ALLOW_FAILURE); | ||||||
|  |  | ||||||
|   arena_allocator.deallocate(this->tensor_arena_, this->tensor_arena_size_); |   if (this->tensor_arena_ != nullptr) { | ||||||
|   this->tensor_arena_ = nullptr; |     arena_allocator.deallocate(this->tensor_arena_, this->tensor_arena_size_); | ||||||
|   arena_allocator.deallocate(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE); |     this->tensor_arena_ = nullptr; | ||||||
|   this->var_arena_ = nullptr; |   } | ||||||
|  |  | ||||||
|  |   if (this->var_arena_ != nullptr) { | ||||||
|  |     arena_allocator.deallocate(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE); | ||||||
|  |     this->var_arena_ = nullptr; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   this->loaded_ = false; | ||||||
| } | } | ||||||
|  |  | ||||||
| bool StreamingModel::perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE]) { | bool StreamingModel::perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE]) { | ||||||
|   if (this->interpreter_ != nullptr) { |   if (this->enabled_ && !this->loaded_) { | ||||||
|  |     // Model is enabled but isn't loaded | ||||||
|  |     if (!this->load_model_()) { | ||||||
|  |       return false; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   if (!this->enabled_ && this->loaded_) { | ||||||
|  |     // Model is disabled but still loaded | ||||||
|  |     this->unload_model(); | ||||||
|  |     return true; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   if (this->loaded_) { | ||||||
|     TfLiteTensor *input = this->interpreter_->input(0); |     TfLiteTensor *input = this->interpreter_->input(0); | ||||||
|  |  | ||||||
|  |     uint8_t stride = this->interpreter_->input(0)->dims->data[1]; | ||||||
|  |     this->current_stride_step_ = this->current_stride_step_ % stride; | ||||||
|  |  | ||||||
|     std::memmove( |     std::memmove( | ||||||
|         (int8_t *) (tflite::GetTensorData<int8_t>(input)) + PREPROCESSOR_FEATURE_SIZE * this->current_stride_step_, |         (int8_t *) (tflite::GetTensorData<int8_t>(input)) + PREPROCESSOR_FEATURE_SIZE * this->current_stride_step_, | ||||||
|         features, PREPROCESSOR_FEATURE_SIZE); |         features, PREPROCESSOR_FEATURE_SIZE); | ||||||
|     ++this->current_stride_step_; |     ++this->current_stride_step_; | ||||||
|  |  | ||||||
|     uint8_t stride = this->interpreter_->input(0)->dims->data[1]; |  | ||||||
|  |  | ||||||
|     if (this->current_stride_step_ >= stride) { |     if (this->current_stride_step_ >= stride) { | ||||||
|       this->current_stride_step_ = 0; |  | ||||||
|  |  | ||||||
|       TfLiteStatus invoke_status = this->interpreter_->Invoke(); |       TfLiteStatus invoke_status = this->interpreter_->Invoke(); | ||||||
|       if (invoke_status != kTfLiteOk) { |       if (invoke_status != kTfLiteOk) { | ||||||
|         ESP_LOGW(TAG, "Streaming interpreter invoke failed"); |         ESP_LOGW(TAG, "Streaming interpreter invoke failed"); | ||||||
| @@ -124,65 +145,159 @@ bool StreamingModel::perform_streaming_inference(const int8_t features[PREPROCES | |||||||
|       if (this->last_n_index_ == this->sliding_window_size_) |       if (this->last_n_index_ == this->sliding_window_size_) | ||||||
|         this->last_n_index_ = 0; |         this->last_n_index_ = 0; | ||||||
|       this->recent_streaming_probabilities_[this->last_n_index_] = output->data.uint8[0];  // probability; |       this->recent_streaming_probabilities_[this->last_n_index_] = output->data.uint8[0];  // probability; | ||||||
|  |       this->unprocessed_probability_status_ = true; | ||||||
|     } |     } | ||||||
|     return true; |     this->ignore_windows_ = std::min(this->ignore_windows_ + 1, 0); | ||||||
|   } |   } | ||||||
|   ESP_LOGE(TAG, "Streaming interpreter is not initialized."); |   return true; | ||||||
|   return false; |  | ||||||
| } | } | ||||||
|  |  | ||||||
| void StreamingModel::reset_probabilities() { | void StreamingModel::reset_probabilities() { | ||||||
|   for (auto &prob : this->recent_streaming_probabilities_) { |   for (auto &prob : this->recent_streaming_probabilities_) { | ||||||
|     prob = 0; |     prob = 0; | ||||||
|   } |   } | ||||||
|  |   this->ignore_windows_ = -MIN_SLICES_BEFORE_DETECTION; | ||||||
| } | } | ||||||
|  |  | ||||||
| WakeWordModel::WakeWordModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_average_size, | WakeWordModel::WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t probability_cutoff, | ||||||
|                              const std::string &wake_word, size_t tensor_arena_size) { |                              size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size, | ||||||
|  |                              bool default_enabled, bool internal_only) { | ||||||
|  |   this->id_ = id; | ||||||
|   this->model_start_ = model_start; |   this->model_start_ = model_start; | ||||||
|   this->probability_cutoff_ = probability_cutoff; |   this->probability_cutoff_ = probability_cutoff; | ||||||
|   this->sliding_window_size_ = sliding_window_average_size; |   this->sliding_window_size_ = sliding_window_average_size; | ||||||
|   this->recent_streaming_probabilities_.resize(sliding_window_average_size, 0); |   this->recent_streaming_probabilities_.resize(sliding_window_average_size, 0); | ||||||
|   this->wake_word_ = wake_word; |   this->wake_word_ = wake_word; | ||||||
|   this->tensor_arena_size_ = tensor_arena_size; |   this->tensor_arena_size_ = tensor_arena_size; | ||||||
|  |   this->register_streaming_ops_(this->streaming_op_resolver_); | ||||||
|  |   this->current_stride_step_ = 0; | ||||||
|  |   this->internal_only_ = internal_only; | ||||||
|  |  | ||||||
|  |   this->pref_ = global_preferences->make_preference<bool>(fnv1_hash(id)); | ||||||
|  |   bool enabled; | ||||||
|  |   if (this->pref_.load(&enabled)) { | ||||||
|  |     // Use the enabled state loaded from flash | ||||||
|  |     this->enabled_ = enabled; | ||||||
|  |   } else { | ||||||
|  |     // If no state saved, then use the default | ||||||
|  |     this->enabled_ = default_enabled; | ||||||
|  |   } | ||||||
| }; | }; | ||||||
|  |  | ||||||
| bool WakeWordModel::determine_detected() { | void WakeWordModel::enable() { | ||||||
|  |   this->enabled_ = true; | ||||||
|  |   if (!this->internal_only_) { | ||||||
|  |     this->pref_.save(&this->enabled_); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void WakeWordModel::disable() { | ||||||
|  |   this->enabled_ = false; | ||||||
|  |   if (!this->internal_only_) { | ||||||
|  |     this->pref_.save(&this->enabled_); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | DetectionEvent WakeWordModel::determine_detected() { | ||||||
|  |   DetectionEvent detection_event; | ||||||
|  |   detection_event.wake_word = &this->wake_word_; | ||||||
|  |   detection_event.max_probability = 0; | ||||||
|  |   detection_event.average_probability = 0; | ||||||
|  |  | ||||||
|  |   if ((this->ignore_windows_ < 0) || !this->enabled_) { | ||||||
|  |     detection_event.detected = false; | ||||||
|  |     return detection_event; | ||||||
|  |   } | ||||||
|  |  | ||||||
|   uint32_t sum = 0; |   uint32_t sum = 0; | ||||||
|   for (auto &prob : this->recent_streaming_probabilities_) { |   for (auto &prob : this->recent_streaming_probabilities_) { | ||||||
|  |     detection_event.max_probability = std::max(detection_event.max_probability, prob); | ||||||
|     sum += prob; |     sum += prob; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   float sliding_window_average = static_cast<float>(sum) / static_cast<float>(255 * this->sliding_window_size_); |   detection_event.average_probability = sum / this->sliding_window_size_; | ||||||
|  |   detection_event.detected = sum > this->probability_cutoff_ * this->sliding_window_size_; | ||||||
|  |  | ||||||
|   // Detect the wake word if the sliding window average is above the cutoff |   this->unprocessed_probability_status_ = false; | ||||||
|   if (sliding_window_average > this->probability_cutoff_) { |   return detection_event; | ||||||
|     ESP_LOGD(TAG, "The '%s' model sliding average probability is %.3f and most recent probability is %.3f", |  | ||||||
|              this->wake_word_.c_str(), sliding_window_average, |  | ||||||
|              this->recent_streaming_probabilities_[this->last_n_index_] / (255.0)); |  | ||||||
|     return true; |  | ||||||
|   } |  | ||||||
|   return false; |  | ||||||
| } | } | ||||||
|  |  | ||||||
| VADModel::VADModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size, | VADModel::VADModel(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size, | ||||||
|                    size_t tensor_arena_size) { |                    size_t tensor_arena_size) { | ||||||
|   this->model_start_ = model_start; |   this->model_start_ = model_start; | ||||||
|   this->probability_cutoff_ = probability_cutoff; |   this->probability_cutoff_ = probability_cutoff; | ||||||
|   this->sliding_window_size_ = sliding_window_size; |   this->sliding_window_size_ = sliding_window_size; | ||||||
|   this->recent_streaming_probabilities_.resize(sliding_window_size, 0); |   this->recent_streaming_probabilities_.resize(sliding_window_size, 0); | ||||||
|   this->tensor_arena_size_ = tensor_arena_size; |   this->tensor_arena_size_ = tensor_arena_size; | ||||||
| }; |   this->register_streaming_ops_(this->streaming_op_resolver_); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | DetectionEvent VADModel::determine_detected() { | ||||||
|  |   DetectionEvent detection_event; | ||||||
|  |   detection_event.max_probability = 0; | ||||||
|  |   detection_event.average_probability = 0; | ||||||
|  |  | ||||||
|  |   if (!this->enabled_) { | ||||||
|  |     // We disabled the VAD model for some reason... so we shouldn't block wake words from being detected | ||||||
|  |     detection_event.detected = true; | ||||||
|  |     return detection_event; | ||||||
|  |   } | ||||||
|  |  | ||||||
| bool VADModel::determine_detected() { |  | ||||||
|   uint32_t sum = 0; |   uint32_t sum = 0; | ||||||
|   for (auto &prob : this->recent_streaming_probabilities_) { |   for (auto &prob : this->recent_streaming_probabilities_) { | ||||||
|  |     detection_event.max_probability = std::max(detection_event.max_probability, prob); | ||||||
|     sum += prob; |     sum += prob; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   float sliding_window_average = static_cast<float>(sum) / static_cast<float>(255 * this->sliding_window_size_); |   detection_event.average_probability = sum / this->sliding_window_size_; | ||||||
|  |   detection_event.detected = sum > (this->probability_cutoff_ * this->sliding_window_size_); | ||||||
|  |  | ||||||
|   return sliding_window_average > this->probability_cutoff_; |   return detection_event; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | bool StreamingModel::register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver) { | ||||||
|  |   if (op_resolver.AddCallOnce() != kTfLiteOk) | ||||||
|  |     return false; | ||||||
|  |   if (op_resolver.AddVarHandle() != kTfLiteOk) | ||||||
|  |     return false; | ||||||
|  |   if (op_resolver.AddReshape() != kTfLiteOk) | ||||||
|  |     return false; | ||||||
|  |   if (op_resolver.AddReadVariable() != kTfLiteOk) | ||||||
|  |     return false; | ||||||
|  |   if (op_resolver.AddStridedSlice() != kTfLiteOk) | ||||||
|  |     return false; | ||||||
|  |   if (op_resolver.AddConcatenation() != kTfLiteOk) | ||||||
|  |     return false; | ||||||
|  |   if (op_resolver.AddAssignVariable() != kTfLiteOk) | ||||||
|  |     return false; | ||||||
|  |   if (op_resolver.AddConv2D() != kTfLiteOk) | ||||||
|  |     return false; | ||||||
|  |   if (op_resolver.AddMul() != kTfLiteOk) | ||||||
|  |     return false; | ||||||
|  |   if (op_resolver.AddAdd() != kTfLiteOk) | ||||||
|  |     return false; | ||||||
|  |   if (op_resolver.AddMean() != kTfLiteOk) | ||||||
|  |     return false; | ||||||
|  |   if (op_resolver.AddFullyConnected() != kTfLiteOk) | ||||||
|  |     return false; | ||||||
|  |   if (op_resolver.AddLogistic() != kTfLiteOk) | ||||||
|  |     return false; | ||||||
|  |   if (op_resolver.AddQuantize() != kTfLiteOk) | ||||||
|  |     return false; | ||||||
|  |   if (op_resolver.AddDepthwiseConv2D() != kTfLiteOk) | ||||||
|  |     return false; | ||||||
|  |   if (op_resolver.AddAveragePool2D() != kTfLiteOk) | ||||||
|  |     return false; | ||||||
|  |   if (op_resolver.AddMaxPool2D() != kTfLiteOk) | ||||||
|  |     return false; | ||||||
|  |   if (op_resolver.AddPad() != kTfLiteOk) | ||||||
|  |     return false; | ||||||
|  |   if (op_resolver.AddPack() != kTfLiteOk) | ||||||
|  |     return false; | ||||||
|  |   if (op_resolver.AddSplitV() != kTfLiteOk) | ||||||
|  |     return false; | ||||||
|  |  | ||||||
|  |   return true; | ||||||
| } | } | ||||||
|  |  | ||||||
| }  // namespace micro_wake_word | }  // namespace micro_wake_word | ||||||
|   | |||||||
| @@ -4,6 +4,8 @@ | |||||||
|  |  | ||||||
| #include "preprocessor_settings.h" | #include "preprocessor_settings.h" | ||||||
|  |  | ||||||
|  | #include "esphome/core/preferences.h" | ||||||
|  |  | ||||||
| #include <tensorflow/lite/core/c/common.h> | #include <tensorflow/lite/core/c/common.h> | ||||||
| #include <tensorflow/lite/micro/micro_interpreter.h> | #include <tensorflow/lite/micro/micro_interpreter.h> | ||||||
| #include <tensorflow/lite/micro/micro_mutable_op_resolver.h> | #include <tensorflow/lite/micro/micro_mutable_op_resolver.h> | ||||||
| @@ -11,30 +13,63 @@ | |||||||
| namespace esphome { | namespace esphome { | ||||||
| namespace micro_wake_word { | namespace micro_wake_word { | ||||||
|  |  | ||||||
|  | static const uint8_t MIN_SLICES_BEFORE_DETECTION = 100; | ||||||
| static const uint32_t STREAMING_MODEL_VARIABLE_ARENA_SIZE = 1024; | static const uint32_t STREAMING_MODEL_VARIABLE_ARENA_SIZE = 1024; | ||||||
|  |  | ||||||
|  | struct DetectionEvent { | ||||||
|  |   std::string *wake_word; | ||||||
|  |   bool detected; | ||||||
|  |   bool partially_detection;  // Set if the most recent probability exceed the threshold, but the sliding window average | ||||||
|  |                              // hasn't yet | ||||||
|  |   uint8_t max_probability; | ||||||
|  |   uint8_t average_probability; | ||||||
|  |   bool blocked_by_vad = false; | ||||||
|  | }; | ||||||
|  |  | ||||||
| class StreamingModel { | class StreamingModel { | ||||||
|  public: |  public: | ||||||
|   virtual void log_model_config() = 0; |   virtual void log_model_config() = 0; | ||||||
|   virtual bool determine_detected() = 0; |   virtual DetectionEvent determine_detected() = 0; | ||||||
|  |  | ||||||
|  |   // Performs inference on the given features. | ||||||
|  |   //  - If the model is enabled but not loaded, it will load it | ||||||
|  |   //  - If the model is disabled but loaded, it will unload it | ||||||
|  |   // Returns true if sucessful or false if there is an error | ||||||
|   bool perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE]); |   bool perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE]); | ||||||
|  |  | ||||||
|   /// @brief Sets all recent_streaming_probabilities to 0 |   /// @brief Sets all recent_streaming_probabilities to 0 and resets the ignore window count | ||||||
|   void reset_probabilities(); |   void reset_probabilities(); | ||||||
|  |  | ||||||
|   /// @brief Allocates tensor and variable arenas and sets up the model interpreter |  | ||||||
|   /// @param op_resolver MicroMutableOpResolver object that must exist until the model is unloaded |  | ||||||
|   /// @return True if successful, false otherwise |  | ||||||
|   bool load_model(tflite::MicroMutableOpResolver<20> &op_resolver); |  | ||||||
|  |  | ||||||
|   /// @brief Destroys the TFLite interpreter and frees the tensor and variable arenas' memory |   /// @brief Destroys the TFLite interpreter and frees the tensor and variable arenas' memory | ||||||
|   void unload_model(); |   void unload_model(); | ||||||
|  |  | ||||||
|  protected: |   /// @brief Enable the model. The next performing_streaming_inference call will load it. | ||||||
|   uint8_t current_stride_step_{0}; |   virtual void enable() { this->enabled_ = true; } | ||||||
|  |  | ||||||
|   float probability_cutoff_; |   /// @brief Disable the model. The next performing_streaming_inference call will unload it. | ||||||
|  |   virtual void disable() { this->enabled_ = false; } | ||||||
|  |  | ||||||
|  |   /// @brief Return true if the model is enabled. | ||||||
|  |   bool is_enabled() { return this->enabled_; } | ||||||
|  |  | ||||||
|  |   bool get_unprocessed_probability_status() { return this->unprocessed_probability_status_; } | ||||||
|  |  | ||||||
|  |  protected: | ||||||
|  |   /// @brief Allocates tensor and variable arenas and sets up the model interpreter | ||||||
|  |   /// @return True if successful, false otherwise | ||||||
|  |   bool load_model_(); | ||||||
|  |   /// @brief Returns true if successfully registered the streaming model's TensorFlow operations | ||||||
|  |   bool register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver); | ||||||
|  |  | ||||||
|  |   tflite::MicroMutableOpResolver<20> streaming_op_resolver_; | ||||||
|  |  | ||||||
|  |   bool loaded_{false}; | ||||||
|  |   bool enabled_{true}; | ||||||
|  |   bool unprocessed_probability_status_{false}; | ||||||
|  |   uint8_t current_stride_step_{0}; | ||||||
|  |   int16_t ignore_windows_{-MIN_SLICES_BEFORE_DETECTION}; | ||||||
|  |  | ||||||
|  |   uint8_t probability_cutoff_;  // Quantized probability cutoff mapping 0.0 - 1.0 to 0 - 255 | ||||||
|   size_t sliding_window_size_; |   size_t sliding_window_size_; | ||||||
|   size_t last_n_index_{0}; |   size_t last_n_index_{0}; | ||||||
|   size_t tensor_arena_size_; |   size_t tensor_arena_size_; | ||||||
| @@ -50,32 +85,62 @@ class StreamingModel { | |||||||
|  |  | ||||||
| class WakeWordModel final : public StreamingModel { | class WakeWordModel final : public StreamingModel { | ||||||
|  public: |  public: | ||||||
|   WakeWordModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_average_size, |   /// @brief Constructs a wake word model object | ||||||
|                 const std::string &wake_word, size_t tensor_arena_size); |   /// @param id (std::string) identifier for this model | ||||||
|  |   /// @param model_start (const uint8_t *) pointer to the start of the model's TFLite FlatBuffer | ||||||
|  |   /// @param probability_cutoff (uint8_t) probability cutoff for acceping the wake word has been said | ||||||
|  |   /// @param sliding_window_average_size (size_t) the length of the sliding window computing the mean rolling | ||||||
|  |   ///                                    probability | ||||||
|  |   /// @param wake_word (std::string) Friendly name of the wake word | ||||||
|  |   /// @param tensor_arena_size (size_t) Size in bytes for allocating the tensor arena | ||||||
|  |   /// @param default_enabled (bool) If true, it will be enabled by default on first boot | ||||||
|  |   /// @param internal_only (bool) If true, the model will not be exposed to HomeAssistant as an available model | ||||||
|  |   WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t probability_cutoff, | ||||||
|  |                 size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size, | ||||||
|  |                 bool default_enabled, bool internal_only); | ||||||
|  |  | ||||||
|   void log_model_config() override; |   void log_model_config() override; | ||||||
|  |  | ||||||
|   /// @brief Checks for the wake word by comparing the mean probability in the sliding window with the probability |   /// @brief Checks for the wake word by comparing the mean probability in the sliding window with the probability | ||||||
|   /// cutoff |   /// cutoff | ||||||
|   /// @return True if wake word is detected, false otherwise |   /// @return True if wake word is detected, false otherwise | ||||||
|   bool determine_detected() override; |   DetectionEvent determine_detected() override; | ||||||
|  |  | ||||||
|  |   const std::string &get_id() const { return this->id_; } | ||||||
|   const std::string &get_wake_word() const { return this->wake_word_; } |   const std::string &get_wake_word() const { return this->wake_word_; } | ||||||
|  |  | ||||||
|  |   void add_trained_language(const std::string &language) { this->trained_languages_.push_back(language); } | ||||||
|  |   const std::vector<std::string> &get_trained_languages() const { return this->trained_languages_; } | ||||||
|  |  | ||||||
|  |   /// @brief Enable the model and save to flash. The next performing_streaming_inference call will load it. | ||||||
|  |   void enable() override; | ||||||
|  |  | ||||||
|  |   /// @brief Disable the model and save to flash. The next performing_streaming_inference call will unload it. | ||||||
|  |   void disable() override; | ||||||
|  |  | ||||||
|  |   bool get_internal_only() { return this->internal_only_; } | ||||||
|  |  | ||||||
|  protected: |  protected: | ||||||
|  |   std::string id_; | ||||||
|   std::string wake_word_; |   std::string wake_word_; | ||||||
|  |   std::vector<std::string> trained_languages_; | ||||||
|  |  | ||||||
|  |   bool internal_only_; | ||||||
|  |  | ||||||
|  |   ESPPreferenceObject pref_; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| class VADModel final : public StreamingModel { | class VADModel final : public StreamingModel { | ||||||
|  public: |  public: | ||||||
|   VADModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size); |   VADModel(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size, | ||||||
|  |            size_t tensor_arena_size); | ||||||
|  |  | ||||||
|   void log_model_config() override; |   void log_model_config() override; | ||||||
|  |  | ||||||
|   /// @brief Checks for voice activity by comparing the max probability in the sliding window with the probability |   /// @brief Checks for voice activity by comparing the max probability in the sliding window with the probability | ||||||
|   /// cutoff |   /// cutoff | ||||||
|   /// @return True if voice activity is detected, false otherwise |   /// @return True if voice activity is detected, false otherwise | ||||||
|   bool determine_detected() override; |   DetectionEvent determine_detected() override; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| }  // namespace micro_wake_word | }  // namespace micro_wake_word | ||||||
|   | |||||||
| @@ -79,6 +79,7 @@ | |||||||
| #define USE_LVGL_TEXTAREA | #define USE_LVGL_TEXTAREA | ||||||
| #define USE_LVGL_TILEVIEW | #define USE_LVGL_TILEVIEW | ||||||
| #define USE_LVGL_TOUCHSCREEN | #define USE_LVGL_TOUCHSCREEN | ||||||
|  | #define USE_MICRO_WAKE_WORD | ||||||
| #define USE_MD5 | #define USE_MD5 | ||||||
| #define USE_MDNS | #define USE_MDNS | ||||||
| #define USE_MEDIA_PLAYER | #define USE_MEDIA_PLAYER | ||||||
|   | |||||||
| @@ -14,8 +14,24 @@ micro_wake_word: | |||||||
|   microphone: echo_microphone |   microphone: echo_microphone | ||||||
|   on_wake_word_detected: |   on_wake_word_detected: | ||||||
|     - logger.log: "Wake word detected" |     - logger.log: "Wake word detected" | ||||||
|  |     - micro_wake_word.stop: | ||||||
|  |     - if: | ||||||
|  |         condition: | ||||||
|  |           - micro_wake_word.model_is_enabled: hey_jarvis_model | ||||||
|  |         then: | ||||||
|  |           - micro_wake_word.disable_model: hey_jarvis_model | ||||||
|  |         else: | ||||||
|  |           - micro_wake_word.enable_model: hey_jarvis_model | ||||||
|  |     - if: | ||||||
|  |         condition: | ||||||
|  |           - not: | ||||||
|  |               - micro_wake_word.is_running: | ||||||
|  |         then: | ||||||
|  |           micro_wake_word.start: | ||||||
|  |   stop_after_detection: false | ||||||
|   models: |   models: | ||||||
|     - model: hey_jarvis |     - model: hey_jarvis | ||||||
|       probability_cutoff: 0.7 |       probability_cutoff: 0.7 | ||||||
|  |       id: hey_jarvis_model | ||||||
|     - model: okay_nabu |     - model: okay_nabu | ||||||
|       sliding_window_size: 5 |       sliding_window_size: 5 | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user