feat: support ollama

This commit is contained in:
2026-02-04 13:19:10 -08:00
parent daedbfd865
commit a0804ed32a
13 changed files with 2480 additions and 5 deletions
+148 -2
View File
@@ -592,7 +592,7 @@ dependencies = [
"bitflags 2.10.0",
"core-foundation 0.10.1",
"core-graphics-types",
"foreign-types",
"foreign-types 0.5.0",
"libc",
]
@@ -1080,6 +1080,15 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]]
name = "foreign-types"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
dependencies = [
"foreign-types-shared 0.1.1",
]
[[package]]
name = "foreign-types"
version = "0.5.0"
@@ -1087,7 +1096,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965"
dependencies = [
"foreign-types-macros",
"foreign-types-shared",
"foreign-types-shared 0.3.1",
]
[[package]]
@@ -1101,6 +1110,12 @@ dependencies = [
"syn 2.0.114",
]
[[package]]
name = "foreign-types-shared"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
[[package]]
name = "foreign-types-shared"
version = "0.3.1"
@@ -1604,8 +1619,11 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
name = "hikari-desktop"
version = "1.1.1"
dependencies = [
"async-trait",
"chrono",
"futures-util",
"parking_lot",
"reqwest",
"semver",
"serde",
"serde_json",
@@ -1716,6 +1734,22 @@ dependencies = [
"webpki-roots",
]
[[package]]
name = "hyper-tls"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0"
dependencies = [
"bytes",
"http-body-util",
"hyper",
"hyper-util",
"native-tls",
"tokio",
"tokio-native-tls",
"tower-service",
]
[[package]]
name = "hyper-util"
version = "0.1.19"
@@ -2288,6 +2322,23 @@ dependencies = [
"windows-sys 0.60.2",
]
[[package]]
name = "native-tls"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e"
dependencies = [
"libc",
"log",
"openssl",
"openssl-probe",
"openssl-sys",
"schannel",
"security-framework",
"security-framework-sys",
"tempfile",
]
[[package]]
name = "ndk"
version = "0.9.0"
@@ -2661,6 +2712,50 @@ dependencies = [
"pathdiff",
]
[[package]]
name = "openssl"
version = "0.10.75"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328"
dependencies = [
"bitflags 2.10.0",
"cfg-if",
"foreign-types 0.3.2",
"libc",
"once_cell",
"openssl-macros",
"openssl-sys",
]
[[package]]
name = "openssl-macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.114",
]
[[package]]
name = "openssl-probe"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
[[package]]
name = "openssl-sys"
version = "0.9.111"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321"
dependencies = [
"cc",
"libc",
"pkg-config",
"vcpkg",
]
[[package]]
name = "option-ext"
version = "0.2.0"
@@ -3417,10 +3512,12 @@ dependencies = [
"http-body-util",
"hyper",
"hyper-rustls",
"hyper-tls",
"hyper-util",
"js-sys",
"log",
"mime",
"native-tls",
"percent-encoding",
"pin-project-lite",
"quinn",
@@ -3431,6 +3528,7 @@ dependencies = [
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tokio-native-tls",
"tokio-rustls",
"tokio-util",
"tower",
@@ -3566,6 +3664,15 @@ dependencies = [
"winapi-util",
]
[[package]]
name = "schannel"
version = "0.1.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1"
dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "schemars"
version = "0.8.22"
@@ -3623,6 +3730,29 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "security-framework"
version = "2.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
dependencies = [
"bitflags 2.10.0",
"core-foundation 0.9.4",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "selectors"
version = "0.24.0"
@@ -4746,6 +4876,16 @@ dependencies = [
"syn 2.0.114",
]
[[package]]
name = "tokio-native-tls"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"
dependencies = [
"native-tls",
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.26.4"
@@ -5111,6 +5251,12 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "vcpkg"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]]
name = "version-compare"
version = "0.2.1"
+3
View File
@@ -31,6 +31,9 @@ tauri-plugin-fs = "2"
tempfile = "3"
semver = "1"
chrono = { version = "0.4.43", features = ["serde"] }
async-trait = "0.1"
reqwest = { version = "0.12", features = ["json", "stream"] }
futures-util = "0.3"
[target.'cfg(windows)'.dependencies]
windows = { version = "0.62", features = [
+15 -3
View File
@@ -4,11 +4,11 @@ use std::sync::Arc;
use tauri::AppHandle;
use crate::config::ClaudeStartOptions;
use crate::provider_bridge::ProviderBridge;
use crate::stats::UsageStats;
use crate::wsl_bridge::WslBridge;
pub struct BridgeManager {
bridges: HashMap<String, WslBridge>,
bridges: HashMap<String, ProviderBridge>,
app_handle: Option<AppHandle>,
}
@@ -45,12 +45,24 @@ impl BridgeManager {
.ok_or_else(|| "App handle not set".to_string())?
.clone();
// Check if existing bridge matches the requested provider type
// If provider type changed, create a new bridge
let should_recreate = self.bridges.get(conversation_id).map_or(false, |bridge| {
bridge.provider_type() != options.provider_type
});
if should_recreate {
// Remove existing bridge if provider type changed
self.bridges.remove(conversation_id);
}
// Reuse existing bridge if it exists (preserves stats across reconnects)
// Only create a new bridge if one doesn't exist for this conversation
let provider_type = options.provider_type;
let bridge = self
.bridges
.entry(conversation_id.to_string())
.or_insert_with(|| WslBridge::new_with_conversation_id(conversation_id.to_string()));
.or_insert_with(|| ProviderBridge::new(provider_type, conversation_id.to_string()));
// Start the Claude process
bridge.start(app, options)?;
+34
View File
@@ -1,7 +1,11 @@
use crate::providers::ProviderType;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ClaudeStartOptions {
#[serde(default)]
pub provider_type: ProviderType,
#[serde(default)]
pub working_dir: String,
@@ -25,10 +29,20 @@ pub struct ClaudeStartOptions {
#[serde(default)]
pub resume_session_id: Option<String>,
// Ollama-specific options
#[serde(default = "default_ollama_base_url")]
pub ollama_base_url: String,
#[serde(default)]
pub ollama_model: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HikariConfig {
#[serde(default)]
pub provider_type: ProviderType,
#[serde(default)]
pub model: Option<String>,
@@ -44,6 +58,13 @@ pub struct HikariConfig {
#[serde(default)]
pub auto_granted_tools: Vec<String>,
// Ollama-specific settings
#[serde(default = "default_ollama_base_url")]
pub ollama_base_url: String,
#[serde(default)]
pub ollama_model: Option<String>,
#[serde(default)]
pub theme: Theme,
@@ -101,11 +122,14 @@ pub struct HikariConfig {
impl Default for HikariConfig {
fn default() -> Self {
Self {
provider_type: ProviderType::default(),
model: None,
api_key: None,
custom_instructions: None,
mcp_servers_json: None,
auto_granted_tools: Vec::new(),
ollama_base_url: default_ollama_base_url(),
ollama_model: None,
theme: Theme::default(),
greeting_enabled: true,
greeting_custom_prompt: None,
@@ -147,6 +171,10 @@ fn default_font_size() -> u32 {
14
}
fn default_ollama_base_url() -> String {
"http://localhost:11434".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Theme {
@@ -185,11 +213,14 @@ mod tests {
#[test]
fn test_default_config() {
let config = HikariConfig::default();
assert_eq!(config.provider_type, ProviderType::ClaudeCli);
assert!(config.model.is_none());
assert!(config.api_key.is_none());
assert!(config.custom_instructions.is_none());
assert!(config.mcp_servers_json.is_none());
assert!(config.auto_granted_tools.is_empty());
assert_eq!(config.ollama_base_url, "http://localhost:11434");
assert!(config.ollama_model.is_none());
assert_eq!(config.theme, Theme::Dark);
assert!(config.greeting_enabled);
assert!(config.greeting_custom_prompt.is_none());
@@ -210,11 +241,14 @@ mod tests {
#[test]
fn test_config_serialization() {
let config = HikariConfig {
provider_type: ProviderType::ClaudeCli,
model: Some("claude-sonnet-4-20250514".to_string()),
api_key: None,
custom_instructions: Some("Be helpful".to_string()),
mcp_servers_json: None,
auto_granted_tools: vec!["Read".to_string(), "Glob".to_string()],
ollama_base_url: "http://localhost:11434".to_string(),
ollama_model: None,
theme: Theme::Light,
greeting_enabled: true,
greeting_custom_prompt: Some("Hello!".to_string()),
+2
View File
@@ -5,6 +5,8 @@ mod commands;
mod config;
mod git;
mod notifications;
mod provider_bridge;
mod providers;
mod quick_actions;
mod sessions;
mod snippets;
+538
View File
@@ -0,0 +1,538 @@
// Provider bridge that unifies different LLM provider implementations
// for use with the BridgeManager.
use parking_lot::RwLock;
use std::sync::Arc;
use tauri::{AppHandle, Emitter};
use crate::achievements::{get_achievement_info, AchievementUnlockedEvent};
use crate::config::ClaudeStartOptions;
use crate::providers::ProviderType;
use crate::stats::{StatsUpdateEvent, UsageStats};
use crate::types::{
CharacterState, ConnectionEvent, ConnectionStatus, OutputEvent, StateChangeEvent,
};
use crate::wsl_bridge::WslBridge;
/// A unified bridge that can wrap either the Claude CLI or Ollama provider
pub enum ProviderBridge {
ClaudeCli(WslBridge),
Ollama(OllamaBridge),
}
impl ProviderBridge {
pub fn new(provider_type: ProviderType, conversation_id: String) -> Self {
match provider_type {
ProviderType::ClaudeCli => {
ProviderBridge::ClaudeCli(WslBridge::new_with_conversation_id(conversation_id))
}
ProviderType::Ollama => {
ProviderBridge::Ollama(OllamaBridge::new(conversation_id))
}
}
}
pub fn start(&mut self, app: AppHandle, options: ClaudeStartOptions) -> Result<(), String> {
match self {
ProviderBridge::ClaudeCli(bridge) => bridge.start(app, options),
ProviderBridge::Ollama(bridge) => bridge.start(app, options),
}
}
pub fn stop(&mut self, app: &AppHandle) {
match self {
ProviderBridge::ClaudeCli(bridge) => bridge.stop(app),
ProviderBridge::Ollama(bridge) => bridge.stop(app),
}
}
pub fn interrupt(&mut self, app: &AppHandle) -> Result<(), String> {
match self {
ProviderBridge::ClaudeCli(bridge) => bridge.interrupt(app),
ProviderBridge::Ollama(bridge) => bridge.interrupt(app),
}
}
pub fn send_message(&mut self, message: &str) -> Result<(), String> {
match self {
ProviderBridge::ClaudeCli(bridge) => bridge.send_message(message),
ProviderBridge::Ollama(bridge) => bridge.send_message(message),
}
}
pub fn send_tool_result(
&mut self,
tool_use_id: &str,
result: serde_json::Value,
) -> Result<(), String> {
match self {
ProviderBridge::ClaudeCli(bridge) => bridge.send_tool_result(tool_use_id, result),
ProviderBridge::Ollama(bridge) => bridge.send_tool_result(tool_use_id, result),
}
}
pub fn is_running(&self) -> bool {
match self {
ProviderBridge::ClaudeCli(bridge) => bridge.is_running(),
ProviderBridge::Ollama(bridge) => bridge.is_running(),
}
}
pub fn get_working_directory(&self) -> &str {
match self {
ProviderBridge::ClaudeCli(bridge) => bridge.get_working_directory(),
ProviderBridge::Ollama(bridge) => bridge.get_working_directory(),
}
}
pub fn get_stats(&self) -> UsageStats {
match self {
ProviderBridge::ClaudeCli(bridge) => bridge.get_stats(),
ProviderBridge::Ollama(bridge) => bridge.get_stats(),
}
}
pub fn provider_type(&self) -> ProviderType {
match self {
ProviderBridge::ClaudeCli(_) => ProviderType::ClaudeCli,
ProviderBridge::Ollama(_) => ProviderType::Ollama,
}
}
}
/// Bridge for Ollama local LLM provider
pub struct OllamaBridge {
conversation_id: String,
working_directory: String,
is_running: bool,
stats: Arc<RwLock<UsageStats>>,
base_url: String,
model: String,
messages: Vec<OllamaMessage>,
cancel_flag: Arc<std::sync::atomic::AtomicBool>,
app_handle: Option<AppHandle>,
}
#[derive(Clone, serde::Serialize, serde::Deserialize)]
struct OllamaMessage {
role: String,
content: String,
}
#[derive(serde::Deserialize)]
struct OllamaChatResponse {
message: Option<OllamaResponseMessage>,
done: bool,
#[serde(default)]
eval_count: Option<u64>,
#[serde(default)]
prompt_eval_count: Option<u64>,
}
#[derive(serde::Deserialize)]
struct OllamaResponseMessage {
content: String,
}
impl OllamaBridge {
pub fn new(conversation_id: String) -> Self {
OllamaBridge {
conversation_id,
working_directory: String::new(),
is_running: false,
stats: Arc::new(RwLock::new(UsageStats::new())),
base_url: "http://localhost:11434".to_string(),
model: "llama3.2".to_string(),
messages: Vec::new(),
cancel_flag: Arc::new(std::sync::atomic::AtomicBool::new(false)),
app_handle: None,
}
}
pub fn start(&mut self, app: AppHandle, options: ClaudeStartOptions) -> Result<(), String> {
if self.is_running {
return Err("Ollama bridge already running".to_string());
}
self.working_directory = options.working_dir.clone();
self.base_url = options.ollama_base_url.clone();
self.model = options.ollama_model.clone().unwrap_or_else(|| "llama3.2".to_string());
self.app_handle = Some(app.clone());
self.cancel_flag.store(false, std::sync::atomic::Ordering::SeqCst);
// Reset messages for new conversation
self.messages.clear();
// Add system prompt if custom instructions are provided
if let Some(ref instructions) = options.custom_instructions {
if !instructions.is_empty() {
self.messages.push(OllamaMessage {
role: "system".to_string(),
content: instructions.clone(),
});
}
}
emit_connection_status(&app, ConnectionStatus::Connecting, Some(self.conversation_id.clone()));
// Check if Ollama is running
let base_url = self.base_url.clone();
let app_clone = app.clone();
let conv_id = self.conversation_id.clone();
let rt = tokio::runtime::Handle::current();
let check_result = rt.block_on(async {
check_ollama_running(&base_url).await
});
match check_result {
Ok(true) => {
self.is_running = true;
emit_connection_status(&app, ConnectionStatus::Connected, Some(self.conversation_id.clone()));
emit_state_change(&app, CharacterState::Idle, None, Some(self.conversation_id.clone()));
Ok(())
}
Ok(false) => {
emit_connection_status(&app_clone, ConnectionStatus::Disconnected, Some(conv_id));
Err("Ollama is not running. Please start Ollama first.".to_string())
}
Err(e) => {
emit_connection_status(&app_clone, ConnectionStatus::Disconnected, Some(conv_id));
Err(format!("Failed to connect to Ollama: {}", e))
}
}
}
pub fn stop(&mut self, app: &AppHandle) {
self.cancel_flag.store(true, std::sync::atomic::Ordering::SeqCst);
self.is_running = false;
self.app_handle = None;
// Save stats before stopping
let stats_snapshot = self.stats.read().clone();
let app_clone = app.clone();
tauri::async_runtime::spawn(async move {
if let Err(e) = crate::stats::save_stats(&app_clone, &stats_snapshot).await {
eprintln!("Failed to save stats: {}", e);
}
});
self.stats.write().reset_session();
emit_connection_status(app, ConnectionStatus::Disconnected, Some(self.conversation_id.clone()));
}
pub fn interrupt(&mut self, app: &AppHandle) -> Result<(), String> {
self.cancel_flag.store(true, std::sync::atomic::Ordering::SeqCst);
self.is_running = false;
emit_connection_status(app, ConnectionStatus::Disconnected, Some(self.conversation_id.clone()));
Ok(())
}
pub fn send_message(&mut self, message: &str) -> Result<(), String> {
if !self.is_running {
return Err("Ollama bridge not running".to_string());
}
let app = self.app_handle.clone().ok_or("App handle not available")?;
// Add user message to history
self.messages.push(OllamaMessage {
role: "user".to_string(),
content: message.to_string(),
});
// Increment message count
self.stats.write().increment_messages();
emit_state_change(&app, CharacterState::Thinking, None, Some(self.conversation_id.clone()));
// Prepare request
let base_url = self.base_url.clone();
let model = self.model.clone();
let messages = self.messages.clone();
let stats = self.stats.clone();
let conv_id = self.conversation_id.clone();
let cancel_flag = self.cancel_flag.clone();
// Spawn async task to handle the chat
tauri::async_runtime::spawn(async move {
match stream_ollama_chat(&base_url, &model, &messages, &app, &stats, &conv_id, &cancel_flag).await {
Ok(_response) => {
// Response handling is done in stream_ollama_chat
}
Err(e) => {
let _ = app.emit(
"claude:output",
OutputEvent {
line_type: "error".to_string(),
content: format!("Ollama error: {}", e),
tool_name: None,
conversation_id: Some(conv_id.clone()),
},
);
emit_state_change(&app, CharacterState::Error, None, Some(conv_id));
}
}
});
Ok(())
}
pub fn send_tool_result(
&mut self,
_tool_use_id: &str,
_result: serde_json::Value,
) -> Result<(), String> {
// Ollama doesn't support tools in the same way as Claude
// For now, we just ignore tool results
Err("Tool results are not supported with Ollama provider".to_string())
}
pub fn is_running(&self) -> bool {
self.is_running
}
pub fn get_working_directory(&self) -> &str {
&self.working_directory
}
pub fn get_stats(&self) -> UsageStats {
self.stats.read().clone()
}
}
async fn check_ollama_running(base_url: &str) -> Result<bool, String> {
let client = reqwest::Client::new();
let url = format!("{}/api/tags", base_url);
match client.get(&url).timeout(std::time::Duration::from_secs(5)).send().await {
Ok(response) => Ok(response.status().is_success()),
Err(e) => {
if e.is_connect() {
Ok(false)
} else {
Err(e.to_string())
}
}
}
}
async fn stream_ollama_chat(
base_url: &str,
model: &str,
messages: &[OllamaMessage],
app: &AppHandle,
stats: &Arc<RwLock<UsageStats>>,
conversation_id: &str,
cancel_flag: &Arc<std::sync::atomic::AtomicBool>,
) -> Result<String, String> {
let client = reqwest::Client::new();
let url = format!("{}/api/chat", base_url);
emit_state_change(app, CharacterState::Typing, None, Some(conversation_id.to_string()));
let request_body = serde_json::json!({
"model": model,
"messages": messages,
"stream": true
});
let response = client
.post(&url)
.json(&request_body)
.send()
.await
.map_err(|e| format!("Failed to send request to Ollama: {}", e))?;
if !response.status().is_success() {
return Err(format!("Ollama returned error status: {}", response.status()));
}
let mut full_response = String::new();
let mut total_tokens: u64 = 0;
let mut prompt_tokens: u64 = 0;
// Stream the response
let mut stream = response.bytes_stream();
use futures_util::StreamExt;
while let Some(chunk_result) = stream.next().await {
// Check for cancellation
if cancel_flag.load(std::sync::atomic::Ordering::SeqCst) {
break;
}
match chunk_result {
Ok(chunk) => {
let chunk_str = String::from_utf8_lossy(&chunk);
// Parse each line as JSON
for line in chunk_str.lines() {
if line.trim().is_empty() {
continue;
}
match serde_json::from_str::<OllamaChatResponse>(line) {
Ok(response) => {
if let Some(msg) = response.message {
full_response.push_str(&msg.content);
// Emit stream event
let _ = app.emit("claude:stream", msg.content);
}
// Capture token counts
if let Some(eval) = response.eval_count {
total_tokens = eval;
}
if let Some(prompt) = response.prompt_eval_count {
prompt_tokens = prompt;
}
if response.done {
break;
}
}
Err(e) => {
eprintln!("Failed to parse Ollama response: {} - Line: {}", e, line);
}
}
}
}
Err(e) => {
return Err(format!("Error reading stream: {}", e));
}
}
}
// Update stats with token usage
{
let mut stats_guard = stats.write();
stats_guard.increment_messages();
// Ollama reports output tokens as eval_count, input as prompt_eval_count
stats_guard.add_usage(prompt_tokens, total_tokens, model);
// Count code blocks in the response
let code_blocks = full_response.matches("```").count() / 2;
for _ in 0..code_blocks {
stats_guard.increment_code_blocks();
}
}
// Emit the complete response
let _ = app.emit(
"claude:output",
OutputEvent {
line_type: "assistant".to_string(),
content: full_response.clone(),
tool_name: None,
conversation_id: Some(conversation_id.to_string()),
},
);
// Check achievements
let newly_unlocked = {
let mut stats_guard = stats.write();
stats_guard.check_achievements()
};
// Emit achievement events
for achievement_id in &newly_unlocked {
let info = get_achievement_info(achievement_id);
let _ = app.emit(
"achievement:unlocked",
AchievementUnlockedEvent { achievement: info },
);
}
// Save achievements if any were unlocked
if !newly_unlocked.is_empty() {
let app_clone = app.clone();
let achievements_progress = stats.read().achievements.clone();
tauri::async_runtime::spawn(async move {
if let Err(e) = crate::achievements::save_achievements(&app_clone, &achievements_progress).await {
eprintln!("Failed to save achievements: {}", e);
}
});
}
// Emit stats update
let current_stats = stats.read().clone();
let _ = app.emit("claude:stats", StatsUpdateEvent { stats: current_stats });
emit_state_change(app, CharacterState::Success, None, Some(conversation_id.to_string()));
Ok(full_response)
}
fn emit_state_change(
app: &AppHandle,
state: CharacterState,
tool_name: Option<String>,
conversation_id: Option<String>,
) {
let _ = app.emit(
"claude:state",
StateChangeEvent {
state,
tool_name,
conversation_id,
},
);
}
fn emit_connection_status(
app: &AppHandle,
status: ConnectionStatus,
conversation_id: Option<String>,
) {
let _ = app.emit(
"claude:connection",
ConnectionEvent {
status,
conversation_id,
},
);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_bridge_ollama_new() {
let bridge = OllamaBridge::new("test-conv".to_string());
assert!(!bridge.is_running());
assert_eq!(bridge.get_working_directory(), "");
assert_eq!(bridge.conversation_id, "test-conv");
}
#[test]
fn test_provider_bridge_type() {
let claude = ProviderBridge::new(ProviderType::ClaudeCli, "conv1".to_string());
assert_eq!(claude.provider_type(), ProviderType::ClaudeCli);
let ollama = ProviderBridge::new(ProviderType::Ollama, "conv2".to_string());
assert_eq!(ollama.provider_type(), ProviderType::Ollama);
}
#[test]
fn test_provider_bridge_is_running_initial() {
let bridge = ProviderBridge::new(ProviderType::ClaudeCli, "test".to_string());
assert!(!bridge.is_running());
let bridge = ProviderBridge::new(ProviderType::Ollama, "test".to_string());
assert!(!bridge.is_running());
}
#[test]
fn test_ollama_message_serialization() {
let msg = OllamaMessage {
role: "user".to_string(),
content: "Hello!".to_string(),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"role\":\"user\""));
assert!(json.contains("\"content\":\"Hello!\""));
}
}
+773
View File
@@ -0,0 +1,773 @@
// This provider wraps the Claude CLI subprocess.
// It will be actively used once providers are fully integrated with BridgeManager.
#![allow(dead_code)]
use async_trait::async_trait;
use std::io::{BufRead, BufReader, Write};
use std::process::{Child, ChildStdin, Command, Stdio};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread;
use tempfile::NamedTempFile;
use crate::providers::traits::{
LlmProvider, ModelInfo, ProviderCapabilities, ProviderConfig, ProviderStreamEvent,
QuestionOption, StreamCallback,
};
#[cfg(target_os = "windows")]
use std::os::windows::process::CommandExt;
fn detect_wsl() -> bool {
if let Ok(version) = std::fs::read_to_string("/proc/version") {
let version_lower = version.to_lowercase();
if version_lower.contains("microsoft") || version_lower.contains("wsl") {
return true;
}
}
if std::path::Path::new("/proc/sys/fs/binfmt_misc/WSLInterop").exists() {
return true;
}
if std::env::var("WSL_DISTRO_NAME").is_ok() {
return true;
}
false
}
fn find_claude_binary() -> Option<String> {
let home = std::env::var("HOME").ok()?;
let paths_to_check = [
format!("{}/.local/bin/claude", home),
format!("{}/.claude/local/claude", home),
"/usr/local/bin/claude".to_string(),
"/usr/bin/claude".to_string(),
];
for path in &paths_to_check {
if std::path::Path::new(path).exists() {
return Some(path.clone());
}
}
if let Ok(output) = Command::new("which").arg("claude").output() {
if output.status.success() {
let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
if !path.is_empty() {
return Some(path);
}
}
}
None
}
pub struct ClaudeCliProvider {
config: ProviderConfig,
process: Option<Child>,
stdin: Option<ChildStdin>,
session_id: Option<String>,
mcp_config_file: Option<NamedTempFile>,
is_running: Arc<AtomicBool>,
}
impl ClaudeCliProvider {
pub fn new(config: ProviderConfig) -> Self {
Self {
config,
process: None,
stdin: None,
session_id: None,
mcp_config_file: None,
is_running: Arc::new(AtomicBool::new(false)),
}
}
fn build_command(&mut self) -> Result<Command, String> {
let mcp_config_path = if let Some(ref mcp_json) = self.config.mcp_servers_json {
if !mcp_json.trim().is_empty() {
serde_json::from_str::<serde_json::Value>(mcp_json)
.map_err(|e| format!("Invalid MCP servers JSON: {}", e))?;
let mut temp_file = NamedTempFile::new()
.map_err(|e| format!("Failed to create temp file for MCP config: {}", e))?;
temp_file
.write_all(mcp_json.as_bytes())
.map_err(|e| format!("Failed to write MCP config: {}", e))?;
temp_file
.flush()
.map_err(|e| format!("Failed to flush MCP config: {}", e))?;
let path = temp_file.path().to_string_lossy().to_string();
self.mcp_config_file = Some(temp_file);
Some(path)
} else {
None
}
} else {
None
};
let is_wsl = detect_wsl();
let working_dir = &self.config.working_directory;
let command = if is_wsl {
let claude_path = find_claude_binary().ok_or_else(|| {
"Could not find claude binary. Is Claude Code installed?".to_string()
})?;
let mut cmd = Command::new(&claude_path);
cmd.args([
"--output-format",
"stream-json",
"--input-format",
"stream-json",
"--verbose",
]);
if let Some(ref model) = self.config.model {
if !model.is_empty() {
cmd.args(["--model", model]);
}
}
for tool in &self.config.allowed_tools {
cmd.args(["--allowedTools", tool]);
}
if let Some(ref instructions) = self.config.custom_instructions {
if !instructions.is_empty() {
cmd.args(["--system-prompt", instructions]);
}
}
if let Some(ref mcp_path) = mcp_config_path {
cmd.args(["--mcp-config", mcp_path]);
}
cmd.current_dir(working_dir);
if let Some(ref api_key) = self.config.api_key {
if !api_key.is_empty() {
cmd.env("ANTHROPIC_API_KEY", api_key);
}
}
cmd
} else {
let mut cmd = Command::new("wsl");
let mut claude_cmd = format!("cd '{}' && ", working_dir);
if let Some(ref api_key) = self.config.api_key {
if !api_key.is_empty() {
claude_cmd.push_str(&format!("ANTHROPIC_API_KEY='{}' ", api_key));
}
}
claude_cmd.push_str(
"claude --output-format stream-json --input-format stream-json --verbose",
);
if let Some(ref model) = self.config.model {
if !model.is_empty() {
claude_cmd.push_str(&format!(" --model '{}'", model));
}
}
for tool in &self.config.allowed_tools {
claude_cmd.push_str(&format!(" --allowedTools '{}'", tool));
}
if let Some(ref instructions) = self.config.custom_instructions {
if !instructions.is_empty() {
let escaped = instructions.replace('\'', "'\\''");
claude_cmd.push_str(&format!(" --system-prompt '{}'", escaped));
}
}
if let Some(ref mcp_path) = mcp_config_path {
claude_cmd.push_str(&format!(" --mcp-config '{}'", mcp_path));
}
cmd.args(["-e", "bash", "-lc", &claude_cmd]);
#[cfg(target_os = "windows")]
cmd.creation_flags(0x08000000); // CREATE_NO_WINDOW
cmd
};
Ok(command)
}
}
#[async_trait]
impl LlmProvider for ClaudeCliProvider {
fn name(&self) -> &'static str {
"Claude CLI"
}
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities {
supports_streaming: true,
supports_tools: true,
supports_vision: true,
supports_thinking: true,
supports_mcp: true,
supports_resume_session: true,
max_context_tokens: Some(200_000),
available_models: vec![
ModelInfo {
id: "claude-sonnet-4-20250514".to_string(),
name: "Claude Sonnet 4".to_string(),
description: Some("Fast and intelligent".to_string()),
context_window: Some(200_000),
input_cost_per_mtok: Some(3.0),
output_cost_per_mtok: Some(15.0),
},
ModelInfo {
id: "claude-opus-4-20250514".to_string(),
name: "Claude Opus 4".to_string(),
description: Some("Most capable model".to_string()),
context_window: Some(200_000),
input_cost_per_mtok: Some(15.0),
output_cost_per_mtok: Some(75.0),
},
ModelInfo {
id: "claude-3-5-haiku-20241022".to_string(),
name: "Claude 3.5 Haiku".to_string(),
description: Some("Fast and efficient".to_string()),
context_window: Some(200_000),
input_cost_per_mtok: Some(1.0),
output_cost_per_mtok: Some(5.0),
},
],
}
}
async fn start(&mut self, callback: StreamCallback) -> Result<(), String> {
if self.process.is_some() {
return Err("Process already running".to_string());
}
callback(ProviderStreamEvent::Connected { session_id: None });
let mut command = self.build_command()?;
command
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
let mut child = command
.spawn()
.map_err(|e| format!("Failed to spawn process: {}", e))?;
let stdin = child.stdin.take();
let stdout = child.stdout.take();
let stderr = child.stderr.take();
self.stdin = stdin;
self.process = Some(child);
self.is_running.store(true, Ordering::SeqCst);
let is_running = self.is_running.clone();
let callback = Arc::new(callback);
if let Some(stdout) = stdout {
let callback_clone = callback.clone();
let is_running_clone = is_running.clone();
thread::spawn(move || {
handle_stdout(stdout, callback_clone, is_running_clone);
});
}
if let Some(stderr) = stderr {
let callback_clone = callback.clone();
thread::spawn(move || {
handle_stderr(stderr, callback_clone);
});
}
Ok(())
}
async fn stop(&mut self) -> Result<(), String> {
if let Some(mut process) = self.process.take() {
let _ = process.kill();
let _ = process.wait();
}
self.stdin = None;
self.session_id = None;
self.mcp_config_file = None;
self.is_running.store(false, Ordering::SeqCst);
Ok(())
}
fn is_running(&self) -> bool {
self.is_running.load(Ordering::SeqCst)
}
async fn send_message(&mut self, message: &str) -> Result<(), String> {
let stdin = self.stdin.as_mut().ok_or("Process not running")?;
let input = serde_json::json!({
"type": "user",
"message": {
"role": "user",
"content": [{
"type": "text",
"text": message
}]
}
});
let json_line = serde_json::to_string(&input).map_err(|e| e.to_string())?;
stdin
.write_all(format!("{}\n", json_line).as_bytes())
.map_err(|e| format!("Failed to write to stdin: {}", e))?;
stdin
.flush()
.map_err(|e| format!("Failed to flush stdin: {}", e))?;
Ok(())
}
async fn send_tool_result(
&mut self,
tool_use_id: &str,
result: serde_json::Value,
_is_error: bool,
) -> Result<(), String> {
let stdin = self.stdin.as_mut().ok_or("Process not running")?;
let content_str = serde_json::to_string(&result).map_err(|e| e.to_string())?;
let input = serde_json::json!({
"type": "user",
"message": {
"role": "user",
"content": [{
"type": "tool_result",
"tool_use_id": tool_use_id,
"content": content_str
}]
}
});
let json_line = serde_json::to_string(&input).map_err(|e| e.to_string())?;
stdin
.write_all(format!("{}\n", json_line).as_bytes())
.map_err(|e| format!("Failed to write to stdin: {}", e))?;
stdin
.flush()
.map_err(|e| format!("Failed to flush stdin: {}", e))?;
Ok(())
}
fn get_model(&self) -> Option<String> {
self.config.model.clone()
}
fn get_working_directory(&self) -> &str {
&self.config.working_directory
}
async fn interrupt(&mut self) -> Result<(), String> {
if let Some(mut process) = self.process.take() {
let _ = process.kill();
let _ = process.wait();
self.stdin = None;
self.is_running.store(false, Ordering::SeqCst);
Ok(())
} else {
Err("No active process to interrupt".to_string())
}
}
}
fn handle_stdout(
stdout: std::process::ChildStdout,
callback: Arc<StreamCallback>,
is_running: Arc<AtomicBool>,
) {
let reader = BufReader::new(stdout);
for line in reader.lines() {
match line {
Ok(line) if !line.is_empty() => {
if let Err(e) = process_json_line(&line, &callback) {
eprintln!("Error processing line: {}", e);
}
}
Err(e) => {
eprintln!("Error reading stdout: {}", e);
break;
}
_ => {}
}
}
is_running.store(false, Ordering::SeqCst);
callback(ProviderStreamEvent::Disconnected);
}
fn handle_stderr(stderr: std::process::ChildStderr, callback: Arc<StreamCallback>) {
let reader = BufReader::new(stderr);
for line in reader.lines() {
match line {
Ok(line) if !line.is_empty() => {
callback(ProviderStreamEvent::Error { message: line });
}
Err(_) => break,
_ => {}
}
}
}
fn process_json_line(line: &str, callback: &StreamCallback) -> Result<(), String> {
let message: serde_json::Value = serde_json::from_str(line)
.map_err(|e| format!("Failed to parse JSON: {} - Line: {}", e, line))?;
let msg_type = message.get("type").and_then(|t| t.as_str()).unwrap_or("");
match msg_type {
"system" => {
let subtype = message
.get("subtype")
.and_then(|s| s.as_str())
.unwrap_or("");
if subtype == "init" {
let session_id = message
.get("session_id")
.and_then(|s| s.as_str())
.map(|s| s.to_string());
callback(ProviderStreamEvent::Connected { session_id });
if let Some(cwd) = message.get("cwd").and_then(|c| c.as_str()) {
callback(ProviderStreamEvent::WorkingDirectory {
path: cwd.to_string(),
});
}
}
}
"assistant" => {
if let Some(msg) = message.get("message") {
if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
for block in content {
let block_type = block.get("type").and_then(|t| t.as_str()).unwrap_or("");
match block_type {
"text" => {
if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
callback(ProviderStreamEvent::TextDelta {
text: text.to_string(),
});
}
}
"thinking" => {
if let Some(thinking) =
block.get("thinking").and_then(|t| t.as_str())
{
callback(ProviderStreamEvent::ThinkingDelta {
text: thinking.to_string(),
});
}
}
"tool_use" => {
let id = block
.get("id")
.and_then(|i| i.as_str())
.unwrap_or("")
.to_string();
let name = block
.get("name")
.and_then(|n| n.as_str())
.unwrap_or("")
.to_string();
let input = block
.get("input")
.cloned()
.unwrap_or(serde_json::Value::Null);
callback(ProviderStreamEvent::ToolUseStart {
id: id.clone(),
name,
});
callback(ProviderStreamEvent::ToolUseEnd { id, input });
}
_ => {}
}
}
}
// Extract usage if available
let usage = msg.get("usage").and_then(|u| {
let input_tokens = u.get("input_tokens").and_then(|t| t.as_u64())?;
let output_tokens = u.get("output_tokens").and_then(|t| t.as_u64())?;
let model = msg
.get("model")
.and_then(|m| m.as_str())
.unwrap_or("unknown")
.to_string();
Some(crate::providers::ProviderUsage {
input_tokens,
output_tokens,
model,
})
});
callback(ProviderStreamEvent::MessageComplete {
content: vec![],
usage,
stop_reason: None,
});
}
}
"result" => {
let subtype = message
.get("subtype")
.and_then(|s| s.as_str())
.unwrap_or("");
// Handle permission denials
if let Some(denials) = message
.get("permission_denials")
.and_then(|d| d.as_array())
{
for denial in denials {
let tool_name = denial
.get("tool_name")
.and_then(|t| t.as_str())
.unwrap_or("")
.to_string();
// Check if this is an AskUserQuestion
if tool_name == "AskUserQuestion" {
if let Some(tool_input) = denial.get("tool_input") {
if let Some(questions) =
tool_input.get("questions").and_then(|q| q.as_array())
{
if let Some(first_question) = questions.first() {
let question = first_question
.get("question")
.and_then(|q| q.as_str())
.unwrap_or("")
.to_string();
let header = first_question
.get("header")
.and_then(|h| h.as_str())
.map(|s| s.to_string());
let multi_select = first_question
.get("multiSelect")
.and_then(|m| m.as_bool())
.unwrap_or(false);
let options: Vec<QuestionOption> = first_question
.get("options")
.and_then(|opts| opts.as_array())
.map(|opts| {
opts.iter()
.filter_map(|opt| {
let label = opt
.get("label")
.and_then(|l| l.as_str())?;
let description = opt
.get("description")
.and_then(|d| d.as_str())
.map(|s| s.to_string());
Some(QuestionOption {
label: label.to_string(),
description,
})
})
.collect()
})
.unwrap_or_default();
let id = denial
.get("tool_use_id")
.and_then(|i| i.as_str())
.unwrap_or("")
.to_string();
callback(ProviderStreamEvent::Question {
id,
question,
header,
options,
multi_select,
});
}
}
}
} else {
let id = denial
.get("tool_use_id")
.and_then(|i| i.as_str())
.unwrap_or("")
.to_string();
let tool_input = denial
.get("tool_input")
.cloned()
.unwrap_or(serde_json::Value::Null);
let description = format_tool_description(&tool_name, &tool_input);
callback(ProviderStreamEvent::PermissionRequest {
id,
tool_name,
tool_input,
description,
});
}
}
}
if subtype != "success" {
if let Some(result) = message.get("result").and_then(|r| r.as_str()) {
callback(ProviderStreamEvent::Error {
message: result.to_string(),
});
}
}
}
_ => {}
}
Ok(())
}
fn format_tool_description(name: &str, input: &serde_json::Value) -> String {
match name {
"Read" => {
if let Some(path) = input.get("file_path").and_then(|v| v.as_str()) {
format!("Reading file: {}", path)
} else {
"Reading file...".to_string()
}
}
"Glob" => {
if let Some(pattern) = input.get("pattern").and_then(|v| v.as_str()) {
format!("Searching for files: {}", pattern)
} else {
"Searching for files...".to_string()
}
}
"Grep" => {
if let Some(pattern) = input.get("pattern").and_then(|v| v.as_str()) {
format!("Searching for: {}", pattern)
} else {
"Searching in files...".to_string()
}
}
"Edit" | "Write" => {
if let Some(path) = input.get("file_path").and_then(|v| v.as_str()) {
format!("Editing: {}", path)
} else {
"Editing file...".to_string()
}
}
"Bash" => {
if let Some(cmd) = input.get("command").and_then(|v| v.as_str()) {
let truncated = if cmd.len() > 50 {
format!("{}...", &cmd[..50])
} else {
cmd.to_string()
};
format!("Running: {}", truncated)
} else {
"Running command...".to_string()
}
}
_ => format!("Using tool: {}", name),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_name() {
let config = ProviderConfig::default();
let provider = ClaudeCliProvider::new(config);
assert_eq!(provider.name(), "Claude CLI");
}
#[test]
fn test_capabilities() {
let config = ProviderConfig::default();
let provider = ClaudeCliProvider::new(config);
let caps = provider.capabilities();
assert!(caps.supports_streaming);
assert!(caps.supports_tools);
assert!(caps.supports_vision);
assert!(caps.supports_thinking);
assert!(caps.supports_mcp);
assert!(caps.supports_resume_session);
assert_eq!(caps.max_context_tokens, Some(200_000));
assert!(!caps.available_models.is_empty());
}
#[test]
fn test_is_running_initial() {
let config = ProviderConfig::default();
let provider = ClaudeCliProvider::new(config);
assert!(!provider.is_running());
}
#[test]
fn test_get_working_directory() {
let config = ProviderConfig {
working_directory: "/home/test".to_string(),
..Default::default()
};
let provider = ClaudeCliProvider::new(config);
assert_eq!(provider.get_working_directory(), "/home/test");
}
#[test]
fn test_get_model() {
let config = ProviderConfig {
model: Some("claude-sonnet-4-20250514".to_string()),
..Default::default()
};
let provider = ClaudeCliProvider::new(config);
assert_eq!(
provider.get_model(),
Some("claude-sonnet-4-20250514".to_string())
);
}
#[test]
fn test_format_tool_description() {
let input = serde_json::json!({"file_path": "/test/file.rs"});
assert_eq!(
format_tool_description("Read", &input),
"Reading file: /test/file.rs"
);
let input = serde_json::json!({"pattern": "*.rs"});
assert_eq!(
format_tool_description("Glob", &input),
"Searching for files: *.rs"
);
let input = serde_json::json!({"command": "ls -la"});
assert_eq!(format_tool_description("Bash", &input), "Running: ls -la");
}
}
+89
View File
@@ -0,0 +1,89 @@
mod claude_cli;
mod ollama;
mod traits;
// Re-exports for when providers are fully integrated
#[allow(unused_imports)]
pub use claude_cli::ClaudeCliProvider;
#[allow(unused_imports)]
pub use ollama::OllamaProvider;
#[allow(unused_imports)]
pub use traits::{
LlmProvider, ModelInfo, ProviderCapabilities, ProviderConfig, ProviderMessage,
ProviderStreamEvent, ProviderUsage, QuestionOption, StreamCallback,
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum ProviderType {
#[default]
ClaudeCli,
Ollama,
}
#[allow(dead_code)]
impl ProviderType {
pub fn display_name(&self) -> &'static str {
match self {
ProviderType::ClaudeCli => "Claude CLI",
ProviderType::Ollama => "Ollama (Local)",
}
}
pub fn description(&self) -> &'static str {
match self {
ProviderType::ClaudeCli => "Use Claude Code CLI for AI assistance",
ProviderType::Ollama => "Use locally running Ollama models",
}
}
}
#[allow(dead_code)]
pub fn create_provider(
provider_type: ProviderType,
config: ProviderConfig,
) -> Box<dyn LlmProvider> {
match provider_type {
ProviderType::ClaudeCli => Box::new(ClaudeCliProvider::new(config)),
ProviderType::Ollama => Box::new(OllamaProvider::new(config)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_type_display_name() {
assert_eq!(ProviderType::ClaudeCli.display_name(), "Claude CLI");
assert_eq!(ProviderType::Ollama.display_name(), "Ollama (Local)");
}
#[test]
fn test_provider_type_default() {
let default: ProviderType = Default::default();
assert_eq!(default, ProviderType::ClaudeCli);
}
#[test]
fn test_provider_type_serialization() {
let claude = ProviderType::ClaudeCli;
let json = serde_json::to_string(&claude).unwrap();
assert_eq!(json, "\"claude_cli\"");
let ollama = ProviderType::Ollama;
let json = serde_json::to_string(&ollama).unwrap();
assert_eq!(json, "\"ollama\"");
}
#[test]
fn test_provider_type_deserialization() {
let claude: ProviderType = serde_json::from_str("\"claude_cli\"").unwrap();
assert_eq!(claude, ProviderType::ClaudeCli);
let ollama: ProviderType = serde_json::from_str("\"ollama\"").unwrap();
assert_eq!(ollama, ProviderType::Ollama);
}
}
+537
View File
@@ -0,0 +1,537 @@
// This provider connects to a local Ollama instance for LLM inference.
// It will be actively used once providers are fully integrated with BridgeManager.
#![allow(dead_code)]
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::mpsc;
use crate::providers::traits::{
LlmProvider, ModelInfo, ProviderCapabilities, ProviderConfig, ProviderMessage,
ProviderStreamEvent, ProviderUsage, StreamCallback,
};
const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434";
#[derive(Debug, Serialize)]
struct OllamaChatRequest {
model: String,
messages: Vec<OllamaMessage>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct OllamaMessage {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct OllamaChatResponse {
#[serde(default)]
message: Option<OllamaResponseMessage>,
#[serde(default)]
done: bool,
#[serde(default)]
eval_count: Option<u64>,
#[serde(default)]
prompt_eval_count: Option<u64>,
}
#[derive(Debug, Deserialize)]
struct OllamaResponseMessage {
#[serde(default)]
content: String,
}
#[derive(Debug, Deserialize)]
struct OllamaTagsResponse {
models: Vec<OllamaModelInfo>,
}
#[derive(Debug, Deserialize)]
struct OllamaModelInfo {
name: String,
#[serde(default)]
details: Option<OllamaModelDetails>,
}
#[derive(Debug, Deserialize)]
struct OllamaModelDetails {
#[serde(default)]
parameter_size: Option<String>,
#[serde(default)]
family: Option<String>,
}
pub struct OllamaProvider {
config: ProviderConfig,
client: reqwest::Client,
base_url: String,
is_running: Arc<AtomicBool>,
conversation_history: Vec<OllamaMessage>,
cancel_tx: Option<mpsc::Sender<()>>,
}
impl OllamaProvider {
pub fn new(config: ProviderConfig) -> Self {
let base_url = config
.api_base_url
.clone()
.unwrap_or_else(|| DEFAULT_OLLAMA_URL.to_string());
Self {
config,
client: reqwest::Client::new(),
base_url,
is_running: Arc::new(AtomicBool::new(false)),
conversation_history: Vec::new(),
cancel_tx: None,
}
}
async fn fetch_available_models(&self) -> Vec<ModelInfo> {
let url = format!("{}/api/tags", self.base_url);
match self.client.get(&url).send().await {
Ok(response) => {
if let Ok(tags) = response.json::<OllamaTagsResponse>().await {
tags.models
.into_iter()
.map(|m| {
let description = m.details.as_ref().map(|d| {
let mut desc_parts = Vec::new();
if let Some(ref family) = d.family {
desc_parts.push(family.clone());
}
if let Some(ref size) = d.parameter_size {
desc_parts.push(format!("{} parameters", size));
}
if desc_parts.is_empty() {
"Local model".to_string()
} else {
desc_parts.join(" - ")
}
});
ModelInfo {
id: m.name.clone(),
name: m.name,
description,
context_window: None,
input_cost_per_mtok: Some(0.0), // Local = free!
output_cost_per_mtok: Some(0.0),
}
})
.collect()
} else {
default_ollama_models()
}
}
Err(_) => default_ollama_models(),
}
}
async fn stream_chat(
&mut self,
callback: Arc<StreamCallback>,
cancel_rx: mpsc::Receiver<()>,
) -> Result<(), String> {
let model = self
.config
.model
.clone()
.unwrap_or_else(|| "llama3.2".to_string());
let request = OllamaChatRequest {
model: model.clone(),
messages: self.conversation_history.clone(),
stream: true,
system: self.config.custom_instructions.clone(),
};
let url = format!("{}/api/chat", self.base_url);
let response = self
.client
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| format!("Failed to connect to Ollama: {}", e))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(format!("Ollama error ({}): {}", status, error_text));
}
let mut full_response = String::new();
let mut total_input_tokens: u64 = 0;
let mut total_output_tokens: u64 = 0;
let mut stream = response.bytes_stream();
let mut cancel_rx = cancel_rx;
use futures_util::StreamExt;
loop {
tokio::select! {
_ = cancel_rx.recv() => {
callback(ProviderStreamEvent::Disconnected);
return Ok(());
}
chunk = stream.next() => {
match chunk {
Some(Ok(bytes)) => {
let text = String::from_utf8_lossy(&bytes);
for line in text.lines() {
if line.trim().is_empty() {
continue;
}
if let Ok(response) = serde_json::from_str::<OllamaChatResponse>(line) {
if let Some(msg) = &response.message {
if !msg.content.is_empty() {
full_response.push_str(&msg.content);
callback(ProviderStreamEvent::TextDelta {
text: msg.content.clone(),
});
}
}
if let Some(tokens) = response.prompt_eval_count {
total_input_tokens = tokens;
}
if let Some(tokens) = response.eval_count {
total_output_tokens = tokens;
}
if response.done {
self.conversation_history.push(OllamaMessage {
role: "assistant".to_string(),
content: full_response.clone(),
});
callback(ProviderStreamEvent::MessageComplete {
content: vec![ProviderMessage::Text {
content: full_response,
}],
usage: Some(ProviderUsage {
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
model,
}),
stop_reason: Some("end_turn".to_string()),
});
return Ok(());
}
}
}
}
Some(Err(e)) => {
return Err(format!("Stream error: {}", e));
}
None => {
break;
}
}
}
}
}
Ok(())
}
}
#[async_trait]
impl LlmProvider for OllamaProvider {
fn name(&self) -> &'static str {
"Ollama"
}
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities {
supports_streaming: true,
supports_tools: false, // Ollama doesn't support tools natively yet
supports_vision: true, // Some models support vision
supports_thinking: false,
supports_mcp: false,
supports_resume_session: false,
max_context_tokens: None, // Varies by model
available_models: default_ollama_models(),
}
}
async fn start(&mut self, callback: StreamCallback) -> Result<(), String> {
if self.is_running.load(Ordering::SeqCst) {
return Err("Provider already running".to_string());
}
// Check if Ollama is reachable
let url = format!("{}/api/tags", self.base_url);
self.client.get(&url).send().await.map_err(|e| {
format!(
"Cannot connect to Ollama at {}. Is it running? Error: {}",
self.base_url, e
)
})?;
self.is_running.store(true, Ordering::SeqCst);
self.conversation_history.clear();
// Create cancellation channel
let (tx, _rx) = mpsc::channel(1);
self.cancel_tx = Some(tx);
callback(ProviderStreamEvent::Connected { session_id: None });
// Fetch and report available models
let models = self.fetch_available_models().await;
if !models.is_empty() {
eprintln!(
"[Ollama] Available models: {:?}",
models.iter().map(|m| &m.id).collect::<Vec<_>>()
);
}
Ok(())
}
async fn stop(&mut self) -> Result<(), String> {
self.is_running.store(false, Ordering::SeqCst);
self.conversation_history.clear();
if let Some(tx) = self.cancel_tx.take() {
let _ = tx.send(()).await;
}
Ok(())
}
fn is_running(&self) -> bool {
self.is_running.load(Ordering::SeqCst)
}
async fn send_message(&mut self, message: &str) -> Result<(), String> {
if !self.is_running.load(Ordering::SeqCst) {
return Err("Provider not running".to_string());
}
// Add user message to history
self.conversation_history.push(OllamaMessage {
role: "user".to_string(),
content: message.to_string(),
});
// Create a new cancel channel for this request
let (tx, _rx) = mpsc::channel(1);
self.cancel_tx = Some(tx);
// We need a callback here, but we don't have access to it in send_message
// This is a limitation of the current trait design
// For now, we'll need to refactor to handle streaming properly
// The callback should be stored from the start() call
// For the MVP, we'll emit events directly
// In a real implementation, we'd need to restructure this
Err("send_message needs refactoring to work with stored callback".to_string())
}
async fn send_tool_result(
&mut self,
_tool_use_id: &str,
_result: serde_json::Value,
_is_error: bool,
) -> Result<(), String> {
// Ollama doesn't support tool use natively
Err("Ollama does not support tool use".to_string())
}
fn get_model(&self) -> Option<String> {
self.config.model.clone()
}
fn get_working_directory(&self) -> &str {
&self.config.working_directory
}
async fn interrupt(&mut self) -> Result<(), String> {
if let Some(tx) = self.cancel_tx.take() {
let _ = tx.send(()).await;
}
self.is_running.store(false, Ordering::SeqCst);
Ok(())
}
}
fn default_ollama_models() -> Vec<ModelInfo> {
vec![
ModelInfo {
id: "llama3.2".to_string(),
name: "Llama 3.2".to_string(),
description: Some("Meta's latest compact model".to_string()),
context_window: Some(128_000),
input_cost_per_mtok: Some(0.0),
output_cost_per_mtok: Some(0.0),
},
ModelInfo {
id: "llama3.2:1b".to_string(),
name: "Llama 3.2 1B".to_string(),
description: Some("Smallest Llama 3.2 variant".to_string()),
context_window: Some(128_000),
input_cost_per_mtok: Some(0.0),
output_cost_per_mtok: Some(0.0),
},
ModelInfo {
id: "qwen2.5-coder".to_string(),
name: "Qwen 2.5 Coder".to_string(),
description: Some("Alibaba's coding-focused model".to_string()),
context_window: Some(32_000),
input_cost_per_mtok: Some(0.0),
output_cost_per_mtok: Some(0.0),
},
ModelInfo {
id: "deepseek-coder-v2".to_string(),
name: "DeepSeek Coder V2".to_string(),
description: Some("DeepSeek's coding model".to_string()),
context_window: Some(128_000),
input_cost_per_mtok: Some(0.0),
output_cost_per_mtok: Some(0.0),
},
ModelInfo {
id: "mistral".to_string(),
name: "Mistral 7B".to_string(),
description: Some("Fast and capable".to_string()),
context_window: Some(32_000),
input_cost_per_mtok: Some(0.0),
output_cost_per_mtok: Some(0.0),
},
ModelInfo {
id: "gemma2".to_string(),
name: "Gemma 2".to_string(),
description: Some("Google's open model".to_string()),
context_window: Some(8_000),
input_cost_per_mtok: Some(0.0),
output_cost_per_mtok: Some(0.0),
},
]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_name() {
let config = ProviderConfig::default();
let provider = OllamaProvider::new(config);
assert_eq!(provider.name(), "Ollama");
}
#[test]
fn test_capabilities() {
let config = ProviderConfig::default();
let provider = OllamaProvider::new(config);
let caps = provider.capabilities();
assert!(caps.supports_streaming);
assert!(!caps.supports_tools);
assert!(!caps.supports_mcp);
assert!(!caps.supports_resume_session);
assert!(!caps.available_models.is_empty());
}
#[test]
fn test_default_url() {
let config = ProviderConfig::default();
let provider = OllamaProvider::new(config);
assert_eq!(provider.base_url, DEFAULT_OLLAMA_URL);
}
#[test]
fn test_custom_url() {
let config = ProviderConfig {
api_base_url: Some("http://custom:8080".to_string()),
..Default::default()
};
let provider = OllamaProvider::new(config);
assert_eq!(provider.base_url, "http://custom:8080");
}
#[test]
fn test_is_running_initial() {
let config = ProviderConfig::default();
let provider = OllamaProvider::new(config);
assert!(!provider.is_running());
}
#[test]
fn test_default_ollama_models() {
let models = default_ollama_models();
assert!(!models.is_empty());
// All models should be free (local)
for model in &models {
assert_eq!(model.input_cost_per_mtok, Some(0.0));
assert_eq!(model.output_cost_per_mtok, Some(0.0));
}
// Should include llama
assert!(models.iter().any(|m| m.id.contains("llama")));
}
#[test]
fn test_ollama_message_serialization() {
let msg = OllamaMessage {
role: "user".to_string(),
content: "Hello!".to_string(),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"role\":\"user\""));
assert!(json.contains("\"content\":\"Hello!\""));
}
#[test]
fn test_chat_request_serialization() {
let request = OllamaChatRequest {
model: "llama3.2".to_string(),
messages: vec![OllamaMessage {
role: "user".to_string(),
content: "Test".to_string(),
}],
stream: true,
system: Some("You are helpful".to_string()),
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("\"model\":\"llama3.2\""));
assert!(json.contains("\"stream\":true"));
assert!(json.contains("\"system\":\"You are helpful\""));
}
#[test]
fn test_chat_request_without_system() {
let request = OllamaChatRequest {
model: "llama3.2".to_string(),
messages: vec![],
stream: true,
system: None,
};
let json = serde_json::to_string(&request).unwrap();
// system should be omitted when None
assert!(!json.contains("\"system\""));
}
}
+236
View File
@@ -0,0 +1,236 @@
// These types are all used by the provider abstraction layer.
// They will be actively used once providers are fully integrated with BridgeManager.
#![allow(dead_code)]
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderConfig {
pub api_key: Option<String>,
pub api_base_url: Option<String>,
pub model: Option<String>,
pub custom_instructions: Option<String>,
pub working_directory: String,
pub mcp_servers_json: Option<String>,
pub allowed_tools: Vec<String>,
#[serde(default)]
pub extra_options: HashMap<String, serde_json::Value>,
}
impl Default for ProviderConfig {
fn default() -> Self {
Self {
api_key: None,
api_base_url: None,
model: None,
custom_instructions: None,
working_directory: String::new(),
mcp_servers_json: None,
allowed_tools: Vec::new(),
extra_options: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderCapabilities {
pub supports_streaming: bool,
pub supports_tools: bool,
pub supports_vision: bool,
pub supports_thinking: bool,
pub supports_mcp: bool,
pub supports_resume_session: bool,
pub max_context_tokens: Option<u64>,
pub available_models: Vec<ModelInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
pub id: String,
pub name: String,
pub description: Option<String>,
pub context_window: Option<u64>,
pub input_cost_per_mtok: Option<f64>,
pub output_cost_per_mtok: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ProviderMessage {
#[serde(rename = "text")]
Text { content: String },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
#[serde(rename = "tool_result")]
ToolResult {
tool_use_id: String,
content: String,
is_error: bool,
},
#[serde(rename = "thinking")]
Thinking { content: String },
#[serde(rename = "image")]
Image { media_type: String, data: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderUsage {
pub input_tokens: u64,
pub output_tokens: u64,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ProviderStreamEvent {
#[serde(rename = "connected")]
Connected { session_id: Option<String> },
#[serde(rename = "text_delta")]
TextDelta { text: String },
#[serde(rename = "thinking_delta")]
ThinkingDelta { text: String },
#[serde(rename = "tool_use_start")]
ToolUseStart { id: String, name: String },
#[serde(rename = "tool_use_delta")]
ToolUseDelta { id: String, input_delta: String },
#[serde(rename = "tool_use_end")]
ToolUseEnd { id: String, input: serde_json::Value },
#[serde(rename = "message_complete")]
MessageComplete {
content: Vec<ProviderMessage>,
usage: Option<ProviderUsage>,
stop_reason: Option<String>,
},
#[serde(rename = "permission_request")]
PermissionRequest {
id: String,
tool_name: String,
tool_input: serde_json::Value,
description: String,
},
#[serde(rename = "question")]
Question {
id: String,
question: String,
header: Option<String>,
options: Vec<QuestionOption>,
multi_select: bool,
},
#[serde(rename = "working_directory")]
WorkingDirectory { path: String },
#[serde(rename = "error")]
Error { message: String },
#[serde(rename = "disconnected")]
Disconnected,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuestionOption {
pub label: String,
pub description: Option<String>,
}
pub type StreamCallback = Box<dyn Fn(ProviderStreamEvent) + Send + Sync>;
#[async_trait]
pub trait LlmProvider: Send + Sync {
fn name(&self) -> &'static str;
fn capabilities(&self) -> ProviderCapabilities;
async fn start(&mut self, callback: StreamCallback) -> Result<(), String>;
async fn stop(&mut self) -> Result<(), String>;
fn is_running(&self) -> bool;
async fn send_message(&mut self, message: &str) -> Result<(), String>;
async fn send_tool_result(
&mut self,
tool_use_id: &str,
result: serde_json::Value,
is_error: bool,
) -> Result<(), String>;
fn get_model(&self) -> Option<String>;
fn get_working_directory(&self) -> &str;
async fn interrupt(&mut self) -> Result<(), String>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_config_default() {
let config = ProviderConfig::default();
assert!(config.api_key.is_none());
assert!(config.api_base_url.is_none());
assert!(config.model.is_none());
assert!(config.working_directory.is_empty());
assert!(config.allowed_tools.is_empty());
}
#[test]
fn test_provider_message_serialization() {
let msg = ProviderMessage::Text {
content: "Hello!".to_string(),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"type\":\"text\""));
assert!(json.contains("\"content\":\"Hello!\""));
}
#[test]
fn test_provider_stream_event_serialization() {
let event = ProviderStreamEvent::TextDelta {
text: "chunk".to_string(),
};
let json = serde_json::to_string(&event).unwrap();
assert!(json.contains("\"type\":\"text_delta\""));
let event = ProviderStreamEvent::Connected {
session_id: Some("test-123".to_string()),
};
let json = serde_json::to_string(&event).unwrap();
assert!(json.contains("\"session_id\":\"test-123\""));
}
#[test]
fn test_model_info() {
let model = ModelInfo {
id: "claude-sonnet-4-20250514".to_string(),
name: "Claude Sonnet 4".to_string(),
description: Some("Fast and intelligent".to_string()),
context_window: Some(200000),
input_cost_per_mtok: Some(3.0),
output_cost_per_mtok: Some(15.0),
};
let json = serde_json::to_string(&model).unwrap();
assert!(json.contains("claude-sonnet-4"));
assert!(json.contains("200000"));
}
#[test]
fn test_provider_usage() {
let usage = ProviderUsage {
input_tokens: 100,
output_tokens: 50,
model: "claude-opus-4-20250514".to_string(),
};
let json = serde_json::to_string(&usage).unwrap();
assert!(json.contains("\"input_tokens\":100"));
assert!(json.contains("\"output_tokens\":50"));
}
}