//! LLaMA summarization using llama-cpp-2. use llama_cpp_2::context::params::LlamaContextParams; use llama_cpp_2::llama_backend::LlamaBackend; use llama_cpp_2::llama_batch::LlamaBatch; use llama_cpp_2::model::params::LlamaModelParams; use llama_cpp_2::model::LlamaModel; use llama_cpp_2::sampling::LlamaSampler; use std::num::NonZeroU32; use std::path::Path; use thiserror::Error; use tracing::{debug, info}; #[derive(Error, Debug)] pub enum SummarizerError { #[error("Failed to initialize LLaMA backend: {0}")] BackendError(String), #[error("Failed to load LLaMA model: {0}")] ModelLoadError(String), #[error("Failed to create context: {0}")] ContextError(String), #[error("Summarization failed: {0}")] SummarizationError(String), #[error("Model not loaded")] ModelNotLoaded, } /// LLaMA summarizer using llama-cpp-2 bindings. pub struct LlamaSummarizer { backend: Option, model: Option, model_path: Option, } impl LlamaSummarizer { /// Create a new summarizer instance (without loading a model). pub fn new() -> Result { let backend = LlamaBackend::init() .map_err(|e| SummarizerError::BackendError(e.to_string()))?; Ok(Self { backend: Some(backend), model: None, model_path: None, }) } /// Load a LLaMA model from the given path. pub fn load_model(&mut self, model_path: &Path) -> Result<(), SummarizerError> { info!("Loading LLaMA model from: {}", model_path.display()); let backend = self.backend.as_ref() .ok_or(SummarizerError::BackendError("Backend not initialized".to_string()))?; let model_params = LlamaModelParams::default(); let model = LlamaModel::load_from_file(backend, model_path, &model_params) .map_err(|e| SummarizerError::ModelLoadError(e.to_string()))?; self.model = Some(model); self.model_path = Some(model_path.to_string_lossy().to_string()); info!("LLaMA model loaded successfully"); Ok(()) } /// Check if a model is loaded. pub fn is_loaded(&self) -> bool { self.model.is_some() } /// Generate a meeting summary from the transcript. pub fn summarize(&self, transcript: &str) -> Result { self.summarize_with_progress(transcript, |_| {}) } /// Generate a meeting summary with progress callback. pub fn summarize_with_progress(&self, transcript: &str, mut progress_fn: F) -> Result where F: FnMut(f64), { let model = self .model .as_ref() .ok_or(SummarizerError::ModelNotLoaded)?; let prompt = format!( "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n\ You are a helpful assistant that creates structured meeting summaries. \ Format your response using the following template:\n\n\ **Title:** A concise, descriptive title for the meeting (5-10 words max).\n\n\ **Summary:** A high level overview of the meeting.\n\n\ **Key decisions:** Any important resolutions that the meeting reached.\n\n\ **Action Items:**\n\ - list of things that must be addressed\n\n\ **Discussion points:**\n\ - a list of each topic/argument/counterpoint brought up in the meeting.\n\n\ Be concise but comprehensive. Focus on capturing all important information.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n\ Please summarize this meeting transcript:\n\n\ {}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", transcript ); debug!("Generating summary for {} character transcript", transcript.len()); // Report initial progress progress_fn(0.0); // Create context for inference let ctx_params = LlamaContextParams::default() .with_n_ctx(NonZeroU32::new(8192)); let mut ctx = model .new_context(self.backend.as_ref().unwrap(), ctx_params) .map_err(|e| SummarizerError::ContextError(e.to_string()))?; // Report progress after context creation progress_fn(0.1); // Tokenize the prompt let tokens = model .str_to_token(&prompt, llama_cpp_2::model::AddBos::Always) .map_err(|e| SummarizerError::SummarizationError(e.to_string()))?; debug!("Prompt tokenized to {} tokens", tokens.len()); // Report progress after tokenization progress_fn(0.2); // Create batch and add tokens let mut batch = LlamaBatch::new(tokens.len(), 1); for (i, token) in tokens.iter().enumerate() { let is_last = i == tokens.len() - 1; batch.add(*token, i as i32, &[0], is_last) .map_err(|e| SummarizerError::SummarizationError(e.to_string()))?; } // Decode the prompt ctx.decode(&mut batch) .map_err(|e| SummarizerError::SummarizationError(e.to_string()))?; // Report progress after prompt decode progress_fn(0.3); // Generate response tokens let mut output_tokens = Vec::new(); let max_tokens = 1024; let mut n_cur = tokens.len(); // Create sampler for token selection let mut sampler = LlamaSampler::chain_simple([ LlamaSampler::temp(0.7), LlamaSampler::top_p(0.9, 1), LlamaSampler::dist(42), ]); for i in 0..max_tokens { // Sample the next token using the sampler // The sampler.sample() takes context and the index of the last token in the batch let token = sampler.sample(&ctx, (batch.n_tokens() - 1) as i32); // Check for end of generation if model.is_eog_token(token) { break; } output_tokens.push(token); // Report progress during generation (from 30% to 90%) let generation_progress = i as f64 / max_tokens as f64; progress_fn(0.3 + (generation_progress * 0.6)); // Prepare next batch batch.clear(); batch.add(token, n_cur as i32, &[0], true) .map_err(|e| SummarizerError::SummarizationError(e.to_string()))?; n_cur += 1; // Decode ctx.decode(&mut batch) .map_err(|e| SummarizerError::SummarizationError(e.to_string()))?; } // Convert tokens to string let output_text = output_tokens .iter() .filter_map(|t| model.token_to_str(*t, llama_cpp_2::model::Special::Tokenize).ok()) .collect::(); // Report completion progress_fn(1.0); info!("Summary generated: {} characters", output_text.len()); Ok(output_text.trim().to_string()) } } impl Default for LlamaSummarizer { fn default() -> Self { Self::new().expect("Failed to create summarizer") } } /// Get the expected model filename for a given size. pub fn get_model_filename(model_size: &str) -> String { format!("Llama-3.2-{}-Instruct-Q4_K_M.gguf", model_size) } /// Check if a LLaMA model exists at the given path. pub fn model_exists(models_dir: &Path, model_size: &str) -> bool { let model_path = models_dir.join(get_model_filename(model_size)); model_path.exists() } #[cfg(test)] mod tests { use super::*; #[test] fn test_get_model_filename() { assert_eq!( get_model_filename("3B"), "Llama-3.2-3B-Instruct-Q4_K_M.gguf" ); } }