generated from nhcarrigan/template
feat: we successfully have the installer working for windows!
Models are downloaded at runtime instead of build.
This commit is contained in:
Generated
+1129
-22
File diff suppressed because it is too large
Load Diff
+29
-1
@@ -18,11 +18,35 @@ crate-type = ["staticlib", "cdylib", "rlib"]
|
||||
tauri-build = { version = "2", features = [] }
|
||||
|
||||
[dependencies]
|
||||
tauri = { version = "2", features = [] }
|
||||
tauri = { version = "2", features = ["custom-protocol"] }
|
||||
tauri-plugin-opener = "2"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
|
||||
# ML libraries for transcription and summarization
|
||||
whisper-rs = "0.14"
|
||||
llama-cpp-2 = "0.1"
|
||||
|
||||
# Audio capture and processing
|
||||
cpal = "0.15"
|
||||
hound = "3.5" # WAV file handling
|
||||
|
||||
# Voice activity detection
|
||||
voice_activity_detector = "0.2"
|
||||
|
||||
# Async runtime
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
|
||||
# Utilities
|
||||
parking_lot = "0.12" # Better mutexes
|
||||
thiserror = "2" # Error handling
|
||||
tracing = "0.1" # Logging
|
||||
tracing-subscriber = "0.3"
|
||||
|
||||
# HTTP client for model downloads
|
||||
reqwest = { version = "0.12", features = ["stream", "rustls-tls"], default-features = false }
|
||||
futures-util = "0.3" # For StreamExt
|
||||
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
windows = { version = "0.62", features = [
|
||||
"Win32_System_Com",
|
||||
@@ -32,3 +56,7 @@ windows = { version = "0.62", features = [
|
||||
"Win32_System_Threading",
|
||||
] }
|
||||
|
||||
[patch.crates-io]
|
||||
# Fix cross-compilation: use CARGO_CFG_TARGET_OS instead of cfg!(windows)
|
||||
llama-cpp-sys-2 = { path = "../patches/llama-cpp-sys-2" }
|
||||
|
||||
|
||||
+434
-44
@@ -1,68 +1,458 @@
|
||||
use std::process::{Child, Command};
|
||||
use std::sync::Mutex;
|
||||
use tauri::{Manager, State};
|
||||
//! Chronara - Meeting transcription and summarization using local AI models.
|
||||
//!
|
||||
//! This is a pure Rust backend using:
|
||||
//! - whisper-rs for transcription
|
||||
//! - llama-cpp-2 for summarization
|
||||
//! - voice_activity_detector for speaker separation
|
||||
|
||||
struct PythonBackend {
|
||||
process: Mutex<Option<Child>>,
|
||||
use parking_lot::Mutex;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tauri::{Emitter, Manager, State};
|
||||
use tracing::{debug, info};
|
||||
|
||||
pub mod ml;
|
||||
|
||||
use ml::summarizer::{get_model_filename, LlamaSummarizer};
|
||||
use ml::transcriber::{TranscriptSegment, WhisperTranscriber};
|
||||
use ml::vad::SpeakerSeparator;
|
||||
use ml::audio::AudioCapture;
|
||||
|
||||
/// Application state containing the ML models and audio capture.
|
||||
struct AppState {
|
||||
transcriber: Mutex<WhisperTranscriber>,
|
||||
summarizer: Mutex<Option<LlamaSummarizer>>,
|
||||
speaker_separator: Mutex<Option<SpeakerSeparator>>,
|
||||
audio_capture: Mutex<Option<AudioCapture>>,
|
||||
logs: Arc<Mutex<Vec<String>>>,
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
fn start_backend(backend: State<PythonBackend>) -> Result<String, String> {
|
||||
let mut process_lock = backend.process.lock().map_err(|e| e.to_string())?;
|
||||
impl AppState {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
transcriber: Mutex::new(WhisperTranscriber::new()),
|
||||
summarizer: Mutex::new(None),
|
||||
speaker_separator: Mutex::new(None),
|
||||
audio_capture: Mutex::new(None),
|
||||
logs: Arc::new(Mutex::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if process_lock.is_some() {
|
||||
return Ok("Backend already running".to_string());
|
||||
/// Emit a log message to the frontend.
|
||||
fn emit_log(app_handle: &tauri::AppHandle, logs: &Arc<Mutex<Vec<String>>>, message: &str) {
|
||||
{
|
||||
let mut logs_guard = logs.lock();
|
||||
logs_guard.push(message.to_string());
|
||||
if logs_guard.len() > 100 {
|
||||
logs_guard.remove(0);
|
||||
}
|
||||
}
|
||||
let _ = app_handle.emit("backend-log", message);
|
||||
info!("{}", message);
|
||||
}
|
||||
|
||||
/// Get the models directory based on environment.
|
||||
fn get_models_dir(app_handle: &tauri::AppHandle) -> PathBuf {
|
||||
// Production mode - use app data directory (user-writable)
|
||||
// On Windows: %APPDATA%\com.chronara.app\models
|
||||
// On macOS: ~/Library/Application Support/com.chronara.app/models
|
||||
// On Linux: ~/.local/share/com.chronara.app/models
|
||||
if let Ok(app_data_dir) = app_handle.path().app_data_dir() {
|
||||
return app_data_dir.join("models");
|
||||
}
|
||||
|
||||
// Get the resource path for the bundled Python executable
|
||||
let python_cmd = if cfg!(windows) {
|
||||
"python"
|
||||
// Fallback: Development mode - use project models directory
|
||||
let current_dir = std::env::current_dir().unwrap_or_default();
|
||||
let project_root = if current_dir.ends_with("src-tauri") {
|
||||
current_dir.parent().unwrap().to_path_buf()
|
||||
} else {
|
||||
"python3"
|
||||
current_dir
|
||||
};
|
||||
project_root.join("models")
|
||||
}
|
||||
|
||||
/// Check if the required models exist.
|
||||
#[tauri::command]
|
||||
fn check_models(app_handle: tauri::AppHandle) -> Result<bool, String> {
|
||||
let models_dir = get_models_dir(&app_handle);
|
||||
|
||||
// Check for LLaMA model
|
||||
let llama_model = models_dir.join(get_model_filename("3B"));
|
||||
let llama_exists = llama_model.exists();
|
||||
|
||||
// Check for Whisper model
|
||||
let whisper_model = models_dir.join("whisper").join("ggml-base.bin");
|
||||
let whisper_exists = whisper_model.exists();
|
||||
|
||||
debug!(
|
||||
"Models check: llama={} ({}), whisper={} ({})",
|
||||
llama_exists,
|
||||
llama_model.display(),
|
||||
whisper_exists,
|
||||
whisper_model.display()
|
||||
);
|
||||
|
||||
// Both models are required
|
||||
Ok(llama_exists && whisper_exists)
|
||||
}
|
||||
|
||||
/// Download a file from a URL with progress tracking.
|
||||
async fn download_file(
|
||||
app_handle: &tauri::AppHandle,
|
||||
logs: &Arc<Mutex<Vec<String>>>,
|
||||
url: &str,
|
||||
dest_path: &std::path::Path,
|
||||
model_name: &str,
|
||||
) -> Result<(), String> {
|
||||
use futures_util::StreamExt;
|
||||
use std::io::Write;
|
||||
|
||||
emit_log(app_handle, logs, &format!("[Models] Downloading {}...", model_name));
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.redirect(reqwest::redirect::Policy::limited(10))
|
||||
.timeout(std::time::Duration::from_secs(3600))
|
||||
.connect_timeout(std::time::Duration::from_secs(30))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||
|
||||
let response = client
|
||||
.get(url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to start download: {}", e))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(format!("Download failed with status: {}", response.status()));
|
||||
}
|
||||
|
||||
let total_size = response.content_length().unwrap_or(0);
|
||||
let total_mb = total_size as f64 / 1_048_576.0;
|
||||
|
||||
emit_log(app_handle, logs, &format!("[Models] {} size: {:.1} MB", model_name, total_mb));
|
||||
|
||||
// Ensure parent directory exists
|
||||
if let Some(parent) = dest_path.parent() {
|
||||
std::fs::create_dir_all(parent)
|
||||
.map_err(|e| format!("Failed to create directory: {}", e))?;
|
||||
}
|
||||
|
||||
// Download to a temp file first
|
||||
let temp_path = dest_path.with_extension("downloading");
|
||||
let mut file = std::fs::File::create(&temp_path)
|
||||
.map_err(|e| format!("Failed to create temp file: {}", e))?;
|
||||
|
||||
let mut downloaded: u64 = 0;
|
||||
let mut last_progress_percent: u64 = 0;
|
||||
let mut stream = response.bytes_stream();
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result.map_err(|e| format!("Download error: {}", e))?;
|
||||
file.write_all(&chunk)
|
||||
.map_err(|e| format!("Failed to write to file: {}", e))?;
|
||||
|
||||
downloaded += chunk.len() as u64;
|
||||
|
||||
if total_size > 0 {
|
||||
let progress_percent = (downloaded * 100) / total_size;
|
||||
if progress_percent >= last_progress_percent + 10 {
|
||||
last_progress_percent = progress_percent;
|
||||
let downloaded_mb = downloaded as f64 / 1_048_576.0;
|
||||
emit_log(
|
||||
app_handle,
|
||||
logs,
|
||||
&format!("[Models] {}: {:.1} MB / {:.1} MB ({}%)", model_name, downloaded_mb, total_mb, progress_percent),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
file.flush().map_err(|e| format!("Failed to flush file: {}", e))?;
|
||||
drop(file);
|
||||
|
||||
std::fs::rename(&temp_path, dest_path)
|
||||
.map_err(|e| format!("Failed to finalize download: {}", e))?;
|
||||
|
||||
emit_log(app_handle, logs, &format!("[Models] {} download complete!", model_name));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Download the required models.
|
||||
#[tauri::command]
|
||||
async fn download_models(
|
||||
state: State<'_, AppState>,
|
||||
app_handle: tauri::AppHandle,
|
||||
) -> Result<String, String> {
|
||||
let logs = Arc::clone(&state.logs);
|
||||
let models_dir = get_models_dir(&app_handle);
|
||||
|
||||
std::fs::create_dir_all(&models_dir)
|
||||
.map_err(|e| format!("Failed to create models directory: {}", e))?;
|
||||
|
||||
// Download LLaMA model if needed
|
||||
let llama_model = models_dir.join(get_model_filename("3B"));
|
||||
if !llama_model.exists() {
|
||||
let llama_url = "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q4_K_M.gguf";
|
||||
download_file(&app_handle, &logs, llama_url, &llama_model, "LLaMA 3.2 3B (~2GB)").await?;
|
||||
} else {
|
||||
emit_log(&app_handle, &logs, "[Models] LLaMA model already present");
|
||||
}
|
||||
|
||||
// Download Whisper model if needed
|
||||
let whisper_dir = models_dir.join("whisper");
|
||||
let whisper_model = whisper_dir.join("ggml-base.bin");
|
||||
if !whisper_model.exists() {
|
||||
let whisper_url = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.bin";
|
||||
download_file(&app_handle, &logs, whisper_url, &whisper_model, "Whisper Base (~142MB)").await?;
|
||||
} else {
|
||||
emit_log(&app_handle, &logs, "[Models] Whisper model already present");
|
||||
}
|
||||
|
||||
emit_log(&app_handle, &logs, "[Models] All models ready!");
|
||||
Ok("Models downloaded successfully".to_string())
|
||||
}
|
||||
|
||||
/// Initialize the ML models.
|
||||
#[tauri::command]
|
||||
async fn initialize_models(
|
||||
state: State<'_, AppState>,
|
||||
app_handle: tauri::AppHandle,
|
||||
) -> Result<String, String> {
|
||||
let logs = Arc::clone(&state.logs);
|
||||
let models_dir = get_models_dir(&app_handle);
|
||||
|
||||
emit_log(&app_handle, &logs, "[Init] Initializing ML models...");
|
||||
|
||||
// Initialize LLaMA summarizer
|
||||
let llama_model_path = models_dir.join(get_model_filename("3B"));
|
||||
if llama_model_path.exists() {
|
||||
emit_log(&app_handle, &logs, "[Init] Loading LLaMA model...");
|
||||
|
||||
match LlamaSummarizer::new() {
|
||||
Ok(mut summarizer) => {
|
||||
if let Err(e) = summarizer.load_model(&llama_model_path) {
|
||||
emit_log(&app_handle, &logs, &format!("[Init ERROR] Failed to load LLaMA: {}", e));
|
||||
} else {
|
||||
*state.summarizer.lock() = Some(summarizer);
|
||||
emit_log(&app_handle, &logs, "[Init] LLaMA model loaded successfully");
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
emit_log(&app_handle, &logs, &format!("[Init ERROR] Failed to create summarizer: {}", e));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
emit_log(&app_handle, &logs, "[Init WARNING] LLaMA model not found, summarization disabled");
|
||||
}
|
||||
|
||||
// Initialize Whisper transcriber (lazy load on first use)
|
||||
let whisper_model_path = models_dir.join("whisper").join("ggml-base.bin");
|
||||
if whisper_model_path.exists() {
|
||||
emit_log(&app_handle, &logs, "[Init] Loading Whisper model...");
|
||||
|
||||
let mut transcriber = state.transcriber.lock();
|
||||
if let Err(e) = transcriber.load_model(&whisper_model_path) {
|
||||
emit_log(&app_handle, &logs, &format!("[Init ERROR] Failed to load Whisper: {}", e));
|
||||
} else {
|
||||
emit_log(&app_handle, &logs, "[Init] Whisper model loaded successfully");
|
||||
}
|
||||
} else {
|
||||
emit_log(&app_handle, &logs, "[Init] Whisper model not found, will download on first transcription");
|
||||
}
|
||||
|
||||
// Initialize VAD for speaker separation
|
||||
emit_log(&app_handle, &logs, "[Init] Initializing voice activity detector...");
|
||||
match SpeakerSeparator::new() {
|
||||
Ok(separator) => {
|
||||
*state.speaker_separator.lock() = Some(separator);
|
||||
emit_log(&app_handle, &logs, "[Init] VAD initialized successfully");
|
||||
}
|
||||
Err(e) => {
|
||||
emit_log(&app_handle, &logs, &format!("[Init WARNING] VAD initialization failed: {}", e));
|
||||
}
|
||||
}
|
||||
|
||||
emit_log(&app_handle, &logs, "[Init] Model initialization complete");
|
||||
Ok("Models initialized".to_string())
|
||||
}
|
||||
|
||||
/// Start recording audio.
|
||||
#[tauri::command]
|
||||
fn start_recording(
|
||||
state: State<'_, AppState>,
|
||||
app_handle: tauri::AppHandle,
|
||||
) -> Result<String, String> {
|
||||
let logs = Arc::clone(&state.logs);
|
||||
|
||||
emit_log(&app_handle, &logs, "[Audio] Starting recording...");
|
||||
|
||||
let mut audio_guard = state.audio_capture.lock();
|
||||
|
||||
// Create audio capture if not exists
|
||||
if audio_guard.is_none() {
|
||||
match AudioCapture::new() {
|
||||
Ok(capture) => {
|
||||
*audio_guard = Some(capture);
|
||||
}
|
||||
Err(e) => {
|
||||
let msg = format!("[Audio ERROR] Failed to create audio capture: {}", e);
|
||||
emit_log(&app_handle, &logs, &msg);
|
||||
return Err(msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Start recording
|
||||
if let Some(ref mut capture) = *audio_guard {
|
||||
if let Err(e) = capture.start_recording() {
|
||||
let msg = format!("[Audio ERROR] Failed to start recording: {}", e);
|
||||
emit_log(&app_handle, &logs, &msg);
|
||||
return Err(msg);
|
||||
}
|
||||
}
|
||||
|
||||
emit_log(&app_handle, &logs, "[Audio] Recording started");
|
||||
Ok("Recording started".to_string())
|
||||
}
|
||||
|
||||
/// Stop recording and return the transcript.
|
||||
#[tauri::command]
|
||||
async fn stop_recording(
|
||||
state: State<'_, AppState>,
|
||||
app_handle: tauri::AppHandle,
|
||||
) -> Result<Vec<TranscriptSegment>, String> {
|
||||
let logs = Arc::clone(&state.logs);
|
||||
|
||||
emit_log(&app_handle, &logs, "[Audio] Stopping recording...");
|
||||
|
||||
// Get the audio samples
|
||||
let audio_samples = {
|
||||
let mut audio_guard = state.audio_capture.lock();
|
||||
if let Some(ref mut capture) = *audio_guard {
|
||||
capture.stop_recording()
|
||||
} else {
|
||||
return Err("No active recording".to_string());
|
||||
}
|
||||
};
|
||||
|
||||
// Start the Python backend
|
||||
let child = Command::new(python_cmd)
|
||||
.args(["-m", "uvicorn", "backend.main:app", "--host", "127.0.0.1", "--port", "8000"])
|
||||
.spawn()
|
||||
.map_err(|e| format!("Failed to start backend: {}", e))?;
|
||||
let duration = audio_samples.len() as f32 / 16000.0;
|
||||
emit_log(&app_handle, &logs, &format!("[Audio] Captured {:.1}s of audio", duration));
|
||||
|
||||
*process_lock = Some(child);
|
||||
Ok("Backend started successfully".to_string())
|
||||
if audio_samples.is_empty() {
|
||||
return Err("No audio captured".to_string());
|
||||
}
|
||||
|
||||
// Transcribe the audio
|
||||
emit_log(&app_handle, &logs, "[Transcribe] Starting transcription...");
|
||||
|
||||
let mut segments = {
|
||||
let transcriber = state.transcriber.lock();
|
||||
if !transcriber.is_loaded() {
|
||||
emit_log(&app_handle, &logs, "[Transcribe ERROR] Whisper model not loaded");
|
||||
return Err("Whisper model not loaded. Please ensure the model is downloaded.".to_string());
|
||||
}
|
||||
|
||||
transcriber.transcribe(&audio_samples)
|
||||
.map_err(|e| format!("Transcription failed: {}", e))?
|
||||
};
|
||||
|
||||
emit_log(&app_handle, &logs, &format!("[Transcribe] Got {} segments", segments.len()));
|
||||
|
||||
// Apply speaker labels using VAD
|
||||
if let Some(ref mut separator) = *state.speaker_separator.lock() {
|
||||
emit_log(&app_handle, &logs, "[Speaker] Applying speaker labels...");
|
||||
segments = separator.apply_speaker_labels(&audio_samples, segments)
|
||||
.map_err(|e| format!("Speaker separation failed: {}", e))?;
|
||||
}
|
||||
|
||||
Ok(segments)
|
||||
}
|
||||
|
||||
/// Transcribe a chunk of audio (for real-time transcription).
|
||||
#[tauri::command]
|
||||
fn stop_backend(backend: State<PythonBackend>) -> Result<String, String> {
|
||||
let mut process_lock = backend.process.lock().map_err(|e| e.to_string())?;
|
||||
async fn transcribe_chunk(
|
||||
state: State<'_, AppState>,
|
||||
audio_data: Vec<f32>,
|
||||
) -> Result<Vec<TranscriptSegment>, String> {
|
||||
let transcriber = state.transcriber.lock();
|
||||
|
||||
if let Some(mut child) = process_lock.take() {
|
||||
child.kill().map_err(|e| format!("Failed to stop backend: {}", e))?;
|
||||
Ok("Backend stopped".to_string())
|
||||
} else {
|
||||
Ok("Backend not running".to_string())
|
||||
if !transcriber.is_loaded() {
|
||||
return Err("Whisper model not loaded".to_string());
|
||||
}
|
||||
|
||||
let segments = transcriber.transcribe(&audio_data)
|
||||
.map_err(|e| format!("Transcription failed: {}", e))?;
|
||||
|
||||
Ok(segments)
|
||||
}
|
||||
|
||||
/// Generate a summary from a transcript.
|
||||
#[tauri::command]
|
||||
async fn summarize(
|
||||
state: State<'_, AppState>,
|
||||
app_handle: tauri::AppHandle,
|
||||
transcript: String,
|
||||
) -> Result<String, String> {
|
||||
let logs = Arc::clone(&state.logs);
|
||||
|
||||
emit_log(&app_handle, &logs, "[Summary] Generating summary...");
|
||||
|
||||
let summarizer_guard = state.summarizer.lock();
|
||||
let summarizer = summarizer_guard.as_ref()
|
||||
.ok_or("Summarizer not initialized")?;
|
||||
|
||||
if !summarizer.is_loaded() {
|
||||
return Err("LLaMA model not loaded".to_string());
|
||||
}
|
||||
|
||||
let summary = summarizer.summarize(&transcript)
|
||||
.map_err(|e| format!("Summarization failed: {}", e))?;
|
||||
|
||||
emit_log(&app_handle, &logs, &format!("[Summary] Generated {} character summary", summary.len()));
|
||||
|
||||
Ok(summary)
|
||||
}
|
||||
|
||||
/// Get backend logs.
|
||||
#[tauri::command]
|
||||
fn get_backend_logs(state: State<'_, AppState>) -> Result<Vec<String>, String> {
|
||||
Ok(state.logs.lock().clone())
|
||||
}
|
||||
|
||||
/// Check if models are loaded and ready.
|
||||
#[tauri::command]
|
||||
fn check_ready(state: State<'_, AppState>) -> Result<bool, String> {
|
||||
let summarizer = state.summarizer.lock();
|
||||
|
||||
// At minimum, we need the summarizer loaded
|
||||
// Whisper can be loaded on first use
|
||||
let ready = summarizer.as_ref().map_or(false, |s| s.is_loaded());
|
||||
|
||||
Ok(ready)
|
||||
}
|
||||
|
||||
#[cfg_attr(mobile, tauri::mobile_entry_point)]
|
||||
pub fn run() {
|
||||
// Initialize tracing
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
info!("Starting Chronara with native Rust backend");
|
||||
|
||||
tauri::Builder::default()
|
||||
.plugin(tauri_plugin_opener::init())
|
||||
.manage(PythonBackend {
|
||||
process: Mutex::new(None),
|
||||
})
|
||||
.invoke_handler(tauri::generate_handler![start_backend, stop_backend])
|
||||
.on_window_event(|window, event| {
|
||||
// Stop backend when window closes
|
||||
if let tauri::WindowEvent::CloseRequested { .. } = event {
|
||||
if let Some(backend) = window.try_state::<PythonBackend>() {
|
||||
if let Ok(mut process_lock) = backend.process.lock() {
|
||||
if let Some(mut child) = process_lock.take() {
|
||||
let _ = child.kill();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.manage(AppState::new())
|
||||
.invoke_handler(tauri::generate_handler![
|
||||
check_models,
|
||||
download_models,
|
||||
initialize_models,
|
||||
start_recording,
|
||||
stop_recording,
|
||||
transcribe_chunk,
|
||||
summarize,
|
||||
get_backend_logs,
|
||||
check_ready,
|
||||
])
|
||||
.run(tauri::generate_context!())
|
||||
.expect("error while running tauri application");
|
||||
}
|
||||
|
||||
@@ -0,0 +1,262 @@
|
||||
//! Audio capture and processing using cpal.
|
||||
//!
|
||||
//! This module handles audio recording in a thread-safe way by using
|
||||
//! a shared buffer that can be accessed from the Tauri state.
|
||||
|
||||
use parking_lot::Mutex;
|
||||
use std::sync::Arc;
|
||||
use std::thread::{self, JoinHandle};
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
/// Sample rate expected by Whisper (16kHz mono)
|
||||
pub const WHISPER_SAMPLE_RATE: u32 = 16000;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum AudioError {
|
||||
#[error("No input device available")]
|
||||
NoInputDevice,
|
||||
#[error("Failed to get default input config: {0}")]
|
||||
ConfigError(String),
|
||||
#[error("Failed to build input stream: {0}")]
|
||||
StreamError(String),
|
||||
#[error("Stream playback error: {0}")]
|
||||
PlaybackError(String),
|
||||
#[error("Recording thread error: {0}")]
|
||||
ThreadError(String),
|
||||
}
|
||||
|
||||
/// Shared audio state that can be safely stored in Tauri state.
|
||||
/// The actual audio capture happens in a separate thread.
|
||||
pub struct AudioCapture {
|
||||
buffer: Arc<Mutex<Vec<f32>>>,
|
||||
is_recording: Arc<Mutex<bool>>,
|
||||
should_stop: Arc<Mutex<bool>>,
|
||||
recording_thread: Option<JoinHandle<()>>,
|
||||
}
|
||||
|
||||
// Implement Send + Sync manually since our struct only contains thread-safe types
|
||||
unsafe impl Send for AudioCapture {}
|
||||
unsafe impl Sync for AudioCapture {}
|
||||
|
||||
impl AudioCapture {
|
||||
/// Create a new audio capture instance.
|
||||
pub fn new() -> Result<Self, AudioError> {
|
||||
Ok(Self {
|
||||
buffer: Arc::new(Mutex::new(Vec::new())),
|
||||
is_recording: Arc::new(Mutex::new(false)),
|
||||
should_stop: Arc::new(Mutex::new(false)),
|
||||
recording_thread: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Start recording audio in a background thread.
|
||||
pub fn start_recording(&mut self) -> Result<(), AudioError> {
|
||||
if *self.is_recording.lock() {
|
||||
warn!("Already recording");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Clear the buffer and reset stop flag
|
||||
self.buffer.lock().clear();
|
||||
*self.should_stop.lock() = false;
|
||||
*self.is_recording.lock() = true;
|
||||
|
||||
let buffer = Arc::clone(&self.buffer);
|
||||
let is_recording = Arc::clone(&self.is_recording);
|
||||
let should_stop = Arc::clone(&self.should_stop);
|
||||
|
||||
// Spawn a thread to handle audio capture
|
||||
let handle = thread::spawn(move || {
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
use cpal::{SampleRate, StreamConfig};
|
||||
|
||||
let host = cpal::default_host();
|
||||
|
||||
let device = match host.default_input_device() {
|
||||
Some(d) => d,
|
||||
None => {
|
||||
error!("No input device available");
|
||||
*is_recording.lock() = false;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
info!("Using audio input device: {}", device.name().unwrap_or_default());
|
||||
|
||||
// Get supported config
|
||||
let supported_config = match device.default_input_config() {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Failed to get input config: {}", e);
|
||||
*is_recording.lock() = false;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
debug!("Supported config: {:?}", supported_config);
|
||||
|
||||
// We want 16kHz mono for Whisper
|
||||
let config = StreamConfig {
|
||||
channels: 1,
|
||||
sample_rate: SampleRate(WHISPER_SAMPLE_RATE),
|
||||
buffer_size: cpal::BufferSize::Default,
|
||||
};
|
||||
|
||||
let buffer_clone = Arc::clone(&buffer);
|
||||
let err_fn = |err| error!("Audio stream error: {}", err);
|
||||
|
||||
// Build the input stream
|
||||
let stream = match device.build_input_stream(
|
||||
&config,
|
||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||
buffer_clone.lock().extend_from_slice(data);
|
||||
},
|
||||
err_fn,
|
||||
None,
|
||||
) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
error!("Failed to build input stream: {}", e);
|
||||
*is_recording.lock() = false;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = stream.play() {
|
||||
error!("Failed to start stream: {}", e);
|
||||
*is_recording.lock() = false;
|
||||
return;
|
||||
}
|
||||
|
||||
info!("Audio recording started");
|
||||
|
||||
// Keep the stream alive until stop is requested
|
||||
while !*should_stop.lock() {
|
||||
thread::sleep(std::time::Duration::from_millis(100));
|
||||
}
|
||||
|
||||
// Stream is automatically stopped when dropped
|
||||
drop(stream);
|
||||
*is_recording.lock() = false;
|
||||
info!("Audio recording stopped");
|
||||
});
|
||||
|
||||
self.recording_thread = Some(handle);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop recording and return the captured audio samples.
|
||||
pub fn stop_recording(&mut self) -> Vec<f32> {
|
||||
// Signal the thread to stop
|
||||
*self.should_stop.lock() = true;
|
||||
|
||||
// Wait for the thread to finish
|
||||
if let Some(handle) = self.recording_thread.take() {
|
||||
let _ = handle.join();
|
||||
}
|
||||
|
||||
let samples = std::mem::take(&mut *self.buffer.lock());
|
||||
info!("Captured {} samples ({:.1}s)", samples.len(), samples.len() as f32 / WHISPER_SAMPLE_RATE as f32);
|
||||
samples
|
||||
}
|
||||
|
||||
/// Get a copy of the current audio buffer without stopping recording.
|
||||
pub fn get_current_buffer(&self) -> Vec<f32> {
|
||||
self.buffer.lock().clone()
|
||||
}
|
||||
|
||||
/// Clear the audio buffer.
|
||||
pub fn clear_buffer(&self) {
|
||||
self.buffer.lock().clear();
|
||||
}
|
||||
|
||||
/// Check if currently recording.
|
||||
pub fn is_recording(&self) -> bool {
|
||||
*self.is_recording.lock()
|
||||
}
|
||||
|
||||
/// Get the duration of captured audio in seconds.
|
||||
pub fn get_duration_seconds(&self) -> f32 {
|
||||
let sample_count = self.buffer.lock().len();
|
||||
sample_count as f32 / WHISPER_SAMPLE_RATE as f32
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AudioCapture {
|
||||
fn default() -> Self {
|
||||
Self::new().expect("Failed to create audio capture")
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for AudioCapture {
|
||||
fn drop(&mut self) {
|
||||
// Ensure recording is stopped when dropped
|
||||
if *self.is_recording.lock() {
|
||||
*self.should_stop.lock() = true;
|
||||
if let Some(handle) = self.recording_thread.take() {
|
||||
let _ = handle.join();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert audio samples from i16 to f32 format.
|
||||
pub fn i16_to_f32(samples: &[i16]) -> Vec<f32> {
|
||||
samples.iter().map(|&s| s as f32 / 32768.0).collect()
|
||||
}
|
||||
|
||||
/// Convert audio samples from f32 to i16 format.
|
||||
pub fn f32_to_i16(samples: &[f32]) -> Vec<i16> {
|
||||
samples
|
||||
.iter()
|
||||
.map(|&s| (s * 32767.0).clamp(-32768.0, 32767.0) as i16)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Resample audio to the target sample rate (simple linear interpolation).
|
||||
/// For production, consider using a proper resampling library like rubato.
|
||||
pub fn resample(samples: &[f32], from_rate: u32, to_rate: u32) -> Vec<f32> {
|
||||
if from_rate == to_rate {
|
||||
return samples.to_vec();
|
||||
}
|
||||
|
||||
let ratio = to_rate as f64 / from_rate as f64;
|
||||
let new_len = (samples.len() as f64 * ratio) as usize;
|
||||
let mut output = Vec::with_capacity(new_len);
|
||||
|
||||
for i in 0..new_len {
|
||||
let src_idx = i as f64 / ratio;
|
||||
let src_idx_floor = src_idx.floor() as usize;
|
||||
let src_idx_ceil = (src_idx_floor + 1).min(samples.len() - 1);
|
||||
let frac = src_idx - src_idx_floor as f64;
|
||||
|
||||
let sample = samples[src_idx_floor] as f64 * (1.0 - frac)
|
||||
+ samples[src_idx_ceil] as f64 * frac;
|
||||
output.push(sample as f32);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_i16_to_f32_conversion() {
|
||||
let i16_samples = vec![0i16, 16384, -16384, 32767, -32768];
|
||||
let f32_samples = i16_to_f32(&i16_samples);
|
||||
|
||||
assert!((f32_samples[0] - 0.0).abs() < 0.001);
|
||||
assert!((f32_samples[1] - 0.5).abs() < 0.001);
|
||||
assert!((f32_samples[2] - (-0.5)).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resample_same_rate() {
|
||||
let samples = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let resampled = resample(&samples, 16000, 16000);
|
||||
assert_eq!(samples, resampled);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
//! Machine learning modules for Chronara.
|
||||
//!
|
||||
//! This module contains the Rust implementations for:
|
||||
//! - Audio capture and processing
|
||||
//! - Whisper transcription
|
||||
//! - LLaMA summarization
|
||||
//! - Voice activity detection for basic speaker separation
|
||||
|
||||
pub mod audio;
|
||||
pub mod transcriber;
|
||||
pub mod summarizer;
|
||||
pub mod vad;
|
||||
|
||||
pub use audio::AudioCapture;
|
||||
pub use transcriber::WhisperTranscriber;
|
||||
pub use summarizer::LlamaSummarizer;
|
||||
pub use vad::SpeakerSeparator;
|
||||
@@ -0,0 +1,198 @@
|
||||
//! LLaMA summarization using llama-cpp-2.
|
||||
|
||||
use llama_cpp_2::context::params::LlamaContextParams;
|
||||
use llama_cpp_2::llama_backend::LlamaBackend;
|
||||
use llama_cpp_2::llama_batch::LlamaBatch;
|
||||
use llama_cpp_2::model::params::LlamaModelParams;
|
||||
use llama_cpp_2::model::LlamaModel;
|
||||
use llama_cpp_2::sampling::LlamaSampler;
|
||||
use std::num::NonZeroU32;
|
||||
use std::path::Path;
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, info};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum SummarizerError {
|
||||
#[error("Failed to initialize LLaMA backend: {0}")]
|
||||
BackendError(String),
|
||||
#[error("Failed to load LLaMA model: {0}")]
|
||||
ModelLoadError(String),
|
||||
#[error("Failed to create context: {0}")]
|
||||
ContextError(String),
|
||||
#[error("Summarization failed: {0}")]
|
||||
SummarizationError(String),
|
||||
#[error("Model not loaded")]
|
||||
ModelNotLoaded,
|
||||
}
|
||||
|
||||
/// LLaMA summarizer using llama-cpp-2 bindings.
|
||||
pub struct LlamaSummarizer {
|
||||
backend: Option<LlamaBackend>,
|
||||
model: Option<LlamaModel>,
|
||||
model_path: Option<String>,
|
||||
}
|
||||
|
||||
impl LlamaSummarizer {
|
||||
/// Create a new summarizer instance (without loading a model).
|
||||
pub fn new() -> Result<Self, SummarizerError> {
|
||||
let backend = LlamaBackend::init()
|
||||
.map_err(|e| SummarizerError::BackendError(e.to_string()))?;
|
||||
|
||||
Ok(Self {
|
||||
backend: Some(backend),
|
||||
model: None,
|
||||
model_path: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Load a LLaMA model from the given path.
|
||||
pub fn load_model(&mut self, model_path: &Path) -> Result<(), SummarizerError> {
|
||||
info!("Loading LLaMA model from: {}", model_path.display());
|
||||
|
||||
let backend = self.backend.as_ref()
|
||||
.ok_or(SummarizerError::BackendError("Backend not initialized".to_string()))?;
|
||||
|
||||
let model_params = LlamaModelParams::default();
|
||||
|
||||
let model = LlamaModel::load_from_file(backend, model_path, &model_params)
|
||||
.map_err(|e| SummarizerError::ModelLoadError(e.to_string()))?;
|
||||
|
||||
self.model = Some(model);
|
||||
self.model_path = Some(model_path.to_string_lossy().to_string());
|
||||
|
||||
info!("LLaMA model loaded successfully");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a model is loaded.
|
||||
pub fn is_loaded(&self) -> bool {
|
||||
self.model.is_some()
|
||||
}
|
||||
|
||||
/// Generate a meeting summary from the transcript.
|
||||
pub fn summarize(&self, transcript: &str) -> Result<String, SummarizerError> {
|
||||
let model = self
|
||||
.model
|
||||
.as_ref()
|
||||
.ok_or(SummarizerError::ModelNotLoaded)?;
|
||||
|
||||
let prompt = format!(
|
||||
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n\
|
||||
You are a helpful assistant that creates concise meeting summaries. Focus on:\n\
|
||||
- Key decisions made\n\
|
||||
- Action items and who owns them\n\
|
||||
- Important discussions and their outcomes\n\
|
||||
- Next steps\n\n\
|
||||
Keep the summary structured and easy to scan.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n\
|
||||
Please summarize this meeting transcript:\n\n\
|
||||
{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n\
|
||||
Meeting Summary:\n",
|
||||
transcript
|
||||
);
|
||||
|
||||
debug!("Generating summary for {} character transcript", transcript.len());
|
||||
|
||||
// Create context for inference
|
||||
let ctx_params = LlamaContextParams::default()
|
||||
.with_n_ctx(NonZeroU32::new(8192));
|
||||
|
||||
let mut ctx = model
|
||||
.new_context(self.backend.as_ref().unwrap(), ctx_params)
|
||||
.map_err(|e| SummarizerError::ContextError(e.to_string()))?;
|
||||
|
||||
// Tokenize the prompt
|
||||
let tokens = model
|
||||
.str_to_token(&prompt, llama_cpp_2::model::AddBos::Always)
|
||||
.map_err(|e| SummarizerError::SummarizationError(e.to_string()))?;
|
||||
|
||||
debug!("Prompt tokenized to {} tokens", tokens.len());
|
||||
|
||||
// Create batch and add tokens
|
||||
let mut batch = LlamaBatch::new(tokens.len(), 1);
|
||||
|
||||
for (i, token) in tokens.iter().enumerate() {
|
||||
let is_last = i == tokens.len() - 1;
|
||||
batch.add(*token, i as i32, &[0], is_last)
|
||||
.map_err(|e| SummarizerError::SummarizationError(e.to_string()))?;
|
||||
}
|
||||
|
||||
// Decode the prompt
|
||||
ctx.decode(&mut batch)
|
||||
.map_err(|e| SummarizerError::SummarizationError(e.to_string()))?;
|
||||
|
||||
// Generate response tokens
|
||||
let mut output_tokens = Vec::new();
|
||||
let max_tokens = 1024;
|
||||
let mut n_cur = tokens.len();
|
||||
|
||||
// Create sampler for token selection
|
||||
let mut sampler = LlamaSampler::chain_simple([
|
||||
LlamaSampler::temp(0.7),
|
||||
LlamaSampler::top_p(0.9, 1),
|
||||
LlamaSampler::dist(42),
|
||||
]);
|
||||
|
||||
for _ in 0..max_tokens {
|
||||
// Sample the next token using the sampler
|
||||
// The sampler.sample() takes context and the index of the last token in the batch
|
||||
let token = sampler.sample(&ctx, (batch.n_tokens() - 1) as i32);
|
||||
|
||||
// Check for end of generation
|
||||
if model.is_eog_token(token) {
|
||||
break;
|
||||
}
|
||||
|
||||
output_tokens.push(token);
|
||||
|
||||
// Prepare next batch
|
||||
batch.clear();
|
||||
batch.add(token, n_cur as i32, &[0], true)
|
||||
.map_err(|e| SummarizerError::SummarizationError(e.to_string()))?;
|
||||
|
||||
n_cur += 1;
|
||||
|
||||
// Decode
|
||||
ctx.decode(&mut batch)
|
||||
.map_err(|e| SummarizerError::SummarizationError(e.to_string()))?;
|
||||
}
|
||||
|
||||
// Convert tokens to string
|
||||
let output_text = output_tokens
|
||||
.iter()
|
||||
.filter_map(|t| model.token_to_str(*t, llama_cpp_2::model::Special::Tokenize).ok())
|
||||
.collect::<String>();
|
||||
|
||||
info!("Summary generated: {} characters", output_text.len());
|
||||
Ok(output_text.trim().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LlamaSummarizer {
|
||||
fn default() -> Self {
|
||||
Self::new().expect("Failed to create summarizer")
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the expected model filename for a given size.
|
||||
pub fn get_model_filename(model_size: &str) -> String {
|
||||
format!("Llama-3.2-{}-Instruct-Q4_K_M.gguf", model_size)
|
||||
}
|
||||
|
||||
/// Check if a LLaMA model exists at the given path.
|
||||
pub fn model_exists(models_dir: &Path, model_size: &str) -> bool {
|
||||
let model_path = models_dir.join(get_model_filename(model_size));
|
||||
model_path.exists()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_get_model_filename() {
|
||||
assert_eq!(
|
||||
get_model_filename("3B"),
|
||||
"Llama-3.2-3B-Instruct-Q4_K_M.gguf"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,180 @@
|
||||
//! Whisper transcription using whisper-rs.
|
||||
|
||||
use std::path::Path;
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, info};
|
||||
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum TranscriberError {
|
||||
#[error("Failed to load Whisper model: {0}")]
|
||||
ModelLoadError(String),
|
||||
#[error("Transcription failed: {0}")]
|
||||
TranscriptionError(String),
|
||||
#[error("Model not loaded")]
|
||||
ModelNotLoaded,
|
||||
}
|
||||
|
||||
/// A transcription segment with timing information.
|
||||
#[derive(Debug, Clone, serde::Serialize)]
|
||||
pub struct TranscriptSegment {
|
||||
pub start: f64,
|
||||
pub end: f64,
|
||||
pub text: String,
|
||||
pub speaker: String,
|
||||
}
|
||||
|
||||
/// Whisper transcriber using whisper-rs bindings.
|
||||
pub struct WhisperTranscriber {
|
||||
context: Option<WhisperContext>,
|
||||
model_path: Option<String>,
|
||||
}
|
||||
|
||||
impl WhisperTranscriber {
|
||||
/// Create a new transcriber instance (without loading a model).
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
context: None,
|
||||
model_path: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Load a Whisper model from the given path.
|
||||
pub fn load_model(&mut self, model_path: &Path) -> Result<(), TranscriberError> {
|
||||
info!("Loading Whisper model from: {}", model_path.display());
|
||||
|
||||
let params = WhisperContextParameters::default();
|
||||
|
||||
let context = WhisperContext::new_with_params(
|
||||
model_path.to_str().unwrap_or_default(),
|
||||
params,
|
||||
)
|
||||
.map_err(|e| TranscriberError::ModelLoadError(e.to_string()))?;
|
||||
|
||||
self.context = Some(context);
|
||||
self.model_path = Some(model_path.to_string_lossy().to_string());
|
||||
|
||||
info!("Whisper model loaded successfully");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a model is loaded.
|
||||
pub fn is_loaded(&self) -> bool {
|
||||
self.context.is_some()
|
||||
}
|
||||
|
||||
/// Transcribe audio samples (expected to be 16kHz mono f32).
|
||||
pub fn transcribe(&self, audio: &[f32]) -> Result<Vec<TranscriptSegment>, TranscriberError> {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.ok_or(TranscriberError::ModelNotLoaded)?;
|
||||
|
||||
debug!("Transcribing {} samples ({:.2}s of audio)",
|
||||
audio.len(),
|
||||
audio.len() as f64 / 16000.0);
|
||||
|
||||
// Create a state for this transcription
|
||||
let mut state = context
|
||||
.create_state()
|
||||
.map_err(|e| TranscriberError::TranscriptionError(e.to_string()))?;
|
||||
|
||||
// Set up transcription parameters
|
||||
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
|
||||
|
||||
// Configure for English transcription
|
||||
params.set_language(Some("en"));
|
||||
params.set_print_special(false);
|
||||
params.set_print_progress(false);
|
||||
params.set_print_realtime(false);
|
||||
params.set_print_timestamps(false);
|
||||
params.set_token_timestamps(true);
|
||||
params.set_translate(false);
|
||||
|
||||
// Run transcription
|
||||
state
|
||||
.full(params, audio)
|
||||
.map_err(|e| TranscriberError::TranscriptionError(e.to_string()))?;
|
||||
|
||||
// Extract segments
|
||||
let num_segments = state.full_n_segments()
|
||||
.map_err(|e| TranscriberError::TranscriptionError(e.to_string()))?;
|
||||
|
||||
let mut segments = Vec::new();
|
||||
|
||||
for i in 0..num_segments {
|
||||
let start = state
|
||||
.full_get_segment_t0(i)
|
||||
.map_err(|e| TranscriberError::TranscriptionError(e.to_string()))?;
|
||||
let end = state
|
||||
.full_get_segment_t1(i)
|
||||
.map_err(|e| TranscriberError::TranscriptionError(e.to_string()))?;
|
||||
let text = state
|
||||
.full_get_segment_text(i)
|
||||
.map_err(|e| TranscriberError::TranscriptionError(e.to_string()))?;
|
||||
|
||||
// Convert centiseconds to seconds
|
||||
let start_secs = start as f64 / 100.0;
|
||||
let end_secs = end as f64 / 100.0;
|
||||
|
||||
segments.push(TranscriptSegment {
|
||||
start: start_secs,
|
||||
end: end_secs,
|
||||
text: text.trim().to_string(),
|
||||
speaker: "Unknown".to_string(), // Will be filled by VAD
|
||||
});
|
||||
}
|
||||
|
||||
info!("Transcription complete: {} segments", segments.len());
|
||||
Ok(segments)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WhisperTranscriber {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Download the Whisper model if not present.
|
||||
/// Returns the path to the model file.
|
||||
pub async fn ensure_whisper_model(models_dir: &Path, model_size: &str) -> Result<std::path::PathBuf, TranscriberError> {
|
||||
let model_filename = format!("ggml-{}.bin", model_size);
|
||||
let model_path = models_dir.join(&model_filename);
|
||||
|
||||
if model_path.exists() {
|
||||
info!("Whisper model already exists: {}", model_path.display());
|
||||
return Ok(model_path);
|
||||
}
|
||||
|
||||
// Model URLs from Hugging Face
|
||||
let model_url = format!(
|
||||
"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/{}",
|
||||
model_filename
|
||||
);
|
||||
|
||||
info!("Downloading Whisper model from: {}", model_url);
|
||||
|
||||
// Create models directory if it doesn't exist
|
||||
std::fs::create_dir_all(models_dir)
|
||||
.map_err(|e| TranscriberError::ModelLoadError(format!("Failed to create models directory: {}", e)))?;
|
||||
|
||||
// Download the model using reqwest or similar
|
||||
// For now, we'll return an error indicating manual download is needed
|
||||
Err(TranscriberError::ModelLoadError(format!(
|
||||
"Model not found. Please download {} to {}",
|
||||
model_filename,
|
||||
model_path.display()
|
||||
)))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_transcriber_not_loaded() {
|
||||
let transcriber = WhisperTranscriber::new();
|
||||
assert!(!transcriber.is_loaded());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,215 @@
|
||||
//! Voice Activity Detection for basic speaker separation.
|
||||
//!
|
||||
//! This module uses the Silero VAD model to detect speech segments
|
||||
//! and provides basic speaker separation based on silence gaps.
|
||||
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, info};
|
||||
use voice_activity_detector::VoiceActivityDetector;
|
||||
|
||||
use super::transcriber::TranscriptSegment;
|
||||
|
||||
/// Sample rate expected by the VAD model (16kHz)
|
||||
pub const VAD_SAMPLE_RATE: u32 = 16000;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum VadError {
|
||||
#[error("Failed to create VAD: {0}")]
|
||||
InitError(String),
|
||||
#[error("VAD processing failed: {0}")]
|
||||
ProcessingError(String),
|
||||
}
|
||||
|
||||
/// A detected speech segment with timing information.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SpeechSegment {
|
||||
pub start_sample: usize,
|
||||
pub end_sample: usize,
|
||||
pub speaker_id: u32,
|
||||
}
|
||||
|
||||
impl SpeechSegment {
|
||||
/// Get the start time in seconds.
|
||||
pub fn start_seconds(&self) -> f64 {
|
||||
self.start_sample as f64 / VAD_SAMPLE_RATE as f64
|
||||
}
|
||||
|
||||
/// Get the end time in seconds.
|
||||
pub fn end_seconds(&self) -> f64 {
|
||||
self.end_sample as f64 / VAD_SAMPLE_RATE as f64
|
||||
}
|
||||
|
||||
/// Get the duration in seconds.
|
||||
pub fn duration_seconds(&self) -> f64 {
|
||||
self.end_seconds() - self.start_seconds()
|
||||
}
|
||||
}
|
||||
|
||||
/// Voice Activity Detector for speaker separation.
|
||||
pub struct SpeakerSeparator {
|
||||
vad: VoiceActivityDetector,
|
||||
/// Minimum silence duration (in seconds) to consider a speaker change
|
||||
min_silence_for_speaker_change: f64,
|
||||
/// Minimum speech duration (in seconds) to consider a valid segment
|
||||
min_speech_duration: f64,
|
||||
}
|
||||
|
||||
impl SpeakerSeparator {
|
||||
/// Create a new speaker separator with default settings.
|
||||
pub fn new() -> Result<Self, VadError> {
|
||||
Self::with_settings(1.5, 0.3)
|
||||
}
|
||||
|
||||
/// Create a new speaker separator with custom settings.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `min_silence_for_speaker_change` - Minimum silence duration (seconds) to consider a speaker change
|
||||
/// * `min_speech_duration` - Minimum speech duration (seconds) to consider a valid segment
|
||||
pub fn with_settings(
|
||||
min_silence_for_speaker_change: f64,
|
||||
min_speech_duration: f64,
|
||||
) -> Result<Self, VadError> {
|
||||
let vad = VoiceActivityDetector::builder()
|
||||
.sample_rate(VAD_SAMPLE_RATE)
|
||||
.chunk_size(512usize) // 512 samples for 16kHz
|
||||
.build()
|
||||
.map_err(|e: voice_activity_detector::Error| VadError::InitError(e.to_string()))?;
|
||||
|
||||
info!("VAD initialized with sample_rate={}, min_silence={}s, min_speech={}s",
|
||||
VAD_SAMPLE_RATE, min_silence_for_speaker_change, min_speech_duration);
|
||||
|
||||
Ok(Self {
|
||||
vad,
|
||||
min_silence_for_speaker_change,
|
||||
min_speech_duration,
|
||||
})
|
||||
}
|
||||
|
||||
/// Detect speech segments and assign speaker IDs based on silence gaps.
|
||||
///
|
||||
/// This is a simple heuristic: if there's a long enough silence gap,
|
||||
/// we assume a different speaker might be talking.
|
||||
pub fn detect_speakers(&mut self, audio: &[f32]) -> Result<Vec<SpeechSegment>, VadError> {
|
||||
let chunk_size = 512;
|
||||
let mut segments = Vec::new();
|
||||
let mut current_segment_start: Option<usize> = None;
|
||||
let mut last_speech_end: Option<usize> = None;
|
||||
let mut current_speaker_id = 0u32;
|
||||
|
||||
let min_silence_samples = (self.min_silence_for_speaker_change * VAD_SAMPLE_RATE as f64) as usize;
|
||||
let min_speech_samples = (self.min_speech_duration * VAD_SAMPLE_RATE as f64) as usize;
|
||||
|
||||
debug!("Processing {} samples for VAD", audio.len());
|
||||
|
||||
for (chunk_idx, chunk) in audio.chunks(chunk_size).enumerate() {
|
||||
if chunk.len() < chunk_size {
|
||||
// Skip incomplete chunks at the end
|
||||
continue;
|
||||
}
|
||||
|
||||
let sample_offset = chunk_idx * chunk_size;
|
||||
|
||||
// Predict voice activity for this chunk
|
||||
let probability = self.vad.predict(chunk.iter().copied());
|
||||
|
||||
let is_speech = probability > 0.5;
|
||||
|
||||
if is_speech {
|
||||
if current_segment_start.is_none() {
|
||||
// Check if we should change speaker
|
||||
if let Some(last_end) = last_speech_end {
|
||||
let silence_duration = sample_offset - last_end;
|
||||
if silence_duration >= min_silence_samples {
|
||||
current_speaker_id += 1;
|
||||
debug!("Speaker change detected at sample {} (silence: {}ms)",
|
||||
sample_offset, silence_duration * 1000 / VAD_SAMPLE_RATE as usize);
|
||||
}
|
||||
}
|
||||
current_segment_start = Some(sample_offset);
|
||||
}
|
||||
} else if let Some(start) = current_segment_start {
|
||||
// Speech ended
|
||||
let segment_duration = sample_offset - start;
|
||||
|
||||
if segment_duration >= min_speech_samples {
|
||||
segments.push(SpeechSegment {
|
||||
start_sample: start,
|
||||
end_sample: sample_offset,
|
||||
speaker_id: current_speaker_id,
|
||||
});
|
||||
}
|
||||
|
||||
last_speech_end = Some(sample_offset);
|
||||
current_segment_start = None;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle case where speech continues to the end
|
||||
if let Some(start) = current_segment_start {
|
||||
let segment_duration = audio.len() - start;
|
||||
if segment_duration >= min_speech_samples {
|
||||
segments.push(SpeechSegment {
|
||||
start_sample: start,
|
||||
end_sample: audio.len(),
|
||||
speaker_id: current_speaker_id,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
info!("Detected {} speech segments with {} speakers",
|
||||
segments.len(),
|
||||
segments.iter().map(|s| s.speaker_id).max().unwrap_or(0) + 1);
|
||||
|
||||
Ok(segments)
|
||||
}
|
||||
|
||||
/// Apply speaker labels to transcript segments based on VAD results.
|
||||
pub fn apply_speaker_labels(
|
||||
&mut self,
|
||||
audio: &[f32],
|
||||
mut transcript: Vec<TranscriptSegment>,
|
||||
) -> Result<Vec<TranscriptSegment>, VadError> {
|
||||
let speech_segments = self.detect_speakers(audio)?;
|
||||
|
||||
for segment in &mut transcript {
|
||||
// Find the speech segment that overlaps with this transcript segment
|
||||
let segment_mid = (segment.start + segment.end) / 2.0;
|
||||
|
||||
for speech in &speech_segments {
|
||||
let speech_start = speech.start_seconds();
|
||||
let speech_end = speech.end_seconds();
|
||||
|
||||
if segment_mid >= speech_start && segment_mid <= speech_end {
|
||||
segment.speaker = format!("Speaker {}", speech.speaker_id + 1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(transcript)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SpeakerSeparator {
|
||||
fn default() -> Self {
|
||||
Self::new().expect("Failed to create speaker separator")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_speech_segment_timing() {
|
||||
let segment = SpeechSegment {
|
||||
start_sample: 16000,
|
||||
end_sample: 32000,
|
||||
speaker_id: 0,
|
||||
};
|
||||
|
||||
assert!((segment.start_seconds() - 1.0).abs() < 0.001);
|
||||
assert!((segment.end_seconds() - 2.0).abs() < 0.001);
|
||||
assert!((segment.duration_seconds() - 1.0).abs() < 0.001);
|
||||
}
|
||||
}
|
||||
@@ -32,8 +32,14 @@
|
||||
"shortDescription": "Meeting transcription and summarization tool using local AI models",
|
||||
"longDescription": "Chronara provides real-time meeting transcription with speaker diarization and AI-powered summaries, all processed locally for maximum privacy.",
|
||||
"windows": {
|
||||
"nsis": {}
|
||||
"nsis": {
|
||||
"displayLanguageSelector": false,
|
||||
"installerIcon": "icons/icon.ico"
|
||||
}
|
||||
},
|
||||
"resources": [
|
||||
"resources/"
|
||||
],
|
||||
"icon": [
|
||||
"icons/32x32.png",
|
||||
"icons/128x128.png",
|
||||
|
||||
Reference in New Issue
Block a user