generated from nhcarrigan/template
feat: display progress for finalising audio and summarising
This commit is contained in:
+18
-3
@@ -371,6 +371,7 @@ async fn stop_recording_batch(
|
||||
// Transcribe the audio
|
||||
emit_log(&app_handle, &logs, "[Transcribe] Starting transcription...");
|
||||
|
||||
let app_handle_clone = app_handle.clone();
|
||||
let mut segments = {
|
||||
let transcriber = state.transcriber.lock();
|
||||
if !transcriber.is_loaded() {
|
||||
@@ -378,7 +379,10 @@ async fn stop_recording_batch(
|
||||
return Err("Whisper model not loaded. Please ensure the model is downloaded.".to_string());
|
||||
}
|
||||
|
||||
transcriber.transcribe(&audio_samples)
|
||||
transcriber.transcribe_with_progress(&audio_samples, move |progress| {
|
||||
// Emit progress event to frontend
|
||||
let _ = app_handle_clone.emit("transcription-progress", progress);
|
||||
})
|
||||
.map_err(|e| format!("Transcription failed: {}", e))?
|
||||
};
|
||||
|
||||
@@ -399,6 +403,7 @@ async fn stop_recording_batch(
|
||||
async fn transcribe_chunk(
|
||||
state: State<'_, AppState>,
|
||||
audio_data: Vec<f32>,
|
||||
app_handle: tauri::AppHandle,
|
||||
) -> Result<Vec<TranscriptSegment>, String> {
|
||||
let transcriber = state.transcriber.lock();
|
||||
|
||||
@@ -406,7 +411,13 @@ async fn transcribe_chunk(
|
||||
return Err("Whisper model not loaded".to_string());
|
||||
}
|
||||
|
||||
let segments = transcriber.transcribe(&audio_data)
|
||||
// Clone the app handle for the closure
|
||||
let app_handle_clone = app_handle.clone();
|
||||
|
||||
let segments = transcriber.transcribe_with_progress(&audio_data, move |progress| {
|
||||
// Emit progress event to frontend
|
||||
let _ = app_handle_clone.emit("transcription-progress", progress);
|
||||
})
|
||||
.map_err(|e| format!("Transcription failed: {}", e))?;
|
||||
|
||||
Ok(segments)
|
||||
@@ -460,7 +471,11 @@ async fn summarize(
|
||||
return Err("LLaMA model not loaded".to_string());
|
||||
}
|
||||
|
||||
let summary = summarizer.summarize(&transcript)
|
||||
let app_handle_clone = app_handle.clone();
|
||||
let summary = summarizer.summarize_with_progress(&transcript, move |progress| {
|
||||
// Emit progress event to frontend
|
||||
let _ = app_handle_clone.emit("summary-progress", progress);
|
||||
})
|
||||
.map_err(|e| format!("Summarization failed: {}", e))?;
|
||||
|
||||
emit_log(&app_handle, &logs, &format!("[Summary] Generated {} character summary", summary.len()));
|
||||
|
||||
@@ -71,6 +71,14 @@ impl LlamaSummarizer {
|
||||
|
||||
/// Generate a meeting summary from the transcript.
|
||||
pub fn summarize(&self, transcript: &str) -> Result<String, SummarizerError> {
|
||||
self.summarize_with_progress(transcript, |_| {})
|
||||
}
|
||||
|
||||
/// Generate a meeting summary with progress callback.
|
||||
pub fn summarize_with_progress<F>(&self, transcript: &str, mut progress_fn: F) -> Result<String, SummarizerError>
|
||||
where
|
||||
F: FnMut(f64),
|
||||
{
|
||||
let model = self
|
||||
.model
|
||||
.as_ref()
|
||||
@@ -94,6 +102,9 @@ impl LlamaSummarizer {
|
||||
|
||||
debug!("Generating summary for {} character transcript", transcript.len());
|
||||
|
||||
// Report initial progress
|
||||
progress_fn(0.0);
|
||||
|
||||
// Create context for inference
|
||||
let ctx_params = LlamaContextParams::default()
|
||||
.with_n_ctx(NonZeroU32::new(8192));
|
||||
@@ -102,6 +113,9 @@ impl LlamaSummarizer {
|
||||
.new_context(self.backend.as_ref().unwrap(), ctx_params)
|
||||
.map_err(|e| SummarizerError::ContextError(e.to_string()))?;
|
||||
|
||||
// Report progress after context creation
|
||||
progress_fn(0.1);
|
||||
|
||||
// Tokenize the prompt
|
||||
let tokens = model
|
||||
.str_to_token(&prompt, llama_cpp_2::model::AddBos::Always)
|
||||
@@ -109,6 +123,9 @@ impl LlamaSummarizer {
|
||||
|
||||
debug!("Prompt tokenized to {} tokens", tokens.len());
|
||||
|
||||
// Report progress after tokenization
|
||||
progress_fn(0.2);
|
||||
|
||||
// Create batch and add tokens
|
||||
let mut batch = LlamaBatch::new(tokens.len(), 1);
|
||||
|
||||
@@ -122,6 +139,9 @@ impl LlamaSummarizer {
|
||||
ctx.decode(&mut batch)
|
||||
.map_err(|e| SummarizerError::SummarizationError(e.to_string()))?;
|
||||
|
||||
// Report progress after prompt decode
|
||||
progress_fn(0.3);
|
||||
|
||||
// Generate response tokens
|
||||
let mut output_tokens = Vec::new();
|
||||
let max_tokens = 1024;
|
||||
@@ -134,7 +154,7 @@ impl LlamaSummarizer {
|
||||
LlamaSampler::dist(42),
|
||||
]);
|
||||
|
||||
for _ in 0..max_tokens {
|
||||
for i in 0..max_tokens {
|
||||
// Sample the next token using the sampler
|
||||
// The sampler.sample() takes context and the index of the last token in the batch
|
||||
let token = sampler.sample(&ctx, (batch.n_tokens() - 1) as i32);
|
||||
@@ -146,6 +166,10 @@ impl LlamaSummarizer {
|
||||
|
||||
output_tokens.push(token);
|
||||
|
||||
// Report progress during generation (from 30% to 90%)
|
||||
let generation_progress = i as f64 / max_tokens as f64;
|
||||
progress_fn(0.3 + (generation_progress * 0.6));
|
||||
|
||||
// Prepare next batch
|
||||
batch.clear();
|
||||
batch.add(token, n_cur as i32, &[0], true)
|
||||
@@ -164,6 +188,9 @@ impl LlamaSummarizer {
|
||||
.filter_map(|t| model.token_to_str(*t, llama_cpp_2::model::Special::Tokenize).ok())
|
||||
.collect::<String>();
|
||||
|
||||
// Report completion
|
||||
progress_fn(1.0);
|
||||
|
||||
info!("Summary generated: {} characters", output_text.len());
|
||||
Ok(output_text.trim().to_string())
|
||||
}
|
||||
|
||||
@@ -65,6 +65,14 @@ impl WhisperTranscriber {
|
||||
|
||||
/// Transcribe audio samples (expected to be 16kHz mono f32).
|
||||
pub fn transcribe(&self, audio: &[f32]) -> Result<Vec<TranscriptSegment>, TranscriberError> {
|
||||
self.transcribe_with_progress(audio, |_| {})
|
||||
}
|
||||
|
||||
/// Transcribe audio samples with progress callback.
|
||||
pub fn transcribe_with_progress<F>(&self, audio: &[f32], mut progress_fn: F) -> Result<Vec<TranscriptSegment>, TranscriberError>
|
||||
where
|
||||
F: FnMut(f64),
|
||||
{
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
@@ -74,11 +82,17 @@ impl WhisperTranscriber {
|
||||
audio.len(),
|
||||
audio.len() as f64 / 16000.0);
|
||||
|
||||
// Report initial progress
|
||||
progress_fn(0.0);
|
||||
|
||||
// Create a state for this transcription
|
||||
let mut state = context
|
||||
.create_state()
|
||||
.map_err(|e| TranscriberError::TranscriptionError(e.to_string()))?;
|
||||
|
||||
// Report progress after state creation
|
||||
progress_fn(0.2);
|
||||
|
||||
// Set up transcription parameters
|
||||
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
|
||||
|
||||
@@ -96,6 +110,9 @@ impl WhisperTranscriber {
|
||||
.full(params, audio)
|
||||
.map_err(|e| TranscriberError::TranscriptionError(e.to_string()))?;
|
||||
|
||||
// Report progress after transcription
|
||||
progress_fn(0.8);
|
||||
|
||||
// Extract segments
|
||||
let num_segments = state.full_n_segments()
|
||||
.map_err(|e| TranscriberError::TranscriptionError(e.to_string()))?;
|
||||
@@ -125,6 +142,9 @@ impl WhisperTranscriber {
|
||||
});
|
||||
}
|
||||
|
||||
// Report completion
|
||||
progress_fn(1.0);
|
||||
|
||||
info!("Transcription complete: {} segments", segments.len());
|
||||
Ok(segments)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user