generated from nhcarrigan/template
425 lines
15 KiB
Rust
425 lines
15 KiB
Rust
//! Enhanced Voice Activity Detection for speaker separation in group meetings.
|
|
//!
|
|
//! This module uses the Silero VAD model combined with voice characteristics
|
|
//! to provide improved speaker separation for meetings with multiple speakers.
|
|
|
|
use rustfft::{FftPlanner, num_complex::Complex};
|
|
use std::collections::HashMap;
|
|
use thiserror::Error;
|
|
use tracing::{debug, info, warn};
|
|
use voice_activity_detector::VoiceActivityDetector;
|
|
|
|
use super::transcriber::TranscriptSegment;
|
|
|
|
/// Sample rate expected by the VAD model (16kHz)
|
|
pub const VAD_SAMPLE_RATE: u32 = 16000;
|
|
|
|
#[derive(Error, Debug)]
|
|
pub enum VadError {
|
|
#[error("Failed to create VAD: {0}")]
|
|
InitError(String),
|
|
#[error("VAD processing failed: {0}")]
|
|
ProcessingError(String),
|
|
}
|
|
|
|
/// Voice characteristics for speaker fingerprinting.
|
|
#[derive(Debug, Clone)]
|
|
pub struct VoiceFingerprint {
|
|
/// Average pitch (fundamental frequency) in Hz
|
|
pub avg_pitch: f32,
|
|
/// Pitch variance
|
|
pub pitch_variance: f32,
|
|
/// Average energy level
|
|
pub avg_energy: f32,
|
|
/// Spectral centroid (brightness indicator)
|
|
pub spectral_centroid: f32,
|
|
}
|
|
|
|
impl VoiceFingerprint {
|
|
/// Calculate similarity between two voice fingerprints (0.0 to 1.0)
|
|
pub fn similarity(&self, other: &VoiceFingerprint) -> f32 {
|
|
// Normalize differences for each feature
|
|
let pitch_diff = (self.avg_pitch - other.avg_pitch).abs() / 200.0; // Normalize by typical pitch range
|
|
let variance_diff = (self.pitch_variance - other.pitch_variance).abs() / 50.0;
|
|
let energy_diff = (self.avg_energy - other.avg_energy).abs() / 0.5;
|
|
let centroid_diff = (self.spectral_centroid - other.spectral_centroid).abs() / 1000.0;
|
|
|
|
// Calculate weighted similarity
|
|
let diff = pitch_diff * 0.4 + variance_diff * 0.2 + energy_diff * 0.2 + centroid_diff * 0.2;
|
|
(1.0 - diff.min(1.0)).max(0.0)
|
|
}
|
|
}
|
|
|
|
/// A detected speech segment with timing information and voice characteristics.
|
|
#[derive(Debug, Clone)]
|
|
pub struct SpeechSegment {
|
|
pub start_sample: usize,
|
|
pub end_sample: usize,
|
|
pub speaker_id: u32,
|
|
pub fingerprint: Option<VoiceFingerprint>,
|
|
}
|
|
|
|
impl SpeechSegment {
|
|
/// Get the start time in seconds.
|
|
pub fn start_seconds(&self) -> f64 {
|
|
self.start_sample as f64 / VAD_SAMPLE_RATE as f64
|
|
}
|
|
|
|
/// Get the end time in seconds.
|
|
pub fn end_seconds(&self) -> f64 {
|
|
self.end_sample as f64 / VAD_SAMPLE_RATE as f64
|
|
}
|
|
|
|
/// Get the duration in seconds.
|
|
pub fn duration_seconds(&self) -> f64 {
|
|
self.end_seconds() - self.start_seconds()
|
|
}
|
|
}
|
|
|
|
/// Voice Activity Detector for speaker separation with enhanced tracking.
|
|
pub struct SpeakerSeparator {
|
|
vad: VoiceActivityDetector,
|
|
/// Minimum silence duration (in seconds) to consider a speaker change
|
|
_min_silence_for_speaker_change: f64,
|
|
/// Minimum speech duration (in seconds) to consider a valid segment
|
|
min_speech_duration: f64,
|
|
/// Known speakers and their voice fingerprints
|
|
known_speakers: HashMap<u32, VoiceFingerprint>,
|
|
/// Next speaker ID to assign
|
|
next_speaker_id: u32,
|
|
/// Similarity threshold for matching speakers (0.0 to 1.0)
|
|
similarity_threshold: f32,
|
|
/// FFT planner for frequency analysis
|
|
fft_planner: FftPlanner<f32>,
|
|
}
|
|
|
|
impl SpeakerSeparator {
|
|
/// Create a new speaker separator with default settings.
|
|
pub fn new() -> Result<Self, VadError> {
|
|
Self::with_settings(0.8, 0.3) // Reduced silence threshold for group meetings
|
|
}
|
|
|
|
/// Create a new speaker separator with custom settings.
|
|
///
|
|
/// # Arguments
|
|
/// * `min_silence_for_speaker_change` - Minimum silence duration (seconds) to consider a speaker change
|
|
/// * `min_speech_duration` - Minimum speech duration (seconds) to consider a valid segment
|
|
pub fn with_settings(
|
|
min_silence_for_speaker_change: f64,
|
|
min_speech_duration: f64,
|
|
) -> Result<Self, VadError> {
|
|
let vad = VoiceActivityDetector::builder()
|
|
.sample_rate(VAD_SAMPLE_RATE)
|
|
.chunk_size(512usize) // 512 samples for 16kHz
|
|
.build()
|
|
.map_err(|e: voice_activity_detector::Error| VadError::InitError(e.to_string()))?;
|
|
|
|
info!("Enhanced VAD initialized with sample_rate={}, min_silence={}s, min_speech={}s",
|
|
VAD_SAMPLE_RATE, min_silence_for_speaker_change, min_speech_duration);
|
|
|
|
Ok(Self {
|
|
vad,
|
|
_min_silence_for_speaker_change: min_silence_for_speaker_change,
|
|
min_speech_duration,
|
|
known_speakers: HashMap::new(),
|
|
next_speaker_id: 0,
|
|
similarity_threshold: 0.7,
|
|
fft_planner: FftPlanner::new(),
|
|
})
|
|
}
|
|
|
|
/// Extract voice fingerprint from an audio segment.
|
|
fn extract_fingerprint(&mut self, audio: &[f32]) -> Option<VoiceFingerprint> {
|
|
if audio.len() < 2048 {
|
|
return None;
|
|
}
|
|
|
|
// Calculate average energy
|
|
let avg_energy = audio.iter().map(|s| s.abs()).sum::<f32>() / audio.len() as f32;
|
|
|
|
// Estimate pitch using autocorrelation (simplified)
|
|
let pitch = self.estimate_pitch(audio);
|
|
|
|
// Calculate spectral centroid
|
|
let spectral_centroid = self.calculate_spectral_centroid(audio);
|
|
|
|
// Calculate pitch variance over windows
|
|
let pitch_variance = self.calculate_pitch_variance(audio);
|
|
|
|
Some(VoiceFingerprint {
|
|
avg_pitch: pitch,
|
|
pitch_variance,
|
|
avg_energy,
|
|
spectral_centroid,
|
|
})
|
|
}
|
|
|
|
/// Estimate pitch using zero-crossing rate (simplified approach).
|
|
fn estimate_pitch(&self, audio: &[f32]) -> f32 {
|
|
let mut zero_crossings = 0;
|
|
let mut prev_sign = audio[0] >= 0.0;
|
|
|
|
for sample in audio.iter().skip(1) {
|
|
let current_sign = *sample >= 0.0;
|
|
if current_sign != prev_sign {
|
|
zero_crossings += 1;
|
|
}
|
|
prev_sign = current_sign;
|
|
}
|
|
|
|
// Convert zero-crossing rate to approximate frequency
|
|
let zcr = zero_crossings as f32 / audio.len() as f32;
|
|
zcr * VAD_SAMPLE_RATE as f32 / 2.0
|
|
}
|
|
|
|
/// Calculate spectral centroid (center of mass of spectrum).
|
|
fn calculate_spectral_centroid(&mut self, audio: &[f32]) -> f32 {
|
|
let fft_size = 2048.min(audio.len());
|
|
let mut input: Vec<Complex<f32>> = audio[..fft_size]
|
|
.iter()
|
|
.map(|&s| Complex::new(s, 0.0))
|
|
.collect();
|
|
|
|
let fft = self.fft_planner.plan_fft_forward(fft_size);
|
|
fft.process(&mut input);
|
|
|
|
let mut weighted_sum = 0.0;
|
|
let mut magnitude_sum = 0.0;
|
|
|
|
for (i, complex) in input.iter().enumerate().take(fft_size / 2) {
|
|
let magnitude = complex.norm();
|
|
let frequency = i as f32 * VAD_SAMPLE_RATE as f32 / fft_size as f32;
|
|
weighted_sum += frequency * magnitude;
|
|
magnitude_sum += magnitude;
|
|
}
|
|
|
|
if magnitude_sum > 0.0 {
|
|
weighted_sum / magnitude_sum
|
|
} else {
|
|
0.0
|
|
}
|
|
}
|
|
|
|
/// Calculate pitch variance across multiple windows.
|
|
fn calculate_pitch_variance(&self, audio: &[f32]) -> f32 {
|
|
let window_size = 1024;
|
|
let num_windows = audio.len() / window_size;
|
|
|
|
if num_windows < 2 {
|
|
return 0.0;
|
|
}
|
|
|
|
let mut pitches = Vec::new();
|
|
for i in 0..num_windows {
|
|
let start = i * window_size;
|
|
let end = (i + 1) * window_size;
|
|
let pitch = self.estimate_pitch(&audio[start..end]);
|
|
pitches.push(pitch);
|
|
}
|
|
|
|
// Calculate variance
|
|
let mean = pitches.iter().sum::<f32>() / pitches.len() as f32;
|
|
let variance = pitches.iter()
|
|
.map(|p| (p - mean).powi(2))
|
|
.sum::<f32>() / pitches.len() as f32;
|
|
|
|
variance.sqrt()
|
|
}
|
|
|
|
/// Find or assign speaker ID based on voice fingerprint.
|
|
fn find_or_assign_speaker(&mut self, fingerprint: &VoiceFingerprint) -> u32 {
|
|
// Try to match with known speakers
|
|
let mut best_match = None;
|
|
let mut best_similarity = 0.0;
|
|
|
|
for (speaker_id, known_fp) in &self.known_speakers {
|
|
let similarity = fingerprint.similarity(known_fp);
|
|
if similarity > best_similarity && similarity >= self.similarity_threshold {
|
|
best_similarity = similarity;
|
|
best_match = Some(*speaker_id);
|
|
}
|
|
}
|
|
|
|
if let Some(speaker_id) = best_match {
|
|
debug!("Matched to existing speaker {} with similarity {:.2}", speaker_id, best_similarity);
|
|
speaker_id
|
|
} else {
|
|
// New speaker
|
|
let speaker_id = self.next_speaker_id;
|
|
self.next_speaker_id += 1;
|
|
self.known_speakers.insert(speaker_id, fingerprint.clone());
|
|
debug!("New speaker {} detected", speaker_id);
|
|
speaker_id
|
|
}
|
|
}
|
|
|
|
/// Detect speech segments and assign speaker IDs based on voice characteristics.
|
|
///
|
|
/// This enhanced version uses voice fingerprinting to track speakers across
|
|
/// overlapping speech and short interruptions, making it suitable for group meetings.
|
|
pub fn detect_speakers(&mut self, audio: &[f32]) -> Result<Vec<SpeechSegment>, VadError> {
|
|
let chunk_size = 512;
|
|
let mut raw_segments = Vec::new();
|
|
let mut current_segment_start: Option<usize> = None;
|
|
let min_speech_samples = (self.min_speech_duration * VAD_SAMPLE_RATE as f64) as usize;
|
|
|
|
debug!("Processing {} samples for enhanced VAD", audio.len());
|
|
|
|
// First pass: detect raw speech segments
|
|
for (chunk_idx, chunk) in audio.chunks(chunk_size).enumerate() {
|
|
if chunk.len() < chunk_size {
|
|
continue;
|
|
}
|
|
|
|
let sample_offset = chunk_idx * chunk_size;
|
|
let probability = self.vad.predict(chunk.iter().copied());
|
|
let is_speech = probability > 0.5;
|
|
|
|
if is_speech {
|
|
if current_segment_start.is_none() {
|
|
current_segment_start = Some(sample_offset);
|
|
}
|
|
} else if let Some(start) = current_segment_start {
|
|
let segment_duration = sample_offset - start;
|
|
if segment_duration >= min_speech_samples {
|
|
raw_segments.push((start, sample_offset));
|
|
}
|
|
current_segment_start = None;
|
|
}
|
|
}
|
|
|
|
// Handle speech continuing to the end
|
|
if let Some(start) = current_segment_start {
|
|
let segment_duration = audio.len() - start;
|
|
if segment_duration >= min_speech_samples {
|
|
raw_segments.push((start, audio.len()));
|
|
}
|
|
}
|
|
|
|
// Second pass: analyze voice characteristics and assign speakers
|
|
let mut segments = Vec::new();
|
|
for (start, end) in raw_segments {
|
|
let segment_audio = &audio[start..end];
|
|
|
|
// Extract voice fingerprint
|
|
let fingerprint = self.extract_fingerprint(segment_audio);
|
|
|
|
let speaker_id = if let Some(ref fp) = fingerprint {
|
|
self.find_or_assign_speaker(fp)
|
|
} else {
|
|
// Fallback: assign based on position
|
|
warn!("Could not extract fingerprint for segment at {}s", start as f64 / VAD_SAMPLE_RATE as f64);
|
|
self.next_speaker_id - 1 // Use last assigned speaker
|
|
};
|
|
|
|
segments.push(SpeechSegment {
|
|
start_sample: start,
|
|
end_sample: end,
|
|
speaker_id,
|
|
fingerprint,
|
|
});
|
|
}
|
|
|
|
info!("Detected {} speech segments with {} speakers",
|
|
segments.len(),
|
|
segments.iter().map(|s| s.speaker_id).max().unwrap_or(0) + 1);
|
|
|
|
Ok(segments)
|
|
}
|
|
|
|
/// Apply speaker labels to transcript segments based on VAD results.
|
|
pub fn apply_speaker_labels(
|
|
&mut self,
|
|
audio: &[f32],
|
|
mut transcript: Vec<TranscriptSegment>,
|
|
) -> Result<Vec<TranscriptSegment>, VadError> {
|
|
let speech_segments = self.detect_speakers(audio)?;
|
|
|
|
for segment in &mut transcript {
|
|
// Find the speech segment that overlaps with this transcript segment
|
|
let segment_mid = (segment.start + segment.end) / 2.0;
|
|
|
|
for speech in &speech_segments {
|
|
let speech_start = speech.start_seconds();
|
|
let speech_end = speech.end_seconds();
|
|
|
|
if segment_mid >= speech_start && segment_mid <= speech_end {
|
|
segment.speaker = format!("Speaker {}", speech.speaker_id + 1);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(transcript)
|
|
}
|
|
|
|
/// Apply speaker labels to a chunk of transcript segments (for real-time processing).
|
|
///
|
|
/// This method is designed to work with the real-time transcription system,
|
|
/// maintaining speaker consistency across chunks.
|
|
pub fn apply_speaker_labels_to_chunk(
|
|
&mut self,
|
|
audio: &[f32],
|
|
mut transcript: Vec<TranscriptSegment>,
|
|
_chunk_start_time: f64,
|
|
) -> Result<Vec<TranscriptSegment>, VadError> {
|
|
// Detect speakers in this chunk
|
|
let speech_segments = self.detect_speakers(audio)?;
|
|
|
|
for segment in &mut transcript {
|
|
// Adjust times relative to chunk start
|
|
let segment_mid = (segment.start + segment.end) / 2.0;
|
|
|
|
// Find matching speech segment
|
|
for speech in &speech_segments {
|
|
let speech_start = speech.start_seconds();
|
|
let speech_end = speech.end_seconds();
|
|
|
|
if segment_mid >= speech_start && segment_mid <= speech_end {
|
|
// Use consistent speaker numbering across chunks
|
|
segment.speaker = format!("Speaker {}", speech.speaker_id + 1);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Clean up old speakers if we have too many (helps with memory in long meetings)
|
|
if self.known_speakers.len() > 10 {
|
|
warn!("Many speakers detected ({}), consider adjusting similarity threshold", self.known_speakers.len());
|
|
}
|
|
|
|
Ok(transcript)
|
|
}
|
|
|
|
/// Reset speaker tracking (useful between different meetings).
|
|
pub fn reset_speakers(&mut self) {
|
|
self.known_speakers.clear();
|
|
self.next_speaker_id = 0;
|
|
info!("Speaker tracking reset");
|
|
}
|
|
}
|
|
|
|
impl Default for SpeakerSeparator {
|
|
fn default() -> Self {
|
|
Self::new().expect("Failed to create speaker separator")
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_speech_segment_timing() {
|
|
let segment = SpeechSegment {
|
|
start_sample: 16000,
|
|
end_sample: 32000,
|
|
speaker_id: 0,
|
|
};
|
|
|
|
assert!((segment.start_seconds() - 1.0).abs() < 0.001);
|
|
assert!((segment.end_seconds() - 2.0).abs() < 0.001);
|
|
assert!((segment.duration_seconds() - 1.0).abs() < 0.001);
|
|
}
|
|
}
|