generated from nhcarrigan/template
229 lines
7.6 KiB
Rust
229 lines
7.6 KiB
Rust
//! LLaMA summarization using llama-cpp-2.
|
|
|
|
use llama_cpp_2::context::params::LlamaContextParams;
|
|
use llama_cpp_2::llama_backend::LlamaBackend;
|
|
use llama_cpp_2::llama_batch::LlamaBatch;
|
|
use llama_cpp_2::model::params::LlamaModelParams;
|
|
use llama_cpp_2::model::LlamaModel;
|
|
use llama_cpp_2::sampling::LlamaSampler;
|
|
use std::num::NonZeroU32;
|
|
use std::path::Path;
|
|
use thiserror::Error;
|
|
use tracing::{debug, info};
|
|
|
|
#[derive(Error, Debug)]
|
|
pub enum SummarizerError {
|
|
#[error("Failed to initialize LLaMA backend: {0}")]
|
|
BackendError(String),
|
|
#[error("Failed to load LLaMA model: {0}")]
|
|
ModelLoadError(String),
|
|
#[error("Failed to create context: {0}")]
|
|
ContextError(String),
|
|
#[error("Summarization failed: {0}")]
|
|
SummarizationError(String),
|
|
#[error("Model not loaded")]
|
|
ModelNotLoaded,
|
|
}
|
|
|
|
/// LLaMA summarizer using llama-cpp-2 bindings.
|
|
pub struct LlamaSummarizer {
|
|
backend: Option<LlamaBackend>,
|
|
model: Option<LlamaModel>,
|
|
model_path: Option<String>,
|
|
}
|
|
|
|
impl LlamaSummarizer {
|
|
/// Create a new summarizer instance (without loading a model).
|
|
pub fn new() -> Result<Self, SummarizerError> {
|
|
let backend = LlamaBackend::init()
|
|
.map_err(|e| SummarizerError::BackendError(e.to_string()))?;
|
|
|
|
Ok(Self {
|
|
backend: Some(backend),
|
|
model: None,
|
|
model_path: None,
|
|
})
|
|
}
|
|
|
|
/// Load a LLaMA model from the given path.
|
|
pub fn load_model(&mut self, model_path: &Path) -> Result<(), SummarizerError> {
|
|
info!("Loading LLaMA model from: {}", model_path.display());
|
|
|
|
let backend = self.backend.as_ref()
|
|
.ok_or(SummarizerError::BackendError("Backend not initialized".to_string()))?;
|
|
|
|
let model_params = LlamaModelParams::default();
|
|
|
|
let model = LlamaModel::load_from_file(backend, model_path, &model_params)
|
|
.map_err(|e| SummarizerError::ModelLoadError(e.to_string()))?;
|
|
|
|
self.model = Some(model);
|
|
self.model_path = Some(model_path.to_string_lossy().to_string());
|
|
|
|
info!("LLaMA model loaded successfully");
|
|
Ok(())
|
|
}
|
|
|
|
/// Check if a model is loaded.
|
|
pub fn is_loaded(&self) -> bool {
|
|
self.model.is_some()
|
|
}
|
|
|
|
/// Generate a meeting summary from the transcript.
|
|
pub fn summarize(&self, transcript: &str) -> Result<String, SummarizerError> {
|
|
self.summarize_with_progress(transcript, |_| {})
|
|
}
|
|
|
|
/// Generate a meeting summary with progress callback.
|
|
pub fn summarize_with_progress<F>(&self, transcript: &str, mut progress_fn: F) -> Result<String, SummarizerError>
|
|
where
|
|
F: FnMut(f64),
|
|
{
|
|
let model = self
|
|
.model
|
|
.as_ref()
|
|
.ok_or(SummarizerError::ModelNotLoaded)?;
|
|
|
|
let prompt = format!(
|
|
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n\
|
|
You are a helpful assistant that creates structured meeting summaries. \
|
|
Format your response using the following template:\n\n\
|
|
**Title:** A concise, descriptive title for the meeting (5-10 words max).\n\n\
|
|
**Summary:** A high level overview of the meeting.\n\n\
|
|
**Key decisions:** Any important resolutions that the meeting reached.\n\n\
|
|
**Action Items:**\n\
|
|
- list of things that must be addressed\n\n\
|
|
**Discussion points:**\n\
|
|
- a list of each topic/argument/counterpoint brought up in the meeting.\n\n\
|
|
Be concise but comprehensive. Focus on capturing all important information.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n\
|
|
Please summarize this meeting transcript:\n\n\
|
|
{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
|
transcript
|
|
);
|
|
|
|
debug!("Generating summary for {} character transcript", transcript.len());
|
|
|
|
// Report initial progress
|
|
progress_fn(0.0);
|
|
|
|
// Create context for inference
|
|
let ctx_params = LlamaContextParams::default()
|
|
.with_n_ctx(NonZeroU32::new(8192));
|
|
|
|
let mut ctx = model
|
|
.new_context(self.backend.as_ref().unwrap(), ctx_params)
|
|
.map_err(|e| SummarizerError::ContextError(e.to_string()))?;
|
|
|
|
// Report progress after context creation
|
|
progress_fn(0.1);
|
|
|
|
// Tokenize the prompt
|
|
let tokens = model
|
|
.str_to_token(&prompt, llama_cpp_2::model::AddBos::Always)
|
|
.map_err(|e| SummarizerError::SummarizationError(e.to_string()))?;
|
|
|
|
debug!("Prompt tokenized to {} tokens", tokens.len());
|
|
|
|
// Report progress after tokenization
|
|
progress_fn(0.2);
|
|
|
|
// Create batch and add tokens
|
|
let mut batch = LlamaBatch::new(tokens.len(), 1);
|
|
|
|
for (i, token) in tokens.iter().enumerate() {
|
|
let is_last = i == tokens.len() - 1;
|
|
batch.add(*token, i as i32, &[0], is_last)
|
|
.map_err(|e| SummarizerError::SummarizationError(e.to_string()))?;
|
|
}
|
|
|
|
// Decode the prompt
|
|
ctx.decode(&mut batch)
|
|
.map_err(|e| SummarizerError::SummarizationError(e.to_string()))?;
|
|
|
|
// Report progress after prompt decode
|
|
progress_fn(0.3);
|
|
|
|
// Generate response tokens
|
|
let mut output_tokens = Vec::new();
|
|
let max_tokens = 1024;
|
|
let mut n_cur = tokens.len();
|
|
|
|
// Create sampler for token selection
|
|
let mut sampler = LlamaSampler::chain_simple([
|
|
LlamaSampler::temp(0.7),
|
|
LlamaSampler::top_p(0.9, 1),
|
|
LlamaSampler::dist(42),
|
|
]);
|
|
|
|
for i in 0..max_tokens {
|
|
// Sample the next token using the sampler
|
|
// The sampler.sample() takes context and the index of the last token in the batch
|
|
let token = sampler.sample(&ctx, (batch.n_tokens() - 1) as i32);
|
|
|
|
// Check for end of generation
|
|
if model.is_eog_token(token) {
|
|
break;
|
|
}
|
|
|
|
output_tokens.push(token);
|
|
|
|
// Report progress during generation (from 30% to 90%)
|
|
let generation_progress = i as f64 / max_tokens as f64;
|
|
progress_fn(0.3 + (generation_progress * 0.6));
|
|
|
|
// Prepare next batch
|
|
batch.clear();
|
|
batch.add(token, n_cur as i32, &[0], true)
|
|
.map_err(|e| SummarizerError::SummarizationError(e.to_string()))?;
|
|
|
|
n_cur += 1;
|
|
|
|
// Decode
|
|
ctx.decode(&mut batch)
|
|
.map_err(|e| SummarizerError::SummarizationError(e.to_string()))?;
|
|
}
|
|
|
|
// Convert tokens to string
|
|
let output_text = output_tokens
|
|
.iter()
|
|
.filter_map(|t| model.token_to_str(*t, llama_cpp_2::model::Special::Tokenize).ok())
|
|
.collect::<String>();
|
|
|
|
// Report completion
|
|
progress_fn(1.0);
|
|
|
|
info!("Summary generated: {} characters", output_text.len());
|
|
Ok(output_text.trim().to_string())
|
|
}
|
|
}
|
|
|
|
impl Default for LlamaSummarizer {
|
|
fn default() -> Self {
|
|
Self::new().expect("Failed to create summarizer")
|
|
}
|
|
}
|
|
|
|
/// Get the expected model filename for a given size.
|
|
pub fn get_model_filename(model_size: &str) -> String {
|
|
format!("Llama-3.2-{}-Instruct-Q4_K_M.gguf", model_size)
|
|
}
|
|
|
|
/// Check if a LLaMA model exists at the given path.
|
|
pub fn model_exists(models_dir: &Path, model_size: &str) -> bool {
|
|
let model_path = models_dir.join(get_model_filename(model_size));
|
|
model_path.exists()
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_get_model_filename() {
|
|
assert_eq!(
|
|
get_model_filename("3B"),
|
|
"Llama-3.2-3B-Instruct-Q4_K_M.gguf"
|
|
);
|
|
}
|
|
}
|