Files
chronara/src-tauri/src/ml/summarizer.rs
T

229 lines
7.6 KiB
Rust

//! LLaMA summarization using llama-cpp-2.
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::LlamaModel;
use llama_cpp_2::sampling::LlamaSampler;
use std::num::NonZeroU32;
use std::path::Path;
use thiserror::Error;
use tracing::{debug, info};
#[derive(Error, Debug)]
pub enum SummarizerError {
#[error("Failed to initialize LLaMA backend: {0}")]
BackendError(String),
#[error("Failed to load LLaMA model: {0}")]
ModelLoadError(String),
#[error("Failed to create context: {0}")]
ContextError(String),
#[error("Summarization failed: {0}")]
SummarizationError(String),
#[error("Model not loaded")]
ModelNotLoaded,
}
/// LLaMA summarizer using llama-cpp-2 bindings.
pub struct LlamaSummarizer {
backend: Option<LlamaBackend>,
model: Option<LlamaModel>,
model_path: Option<String>,
}
impl LlamaSummarizer {
/// Create a new summarizer instance (without loading a model).
pub fn new() -> Result<Self, SummarizerError> {
let backend = LlamaBackend::init()
.map_err(|e| SummarizerError::BackendError(e.to_string()))?;
Ok(Self {
backend: Some(backend),
model: None,
model_path: None,
})
}
/// Load a LLaMA model from the given path.
pub fn load_model(&mut self, model_path: &Path) -> Result<(), SummarizerError> {
info!("Loading LLaMA model from: {}", model_path.display());
let backend = self.backend.as_ref()
.ok_or(SummarizerError::BackendError("Backend not initialized".to_string()))?;
let model_params = LlamaModelParams::default();
let model = LlamaModel::load_from_file(backend, model_path, &model_params)
.map_err(|e| SummarizerError::ModelLoadError(e.to_string()))?;
self.model = Some(model);
self.model_path = Some(model_path.to_string_lossy().to_string());
info!("LLaMA model loaded successfully");
Ok(())
}
/// Check if a model is loaded.
pub fn is_loaded(&self) -> bool {
self.model.is_some()
}
/// Generate a meeting summary from the transcript.
pub fn summarize(&self, transcript: &str) -> Result<String, SummarizerError> {
self.summarize_with_progress(transcript, |_| {})
}
/// Generate a meeting summary with progress callback.
pub fn summarize_with_progress<F>(&self, transcript: &str, mut progress_fn: F) -> Result<String, SummarizerError>
where
F: FnMut(f64),
{
let model = self
.model
.as_ref()
.ok_or(SummarizerError::ModelNotLoaded)?;
let prompt = format!(
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n\
You are a helpful assistant that creates structured meeting summaries. \
Format your response using the following template:\n\n\
**Title:** A concise, descriptive title for the meeting (5-10 words max).\n\n\
**Summary:** A high level overview of the meeting.\n\n\
**Key decisions:** Any important resolutions that the meeting reached.\n\n\
**Action Items:**\n\
- list of things that must be addressed\n\n\
**Discussion points:**\n\
- a list of each topic/argument/counterpoint brought up in the meeting.\n\n\
Be concise but comprehensive. Focus on capturing all important information.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n\
Please summarize this meeting transcript:\n\n\
{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
transcript
);
debug!("Generating summary for {} character transcript", transcript.len());
// Report initial progress
progress_fn(0.0);
// Create context for inference
let ctx_params = LlamaContextParams::default()
.with_n_ctx(NonZeroU32::new(8192));
let mut ctx = model
.new_context(self.backend.as_ref().unwrap(), ctx_params)
.map_err(|e| SummarizerError::ContextError(e.to_string()))?;
// Report progress after context creation
progress_fn(0.1);
// Tokenize the prompt
let tokens = model
.str_to_token(&prompt, llama_cpp_2::model::AddBos::Always)
.map_err(|e| SummarizerError::SummarizationError(e.to_string()))?;
debug!("Prompt tokenized to {} tokens", tokens.len());
// Report progress after tokenization
progress_fn(0.2);
// Create batch and add tokens
let mut batch = LlamaBatch::new(tokens.len(), 1);
for (i, token) in tokens.iter().enumerate() {
let is_last = i == tokens.len() - 1;
batch.add(*token, i as i32, &[0], is_last)
.map_err(|e| SummarizerError::SummarizationError(e.to_string()))?;
}
// Decode the prompt
ctx.decode(&mut batch)
.map_err(|e| SummarizerError::SummarizationError(e.to_string()))?;
// Report progress after prompt decode
progress_fn(0.3);
// Generate response tokens
let mut output_tokens = Vec::new();
let max_tokens = 1024;
let mut n_cur = tokens.len();
// Create sampler for token selection
let mut sampler = LlamaSampler::chain_simple([
LlamaSampler::temp(0.7),
LlamaSampler::top_p(0.9, 1),
LlamaSampler::dist(42),
]);
for i in 0..max_tokens {
// Sample the next token using the sampler
// The sampler.sample() takes context and the index of the last token in the batch
let token = sampler.sample(&ctx, (batch.n_tokens() - 1) as i32);
// Check for end of generation
if model.is_eog_token(token) {
break;
}
output_tokens.push(token);
// Report progress during generation (from 30% to 90%)
let generation_progress = i as f64 / max_tokens as f64;
progress_fn(0.3 + (generation_progress * 0.6));
// Prepare next batch
batch.clear();
batch.add(token, n_cur as i32, &[0], true)
.map_err(|e| SummarizerError::SummarizationError(e.to_string()))?;
n_cur += 1;
// Decode
ctx.decode(&mut batch)
.map_err(|e| SummarizerError::SummarizationError(e.to_string()))?;
}
// Convert tokens to string
let output_text = output_tokens
.iter()
.filter_map(|t| model.token_to_str(*t, llama_cpp_2::model::Special::Tokenize).ok())
.collect::<String>();
// Report completion
progress_fn(1.0);
info!("Summary generated: {} characters", output_text.len());
Ok(output_text.trim().to_string())
}
}
impl Default for LlamaSummarizer {
fn default() -> Self {
Self::new().expect("Failed to create summarizer")
}
}
/// Get the expected model filename for a given size.
pub fn get_model_filename(model_size: &str) -> String {
format!("Llama-3.2-{}-Instruct-Q4_K_M.gguf", model_size)
}
/// Check if a LLaMA model exists at the given path.
pub fn model_exists(models_dir: &Path, model_size: &str) -> bool {
let model_path = models_dir.join(get_model_filename(model_size));
model_path.exists()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_model_filename() {
assert_eq!(
get_model_filename("3B"),
"Llama-3.2-3B-Instruct-Q4_K_M.gguf"
);
}
}