feat: diarisation maybe

This commit is contained in:
2026-01-29 15:34:04 -08:00
parent 9efda8ded6
commit 4ed7462a17
5 changed files with 432 additions and 107 deletions
+256 -47
View File
@@ -1,10 +1,12 @@
//! Voice Activity Detection for basic speaker separation.
//! Enhanced Voice Activity Detection for speaker separation in group meetings.
//!
//! This module uses the Silero VAD model to detect speech segments
//! and provides basic speaker separation based on silence gaps.
//! This module uses the Silero VAD model combined with voice characteristics
//! to provide improved speaker separation for meetings with multiple speakers.
use rustfft::{FftPlanner, num_complex::Complex};
use std::collections::HashMap;
use thiserror::Error;
use tracing::{debug, info};
use tracing::{debug, info, warn};
use voice_activity_detector::VoiceActivityDetector;
use super::transcriber::TranscriptSegment;
@@ -20,12 +22,41 @@ pub enum VadError {
ProcessingError(String),
}
/// A detected speech segment with timing information.
/// Voice characteristics for speaker fingerprinting.
#[derive(Debug, Clone)]
pub struct VoiceFingerprint {
/// Average pitch (fundamental frequency) in Hz
pub avg_pitch: f32,
/// Pitch variance
pub pitch_variance: f32,
/// Average energy level
pub avg_energy: f32,
/// Spectral centroid (brightness indicator)
pub spectral_centroid: f32,
}
impl VoiceFingerprint {
/// Calculate similarity between two voice fingerprints (0.0 to 1.0)
pub fn similarity(&self, other: &VoiceFingerprint) -> f32 {
// Normalize differences for each feature
let pitch_diff = (self.avg_pitch - other.avg_pitch).abs() / 200.0; // Normalize by typical pitch range
let variance_diff = (self.pitch_variance - other.pitch_variance).abs() / 50.0;
let energy_diff = (self.avg_energy - other.avg_energy).abs() / 0.5;
let centroid_diff = (self.spectral_centroid - other.spectral_centroid).abs() / 1000.0;
// Calculate weighted similarity
let diff = pitch_diff * 0.4 + variance_diff * 0.2 + energy_diff * 0.2 + centroid_diff * 0.2;
(1.0 - diff.min(1.0)).max(0.0)
}
}
/// A detected speech segment with timing information and voice characteristics.
#[derive(Debug, Clone)]
pub struct SpeechSegment {
pub start_sample: usize,
pub end_sample: usize,
pub speaker_id: u32,
pub fingerprint: Option<VoiceFingerprint>,
}
impl SpeechSegment {
@@ -45,19 +76,27 @@ impl SpeechSegment {
}
}
/// Voice Activity Detector for speaker separation.
/// Voice Activity Detector for speaker separation with enhanced tracking.
pub struct SpeakerSeparator {
vad: VoiceActivityDetector,
/// Minimum silence duration (in seconds) to consider a speaker change
min_silence_for_speaker_change: f64,
_min_silence_for_speaker_change: f64,
/// Minimum speech duration (in seconds) to consider a valid segment
min_speech_duration: f64,
/// Known speakers and their voice fingerprints
known_speakers: HashMap<u32, VoiceFingerprint>,
/// Next speaker ID to assign
next_speaker_id: u32,
/// Similarity threshold for matching speakers (0.0 to 1.0)
similarity_threshold: f32,
/// FFT planner for frequency analysis
fft_planner: FftPlanner<f32>,
}
impl SpeakerSeparator {
/// Create a new speaker separator with default settings.
pub fn new() -> Result<Self, VadError> {
Self::with_settings(1.5, 0.3)
Self::with_settings(0.8, 0.3) // Reduced silence threshold for group meetings
}
/// Create a new speaker separator with custom settings.
@@ -75,87 +114,212 @@ impl SpeakerSeparator {
.build()
.map_err(|e: voice_activity_detector::Error| VadError::InitError(e.to_string()))?;
info!("VAD initialized with sample_rate={}, min_silence={}s, min_speech={}s",
info!("Enhanced VAD initialized with sample_rate={}, min_silence={}s, min_speech={}s",
VAD_SAMPLE_RATE, min_silence_for_speaker_change, min_speech_duration);
Ok(Self {
vad,
min_silence_for_speaker_change,
_min_silence_for_speaker_change: min_silence_for_speaker_change,
min_speech_duration,
known_speakers: HashMap::new(),
next_speaker_id: 0,
similarity_threshold: 0.7,
fft_planner: FftPlanner::new(),
})
}
/// Detect speech segments and assign speaker IDs based on silence gaps.
/// Extract voice fingerprint from an audio segment.
fn extract_fingerprint(&mut self, audio: &[f32]) -> Option<VoiceFingerprint> {
if audio.len() < 2048 {
return None;
}
// Calculate average energy
let avg_energy = audio.iter().map(|s| s.abs()).sum::<f32>() / audio.len() as f32;
// Estimate pitch using autocorrelation (simplified)
let pitch = self.estimate_pitch(audio);
// Calculate spectral centroid
let spectral_centroid = self.calculate_spectral_centroid(audio);
// Calculate pitch variance over windows
let pitch_variance = self.calculate_pitch_variance(audio);
Some(VoiceFingerprint {
avg_pitch: pitch,
pitch_variance,
avg_energy,
spectral_centroid,
})
}
/// Estimate pitch using zero-crossing rate (simplified approach).
fn estimate_pitch(&self, audio: &[f32]) -> f32 {
let mut zero_crossings = 0;
let mut prev_sign = audio[0] >= 0.0;
for sample in audio.iter().skip(1) {
let current_sign = *sample >= 0.0;
if current_sign != prev_sign {
zero_crossings += 1;
}
prev_sign = current_sign;
}
// Convert zero-crossing rate to approximate frequency
let zcr = zero_crossings as f32 / audio.len() as f32;
zcr * VAD_SAMPLE_RATE as f32 / 2.0
}
/// Calculate spectral centroid (center of mass of spectrum).
fn calculate_spectral_centroid(&mut self, audio: &[f32]) -> f32 {
let fft_size = 2048.min(audio.len());
let mut input: Vec<Complex<f32>> = audio[..fft_size]
.iter()
.map(|&s| Complex::new(s, 0.0))
.collect();
let fft = self.fft_planner.plan_fft_forward(fft_size);
fft.process(&mut input);
let mut weighted_sum = 0.0;
let mut magnitude_sum = 0.0;
for (i, complex) in input.iter().enumerate().take(fft_size / 2) {
let magnitude = complex.norm();
let frequency = i as f32 * VAD_SAMPLE_RATE as f32 / fft_size as f32;
weighted_sum += frequency * magnitude;
magnitude_sum += magnitude;
}
if magnitude_sum > 0.0 {
weighted_sum / magnitude_sum
} else {
0.0
}
}
/// Calculate pitch variance across multiple windows.
fn calculate_pitch_variance(&self, audio: &[f32]) -> f32 {
let window_size = 1024;
let num_windows = audio.len() / window_size;
if num_windows < 2 {
return 0.0;
}
let mut pitches = Vec::new();
for i in 0..num_windows {
let start = i * window_size;
let end = (i + 1) * window_size;
let pitch = self.estimate_pitch(&audio[start..end]);
pitches.push(pitch);
}
// Calculate variance
let mean = pitches.iter().sum::<f32>() / pitches.len() as f32;
let variance = pitches.iter()
.map(|p| (p - mean).powi(2))
.sum::<f32>() / pitches.len() as f32;
variance.sqrt()
}
/// Find or assign speaker ID based on voice fingerprint.
fn find_or_assign_speaker(&mut self, fingerprint: &VoiceFingerprint) -> u32 {
// Try to match with known speakers
let mut best_match = None;
let mut best_similarity = 0.0;
for (speaker_id, known_fp) in &self.known_speakers {
let similarity = fingerprint.similarity(known_fp);
if similarity > best_similarity && similarity >= self.similarity_threshold {
best_similarity = similarity;
best_match = Some(*speaker_id);
}
}
if let Some(speaker_id) = best_match {
debug!("Matched to existing speaker {} with similarity {:.2}", speaker_id, best_similarity);
speaker_id
} else {
// New speaker
let speaker_id = self.next_speaker_id;
self.next_speaker_id += 1;
self.known_speakers.insert(speaker_id, fingerprint.clone());
debug!("New speaker {} detected", speaker_id);
speaker_id
}
}
/// Detect speech segments and assign speaker IDs based on voice characteristics.
///
/// This is a simple heuristic: if there's a long enough silence gap,
/// we assume a different speaker might be talking.
/// This enhanced version uses voice fingerprinting to track speakers across
/// overlapping speech and short interruptions, making it suitable for group meetings.
pub fn detect_speakers(&mut self, audio: &[f32]) -> Result<Vec<SpeechSegment>, VadError> {
let chunk_size = 512;
let mut segments = Vec::new();
let mut raw_segments = Vec::new();
let mut current_segment_start: Option<usize> = None;
let mut last_speech_end: Option<usize> = None;
let mut current_speaker_id = 0u32;
let min_silence_samples = (self.min_silence_for_speaker_change * VAD_SAMPLE_RATE as f64) as usize;
let min_speech_samples = (self.min_speech_duration * VAD_SAMPLE_RATE as f64) as usize;
debug!("Processing {} samples for VAD", audio.len());
debug!("Processing {} samples for enhanced VAD", audio.len());
// First pass: detect raw speech segments
for (chunk_idx, chunk) in audio.chunks(chunk_size).enumerate() {
if chunk.len() < chunk_size {
// Skip incomplete chunks at the end
continue;
}
let sample_offset = chunk_idx * chunk_size;
// Predict voice activity for this chunk
let probability = self.vad.predict(chunk.iter().copied());
let is_speech = probability > 0.5;
if is_speech {
if current_segment_start.is_none() {
// Check if we should change speaker
if let Some(last_end) = last_speech_end {
let silence_duration = sample_offset - last_end;
if silence_duration >= min_silence_samples {
current_speaker_id += 1;
debug!("Speaker change detected at sample {} (silence: {}ms)",
sample_offset, silence_duration * 1000 / VAD_SAMPLE_RATE as usize);
}
}
current_segment_start = Some(sample_offset);
}
} else if let Some(start) = current_segment_start {
// Speech ended
let segment_duration = sample_offset - start;
if segment_duration >= min_speech_samples {
segments.push(SpeechSegment {
start_sample: start,
end_sample: sample_offset,
speaker_id: current_speaker_id,
});
raw_segments.push((start, sample_offset));
}
last_speech_end = Some(sample_offset);
current_segment_start = None;
}
}
// Handle case where speech continues to the end
// Handle speech continuing to the end
if let Some(start) = current_segment_start {
let segment_duration = audio.len() - start;
if segment_duration >= min_speech_samples {
segments.push(SpeechSegment {
start_sample: start,
end_sample: audio.len(),
speaker_id: current_speaker_id,
});
raw_segments.push((start, audio.len()));
}
}
// Second pass: analyze voice characteristics and assign speakers
let mut segments = Vec::new();
for (start, end) in raw_segments {
let segment_audio = &audio[start..end];
// Extract voice fingerprint
let fingerprint = self.extract_fingerprint(segment_audio);
let speaker_id = if let Some(ref fp) = fingerprint {
self.find_or_assign_speaker(fp)
} else {
// Fallback: assign based on position
warn!("Could not extract fingerprint for segment at {}s", start as f64 / VAD_SAMPLE_RATE as f64);
self.next_speaker_id - 1 // Use last assigned speaker
};
segments.push(SpeechSegment {
start_sample: start,
end_sample: end,
speaker_id,
fingerprint,
});
}
info!("Detected {} speech segments with {} speakers",
segments.len(),
segments.iter().map(|s| s.speaker_id).max().unwrap_or(0) + 1);
@@ -188,6 +352,51 @@ impl SpeakerSeparator {
Ok(transcript)
}
/// Apply speaker labels to a chunk of transcript segments (for real-time processing).
///
/// This method is designed to work with the real-time transcription system,
/// maintaining speaker consistency across chunks.
pub fn apply_speaker_labels_to_chunk(
&mut self,
audio: &[f32],
mut transcript: Vec<TranscriptSegment>,
_chunk_start_time: f64,
) -> Result<Vec<TranscriptSegment>, VadError> {
// Detect speakers in this chunk
let speech_segments = self.detect_speakers(audio)?;
for segment in &mut transcript {
// Adjust times relative to chunk start
let segment_mid = (segment.start + segment.end) / 2.0;
// Find matching speech segment
for speech in &speech_segments {
let speech_start = speech.start_seconds();
let speech_end = speech.end_seconds();
if segment_mid >= speech_start && segment_mid <= speech_end {
// Use consistent speaker numbering across chunks
segment.speaker = format!("Speaker {}", speech.speaker_id + 1);
break;
}
}
}
// Clean up old speakers if we have too many (helps with memory in long meetings)
if self.known_speakers.len() > 10 {
warn!("Many speakers detected ({}), consider adjusting similarity threshold", self.known_speakers.len());
}
Ok(transcript)
}
/// Reset speaker tracking (useful between different meetings).
pub fn reset_speakers(&mut self) {
self.known_speakers.clear();
self.next_speaker_id = 0;
info!("Speaker tracking reset");
}
}
impl Default for SpeakerSeparator {