generated from nhcarrigan/template
feat: we successfully have the installer working for windows!
Models are downloaded at runtime instead of build.
This commit is contained in:
@@ -0,0 +1,198 @@
|
||||
//! LLaMA summarization using llama-cpp-2.
|
||||
|
||||
use llama_cpp_2::context::params::LlamaContextParams;
|
||||
use llama_cpp_2::llama_backend::LlamaBackend;
|
||||
use llama_cpp_2::llama_batch::LlamaBatch;
|
||||
use llama_cpp_2::model::params::LlamaModelParams;
|
||||
use llama_cpp_2::model::LlamaModel;
|
||||
use llama_cpp_2::sampling::LlamaSampler;
|
||||
use std::num::NonZeroU32;
|
||||
use std::path::Path;
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, info};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum SummarizerError {
|
||||
#[error("Failed to initialize LLaMA backend: {0}")]
|
||||
BackendError(String),
|
||||
#[error("Failed to load LLaMA model: {0}")]
|
||||
ModelLoadError(String),
|
||||
#[error("Failed to create context: {0}")]
|
||||
ContextError(String),
|
||||
#[error("Summarization failed: {0}")]
|
||||
SummarizationError(String),
|
||||
#[error("Model not loaded")]
|
||||
ModelNotLoaded,
|
||||
}
|
||||
|
||||
/// LLaMA summarizer using llama-cpp-2 bindings.
|
||||
pub struct LlamaSummarizer {
|
||||
backend: Option<LlamaBackend>,
|
||||
model: Option<LlamaModel>,
|
||||
model_path: Option<String>,
|
||||
}
|
||||
|
||||
impl LlamaSummarizer {
|
||||
/// Create a new summarizer instance (without loading a model).
|
||||
pub fn new() -> Result<Self, SummarizerError> {
|
||||
let backend = LlamaBackend::init()
|
||||
.map_err(|e| SummarizerError::BackendError(e.to_string()))?;
|
||||
|
||||
Ok(Self {
|
||||
backend: Some(backend),
|
||||
model: None,
|
||||
model_path: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Load a LLaMA model from the given path.
|
||||
pub fn load_model(&mut self, model_path: &Path) -> Result<(), SummarizerError> {
|
||||
info!("Loading LLaMA model from: {}", model_path.display());
|
||||
|
||||
let backend = self.backend.as_ref()
|
||||
.ok_or(SummarizerError::BackendError("Backend not initialized".to_string()))?;
|
||||
|
||||
let model_params = LlamaModelParams::default();
|
||||
|
||||
let model = LlamaModel::load_from_file(backend, model_path, &model_params)
|
||||
.map_err(|e| SummarizerError::ModelLoadError(e.to_string()))?;
|
||||
|
||||
self.model = Some(model);
|
||||
self.model_path = Some(model_path.to_string_lossy().to_string());
|
||||
|
||||
info!("LLaMA model loaded successfully");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a model is loaded.
|
||||
pub fn is_loaded(&self) -> bool {
|
||||
self.model.is_some()
|
||||
}
|
||||
|
||||
/// Generate a meeting summary from the transcript.
|
||||
pub fn summarize(&self, transcript: &str) -> Result<String, SummarizerError> {
|
||||
let model = self
|
||||
.model
|
||||
.as_ref()
|
||||
.ok_or(SummarizerError::ModelNotLoaded)?;
|
||||
|
||||
let prompt = format!(
|
||||
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n\
|
||||
You are a helpful assistant that creates concise meeting summaries. Focus on:\n\
|
||||
- Key decisions made\n\
|
||||
- Action items and who owns them\n\
|
||||
- Important discussions and their outcomes\n\
|
||||
- Next steps\n\n\
|
||||
Keep the summary structured and easy to scan.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n\
|
||||
Please summarize this meeting transcript:\n\n\
|
||||
{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n\
|
||||
Meeting Summary:\n",
|
||||
transcript
|
||||
);
|
||||
|
||||
debug!("Generating summary for {} character transcript", transcript.len());
|
||||
|
||||
// Create context for inference
|
||||
let ctx_params = LlamaContextParams::default()
|
||||
.with_n_ctx(NonZeroU32::new(8192));
|
||||
|
||||
let mut ctx = model
|
||||
.new_context(self.backend.as_ref().unwrap(), ctx_params)
|
||||
.map_err(|e| SummarizerError::ContextError(e.to_string()))?;
|
||||
|
||||
// Tokenize the prompt
|
||||
let tokens = model
|
||||
.str_to_token(&prompt, llama_cpp_2::model::AddBos::Always)
|
||||
.map_err(|e| SummarizerError::SummarizationError(e.to_string()))?;
|
||||
|
||||
debug!("Prompt tokenized to {} tokens", tokens.len());
|
||||
|
||||
// Create batch and add tokens
|
||||
let mut batch = LlamaBatch::new(tokens.len(), 1);
|
||||
|
||||
for (i, token) in tokens.iter().enumerate() {
|
||||
let is_last = i == tokens.len() - 1;
|
||||
batch.add(*token, i as i32, &[0], is_last)
|
||||
.map_err(|e| SummarizerError::SummarizationError(e.to_string()))?;
|
||||
}
|
||||
|
||||
// Decode the prompt
|
||||
ctx.decode(&mut batch)
|
||||
.map_err(|e| SummarizerError::SummarizationError(e.to_string()))?;
|
||||
|
||||
// Generate response tokens
|
||||
let mut output_tokens = Vec::new();
|
||||
let max_tokens = 1024;
|
||||
let mut n_cur = tokens.len();
|
||||
|
||||
// Create sampler for token selection
|
||||
let mut sampler = LlamaSampler::chain_simple([
|
||||
LlamaSampler::temp(0.7),
|
||||
LlamaSampler::top_p(0.9, 1),
|
||||
LlamaSampler::dist(42),
|
||||
]);
|
||||
|
||||
for _ in 0..max_tokens {
|
||||
// Sample the next token using the sampler
|
||||
// The sampler.sample() takes context and the index of the last token in the batch
|
||||
let token = sampler.sample(&ctx, (batch.n_tokens() - 1) as i32);
|
||||
|
||||
// Check for end of generation
|
||||
if model.is_eog_token(token) {
|
||||
break;
|
||||
}
|
||||
|
||||
output_tokens.push(token);
|
||||
|
||||
// Prepare next batch
|
||||
batch.clear();
|
||||
batch.add(token, n_cur as i32, &[0], true)
|
||||
.map_err(|e| SummarizerError::SummarizationError(e.to_string()))?;
|
||||
|
||||
n_cur += 1;
|
||||
|
||||
// Decode
|
||||
ctx.decode(&mut batch)
|
||||
.map_err(|e| SummarizerError::SummarizationError(e.to_string()))?;
|
||||
}
|
||||
|
||||
// Convert tokens to string
|
||||
let output_text = output_tokens
|
||||
.iter()
|
||||
.filter_map(|t| model.token_to_str(*t, llama_cpp_2::model::Special::Tokenize).ok())
|
||||
.collect::<String>();
|
||||
|
||||
info!("Summary generated: {} characters", output_text.len());
|
||||
Ok(output_text.trim().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LlamaSummarizer {
|
||||
fn default() -> Self {
|
||||
Self::new().expect("Failed to create summarizer")
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the expected model filename for a given size.
|
||||
pub fn get_model_filename(model_size: &str) -> String {
|
||||
format!("Llama-3.2-{}-Instruct-Q4_K_M.gguf", model_size)
|
||||
}
|
||||
|
||||
/// Check if a LLaMA model exists at the given path.
|
||||
pub fn model_exists(models_dir: &Path, model_size: &str) -> bool {
|
||||
let model_path = models_dir.join(get_model_filename(model_size));
|
||||
model_path.exists()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_get_model_filename() {
|
||||
assert_eq!(
|
||||
get_model_filename("3B"),
|
||||
"Llama-3.2-3B-Instruct-Q4_K_M.gguf"
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user