mirror of
				https://github.com/esphome/esphome.git
				synced 2025-10-31 15:12:06 +00:00 
			
		
		
		
	microWakeWord - add new ops and small improvements (#6360)
This commit is contained in:
		
				
					committed by
					
						 Jesse Hills
						Jesse Hills
					
				
			
			
				
	
			
			
			
						parent
						
							d121fa5d05
						
					
				
				
					commit
					9e378189c3
				
			| @@ -93,11 +93,18 @@ int MicroWakeWord::read_microphone_() { | |||||||
|     return 0; |     return 0; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   size_t bytes_written = this->ring_buffer_->write((void *) this->input_buffer_, bytes_read); |   size_t bytes_free = this->ring_buffer_->free(); | ||||||
|   if (bytes_written != bytes_read) { |  | ||||||
|     ESP_LOGW(TAG, "Failed to write some data to ring buffer (written=%d, expected=%d)", bytes_written, bytes_read); |   if (bytes_free < bytes_read) { | ||||||
|  |     ESP_LOGW(TAG, | ||||||
|  |              "Not enough free bytes in ring buffer to store incoming audio data (free bytes=%d, incoming bytes=%d). " | ||||||
|  |              "Resetting the ring buffer. Wake word detection accuracy will be reduced.", | ||||||
|  |              bytes_free, bytes_read); | ||||||
|  |  | ||||||
|  |     this->ring_buffer_->reset(); | ||||||
|   } |   } | ||||||
|   return bytes_written; |  | ||||||
|  |   return this->ring_buffer_->write((void *) this->input_buffer_, bytes_read); | ||||||
| } | } | ||||||
|  |  | ||||||
| void MicroWakeWord::loop() { | void MicroWakeWord::loop() { | ||||||
| @@ -206,12 +213,6 @@ bool MicroWakeWord::initialize_models() { | |||||||
|     return false; |     return false; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   this->preprocessor_stride_buffer_ = audio_samples_allocator.allocate(HISTORY_SAMPLES_TO_KEEP); |  | ||||||
|   if (this->preprocessor_stride_buffer_ == nullptr) { |  | ||||||
|     ESP_LOGE(TAG, "Could not allocate the audio preprocessor's stride buffer."); |  | ||||||
|     return false; |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   this->preprocessor_model_ = tflite::GetModel(G_AUDIO_PREPROCESSOR_INT8_TFLITE); |   this->preprocessor_model_ = tflite::GetModel(G_AUDIO_PREPROCESSOR_INT8_TFLITE); | ||||||
|   if (this->preprocessor_model_->version() != TFLITE_SCHEMA_VERSION) { |   if (this->preprocessor_model_->version() != TFLITE_SCHEMA_VERSION) { | ||||||
|     ESP_LOGE(TAG, "Wake word's audio preprocessor model's schema is not supported"); |     ESP_LOGE(TAG, "Wake word's audio preprocessor model's schema is not supported"); | ||||||
| @@ -225,7 +226,7 @@ bool MicroWakeWord::initialize_models() { | |||||||
|   } |   } | ||||||
|  |  | ||||||
|   static tflite::MicroMutableOpResolver<18> preprocessor_op_resolver; |   static tflite::MicroMutableOpResolver<18> preprocessor_op_resolver; | ||||||
|   static tflite::MicroMutableOpResolver<14> streaming_op_resolver; |   static tflite::MicroMutableOpResolver<17> streaming_op_resolver; | ||||||
|  |  | ||||||
|   if (!this->register_preprocessor_ops_(preprocessor_op_resolver)) |   if (!this->register_preprocessor_ops_(preprocessor_op_resolver)) | ||||||
|     return false; |     return false; | ||||||
| @@ -329,7 +330,6 @@ bool MicroWakeWord::detect_wake_word_() { | |||||||
|   } |   } | ||||||
|  |  | ||||||
|   // Perform inference |   // Perform inference | ||||||
|   uint32_t streaming_size = micros(); |  | ||||||
|   float streaming_prob = this->perform_streaming_inference_(); |   float streaming_prob = this->perform_streaming_inference_(); | ||||||
|  |  | ||||||
|   // Add the most recent probability to the sliding window |   // Add the most recent probability to the sliding window | ||||||
| @@ -357,6 +357,9 @@ bool MicroWakeWord::detect_wake_word_() { | |||||||
|     for (auto &prob : this->recent_streaming_probabilities_) { |     for (auto &prob : this->recent_streaming_probabilities_) { | ||||||
|       prob = 0; |       prob = 0; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     ESP_LOGD(TAG, "Wake word sliding average probability is %.3f and most recent probability is %.3f", | ||||||
|  |              sliding_window_average, streaming_prob); | ||||||
|     return true; |     return true; | ||||||
|   } |   } | ||||||
|  |  | ||||||
| @@ -371,23 +374,6 @@ void MicroWakeWord::set_sliding_window_average_size(size_t size) { | |||||||
| bool MicroWakeWord::slice_available_() { | bool MicroWakeWord::slice_available_() { | ||||||
|   size_t available = this->ring_buffer_->available(); |   size_t available = this->ring_buffer_->available(); | ||||||
|  |  | ||||||
|   size_t free = this->ring_buffer_->free(); |  | ||||||
|  |  | ||||||
|   if (free < NEW_SAMPLES_TO_GET * sizeof(int16_t)) { |  | ||||||
|     // If the ring buffer is within one audio slice of being full, then wake word detection will have issues. |  | ||||||
|     // If this is constantly occuring, then some possibilities why are |  | ||||||
|     //  1) there are too many other slow components configured |  | ||||||
|     //  2) the ESP32 isn't fast enough; e.g., an ESP32 is much slower than an ESP32-S3 at inferences. |  | ||||||
|     //  3) the model is too large |  | ||||||
|     //  4) the model uses operations that are not optimized |  | ||||||
|     ESP_LOGW(TAG, |  | ||||||
|              "Audio buffer is nearly full. Wake word detection may be less accurate and have slower reponse times. " |  | ||||||
| #if !defined(USE_ESP32_VARIANT_ESP32S3) |  | ||||||
|              "microWakeWord is designed for the ESP32-S3. The current platform is too slow for this model." |  | ||||||
| #endif |  | ||||||
|     ); |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   return available > (NEW_SAMPLES_TO_GET * sizeof(int16_t)); |   return available > (NEW_SAMPLES_TO_GET * sizeof(int16_t)); | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -396,13 +382,12 @@ bool MicroWakeWord::stride_audio_samples_(int16_t **audio_samples) { | |||||||
|     return false; |     return false; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   // Copy 320 bytes (160 samples over 10 ms) into preprocessor_audio_buffer_ from history in |   // Copy the last 320 bytes (160 samples over 10 ms) from the audio buffer to the start of the audio buffer | ||||||
|   // preprocessor_stride_buffer_ |   memcpy((void *) (this->preprocessor_audio_buffer_), (void *) (this->preprocessor_audio_buffer_ + NEW_SAMPLES_TO_GET), | ||||||
|   memcpy((void *) (this->preprocessor_audio_buffer_), (void *) (this->preprocessor_stride_buffer_), |  | ||||||
|          HISTORY_SAMPLES_TO_KEEP * sizeof(int16_t)); |          HISTORY_SAMPLES_TO_KEEP * sizeof(int16_t)); | ||||||
|  |  | ||||||
|   // Copy 640 bytes (320 samples over 20 ms) from the ring buffer |   // Copy 640 bytes (320 samples over 20 ms) from the ring buffer into the audio buffer offset 320 bytes (160 samples | ||||||
|   // The first 320 bytes (160 samples over 10 ms) will be from history |   // over 10 ms) | ||||||
|   size_t bytes_read = this->ring_buffer_->read((void *) (this->preprocessor_audio_buffer_ + HISTORY_SAMPLES_TO_KEEP), |   size_t bytes_read = this->ring_buffer_->read((void *) (this->preprocessor_audio_buffer_ + HISTORY_SAMPLES_TO_KEEP), | ||||||
|                                                NEW_SAMPLES_TO_GET * sizeof(int16_t), pdMS_TO_TICKS(200)); |                                                NEW_SAMPLES_TO_GET * sizeof(int16_t), pdMS_TO_TICKS(200)); | ||||||
|  |  | ||||||
| @@ -415,11 +400,6 @@ bool MicroWakeWord::stride_audio_samples_(int16_t **audio_samples) { | |||||||
|     return false; |     return false; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   // Copy the last 320 bytes (160 samples over 10 ms) from the audio buffer into history stride buffer for the next |  | ||||||
|   // iteration |  | ||||||
|   memcpy((void *) (this->preprocessor_stride_buffer_), (void *) (this->preprocessor_audio_buffer_ + NEW_SAMPLES_TO_GET), |  | ||||||
|          HISTORY_SAMPLES_TO_KEEP * sizeof(int16_t)); |  | ||||||
|  |  | ||||||
|   *audio_samples = this->preprocessor_audio_buffer_; |   *audio_samples = this->preprocessor_audio_buffer_; | ||||||
|   return true; |   return true; | ||||||
| } | } | ||||||
| @@ -480,7 +460,7 @@ bool MicroWakeWord::register_preprocessor_ops_(tflite::MicroMutableOpResolver<18 | |||||||
|   return true; |   return true; | ||||||
| } | } | ||||||
|  |  | ||||||
| bool MicroWakeWord::register_streaming_ops_(tflite::MicroMutableOpResolver<14> &op_resolver) { | bool MicroWakeWord::register_streaming_ops_(tflite::MicroMutableOpResolver<17> &op_resolver) { | ||||||
|   if (op_resolver.AddCallOnce() != kTfLiteOk) |   if (op_resolver.AddCallOnce() != kTfLiteOk) | ||||||
|     return false; |     return false; | ||||||
|   if (op_resolver.AddVarHandle() != kTfLiteOk) |   if (op_resolver.AddVarHandle() != kTfLiteOk) | ||||||
| @@ -509,6 +489,12 @@ bool MicroWakeWord::register_streaming_ops_(tflite::MicroMutableOpResolver<14> & | |||||||
|     return false; |     return false; | ||||||
|   if (op_resolver.AddQuantize() != kTfLiteOk) |   if (op_resolver.AddQuantize() != kTfLiteOk) | ||||||
|     return false; |     return false; | ||||||
|  |   if (op_resolver.AddDepthwiseConv2D() != kTfLiteOk) | ||||||
|  |     return false; | ||||||
|  |   if (op_resolver.AddAveragePool2D() != kTfLiteOk) | ||||||
|  |     return false; | ||||||
|  |   if (op_resolver.AddMaxPool2D() != kTfLiteOk) | ||||||
|  |     return false; | ||||||
|  |  | ||||||
|   return true; |   return true; | ||||||
| } | } | ||||||
|   | |||||||
| @@ -128,7 +128,6 @@ class MicroWakeWord : public Component { | |||||||
|  |  | ||||||
|   // Stores audio fed into feature generator preprocessor |   // Stores audio fed into feature generator preprocessor | ||||||
|   int16_t *preprocessor_audio_buffer_; |   int16_t *preprocessor_audio_buffer_; | ||||||
|   int16_t *preprocessor_stride_buffer_; |  | ||||||
|  |  | ||||||
|   bool detected_{false}; |   bool detected_{false}; | ||||||
|  |  | ||||||
| @@ -181,7 +180,7 @@ class MicroWakeWord : public Component { | |||||||
|   bool register_preprocessor_ops_(tflite::MicroMutableOpResolver<18> &op_resolver); |   bool register_preprocessor_ops_(tflite::MicroMutableOpResolver<18> &op_resolver); | ||||||
|  |  | ||||||
|   /// @brief Returns true if successfully registered the streaming model's TensorFlow operations |   /// @brief Returns true if successfully registered the streaming model's TensorFlow operations | ||||||
|   bool register_streaming_ops_(tflite::MicroMutableOpResolver<14> &op_resolver); |   bool register_streaming_ops_(tflite::MicroMutableOpResolver<17> &op_resolver); | ||||||
| }; | }; | ||||||
|  |  | ||||||
| template<typename... Ts> class StartAction : public Action<Ts...>, public Parented<MicroWakeWord> { | template<typename... Ts> class StartAction : public Action<Ts...>, public Parented<MicroWakeWord> { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user