diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index cfcea77..1084e1a 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -592,7 +592,7 @@ dependencies = [ "bitflags 2.10.0", "core-foundation 0.10.1", "core-graphics-types", - "foreign-types", + "foreign-types 0.5.0", "libc", ] @@ -1080,6 +1080,15 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared 0.1.1", +] + [[package]] name = "foreign-types" version = "0.5.0" @@ -1087,7 +1096,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" dependencies = [ "foreign-types-macros", - "foreign-types-shared", + "foreign-types-shared 0.3.1", ] [[package]] @@ -1101,6 +1110,12 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "foreign-types-shared" version = "0.3.1" @@ -1604,8 +1619,11 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" name = "hikari-desktop" version = "1.1.1" dependencies = [ + "async-trait", "chrono", + "futures-util", "parking_lot", + "reqwest", "semver", "serde", "serde_json", @@ -1716,6 +1734,22 @@ dependencies = [ "webpki-roots", ] +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + [[package]] name = "hyper-util" version = "0.1.19" @@ -2288,6 +2322,23 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "native-tls" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "ndk" version = "0.9.0" @@ -2661,6 +2712,50 @@ dependencies = [ "pathdiff", ] +[[package]] +name = "openssl" +version = "0.10.75" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" +dependencies = [ + "bitflags 2.10.0", + "cfg-if", + "foreign-types 0.3.2", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", +] + +[[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + +[[package]] +name = "openssl-sys" +version = "0.9.111" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "option-ext" version = "0.2.0" @@ -3417,10 +3512,12 @@ dependencies = [ "http-body-util", "hyper", "hyper-rustls", + "hyper-tls", "hyper-util", "js-sys", "log", "mime", + "native-tls", "percent-encoding", "pin-project-lite", "quinn", @@ -3431,6 +3528,7 @@ dependencies = [ "serde_urlencoded", "sync_wrapper", "tokio", + "tokio-native-tls", "tokio-rustls", "tokio-util", "tower", @@ -3566,6 +3664,15 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "schannel" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "schemars" version = "0.8.22" @@ -3623,6 +3730,29 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags 2.10.0", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "selectors" version = "0.24.0" @@ -4746,6 +4876,16 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.26.4" @@ -5111,6 +5251,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version-compare" version = "0.2.1" diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index f2c1517..c70b8f1 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -31,6 +31,9 @@ tauri-plugin-fs = "2" tempfile = "3" semver = "1" chrono = { version = "0.4.43", features = ["serde"] } +async-trait = "0.1" +reqwest = { version = "0.12", features = ["json", "stream"] } +futures-util = "0.3" [target.'cfg(windows)'.dependencies] windows = { version = "0.62", features = [ diff --git a/src-tauri/src/bridge_manager.rs b/src-tauri/src/bridge_manager.rs index 4a2370b..e08aacb 100644 --- a/src-tauri/src/bridge_manager.rs +++ b/src-tauri/src/bridge_manager.rs @@ -4,11 +4,11 @@ use std::sync::Arc; use tauri::AppHandle; use crate::config::ClaudeStartOptions; +use crate::provider_bridge::ProviderBridge; use crate::stats::UsageStats; -use crate::wsl_bridge::WslBridge; pub struct BridgeManager { - bridges: HashMap, + bridges: HashMap, app_handle: Option, } @@ -45,12 +45,24 @@ impl BridgeManager { .ok_or_else(|| "App handle not set".to_string())? .clone(); + // Check if existing bridge matches the requested provider type + // If provider type changed, create a new bridge + let should_recreate = self.bridges.get(conversation_id).map_or(false, |bridge| { + bridge.provider_type() != options.provider_type + }); + + if should_recreate { + // Remove existing bridge if provider type changed + self.bridges.remove(conversation_id); + } + // Reuse existing bridge if it exists (preserves stats across reconnects) // Only create a new bridge if one doesn't exist for this conversation + let provider_type = options.provider_type; let bridge = self .bridges .entry(conversation_id.to_string()) - .or_insert_with(|| WslBridge::new_with_conversation_id(conversation_id.to_string())); + .or_insert_with(|| ProviderBridge::new(provider_type, conversation_id.to_string())); // Start the Claude process bridge.start(app, options)?; diff --git a/src-tauri/src/config.rs b/src-tauri/src/config.rs index 6ca733e..baae2dd 100644 --- a/src-tauri/src/config.rs +++ b/src-tauri/src/config.rs @@ -1,7 +1,11 @@ +use crate::providers::ProviderType; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct ClaudeStartOptions { + #[serde(default)] + pub provider_type: ProviderType, + #[serde(default)] pub working_dir: String, @@ -25,10 +29,20 @@ pub struct ClaudeStartOptions { #[serde(default)] pub resume_session_id: Option, + + // Ollama-specific options + #[serde(default = "default_ollama_base_url")] + pub ollama_base_url: String, + + #[serde(default)] + pub ollama_model: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct HikariConfig { + #[serde(default)] + pub provider_type: ProviderType, + #[serde(default)] pub model: Option, @@ -44,6 +58,13 @@ pub struct HikariConfig { #[serde(default)] pub auto_granted_tools: Vec, + // Ollama-specific settings + #[serde(default = "default_ollama_base_url")] + pub ollama_base_url: String, + + #[serde(default)] + pub ollama_model: Option, + #[serde(default)] pub theme: Theme, @@ -101,11 +122,14 @@ pub struct HikariConfig { impl Default for HikariConfig { fn default() -> Self { Self { + provider_type: ProviderType::default(), model: None, api_key: None, custom_instructions: None, mcp_servers_json: None, auto_granted_tools: Vec::new(), + ollama_base_url: default_ollama_base_url(), + ollama_model: None, theme: Theme::default(), greeting_enabled: true, greeting_custom_prompt: None, @@ -147,6 +171,10 @@ fn default_font_size() -> u32 { 14 } +fn default_ollama_base_url() -> String { + "http://localhost:11434".to_string() +} + #[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)] #[serde(rename_all = "lowercase")] pub enum Theme { @@ -185,11 +213,14 @@ mod tests { #[test] fn test_default_config() { let config = HikariConfig::default(); + assert_eq!(config.provider_type, ProviderType::ClaudeCli); assert!(config.model.is_none()); assert!(config.api_key.is_none()); assert!(config.custom_instructions.is_none()); assert!(config.mcp_servers_json.is_none()); assert!(config.auto_granted_tools.is_empty()); + assert_eq!(config.ollama_base_url, "http://localhost:11434"); + assert!(config.ollama_model.is_none()); assert_eq!(config.theme, Theme::Dark); assert!(config.greeting_enabled); assert!(config.greeting_custom_prompt.is_none()); @@ -210,11 +241,14 @@ mod tests { #[test] fn test_config_serialization() { let config = HikariConfig { + provider_type: ProviderType::ClaudeCli, model: Some("claude-sonnet-4-20250514".to_string()), api_key: None, custom_instructions: Some("Be helpful".to_string()), mcp_servers_json: None, auto_granted_tools: vec!["Read".to_string(), "Glob".to_string()], + ollama_base_url: "http://localhost:11434".to_string(), + ollama_model: None, theme: Theme::Light, greeting_enabled: true, greeting_custom_prompt: Some("Hello!".to_string()), diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 6db2820..793cda9 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -5,6 +5,8 @@ mod commands; mod config; mod git; mod notifications; +mod provider_bridge; +mod providers; mod quick_actions; mod sessions; mod snippets; diff --git a/src-tauri/src/provider_bridge.rs b/src-tauri/src/provider_bridge.rs new file mode 100644 index 0000000..c71e529 --- /dev/null +++ b/src-tauri/src/provider_bridge.rs @@ -0,0 +1,538 @@ +// Provider bridge that unifies different LLM provider implementations +// for use with the BridgeManager. + +use parking_lot::RwLock; +use std::sync::Arc; +use tauri::{AppHandle, Emitter}; + +use crate::achievements::{get_achievement_info, AchievementUnlockedEvent}; +use crate::config::ClaudeStartOptions; +use crate::providers::ProviderType; +use crate::stats::{StatsUpdateEvent, UsageStats}; +use crate::types::{ + CharacterState, ConnectionEvent, ConnectionStatus, OutputEvent, StateChangeEvent, +}; +use crate::wsl_bridge::WslBridge; + +/// A unified bridge that can wrap either the Claude CLI or Ollama provider +pub enum ProviderBridge { + ClaudeCli(WslBridge), + Ollama(OllamaBridge), +} + +impl ProviderBridge { + pub fn new(provider_type: ProviderType, conversation_id: String) -> Self { + match provider_type { + ProviderType::ClaudeCli => { + ProviderBridge::ClaudeCli(WslBridge::new_with_conversation_id(conversation_id)) + } + ProviderType::Ollama => { + ProviderBridge::Ollama(OllamaBridge::new(conversation_id)) + } + } + } + + pub fn start(&mut self, app: AppHandle, options: ClaudeStartOptions) -> Result<(), String> { + match self { + ProviderBridge::ClaudeCli(bridge) => bridge.start(app, options), + ProviderBridge::Ollama(bridge) => bridge.start(app, options), + } + } + + pub fn stop(&mut self, app: &AppHandle) { + match self { + ProviderBridge::ClaudeCli(bridge) => bridge.stop(app), + ProviderBridge::Ollama(bridge) => bridge.stop(app), + } + } + + pub fn interrupt(&mut self, app: &AppHandle) -> Result<(), String> { + match self { + ProviderBridge::ClaudeCli(bridge) => bridge.interrupt(app), + ProviderBridge::Ollama(bridge) => bridge.interrupt(app), + } + } + + pub fn send_message(&mut self, message: &str) -> Result<(), String> { + match self { + ProviderBridge::ClaudeCli(bridge) => bridge.send_message(message), + ProviderBridge::Ollama(bridge) => bridge.send_message(message), + } + } + + pub fn send_tool_result( + &mut self, + tool_use_id: &str, + result: serde_json::Value, + ) -> Result<(), String> { + match self { + ProviderBridge::ClaudeCli(bridge) => bridge.send_tool_result(tool_use_id, result), + ProviderBridge::Ollama(bridge) => bridge.send_tool_result(tool_use_id, result), + } + } + + pub fn is_running(&self) -> bool { + match self { + ProviderBridge::ClaudeCli(bridge) => bridge.is_running(), + ProviderBridge::Ollama(bridge) => bridge.is_running(), + } + } + + pub fn get_working_directory(&self) -> &str { + match self { + ProviderBridge::ClaudeCli(bridge) => bridge.get_working_directory(), + ProviderBridge::Ollama(bridge) => bridge.get_working_directory(), + } + } + + pub fn get_stats(&self) -> UsageStats { + match self { + ProviderBridge::ClaudeCli(bridge) => bridge.get_stats(), + ProviderBridge::Ollama(bridge) => bridge.get_stats(), + } + } + + pub fn provider_type(&self) -> ProviderType { + match self { + ProviderBridge::ClaudeCli(_) => ProviderType::ClaudeCli, + ProviderBridge::Ollama(_) => ProviderType::Ollama, + } + } +} + +/// Bridge for Ollama local LLM provider +pub struct OllamaBridge { + conversation_id: String, + working_directory: String, + is_running: bool, + stats: Arc>, + base_url: String, + model: String, + messages: Vec, + cancel_flag: Arc, + app_handle: Option, +} + +#[derive(Clone, serde::Serialize, serde::Deserialize)] +struct OllamaMessage { + role: String, + content: String, +} + +#[derive(serde::Deserialize)] +struct OllamaChatResponse { + message: Option, + done: bool, + #[serde(default)] + eval_count: Option, + #[serde(default)] + prompt_eval_count: Option, +} + +#[derive(serde::Deserialize)] +struct OllamaResponseMessage { + content: String, +} + +impl OllamaBridge { + pub fn new(conversation_id: String) -> Self { + OllamaBridge { + conversation_id, + working_directory: String::new(), + is_running: false, + stats: Arc::new(RwLock::new(UsageStats::new())), + base_url: "http://localhost:11434".to_string(), + model: "llama3.2".to_string(), + messages: Vec::new(), + cancel_flag: Arc::new(std::sync::atomic::AtomicBool::new(false)), + app_handle: None, + } + } + + pub fn start(&mut self, app: AppHandle, options: ClaudeStartOptions) -> Result<(), String> { + if self.is_running { + return Err("Ollama bridge already running".to_string()); + } + + self.working_directory = options.working_dir.clone(); + self.base_url = options.ollama_base_url.clone(); + self.model = options.ollama_model.clone().unwrap_or_else(|| "llama3.2".to_string()); + self.app_handle = Some(app.clone()); + self.cancel_flag.store(false, std::sync::atomic::Ordering::SeqCst); + + // Reset messages for new conversation + self.messages.clear(); + + // Add system prompt if custom instructions are provided + if let Some(ref instructions) = options.custom_instructions { + if !instructions.is_empty() { + self.messages.push(OllamaMessage { + role: "system".to_string(), + content: instructions.clone(), + }); + } + } + + emit_connection_status(&app, ConnectionStatus::Connecting, Some(self.conversation_id.clone())); + + // Check if Ollama is running + let base_url = self.base_url.clone(); + let app_clone = app.clone(); + let conv_id = self.conversation_id.clone(); + + let rt = tokio::runtime::Handle::current(); + let check_result = rt.block_on(async { + check_ollama_running(&base_url).await + }); + + match check_result { + Ok(true) => { + self.is_running = true; + emit_connection_status(&app, ConnectionStatus::Connected, Some(self.conversation_id.clone())); + emit_state_change(&app, CharacterState::Idle, None, Some(self.conversation_id.clone())); + Ok(()) + } + Ok(false) => { + emit_connection_status(&app_clone, ConnectionStatus::Disconnected, Some(conv_id)); + Err("Ollama is not running. Please start Ollama first.".to_string()) + } + Err(e) => { + emit_connection_status(&app_clone, ConnectionStatus::Disconnected, Some(conv_id)); + Err(format!("Failed to connect to Ollama: {}", e)) + } + } + } + + pub fn stop(&mut self, app: &AppHandle) { + self.cancel_flag.store(true, std::sync::atomic::Ordering::SeqCst); + self.is_running = false; + self.app_handle = None; + + // Save stats before stopping + let stats_snapshot = self.stats.read().clone(); + let app_clone = app.clone(); + tauri::async_runtime::spawn(async move { + if let Err(e) = crate::stats::save_stats(&app_clone, &stats_snapshot).await { + eprintln!("Failed to save stats: {}", e); + } + }); + + self.stats.write().reset_session(); + emit_connection_status(app, ConnectionStatus::Disconnected, Some(self.conversation_id.clone())); + } + + pub fn interrupt(&mut self, app: &AppHandle) -> Result<(), String> { + self.cancel_flag.store(true, std::sync::atomic::Ordering::SeqCst); + self.is_running = false; + emit_connection_status(app, ConnectionStatus::Disconnected, Some(self.conversation_id.clone())); + Ok(()) + } + + pub fn send_message(&mut self, message: &str) -> Result<(), String> { + if !self.is_running { + return Err("Ollama bridge not running".to_string()); + } + + let app = self.app_handle.clone().ok_or("App handle not available")?; + + // Add user message to history + self.messages.push(OllamaMessage { + role: "user".to_string(), + content: message.to_string(), + }); + + // Increment message count + self.stats.write().increment_messages(); + + emit_state_change(&app, CharacterState::Thinking, None, Some(self.conversation_id.clone())); + + // Prepare request + let base_url = self.base_url.clone(); + let model = self.model.clone(); + let messages = self.messages.clone(); + let stats = self.stats.clone(); + let conv_id = self.conversation_id.clone(); + let cancel_flag = self.cancel_flag.clone(); + + // Spawn async task to handle the chat + tauri::async_runtime::spawn(async move { + match stream_ollama_chat(&base_url, &model, &messages, &app, &stats, &conv_id, &cancel_flag).await { + Ok(_response) => { + // Response handling is done in stream_ollama_chat + } + Err(e) => { + let _ = app.emit( + "claude:output", + OutputEvent { + line_type: "error".to_string(), + content: format!("Ollama error: {}", e), + tool_name: None, + conversation_id: Some(conv_id.clone()), + }, + ); + emit_state_change(&app, CharacterState::Error, None, Some(conv_id)); + } + } + }); + + Ok(()) + } + + pub fn send_tool_result( + &mut self, + _tool_use_id: &str, + _result: serde_json::Value, + ) -> Result<(), String> { + // Ollama doesn't support tools in the same way as Claude + // For now, we just ignore tool results + Err("Tool results are not supported with Ollama provider".to_string()) + } + + pub fn is_running(&self) -> bool { + self.is_running + } + + pub fn get_working_directory(&self) -> &str { + &self.working_directory + } + + pub fn get_stats(&self) -> UsageStats { + self.stats.read().clone() + } +} + +async fn check_ollama_running(base_url: &str) -> Result { + let client = reqwest::Client::new(); + let url = format!("{}/api/tags", base_url); + + match client.get(&url).timeout(std::time::Duration::from_secs(5)).send().await { + Ok(response) => Ok(response.status().is_success()), + Err(e) => { + if e.is_connect() { + Ok(false) + } else { + Err(e.to_string()) + } + } + } +} + +async fn stream_ollama_chat( + base_url: &str, + model: &str, + messages: &[OllamaMessage], + app: &AppHandle, + stats: &Arc>, + conversation_id: &str, + cancel_flag: &Arc, +) -> Result { + let client = reqwest::Client::new(); + let url = format!("{}/api/chat", base_url); + + emit_state_change(app, CharacterState::Typing, None, Some(conversation_id.to_string())); + + let request_body = serde_json::json!({ + "model": model, + "messages": messages, + "stream": true + }); + + let response = client + .post(&url) + .json(&request_body) + .send() + .await + .map_err(|e| format!("Failed to send request to Ollama: {}", e))?; + + if !response.status().is_success() { + return Err(format!("Ollama returned error status: {}", response.status())); + } + + let mut full_response = String::new(); + let mut total_tokens: u64 = 0; + let mut prompt_tokens: u64 = 0; + + // Stream the response + let mut stream = response.bytes_stream(); + use futures_util::StreamExt; + + while let Some(chunk_result) = stream.next().await { + // Check for cancellation + if cancel_flag.load(std::sync::atomic::Ordering::SeqCst) { + break; + } + + match chunk_result { + Ok(chunk) => { + let chunk_str = String::from_utf8_lossy(&chunk); + + // Parse each line as JSON + for line in chunk_str.lines() { + if line.trim().is_empty() { + continue; + } + + match serde_json::from_str::(line) { + Ok(response) => { + if let Some(msg) = response.message { + full_response.push_str(&msg.content); + + // Emit stream event + let _ = app.emit("claude:stream", msg.content); + } + + // Capture token counts + if let Some(eval) = response.eval_count { + total_tokens = eval; + } + if let Some(prompt) = response.prompt_eval_count { + prompt_tokens = prompt; + } + + if response.done { + break; + } + } + Err(e) => { + eprintln!("Failed to parse Ollama response: {} - Line: {}", e, line); + } + } + } + } + Err(e) => { + return Err(format!("Error reading stream: {}", e)); + } + } + } + + // Update stats with token usage + { + let mut stats_guard = stats.write(); + stats_guard.increment_messages(); + // Ollama reports output tokens as eval_count, input as prompt_eval_count + stats_guard.add_usage(prompt_tokens, total_tokens, model); + + // Count code blocks in the response + let code_blocks = full_response.matches("```").count() / 2; + for _ in 0..code_blocks { + stats_guard.increment_code_blocks(); + } + } + + // Emit the complete response + let _ = app.emit( + "claude:output", + OutputEvent { + line_type: "assistant".to_string(), + content: full_response.clone(), + tool_name: None, + conversation_id: Some(conversation_id.to_string()), + }, + ); + + // Check achievements + let newly_unlocked = { + let mut stats_guard = stats.write(); + stats_guard.check_achievements() + }; + + // Emit achievement events + for achievement_id in &newly_unlocked { + let info = get_achievement_info(achievement_id); + let _ = app.emit( + "achievement:unlocked", + AchievementUnlockedEvent { achievement: info }, + ); + } + + // Save achievements if any were unlocked + if !newly_unlocked.is_empty() { + let app_clone = app.clone(); + let achievements_progress = stats.read().achievements.clone(); + tauri::async_runtime::spawn(async move { + if let Err(e) = crate::achievements::save_achievements(&app_clone, &achievements_progress).await { + eprintln!("Failed to save achievements: {}", e); + } + }); + } + + // Emit stats update + let current_stats = stats.read().clone(); + let _ = app.emit("claude:stats", StatsUpdateEvent { stats: current_stats }); + + emit_state_change(app, CharacterState::Success, None, Some(conversation_id.to_string())); + + Ok(full_response) +} + +fn emit_state_change( + app: &AppHandle, + state: CharacterState, + tool_name: Option, + conversation_id: Option, +) { + let _ = app.emit( + "claude:state", + StateChangeEvent { + state, + tool_name, + conversation_id, + }, + ); +} + +fn emit_connection_status( + app: &AppHandle, + status: ConnectionStatus, + conversation_id: Option, +) { + let _ = app.emit( + "claude:connection", + ConnectionEvent { + status, + conversation_id, + }, + ); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_provider_bridge_ollama_new() { + let bridge = OllamaBridge::new("test-conv".to_string()); + assert!(!bridge.is_running()); + assert_eq!(bridge.get_working_directory(), ""); + assert_eq!(bridge.conversation_id, "test-conv"); + } + + #[test] + fn test_provider_bridge_type() { + let claude = ProviderBridge::new(ProviderType::ClaudeCli, "conv1".to_string()); + assert_eq!(claude.provider_type(), ProviderType::ClaudeCli); + + let ollama = ProviderBridge::new(ProviderType::Ollama, "conv2".to_string()); + assert_eq!(ollama.provider_type(), ProviderType::Ollama); + } + + #[test] + fn test_provider_bridge_is_running_initial() { + let bridge = ProviderBridge::new(ProviderType::ClaudeCli, "test".to_string()); + assert!(!bridge.is_running()); + + let bridge = ProviderBridge::new(ProviderType::Ollama, "test".to_string()); + assert!(!bridge.is_running()); + } + + #[test] + fn test_ollama_message_serialization() { + let msg = OllamaMessage { + role: "user".to_string(), + content: "Hello!".to_string(), + }; + let json = serde_json::to_string(&msg).unwrap(); + assert!(json.contains("\"role\":\"user\"")); + assert!(json.contains("\"content\":\"Hello!\"")); + } +} diff --git a/src-tauri/src/providers/claude_cli.rs b/src-tauri/src/providers/claude_cli.rs new file mode 100644 index 0000000..3eaa15f --- /dev/null +++ b/src-tauri/src/providers/claude_cli.rs @@ -0,0 +1,773 @@ +// This provider wraps the Claude CLI subprocess. +// It will be actively used once providers are fully integrated with BridgeManager. +#![allow(dead_code)] + +use async_trait::async_trait; +use std::io::{BufRead, BufReader, Write}; +use std::process::{Child, ChildStdin, Command, Stdio}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::thread; +use tempfile::NamedTempFile; + +use crate::providers::traits::{ + LlmProvider, ModelInfo, ProviderCapabilities, ProviderConfig, ProviderStreamEvent, + QuestionOption, StreamCallback, +}; + +#[cfg(target_os = "windows")] +use std::os::windows::process::CommandExt; + +fn detect_wsl() -> bool { + if let Ok(version) = std::fs::read_to_string("/proc/version") { + let version_lower = version.to_lowercase(); + if version_lower.contains("microsoft") || version_lower.contains("wsl") { + return true; + } + } + + if std::path::Path::new("/proc/sys/fs/binfmt_misc/WSLInterop").exists() { + return true; + } + + if std::env::var("WSL_DISTRO_NAME").is_ok() { + return true; + } + + false +} + +fn find_claude_binary() -> Option { + let home = std::env::var("HOME").ok()?; + let paths_to_check = [ + format!("{}/.local/bin/claude", home), + format!("{}/.claude/local/claude", home), + "/usr/local/bin/claude".to_string(), + "/usr/bin/claude".to_string(), + ]; + + for path in &paths_to_check { + if std::path::Path::new(path).exists() { + return Some(path.clone()); + } + } + + if let Ok(output) = Command::new("which").arg("claude").output() { + if output.status.success() { + let path = String::from_utf8_lossy(&output.stdout).trim().to_string(); + if !path.is_empty() { + return Some(path); + } + } + } + + None +} + +pub struct ClaudeCliProvider { + config: ProviderConfig, + process: Option, + stdin: Option, + session_id: Option, + mcp_config_file: Option, + is_running: Arc, +} + +impl ClaudeCliProvider { + pub fn new(config: ProviderConfig) -> Self { + Self { + config, + process: None, + stdin: None, + session_id: None, + mcp_config_file: None, + is_running: Arc::new(AtomicBool::new(false)), + } + } + + fn build_command(&mut self) -> Result { + let mcp_config_path = if let Some(ref mcp_json) = self.config.mcp_servers_json { + if !mcp_json.trim().is_empty() { + serde_json::from_str::(mcp_json) + .map_err(|e| format!("Invalid MCP servers JSON: {}", e))?; + + let mut temp_file = NamedTempFile::new() + .map_err(|e| format!("Failed to create temp file for MCP config: {}", e))?; + temp_file + .write_all(mcp_json.as_bytes()) + .map_err(|e| format!("Failed to write MCP config: {}", e))?; + temp_file + .flush() + .map_err(|e| format!("Failed to flush MCP config: {}", e))?; + + let path = temp_file.path().to_string_lossy().to_string(); + self.mcp_config_file = Some(temp_file); + Some(path) + } else { + None + } + } else { + None + }; + + let is_wsl = detect_wsl(); + let working_dir = &self.config.working_directory; + + let command = if is_wsl { + let claude_path = find_claude_binary().ok_or_else(|| { + "Could not find claude binary. Is Claude Code installed?".to_string() + })?; + + let mut cmd = Command::new(&claude_path); + cmd.args([ + "--output-format", + "stream-json", + "--input-format", + "stream-json", + "--verbose", + ]); + + if let Some(ref model) = self.config.model { + if !model.is_empty() { + cmd.args(["--model", model]); + } + } + + for tool in &self.config.allowed_tools { + cmd.args(["--allowedTools", tool]); + } + + if let Some(ref instructions) = self.config.custom_instructions { + if !instructions.is_empty() { + cmd.args(["--system-prompt", instructions]); + } + } + + if let Some(ref mcp_path) = mcp_config_path { + cmd.args(["--mcp-config", mcp_path]); + } + + cmd.current_dir(working_dir); + + if let Some(ref api_key) = self.config.api_key { + if !api_key.is_empty() { + cmd.env("ANTHROPIC_API_KEY", api_key); + } + } + + cmd + } else { + let mut cmd = Command::new("wsl"); + + let mut claude_cmd = format!("cd '{}' && ", working_dir); + + if let Some(ref api_key) = self.config.api_key { + if !api_key.is_empty() { + claude_cmd.push_str(&format!("ANTHROPIC_API_KEY='{}' ", api_key)); + } + } + + claude_cmd.push_str( + "claude --output-format stream-json --input-format stream-json --verbose", + ); + + if let Some(ref model) = self.config.model { + if !model.is_empty() { + claude_cmd.push_str(&format!(" --model '{}'", model)); + } + } + + for tool in &self.config.allowed_tools { + claude_cmd.push_str(&format!(" --allowedTools '{}'", tool)); + } + + if let Some(ref instructions) = self.config.custom_instructions { + if !instructions.is_empty() { + let escaped = instructions.replace('\'', "'\\''"); + claude_cmd.push_str(&format!(" --system-prompt '{}'", escaped)); + } + } + + if let Some(ref mcp_path) = mcp_config_path { + claude_cmd.push_str(&format!(" --mcp-config '{}'", mcp_path)); + } + + cmd.args(["-e", "bash", "-lc", &claude_cmd]); + + #[cfg(target_os = "windows")] + cmd.creation_flags(0x08000000); // CREATE_NO_WINDOW + + cmd + }; + + Ok(command) + } +} + +#[async_trait] +impl LlmProvider for ClaudeCliProvider { + fn name(&self) -> &'static str { + "Claude CLI" + } + + fn capabilities(&self) -> ProviderCapabilities { + ProviderCapabilities { + supports_streaming: true, + supports_tools: true, + supports_vision: true, + supports_thinking: true, + supports_mcp: true, + supports_resume_session: true, + max_context_tokens: Some(200_000), + available_models: vec![ + ModelInfo { + id: "claude-sonnet-4-20250514".to_string(), + name: "Claude Sonnet 4".to_string(), + description: Some("Fast and intelligent".to_string()), + context_window: Some(200_000), + input_cost_per_mtok: Some(3.0), + output_cost_per_mtok: Some(15.0), + }, + ModelInfo { + id: "claude-opus-4-20250514".to_string(), + name: "Claude Opus 4".to_string(), + description: Some("Most capable model".to_string()), + context_window: Some(200_000), + input_cost_per_mtok: Some(15.0), + output_cost_per_mtok: Some(75.0), + }, + ModelInfo { + id: "claude-3-5-haiku-20241022".to_string(), + name: "Claude 3.5 Haiku".to_string(), + description: Some("Fast and efficient".to_string()), + context_window: Some(200_000), + input_cost_per_mtok: Some(1.0), + output_cost_per_mtok: Some(5.0), + }, + ], + } + } + + async fn start(&mut self, callback: StreamCallback) -> Result<(), String> { + if self.process.is_some() { + return Err("Process already running".to_string()); + } + + callback(ProviderStreamEvent::Connected { session_id: None }); + + let mut command = self.build_command()?; + + command + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + + let mut child = command + .spawn() + .map_err(|e| format!("Failed to spawn process: {}", e))?; + + let stdin = child.stdin.take(); + let stdout = child.stdout.take(); + let stderr = child.stderr.take(); + + self.stdin = stdin; + self.process = Some(child); + self.is_running.store(true, Ordering::SeqCst); + + let is_running = self.is_running.clone(); + let callback = Arc::new(callback); + + if let Some(stdout) = stdout { + let callback_clone = callback.clone(); + let is_running_clone = is_running.clone(); + thread::spawn(move || { + handle_stdout(stdout, callback_clone, is_running_clone); + }); + } + + if let Some(stderr) = stderr { + let callback_clone = callback.clone(); + thread::spawn(move || { + handle_stderr(stderr, callback_clone); + }); + } + + Ok(()) + } + + async fn stop(&mut self) -> Result<(), String> { + if let Some(mut process) = self.process.take() { + let _ = process.kill(); + let _ = process.wait(); + } + self.stdin = None; + self.session_id = None; + self.mcp_config_file = None; + self.is_running.store(false, Ordering::SeqCst); + Ok(()) + } + + fn is_running(&self) -> bool { + self.is_running.load(Ordering::SeqCst) + } + + async fn send_message(&mut self, message: &str) -> Result<(), String> { + let stdin = self.stdin.as_mut().ok_or("Process not running")?; + + let input = serde_json::json!({ + "type": "user", + "message": { + "role": "user", + "content": [{ + "type": "text", + "text": message + }] + } + }); + + let json_line = serde_json::to_string(&input).map_err(|e| e.to_string())?; + + stdin + .write_all(format!("{}\n", json_line).as_bytes()) + .map_err(|e| format!("Failed to write to stdin: {}", e))?; + + stdin + .flush() + .map_err(|e| format!("Failed to flush stdin: {}", e))?; + + Ok(()) + } + + async fn send_tool_result( + &mut self, + tool_use_id: &str, + result: serde_json::Value, + _is_error: bool, + ) -> Result<(), String> { + let stdin = self.stdin.as_mut().ok_or("Process not running")?; + + let content_str = serde_json::to_string(&result).map_err(|e| e.to_string())?; + + let input = serde_json::json!({ + "type": "user", + "message": { + "role": "user", + "content": [{ + "type": "tool_result", + "tool_use_id": tool_use_id, + "content": content_str + }] + } + }); + + let json_line = serde_json::to_string(&input).map_err(|e| e.to_string())?; + + stdin + .write_all(format!("{}\n", json_line).as_bytes()) + .map_err(|e| format!("Failed to write to stdin: {}", e))?; + + stdin + .flush() + .map_err(|e| format!("Failed to flush stdin: {}", e))?; + + Ok(()) + } + + fn get_model(&self) -> Option { + self.config.model.clone() + } + + fn get_working_directory(&self) -> &str { + &self.config.working_directory + } + + async fn interrupt(&mut self) -> Result<(), String> { + if let Some(mut process) = self.process.take() { + let _ = process.kill(); + let _ = process.wait(); + self.stdin = None; + self.is_running.store(false, Ordering::SeqCst); + Ok(()) + } else { + Err("No active process to interrupt".to_string()) + } + } +} + +fn handle_stdout( + stdout: std::process::ChildStdout, + callback: Arc, + is_running: Arc, +) { + let reader = BufReader::new(stdout); + + for line in reader.lines() { + match line { + Ok(line) if !line.is_empty() => { + if let Err(e) = process_json_line(&line, &callback) { + eprintln!("Error processing line: {}", e); + } + } + Err(e) => { + eprintln!("Error reading stdout: {}", e); + break; + } + _ => {} + } + } + + is_running.store(false, Ordering::SeqCst); + callback(ProviderStreamEvent::Disconnected); +} + +fn handle_stderr(stderr: std::process::ChildStderr, callback: Arc) { + let reader = BufReader::new(stderr); + + for line in reader.lines() { + match line { + Ok(line) if !line.is_empty() => { + callback(ProviderStreamEvent::Error { message: line }); + } + Err(_) => break, + _ => {} + } + } +} + +fn process_json_line(line: &str, callback: &StreamCallback) -> Result<(), String> { + let message: serde_json::Value = serde_json::from_str(line) + .map_err(|e| format!("Failed to parse JSON: {} - Line: {}", e, line))?; + + let msg_type = message.get("type").and_then(|t| t.as_str()).unwrap_or(""); + + match msg_type { + "system" => { + let subtype = message + .get("subtype") + .and_then(|s| s.as_str()) + .unwrap_or(""); + if subtype == "init" { + let session_id = message + .get("session_id") + .and_then(|s| s.as_str()) + .map(|s| s.to_string()); + callback(ProviderStreamEvent::Connected { session_id }); + + if let Some(cwd) = message.get("cwd").and_then(|c| c.as_str()) { + callback(ProviderStreamEvent::WorkingDirectory { + path: cwd.to_string(), + }); + } + } + } + + "assistant" => { + if let Some(msg) = message.get("message") { + if let Some(content) = msg.get("content").and_then(|c| c.as_array()) { + for block in content { + let block_type = block.get("type").and_then(|t| t.as_str()).unwrap_or(""); + match block_type { + "text" => { + if let Some(text) = block.get("text").and_then(|t| t.as_str()) { + callback(ProviderStreamEvent::TextDelta { + text: text.to_string(), + }); + } + } + "thinking" => { + if let Some(thinking) = + block.get("thinking").and_then(|t| t.as_str()) + { + callback(ProviderStreamEvent::ThinkingDelta { + text: thinking.to_string(), + }); + } + } + "tool_use" => { + let id = block + .get("id") + .and_then(|i| i.as_str()) + .unwrap_or("") + .to_string(); + let name = block + .get("name") + .and_then(|n| n.as_str()) + .unwrap_or("") + .to_string(); + let input = block + .get("input") + .cloned() + .unwrap_or(serde_json::Value::Null); + + callback(ProviderStreamEvent::ToolUseStart { + id: id.clone(), + name, + }); + callback(ProviderStreamEvent::ToolUseEnd { id, input }); + } + _ => {} + } + } + } + + // Extract usage if available + let usage = msg.get("usage").and_then(|u| { + let input_tokens = u.get("input_tokens").and_then(|t| t.as_u64())?; + let output_tokens = u.get("output_tokens").and_then(|t| t.as_u64())?; + let model = msg + .get("model") + .and_then(|m| m.as_str()) + .unwrap_or("unknown") + .to_string(); + Some(crate::providers::ProviderUsage { + input_tokens, + output_tokens, + model, + }) + }); + + callback(ProviderStreamEvent::MessageComplete { + content: vec![], + usage, + stop_reason: None, + }); + } + } + + "result" => { + let subtype = message + .get("subtype") + .and_then(|s| s.as_str()) + .unwrap_or(""); + + // Handle permission denials + if let Some(denials) = message + .get("permission_denials") + .and_then(|d| d.as_array()) + { + for denial in denials { + let tool_name = denial + .get("tool_name") + .and_then(|t| t.as_str()) + .unwrap_or("") + .to_string(); + + // Check if this is an AskUserQuestion + if tool_name == "AskUserQuestion" { + if let Some(tool_input) = denial.get("tool_input") { + if let Some(questions) = + tool_input.get("questions").and_then(|q| q.as_array()) + { + if let Some(first_question) = questions.first() { + let question = first_question + .get("question") + .and_then(|q| q.as_str()) + .unwrap_or("") + .to_string(); + + let header = first_question + .get("header") + .and_then(|h| h.as_str()) + .map(|s| s.to_string()); + + let multi_select = first_question + .get("multiSelect") + .and_then(|m| m.as_bool()) + .unwrap_or(false); + + let options: Vec = first_question + .get("options") + .and_then(|opts| opts.as_array()) + .map(|opts| { + opts.iter() + .filter_map(|opt| { + let label = opt + .get("label") + .and_then(|l| l.as_str())?; + let description = opt + .get("description") + .and_then(|d| d.as_str()) + .map(|s| s.to_string()); + Some(QuestionOption { + label: label.to_string(), + description, + }) + }) + .collect() + }) + .unwrap_or_default(); + + let id = denial + .get("tool_use_id") + .and_then(|i| i.as_str()) + .unwrap_or("") + .to_string(); + + callback(ProviderStreamEvent::Question { + id, + question, + header, + options, + multi_select, + }); + } + } + } + } else { + let id = denial + .get("tool_use_id") + .and_then(|i| i.as_str()) + .unwrap_or("") + .to_string(); + let tool_input = denial + .get("tool_input") + .cloned() + .unwrap_or(serde_json::Value::Null); + let description = format_tool_description(&tool_name, &tool_input); + + callback(ProviderStreamEvent::PermissionRequest { + id, + tool_name, + tool_input, + description, + }); + } + } + } + + if subtype != "success" { + if let Some(result) = message.get("result").and_then(|r| r.as_str()) { + callback(ProviderStreamEvent::Error { + message: result.to_string(), + }); + } + } + } + + _ => {} + } + + Ok(()) +} + +fn format_tool_description(name: &str, input: &serde_json::Value) -> String { + match name { + "Read" => { + if let Some(path) = input.get("file_path").and_then(|v| v.as_str()) { + format!("Reading file: {}", path) + } else { + "Reading file...".to_string() + } + } + "Glob" => { + if let Some(pattern) = input.get("pattern").and_then(|v| v.as_str()) { + format!("Searching for files: {}", pattern) + } else { + "Searching for files...".to_string() + } + } + "Grep" => { + if let Some(pattern) = input.get("pattern").and_then(|v| v.as_str()) { + format!("Searching for: {}", pattern) + } else { + "Searching in files...".to_string() + } + } + "Edit" | "Write" => { + if let Some(path) = input.get("file_path").and_then(|v| v.as_str()) { + format!("Editing: {}", path) + } else { + "Editing file...".to_string() + } + } + "Bash" => { + if let Some(cmd) = input.get("command").and_then(|v| v.as_str()) { + let truncated = if cmd.len() > 50 { + format!("{}...", &cmd[..50]) + } else { + cmd.to_string() + }; + format!("Running: {}", truncated) + } else { + "Running command...".to_string() + } + } + _ => format!("Using tool: {}", name), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_provider_name() { + let config = ProviderConfig::default(); + let provider = ClaudeCliProvider::new(config); + assert_eq!(provider.name(), "Claude CLI"); + } + + #[test] + fn test_capabilities() { + let config = ProviderConfig::default(); + let provider = ClaudeCliProvider::new(config); + let caps = provider.capabilities(); + + assert!(caps.supports_streaming); + assert!(caps.supports_tools); + assert!(caps.supports_vision); + assert!(caps.supports_thinking); + assert!(caps.supports_mcp); + assert!(caps.supports_resume_session); + assert_eq!(caps.max_context_tokens, Some(200_000)); + assert!(!caps.available_models.is_empty()); + } + + #[test] + fn test_is_running_initial() { + let config = ProviderConfig::default(); + let provider = ClaudeCliProvider::new(config); + assert!(!provider.is_running()); + } + + #[test] + fn test_get_working_directory() { + let config = ProviderConfig { + working_directory: "/home/test".to_string(), + ..Default::default() + }; + let provider = ClaudeCliProvider::new(config); + assert_eq!(provider.get_working_directory(), "/home/test"); + } + + #[test] + fn test_get_model() { + let config = ProviderConfig { + model: Some("claude-sonnet-4-20250514".to_string()), + ..Default::default() + }; + let provider = ClaudeCliProvider::new(config); + assert_eq!( + provider.get_model(), + Some("claude-sonnet-4-20250514".to_string()) + ); + } + + #[test] + fn test_format_tool_description() { + let input = serde_json::json!({"file_path": "/test/file.rs"}); + assert_eq!( + format_tool_description("Read", &input), + "Reading file: /test/file.rs" + ); + + let input = serde_json::json!({"pattern": "*.rs"}); + assert_eq!( + format_tool_description("Glob", &input), + "Searching for files: *.rs" + ); + + let input = serde_json::json!({"command": "ls -la"}); + assert_eq!(format_tool_description("Bash", &input), "Running: ls -la"); + } +} diff --git a/src-tauri/src/providers/mod.rs b/src-tauri/src/providers/mod.rs new file mode 100644 index 0000000..48e75f9 --- /dev/null +++ b/src-tauri/src/providers/mod.rs @@ -0,0 +1,89 @@ +mod claude_cli; +mod ollama; +mod traits; + +// Re-exports for when providers are fully integrated +#[allow(unused_imports)] +pub use claude_cli::ClaudeCliProvider; +#[allow(unused_imports)] +pub use ollama::OllamaProvider; +#[allow(unused_imports)] +pub use traits::{ + LlmProvider, ModelInfo, ProviderCapabilities, ProviderConfig, ProviderMessage, + ProviderStreamEvent, ProviderUsage, QuestionOption, StreamCallback, +}; + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum ProviderType { + #[default] + ClaudeCli, + Ollama, +} + +#[allow(dead_code)] +impl ProviderType { + pub fn display_name(&self) -> &'static str { + match self { + ProviderType::ClaudeCli => "Claude CLI", + ProviderType::Ollama => "Ollama (Local)", + } + } + + pub fn description(&self) -> &'static str { + match self { + ProviderType::ClaudeCli => "Use Claude Code CLI for AI assistance", + ProviderType::Ollama => "Use locally running Ollama models", + } + } +} + +#[allow(dead_code)] +pub fn create_provider( + provider_type: ProviderType, + config: ProviderConfig, +) -> Box { + match provider_type { + ProviderType::ClaudeCli => Box::new(ClaudeCliProvider::new(config)), + ProviderType::Ollama => Box::new(OllamaProvider::new(config)), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_provider_type_display_name() { + assert_eq!(ProviderType::ClaudeCli.display_name(), "Claude CLI"); + assert_eq!(ProviderType::Ollama.display_name(), "Ollama (Local)"); + } + + #[test] + fn test_provider_type_default() { + let default: ProviderType = Default::default(); + assert_eq!(default, ProviderType::ClaudeCli); + } + + #[test] + fn test_provider_type_serialization() { + let claude = ProviderType::ClaudeCli; + let json = serde_json::to_string(&claude).unwrap(); + assert_eq!(json, "\"claude_cli\""); + + let ollama = ProviderType::Ollama; + let json = serde_json::to_string(&ollama).unwrap(); + assert_eq!(json, "\"ollama\""); + } + + #[test] + fn test_provider_type_deserialization() { + let claude: ProviderType = serde_json::from_str("\"claude_cli\"").unwrap(); + assert_eq!(claude, ProviderType::ClaudeCli); + + let ollama: ProviderType = serde_json::from_str("\"ollama\"").unwrap(); + assert_eq!(ollama, ProviderType::Ollama); + } +} diff --git a/src-tauri/src/providers/ollama.rs b/src-tauri/src/providers/ollama.rs new file mode 100644 index 0000000..d366f18 --- /dev/null +++ b/src-tauri/src/providers/ollama.rs @@ -0,0 +1,537 @@ +// This provider connects to a local Ollama instance for LLM inference. +// It will be actively used once providers are fully integrated with BridgeManager. +#![allow(dead_code)] + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use tokio::sync::mpsc; + +use crate::providers::traits::{ + LlmProvider, ModelInfo, ProviderCapabilities, ProviderConfig, ProviderMessage, + ProviderStreamEvent, ProviderUsage, StreamCallback, +}; + +const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434"; + +#[derive(Debug, Serialize)] +struct OllamaChatRequest { + model: String, + messages: Vec, + stream: bool, + #[serde(skip_serializing_if = "Option::is_none")] + system: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct OllamaMessage { + role: String, + content: String, +} + +#[derive(Debug, Deserialize)] +struct OllamaChatResponse { + #[serde(default)] + message: Option, + #[serde(default)] + done: bool, + #[serde(default)] + eval_count: Option, + #[serde(default)] + prompt_eval_count: Option, +} + +#[derive(Debug, Deserialize)] +struct OllamaResponseMessage { + #[serde(default)] + content: String, +} + +#[derive(Debug, Deserialize)] +struct OllamaTagsResponse { + models: Vec, +} + +#[derive(Debug, Deserialize)] +struct OllamaModelInfo { + name: String, + #[serde(default)] + details: Option, +} + +#[derive(Debug, Deserialize)] +struct OllamaModelDetails { + #[serde(default)] + parameter_size: Option, + #[serde(default)] + family: Option, +} + +pub struct OllamaProvider { + config: ProviderConfig, + client: reqwest::Client, + base_url: String, + is_running: Arc, + conversation_history: Vec, + cancel_tx: Option>, +} + +impl OllamaProvider { + pub fn new(config: ProviderConfig) -> Self { + let base_url = config + .api_base_url + .clone() + .unwrap_or_else(|| DEFAULT_OLLAMA_URL.to_string()); + + Self { + config, + client: reqwest::Client::new(), + base_url, + is_running: Arc::new(AtomicBool::new(false)), + conversation_history: Vec::new(), + cancel_tx: None, + } + } + + async fn fetch_available_models(&self) -> Vec { + let url = format!("{}/api/tags", self.base_url); + + match self.client.get(&url).send().await { + Ok(response) => { + if let Ok(tags) = response.json::().await { + tags.models + .into_iter() + .map(|m| { + let description = m.details.as_ref().map(|d| { + let mut desc_parts = Vec::new(); + if let Some(ref family) = d.family { + desc_parts.push(family.clone()); + } + if let Some(ref size) = d.parameter_size { + desc_parts.push(format!("{} parameters", size)); + } + if desc_parts.is_empty() { + "Local model".to_string() + } else { + desc_parts.join(" - ") + } + }); + + ModelInfo { + id: m.name.clone(), + name: m.name, + description, + context_window: None, + input_cost_per_mtok: Some(0.0), // Local = free! + output_cost_per_mtok: Some(0.0), + } + }) + .collect() + } else { + default_ollama_models() + } + } + Err(_) => default_ollama_models(), + } + } + + async fn stream_chat( + &mut self, + callback: Arc, + cancel_rx: mpsc::Receiver<()>, + ) -> Result<(), String> { + let model = self + .config + .model + .clone() + .unwrap_or_else(|| "llama3.2".to_string()); + + let request = OllamaChatRequest { + model: model.clone(), + messages: self.conversation_history.clone(), + stream: true, + system: self.config.custom_instructions.clone(), + }; + + let url = format!("{}/api/chat", self.base_url); + + let response = self + .client + .post(&url) + .json(&request) + .send() + .await + .map_err(|e| format!("Failed to connect to Ollama: {}", e))?; + + if !response.status().is_success() { + let status = response.status(); + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(format!("Ollama error ({}): {}", status, error_text)); + } + + let mut full_response = String::new(); + let mut total_input_tokens: u64 = 0; + let mut total_output_tokens: u64 = 0; + + let mut stream = response.bytes_stream(); + let mut cancel_rx = cancel_rx; + + use futures_util::StreamExt; + + loop { + tokio::select! { + _ = cancel_rx.recv() => { + callback(ProviderStreamEvent::Disconnected); + return Ok(()); + } + chunk = stream.next() => { + match chunk { + Some(Ok(bytes)) => { + let text = String::from_utf8_lossy(&bytes); + + for line in text.lines() { + if line.trim().is_empty() { + continue; + } + + if let Ok(response) = serde_json::from_str::(line) { + if let Some(msg) = &response.message { + if !msg.content.is_empty() { + full_response.push_str(&msg.content); + callback(ProviderStreamEvent::TextDelta { + text: msg.content.clone(), + }); + } + } + + if let Some(tokens) = response.prompt_eval_count { + total_input_tokens = tokens; + } + if let Some(tokens) = response.eval_count { + total_output_tokens = tokens; + } + + if response.done { + self.conversation_history.push(OllamaMessage { + role: "assistant".to_string(), + content: full_response.clone(), + }); + + callback(ProviderStreamEvent::MessageComplete { + content: vec![ProviderMessage::Text { + content: full_response, + }], + usage: Some(ProviderUsage { + input_tokens: total_input_tokens, + output_tokens: total_output_tokens, + model, + }), + stop_reason: Some("end_turn".to_string()), + }); + + return Ok(()); + } + } + } + } + Some(Err(e)) => { + return Err(format!("Stream error: {}", e)); + } + None => { + break; + } + } + } + } + } + + Ok(()) + } +} + +#[async_trait] +impl LlmProvider for OllamaProvider { + fn name(&self) -> &'static str { + "Ollama" + } + + fn capabilities(&self) -> ProviderCapabilities { + ProviderCapabilities { + supports_streaming: true, + supports_tools: false, // Ollama doesn't support tools natively yet + supports_vision: true, // Some models support vision + supports_thinking: false, + supports_mcp: false, + supports_resume_session: false, + max_context_tokens: None, // Varies by model + available_models: default_ollama_models(), + } + } + + async fn start(&mut self, callback: StreamCallback) -> Result<(), String> { + if self.is_running.load(Ordering::SeqCst) { + return Err("Provider already running".to_string()); + } + + // Check if Ollama is reachable + let url = format!("{}/api/tags", self.base_url); + self.client.get(&url).send().await.map_err(|e| { + format!( + "Cannot connect to Ollama at {}. Is it running? Error: {}", + self.base_url, e + ) + })?; + + self.is_running.store(true, Ordering::SeqCst); + self.conversation_history.clear(); + + // Create cancellation channel + let (tx, _rx) = mpsc::channel(1); + self.cancel_tx = Some(tx); + + callback(ProviderStreamEvent::Connected { session_id: None }); + + // Fetch and report available models + let models = self.fetch_available_models().await; + if !models.is_empty() { + eprintln!( + "[Ollama] Available models: {:?}", + models.iter().map(|m| &m.id).collect::>() + ); + } + + Ok(()) + } + + async fn stop(&mut self) -> Result<(), String> { + self.is_running.store(false, Ordering::SeqCst); + self.conversation_history.clear(); + + if let Some(tx) = self.cancel_tx.take() { + let _ = tx.send(()).await; + } + + Ok(()) + } + + fn is_running(&self) -> bool { + self.is_running.load(Ordering::SeqCst) + } + + async fn send_message(&mut self, message: &str) -> Result<(), String> { + if !self.is_running.load(Ordering::SeqCst) { + return Err("Provider not running".to_string()); + } + + // Add user message to history + self.conversation_history.push(OllamaMessage { + role: "user".to_string(), + content: message.to_string(), + }); + + // Create a new cancel channel for this request + let (tx, _rx) = mpsc::channel(1); + self.cancel_tx = Some(tx); + + // We need a callback here, but we don't have access to it in send_message + // This is a limitation of the current trait design + // For now, we'll need to refactor to handle streaming properly + // The callback should be stored from the start() call + + // For the MVP, we'll emit events directly + // In a real implementation, we'd need to restructure this + + Err("send_message needs refactoring to work with stored callback".to_string()) + } + + async fn send_tool_result( + &mut self, + _tool_use_id: &str, + _result: serde_json::Value, + _is_error: bool, + ) -> Result<(), String> { + // Ollama doesn't support tool use natively + Err("Ollama does not support tool use".to_string()) + } + + fn get_model(&self) -> Option { + self.config.model.clone() + } + + fn get_working_directory(&self) -> &str { + &self.config.working_directory + } + + async fn interrupt(&mut self) -> Result<(), String> { + if let Some(tx) = self.cancel_tx.take() { + let _ = tx.send(()).await; + } + self.is_running.store(false, Ordering::SeqCst); + Ok(()) + } +} + +fn default_ollama_models() -> Vec { + vec![ + ModelInfo { + id: "llama3.2".to_string(), + name: "Llama 3.2".to_string(), + description: Some("Meta's latest compact model".to_string()), + context_window: Some(128_000), + input_cost_per_mtok: Some(0.0), + output_cost_per_mtok: Some(0.0), + }, + ModelInfo { + id: "llama3.2:1b".to_string(), + name: "Llama 3.2 1B".to_string(), + description: Some("Smallest Llama 3.2 variant".to_string()), + context_window: Some(128_000), + input_cost_per_mtok: Some(0.0), + output_cost_per_mtok: Some(0.0), + }, + ModelInfo { + id: "qwen2.5-coder".to_string(), + name: "Qwen 2.5 Coder".to_string(), + description: Some("Alibaba's coding-focused model".to_string()), + context_window: Some(32_000), + input_cost_per_mtok: Some(0.0), + output_cost_per_mtok: Some(0.0), + }, + ModelInfo { + id: "deepseek-coder-v2".to_string(), + name: "DeepSeek Coder V2".to_string(), + description: Some("DeepSeek's coding model".to_string()), + context_window: Some(128_000), + input_cost_per_mtok: Some(0.0), + output_cost_per_mtok: Some(0.0), + }, + ModelInfo { + id: "mistral".to_string(), + name: "Mistral 7B".to_string(), + description: Some("Fast and capable".to_string()), + context_window: Some(32_000), + input_cost_per_mtok: Some(0.0), + output_cost_per_mtok: Some(0.0), + }, + ModelInfo { + id: "gemma2".to_string(), + name: "Gemma 2".to_string(), + description: Some("Google's open model".to_string()), + context_window: Some(8_000), + input_cost_per_mtok: Some(0.0), + output_cost_per_mtok: Some(0.0), + }, + ] +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_provider_name() { + let config = ProviderConfig::default(); + let provider = OllamaProvider::new(config); + assert_eq!(provider.name(), "Ollama"); + } + + #[test] + fn test_capabilities() { + let config = ProviderConfig::default(); + let provider = OllamaProvider::new(config); + let caps = provider.capabilities(); + + assert!(caps.supports_streaming); + assert!(!caps.supports_tools); + assert!(!caps.supports_mcp); + assert!(!caps.supports_resume_session); + assert!(!caps.available_models.is_empty()); + } + + #[test] + fn test_default_url() { + let config = ProviderConfig::default(); + let provider = OllamaProvider::new(config); + assert_eq!(provider.base_url, DEFAULT_OLLAMA_URL); + } + + #[test] + fn test_custom_url() { + let config = ProviderConfig { + api_base_url: Some("http://custom:8080".to_string()), + ..Default::default() + }; + let provider = OllamaProvider::new(config); + assert_eq!(provider.base_url, "http://custom:8080"); + } + + #[test] + fn test_is_running_initial() { + let config = ProviderConfig::default(); + let provider = OllamaProvider::new(config); + assert!(!provider.is_running()); + } + + #[test] + fn test_default_ollama_models() { + let models = default_ollama_models(); + assert!(!models.is_empty()); + + // All models should be free (local) + for model in &models { + assert_eq!(model.input_cost_per_mtok, Some(0.0)); + assert_eq!(model.output_cost_per_mtok, Some(0.0)); + } + + // Should include llama + assert!(models.iter().any(|m| m.id.contains("llama"))); + } + + #[test] + fn test_ollama_message_serialization() { + let msg = OllamaMessage { + role: "user".to_string(), + content: "Hello!".to_string(), + }; + + let json = serde_json::to_string(&msg).unwrap(); + assert!(json.contains("\"role\":\"user\"")); + assert!(json.contains("\"content\":\"Hello!\"")); + } + + #[test] + fn test_chat_request_serialization() { + let request = OllamaChatRequest { + model: "llama3.2".to_string(), + messages: vec![OllamaMessage { + role: "user".to_string(), + content: "Test".to_string(), + }], + stream: true, + system: Some("You are helpful".to_string()), + }; + + let json = serde_json::to_string(&request).unwrap(); + assert!(json.contains("\"model\":\"llama3.2\"")); + assert!(json.contains("\"stream\":true")); + assert!(json.contains("\"system\":\"You are helpful\"")); + } + + #[test] + fn test_chat_request_without_system() { + let request = OllamaChatRequest { + model: "llama3.2".to_string(), + messages: vec![], + stream: true, + system: None, + }; + + let json = serde_json::to_string(&request).unwrap(); + // system should be omitted when None + assert!(!json.contains("\"system\"")); + } +} diff --git a/src-tauri/src/providers/traits.rs b/src-tauri/src/providers/traits.rs new file mode 100644 index 0000000..dff07ea --- /dev/null +++ b/src-tauri/src/providers/traits.rs @@ -0,0 +1,236 @@ +// These types are all used by the provider abstraction layer. +// They will be actively used once providers are fully integrated with BridgeManager. +#![allow(dead_code)] + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ProviderConfig { + pub api_key: Option, + pub api_base_url: Option, + pub model: Option, + pub custom_instructions: Option, + pub working_directory: String, + pub mcp_servers_json: Option, + pub allowed_tools: Vec, + #[serde(default)] + pub extra_options: HashMap, +} + +impl Default for ProviderConfig { + fn default() -> Self { + Self { + api_key: None, + api_base_url: None, + model: None, + custom_instructions: None, + working_directory: String::new(), + mcp_servers_json: None, + allowed_tools: Vec::new(), + extra_options: HashMap::new(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ProviderCapabilities { + pub supports_streaming: bool, + pub supports_tools: bool, + pub supports_vision: bool, + pub supports_thinking: bool, + pub supports_mcp: bool, + pub supports_resume_session: bool, + pub max_context_tokens: Option, + pub available_models: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelInfo { + pub id: String, + pub name: String, + pub description: Option, + pub context_window: Option, + pub input_cost_per_mtok: Option, + pub output_cost_per_mtok: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ProviderMessage { + #[serde(rename = "text")] + Text { content: String }, + #[serde(rename = "tool_use")] + ToolUse { + id: String, + name: String, + input: serde_json::Value, + }, + #[serde(rename = "tool_result")] + ToolResult { + tool_use_id: String, + content: String, + is_error: bool, + }, + #[serde(rename = "thinking")] + Thinking { content: String }, + #[serde(rename = "image")] + Image { media_type: String, data: String }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ProviderUsage { + pub input_tokens: u64, + pub output_tokens: u64, + pub model: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ProviderStreamEvent { + #[serde(rename = "connected")] + Connected { session_id: Option }, + #[serde(rename = "text_delta")] + TextDelta { text: String }, + #[serde(rename = "thinking_delta")] + ThinkingDelta { text: String }, + #[serde(rename = "tool_use_start")] + ToolUseStart { id: String, name: String }, + #[serde(rename = "tool_use_delta")] + ToolUseDelta { id: String, input_delta: String }, + #[serde(rename = "tool_use_end")] + ToolUseEnd { id: String, input: serde_json::Value }, + #[serde(rename = "message_complete")] + MessageComplete { + content: Vec, + usage: Option, + stop_reason: Option, + }, + #[serde(rename = "permission_request")] + PermissionRequest { + id: String, + tool_name: String, + tool_input: serde_json::Value, + description: String, + }, + #[serde(rename = "question")] + Question { + id: String, + question: String, + header: Option, + options: Vec, + multi_select: bool, + }, + #[serde(rename = "working_directory")] + WorkingDirectory { path: String }, + #[serde(rename = "error")] + Error { message: String }, + #[serde(rename = "disconnected")] + Disconnected, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct QuestionOption { + pub label: String, + pub description: Option, +} + +pub type StreamCallback = Box; + +#[async_trait] +pub trait LlmProvider: Send + Sync { + fn name(&self) -> &'static str; + + fn capabilities(&self) -> ProviderCapabilities; + + async fn start(&mut self, callback: StreamCallback) -> Result<(), String>; + + async fn stop(&mut self) -> Result<(), String>; + + fn is_running(&self) -> bool; + + async fn send_message(&mut self, message: &str) -> Result<(), String>; + + async fn send_tool_result( + &mut self, + tool_use_id: &str, + result: serde_json::Value, + is_error: bool, + ) -> Result<(), String>; + + fn get_model(&self) -> Option; + + fn get_working_directory(&self) -> &str; + + async fn interrupt(&mut self) -> Result<(), String>; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_provider_config_default() { + let config = ProviderConfig::default(); + assert!(config.api_key.is_none()); + assert!(config.api_base_url.is_none()); + assert!(config.model.is_none()); + assert!(config.working_directory.is_empty()); + assert!(config.allowed_tools.is_empty()); + } + + #[test] + fn test_provider_message_serialization() { + let msg = ProviderMessage::Text { + content: "Hello!".to_string(), + }; + let json = serde_json::to_string(&msg).unwrap(); + assert!(json.contains("\"type\":\"text\"")); + assert!(json.contains("\"content\":\"Hello!\"")); + } + + #[test] + fn test_provider_stream_event_serialization() { + let event = ProviderStreamEvent::TextDelta { + text: "chunk".to_string(), + }; + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"type\":\"text_delta\"")); + + let event = ProviderStreamEvent::Connected { + session_id: Some("test-123".to_string()), + }; + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"session_id\":\"test-123\"")); + } + + #[test] + fn test_model_info() { + let model = ModelInfo { + id: "claude-sonnet-4-20250514".to_string(), + name: "Claude Sonnet 4".to_string(), + description: Some("Fast and intelligent".to_string()), + context_window: Some(200000), + input_cost_per_mtok: Some(3.0), + output_cost_per_mtok: Some(15.0), + }; + + let json = serde_json::to_string(&model).unwrap(); + assert!(json.contains("claude-sonnet-4")); + assert!(json.contains("200000")); + } + + #[test] + fn test_provider_usage() { + let usage = ProviderUsage { + input_tokens: 100, + output_tokens: 50, + model: "claude-opus-4-20250514".to_string(), + }; + + let json = serde_json::to_string(&usage).unwrap(); + assert!(json.contains("\"input_tokens\":100")); + assert!(json.contains("\"output_tokens\":50")); + } +} diff --git a/src/lib/components/ConfigSidebar.svelte b/src/lib/components/ConfigSidebar.svelte index c664009..0b91735 100644 --- a/src/lib/components/ConfigSidebar.svelte +++ b/src/lib/components/ConfigSidebar.svelte @@ -3,6 +3,7 @@ configStore, type HikariConfig, type Theme, + type ProviderType, type CustomThemeColors, applyFontSize, applyCustomThemeColors, @@ -14,11 +15,14 @@ import { getCurrentWindow } from "@tauri-apps/api/window"; let config: HikariConfig = $state({ + provider_type: "claude_cli", model: null, api_key: null, custom_instructions: null, mcp_servers_json: null, auto_granted_tools: [], + ollama_base_url: "http://localhost:11434", + ollama_model: null, theme: "dark", greeting_enabled: true, greeting_custom_prompt: null, @@ -72,12 +76,27 @@ grantedTools = Array.from(tools); }); + const availableProviders: { value: ProviderType; label: string; description: string }[] = [ + { value: "claude_cli", label: "Claude CLI", description: "Use Claude Code CLI for AI assistance" }, + { value: "ollama", label: "Ollama (Local)", description: "Use locally running Ollama models" }, + ]; + const availableModels = [ { value: "", label: "Default (from ~/.claude)" }, { value: "claude-sonnet-4-20250514", label: "Claude Sonnet 4" }, { value: "claude-opus-4-20250514", label: "Claude Opus 4" }, ]; + const ollamaModels = [ + { value: "", label: "Default (llama3.2)" }, + { value: "llama3.2", label: "Llama 3.2" }, + { value: "llama3.2:1b", label: "Llama 3.2 1B" }, + { value: "qwen2.5-coder", label: "Qwen 2.5 Coder" }, + { value: "deepseek-coder-v2", label: "DeepSeek Coder V2" }, + { value: "mistral", label: "Mistral 7B" }, + { value: "gemma2", label: "Gemma 2" }, + ]; + const commonTools = [ "Read", "Write", @@ -207,6 +226,79 @@ {/if} + +
+

+ AI Provider +

+ +
+ {#each availableProviders as provider (provider.value)} + + {/each} +
+ + + {#if config.provider_type === "ollama"} +
+

Ollama Settings

+ + +
+ + +
+ + +
+ + +

+ Make sure the model is downloaded via ollama pull +

+
+ +
+ Note: Ollama doesn't support tools, MCP servers, or thinking blocks. + For full Claude Code features, use the Claude CLI provider. +
+
+ {/if} +
+

diff --git a/src/lib/components/StatusBar.svelte b/src/lib/components/StatusBar.svelte index be2942e..840e2e8 100644 --- a/src/lib/components/StatusBar.svelte +++ b/src/lib/components/StatusBar.svelte @@ -43,11 +43,14 @@ let showProfile = $state(false); const progress = $derived($achievementProgress); let currentConfig: HikariConfig = $state({ + provider_type: "claude_cli", model: null, api_key: null, custom_instructions: null, mcp_servers_json: null, auto_granted_tools: [], + ollama_base_url: "http://localhost:11434", + ollama_model: null, theme: "dark", greeting_enabled: true, greeting_custom_prompt: null, @@ -145,12 +148,15 @@ await invoke("start_claude", { conversationId, options: { + provider_type: currentConfig.provider_type || "claude_cli", working_dir: targetDir, model: currentConfig.model || null, api_key: currentConfig.api_key || null, custom_instructions: currentConfig.custom_instructions || null, mcp_servers_json: currentConfig.mcp_servers_json || null, allowed_tools: allAllowedTools, + ollama_base_url: currentConfig.ollama_base_url || "http://localhost:11434", + ollama_model: currentConfig.ollama_model || null, }, }); } catch (error) { diff --git a/src/lib/stores/config.ts b/src/lib/stores/config.ts index 0d5de7c..855fda9 100644 --- a/src/lib/stores/config.ts +++ b/src/lib/stores/config.ts @@ -2,6 +2,7 @@ import { writable, derived } from "svelte/store"; import { invoke } from "@tauri-apps/api/core"; export type Theme = "dark" | "light" | "high-contrast" | "custom"; +export type ProviderType = "claude_cli" | "ollama"; export interface CustomThemeColors { bg_primary: string | null; @@ -15,11 +16,14 @@ export interface CustomThemeColors { } export interface HikariConfig { + provider_type: ProviderType; model: string | null; api_key: string | null; custom_instructions: string | null; mcp_servers_json: string | null; auto_granted_tools: string[]; + ollama_base_url: string; + ollama_model: string | null; theme: Theme; greeting_enabled: boolean; greeting_custom_prompt: string | null; @@ -40,11 +44,14 @@ export interface HikariConfig { } const defaultConfig: HikariConfig = { + provider_type: "claude_cli", model: null, api_key: null, custom_instructions: null, mcp_servers_json: null, auto_granted_tools: [], + ollama_base_url: "http://localhost:11434", + ollama_model: null, theme: "dark", greeting_enabled: true, greeting_custom_prompt: null,