// This provider connects to a local Ollama instance for LLM inference. // It will be actively used once providers are fully integrated with BridgeManager. #![allow(dead_code)] use async_trait::async_trait; use serde::{Deserialize, Serialize}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use tokio::sync::mpsc; use crate::providers::traits::{ LlmProvider, ModelInfo, ProviderCapabilities, ProviderConfig, ProviderMessage, ProviderStreamEvent, ProviderUsage, StreamCallback, }; const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434"; #[derive(Debug, Serialize)] struct OllamaChatRequest { model: String, messages: Vec, stream: bool, #[serde(skip_serializing_if = "Option::is_none")] system: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] struct OllamaMessage { role: String, content: String, } #[derive(Debug, Deserialize)] struct OllamaChatResponse { #[serde(default)] message: Option, #[serde(default)] done: bool, #[serde(default)] eval_count: Option, #[serde(default)] prompt_eval_count: Option, } #[derive(Debug, Deserialize)] struct OllamaResponseMessage { #[serde(default)] content: String, } #[derive(Debug, Deserialize)] struct OllamaTagsResponse { models: Vec, } #[derive(Debug, Deserialize)] struct OllamaModelInfo { name: String, #[serde(default)] details: Option, } #[derive(Debug, Deserialize)] struct OllamaModelDetails { #[serde(default)] parameter_size: Option, #[serde(default)] family: Option, } pub struct OllamaProvider { config: ProviderConfig, client: reqwest::Client, base_url: String, is_running: Arc, conversation_history: Vec, cancel_tx: Option>, } impl OllamaProvider { pub fn new(config: ProviderConfig) -> Self { let base_url = config .api_base_url .clone() .unwrap_or_else(|| DEFAULT_OLLAMA_URL.to_string()); Self { config, client: reqwest::Client::new(), base_url, is_running: Arc::new(AtomicBool::new(false)), conversation_history: Vec::new(), cancel_tx: None, } } async fn fetch_available_models(&self) -> Vec { let url = format!("{}/api/tags", self.base_url); match self.client.get(&url).send().await { Ok(response) => { if let Ok(tags) = response.json::().await { tags.models .into_iter() .map(|m| { let description = m.details.as_ref().map(|d| { let mut desc_parts = Vec::new(); if let Some(ref family) = d.family { desc_parts.push(family.clone()); } if let Some(ref size) = d.parameter_size { desc_parts.push(format!("{} parameters", size)); } if desc_parts.is_empty() { "Local model".to_string() } else { desc_parts.join(" - ") } }); ModelInfo { id: m.name.clone(), name: m.name, description, context_window: None, input_cost_per_mtok: Some(0.0), // Local = free! output_cost_per_mtok: Some(0.0), } }) .collect() } else { default_ollama_models() } } Err(_) => default_ollama_models(), } } async fn stream_chat( &mut self, callback: Arc, cancel_rx: mpsc::Receiver<()>, ) -> Result<(), String> { let model = self .config .model .clone() .unwrap_or_else(|| "llama3.2".to_string()); let request = OllamaChatRequest { model: model.clone(), messages: self.conversation_history.clone(), stream: true, system: self.config.custom_instructions.clone(), }; let url = format!("{}/api/chat", self.base_url); let response = self .client .post(&url) .json(&request) .send() .await .map_err(|e| format!("Failed to connect to Ollama: {}", e))?; if !response.status().is_success() { let status = response.status(); let error_text = response .text() .await .unwrap_or_else(|_| "Unknown error".to_string()); return Err(format!("Ollama error ({}): {}", status, error_text)); } let mut full_response = String::new(); let mut total_input_tokens: u64 = 0; let mut total_output_tokens: u64 = 0; let mut stream = response.bytes_stream(); let mut cancel_rx = cancel_rx; use futures_util::StreamExt; loop { tokio::select! { _ = cancel_rx.recv() => { callback(ProviderStreamEvent::Disconnected); return Ok(()); } chunk = stream.next() => { match chunk { Some(Ok(bytes)) => { let text = String::from_utf8_lossy(&bytes); for line in text.lines() { if line.trim().is_empty() { continue; } if let Ok(response) = serde_json::from_str::(line) { if let Some(msg) = &response.message { if !msg.content.is_empty() { full_response.push_str(&msg.content); callback(ProviderStreamEvent::TextDelta { text: msg.content.clone(), }); } } if let Some(tokens) = response.prompt_eval_count { total_input_tokens = tokens; } if let Some(tokens) = response.eval_count { total_output_tokens = tokens; } if response.done { self.conversation_history.push(OllamaMessage { role: "assistant".to_string(), content: full_response.clone(), }); callback(ProviderStreamEvent::MessageComplete { content: vec![ProviderMessage::Text { content: full_response, }], usage: Some(ProviderUsage { input_tokens: total_input_tokens, output_tokens: total_output_tokens, model, }), stop_reason: Some("end_turn".to_string()), }); return Ok(()); } } } } Some(Err(e)) => { return Err(format!("Stream error: {}", e)); } None => { break; } } } } } Ok(()) } } #[async_trait] impl LlmProvider for OllamaProvider { fn name(&self) -> &'static str { "Ollama" } fn capabilities(&self) -> ProviderCapabilities { ProviderCapabilities { supports_streaming: true, supports_tools: false, // Ollama doesn't support tools natively yet supports_vision: true, // Some models support vision supports_thinking: false, supports_mcp: false, supports_resume_session: false, max_context_tokens: None, // Varies by model available_models: default_ollama_models(), } } async fn start(&mut self, callback: StreamCallback) -> Result<(), String> { if self.is_running.load(Ordering::SeqCst) { return Err("Provider already running".to_string()); } // Check if Ollama is reachable let url = format!("{}/api/tags", self.base_url); self.client.get(&url).send().await.map_err(|e| { format!( "Cannot connect to Ollama at {}. Is it running? Error: {}", self.base_url, e ) })?; self.is_running.store(true, Ordering::SeqCst); self.conversation_history.clear(); // Create cancellation channel let (tx, _rx) = mpsc::channel(1); self.cancel_tx = Some(tx); callback(ProviderStreamEvent::Connected { session_id: None }); // Fetch and report available models let models = self.fetch_available_models().await; if !models.is_empty() { eprintln!( "[Ollama] Available models: {:?}", models.iter().map(|m| &m.id).collect::>() ); } Ok(()) } async fn stop(&mut self) -> Result<(), String> { self.is_running.store(false, Ordering::SeqCst); self.conversation_history.clear(); if let Some(tx) = self.cancel_tx.take() { let _ = tx.send(()).await; } Ok(()) } fn is_running(&self) -> bool { self.is_running.load(Ordering::SeqCst) } async fn send_message(&mut self, message: &str) -> Result<(), String> { if !self.is_running.load(Ordering::SeqCst) { return Err("Provider not running".to_string()); } // Add user message to history self.conversation_history.push(OllamaMessage { role: "user".to_string(), content: message.to_string(), }); // Create a new cancel channel for this request let (tx, _rx) = mpsc::channel(1); self.cancel_tx = Some(tx); // We need a callback here, but we don't have access to it in send_message // This is a limitation of the current trait design // For now, we'll need to refactor to handle streaming properly // The callback should be stored from the start() call // For the MVP, we'll emit events directly // In a real implementation, we'd need to restructure this Err("send_message needs refactoring to work with stored callback".to_string()) } async fn send_tool_result( &mut self, _tool_use_id: &str, _result: serde_json::Value, _is_error: bool, ) -> Result<(), String> { // Ollama doesn't support tool use natively Err("Ollama does not support tool use".to_string()) } fn get_model(&self) -> Option { self.config.model.clone() } fn get_working_directory(&self) -> &str { &self.config.working_directory } async fn interrupt(&mut self) -> Result<(), String> { if let Some(tx) = self.cancel_tx.take() { let _ = tx.send(()).await; } self.is_running.store(false, Ordering::SeqCst); Ok(()) } } fn default_ollama_models() -> Vec { vec![ ModelInfo { id: "llama3.2".to_string(), name: "Llama 3.2".to_string(), description: Some("Meta's latest compact model".to_string()), context_window: Some(128_000), input_cost_per_mtok: Some(0.0), output_cost_per_mtok: Some(0.0), }, ModelInfo { id: "llama3.2:1b".to_string(), name: "Llama 3.2 1B".to_string(), description: Some("Smallest Llama 3.2 variant".to_string()), context_window: Some(128_000), input_cost_per_mtok: Some(0.0), output_cost_per_mtok: Some(0.0), }, ModelInfo { id: "qwen2.5-coder".to_string(), name: "Qwen 2.5 Coder".to_string(), description: Some("Alibaba's coding-focused model".to_string()), context_window: Some(32_000), input_cost_per_mtok: Some(0.0), output_cost_per_mtok: Some(0.0), }, ModelInfo { id: "deepseek-coder-v2".to_string(), name: "DeepSeek Coder V2".to_string(), description: Some("DeepSeek's coding model".to_string()), context_window: Some(128_000), input_cost_per_mtok: Some(0.0), output_cost_per_mtok: Some(0.0), }, ModelInfo { id: "mistral".to_string(), name: "Mistral 7B".to_string(), description: Some("Fast and capable".to_string()), context_window: Some(32_000), input_cost_per_mtok: Some(0.0), output_cost_per_mtok: Some(0.0), }, ModelInfo { id: "gemma2".to_string(), name: "Gemma 2".to_string(), description: Some("Google's open model".to_string()), context_window: Some(8_000), input_cost_per_mtok: Some(0.0), output_cost_per_mtok: Some(0.0), }, ] } #[cfg(test)] mod tests { use super::*; #[test] fn test_provider_name() { let config = ProviderConfig::default(); let provider = OllamaProvider::new(config); assert_eq!(provider.name(), "Ollama"); } #[test] fn test_capabilities() { let config = ProviderConfig::default(); let provider = OllamaProvider::new(config); let caps = provider.capabilities(); assert!(caps.supports_streaming); assert!(!caps.supports_tools); assert!(!caps.supports_mcp); assert!(!caps.supports_resume_session); assert!(!caps.available_models.is_empty()); } #[test] fn test_default_url() { let config = ProviderConfig::default(); let provider = OllamaProvider::new(config); assert_eq!(provider.base_url, DEFAULT_OLLAMA_URL); } #[test] fn test_custom_url() { let config = ProviderConfig { api_base_url: Some("http://custom:8080".to_string()), ..Default::default() }; let provider = OllamaProvider::new(config); assert_eq!(provider.base_url, "http://custom:8080"); } #[test] fn test_is_running_initial() { let config = ProviderConfig::default(); let provider = OllamaProvider::new(config); assert!(!provider.is_running()); } #[test] fn test_default_ollama_models() { let models = default_ollama_models(); assert!(!models.is_empty()); // All models should be free (local) for model in &models { assert_eq!(model.input_cost_per_mtok, Some(0.0)); assert_eq!(model.output_cost_per_mtok, Some(0.0)); } // Should include llama assert!(models.iter().any(|m| m.id.contains("llama"))); } #[test] fn test_ollama_message_serialization() { let msg = OllamaMessage { role: "user".to_string(), content: "Hello!".to_string(), }; let json = serde_json::to_string(&msg).unwrap(); assert!(json.contains("\"role\":\"user\"")); assert!(json.contains("\"content\":\"Hello!\"")); } #[test] fn test_chat_request_serialization() { let request = OllamaChatRequest { model: "llama3.2".to_string(), messages: vec![OllamaMessage { role: "user".to_string(), content: "Test".to_string(), }], stream: true, system: Some("You are helpful".to_string()), }; let json = serde_json::to_string(&request).unwrap(); assert!(json.contains("\"model\":\"llama3.2\"")); assert!(json.contains("\"stream\":true")); assert!(json.contains("\"system\":\"You are helpful\"")); } #[test] fn test_chat_request_without_system() { let request = OllamaChatRequest { model: "llama3.2".to_string(), messages: vec![], stream: true, system: None, }; let json = serde_json::to_string(&request).unwrap(); // system should be omitted when None assert!(!json.contains("\"system\"")); } }