generated from nhcarrigan/template
feat: display progress for finalising audio and summarising
This commit is contained in:
@@ -71,6 +71,14 @@ impl LlamaSummarizer {
|
||||
|
||||
/// Generate a meeting summary from the transcript.
|
||||
pub fn summarize(&self, transcript: &str) -> Result<String, SummarizerError> {
|
||||
self.summarize_with_progress(transcript, |_| {})
|
||||
}
|
||||
|
||||
/// Generate a meeting summary with progress callback.
|
||||
pub fn summarize_with_progress<F>(&self, transcript: &str, mut progress_fn: F) -> Result<String, SummarizerError>
|
||||
where
|
||||
F: FnMut(f64),
|
||||
{
|
||||
let model = self
|
||||
.model
|
||||
.as_ref()
|
||||
@@ -94,6 +102,9 @@ impl LlamaSummarizer {
|
||||
|
||||
debug!("Generating summary for {} character transcript", transcript.len());
|
||||
|
||||
// Report initial progress
|
||||
progress_fn(0.0);
|
||||
|
||||
// Create context for inference
|
||||
let ctx_params = LlamaContextParams::default()
|
||||
.with_n_ctx(NonZeroU32::new(8192));
|
||||
@@ -102,6 +113,9 @@ impl LlamaSummarizer {
|
||||
.new_context(self.backend.as_ref().unwrap(), ctx_params)
|
||||
.map_err(|e| SummarizerError::ContextError(e.to_string()))?;
|
||||
|
||||
// Report progress after context creation
|
||||
progress_fn(0.1);
|
||||
|
||||
// Tokenize the prompt
|
||||
let tokens = model
|
||||
.str_to_token(&prompt, llama_cpp_2::model::AddBos::Always)
|
||||
@@ -109,6 +123,9 @@ impl LlamaSummarizer {
|
||||
|
||||
debug!("Prompt tokenized to {} tokens", tokens.len());
|
||||
|
||||
// Report progress after tokenization
|
||||
progress_fn(0.2);
|
||||
|
||||
// Create batch and add tokens
|
||||
let mut batch = LlamaBatch::new(tokens.len(), 1);
|
||||
|
||||
@@ -122,6 +139,9 @@ impl LlamaSummarizer {
|
||||
ctx.decode(&mut batch)
|
||||
.map_err(|e| SummarizerError::SummarizationError(e.to_string()))?;
|
||||
|
||||
// Report progress after prompt decode
|
||||
progress_fn(0.3);
|
||||
|
||||
// Generate response tokens
|
||||
let mut output_tokens = Vec::new();
|
||||
let max_tokens = 1024;
|
||||
@@ -134,7 +154,7 @@ impl LlamaSummarizer {
|
||||
LlamaSampler::dist(42),
|
||||
]);
|
||||
|
||||
for _ in 0..max_tokens {
|
||||
for i in 0..max_tokens {
|
||||
// Sample the next token using the sampler
|
||||
// The sampler.sample() takes context and the index of the last token in the batch
|
||||
let token = sampler.sample(&ctx, (batch.n_tokens() - 1) as i32);
|
||||
@@ -146,6 +166,10 @@ impl LlamaSummarizer {
|
||||
|
||||
output_tokens.push(token);
|
||||
|
||||
// Report progress during generation (from 30% to 90%)
|
||||
let generation_progress = i as f64 / max_tokens as f64;
|
||||
progress_fn(0.3 + (generation_progress * 0.6));
|
||||
|
||||
// Prepare next batch
|
||||
batch.clear();
|
||||
batch.add(token, n_cur as i32, &[0], true)
|
||||
@@ -164,6 +188,9 @@ impl LlamaSummarizer {
|
||||
.filter_map(|t| model.token_to_str(*t, llama_cpp_2::model::Special::Tokenize).ok())
|
||||
.collect::<String>();
|
||||
|
||||
// Report completion
|
||||
progress_fn(1.0);
|
||||
|
||||
info!("Summary generated: {} characters", output_text.len());
|
||||
Ok(output_text.trim().to_string())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user