generated from nhcarrigan/template
Compare commits
2 Commits
v1.9.0
...
feat/providers
| Author | SHA1 | Date | |
|---|---|---|---|
|
ae06cc301d
|
|||
|
a0804ed32a
|
Generated
+148
-2
@@ -592,7 +592,7 @@ dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"core-foundation 0.10.1",
|
||||
"core-graphics-types",
|
||||
"foreign-types",
|
||||
"foreign-types 0.5.0",
|
||||
"libc",
|
||||
]
|
||||
|
||||
@@ -1080,6 +1080,15 @@ version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
|
||||
|
||||
[[package]]
|
||||
name = "foreign-types"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
|
||||
dependencies = [
|
||||
"foreign-types-shared 0.1.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "foreign-types"
|
||||
version = "0.5.0"
|
||||
@@ -1087,7 +1096,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965"
|
||||
dependencies = [
|
||||
"foreign-types-macros",
|
||||
"foreign-types-shared",
|
||||
"foreign-types-shared 0.3.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1101,6 +1110,12 @@ dependencies = [
|
||||
"syn 2.0.114",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "foreign-types-shared"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
|
||||
|
||||
[[package]]
|
||||
name = "foreign-types-shared"
|
||||
version = "0.3.1"
|
||||
@@ -1604,8 +1619,11 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
|
||||
name = "hikari-desktop"
|
||||
version = "1.1.1"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"chrono",
|
||||
"futures-util",
|
||||
"parking_lot",
|
||||
"reqwest",
|
||||
"semver",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -1716,6 +1734,22 @@ dependencies = [
|
||||
"webpki-roots",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-tls"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"http-body-util",
|
||||
"hyper",
|
||||
"hyper-util",
|
||||
"native-tls",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-util"
|
||||
version = "0.1.19"
|
||||
@@ -2288,6 +2322,23 @@ dependencies = [
|
||||
"windows-sys 0.60.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "native-tls"
|
||||
version = "0.2.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"log",
|
||||
"openssl",
|
||||
"openssl-probe",
|
||||
"openssl-sys",
|
||||
"schannel",
|
||||
"security-framework",
|
||||
"security-framework-sys",
|
||||
"tempfile",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ndk"
|
||||
version = "0.9.0"
|
||||
@@ -2661,6 +2712,50 @@ dependencies = [
|
||||
"pathdiff",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openssl"
|
||||
version = "0.10.75"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328"
|
||||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"cfg-if",
|
||||
"foreign-types 0.3.2",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"openssl-macros",
|
||||
"openssl-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openssl-macros"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.114",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openssl-probe"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
|
||||
|
||||
[[package]]
|
||||
name = "openssl-sys"
|
||||
version = "0.9.111"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
"pkg-config",
|
||||
"vcpkg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "option-ext"
|
||||
version = "0.2.0"
|
||||
@@ -3417,10 +3512,12 @@ dependencies = [
|
||||
"http-body-util",
|
||||
"hyper",
|
||||
"hyper-rustls",
|
||||
"hyper-tls",
|
||||
"hyper-util",
|
||||
"js-sys",
|
||||
"log",
|
||||
"mime",
|
||||
"native-tls",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"quinn",
|
||||
@@ -3431,6 +3528,7 @@ dependencies = [
|
||||
"serde_urlencoded",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tokio-rustls",
|
||||
"tokio-util",
|
||||
"tower",
|
||||
@@ -3566,6 +3664,15 @@ dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "schannel"
|
||||
version = "0.1.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1"
|
||||
dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "schemars"
|
||||
version = "0.8.22"
|
||||
@@ -3623,6 +3730,29 @@ version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
||||
|
||||
[[package]]
|
||||
name = "security-framework"
|
||||
version = "2.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
|
||||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"core-foundation 0.9.4",
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
"security-framework-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "security-framework-sys"
|
||||
version = "2.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0"
|
||||
dependencies = [
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "selectors"
|
||||
version = "0.24.0"
|
||||
@@ -4746,6 +4876,16 @@ dependencies = [
|
||||
"syn 2.0.114",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-native-tls"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"
|
||||
dependencies = [
|
||||
"native-tls",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.26.4"
|
||||
@@ -5111,6 +5251,12 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "vcpkg"
|
||||
version = "0.2.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
|
||||
|
||||
[[package]]
|
||||
name = "version-compare"
|
||||
version = "0.2.1"
|
||||
|
||||
@@ -31,6 +31,9 @@ tauri-plugin-fs = "2"
|
||||
tempfile = "3"
|
||||
semver = "1"
|
||||
chrono = { version = "0.4.43", features = ["serde"] }
|
||||
async-trait = "0.1"
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
futures-util = "0.3"
|
||||
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
windows = { version = "0.62", features = [
|
||||
|
||||
@@ -4,11 +4,11 @@ use std::sync::Arc;
|
||||
use tauri::AppHandle;
|
||||
|
||||
use crate::config::ClaudeStartOptions;
|
||||
use crate::provider_bridge::ProviderBridge;
|
||||
use crate::stats::UsageStats;
|
||||
use crate::wsl_bridge::WslBridge;
|
||||
|
||||
pub struct BridgeManager {
|
||||
bridges: HashMap<String, WslBridge>,
|
||||
bridges: HashMap<String, ProviderBridge>,
|
||||
app_handle: Option<AppHandle>,
|
||||
}
|
||||
|
||||
@@ -45,12 +45,24 @@ impl BridgeManager {
|
||||
.ok_or_else(|| "App handle not set".to_string())?
|
||||
.clone();
|
||||
|
||||
// Check if existing bridge matches the requested provider type
|
||||
// If provider type changed, create a new bridge
|
||||
let should_recreate = self.bridges.get(conversation_id).map_or(false, |bridge| {
|
||||
bridge.provider_type() != options.provider_type
|
||||
});
|
||||
|
||||
if should_recreate {
|
||||
// Remove existing bridge if provider type changed
|
||||
self.bridges.remove(conversation_id);
|
||||
}
|
||||
|
||||
// Reuse existing bridge if it exists (preserves stats across reconnects)
|
||||
// Only create a new bridge if one doesn't exist for this conversation
|
||||
let provider_type = options.provider_type;
|
||||
let bridge = self
|
||||
.bridges
|
||||
.entry(conversation_id.to_string())
|
||||
.or_insert_with(|| WslBridge::new_with_conversation_id(conversation_id.to_string()));
|
||||
.or_insert_with(|| ProviderBridge::new(provider_type, conversation_id.to_string()));
|
||||
|
||||
// Start the Claude process
|
||||
bridge.start(app, options)?;
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
use crate::providers::ProviderType;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ClaudeStartOptions {
|
||||
#[serde(default)]
|
||||
pub provider_type: ProviderType,
|
||||
|
||||
#[serde(default)]
|
||||
pub working_dir: String,
|
||||
|
||||
@@ -25,10 +29,47 @@ pub struct ClaudeStartOptions {
|
||||
|
||||
#[serde(default)]
|
||||
pub resume_session_id: Option<String>,
|
||||
|
||||
// Ollama-specific options
|
||||
#[serde(default = "default_ollama_base_url")]
|
||||
pub ollama_base_url: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub ollama_model: Option<String>,
|
||||
|
||||
// OpenAI-specific options
|
||||
#[serde(default)]
|
||||
pub openai_api_key: Option<String>,
|
||||
|
||||
#[serde(default = "default_openai_base_url")]
|
||||
pub openai_base_url: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub openai_model: Option<String>,
|
||||
|
||||
// Anthropic-specific options
|
||||
#[serde(default)]
|
||||
pub anthropic_api_key: Option<String>,
|
||||
|
||||
#[serde(default = "default_anthropic_base_url")]
|
||||
pub anthropic_base_url: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub anthropic_model: Option<String>,
|
||||
|
||||
// Gemini-specific options
|
||||
#[serde(default)]
|
||||
pub gemini_api_key: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub gemini_model: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HikariConfig {
|
||||
#[serde(default)]
|
||||
pub provider_type: ProviderType,
|
||||
|
||||
#[serde(default)]
|
||||
pub model: Option<String>,
|
||||
|
||||
@@ -44,6 +85,40 @@ pub struct HikariConfig {
|
||||
#[serde(default)]
|
||||
pub auto_granted_tools: Vec<String>,
|
||||
|
||||
// Ollama-specific settings
|
||||
#[serde(default = "default_ollama_base_url")]
|
||||
pub ollama_base_url: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub ollama_model: Option<String>,
|
||||
|
||||
// OpenAI-specific settings
|
||||
#[serde(default)]
|
||||
pub openai_api_key: Option<String>,
|
||||
|
||||
#[serde(default = "default_openai_base_url")]
|
||||
pub openai_base_url: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub openai_model: Option<String>,
|
||||
|
||||
// Anthropic-specific settings
|
||||
#[serde(default)]
|
||||
pub anthropic_api_key: Option<String>,
|
||||
|
||||
#[serde(default = "default_anthropic_base_url")]
|
||||
pub anthropic_base_url: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub anthropic_model: Option<String>,
|
||||
|
||||
// Gemini-specific settings
|
||||
#[serde(default)]
|
||||
pub gemini_api_key: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub gemini_model: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub theme: Theme,
|
||||
|
||||
@@ -101,11 +176,22 @@ pub struct HikariConfig {
|
||||
impl Default for HikariConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
provider_type: ProviderType::default(),
|
||||
model: None,
|
||||
api_key: None,
|
||||
custom_instructions: None,
|
||||
mcp_servers_json: None,
|
||||
auto_granted_tools: Vec::new(),
|
||||
ollama_base_url: default_ollama_base_url(),
|
||||
ollama_model: None,
|
||||
openai_api_key: None,
|
||||
openai_base_url: default_openai_base_url(),
|
||||
openai_model: None,
|
||||
anthropic_api_key: None,
|
||||
anthropic_base_url: default_anthropic_base_url(),
|
||||
anthropic_model: None,
|
||||
gemini_api_key: None,
|
||||
gemini_model: None,
|
||||
theme: Theme::default(),
|
||||
greeting_enabled: true,
|
||||
greeting_custom_prompt: None,
|
||||
@@ -147,6 +233,18 @@ fn default_font_size() -> u32 {
|
||||
14
|
||||
}
|
||||
|
||||
fn default_ollama_base_url() -> String {
|
||||
"http://localhost:11434".to_string()
|
||||
}
|
||||
|
||||
fn default_openai_base_url() -> String {
|
||||
"https://api.openai.com/v1".to_string()
|
||||
}
|
||||
|
||||
fn default_anthropic_base_url() -> String {
|
||||
"https://api.anthropic.com".to_string()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Theme {
|
||||
@@ -185,11 +283,26 @@ mod tests {
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = HikariConfig::default();
|
||||
assert_eq!(config.provider_type, ProviderType::ClaudeCli);
|
||||
assert!(config.model.is_none());
|
||||
assert!(config.api_key.is_none());
|
||||
assert!(config.custom_instructions.is_none());
|
||||
assert!(config.mcp_servers_json.is_none());
|
||||
assert!(config.auto_granted_tools.is_empty());
|
||||
assert_eq!(config.ollama_base_url, "http://localhost:11434");
|
||||
assert!(config.ollama_model.is_none());
|
||||
// OpenAI defaults
|
||||
assert!(config.openai_api_key.is_none());
|
||||
assert_eq!(config.openai_base_url, "https://api.openai.com/v1");
|
||||
assert!(config.openai_model.is_none());
|
||||
// Anthropic defaults
|
||||
assert!(config.anthropic_api_key.is_none());
|
||||
assert_eq!(config.anthropic_base_url, "https://api.anthropic.com");
|
||||
assert!(config.anthropic_model.is_none());
|
||||
// Gemini defaults
|
||||
assert!(config.gemini_api_key.is_none());
|
||||
assert!(config.gemini_model.is_none());
|
||||
// Other settings
|
||||
assert_eq!(config.theme, Theme::Dark);
|
||||
assert!(config.greeting_enabled);
|
||||
assert!(config.greeting_custom_prompt.is_none());
|
||||
@@ -210,11 +323,22 @@ mod tests {
|
||||
#[test]
|
||||
fn test_config_serialization() {
|
||||
let config = HikariConfig {
|
||||
provider_type: ProviderType::ClaudeCli,
|
||||
model: Some("claude-sonnet-4-20250514".to_string()),
|
||||
api_key: None,
|
||||
custom_instructions: Some("Be helpful".to_string()),
|
||||
mcp_servers_json: None,
|
||||
auto_granted_tools: vec!["Read".to_string(), "Glob".to_string()],
|
||||
ollama_base_url: "http://localhost:11434".to_string(),
|
||||
ollama_model: None,
|
||||
openai_api_key: None,
|
||||
openai_base_url: "https://api.openai.com/v1".to_string(),
|
||||
openai_model: None,
|
||||
anthropic_api_key: None,
|
||||
anthropic_base_url: "https://api.anthropic.com".to_string(),
|
||||
anthropic_model: None,
|
||||
gemini_api_key: None,
|
||||
gemini_model: None,
|
||||
theme: Theme::Light,
|
||||
greeting_enabled: true,
|
||||
greeting_custom_prompt: Some("Hello!".to_string()),
|
||||
|
||||
@@ -5,6 +5,8 @@ mod commands;
|
||||
mod config;
|
||||
mod git;
|
||||
mod notifications;
|
||||
mod provider_bridge;
|
||||
mod providers;
|
||||
mod quick_actions;
|
||||
mod sessions;
|
||||
mod snippets;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,773 @@
|
||||
// This provider wraps the Claude CLI subprocess.
|
||||
// It will be actively used once providers are fully integrated with BridgeManager.
|
||||
#![allow(dead_code)]
|
||||
|
||||
use async_trait::async_trait;
|
||||
use std::io::{BufRead, BufReader, Write};
|
||||
use std::process::{Child, ChildStdin, Command, Stdio};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
use crate::providers::traits::{
|
||||
LlmProvider, ModelInfo, ProviderCapabilities, ProviderConfig, ProviderStreamEvent,
|
||||
QuestionOption, StreamCallback,
|
||||
};
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
use std::os::windows::process::CommandExt;
|
||||
|
||||
fn detect_wsl() -> bool {
|
||||
if let Ok(version) = std::fs::read_to_string("/proc/version") {
|
||||
let version_lower = version.to_lowercase();
|
||||
if version_lower.contains("microsoft") || version_lower.contains("wsl") {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if std::path::Path::new("/proc/sys/fs/binfmt_misc/WSLInterop").exists() {
|
||||
return true;
|
||||
}
|
||||
|
||||
if std::env::var("WSL_DISTRO_NAME").is_ok() {
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
fn find_claude_binary() -> Option<String> {
|
||||
let home = std::env::var("HOME").ok()?;
|
||||
let paths_to_check = [
|
||||
format!("{}/.local/bin/claude", home),
|
||||
format!("{}/.claude/local/claude", home),
|
||||
"/usr/local/bin/claude".to_string(),
|
||||
"/usr/bin/claude".to_string(),
|
||||
];
|
||||
|
||||
for path in &paths_to_check {
|
||||
if std::path::Path::new(path).exists() {
|
||||
return Some(path.clone());
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(output) = Command::new("which").arg("claude").output() {
|
||||
if output.status.success() {
|
||||
let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
if !path.is_empty() {
|
||||
return Some(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
pub struct ClaudeCliProvider {
|
||||
config: ProviderConfig,
|
||||
process: Option<Child>,
|
||||
stdin: Option<ChildStdin>,
|
||||
session_id: Option<String>,
|
||||
mcp_config_file: Option<NamedTempFile>,
|
||||
is_running: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl ClaudeCliProvider {
|
||||
pub fn new(config: ProviderConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
process: None,
|
||||
stdin: None,
|
||||
session_id: None,
|
||||
mcp_config_file: None,
|
||||
is_running: Arc::new(AtomicBool::new(false)),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_command(&mut self) -> Result<Command, String> {
|
||||
let mcp_config_path = if let Some(ref mcp_json) = self.config.mcp_servers_json {
|
||||
if !mcp_json.trim().is_empty() {
|
||||
serde_json::from_str::<serde_json::Value>(mcp_json)
|
||||
.map_err(|e| format!("Invalid MCP servers JSON: {}", e))?;
|
||||
|
||||
let mut temp_file = NamedTempFile::new()
|
||||
.map_err(|e| format!("Failed to create temp file for MCP config: {}", e))?;
|
||||
temp_file
|
||||
.write_all(mcp_json.as_bytes())
|
||||
.map_err(|e| format!("Failed to write MCP config: {}", e))?;
|
||||
temp_file
|
||||
.flush()
|
||||
.map_err(|e| format!("Failed to flush MCP config: {}", e))?;
|
||||
|
||||
let path = temp_file.path().to_string_lossy().to_string();
|
||||
self.mcp_config_file = Some(temp_file);
|
||||
Some(path)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let is_wsl = detect_wsl();
|
||||
let working_dir = &self.config.working_directory;
|
||||
|
||||
let command = if is_wsl {
|
||||
let claude_path = find_claude_binary().ok_or_else(|| {
|
||||
"Could not find claude binary. Is Claude Code installed?".to_string()
|
||||
})?;
|
||||
|
||||
let mut cmd = Command::new(&claude_path);
|
||||
cmd.args([
|
||||
"--output-format",
|
||||
"stream-json",
|
||||
"--input-format",
|
||||
"stream-json",
|
||||
"--verbose",
|
||||
]);
|
||||
|
||||
if let Some(ref model) = self.config.model {
|
||||
if !model.is_empty() {
|
||||
cmd.args(["--model", model]);
|
||||
}
|
||||
}
|
||||
|
||||
for tool in &self.config.allowed_tools {
|
||||
cmd.args(["--allowedTools", tool]);
|
||||
}
|
||||
|
||||
if let Some(ref instructions) = self.config.custom_instructions {
|
||||
if !instructions.is_empty() {
|
||||
cmd.args(["--system-prompt", instructions]);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref mcp_path) = mcp_config_path {
|
||||
cmd.args(["--mcp-config", mcp_path]);
|
||||
}
|
||||
|
||||
cmd.current_dir(working_dir);
|
||||
|
||||
if let Some(ref api_key) = self.config.api_key {
|
||||
if !api_key.is_empty() {
|
||||
cmd.env("ANTHROPIC_API_KEY", api_key);
|
||||
}
|
||||
}
|
||||
|
||||
cmd
|
||||
} else {
|
||||
let mut cmd = Command::new("wsl");
|
||||
|
||||
let mut claude_cmd = format!("cd '{}' && ", working_dir);
|
||||
|
||||
if let Some(ref api_key) = self.config.api_key {
|
||||
if !api_key.is_empty() {
|
||||
claude_cmd.push_str(&format!("ANTHROPIC_API_KEY='{}' ", api_key));
|
||||
}
|
||||
}
|
||||
|
||||
claude_cmd.push_str(
|
||||
"claude --output-format stream-json --input-format stream-json --verbose",
|
||||
);
|
||||
|
||||
if let Some(ref model) = self.config.model {
|
||||
if !model.is_empty() {
|
||||
claude_cmd.push_str(&format!(" --model '{}'", model));
|
||||
}
|
||||
}
|
||||
|
||||
for tool in &self.config.allowed_tools {
|
||||
claude_cmd.push_str(&format!(" --allowedTools '{}'", tool));
|
||||
}
|
||||
|
||||
if let Some(ref instructions) = self.config.custom_instructions {
|
||||
if !instructions.is_empty() {
|
||||
let escaped = instructions.replace('\'', "'\\''");
|
||||
claude_cmd.push_str(&format!(" --system-prompt '{}'", escaped));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref mcp_path) = mcp_config_path {
|
||||
claude_cmd.push_str(&format!(" --mcp-config '{}'", mcp_path));
|
||||
}
|
||||
|
||||
cmd.args(["-e", "bash", "-lc", &claude_cmd]);
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
cmd.creation_flags(0x08000000); // CREATE_NO_WINDOW
|
||||
|
||||
cmd
|
||||
};
|
||||
|
||||
Ok(command)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for ClaudeCliProvider {
|
||||
fn name(&self) -> &'static str {
|
||||
"Claude CLI"
|
||||
}
|
||||
|
||||
fn capabilities(&self) -> ProviderCapabilities {
|
||||
ProviderCapabilities {
|
||||
supports_streaming: true,
|
||||
supports_tools: true,
|
||||
supports_vision: true,
|
||||
supports_thinking: true,
|
||||
supports_mcp: true,
|
||||
supports_resume_session: true,
|
||||
max_context_tokens: Some(200_000),
|
||||
available_models: vec![
|
||||
ModelInfo {
|
||||
id: "claude-sonnet-4-20250514".to_string(),
|
||||
name: "Claude Sonnet 4".to_string(),
|
||||
description: Some("Fast and intelligent".to_string()),
|
||||
context_window: Some(200_000),
|
||||
input_cost_per_mtok: Some(3.0),
|
||||
output_cost_per_mtok: Some(15.0),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "claude-opus-4-20250514".to_string(),
|
||||
name: "Claude Opus 4".to_string(),
|
||||
description: Some("Most capable model".to_string()),
|
||||
context_window: Some(200_000),
|
||||
input_cost_per_mtok: Some(15.0),
|
||||
output_cost_per_mtok: Some(75.0),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "claude-3-5-haiku-20241022".to_string(),
|
||||
name: "Claude 3.5 Haiku".to_string(),
|
||||
description: Some("Fast and efficient".to_string()),
|
||||
context_window: Some(200_000),
|
||||
input_cost_per_mtok: Some(1.0),
|
||||
output_cost_per_mtok: Some(5.0),
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
async fn start(&mut self, callback: StreamCallback) -> Result<(), String> {
|
||||
if self.process.is_some() {
|
||||
return Err("Process already running".to_string());
|
||||
}
|
||||
|
||||
callback(ProviderStreamEvent::Connected { session_id: None });
|
||||
|
||||
let mut command = self.build_command()?;
|
||||
|
||||
command
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped());
|
||||
|
||||
let mut child = command
|
||||
.spawn()
|
||||
.map_err(|e| format!("Failed to spawn process: {}", e))?;
|
||||
|
||||
let stdin = child.stdin.take();
|
||||
let stdout = child.stdout.take();
|
||||
let stderr = child.stderr.take();
|
||||
|
||||
self.stdin = stdin;
|
||||
self.process = Some(child);
|
||||
self.is_running.store(true, Ordering::SeqCst);
|
||||
|
||||
let is_running = self.is_running.clone();
|
||||
let callback = Arc::new(callback);
|
||||
|
||||
if let Some(stdout) = stdout {
|
||||
let callback_clone = callback.clone();
|
||||
let is_running_clone = is_running.clone();
|
||||
thread::spawn(move || {
|
||||
handle_stdout(stdout, callback_clone, is_running_clone);
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(stderr) = stderr {
|
||||
let callback_clone = callback.clone();
|
||||
thread::spawn(move || {
|
||||
handle_stderr(stderr, callback_clone);
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&mut self) -> Result<(), String> {
|
||||
if let Some(mut process) = self.process.take() {
|
||||
let _ = process.kill();
|
||||
let _ = process.wait();
|
||||
}
|
||||
self.stdin = None;
|
||||
self.session_id = None;
|
||||
self.mcp_config_file = None;
|
||||
self.is_running.store(false, Ordering::SeqCst);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_running(&self) -> bool {
|
||||
self.is_running.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
async fn send_message(&mut self, message: &str) -> Result<(), String> {
|
||||
let stdin = self.stdin.as_mut().ok_or("Process not running")?;
|
||||
|
||||
let input = serde_json::json!({
|
||||
"type": "user",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": message
|
||||
}]
|
||||
}
|
||||
});
|
||||
|
||||
let json_line = serde_json::to_string(&input).map_err(|e| e.to_string())?;
|
||||
|
||||
stdin
|
||||
.write_all(format!("{}\n", json_line).as_bytes())
|
||||
.map_err(|e| format!("Failed to write to stdin: {}", e))?;
|
||||
|
||||
stdin
|
||||
.flush()
|
||||
.map_err(|e| format!("Failed to flush stdin: {}", e))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_tool_result(
|
||||
&mut self,
|
||||
tool_use_id: &str,
|
||||
result: serde_json::Value,
|
||||
_is_error: bool,
|
||||
) -> Result<(), String> {
|
||||
let stdin = self.stdin.as_mut().ok_or("Process not running")?;
|
||||
|
||||
let content_str = serde_json::to_string(&result).map_err(|e| e.to_string())?;
|
||||
|
||||
let input = serde_json::json!({
|
||||
"type": "user",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_use_id,
|
||||
"content": content_str
|
||||
}]
|
||||
}
|
||||
});
|
||||
|
||||
let json_line = serde_json::to_string(&input).map_err(|e| e.to_string())?;
|
||||
|
||||
stdin
|
||||
.write_all(format!("{}\n", json_line).as_bytes())
|
||||
.map_err(|e| format!("Failed to write to stdin: {}", e))?;
|
||||
|
||||
stdin
|
||||
.flush()
|
||||
.map_err(|e| format!("Failed to flush stdin: {}", e))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_model(&self) -> Option<String> {
|
||||
self.config.model.clone()
|
||||
}
|
||||
|
||||
fn get_working_directory(&self) -> &str {
|
||||
&self.config.working_directory
|
||||
}
|
||||
|
||||
async fn interrupt(&mut self) -> Result<(), String> {
|
||||
if let Some(mut process) = self.process.take() {
|
||||
let _ = process.kill();
|
||||
let _ = process.wait();
|
||||
self.stdin = None;
|
||||
self.is_running.store(false, Ordering::SeqCst);
|
||||
Ok(())
|
||||
} else {
|
||||
Err("No active process to interrupt".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_stdout(
|
||||
stdout: std::process::ChildStdout,
|
||||
callback: Arc<StreamCallback>,
|
||||
is_running: Arc<AtomicBool>,
|
||||
) {
|
||||
let reader = BufReader::new(stdout);
|
||||
|
||||
for line in reader.lines() {
|
||||
match line {
|
||||
Ok(line) if !line.is_empty() => {
|
||||
if let Err(e) = process_json_line(&line, &callback) {
|
||||
eprintln!("Error processing line: {}", e);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error reading stdout: {}", e);
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
is_running.store(false, Ordering::SeqCst);
|
||||
callback(ProviderStreamEvent::Disconnected);
|
||||
}
|
||||
|
||||
fn handle_stderr(stderr: std::process::ChildStderr, callback: Arc<StreamCallback>) {
|
||||
let reader = BufReader::new(stderr);
|
||||
|
||||
for line in reader.lines() {
|
||||
match line {
|
||||
Ok(line) if !line.is_empty() => {
|
||||
callback(ProviderStreamEvent::Error { message: line });
|
||||
}
|
||||
Err(_) => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn process_json_line(line: &str, callback: &StreamCallback) -> Result<(), String> {
|
||||
let message: serde_json::Value = serde_json::from_str(line)
|
||||
.map_err(|e| format!("Failed to parse JSON: {} - Line: {}", e, line))?;
|
||||
|
||||
let msg_type = message.get("type").and_then(|t| t.as_str()).unwrap_or("");
|
||||
|
||||
match msg_type {
|
||||
"system" => {
|
||||
let subtype = message
|
||||
.get("subtype")
|
||||
.and_then(|s| s.as_str())
|
||||
.unwrap_or("");
|
||||
if subtype == "init" {
|
||||
let session_id = message
|
||||
.get("session_id")
|
||||
.and_then(|s| s.as_str())
|
||||
.map(|s| s.to_string());
|
||||
callback(ProviderStreamEvent::Connected { session_id });
|
||||
|
||||
if let Some(cwd) = message.get("cwd").and_then(|c| c.as_str()) {
|
||||
callback(ProviderStreamEvent::WorkingDirectory {
|
||||
path: cwd.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
"assistant" => {
|
||||
if let Some(msg) = message.get("message") {
|
||||
if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
|
||||
for block in content {
|
||||
let block_type = block.get("type").and_then(|t| t.as_str()).unwrap_or("");
|
||||
match block_type {
|
||||
"text" => {
|
||||
if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
|
||||
callback(ProviderStreamEvent::TextDelta {
|
||||
text: text.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
"thinking" => {
|
||||
if let Some(thinking) =
|
||||
block.get("thinking").and_then(|t| t.as_str())
|
||||
{
|
||||
callback(ProviderStreamEvent::ThinkingDelta {
|
||||
text: thinking.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
"tool_use" => {
|
||||
let id = block
|
||||
.get("id")
|
||||
.and_then(|i| i.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let name = block
|
||||
.get("name")
|
||||
.and_then(|n| n.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let input = block
|
||||
.get("input")
|
||||
.cloned()
|
||||
.unwrap_or(serde_json::Value::Null);
|
||||
|
||||
callback(ProviderStreamEvent::ToolUseStart {
|
||||
id: id.clone(),
|
||||
name,
|
||||
});
|
||||
callback(ProviderStreamEvent::ToolUseEnd { id, input });
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract usage if available
|
||||
let usage = msg.get("usage").and_then(|u| {
|
||||
let input_tokens = u.get("input_tokens").and_then(|t| t.as_u64())?;
|
||||
let output_tokens = u.get("output_tokens").and_then(|t| t.as_u64())?;
|
||||
let model = msg
|
||||
.get("model")
|
||||
.and_then(|m| m.as_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
Some(crate::providers::ProviderUsage {
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
model,
|
||||
})
|
||||
});
|
||||
|
||||
callback(ProviderStreamEvent::MessageComplete {
|
||||
content: vec![],
|
||||
usage,
|
||||
stop_reason: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
"result" => {
|
||||
let subtype = message
|
||||
.get("subtype")
|
||||
.and_then(|s| s.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
// Handle permission denials
|
||||
if let Some(denials) = message
|
||||
.get("permission_denials")
|
||||
.and_then(|d| d.as_array())
|
||||
{
|
||||
for denial in denials {
|
||||
let tool_name = denial
|
||||
.get("tool_name")
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
// Check if this is an AskUserQuestion
|
||||
if tool_name == "AskUserQuestion" {
|
||||
if let Some(tool_input) = denial.get("tool_input") {
|
||||
if let Some(questions) =
|
||||
tool_input.get("questions").and_then(|q| q.as_array())
|
||||
{
|
||||
if let Some(first_question) = questions.first() {
|
||||
let question = first_question
|
||||
.get("question")
|
||||
.and_then(|q| q.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
let header = first_question
|
||||
.get("header")
|
||||
.and_then(|h| h.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let multi_select = first_question
|
||||
.get("multiSelect")
|
||||
.and_then(|m| m.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
let options: Vec<QuestionOption> = first_question
|
||||
.get("options")
|
||||
.and_then(|opts| opts.as_array())
|
||||
.map(|opts| {
|
||||
opts.iter()
|
||||
.filter_map(|opt| {
|
||||
let label = opt
|
||||
.get("label")
|
||||
.and_then(|l| l.as_str())?;
|
||||
let description = opt
|
||||
.get("description")
|
||||
.and_then(|d| d.as_str())
|
||||
.map(|s| s.to_string());
|
||||
Some(QuestionOption {
|
||||
label: label.to_string(),
|
||||
description,
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let id = denial
|
||||
.get("tool_use_id")
|
||||
.and_then(|i| i.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
callback(ProviderStreamEvent::Question {
|
||||
id,
|
||||
question,
|
||||
header,
|
||||
options,
|
||||
multi_select,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let id = denial
|
||||
.get("tool_use_id")
|
||||
.and_then(|i| i.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let tool_input = denial
|
||||
.get("tool_input")
|
||||
.cloned()
|
||||
.unwrap_or(serde_json::Value::Null);
|
||||
let description = format_tool_description(&tool_name, &tool_input);
|
||||
|
||||
callback(ProviderStreamEvent::PermissionRequest {
|
||||
id,
|
||||
tool_name,
|
||||
tool_input,
|
||||
description,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if subtype != "success" {
|
||||
if let Some(result) = message.get("result").and_then(|r| r.as_str()) {
|
||||
callback(ProviderStreamEvent::Error {
|
||||
message: result.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn format_tool_description(name: &str, input: &serde_json::Value) -> String {
|
||||
match name {
|
||||
"Read" => {
|
||||
if let Some(path) = input.get("file_path").and_then(|v| v.as_str()) {
|
||||
format!("Reading file: {}", path)
|
||||
} else {
|
||||
"Reading file...".to_string()
|
||||
}
|
||||
}
|
||||
"Glob" => {
|
||||
if let Some(pattern) = input.get("pattern").and_then(|v| v.as_str()) {
|
||||
format!("Searching for files: {}", pattern)
|
||||
} else {
|
||||
"Searching for files...".to_string()
|
||||
}
|
||||
}
|
||||
"Grep" => {
|
||||
if let Some(pattern) = input.get("pattern").and_then(|v| v.as_str()) {
|
||||
format!("Searching for: {}", pattern)
|
||||
} else {
|
||||
"Searching in files...".to_string()
|
||||
}
|
||||
}
|
||||
"Edit" | "Write" => {
|
||||
if let Some(path) = input.get("file_path").and_then(|v| v.as_str()) {
|
||||
format!("Editing: {}", path)
|
||||
} else {
|
||||
"Editing file...".to_string()
|
||||
}
|
||||
}
|
||||
"Bash" => {
|
||||
if let Some(cmd) = input.get("command").and_then(|v| v.as_str()) {
|
||||
let truncated = if cmd.len() > 50 {
|
||||
format!("{}...", &cmd[..50])
|
||||
} else {
|
||||
cmd.to_string()
|
||||
};
|
||||
format!("Running: {}", truncated)
|
||||
} else {
|
||||
"Running command...".to_string()
|
||||
}
|
||||
}
|
||||
_ => format!("Using tool: {}", name),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_provider_name() {
|
||||
let config = ProviderConfig::default();
|
||||
let provider = ClaudeCliProvider::new(config);
|
||||
assert_eq!(provider.name(), "Claude CLI");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_capabilities() {
|
||||
let config = ProviderConfig::default();
|
||||
let provider = ClaudeCliProvider::new(config);
|
||||
let caps = provider.capabilities();
|
||||
|
||||
assert!(caps.supports_streaming);
|
||||
assert!(caps.supports_tools);
|
||||
assert!(caps.supports_vision);
|
||||
assert!(caps.supports_thinking);
|
||||
assert!(caps.supports_mcp);
|
||||
assert!(caps.supports_resume_session);
|
||||
assert_eq!(caps.max_context_tokens, Some(200_000));
|
||||
assert!(!caps.available_models.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_running_initial() {
|
||||
let config = ProviderConfig::default();
|
||||
let provider = ClaudeCliProvider::new(config);
|
||||
assert!(!provider.is_running());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_working_directory() {
|
||||
let config = ProviderConfig {
|
||||
working_directory: "/home/test".to_string(),
|
||||
..Default::default()
|
||||
};
|
||||
let provider = ClaudeCliProvider::new(config);
|
||||
assert_eq!(provider.get_working_directory(), "/home/test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_model() {
|
||||
let config = ProviderConfig {
|
||||
model: Some("claude-sonnet-4-20250514".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
let provider = ClaudeCliProvider::new(config);
|
||||
assert_eq!(
|
||||
provider.get_model(),
|
||||
Some("claude-sonnet-4-20250514".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_tool_description() {
|
||||
let input = serde_json::json!({"file_path": "/test/file.rs"});
|
||||
assert_eq!(
|
||||
format_tool_description("Read", &input),
|
||||
"Reading file: /test/file.rs"
|
||||
);
|
||||
|
||||
let input = serde_json::json!({"pattern": "*.rs"});
|
||||
assert_eq!(
|
||||
format_tool_description("Glob", &input),
|
||||
"Searching for files: *.rs"
|
||||
);
|
||||
|
||||
let input = serde_json::json!({"command": "ls -la"});
|
||||
assert_eq!(format_tool_description("Bash", &input), "Running: ls -la");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
mod claude_cli;
|
||||
mod ollama;
|
||||
mod traits;
|
||||
|
||||
// Re-exports for when providers are fully integrated
|
||||
#[allow(unused_imports)]
|
||||
pub use claude_cli::ClaudeCliProvider;
|
||||
#[allow(unused_imports)]
|
||||
pub use ollama::OllamaProvider;
|
||||
#[allow(unused_imports)]
|
||||
pub use traits::{
|
||||
LlmProvider, ModelInfo, ProviderCapabilities, ProviderConfig, ProviderMessage,
|
||||
ProviderStreamEvent, ProviderUsage, QuestionOption, StreamCallback,
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ProviderType {
|
||||
#[default]
|
||||
ClaudeCli,
|
||||
Ollama,
|
||||
OpenAi,
|
||||
Anthropic,
|
||||
Gemini,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl ProviderType {
|
||||
pub fn display_name(&self) -> &'static str {
|
||||
match self {
|
||||
ProviderType::ClaudeCli => "Claude CLI",
|
||||
ProviderType::Ollama => "Ollama (Local)",
|
||||
ProviderType::OpenAi => "OpenAI API",
|
||||
ProviderType::Anthropic => "Anthropic API",
|
||||
ProviderType::Gemini => "Google Gemini",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn description(&self) -> &'static str {
|
||||
match self {
|
||||
ProviderType::ClaudeCli => "Use Claude Code CLI for AI assistance",
|
||||
ProviderType::Ollama => "Use locally running Ollama models",
|
||||
ProviderType::OpenAi => "Direct OpenAI API access (GPT-4, etc.)",
|
||||
ProviderType::Anthropic => "Direct Anthropic API access (Claude models)",
|
||||
ProviderType::Gemini => "Direct Google Gemini API access",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn requires_api_key(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
ProviderType::OpenAi | ProviderType::Anthropic | ProviderType::Gemini
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Note: The new providers (OpenAI, Anthropic, Gemini) are implemented directly
|
||||
// in provider_bridge.rs using the Bridge pattern rather than the LlmProvider trait.
|
||||
// This simplifies the architecture while still providing full functionality.
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn create_provider(
|
||||
provider_type: ProviderType,
|
||||
config: ProviderConfig,
|
||||
) -> Box<dyn LlmProvider> {
|
||||
match provider_type {
|
||||
ProviderType::ClaudeCli => Box::new(ClaudeCliProvider::new(config)),
|
||||
ProviderType::Ollama => Box::new(OllamaProvider::new(config)),
|
||||
// The new API-based providers are handled in provider_bridge.rs
|
||||
ProviderType::OpenAi | ProviderType::Anthropic | ProviderType::Gemini => {
|
||||
// These providers use the Bridge pattern in provider_bridge.rs
|
||||
// Fall back to Claude CLI for trait-based usage
|
||||
Box::new(ClaudeCliProvider::new(config))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_provider_type_display_name() {
|
||||
assert_eq!(ProviderType::ClaudeCli.display_name(), "Claude CLI");
|
||||
assert_eq!(ProviderType::Ollama.display_name(), "Ollama (Local)");
|
||||
assert_eq!(ProviderType::OpenAi.display_name(), "OpenAI API");
|
||||
assert_eq!(ProviderType::Anthropic.display_name(), "Anthropic API");
|
||||
assert_eq!(ProviderType::Gemini.display_name(), "Google Gemini");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_type_default() {
|
||||
let default: ProviderType = Default::default();
|
||||
assert_eq!(default, ProviderType::ClaudeCli);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_type_serialization() {
|
||||
let claude = ProviderType::ClaudeCli;
|
||||
let json = serde_json::to_string(&claude).unwrap();
|
||||
assert_eq!(json, "\"claude_cli\"");
|
||||
|
||||
let ollama = ProviderType::Ollama;
|
||||
let json = serde_json::to_string(&ollama).unwrap();
|
||||
assert_eq!(json, "\"ollama\"");
|
||||
|
||||
let openai = ProviderType::OpenAi;
|
||||
let json = serde_json::to_string(&openai).unwrap();
|
||||
assert_eq!(json, "\"open_ai\"");
|
||||
|
||||
let anthropic = ProviderType::Anthropic;
|
||||
let json = serde_json::to_string(&anthropic).unwrap();
|
||||
assert_eq!(json, "\"anthropic\"");
|
||||
|
||||
let gemini = ProviderType::Gemini;
|
||||
let json = serde_json::to_string(&gemini).unwrap();
|
||||
assert_eq!(json, "\"gemini\"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_type_deserialization() {
|
||||
let claude: ProviderType = serde_json::from_str("\"claude_cli\"").unwrap();
|
||||
assert_eq!(claude, ProviderType::ClaudeCli);
|
||||
|
||||
let ollama: ProviderType = serde_json::from_str("\"ollama\"").unwrap();
|
||||
assert_eq!(ollama, ProviderType::Ollama);
|
||||
|
||||
let openai: ProviderType = serde_json::from_str("\"open_ai\"").unwrap();
|
||||
assert_eq!(openai, ProviderType::OpenAi);
|
||||
|
||||
let anthropic: ProviderType = serde_json::from_str("\"anthropic\"").unwrap();
|
||||
assert_eq!(anthropic, ProviderType::Anthropic);
|
||||
|
||||
let gemini: ProviderType = serde_json::from_str("\"gemini\"").unwrap();
|
||||
assert_eq!(gemini, ProviderType::Gemini);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_type_requires_api_key() {
|
||||
assert!(!ProviderType::ClaudeCli.requires_api_key());
|
||||
assert!(!ProviderType::Ollama.requires_api_key());
|
||||
assert!(ProviderType::OpenAi.requires_api_key());
|
||||
assert!(ProviderType::Anthropic.requires_api_key());
|
||||
assert!(ProviderType::Gemini.requires_api_key());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,537 @@
|
||||
// This provider connects to a local Ollama instance for LLM inference.
|
||||
// It will be actively used once providers are fully integrated with BridgeManager.
|
||||
#![allow(dead_code)]
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::providers::traits::{
|
||||
LlmProvider, ModelInfo, ProviderCapabilities, ProviderConfig, ProviderMessage,
|
||||
ProviderStreamEvent, ProviderUsage, StreamCallback,
|
||||
};
|
||||
|
||||
const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434";
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct OllamaChatRequest {
|
||||
model: String,
|
||||
messages: Vec<OllamaMessage>,
|
||||
stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
system: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct OllamaMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OllamaChatResponse {
|
||||
#[serde(default)]
|
||||
message: Option<OllamaResponseMessage>,
|
||||
#[serde(default)]
|
||||
done: bool,
|
||||
#[serde(default)]
|
||||
eval_count: Option<u64>,
|
||||
#[serde(default)]
|
||||
prompt_eval_count: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OllamaResponseMessage {
|
||||
#[serde(default)]
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OllamaTagsResponse {
|
||||
models: Vec<OllamaModelInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OllamaModelInfo {
|
||||
name: String,
|
||||
#[serde(default)]
|
||||
details: Option<OllamaModelDetails>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OllamaModelDetails {
|
||||
#[serde(default)]
|
||||
parameter_size: Option<String>,
|
||||
#[serde(default)]
|
||||
family: Option<String>,
|
||||
}
|
||||
|
||||
pub struct OllamaProvider {
|
||||
config: ProviderConfig,
|
||||
client: reqwest::Client,
|
||||
base_url: String,
|
||||
is_running: Arc<AtomicBool>,
|
||||
conversation_history: Vec<OllamaMessage>,
|
||||
cancel_tx: Option<mpsc::Sender<()>>,
|
||||
}
|
||||
|
||||
impl OllamaProvider {
|
||||
pub fn new(config: ProviderConfig) -> Self {
|
||||
let base_url = config
|
||||
.api_base_url
|
||||
.clone()
|
||||
.unwrap_or_else(|| DEFAULT_OLLAMA_URL.to_string());
|
||||
|
||||
Self {
|
||||
config,
|
||||
client: reqwest::Client::new(),
|
||||
base_url,
|
||||
is_running: Arc::new(AtomicBool::new(false)),
|
||||
conversation_history: Vec::new(),
|
||||
cancel_tx: None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn fetch_available_models(&self) -> Vec<ModelInfo> {
|
||||
let url = format!("{}/api/tags", self.base_url);
|
||||
|
||||
match self.client.get(&url).send().await {
|
||||
Ok(response) => {
|
||||
if let Ok(tags) = response.json::<OllamaTagsResponse>().await {
|
||||
tags.models
|
||||
.into_iter()
|
||||
.map(|m| {
|
||||
let description = m.details.as_ref().map(|d| {
|
||||
let mut desc_parts = Vec::new();
|
||||
if let Some(ref family) = d.family {
|
||||
desc_parts.push(family.clone());
|
||||
}
|
||||
if let Some(ref size) = d.parameter_size {
|
||||
desc_parts.push(format!("{} parameters", size));
|
||||
}
|
||||
if desc_parts.is_empty() {
|
||||
"Local model".to_string()
|
||||
} else {
|
||||
desc_parts.join(" - ")
|
||||
}
|
||||
});
|
||||
|
||||
ModelInfo {
|
||||
id: m.name.clone(),
|
||||
name: m.name,
|
||||
description,
|
||||
context_window: None,
|
||||
input_cost_per_mtok: Some(0.0), // Local = free!
|
||||
output_cost_per_mtok: Some(0.0),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
default_ollama_models()
|
||||
}
|
||||
}
|
||||
Err(_) => default_ollama_models(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn stream_chat(
|
||||
&mut self,
|
||||
callback: Arc<StreamCallback>,
|
||||
cancel_rx: mpsc::Receiver<()>,
|
||||
) -> Result<(), String> {
|
||||
let model = self
|
||||
.config
|
||||
.model
|
||||
.clone()
|
||||
.unwrap_or_else(|| "llama3.2".to_string());
|
||||
|
||||
let request = OllamaChatRequest {
|
||||
model: model.clone(),
|
||||
messages: self.conversation_history.clone(),
|
||||
stream: true,
|
||||
system: self.config.custom_instructions.clone(),
|
||||
};
|
||||
|
||||
let url = format!("{}/api/chat", self.base_url);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to connect to Ollama: {}", e))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(format!("Ollama error ({}): {}", status, error_text));
|
||||
}
|
||||
|
||||
let mut full_response = String::new();
|
||||
let mut total_input_tokens: u64 = 0;
|
||||
let mut total_output_tokens: u64 = 0;
|
||||
|
||||
let mut stream = response.bytes_stream();
|
||||
let mut cancel_rx = cancel_rx;
|
||||
|
||||
use futures_util::StreamExt;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel_rx.recv() => {
|
||||
callback(ProviderStreamEvent::Disconnected);
|
||||
return Ok(());
|
||||
}
|
||||
chunk = stream.next() => {
|
||||
match chunk {
|
||||
Some(Ok(bytes)) => {
|
||||
let text = String::from_utf8_lossy(&bytes);
|
||||
|
||||
for line in text.lines() {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Ok(response) = serde_json::from_str::<OllamaChatResponse>(line) {
|
||||
if let Some(msg) = &response.message {
|
||||
if !msg.content.is_empty() {
|
||||
full_response.push_str(&msg.content);
|
||||
callback(ProviderStreamEvent::TextDelta {
|
||||
text: msg.content.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(tokens) = response.prompt_eval_count {
|
||||
total_input_tokens = tokens;
|
||||
}
|
||||
if let Some(tokens) = response.eval_count {
|
||||
total_output_tokens = tokens;
|
||||
}
|
||||
|
||||
if response.done {
|
||||
self.conversation_history.push(OllamaMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: full_response.clone(),
|
||||
});
|
||||
|
||||
callback(ProviderStreamEvent::MessageComplete {
|
||||
content: vec![ProviderMessage::Text {
|
||||
content: full_response,
|
||||
}],
|
||||
usage: Some(ProviderUsage {
|
||||
input_tokens: total_input_tokens,
|
||||
output_tokens: total_output_tokens,
|
||||
model,
|
||||
}),
|
||||
stop_reason: Some("end_turn".to_string()),
|
||||
});
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
return Err(format!("Stream error: {}", e));
|
||||
}
|
||||
None => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for OllamaProvider {
|
||||
fn name(&self) -> &'static str {
|
||||
"Ollama"
|
||||
}
|
||||
|
||||
fn capabilities(&self) -> ProviderCapabilities {
|
||||
ProviderCapabilities {
|
||||
supports_streaming: true,
|
||||
supports_tools: false, // Ollama doesn't support tools natively yet
|
||||
supports_vision: true, // Some models support vision
|
||||
supports_thinking: false,
|
||||
supports_mcp: false,
|
||||
supports_resume_session: false,
|
||||
max_context_tokens: None, // Varies by model
|
||||
available_models: default_ollama_models(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn start(&mut self, callback: StreamCallback) -> Result<(), String> {
|
||||
if self.is_running.load(Ordering::SeqCst) {
|
||||
return Err("Provider already running".to_string());
|
||||
}
|
||||
|
||||
// Check if Ollama is reachable
|
||||
let url = format!("{}/api/tags", self.base_url);
|
||||
self.client.get(&url).send().await.map_err(|e| {
|
||||
format!(
|
||||
"Cannot connect to Ollama at {}. Is it running? Error: {}",
|
||||
self.base_url, e
|
||||
)
|
||||
})?;
|
||||
|
||||
self.is_running.store(true, Ordering::SeqCst);
|
||||
self.conversation_history.clear();
|
||||
|
||||
// Create cancellation channel
|
||||
let (tx, _rx) = mpsc::channel(1);
|
||||
self.cancel_tx = Some(tx);
|
||||
|
||||
callback(ProviderStreamEvent::Connected { session_id: None });
|
||||
|
||||
// Fetch and report available models
|
||||
let models = self.fetch_available_models().await;
|
||||
if !models.is_empty() {
|
||||
eprintln!(
|
||||
"[Ollama] Available models: {:?}",
|
||||
models.iter().map(|m| &m.id).collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&mut self) -> Result<(), String> {
|
||||
self.is_running.store(false, Ordering::SeqCst);
|
||||
self.conversation_history.clear();
|
||||
|
||||
if let Some(tx) = self.cancel_tx.take() {
|
||||
let _ = tx.send(()).await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_running(&self) -> bool {
|
||||
self.is_running.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
async fn send_message(&mut self, message: &str) -> Result<(), String> {
|
||||
if !self.is_running.load(Ordering::SeqCst) {
|
||||
return Err("Provider not running".to_string());
|
||||
}
|
||||
|
||||
// Add user message to history
|
||||
self.conversation_history.push(OllamaMessage {
|
||||
role: "user".to_string(),
|
||||
content: message.to_string(),
|
||||
});
|
||||
|
||||
// Create a new cancel channel for this request
|
||||
let (tx, _rx) = mpsc::channel(1);
|
||||
self.cancel_tx = Some(tx);
|
||||
|
||||
// We need a callback here, but we don't have access to it in send_message
|
||||
// This is a limitation of the current trait design
|
||||
// For now, we'll need to refactor to handle streaming properly
|
||||
// The callback should be stored from the start() call
|
||||
|
||||
// For the MVP, we'll emit events directly
|
||||
// In a real implementation, we'd need to restructure this
|
||||
|
||||
Err("send_message needs refactoring to work with stored callback".to_string())
|
||||
}
|
||||
|
||||
async fn send_tool_result(
|
||||
&mut self,
|
||||
_tool_use_id: &str,
|
||||
_result: serde_json::Value,
|
||||
_is_error: bool,
|
||||
) -> Result<(), String> {
|
||||
// Ollama doesn't support tool use natively
|
||||
Err("Ollama does not support tool use".to_string())
|
||||
}
|
||||
|
||||
fn get_model(&self) -> Option<String> {
|
||||
self.config.model.clone()
|
||||
}
|
||||
|
||||
fn get_working_directory(&self) -> &str {
|
||||
&self.config.working_directory
|
||||
}
|
||||
|
||||
async fn interrupt(&mut self) -> Result<(), String> {
|
||||
if let Some(tx) = self.cancel_tx.take() {
|
||||
let _ = tx.send(()).await;
|
||||
}
|
||||
self.is_running.store(false, Ordering::SeqCst);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn default_ollama_models() -> Vec<ModelInfo> {
|
||||
vec![
|
||||
ModelInfo {
|
||||
id: "llama3.2".to_string(),
|
||||
name: "Llama 3.2".to_string(),
|
||||
description: Some("Meta's latest compact model".to_string()),
|
||||
context_window: Some(128_000),
|
||||
input_cost_per_mtok: Some(0.0),
|
||||
output_cost_per_mtok: Some(0.0),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "llama3.2:1b".to_string(),
|
||||
name: "Llama 3.2 1B".to_string(),
|
||||
description: Some("Smallest Llama 3.2 variant".to_string()),
|
||||
context_window: Some(128_000),
|
||||
input_cost_per_mtok: Some(0.0),
|
||||
output_cost_per_mtok: Some(0.0),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "qwen2.5-coder".to_string(),
|
||||
name: "Qwen 2.5 Coder".to_string(),
|
||||
description: Some("Alibaba's coding-focused model".to_string()),
|
||||
context_window: Some(32_000),
|
||||
input_cost_per_mtok: Some(0.0),
|
||||
output_cost_per_mtok: Some(0.0),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "deepseek-coder-v2".to_string(),
|
||||
name: "DeepSeek Coder V2".to_string(),
|
||||
description: Some("DeepSeek's coding model".to_string()),
|
||||
context_window: Some(128_000),
|
||||
input_cost_per_mtok: Some(0.0),
|
||||
output_cost_per_mtok: Some(0.0),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "mistral".to_string(),
|
||||
name: "Mistral 7B".to_string(),
|
||||
description: Some("Fast and capable".to_string()),
|
||||
context_window: Some(32_000),
|
||||
input_cost_per_mtok: Some(0.0),
|
||||
output_cost_per_mtok: Some(0.0),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "gemma2".to_string(),
|
||||
name: "Gemma 2".to_string(),
|
||||
description: Some("Google's open model".to_string()),
|
||||
context_window: Some(8_000),
|
||||
input_cost_per_mtok: Some(0.0),
|
||||
output_cost_per_mtok: Some(0.0),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_provider_name() {
|
||||
let config = ProviderConfig::default();
|
||||
let provider = OllamaProvider::new(config);
|
||||
assert_eq!(provider.name(), "Ollama");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_capabilities() {
|
||||
let config = ProviderConfig::default();
|
||||
let provider = OllamaProvider::new(config);
|
||||
let caps = provider.capabilities();
|
||||
|
||||
assert!(caps.supports_streaming);
|
||||
assert!(!caps.supports_tools);
|
||||
assert!(!caps.supports_mcp);
|
||||
assert!(!caps.supports_resume_session);
|
||||
assert!(!caps.available_models.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_url() {
|
||||
let config = ProviderConfig::default();
|
||||
let provider = OllamaProvider::new(config);
|
||||
assert_eq!(provider.base_url, DEFAULT_OLLAMA_URL);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_url() {
|
||||
let config = ProviderConfig {
|
||||
api_base_url: Some("http://custom:8080".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
let provider = OllamaProvider::new(config);
|
||||
assert_eq!(provider.base_url, "http://custom:8080");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_running_initial() {
|
||||
let config = ProviderConfig::default();
|
||||
let provider = OllamaProvider::new(config);
|
||||
assert!(!provider.is_running());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_ollama_models() {
|
||||
let models = default_ollama_models();
|
||||
assert!(!models.is_empty());
|
||||
|
||||
// All models should be free (local)
|
||||
for model in &models {
|
||||
assert_eq!(model.input_cost_per_mtok, Some(0.0));
|
||||
assert_eq!(model.output_cost_per_mtok, Some(0.0));
|
||||
}
|
||||
|
||||
// Should include llama
|
||||
assert!(models.iter().any(|m| m.id.contains("llama")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ollama_message_serialization() {
|
||||
let msg = OllamaMessage {
|
||||
role: "user".to_string(),
|
||||
content: "Hello!".to_string(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
assert!(json.contains("\"role\":\"user\""));
|
||||
assert!(json.contains("\"content\":\"Hello!\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_request_serialization() {
|
||||
let request = OllamaChatRequest {
|
||||
model: "llama3.2".to_string(),
|
||||
messages: vec![OllamaMessage {
|
||||
role: "user".to_string(),
|
||||
content: "Test".to_string(),
|
||||
}],
|
||||
stream: true,
|
||||
system: Some("You are helpful".to_string()),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&request).unwrap();
|
||||
assert!(json.contains("\"model\":\"llama3.2\""));
|
||||
assert!(json.contains("\"stream\":true"));
|
||||
assert!(json.contains("\"system\":\"You are helpful\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_request_without_system() {
|
||||
let request = OllamaChatRequest {
|
||||
model: "llama3.2".to_string(),
|
||||
messages: vec![],
|
||||
stream: true,
|
||||
system: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&request).unwrap();
|
||||
// system should be omitted when None
|
||||
assert!(!json.contains("\"system\""));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,236 @@
|
||||
// These types are all used by the provider abstraction layer.
|
||||
// They will be actively used once providers are fully integrated with BridgeManager.
|
||||
#![allow(dead_code)]
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProviderConfig {
|
||||
pub api_key: Option<String>,
|
||||
pub api_base_url: Option<String>,
|
||||
pub model: Option<String>,
|
||||
pub custom_instructions: Option<String>,
|
||||
pub working_directory: String,
|
||||
pub mcp_servers_json: Option<String>,
|
||||
pub allowed_tools: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub extra_options: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
impl Default for ProviderConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
api_key: None,
|
||||
api_base_url: None,
|
||||
model: None,
|
||||
custom_instructions: None,
|
||||
working_directory: String::new(),
|
||||
mcp_servers_json: None,
|
||||
allowed_tools: Vec::new(),
|
||||
extra_options: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProviderCapabilities {
|
||||
pub supports_streaming: bool,
|
||||
pub supports_tools: bool,
|
||||
pub supports_vision: bool,
|
||||
pub supports_thinking: bool,
|
||||
pub supports_mcp: bool,
|
||||
pub supports_resume_session: bool,
|
||||
pub max_context_tokens: Option<u64>,
|
||||
pub available_models: Vec<ModelInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelInfo {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub context_window: Option<u64>,
|
||||
pub input_cost_per_mtok: Option<f64>,
|
||||
pub output_cost_per_mtok: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ProviderMessage {
|
||||
#[serde(rename = "text")]
|
||||
Text { content: String },
|
||||
#[serde(rename = "tool_use")]
|
||||
ToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input: serde_json::Value,
|
||||
},
|
||||
#[serde(rename = "tool_result")]
|
||||
ToolResult {
|
||||
tool_use_id: String,
|
||||
content: String,
|
||||
is_error: bool,
|
||||
},
|
||||
#[serde(rename = "thinking")]
|
||||
Thinking { content: String },
|
||||
#[serde(rename = "image")]
|
||||
Image { media_type: String, data: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProviderUsage {
|
||||
pub input_tokens: u64,
|
||||
pub output_tokens: u64,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ProviderStreamEvent {
|
||||
#[serde(rename = "connected")]
|
||||
Connected { session_id: Option<String> },
|
||||
#[serde(rename = "text_delta")]
|
||||
TextDelta { text: String },
|
||||
#[serde(rename = "thinking_delta")]
|
||||
ThinkingDelta { text: String },
|
||||
#[serde(rename = "tool_use_start")]
|
||||
ToolUseStart { id: String, name: String },
|
||||
#[serde(rename = "tool_use_delta")]
|
||||
ToolUseDelta { id: String, input_delta: String },
|
||||
#[serde(rename = "tool_use_end")]
|
||||
ToolUseEnd { id: String, input: serde_json::Value },
|
||||
#[serde(rename = "message_complete")]
|
||||
MessageComplete {
|
||||
content: Vec<ProviderMessage>,
|
||||
usage: Option<ProviderUsage>,
|
||||
stop_reason: Option<String>,
|
||||
},
|
||||
#[serde(rename = "permission_request")]
|
||||
PermissionRequest {
|
||||
id: String,
|
||||
tool_name: String,
|
||||
tool_input: serde_json::Value,
|
||||
description: String,
|
||||
},
|
||||
#[serde(rename = "question")]
|
||||
Question {
|
||||
id: String,
|
||||
question: String,
|
||||
header: Option<String>,
|
||||
options: Vec<QuestionOption>,
|
||||
multi_select: bool,
|
||||
},
|
||||
#[serde(rename = "working_directory")]
|
||||
WorkingDirectory { path: String },
|
||||
#[serde(rename = "error")]
|
||||
Error { message: String },
|
||||
#[serde(rename = "disconnected")]
|
||||
Disconnected,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct QuestionOption {
|
||||
pub label: String,
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
pub type StreamCallback = Box<dyn Fn(ProviderStreamEvent) + Send + Sync>;
|
||||
|
||||
#[async_trait]
|
||||
pub trait LlmProvider: Send + Sync {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
fn capabilities(&self) -> ProviderCapabilities;
|
||||
|
||||
async fn start(&mut self, callback: StreamCallback) -> Result<(), String>;
|
||||
|
||||
async fn stop(&mut self) -> Result<(), String>;
|
||||
|
||||
fn is_running(&self) -> bool;
|
||||
|
||||
async fn send_message(&mut self, message: &str) -> Result<(), String>;
|
||||
|
||||
async fn send_tool_result(
|
||||
&mut self,
|
||||
tool_use_id: &str,
|
||||
result: serde_json::Value,
|
||||
is_error: bool,
|
||||
) -> Result<(), String>;
|
||||
|
||||
fn get_model(&self) -> Option<String>;
|
||||
|
||||
fn get_working_directory(&self) -> &str;
|
||||
|
||||
async fn interrupt(&mut self) -> Result<(), String>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_provider_config_default() {
|
||||
let config = ProviderConfig::default();
|
||||
assert!(config.api_key.is_none());
|
||||
assert!(config.api_base_url.is_none());
|
||||
assert!(config.model.is_none());
|
||||
assert!(config.working_directory.is_empty());
|
||||
assert!(config.allowed_tools.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_message_serialization() {
|
||||
let msg = ProviderMessage::Text {
|
||||
content: "Hello!".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
assert!(json.contains("\"type\":\"text\""));
|
||||
assert!(json.contains("\"content\":\"Hello!\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_stream_event_serialization() {
|
||||
let event = ProviderStreamEvent::TextDelta {
|
||||
text: "chunk".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&event).unwrap();
|
||||
assert!(json.contains("\"type\":\"text_delta\""));
|
||||
|
||||
let event = ProviderStreamEvent::Connected {
|
||||
session_id: Some("test-123".to_string()),
|
||||
};
|
||||
let json = serde_json::to_string(&event).unwrap();
|
||||
assert!(json.contains("\"session_id\":\"test-123\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_info() {
|
||||
let model = ModelInfo {
|
||||
id: "claude-sonnet-4-20250514".to_string(),
|
||||
name: "Claude Sonnet 4".to_string(),
|
||||
description: Some("Fast and intelligent".to_string()),
|
||||
context_window: Some(200000),
|
||||
input_cost_per_mtok: Some(3.0),
|
||||
output_cost_per_mtok: Some(15.0),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&model).unwrap();
|
||||
assert!(json.contains("claude-sonnet-4"));
|
||||
assert!(json.contains("200000"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_usage() {
|
||||
let usage = ProviderUsage {
|
||||
input_tokens: 100,
|
||||
output_tokens: 50,
|
||||
model: "claude-opus-4-20250514".to_string(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&usage).unwrap();
|
||||
assert!(json.contains("\"input_tokens\":100"));
|
||||
assert!(json.contains("\"output_tokens\":50"));
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@
|
||||
configStore,
|
||||
type HikariConfig,
|
||||
type Theme,
|
||||
type ProviderType,
|
||||
type CustomThemeColors,
|
||||
applyFontSize,
|
||||
applyCustomThemeColors,
|
||||
@@ -14,11 +15,22 @@
|
||||
import { getCurrentWindow } from "@tauri-apps/api/window";
|
||||
|
||||
let config: HikariConfig = $state({
|
||||
provider_type: "claude_cli",
|
||||
model: null,
|
||||
api_key: null,
|
||||
custom_instructions: null,
|
||||
mcp_servers_json: null,
|
||||
auto_granted_tools: [],
|
||||
ollama_base_url: "http://localhost:11434",
|
||||
ollama_model: null,
|
||||
openai_api_key: null,
|
||||
openai_base_url: "https://api.openai.com/v1",
|
||||
openai_model: null,
|
||||
anthropic_api_key: null,
|
||||
anthropic_base_url: "https://api.anthropic.com",
|
||||
anthropic_model: null,
|
||||
gemini_api_key: null,
|
||||
gemini_model: null,
|
||||
theme: "dark",
|
||||
greeting_enabled: true,
|
||||
greeting_custom_prompt: null,
|
||||
@@ -72,12 +84,53 @@
|
||||
grantedTools = Array.from(tools);
|
||||
});
|
||||
|
||||
const availableProviders: { value: ProviderType; label: string; description: string }[] = [
|
||||
{ value: "claude_cli", label: "Claude CLI", description: "Use Claude Code CLI for AI assistance" },
|
||||
{ value: "ollama", label: "Ollama (Local)", description: "Use locally running Ollama models" },
|
||||
{ value: "open_ai", label: "OpenAI API", description: "Direct OpenAI API access (GPT-4o, etc.)" },
|
||||
{ value: "anthropic", label: "Anthropic API", description: "Direct Anthropic API access (Claude models)" },
|
||||
{ value: "gemini", label: "Google Gemini", description: "Direct Google Gemini API access" },
|
||||
];
|
||||
|
||||
const availableModels = [
|
||||
{ value: "", label: "Default (from ~/.claude)" },
|
||||
{ value: "claude-sonnet-4-20250514", label: "Claude Sonnet 4" },
|
||||
{ value: "claude-opus-4-20250514", label: "Claude Opus 4" },
|
||||
];
|
||||
|
||||
const ollamaModels = [
|
||||
{ value: "", label: "Default (llama3.2)" },
|
||||
{ value: "llama3.2", label: "Llama 3.2" },
|
||||
{ value: "llama3.2:1b", label: "Llama 3.2 1B" },
|
||||
{ value: "qwen2.5-coder", label: "Qwen 2.5 Coder" },
|
||||
{ value: "deepseek-coder-v2", label: "DeepSeek Coder V2" },
|
||||
{ value: "mistral", label: "Mistral 7B" },
|
||||
{ value: "gemma2", label: "Gemma 2" },
|
||||
];
|
||||
|
||||
const openaiModels = [
|
||||
{ value: "", label: "Default (gpt-4o)" },
|
||||
{ value: "gpt-4o", label: "GPT-4o" },
|
||||
{ value: "gpt-4o-mini", label: "GPT-4o Mini" },
|
||||
{ value: "gpt-4-turbo", label: "GPT-4 Turbo" },
|
||||
{ value: "gpt-3.5-turbo", label: "GPT-3.5 Turbo" },
|
||||
];
|
||||
|
||||
const anthropicModels = [
|
||||
{ value: "", label: "Default (Claude Sonnet 4.5)" },
|
||||
{ value: "claude-sonnet-4-5-20250514", label: "Claude Sonnet 4.5" },
|
||||
{ value: "claude-opus-4-20250514", label: "Claude Opus 4" },
|
||||
{ value: "claude-3-5-sonnet-20241022", label: "Claude 3.5 Sonnet" },
|
||||
{ value: "claude-3-5-haiku-20241022", label: "Claude 3.5 Haiku" },
|
||||
];
|
||||
|
||||
const geminiModels = [
|
||||
{ value: "", label: "Default (gemini-2.0-flash)" },
|
||||
{ value: "gemini-2.0-flash", label: "Gemini 2.0 Flash" },
|
||||
{ value: "gemini-1.5-pro", label: "Gemini 1.5 Pro" },
|
||||
{ value: "gemini-1.5-flash", label: "Gemini 1.5 Flash" },
|
||||
];
|
||||
|
||||
const commonTools = [
|
||||
"Read",
|
||||
"Write",
|
||||
@@ -207,6 +260,235 @@
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Provider Selection Section -->
|
||||
<section class="mb-6">
|
||||
<h3 class="text-sm font-medium text-[var(--accent-primary)] uppercase tracking-wider mb-3">
|
||||
AI Provider
|
||||
</h3>
|
||||
|
||||
<div class="space-y-2">
|
||||
{#each availableProviders as provider (provider.value)}
|
||||
<label class="flex items-start gap-3 p-3 rounded-lg border cursor-pointer transition-colors {config.provider_type === provider.value
|
||||
? 'border-[var(--accent-primary)] bg-[var(--accent-primary)]/10'
|
||||
: 'border-[var(--border-color)] bg-[var(--bg-primary)] hover:border-[var(--accent-primary)]/50'}">
|
||||
<input
|
||||
type="radio"
|
||||
name="provider"
|
||||
value={provider.value}
|
||||
checked={config.provider_type === provider.value}
|
||||
onchange={() => config.provider_type = provider.value}
|
||||
class="mt-1 w-4 h-4 text-[var(--accent-primary)] bg-[var(--bg-primary)] border-[var(--border-color)] focus:ring-[var(--accent-primary)]"
|
||||
/>
|
||||
<div class="flex-1">
|
||||
<div class="text-sm font-medium text-[var(--text-primary)]">{provider.label}</div>
|
||||
<div class="text-xs text-[var(--text-tertiary)]">{provider.description}</div>
|
||||
</div>
|
||||
</label>
|
||||
{/each}
|
||||
</div>
|
||||
|
||||
<!-- Ollama-specific settings -->
|
||||
{#if config.provider_type === "ollama"}
|
||||
<div class="mt-4 p-3 bg-[var(--bg-primary)] rounded-lg border border-[var(--border-color)]">
|
||||
<h4 class="text-sm font-medium text-[var(--text-primary)] mb-3">Ollama Settings</h4>
|
||||
|
||||
<!-- Ollama Base URL -->
|
||||
<div class="mb-3">
|
||||
<label for="ollama-url" class="block text-xs text-[var(--text-secondary)] mb-1">
|
||||
Base URL
|
||||
</label>
|
||||
<input
|
||||
id="ollama-url"
|
||||
type="text"
|
||||
bind:value={config.ollama_base_url}
|
||||
placeholder="http://localhost:11434"
|
||||
class="w-full px-3 py-2 text-sm bg-[var(--bg-secondary)] border border-[var(--border-color)] rounded-lg text-[var(--text-primary)] focus:outline-none focus:border-[var(--accent-primary)]"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- Ollama Model Selection -->
|
||||
<div class="mb-3">
|
||||
<label for="ollama-model" class="block text-xs text-[var(--text-secondary)] mb-1">
|
||||
Model
|
||||
</label>
|
||||
<select
|
||||
id="ollama-model"
|
||||
bind:value={config.ollama_model}
|
||||
class="w-full px-3 py-2 text-sm bg-[var(--bg-secondary)] border border-[var(--border-color)] rounded-lg text-[var(--text-primary)] focus:outline-none focus:border-[var(--accent-primary)]"
|
||||
>
|
||||
{#each ollamaModels as model (model.value)}
|
||||
<option value={model.value}>{model.label}</option>
|
||||
{/each}
|
||||
</select>
|
||||
<p class="text-xs text-[var(--text-tertiary)] mt-1">
|
||||
Make sure the model is downloaded via <code class="text-[var(--accent-secondary)]">ollama pull</code>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div class="p-2 bg-yellow-500/10 border border-yellow-500/30 rounded text-xs text-yellow-400">
|
||||
<strong>Note:</strong> Ollama doesn't support tools, MCP servers, or thinking blocks.
|
||||
For full Claude Code features, use the Claude CLI provider.
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- OpenAI-specific settings -->
|
||||
{#if config.provider_type === "open_ai"}
|
||||
<div class="mt-4 p-3 bg-[var(--bg-primary)] rounded-lg border border-[var(--border-color)]">
|
||||
<h4 class="text-sm font-medium text-[var(--text-primary)] mb-3">OpenAI Settings</h4>
|
||||
|
||||
<!-- OpenAI API Key -->
|
||||
<div class="mb-3">
|
||||
<label for="openai-api-key" class="block text-xs text-[var(--text-secondary)] mb-1">
|
||||
API Key
|
||||
</label>
|
||||
<input
|
||||
id="openai-api-key"
|
||||
type="password"
|
||||
bind:value={config.openai_api_key}
|
||||
placeholder="sk-..."
|
||||
class="w-full px-3 py-2 text-sm bg-[var(--bg-secondary)] border border-[var(--border-color)] rounded-lg text-[var(--text-primary)] focus:outline-none focus:border-[var(--accent-primary)]"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- OpenAI Base URL -->
|
||||
<div class="mb-3">
|
||||
<label for="openai-url" class="block text-xs text-[var(--text-secondary)] mb-1">
|
||||
Base URL <span class="text-[var(--text-tertiary)]">(for OpenAI-compatible APIs)</span>
|
||||
</label>
|
||||
<input
|
||||
id="openai-url"
|
||||
type="text"
|
||||
bind:value={config.openai_base_url}
|
||||
placeholder="https://api.openai.com/v1"
|
||||
class="w-full px-3 py-2 text-sm bg-[var(--bg-secondary)] border border-[var(--border-color)] rounded-lg text-[var(--text-primary)] focus:outline-none focus:border-[var(--accent-primary)]"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- OpenAI Model Selection -->
|
||||
<div class="mb-3">
|
||||
<label for="openai-model" class="block text-xs text-[var(--text-secondary)] mb-1">
|
||||
Model
|
||||
</label>
|
||||
<select
|
||||
id="openai-model"
|
||||
bind:value={config.openai_model}
|
||||
class="w-full px-3 py-2 text-sm bg-[var(--bg-secondary)] border border-[var(--border-color)] rounded-lg text-[var(--text-primary)] focus:outline-none focus:border-[var(--accent-primary)]"
|
||||
>
|
||||
{#each openaiModels as model (model.value)}
|
||||
<option value={model.value}>{model.label}</option>
|
||||
{/each}
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div class="p-2 bg-blue-500/10 border border-blue-500/30 rounded text-xs text-blue-400">
|
||||
<strong>Tip:</strong> You can use this with any OpenAI-compatible API (Groq, Together AI, etc.)
|
||||
by changing the Base URL.
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Anthropic-specific settings -->
|
||||
{#if config.provider_type === "anthropic"}
|
||||
<div class="mt-4 p-3 bg-[var(--bg-primary)] rounded-lg border border-[var(--border-color)]">
|
||||
<h4 class="text-sm font-medium text-[var(--text-primary)] mb-3">Anthropic Settings</h4>
|
||||
|
||||
<!-- Anthropic API Key -->
|
||||
<div class="mb-3">
|
||||
<label for="anthropic-api-key" class="block text-xs text-[var(--text-secondary)] mb-1">
|
||||
API Key
|
||||
</label>
|
||||
<input
|
||||
id="anthropic-api-key"
|
||||
type="password"
|
||||
bind:value={config.anthropic_api_key}
|
||||
placeholder="sk-ant-..."
|
||||
class="w-full px-3 py-2 text-sm bg-[var(--bg-secondary)] border border-[var(--border-color)] rounded-lg text-[var(--text-primary)] focus:outline-none focus:border-[var(--accent-primary)]"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- Anthropic Base URL -->
|
||||
<div class="mb-3">
|
||||
<label for="anthropic-url" class="block text-xs text-[var(--text-secondary)] mb-1">
|
||||
Base URL <span class="text-[var(--text-tertiary)]">(optional)</span>
|
||||
</label>
|
||||
<input
|
||||
id="anthropic-url"
|
||||
type="text"
|
||||
bind:value={config.anthropic_base_url}
|
||||
placeholder="https://api.anthropic.com"
|
||||
class="w-full px-3 py-2 text-sm bg-[var(--bg-secondary)] border border-[var(--border-color)] rounded-lg text-[var(--text-primary)] focus:outline-none focus:border-[var(--accent-primary)]"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- Anthropic Model Selection -->
|
||||
<div class="mb-3">
|
||||
<label for="anthropic-model" class="block text-xs text-[var(--text-secondary)] mb-1">
|
||||
Model
|
||||
</label>
|
||||
<select
|
||||
id="anthropic-model"
|
||||
bind:value={config.anthropic_model}
|
||||
class="w-full px-3 py-2 text-sm bg-[var(--bg-secondary)] border border-[var(--border-color)] rounded-lg text-[var(--text-primary)] focus:outline-none focus:border-[var(--accent-primary)]"
|
||||
>
|
||||
{#each anthropicModels as model (model.value)}
|
||||
<option value={model.value}>{model.label}</option>
|
||||
{/each}
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div class="p-2 bg-purple-500/10 border border-purple-500/30 rounded text-xs text-purple-400">
|
||||
<strong>Note:</strong> This uses the Anthropic API directly without Claude Code CLI features
|
||||
like tools, MCP, or thinking blocks.
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Gemini-specific settings -->
|
||||
{#if config.provider_type === "gemini"}
|
||||
<div class="mt-4 p-3 bg-[var(--bg-primary)] rounded-lg border border-[var(--border-color)]">
|
||||
<h4 class="text-sm font-medium text-[var(--text-primary)] mb-3">Gemini Settings</h4>
|
||||
|
||||
<!-- Gemini API Key -->
|
||||
<div class="mb-3">
|
||||
<label for="gemini-api-key" class="block text-xs text-[var(--text-secondary)] mb-1">
|
||||
API Key
|
||||
</label>
|
||||
<input
|
||||
id="gemini-api-key"
|
||||
type="password"
|
||||
bind:value={config.gemini_api_key}
|
||||
placeholder="AIza..."
|
||||
class="w-full px-3 py-2 text-sm bg-[var(--bg-secondary)] border border-[var(--border-color)] rounded-lg text-[var(--text-primary)] focus:outline-none focus:border-[var(--accent-primary)]"
|
||||
/>
|
||||
<p class="text-xs text-[var(--text-tertiary)] mt-1">
|
||||
Get your API key from <a href="https://aistudio.google.com/apikey" target="_blank" class="text-[var(--accent-secondary)] hover:underline">Google AI Studio</a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Gemini Model Selection -->
|
||||
<div class="mb-3">
|
||||
<label for="gemini-model" class="block text-xs text-[var(--text-secondary)] mb-1">
|
||||
Model
|
||||
</label>
|
||||
<select
|
||||
id="gemini-model"
|
||||
bind:value={config.gemini_model}
|
||||
class="w-full px-3 py-2 text-sm bg-[var(--bg-secondary)] border border-[var(--border-color)] rounded-lg text-[var(--text-primary)] focus:outline-none focus:border-[var(--accent-primary)]"
|
||||
>
|
||||
{#each geminiModels as model (model.value)}
|
||||
<option value={model.value}>{model.label}</option>
|
||||
{/each}
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div class="p-2 bg-green-500/10 border border-green-500/30 rounded text-xs text-green-400">
|
||||
<strong>Note:</strong> Gemini has a generous free tier! Great for experimenting.
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</section>
|
||||
|
||||
<!-- Agent Settings Section -->
|
||||
<section class="mb-6">
|
||||
<h3 class="text-sm font-medium text-[var(--accent-primary)] uppercase tracking-wider mb-3">
|
||||
|
||||
@@ -43,11 +43,22 @@
|
||||
let showProfile = $state(false);
|
||||
const progress = $derived($achievementProgress);
|
||||
let currentConfig: HikariConfig = $state({
|
||||
provider_type: "claude_cli",
|
||||
model: null,
|
||||
api_key: null,
|
||||
custom_instructions: null,
|
||||
mcp_servers_json: null,
|
||||
auto_granted_tools: [],
|
||||
ollama_base_url: "http://localhost:11434",
|
||||
ollama_model: null,
|
||||
openai_api_key: null,
|
||||
openai_base_url: "https://api.openai.com/v1",
|
||||
openai_model: null,
|
||||
anthropic_api_key: null,
|
||||
anthropic_base_url: "https://api.anthropic.com",
|
||||
anthropic_model: null,
|
||||
gemini_api_key: null,
|
||||
gemini_model: null,
|
||||
theme: "dark",
|
||||
greeting_enabled: true,
|
||||
greeting_custom_prompt: null,
|
||||
@@ -145,12 +156,23 @@
|
||||
await invoke("start_claude", {
|
||||
conversationId,
|
||||
options: {
|
||||
provider_type: currentConfig.provider_type || "claude_cli",
|
||||
working_dir: targetDir,
|
||||
model: currentConfig.model || null,
|
||||
api_key: currentConfig.api_key || null,
|
||||
custom_instructions: currentConfig.custom_instructions || null,
|
||||
mcp_servers_json: currentConfig.mcp_servers_json || null,
|
||||
allowed_tools: allAllowedTools,
|
||||
ollama_base_url: currentConfig.ollama_base_url || "http://localhost:11434",
|
||||
ollama_model: currentConfig.ollama_model || null,
|
||||
openai_api_key: currentConfig.openai_api_key || null,
|
||||
openai_base_url: currentConfig.openai_base_url || "https://api.openai.com/v1",
|
||||
openai_model: currentConfig.openai_model || null,
|
||||
anthropic_api_key: currentConfig.anthropic_api_key || null,
|
||||
anthropic_base_url: currentConfig.anthropic_base_url || "https://api.anthropic.com",
|
||||
anthropic_model: currentConfig.anthropic_model || null,
|
||||
gemini_api_key: currentConfig.gemini_api_key || null,
|
||||
gemini_model: currentConfig.gemini_model || null,
|
||||
},
|
||||
});
|
||||
} catch (error) {
|
||||
|
||||
@@ -2,6 +2,7 @@ import { writable, derived } from "svelte/store";
|
||||
import { invoke } from "@tauri-apps/api/core";
|
||||
|
||||
export type Theme = "dark" | "light" | "high-contrast" | "custom";
|
||||
export type ProviderType = "claude_cli" | "ollama" | "open_ai" | "anthropic" | "gemini";
|
||||
|
||||
export interface CustomThemeColors {
|
||||
bg_primary: string | null;
|
||||
@@ -15,11 +16,26 @@ export interface CustomThemeColors {
|
||||
}
|
||||
|
||||
export interface HikariConfig {
|
||||
provider_type: ProviderType;
|
||||
model: string | null;
|
||||
api_key: string | null;
|
||||
custom_instructions: string | null;
|
||||
mcp_servers_json: string | null;
|
||||
auto_granted_tools: string[];
|
||||
// Ollama settings
|
||||
ollama_base_url: string;
|
||||
ollama_model: string | null;
|
||||
// OpenAI settings
|
||||
openai_api_key: string | null;
|
||||
openai_base_url: string;
|
||||
openai_model: string | null;
|
||||
// Anthropic settings
|
||||
anthropic_api_key: string | null;
|
||||
anthropic_base_url: string;
|
||||
anthropic_model: string | null;
|
||||
// Gemini settings
|
||||
gemini_api_key: string | null;
|
||||
gemini_model: string | null;
|
||||
theme: Theme;
|
||||
greeting_enabled: boolean;
|
||||
greeting_custom_prompt: string | null;
|
||||
@@ -40,11 +56,22 @@ export interface HikariConfig {
|
||||
}
|
||||
|
||||
const defaultConfig: HikariConfig = {
|
||||
provider_type: "claude_cli",
|
||||
model: null,
|
||||
api_key: null,
|
||||
custom_instructions: null,
|
||||
mcp_servers_json: null,
|
||||
auto_granted_tools: [],
|
||||
ollama_base_url: "http://localhost:11434",
|
||||
ollama_model: null,
|
||||
openai_api_key: null,
|
||||
openai_base_url: "https://api.openai.com/v1",
|
||||
openai_model: null,
|
||||
anthropic_api_key: null,
|
||||
anthropic_base_url: "https://api.anthropic.com",
|
||||
anthropic_model: null,
|
||||
gemini_api_key: null,
|
||||
gemini_model: null,
|
||||
theme: "dark",
|
||||
greeting_enabled: true,
|
||||
greeting_custom_prompt: null,
|
||||
|
||||
Reference in New Issue
Block a user