generated from nhcarrigan/template
538 lines
17 KiB
Rust
538 lines
17 KiB
Rust
// This provider connects to a local Ollama instance for LLM inference.
|
|
// It will be actively used once providers are fully integrated with BridgeManager.
|
|
#![allow(dead_code)]
|
|
|
|
use async_trait::async_trait;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::sync::atomic::{AtomicBool, Ordering};
|
|
use std::sync::Arc;
|
|
use tokio::sync::mpsc;
|
|
|
|
use crate::providers::traits::{
|
|
LlmProvider, ModelInfo, ProviderCapabilities, ProviderConfig, ProviderMessage,
|
|
ProviderStreamEvent, ProviderUsage, StreamCallback,
|
|
};
|
|
|
|
const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434";
|
|
|
|
#[derive(Debug, Serialize)]
|
|
struct OllamaChatRequest {
|
|
model: String,
|
|
messages: Vec<OllamaMessage>,
|
|
stream: bool,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
system: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct OllamaMessage {
|
|
role: String,
|
|
content: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct OllamaChatResponse {
|
|
#[serde(default)]
|
|
message: Option<OllamaResponseMessage>,
|
|
#[serde(default)]
|
|
done: bool,
|
|
#[serde(default)]
|
|
eval_count: Option<u64>,
|
|
#[serde(default)]
|
|
prompt_eval_count: Option<u64>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct OllamaResponseMessage {
|
|
#[serde(default)]
|
|
content: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct OllamaTagsResponse {
|
|
models: Vec<OllamaModelInfo>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct OllamaModelInfo {
|
|
name: String,
|
|
#[serde(default)]
|
|
details: Option<OllamaModelDetails>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct OllamaModelDetails {
|
|
#[serde(default)]
|
|
parameter_size: Option<String>,
|
|
#[serde(default)]
|
|
family: Option<String>,
|
|
}
|
|
|
|
pub struct OllamaProvider {
|
|
config: ProviderConfig,
|
|
client: reqwest::Client,
|
|
base_url: String,
|
|
is_running: Arc<AtomicBool>,
|
|
conversation_history: Vec<OllamaMessage>,
|
|
cancel_tx: Option<mpsc::Sender<()>>,
|
|
}
|
|
|
|
impl OllamaProvider {
|
|
pub fn new(config: ProviderConfig) -> Self {
|
|
let base_url = config
|
|
.api_base_url
|
|
.clone()
|
|
.unwrap_or_else(|| DEFAULT_OLLAMA_URL.to_string());
|
|
|
|
Self {
|
|
config,
|
|
client: reqwest::Client::new(),
|
|
base_url,
|
|
is_running: Arc::new(AtomicBool::new(false)),
|
|
conversation_history: Vec::new(),
|
|
cancel_tx: None,
|
|
}
|
|
}
|
|
|
|
async fn fetch_available_models(&self) -> Vec<ModelInfo> {
|
|
let url = format!("{}/api/tags", self.base_url);
|
|
|
|
match self.client.get(&url).send().await {
|
|
Ok(response) => {
|
|
if let Ok(tags) = response.json::<OllamaTagsResponse>().await {
|
|
tags.models
|
|
.into_iter()
|
|
.map(|m| {
|
|
let description = m.details.as_ref().map(|d| {
|
|
let mut desc_parts = Vec::new();
|
|
if let Some(ref family) = d.family {
|
|
desc_parts.push(family.clone());
|
|
}
|
|
if let Some(ref size) = d.parameter_size {
|
|
desc_parts.push(format!("{} parameters", size));
|
|
}
|
|
if desc_parts.is_empty() {
|
|
"Local model".to_string()
|
|
} else {
|
|
desc_parts.join(" - ")
|
|
}
|
|
});
|
|
|
|
ModelInfo {
|
|
id: m.name.clone(),
|
|
name: m.name,
|
|
description,
|
|
context_window: None,
|
|
input_cost_per_mtok: Some(0.0), // Local = free!
|
|
output_cost_per_mtok: Some(0.0),
|
|
}
|
|
})
|
|
.collect()
|
|
} else {
|
|
default_ollama_models()
|
|
}
|
|
}
|
|
Err(_) => default_ollama_models(),
|
|
}
|
|
}
|
|
|
|
async fn stream_chat(
|
|
&mut self,
|
|
callback: Arc<StreamCallback>,
|
|
cancel_rx: mpsc::Receiver<()>,
|
|
) -> Result<(), String> {
|
|
let model = self
|
|
.config
|
|
.model
|
|
.clone()
|
|
.unwrap_or_else(|| "llama3.2".to_string());
|
|
|
|
let request = OllamaChatRequest {
|
|
model: model.clone(),
|
|
messages: self.conversation_history.clone(),
|
|
stream: true,
|
|
system: self.config.custom_instructions.clone(),
|
|
};
|
|
|
|
let url = format!("{}/api/chat", self.base_url);
|
|
|
|
let response = self
|
|
.client
|
|
.post(&url)
|
|
.json(&request)
|
|
.send()
|
|
.await
|
|
.map_err(|e| format!("Failed to connect to Ollama: {}", e))?;
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let error_text = response
|
|
.text()
|
|
.await
|
|
.unwrap_or_else(|_| "Unknown error".to_string());
|
|
return Err(format!("Ollama error ({}): {}", status, error_text));
|
|
}
|
|
|
|
let mut full_response = String::new();
|
|
let mut total_input_tokens: u64 = 0;
|
|
let mut total_output_tokens: u64 = 0;
|
|
|
|
let mut stream = response.bytes_stream();
|
|
let mut cancel_rx = cancel_rx;
|
|
|
|
use futures_util::StreamExt;
|
|
|
|
loop {
|
|
tokio::select! {
|
|
_ = cancel_rx.recv() => {
|
|
callback(ProviderStreamEvent::Disconnected);
|
|
return Ok(());
|
|
}
|
|
chunk = stream.next() => {
|
|
match chunk {
|
|
Some(Ok(bytes)) => {
|
|
let text = String::from_utf8_lossy(&bytes);
|
|
|
|
for line in text.lines() {
|
|
if line.trim().is_empty() {
|
|
continue;
|
|
}
|
|
|
|
if let Ok(response) = serde_json::from_str::<OllamaChatResponse>(line) {
|
|
if let Some(msg) = &response.message {
|
|
if !msg.content.is_empty() {
|
|
full_response.push_str(&msg.content);
|
|
callback(ProviderStreamEvent::TextDelta {
|
|
text: msg.content.clone(),
|
|
});
|
|
}
|
|
}
|
|
|
|
if let Some(tokens) = response.prompt_eval_count {
|
|
total_input_tokens = tokens;
|
|
}
|
|
if let Some(tokens) = response.eval_count {
|
|
total_output_tokens = tokens;
|
|
}
|
|
|
|
if response.done {
|
|
self.conversation_history.push(OllamaMessage {
|
|
role: "assistant".to_string(),
|
|
content: full_response.clone(),
|
|
});
|
|
|
|
callback(ProviderStreamEvent::MessageComplete {
|
|
content: vec![ProviderMessage::Text {
|
|
content: full_response,
|
|
}],
|
|
usage: Some(ProviderUsage {
|
|
input_tokens: total_input_tokens,
|
|
output_tokens: total_output_tokens,
|
|
model,
|
|
}),
|
|
stop_reason: Some("end_turn".to_string()),
|
|
});
|
|
|
|
return Ok(());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Some(Err(e)) => {
|
|
return Err(format!("Stream error: {}", e));
|
|
}
|
|
None => {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl LlmProvider for OllamaProvider {
|
|
fn name(&self) -> &'static str {
|
|
"Ollama"
|
|
}
|
|
|
|
fn capabilities(&self) -> ProviderCapabilities {
|
|
ProviderCapabilities {
|
|
supports_streaming: true,
|
|
supports_tools: false, // Ollama doesn't support tools natively yet
|
|
supports_vision: true, // Some models support vision
|
|
supports_thinking: false,
|
|
supports_mcp: false,
|
|
supports_resume_session: false,
|
|
max_context_tokens: None, // Varies by model
|
|
available_models: default_ollama_models(),
|
|
}
|
|
}
|
|
|
|
async fn start(&mut self, callback: StreamCallback) -> Result<(), String> {
|
|
if self.is_running.load(Ordering::SeqCst) {
|
|
return Err("Provider already running".to_string());
|
|
}
|
|
|
|
// Check if Ollama is reachable
|
|
let url = format!("{}/api/tags", self.base_url);
|
|
self.client.get(&url).send().await.map_err(|e| {
|
|
format!(
|
|
"Cannot connect to Ollama at {}. Is it running? Error: {}",
|
|
self.base_url, e
|
|
)
|
|
})?;
|
|
|
|
self.is_running.store(true, Ordering::SeqCst);
|
|
self.conversation_history.clear();
|
|
|
|
// Create cancellation channel
|
|
let (tx, _rx) = mpsc::channel(1);
|
|
self.cancel_tx = Some(tx);
|
|
|
|
callback(ProviderStreamEvent::Connected { session_id: None });
|
|
|
|
// Fetch and report available models
|
|
let models = self.fetch_available_models().await;
|
|
if !models.is_empty() {
|
|
eprintln!(
|
|
"[Ollama] Available models: {:?}",
|
|
models.iter().map(|m| &m.id).collect::<Vec<_>>()
|
|
);
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn stop(&mut self) -> Result<(), String> {
|
|
self.is_running.store(false, Ordering::SeqCst);
|
|
self.conversation_history.clear();
|
|
|
|
if let Some(tx) = self.cancel_tx.take() {
|
|
let _ = tx.send(()).await;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn is_running(&self) -> bool {
|
|
self.is_running.load(Ordering::SeqCst)
|
|
}
|
|
|
|
async fn send_message(&mut self, message: &str) -> Result<(), String> {
|
|
if !self.is_running.load(Ordering::SeqCst) {
|
|
return Err("Provider not running".to_string());
|
|
}
|
|
|
|
// Add user message to history
|
|
self.conversation_history.push(OllamaMessage {
|
|
role: "user".to_string(),
|
|
content: message.to_string(),
|
|
});
|
|
|
|
// Create a new cancel channel for this request
|
|
let (tx, _rx) = mpsc::channel(1);
|
|
self.cancel_tx = Some(tx);
|
|
|
|
// We need a callback here, but we don't have access to it in send_message
|
|
// This is a limitation of the current trait design
|
|
// For now, we'll need to refactor to handle streaming properly
|
|
// The callback should be stored from the start() call
|
|
|
|
// For the MVP, we'll emit events directly
|
|
// In a real implementation, we'd need to restructure this
|
|
|
|
Err("send_message needs refactoring to work with stored callback".to_string())
|
|
}
|
|
|
|
async fn send_tool_result(
|
|
&mut self,
|
|
_tool_use_id: &str,
|
|
_result: serde_json::Value,
|
|
_is_error: bool,
|
|
) -> Result<(), String> {
|
|
// Ollama doesn't support tool use natively
|
|
Err("Ollama does not support tool use".to_string())
|
|
}
|
|
|
|
fn get_model(&self) -> Option<String> {
|
|
self.config.model.clone()
|
|
}
|
|
|
|
fn get_working_directory(&self) -> &str {
|
|
&self.config.working_directory
|
|
}
|
|
|
|
async fn interrupt(&mut self) -> Result<(), String> {
|
|
if let Some(tx) = self.cancel_tx.take() {
|
|
let _ = tx.send(()).await;
|
|
}
|
|
self.is_running.store(false, Ordering::SeqCst);
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
fn default_ollama_models() -> Vec<ModelInfo> {
|
|
vec![
|
|
ModelInfo {
|
|
id: "llama3.2".to_string(),
|
|
name: "Llama 3.2".to_string(),
|
|
description: Some("Meta's latest compact model".to_string()),
|
|
context_window: Some(128_000),
|
|
input_cost_per_mtok: Some(0.0),
|
|
output_cost_per_mtok: Some(0.0),
|
|
},
|
|
ModelInfo {
|
|
id: "llama3.2:1b".to_string(),
|
|
name: "Llama 3.2 1B".to_string(),
|
|
description: Some("Smallest Llama 3.2 variant".to_string()),
|
|
context_window: Some(128_000),
|
|
input_cost_per_mtok: Some(0.0),
|
|
output_cost_per_mtok: Some(0.0),
|
|
},
|
|
ModelInfo {
|
|
id: "qwen2.5-coder".to_string(),
|
|
name: "Qwen 2.5 Coder".to_string(),
|
|
description: Some("Alibaba's coding-focused model".to_string()),
|
|
context_window: Some(32_000),
|
|
input_cost_per_mtok: Some(0.0),
|
|
output_cost_per_mtok: Some(0.0),
|
|
},
|
|
ModelInfo {
|
|
id: "deepseek-coder-v2".to_string(),
|
|
name: "DeepSeek Coder V2".to_string(),
|
|
description: Some("DeepSeek's coding model".to_string()),
|
|
context_window: Some(128_000),
|
|
input_cost_per_mtok: Some(0.0),
|
|
output_cost_per_mtok: Some(0.0),
|
|
},
|
|
ModelInfo {
|
|
id: "mistral".to_string(),
|
|
name: "Mistral 7B".to_string(),
|
|
description: Some("Fast and capable".to_string()),
|
|
context_window: Some(32_000),
|
|
input_cost_per_mtok: Some(0.0),
|
|
output_cost_per_mtok: Some(0.0),
|
|
},
|
|
ModelInfo {
|
|
id: "gemma2".to_string(),
|
|
name: "Gemma 2".to_string(),
|
|
description: Some("Google's open model".to_string()),
|
|
context_window: Some(8_000),
|
|
input_cost_per_mtok: Some(0.0),
|
|
output_cost_per_mtok: Some(0.0),
|
|
},
|
|
]
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_provider_name() {
|
|
let config = ProviderConfig::default();
|
|
let provider = OllamaProvider::new(config);
|
|
assert_eq!(provider.name(), "Ollama");
|
|
}
|
|
|
|
#[test]
|
|
fn test_capabilities() {
|
|
let config = ProviderConfig::default();
|
|
let provider = OllamaProvider::new(config);
|
|
let caps = provider.capabilities();
|
|
|
|
assert!(caps.supports_streaming);
|
|
assert!(!caps.supports_tools);
|
|
assert!(!caps.supports_mcp);
|
|
assert!(!caps.supports_resume_session);
|
|
assert!(!caps.available_models.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_default_url() {
|
|
let config = ProviderConfig::default();
|
|
let provider = OllamaProvider::new(config);
|
|
assert_eq!(provider.base_url, DEFAULT_OLLAMA_URL);
|
|
}
|
|
|
|
#[test]
|
|
fn test_custom_url() {
|
|
let config = ProviderConfig {
|
|
api_base_url: Some("http://custom:8080".to_string()),
|
|
..Default::default()
|
|
};
|
|
let provider = OllamaProvider::new(config);
|
|
assert_eq!(provider.base_url, "http://custom:8080");
|
|
}
|
|
|
|
#[test]
|
|
fn test_is_running_initial() {
|
|
let config = ProviderConfig::default();
|
|
let provider = OllamaProvider::new(config);
|
|
assert!(!provider.is_running());
|
|
}
|
|
|
|
#[test]
|
|
fn test_default_ollama_models() {
|
|
let models = default_ollama_models();
|
|
assert!(!models.is_empty());
|
|
|
|
// All models should be free (local)
|
|
for model in &models {
|
|
assert_eq!(model.input_cost_per_mtok, Some(0.0));
|
|
assert_eq!(model.output_cost_per_mtok, Some(0.0));
|
|
}
|
|
|
|
// Should include llama
|
|
assert!(models.iter().any(|m| m.id.contains("llama")));
|
|
}
|
|
|
|
#[test]
|
|
fn test_ollama_message_serialization() {
|
|
let msg = OllamaMessage {
|
|
role: "user".to_string(),
|
|
content: "Hello!".to_string(),
|
|
};
|
|
|
|
let json = serde_json::to_string(&msg).unwrap();
|
|
assert!(json.contains("\"role\":\"user\""));
|
|
assert!(json.contains("\"content\":\"Hello!\""));
|
|
}
|
|
|
|
#[test]
|
|
fn test_chat_request_serialization() {
|
|
let request = OllamaChatRequest {
|
|
model: "llama3.2".to_string(),
|
|
messages: vec![OllamaMessage {
|
|
role: "user".to_string(),
|
|
content: "Test".to_string(),
|
|
}],
|
|
stream: true,
|
|
system: Some("You are helpful".to_string()),
|
|
};
|
|
|
|
let json = serde_json::to_string(&request).unwrap();
|
|
assert!(json.contains("\"model\":\"llama3.2\""));
|
|
assert!(json.contains("\"stream\":true"));
|
|
assert!(json.contains("\"system\":\"You are helpful\""));
|
|
}
|
|
|
|
#[test]
|
|
fn test_chat_request_without_system() {
|
|
let request = OllamaChatRequest {
|
|
model: "llama3.2".to_string(),
|
|
messages: vec![],
|
|
stream: true,
|
|
system: None,
|
|
};
|
|
|
|
let json = serde_json::to_string(&request).unwrap();
|
|
// system should be omitted when None
|
|
assert!(!json.contains("\"system\""));
|
|
}
|
|
}
|