feat: support ollama

This commit is contained in:
2026-02-04 13:19:10 -08:00
parent daedbfd865
commit a0804ed32a
13 changed files with 2480 additions and 5 deletions
+537
View File
@@ -0,0 +1,537 @@
// This provider connects to a local Ollama instance for LLM inference.
// It will be actively used once providers are fully integrated with BridgeManager.
#![allow(dead_code)]
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::mpsc;
use crate::providers::traits::{
LlmProvider, ModelInfo, ProviderCapabilities, ProviderConfig, ProviderMessage,
ProviderStreamEvent, ProviderUsage, StreamCallback,
};
const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434";
#[derive(Debug, Serialize)]
struct OllamaChatRequest {
model: String,
messages: Vec<OllamaMessage>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct OllamaMessage {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct OllamaChatResponse {
#[serde(default)]
message: Option<OllamaResponseMessage>,
#[serde(default)]
done: bool,
#[serde(default)]
eval_count: Option<u64>,
#[serde(default)]
prompt_eval_count: Option<u64>,
}
#[derive(Debug, Deserialize)]
struct OllamaResponseMessage {
#[serde(default)]
content: String,
}
#[derive(Debug, Deserialize)]
struct OllamaTagsResponse {
models: Vec<OllamaModelInfo>,
}
#[derive(Debug, Deserialize)]
struct OllamaModelInfo {
name: String,
#[serde(default)]
details: Option<OllamaModelDetails>,
}
#[derive(Debug, Deserialize)]
struct OllamaModelDetails {
#[serde(default)]
parameter_size: Option<String>,
#[serde(default)]
family: Option<String>,
}
pub struct OllamaProvider {
config: ProviderConfig,
client: reqwest::Client,
base_url: String,
is_running: Arc<AtomicBool>,
conversation_history: Vec<OllamaMessage>,
cancel_tx: Option<mpsc::Sender<()>>,
}
impl OllamaProvider {
pub fn new(config: ProviderConfig) -> Self {
let base_url = config
.api_base_url
.clone()
.unwrap_or_else(|| DEFAULT_OLLAMA_URL.to_string());
Self {
config,
client: reqwest::Client::new(),
base_url,
is_running: Arc::new(AtomicBool::new(false)),
conversation_history: Vec::new(),
cancel_tx: None,
}
}
async fn fetch_available_models(&self) -> Vec<ModelInfo> {
let url = format!("{}/api/tags", self.base_url);
match self.client.get(&url).send().await {
Ok(response) => {
if let Ok(tags) = response.json::<OllamaTagsResponse>().await {
tags.models
.into_iter()
.map(|m| {
let description = m.details.as_ref().map(|d| {
let mut desc_parts = Vec::new();
if let Some(ref family) = d.family {
desc_parts.push(family.clone());
}
if let Some(ref size) = d.parameter_size {
desc_parts.push(format!("{} parameters", size));
}
if desc_parts.is_empty() {
"Local model".to_string()
} else {
desc_parts.join(" - ")
}
});
ModelInfo {
id: m.name.clone(),
name: m.name,
description,
context_window: None,
input_cost_per_mtok: Some(0.0), // Local = free!
output_cost_per_mtok: Some(0.0),
}
})
.collect()
} else {
default_ollama_models()
}
}
Err(_) => default_ollama_models(),
}
}
async fn stream_chat(
&mut self,
callback: Arc<StreamCallback>,
cancel_rx: mpsc::Receiver<()>,
) -> Result<(), String> {
let model = self
.config
.model
.clone()
.unwrap_or_else(|| "llama3.2".to_string());
let request = OllamaChatRequest {
model: model.clone(),
messages: self.conversation_history.clone(),
stream: true,
system: self.config.custom_instructions.clone(),
};
let url = format!("{}/api/chat", self.base_url);
let response = self
.client
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| format!("Failed to connect to Ollama: {}", e))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(format!("Ollama error ({}): {}", status, error_text));
}
let mut full_response = String::new();
let mut total_input_tokens: u64 = 0;
let mut total_output_tokens: u64 = 0;
let mut stream = response.bytes_stream();
let mut cancel_rx = cancel_rx;
use futures_util::StreamExt;
loop {
tokio::select! {
_ = cancel_rx.recv() => {
callback(ProviderStreamEvent::Disconnected);
return Ok(());
}
chunk = stream.next() => {
match chunk {
Some(Ok(bytes)) => {
let text = String::from_utf8_lossy(&bytes);
for line in text.lines() {
if line.trim().is_empty() {
continue;
}
if let Ok(response) = serde_json::from_str::<OllamaChatResponse>(line) {
if let Some(msg) = &response.message {
if !msg.content.is_empty() {
full_response.push_str(&msg.content);
callback(ProviderStreamEvent::TextDelta {
text: msg.content.clone(),
});
}
}
if let Some(tokens) = response.prompt_eval_count {
total_input_tokens = tokens;
}
if let Some(tokens) = response.eval_count {
total_output_tokens = tokens;
}
if response.done {
self.conversation_history.push(OllamaMessage {
role: "assistant".to_string(),
content: full_response.clone(),
});
callback(ProviderStreamEvent::MessageComplete {
content: vec![ProviderMessage::Text {
content: full_response,
}],
usage: Some(ProviderUsage {
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
model,
}),
stop_reason: Some("end_turn".to_string()),
});
return Ok(());
}
}
}
}
Some(Err(e)) => {
return Err(format!("Stream error: {}", e));
}
None => {
break;
}
}
}
}
}
Ok(())
}
}
#[async_trait]
impl LlmProvider for OllamaProvider {
fn name(&self) -> &'static str {
"Ollama"
}
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities {
supports_streaming: true,
supports_tools: false, // Ollama doesn't support tools natively yet
supports_vision: true, // Some models support vision
supports_thinking: false,
supports_mcp: false,
supports_resume_session: false,
max_context_tokens: None, // Varies by model
available_models: default_ollama_models(),
}
}
async fn start(&mut self, callback: StreamCallback) -> Result<(), String> {
if self.is_running.load(Ordering::SeqCst) {
return Err("Provider already running".to_string());
}
// Check if Ollama is reachable
let url = format!("{}/api/tags", self.base_url);
self.client.get(&url).send().await.map_err(|e| {
format!(
"Cannot connect to Ollama at {}. Is it running? Error: {}",
self.base_url, e
)
})?;
self.is_running.store(true, Ordering::SeqCst);
self.conversation_history.clear();
// Create cancellation channel
let (tx, _rx) = mpsc::channel(1);
self.cancel_tx = Some(tx);
callback(ProviderStreamEvent::Connected { session_id: None });
// Fetch and report available models
let models = self.fetch_available_models().await;
if !models.is_empty() {
eprintln!(
"[Ollama] Available models: {:?}",
models.iter().map(|m| &m.id).collect::<Vec<_>>()
);
}
Ok(())
}
async fn stop(&mut self) -> Result<(), String> {
self.is_running.store(false, Ordering::SeqCst);
self.conversation_history.clear();
if let Some(tx) = self.cancel_tx.take() {
let _ = tx.send(()).await;
}
Ok(())
}
fn is_running(&self) -> bool {
self.is_running.load(Ordering::SeqCst)
}
async fn send_message(&mut self, message: &str) -> Result<(), String> {
if !self.is_running.load(Ordering::SeqCst) {
return Err("Provider not running".to_string());
}
// Add user message to history
self.conversation_history.push(OllamaMessage {
role: "user".to_string(),
content: message.to_string(),
});
// Create a new cancel channel for this request
let (tx, _rx) = mpsc::channel(1);
self.cancel_tx = Some(tx);
// We need a callback here, but we don't have access to it in send_message
// This is a limitation of the current trait design
// For now, we'll need to refactor to handle streaming properly
// The callback should be stored from the start() call
// For the MVP, we'll emit events directly
// In a real implementation, we'd need to restructure this
Err("send_message needs refactoring to work with stored callback".to_string())
}
async fn send_tool_result(
&mut self,
_tool_use_id: &str,
_result: serde_json::Value,
_is_error: bool,
) -> Result<(), String> {
// Ollama doesn't support tool use natively
Err("Ollama does not support tool use".to_string())
}
fn get_model(&self) -> Option<String> {
self.config.model.clone()
}
fn get_working_directory(&self) -> &str {
&self.config.working_directory
}
async fn interrupt(&mut self) -> Result<(), String> {
if let Some(tx) = self.cancel_tx.take() {
let _ = tx.send(()).await;
}
self.is_running.store(false, Ordering::SeqCst);
Ok(())
}
}
fn default_ollama_models() -> Vec<ModelInfo> {
vec![
ModelInfo {
id: "llama3.2".to_string(),
name: "Llama 3.2".to_string(),
description: Some("Meta's latest compact model".to_string()),
context_window: Some(128_000),
input_cost_per_mtok: Some(0.0),
output_cost_per_mtok: Some(0.0),
},
ModelInfo {
id: "llama3.2:1b".to_string(),
name: "Llama 3.2 1B".to_string(),
description: Some("Smallest Llama 3.2 variant".to_string()),
context_window: Some(128_000),
input_cost_per_mtok: Some(0.0),
output_cost_per_mtok: Some(0.0),
},
ModelInfo {
id: "qwen2.5-coder".to_string(),
name: "Qwen 2.5 Coder".to_string(),
description: Some("Alibaba's coding-focused model".to_string()),
context_window: Some(32_000),
input_cost_per_mtok: Some(0.0),
output_cost_per_mtok: Some(0.0),
},
ModelInfo {
id: "deepseek-coder-v2".to_string(),
name: "DeepSeek Coder V2".to_string(),
description: Some("DeepSeek's coding model".to_string()),
context_window: Some(128_000),
input_cost_per_mtok: Some(0.0),
output_cost_per_mtok: Some(0.0),
},
ModelInfo {
id: "mistral".to_string(),
name: "Mistral 7B".to_string(),
description: Some("Fast and capable".to_string()),
context_window: Some(32_000),
input_cost_per_mtok: Some(0.0),
output_cost_per_mtok: Some(0.0),
},
ModelInfo {
id: "gemma2".to_string(),
name: "Gemma 2".to_string(),
description: Some("Google's open model".to_string()),
context_window: Some(8_000),
input_cost_per_mtok: Some(0.0),
output_cost_per_mtok: Some(0.0),
},
]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_name() {
let config = ProviderConfig::default();
let provider = OllamaProvider::new(config);
assert_eq!(provider.name(), "Ollama");
}
#[test]
fn test_capabilities() {
let config = ProviderConfig::default();
let provider = OllamaProvider::new(config);
let caps = provider.capabilities();
assert!(caps.supports_streaming);
assert!(!caps.supports_tools);
assert!(!caps.supports_mcp);
assert!(!caps.supports_resume_session);
assert!(!caps.available_models.is_empty());
}
#[test]
fn test_default_url() {
let config = ProviderConfig::default();
let provider = OllamaProvider::new(config);
assert_eq!(provider.base_url, DEFAULT_OLLAMA_URL);
}
#[test]
fn test_custom_url() {
let config = ProviderConfig {
api_base_url: Some("http://custom:8080".to_string()),
..Default::default()
};
let provider = OllamaProvider::new(config);
assert_eq!(provider.base_url, "http://custom:8080");
}
#[test]
fn test_is_running_initial() {
let config = ProviderConfig::default();
let provider = OllamaProvider::new(config);
assert!(!provider.is_running());
}
#[test]
fn test_default_ollama_models() {
let models = default_ollama_models();
assert!(!models.is_empty());
// All models should be free (local)
for model in &models {
assert_eq!(model.input_cost_per_mtok, Some(0.0));
assert_eq!(model.output_cost_per_mtok, Some(0.0));
}
// Should include llama
assert!(models.iter().any(|m| m.id.contains("llama")));
}
#[test]
fn test_ollama_message_serialization() {
let msg = OllamaMessage {
role: "user".to_string(),
content: "Hello!".to_string(),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"role\":\"user\""));
assert!(json.contains("\"content\":\"Hello!\""));
}
#[test]
fn test_chat_request_serialization() {
let request = OllamaChatRequest {
model: "llama3.2".to_string(),
messages: vec![OllamaMessage {
role: "user".to_string(),
content: "Test".to_string(),
}],
stream: true,
system: Some("You are helpful".to_string()),
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("\"model\":\"llama3.2\""));
assert!(json.contains("\"stream\":true"));
assert!(json.contains("\"system\":\"You are helpful\""));
}
#[test]
fn test_chat_request_without_system() {
let request = OllamaChatRequest {
model: "llama3.2".to_string(),
messages: vec![],
stream: true,
system: None,
};
let json = serde_json::to_string(&request).unwrap();
// system should be omitted when None
assert!(!json.contains("\"system\""));
}
}