//! Enhanced Voice Activity Detection for speaker separation in group meetings. //! //! This module uses the Silero VAD model combined with voice characteristics //! to provide improved speaker separation for meetings with multiple speakers. use rustfft::{FftPlanner, num_complex::Complex}; use std::collections::HashMap; use thiserror::Error; use tracing::{debug, info, warn}; use voice_activity_detector::VoiceActivityDetector; use super::transcriber::TranscriptSegment; /// Sample rate expected by the VAD model (16kHz) pub const VAD_SAMPLE_RATE: u32 = 16000; #[derive(Error, Debug)] pub enum VadError { #[error("Failed to create VAD: {0}")] InitError(String), #[error("VAD processing failed: {0}")] ProcessingError(String), } /// Voice characteristics for speaker fingerprinting. #[derive(Debug, Clone)] pub struct VoiceFingerprint { /// Average pitch (fundamental frequency) in Hz pub avg_pitch: f32, /// Pitch variance pub pitch_variance: f32, /// Average energy level pub avg_energy: f32, /// Spectral centroid (brightness indicator) pub spectral_centroid: f32, } impl VoiceFingerprint { /// Calculate similarity between two voice fingerprints (0.0 to 1.0) pub fn similarity(&self, other: &VoiceFingerprint) -> f32 { // Normalize differences for each feature let pitch_diff = (self.avg_pitch - other.avg_pitch).abs() / 200.0; // Normalize by typical pitch range let variance_diff = (self.pitch_variance - other.pitch_variance).abs() / 50.0; let energy_diff = (self.avg_energy - other.avg_energy).abs() / 0.5; let centroid_diff = (self.spectral_centroid - other.spectral_centroid).abs() / 1000.0; // Calculate weighted similarity let diff = pitch_diff * 0.4 + variance_diff * 0.2 + energy_diff * 0.2 + centroid_diff * 0.2; (1.0 - diff.min(1.0)).max(0.0) } } /// A detected speech segment with timing information and voice characteristics. #[derive(Debug, Clone)] pub struct SpeechSegment { pub start_sample: usize, pub end_sample: usize, pub speaker_id: u32, pub fingerprint: Option, } impl SpeechSegment { /// Get the start time in seconds. pub fn start_seconds(&self) -> f64 { self.start_sample as f64 / VAD_SAMPLE_RATE as f64 } /// Get the end time in seconds. pub fn end_seconds(&self) -> f64 { self.end_sample as f64 / VAD_SAMPLE_RATE as f64 } /// Get the duration in seconds. pub fn duration_seconds(&self) -> f64 { self.end_seconds() - self.start_seconds() } } /// Voice Activity Detector for speaker separation with enhanced tracking. pub struct SpeakerSeparator { vad: VoiceActivityDetector, /// Minimum silence duration (in seconds) to consider a speaker change _min_silence_for_speaker_change: f64, /// Minimum speech duration (in seconds) to consider a valid segment min_speech_duration: f64, /// Known speakers and their voice fingerprints known_speakers: HashMap, /// Next speaker ID to assign next_speaker_id: u32, /// Similarity threshold for matching speakers (0.0 to 1.0) similarity_threshold: f32, /// FFT planner for frequency analysis fft_planner: FftPlanner, } impl SpeakerSeparator { /// Create a new speaker separator with default settings. pub fn new() -> Result { Self::with_settings(0.8, 0.3) // Reduced silence threshold for group meetings } /// Create a new speaker separator with custom settings. /// /// # Arguments /// * `min_silence_for_speaker_change` - Minimum silence duration (seconds) to consider a speaker change /// * `min_speech_duration` - Minimum speech duration (seconds) to consider a valid segment pub fn with_settings( min_silence_for_speaker_change: f64, min_speech_duration: f64, ) -> Result { let vad = VoiceActivityDetector::builder() .sample_rate(VAD_SAMPLE_RATE) .chunk_size(512usize) // 512 samples for 16kHz .build() .map_err(|e: voice_activity_detector::Error| VadError::InitError(e.to_string()))?; info!("Enhanced VAD initialized with sample_rate={}, min_silence={}s, min_speech={}s", VAD_SAMPLE_RATE, min_silence_for_speaker_change, min_speech_duration); Ok(Self { vad, _min_silence_for_speaker_change: min_silence_for_speaker_change, min_speech_duration, known_speakers: HashMap::new(), next_speaker_id: 0, similarity_threshold: 0.7, fft_planner: FftPlanner::new(), }) } /// Extract voice fingerprint from an audio segment. fn extract_fingerprint(&mut self, audio: &[f32]) -> Option { if audio.len() < 2048 { return None; } // Calculate average energy let avg_energy = audio.iter().map(|s| s.abs()).sum::() / audio.len() as f32; // Estimate pitch using autocorrelation (simplified) let pitch = self.estimate_pitch(audio); // Calculate spectral centroid let spectral_centroid = self.calculate_spectral_centroid(audio); // Calculate pitch variance over windows let pitch_variance = self.calculate_pitch_variance(audio); Some(VoiceFingerprint { avg_pitch: pitch, pitch_variance, avg_energy, spectral_centroid, }) } /// Estimate pitch using zero-crossing rate (simplified approach). fn estimate_pitch(&self, audio: &[f32]) -> f32 { let mut zero_crossings = 0; let mut prev_sign = audio[0] >= 0.0; for sample in audio.iter().skip(1) { let current_sign = *sample >= 0.0; if current_sign != prev_sign { zero_crossings += 1; } prev_sign = current_sign; } // Convert zero-crossing rate to approximate frequency let zcr = zero_crossings as f32 / audio.len() as f32; zcr * VAD_SAMPLE_RATE as f32 / 2.0 } /// Calculate spectral centroid (center of mass of spectrum). fn calculate_spectral_centroid(&mut self, audio: &[f32]) -> f32 { let fft_size = 2048.min(audio.len()); let mut input: Vec> = audio[..fft_size] .iter() .map(|&s| Complex::new(s, 0.0)) .collect(); let fft = self.fft_planner.plan_fft_forward(fft_size); fft.process(&mut input); let mut weighted_sum = 0.0; let mut magnitude_sum = 0.0; for (i, complex) in input.iter().enumerate().take(fft_size / 2) { let magnitude = complex.norm(); let frequency = i as f32 * VAD_SAMPLE_RATE as f32 / fft_size as f32; weighted_sum += frequency * magnitude; magnitude_sum += magnitude; } if magnitude_sum > 0.0 { weighted_sum / magnitude_sum } else { 0.0 } } /// Calculate pitch variance across multiple windows. fn calculate_pitch_variance(&self, audio: &[f32]) -> f32 { let window_size = 1024; let num_windows = audio.len() / window_size; if num_windows < 2 { return 0.0; } let mut pitches = Vec::new(); for i in 0..num_windows { let start = i * window_size; let end = (i + 1) * window_size; let pitch = self.estimate_pitch(&audio[start..end]); pitches.push(pitch); } // Calculate variance let mean = pitches.iter().sum::() / pitches.len() as f32; let variance = pitches.iter() .map(|p| (p - mean).powi(2)) .sum::() / pitches.len() as f32; variance.sqrt() } /// Find or assign speaker ID based on voice fingerprint. fn find_or_assign_speaker(&mut self, fingerprint: &VoiceFingerprint) -> u32 { // Try to match with known speakers let mut best_match = None; let mut best_similarity = 0.0; for (speaker_id, known_fp) in &self.known_speakers { let similarity = fingerprint.similarity(known_fp); if similarity > best_similarity && similarity >= self.similarity_threshold { best_similarity = similarity; best_match = Some(*speaker_id); } } if let Some(speaker_id) = best_match { debug!("Matched to existing speaker {} with similarity {:.2}", speaker_id, best_similarity); speaker_id } else { // New speaker let speaker_id = self.next_speaker_id; self.next_speaker_id += 1; self.known_speakers.insert(speaker_id, fingerprint.clone()); debug!("New speaker {} detected", speaker_id); speaker_id } } /// Detect speech segments and assign speaker IDs based on voice characteristics. /// /// This enhanced version uses voice fingerprinting to track speakers across /// overlapping speech and short interruptions, making it suitable for group meetings. pub fn detect_speakers(&mut self, audio: &[f32]) -> Result, VadError> { let chunk_size = 512; let mut raw_segments = Vec::new(); let mut current_segment_start: Option = None; let min_speech_samples = (self.min_speech_duration * VAD_SAMPLE_RATE as f64) as usize; debug!("Processing {} samples for enhanced VAD", audio.len()); // First pass: detect raw speech segments for (chunk_idx, chunk) in audio.chunks(chunk_size).enumerate() { if chunk.len() < chunk_size { continue; } let sample_offset = chunk_idx * chunk_size; let probability = self.vad.predict(chunk.iter().copied()); let is_speech = probability > 0.5; if is_speech { if current_segment_start.is_none() { current_segment_start = Some(sample_offset); } } else if let Some(start) = current_segment_start { let segment_duration = sample_offset - start; if segment_duration >= min_speech_samples { raw_segments.push((start, sample_offset)); } current_segment_start = None; } } // Handle speech continuing to the end if let Some(start) = current_segment_start { let segment_duration = audio.len() - start; if segment_duration >= min_speech_samples { raw_segments.push((start, audio.len())); } } // Second pass: analyze voice characteristics and assign speakers let mut segments = Vec::new(); for (start, end) in raw_segments { let segment_audio = &audio[start..end]; // Extract voice fingerprint let fingerprint = self.extract_fingerprint(segment_audio); let speaker_id = if let Some(ref fp) = fingerprint { self.find_or_assign_speaker(fp) } else { // Fallback: assign based on position warn!("Could not extract fingerprint for segment at {}s", start as f64 / VAD_SAMPLE_RATE as f64); self.next_speaker_id - 1 // Use last assigned speaker }; segments.push(SpeechSegment { start_sample: start, end_sample: end, speaker_id, fingerprint, }); } info!("Detected {} speech segments with {} speakers", segments.len(), segments.iter().map(|s| s.speaker_id).max().unwrap_or(0) + 1); Ok(segments) } /// Apply speaker labels to transcript segments based on VAD results. pub fn apply_speaker_labels( &mut self, audio: &[f32], mut transcript: Vec, ) -> Result, VadError> { let speech_segments = self.detect_speakers(audio)?; for segment in &mut transcript { // Find the speech segment that overlaps with this transcript segment let segment_mid = (segment.start + segment.end) / 2.0; for speech in &speech_segments { let speech_start = speech.start_seconds(); let speech_end = speech.end_seconds(); if segment_mid >= speech_start && segment_mid <= speech_end { segment.speaker = format!("Speaker {}", speech.speaker_id + 1); break; } } } Ok(transcript) } /// Apply speaker labels to a chunk of transcript segments (for real-time processing). /// /// This method is designed to work with the real-time transcription system, /// maintaining speaker consistency across chunks. pub fn apply_speaker_labels_to_chunk( &mut self, audio: &[f32], mut transcript: Vec, _chunk_start_time: f64, ) -> Result, VadError> { // Detect speakers in this chunk let speech_segments = self.detect_speakers(audio)?; for segment in &mut transcript { // Adjust times relative to chunk start let segment_mid = (segment.start + segment.end) / 2.0; // Find matching speech segment for speech in &speech_segments { let speech_start = speech.start_seconds(); let speech_end = speech.end_seconds(); if segment_mid >= speech_start && segment_mid <= speech_end { // Use consistent speaker numbering across chunks segment.speaker = format!("Speaker {}", speech.speaker_id + 1); break; } } } // Clean up old speakers if we have too many (helps with memory in long meetings) if self.known_speakers.len() > 10 { warn!("Many speakers detected ({}), consider adjusting similarity threshold", self.known_speakers.len()); } Ok(transcript) } /// Reset speaker tracking (useful between different meetings). pub fn reset_speakers(&mut self) { self.known_speakers.clear(); self.next_speaker_id = 0; info!("Speaker tracking reset"); } } impl Default for SpeakerSeparator { fn default() -> Self { Self::new().expect("Failed to create speaker separator") } } #[cfg(test)] mod tests { use super::*; #[test] fn test_speech_segment_timing() { let segment = SpeechSegment { start_sample: 16000, end_sample: 32000, speaker_id: 0, }; assert!((segment.start_seconds() - 1.0).abs() < 0.001); assert!((segment.end_seconds() - 2.0).abs() < 0.001); assert!((segment.duration_seconds() - 1.0).abs() < 0.001); } }